diff --git a/.github/workflows/enflame-build-and-test.yml b/.github/workflows/enflame-build-and-test.yml new file mode 100644 index 000000000..c94855de9 --- /dev/null +++ b/.github/workflows/enflame-build-and-test.yml @@ -0,0 +1,69 @@ +name: Enflame-Build-And-Test + +on: + push: + branches: [ "triton_v3.3.x" ] + pull_request: + branches: [ "triton_v3.3.x" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + enflame-build-and-test: + runs-on: enflame + if: ${{ github.repository == 'FlagTree/flagtree' || github.repository == 'flagos-ai/flagtree' }} + steps: + - name: Setup environment + shell: bash + run: | + source ~/env.sh + env | grep -E '^(http_proxy|https_proxy|all_proxy|no_proxy)=' >> $GITHUB_ENV || true + + - name: Checkout code (attempt 1) + id: checkout1 + uses: actions/checkout@v5 + continue-on-error: true + + - name: Sleep before checkout2 + if: steps.checkout1.outcome == 'failure' + run: | + echo "First checkout attempt failed. Sleeping for 120 seconds before retry..." + sleep 120 + + - name: Checkout code (attempt 2) + id: checkout2 + if: steps.checkout1.outcome == 'failure' + uses: actions/checkout@v5 + continue-on-error: true + + - name: Sleep before final checkout + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + run: | + echo "Second checkout attempt failed. Sleeping for 180 seconds before final retry..." + sleep 180 + + - name: Checkout code (final attempt) + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + uses: actions/checkout@v5 + + - name: Verify checkout success + if: success() + run: echo "Checkout completed successfully" + + - name: FlagTree Build on Enflame + shell: bash + run: | + set -x + pip uninstall -y triton + pip uninstall -y triton_gcu + export FLAGTREE_BACKEND=enflame + cd python + MAX_JOBS=32 python3 -m pip install . --no-build-isolation + + - name: FlagTree Test on Enflame + shell: bash + run: | + set -x + python3 -m pytest -s third_party/enflame/python/test/unit diff --git a/CMakeLists.txt b/CMakeLists.txt index 16431d047..8f5c3fb3d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -213,7 +213,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party) include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files # link_directories(${LLVM_LIBRARY_DIR}) -if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu|tsingmicro)$") +if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu|tsingmicro|enflame)$") include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files add_subdirectory(include) diff --git a/README.md b/README.md index f14843930..0b51c4315 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,16 @@ cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=hcu python3 -m pip install . --no-build-isolation -v ``` - +[enflame](https://github.com/FlagTree/flagtree/triton_v3.3.x/main/third_party/enflame/) +```shell +# 推荐使用镜像: https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-flagtree-0.3.1.tar.gz +mkdir -p ~/.flagtree/enflame; cd ~/.flagtree/enflame +wget baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz +tar zxvf enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=enflame +python3 -m pip install . --no-build-isolation -v +``` [nvidia](/third_party/nvidia/) To build with default backends nvidia, amd, triton_shared cpu: ```shell diff --git a/README_cn.md b/README_cn.md index fdd39b23e..61b0cbff6 100644 --- a/README_cn.md +++ b/README_cn.md @@ -114,7 +114,16 @@ cd ${YOUR_CODE_DIR}/flagtree/python export FLAGTREE_BACKEND=hcu python3 -m pip install . --no-build-isolation -v ``` - +[enflame](https://github.com/FlagTree/flagtree/triton_v3.3.x/main/third_party/enflame/) +```shell +# 推荐使用镜像: https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-flagtree-0.3.1.tar.gz +mkdir -p ~/.flagtree/enflame; cd ~/.flagtree/enflame +wget baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz +tar zxvf enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=enflame +python3 -m pip install . --no-build-isolation -v +``` [nvidia](/third_party/nvidia/) 使用默认的构建命令,可以构建安装 nvidia、amd、triton_shared cpu 后端: ```shell diff --git a/python/setup.py b/python/setup.py index eefe8c965..d2c2f2f0e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -596,7 +596,7 @@ def build_extension(self, ext): ) if helper.flagtree_backend: - if helper.flagtree_backend in ("aipu", "tsingmicro"): + if helper.flagtree_backend in ("aipu", "tsingmicro", "enflame"): backends = [ *BackendInstaller.copy(helper.default_backends + helper.extend_backends), *BackendInstaller.copy_externals(), diff --git a/python/setup_tools/setup_helper.py b/python/setup_tools/setup_helper.py index db1fe5923..dbde2bc84 100644 --- a/python/setup_tools/setup_helper.py +++ b/python/setup_tools/setup_helper.py @@ -9,7 +9,7 @@ extend_backends = [] default_backends = ["nvidia", "amd"] -plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro"] +plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro", "enflame"] ext_sourcedir = "triton/_C/" flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() @@ -380,6 +380,21 @@ def check_env(env_val): post_hock=set_llvm_env, ) +# enflame +cache.store( + file="llvm-d752c5b-gcc9-x64", + condition=("enflame" == flagtree_backend), + ## TODO upload enflame llvm to blob storage + url = f"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz", + pre_hock=lambda: check_env('KURAMA_LLVM_DIR_GCU300'), + post_hock=lambda path: set_env({ + 'KURAMA_LLVM_DIR_GCU300': path, + 'LLVM_INCLUDE_DIRS': Path(path) / "include", + 'LLVM_LIBRARY_DIR': Path(path) / "lib", + 'LLVM_SYSPATH': path, + }), +) + # tsingmicro cache.store( file="tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-python3.11-x64", diff --git a/python/setup_tools/utils/enflame.py b/python/setup_tools/utils/enflame.py new file mode 100644 index 000000000..f02ba668c --- /dev/null +++ b/python/setup_tools/utils/enflame.py @@ -0,0 +1,27 @@ +import os +import shutil +from pathlib import Path +from build_helpers import get_cmake_dir + +def install_extension(*args, **kargs): + cmake_dir = get_cmake_dir() + binary_dir = cmake_dir / "bin" + python_root_dir = Path(__file__).parent.parent.parent + src_root_dir = python_root_dir.parent + + drvfile = src_root_dir / 'third_party' / 'nvidia' / 'backend' / 'driver.py' + with open(drvfile, 'r') as f: + lines = f.readlines() + for i, line in enumerate(lines): + if 'def is_active():' in line: + if not 'return False' in lines[i+1]: + lines.insert(i+1, ' return False\n') + break + with open(drvfile, 'w') as f: + f.writelines(lines) + + dst_dir = python_root_dir / "triton" / "backends" / "enflame" + for target in ["triton-gcu300-opt"]: + src_path = binary_dir / target + dst_path = dst_dir / target + shutil.copy(src_path, dst_path) \ No newline at end of file diff --git a/third_party/enflame/CMakeLists.txt b/third_party/enflame/CMakeLists.txt new file mode 100644 index 000000000..671744c90 --- /dev/null +++ b/third_party/enflame/CMakeLists.txt @@ -0,0 +1,74 @@ +#set(CMAKE_INCLUDE_DIRECTORIES_BEFORE ON) + + +set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install) +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +if (DEFINED ENV{LLVM_SYSPATH}) +else() +endif() + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +# 添加 cmake 模块路径,以便 include() 可以找到相应的文件 +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/2nd") + +# 添加 triton_enflame 插件的构建 +add_triton_plugin(TritonEnflame + ${CMAKE_CURRENT_SOURCE_DIR}/triton_enflame.cc + + DEPENDS + TritonTableGen +) + +# 设置包含目录 +target_include_directories(TritonEnflame PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/triton_gcu/include +) + +# 链接必要的库 +target_link_libraries(TritonEnflame PRIVATE + # MLIR 核心库 + MLIRIR + MLIRPass + MLIRTransforms + MLIROptLib + + # Python 绑定库 + Python3::Module + pybind11::headers + + # 如果需要链接 GCU 相关的库,可以在这里添加 + # GCUIRgcu300 + # TritonGCUIR_gcu300 + # MLIRTritonGCUCommon_gcu300 + # MLIRTritonToGCU_gcu300 + # MLIRTritonGCUTransforms_gcu300 +) + +# 设置编译选项 +target_compile_definitions(TritonEnflame PRIVATE + -DTRITON_ENFLAME_VERSION="1.0.0" +) + + +set(HACK_INCLUDE_DIRS true) +if(HACK_INCLUDE_DIRS) + + get_property(inc_dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) + + list(REMOVE_ITEM inc_dirs "${CMAKE_SOURCE_DIR}/include") + list(REMOVE_ITEM inc_dirs "${CMAKE_BINARY_DIR}/include") + list(REMOVE_ITEM inc_dirs "${CMAKE_SOURCE_DIR}/third_party") + list(REMOVE_ITEM inc_dirs "${CMAKE_BINARY_DIR}/third_party") + list(REMOVE_ITEM inc_dirs "${MLIR_INCLUDE_DIRS}") + list(REMOVE_ITEM inc_dirs "${LLVM_INCLUDE_DIRS}") + + set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "${include_dirs}") +endif() + +# 添加 triton_gcu 子目录构建 +add_subdirectory(triton_gcu) +add_subdirectory(include) diff --git a/third_party/enflame/README.md b/third_party/enflame/README.md new file mode 100644 index 000000000..2206bbe0d --- /dev/null +++ b/third_party/enflame/README.md @@ -0,0 +1,138 @@ +# Flagtree Framework - Enflame Accelerator Support + +## Overview + +Flagtree is a high-performance computing framework optimized for Enflame accelerators. This repository provides core component backend bindings and test suites for developing and deploying applications on Enflame hardware platforms. + +## Prerequisites + +- Linux host system with Docker support +- Enflame 3rd Generation Accelerator Card (S60) +- Minimum 16GB RAM (32GB recommended) +- 100GB available disk space + +## Environment Preparation + +### 1. Pull Source Code + +```bash +# Pull code and switch to triton_v3.3.x branch +cd ~ +git clone https://github.com/flagos-ai/flagtree.git +cd flagtree +git checkout triton_v3.3.x +``` + +### 2. Prepare Docker Image + +```bash +# Load pre-built container image +curl -sL https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-flagtree-0.3.1.tar.gz | docker load + +# Or manually download and load +wget https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-flagtree-0.3.1.tar.gz +docker load -i enflame-flagtree-0.3.1.tar.gz +``` + +### 3. Start Docker Container + +```bash +# To re-run container, remove the existing one +# docker rm -f enflame-flagtree + +# Assuming flagtree source code is located at ~/flagtree +docker run -itd \ + --privileged \ + --name enflame-flagtree \ + -v ~/flagtree:/root/flagtree \ + enflame/flagtree:0.3.1 bash +``` + +### 4. Install Driver + +```bash +# Extract and install Enflame driver +docker cp enflame-flagtree:/enflame enflame + +sudo bash enflame/driver/enflame-x86_64-gcc-1.6.3.12-20251115104629.run +# Use other arguments if prompt, e.g. +# sudo bash enflame/driver/enflame-x86_64-gcc-1.6.3.12-20251115104629.run --virt-host + +efsmi +``` + +Check driver status with efsmi. Example output: + +``` +------------------------------------------------------------------------------- +--------------------- Enflame System Management Interface --------------------- +--------- Enflame Tech, All Rights Reserved. 2024-2025 Copyright (C) ---------- +------------------------------------------------------------------------------- + ++2025-11-28, 10:50:14 CST-----------------------------------------------------+ +| EFSMI: 1.6.3.12 Driver Ver: 1.6.3.12 | ++-----------------------------+-------------------+---------------------------+ +| DEV NAME | FW VER | BUS-ID ECC | +| TEMP Lpm Pwr(Usage/Cap) | Mem GCU Virt | DUsed SN | +|=============================================================================| +| 0 Enflame S60G | 31.5.3 | 00:2e:00.0 Disable | +| 34℃ LP0 N/A | 23552MiB SRIOV | 0% A018K30520031 | ++-----------------------------+-------------------+---------------------------+ +| 1 Enflame S60G | 31.5.3 | 00:2f:00.0 Disable | +| 34℃ LP0 N/A | 23552MiB SRIOV | 0% A018K30520031 | ++-----------------------------+-------------------+---------------------------+ +``` + +### 5. Enter Docker Container + +```bash +# Execute docker +docker exec -it enflame-flagtree bash +``` + +> Note: All subsequent commands should be executed within the container. + +## Build and Install + +### 1. Prepare Toolchain + +``` +mkdir -p ~/.flagtree/enflame +cd ~/.flagtree/enflame +wget baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz +tar -xzf enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz +``` + +### 2. Configure Build Environment + +```bash +export FLAGTREE_BACKEND=enflame +git config --global --add safe.directory ~/flagtree +``` + +### 3. Install Python Dependencies + +```bash +cd ~/flagtree/python +pip3 install -r requirements.txt +``` + +### 4. Build and Install Package + +```bash +cd ~/flagtree/python + +# Initial build +pip3 install . --no-build-isolation -v + +# Rebuild after code modification +pip3 install . --no-build-isolation --force-reinstall -v +``` + +## Test Validation + +```bash +# Run unit tests +cd ~/flagtree +pytest third_party/enflame/python/test/unit +``` diff --git a/third_party/enflame/README_cn.md b/third_party/enflame/README_cn.md new file mode 100644 index 000000000..b4ccc6ffd --- /dev/null +++ b/third_party/enflame/README_cn.md @@ -0,0 +1,138 @@ +# Flagtree 框架 - 燧原加速器支持 + +## 概述 + +Flagtree 是针对燧原加速器优化的高性能计算框架。本代码仓库提供核心组件后端绑定和测试套件,用于在燧原硬件平台上开发和部署应用程序。 + +## 前提条件 + +- 支持 Docker 的 Linux 主机系统 +- 燧原第三代加速卡(S60) +- 最小 16GB 内存(推荐 32GB) +- 100GB 可用磁盘空间 + +## 环境准备 + +### 1. 拉取源代码 + +```bash +# 拉取代码并切换到triton_v3.3.x分支 +cd ~ +git clone https://github.com/flagos-ai/flagtree.git +cd flagtree +git checkout triton_v3.3.x +``` + +### 2. 准备 Docker 镜像 + +```bash +# 加载预构建的容器镜像 +curl -sL https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-flagtree-0.3.1.tar.gz | docker load + +# 或手动下载后加载 +wget https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-flagtree-0.3.1.tar.gz +docker load -i enflame-flagtree-0.3.1.tar.gz +``` + +### 3. 启动Docker容器 + +```bash +# 如果需要重建容器,请先删除 +# docker rm -f enflame-flagtree + +# 假设 flagtree 源码位于 ~/flagtree +docker run -itd \ + --privileged \ + --name enflame-flagtree \ + -v ~/flagtree:/root/flagtree \ + enflame/flagtree:0.3.1 bash +``` + +### 4. 安装驱动 + +```bash +# 提取并安装燧原驱动程序 +docker cp enflame-flagtree:/enflame enflame + +sudo bash enflame/driver/enflame-x86_64-gcc-1.6.3.12-20251115104629.run +# 如果上面的命令提示你使用其它参数,请按照提示操作,比如 +# sudo bash enflame/driver/enflame-x86_64-gcc-1.6.3.12-20251115104629.run --virt-host + +efsmi +``` + +用 efsmi 检查驱动是否正常安装,正常输出示意: + +``` +------------------------------------------------------------------------------- +--------------------- Enflame System Management Interface --------------------- +--------- Enflame Tech, All Rights Reserved. 2024-2025 Copyright (C) ---------- +------------------------------------------------------------------------------- + ++2025-11-28, 10:50:14 CST-----------------------------------------------------+ +| EFSMI: 1.6.3.12 Driver Ver: 1.6.3.12 | ++-----------------------------+-------------------+---------------------------+ +| DEV NAME | FW VER | BUS-ID ECC | +| TEMP Lpm Pwr(Usage/Cap) | Mem GCU Virt | DUsed SN | +|=============================================================================| +| 0 Enflame S60G | 31.5.3 | 00:2e:00.0 Disable | +| 34℃ LP0 N/A | 23552MiB SRIOV | 0% A018K30520031 | ++-----------------------------+-------------------+---------------------------+ +| 1 Enflame S60G | 31.5.3 | 00:2f:00.0 Disable | +| 34℃ LP0 N/A | 23552MiB SRIOV | 0% A018K30520031 | ++-----------------------------+-------------------+---------------------------+ +``` + +### 5. 进入Docker容器 + +```bash +# 执行docker +docker exec -it enflame-flagtree bash +``` + +> 注意,后续所有命令都在容器内进行。 + +## 编译构建 + +### 1. 准备工具链 + +``` +mkdir -p ~/.flagtree/enflame +cd ~/.flagtree/enflame +wget baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz +tar -xzf enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz +``` + +### 2. 配置构建环境 + +```bash +export FLAGTREE_BACKEND=enflame +git config --global --add safe.directory ~/flagtree +``` + +### 3. 安装 Python 依赖 + +```bash +cd ~/flagtree/python +pip3 install -r requirements.txt +``` + +### 4. 构建和安装包 + +```bash +cd ~/flagtree/python + +# 初始构建 +pip3 install . --no-build-isolation -v + +# 代码修改后重新构建 +pip3 install . --no-build-isolation --force-reinstall -v +``` + +## 测试验证 + +```bash +# 运行单元测试 +cd ~/flagtree +pytest third_party/enflame/python/test/unit +``` diff --git a/third_party/enflame/backend/.owners b/third_party/enflame/backend/.owners new file mode 100644 index 000000000..21b294f89 --- /dev/null +++ b/third_party/enflame/backend/.owners @@ -0,0 +1,3 @@ +xiaojuan.zhai +baoqi.liu +chongzhou.yang diff --git a/third_party/enflame/backend/__init__.py b/third_party/enflame/backend/__init__.py new file mode 100644 index 000000000..a06ea1450 --- /dev/null +++ b/third_party/enflame/backend/__init__.py @@ -0,0 +1,27 @@ +# +# Copyright 2024 Enflame. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#import triton_gcu.triton.libdevice +try: + import torch_gcu +except ImportError: + pass + +# append gcu backend and driver +#from triton.backends import Backend, backends +#from triton_gcu.triton.compiler import _GCUBackend +#from triton_gcu.triton.driver import _GCUDriver +#backends.clear() +#backends["gcu"] = Backend(_GCUBackend, _GCUDriver) diff --git a/third_party/enflame/backend/autotuner.py b/third_party/enflame/backend/autotuner.py new file mode 100644 index 000000000..a2a517aee --- /dev/null +++ b/third_party/enflame/backend/autotuner.py @@ -0,0 +1,50 @@ +from triton.testing import do_bench, do_bench_cudagraph +from triton.runtime.autotuner import Autotuner +from triton.runtime.errors import OutOfResources + + +class TritonGCUAutotuner(Autotuner): + + def _bench(self, *args, config, **meta): + from triton.compiler.errors import CompileTimeAssertionFailure + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(args, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(args, exception=None) + + try: + if self.use_cuda_graph: + import torch + with torch.cuda.stream(torch.cuda.Stream()): + bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median") + return bench_res + return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure): + return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] + + +Autotuner._bench = TritonGCUAutotuner._bench diff --git a/third_party/enflame/backend/backend.py b/third_party/enflame/backend/backend.py new file mode 100644 index 000000000..6ced9704d --- /dev/null +++ b/third_party/enflame/backend/backend.py @@ -0,0 +1,542 @@ +# +# Copyright 2024 Enflame. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sysconfig +import functools +import tempfile +from pathlib import Path +import hashlib + +import setuptools +from setuptools import Extension + +import torch +from triton.backends.enflame.toolkit import * +from triton.backends.enflame.filecache import get_cache_manager + +import importlib.metadata +import site + +from typing import Dict +from types import ModuleType + + +@functools.lru_cache() +def _version_key(): + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + return '-'.join(contents) + + +def _make_so_cache_key(version_hash, signature, constants, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key + + +################################################################################# +# below for gcu + +# gcu kernel translation + + +def _get_topscc_root(): + return os.getenv("CAPS_PATH", "/opt/tops") + + +def _kernel_to_fatbin(kernel: str, arch: int, enable_transform: bool): + print(kernel) + with tempfile.TemporaryDirectory() as tmpdir: + bin = os.path.join(tmpdir, "kernel.fatbin") + toolkit.compile(kernel, "--device-only", f"--arch=gcu{arch}", f"--output={bin}", + "--enable-transform" if enable_transform else "") + with open(bin, "rb") as f: + return f.read() + + +def build_gcu_ext(name, src, srcdir, extra_objects=[], extra_libraries=[]): + + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + local_lib_path = os.path.join(datadir, 'lib') + + # fallback on setuptools + extra_compile_args = ['-w'] + library_dirs = [os.path.join(_get_topscc_root(), "lib"), local_lib_path] + include_dirs = [os.path.join(_get_topscc_root(), "include")] + link_args = [f"-Wl,-rpath={local_lib_path}"] + libraries = ["topsrt"] + extra_libraries + define_macros = [] + + # create extension module + ext = Extension(name, [src], extra_objects=extra_objects, extra_compile_args=extra_compile_args, + include_dirs=include_dirs, library_dirs=library_dirs, libraries=libraries, + define_macros=define_macros, extra_link_args=link_args) + + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + setuptools.setup(**args) + return so + + +# +# GCU +# +class GCUUtils(object): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(GCUUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + utilsdir = os.path.join(datadir, "utils") + src = Path(os.path.join(utilsdir, "gcu.cpp")).read_text() + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + fname = "gcu_utils.so" + cache_path = cache.get_file(fname) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.cpp") + with open(src_path, "w") as f: + f.write(src) + so = build_gcu_ext("gcu_utils", src_path, tmpdir) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), fname, binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location("gcu_utils", cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "u1": "uint32_t", + "i8": "int8_t", + "u8": "uint8_t", + "i16": "int16_t", + "u16": "uint16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "f16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + "index": "int64_t", + }[ty] + + +def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + if ty[0] in ("constexpr"): + return "PyObject*" + return ty_to_cpp(ty) + + +def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + +def generate_launcher(constants, signature, arch='gcu300', no_constant_args=False): + start_desc = len(signature) + #signature = generate_cu_signature(constants, signature, ids) + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiKKOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + params = range(len(signature)) + params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] + + nonconst_args_size = 0 + for i, ty in signature.items(): + if ty != "constexpr": + nonconst_args_size += 1 + + # generate glue code + launch_str = '' + if 'gcu400' == arch or 'gcu410' == arch: + launch_str += f"""topsLaunchConfig_t config; + memset(&config, 0x0, sizeof(config)); + config.gridDim = dim3(gridX, gridY, gridZ); + config.blockDim = dim3(1, 1, 1); + auto att = (struct topsLaunchAttribute *)malloc( + sizeof(struct topsLaunchAttribute)); + att->id = topsLaunchAttributeThreadDimension; + att->val.ThreadDim.x = num_warps; + att->val.ThreadDim.y = 1; + att->val.ThreadDim.z = 1; + config.attrs = att; + config.numAttrs = 1; + config.stream = stream; + TOPS_CHECK(topsModuleLaunchKernelEx(&config, function, params, NULL));""" + else: + launch_str += 'TOPS_CHECK(topsModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps, 1, 1, shared_memory, stream, params, 0));' + src = f""" +#include +#include +#include +#include + +static inline void gcuAssert(topsError_t code, const char *file, int line) +{{ + if (code != TOPS_SUCCESS) + {{ + const char* prefix = "Kurama Error [TOPS]: "; + const char* str = topsGetErrorString(code); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + }} +}} + +#define TOPS_CHECK(ans) {{ gcuAssert((ans), __FILE__, __LINE__); }} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, topsStream_t stream, topsFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(params)} }}; + if (gridX*gridY*gridZ > 0) {{ + //printf("xxx %d %d %d\\n", gridX, gridY, gridZ); + {launch_str} + }} +}} + +typedef struct _DevicePtrInfo {{ + topsDeviceptr_t dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (topsDeviceptr_t)PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (topsDeviceptr_t)PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + uint64_t dev_ptr; + int status = topsPointerGetAttribute(&dev_ptr, TOPS_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == TOPS_ERROR_INVALID_VALUE) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Kurama (cpu tensor?)", idx); + ptr_info.valid = false; + }} + // ptr_info.dev_ptr = dev_ptr; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + // extract kernel metadata + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + return NULL; + }} + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items() if ty != "constexpr"])}; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (topsStream_t)_stream, (topsFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items() if ty != "constexpr") if nonconst_args_size > 0 else ''}); + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + if(PyErr_Occurred()) {{ + return NULL; + }} + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +def compile_module_from_src(src, name): + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.cpp") + with open(src_path, "w") as f: + f.write(src) + so = build_gcu_ext(name, src_path, tmpdir) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +class GcuLauncher(object): + + def __init__(self, src, metadata): + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = generate_launcher(constants, signature, metadata.arch) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class GCUDriver(object): + + def __init__(self): + self.utils = GCUUtils() + self.get_current_stream = lambda idx: torch.gcu.current_stream(idx).gcu_stream + self.get_current_device = lambda: torch.device(f"{device_name}:{torch.gcu.current_device()}").index + self.launcher_cls = GcuLauncher + + def get_device_properties(self, device): + props = self.utils.get_device_properties(device) + props["version"] = int(props["arch_name"].split('-')[-1][3:]) + del props["arch_name"] + return props + + def get_stream(self, idx=None): + if idx is None: + idx = self.get_current_device() + try: + return torch.gcu.current_stream(idx).gcu_stream + except: + return 0 + + def get_arch(self): + device = self.get_current_device() + device_properties = self.utils.get_device_properties(device) + arch = device_properties['arch_name'] + return arch + + def get_warp_size(self): + device = self.get_current_device() + device_properties = self.utils.get_device_properties(device) + warp_size = device_properties['max_threads_per_block'] + return warp_size + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + +class GCUBackend(object): + + def __init__(self) -> None: + self.driver = GCUDriver() + + def load_dialects(self, ctx): + pass + + @functools.lru_cache() + def hash(self): + return f'{self.get_architecture_descriptor()}' + + def get_architecture_descriptor(self, **kwargs): + device = self.driver.get_current_device() + device_properties = self.driver.get_device_properties(device) + capability = { + "max_threads_per_block": device_properties["max_threads_per_block"], "multiprocessor_count": + device_properties["multiprocessor_count"], "version": device_properties["version"], "max_shared_mem": + device_properties["max_shared_mem"] + } + return capability + + def compile_kernel(self, name, kernel, enable_transform, signature, constants): + arch = self.get_architecture_descriptor() + kernel_key = f'{name}-{hashlib.md5(str(kernel).encode("utf-8")).hexdigest()}' + cache_key = f"kernel-{arch['version']}-{_version_key()}" + cache_manager = get_cache_manager(cache_key) + cache_path = cache_manager.get_file(kernel_key) + if cache_path is None: + bin = _kernel_to_fatbin(kernel, arch["version"], enable_transform) + return cache_manager.put(bin, kernel_key, binary=True) + else: + return cache_path + + def get_arch_name(self): + return f"gcu{self.get_architecture_descriptor()['version']}" + + def get_num_clusters(self): + return self.get_architecture_descriptor()['multiprocessor_count'] + + def get_num_processors(self): + return self.get_architecture_descriptor()['max_threads_per_block'] + + def compile(self, name, kernel, enable_transform=False, signature={}, constants=[]): + kernel_path = self.compile_kernel(name, kernel, enable_transform, signature, constants) + with open(kernel_path, "rb") as binary: + bin = binary.read() + m, func, _, _ = self.get_load_binary_fn()(name, bin, 0, self.get_current_device()) + assert func != 0, "cannot find kenrel function" + launcher_path = self.make_launcher_stub(name, signature, constants, True) + import importlib.util + spec = importlib.util.spec_from_file_location("__triton_launcher", launcher_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return Kernel(name, m, func, mod.launch, constants) + + def get_version_key(self): + return _version_key() + + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.gcu import libdevice + return {"triton.language.extra.libdevice": libdevice} + + +driver = GCUBackend() + + +class Kernel(object): + + def __init__(self, name, mod, func, launcher, constants): + self.name = name + self.mod = mod + self.func = func + self.launcher = launcher + self.constants = constants + + def __call__(self, *args, gridX=1, gridY=1, gridZ=1, blockX=1): + arch = driver.get_architecture_descriptor() + self.launcher(gridX, gridY, gridZ, blockX, + #0, 0, 0, 0, + arch["max_shared_mem"], driver.get_stream(), self.func, None, None, None, *args) + + def __getitem__(self, dims): + blockX, *gridXYZ = dims + gridX = gridXYZ[0] if len(gridXYZ) >= 1 else 1 + gridY = gridXYZ[1] if len(gridXYZ) >= 2 else 1 + gridZ = gridXYZ[2] if len(gridXYZ) >= 3 else 1 + + def launcher(*args): + self.__call__(*args, gridX=gridX, gridY=gridY, gridZ=gridZ, blockX=blockX) + + return launcher diff --git a/third_party/enflame/backend/compiler.py b/third_party/enflame/backend/compiler.py new file mode 100644 index 000000000..8e2405df0 --- /dev/null +++ b/third_party/enflame/backend/compiler.py @@ -0,0 +1,272 @@ +# +# Copyright 2024 Enflame. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +import os +import tempfile +from pathlib import Path +from triton.backends.compiler import BaseBackend, GPUTarget +from triton.backends.enflame.backend import GCUBackend +from triton.backends.enflame import toolkit +from triton.backends.enflame.toolkit import * +from dataclasses import dataclass +import functools +from typing import Any, Tuple +import hashlib +from triton._C.libtriton import ir, passes, llvm +from typing import Dict +from types import ModuleType + + +def _patch_kernel(kernel): + # add gpu module + kernel = re.sub('module ([^\n]+)\n', 'module \\1\ngpu.module @triton {\n', kernel) + pattern = r'#loc\d* = loc\(.*?\)\n' + loc_lines = re.findall(pattern, kernel) + kernel = re.sub(pattern, '', kernel) + kernel = ''.join(loc_lines) + kernel.replace(pattern, '') + kernel += '}\n' + return kernel + + +def make_ttir(mod, metadata, options): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + #passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + +def make_ttgir(mod, metadata, options): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, f"gcu:{options.arch}", options.num_warps, options.warp_size, + options.num_ctas) + # passes.ttgpuir.add_coalesce(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + return mod + + +def make_gcuir(mod, metadata, options): + patched_mod = _patch_kernel(str(mod)) + metadata['name'] = re.search('tt.func public @(\\w+)\\(', patched_mod).group(1).strip() + passes = [] + if toolkit.get_bool_env("MLIR_ENABLE_DUMP"): + passes.append('-mlir-print-ir-after-all') + if toolkit.get_bool_env("MLIR_ENABLE_TIMING"): + passes.append('--mlir-timing') + passes.append('--mlir-timing-display=list') + if options.arch == "gcu300": + passes += [ + # '-mlir-disable-threading', + # '-mlir-print-ir-module-scope', + '-triton-gpu-to-triton-gcu', '-gcu64-type-verifier', '-convert-tensor-pointer', + '-triton-gcu-dot-layout-optimize', '-tritongpu-remove-layout-conversions', + '-convert-triton-load-store-to-gcu-dma', '-canonicalize', '-loop-invariant-code-motion', + '-gcu-triton-fusion', '-triton-gcu-data-layout-optimize', '-canonicalize', + '-triton-gcu-pingpong=' + 'num_stages=' + str(options.num_stages), '-flatten-triton-func', + '-convert-triton-to-gcu=' + 'vector-length=' + str(options.vector_length), '-cse', '-canonicalize' + ] + elif options.arch == "gcu400" or options.arch == "gcu410": + if toolkit.get_bool_env("ENABLE_I64_CHECK"): + passes.append('-gcu64-type-verifier') + LOAD_STORE_TO_DMA_PASS = '-convert-triton-load-store-to-gcu-dma' + if toolkit.get_bool_env("TRITON_GCU_ENABLE_STRIDE_BROADCAST"): + LOAD_STORE_TO_DMA_PASS += '=support_stride0=true' + passes += [ + # '-mlir-disable-threading', + # '-mlir-print-ir-module-scope', + '-triton-gpu-to-triton-gcu', '-convert-tensor-pointer', LOAD_STORE_TO_DMA_PASS, '-canonicalize', + '-loop-invariant-code-motion', '-gcu-combine-ops', '-gcu-triton-fusion=arch=' + options.arch, + '-canonicalize', '-flatten-triton-func', '-convert-triton-to-gcu', '-cse', '-canonicalize' + ] + return toolkit.triton_gcu_opt(patched_mod, *passes, arch=options.arch) + + +def make_llir(mod, metadata, options): + passes = [] + if toolkit.get_bool_env("MLIR_ENABLE_DUMP"): + passes.append('-mlir-print-ir-after-all') + if not toolkit.get_bool_env("TRITON_DISABLE_LINE_INFO", True): + passes.append('--ensure-debug-info-scope-on-llvm-func') + if toolkit.get_bool_env("MLIR_ENABLE_TIMING"): + passes.append('--mlir-timing') + passes.append('--mlir-timing-display=list') + passes += [ + '-insert-local-fence=arch=' + options.arch, '--convert-vector-to-scf=target-rank=1', '-lower-affine', + '-convert-vector-to-gcu=vector-bit-width=' + str(options.vector_length * 8), '-canonicalize', + '-convert-memref-to-gcu', '-kernel-memory-alloc=arch=' + options.arch + ' num-warps=' + str(options.num_warps), + '-loop-invariant-code-motion', '-convert-scf-to-cf', '-canonicalize', '-cse', '--symbol-dce', + '-gcu-remove-transform-ir', '-convert-vector-to-gcu=vector-bit-width=' + str(options.vector_length * 8), + '-canonicalize', + '--convert-gpu-to-gcu=chipset=' + options.arch + ' vector-bit-width=' + str(options.vector_length * 8), + '--gcu-attach-target=arch=' + options.arch, '-convert-index-to-llvm', '-gpu-to-llvm', '-convert-llvm-to-gcu', + '-alloca-to-entry', '-canonicalize' + ] + + ## Do nothing, until we figure out how to link .bc into triton_gcu. + #if options.extern_libs: + # paths = [path for (name, path) in options.extern_libs] + # llvm.link_extern_libs(llvm_mod, paths) + + return toolkit.gcu_compiler_opt(mod, *passes) + + +def make_fatbin(mod, metadata, options): + metadata['shared'] = int(re.search('gcu.shared_memory_size = (\\d+)', str(mod)).group(1).strip()) + with tempfile.TemporaryDirectory() as tmpdir: + bin = os.path.join(tmpdir, "kernel.fatbin") + toolkit.compile(mod, "--device-only", "--is-triton-backend", f"--arch={options.arch}", + f"--toolkit-path={datadir}", f"--output={bin}") + with open(bin, "rb") as f: + return f.read() + + +@functools.lru_cache(None) +def file_hash(path): + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +def min_dot_size(target: GPUTarget): + return lambda lhsType, rhsType: (1, 1, 1) + + +@dataclass() +class GCUOptions: + num_warps: int = 4 + warp_size: int = 1 + num_ctas: int = 1 + num_stages: int = 3 + arch: str = "gcu300" + vector_length: int = 512 + debug: bool = False + cluster_dims: tuple = (1, 1, 1) + allow_fp8e4nv: bool = False + allow_fp8e4b15: bool = False + supported_fp8_dtypes: Tuple[str] = () + deprecated_fp8_dtypes: Tuple[str] = () + default_dot_input_precision: str = "ieee" + allowed_dot_input_precisions: Tuple[str] = ("ieee", ) + backend_name: str = 'gcu' + max_num_imprecise_acc_default: int = 0 + enable_fp_fusion: bool = True + launch_cooperative_grid: bool = False + extern_libs: dict = None + sanitize_overflow: bool = False + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 + arch: str = None + + def __post_init__(self): + architecture = GCUBackend().get_architecture_descriptor() + self.arch = "gcu" + str(architecture['version']) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + if self.arch == "gcu400" or self.arch == "gcu410": + assert self.num_warps <= 4, "num_warps must not exceed 4" + self.vector_length = 2048 + self.allow_fp8e4nv = True + self.allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + self.max_num_imprecise_acc_default = 2**30 + self.supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5") + self.deprecated_fp8_dtypes: Tuple[str] = () + self.sanitize_overflow: bool = True + + ## register the libdevice + default_libdir = Path(__file__).parent / 'lib' + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get('libdevice', None): + extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.bc')) + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + + pass + + def hash(self): + ## Restore the code below, when we have libdevice.bc + #hash_dict = dict(self.__dict__) + #hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class _GCUBackend(BaseBackend): + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self._backend = GCUBackend() + self.binary_ext = "fatbin" + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'gcu' + + def parse_options(self, opts) -> Any: + args = {k: opts[k] for k in GCUOptions.__dataclass_fields__.keys() if k in opts} + + if "enable_fp_fusion" not in opts: + args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" + + args.update({k: opts[k] for k in GCUOptions.__dataclass_fields__.keys() if k in opts}) + return GCUOptions(**args) + + def load_dialects(self, ctx): + self._backend.load_dialects(ctx) + + @functools.lru_cache() + def hash(self): + return self._backend.hash() + + def get_architecture_descriptor(self, **kwargs): + return self._backend.get_architecture_descriptor(**kwargs) + + def get_codegen_implementation(self, options): + codegen_fns = {"min_dot_size": min_dot_size(self.target)} + return codegen_fns + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: make_ttgir(src, metadata, options) + stages["gcuir"] = lambda src, metadata: make_gcuir(src, metadata, options) + stages["llir"] = lambda src, metadata: make_llir(src, metadata, options) + stages["fatbin"] = lambda src, metadata: make_fatbin(src, metadata, options) + + def get_module_map(self) -> Dict[str, ModuleType]: + return self._backend.get_module_map() diff --git a/third_party/enflame/backend/driver.py b/third_party/enflame/backend/driver.py new file mode 100644 index 000000000..78d64d94c --- /dev/null +++ b/third_party/enflame/backend/driver.py @@ -0,0 +1,72 @@ +# +# Copyright 2024 Enflame. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from triton.backends.compiler import GPUTarget +from triton.backends.driver import DriverBase +from triton.backends.enflame.backend import GCUBackend, GCUDriver + + +class _GCUDriver(DriverBase): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(_GCUDriver, cls).__new__(cls) + return cls.instance + + def __init__(self): + self._driver = GCUDriver() + self.utils = self._driver.utils + self.backend = "gcu" + self.get_current_stream = self._driver.get_current_stream + self.get_current_device = self._driver.get_current_device + self.launcher_cls = self._driver.launcher_cls + + def get_active_torch_device(self): + import torch + return torch.device("gcu", self.get_current_device()) + + def get_device_properties(self, device): + return self._driver.get_device_properties(device) + + def get_stream(self, idx=None): + return self._driver.get_stream(id) + + def get_arch(self): + return self._driver.get_arch() + + def get_current_target(self): + arch = self._driver.get_arch() + warp_size = self._driver.get_warp_size() + return GPUTarget(self.backend, arch.split(':')[0], warp_size) + + @staticmethod + def is_active(): + return True + + def get_benchmarker(self): + return self._driver.get_benchmarker() + + def get_device_interface(self): + import torch + return torch.gcu + + def get_empty_cache_for_benchmark(self): + import torch + # It's the same as the Nvidia backend. + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='gcu') + + def clear_cache(self, cache): + cache.zero_() diff --git a/third_party/enflame/backend/filecache.py b/third_party/enflame/backend/filecache.py new file mode 100644 index 000000000..fc37c4e77 --- /dev/null +++ b/third_party/enflame/backend/filecache.py @@ -0,0 +1,170 @@ +# +# Copyright 2024 Enflame. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import json +import random +from pathlib import Path +import hashlib +from abc import ABC, abstractmethod +from typing import Dict, Optional + +############################################################# +# Cache manager is a modified version from triton + + +def default_cache_dir(): + return os.path.join(Path.home(), ".kurama", "cache") + + +def default_override_dir(): + return os.path.join(Path.home(), ".kurama", "override") + + +def default_dump_dir(): + return os.path.join(Path.home(), ".kurama", "dump") + + +class CacheManager(ABC): + + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def has_file(self, filename) -> bool: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("KURAMA_CACHE_DIR", "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = random.randint(0, 1000000) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use tempfile to be robust against program interruptions + temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + return filepath + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("KURAMA_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + import importlib + + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(key) diff --git a/third_party/enflame/backend/toolkit.py b/third_party/enflame/backend/toolkit.py new file mode 100644 index 000000000..3a39558a4 --- /dev/null +++ b/third_party/enflame/backend/toolkit.py @@ -0,0 +1,86 @@ +# +# Copyright 2024 Enflame. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from pathlib import Path +import subprocess + +device_name = "gcu" +datadir = "/opt/triton_gcu" +if not os.path.exists(datadir): + raise Exception("Cannot find data directory in " + datadir) + +TOOLKIT_PATH = os.path.join(datadir, "bin") +RUNTIME_PATH = os.path.join(datadir, "lib") + +PY_TOOLS_PATH = Path(__file__).parent + + +# toolkit +def _run_command(cmd, content, *args): + if not isinstance(content, str): + content = str(content) + result = subprocess.run([os.path.join(TOOLKIT_PATH, cmd)] + list(args), input=content, capture_output=True, + text=True, encoding="utf-8") + if result.returncode != 0: + raise Exception(result.stderr) + # print(__file__, "run command: \n", [os.path.join(TOOLKIT_PATH, cmd)] + list(args)) + # print subprocess std::cerr << log to terminator + print(result.stderr) + return result.stdout + + +def _run_command2(cmd, content, *args): + if not isinstance(content, str): + content = str(content) + result = subprocess.run([PY_TOOLS_PATH / cmd] + list(args), input=content, capture_output=True, text=True, + encoding="utf-8") + if result.returncode != 0: + raise Exception(result.stderr) + # print(__file__, "run command: \n", [os.path.join(PY_TOOLS_PATH, cmd)] + list(args)) + # print(subprocess.stderr) + print(result.stderr) + return result.stdout + + +def triton_gcu_opt(content, *args, arch): + passes = ["-mlir-print-op-generic"] + list(args) + if arch == "gcu410": + arch = "gcu400" + return _run_command2(f"triton-{arch}-opt", content, *passes) + + +def gcu_compiler_opt(content, *args): + passes = ["-mlir-print-op-generic"] + list(args) + return _run_command("gcu-compiler-opt", content, *passes) + + +def compile(content, *args): + return _run_command("gcu-compiler-compile", content, *args) + + +# Return the boolean value of an environment variable. +# +# Helpful environment variables: +# +# - "MLIR_ENABLE_DUMP=1` dumps the IR before every MLIR pass Triton runs and +# the IR after every MLIR pass GCU runs. +def get_bool_env(env, defaultValue=False): + s = os.getenv(env, "").lower() + if (s == "1" or s == "true" or s == "on"): + return True + if (s == "0" or s == "false" or s == "off"): + return False + return defaultValue diff --git a/third_party/enflame/cmake/triton_gcu.cmake b/third_party/enflame/cmake/triton_gcu.cmake new file mode 100644 index 000000000..86aa743a5 --- /dev/null +++ b/third_party/enflame/cmake/triton_gcu.cmake @@ -0,0 +1,163 @@ +# For Triton + +# ###################################################### +# Get LLVM for triton +include(triton_gcu_llvm) +include(triton_gcu_llvm_config) + +# Disable warnings that show up in external code (gtest;pybind11) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Werror -Wno-unused-parameter -Wno-unused-but-set-parameter") +include_directories(SYSTEM ${MLIR_INCLUDE_DIRS}) +include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files + +# 使用本地的 triton 文件,不需要下载 +set(third_party_triton_${arch}_fetch_src "${CMAKE_CURRENT_LIST_DIR}/../include/triton") +set(third_party_triton_${arch}_fetch_bin "${CMAKE_CURRENT_BINARY_DIR}/third_party_triton_${arch}_bin") +file(GLOB_RECURSE third_party_triton_${arch}_src "${CMAKE_CURRENT_LIST_DIR}/../include/triton/*") + +set(TRITON_SOURCE_DIR ${third_party_triton_${arch}_fetch_src}) +message(STATUS "TRITON_SOURCE_DIR: ${TRITON_SOURCE_DIR}") +set(TRITON_VERSION_FILE ${third_party_triton_${arch}_fetch_src}/python/triton/__init__.py) + +# 提取版本号 +execute_process( + COMMAND grep "__version__ = '" ${TRITON_VERSION_FILE} + COMMAND sed "s/.*__version__ = '\\([0-9]*\\.[0-9]*\\.[0-9]*\\)'.*/\\1/" + OUTPUT_VARIABLE TRITON_ORIG_VERSION_TEMP + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + RESULT_VARIABLE VERSION_EXTRACT_RESULT +) + +if(TRITON_ORIG_VERSION_TEMP AND VERSION_EXTRACT_RESULT EQUAL 0) + message(STATUS "Successfully extracted Triton version: ${TRITON_ORIG_VERSION_TEMP}") + set(TRITON_ORIG_VERSION ${TRITON_ORIG_VERSION_TEMP} CACHE STRING "Triton original version" FORCE) +else() + message(WARNING "Could not extract version from ${TRITON_VERSION_FILE}") + set(TRITON_ORIG_VERSION "unknown" CACHE STRING "Triton original version" FORCE) +endif() + +include(${CMAKE_CURRENT_LIST_DIR}/triton_${arch}.cmake) + +file(MAKE_DIRECTORY ${third_party_triton_${arch}_fetch_bin}) + +list(APPEND triton_cmake_args -DMLIR_DIR=${MLIR_DIR}) +list(APPEND triton_cmake_args -DLLVM_LIBRARY_DIR=${LLVM_LIBRARY_DIR}) +list(APPEND triton_cmake_args -DTRITON_BUILD_UT=OFF) +list(APPEND triton_cmake_args -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}) +list(APPEND triton_cmake_args -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}) +list(APPEND triton_cmake_args -DCMAKE_BUILD_TYPE=Release) + +add_custom_command( + OUTPUT ${triton_${arch}_objs} + COMMAND sed -i "/add_subdirectory\\(test\\)/d" ${third_party_triton_${arch}_fetch_src}/CMakeLists.txt + COMMAND sed -i "/add_subdirectory\\(bin\\)/d" ${third_party_triton_${arch}_fetch_src}/CMakeLists.txt + COMMAND cmake -S ${third_party_triton_${arch}_fetch_src} -B ${third_party_triton_${arch}_fetch_bin} ${triton_cmake_args} -DCMAKE_CXX_FLAGS="-Wno-reorder" -G Ninja + COMMAND cmake --build ${third_party_triton_${arch}_fetch_bin} --target all ${JOB_SETTING} + DEPENDS ${third_party_triton_${arch}_src} +) + +add_custom_target(third_party_triton_${arch}_fetch_build ALL DEPENDS ${triton_${arch}_objs}) + +add_library(triton_${arch} INTERFACE) +add_dependencies(triton_${arch} third_party_triton_${arch}_fetch_build) + +message(STATUS "third_party_triton_${arch}_fetch_bin is ${third_party_triton_${arch}_fetch_bin}") + + +include_directories(${third_party_triton_${arch}_fetch_src}/include) +include_directories(${third_party_triton_${arch}_fetch_bin}/include) # Tablegen'd files + +set(MLIR_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) + +#add_subdirectory(${third_party_triton_${arch}_fetch_src}/include ${third_party_triton_${arch}_fetch_bin}/include) +#add_subdirectory(${third_party_triton_${arch}_fetch_src}/third_party/f2reduce ${third_party_triton_${arch}_fetch_bin}/third_party/f2reduce) + +include_directories(${third_party_triton_${arch}_fetch_src}) +include_directories(${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/Transforms) # TritonCombine.inc +#add_subdirectory(${third_party_triton_${arch}_fetch_src}/lib ${third_party_triton_${arch}_fetch_bin}/lib) +# include_directories(${CMAKE_CURRENT_BINARY_DIR}/kernels) + +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(test) + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + + +add_llvm_executable(triton-${arch}-opt triton-${arch}-opt.cpp PARTIAL_SOURCES_INTENDED) +set_target_properties(triton-${arch}-opt PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +llvm_update_compile_flags(triton-${arch}-opt) + +target_link_libraries(triton-${arch}-opt PRIVATE +GCUIR${arch} +MemrefExtIR${arch} +MathExtIR${arch} +TritonGCUIR_${arch} +TritonGCUTestAnalysis_${arch} +MLIRTritonToGCU_${arch} +MLIRTritonGCUTransforms_${arch} +${dialect_libs} +${conversion_libs} +${translation_libs} +${extension_libs} +# MLIR core +MLIROptLib +MLIRPass +MLIRTransforms +${triton_${arch}_objs} +) +add_dependencies(triton-${arch}-opt triton_${arch}) + +mlir_check_all_link_libraries(triton-${arch}-opt) + +# target_compile_options(obj.TritonGCUAnalysis_${arch} PUBLIC $<$:-Wno-sign-compare>) +# target_compile_options(obj.MLIRTritonToGCU_${arch} PUBLIC $<$:-Wno-sign-compare -Wno-unused-variable>) +# target_compile_options(obj.TritonGCUIR_${arch} PUBLIC $<$:-Wno-sign-compare -Wno-unused-variable>) +# target_compile_options(obj.MLIRMemRefToGCU PUBLIC $<$:-Wno-sign-compare -Wno-unused-variable>) +# target_compile_options(triton-${arch}-opt PUBLIC $<$:-Wno-sign-compare -Wno-unused-variable>) +### + +# target_compile_options(TritonGPUToLLVM PUBLIC $<$:-Wno-maybe-uninitialized -Wno-extra -Wno-unused-variable>) +# target_compile_options(TritonGPUTransforms PUBLIC $<$:-Wno-maybe-uninitialized -Wno-extra>) +# target_compile_options(TritonIR PUBLIC $<$:-Wno-sign-compare -Wno-unused-but-set-variable -Wno-unused-variable>) +# target_compile_options(TritonGPUIR PUBLIC $<$:-Wno-maybe-uninitialized -Wno-extra -Wno-reorder -Wno-parentheses>) +# target_compile_options(TritonTransforms PUBLIC $<$:-Wno-maybe-uninitialized -Wno-extra>) +# target_compile_options(TritonTools PUBLIC $<$:-Wno-sign-compare -Wno-unused-but-set-variable -Wno-unused-function -Wno-unused-variable -Wno-parentheses>) +# target_compile_options(TritonNvidiaGPUIR PUBLIC $<$:-Wno-comment -Wno-reorder>) +# target_compile_options(PrintLoadStoreMemSpaces PUBLIC $<$:-Wno-unused-variable>) +# target_compile_options(TritonLLVMIR PUBLIC $<$:-Wno-unused-variable -Wno-unused-but-set-variable>) +target_compile_options(obj.TritonGCUAnalysis_${arch} PUBLIC $<$:-Wno-sign-compare -Wno-deprecated-declarations -Wno-unused-variable -Wno-parentheses -Wno-comment -Wno-maybe-uninitialized>) +target_compile_options(obj.MLIRTritonToGCU_${arch} PUBLIC $<$:-Wno-sign-compare -Wno-unused-variable -Wno-deprecated-declarations -Wno-reorder -Wno-unused-but-set-variable -Wno-comment -Wno-maybe-uninitialized>) +target_compile_options(obj.TritonGCUIR_${arch} PUBLIC $<$:-Wno-sign-compare -Wno-unused-variable -Wno-comment -Wno-maybe-uninitialized>) + +target_compile_options(triton-${arch}-opt PUBLIC $<$:-Wno-sign-compare -Wno-unused-variable -Wno-reorder -Wno-maybe-uninitialized>) + +# target_compile_options(TritonIR PUBLIC $<$:-Wno-sign-compare -Wno-unused-but-set-variable -Wno-unused-variable>) +# target_compile_options(TritonGPUIR PUBLIC $<$:-Wno-extra -Wno-reorder -Wno-parentheses -Wno-unused-function>) +# target_compile_options(TritonAnalysis PUBLIC $<$:-Wno-reorder-ctor>) +target_compile_options(obj.TritonGCUAnalysis_${arch} PUBLIC $<$:-Wno-sign-compare -Wno-deprecated-declarations -Wno-unused-variable -Wno-parentheses -Wno-comment -Wno-maybe-uninitialized>) +# target_compile_options(TritonTransforms PUBLIC $<$:-Wno-deprecated-copy>) +# target_compile_options(TritonTools PUBLIC $<$:-Wno-sign-compare -Wno-unused-function -Wno-unused-variable -Wno-parentheses>) + +set(KURAMA_TOOLS_TARGET + triton-${arch}-opt +) + +add_custom_target(triton-${arch}-tools ALL DEPENDS + ${KURAMA_TOOLS_TARGET} +) + +# 将TRITON_ORIG_VERSION变量提升到所有上级作用域 +set(TRITON_ORIG_VERSION ${TRITON_ORIG_VERSION} PARENT_SCOPE) +set(TRITON_ORIG_VERSION ${TRITON_ORIG_VERSION} CACHE STRING "Triton original version" FORCE) +set_property(GLOBAL PROPERTY TRITON_ORIG_VERSION ${TRITON_ORIG_VERSION}) + +message(STATUS "TRITON_ORIG_VERSION (${TRITON_ORIG_VERSION}) is now available in all parent scopes") diff --git a/third_party/enflame/cmake/triton_gcu300.cmake b/third_party/enflame/cmake/triton_gcu300.cmake new file mode 100644 index 000000000..7fe591fd6 --- /dev/null +++ b/third_party/enflame/cmake/triton_gcu300.cmake @@ -0,0 +1,108 @@ + +set(triton_${arch}_objs +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/AllocateSharedMemory.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/AllocateWarpGroups.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/AssertOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/ControlFlowOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/ConvertLayoutOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/DecomposeUnsupportedConversions.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/ElementwiseOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/FuncOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/GatherOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/GlobalScratchMemoryAllocation.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/HistogramOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/MakeRangeOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/MemoryOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/PrintOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/ReduceOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/SPMDOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/ScanOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/TypeConverter.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/Utility.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/ViewOpToLLVM.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/DotOpToLLVM/FMA.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonGPUToLLVM/CMakeFiles/TritonGPUToLLVM.dir/DotOpToLLVM/FMADotUtility.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonToTritonGPU/CMakeFiles/TritonToTritonGPU.dir/TritonGPUConversion.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Conversion/TritonToTritonGPU/CMakeFiles/TritonToTritonGPU.dir/TritonToTritonGPUPass.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Analysis/CMakeFiles/TritonAnalysis.dir/Allocation.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Analysis/CMakeFiles/TritonAnalysis.dir/AxisInfo.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Analysis/CMakeFiles/TritonAnalysis.dir/Alias.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Analysis/CMakeFiles/TritonAnalysis.dir/Membar.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Analysis/CMakeFiles/TritonAnalysis.dir/Utility.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/AccelerateMatmul.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Coalesce.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/CoalesceAsyncCopy.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/CombineTensorSelectAndIf.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/DecomposeScaledBlocked.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/F32DotTC.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/FuseNestedLoops.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/OptimizeAccumulatorInit.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/OptimizeDotOperands.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/OptimizeThreadLocality.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/PingPong.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Prefetch.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/ReduceDataDuplication.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/RemoveLayoutConversions.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/ReorderInstructions.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/TaskIdPropagate.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Utility.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/WSCanonicalization.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/WSCodePartition.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/WSDataPartition.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/WSLowering.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/WSTaskPartition.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/AssignLatencies.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/LowerLoops.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/ModifiedAccMMAPipeline.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/PipelineExpander.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/PipeliningUtility.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/Schedule.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/ScheduleLoops.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/SoftwarePipeliner.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/TC05MMAPipeline.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/TMAStoresPipeline.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/TestPipelineAssignLatencies.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/TestPipelineLowerLoop.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/TestPipelineScheduleLoop.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/WGMMAPipeline.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/IR/CMakeFiles/TritonGPUIR.dir/Ops.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/IR/CMakeFiles/TritonGPUIR.dir/Dialect.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/IR/CMakeFiles/TritonGPUIR.dir/LinearLayoutConversions.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonGPU/IR/CMakeFiles/TritonGPUIR.dir/Types.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeFiles/TritonNvidiaGPUTransforms.dir/FenceInsertion.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeFiles/TritonNvidiaGPUTransforms.dir/KeepAccInTMem.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeFiles/TritonNvidiaGPUTransforms.dir/MMALowering.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeFiles/TritonNvidiaGPUTransforms.dir/PlanCTA.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeFiles/TritonNvidiaGPUTransforms.dir/PromoteLHSToTMem.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeFiles/TritonNvidiaGPUTransforms.dir/TMALowering.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeFiles/TritonNvidiaGPUTransforms.dir/TensorMemoryAllocation.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeFiles/TritonNvidiaGPUTransforms.dir/Utility.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/IR/CMakeFiles/TritonNvidiaGPUIR.dir/Ops.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/IR/CMakeFiles/TritonNvidiaGPUIR.dir/Dialect.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/TritonNvidiaGPU/IR/CMakeFiles/TritonNvidiaGPUIR.dir/Types.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/Transforms/CMakeFiles/TritonTransforms.dir/LoopUnroll.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/Transforms/CMakeFiles/TritonTransforms.dir/ReorderBroadcast.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/Transforms/CMakeFiles/TritonTransforms.dir/RewriteTensorPointer.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/Transforms/CMakeFiles/TritonTransforms.dir/Combine.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/IR/CMakeFiles/TritonIR.dir/Ops.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/IR/CMakeFiles/TritonIR.dir/Dialect.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/IR/CMakeFiles/TritonIR.dir/Types.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/IR/CMakeFiles/TritonIR.dir/Traits.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Dialect/Triton/IR/CMakeFiles/TritonIR.dir/OpInterfaces.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Target/LLVMIR/CMakeFiles/TritonLLVMIR.dir/LLVMDIScope.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Target/LLVMIR/CMakeFiles/TritonLLVMIR.dir/LLVMIRBreakPhiStruct.cpp.o + +${third_party_triton_${arch}_fetch_bin}/lib/Tools/CMakeFiles/TritonTools.dir/LinearLayout.cpp.o +${third_party_triton_${arch}_fetch_bin}/lib/Tools/CMakeFiles/TritonTools.dir/LayoutUtils.cpp.o + +${third_party_triton_${arch}_fetch_bin}/third_party/f2reduce/CMakeFiles/f2reduce.dir/f2reduce.cpp.o +) diff --git a/third_party/enflame/cmake/triton_gcu_llvm.cmake b/third_party/enflame/cmake/triton_gcu_llvm.cmake new file mode 100644 index 000000000..c2ada4a9a --- /dev/null +++ b/third_party/enflame/cmake/triton_gcu_llvm.cmake @@ -0,0 +1,10 @@ +# 使用预编译的 LLVM +# 检查环境变量中指定的 LLVM 路径 +if(DEFINED ENV{KURAMA_LLVM_DIR_${ARCH}}) + message(STATUS ": using user provide llvm path $ENV{KURAMA_LLVM_DIR_${ARCH}}") + set(KURAMA_LLVM_DIR_${ARCH} "$ENV{KURAMA_LLVM_DIR_${ARCH}}") +elseif(KURAMA_LLVM_DIR_${ARCH} AND EXISTS ${KURAMA_LLVM_DIR_${ARCH}}/lib/cmake) + message(STATUS ": using previous exists llvm") +else() + message(FATAL_ERROR "KURAMA_LLVM_DIR_${ARCH} environment variable is not set or LLVM not found at specified path") +endif() diff --git a/third_party/enflame/cmake/triton_gcu_llvm_config.cmake b/third_party/enflame/cmake/triton_gcu_llvm_config.cmake new file mode 100644 index 000000000..bbee33d51 --- /dev/null +++ b/third_party/enflame/cmake/triton_gcu_llvm_config.cmake @@ -0,0 +1,47 @@ +message(STATUS ": set llvm path ${KURAMA_LLVM_DIR_${ARCH}}") +message(STATUS ": Use the LLVM version with revision to re-config for the kurama") + +set(LLVM_ROOT_DIR ${KURAMA_LLVM_DIR_${ARCH}}) + +if(DEFINED BUILD_CAPS_PATH) + set(VAR_BUILD_CAPS_PATH "${BUILD_CAPS_PATH}" CACHE PATH "Fallback path to search for BUILD CAPS installs") +elseif(DEFINED ENV{BUILD_CAPS_PATH}) + set(VAR_BUILD_CAPS_PATH "$ENV{BUILD_CAPS_PATH}" CACHE PATH "Fallback path to search for BUILD CAPS installs") +elseif(EXISTS ${CMAKE_BINARY_DIR}/opt/tops) + set(VAR_BUILD_CAPS_PATH "${CMAKE_BINARY_DIR}/opt/tops" CACHE PATH "search build dir for BUILD CAPS installs") +else() + set(VAR_BUILD_CAPS_PATH "/opt/tops" CACHE PATH "Fallback path to search for BUILD CAPS installs") +endif() + +message(STATUS ": MLIR Default BUILD CAPS toolkit path: ${VAR_BUILD_CAPS_PATH}") + +if(DEFINED EXECUTION_CAPS_PATH) + set(VAR_EXECUTION_CAPS_PATH "${EXECUTION_CAPS_PATH}" CACHE PATH "Fallback path to search for TEST CAPS installs") +elseif(DEFINED ENV{EXECUTION_CAPS_PATH}) + set(VAR_EXECUTION_CAPS_PATH "$ENV{EXECUTION_CAPS_PATH}" CACHE PATH "Fallback path to search for TEST CAPS installs") +else() + set(VAR_EXECUTION_CAPS_PATH "/opt/tops" CACHE PATH "Fallback path to search for TEST CAPS installs") +endif() + +message(STATUS ": MLIR Default TEST CAPS toolkit path: ${VAR_EXECUTION_CAPS_PATH}") + +set(LLVM_EXTERNAL_LIT ${VAR_BUILD_CAPS_PATH}/bin/llvm-lit) + +set(LLVM_LIBRARY_DIR ${LLVM_ROOT_DIR}/lib) + +# LLVM +set(LLVM_DIR ${LLVM_LIBRARY_DIR}/cmake/llvm) +message(STATUS ": llvm found in ${LLVM_DIR}") +find_package(LLVM REQUIRED HINTS ${LLVM_DIR}) + +# MLIR +set(MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir) +message(STATUS ": mlir found in ${MLIR_DIR}") +find_package(MLIR REQUIRED CONFIG HINTS ${MLIR_DIR}) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + +include(${LLVM_CMAKE_DIR}/TableGen.cmake) # required by AddMLIR +include(${LLVM_CMAKE_DIR}/AddLLVM.cmake) +include(${MLIR_CMAKE_DIR}/AddMLIR.cmake) diff --git a/third_party/enflame/include/CMakeLists.txt b/third_party/enflame/include/CMakeLists.txt new file mode 100644 index 000000000..bd4e0c996 --- /dev/null +++ b/third_party/enflame/include/CMakeLists.txt @@ -0,0 +1 @@ +configure_file(version.h.in version.h) diff --git a/third_party/enflame/include/triton/CMakeLists.txt b/third_party/enflame/include/triton/CMakeLists.txt new file mode 100644 index 000000000..aa611df7a --- /dev/null +++ b/third_party/enflame/include/triton/CMakeLists.txt @@ -0,0 +1,335 @@ +cmake_minimum_required(VERSION 3.18) + +if(POLICY CMP0116) +# Introduced in cmake 3.20 +# https://cmake.org/cmake/help/latest/policy/CMP0116.html + cmake_policy(SET CMP0116 OLD) +endif() + +include(ExternalProject) + +set(CMAKE_CXX_STANDARD 17) + +set(CMAKE_INCLUDE_CURRENT_DIR ON) + +project(triton CXX C) +include(CTest) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +# Options +option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) +option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) +option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) +option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) +option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON) +set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") + +if(TRITON_BUILD_WITH_CCACHE) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" + CACHE STRING "C compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" + CACHE STRING "CXX compiler launcher") + else() + message( + STATUS + "Could not find ccache. Consider installing ccache to speed up compilation." + ) + endif() +endif() + +set(TRITON_PARALLEL_LINK_JOBS "" CACHE STRING + "Define the maximum number of concurrent link jobs (Ninja only).") +if (TRITON_PARALLEL_LINK_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS link_job_pool=${TRITON_PARALLEL_LINK_JOBS}) + set(CMAKE_JOB_POOL_LINK link_job_pool) +endif() + + +# Ensure Python3 vars are set correctly +# used conditionally in this file and by lit tests + +# Customized release build type with assertions: TritonRelBuildWithAsserts +if(NOT MSVC) + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") + set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") +else() + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /RTC1 /bigobj /Zc:preprocessor /permissive-") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /RTC1 /bigobj /Zc:preprocessor /permissive-") + set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") +endif() + +# Default build type +if(NOT CMAKE_BUILD_TYPE) + message(STATUS "Default build type: Release") + set(CMAKE_BUILD_TYPE "Release") +endif() + +if(NOT WIN32) + find_library(TERMINFO_LIBRARY tinfo) +endif() + +if(TRITON_BUILD_UT) + # This is an aggregate target for all unit tests. + add_custom_target(TritonUnitTests) + set_target_properties(TritonUnitTests PROPERTIES FOLDER "Triton/Tests") + include(AddTritonUnitTest) +endif() + +# Compiler flags +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS") +endif() + + +# ######### +# LLVM +# ######### +if(NOT MLIR_DIR) + set(MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir) +endif() + +# MLIR +find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR}) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + +include(TableGen) # required by AddMLIR +include(AddLLVM) +include(AddMLIR) + +# Utilities +function(add_triton_object name) + cmake_parse_arguments(ARG "" "" "DEPENDS;LINK_LIBS" ${ARGN}) + add_library(${name} OBJECT) + target_sources(${name} + PRIVATE ${ARG_UNPARSED_ARGUMENTS} + INTERFACE $ + ) + + + # add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) + if(ARG_DEPENDS) + add_dependencies(${name} ${ARG_DEPENDS}) + endif() + if(ARG_LINK_LIBS) + target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) + endif() +endfunction(add_triton_object) + +set_property(GLOBAL PROPERTY TRITON_LIBS "") +function(add_triton_library name) + set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name}) + add_triton_object(${name} ${ARGN}) + llvm_update_compile_flags(${name}) +endfunction() + +set_property(GLOBAL PROPERTY TRITON_PLUGINS "") +function(add_triton_plugin name) + set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name}) + add_triton_object(${name} ${ARGN}) +endfunction() + + +# Disable warnings that show up in external code (gtest;pybind11) +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fvisibility=hidden") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4244 /wd4624 /wd4715 /wd4530") +endif() + +include_directories(".") +include_directories(${MLIR_INCLUDE_DIRS}) +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${PROJECT_SOURCE_DIR}/include) +include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files +include_directories(${PROJECT_SOURCE_DIR}/third_party) +include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files + +# link_directories(${LLVM_LIBRARY_DIR}) +add_subdirectory(include) +add_subdirectory(lib) + +# TODO: Figure out which target is sufficient to fix errors; triton is +# apparently not enough. Currently set linking libstdc++fs for all targets +# to support some old version GCC compilers like 8.3.0. +if (NOT WIN32 AND NOT APPLE) + link_libraries(stdc++fs) +endif() + + +# ----- + +# ------ +if(TRITON_BUILD_PYTHON_MODULE) + message(STATUS "Adding Python module") + set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) + include_directories(${PYTHON_SRC_PATH}) + + # Python Interpreter is used to run lit tests + find_package(Python3 REQUIRED COMPONENTS Development.Module Interpreter) + find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") + + if (DEFINED TRITON_PLUGIN_DIRS) + foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS}) + # Read the plugin name under dir/backend/name.conf + cmake_path(APPEND PLUGIN_DIR "backend" "name.conf" OUTPUT_VARIABLE PLUGIN_NAME_PATH) + file(READ ${PLUGIN_NAME_PATH} PLUGIN_NAME) + string(STRIP ${PLUGIN_NAME} PLUGIN_NAME) + + list(APPEND TRITON_PLUGIN_NAMES ${PLUGIN_NAME}) + + # Include the plugin as part of the build, placing the build output under + # ${TRITON_BINARY_DIR}/third_party/${PLUGIN_NAME} + cmake_path(APPEND TRITON_BINARY_DIR "third_party" ${PLUGIN_NAME} OUTPUT_VARIABLE PLUGIN_DIR_BUILD_OUTPUT) + message(STATUS "Building plugin '${PLUGIN_NAME}' from ${PLUGIN_DIR} with output ${PLUGIN_DIR_BUILD_OUTPUT}") + add_subdirectory(${PLUGIN_DIR} ${PLUGIN_DIR_BUILD_OUTPUT}) + endforeach() + endif() + + foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) + add_subdirectory(third_party/${CODEGEN_BACKEND}) + endforeach() + + if (TRITON_BUILD_PROTON) + add_subdirectory(third_party/proton) + endif() + # We always build proton dialect + list(APPEND TRITON_PLUGIN_NAMES "proton") + add_subdirectory(third_party/proton/dialect) + + get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) + set(TRITON_LIBRARIES + ${triton_libs} + ${triton_plugins} + + # mlir + MLIRAMDGPUDialect + MLIRNVVMDialect + MLIRNVVMToLLVMIRTranslation + MLIRGPUToNVVMTransforms + MLIRGPUToGPURuntimeTransforms + MLIRGPUTransforms + MLIRIR + MLIRControlFlowToLLVM + MLIRBytecodeWriter + MLIRPass + MLIRTransforms + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + MLIRMathToLLVM + MLIRROCDLToLLVMIRTranslation + MLIRGPUDialect + MLIRSCFToControlFlow + MLIRIndexToLLVM + MLIRGPUToROCDLTransforms + MLIRUBToLLVM + + # LLVM + LLVMPasses + LLVMNVPTXCodeGen + # LLVMNVPTXAsmPrinter + LLVMAMDGPUCodeGen + LLVMAMDGPUAsmParser + + Python3::Module + pybind11::headers + + ) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 + CMAKE_SYSTEM_PROCESSOR MATCHES "arm64" OR # macOS arm64 + CMAKE_OSX_ARCHITECTURES MATCHES "arm64") # also macOS arm64 + list(APPEND TRITON_LIBRARIES + LLVMAArch64CodeGen + LLVMAArch64AsmParser + ) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64") + list(APPEND TRITON_LIBRARIES + LLVMX86CodeGen + LLVMX86AsmParser + ) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le") + list(APPEND TRITON_LIBRARIES + LLVMPowerPCAsmParser + LLVMPowerPCCodeGen + ) + else() + message(FATAL_ERROR "LLVM codegen/ASM parser libs: This HW architecture (${CMAKE_SYSTEM_PROCESSOR}) is not configured in cmake lib dependencies.") + endif() + + # Define triton library + string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS}) + + if (DEFINED TRITON_PLUGIN_NAMES) + string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_BACKENDS_TUPLE} ${TRITON_PLUGIN_NAMES}) + endif() + + message(STATUS "Triton backends tuple: ${TRITON_BACKENDS_TUPLE}") + + set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})") + add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) + add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc + ${PYTHON_SRC_PATH}/ir.cc + ${PYTHON_SRC_PATH}/passes.cc + ${PYTHON_SRC_PATH}/interpreter.cc + ${PYTHON_SRC_PATH}/llvm.cc) + + # Link triton with its dependencies + target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES}) + if(WIN32) + target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) + set_target_properties(triton PROPERTIES SUFFIX ".pyd") + set_target_properties(triton PROPERTIES PREFIX "lib") + else() + target_link_libraries(triton PRIVATE z) + endif() + target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) +endif() + +if (UNIX AND NOT APPLE) + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") +endif() + +if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") + + # Check if the platform is MacOS + if(APPLE) + set(PYTHON_LDFLAGS "-undefined dynamic_lookup") + endif() + + target_link_options(triton PRIVATE ${PYTHON_LDFLAGS}) +endif() + +if(NOT TRITON_BUILD_PYTHON_MODULE) + foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) + add_subdirectory(third_party/${CODEGEN_BACKEND}) + endforeach() + add_subdirectory(third_party/proton/dialect) +endif() + +find_package(Threads REQUIRED) + +add_subdirectory(third_party/f2reduce) + +if(TRITON_BUILD_UT) + add_subdirectory(unittest) + # This target runs all the unit tests. + add_custom_target(check-triton-unit-tests + COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure + DEPENDS TritonUnitTests + USES_TERMINAL + ) +endif() diff --git a/third_party/enflame/include/triton/CONTRIBUTING.md b/third_party/enflame/include/triton/CONTRIBUTING.md new file mode 100644 index 000000000..2122eaf6c --- /dev/null +++ b/third_party/enflame/include/triton/CONTRIBUTING.md @@ -0,0 +1,70 @@ +# Governance Structure + +Triton adopts the following hierarchical technical governance structure: +* A community of **contributors** who file issues and submit pull requests +* A group of **module maintainers** who own parts of Triton and drive their development +* A body of **core maintainers** who own Triton overall and drive its development +* A **lead core maintainer** who is the catch-all decision maker when consensus cannot be reached by core maintainers + +All contributions are expected to follow Triton’s design principles, as enforced by module and core maintainers. While high-quality pull requests are appreciated and encouraged, all maintainers reserve the right to prioritize their own work over code reviews at-will, hence contributors should not expect their work to be reviewed promptly. + +Contributors can maximize the chances of their work being accepted by maintainers by meeting a high quality bar before sending a PR to maintainers. We encourage maintainers who contribute to Triton on behalf of a company to get reviews from senior developers within their company before sending to maintainers. +Module maintainers +We aim to make the Triton codebase as modular as possible, such that different components (e.g., subdirectories) can be improved in parallel under the supervision of different module maintainers. + +What constitutes (or not) a module is up to the core maintainers. Core maintainers also reserve the right to decide whether the development of a module should happen – or keep happening – in-tree or not. + +**List of in-tree modules (as of 05/12/2024, alphabetical order):** +* AMD backend (Lei Zhang) +* Interpreter (Keren Zhou) +* Profiler (Keren Zhou) + +Note: Parts of Triton that are not listed above (e.g., Nvidia backend) are assumed to be owned by core maintainers. + +Note: Some important parts of the Triton eco-system (e.g., Intel XPU backend) may be maintained out-of-tree and advertised in our repository. The governance rules described in this document do not carry over to these modules. + +__List of out-of-tree modules (as of 05/12/2024, alphabetical order):__ +* CPU backend (Bert Maher, Ilya Enkovich) +* Intel backend (Ettore Tiotto, Whitney Tsang) + + +## Core maintainers +The core maintainers drive the development of Triton at large and set the roadmap for the project. As such, they have the following responsibilities: +* Proposing, implementing and reviewing profound changes to user-facing APIs, IR specifications and/or pass infrastructures +* Enforcing code quality standards and adherence to core design principles +* Drawing module boundaries and resolving disputes between module maintainers + + +The core maintainers as a group have the power to veto any decision made at a Module maintainer level. + +The core maintainers should publicly articulate their decision-making, and share the reasoning behind their decisions, vetoes, and dispute resolution. + +__List of core maintainers (as of 01/30/2025, alphabetical order):__ +* Jeff Niu +* Keren Zhou +* Mario Lezcano-Casado +* Pawel Szczerbuk +* Peter Bell +* Phil Tillet +* Thomas Raoux +* Zahi Moudallal + +## Lead core maintainer +When core maintainers cannot come to a consensus, a publicly declared lead maintainer is expected to settle the debate and make executive decisions. + +The Lead Core Maintainer should publicly articulate their decision-making, and give a clear reasoning for their decisions. + +The Lead Core Maintainer is also responsible for confirming or removing core maintainers. + +**Lead maintainer (as of 05/12/2024)** +* Phil Tillet + +# Decision Making + +## Uncontroversial Changes + +We are committed to accepting functional bug fixes that meet our quality standards – and include minimized unit tests to avoid future regressions. Performance improvements generally fall under the same category, with the caveat that they may be rejected if the trade-off between usefulness and complexity is deemed unfavorable by core maintainers (e.g., complex swizzling logic to improve the performance of non-tensor-cores matrix multiplications). Design changes that neither fix known functional nor performance issues are automatically considered controversial. + +## Controversial Changes + +More controversial design changes (e.g., changes in our IRs/APIs/Passes) are evaluated on a case-by-case basis under the subjective judgment of core maintainers. While it is possible for contributors to propose and land deep design changes upstream (see https://github.com/triton-lang/triton/pull/1305), the community should expect such occurrences to be relatively rare. diff --git a/third_party/enflame/include/triton/LICENSE b/third_party/enflame/include/triton/LICENSE new file mode 100644 index 000000000..1d0238e86 --- /dev/null +++ b/third_party/enflame/include/triton/LICENSE @@ -0,0 +1,23 @@ +/* +* Copyright 2018-2020 Philippe Tillet +* Copyright 2020-2022 OpenAI +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, +* subject to the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ diff --git a/third_party/enflame/include/triton/Makefile b/third_party/enflame/include/triton/Makefile new file mode 100644 index 000000000..f984decf3 --- /dev/null +++ b/third_party/enflame/include/triton/Makefile @@ -0,0 +1,97 @@ +# This is not the build system, just a helper to run common development commands. +# Make sure to first initialize the build system with: +# make dev-install + +PYTHON ?= python +BUILD_DIR := $(shell cd python; $(PYTHON) -c 'from build_helpers import get_cmake_dir; print(get_cmake_dir())') +TRITON_OPT := $(BUILD_DIR)/bin/triton-opt +PYTEST := $(PYTHON) -m pytest + +# Incremental builds + +.PHONY: all +all: + ninja -C $(BUILD_DIR) + +.PHONY: triton-opt +triton-opt: + ninja -C $(BUILD_DIR) triton-opt + +# Testing + +.PHONY: test-lit +test-lit: + ninja -C $(BUILD_DIR) check-triton-lit-tests + +.PHONY: test-cpp +test-cpp: + ninja -C $(BUILD_DIR) check-triton-unit-tests + +.PHONY: test-python +test-unit: all + cd python/test/unit && $(PYTEST) -s -n 8 --ignore=cuda/test_flashattention.py \ + --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py + $(PYTEST) -s -n 8 python/test/unit/language/test_subprocess.py + $(PYTEST) -s -n 8 python/test/unit/test_debug.py --forked + TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py + # Run cuda/test_flashattention.py separately to avoid out of gpu memory + $(PYTEST) -s python/test/unit/cuda/test_flashattention.py + TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \ + $(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py + +.PHONY: test-regression +test-regression: all + $(PYTEST) -s -n 8 python/test/regression + +.PHONY: test-interpret +test-interpret: all + cd python/test/unit && TRITON_INTERPRET=1 $(PYTEST) -s -n 16 -m interpreter cuda language/test_core.py language/test_standard.py \ + language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \ + language/test_tuple.py runtime/test_autotuner.py::test_kwargs[False] \ + ../../tutorials/06-fused-attention.py::test_op --device=cpu + +.PHONY: test-proton +test-proton: all + $(PYTEST) -s -n 8 third_party/proton/test + +.PHONY: test-python +test-python: test-unit test-regression test-interpret test-proton + +.PHONY: test-nogpu +test-nogpu: test-lit test-cpp + +.PHONY: test +test: test-lit test-cpp test-python + +# pip install-ing + +.PHONY: dev-install-requires +dev-install-requires: + $(PYTHON) -m pip install -r python/requirements.txt + $(PYTHON) -m pip install -r python/test-requirements.txt + + +.PHONY: dev-install-torch +dev-install-torch: + # install torch but ensure pytorch-triton isn't installed + $(PYTHON) -m pip install torch + $(PYTHON) -m pip uninstall triton pytorch-triton -y + +.PHONY: dev-install-triton +dev-install-triton: + $(PYTHON) -m pip install -e python --no-build-isolation -v + +.PHONY: dev-install +.NOPARALLEL: dev-install +dev-install: dev-install-requires dev-install-triton + +# Updating lit tests + +.PHONY: golden-samples +golden-samples: triton-opt + $(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-pipeline -canonicalize | \ + $(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \ + -o test/TritonGPU/samples/simulated-grouped-gemm.mlir + $(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-pipeline -canonicalize | \ + $(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \ + -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir diff --git a/third_party/enflame/include/triton/README.md b/third_party/enflame/include/triton/README.md new file mode 100644 index 000000000..279ae7dcd --- /dev/null +++ b/third_party/enflame/include/triton/README.md @@ -0,0 +1,462 @@ +
+ Triton logo +
+ +| **`Documentation`** | **`Nightly Wheels`** | +|-------------------- | -------------------- | +| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) | + +# Triton + +This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs. + +The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton! + +The [official documentation](https://triton-lang.org) contains installation instructions and tutorials. See also these third-party [Triton puzzles](https://github.com/srush/Triton-Puzzles), which can all be run using the Triton interpreter -- no GPU required. + +# Quick Installation + +You can install the latest stable release of Triton from pip: + +```shell +pip install triton +``` + +Binary wheels are available for CPython 3.9-3.13. + +# Enabling Blackwell Support + +The main branch now features support for NVIDIA Blackwell GPUs using 5th +generation tensor cores. To enable this, you will need two additional steps: + +1. Build a pre-release PyTorch from source with CUDA 12.8 +2. Build triton from the latest source + + +First, to build pytorch you need to have CUDA 12.8 installed locally. If not, +follow the [instructions for your platform](https://developer.nvidia.com/cuda-downloads) +```bash +# Clone and checkout pytorch 2.6 release candidate +git clone https://github.com/pytorch/pytorch +cd pytorch +git checkout v2.6.0-rc9 +git submodule sync +git submodule update --init --recursive -j 8 + +# Install build dependencies (assumes you already have a system compiler) +pip install -r requirements.txt +pip install mkl-static mkl-include wheel + +# Build PyTorch (will take a long time) +export CUDA_HOME=/usr/local/cuda-12.8 +export CUDA_PATH=$CUDA_HOME +export TORCH_CUDA_ARCH_LIST=Blackwell +python setup.py develop + +# Optional, package build into a wheel to install on other machines. +python setup.py bdist_wheel +ls dist # Wheel should be output in this directory +``` + +Note that if you use the domain libraries (`torchvision`, `torchtext`, +`torchaudio`, etc.) these will need to be built from source as well, otherwise +their custom PyTorch extensions will not work. + +Finally, follow the instructions below to install triton from source. + +# Install from source + +```shell +git clone https://github.com/triton-lang/triton.git +cd triton + +pip install -r python/requirements.txt # build-time dependencies +pip install -e python +``` + +Or with a virtualenv: + +```shell +git clone https://github.com/triton-lang/triton.git +cd triton + +python -m venv .venv --prompt triton +source .venv/bin/activate + +pip install -r python/requirements.txt # build-time dependencies +pip install -e python +``` + +# Building with a custom LLVM + +Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build +downloads a prebuilt LLVM, but you can also build LLVM from source and use that. + +LLVM does not have a stable API, so the Triton build will not work at an +arbitrary LLVM version. + +1. Find the version of LLVM that Triton builds against. Check +`cmake/llvm-hash.txt` to see the current version. For example, if it says: + 49af6502c6dcb4a7f7520178bd14df396f78240c + + This means that the version of Triton you have builds against + [LLVM](https://github.com/llvm/llvm-project) 49af6502. + +2. `git checkout` LLVM at this revision. Optionally, make additional + modifications to LLVM. + +3. [Build LLVM](https://llvm.org/docs/CMake.html). For example, you might run + + $ cd $HOME/llvm-project # your clone of LLVM. + $ mkdir build + $ cd build + $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm;lld" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" + $ ninja + +4. Grab a snack, this will take a while. + +5. Build Triton as above, but set the following environment variables. + + # Modify as appropriate to point to your LLVM build. + $ export LLVM_BUILD_DIR=$HOME/llvm-project/build + + $ cd + $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ + LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ + LLVM_SYSPATH=$LLVM_BUILD_DIR \ + pip install -e python + +# Tips for building + +- Set `TRITON_BUILD_WITH_CLANG_LLD=true` as an environment variable to use clang + and lld. lld in particular results in faster builds. + +- Set `TRITON_BUILD_WITH_CCACHE=true` to build with ccache. + +- Set `TRITON_HOME=/some/path` to change the location of the `.triton` + directory where Triton's cache is located and downloads are stored + during the build. By default, this is the user's home directory. It + can be changed anytime. + +- Pass `--no-build-isolation` to `pip install` to make nop builds faster. + Without this, every invocation of `pip install` uses a different symlink to + cmake, and this forces ninja to rebuild most of the `.a` files. + +- vscode intellisense has some difficulty figuring out how to build Triton's C++ + (probably because, in our build, users don't invoke cmake directly, but + instead use setup.py). Teach vscode how to compile Triton as follows. + + - Do a local build. Run command `pip install -e python` + - Get the full path to the `compile_commands.json` file produced by the build: + `find python/build -name 'compile_commands.json' | xargs readlink -f`. + You might get a full path similar to `/Users/{username}/triton/python/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json` + - In vscode, install the + [C/C++ + extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.cpptools), + then open the command palette (`Shift + Command + P` on Mac, or `Shift + + Ctrl + P` on Windows/Linux) and open `C/C++: Edit Configurations (UI)`. + - Open "Advanced Settings" and paste the full path to + `compile_commands.json` into the "Compile Commands" textbox. + +# Running tests + +There currently isn't a turnkey way to run all the Triton tests, but you can +follow the following recipe. + +```shell +# One-time setup. Note this will reinstall local Triton because torch +# overwrites it with the public version. +$ make dev-install + +# To run all tests (requires a GPU) +$ make test + +# Or, to run tests without a gpu +$ make test-nogpu +``` + +# Tips for hacking + +For detailed instructions on how to debug Triton's frontend, please refer to this [tutorial](https://triton-lang.org/main/programming-guide/chapter-3/debugging.html). The following includes additional tips for hacking on Triton's backend. + +**Helpful environment variables** + +- `MLIR_ENABLE_DUMP=1` dumps the IR before every MLIR pass Triton runs, for all + kernels. Use `MLIR_ENABLE_DUMP=kernelName` to dump for a specific kernel only. + - Triton cache can interfere with the dump. In cases where `MLIR_ENABLE_DUMP=1` does not work, try cleaning your triton cache: `rm -r ~/.triton/cache/*` +- `MLIR_DUMP_PATH` specifies where `MLIR_ENABLE_DUMP` will dump to. If unset will dump to stderr. +- `LLVM_IR_ENABLE_DUMP=1` dumps the IR before every pass run over the LLVM IR. +- `TRITON_REPRODUCER_PATH=` will generate an MLIR reproducer file + at `` before each MLIR compiler stage. If any of the stages fail, + `` will be a local MLIR reproducer captured right before the failing pass. +- `TRITON_INTERPRET=1` uses the Triton interpreter instead of running on the + GPU. You can insert Python breakpoints in your kernel code! +- `TRITON_ENABLE_LLVM_DEBUG=1` passes `-debug` to LLVM, printing a lot of + debugging information to stdout. If this is too noisy, run with just + `TRITON_LLVM_DEBUG_ONLY` instead to limit the output. + + An alternative way to reduce output noisiness is running with + `LLVM_IR_ENABLE_DUMP=1`, extract the IR before the LLVM pass of interest, and + then run LLVM's `opt` standalone, perhaps passing `-debug-only=foo` on the + command line. +- `TRITON_LLVM_DEBUG_ONLY=` is the equivalent of LLVM's + `-debug-only` command-line option. This limits the LLVM debug output to + specific pass or component names (which are specified using `#define + DEBUG_TYPE` throughout LLVM and Triton) in order to allow the debug output to + be less noisy. `TRITON_LLVM_DEBUG_ONLY` allows for one or more comma + separated values to be specified (eg + `TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions"` or + `TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"`). +- `TRITON_ENABLE_ASAN=1` invokes the LLVM address sanitizer for + memory leak and out of bounds access detection. Currently only supported on the AMD + backend. This must be run using the ASAN libraries documented [here](https://rocm.docs.amd.com/projects/llvm-project/en/latest/conceptual/using-gpu-sanitizer.html). + + When enabling the address sanitizer it is recommended to disable various memory caching strategies + both within the ROCm stack and PyTorch. This will give the address sanitizer the best chance at finding the + memory fault where it originates. See this [test](https://github.com/triton-lang/triton/blob/main/third_party/amd/python/test/test_address_sanitizer.py) for more details. + +- `USE_IR_LOC={ttir,ttgir}` reparses the IR such that the location information + will be the line number of the IR file with that particular extension, + instead of line number of the python file. This can provide a direct mapping + from the IR to llir/ptx. When used with performance tools, it can provide a + breakdown on IR instructions. +- `TRITON_PRINT_AUTOTUNING=1` prints out the best autotuning config and total time + spent for each kernel after autotuning is complete. +- `DISABLE_LLVM_OPT` will disable llvm optimizations for make_llir and make_ptx + if its value is true when parsing as Bool. Otherwise, it will be parsed as a list + of flags to disable llvm optimizations. One usage case is + `DISABLE_LLVM_OPT="disable-lsr"` + Loop strength reduction is known to cause up to 10% performance changes for + certain kernels with register pressure. +- `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit. +- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. +- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. +- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). +- `MLIR_ENABLE_DIAGNOSTICS=` controls diagnostic emission in MLIR. + Options are: `warnings`, `remarks`, `stacktraces`, `operations`. + Use comma-separated values to customize output. For example, + `MLIR_ENABLE_DIAGNOSTICS=remarks,operations` enables remarks and IR operations, + while `MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces` enables warnings with + stacktraces. By default, only errors are shown. Setting `warnings` includes + errors and warnings; `remarks` includes errors, warnings, and remarks. +- `MLIR_ENABLE_REMARK` is deprecated. Please use `MLIR_ENABLE_DIAGNOSTICS=remarks`. +- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn. +- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1. +- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage. +- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1. +- `TRITON_F32_DEFAULT` sets the default input precision of `tl.dot` when using 32-bit floats, which can be either `ieee`, `tf32`, or `tf32x3`. + +**Kernel Override Steps** + +```bash +export TRITON_ALWAYS_COMPILE=1 +export TRITON_KERNEL_DUMP=1 +export TRITON_DUMP_DIR= +export TRITON_KERNEL_OVERRIDE=1 +export TRITON_OVERRIDE_DIR= +# Step 1: Run the kernel once to dump kernel's IRs and ptx/amdgcn in $TRITON_DUMP_DIR +# Step 2: Copy $TRITON_DUMP_DIR/ to $TRITON_OVERRIDE_DIR +# Step 3: Delete the stages that you do not want to override and modify the stage you do want to override +# Step 4: Run the kernel again to see the overridden result +``` + + +# Changelog + +Version 2.0 is out! New features include: + +- Many, many bug fixes +- Performance improvements +- Backend rewritten to use MLIR +- Support for kernels that contain back-to-back matmuls (e.g., flash attention) + +# Contributing + +Community contributions are more than welcome, whether it be to fix bugs or to add new features at [github](https://github.com/triton-lang/triton/). For more detailed instructions, please visit our [contributor's guide](CONTRIBUTING.md). + +# Compatibility + +Supported Platforms: + +- Linux + +Supported Hardware: + +- NVIDIA GPUs (Compute Capability 8.0+) +- AMD GPUs (ROCm 6.2+) +- Under development: CPUs + +# Development Container (Dev Container) + +**Dev Containers** for the Triton project are available from +the [triton-dev-containers repository](https://github.com/redhat-et/triton-dev-containers) + +### Key Benefits: +- **Consistency**: All developers can work with the same development + environment, ensuring uniform behavior across different systems. +- **Isolation**: The container prevents potential conflicts with software + installed on your local machine. +- **Portability**: Easily share the development environment with team members, + minimizing onboarding time and setup issues. + +### How to Use the Dev Container: + +For detailed instructions on how to use the dev containers please see +the [dev container user guide](https://github.com/redhat-et/triton-dev-containers/blob/main/.devcontainer/devcontainer.md) + +# Warp Specialization Support + + +Warp specialization enhances kernel performance by utilizing an asynchronous execution model, where different parts of the kernel are handled by separate hardware units. The data communication between these units, via shared memory on the H100, operates with high efficiency. With this in mind, we’ve developed an automatic warp specialization optimization that partitions a user kernel into asynchronous tasks (which map to warp groups on NVIDIA GPU), which naturally execute concurrently, leveraging the hardware’s multitasking warp scheduler. The following sections provide a breakdown of the compiler features developed to enable warp specialization. + + +## Asynchronous Tasks + +Warp specialization is built on top of the concept of partitioning the user’s program into asynchronous tasks (referred to as "async tasks" or “tasks” in the following sections). Each async task will be executed by a standalone warp group on the supported hardware, to achieve instruction level parallelism. While optimally and automatically partitioning asynchronous tasks remains a challenge for compilers, our approach to automatic task partitioning has proven effective for kernels similar to typical examples like GEMM and Flash Attention. + +To enable warp specialization, user just needs to specify certain autotune flags, i.e., `num_consumer_groups` and `num_buffers_warp_spec`. For example, a warp-specialized GEMM implementation might look like below. You can find a complete example in 09-persistent-matmul.py. + +```python +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + num_consumer_groups=2, + num_buffers_warp_spec=3, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_persistent_ws_kernel( + a_ptr, b_ptr, c_ptr, M, N, K, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_pid_m + pid_n = pid % num_pid_n + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + acc += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + c = acc.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] + tl.store(c_ptrs, c) +``` + +The compiler automatically determines how to utilize one producer warp group and two consumer warp groups to execute the kernel. It begins by assigning task IDs to certain anchor operations, which influence the task assignments for the remaining operations. Once the anchor tasks are annotated, the compiler assigns the non-anchor operations to tasks as follows: + +- Control dependencies exclusive to an anchor operation are included in the same task as the anchor operation. +- Data dependencies exclusive to an anchor operation are included in the same task as the anchor operation, unless they are another anchor operation. +- Control or data dependencies shared between tasks are included in all those tasks. + +For the GEMM example above, the compiler computes a task scheme and annotates it in the IR using MLIR attributes. To illustrate this more clearly, let's use source code annotations. After task propagation: + + +```python +@triton.jit +def matmul_persistent_ws_kernel( + a_ptr, b_ptr, c_ptr, M, N, K, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) # async_task 0, 1 + num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1 + num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1 + pid_m = pid // num_pid_m # async_task 0, 1 + pid_n = pid % num_pid_n # async_task 0, 1 + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # async_task 0, 1 + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # async_task 0, 1 + offs_k = tl.arange(0, BLOCK_K) # async_task 0 + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0 + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0 + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # async_task 1 + for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1 + a = tl.load(a_ptrs) # async_task 0 + b = tl.load(b_ptrs) # async_task 0 + acc += tl.dot(a, b) # async_task 1 + a_ptrs += BLOCK_K * stride_ak # async_task 0 + b_ptrs += BLOCK_K * stride_bk # async_task 0 + c = acc.to(tl.float16) # async_task 1 + c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] # async_task 1 + tl.store(c_ptrs, c) # async_task 1 +``` + + +## Data Partitioning + +To further improve performance, the compiler will split the same workload across two async tasks This way, when one task is blocked on a heavy computation (e.g., the dot operation), the other group can execute other operations in parallel. The compiler determines how to divide the work between the two tasks to maximize performance. On the H100 GPU, the compiler will, by default, attempt to split the input tensor A along the M dimension so that each consumer computes half of the output tensor independently. This approach is known as cooperative partitioning. If this split is not advantageous—for instance, if it results in a smaller-than-native `wgmma` instruction—the compiler will instead attempt to split along the N dimension. + +The transformed code for the above GEMM kernel with a configured tile size [128, 256, 64] will look like below (using source annotations instead of IR for illustration). + + +```python +@triton.jit +def matmul_persistent_ws_kernel( + a_ptr, b_ptr, c_ptr, M, N, K, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) # async_task 0, 1, 2 + num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1, 2 + num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1, 2 + pid_m = pid // num_pid_m # async_task 0, 1, 2 + pid_n = pid % num_pid_n # async_task 0, 1, 2 + offs_m_1 = pid_m * BLOCK_M + tl.arange(0, BLOCK_M // 2) # async_task 0, 1, 2 + offs_m_2 = pid_m * BLOCK_M + tl.arange(BLOCK_M // 2, BLOCK_M) # async_task 0, 1, 2 + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_N) # async_task 0, 1, 2 + offs_k = tl.arange(0, BLOCK_K) # async_task 0 + a_ptrs_1 = a_ptr + (offs_m_1[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0 + a_ptrs_2 = a_ptr + (offs_m_2[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0 + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0 + acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 1 + acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 2 + for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1, 2 + a_1 = tl.load(a_ptrs_1) # async_task 0 + a_2 = tl.load(a_ptrs_2) # async_task 0 + b = tl.load(b_ptrs) # async_task 0 + acc_1 += tl.dot(a_1, b) # async_task 1 + acc_2 += tl.dot(a_2, b) # async_task 2 + a_ptrs_1 += BLOCK_K * stride_ak # async_task 0 + a_ptrs_2 += BLOCK_K * stride_ak # async_task 0 + b_ptrs += BLOCK_K * stride_bk # async_task 0 + c_1 = acc_1.to(tl.float16) # async_task 1 + c_2 = acc_2.to(tl.float16) # async_task 2 + c_ptrs_1 = c_ptr_1 + stride_cm * offs_m_1[:, None] + stride_cn * offs_n[None, :] # async_task 1 + c_ptrs_2 = c_ptr_2 + stride_cm * offs_m_2[:, None] + stride_cn * offs_n[None, :] # async_task 2 + tl.store(c_ptrs_1, c_1) # async_task 1 + tl.store(c_ptrs_2, c_2) # async_task 2 +``` + + +## Code Partitioning + +We assume all operations are already marked with a list of taskIds. We first find all communications required between warp groups. Each communication starts from a load operation with a single taskId, and ends at a direct user of the load which belongs to a different taskId. For `ForOps` containing a communication channel, we add additional arguments: `phase` and `bufferIndex`. + +We introduce a tuning configuration: `num_buffers_warp_spec`. For each communication channel, if it is within a `forOp`, we use an array of buffers in SMEM to save the results, and size of the array is determined by `num_buffers_warp_spec`. We also use an array of barriers for each communication channel that is inside a `ForOp`. At this pass, four new operations are introduced to correctly synchronize between the producer and the consumer: `ProducerAcquireOp`, `ProducerCommitOp`, `ConsumerWaitOp`, and `ConsumerReleaseOp`. Each of the four new ops take a token, a buffer Index. `ProducerAcquire` and `ConsumerWait` take an additional phase operand. + + +For `ForOps` with multiple task Ids, we clone one copy for each taskId, each copy contains the operations with the specific taskId. In the end, we create multiple `IfOps`, one for each possible taskId. We go through the body of the function, clone the op for each attached task Id and put the cloned op in the right `IfOp`. + +To adjust register usage, we introduce two new ops: `RegAllocOp` and `RegDeallocOp`, both taking an integer operand. For each warp group, we decide to insert either `RegAllocOp` or `RegDeallocOp`. The current heuristic is simple: if the task Id is 0, we add `RegDeallocOp`, otherwise we use `RegAllocOp`. The amount of register adjustment can be tuned via `reg_dec_producer` and `reg_inc_consumer`. + +This pass also lowers `loadOp`s to `AsyncTMACopyGlobalToLocalOp` or `AsyncCopyGlobalToLocalOp`, so the communication can be expressed via SMEM. For TMA, the producer will become +`ProducerAcquire` -> `barrier_expect` -> `AsyncTMACopyGlobalToLocalOp`, and the consumer will contain `wait_barrier` -> ops -> `ConsumerRelease`. For non-TMA loads, the producer will become `ProducerAcquire` -> `AsyncCopyGlobalToLocalOp` -> `ProducerCommitOp`, and the consumer will contain `ConsumerWaitOp` -> ops -> `ConsumerRelease`. diff --git a/third_party/enflame/include/triton/RELEASE.md b/third_party/enflame/include/triton/RELEASE.md new file mode 100644 index 000000000..edd1e4201 --- /dev/null +++ b/third_party/enflame/include/triton/RELEASE.md @@ -0,0 +1,48 @@ +# Releasing Triton + +Triton releases provide a stable snapshot of the code base encapsulated into a binary that can easily be consumed through PyPI. Additionally, releases represent points in time when we, as the development team, can signal to the community that certain new features are available, what improvements have been made, and any changes that are coming that may impact them (i.e. breaking changes). + +## Release Compatibility Matrix + +Following is the Release Compatibility Matrix for Triton releases: + +| Triton version | Python version | Manylinux version | +| --- | --- | --- | +| 3.2.0 | >=3.9, <=3.13 | glibc 2.17+ x86-64 | +| 3.1.0 | >=3.8, <=3.12 | glibc 2.17+ x86-64 | +| 3.0.0 | >=3.8, <=3.12 | glibc 2.17+ x86-64 | +| 2.3.1 | >=3.7, <=3.12 | glibc 2.17+ x86-64 | +| 2.3.0 | >=3.7, <=3.12 | glibc 2.17+ x86-64 | +| 2.2.0 | >=3.7, <=3.12 | glibc 2.17+ x86-64 | +| 2.1.0 | >=3.7, <=3.11 | glibc 2.17+ x86-64 | +| 2.0.0 | >=3.6, <=3.11 | glibc 2.17+ x86-64 | +| 1.1.1 | >=3.6, <=3.9 | glibc 2.17+ x86-64 | +| 1.1.0 | >=3.6, <=3.9 | glibc 2.17+ x86-64 | +| 1.0.0 | >=3.6, <=3.9 | glibc 2.17+ x86-64 | + +## Release Cadence + +Following is the release cadence for year 2024/2025. All future release dates below are tentative. Please note: Patch Releases are optional. + +| Minor Version | Release branch cut | Release date | Patch Release date | +| --- | --- | --- | --- | +| 3.5.0 | Sep 2025 | Oct 2025 | --- | +| 3.4.0 | Jun 2025 | Jul 2025 | --- | +| 3.3.0 | Feb/Mar 2025 | Apr 2025 | --- | +| 3.2.0 | Dec 2024 | Jan 2025 | --- | +| 3.1.0 | Jun 2024 | Oct 2024 | --- | +| 3.0.0 | Jun 2024 | Jul 2024 | --- | +| 2.3.0 | Dec 2023 | Apr 2024 | May 2024 | +| 2.2.0 | Dec 2023 | Jan 2024 | --- | + +## Release Cherry-Pick Criteria + +After branch cut, we approach finalizing the release branch with clear criteria on what cherry picks are allowed in. Note: a cherry pick is a process to land a PR in the release branch after branch cut. These are typically limited to ensure that the team has sufficient time to complete a thorough round of testing on a stable code base. + +* Regression fixes - that address functional/performance regression against the most recent release (e.g. 3.2 for 3.3 release) +* Critical fixes - critical fixes for severe issue such as silent incorrectness, backwards compatibility, crashes, deadlocks, (large) memory leaks +* Fixes to new features introduced in the most recent release (e.g. 3.2 for 3.3 release) +* Documentation improvements +* Release branch specific changes (e.g. change version identifiers or CI fixes) + +Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes. An issue is for tracking cherry-picks to the release branch is created after the branch cut. **Only issues that have ‘cherry-picks’ in the issue tracker will be considered for the release.** diff --git a/third_party/enflame/include/triton/TritonNvidiaGPUAttrDefs.td b/third_party/enflame/include/triton/TritonNvidiaGPUAttrDefs.td new file mode 100644 index 000000000..d823e617b --- /dev/null +++ b/third_party/enflame/include/triton/TritonNvidiaGPUAttrDefs.td @@ -0,0 +1,71 @@ +#ifndef TRITONNVIDIAGPU_ATTRDEFS +#define TRITONNVIDIAGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +def TTG_TensorMemorySpace : AttrDef { + let mnemonic = "tensor_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to tensor memory. + The memory is laid out in blocks of size blockM x blockN. Each block is distributed + across TMEM 128 rows. + + Blocks are distributed along M dimension first and then N dimension. This is an arbitrary + convention that need to be followed operations reading/writing to TMEM. + + a tensor <128x128xf32> with blockM = 64 and blockN = 64 will be distributed as follows: + + \ col 0 1 31 32 64 96 127 + rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) (64, 0) ... (0, 64) ... (64, 64) ... (64, 96) + 1 + ... + 15 (15, 0) (15, 1) ... (15, 31) (79, 0) ... (15, 64) ... (79, 64) ... (79, 96) + 16 ( 0, 32) ( 0, 33) ... ( 0, 63) (64, 32) ... ( 0, 96) ... (64, 96) ... (64, 127) + ... + 31 (15, 32) (15, 33) ... (15, 63) (79, 32) ... (15, 96) ... (79, 96) ... (79, 127) + 32 (16, 0) (16, 1) ... (16, 31) (80, 0) ... (16, 64) ... (80, 64) ... (80, 96) + ... + 127 (63, 32) (63, 33) ... (63, 63) (127, 32) ... (63, 96) ... (127, 96)... (127, 127) + }]; +} + +def TTG_TensorMemoryEncodingAttr : AttrDef { + let mnemonic = "tensor_memory_encoding"; + let attrName = "triton.gpu.tensor_memory_encoding"; + let description = [{ + An encoding to represent the different way the tensor memory is laid out. + `unpacked` attributes indicates whether types smaller than 32bits are unpacked (take full 32bits) + or are packed (N elements are stored within one 32bits row). + }]; + let parameters = ( + ins + "unsigned":$blockM, + "unsigned":$blockN, + "bool":$unpacked, + DefaultValuedParameter<"unsigned", "1">:$CTASplitM, + DefaultValuedParameter<"unsigned", "1">:$CTASplitN + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def TTG_TensorMemoryScalesEncodingAttr : AttrDef { + let mnemonic = "tensor_memory_scales_encoding"; + let attrName = "triton.gpu.tensor_memory_scales_encoding"; + let description = [{ + An encoding to represent the layout of tensor memory scales. + As described in the PTX doc, blocked scales in TMEM must be in a special layout. They are organized + as a multiple copies of "chunk", each of which having the size 32x4x4B. Moreover, such chunks are duplicated + over 4 warps to fill entire 128 rows of TMEM. This encoding indicates that a tensor in TMEM is in such a special + layout. + }]; + let parameters = ( + ins + DefaultValuedParameter<"unsigned", "1">:$CTASplitM, + DefaultValuedParameter<"unsigned", "1">:$CTASplitN + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +#endif diff --git a/third_party/enflame/include/triton/bin/CMakeLists.txt b/third_party/enflame/include/triton/bin/CMakeLists.txt new file mode 100644 index 000000000..7ed6d6270 --- /dev/null +++ b/third_party/enflame/include/triton/bin/CMakeLists.txt @@ -0,0 +1,93 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + +add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED) + +# TODO: what's this? +llvm_update_compile_flags(triton-opt) +target_link_libraries(triton-opt PRIVATE + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIROptLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-opt) + +add_llvm_executable(triton-reduce triton-reduce.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(triton-reduce) + +llvm_update_compile_flags(triton-reduce) +target_link_libraries(triton-reduce PRIVATE + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIRReduceLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-reduce) + +add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED) + +llvm_update_compile_flags(triton-lsp) +target_link_libraries(triton-lsp PRIVATE + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # tests + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + # MLIR core + MLIRLspServerLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-lsp) + + +add_llvm_executable(triton-llvm-opt + triton-llvm-opt.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(triton-llvm-opt PRIVATE + TritonLLVMIR + + LLVMAnalysis + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(triton-llvm-opt) + + +add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED) +target_link_libraries(triton-tensor-layout PRIVATE + ${triton_libs} + ${conversion_libs} + ${dialect_libs} + TritonTestAnalysis + TritonTestDialectTritonGPU + TritonAMDGPUTestAnalysis + ) diff --git a/third_party/enflame/include/triton/bin/RegisterTritonDialects.h b/third_party/enflame/include/triton/bin/RegisterTritonDialects.h new file mode 100644 index 000000000..cd96bfbfb --- /dev/null +++ b/third_party/enflame/include/triton/bin/RegisterTritonDialects.h @@ -0,0 +1,85 @@ +#pragma once +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "amd/include/TritonAMDGPUTransforms/Passes.h" +#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +// Below headers will allow registration to ROCm passes +#include "TritonAMDGPUToLLVM/Passes.h" +#include "TritonAMDGPUTransforms/Passes.h" +#include "TritonAMDGPUTransforms/TritonGPUConversion.h" + +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include "nvidia/include/NVGPUToLLVM/Passes.h" +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" + +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/InitAllPasses.h" + +namespace mlir { +namespace test { +void registerTestAliasPass(); +void registerTestAlignmentPass(); +void registerTestAllocationPass(); +void registerTestMembarPass(); +void registerTestTritonAMDGPURangeAnalysis(); +} // namespace test +} // namespace mlir + +inline void registerTritonDialects(mlir::DialectRegistry ®istry) { + mlir::registerAllPasses(); + mlir::registerTritonPasses(); + mlir::triton::gpu::registerTritonGPUPasses(); + mlir::registerTritonNvidiaGPUPasses(); + mlir::test::registerTestAliasPass(); + mlir::test::registerTestAlignmentPass(); + mlir::test::registerTestAllocationPass(); + mlir::test::registerTestMembarPass(); + mlir::test::registerTestTritonAMDGPURangeAnalysis(); + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::gpu::registerAllocateSharedMemoryPass(); + mlir::triton::gpu::registerTritonGPUAllocateWarpGroups(); + mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass(); + mlir::triton::registerConvertWarpSpecializeToLLVM(); + mlir::triton::registerConvertTritonGPUToLLVMPass(); + mlir::triton::registerConvertNVGPUToLLVMPass(); + mlir::registerLLVMDIScope(); + + // TritonAMDGPUToLLVM passes + mlir::triton::registerConvertTritonAMDGPUToLLVM(); + mlir::triton::registerConvertBuiltinFuncToLLVM(); + mlir::triton::registerDecomposeUnsupportedAMDConversions(); + mlir::triton::registerOptimizeAMDLDSUsage(); + + // TritonAMDGPUTransforms passes + mlir::registerTritonAMDGPUAccelerateMatmul(); + mlir::registerTritonAMDGPUOptimizeEpilogue(); + mlir::registerTritonAMDGPUHoistLayoutConversions(); + mlir::registerTritonAMDGPUReorderInstructions(); + mlir::registerTritonAMDGPUBlockPingpong(); + mlir::registerTritonAMDGPUStreamPipeline(); + mlir::registerTritonAMDGPUCanonicalizePointers(); + mlir::registerTritonAMDGPUConvertToBufferOps(); + mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); + mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + + registry + .insert(); +} diff --git a/third_party/enflame/include/triton/bin/triton-llvm-opt.cpp b/third_party/enflame/include/triton/bin/triton-llvm-opt.cpp new file mode 100644 index 000000000..1ec804cb5 --- /dev/null +++ b/third_party/enflame/include/triton/bin/triton-llvm-opt.cpp @@ -0,0 +1,121 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt OutputFilename("o", + cl::desc("Override output filename"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + // Default to standard output. + if (OutputFilename.empty()) + OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + Out->keep(); + return 0; +} diff --git a/third_party/enflame/include/triton/bin/triton-lsp.cpp b/third_party/enflame/include/triton/bin/triton-lsp.cpp new file mode 100644 index 000000000..f95036dc6 --- /dev/null +++ b/third_party/enflame/include/triton/bin/triton-lsp.cpp @@ -0,0 +1,10 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/third_party/enflame/include/triton/bin/triton-opt.cpp b/third_party/enflame/include/triton/bin/triton-opt.cpp new file mode 100644 index 000000000..2d2570771 --- /dev/null +++ b/third_party/enflame/include/triton/bin/triton-opt.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Triton (GPU) optimizer driver\n", registry)); +} diff --git a/third_party/enflame/include/triton/bin/triton-reduce.cpp b/third_party/enflame/include/triton/bin/triton-reduce.cpp new file mode 100644 index 000000000..8235f8fc8 --- /dev/null +++ b/third_party/enflame/include/triton/bin/triton-reduce.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-reduce/MlirReduceMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::mlirReduceMain(argc, argv, context)); +} diff --git a/third_party/enflame/include/triton/bin/triton-tensor-layout.cpp b/third_party/enflame/include/triton/bin/triton-tensor-layout.cpp new file mode 100644 index 000000000..cc121b3e1 --- /dev/null +++ b/third_party/enflame/include/triton/bin/triton-tensor-layout.cpp @@ -0,0 +1,232 @@ +#include "RegisterTritonDialects.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/AsmParser/AsmParserState.h" +#include "mlir/IR/MLIRContext.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace mlir; + +// A CLI tool to print the layout of a tensor. +// +// clang-format off +// Example usage: +// +// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view +// +// An input file usually looks like: +// ''' +// #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +// #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +// ''' +// clang-format on + +//===--------------------------------------------------------------------===// +// CLI options +//===--------------------------------------------------------------------===// + +cl::OptionCategory PrinterCategory("Available Print Options", + "Options for the tensor layout printing."); + +static cl::opt InputFile( + "i", cl::desc("File that contains the tensor data layout attributes"), + cl::init(""), cl::value_desc("filename"), cl::cat(PrinterCategory)); + +static cl::opt + OutputFile("o", cl::desc("Output file to write the layout into"), + cl::init(""), cl::value_desc("filename"), + cl::cat(PrinterCategory)); + +static cl::opt + DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"), + cl::value_desc("layout-string"), cl::init(""), + cl::cat(PrinterCategory)); + +static cl::list + AliasName("alias-names", + cl::desc("A list of alias names (separated by comma) of the " + "layout attributes in the input file"), + cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated, + cl::ZeroOrMore, cl::cat(PrinterCategory)); + +static cl::opt UseHWPointOfView( + "use-hw-view", + llvm::cl::desc( + "Print the layout in hardware point of view. This means the output is " + "from the warp's perspective. Otherwise, the output is from the " + "tensor's perspective (e.g., each element maps to xxx thread)."), + cl::init(false), cl::cat(PrinterCategory)); + +static cl::opt TensorStr( + "t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"), + cl::init(""), cl::value_desc("tensor-type"), cl::cat(PrinterCategory)); + +//===--------------------------------------------------------------------===// +// Helper functions +//===--------------------------------------------------------------------===// + +LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { + // DistributedEncodingTrait and SharedEncodingTrait implements the + // toLinearLayout interface. + mlir::Attribute layout = tensorType.getEncoding(); + if (isa(layout)) { + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); + } + + llvm::errs() << "Unsupported tensor layout attribute: " + << tensorType.getEncoding() << "\n"; + return failure(); +} + +LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, + ArrayRef names, + TensorType tensorTy, raw_string_ostream &ss) { + if (filename.empty()) + return success(); + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + ParserConfig config(context); + auto asmState = AsmParserState(); + + Block parsedIR; + if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) { + llvm::errs() << "Fail to parse the input file: " << filename << "\n"; + return failure(); + } + + auto printLambda = [&](StringRef name, mlir::Attribute attr) { + ss << "Print layout attribute: #" << name << " = " << attr << "\n"; + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), attr); + + return layoutPrint(rankedTensorTy, ss); + }; + + if (names.empty()) + // If no alias name is given, we print all layout attributes in the file. + for (const auto &def : asmState.getAttributeAliasDefs()) { + if (failed(printLambda(def.name, def.value))) + return failure(); + } + else { + // Print the layout attributes with the given alias names. + for (const auto &alias : names) { + auto def = asmState.getAttributeAliasDef(alias); + if (!def) { + llvm::errs() << "Can't find the layout attribute: " << alias << "\n"; + return failure(); + } + + if (failed(printLambda(alias, def->value))) + return failure(); + + ss << "\n"; + } + } + + return success(); +} + +LogicalResult printLayoutFromString(MLIRContext *context, + StringRef layoutAttrStr, + TensorType tensorTy, + raw_string_ostream &ss) { + if (layoutAttrStr.empty()) + return success(); + + mlir::Attribute layout = parseAttribute(layoutAttrStr, context); + if (!layout) { + llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n"; + return failure(); + } + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), layout); + + ss << "Print layout attribute: " << layout << "\n"; + + return layoutPrint(rankedTensorTy, ss); +} + +//===--------------------------------------------------------------------===// +// Main entry point +//===--------------------------------------------------------------------===// + +int main(int argc, char **argv) { + cl::HideUnrelatedOptions(PrinterCategory); + cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n"); + + DialectRegistry registry; + registerTritonDialects(registry); + + MLIRContext ctx(registry); + ctx.loadAllAvailableDialects(); + + if (TensorStr.empty()) { + llvm::errs() << "Must specify the tensor type argument\n"; + return 1; + } + + mlir::Type parsedTy = parseType(TensorStr, &ctx); + if (!parsedTy) { + llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr + << "\n"; + return 1; + } + + TensorType tensorType = dyn_cast(parsedTy); + if (!tensorType) { + llvm::errs() << "Invalid tensor type argument: " << TensorStr << "\n"; + return 1; + } + + std::string storage; + raw_string_ostream ss(storage); + + if (failed(printLayoutFromFile(&ctx, InputFile, AliasName, tensorType, ss))) + return 1; + + if (failed(printLayoutFromString(&ctx, DataLayoutStr, tensorType, ss))) + return 1; + + if (OutputFile.empty()) { + llvm::outs() << ss.str(); + } else { + std::error_code ec; + llvm::raw_fd_ostream outFs(OutputFile, ec, llvm::sys::fs::OF_Text); + if (ec) { + llvm::errs() << "Error: " << ec.message() << " : unable to open " + << OutputFile << " for output\n"; + return 1; + } + outFs << ss.str(); + outFs.close(); + } + + return 0; +} diff --git a/third_party/enflame/include/triton/cmake/AddTritonUnitTest.cmake b/third_party/enflame/include/triton/cmake/AddTritonUnitTest.cmake new file mode 100644 index 000000000..939c1e4ad --- /dev/null +++ b/third_party/enflame/include/triton/cmake/AddTritonUnitTest.cmake @@ -0,0 +1,44 @@ +include(${PROJECT_SOURCE_DIR}/unittest/googletest.cmake) + +include(GoogleTest) +enable_testing() + +function(add_triton_ut) + set(options) + set(oneValueArgs NAME) + set(multiValueArgs SRCS LIBS DEFS) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + + add_test(NAME ${__NAME} + COMMAND ${__NAME}) + add_executable( + ${__NAME} + ${__SRCS}) + target_link_libraries( + ${__NAME} + PRIVATE + GTest::gtest_main + ${triton_libs} + ${dialect_libs} + ${conversion_libs} + gmock + ${__LIBS}) + + if(NOT MSVC) + target_compile_options(${__NAME} PRIVATE -fno-rtti) + endif() + + target_compile_definitions(${__NAME} PRIVATE ${__DEFS}) + + # Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac + # laptop. I think the issue may be that the very first time you run a program + # it's a bit slow. + gtest_discover_tests(${__NAME} DISCOVERY_TIMEOUT 60) + + # Add the unit test to the top-level unit test target. + add_dependencies(TritonUnitTests ${__NAME}) +endfunction() diff --git a/third_party/enflame/include/triton/cmake/FindLLVM.cmake b/third_party/enflame/include/triton/cmake/FindLLVM.cmake new file mode 100644 index 000000000..47d068fc6 --- /dev/null +++ b/third_party/enflame/include/triton/cmake/FindLLVM.cmake @@ -0,0 +1,196 @@ +# - Find LLVM headers and libraries. +# This module locates LLVM and adapts the llvm-config output for use with +# CMake. +# +# A given list of COMPONENTS is passed to llvm-config. +# +# The following variables are defined: +# LLVM_FOUND - true if LLVM was found +# LLVM_CXXFLAGS - C++ compiler flags for files that include LLVM headers. +# LLVM_ENABLE_ASSERTIONS - Whether LLVM was built with enabled assertions (ON/OFF). +# LLVM_INCLUDE_DIRS - Directory containing LLVM include files. +# LLVM_IS_SHARED - Whether LLVM is going to be linked dynamically (ON) or statically (OFF). +# LLVM_LDFLAGS - Linker flags to add when linking against LLVM +# (includes -LLLVM_LIBRARY_DIRS). +# LLVM_LIBRARIES - Full paths to the library files to link against. +# LLVM_LIBRARY_DIRS - Directory containing LLVM libraries. +# LLVM_NATIVE_ARCH - Backend corresponding to LLVM_HOST_TARGET, e.g., +# X86 for x86_64 and i686 hosts. +# LLVM_ROOT_DIR - The root directory of the LLVM installation. +# llvm-config is searched for in ${LLVM_ROOT_DIR}/bin. +# LLVM_TARGETS_TO_BUILD - List of built LLVM targets. +# LLVM_VERSION_MAJOR - Major version of LLVM. +# LLVM_VERSION_MINOR - Minor version of LLVM. +# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn). +# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0). +# +# Note: The variable names were chosen in conformance with the official CMake +# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt. + +# Try suffixed versions to pick up the newest LLVM install available on Debian +# derivatives. +# We also want an user-specified LLVM_ROOT_DIR to take precedence over the +# system default locations such as /usr/local/bin. Executing find_program() +# multiples times is the approach recommended in the docs. +set(llvm_config_names llvm-config-6.0 llvm-config60 + llvm-config) +foreach(v RANGE 7 17) + # names like llvm-config-7.0 llvm-config70 llvm-config-7 llvm-config-7-64 + list(PREPEND llvm_config_names llvm-config-${v}.0 llvm-config${v}0 llvm-config-${v} llvm-config-${v}-64) +endforeach() +find_program(LLVM_CONFIG + NAMES ${llvm_config_names} + PATHS ${LLVM_ROOT_DIR}/bin NO_DEFAULT_PATH + DOC "Path to llvm-config tool.") +find_program(LLVM_CONFIG NAMES ${llvm_config_names}) +if(APPLE) + # extra fallbacks for MacPorts & Homebrew + find_program(LLVM_CONFIG + NAMES ${llvm_config_names} + PATHS /opt/local/libexec/llvm-11/bin /opt/local/libexec/llvm-10/bin /opt/local/libexec/llvm-9.0/bin + /opt/local/libexec/llvm-8.0/bin /opt/local/libexec/llvm-7.0/bin /opt/local/libexec/llvm-6.0/bin + /opt/local/libexec/llvm/bin + /usr/local/opt/llvm@11/bin /usr/local/opt/llvm@10/bin /usr/local/opt/llvm@9/bin + /usr/local/opt/llvm@8/bin /usr/local/opt/llvm@7/bin /usr/local/opt/llvm@6/bin + /usr/local/opt/llvm/bin + NO_DEFAULT_PATH) +endif() + +# Prints a warning/failure message depending on the required/quiet flags. Copied +# from FindPackageHandleStandardArgs.cmake because it doesn't seem to be exposed. +macro(_LLVM_FAIL _msg) + if(LLVM_FIND_REQUIRED) + message(FATAL_ERROR "${_msg}") + else() + if(NOT LLVM_FIND_QUIETLY) + message(WARNING "${_msg}") + endif() + endif() +endmacro() + + +if(NOT LLVM_CONFIG) + if(NOT LLVM_FIND_QUIETLY) + _LLVM_FAIL("No LLVM installation (>= ${LLVM_FIND_VERSION}) found. Try manually setting the 'LLVM_ROOT_DIR' or 'LLVM_CONFIG' variables.") + endif() +else() + macro(llvm_set var flag) + if(LLVM_FIND_QUIETLY) + set(_quiet_arg ERROR_QUIET) + endif() + set(result_code) + execute_process( + COMMAND ${LLVM_CONFIG} --link-static --${flag} + RESULT_VARIABLE result_code + OUTPUT_VARIABLE LLVM_${var} + OUTPUT_STRIP_TRAILING_WHITESPACE + ${_quiet_arg} + ) + if(result_code) + _LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'") + else() + if(${ARGV2}) + file(TO_CMAKE_PATH "${LLVM_${var}}" LLVM_${var}) + endif() + endif() + endmacro() + macro(llvm_set_libs var flag components) + if(LLVM_FIND_QUIETLY) + set(_quiet_arg ERROR_QUIET) + endif() + set(result_code) + execute_process( + COMMAND ${LLVM_CONFIG} --link-static --${flag} ${components} + RESULT_VARIABLE result_code + OUTPUT_VARIABLE tmplibs + OUTPUT_STRIP_TRAILING_WHITESPACE + ${_quiet_arg} + ) + if(result_code) + _LLVM_FAIL("Failed to execute llvm-config ('${LLVM_CONFIG}', result code: '${result_code})'") + else() + file(TO_CMAKE_PATH "${tmplibs}" tmplibs) + string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_${var} ${tmplibs}) + endif() + endmacro() + + llvm_set(VERSION_STRING version) + llvm_set(CXXFLAGS cxxflags) + llvm_set(INCLUDE_DIRS includedir true) + llvm_set(ROOT_DIR prefix true) + llvm_set(ENABLE_ASSERTIONS assertion-mode) + + # The LLVM version string _may_ contain a git/svn suffix, so match only the x.y.z part + string(REGEX MATCH "^[0-9]+[.][0-9]+[.][0-9]+" LLVM_VERSION_BASE_STRING "${LLVM_VERSION_STRING}") + + llvm_set(SHARED_MODE shared-mode) + if(LLVM_SHARED_MODE STREQUAL "shared") + set(LLVM_IS_SHARED ON) + else() + set(LLVM_IS_SHARED OFF) + endif() + + llvm_set(LDFLAGS ldflags) + llvm_set(SYSTEM_LIBS system-libs) + string(REPLACE "\n" " " LLVM_LDFLAGS "${LLVM_LDFLAGS} ${LLVM_SYSTEM_LIBS}") + if(APPLE) # unclear why/how this happens + string(REPLACE "-llibxml2.tbd" "-lxml2" LLVM_LDFLAGS ${LLVM_LDFLAGS}) + endif() + + llvm_set(LIBRARY_DIRS libdir true) + llvm_set_libs(LIBRARIES libfiles "${LLVM_FIND_COMPONENTS}") + # LLVM bug: llvm-config --libs tablegen returns -lLLVM-3.8.0 + # but code for it is not in shared library + if("${LLVM_FIND_COMPONENTS}" MATCHES "tablegen") + if (NOT "${LLVM_LIBRARIES}" MATCHES "LLVMTableGen") + set(LLVM_LIBRARIES "${LLVM_LIBRARIES};-lLLVMTableGen") + endif() + endif() + + llvm_set(CMAKEDIR cmakedir) + llvm_set(TARGETS_TO_BUILD targets-built) + string(REGEX MATCHALL "${pattern}[^ ]+" LLVM_TARGETS_TO_BUILD ${LLVM_TARGETS_TO_BUILD}) + + # Parse LLVM_NATIVE_ARCH manually from LLVMConfig.cmake; including it leads to issues like + # https://github.com/ldc-developers/ldc/issues/3079. + file(STRINGS "${LLVM_CMAKEDIR}/LLVMConfig.cmake" LLVM_NATIVE_ARCH LIMIT_COUNT 1 REGEX "^set\\(LLVM_NATIVE_ARCH (.+)\\)$") + string(REGEX MATCH "set\\(LLVM_NATIVE_ARCH (.+)\\)" LLVM_NATIVE_ARCH "${LLVM_NATIVE_ARCH}") + set(LLVM_NATIVE_ARCH ${CMAKE_MATCH_1}) + message(STATUS "LLVM_NATIVE_ARCH: ${LLVM_NATIVE_ARCH}") + + # On CMake builds of LLVM, the output of llvm-config --cxxflags does not + # include -fno-rtti, leading to linker errors. Be sure to add it. + if(NOT MSVC AND (CMAKE_COMPILER_IS_GNUCXX OR (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang"))) + if(NOT ${LLVM_CXXFLAGS} MATCHES "-fno-rtti") + set(LLVM_CXXFLAGS "${LLVM_CXXFLAGS} -fno-rtti") + endif() + endif() + + # Remove some clang-specific flags for gcc. + if(CMAKE_COMPILER_IS_GNUCXX) + string(REPLACE "-Wcovered-switch-default " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + string(REPLACE "-Wstring-conversion " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + string(REPLACE "-fcolor-diagnostics " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + # this requires more recent gcc versions (not supported by 4.9) + string(REPLACE "-Werror=unguarded-availability-new " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + endif() + + # Remove gcc-specific flags for clang. + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + string(REPLACE "-Wno-maybe-uninitialized " "" LLVM_CXXFLAGS ${LLVM_CXXFLAGS}) + endif() + + string(REGEX REPLACE "([0-9]+).*" "\\1" LLVM_VERSION_MAJOR "${LLVM_VERSION_STRING}" ) + string(REGEX REPLACE "[0-9]+\\.([0-9]+).*[A-Za-z]*" "\\1" LLVM_VERSION_MINOR "${LLVM_VERSION_STRING}" ) + + if (${LLVM_VERSION_STRING} VERSION_LESS ${LLVM_FIND_VERSION}) + _LLVM_FAIL("Unsupported LLVM version ${LLVM_VERSION_STRING} found (${LLVM_CONFIG}). At least version ${LLVM_FIND_VERSION} is required. You can also set variables 'LLVM_ROOT_DIR' or 'LLVM_CONFIG' to use a different LLVM installation.") + endif() +endif() + +# Use the default CMake facilities for handling QUIET/REQUIRED. +include(FindPackageHandleStandardArgs) + +find_package_handle_standard_args(LLVM + REQUIRED_VARS LLVM_ROOT_DIR + VERSION_VAR LLVM_VERSION_STRING) diff --git a/third_party/enflame/include/triton/cmake/json-version.txt b/third_party/enflame/include/triton/cmake/json-version.txt new file mode 100644 index 000000000..c294f65bf --- /dev/null +++ b/third_party/enflame/include/triton/cmake/json-version.txt @@ -0,0 +1 @@ +v3.11.3 diff --git a/third_party/enflame/include/triton/cmake/llvm-hash.txt b/third_party/enflame/include/triton/cmake/llvm-hash.txt new file mode 100644 index 000000000..28e49cc1f --- /dev/null +++ b/third_party/enflame/include/triton/cmake/llvm-hash.txt @@ -0,0 +1 @@ +a66376b0dc3b2ea8a84fda26faca287980986f78 diff --git a/third_party/enflame/include/triton/cmake/nvidia-toolchain-version.json b/third_party/enflame/include/triton/cmake/nvidia-toolchain-version.json new file mode 100644 index 000000000..0436299a0 --- /dev/null +++ b/third_party/enflame/include/triton/cmake/nvidia-toolchain-version.json @@ -0,0 +1,9 @@ +{ + "ptxas-blackwell": "12.8.61", + "ptxas": "12.4.99", + "cuobjdump": "12.8.55", + "nvdisasm": "12.8.55", + "cudacrt": "12.8.61", + "cudart": "12.8.57", + "cupti": "12.8.57" +} diff --git a/third_party/enflame/include/triton/docs/Makefile b/third_party/enflame/include/triton/docs/Makefile new file mode 100644 index 000000000..6a9b50be3 --- /dev/null +++ b/third_party/enflame/include/triton/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXPROJ = Triton +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/third_party/enflame/include/triton/docs/_templates/versions.html b/third_party/enflame/include/triton/docs/_templates/versions.html new file mode 100644 index 000000000..e3bbb1dd9 --- /dev/null +++ b/third_party/enflame/include/triton/docs/_templates/versions.html @@ -0,0 +1,27 @@ +{%- if current_version %} +
+ + Other Versions + v: {{ current_version.name }} + + +
+ {%- if versions.tags %} +
+
Tags
+ {%- for item in versions.tags %} +
{{ item.name }}
+ {%- endfor %} +
+ {%- endif %} + {%- if versions.branches %} +
+
Branches
+ {%- for item in versions.branches %} +
{{ item.name }}
+ {%- endfor %} +
+ {%- endif %} +
+
+{%- endif %} diff --git a/third_party/enflame/include/triton/docs/backend/ldmatrixOperand0.svg b/third_party/enflame/include/triton/docs/backend/ldmatrixOperand0.svg new file mode 100644 index 000000000..745f411db --- /dev/null +++ b/third_party/enflame/include/triton/docs/backend/ldmatrixOperand0.svg @@ -0,0 +1,16 @@ + + + eyJ2ZXJzaW9uIjoiMSIsImVuY29kaW5nIjoiYnN0cmluZyIsImNvbXByZXNzZWQiOnRydWUsImVuY29kZWQiOiJ4nO1dWXPiyLJ+n1/R4fM66NS+TMR9wDbgXHUwMDE1sME2cGKCXHUwMDEwO2ZcdTAwMTFcdTAwMDZcdTAwMDGNJ+a/3yzaXHUwMDA2XHUwMDE5kFx1MDAwMVx1MDAxYrBwQ6+WQCpU+WXml5WV+c9cdTAwMWY/flx1MDAxY7mjTvnor1x1MDAxZkfln0W7WS917eHRn+b4oNzt1Z02nFwi4597Tr9bXHUwMDFjv7Pmup3eX//97/RcdTAwMTNW0Wn9+lS5WW6V225cdTAwMGbe9z/4+cePf8Z/e+7TLVx1MDAxN127XW2Wx1x1MDAxZlx1MDAxOJ+a3oopNXs07rTHt+VScqWJXHUwMDEwkzfUe6dwO7dcXIKzXHUwMDE1u9krT8+YQ0fts4vicyaETntcdTAwMDJdP6ZcdTAwMWGaxFxuenrXSr3ZTLmj5q/vZFx1MDAxN2v9rmdMPbfrNMpcdTAwMGb1kluD83jm+ORzPVx1MDAwN57A9FNdp1+ttcu93pvPOFx1MDAxZLtYd0fmXHUwMDE4QpOjv1x1MDAxZcJfP6ZHfsJPIUqYJbHgSFwiheCl2OT8+FxuWDBpcYGFoExjXCLpzNBOnKbTNUP7XHUwMDBmLptf08FcdTAwMTXsYqNcbiNslybvcbt2u9exuzBl0/dccl+/tJhcdTAwMGWtVq5Xa+7MwV55/Oy11FopyeXkhLlL57w0loK/vVx1MDAwZqddenk47X6zOVx1MDAxZJg5XHUwMDEx8UjO9DP9Tsn+NcNYKIFcdTAwMTBnSjA0fSTNersxe7mmU2xMhWJ89N8/PyCMxPONZoSRYjiJXHUwMDExWVlcdTAwMTafeUdF++FiKnxxc96oVNmoXuxcdTAwMDReXHUwMDE2XHRTllx1MDAxNCCLhHGBXHUwMDA0kmROXHUwMDE2lYVcdTAwMTEmXHUwMDE45kVRprYmi2yBKLJ5SeRaU61cdTAwMTj7dpKo/SRcdTAwMTE0gEJMS7a6Xrzvpt1o86adIc+hYmSAn6q1XHUwMDA0+1x1MDAwNrKoNieL/ymQXG4pXHUwMDE0PiqHXHUwMDE4S6o5w0h9N0FcdTAwMTTUV1x1MDAxMLVcdTAwMTScUjBbK1x1MDAwYmK0/Zx9XHUwMDFllnN5XHUwMDE0iuNwI1q8umtcXO6/IEqyQUG0eUlVKlx1MDAxZlx1MDAxNUSmKNgp7flS30VcdTAwMGV9TbNcIoxcIiXI6vqwkL7C0ZPQaJTMXHUwMDE0k6mbSCvciEe+gVx1MDAxOIpccophpVIuav1hfci4ZpIxKvZPXHUwMDBl3fJPd6FcYkpfqoJcdTAwMDVhgnO+hn/YS1x1MDAxZodEIalcdTAwMWZ7l91UMn9/g4poXHUwMDBm/ENcIlx1MDAwMuJcdTAwMWZiYoFOQEQppLimks+LJYG3KClcdKGKY4xcdTAwMTgnc2KKKIBcdTAwMDbmLZhmeypcck7bTdWfx6ImLUVcdTAwMTGWUlx1MDAwM1uER+yxz+ZdUbtVb47eTPBYnuFxoqM3h8LNetWI9VGzXFx5K+9uXHUwMDFkaP7ktOt4xLJcYrew6+1y97w0O3SnW6/W23YzPX87+Kbls1x0n7RcYvfMeq9szo5cdTAwMWbSx1Cp8OzRqYNCpZLmr9VcdTAwMWSU+0j1pjPKh+/j9ehj1+43htF88FHJlqFyg5ZhXHUwMDA3qCSgSjVccqhcdTAwMGZzXHUwMDAw5Sqg9GdccphyQlx0Qau7a+mzYuo027jr313V71x1MDAxM5XjXFynXHUwMDEy2lx1MDAwM/q61FTuXHUwMDE1KFx1MDAwNdJcdTAwMWNcdTAwMDEsd+fQca+oXHUwMDFlQLlcdTAwMDFQstmjXHUwMDEzUDJOudBcbq9O5Vx1MDAxM6dccpLouNlonNdcdTAwMWFcdTAwMWRUuY4l6F3wQbnUUm4wpvR+fFx1MDAxM4RcdTAwMTNrXHUwMDAyXHUwMDE0STEpuCBcdTAwMWYzlVx1MDAxY1x0Slx1MDAxNFe7s5VfXHUwMDA3S7xbWOJdwVLMXHUwMDFlnayAUZheydYwlZHiVfsh25S1mo1k7zEnu8XjcPBRudRU7lx1MDAxNyo18MoxZFx1MDAwZaCcXG72noGSz1x1MDAxZX1cdTAwMDUlzLcmiNPVTWWe4cJ5pI+cXHUwMDA0i9baXHUwMDE37YfM+ShcdTAwMTF8UC4zlVx1MDAxMu1cdTAwMTcoOXwjscuI5Fx1MDAwMZQ7s5Qgg5RcbqH16qjMXHUwMDE1q1x1MDAxN4n2qFx1MDAxMrrrdk7i4evm5UUqXHUwMDFkfFQuM5V7hkrAi5RcdTAwMWPOXHUwMDFmbOW3hCUmkmFOVl9cdTAwMTe5L0a6bHB6XHUwMDEyPS/fRYrk+Ge2werBh+VSY7nBJeJlwVx1MDAxZaGZZoIojT5cZkrKKVVK7S6t5utASXZcdTAwMGJKsitQ+q6XXHUwMDBiXHUwMDBlrlx1MDAxMOd49VWRdu7iXHUwMDExq24yKZ6qTid3nz1uoT2I9Sw1lfuESZPlJlx1MDAxMFjaXHUwMDAzJKdi/V0giVx1MDAwNVdCYbZGXHUwMDBly2WFjlx1MDAxZaJcdNuuZduql7xcdTAwMGI9ZVLN4GNyqZ3cYf7ApzHJkFx1MDAwNiVcItnv4Lv+fpjkUlx1MDAxMVx1MDAwMUK6uu9cdTAwMWGJ187il/gmkmo+t6utTKKE73DwMbnUTu5cdTAwMTMmsdBASrHGXHUwMDA3Q7m/oPRNtKMw64pQPiUmyzB5XHUwMDFlzTjPvbZcdTAwMWJ7qIRy1zTf1K3+IPiYXFxqJ3eYPFx1MDAwMH5cdCHSuMtcdTAwMWNzjzpcXFx1MDAwM5VcdTAwMDRRxqRkO9yo8XWopLtFJd1cdTAwMTEqPbHVWVNcdFx1MDAxMqilIGvsXHUwMDA0uFx1MDAxY0Uj3frIuXDOO9l0uow6PJxcbj4sl5rK/YIl5lRRTJDYXf7rXHUwMDAxlruDJVx1MDAwMj9KUbmGXHUwMDA3e/0sRvdnuXAsWWaF6HVEs/BzKfiwXFxqLXeVP7AhWFx1MDAwMmCAfvBcdTAwMWRmwFx1MDAxZWC5M1hqwpVCeo1cXDs2TLpcdTAwMGVqpFx1MDAxZYe5tN2Tt/F2fNRcdTAwMGY+KpdcdTAwMWHLfUMlYojK32Opcr9ROX7XXHUwMDAyVDLin1x1MDAwMqsxJ+YprW4sXHUwMDBikWpGnLFu5qKTzz/Uonb3XCJcdTAwMWR8aolcdTAwMDF1mmHNXHUwMDE0/OZcdTAwMTRNXHUwMDAzYK+rXCLSXCJUS8U5p9KzcrvJXYSrbe5cdTAwMDfRJUBcdTAwMWFcdTAwMTc7py/H4GiyZVx1MDAwZsvhzujx7jKn9VnpPH5zNzh6OVx1MDAxZlx1MDAxOFj2XFy761x1MDAxZdfbpXq7OvuRcrvkc6Zp99xcdTAwMTOn1aq7MIykU2+7s+9cdTAwMThfN9ztOsNa2Z5cdTAwMDNcdTAwMWFcXNn3XFzHXFzu7bOc/u/HdI7GP0z+//efy9/NPG//w/vv+oD1j89qxaWgbFxy0plHhVx1MDAxNqnl6mLwM2azVDv0VG1cdTAwMDQ/t1x1MDAwMGuwo2aTXCJCUlKp8Wx9XHUwMDE4XHSGVlx1MDAxMICKZlx1MDAwMqzpp+rD+FwiXHUwMDE2bKikmGJcZnRcdTAwMTHGXCLZ9DZcdTAwMTNcYiNLMo1cdTAwMDVcZkJcbqw1wp6n9LJiQjGT0m+/5Vx1MDAwMdKe6349pP2n3LxC87O9KdBz3+RbXGY4UFx1MDAxNKs18vwqZ7e5XHUwMDBiJyfVU9+9jJepuuWktVx1MDAxN6A3gsmk0lx1MDAxYchcdTAwMDSdXHUwMDAzvdBcdTAwMTZcdTAwMWLXhuKYaI+aXGYg6JGWXHUwMDAwep9l0lx1MDAwM+g91/3+oPcjzNLf0nOFwP6pNVxmfXo4ynSGjV7nuZpcdTAwMWLc1o5j+jopgod5U/eNXHUwMDEyxjA8cSkxntZ8+kWgYS64XHUwMDEw4/NUYD5v+Fx1MDAxObbg88DujFSr2aFuRlx1MDAwN0hiKS2YYlx1MDAwYjeOYmzUXHUwMDE0MFx1MDAwNaakgK/hseyTpVx1MDAxZrNxlHBcdTAwMWY3PjC49tBmbXE59qeQoJp5q8K9Yc1TZfjKmod2t3Ntu4lKpVd2d8ugfW49y6Y9XCIwIdP6Q1baa1xm5qy0MDZarlF4gURoJa9Q+p7EOr1erVLIXHUwMDBm+Fx1MDAxZVSGQshcIqBcdTAwMGY1o1Qqxub3rWlcdTAwMDBcdTAwMGZcdTAwMTdcdTAwMDSYNuJb8sxcdTAwMTfWbLRcZrdn0uyIIEKQ+WCWJIxyLnxcdTAwMTZkp0aZRjDnxftY/SnfiyZT7ZNBk1x1MDAxY4zy5MpfQK7nJ3dDfre3pODcXHUwMDA2cfC8Ned6dUTfVvqDZ+e4L5Pdu97orpejXHUwMDAyXHUwMDA1v1x1MDAxOCtBXHUwMDE0ni6WII/waLXypFx1MDAxMU0hXHJSKsFcdTAwMDFcdTAwMDfvSG3H7+agVSRChFx1MDAwMK1cdTAwMTfaY1M9XjdFSoJ3YKLr1NQwmkM42DAtwWgvToQ6INxz3a9HuN+Mm1dofrI3hXjpXHUwMDFiXHUwMDBmJ1hcIlwiyVx1MDAxYeFwJz2q3MZOn5Fg55yeI11cYlxy96B2XHUwMDEyXHUwMDAwnnLQp1x1MDAwMoSTaM2nRvp1Q520gOwgmFx1MDAxZKoo3Y5cct9cdTAwMDTgwVx1MDAxMVFCXHUwMDBisiy2dsD7t8e7XHUwMDFmycbghPpcdTAwMDFeXHUwMDE46THxppVcdTAwMDF/3LsptFwikYtcdTAwMDSPsjuJUVJcXDqPwVx1MDAwM/wylk2ppc26LrjElFx1MDAwYi+tea3TJC3w7Vx1MDAxNWZcYiH9ufLrvlxuQDFcdTAwMGJjXHUwMDE4psKL0p6Xs2xcZlx1MDAxNMNsgN6jxelcdTAwMGaz7Hr74et4tu/NP8+0/VDLkW8xcMZcdTAwMTHjpnLTyqBcclx1MDAwYvdp0Gd65LTSlWbyisnTcyd4oJ3tkYCRZSp1SkxcdTAwMTTVks1GwpQgXHUwMDE2+CtCSvgj0Lai4dhSXHUwMDBig2DamiPYQmEqwbnYo21cYuLN0XdcdTAwMDDodEvl7o//+/E//Cf6e7fw87n1KuDD/EPo8zbcmDWZkiOw33x19PWr1ZNcdTAwMDJ6anDcXHUwMDFjUPQkXHUwMDFmMr1eLvjo42BcdTAwMDFcdTAwMTGXinJqfJY5+ElcZkZcdTAwMTZgyc26XHUwMDAwUp/K5PKFn7ZcdTAwMTZHoFx1MDAxN4BcdTAwMGbMitJcXLF9ymteXHUwMDE5fde7Rdz1TlBGkW/sXHTcX4bW6Fx1MDAwMvSEdSleJbVcdTAwMGKV5vQxcnxcXFx1MDAwZvFi4DFGuFx1MDAwMFxiXHQsJFx1MDAwMT+Pq9m4k6LE0lowrLCmniqcX4QwbVx1MDAwMvxqlzV6d1x1MDAwN7DL3Vx1MDAwMuxyo1x1MDAwMLNccn1d6EUy31hcdTAwMGZcdTAwMDe5MvX6V19gfYjf50JXPD9cdTAwMTioq5M+PSvfXHUwMDFl12qBx1x1MDAxONXY0uAjamFiJVx1MDAxNJFZkHGhLFxubFxcmk5bcjsgw5xbSDHMpdaIMVx1MDAwZu2ZXHUwMDAyzlx1MDAwMqeecUU0wlxiK6/P+Vx1MDAwMj9FXHUwMDE1MFkgi8GG30xcdTAwMTDnVVh/TNrGjdFx1HpcdTAwMWGmnFS5UspWU/XE8elj6bp1MXlyY8BcdTAwMTb7ZpQhbJlGO1SDXHUwMDEy5EBcdTAwMTnBI/e8q2p3xkpMXG6g0ZKARlUw1/LlXHL/Tka1s/DRLFx1MDAxODebqOEnR+ZcdTAwMTWaXHUwMDE3oen1/vD+u7ZcdTAwMWVh2tdcdTAwMWbmXHUwMDE0XHUwMDBiSsVcdTAwMWEhpHNAjOvW+JNO3kc7jqg8ZvRoXHUwMDBm9Fxis1x1MDAxMLA7yoRGiLLpZV7VXGIzKdbgMUvxJoS+7WVf8NOlwVx1MDAwN1x1MDAwN4+cU1x1MDAwZvecdFx1MDAwN1x1MDAxMqA4XGJcbvgu3Fx1MDAxNdVGJ3ImnPxtXCIyUG6vd1lcdTAwMWOe23W8SG0g8Jy00EBcbkyvIMGmztVEa1xiXHUwMDBiwVxcXHUwMDAyWPS4YVx1MDAwNNWvePluauPtUvKswKypJPy8eSZ8Mzi5McCYrq5cIt43XHUwMDBlgVFcdTAwMTEmykyoXHUwMDAyp1x1MDAxOFEgnjNBZkaZXHUwMDA1elGAaGkutPRcdTAwMDR5X4LMhFx1MDAwMcWmIINcdTAwMTJLxrbkepiEstX8e3AuxvZ1g1xmelwiW/94JHAlf/KNzP1cdTAwMDLH5My/r4JcdTAwMWFcdTAwMTj6kHK79VK59CP8s97bLZNYfOdccpCK91v4Yl9iobQ2OUZr1Erm9HSYXHUwMDFkZZxyXCL3XFzKjKrDarhcdTAwMTdcclx1MDAxZdrnNiBryyxjMDLuN6TFXFxcIlx1MDAxODNcdTAwMTE0cNSAhXHv0u2H0F3UxUXoVpagQlBwe4Vk0uOVTDc5XCLQQoxISbQpsTOfqi2ZXHUwMDAyr1x1MDAwNlx1MDAwNbRMx4dsXHUwMDEx9fdXjbpcdTAwMTVar8F7nzulXHUwMDFlzvbtylx1MDAwMDux1EUmnnjO2nsgnsKYeVx1MDAwMp5cdTAwMWbRXHUwMDFjeWNt065tSHCMMdJC8E9Zn1x1MDAxZPSH0mDIJJc77Np26EXzsY2471x1MDAxYlx1MDAwZU+DhLnNfeCwXHUwMDBicEpXh2Y27PaSjcK1LvZzIVx1MDAwN+OLZD+VXHI8NPEqzd/Be1x1MDAwNJHhRCr82aWVxaZDLSCTXG5ZyPNcdTAwMTLTgb2ucGKgvmDrvpGtXHUwMDAw9ufLWyRcdTAwMDdLgtfoxTI8e4zk2s3n6lP4vi+StexIVIK/0o6pNlRcdTAwMTgojDRdLsQ8U1HY4oopTSnSTHyWqSxcdTAwMTbIddb6XHUwMDEwxzBcZipcdTAwMDNcdTAwMWVcdTAwMGL9XHUwMDEwl1C71f5qo6zBXHUwMDE3Zf71/UD6wCogzFZPO3WPk6n7m87xYNh4UPfXP0+d03zw0041scBcdTAwMWTQXHUwMDAwMMqJwni2XG6DIOCRXHUwMDExXGb+XHUwMDFhXHUwMDAx51xmbUftr4EyrrTxuPappdHvXHIySv1BxiU1S+prVJuOPV2JzFO+416oQXyIYvFcdTAwMGW7I8FcdTAwMDdcdTAwMTmyxmFf002WUTq33Cck8FwiRM1KmlaUfypOv33aQ8DcSqzFb1FcdTAwMDJ+v2mPb1xcnPuCklx1MDAwMC1nXHUwMDE4c7p6adtcdTAwMTSr3UVtXHUwMDFjP4mS1nNOnkd55PkseKBcXFx1MDAxMlx1MDAxOWeWXHUwMDAwjCpcdTAwMDCjxsxbXHUwMDE5ZIpRPF7pJVx1MDAxYSH22bU0XHUwMDFmQygsrsFcdTAwMWOLxemd1MKaKilgelx1MDAxMFxmgsyvqzGpTNRtj5Bp1sTAiYcxXHUwMDEzU2nQs0N0ibXs/VxuMl/bbqpmd8q7xanvzVeypPhDoOXcl1x1MDAxNGqQatNUbnVvNZGu2HeMX2edVJrK8u3NqHDWXHJcdTAwMWVm56JcdTAwMTRcdTAwMDBcdTAwMTBEmOacmaiMntskxYE1KrBezGyl0NtxV1x1MDAwMX6WxlxiL+41tlx1MDAxY6UmX4Yy7Fd96HvB1ICqXu07/d7XIPW9+29cdTAwMDCsfjtcdTAwMWGF9t0rYchcdTAwMTRljK9ROyhcdTAwMWaLRirVRkXZKeeckEimdXlcdTAwMWPAXHUwMDFkTthcdTAwMDJqXHUwMDA21lx1MDAwYjxcXGmkxaOwfsX/tSlTXHUwMDAw9lVwqkEs8WxMh4NcdTAwMDE0nFMoxYjcUjEhYcHouKJSXCJcdTAwMTgh3LNcdTAwMWPyrFx1MDAxY05wLJFFOUem2Vx1MDAxMVx1MDAwNaaC5sP/XHUwMDAwXHUwMDA3cJD9wv8vx+DoQylvn7f0qFmisVx1MDAwMrvqh1MoOzpsdHy98m42OobenXbzmpvw6SX/8P679uZm5LvXXHUwMDExm1x1MDAwZXdcZmG8urP9XHUwMDE0UonujdNNJVx1MDAwNuJcIpzhtXv+XHUwMDE0wH5LS1WBwFx1MDAxNsWUKSSQYt6KXHUwMDBmL+42xWBxsFx1MDAwMoGWpubQVjRcdTAwMDHQYZNrJc0kmHQgvsCcI0OHMUXERJmBXHUwMDA1wO85r9ssZ75Jhj/ogcDqXHUwMDAx/zk3r9CC6d6UXHUwMDFh8E9PIWBGtDa7uFZWXHUwMDAzXHUwMDE3z41q+fS6dc5cdTAwMWPWzPVcdTAwMDY5J1x1MDAxM1xuYC/vZWqAXHUwMDAyzDGjIK+CXHUwMDEwosEvXqRcdTAwMDdcdTAwMTBBXG7Oacm3VlN0XHUwMDEzelx1MDAwMFx1MDAxM07hqyxe/zmoXHUwMDAxz3V/YzUg/OtcdTAwMGKCXHUwMDFh0Fx1MDAxNDO8Rppau4Hr0btirHo/inXL0aH9WL053j81oIRFXHUwMDE55aZEXHUwMDFmeNVcXM/ygnEzcXhxboojeFNcdTAwMDC2m8guqcWAxFx1MDAwMyPBXGY09Pz2XHUwMDE3XHUwMDE4XHUwMDBm42op5I/z3fSjXHUwMDFkO39uqIdaeXSKqjdx51x1MDAwMPnXK39B8bLZmd2UkUf+m7RcdTAwMTFcdTAwMTamjePK2E7epNONc5c2Q4nrjFvqkPbPenv/sM2ZJeV4vUlSJqmaxkVei4dSeFx1MDAwM5NcdTAwMWGugpHckqcvrPFkXHUwMDEzXHUwMDA1fjpcdTAwMTXeXHJr3spGXFxcdTAwMDLlMJRcdTAwMWVcdTAwMGJNkJjH+7g0jV5ayuyA91x1MDAwMODdf87NK7RgujemXHUwMDA0/Fx1MDAwM/WIUSbWsfDF4nEjo93cdbVw2Wb5XFymXWxcdTAwMDWwhPAyLVx1MDAwMF/Zklx1MDAxY1g/14RcdTAwMGKO0JyJZ9xcdTAwMTKKcZN4Y1IuXHUwMDAzrVx1MDAwNjSlZpDLXG6JXHUwMDFm1MDvrFx1MDAwNoTwX1x1MDAwMoA7aqLoXHUwMDFhi+z6odHs5q8yjdFxol1P93UnW0DB01x1MDAwM1x1MDAxNOi8wbgwtSM0fM2pP/TSjMvS47RcdTAwMWbMQVxyXHUwMDEyT1x1MDAxObhcdTAwMTe+z7Flmu2g8TaWLVx1MDAxNVD6gKevKWGCcZ/kzinkI444d+LDk5tG4T594iYuo5Gr7lx1MDAwMfKvV/5Gnr70R7dcIlx1MDAwMlgqXWM/T/VGUzZcdTAwMWHE+mSYLZw0XHUwMDFlO7H0aVx1MDAwMFx1MDAxN/iWoVx1MDAxYow8XHUwMDE1Jn1cdTAwMDWeNmeEzqLb9PSimFx1MDAxMK2khoe0XHUwMDFkXHUwMDFlvyEjj6UkSnLlUzntXHUwMDAwec91v1x1MDAxZfJf5+wr/zJcdTAwMTTmNuu0XHUwMDBiSYZcdTAwMWaqXCJ8XCIj0bCrXHUwMDEzLaqzXHUwMDE3jec91Fx1MDAwMtiiSlJJwc9XZJGNh+kgJulcdTAwMWObbbmBdvXh+5tujH5ccuNcdTAwMGZawHPd30BcdTAwMGL4b+59p2VcdTAwMDFcZomZ8jcr64E2wpFTXHUwMDAyt0KOO0qoaPdM1uLB01x1MDAwM3PVXHUwMDE5XHUwMDE14Fx1MDAxZSMy3qBG51x1MDAxMmgxePYwI0TBc5eftv1+qXkgXHUwMDAxPol5Zmu8XHUwMDE2nDDFwFx1MDAxM6FqbiPhOP9cYjRcdTAwMWLdyVx1MDAwZS5cdTAwMGU49rSB/VhenrJcYlx1MDAwN3E2ulx1MDAxNlx1MDAxZSlTK9cunqbFpZr1YvmLilx1MDAxOC9cdTAwMWbFKil66l3Ivrv1l1x1MDAxMN/cXHUwMDFjpVx1MDAxOVx1MDAwNydVrVx1MDAxZbAnXHUwMDExme2epa+eMlf96DNrj2KC+nX4K3adXi9Us91iLVxi0CXM4pyYiFx1MDAwNDFNk/jcNjAmLKZcdTAwMDUxXHUwMDBiplx1MDAxMqnPXHUwMDA16X9t5J1cdTAwMDcvMDW4vpZKMlNiXFwsaPKHTdpcdTAwMTBcdTAwMTfUaG+Qdm9cdTAwMGbSSVxuPDHp+z7dXHUwMDA3vtpcdTAwMTB/zMX0KKS5LVTctIfhZPXMb1x1MDAxZCudx3ODRJhWXHUwMDFh92fd4UXjtlx1MDAxNsB+OEtcdTAwMDPKQlqaXHUwMDExUyxCXHUwMDAzUsl8XHJFLS2OqGZcZoRXb6n42YYyRzhDZvVrN4bn4D1cdTAwMDYtJ0Syd1wixXRcXF9TrLFcdTAwMTG5VCbZuFx1MDAxYmnf5s9zg+Z1/H5QXHUwMDBlIItcXCU1XGZcdTAwMDGLNG0lwSBcdDybXHUwMDFhZlx1MDAwMFx1MDAwZXOBNEXSlFFccjLANaFCMFx1MDAxNVCrdFx1MDAwMLh5bVx1MDAwMOC+9FD6u5nYVEKia9QqzVx1MDAxZacq6LRYdE6KXHUwMDBm9nA4QumL0+BcdTAwMTdcdTAwMWEgXHUwMDEy7LWpkaHAdSPIW1x1MDAxNPRcdTAwMDXOipq6YcZcdTAwMDWFP57o+kZcdTAwMGLMXHUwMDEwXHUwMDBiSynp4iayS1x1MDAxOVwiXHUwMDAxv1x1MDAxMlQx26PyXHUwMDAzXHUwMDFmZogvW1x1MDAxY1OtcusrqOE7t/88J/QzxFj44lx1MDAxNMyNXHUwMDAyxUHWXHUwMDAw6k+nkozHT2O8clx1MDAxNkftQTrefXL9PO2g8UFqackxqEKAXHUwMDAzVWo2W8Ns1JJcXGiTt1x1MDAwZVRtO3xcdTAwMTBcdTAwMGLQyJpq8H4450J4lOjU9s5CXHUwMDE0c1x1MDAxOJfJwz6Y2q+NxPrN3czHP+c4+y++YE206Z+6Rth1II/laazjNGotUXk4LtxmunpcdTAwMGaTqcHIWkphY0SFKXyqZiOxYFx1MDAwMC2EXHUwMDE4N42h3mB7o+3ax1ukXHUwMDE1XGaEguZcdTAwMDSDu3iXJVx1MDAxOH1cblgyez84jGnOcVx1MDAwNi9NaO1XKXxcIkZHXHUwMDE5PMzb+eYolkup7OnD481VMzk8LL+8XnlHeyzfnXTzmp3uXHLpXHUwMDAx9U47SaWV6ae4elx1MDAxODdD0lrW4+Ik1c6HSvHHIbuxK/unXHUwMDA2sNCmS5w02ShSMqRmY7qgXCLAN+Smklx1MDAxZVx1MDAxMVpsq1fdZrZWgdNtdsbgZclcdTAwMThcdTAwMDc9XHUwMDEwXHUwMDAwPfBlgTTl373rV79cdTAwMTaC2Vx1MDAxYes56KSAyGn8xlx1MDAwZdXq7lXlvn19XHUwMDE13j9FQDhcdTAwMTBrwbSEXHUwMDA3zUzlibmmlVhbXHUwMDE0UaVcdTAwMTHS3lxugFx1MDAwMdRcdTAwMDOSXCKMuE+p14NcdTAwMTbwXFz3N9BcdTAwMDL+fWp9WYGpboY0Xaeec07G083IRST98+zutHVcdTAwMTmWXHUwMDE3J05cdTAwMDC3Yc2G26g23eKlMpmob2r1vWCecEtcdTAwMTJcIk27NXjvlrIxXHUwMDEw6CZMYW5cdTAwMTfWM1uekCGx2TNcIlXAufxmXHUwMDEzMtL1Zjneb13b7o7bdixcdTAwMTnC9sJuXHUwMDEy+7vvprC3VnL1jVx1MDAxMj/zvKNKTdJH+fthe9Sxr7LJu+DhdanVpsQyO4tcdTAwMTExXHUwMDE5XHJIqLnKXGJcdTAwMThbJo+amlxuqlxmb2d923hcdTAwMGVcYoZhXHUwMDFhgVx1MDAwMlKVXHUwMDBmi6fK4lxcXHUwMDE490pKwoB0zKdlgFx1MDAxOaCaL/Xf4y3WrcQ7PTfUL+tyzi7kzyPVg+V+vfKOePy7025e81x1MDAxM74pXHUwMDBmnvj2YiDgLUhJ1+hyXflcdTAwMTmOdlx1MDAxZvvPTi17e376XHUwMDE0a1x1MDAxNJxEfVx1MDAwZlVcdTAwMDE2XHUwMDBi3VJhaYqm0bmFcPigZVx1MDAxYU1TwsDbXG50oouJ5Jlm5ctKplx1MDAxZNRAXHUwMDAw1MBcdTAwMTfmw/iWJ1x1MDAxNlhzgldXXHUwMDAxI6dcdTAwMTGPSVx1MDAxOc7ooZPuN1x1MDAxMlx1MDAxN/FhbbCHKoCbMvzwlM2bhLeGjKdMXHUwMDEyqFx1MDAwNuD6SlGynZD+hnRcdTAwMDC4NFx1MDAxNJw6Qlx1MDAxNi+iXHUwMDFmlIDnur+BXHUwMDEy8GPxXG77VlFcdTAwMDBiJUwhoDX6hDdGyXzpKinS0Zx9O2pcXHSvXHUwMDFlesFTXHUwMDAzcyReWubZM1x1MDAwMnBcdTAwMTFcdTAwMWHPlk5cdTAwMTFcdTAwMThZMC9UmlxmV4S301x1MDAwMlx1MDAwN//atyFcZs34XHUwMDEwideMXHUwMDBirfVu+uJwUJCeIPCXkfgrxy59+aZcbp9BfJ7I+0be/HtcdTAwMWOCjIDdomuUOHy/rXJQMTvuyc2VZkgrzlx1MDAxMOFzeW5UW2LcP1x1MDAwM2BB2Zb2QDPQXGZcdTAwMGKrky/sXFxcdTAwMDW62zQ72lxc24CJXHJY0GP3/V7rb+Tsgz12P4v/1Vv2nEyw9sPeeZtd35uv087nj5fHeWR3OilcdTAwMTdcdTAwMWXm0Wuj86NBvTw8Xih05mWmZaxcdTAwMWZcZlx1MDAxMstm1v/5949//1x1MDAxZrZywsoifQ== + + + + 000111122223333warpMatOffsetinWarpMatOffsetorder = [1,0]MKStrided Axis0880stridedMatShapecontiguousMatShapecontiguousSliceMatOffsetstridedSmemOffsetcontiguousTileNumMatscontiguousLoadMatOffsetContiguous axis diff --git a/third_party/enflame/include/triton/docs/backend/ldmatrixOperand1.svg b/third_party/enflame/include/triton/docs/backend/ldmatrixOperand1.svg new file mode 100644 index 000000000..2c9891e1e --- /dev/null +++ b/third_party/enflame/include/triton/docs/backend/ldmatrixOperand1.svg @@ -0,0 +1,16 @@ + + + eyJ2ZXJzaW9uIjoiMSIsImVuY29kaW5nIjoiYnN0cmluZyIsImNvbXByZXNzZWQiOnRydWUsImVuY29kZWQiOiJ4nO1daVPq2Nb+3r/COvdrk7vnvVdXvVx1MDAxZlx1MDAxY0BcdTAwMDRUUHG61WVcdTAwMDVcYlx1MDAxMGSSwYGu/u/v2lx1MDAxY49EIFx1MDAwMlwiXHUwMDE4jtJdR01Cpr2eNVx1MDAwZv/8sbX1o/fU9n78tfXDeyy6db/UcVx1MDAxZn78abffe52u32riLjb8u9vqd4rDI6u9Xrv713//O/qGU2w1fn7Lq3tccq/Z6+Jx/8O/t7b+XHUwMDE5/lx1MDAxYrhOxyv23Gal7lxyvzDcNbqU4Gp861GrObwsNYJKoShcdTAwMWJcdTAwMWThd/fwej2vhLvLbr3rjfbYTT92d3vJh1x1MDAxNL+vypy4fsjc5/PJ3f3RZct+vX7ae6r/fCi3WO13XHUwMDAyN9XtdVq33oVf6lXt1ce2v3yv28JXMPpWp9WvVJtet/vqO622W/R7T3ZcdTAwMWIhL1t/voW/tkZbXHUwMDFl8a9cdTAwMThnzFFGcEmVpkxcdTAwMTMtX/ZcdTAwMGbPwIhijmFSXHUwMDEzoaRcdTAwMDaQY7e226q3OvbW/kM9+9/o5lxubvG2gnfYLL1cdTAwMWPT67jNbtvt4JqNjnv49dBqdGtVz69Ue2NcdTAwMWK73vDdM0KFXHUwMDEwXG7MaI+9TPugNKSDv4Nvp1l6fjvNfr0+ujO7I1x1MDAxZaCd0Xf67ZL7c4mpMopcdTAwMTDJpUA6eNlf95u346ert4q3I6pcdTAwMThu/ffPd5Ajl1x1MDAxMEaOnFx1MDAwYmDMcDE3Nd5cdTAwMGUqO4mrZpxmzzJcdTAwMDfn3mkqJi4uN4BcdTAwMWGJI1x1MDAwNKPCUGKUwl/GqFx1MDAxMZfeUdxoxlx1MDAxNTXUrIxcdTAwMTjFXHUwMDE0Wlx1MDAxNFx1MDAxM6SotTSKSmI2j1x1MDAxMnveY28qXHUwMDExXHUwMDFhXHUwMDFlRoSGcEa10WRuXCI8U7X7zFV5//E21zvNXHUwMDFlpJPF5kM3+kRIzVxmXCJU6yFCylx1MDAxY6Y0YcZcdTAwMTAjgVx1MDAwN1jzXHUwMDBiTVwi92ZGa8a4kZRcdTAwMTIh2TiNSoVLJmhEueWIXHUwMDE2Ws3eqT9cdTAwMTjKXu1cdTAwMThOqNZA8MFcclf81VFcdLfh159eLe+QmPFtklx1MDAxZq82bdf9iqXpXHUwMDFmda/8mth7PipcdTAwMTMvu3ut9mhvXHUwMDExL+H6Ta9zUFx1MDAxYb/1Vsev+E23fjZ5OXxSL/lcIrRcdTAwMWMmXHUwMDAzi9717N7hS3onJMn41lx1MDAxNzVF4FtHUOpcdTAwMTFcdTAwMTHOwqToVL2yq/frbtF9ql30XHUwMDA3941cdTAwMTKNPCZcdTAwMTlcYsdcdTAwMTJcdTAwMDRDOGqJXG7LhFxc4I5cdTAwMTSUIKXj3lx1MDAwMFx1MDAwZvtwwYC0SYFcdMGN0Eoq9i5QakVcdTAwMThcdTAwMDBEVHB8LCjpekFJ11x1MDAwNUo6vvVcdTAwMTcoJSDHllrPr6zdiVi6lUz3suVsvHC3p+XRcfZ28zGpNlxuk4xTq9/w0Y5vTG5cdTAwMTgmdbigNIxRwiWZ355XxUIjmS7w++JVie1fn1RpOXdcdTAwMWF9UGrlXGJmmNXV0TTR01x1MDAwNKVSgDDRiEy1OlCi8qpAgFDMXHUwMDAweSckkapcdTAwMTXeJPtcbrorWy8k2bogXHUwMDE5KiaZJlx1MDAxMilcdTAwMTHmNydcdTAwMGKHrTZ52s5C8ynxdH3ZO+Cxi99cdTAwMDCRaoNcdTAwMTBcdFx1MDAxNFx1MDAxNN7mV9Bbf1NAQrjPW1x1MDAwMcVVN2S06rNcdTAwMTD5cH2IeFx1MDAxOeQzyUqN9XpnRd1umOgjUlx1MDAxYcdIhlCgWihcclx1MDAwMVwiXHUwMDFiXHRJXHUwMDAwhIHUSIGBXHUwMDE4wFxuIElcdTAwMDVjTFPNXHUwMDEwcqi6vlx1MDAwN5SC45dccnxcdTAwMDXFla9cdTAwMTeUfF2gNONbXHUwMDAzXHUwMDFlXHUwMDFlSkCipJxcdTAwMWKUnXIy2fWzkKyT6ulh5367XHUwMDE2N/nfXHUwMDAwlGqjQCmFQMvXXHUwMDA0rOJvVG5cdTAwMTYqjVxujcdcdTAwMTl864DrP7/u6nF5ZKBbXHUwMDE41Fxu9+VMvlx1MDAxYj/tXGbOXCJcdTAwMGZKLiTqjMRQhVZcdTAwMTiVbERcdTAwMTE/MYlKrcMoXHUwMDA1rlxik1xc6WUw+Z9y2StcdTAwMDJM4lEzx4BcdTAwMTJGTI2AUOpcYqOk5MJopXlQN31GoqVogyDZoFx1MDAxMFxiOFIzylx1MDAxMImKg2BSvzroXHUwMDA1ifxcdTAwMTXshlxumdtpXHUwMDFmur3jcrnr9daLypBLjyN0XHUwMDFhQOFd+KRUviE2XHUwMDA1XHUwMDA1XHUwMDAyuFx1MDAxNHMjdO+i5PZcdTAwMGZcblx1MDAxNzvdXHUwMDBlPz2p5e92XHUwMDFlmFx1MDAxZnmEMilcdTAwMWSGooZcdTAwMTOhUU9QZFx1MDAxMqLSXHUwMDAxQ1x004JQVPBXXHUwMDAyUSNcdTAwMWM0XHUwMDFkUFxypdPMytlcdTAwMTBVnKPla8TmXGLLdyPUb158XHUwMDFlRkMvvjqUXG5cdTAwMTmaUMCp0dzoXHUwMDA1fLLmwk1cdTAwMTTy5zHYq+ar243KYN9rxiOPUU7AIai1ojSlUlx1MDAxYlxy4zlWTFx1MDAxYmeIXHUwMDFl6ycyZqmMglCMolx1MDAwNu2YqVx1MDAxMlx1MDAxNJxcdO1cdTAwMTV1Wo73StnmIJKqV1vfgGCrU/I6W/+39T/6J/l7vVx1MDAwMFxmufQ88KPyXfjjXHUwMDAxjjwuJJniaMGQXHUwMDA1hCS96KaLx7nrrpvIqv37ejJ7TeqRXHUwMDA3IFx1MDAxM8pcdTAwMTkm61xiplx1MDAxODLscdOSSeKggNJWyeVoga5cdTAwMDR/4EzXX6ehj1xiIYcq72+IvqP1XCLuaD0oXHUwMDBiJImNo4xcdTAwMWLg3CDS5kZZLHNy2Es/7O/7e6mByVxcqKtrXHUwMDEz/XxcdTAwMDCkI0dJkEBcdTAwMTWg5CAj6/nZgcO5g+qfJlQj0S2XN/chKENpqFx1MDAxOKrMvyPK0utFWfpDUeZ2Oq2H6eHEUGGmhDKG8Fx1MDAwNVKka1x1MDAxNXeQlOn7fLt6l3FPXHUwMDEyve5Nelx1MDAwM7JTXHUwMDAxrSlgSFGSoeHBR6dcdTAwMTmegIJBlCGpUWqIoHo19lx1MDAxZUWrk9hcdTAwMTJcdFxyYGXVyPpcdTAwMTkhzlx1MDAxMVx1MDAxMkneMCCUUFx1MDAxM9Q8X/wyoFx1MDAxMIFcdTAwMTCISK1cdTAwMTJ/mslcdTAwMDBcdTAwMGJeXHUwMDAwf92e2+nt+M2S36zgzl/UuvVSfzKEx4/DSvfqeJ8mXHUwMDA0XGbIY1xcpnKqQ09eXt1cdTAwMTCxxf5w4Vx1MDAxZFRCODJKhVx1MDAxNiSqXHUwMDFjXHUwMDEwOKbitvFcYo18lFx1MDAxYjSKjVx1MDAwMCpccojnXHUwMDAz/n25J69ZXHUwMDFh3dHrh3C7vd1Wo+H38PGzLb/ZXHUwMDFiP2L4PNtcdTAwMTZiVc+dQDaeObhvXHUwMDFji217xlHBjf2MfttcdTAwMWFcdTAwMTHr8I+X3//+c+rRoWRkP7FJXG5cdTAwMWGd74/gz8XZiFxizUqgSilOiYb5M2o97vdbd8lOJ759vttP7JS8hop+Ri1cdTAwMDfmcKYoaGDU5s1N8Fx1MDAxMYKiXHUwMDE0f1xubVBILpWWXHUwMDEwykemVfxI7WjNgUtkXpJcdTAwMDd09V+BXHUwMDE1IJpcdTAwMWFG11x1MDAxM+6MXGbXIFx1MDAwZSBTYMhcdTAwMTaAU1SiRlx1MDAwMcFcdTAwMTe2IXDB0GpA+4ZzS8G/K9t4dfRcdTAwMDS9LMgkwmOyocmE1u+HS6FcdTAwMTYoXHUwMDBlhOJVKkPTXHSVbKVcdTAwMTPZRIvc07tM9HiEcFBx58Yg+FH2XHUwMDA04l9DliFQgedMXHUwMDE4yZVcIoDQXHUwMDE4U/DxO1x1MDAwZbfpRciyXHUwMDExN6sxoyllzpxcbj7eXG5cYio2KFlpbvX+tNfxS15pa/vR765X059+5Vx1MDAwZlD636yNlDQ8bYkqwbVYJJHQ9/xrkfOfxHbhlvBsLH91yDdA82fUkVx1MDAwNlx1MDAwNTNcdTAwMDKQ2pDOtOJILrSgNkWCXHUwMDFiNnZri+KvXGLFafgzXHUwMDBlXG5cdTAwMTjFJdK9tlx1MDAxN5vEXCIjaOdcdTAwMGKmNVx1MDAwM5tqXHUwMDE4jDk9a/54XHUwMDA2Se1cdTAwMDGRhOa75IWAcD+rXCKoUqpFynfT9cvHy1xcOd9MdY87tWteOlx1MDAxZlx1MDAxY4noU+iwctJcdTAwMDBaNaigkWD0NVC+S6zWyZhcdTAwMTJmqXyB1ZdO4pJcdTAwMWFJeESpdKpcdTAwMDD5oqWTb7d5IOG+WWpcdTAwMGJ9kZnO75utkZujdLJau9xcdTAwMTP9zqW+yjbvXHUwMDBloy87UKtzkMiERk5EXHJcdTAwMTV8aptcdTAwMDclQCNcdTAwMWTyZZ1G0yWHmWLtoZFJXHUwMDAyXHUwMDFmXHUwMDE10Fx1MDAxOZ8zXHUwMDAzbERGS7M+JU5cdTAwMDcpc1x1MDAxNZKCvuF9oEwoQZB3zk2Pg+2LVuWqe9nc3j8x3e1U5Tyuo1x1MDAxZitAXHUwMDBi1uG4qlx1MDAwNEAyjmQ3XHUwMDFlkiOgXHUwMDFkQpH7XHUwMDEytDi0WDZYMJ1cIlx1MDAxN1xuXHUwMDE22PRcdTAwMTXQa3JWLkqIy1lcdTAwMTNmvczffKjdXHUwMDEwXG6zQCXnZODbXGJcdTAwMGWUzu/k47lHLo6usu6T6bVcdTAwMDbCj1x1MDAxN1x1MDAwN6lcXORhXHUwMDA2xOFEgaUqTSBo0D+jTHBHoPaj0GgnxMDYnX1cdTAwMDLKXGZcYs1Rb/tG2WagjJlwlFx1MDAwMTVcdTAwMWMtXHUwMDAxOb8wy1BRUYXrJE1Wj7o1L+VcdTAwMWU+uNnIo8wgJUurb1x1MDAxYiZRfk/6xYA51Fx1MDAwNnPQMuKrLfD7XHUwMDAwo1x1MDAwN1x1MDAxOaNFxvqSpZdG4Fx1MDAxNzV6wltThPrKOFO4siDmN3dah36jUJOJfLLEitvXXHLoPbZ70UPkLMe1xZ9cdTAwMTZcXFFuXHUwMDE4sLGYOVx1MDAwMpRcdTAwMGVDJ1x1MDAxMlx1MDAxNGHBhNVcdTAwMGZcdTAwMTWDypFgY97TkzDxXHUwMDE2gVx1MDAxYq0oKr1AXGKbiHuh0mJbVKyzQcXSyCRcdTAwMGVcdTAwMDByPVs9XCK1JIFI2lxmWdn96WU+dHunVbftrVx1MDAxN6ehXHUwMDE3n0uO0neBNpjrN1x1MDAxZW5cItqmS6lcdTAwMDWMwmSlceb7zXji+PqE5rO78X6zV4xcdTAwMWVqJ5xcdTAwMTRcYmOgQqKE0oBcdTAwMDbXSFx1MDAwYnx2XHUwMDFmauFIqlx1MDAxMcZAbVXSSnCKXHUwMDAwdIBcdTAwMTI6vY/MPDhcdTAwMTW2J846O659XHUwMDFlTi2q/Eq/1e9+XHUwMDBlVN+6/lx1MDAwN6B1eNRcdTAwMTS0UsJDa1x1MDAwM22miqBULFDWsFt97Gbqrr9/PbhcdTAwMWU0r65y/eR5XHUwMDA0y+ipIzVBXHUwMDAxhlqEtuRcdTAwMTJI+/qZZo3g0Kjwg+1cImCIXHUwMDFj9zHadFx1MDAxZkW17XWhIPhcdTAwMDY/tFrQMca2spXWcYQsRHuxQG7uyPHIXHUwMDFkNIylJMJoXG5cItB69SVcdTAwMDdbo1x1MDAxNlx1MDAxZFx1MDAxYT5+3oZbm3uN+lEpee+XMy4hXHUwMDE3JJvd6179eN5cdTAwMWZcdTAwMTm0j6WjjOdurCcr5NW+XHUwMDBmTVx0ib257PYzseCjU/5cdTAwMTH8uTgr0KFcdTAwMDVOTDBGXHUwMDE3ylx1MDAxMzlR7Wr7OJ/bvy7Uk5mb7ElSXHUwMDEytXmcgFq/Pm5cdTAwMDY0Ni0/NFx1MDAxM6xcdTAwMDCM9Uwxhlx1MDAxYTtDYblcdTAwMTJWgPYwoIzWXG5hg7JNyynynFh7mHLCOJpFtlx1MDAwN3awiPm5WSNDfqWkmp4m/s1cYlx1MDAwMuf9fEZcdTAwMTC+5vZcdTAwMTObstxcdTAwMWbGXHUwMDA3WGhYXHUwMDA3hECpXGJcdTAwMGLE/9up087ZXHUwMDA1KT3EWaJ245uH5KBxtXl8XHUwMDAwke1cdTAwMTCKdjdIm7Yoxz1jXHUwMDE0tKNR35eGc9twP8psXHUwMDAwXHJcdTAwMDJ8WFx1MDAwZSGesW8+XHUwMDEwOO9cdTAwMTfmXHUwMDAzMrzj3c9uzUwskG3QvfNUJeZ2szxbXHUwMDE58HSxkDzUXHUwMDExdIjP4lx1MDAwM8Y2TMdcdTAwMDdcdTAwMTc2XCKAb3xcIi+IXHJbxVrXXHUwMDFjmtZiNX1EpqRcdTAwMWZoXHUwMDFi/lwiQG2Td/FquMWveJQwXHUwMDFjkDfNMlx1MDAwMqrJvKYy8XRkXHUwMDBl4v1Y4vbsuEVcbt+Y/3Xm9WD+1dHjS/tB8Fx1MDAwZVx1MDAxMud4XHUwMDA1XHUwMDFhw/UxUs5cdTAwMWZUPvN2+peVq9hcdTAwMTlpcFaGdCO571U3XHUwMDBm3dKaVprYqjSJbJVPXHUwMDE0fkpwXHUwMDE47uaCXHUwMDE45Mcram+gnOFqo0WBiFUmUFx1MDAxZlx1MDAxMpDyXFxqsIVxmlCFt6smXHUwMDEwXHUwMDBmXHUwMDE0+ZOGkHyjb8BcdTAwMDfO+/mAXHUwMDBmX3P7iU1Z7o/iXHUwMDAyMlxcyEtmy4rNXHUwMDAyXHUwMDA1ZKXHXHUwMDFkQprs9rh9UW5UTzuF8uPZweaxXHUwMDAxXHUwMDE031x1MDAwZVx1MDAxOVx1MDAxNvPZLrc2Q3ycXHUwMDBmoDVAUVx00razQTBcYlx1MDAxOUE+MMyNMTys1e03I1xinPcrM1x1MDAwMlx1MDAxYer8XHUwMDAzW/lAXHUwMDE1nz9qd1xiXHUwMDA1f+/GVdl69byW0GfH11x1MDAxN7lcYkbtuENcdTAwMTHlwFx1MDAxNNKjzYhcZryC5zCAw1x1MDAxMP7U9nZcdTAwMDI0eSbSYZRypNRcdTAwMDZ3aVSYoqPso33PheYhXHUwMDAx9lx1MDAxMeTx1Vx1MDAxY2VcdTAwMTJcdTAwMDe15KBHr3QxIVx1MDAxYb7i35D/debfR9nXLFx1MDAxNN1UMFx1MDAwNLhcdTAwMTRcdTAwMGJUnYlO/Py6aiq3542b/qHYzlx1MDAxZu6T0lx1MDAwNsLbOGrozeMgOFo748k0+GZcdTAwMWMtXHUwMDE47jPUUm+kpTyqaVxcM1x1MDAxNpJv+lxy+cB5P1x1MDAxZvKfJuU1XHUwMDBiN/pBgkZyn9+ld7s9qMRcdTAwMWLyoLpdOa+Vzlx1MDAwNzHe3N/bQDZcdTAwMDBcdTAwMGUq8lx1MDAxY5fBMMWRbseTdfD1O1pcdTAwMWFcdTAwMDK2nCXY2jOCfIBcdTAwMTHrLFx1MDAxNmEtgr/5QOC8X4BcdTAwMGa8PaaXhtr+hivFbKLp3MwgP6hdi7vDqjxkOVUuXHUwMDE3/PrV0XZcYjModlrdbqzq9orVz2BcYlNcdTAwMDdHXHUwMDEzZcegckEkm6wopFx1MDAwZbJcdTAwMGYuhJJs2d7gP8tcdTAwMDOn4F+iNlx1MDAwMvjWiaBMikBcdTAwMDXLKNOHOlx1MDAxNFx1MDAxONjsYFx1MDAwNmpaQ1x1MDAxOdufz9r7U7H/2aB+l8BcdTAwMDL6RtNPvFx1MDAxMyM1XYBIq2c53bk48m7i/OxGPPGLw8tcYnbdnemfUtqxbVtcdTAwMDCf3pLCOL3aWihcdTAwMTR31Pazx/9XI68+KCWF4lx1MDAwM6Jcclx1MDAxZPFU8N9dXHUwMDBlfVaMXHUwMDE53iht1Fx1MDAwNG9IL1BzdZdN1SrluEn72+2DuK5X9EVtXHUwMDAzU004J45SXFyDZIZcdFxu00pcdTAwMWSpZLbGw1b8r8gs/Vx1MDAxOHgzYpVcYrPO/r/f+J48euX4XHUwMDBlK1x1MDAwNpE81N60xYTMVirNL769biO906yftuOP/Vb6JqHu281cYrqVJ8YmorSmdkQk0Vx1MDAxNFX5iZotNEA1N4QpgkyBLGtehrSsYFx1MDAwZdVa8+mjZ8CSXHUwMDA3Krc2r8V6uyeaV1Bi7WeIaveKXHUwMDExXHSMSkHMsJRcdTAwMTWfl9swXbBFz3xcdTAwMTVbp1xyr/FcdTAwMTmTLd64/Dx1IOZdcpiL8M4yZGhzyFx1MDAwNXI+i17VyFx1MDAxY82dl1x1MDAxYuYpXHUwMDAxnkzJXHUwMDA3uSG2IHckas5cdTAwMDK0XHUwMDEwxLBJ1Vx1MDAxYW1FlNqaSkFlIC78oaagQoZcZlx1MDAxYyhCzs7LXHUwMDBicIyR7Fx1MDAxZEcomqegXGZf45iLb0E7eXT40o19fTmtOTxB21x1MDAwZV5RlFx1MDAwNYdcdTAwMGXOQmsjNihf7HiaJ7rHeVZh97W9x3L0pOrMXGZtIVx1MDAxZK1cYkObeDjqkI65cdBcdTAwMTBcdTAwMDbr57HY5lxcrcosRlEvmebUtlGwNdLUi1x1MDAwNVx1MDAxY86B5sB4MzahnIB18Wk9MUlq2KaHhvX0f6GkXHUwMDFmp6nH+9jT8Zl/tVx1MDAxYiOF+C3AVWfw7cj9deY1VW29ve72M7niXHUwMDFmxFxyKKGhSjbejUKNboHY7pk4SNfKbqu4r6/P82VVSZ/VXCLYXHUwMDAzfKaLTFx1MDAxM4dcYrA9pvFcdTAwMTVoNVx1MDAxMdOhYFM3mK2LVGrprlxcq63XQE2EM1xi65fwzVxuXHUwMDAy5/18VvBpvjRcdTAwMTM+XHUwMDEwnVx1MDAxMUkoQTDMz1x1MDAwN9ilVy2cdp5cdTAwMWVukqe0bZp+4fzqYfP4XHUwMDAw59QxjFx1MDAxYtTnOdhcdTAwMTlcZuN8gOGCSFx1MDAwMMDduCiRdqbZqlx1MDAxMi2M0lx1MDAxMe/c91xy8Fx1MDAxNTnTdKizXFxISYk2XHUwMDBizL+TnXwv/aBiyVOdT56WLnbSXjZcdTAwMTk9fI/70ox2QCqhqJLKTExWJ4Q7llx1MDAxOVx1MDAwMOO4XHUwMDA0etm87JC2KsQqe1x1MDAxY1d2av+jmb40vG0uZFhhRmRQ/Fx1MDAxMa60UVOTM7/uXHUwMDFk9Vx1MDAxYodub81d/mfcwupcXGpcdTAwMTCeUE2BXHUwMDEwMJIs0L3MkEL5rrZ7fnBeyFx1MDAxN9tcdTAwMDdN2n7Yi+C4StvXXykke2FsSjRcdTAwMGKorM991amjkSVcbi2sQzlcdTAwMThcdTAwMTd6zrk0yiFcbrVcdTAwMTVQRIJeUa5cdTAwMTXyXHUwMDEwyVx1MDAxONPS6tXIyqd3VkHVQILAu9HMXHUwMDE2ULxqePgr0ZrZPn0kRFwijzTzbur6svCUzFxicG9cdTAwMGJ7XHUwMDA3mcpjKeN/a+a/zrw2I/2NdbefyVx1MDAxNf8g3Vx1MDAxY3TogHehmbJNXHUwMDAz5udcdTAwMDWHj4lEPH+SIzF5VL1J9bk5PnqMJC94WzUnttVcclVcdTAwMDT1csGEXHUwMDFln/duWypcYk4sT1x1MDAxMIQsmXa1atVcdTAwMWNQ/1x1MDAwMlx1MDAxMtYq7ZtcdTAwMTFcdTAwMDTO+/mM4PNMdFx1MDAxZD6KiyC/YZIvUGXVrlxc5Z46XHUwMDAzt3dx/1Q/yHZuc4NEalx1MDAwM/lcdTAwMDBD/VIrQNGvXHUwMDAxZEC7XHUwMDFj8Vx1MDAwMWJs6iUqnyranVXoMC44u7PKN1x1MDAxYvhcbmwgzJBcdTAwMDdcdTAwMTLec1HxhXLeWHaPnZWOXHUwMDFlXHUwMDFmXHUwMDEyqp6/LrCrdCpcdTAwMDHR41x1MDAwMVx1MDAxM+OVpGNQYCrCte3WP55xzbhCM1x1MDAwMe18Y+eIsNW0MUawOlx1MDAxNPVcdTAwMGbU96b1R52dXHUwMDEzw9GGsVx011x1MDAxMXfHfawhn2m5tp/wZ2TGzLyJ5Y35t0f3kVAlniowVGhYoHoqk83HXHUwMDFhddM6badau/JiX915jVxitkqdcMAxR3KgwFxiQ6nOJoS1XHUwMDEwXHUwMDBlYsK2htJCLNdcdTAwMDftzVx1MDAxMVx1MDAwMVx1MDAwMVx1MDAxMftcdTAwMDJZMZFcdTAwMTXDkZVcdTAwMTC9zvF8XHUwMDFmVlx1MDAxMfH2JDD2Rtde61xcQU1zflK8z6Xi8uhg9+akXUlcdTAwMTDi32Zo7iz6pCiYg+ok03ZcdTAwMDJcdTAwMWZccsrUXHUwMDExKWpGXGJcdTAwMTIrroxeKrTzXHUwMDAxpEi5QS1cdTAwMTjC+nBEmlx1MDAxNkN7vcvwjrFEKlx1MDAwZVxcL5CI1HroJkspN3d9mSNcdTAwMGZcdTAwMDfl27vG02lcdTAwMDRHXHUwMDBiT5AheZtcZtWayPCjJkVcbm1HhUeSSKfpNN9DU7bGMFx1MDAxOepjsCuOXHUwMDBir9T8ouG4pt3cXHTt1Vx1MDAxZlnWfdKZK5ftRTA5cFx1MDAxY5NUOYYpO3VPo/Wm2cT8XHUwMDA1wVx1MDAxZMXY0O8o6FLZP2+LXHUwMDA2pE3UlYRA+0Gr6YNcdTAwMThmg5JZ36LQNKJqzMeikq5cdTAwMTeVdF2oXGZP2bX5J9o2XGadXHUwMDFilXu7/n7b7/GY2Cnmk/nDnOr3NiB4P1x1MDAxM5Vqk1CpiTTKXHUwMDEwvjnm/zcot16DkoVcdTAwMTeXS81cdTAwMDRHXHUwMDBid35X3F7rtkh7J2fXXHUwMDAzP73b8qredl1HvzqNgnKITYPlikihgiH3kVx1MDAxOcWEbVx1MDAwMKeN9YauUn9VIEDYlpPTXXKzMUlttZI065x0/nmgZOtcdTAwMDUlW1x1MDAxNyjDh/7ZKVUoXHUwMDE5XHUwMDE2cI/vXd70d1x1MDAxZVx1MDAxYd75XHK/yWVu2VMsdv9cdTAwMWJgUm1cdTAwMTQmcc3Q/JVfQVD+pphcZtdeXHLCkWi9QK/w5oGCzn37SVdcdTAwMGJcdTAwMDeJRLp7WUw9xqKPSW1cdTAwMWMhQCOFK1x1MDAwZUDNeDNQKyelbVxurlxmQ1x1MDAxYnupQUCzMEmFzWJCe9BcZkvd3oNKOyrUZtV8XHUwMDA1UPL1gpKvXHUwMDBilOHOV5RcdTAwMWFELDKu6+H8/Oa46rfg8mKvQJM7jeP87lx1MDAwNiivM0GpNlxulLZLXHUwMDA0KCO+hPv190SlVOFVWJRcdTAwMWGh9Fwi1ZiHZ91eOnbT0eVC8+ypJPpcdTAwMTeuieDUnGkhXHUwMDExQFx1MDAwZWSH5CjOJzI7UVQqw8FoiTa2Wan6unxIhGqFaFHsS4jKzY6JvJ2+ocJcdTAwMGIyUDxwm1x1MDAwNzS/wExcXNztXHUwMDFjpMieip+nYtXOoHk8ONXRRyZBOcWVIGg9Kkvxk1qsdCSiwM74XHUwMDEzKFFXXHUwMDA2zfmC5oxIXHUwMDFiWlx1MDAxNVx1MDAxME2B+CYxhtfphtMhRWaFrGpcdTAwMDE67O/FbooxX948Xp7XPCDuWaZRi1x1MDAxZVx1MDAxZM7s3mGIg9JcdTAwMTGVNKQ8yfhEL3ZjXHUwMDFjsMNcdTAwMTUpcMLJco13wicu23Hpxth272C44KFcdTAwMTOXueVcdTAwMThcZqgkREwpXHUwMDA2INx20lazerFcdTAwMTe65Zvtu/JNJlXLxY/263e756Xvev2XM69t4PJcdTAwMWKrbj/j6z0641x1MDAxZsGfXHUwMDBi11x1MDAwNKnwkKCxLWokWyCHq52/oV7sjlxcl9r1beNeXHUwMDFkljJqXHUwMDAzW9syVFxyqXVqKuT5iCAy7vpknDjIXCKRSVx1MDAxMFQsYdlcdTAwMTLfXHUwMDE1N+7AR1x1MDAwMD178vo3I4hcdTAwMDAj+LwmuDTUZESmoznajfNbjLknPchl2rLXKruFVK1+ma+nnzaPXHUwMDBmcIZsWWlcdTAwMGWSXHUwMDFisOX+XHUwMDEzbMCaXGJaXHUwMDEzw0DL5fTUVbNcdTAwMDH8ttBcdTAwMDZmXHUwMDBlYPvmXHUwMDAyX4FcdTAwMGKEunN5KFx1MDAxN1x1MDAxMMZOXHUwMDE5XHUwMDE0XHUwMDBiJFxiXHUwMDFkVrpXx/s0IWBAXHUwMDFl4zKVU1x1MDAxZHpcdTAwMTI9LjBeXHUwMDE2XHUwMDA0wtFcbl9cdTAwMDQqP7a2YFxm9Fx1MDAxNKRy7ExGW5JcdTAwMDNsVUPWXHUwMDA1caZGOik4k74hw1xigptcdTAwMDVyLpe1T1/I6Z9cdTAwMDDdPS9rreJcdTAwMGWSMn2fb1fvMu5Jote9SY+aWbxcIjPXouXHy55//3zrvFx1MDAxZff7rbtkp1x1MDAxM98+3+0ndkpeQ9H5zjvBniZZjU0nXHUwMDFlvbl3erTUq61vlCrtvlRcdG25j/6au42EXnye0iT6c7qLfaXD1/nDbbdPe/gycd/PRftx73tcdTAwMGY7U4nZfuyyXGZcdTAwMTfaXHUwMDAy3LPU9M+/f/z7/3O2XHJLIn0= + + + + 0112233warpMatOffsetinWarpMatOffsetorder = [1,0]NKStrided Axis0880stridedMatShapecontiguousMatShapestridedSmemOffsetcontiguousTileNumMatscontiguousLoadMatOffset01122330Contiguous axis diff --git a/third_party/enflame/include/triton/docs/conf.py b/third_party/enflame/include/triton/docs/conf.py new file mode 100644 index 000000000..6f74ffa8e --- /dev/null +++ b/third_party/enflame/include/triton/docs/conf.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +# +# Triton documentation build configuration file, created by +# sphinx-quickstart on Mon Feb 10 01:19:09 2020. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + +# -- General configuration ------------------------------------------------ + +import os +import platform +import shutil +import sys +import sysconfig +from pathlib import Path + +import sphinx_rtd_theme +from sphinx_gallery.sorting import FileNameSortKey + + +def process_sig(app, what, name, obj, options, signature, return_annotation): + if signature and '_builder' in signature: + signature = signature.split('_builder')[0] + ")" + return (signature, return_annotation) + + +def get_cmake_dir(): + plat_name = sysconfig.get_platform() + python_version = sysconfig.get_python_version() + dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}" + cmake_dir = Path("../python") / "build" / dir_name + return cmake_dir + + +def setup_generated_mlir_docs(): + dst_path = Path("dialects") + os.makedirs(dst_path, exist_ok=True) + + cmake_dir = get_cmake_dir() + src_dir = cmake_dir / "docs" / "dialects" + assert os.path.isdir(src_dir) + + shutil.copytree(src_dir, dst_path, dirs_exist_ok=True) + + files = os.listdir(dst_path) + + dialects = "\n ".join(["./" + f for f in files if "Dialect" in f]) + ops = [f for f in files if "Ops" in f] + + # Add titles + for op in ops: + with open(dst_path / op, 'r+') as f: + lines = f.readlines() + lines.insert(0, "# " + op.split(".md")[0]) + f.seek(0) + f.writelines(lines) + ops = "\n ".join(["./" + op for op in ops]) + + rst_string = f""" +Triton MLIR Dialects and Ops +===================== + +.. toctree:: + :maxdepth: 1 + :caption: Dialects + + {dialects} + +.. toctree:: + :maxdepth: 1 + :caption: Dialect Ops + + {ops} +""" + with open(dst_path / "dialects.rst", "w+") as f: + f.write(rst_string) + + +def setup(app): + """Customize function args retrieving to get args under decorator.""" + import subprocess + + import sphinx + + app.connect("autodoc-process-signature", process_sig) + max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) + print(f"Installing Triton Python package using {max_jobs} threads") + subprocess.run("pip install -e ../python", shell=True, env=os.environ.copy()) + + setup_generated_mlir_docs() + + def forward_jit_fn(func): + old = func + + def wrapped(obj, **kwargs): + import triton + if isinstance(obj, triton.runtime.JITFunction): + obj = obj.fn + return old(obj) + + return wrapped + + old_documenter = sphinx.ext.autosummary.get_documenter + + def documenter(app, obj, parent): + import triton + if isinstance(obj, triton.runtime.JITFunction): + obj = obj.fn + return old_documenter(app, obj, parent) + + sphinx.ext.autosummary.get_documenter = documenter + sphinx.util.inspect.unwrap_all = forward_jit_fn(sphinx.util.inspect.unwrap_all) + sphinx.util.inspect.signature = forward_jit_fn(sphinx.util.inspect.signature) + sphinx.util.inspect.object_description = forward_jit_fn(sphinx.util.inspect.object_description) + + +# Auto Doc + +sys.path.insert(0, os.path.abspath('../python/')) +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx.ext.autosummary', + 'sphinx.ext.coverage', + 'sphinx.ext.napoleon', + 'sphinx_multiversion', + 'sphinx.ext.autosectionlabel', + 'myst_parser', +] +autosummary_generate = True + +# versioning config +smv_tag_whitelist = r'^(v3.3.1)$' +smv_branch_whitelist = r'^main$' +smv_remote_whitelist = None +smv_released_pattern = r'^tags/.*$' +smv_outputdir_format = '{ref.name}' +smv_prefer_remote_refs = False + +# Sphinx gallery +extensions += ['sphinx_gallery.gen_gallery'] + +sphinx_gallery_conf = { + 'examples_dirs': '../python/tutorials/', + 'gallery_dirs': 'getting-started/tutorials', + 'filename_pattern': '', + 'ignore_pattern': r'(__init__\.py|11.*.py)', + 'within_subsection_order': FileNameSortKey, + 'reference_url': { + 'sphinx_gallery': None, + }, + # Examples don't work on non-Linux platforms, because they actually run + # Triton. But it's nice to be able to run the rest of the docs build. + 'abort_on_example_error': platform.system() == 'Linux', +} + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] +html_sidebars = { + '**': [ + '_templates/versions.html', + ], +} + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = 'Triton' +copyright = '2020, Philippe Tillet' +author = 'Philippe Tillet' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = '' +# The full version, including alpha/beta/rc tags. +release = '' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'en' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# + +html_theme = 'sphinx_rtd_theme' +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] +html_css_files = [ + 'css/custom.css', +] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# This is required for the alabaster theme +# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars +html_sidebars = { + '**': [ + 'relations.html', # needs 'show_related': True theme option to display + 'searchbox.html', + ] +} + +html_logo = "https://cdn.openai.com/triton/assets/triton-logo.png" + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'Tritondoc' + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'Triton.tex', 'Triton Documentation', 'Philippe Tillet', 'manual'), +] + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [(master_doc, 'triton', 'Triton Documentation', [author], 1)] + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'Triton', 'Triton Documentation', author, 'Triton', 'One line description of project.', + 'Miscellaneous'), +] diff --git a/third_party/enflame/include/triton/docs/getting-started/installation.rst b/third_party/enflame/include/triton/docs/getting-started/installation.rst new file mode 100644 index 000000000..06c135ddc --- /dev/null +++ b/third_party/enflame/include/triton/docs/getting-started/installation.rst @@ -0,0 +1,59 @@ +============ +Installation +============ + +For supported platform/OS and supported hardware, review the `Compatibility `_ section on Github. + +-------------------- +Binary Distributions +-------------------- + +You can install the latest stable release of Triton from pip: + +.. code-block:: bash + + pip install triton + +Binary wheels are available for CPython 3.8-3.12 and PyPy 3.8-3.9. + +And the latest nightly release: + +.. code-block:: bash + + pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly + + +----------- +From Source +----------- + +++++++++++++++ +Python Package +++++++++++++++ + +You can install the Python package from source by running the following commands: + +.. code-block:: bash + + git clone https://github.com/triton-lang/triton.git; + cd triton/python; + pip install ninja cmake wheel; # build-time dependencies + pip install -e . + +Note that, if llvm is not present on your system, the setup.py script will download the official LLVM static libraries and link against that. + +For building with a custom LLVM, review the `Building with a custom LLVM `_ section on Github. + +You can then test your installation by running the unit tests: + +.. code-block:: bash + + pip install -e '.[tests]' + pytest -vs test/unit/ + +and the benchmarks + +.. code-block:: bash + + cd bench + python -m run --with-plots --result-dir /tmp/triton-bench diff --git a/third_party/enflame/include/triton/docs/getting-started/tutorials/grouped_vs_row_major_ordering.png b/third_party/enflame/include/triton/docs/getting-started/tutorials/grouped_vs_row_major_ordering.png new file mode 100644 index 000000000..46a356de7 Binary files /dev/null and b/third_party/enflame/include/triton/docs/getting-started/tutorials/grouped_vs_row_major_ordering.png differ diff --git a/third_party/enflame/include/triton/docs/getting-started/tutorials/parallel_reduction.png b/third_party/enflame/include/triton/docs/getting-started/tutorials/parallel_reduction.png new file mode 100644 index 000000000..8105b2883 Binary files /dev/null and b/third_party/enflame/include/triton/docs/getting-started/tutorials/parallel_reduction.png differ diff --git a/third_party/enflame/include/triton/docs/getting-started/tutorials/random_bits.png b/third_party/enflame/include/triton/docs/getting-started/tutorials/random_bits.png new file mode 100644 index 000000000..198f90a5e Binary files /dev/null and b/third_party/enflame/include/triton/docs/getting-started/tutorials/random_bits.png differ diff --git a/third_party/enflame/include/triton/docs/index.rst b/third_party/enflame/include/triton/docs/index.rst new file mode 100644 index 000000000..e9cf1e79f --- /dev/null +++ b/third_party/enflame/include/triton/docs/index.rst @@ -0,0 +1,72 @@ +Welcome to Triton's documentation! +================================== + +Triton_ is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware. + + +Getting Started +--------------- + +- Follow the :doc:`installation instructions ` for your platform of choice. +- Take a look at the :doc:`tutorials ` to learn how to write your first Triton program. + +.. toctree:: + :maxdepth: 1 + :caption: Getting Started + :hidden: + + getting-started/installation + getting-started/tutorials/index + + +Python API +---------- + +- :doc:`triton ` +- :doc:`triton.language ` +- :doc:`triton.testing ` +- :doc:`Triton semantics ` + + +.. toctree:: + :maxdepth: 1 + :caption: Python API + :hidden: + + python-api/triton + python-api/triton.language + python-api/triton.testing + python-api/triton-semantics + + +Triton MLIR Dialects and Ops +-------------------- + +- :doc:`Triton MLIR Dialects and Ops ` + +.. toctree:: + :maxdepth: 1 + :caption: Triton MLIR Dialects + :hidden: + + dialects/dialects + +Going Further +------------- + +Check out the following documents to learn more about Triton and how it compares against other DSLs for DNNs: + +- Chapter 1: :doc:`Introduction ` +- Chapter 2: :doc:`Related Work ` +- Chapter 3: :doc:`Debugging ` + +.. toctree:: + :maxdepth: 1 + :caption: Programming Guide + :hidden: + + programming-guide/chapter-1/introduction + programming-guide/chapter-2/related-work + programming-guide/chapter-3/debugging + +.. _Triton: https://github.com/triton-lang/triton diff --git a/third_party/enflame/include/triton/docs/meetups/01-24-2024/notes.md b/third_party/enflame/include/triton/docs/meetups/01-24-2024/notes.md new file mode 100644 index 000000000..524950fb3 --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/01-24-2024/notes.md @@ -0,0 +1,25 @@ +#### Agenda: + +##### Items: +1. 3rd party refactoring backend update. +2. AMD update about experience with refactored backend and new process. +3. Plan to restore the Intel XPU backend as third-party module. +4. Open discussion. + +##### Minutes: +Recording link [here](https://youtu.be/uRlqolhNbRk) + +1. 3rd party refactoring backend update. + - Backends are passes and IRs are shared by the backends to avoid divergence and duplications so that developers do not have to change the Triton source code + - To discover backend forks in directories, put environment vars in setup.py. + - Backends can link whatever library they want, they don’t need to copy paste Nvidia code. + - Nvidia uses the same API as other backends, (refactoring of the C++ code is still remaining). No special casing for Nvidia code. + - If Triton dependency is on top of the main branch then it will work for forks/branches. + - Still remaining: LLVM IR conversion – reusuable pattern rewriters update; Reduce complexity in statefulness in Triton GPU - inherit from base pattern +2. AMD update about experience with refactored backend and new process. + - Skipped due to lack of time. Will be covered in February meetup +3. Plan to restore the Intel XPU backend as third-party module. + - Prereqs to upstream – Will take into account the system HW and SW, with perf to be ~80% of Nvidia, to allow upstreaming. + - Consider how useful it is for AI research to allow upstreaming – as it impacts maintenance cost of the backends. + - Don’t have plans to upstream mobile backends + - Intel will hold offline discussion with Open AI for being in-tree. diff --git a/third_party/enflame/include/triton/docs/meetups/02-20-2024/Proton.pdf b/third_party/enflame/include/triton/docs/meetups/02-20-2024/Proton.pdf new file mode 100644 index 000000000..635b90362 Binary files /dev/null and b/third_party/enflame/include/triton/docs/meetups/02-20-2024/Proton.pdf differ diff --git a/third_party/enflame/include/triton/docs/meetups/02-20-2024/notes.md b/third_party/enflame/include/triton/docs/meetups/02-20-2024/notes.md new file mode 100644 index 000000000..a45df659a --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/02-20-2024/notes.md @@ -0,0 +1,23 @@ +#### Agenda: + +##### Items: +1. Intel update +2. AMD update +3. Profiler update +4. We are in the process of transitioning to a pro slack plan, so everybody will be able to see history. Expect this to take a few more weeks. +5. We are still working on finalizing a document about our technical governance structure. Expect this to take a few more weeks too.4. Open discussion. + +##### Minutes: +Recording link [here](https://youtu.be/JDQCdj18Snc) + +1. Intel GPU integration with Triton and Pytorch: + - No strong requirement from PyTorch for specific backends to be part of Triton official release. + - Can use a separate branch/fork for CI/CD and testing. + - Intel team will work with Pytorch offline to close. +2. AMD GPU backend update: + - AMD team shared the refactored design for AMD backend. + - The new design is modularized and reduces clutter and duplication in upstream Triton. + - Further work needed for regression testing and secure runners. +3. Proton profiler update: + - Keren from the OpenAI team presented a new profiler tool for Triton kernels, which supports multiple vendors, metrics, and formats. + - Outlined the plan for open-sourcing, integrating, and extending the tool. diff --git a/third_party/enflame/include/triton/docs/meetups/04-02-2024/notes.md b/third_party/enflame/include/triton/docs/meetups/04-02-2024/notes.md new file mode 100644 index 000000000..5132a8bb7 --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/04-02-2024/notes.md @@ -0,0 +1,19 @@ +#### Agenda: + +##### Items: +1. Interpreter update +2. Experience with TMA support and future plans for it +3. CGO trip report +4. Triton upstream CI and unit test status from AMD +5. Open discussion + +##### Minutes: +Recording link [here](https://youtu.be/VTcFe2XxZZc) + +Presentations repo [here](https://drive.google.com/drive/folders/1bKpvz1NiBL_fHrGhMoZPvQfXCeetV2iY?usp=sharing) + +1. Triton interpreter mode: The Open AI presented the interpreter mode for Triton code, which allows users to debug and inspect individual GPU programs using native Python print or PDB. It is currently being turned on using an environment variables, code decorators for individual functions being interpreted are still TBD. It can also run on CPU without GPU. For more details about the presentation please refer slides. +2. Tensor Memory Access (TMA) discussion: The current implementation of TMA in Triton has some limitations, so has been removed for now. The plan is to rethink how to do it better in the future. The goal is to support TMA implicitly, but the challenge is to handle the different memory layouts for different backends. There is a pull request to improve the launch overhead of kernels, which is related to TMA, but it would require extensive review and testing. +3. CGO trip report: Ian Bearman from Microsoft shared his experience of attending CGO and the Compilers for Machine Learning workshop. He and Javed Absar from Qualcomm gave talks about Triton shared and answered questions about Triton. There was a lot of interest in Triton as a cross-platform kernel language and questions were around the PyTorch integration, the performance portability, and the codegen bugs. It will be good to make the Triton-Pytorch connection more visible. There was also another project called Turbine that was similar to Triton. Please refer to the slides for more details. +4. AMD upstream CI and unit tests status: The AMD team discussed CI and enabling tests for MI 210 and MI 300. Work is in progress for performance gaps, compilation errors and fixes for FP8IN and flash attention kernels. The plan is to upstream these changes soon. Please refer to the slides for more details. +5. Third party CPU backend: The Intel team is driving discussions for community collaboration on a proof of concept for a CPU backend for Triton, using MLIR and OpenMP. There will be a follow-up meeting to discuss the logistics and design. Please refer to the third-party channel in slack for more details. diff --git a/third_party/enflame/include/triton/docs/meetups/05-07-2024/notes.md b/third_party/enflame/include/triton/docs/meetups/05-07-2024/notes.md new file mode 100644 index 000000000..eb4d9d253 --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/05-07-2024/notes.md @@ -0,0 +1,21 @@ +#### Agenda: +1. Triton CPU summary +2. Triton introduced a new Triton layout redesign (linear layout PR3794 ). Does this layout try to cover Triton CPU backend for SIMD instructions. +3. Triton Stream-k on AMD GPUs + +##### Items: +Meeting notes: +1. Triton CPU backend: The Meta team presented their motivation, design, and progress on developing a CPU backend for Triton. + There is a demand for heterogeneity and portability across different CPU architectures, especially for small batch sizes and inference workloads. + They proposed to use MLIR and vector dialect to lower Triton IR to LLVM IR, and to leverage existing dialects and transformations for GPU backends. + There maybe a possible refactoring of the CPU backend to make it more general and modular. + Currently they have done initial work on plumbing the CPU backend and implementing a basic vector load operation using transfer read. + Repo and other details are in the slides below. + Open questions: How to handle different vector widths and operations, how to support ARM Neon, how to set performance goals and criteria, and how to coordinate with other Triton developers and contributors. +2. Stream-k for AMD: The AMD team presented their implementation and evaluation of Stream-k, a load-balanced scheme for matrix multiplication that can handle different tile sizes and split K dimensions. + They compared it with PyTorch Matmul and Triton Matmul. Other details are in the slides below. + +##### Minutes: +Recording link [here](https://youtu.be/hgINpebZ7n0) + +Presentations repo [here](https://drive.google.com/drive/folders/1xPnRO5P59aMVJnXz_o9ASTUgTXK1lhHW?usp=drive_link) diff --git a/third_party/enflame/include/triton/docs/meetups/07-18-2023/notes.md b/third_party/enflame/include/triton/docs/meetups/07-18-2023/notes.md new file mode 100644 index 000000000..d3cb875c5 --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/07-18-2023/notes.md @@ -0,0 +1,44 @@ +#### Agenda: + +##### Announcements: +1. Triton conference planned mid September in the Microsoft Silicon Valley Campus. + +##### Items: +1. Alternative backend development approach (e.g. AMD, Intel) +2. State of the documentation, is there a planned effort? If yes, what do you think is the priority? +3. Mechanisms for smaller technical discussions: Slack channel per topic? Dedicated meetings for some topics? +4. Stability, testing, regressions: Improving CI and conformance/testing for validating new back-ends. +5. Language improvements/pain points +6. Windows Support +7. Discussion of known/anticipated design changes for H100 +8. Some specific more tactical areas: + - int8. + - A low hanging fruit is to let tl.dot take int8 and leverage mma. + - Sm75. + - device functions. How hard is this to support while Triton frontend traverses AST? + - remove torch dependencies from the frontend. (it sounds like there is already progress on this but could be worth discussing) + +##### Minutes +Recording link [here](https://drive.google.com/file/d/1uMlIvih_E5FITwPnNHwTYzo-UKqtey2c/view) + +1. Backend plans/broader roadmap: + - Plan is for major updates to come in the Triton development meetup which will happen mid-September. For major design changes, currently the plan is to not upstream them directly but have a staging state and different backends can be integrated through a plugin mechanism where Triton provides a layer at the Triton IR layer that is generic and other backends can plug into that. + - Short term roadmap plans are very focused on things like improving all FP8 things on Ampere and Hopper support (end of August). After Hopper support lands, priorities will include refactoring codebase to increase maintainability. + - Linalg – upstreaming on hold due to limited dev bandwidth. Want to build an ecosystem where others can leverage Linalg like passes developed in their backend. + - For now, peak performance on Nvidia GPUs needs Nvidia specific things, but the convergence of programming models for different backends will allow convergence of hardware backend support in Triton. +2. Documentation: + - OpenAI has included comments in the backend code. + - Seek community involvement to improve tutorials, based on new users knowing what is missing. + - Seek community involvement for signature changes and doc updates. + - Thread created in slack for suggestions on areas needing doc updates. Ian Bearman and his team may have bandwidth to update certain documentation. +3. Discussion channels: + - Preferred #dev channel in slack for technical discussions. + - Between GitHub and Slack it would be good to post links into places so folks know discussions are happening elsewhere +4. CI/testing: + - Pretty liberal in terms of accepting regression tests and integration tests for Nvidia. + - Plugin interface tested like everything else, and regressions there would block merges into main. + - Correctness/Performance of external backends are tested nightly, but regressions do not prevent wheels from being built. +5. Language improvements: + - Have added location information support into Triton codegen. + - Feel free to bring up pain points in slack. +7. Windows Support: Technically not difficult to get a preliminary version. Most of the maintenance burden would come from having to support it when it breaks. diff --git a/third_party/enflame/include/triton/docs/meetups/08-06-2024/notes.md b/third_party/enflame/include/triton/docs/meetups/08-06-2024/notes.md new file mode 100644 index 000000000..48762c62a --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/08-06-2024/notes.md @@ -0,0 +1,13 @@ +#### Agenda: +1. Triton-CPU Update +2. Intel GPU backend update + +##### Items: +Meeting notes: +1. Triton-CPU Update: Intel and Meta jointly presented the work on Triton-CPU, highlighting good progress on coverage and performance improvements. They also covered some of the optimizations they leveraged to get performance comparable to torch-native and torch-inductor. More details are in their slides. +2. Intel GPU Backend: Intel GPU backend shows good performance close to expert-tuned kernels and the use of block pointers for performance gains. There were questions around the future of block pointers and their importance for performance gains. With block-pointer deprecation there is a need for a more generic interface to support various backends including Intel GPU. +3. The 2024 Triton conference is on September 17th 2024 in Fremont California! Please register [here](README.md). +##### Minutes: +Recording link [here](https://youtu.be/dfL3L4_3ujg) + +Presentations repo [here](https://drive.google.com/drive/folders/1fQ3zVrM7DT8W8FGJWKx1wNr2X53tYbeT?usp=sharing) diff --git a/third_party/enflame/include/triton/docs/meetups/08-22-2023/amd-update.pdf b/third_party/enflame/include/triton/docs/meetups/08-22-2023/amd-update.pdf new file mode 100644 index 000000000..e0178355c Binary files /dev/null and b/third_party/enflame/include/triton/docs/meetups/08-22-2023/amd-update.pdf differ diff --git a/third_party/enflame/include/triton/docs/meetups/08-22-2023/intel-xpu-update.pptx b/third_party/enflame/include/triton/docs/meetups/08-22-2023/intel-xpu-update.pptx new file mode 100644 index 000000000..d9c61dfaa Binary files /dev/null and b/third_party/enflame/include/triton/docs/meetups/08-22-2023/intel-xpu-update.pptx differ diff --git a/third_party/enflame/include/triton/docs/meetups/08-22-2023/notes.md b/third_party/enflame/include/triton/docs/meetups/08-22-2023/notes.md new file mode 100644 index 000000000..ef5f578d8 --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/08-22-2023/notes.md @@ -0,0 +1,41 @@ +#### Agenda: + +##### Announcements: +1. Triton conference registration opening soon. Conference on 20th September at the Microsoft Silicon Valley Campus. + +##### Items: +1. H100 updates +2. Triton release plan update +3. Linalg updates +4. Intel GPU Backend status update. +5. Intel working on the CPU backend for Triton. +6. AMD updates +7. Open discussion + +##### Minutes: +Recording link [here](https://drive.google.com/file/d/19Nnc0i7zUyn-ni2RSFHbPHHiPkYU96Mz/view) + +1. H100 updates: + - Preliminary support is merged, disabled by default, can be enabled with env variables + - Supports latest tensor cores, FP8s. Support for Flash Attention on the main branch coming soon. + - Performance is very good on Matmuls, 80-90% of cublas on large Matmuls right now, will eventually reach parity with cublas. Above 600 teraflops on fp16 on xxm card, cublas is 670 on random input data. FP8 is twice that, around 1.2 petaflops. + - Hopper support includes the full FP8 support for compute. +2. Triton release plan update + - No specific dates for now, plan is to release before end of 2023. + - Will move to 3.0 release due to minor backward compatibility breaking changes. For eg. Will move compiler options in the indexing operators as hardcoded operators in the kernel, will bump the major version. + - Functionally the main goal will be to have 3rd party plugins for Intel and AMD gpus. + - May synchronise with a PyTorch release so that PyTorch can benefit from the latest features, however continuous integration workflow is the default release cadence expected. + - Will switch the default behavior to optimized mode for the release, needs more discussion with Nvidia. + - Will expose flags for a user to enable kernel selection themselves. + - Open question: Pytorch hasn’t rebased to latest triton, it is close to PyTorch code freeze – will PyTorch still sync with Triton 2.0? Will we have another release to support triton 2.0? + - Community can start with the latest stable branch and rebase 3rd party plugin on top of that. OAI has no resources to commit to, but community can contribute. +3. Linalg updates + - Discussion on Github for Linalg as a middle layer between the language and target hardware. Includes support for block pointers and modulo operators. + - Please join the conversation [here](https://github.com/triton-lang/triton/discussions/1842) + - Branch pushed is behind the tip, will work on getting it caught up on the tip. +4. Intel GPU Backend status update. + - Please refer to slides [here](https://github.com/triton-lang/triton/blob/main/docs/meetups/Intel%20XPU%20Backend%20for%20Triton%20-%20Update%20-%200823.pptx) +5. Intel working on the CPU backend for Triton. + - Please refer to slides [here](https://github.com/triton-lang/triton/blob/main/docs/meetups/Intel%20XPU%20Backend%20for%20Triton%20-%20Update%20-%200823.pptx) +6. AMD updates + - Please refer to slides [here](https://github.com/triton-lang/triton/blob/main/docs/meetups/Triton_AMD_update_0823.pdf). diff --git a/third_party/enflame/include/triton/docs/meetups/10-25-2023/intel-xpu-update.pdf b/third_party/enflame/include/triton/docs/meetups/10-25-2023/intel-xpu-update.pdf new file mode 100644 index 000000000..defc4b719 Binary files /dev/null and b/third_party/enflame/include/triton/docs/meetups/10-25-2023/intel-xpu-update.pdf differ diff --git a/third_party/enflame/include/triton/docs/meetups/10-25-2023/notes.md b/third_party/enflame/include/triton/docs/meetups/10-25-2023/notes.md new file mode 100644 index 000000000..04777eb3f --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/10-25-2023/notes.md @@ -0,0 +1,24 @@ +#### Agenda: + +##### Items: +1. H100 updates +2. Triton-Shared layer updates +3. Intel update +4. Open discussion + +##### Minutes: +Recording link [here](https://youtu.be/KZAzpKx1ebI) + +1. H100 updates + - Enabled WGMMA by default, now any matmul can reuse it. + - fp8 formats enabled – 1.3 Petaflops on dense matmul on H100 (gemm performance) + - Enabled Flash Attention using wgmma, resulting in 450 teraflop on fwd pass and 250 on backward pass – still working on perf for flash attention + - fp8 numbers with flash attention running in fp8 with matmul is tricky, because the fp8 layout is significantly different than what is returned by wgmma, still wip + +2. Triton-Shared layer + - Please refer to slides for more details + - Created a repo where you can find the middle layer + - Available as a plugin into triton + +3. Intel Update + - Please refer to slides for more details diff --git a/third_party/enflame/include/triton/docs/meetups/10-25-2023/triton-shared.pptx b/third_party/enflame/include/triton/docs/meetups/10-25-2023/triton-shared.pptx new file mode 100644 index 000000000..ea2f0ea41 Binary files /dev/null and b/third_party/enflame/include/triton/docs/meetups/10-25-2023/triton-shared.pptx differ diff --git a/third_party/enflame/include/triton/docs/meetups/12-13-2023/notes.md b/third_party/enflame/include/triton/docs/meetups/12-13-2023/notes.md new file mode 100644 index 000000000..d7c28a061 --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/12-13-2023/notes.md @@ -0,0 +1,16 @@ +#### Agenda: + +##### Items: +1. Refactoring plan for 3rd party backends +2. Front end refactoring (AMD) +3. Things like block pointers, ptr_analysis, mask_analysis can be used for GPUs, is there a plan to incrementally include components from Triton shared for GPU development. + +##### Minutes: +Recording link [here](https://youtu.be/Lo43DQYkOWM) + +1. Refactoring plan for 3rd party backends + - Refactoring to be completed by end of the year so that all GPU backends can be individual passes on Triton GPU IR instead of being completely out of tree. The goal is for users to get other GPUs besides Cuda when they install Triton. Non-GPU Triton IR expected to stay as is. +3. Front end refactoring (AMD) + - Will work with Phil for AMD related refactoring. Will share more details in next meetup about where AMD has diverged from Triton GPU IR and in the codeflow. +4. Things like block pointers, ptr_analysis, mask_analysis can be used for GPUs, is there a plan to incrementally include components from Triton shared for GPU development. + - Can look at it on a case by case basis. diff --git a/third_party/enflame/include/triton/docs/meetups/dev-meetup-2023.md b/third_party/enflame/include/triton/docs/meetups/dev-meetup-2023.md new file mode 100644 index 000000000..27719b107 --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/dev-meetup-2023.md @@ -0,0 +1,27 @@ +The conference slides are available [here](https://drive.google.com/drive/folders/1yDFc4ElNN_GGhWDdMlM4wcm5uFEFFVQk?usp=sharing) + +The conference videos will be available [here](https://youtube.com/playlist?list=PLc_vA1r0qoiRZfUC3o4_yjj0FtWvodKAz&feature=shared) when ready. + +# Triton Developer Conference +The Triton Developer Conference was held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference was held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. + +Agenda for the conference: + +|Time |Title |Speaker +|--------|-------|-------| +|10:00 AM|Welcome|Kevin Scott (Microsoft)| +|10:20 AM|The Triton Compiler: Past, Present and Future|Phil Tillet (OpenAI)| +|11:00 AM|**Break**|| +|11:20 AM|Hopper support in Triton|Gustav Zhu (Nvidia)| +|11:40 AM|Bringing Triton to AMD GPUs|Jason Furmanek, Lixun Zhang (AMD)| +|12:00 PM|Intel XPU Backend for Triton|Eikan Wang (Intel)| +|12:20 PM|Vectorization of Triton Kernels for Qualcomm Hexagon Backend|Javed Absar (Qualcomm)| +|12:30 PM|**Lunch**|| +|1:40 PM |Triton for MTIA|Roman Levenstein et al, (Meta)| +|2:00 PM |Using Triton IR for high-performance fusions in XLA|George Karpenkov (Google)| +|2:20 PM |Triton for All: Triton as a device-independent language|Ian Bearman (Microsoft)| +|2:40 PM|**Break**|| +|3:00 PM|PyTorch 2.0 and TorchInductor|Jason Ansel, Horace He (Meta)| +|3:20 PM|Pallas: A JAX Kernel Language|Sharad Vikram (Google)| +|3:40 PM|Writing Grouped GEMMs in Triton|Vinod Grover (Nvidia)| +|4:00 PM|**Reception**|| diff --git a/third_party/enflame/include/triton/docs/meetups/dev_conference_2024.md b/third_party/enflame/include/triton/docs/meetups/dev_conference_2024.md new file mode 100644 index 000000000..6816b4c59 --- /dev/null +++ b/third_party/enflame/include/triton/docs/meetups/dev_conference_2024.md @@ -0,0 +1,3 @@ +The conference slides are available [here](https://drive.google.com/drive/folders/1osK9hwcX_lC1EjdZGB-v4w5oKx23UnU2?usp=drive_link) + +The conference videos are available [here](https://www.youtube.com/playlist?list=PLc_vA1r0qoiTjlrINKUuFrI8Ptoopm8Vz). diff --git a/third_party/enflame/include/triton/docs/programming-guide/chapter-1/cuda-parallel-matmul.png b/third_party/enflame/include/triton/docs/programming-guide/chapter-1/cuda-parallel-matmul.png new file mode 100644 index 000000000..8050ad150 Binary files /dev/null and b/third_party/enflame/include/triton/docs/programming-guide/chapter-1/cuda-parallel-matmul.png differ diff --git a/third_party/enflame/include/triton/docs/programming-guide/chapter-1/introduction.rst b/third_party/enflame/include/triton/docs/programming-guide/chapter-1/introduction.rst new file mode 100644 index 000000000..2a843fd37 --- /dev/null +++ b/third_party/enflame/include/triton/docs/programming-guide/chapter-1/introduction.rst @@ -0,0 +1,71 @@ +============ +Introduction +============ + +----------- +Motivations +----------- + +Over the past decade, Deep Neural Networks (DNNs) have emerged as an important class of Machine Learning (ML) models, capable of achieving state-of-the-art performance across many domains ranging from natural language processing [SUTSKEVER2014]_ to computer vision [REDMON2016]_ to computational neuroscience [LEE2017]_. The strength of these models lies in their hierarchical structure, composed of a sequence of parametric (e.g., convolutional) and non-parametric (e.g., rectified linearity) *layers*. This pattern, though notoriously computationally expensive, also generates a large amount of highly parallelizable work particularly well suited for multi- and many- core processors. + +As a consequence, Graphics Processing Units (GPUs) have become a cheap and accessible resource for exploring and/or deploying novel research ideas in the field. This trend has been accelerated by the release of several frameworks for General-Purpose GPU (GPGPU) computing, such as CUDA and OpenCL, which have made the development of high-performance programs easier. Yet, GPUs remain incredibly challenging to optimize for locality and parallelism, especially for computations that cannot be efficiently implemented using a combination of pre-existing optimized primitives. To make matters worse, GPU architectures are also rapidly evolving and specializing, as evidenced by the addition of tensor cores to NVIDIA (and more recently AMD) micro-architectures. + +This tension between the computational opportunities offered by DNNs and the practical difficulty of GPU programming has created substantial academic and industrial interest for Domain-Specific Languages (DSLs) and compilers. Regrettably, these systems -- whether they be based on polyhedral machinery (e.g., Tiramisu [BAGHDADI2021]_, Tensor Comprehensions [VASILACHE2018]_) or scheduling languages (e.g., Halide [JRK2013]_, TVM [CHEN2018]_) -- remain less flexible and (for the same algorithm) markedly slower than the best handwritten compute kernels available in libraries like `cuBLAS `_, `cuDNN `_ or `TensorRT `_. + +The main premise of this project is the following: programming paradigms based on blocked algorithms [LAM1991]_ can facilitate the construction of high-performance compute kernels for neural networks. We specifically revisit traditional "Single Program, Multiple Data" (SPMD [AUGUIN1983]_) execution models for GPUs, and propose a variant in which programs -- rather than threads -- are blocked. For example, in the case of matrix multiplication, CUDA and Triton differ as follows: + +.. table:: + :widths: 50 50 + + +-----------------------------------------------------+-----------------------------------------------------+ + | CUDA Programming Model | Triton Programming Model | + | | | + | (Scalar Program, Blocked Threads) | (Blocked Program, Scalar Threads) | + +=====================================================+=====================================================+ + | | | + |.. code-block:: C |.. code-block:: C | + | | :force: | + | | | + | #pragma parallel | #pragma parallel | + | for(int m = 0; m < M; m++) | for(int m = 0; m < M; m += MB) | + | #pragma parallel | #pragma parallel | + | for(int n = 0; n < N; n++){ | for(int n = 0; n < N; n += NB){ | + | float acc = 0; | float acc[MB, NB] = 0; | + | for(int k = 0; k < K; k++) | for(int k = 0; k < K; k += KB) | + | acc += A[m, k] * B[k, n]; | acc += A[m:m+MB, k:k+KB] | + | | @ B[k:k+KB, n:n+NB]; | + | C[m, n] = acc; | C[m:m+MB, n:n+NB] = acc; | + | } | } | + | | | + +-----------------------------------------------------+-----------------------------------------------------+ + | |pic1| | |pic2| | + +-----------------------------------------------------+-----------------------------------------------------+ + + +.. |pic1| image:: cuda-parallel-matmul.png + +.. |pic2| image:: triton-parallel-matmul.png + +A key benefit of this approach is that it leads to block-structured iteration spaces that offer programmers more flexibility than existing DSLs when implementing sparse operations, all while allowing compilers to aggressively optimize programs for data locality and parallelism. + + +---------- +Challenges +---------- + +The main challenge posed by our proposed paradigm is that of work scheduling, i.e., how the work done by each program instance should be partitioned for efficient execution on modern GPUs. To address this issue, the Triton compiler makes heavy use of *block-level data-flow analysis*, a technique for scheduling iteration blocks statically based on the control- and data-flow structure of the target program. The resulting system actually works surprisingly well: our compiler manages to apply a broad range of interesting optimization automatically (e.g., automatic coalescing, thread swizzling, pre-fetching, automatic vectorization, tensor core-aware instruction selection, shared memory allocation/synchronization, asynchronous copy scheduling). Of course doing all this is not trivial; one of the purposes of this guide is to give you a sense of how it works. + + +---------- +References +---------- + +.. [SUTSKEVER2014] I. Sutskever et al., "Sequence to Sequence Learning with Neural Networks", NIPS 2014 +.. [REDMON2016] J. Redmon et al., "You Only Look Once: Unified, Real-Time Object Detection", CVPR 2016 +.. [LEE2017] K. Lee et al., "Superhuman Accuracy on the SNEMI3D Connectomics Challenge", ArXiV 2017 +.. [BAGHDADI2021] R. Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021 +.. [VASILACHE2018] N. Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018 +.. [JRK2013] J. Ragan-Kelley et al., "Halide: A Language and Compiler for Optimizing Parallelism, Locality, and Recomputation in Image Processing Pipelines", PLDI 2013 +.. [CHEN2018] T. Chen et al., "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning", OSDI 2018 +.. [LAM1991] M. Lam et al., "The Cache Performance and Optimizations of Blocked Algorithms", ASPLOS 1991 +.. [AUGUIN1983] M. Auguin et al., "Opsila: an advanced SIMD for numerical analysis and signal processing", EUROMICRO 1983 diff --git a/third_party/enflame/include/triton/docs/programming-guide/chapter-1/triton-parallel-matmul.png b/third_party/enflame/include/triton/docs/programming-guide/chapter-1/triton-parallel-matmul.png new file mode 100644 index 000000000..7b11ba2af Binary files /dev/null and b/third_party/enflame/include/triton/docs/programming-guide/chapter-1/triton-parallel-matmul.png differ diff --git a/third_party/enflame/include/triton/docs/programming-guide/chapter-2/halide-iteration.png b/third_party/enflame/include/triton/docs/programming-guide/chapter-2/halide-iteration.png new file mode 100644 index 000000000..073634677 Binary files /dev/null and b/third_party/enflame/include/triton/docs/programming-guide/chapter-2/halide-iteration.png differ diff --git a/third_party/enflame/include/triton/docs/programming-guide/chapter-2/polyhedral-iteration.png b/third_party/enflame/include/triton/docs/programming-guide/chapter-2/polyhedral-iteration.png new file mode 100644 index 000000000..02f9c2593 Binary files /dev/null and b/third_party/enflame/include/triton/docs/programming-guide/chapter-2/polyhedral-iteration.png differ diff --git a/third_party/enflame/include/triton/docs/programming-guide/chapter-2/related-work.rst b/third_party/enflame/include/triton/docs/programming-guide/chapter-2/related-work.rst new file mode 100644 index 000000000..fc2e2f1df --- /dev/null +++ b/third_party/enflame/include/triton/docs/programming-guide/chapter-2/related-work.rst @@ -0,0 +1,213 @@ +============ +Related Work +============ + +At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlight its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages. + + +---------------------- +Polyhedral Compilation +---------------------- + +Traditional compilers typically rely on intermediate representations, such as LLVM-IR [LATTNER2004]_, that encode control flow information using (un)conditional branches. This relatively low-level format makes it difficult to statically analyze the runtime behavior (e.g., cache misses) of input programs, and to automatically optimize loops accordingly through the use of tiling [WOLFE1989]_, fusion [DARTE1999]_ and interchange [ALLEN1984]_. To solve this issue, polyhedral compilers [ANCOURT1991]_ rely on program representations that have statically predictable control flow, thereby enabling aggressive compile-time program transformations for data locality and parallelism. Though this strategy has been adopted by many languages and compilers for DNNs such as Tiramisu [BAGHDADI2021]_, Tensor Comprehensions [VASILACHE2018]_, Diesel [ELANGO2018]_ and the Affine dialect in MLIR [LATTNER2019]_, it also comes with a number of limitations that will be described later in this section. + +++++++++++++++++++++++ +Program Representation +++++++++++++++++++++++ + +Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming. + +.. table:: + :widths: 50 50 + + +-----------------------------------------------------+-----------------------------------------------------+ + | | | + |.. code-block:: C | |pic1| | + | | | + | for(int i = 0; i < 3; i++) | | + | for(int j = i; j < 5; j++) | | + | A[i][j] = 0; | | + +-----------------------------------------------------+-----------------------------------------------------+ + +.. |pic1| image:: polyhedral-iteration.png + :width: 300 + +Polyhedral compilers focus on a class of programs commonly known as **Static Control Parts** (SCoP), *i.e.*, maximal sets of consecutive statements in which conditionals and loop bounds are affine functions of surrounding loop indices and global invariant parameters. As shown above, programs in this format always lead to iteration domains that are bounded by affine inequalities, i.e., polyhedral. These polyhedra can also be defined algebraically; for the above example: + +.. math:: + + \mathcal{P} = \{ i, j \in \mathbb{Z}^2 + ~|~ + \begin{pmatrix} + 1 & 0 \\ + -1 & 0 \\ + -1 & 1 \\ + 0 & -1 \\ + \end{pmatrix} + \begin{pmatrix} + i \\ + j + \end{pmatrix} + + + \begin{pmatrix} + 0 \\ + 2 \\ + 0 \\ + 4 + \end{pmatrix} + \geq + 0 + \} + + +Each point :math:`(i, j)` in :math:`\mathcal{P}` represents a *polyhedral statement*, that is a program statement which (1) does not induce control-flow side effects (e.g., :code:`for`, :code:`if`, :code:`break`) and (2) contains only affine functions of loop indices and global parameters in array accesses. To facilitate alias analysis, array accesses are also mathematically abstracted, using so-called *access function*. In other words, :code:`A[i][j]` is simply :code:`A[f(i,j)]` where the access function :math:`f` is defined by: + +.. math:: + + f(i, j) = \begin{pmatrix} + 1 & 0\\ + 0 & 1\\ + \end{pmatrix} + \begin{pmatrix} + i\\ + j + \end{pmatrix} + = + (i, j) + + +Note that the iteration domains of an SCoP does not specify the order in which its statements shall execute. In fact, this iteration domain may be traversed in many different possible legal orders, i.e. *schedules*. Formally, a schedule is defined as a p-dimensional affine transformation :math:`\Theta` of loop indices :math:`\mathbf{x}` and global invariant parameters :math:`\mathbf{g}`: + +.. math:: + \Theta_S(\mathbf{x}) = T_S \begin{pmatrix} + \vec{x}\\ + \vec{g}\\ + 1 + \end{pmatrix} + \qquad + T_S \in \mathbb{Z} ^{p \times (\text{dim}(\mathbf{x}) + \text{dim}(\mathbf{g}) + 1)} + + +Where :math:`\Theta_S(\mathbf{x})` is a p-dimensional vector representing the slowest to fastest growing indices (from left to right) when traversing the loop nest surrounding :math:`S`. For the code shown above, the original schedule defined by the loop nest in C can be retrieved by using: + +.. math:: + \Theta_S(\mathbf{x}) = \begin{pmatrix} + 1 & 0 \\ + 0 & 1 \\ + \end{pmatrix} + \begin{pmatrix} + i & j + \end{pmatrix}^T + = + \begin{pmatrix} + i & j + \end{pmatrix}^T + + +where :math:`i` and :math:`j` are respectively the slowest and fastest growing loop indices in the nest. If :math:`T_S` is a vector (resp. tensor), then :math:`\Theta_S` is a said to be one-dimensional (resp. multi-dimensional). + +++++++++++ +Advantages +++++++++++ + +Programs amenable to polyhedral compilation can be aggressively transformed and optimized. Most of these transformations actually boil down to the production of schedules and iteration domains that enable loop transformations promoting parallelism and spatial/temporal data locality (e.g., fusion, interchange, tiling, parallelization). + +Polyhedral compilers can also automatically go through complex verification processes to ensure that the semantics of their input program is preserved throughout this optimization phase. Note that polyhedral optimizers are not incompatible with more standard optimization techniques. In fact, it is not uncommon for these systems to be implemented as a set of LLVM passes that can be run ahead of more traditional compilation techniques [GROSSER2012]_. + +All in all, polyhedral machinery is extremely powerful, when applicable. It has been shown to support most common loop transformations, and has indeed achieved performance comparable to state-of-the-art GPU libraries for dense matrix multiplication [ELANGO2018]_. Additionally, it is also fully automatic and doesn't require any hint from programmers apart from source-code in a C-like format. + ++++++++++++ +Limitations ++++++++++++ + +Unfortunately, polyhedral compilers suffer from two major limitations that have prevented its adoption as a universal method for code generation in neural networks. + +First, the set of possible program transformations :math:`\Omega = \{ \Theta_S ~|~ S \in \text{program} \}` is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [SATO2019]_. + +Second, the polyhedral framework is not very generally applicable; SCoPs are relatively common [GIRBAL2006]_ but require loop bounds and array subscripts to be affine functions of loop indices, which typically only occurs in regular, dense computations. For this reason, this framework still has to be successfully applied to sparse -- or even structured-sparse -- neural networks, whose importance has been rapidly rising over the past few years. + +On the other hand, blocked program representations advocated by this dissertation are less restricted in scope and can achieve close to peak performance using standard dataflow analysis. + + +-------------------- +Scheduling Languages +-------------------- + +Separation of concerns [DIJKSTRA82]_ is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently. + +.. code-block:: python + :linenos: + + // algorithm + Var x("x"), y("y"); + Func matmul("matmul"); + RDom k(0, matrix_size); + RVar ki; + matmul(x, y) = 0.0f; + matmul(x, y) += A(k, y) * B(x, k); + // schedule + Var xi("xi"), xo("xo"), yo("yo"), yi("yo"), yii("yii"), xii("xii"); + matmul.vectorize(x, 8); + matmul.update(0) + .split(x, x, xi, block_size).split(xi, xi, xii, 8) + .split(y, y, yi, block_size).split(yi, yi, yii, 4) + .split(k, k, ki, block_size) + .reorder(xii, yii, xi, ki, yi, k, x, y) + .parallel(y).vectorize(xii).unroll(xi).unroll(yii); + + +The resulting code may however not be completely portable, as schedules can sometimes rely on execution models (e.g., SPMD) or hardware intrinsics (e.g., matrix-multiply-accumulate) that are not widely available. This issue can be mitigated by auto-scheduling mechanisms [MULLAPUDI2016]_. + +++++++++++ +Advantages +++++++++++ + +The main advantage of this approach is that it allows programmers to write an algorithm *only once*, and focus on performance optimization separately. It makes it possible to manually specify optimizations that a polyhedral compiler wouldn't be able to figure out automatically using static data-flow analysis. + +Scheduling languages are, without a doubt, one of the most popular approaches for neural network code generation. The most popular system for this purpose is probably TVM, which provides good performance across a wide range of platforms as well as built-in automatic scheduling mechanisms. + ++++++++++++ +Limitations ++++++++++++ + +This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indices without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular. + +.. table:: + :widths: 50 50 + + +-----------------------------------------------------+-----------------------------------------------------+ + | | | + |.. code-block:: C | |pic2| | + | | | + | for(int i = 0; i < 4; i++) | | + | for(int j = 0; j < 4; j++) | | + | float acc = 0; | | + | for(int k = 0; k < K[i]; k++) | | + | acc += A[i][col[i, k]] * B[k][j] | | + | C[i][j] = acc; | | + +-----------------------------------------------------+-----------------------------------------------------+ + +.. |pic2| image:: halide-iteration.png + :width: 300 + +On the other hand, the block-based program representation that we advocate for through this work allows for block-structured iteration spaces and allows programmers to manually handle load-balancing as they wish. + + +---------- +References +---------- + +.. [LATTNER2004] C. Lattner et al., "LLVM: a compilation framework for lifelong program analysis transformation", CGO 2004 +.. [WOLFE1989] M. Wolfe, "More Iteration Space Tiling", SC 1989 +.. [DARTE1999] A. Darte, "On the Complexity of Loop Fusion", PACT 1999 +.. [ALLEN1984] J. Allen et al., "Automatic Loop Interchange", SIGPLAN Notices 1984 +.. [ANCOURT1991] C. Ancourt et al., "Scanning Polyhedra with DO Loops", PPoPP 1991 +.. [BAGHDADI2021] R. Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021 +.. [VASILACHE2018] N. Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018 +.. [ELANGO2018] V. Elango et al. "Diesel: DSL for Linear Algebra and Neural Net Computations on GPUs", MAPL 2018 +.. [LATTNER2019] C. Lattner et al., "MLIR Primer: A Compiler Infrastructure for the End of Moore’s Law", Arxiv 2019 +.. [GROSSER2012] T. Grosser et al., "Polly - Performing Polyhedral Optimizations on a Low-Level Intermediate Representation", Parallel Processing Letters 2012 +.. [SATO2019] Y. Sato et al., "An Autotuning Framework for Scalable Execution of Tiled Code via Iterative Polyhedral Compilation", TACO 2019 +.. [GIRBAL2006] S. Girbal et al., "Semi-Automatic Composition of Loop Transformations for Deep Parallelism and Memory Hierarchies", International Journal of Parallel Programming 2006 +.. [DIJKSTRA82] E. W. Dijkstra et al., "On the role of scientific thought", Selected writings on computing: a personal perspective 1982 +.. [MULLAPUDI2016] R. Mullapudi et al., "Automatically scheduling halide image processing pipelines", TOG 2016 diff --git a/third_party/enflame/include/triton/docs/programming-guide/chapter-3/debugging.rst b/third_party/enflame/include/triton/docs/programming-guide/chapter-3/debugging.rst new file mode 100644 index 000000000..c470363c6 --- /dev/null +++ b/third_party/enflame/include/triton/docs/programming-guide/chapter-3/debugging.rst @@ -0,0 +1,82 @@ +================ +Debugging Triton +================ + +This tutorial provides guidance for debugging Triton programs. +It is mostly documented for Triton users. +Developers interested in exploring Triton's backend, including MLIR code transformation and LLVM code generation, +can refer to this `section `_ to explore debugging options. + +------------------------------------ +Using Triton's Debugging Operations +------------------------------------ + +Triton includes four debugging operators that allow users to check and inspect tensor values: + +- :code:`static_print` and :code:`static_assert` are intended for compile-time debugging. +- :code:`device_print` and :code:`device_assert` are used for runtime debugging. + +:code:`device_assert` executes only when :code:`TRITON_DEBUG` is set to :code:`1`. +Other debugging operators execute regardless of the value of :code:`TRITON_DEBUG`. + +---------------------------- +Using the Interpreter +---------------------------- + +The interpreter is a straightforward and helpful tool for debugging Triton programs. +It allows Triton users to run Triton programs on the CPU and inspect the intermediate results of each operation. +To enable the interpreter mode, set the environment variable :code:`TRITON_INTERPRET` to :code:`1`. +This setting causes all Triton kernels to bypass compilation and be simulated by the interpreter using numpy equivalents of Triton operations. +The interpreter processes each Triton program instance sequentially, executing operations one at a time. + +There are three primary ways to use the interpreter: + +- Print the intermediate results of each operation using the Python :code:`print` function. To inspect an entire tensor, use :code:`print(tensor)`. To examine individual tensor values at :code:`idx`, use :code:`print(tensor.handle.data[idx])`. + +- Attach :code:`pdb` for step-by-step debugging of the Triton program: + + .. code-block:: bash + + TRITON_INTERPRET=1 pdb main.py + b main.py: + r + +- Import the :code:`pdb` package and set breakpoints in the Triton program: + + .. code-block:: python + + import triton + import triton.language as tl + import pdb + + @triton.jit + def kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr): + pdb.set_trace() + offs = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs) + tl.store(y_ptr + offs, x) + +++++++++++++++++++ +Limitations +++++++++++++++++++ + +The interpreter has several known limitations: + +- It does not support operations on :code:`bfloat16` numeric types. To perform operations on :code:`bfloat16` tensors, use :code:`tl.cast(tensor)` to convert the tensor to :code:`float32`. +- It does not support indirect memory access patterns such as: + + .. code-block:: python + + ptr = tl.load(ptr) + x = tl.load(ptr) + +---------------------------- +Using Third-party Tools +---------------------------- + +For debugging on NVIDIA GPUs, `compute-sanitizer `_ is an effective tool for checking data races and memory access issues. +To use it, prepend :code:`compute-sanitizer` to your command to run the Triton program. + +For debugging on AMD GPUs, you may want to try the LLVM `AddressSanitizer `_ for ROCm. + +For detailed visualization of memory access in Triton programs, consider using the `triton-viz `_ tool, which is agnostic to the underlying GPUs. diff --git a/third_party/enflame/include/triton/docs/python-api/triton-semantics.rst b/third_party/enflame/include/triton/docs/python-api/triton-semantics.rst new file mode 100644 index 000000000..bdf254111 --- /dev/null +++ b/third_party/enflame/include/triton/docs/python-api/triton-semantics.rst @@ -0,0 +1,45 @@ +Triton Semantics +================ + +Triton mostly follows the semantics of NumPy with minor exceptions. In this document, we go over some of the array computing features supported in Triton, and we cover the exceptions where Triton's semantics deviate from that NumPy. + +Type Promotion +-------------- + +**Type Promotion** occurs when tensors of different data types are used in an operation. For binary operations associated to `dunder methods `_ and the ternary function ``tl.where`` on its last two arguments, Triton automatically converts the input tensors to a common data type following a hierarchy of kinds (sets of dtypes): ``{bool} < {integral dypes} < {floating point dtypes}``. + +The algorithm is as follows: + +1. **Kind** If one tensor is of a dtype of a higher kind, the other tensor is promoted to this dtype: ``(int32, bfloat16) -> bfloat16`` + +2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32`` + +3. **Prefer float16** If both tensors are of the same width and signedness but different dtypes (``float16`` and ``bfloat16`` or different ``fp8`` types), they are both promoted to ``float16``. ``(float16, bfloat16) -> float16`` + +4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32`` + +The rules are a bit different when they involve a scalar. By scalar here we mean a numeric literal, a variable marked with `tl.constexpr` or a combination of these. These are represented by NumPy scalars and have types ``bool``, ``int`` and ``float``. + +When an operation involves a tensor and a scalar: + +1. If the scalar is of a kind lower or equal to the tensor, it will not participate in the promotion: ``(uint8, int) -> uint8`` + +2. If the scalar is of a higher kind, we choose the lowest dtype in which it fits among ``int32`` < ``uint32`` < ``int64`` < ``uint64`` for ints and ``float32`` < ``float64`` for floats. Then, both the tensor and the scalar are promoted to this dtype: ``(int16, 4.0) -> float32`` + + +Broadcasting +------------ + +**Broadcasting** allows operations on tensors of different shapes by automatically expanding their shapes to a compatible size without copying the data. This follows the following rules: + +1. If one of the tensor shapes is shorter, pad it on the left with ones until both tensors have the same number of dimensions: ``((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))`` + +2. Two dimensions are compatible if they are equal, or if one of them is 1. A dimension of 1 will be expanded to match the dimension of the other tensor. ``((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))`` + + +Differences with NumPy +---------------------- + +**C rounding in integer division** Operators in Triton follow C semantics rather than Python semantics for efficiency. As such, ``int // int`` implements `rounding towards zero as in C `_ for integers of mixed signs, rather than rounding towards minus infinity as in Python. For the same reason, the modulus operator ``int % int`` (which is defined as ``a % b = a - b * (a // b)``) also follows C semantics rather than Python semantics. + +Perhaps confusingly, integer division and modulus follow Python semantics for computations where all the inputs are scalars. diff --git a/third_party/enflame/include/triton/docs/python-api/triton.language.rst b/third_party/enflame/include/triton/docs/python-api/triton.language.rst new file mode 100644 index 000000000..879f13d1e --- /dev/null +++ b/third_party/enflame/include/triton/docs/python-api/triton.language.rst @@ -0,0 +1,225 @@ +triton.language +=============== + +.. currentmodule:: triton.language + + +Programming Model +----------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + tensor + program_id + num_programs + + +Creation Ops +------------ + +.. autosummary:: + :toctree: generated + :nosignatures: + + arange + cat + full + zeros + zeros_like + cast + + +Shape Manipulation Ops +---------------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + broadcast + broadcast_to + expand_dims + interleave + join + permute + ravel + reshape + split + trans + view + + +Linear Algebra Ops +------------------ + +.. autosummary:: + :toctree: generated + :nosignatures: + + dot + dot_scaled + + +Memory/Pointer Ops +---------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + load + store + make_block_ptr + advance + + +Indexing Ops +------------ + +.. autosummary:: + :toctree: generated + :nosignatures: + + flip + where + swizzle2d + + +Math Ops +-------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + abs + cdiv + ceil + clamp + cos + div_rn + erf + exp + exp2 + fdiv + floor + fma + log + log2 + maximum + minimum + rsqrt + sigmoid + sin + softmax + sqrt + sqrt_rn + umulhi + + +Reduction Ops +------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + argmax + argmin + max + min + reduce + sum + xor_sum + +Scan/Sort Ops +------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + associative_scan + cumprod + cumsum + histogram + sort + gather + +Atomic Ops +---------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + atomic_add + atomic_and + atomic_cas + atomic_max + atomic_min + atomic_or + atomic_xchg + atomic_xor + +Random Number Generation +------------------------ + +.. autosummary:: + :toctree: generated + :nosignatures: + + randint4x + randint + rand + randn + + +Iterators +----------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + range + static_range + + +Inline Assembly +----------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + inline_asm_elementwise + + +Compiler Hint Ops +----------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + assume + debug_barrier + max_constancy + max_contiguous + multiple_of + + +Debug Ops +----------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + static_print + static_assert + device_print + device_assert diff --git a/third_party/enflame/include/triton/docs/python-api/triton.rst b/third_party/enflame/include/triton/docs/python-api/triton.rst new file mode 100644 index 000000000..7622692cc --- /dev/null +++ b/third_party/enflame/include/triton/docs/python-api/triton.rst @@ -0,0 +1,13 @@ +triton +====== + +.. currentmodule:: triton + +.. autosummary:: + :toctree: generated + :nosignatures: + + jit + autotune + heuristics + Config diff --git a/third_party/enflame/include/triton/docs/python-api/triton.testing.rst b/third_party/enflame/include/triton/docs/python-api/triton.testing.rst new file mode 100644 index 000000000..c89b0ba42 --- /dev/null +++ b/third_party/enflame/include/triton/docs/python-api/triton.testing.rst @@ -0,0 +1,14 @@ +triton.testing +============== + +.. currentmodule:: triton.testing + +.. autosummary:: + :toctree: generated + :nosignatures: + + Benchmark + do_bench + do_bench_cudagraph + perf_report + assert_close diff --git a/third_party/enflame/include/triton/include/CMakeLists.txt b/third_party/enflame/include/triton/include/CMakeLists.txt new file mode 100644 index 000000000..109c292fe --- /dev/null +++ b/third_party/enflame/include/triton/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) diff --git a/third_party/enflame/include/triton/include/triton/Analysis/Alias.h b/third_party/enflame/include/triton/include/triton/Analysis/Alias.h new file mode 100644 index 000000000..199238bea --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Analysis/Alias.h @@ -0,0 +1,96 @@ +#ifndef TRITON_ANALYSIS_ALIAS_H +#define TRITON_ANALYSIS_ALIAS_H + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { + +class AliasInfo { +public: + AliasInfo() = default; + AliasInfo(Value value) { insert(value); } + + void insert(Value value) { allocs.insert(value); } + + const DenseSet &getAllocs() const { return allocs; } + + bool operator==(const AliasInfo &other) const { + return allocs == other.allocs; + } + + /// The pessimistic value state of a value without alias + static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) { + return AliasInfo(); + } + static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } + + /// The union of both arguments + static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); + + void print(raw_ostream &os) const { + llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); }); + } + +private: + /// The set of allocated values that are aliased by this lattice. + /// For now, we only consider aliased value produced by the following + /// situations: + /// 1. values returned by scf.yield + /// 2. block arguments in scf.for + /// Example: + /// alloc v1 alloc v2 + /// | | + /// |--------------| |------------| + /// scf.for v3 scf.for v4 scf.for v5 + /// | + /// scf.yield v6 + /// + /// v1's alloc [v1] + /// v2's alloc [v2] + /// v3's alloc [v1] + /// v4's alloc [v1, v2] + /// v5's alloc [v2] + /// v6's alloc [v1] + /// + /// Therefore, v1's liveness range is the union of v3, v4, and v6 + /// v2's liveness range is the union of v4 and v5. + DenseSet allocs; +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Alias Analysis +//===----------------------------------------------------------------------===// +class SharedMemoryAliasAnalysis + : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +public: + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::SparseForwardDataFlowAnalysis; + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + + /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. + /// Given two values, returns their aliasing behavior. + AliasResult alias(Value lhs, Value rhs); + + /// Returns the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location); + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged(lattice, + lattice->join(AliasInfo::getPessimisticValueState( + lattice->getAnchor()))); + } + + /// Computes if the alloc set of the results are changed. + LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALIAS_H diff --git a/third_party/enflame/include/triton/include/triton/Analysis/Allocation.h b/third_party/enflame/include/triton/include/triton/Analysis/Allocation.h new file mode 100644 index 000000000..40352e8ae --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Analysis/Allocation.h @@ -0,0 +1,311 @@ +#ifndef TRITON_ANALYSIS_ALLOCATION_H +#define TRITON_ANALYSIS_ALLOCATION_H + +#include "triton/Analysis/Utility.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include +#include + +namespace mlir { + +namespace triton { +class AllocationAnalysis; + +/// Callback to allow backends to specify target-specific scratch sizes for +/// some operations. +using AllocationAnalysisScratchSizeFn = std::function; + +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op); + +// To convert a tensor from one layout to another, we need to allocate a +// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may +// require multiple iterations, with each iteration involving multiple +// vectorized loads/stores. The scratch buffer has a shape (`repShape`) that +// represents the maximum size accessed in each dimension during each iteration. +// It is padded (`paddedRepShape`) to avoid bank conflicts and is accessed in a +// specific `order`. +struct ScratchConfig { + SmallVector repShape; + SmallVector paddedRepShape; + SmallVector order; + unsigned inVec; + unsigned outVec; + + ScratchConfig(SmallVector repShape, + SmallVector paddedRepShape, unsigned inVec = 1, + unsigned outVec = 1) + : repShape(repShape), paddedRepShape(paddedRepShape), inVec(inVec), + outVec(outVec) {} + + void print(llvm::raw_ostream &os) const { + os << "repShape: ["; + llvm::interleaveComma(repShape, os); + os << "]"; + os << ", paddedRepShape: ["; + llvm::interleaveComma(paddedRepShape, os); + os << "]"; + os << ", order: ["; + llvm::interleaveComma(order, os); + os << "]"; + os << ", inVec: " << inVec << ", outVec: " << outVec << "\n"; + } +}; + +// For a layout conversion between `srcTy` and `dstTy`, return the vector length +// that can be used for the stores to and loads from shared memory, +// respectively. +std::pair +getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy); + +ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, + RankedTensorType dstTy); + +} // namespace triton + +/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h +/// A class that represents an interval, specified using a start and an end +/// values: [Start, End). +template class Interval { +public: + Interval() {} + Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); } + T start() const { return Start; } + T end() const { return End; } + T size() const { return End - Start; } + bool contains(T Addr) const { return Start <= Addr && Addr < End; } + bool intersects(const Interval &R) const { + return Start < R.End && R.Start < End; + } + bool operator==(const Interval &R) const { + return Start == R.Start && End == R.End; + } + bool operator!=(const Interval &R) const { return !(*this == R); } + bool operator<(const Interval &R) const { + return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); + } + +private: + T Start = std::numeric_limits::min(); + T End = std::numeric_limits::max(); +}; + +template Interval(T, T) -> Interval; + +class Allocation { +public: + /// A unique identifier for shared memory buffers + using BufferId = size_t; + using BufferIdSetT = DenseSet; + using FuncAllocMapT = CallGraph::FuncDataMapT; + + static constexpr BufferId InvalidBufferId = + std::numeric_limits::max(); + + Allocation() = default; + /// Creates a new Allocation analysis that computes the shared memory + /// information for all associated shared memory values. + explicit Allocation(Operation *operation) : operation(operation) {} + + /// Runs allocation analysis on the given top-level operation. + void run(FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter); + + /// Returns the operation this analysis was constructed from. + Operation *getOperation() const { return operation; } + + /// Returns the offset of the given buffer in the shared memory. + size_t getOffset(BufferId bufferId) const { + return bufferSet.at(bufferId).offset; + } + + /// Returns the size of the given buffer in the shared memory. + size_t getAllocatedSize(BufferId bufferId) const { + return bufferSet.at(bufferId).size; + } + + /// Returns the allocated interval of the given buffer. + Interval getAllocatedInterval(BufferId bufferId) const { + auto &buffer = bufferSet.at(bufferId); + return Interval(buffer.offset, buffer.offset + buffer.size); + } + + /// Returns the buffer id of the given value. + /// This interface only returns the allocated buffer id. + /// If you want to get all the buffer ids that are associated with the given + /// value, including alias buffers, use getBufferIds. + BufferId getBufferId(Value value) const { + if (valueBuffer.count(value)) { + return valueBuffer.lookup(value)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns all the buffer ids of the given value, including alias buffers. + BufferIdSetT getBufferIds(Value value) const { + BufferIdSetT bufferIds; + auto allocBufferId = getBufferId(value); + if (allocBufferId != InvalidBufferId) + bufferIds.insert(allocBufferId); + for (auto *buffer : aliasBuffer.lookup(value)) { + if (buffer->id != InvalidBufferId) + bufferIds.insert(buffer->id); + } + return bufferIds; + } + + /// Returns the scratch buffer id of the given value. + BufferId getBufferId(Operation *operation) const { + if (opScratch.count(operation)) { + return opScratch.lookup(operation)->id; + } else if (opVirtual.count(operation)) { + return opVirtual.lookup(operation)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns if the given buffer is a virtual buffer. + bool isVirtualBuffer(BufferId bufferId) const { + return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual; + } + + /// Returns the size of total shared memory allocated + size_t getSharedMemorySize() const { return sharedMemorySize; } + + /// Returns mapping from operation to list of live LDS buffers + std::map> getLiveBuffers(); + +private: + /// A class that represents a shared memory buffer + struct BufferT { + /// Explicit: ttg.local_alloc + /// Scratch: ttg.convert_layout + /// Virtual: triton.call + enum class BufferKind { Explicit, Scratch, Virtual }; + + BufferKind kind; + BufferId id; + Operation *owner; + size_t size; + size_t alignment; + size_t offset; + SetVector regionIds; + int sharingGroup; // -1 means not shared + + bool operator==(const BufferT &other) const { return id == other.id; } + bool operator<(const BufferT &other) const { return id < other.id; } + + BufferT(BufferKind kind, BufferId id, Operation *owner, size_t size, + size_t alignment = 4, size_t offset = 0, int sharingGroup = -1) + : kind(kind), id(id), owner(owner), size(size), alignment(alignment), + offset(offset), sharingGroup(sharingGroup) {} + + size_t setOffsetAligned(size_t newOffset) { + return offset = llvm::alignTo(newOffset, alignment); + } + }; + + /// Op -> Scratch Buffer + using OpScratchMapT = llvm::MapVector; + /// Value -> Explicit Buffer + using ValueBufferMapT = llvm::MapVector; + /// Value -> Alias Buffer + using AliasBufferMapT = llvm::MapVector>; + /// BufferId -> Buffer + using BufferSetT = std::map; + +private: + template + void addBuffer(KeyType &key, Args &&...args) { + BufferId nextId = bufferIdCounter++; + auto [it, inserted] = bufferSet.insert_or_assign( + nextId, BufferT(Kind, nextId, key, std::forward(args)...)); + BufferT *buffer = &it->second; + if constexpr (Kind == BufferT::BufferKind::Explicit) { + valueBuffer[key] = buffer; + } else if constexpr (Kind == BufferT::BufferKind::Virtual) { + opVirtual[key] = buffer; + } else { + opScratch[key] = buffer; + } + } + + void addAlias(Value value, Value alloc) { + aliasBuffer[value].insert(valueBuffer[alloc]); + } + +private: + Operation *operation = nullptr; + OpScratchMapT opScratch; + OpScratchMapT opVirtual; + ValueBufferMapT valueBuffer; + AliasBufferMapT aliasBuffer; + BufferSetT bufferSet; + size_t sharedMemorySize = 0; + + size_t bufferIdCounter = 0; + + friend class triton::AllocationAnalysis; +}; + +/// Static analysis that computes the allocation of shared memory buffers +/// of the entire call graph. +/// The allocation is performed in a post-order walk of the call graph. +/// Each call op is treated like convert_layout that allocates a scratch buffer. +/// At each call, we compute the start offset of the scratch buffer and pass it +/// as an argument to the callee. +class ModuleAllocation : public CallGraph { +public: + using FuncOffsetMapT = DenseMap; + + ModuleAllocation(ModuleOp moduleOp, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter = + triton::defaultAllocationAnalysisScratchSizeFn) + : CallGraph(moduleOp) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); + if (inserted) + iter->second.run(funcMap, scratchSizeGetter); + }); + } + + size_t getSharedMemorySize() { + size_t size = 0; + for (auto funcOp : getRoots()) { + auto *alloc = getFuncData(funcOp); + size = std::max(size, alloc->getSharedMemorySize()); + } + return size; + } + + size_t getSharedMemorySize(FunctionOpInterface funcOp) { + return getFuncData(funcOp)->getSharedMemorySize(); + } + + void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) { + sharedMemoryValue[funcOp] = value; + } + + Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) { + return sharedMemoryValue[funcOp]; + } + +private: + FuncOffsetMapT sharedMemoryValue; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALLOCATION_H diff --git a/third_party/enflame/include/triton/include/triton/Analysis/AxisInfo.h b/third_party/enflame/include/triton/include/triton/Analysis/AxisInfo.h new file mode 100644 index 000000000..7aff22c74 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Analysis/AxisInfo.h @@ -0,0 +1,230 @@ +#ifndef TRITON_ANALYSIS_AXISINFO_H +#define TRITON_ANALYSIS_AXISINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include +#include + +namespace mlir::triton { + +//===----------------------------------------------------------------------===// +// AxisInfo +//===----------------------------------------------------------------------===// + +/// This lattice value represents known information on the axes of a lattice. +class AxisInfo { +public: + typedef SmallVector DimVectorT; + +public: + AxisInfo() : AxisInfo({}, {}, {}) {} + + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy) + : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} + + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy, std::optional constantValue) + : contiguity(contiguity), divisibility(divisibility), + constancy(constancy), constantValue(constantValue) { + assert(divisibility.size() == contiguity.size()); + assert(constancy.size() == contiguity.size()); + } + + // contiguity[d] is the length of the shortest sequence of contiguous integers + // along dimension d. + // + // If we have an array of N elements with a contiguity value C, then the array + // can be divided into a list of N/C sequences of C contiguous elements. + // Since we have N = 2^k, C must be a power of two. + // + // For example, the 2D array + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has contiguity [1, 4], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27], + // [18, 22, 26, 30], + // [19, 23, 27, 31]] + // + // has contiguity [2, 1]. + int64_t getContiguity(size_t dim) const { return contiguity[dim]; } + const DimVectorT &getContiguity() const { return contiguity; } + + // divisibility[d] is the largest power of two that divides the first element + // of all groups of length contiguity[d] along dimension d. + // + // For example, + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has divisibility [1, 2], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27]] + // + // has divisibility [4, 1]. + // + // On the other hand, + // + // [0, 1, 2, 0, 4, 5, 6, 7] + // + // has divisibility 1 because its contiguity is 1. + int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } + const DimVectorT &getDivisibility() const { return divisibility; } + + // constancy[d] is the length of the shortest sequence of repeating integers + // along dimension d. + // + // This is particularly useful to infer the contiguity of operations (e.g. + // add) involving a constant. + // + // If we have an array of N elements, with a constancy value C, then the array + // can be divided into a list of N/C sequences of C elements with the same + // value. Since we have N = 2^k, C must be a power of two. + // + // For example + // + // [[8, 8, 8, 8, 12, 12, 12, 12], + // [16, 16, 16, 16, 20, 20, 20, 20]] + // + // has constancy [1, 4]. + int64_t getConstancy(size_t dim) const { return constancy[dim]; } + const DimVectorT &getConstancy() const { return constancy; } + + int getRank() const { return contiguity.size(); } + + std::optional getConstantValue() const { return constantValue; } + + template + static void + initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, + DimVectorT *divisibility, DimVectorT *constancy); + + bool operator==(const AxisInfo &other) const { + return contiguity == other.contiguity && + divisibility == other.divisibility && constancy == other.constancy && + constantValue == other.constantValue; + } + + static AxisInfo getPessimisticValueState(Value value); + + // The gcd of both arguments for each dimension + static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); + + void print(raw_ostream &os) const { + auto print = [&](StringRef name, DimVectorT vec) { + os << name << " = ["; + llvm::interleaveComma(vec, os); + os << "]"; + }; + print("contiguity", contiguity); + print(", divisibility", divisibility); + print(", constancy", constancy); + os << ", constant_value = "; + if (constantValue) + os << *constantValue; + else + os << ""; + } + +private: + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + + // The constant value of the lattice if we can infer it. + std::optional constantValue; +}; + +// Module level axis info analysis based on the call graph, assuming that we do +// not have recursive functions. +// +// Since each function will be called multiple times, we need to calculate the +// axis info based on the axis info of all the callers. In the future, we can +// perform optimization using function cloning so that each call site will have +// unique axis info. +using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public CallGraph { +public: + explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, AxisInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + SymbolTableCollection symbolTable; + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = dyn_cast( + callOp.resolveCallableInTable(&symbolTable)); + update(callOp, callee); + }); + } + } + + AxisInfo *getAxisInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + + unsigned getContiguity(Value value); + unsigned getAlignment(Value value); + + // Overloads of the above methods but have separated elementBitWidth to + // calculate the contiguity. These are useful for computing axis info when + // lowering to hardware intrinsics that require a scalar/warp-uniform base ptr + // with separate per lane offsets like AMD buffer operations. + // + // As a concrete example, instead of a single tensor<128x64x!tt.ptr> + // value, now we have two separate values: !tt.ptr for the base pointer + // and tensor<128x64xi32> for the offset. For such cases, we want to compute + // the contiguity on the offsets but use the pointee element type bit width + // instead of the offset element type bit width for alignment + unsigned getContiguity(Value offsetsValue, unsigned elementBitWidth); + unsigned getAlignment(Value offsetsValue, unsigned elementBitWidth); + + unsigned getMaskAlignment(Value mask); + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace mlir::triton + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Analysis/Membar.h b/third_party/enflame/include/triton/include/triton/Analysis/Membar.h new file mode 100644 index 000000000..f06c4d996 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Analysis/Membar.h @@ -0,0 +1,192 @@ +#ifndef TRITON_ANALYSIS_MEMBAR_H +#define TRITON_ANALYSIS_MEMBAR_H + +#include "Allocation.h" +#include "llvm/ADT/SmallPtrSet.h" + +#include + +namespace mlir { + +class OpBuilder; + +/// Callback to allow backend to provide more information on whether a barrier +/// is needed between two operations. Even though two operations access the same +/// shared memory they may not require a barrier in between them. +using MembarFilterFn = std::function; + +struct BlockInfo { + using IntervalMapT = std::map, std::set>; + + IntervalMapT syncReadIntervals; + IntervalMapT syncWriteIntervals; + + BlockInfo() = default; + + /// Unions two BlockInfo objects. + BlockInfo &join(const BlockInfo &other) { + for (auto &interval : other.syncReadIntervals) + syncReadIntervals[interval.first].insert(interval.second.begin(), + interval.second.end()); + for (auto &interval : other.syncWriteIntervals) + syncWriteIntervals[interval.first].insert(interval.second.begin(), + interval.second.end()); + return *this; + } + + void dump() { + auto &err = llvm::errs(); + err << "Block Interval:\n"; + err << " Read Intervals:\n"; + for (auto &[interval, ops] : syncReadIntervals) { + err << " [" << interval.start() << ", " << interval.end() << "] "; + for (auto &op : ops) + err << op->getName() << " "; + err << "\n"; + } + err << " Write Intervals:\n"; + for (auto &[interval, ops] : syncWriteIntervals) { + err << " [" << interval.start() << ", " << interval.end() << "] "; + for (auto &op : ops) + err << op->getName() << " "; + err << "\n"; + } + } + + /// Returns true if intervals in two BlockInfo objects are intersected. + bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const { + return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals, + filter) || + /*WAR*/ + isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) || + /*WAW*/ + isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter); + } + + /// Clears the intervals because a barrier is inserted. + void sync() { + syncReadIntervals.clear(); + syncWriteIntervals.clear(); + } + + /// Compares two BlockInfo objects. + bool operator==(const BlockInfo &other) const { + return syncReadIntervals == other.syncReadIntervals && + syncWriteIntervals == other.syncWriteIntervals; + } + + bool operator!=(const BlockInfo &other) const { return !(*this == other); } + +private: + bool isIntersected(const IntervalMapT &lhsIntervalSet, + const IntervalMapT &rhsIntervalSet, + MembarFilterFn filter) const { + for (auto &lhs : lhsIntervalSet) + for (auto &rhs : rhsIntervalSet) + if (lhs.first.intersects(rhs.first)) + for (auto lhsOp : lhs.second) + for (auto rhsOp : rhs.second) + if (!filter || !filter(lhsOp, rhsOp)) + return true; + + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Barrier Analysis +//===----------------------------------------------------------------------===// +class MembarAnalysis { + using VirtualBlock = std::pair; + +public: + using FuncBlockInfoMapT = CallGraph::FuncDataMapT; + /// Creates a new Membar analysis that generates the shared memory barrier + /// in the following circumstances: + /// - RAW: If a shared memory write is followed by a shared memory read, and + /// their addresses are intersected, a barrier is inserted. + /// - WAR: If a shared memory read is followed by a shared memory write, and + /// their addresses are intersected, a barrier is inserted. + /// The following circumstances do not require a barrier: + /// - WAW: not possible because overlapped memory allocation is not allowed. + /// - RAR: no write is performed. + /// Temporary storage of operations such as Reduce are considered as both + /// a shared memory read. If the temporary storage is written but not read, + /// it is considered as the problem of the operation itself but not the membar + /// analysis. + MembarAnalysis() = default; + explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter) + : allocation(allocation), filter(filter) {} + + /// Runs the membar analysis to the given operation, inserts a barrier if + /// necessary. + void run(FuncBlockInfoMapT &funcBlockInfoMap); + +private: + /// Applies the barrier analysis based on the SCF dialect, in which each + /// region has a single basic block only. + /// Example: + /// region1 + /// op1 + /// op2 (scf.if) + /// region2 + /// op3 + /// op4 + /// region3 + /// op5 + /// op6 + /// op7 + /// TODO: Explain why we don't use ForwardAnalysis: + void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder); + + /// Updates the BlockInfo operation based on the operation. + void update(Operation *operation, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder); + + /// Collects the successors of the terminator + void visitTerminator(Operation *operation, + SmallVector &successors); + + void insertBarrier(Operation *operation, OpBuilder *builder); + +private: + Allocation *allocation = nullptr; + MembarFilterFn filter = nullptr; +}; + +/// Postorder traversal on the callgraph to insert membar instructions +/// of each function. +/// Each function maintains a BlockInfo map that includes all potential buffers +/// after returning. This way users do not have to explicitly insert membars +/// before and after function calls, but might be a bit conservative. +class ModuleMembarAnalysis : public CallGraph { +public: + ModuleMembarAnalysis(ModuleAllocation *moduleAllocation, + MembarFilterFn filter = nullptr) + : CallGraph(moduleAllocation->getModuleOp()), + moduleAllocation(moduleAllocation), filter(filter) {} + + void run() { + walk( + // Pre-order walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order walk callback + [&](FunctionOpInterface funcOp) { + auto *allocation = moduleAllocation->getFuncData(funcOp); + auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo()); + if (inserted) { + MembarAnalysis analysis(allocation, filter); + analysis.run(funcMap); + } + }); + } + +private: + ModuleAllocation *moduleAllocation; + MembarFilterFn filter; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_MEMBAR_H diff --git a/third_party/enflame/include/triton/include/triton/Analysis/Utility.h b/third_party/enflame/include/triton/include/triton/Analysis/Utility.h new file mode 100644 index 000000000..ba5b7fe2e --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Analysis/Utility.h @@ -0,0 +1,419 @@ +#ifndef TRITON_ANALYSIS_UTILITY_H +#define TRITON_ANALYSIS_UTILITY_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" + +namespace mlir { + +inline bool isZeroConst(Value v) { + auto constantOp = v.getDefiningOp(); + if (!constantOp) + return false; + if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + return false; +} + +class ReduceOpHelper { +public: + explicit ReduceOpHelper(triton::ReduceOp op) + : op(op.getOperation()), axis(op.getAxis()) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + + ArrayRef getSrcShape() { return srcShape; } + + Attribute getSrcLayout() { return srcEncoding; } + + triton::ReduceOp getOperation() { return op; } + + unsigned getThreadOffsetOnReductionAxis(); + + bool isWarpSynchronous(); + + unsigned getInterWarpSizeWithUniqueData(); + + unsigned getIntraWarpSizeWithUniqueData(); + + // The shape of the shared memory space needed for the reduction. + SmallVector getScratchRepShape(); + + SmallVector getOrderWithAxisAtBeginning(); + + unsigned getScratchSizeInBytes(); + + bool isReduceWithinCTA(); + +private: + triton::ReduceOp op; + ArrayRef srcShape; + Attribute srcEncoding; + SmallVector srcElementTypes; + int axis; +}; + +class ScanLoweringHelper { +public: + explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + legacyEncoding = firstTy.getEncoding(); + srcEncoding = triton::gpu::toLinearEncoding(legacyEncoding, srcShape); + srcElementTypes = op.getElementTypes(); + // The codegen does not support different element/thread/warp order so + // we choose one a priori. We choose that of the blocked encoding. + // When we generalise this code to other layouts we'll probably need to + // get rid of all this logic and the *Stride auxiliary methods + // and replace them by transposes and reshapes on the LinearLayout + if (auto blockedEncoding = + dyn_cast(legacyEncoding)) { + order = llvm::to_vector(blockedEncoding.getOrder()); + } else { + order = srcEncoding.getOrder(); + } + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != legacyEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + // Return true if the lowering of the scan op is supported. + bool isSupported(); + // Return the number of elements per thread along axis dim. + unsigned getAxisNumElementsPerThread(); + // Return the number of elements per thread along non-axis dims. + unsigned getNonAxisNumElementsPerThread(); + // Return the number of threads per warp along non-axis dims. + unsigned getNonAxisNumThreadsPerWarp(); + // Return the flat numbers of threads computing independent scan results. + unsigned getNonAxisNumThreadsPerCTA(); + // Return the number of warps per CTA along axis dim with unique data. + unsigned getAxisNumWarpsWithUniqueData(); + // Return the number of threads per warp along axis dim with unique data. + unsigned getAxisNumThreadsPerWarpWithUniqueData(); + // Return the number of blocks along axis dim. + unsigned getAxisNumBlocks(); + // Return the number of blocks along non axis dim. + unsigned getNonAxisNumBlocks(); + // Return the size of the scratch space needed for scan lowering. + unsigned getScratchSizeInBytes(); + // Return the number of elements of the scratch space needed for scan + // lowering. + unsigned getScratchSizeInElems(); + + // Stride between contiguous element along axis dim. + unsigned getAxisElementStride(); + // Stride between contiguous threads along axis dim. + unsigned getAxisThreadStride(); + // Stride between contiguous blocks along axis dim. + unsigned getAxisBlockStride(); + + Location getLoc() { return scanOp.getLoc(); } + unsigned getAxis() { return scanOp.getAxis(); } + bool getReverse() { return scanOp.getReverse(); } + triton::gpu::LinearEncodingAttr getEncoding() { return srcEncoding; } + llvm::ArrayRef getShape() { return srcShape; } + unsigned getNumOperands() { return scanOp.getNumOperands(); } + SmallVector getElementTypes() { return srcElementTypes; } + SmallVector getOrder() { return order; } + Region &getCombineOp(); + +private: + triton::ScanOp scanOp; + triton::gpu::LinearEncodingAttr srcEncoding; + Attribute legacyEncoding; + llvm::ArrayRef srcShape; + SmallVector srcElementTypes; + SmallVector order; +}; + +// Helper class for lowering `tt.gather` operations. This class shares lowering +// logic between shared memory allocation and LLVM codegen. +class GatherLoweringHelper { +public: + GatherLoweringHelper(triton::GatherOp gatherOp); + + // Get the shared memory scratch size required by this op. + unsigned getScratchSizeInBytes(); + // Determine if the gather can be performed completely within a warp. + bool isWarpLocal(); + +private: + triton::GatherOp gatherOp; +}; + +// This struct represents a decomposed layout conversion within a warp into +// three transformations: P1 and P2 represent lane-dependent register shuffles +// and W represents a warp shuffle. P2^-1 is returned because it represents the +// (reg, lane) -> (reg) mapping from the perspective of the destination element. +// +// Nearly all layout conversions that only require data movement within a warp +// can be implemented this way. +struct DecomposedWarpConversion { + triton::LinearLayout P1, W, P2inv; + triton::LinearLayout reducedP1, reducedP2inv; +}; + +// Given the source and destination tensor types where a layout conversion only +// involves data movement within warps, attempt to find a decomposition for a +// warp layout conversion. +std::optional +getWarpLayoutConvertDecomposition(RankedTensorType srcTy, + RankedTensorType dstTy); + +// Decomposes a reshape into simpler pieces. +// +// As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2]. +// You might explain what this does as follows. +// +// - Split the first input dimension into [2,2]. +// - Take the remaining two input dimensions, merge them into a single [16] +// dim, and then split that into [8,2]. +// +// In general, a reshape can be described a sequence of smushing one or more +// input dimensions together and then breaking them apart into one or more +// output dimensions. So we could represent the example above as follows. +// +// [ +// ([0], [0, 1]), # input dim [0] -> output dims [0, 1] +// ([1, 2], [2, 3]), # input dims [1, 2] -> output dims [2, 3] +// ] +// +// Notice that the input dims (first tuple elems) appear in sequential order if +// you read left-to-right-top-to-bottom, and so do the output dims. +// +// This function returns the above decomposition. +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, ArrayRef dstShape); + +// Returns the number of elements in the scratch space needed. +// If shape is empty, it means no shared memory is needed. +unsigned getNumScratchElements(ArrayRef shape); + +bool supportWMMA(triton::DotOp op); + +bool supportMMA(triton::DotOp op, int version); + +bool supportMMA(Value value, int version); + +// Conversion from `srcTy` to `dstTy` involving the minimum amount of data +// transfer provided that both types can be converted to LL (if it can't it'll +// return nullopt). The output will be such that layout.getInDimNames() == +// layout.getOutDimNames() and the conversion will not include kBlock (resp. +// kWarp or kLane) if it can be avoided +triton::LinearLayout minimalCvtLayout(RankedTensorType srcTy, + RankedTensorType dstTy); + +// Conversion from `srcTy` to `dstTy` only involves reordering of registers. +// There is no need for data exchange across threads, warps, or blocks. +bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy); + +// Conversion from `srcTy` to `dstTy` involves data exchange across threads +// within a warp. No data exchange across warps or blocks is needed. +bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy); + +// Conversion from `srcTy` to `dstTy` involves data exchange across threads, +// warps, and possibly blocks. +bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); + +bool atomicNeedsSharedMemory(Value result); + +// Return true if the src and dst layout match. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy); + +// Check if MFMA layout can be converted to the dot operand +// layout using warp shuffle. +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy); + +// TODO: Move utility functions that belong to ConvertLayoutOp to class +// ConvertLayoutOpHelper in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); + +/// Multi-root DAG topological sort. +/// Performs a topological sort of the Operation in the `toSort` SetVector. +/// Returns a topologically sorted SetVector. +/// It is faster than mlir::topologicalSort because it prunes nodes that have +/// been visited before. +SetVector +multiRootTopologicalSort(const SetVector &toSort); + +/// This uses the toplogicalSort above +SetVector +multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, + TransitiveFilter forwardFilter = nullptr); + +/// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +/// This class represents a call graph for a given ModuleOp and holds +/// data of type T associated with each FunctionOpInterface. +template class CallGraph { +public: + using FuncDataMapT = DenseMap; + + /// Constructor that builds the call graph for the given moduleOp. + explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); } + + /// Walks the call graph and applies the provided update functions + /// to the edges and nodes. + template + void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) { + DenseSet visited; + for (auto root : roots) { + doWalk(root, visited, updateEdgeFn, + updateNodeFn); + } + } + + /// Retrieves the data associated with a function + T *getFuncData(FunctionOpInterface funcOp) { + if (funcMap.count(funcOp)) { + return &funcMap[funcOp]; + } + return nullptr; + } + + /// Getters + ModuleOp getModuleOp() const { return moduleOp; } + SmallVector getRoots() const { return roots; } + size_t getNumFunctions() const { return funcMap.size(); } + + /// Returns true if the given function is a root. + bool isRoot(FunctionOpInterface funcOp) const { + return llvm::is_contained(roots, funcOp); + } + + /// Maps the data and the graph nodes associated with a funcOp to a + /// targetFuncOp. + template + void mapFuncOp(FROM funcOp, TO targetFuncOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.second == funcOp) { + edge.second = targetFuncOp; + } + } + } + graph[targetFuncOp] = graph[funcOp]; + // Replace in roots + for (auto it = roots.begin(); it != roots.end(); ++it) { + if (*it == funcOp) { + *it = targetFuncOp; + break; + } + } + // Replace in funcMap + funcMap[targetFuncOp] = funcMap[funcOp]; + } + + /// Maps the graph edges associated with a callOp to a targetCallOp. + template + void mapCallOp(FROM callOp, TO targetCallOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.first == callOp) { + edge.first = targetCallOp; + } + } + } + } + +private: + void build() { + SymbolTableCollection symbolTable; + DenseSet visited; + // Build graph + moduleOp.walk([&](Operation *op) { + auto caller = op->getParentOfType(); + if (auto callOp = dyn_cast(op)) { + auto *callee = callOp.resolveCallableInTable(&symbolTable); + auto funcOp = dyn_cast_or_null(callee); + if (funcOp) { + graph[caller].emplace_back( + std::pair(callOp, funcOp)); + visited.insert(funcOp); + } + } + }); + // Find roots + moduleOp.walk([&](FunctionOpInterface funcOp) { + if (!visited.count(funcOp)) { + roots.push_back(funcOp); + } + }); + } + + template + void doWalk(FunctionOpInterface funcOp, + DenseSet &visited, UpdateEdgeFn updateEdgeFn, + UpdateNodeFn updateNodeFn) { + if (visited.count(funcOp)) { + llvm::report_fatal_error("Cycle detected in call graph"); + } + if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) { + updateNodeFn(funcOp); + } + for (auto [callOp, callee] : graph[funcOp]) { + if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) { + updateEdgeFn(callOp, callee); + } + doWalk(callee, visited, updateEdgeFn, + updateNodeFn); + if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) { + updateEdgeFn(callOp, callee); + } + } + if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) { + updateNodeFn(funcOp); + } + visited.erase(funcOp); + } + +protected: + ModuleOp moduleOp; + DenseMap>> + graph; + FuncDataMapT funcMap; + SmallVector roots; +}; +// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v); + +} // namespace mlir + +#endif // TRITON_ANALYSIS_UTILITY_H diff --git a/third_party/enflame/include/triton/include/triton/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/CMakeLists.txt new file mode 100644 index 000000000..27c703b3c --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) diff --git a/third_party/enflame/include/triton/include/triton/Conversion/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Conversion/CMakeLists.txt new file mode 100644 index 000000000..730f5cadd --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonGPUToLLVM) +add_subdirectory(TritonToTritonGPU) diff --git a/third_party/enflame/include/triton/include/triton/Conversion/MLIRTypes.h b/third_party/enflame/include/triton/include/triton/Conversion/MLIRTypes.h new file mode 100644 index 000000000..dd8d4be4c --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/MLIRTypes.h @@ -0,0 +1,46 @@ +#ifndef TRITON_CONVERSION_MLIR_TYPES_H +#define TRITON_CONVERSION_MLIR_TYPES_H + +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// This file redefines some common MLIR types for easy usage. +namespace mlir { +namespace triton { +namespace type { + +// Integer types +inline Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); } +inline Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); } +inline Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); } +inline Type u32Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 32, IntegerType::Unsigned); +} +inline Type u1Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 1, IntegerType::Unsigned); +} + +// Float types +inline Type f16Ty(MLIRContext *ctx) { return Float16Type::get(ctx); } +inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); } +inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); } +inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } + +inline bool isFloat8(Type type) { + return isa(type); +} + +inline bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128() || + type.isBF16() || llvm::isa(type) || + isFloat8(type); +} + +inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } + +} // namespace type +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_MLIR_TYPES_H diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h new file mode 100644 index 000000000..00ec88089 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h @@ -0,0 +1,27 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +inline std::string strJoin(llvm::ArrayRef strs, + llvm::StringRef delimiter) { + return llvm::join(strs.begin(), strs.end(), delimiter); +} + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..93f8374e5 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonGPUToLLVM) +add_public_tablegen_target(TritonGPUConversionPassIncGen) diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h new file mode 100644 index 000000000..656f2bfca --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -0,0 +1,235 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H +#define TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton { + +namespace gpu { + +Type getElementType(Value value); + +class MultipleOperandsRange + : public iterator_range>::iterator> { + using ContainerT = SmallVector>; + +public: + using iterator_range::iterator_range; + ContainerT::reference operator[](ContainerT::size_type idx) { + return begin()[idx]; + } + ContainerT::const_reference operator[](ContainerT::size_type idx) const { + return begin()[idx]; + } + ContainerT::size_type size() const { return end() - begin(); } +}; + +// Base pattern for elementwise conversion using ConcreteT. Unpacks individual +// elements from a `!llvm.struct` via `llvm.extactvalue`, calls +// ConcreteT::createDestOps on each element, and packs them back into an +// `!llvm.struct` using `llvm.insertvalue`. +// +// Also supports processing the inputs in a vectorized form by consuming and +// producing multiple operand sets in ConcreteT::createDestOps. +template +class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ElementwiseOpConversionBase( + LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit), + axisAnalysisPass(axisAnalysisPass) {} + + // Try to deduplicate the resultVals based on the + // constancy properties of the result discovered by + // the axis analysis pass. If possible, redundant + // computation is eliminated. + SmallVector maybeDeduplicate(SourceOp op, + SmallVector resultVals) const { + if (!isMemoryEffectFree(op)) + // the op has side effects: can't dedup + return resultVals; + SmallVector results = op->getResults(); + if (results.size() == 0 || results.size() > 1) + // there must be exactly 1 result + return resultVals; + Value result = results[0]; + Type type = result.getType(); + if (!type) + return resultVals; + RankedTensorType rtType = dyn_cast(type); + if (!rtType) + // the result must be a tensor + return resultVals; + Attribute encoding = rtType.getEncoding(); + if (!encoding) + // encoding not available + return resultVals; + Attribute baseEncoding = encoding; + if (isa(baseEncoding) || + isa(baseEncoding)) + // TODO: this logic seems incorrect for mfma and wmma layout. Skip for + // now. We saw mismatches for some flash-attention and dot tests on AMD + // backend. Note that this logic works for sliced layout whose parent is + // mfma layout. Therefore, this is not combined with the following check. + return resultVals; + while (auto sliced = dyn_cast(baseEncoding)) + baseEncoding = sliced.getParent(); + if (isa(baseEncoding)) { + // TODO: this logic seems incorrect for mma layout. Skip for now. + // The following test crashes and some other miscompile: + // test_core::test_fp8_dot_acc + return resultVals; + } + + SmallVector elemsPerThread = getElemsPerThread(rtType); + int rank = elemsPerThread.size(); + if (product(elemsPerThread) != resultVals.size()) + return resultVals; + AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result); + if (!axisInfo) + // axis info (e.g., constancy) not available + return resultVals; + SmallVector contigPerThread = getContigPerThread(rtType); + if (rank != contigPerThread.size()) + return resultVals; + + SmallVector constancy = axisInfo->getConstancy(); + if (rank != constancy.size()) + return resultVals; + bool hasConstancy = false; + for (int i = 0; i < rank; ++i) { + if (constancy[i] > contigPerThread[i]) { + if (constancy[i] % contigPerThread[i] != 0) + // constancy is not evenly covered by contigPerThread + return resultVals; + // can't move the values across different + // "contigPerThread"-sized blocks + constancy[i] = contigPerThread[i]; + } + if (elemsPerThread[i] < 1 || constancy[i] < 1) + return resultVals; + if (!(elemsPerThread[i] % constancy[i] == 0 || + constancy[i] % elemsPerThread[i] == 0)) + // either the constancy along each dimension must fit + // into the elemsPerThread or the other way around + return resultVals; + if (constancy[i] > 1) + hasConstancy = true; + } + if (!hasConstancy) + // nothing to deduplicate + return resultVals; + + if (rank > 1) { + // reorder the shape and constancy vectors by the axis order: + // from the fastest-changing to the smallest-changing axis + SmallVector order = getOrder(rtType); + if (rank != order.size()) + return resultVals; + elemsPerThread = applyPermutation(elemsPerThread, order); + constancy = applyPermutation(constancy, order); + } + + SmallVector strides(rank, 1); + for (int i = 1; i < rank; ++i) { + strides[i] = strides[i - 1] * elemsPerThread[i - 1]; + } + SmallVector dedupResultVals; + dedupResultVals.reserve(resultVals.size()); + for (int i = 0; i < resultVals.size(); ++i) { + // each coordinate of the orig_idx is "coarsened" using the + // constancy along this dimension: the resulting dedup_idx + // points to the reused value in the original resultsVal + int orig_idx = i; + int dedup_idx = 0; + for (int j = 0; j < rank; ++j) { + int coord_j = orig_idx % elemsPerThread[j]; + dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j]; + orig_idx /= elemsPerThread[j]; + } + dedupResultVals.push_back(resultVals[dedup_idx]); + } + + return dedupResultVals; + } + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = op.getType(); + Location loc = op->getLoc(); + // element type + auto resultElementTy = getElementTypeOrSelf(resultTy); + Type elemTy = this->getTypeConverter()->convertType(resultElementTy); + SmallVector> allOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + allOperands.resize(subOperands.size()); + for (auto v : llvm::enumerate(subOperands)) + allOperands[v.index()].push_back(v.value()); + } + if (allOperands.size() == 0) + allOperands.push_back({}); + + SmallVector resultVals; + for (auto it = allOperands.begin(), end = allOperands.end(); it != end;) { + auto curr = static_cast(this)->createDestOps( + op, adaptor, rewriter, elemTy, MultipleOperandsRange(it, end), loc); + if (curr.size() == 0) + return failure(); + for (auto v : curr) { + if (!static_cast(v)) + return failure(); + resultVals.push_back(v); + } + it += curr.size(); + } + resultVals = maybeDeduplicate(op, resultVals); + Value view = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + rewriter.replaceOp(op, view); + + return success(); + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + +// Trivial case where we map elementwise to an existing LLVM operator +template +struct ElementwiseOpConversion + : public ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; + } +}; + +} // namespace gpu + +} // namespace mlir::triton +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h new file mode 100644 index 000000000..907d36ed4 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h @@ -0,0 +1,35 @@ +#ifndef TRITON_CONVERSION_FMA_DOT_UTILITY_H +#define TRITON_CONVERSION_FMA_DOT_UTILITY_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::gpu { + +/// Abstract interface for scalar multiplication of Value vectors. +/// +/// Enable generation of hardware specific code in different backends. +class FMAVectorMultiplier { +public: + /// \returns scalar product of two arrays, plus c: a·b + c + virtual Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) = 0; + + virtual ~FMAVectorMultiplier() = default; +}; + +/// Implements a framework for FMA dot conversion to llvm. +/// +/// This function implements architecture independent part of FMA dot +/// conversion and calls "multiplier" object, which is defined by caller +/// and implements architecture dependant part of conversion. +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier); + +} // namespace mlir::triton::gpu + +#endif // TRITON_CONVERSION_FMA_DOT_UTILITY_H diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Passes.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Passes.h new file mode 100644 index 000000000..2a3a67a59 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Passes.h @@ -0,0 +1,25 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H + +#include "mlir/Pass/Pass.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton::gpu { + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +} // namespace triton::gpu + +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Passes.td new file mode 100644 index 000000000..fa3cc63c7 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -0,0 +1,45 @@ +#ifndef TRITONCOMMONGPU_CONVERSION_PASSES +#define TRITONCOMMONGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> { + let summary = "Add metadata for shared memory allocation"; + + let description = [{ + This pass uses the `ModuleAllocation` analysis to: + - Annotate modules with an attribute with the amount of shared/local + memory used. + - Annotate operations with an offset into the total shared/local memory. + }]; +} + +def TritonGPUGlobalScratchAllocationPass : Pass<"tritongpu-global-scratch-memory-allocation", "mlir::ModuleOp"> { + let summary = "Assign global scratch memory allocation"; + + let description = [{ + Decide on global scratch space memory allocation and assign attributes to each allocation. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect" + ]; +} + +def TritonGPUAllocateWarpGroups : Pass<"tritongpu-allocate-warp-groups", "mlir::ModuleOp"> { + let summary = "Allocate warp groups"; + + let description = [{ + The `tritongpu-allocate-warp-groups` pass performs warpgroup allocation for + a GPU program. When a GPU program contains warp specialization, additional + warps are launched in addition to the "default" warp group. The "default" + warpgroup executes top-level code in a `tt.func` and its size is specified + by the user via the `num_warps` argument. + + This pass analyzes `ttg.warp_specialize` ops in the program and determines + the total number of needed warps, then attaches the range of warp IDs to + each warpgroup function. + }]; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 000000000..43a1ac6a1 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,112 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H + +#include "TargetInfoBase.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::BlockedEncodingAttr; +LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +namespace mlir { +namespace triton { + +constexpr int patternBenefitDefault = 1; +constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; +constexpr int patternBenefitClampOptimizedPattern = 20; +constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; +constexpr int patternBenefitNvidiaTensorCoreSubviewPattern = 20; + +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +// The given callback is invoked at the end of a successful rewrite. The +// callback receives 1) the current source op, 2) the number of issued LLVM +// instructions and 3) their input types. Each MLIR backend can provide a +// callback and, thus, handle backend-specific behaviors. +void populateMemoryOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateViewOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMinMaxFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool hwNanPropagationSupported, + PatternBenefit benefit); +void populateClampFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateRegReallocOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Patterns.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Patterns.h new file mode 100644 index 000000000..1e45b2082 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Patterns.h @@ -0,0 +1,27 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PATTERNS_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PATTERNS_H + +#include + +namespace mlir { +class ModuleOp; +class RankedTensorType; + +namespace triton::gpu { + +/// Replaces `blocked -> dot_op` with `blocked -> shared -> dot_op` in the given +/// |module| op because the codegen doesn't handle `blocked -> dot_op` directly. +void decomposeBlockedToDotLayoutConversion(ModuleOp module); + +/// Replaces `mfma -> dot_op` with `mfma -> blocked -> dot_op` in the +/// given |module| op, but bypass the decomposition if |shortcutFn| returns +/// true. +using ShortcutFn = std::function; +void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, + ShortcutFn shortcutFn); + +} // namespace triton::gpu + +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h new file mode 100644 index 000000000..f1fb3cf72 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -0,0 +1,104 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H + +#include "triton/Conversion/MLIRTypes.h" + +namespace mlir::triton { + +class TargetInfoBase { +public: + virtual bool supportMaximumMinimum() const = 0; + + virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; + + virtual Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const = 0; + + // Store/load a value from shared memory, either in the same CTA or, if + // `ctaId` is non-nullopt, in another CTA in the same group. + // + // A target that does not support cross-CTA transfers will assert if ctaId is + // non-nullopt. + // + // Assumes the address is aligned to the width of `val`. + virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const = 0; + virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const = 0; + + void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) const { + storeDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, val, pred); + } + Value loadShared(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred) const { + return loadDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, elemTy, + pred); + } + + virtual bool canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const = 0; + + virtual void storeMatrixShared(RewriterBase &rewriter, Location loc, + Value ptr, Value val) const = 0; + + virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const = 0; + + virtual Value programId(RewriterBase &rewriter, Location loc, + ModuleOp moduleOp, int axis) const = 0; + + virtual bool warpReduce(RewriterBase &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, + unsigned interleave) const = 0; + + virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; + // Emits LLVM code with |rewriter| to print a message following the given + // format from the device. |formatStrStart| is the pointer to the start of + // the format string global variable; |args| are the arguments to fill + // placeholders in the format string. + virtual void printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const = 0; + + // Emits LLVM code with |rewriter| to print a message, particularly useful for + // backend debug. |msg| is the message to print, |args| are the arguments to + // fill placeholders in the |msg|. + // NOTE: This function is used for backend debug. DO NOT DELETE. + // Example use: targetInfo.printf(rewriter,"index: %d, value: %f", {index, + // value}); + virtual void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const = 0; + + // Emits LLVM code with |rewriter| to perform assertion failure with the given + // |message| from the given |func| in |file|. + virtual void assertFail(RewriterBase &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const = 0; + + virtual int getSharedAddressSpace() const = 0; + + virtual int getAddressSpace(Attribute addressSpace) const = 0; + + virtual bool supportVectorizedAtomics() const = 0; + + // Helper used by targets to annotate store operations during lowering to + // llvm. + virtual void storeOpAnnotation(triton::gpu::LocalStoreOp op, + size_t localStoreOpCount, Type type) const {} + + virtual ~TargetInfoBase() {} +}; +} // namespace mlir::triton +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h new file mode 100644 index 000000000..c8316d02e --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -0,0 +1,48 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonGPUToLLVMTypeConverter(MLIRContext *ctx, + const LowerToLLVMOptions &option, + const TargetInfoBase &targetInfo, + const DataLayoutAnalysis *analysis = nullptr); + TritonGPUToLLVMTypeConverter(MLIRContext *ctx, + const TargetInfoBase &targetInfo, + const DataLayoutAnalysis *analysis = nullptr); + + Type convertTritonTensorType(RankedTensorType type, + const TargetInfoBase &targetInfo); + Type convertMemDescType(triton::gpu::MemDescType type, + const TargetInfoBase &targetInfo); + Type convertAsyncTokenType(triton::gpu::AsyncTokenType type); + + template + void convertFP8Type() { + addConversion([&](T1 type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }), + addConversion([&](T2 type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }), + addConversion([&](T3 type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }), + addConversion([&](T4 type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); + } +}; + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Utility.h new file mode 100644 index 000000000..0181d183c --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -0,0 +1,898 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H + +#include + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "ttgpu_to_llvm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::LLVM { +using namespace mlir::triton; + +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v); +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v); +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v); +Value createConstantF16(Location loc, OpBuilder &rewriter, float v); +Value createConstantBF16(Location loc, OpBuilder &rewriter, float v); +Value createConstantF32(Location loc, OpBuilder &rewriter, float v); +Value createConstantF64(Location loc, OpBuilder &rewriter, double v); +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type); +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value); +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value); + +LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc, + LLVMFuncOp funcOp, ValueRange args); +LLVM::CallIntrinsicOp +createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic, + TypeRange types, ValueRange args); +} // namespace mlir::LLVM + +// Is v an integer or floating-point scalar constant equal to 0? +bool isConstantZero(Value v); + +namespace mlir::triton { + +struct TritonLLVMOpBuilder { + TritonLLVMOpBuilder(Location loc, OpBuilder &builder) + : loc(loc), builder(&builder) {} + + // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive + // Operators + template LLVM::SIToFPOp inttofloat(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::IntToPtrOp inttoptr(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::PtrToIntOp ptrtoint(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::ZExtOp zext(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::SExtOp sext(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::FPExtOp fpext(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::FPTruncOp fptrunc(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::TruncOp trunc(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::UDivOp udiv(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::SDivOp sdiv(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::URemOp urem(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::AddOp add(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::SubOp sub(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::FAddOp fadd(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::MulOp mul(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::FMulOp fmul(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::FMAOp fma(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::FNegOp neg(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::SMaxOp smax(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::UMaxOp umax(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::MaxNumOp fmax(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::SMinOp smin(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::UMinOp umin(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::MinNumOp fmin(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::ShlOp shl(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::LShrOp lshr(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::AShrOp ashr(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::AndOp and_(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::XOrOp xor_(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::OrOp or_(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + LLVM::BitcastOp bitcast(Value val, Type type) { + return builder->create(loc, type, val); + } + template + LLVM::AddrSpaceCastOp addrspacecast(Args &&...args) { + return builder->create(loc, + std::forward(args)...); + } + template LLVM::GEPOp gep(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::InsertValueOp insert_val(Args &&...args) { + return builder->create(loc, + std::forward(args)...); + } + template LLVM::ExtractValueOp extract_val(Args &&...args) { + return builder->create(loc, + std::forward(args)...); + } + template + LLVM::InsertElementOp insert_element(Args &&...args) { + return builder->create(loc, + std::forward(args)...); + } + template + LLVM::ExtractElementOp extract_element(Args &&...args) { + return builder->create(loc, + std::forward(args)...); + } + template LLVM::LoadOp load(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::StoreOp store(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + LLVM::FCmpOp fcmp_ogt(Value lhs, Value rhs) { + return builder->create(loc, builder->getI1Type(), + LLVM::FCmpPredicate::ogt, lhs, rhs); + } + LLVM::FCmpOp fcmp_olt(Value lhs, Value rhs) { + return builder->create(loc, builder->getI1Type(), + LLVM::FCmpPredicate::olt, lhs, rhs); + } + LLVM::FCmpOp fcmp_eq(Value lhs, Value rhs) { + return builder->create(loc, builder->getI1Type(), + LLVM::FCmpPredicate::oeq, lhs, rhs); + } + template LLVM::ICmpOp icmp_eq(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::eq, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ne(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::ne, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_slt(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::slt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sle(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::sle, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sgt(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::sgt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_sge(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::sge, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ult(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::ult, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ule(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::ule, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_ugt(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::ugt, + std::forward(args)...); + } + template LLVM::ICmpOp icmp_uge(Args &&...args) { + return builder->create(loc, LLVM::ICmpPredicate::uge, + std::forward(args)...); + } + template LLVM::SelectOp select(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::AddressOfOp address_of(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + mlir::gpu::BarrierOp barrier() { + return builder->create(loc); + } + template LLVM::UndefOp undef(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::ZeroOp null(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + template LLVM::CallOp call(Args &&...args) { + return builder->create(loc, std::forward(args)...); + } + // Constants + Value int_val(short bitwidth, int64_t val) { + Type ty = builder->getIntegerType(bitwidth); + return builder->create(loc, ty, + builder->getIntegerAttr(ty, val)); + } + Value i1_val(int64_t val) { return int_val(1, val); } + Value true_val() { return int_val(1, true); } + Value false_val() { return int_val(1, false); } + Value f16_val(float v) { return LLVM::createConstantF16(loc, *builder, v); } + Value bf16_val(float v) { return LLVM::createConstantBF16(loc, *builder, v); } + Value f32_val(float v) { return LLVM::createConstantF32(loc, *builder, v); } + Value f64_val(double v) { return LLVM::createConstantF64(loc, *builder, v); } + Value i8_val(int64_t val) { return int_val(8, val); } + Value i16_val(int64_t val) { return int_val(16, val); } + Value i32_val(int64_t val) { return int_val(32, val); } + Value i64_val(int64_t val) { return int_val(64, val); } + + Location loc; + OpBuilder *builder; +}; + +// This builder combines an IRRewriter and a TritonLLVMOpBuilder into one, +// making it easy to create operations with an implicit location and create LLVM +// operations with shorthands. +class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder { +public: + // Create a builder with an implicit location. Arguments are forwarded to + // IRRewriter's constructor. + template + TritonLLVMIRRewriter(Location loc, Args &&...args) + : IRRewriter(std::forward(args)...), + TritonLLVMOpBuilder(loc, *this) {} + + // Get the implicit location. + Location getLoc() const { return loc; } + // Set the implicit location used to build ops. + void setLoc(Location loc) { this->loc = loc; } + + // Wrapper for op creation that passes an implicit location. + template OpTy create(Args &&...args) { + return OpBuilder::create(loc, std::forward(args)...); + } +}; +} // namespace mlir::triton + +// Types +#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) +#define int_ty(width) rewriter.getIntegerType(width) +#define i64_ty rewriter.getIntegerType(64) +#define i32_ty rewriter.getIntegerType(32) +#define i16_ty rewriter.getIntegerType(16) +#define i32_ty rewriter.getIntegerType(32) +#define i64_ty rewriter.getIntegerType(64) +#define ui32_ty rewriter.getIntegerType(32, false) +#define ui64_ty rewriter.getIntegerType(64, false) +#define f16_ty rewriter.getF16Type() +#define bf16_ty rewriter.getBF16Type() +#define i8_ty rewriter.getIntegerType(8) +#define i1_ty rewriter.getI1Type() +#define f32_ty rewriter.getF32Type() +#define f64_ty rewriter.getF64Type() +#define vec_ty(type, num) VectorType::get(num, type) +#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) +#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) +#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) + +// Attributes +#define i32_arr_attr(...) rewriter.getI32ArrayAttr({__VA_ARGS__}) +#define i64_arr_attr(...) rewriter.getI64ArrayAttr({__VA_ARGS__}) +#define str_attr(str) ::mlir::StringAttr::get(ctx, (str)) + +namespace mlir { +namespace triton { + +static inline void insertBarrier(OpBuilder &builder, Operation *op) { + auto barrierOp = builder.create(op->getLoc()); + auto asyncTaskIds = getAsyncTaskIds(op); + assert(asyncTaskIds.size() <= 1); + if (asyncTaskIds.size() == 1) { + int asyncTaskId = asyncTaskIds[0]; + int barId = asyncTaskId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + auto mod = op->getParentOfType(); + int numWarps = mlir::triton::gpu::lookupNumWarps(op); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numThreads = numWarps * warpSize; + barrierOp->setAttr("bar_id", builder.getI64IntegerAttr(barId)); + barrierOp->setAttr("num_threads", builder.getI64IntegerAttr(numThreads)); + } +} + +// Delinearize supposing order is [0, 1, .. , n] +template +llvm::SmallVector getMultiDimIndexImpl(T linearIndex, + llvm::ArrayRef shape) { + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearRemain = linearIndex; + llvm::SmallVector multiDimIndex(rank); + for (int i = rank - 1; i >= 0; --i) { + multiDimIndex[i] = linearRemain / accMul; + linearRemain = linearRemain % accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return multiDimIndex; +} + +template +llvm::SmallVector getMultiDimIndex(T linearIndex, llvm::ArrayRef shape, + llvm::ArrayRef order) { + size_t rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + auto reorderedMultiDim = getMultiDimIndexImpl(linearIndex, reordered); + llvm::SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +// Linearize supposing order is [0, 1, .. , n] +template +T getLinearIndexImpl(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape) { + assert(multiDimIndex.size() == shape.size()); + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearIndex = 0; + for (int i = rank - 1; i >= 0; --i) { + linearIndex += multiDimIndex[i] * accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return linearIndex; +} + +template +T getLinearIndex(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape, + llvm::ArrayRef order) { + assert(shape.size() == order.size()); + return getLinearIndexImpl(applyPermutation(multiDimIndex, order), + applyPermutation(shape, order)); +} + +namespace gpu { +Type getFunctionType(Type resultType, ValueRange operands); + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, + StringRef libname = "", + StringRef libpath = ""); +} // namespace gpu + +} // namespace triton + +namespace LLVM { +using namespace mlir::triton; + +// Is v an integer or floating-point scalar constant equal to 0? +bool isConstantZero(Value v); + +class SharedMemoryObject { +public: + SharedMemoryObject(Value base, Type baseElemType, ArrayRef offsets) + : base(base), baseElemType(baseElemType), + offsets(offsets.begin(), offsets.end()) {} + + SharedMemoryObject(Value base, Type baseElemType, int64_t rank, Location loc, + RewriterBase &rewriter) + : base(base), baseElemType(baseElemType) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + offsets.append(rank, b.i32_val(0)); + } + + SmallVector getOffsets() const { return offsets; } + Value getBase() const { return base; } + Type getBaseElemType() const { return baseElemType; } + + SmallVector getElems() const { + SmallVector elems; + elems.push_back(base); + elems.append(offsets.begin(), offsets.end()); + return elems; + } + + SmallVector getTypes() const { + SmallVector types; + types.push_back(base.getType()); + types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); + return types; + } + + SmallVector getStrides(triton::gpu::MemDescType memDesc, Location loc, + RewriterBase &rewriter) const { + auto allocShape = memDesc.getAllocShape(); + auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA( + memDesc.getEncoding(), allocShape); + auto layoutOrder = triton::gpu::getOrder(memDesc); + auto allocStrides = SharedMemoryObject::getStridesForShape( + allocShapePerCTA, layoutOrder, loc, rewriter); + return SmallVector(allocStrides.end() - offsets.size(), + allocStrides.end()); + } + + // TODO(Keren): deprecate the method once AMD backend has cleaned up + Value getCSwizzleOffset(int dim) const { + assert(dim >= 0 && dim < offsets.size()); + return offsets[dim]; + } + + // TODO(Keren): deprecate the method once AMD backend has cleaned up + Value getBaseBeforeSlice(int dim, Location loc, + RewriterBase &rewriter) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value cSwizzleOffset = getCSwizzleOffset(dim); + Value offset = b.sub(b.i32_val(0), cSwizzleOffset); + Type type = base.getType(); + return b.gep(type, baseElemType, base, offset); + } + +private: + static SmallVector + getOrderForShape(ArrayRef shape, ArrayRef layoutOrder) { + SmallVector order(shape.size()); + // Default minor-to-major order + std::iota(order.rbegin(), order.rend(), 0); + if (layoutOrder.size() > 0) { + // If a layout order is provided, we assume it specifies the order in + // which the dimensions are first accessed, and unspecified dimensions + // retain the minor-to-major order. For example, if order = [2, 1, 0] and + // layoutOrder = [0, 1], we need to shift `layoutOrder` + // by -1 (move them right). The resulting order will then be [1, 2, 0]. + int rankDiff = layoutOrder.size() - shape.size(); + auto minRank = std::min(shape.size(), layoutOrder.size()); + for (size_t i = 0; i < minRank; ++i) + order[i] = layoutOrder[i] - rankDiff; + } + assert(isPermutationOfIota(order) && "Invalid order"); + return order; + } + + static SmallVector getStridesForShape(ArrayRef shape, + ArrayRef layoutOrder, + Location loc, + RewriterBase &rewriter) { + SmallVector strides(shape.size()); + auto order = SharedMemoryObject::getOrderForShape(shape, layoutOrder); + int64_t stride = 1; + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (auto idx : order) { + strides[idx] = b.i32_val(stride); + stride *= shape[idx]; + } + return strides; + } + + Value base; // i32 ptr. The start address of the shared memory object. + Type baseElemType; + SmallVector + offsets; // i32 int. The offsets are zero at the initial allocation. +}; + +Value getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + RewriterBase &rewriter); + +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape); + +SmallVector delinearize(unsigned linear, ArrayRef shape, + ArrayRef order); + +// Returns a tuple with the delinearized coordinates and a boolean which is true +// iff the Value is not broadcasted (equivalently, if the value is the "first" +// lane/thread/etc. that holds the given value). In mathy terms, the boolean is +// true if the element is the canonical representative of the class. +std::tuple, Value> +delinearize(RewriterBase &rewriter, Location loc, + triton::gpu::DistributedEncodingTrait layout, + ArrayRef shape, StringAttr dimName, Value linear); + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape); + +size_t linearize(ArrayRef multiDim, ArrayRef shape, + ArrayRef order); + +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content); + +inline bool isKernel(FunctionOpInterface funcOp) { + return funcOp.getVisibility() == SymbolTable::Visibility::Public; +} + +inline Value getStackPointer(RewriterBase &rewriter, + FunctionOpInterface funcOp) { + // See NOTE: [Additional Function Arguments] + if (!isKernel(funcOp)) { + return funcOp.getArgument(funcOp.getNumArguments() - 2); + } + + auto mod = funcOp->getParentOfType(); + auto globalBase = dyn_cast(mod.lookupSymbol("global_smem")); + assert(globalBase); + return rewriter.create(funcOp.getLoc(), globalBase); +} + +inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, + FunctionOpInterface funcOp, + Value allocOffset = {}) { + // See NOTE: [Additional Function Arguments] + if (!isKernel(funcOp)) { + // Base for this function + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1); + if (!allocOffset) { + return gmemBase; + } + + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return b.gep(ptrTy, i8_ty, gmemBase, allocOffset); + } + + // Base for entire kernel + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1); + + ModuleOp mod = funcOp.getOperation()->getParentOfType(); + auto allocSizeAttr = mod.getOperation()->getAttrOfType( + "ttg.global_scratch_memory_size"); + if (!allocSizeAttr) { + return gmemBase; + } + + Value gridIdx[3]; + Value gridDim[2]; + for (int k = 0; k < 3; ++k) { + gridIdx[k] = rewriter.create(loc, k); + } + for (int k = 0; k < 2; ++k) { + gridDim[k] = rewriter.create(loc, k); + } + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value linearId = gridIdx[2]; + for (int k = 0; k < 2; ++k) { + linearId = b.add(gridIdx[1 - k], b.mul(linearId, gridDim[1 - k])); + } + + auto allocSize = allocSizeAttr.getValue().getZExtValue(); + + Value offset = b.mul(linearId, b.i32_val(allocSize)); + if (allocOffset) { + offset = b.add(offset, allocOffset); + } + + auto *ctx = rewriter.getContext(); + auto res = + b.gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset); + return res; +} + +inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Operation *op) { + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), + target.getSharedAddressSpace()); + auto func = op->template getParentOfType(); + if (!func) + func = cast(op); + + assert(op->hasAttr("allocation.offset")); + size_t offset = cast(op->getAttr("allocation.offset")) + .getValue() + .getZExtValue(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value offVal = b.i32_val(offset); + Value base = + b.gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); + return base; +} + +// ----------------------------------------------------------------------- +// MXFP utilities +// ----------------------------------------------------------------------- + +// Scale a mxfp4 value by a given scale. +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale, + bool fastMath); + +} // namespace LLVM + +// ----------------------------------------------------------------------- +// Hardware Indices +// ----------------------------------------------------------------------- + +// If an operation is contained within a warp specialize region, this returns +// the thread ID offset of that warpgroup. +std::optional getWarpGroupStartThreadId(Block *block); + +// Returns CTA level thread ID. +Value getThreadId(OpBuilder &rewriter, Location loc); + +// Get the lane ID, which is index of the thread within its warp. +Value getLaneId(OpBuilder &rewriter, Location loc); + +// Get the lane ID and warp ID. +std::pair getLaneAndWarpId(OpBuilder &rewriter, Location loc); + +// ----------------------------------------------------------------------- +// Shared memory utilities +// ----------------------------------------------------------------------- +using LLVM::SharedMemoryObject; +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::SharedMemoryObject; +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::CTALayoutAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, + ArrayRef strides) { + assert(offsets.size() == strides.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value ret = b.i32_val(0); + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + ret = b.add(ret, b.mul(offset, stride)); + } + return ret; +} + +/// Extend 2d shared object to 3d. +/// +/// If tensor has 3 dimensions, returns original shared object. +/// If tensor shape is [M, N], return shared object describing shape [1, M, N] +/// +/// This Function is used to simplify processing of 2d and 3d dot operands, +/// particularly in the conversion of local_load operation. +/// +/// \param rewriter +/// \param loc +/// \param smemObj +/// \param shape shape of a tensor represented by smemObj +/// \returns shared object describing 3d tensor +SharedMemoryObject +getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, + SharedMemoryObject smemObj, + ArrayRef shape); + +// "Applies" the given layout by computing layout(indices) and returning the +// resulting Values. +// +// In other words, this generates LLVM-dialect MLIR code to "run" the layout +// function. +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices); + +SmallVector> emitOffsetForLayout(Attribute layout, + RankedTensorType type); + +// Emit indices calculation within each ConversionPattern, and returns a +// [elemsPerThread X rank] index matrix. +// +// For example, for a thread a owns `elemsPerThread` elements of a tensor with +// type `type` and layout `layout`, the result will contain `elemsPerThread` +// vectors. Each vector contains the SSA values of the indices required to +// access the corresponding element, starting from the inner dimension. +SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset); + +// Emits IR to load data from shared memory into registers, or to store data +// from registers into shared memory. +// +// You supply perVectorCallback, which is called once per group of register +// elements to transfer. You can use this callback to emit IR to load or store +// data from or to shared memory. +// +// elemLlvmTy should be dstTy's element type converted to an LLVM-dialect type. +// +// If maxVecElems is provided, we won't vectorize more than this many elements. +// +// Returns true on success. +[[nodiscard]] bool emitTransferBetweenRegistersAndShared( + RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, + Type elemLlvmTy, std::optional maxVecElems, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, + std::function perVectorCallback); + +[[nodiscard]] bool emitTransferBetweenRegistersAndShared( + LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, + std::optional maxVecElems, const SharedMemoryObject &smemObj, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + std::function perVectorCallback); + +SmallVector loadSharedToDistributed(RankedTensorType dstTy, + triton::gpu::MemDescType srcTy, + Type elemLlvmTy, + const SharedMemoryObject &smemObj, + Location loc, RewriterBase &rewriter, + const TargetInfoBase &target); + +void storeDistributedToShared( + triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, const SharedMemoryObject &smemObj, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, + std::pair *const llvmOpCount = nullptr); + +inline SmallVector unpackLLElements(Location loc, Value llvmStruct, + RewriterBase &rewriter) { + assert(bool(llvmStruct) && "can not unpack null values"); + if (llvmStruct.getType().isIntOrIndexOrFloat() || + isa(llvmStruct.getType()) || + isa(llvmStruct.getType())) + return {llvmStruct}; + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector results(types.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + results[i] = b.extract_val(type, llvmStruct, i); + } + return results; +} + +inline Value packLLElements(Location loc, + const LLVMTypeConverter *typeConverter, + ValueRange resultVals, RewriterBase &rewriter, + Type type) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (!structType) { + assert(resultVals.size() == 1); + return *resultVals.begin(); + } + + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + } + Value llvmStruct = rewriter.create(loc, structType); + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (const auto &v : llvm::enumerate(resultVals)) { + if (!v.value()) { + emitError(loc) + << "cannot insert null values into struct, but tried to insert" + << v.value(); + } + if (v.value().getType() != elementTypes[v.index()]) { + LDBG("type " << type << " structType " << structType); + LDBG("value " << v.value()); + emitError(loc) << "invalid element type in packLLElements. Expected " + << elementTypes[v.index()] << " but got " + << v.value().getType(); + } + llvmStruct = b.insert_val(structType, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +inline SmallVector unpackLLVector(Location loc, Value llvmVec, + RewriterBase &rewriter) { + assert(bool(llvmVec) && "cannot unpack null value"); + if (llvmVec.getType().isIntOrIndexOrFloat() || + isa(llvmVec.getType()) || + isa(llvmVec.getType())) + return {llvmVec}; + + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector results; + for (int i = 0; i < cast(llvmVec.getType()).getNumElements(); + i++) { + results.push_back(b.extract_element(llvmVec, b.i32_val(i))); + } + return results; +} + +inline Value packLLVector(Location loc, ValueRange vals, + RewriterBase &rewriter) { + assert(vals.size() > 0); + auto vecType = vec_ty(vals[0].getType(), vals.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value vec = b.undef(vecType); + for (int i = 0; i < vals.size(); i++) { + vec = b.insert_element(vec, vals[i], b.i32_val(i)); + } + return vec; +} + +inline bool +isSimpleSharedMemoryAccess(ArrayRef shape, + ArrayRef allocShape, + triton::gpu::SharedEncodingTrait sharedEnc) { + auto rank = shape.size(); + auto swizzledLayout = + dyn_cast(sharedEnc); + bool noSwizzling = swizzledLayout && swizzledLayout.getMaxPhase() == 1; + return /*no swizzling*/ noSwizzling || + /*swizzling but same shape*/ shape == allocShape || + /*swizzling and rank-reduced and rank >= 2*/ + (shape == allocShape.take_back(rank) && rank >= 2); +} + +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..99d90c4d7 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU) +add_public_tablegen_target(TritonConversionPassIncGen) diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/Passes.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/Passes.h new file mode 100644 index 000000000..e159406b3 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_CONVERSION_PASSES_H +#define TRITON_CONVERSION_PASSES_H + +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/Passes.td b/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/Passes.td new file mode 100644 index 000000000..f20c36040 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -0,0 +1,43 @@ +#ifndef TRITON_CONVERSION_PASSES +#define TRITON_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> { + let summary = "Convert Triton to TritonGPU"; + let description = [{ + This pass converts the Triton Dialect into the TritonGPU Dialect. + This is a partial conversion that also affects other dialects + (namely `Arith`, `Math`, `SCF` and `CF`). + For these dialects, and many Triton dialect operations the conversions + mainly consists of enhancing the tensor type and the `tt.ptr>` + type with an appropriate layout encoding (these encodings generally + include information on `numWarps`, `threadsPerWarp` and `numCTAs`). + }]; + let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + // TODO: Does this pass depend on SCF? + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps">, + + Option<"threadsPerWarp", "threads-per-warp", + "int32_t", /*default*/"32", + "number of threads per warp">, + Option<"numCTAs", "num-ctas", + "int32_t", /*default*/"1", + "number of ctas in a cga">, + Option<"target", "target", + "std::string", /*default*/"\"\"", + "the GPU target, e.g., cuda:80, hip:gfx942"> + ]; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h new file mode 100644 index 000000000..c7621b9ae --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -0,0 +1,26 @@ +#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H +#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H + +#include +#include +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +// Create the pass with numWarps passed from cl::opt. +std::unique_ptr> createConvertTritonToTritonGPUPass(); + +// Create the pass with numWarps set explicitly. +std::unique_ptr> +createConvertTritonToTritonGPUPass(const std::string &target, int numWarps, + int threadsPerWarp = 32, int numCTAs = 1); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/CMakeLists.txt new file mode 100644 index 000000000..6ef40db00 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) +add_subdirectory(TritonNvidiaGPU) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..fecd5adf6 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,27 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td) +mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) + +add_public_tablegen_target(TritonTableGen) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Dialect.h b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Dialect.h new file mode 100644 index 000000000..fe8af8499 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Dialect.h @@ -0,0 +1,108 @@ +#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { + +struct GlobalMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +class DialectInferLayoutInterface + : public DialectInterface::Base { +public: + DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef shape, + ArrayRef order, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const = 0; + + // Note: This function only verifies the operand encoding. It doesn't infer + // the result encoding. + virtual LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const = 0; + + // Tries to compute the encoding for the result of a reshape operation that + // makes the reshape a "nop", i.e. the same GPU threads contain the same + // elements as before the reshape using legacy layouts. This is not always + // possible (in which case we fallback to using LinearLayouts) + // In the future we'll always use LinearLayouts + virtual LogicalResult + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + // Check if two layouts are structurally the same, even if their names are + // different + virtual LogicalResult + verifyLayoutsAreEqual(ArrayRef shape, Attribute expected, + Attribute got, std::optional loc) const = 0; + + virtual LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const = 0; + + virtual LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const = 0; + + // Verify that the encoding are compatible to be used together in a dot + // operation + virtual LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const = 0; + + virtual LogicalResult + inferFp4ToFpOpEncoding(ArrayRef shape, int axis, Attribute inEnc, + Attribute &outEnc, bool fwdInference, + std::optional loc) const = 0; +}; + +class DialectVerifyTensorLayoutInterface + : public DialectInterface::Base { +public: + DialectVerifyTensorLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + verifyTensorLayout(Attribute layout, RankedTensorType type, Operation *op, + function_ref emitError) const = 0; +}; + +} // namespace triton +} // namespace mlir + +#endif // TRITON_IR_DIALECT_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Interfaces.h b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Interfaces.h new file mode 100644 index 000000000..f8f3a6f74 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Interfaces.h @@ -0,0 +1,9 @@ +#ifndef TRITON_IR_INTERFACES_H_ +#define TRITON_IR_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/OpInterfaces.h b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/OpInterfaces.h new file mode 100644 index 000000000..d9392d89e --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/OpInterfaces.h @@ -0,0 +1,23 @@ +#ifndef TRITON_IR_OP_INTERFACES_H_ +#define TRITON_IR_OP_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +namespace triton { + +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op); + +LogicalResult verifyDotOpInterface(Operation *op); + +} // namespace impl + +} // namespace triton +} // namespace mlir + +#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc" + +#endif // TRITON_IR_OP_INTERFACES_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Traits.h b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Traits.h new file mode 100644 index 000000000..b17dbce63 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Traits.h @@ -0,0 +1,125 @@ +#ifndef TRITON_IR_TRAITS_H_ +#define TRITON_IR_TRAITS_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +// The rationale for this trait is to prevent users from creating programs +// that would have catastrophic register pressure and cause the compiler to +// hang. +// Since H100 has 256KB registers, we should allow users to create tensors +// of size up to 256K elements. It will spill for datatypes wider than 1B, +// but we probably should limit number of elements (rather than bytes) to +// keep specs simple +int constexpr maxTensorNumElements = 1048576; + +LogicalResult verifyTensorSize(Operation *op); +LogicalResult verifyTensorLayouts(Operation *op); + +LogicalResult verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType = false); +LogicalResult verifyEquivalentType(Type typeA, Type typeB); +LogicalResult +verifySameOperandsAndResultEncoding(Operation *op, + bool allowTensorPointerType = false); + +LogicalResult verifySameLoadStoreOperandsShape(Operation *op); + +LogicalResult verifySameLoadStoreOperandsAndResultShape(Operation *op); + +} // namespace impl + +template +class TensorSizeTrait : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorSize(op); + } +}; + +// Trait applied to all Triton MLIR ops. Checks that the layouts of tensors are +// valid. +template +class VerifyTensorLayoutsTrait + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorLayouts(op); + } +}; + +template +class SameOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding(op); + } +}; + +template +class SameOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op); + } +}; + +template +class SameLoadStoreOperandsShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsShape(op); + } +}; + +template +class SameLoadStoreOperandsAndResultShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsAndResultShape(op); + } +}; + +template +class SameLoadStoreOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op, + /*allowTensorPointerType=*/true); + } +}; + +template +class SameLoadStoreOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding( + op, /*allowTensorPointerType=*/true); + } +}; + +// This trait indicates that regions in the op may execute concurrently with +// each other. +template +struct AsyncRegions : public TraitBase {}; + +} // namespace OpTrait +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonAttrDefs.td new file mode 100644 index 000000000..1e7e663ad --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -0,0 +1,137 @@ +#ifndef TRITON_ATTR_DEFS +#define TRITON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Attributes for LoadOp and StoreOp +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + I32EnumAttrCase<"WB", 4, "wb">, + I32EnumAttrCase<"CS", 5, "cs">, + I32EnumAttrCase<"WT", 6, "wt">, + I32EnumAttrCase<"CV", 7, "cv">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "evict_normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_PaddingOptionAttr : I32EnumAttr< + "PaddingOption", "", + [ + I32EnumAttrCase<"PAD_ZERO", 1, "zero">, + // We can not set the string value to "NAN" because it is a keyword in C++ + I32EnumAttrCase<"PAD_NAN", 2, "nan"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin">, + I32EnumAttrCase<"XCHG", 10, "exch"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Program ID dimensions. +def TT_ProgramDim : I32EnumAttr< + "ProgramIDDim", "", + [ + I32EnumAttrCase<"X", 0, "x">, + I32EnumAttrCase<"Y", 1, "y">, + I32EnumAttrCase<"Z", 2, "z">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Rounding mode. +def TT_RoundingModeAttr : I32EnumAttr< + "RoundingMode", "", + [ + I32EnumAttrCase<"RTZ", 0, "rtz">, + I32EnumAttrCase<"RTNE", 1, "rtne">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// PropagateNan. +def TT_PropagateNanAttr : I32EnumAttr< + "PropagateNan", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"ALL", 0xFFFF, "all">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// InputPrecision +def TT_InputPrecisionAttr : I32EnumAttr< + "InputPrecision", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee"> + ]>{ + let cppNamespace = "::mlir::triton"; +} + +// Type for ScaleDotElemType kind of floats. +def TT_ScaleDotElemTypeAttr : I32EnumAttr< + "ScaleDotElemType", "", + [ + I32EnumAttrCase<"E4M3", 0, "e4m3">, + I32EnumAttrCase<"E5M2", 1, "e5m2">, + I32EnumAttrCase<"E2M3", 2, "e2m3">, + I32EnumAttrCase<"E3M2", 3, "e3m2">, + I32EnumAttrCase<"E2M1", 4, "e2m1">, + I32EnumAttrCase<"BF16", 5, "bf16">, + I32EnumAttrCase<"FP16", 6, "fp16"> + ]>{ + let cppNamespace = "::mlir::triton"; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonDialect.td b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonDialect.td new file mode 100644 index 000000000..a91b7951a --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -0,0 +1,47 @@ +#ifndef TRITON_DIALECT +#define TRITON_DIALECT + +include "mlir/IR/OpBase.td" + +def Triton_Dialect : Dialect { + let name = "tt"; + + let cppNamespace = "::mlir::triton"; + + let summary = "The Triton IR in MLIR"; + + let description = [{ + Triton Dialect. + + Dependent Dialects: + * Arith: + * addf, addi, andi, cmpf, cmpi, divf, fptosi, ... + * Math: + * exp, sin, cos, log, ... + * StructuredControlFlow: + * for, if, while, yield, condition + * ControlFlow: + * br, cond_br + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "math::MathDialect", + "scf::SCFDialect", + "cf::ControlFlowDialect", + "ub::UBDialect" + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "triton/Dialect/Triton/IR/TritonTypes.td" + + +#endif // TRITON_DIALECT diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonInterfaces.td new file mode 100644 index 000000000..3d6d2aee9 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -0,0 +1,30 @@ +#ifndef TRITON_INTERFACES +#define TRITON_INTERFACES + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" + +def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; +def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; +def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; +def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; +def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">; +def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">; +def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">; +def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">; +def AsyncRegions : NativeOpTrait<"AsyncRegions">; + +// A trait equivalent to InferTypeOpAdaptor, but that checks for structural +// equivalence of the layouts of the result rather than just layout equality. +def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{ + static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { + if (lhs.size() != rhs.size()) + return false; + return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) { + auto [lhs, rhs] = tup; + return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs)); + }); + } +}]>; + +#endif // TRITON_INTERFACES diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td new file mode 100644 index 000000000..3e01f53e9 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td @@ -0,0 +1,63 @@ +#ifndef TRITON_OP_INTERFACES +#define TRITON_OP_INTERFACES + +include "mlir/IR/OpBase.td" + + +def TransposeOpInterface : OpInterface<"TransposeOpInterface"> { + let description = [{ + This interface is implemented by operations that perform a transpose. + It provides methods to access common properties such as the order attribute + and the source operand. + }]; + + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get the source operand of the transposition.", + /*retType=*/"::mlir::Value", + /*methodName=*/"getSrc", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Get the order of the transposition.", + /*retType=*/"::mlir::ArrayRef", + /*methodName=*/"getOrder", + /*args=*/(ins)> + ]; + + let verify = [{ + return ::mlir::triton::impl::verifyTransposeOpInterface($_op); + }]; +} + +def DotOpInterface : OpInterface<"DotOpInterface"> { + let description = [{ + This interface is implemented by operations that perform a dot product. + }]; + + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get the LHS A tensor", + /*retType=*/"::mlir::Value", + /*methodName=*/"getA", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Get the RHS B tensor", + /*retType=*/"::mlir::Value", + /*methodName=*/"getB", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/"Verify the dimensions of the A and B DotOp operands.", + /*retType=*/"bool", + /*methodName=*/"verifyDims", + /*args=*/(ins)> + ]; + + let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }]; +} + + +#endif // TRITON_OP_INTERFACES diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonOps.td new file mode 100644 index 000000000..9b033fac0 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonOps.td @@ -0,0 +1,1414 @@ +#ifndef TRITON_OPS +#define TRITON_OPS + +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" + + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +// +// Op Base +// +class TT_Op traits = []> : + Op { +} + +// +// Cast Ops +// +// Use cast ops in arith: +// bitcast +// fptoui, fptosi, uitofp, sitofp, +// extf, tructf, +// extui, extsi, tructi +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure]> { + let summary = "Cast int64 to pointer"; + + let arguments = (ins TT_I64Like:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure]> { + let summary = "Cast pointer to int64"; + + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_I64Like:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// arith.bitcast doesn't support pointers +def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure]> { + let summary = "Cast between types of the same bitwidth"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + // TODO: Add verifier +} + +def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure]> { + let summary = "Floating point casting for custom types"; + + let description = [{ + Floating point casting for custom types (F8), and non-default rounding modes. + + F8 <-> FP16, BF16, FP32, FP64 + }]; + + let arguments = ( + ins TT_FloatLike:$src, + OptionalAttr:$rounding + ); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; + + let hasVerifier = 1; + + let hasFolder = 1; +} + +// +// Arithmetic Ops +// + +def TT_ClampFOp : TT_Op<"clampf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Clamp operation for floating point types"; + + let description = [{ + Clamp operation for floating point types. + + The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. + }]; + + let arguments = ( + ins + TT_FloatLike:$x, + TT_FloatLike:$min, + TT_FloatLike:$max, + TT_PropagateNanAttr:$propagateNan + ); + + let results = (outs TT_FloatLike:$result); + + // List $propagateNan explicitly rather than relying on attr-dict to pick it + // up, because if it's inside attr-dict, its value will be printed as a + // number rather than as a meaningful string. + let assemblyFormat = "$x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)"; +} + +// +// Math Ops +// + +def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise sqrt for floating point types"; + + let description = [{ + Precise sqrt for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x attr-dict `:` type($x)"; +} + +def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise div for floating point types"; + + let description = [{ + Precise div for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$y); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Most significant N bits of the 2N-bit product of two integers"; + + let description = [{ + Most significant N bits of the 2N-bit product of two integers. + }]; + + let arguments = (ins TT_IntLike:$x, TT_IntLike:$y); + + let results = (outs TT_IntLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +// +// Pointer Arith Ops +// +def TT_AddPtrOp : TT_Op<"addptr", + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; + let hasFolder = 1; +} + +def TT_AdvanceOp : TT_Op<"advance", + [Pure, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let summary = "Advance a tensor pointer by offsets"; + + let arguments = (ins TT_TensorPtr:$ptr, Variadic:$offsets); + + let results = (outs TT_TensorPtr:$result); + + let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let hasFolder = 1; +} + +// +// Load/Store Ops +// +def TT_LoadOp : TT_Op<"load", [ + SameLoadStoreOperandsAndResultShape, + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, + TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Load from a tensor of pointers or from a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + Optional:$mask, + Optional:$other, + + DefaultValuedAttr{}">:$boundaryCheck, + OptionalAttr:$padding, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let results = (outs TT_Type:$result); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; + + // Specify `cacheModifier` and `evictionPolicy` explicitly in the + // assemblyFormat instead of as part of attr-dict so that they get printed + // as strings rather than opaque integers. + // + // Note there's no comma between `other` and `cacheModifier` and between + // `cacheModifier` and `evictionPolicy`. This is due to an apparent + // limitation in the MLIR custom-format parser. In oilist, the initial + // keywords of each clause have to be unique, so they can't be `,`. + // + // Even if we gave up on order-independence and used vanilla optional + // clauses, the format (`,` `foo` `=` $foo^)? (`,` `bar` `=` $bar^)? will + // not match the string ", bar = 0" because after the initial comma (first + // token of the first optional clause) we expect to see "foo". + let assemblyFormat = [{ + $ptr (`,` $mask^)? (`,` $other^)? + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +def TT_StoreOp : TT_Op<"store", [ + SameLoadStoreOperandsShape, + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value type matches ptr type", "ptr", "value", + "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Store by a tensor of pointers or by a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + TT_Type:$value, + Optional:$mask, + DefaultValuedAttr{}">:$boundaryCheck, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)>, + // A tensor pointer with boundary check + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$boundaryCheck, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)> + ]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between mask, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $ptr `,` $value (`,` $mask^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +// +// Atomic Ops +// +def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"mask type matches value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "atomic rmw"; + + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + + return old value at $ptr + }]; + + let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, + TT_Type:$val, Optional:$mask, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $atomic_rmw_op, $sem, and $scope rather than relying on + // attr-dict so they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "atomic cas"; + + let description = [{ + compare $cmp with data $old at location $ptr, + + if $old == $cmp, store $val to $ptr, + + else store $old to $ptr, + + return $old + }]; + + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $sem and $scope rather than relying on attr-dict so + // they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + +// +// Shape Manipulation Ops +// +def TT_SplatOp : TT_Op<"splat", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; +} + +def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +def TT_ReshapeOp : TT_Op<"reshape", [Pure, + SameOperandsAndResultElementType]> { + let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set."; + let description = [{ + reinterpret a tensor to a different shape. + + If allow_reorder is set the compiler is free to change the order of + elements to generate more efficient code. + + If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. + The compiler is still free to change it for better performance. + }]; + let builders = [ + OpBuilder<(ins "ArrayRef":$shape, "TypedValue":$src)> + ]; + + let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)"; + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +def TT_BroadcastOp : TT_Op<"broadcast", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "broadcast a tensor"; + + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizer = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +// cat is not `pure` because it may reorder elements +def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, + SameTypeOperands, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_JoinOp : TT_Op<"join", [ + NoMemoryEffect, SameTypeOperands, + InferTypeOpWithLayoutEquivalence, +]> { + let summary = "join two tensors along a new, minor dimension"; + let description = [{ + For example, if the two input tensors are 4x8xf32, returns a tensor of + shape 4x8x2xf32. + + Because Triton tensors always have a power-of-two number of elements, + the two input tensors must have the same shape. + }]; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_SplitOp : TT_Op<"split", [ + NoMemoryEffect, + InferTypeOpWithLayoutEquivalence, + TypesMatchWith<"outLHS and outRHS types match", + "outLHS", "outRHS", "$_self">, +]> { + let summary = "splits a tensor into two, along its last dimension"; + let description = [{ + The input must be a tensor whose last dimension has size 2. Returns two + tensors, src[..., 0] and src[..., 1]. + + For example, if the input shape is 4x8x2xf32, returns two tensors of + shape 4x8xf32. + }]; + + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Tensor:$outLHS, TT_Tensor:$outRHS); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; +} + +def TT_TransOp : TT_Op<"trans", [Pure, + TransposeOpInterface, + InferTypeOpWithLayoutEquivalence, + SameOperandsAndResultElementType]> { + + let summary = "rearrange the dimensions of a tensor"; + let description = [{ + For example, given a tensor x with shape [1,2,4], transpose(x) with + order=[2,0,1] rearranges the tensor to have shape [4,1,2]. + + Although this op is called "trans", it implements both tl.trans() and + tl.permute(). ("permute" might be a better name, but it's called "trans" + because originally it only supported 2D tensors.) + + ## Implementation note on encodings: + + In the TritonGPU dialect (and probably others), an encoding is chosen for + this op's output so it's a nop from the perspective of code generation. + + For example, suppose tensor x has an encoding such that GPU thread [i,j,k] + has a register containing element [i,j,k] of the tensor. Now we transpose + x with order [2,1,0], i.e. we reverse the order of its dimensions. In + TritonGPU, we will choose a layout for the output of the transpose so that + GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the + same element it had before! All we've done is "rename" the element that + thread [i,j,k] has. + + The "real" transpose -- i.e. moving data between GPU threads -- occurs in + convertLayout ops that appear before and/or after the operation. + + We do this so that you can chain multiple data-movement ops (e.g. + transpose+reshape+concat) without going to shared memory after each one. + }]; + + let arguments = ( + ins TT_Tensor:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// SPMD Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +// +// Dot Op +// +def TT_DotOp : TT_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. + }]; + + let arguments = ( + ins + TT_FpIntTensor:$a, + TT_FpIntTensor:$b, + TT_FpIntTensor:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TT_FpIntTensor:$d); + + // attr-dict prints enums as integers. To get inputPrecision printed as a + // string, we need to specify it explicitly. + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + + +// +// DotScaled Op +// +def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot_scaled"; + + let description = [{ + $d = matrix_multiply(scale($a, $a_scale), scale($b, $b_scale)) + $c. + Where scale(x, s) is a function that applies the scale per block following microscaling spec. + }]; + + let arguments = ( + ins + // inputs are floats if we have a type for them, otherwise (fp4), + // they are packed in pairs in an I8Tensor + RankedTensorOf<[TT_Float,I8]>:$a, + RankedTensorOf<[TT_Float,I8]>:$b, + TT_FloatTensor:$c, + Optional>:$a_scale, + Optional>:$b_scale, + TT_ScaleDotElemTypeAttr:$a_elem_type, + TT_ScaleDotElemTypeAttr:$b_elem_type, + BoolAttr:$fastMath + ); + + let results = (outs TT_FloatTensor:$d); + + let assemblyFormat = [{ + $a (`scale` $a_scale^)? `,` $b (`scale` $b_scale^)? `,` $c + `lhs` `=` $a_elem_type `rhs` `=` $b_elem_type attr-dict + `:` type($a) (`,` type($a_scale)^)? `*` type($b) (`,` type($b_scale)^)? `->` type($d) + }]; +} + +// +// Reduce Op +// +def TT_ReduceOp: TT_Op<"reduce", + [Pure, + SameOperandsShape, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + + // Returns the CombineOp iff this ReduceOp's region contains only + // one CombineOp other than the return, or nullptr if not applicable. + ::mlir::Operation *getSingleCombiner(); + }]; +} + +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// +// Scan Op +// +def TT_ScanOp: TT_Op<"scan", + [Pure, + SameOperandsAndResultEncoding, + SameOperandsAndResultShape, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Associative scan using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$reverse); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$reverse)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ScanReturnOp: TT_Op<"scan.return", + [HasParent<"ScanOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for scan operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + + +// +// External Elementwise op +// +def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods, + ConditionallySpeculatable]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + +} + +// +// Make Range Op +// +def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { + let summary = "make range"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, I32Attr:$end); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// ElementwiseInlineAsm Op +// +def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ + Elementwise, + SameOperandsAndResultEncoding, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { + let summary = "inline assembly applying an elementwise operation to a group of packed elements."; + let description = [{ + Runs an inline asm block to generate one or more tensors. + + The asm block is given `packed_element` elements at a time. Exactly which + elems it receives is unspecified. + }]; + + let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) + }]; + + let hasVerifier = 1; +} + +// +// Histogram Op +// +def TT_HistogramOp : TT_Op<"histogram", [Pure]> { + let summary = "return a histogram of the inputs."; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + }]; + + let arguments = (ins TT_IntTensor:$src); + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + +// +// Gather Op +// +def TT_GatherOp : TT_Op<"gather", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "local gather operation"; + let description = [{ + Gather elements from the input tensor using the indices tensor along a + single specified axis. The output tensor has the same shape as the indices + tensor. The input and indices tensors must have the same number of + dimension, and each dimension of the indices tensor that is not the gather + dimension cannot be greater than the corresponding dimension in the input + tensor. + + The `efficient_layout` attribute is set when the compiler has determined an + optimized layout for the operation, indicating that it should not be + changed. + }]; + + let arguments = (ins + TT_Tensor:$src, + TT_IntTensor:$indices, + I32Attr:$axis, + UnitAttr:$efficient_layout + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` attr-dict `:` + functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +// +// Print Op +// +def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrite]>]> { + let arguments = ( + ins + StrAttr:$prefix, + BoolAttr:$hex, + Variadic>:$args, + DenseI32ArrayAttr:$isSigned + ); + let summary = "Device-side print, as in CUDA for debugging"; + let description = [{ + `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict (`:` $args^ `:` type($args))? + }]; +} + +// +// Assert Op +// +def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "Device-side assert, as in CUDA for correctness checking"; + let description = [{ + `tt.assert` takes a condition tensor and a message string. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; +} + +// +// Make Tensor Pointer Op +// +def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", + [Pure, + SameVariadicOperandSize, + TypesMatchWith<"infer pointer type from the result type", + "result", "base", + "getPointerType(getElementTypeOfTensorPointerType($_self), getAddressSpace($_self))">]> { + let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; + + let description = [{ + `tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a + pointer to the block tensor, e.g. returns a type of `tt.ptr>`. + }]; + + // TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints. + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorPtr:$result); + + // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly + // Add additional `[]` to increase readability and split variadic lists + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins + "Value":$base, + "ValueRange":$shape, + "ValueRange":$strides, + "ValueRange":$offsets, + "ArrayRef":$tensorShape, + "ArrayRef":$order + )> + ]; +} + +// +// Make Tensor Descriptor Op +// +def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ + Pure, + SameVariadicOperandSize, +]> { + let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size"; + + let description = [{ + `tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size, + and returns a descriptor object which can be used to load/store from the tensor in global memory. + }]; + + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides + ); + + let results = (outs TT_TensorDescType:$result); + + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)"; + + let builders = [ + OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef":$blockShape)> + ]; + + let extraClassDeclaration = [{ + ArrayRef getTensorShape() { + return getType().getBlockType().getShape(); + } + }]; +} + +def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> { + let summary = "Reinterpret a pointer as a tensor descriptor"; + + let description = [{ + This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects. + Ideally, we can remove this once the APIs are fully fleshed out. + }]; + + let arguments = (ins TT_Ptr:$rawDesc); + let results = (outs TT_TensorDescType:$result); + + let assemblyFormat = [{ + $rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result)) + }]; +} + +// The following ops, including `call`, `func`, and `return` are copied and modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +// We could revert it back once MLIR has a better inliner interface. +// +// Function Ops +// +def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `tt.call` operation represents a direct call to a function that is + within the same symbol scope as the call. The operands and result types of + the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); + } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", cast(callee)); + } + + // Required by CallOpInterface. + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def FuncOp : TT_Op<"func", [ + AffineScope, AutomaticAllocationScope, CallableOpInterface, + FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface, + // HasParent<"ModuleOp"> +]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). An external function declaration (used when + referring to a function declared in some other module) has no body. While + the MLIR textual form provides a nice inline syntax for function arguments, + they are internally represented as “block arguments” to the first block in + the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // External function definitions. + tt.func @abort() + tt.func @scribble(i32, i64, memref) -> f64 + + // A function that returns its argument twice: + tt.func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 + } + + // A function with an argument attribute + tt.func @example_fn_arg(%x: i32 {swift.self = unit}) + + // A function with a result attribute + tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + + // A function with an attribute + tt.func @example_fn_attr() attributes {dialectName.attrName = false} + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType().getResults(); } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; +} + +def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `tt.return` operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. + + Example: + + ```mlir + tt.func @foo() : (i32, f8) { + ... + tt.return %0, %1 : i32, f8 + } + ``` + }]; + + let arguments = (ins Variadic:$srcs); + + let builders = [OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]>]; + + let assemblyFormat = "attr-dict ($srcs^ `:` type($srcs))?"; + let hasVerifier = 1; +} + + +def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead]>]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc` is a tensor descriptor object. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = (ins + TT_TensorDescType:$desc, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc)) `->` type($result) + }]; + + let hasVerifier = 1; +} + +def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ + MemoryEffects<[MemRead, MemWrite]>, +]> { + let summary = "store value based on descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc` is a tensor descriptor object. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = (ins + TT_TensorDescType:$desc, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $desc `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc)) `,` type($src) + }]; + + let hasVerifier = 1; +} + +def TT_ExperimentalDescriptorGatherOp : TT_Op<"experimental_descriptor_gather", [MemoryEffects<[MemRead]>]> { + let summary = "gather multiple rows from a descriptor into a single tensor"; + let description = [{ + The `tt.experimental_desciptor_gather` op will be lowered to NVIDIA TMA + load operations on targets that support it. + + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The descriptor block must have 1 row and the indices must be a 1D tensor. + Accordingly, the result is a 2D tensor multiple rows. + + This is an escape hatch and is only there for testing/experimenting. This + op will be removed in the future. + }]; + + let arguments = (ins + TT_TensorDescType:$desc, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc `[` $x_offsets `,` $y_offset `]` + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + // TMA gathers have resstrictions on the minimum size of the gather result. + // This function verifies the result type. + static LogicalResult verifyResultType(Operation *op, mlir::ShapedType type); + }]; +} + +def TT_ExperimentalDescriptorScatterOp : TT_Op<"experimental_descriptor_scatter", [ + MemoryEffects<[MemRead, MemWrite]>, +]> { + let arguments = (ins + TT_TensorDescType:$desc, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset, + TT_Tensor:$src + ); + + let assemblyFormat = [{ + $desc `[` $x_offsets `,` $y_offset `]` `,` $src + attr-dict `:` type(operands) + }]; +} + +def TT_ExperimentalTensormapCreateOp: TT_Op< + "experimental_tensormap_create", + [ + MemoryEffects<[MemRead, MemWrite]>, + AttrSizedOperandSegments, + ] +> { + let summary = "Create a new TMA descriptor on device"; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_PtrType:$global_address, + Variadic:$box_dim, + Variadic:$global_dim, + Variadic:$global_stride, + Variadic:$element_stride, + ConfinedAttr]>:$elem_type, + ConfinedAttr]>:$interleave_layout, + ConfinedAttr]>:$swizzle_mode, + ConfinedAttr]>:$fill_mode + ); + let extraClassDeclaration = [{ + int32_t getRank() { + return getBoxDim().size(); + } + }]; + let assemblyFormat = [{ + $desc_ptr `,` $global_address `,` + `[` $box_dim `]` `,` + `[` $global_dim `]` `,` + `[` $global_stride `]` `,` + `[` $element_stride `]` + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< + "experimental_tensormap_fenceproxy_acquire", + [MemoryEffects<[MemWrite]>] +> { + let summary = "Acquire fence on a tensormap object"; + let arguments = (ins TT_PtrType:$desc_ptr); + let assemblyFormat = [{ + $desc_ptr attr-dict `:` qualified(type($desc_ptr)) + }]; +} + + +#endif // Triton_OPS diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonTypes.td b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonTypes.td new file mode 100644 index 000000000..a70b97dbc --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -0,0 +1,107 @@ +#ifndef TRITON_TYPES +#define TRITON_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" + +// +// Types +// +class TritonTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_FloatTensor : RankedTensorOf<[TT_Float]>; +def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def TT_BoolTensor : RankedTensorOf<[I1]>; +def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; + +// Integer Type +def I4 : I<4>; +def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; +def TT_IntTensor : RankedTensorOf<[TT_Int]>; +def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; + +// I32 Type +// TT_I32 -> I32 +// TT_I32Tensor -> I32Tensor +def TT_I32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// TT_I64 -> I64 +// TT_I64Tensor -> I64Tensor +def TT_I64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class TT_PtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `TT_PtrOf`) +def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def TT_Ptr : TT_PtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def TT_PtrTensor : RankedTensorOf<[TT_Ptr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>; + +// Tensor Type +def TT_FpIntTensor : RankedTensorOf<[TT_Float, TT_Int]>; +def TT_Tensor : RankedTensorOf<[TT_Float, TT_Int, TT_Ptr]>; + +// Pointer Type to Tensor Type: `ptr>` +def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>; + +// Any Type in Triton IR +def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>; + +// Result type of ExperimentalMakeTensorDescriptor +def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> { + let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system"; + + let description = [{ + A portable abstraction for nvidia-TMA descriptors. + }]; + + let parameters = (ins "RankedTensorType":$blockType); + let assemblyFormat = "`<` $blockType `>`"; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Types.h b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Types.h new file mode 100644 index 000000000..6bcac9522 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Types.h @@ -0,0 +1,41 @@ +#ifndef TRITON_IR_TYPES_H_ +#define TRITON_IR_TYPES_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.h.inc" + +namespace mlir { + +namespace triton { + +bool isTensorPointerType(Type type); + +bool isTensorOrTensorPointerType(Type type); + +unsigned getPointeeBitWidth(Type type); + +Type getPointeeType(Type type); + +Type getPointerType(Type type, int addressSpace = 1); + +int getAddressSpace(Type type); + +Type getElementTypeOfTensorPointerType(Type type); + +Type getI1SameShape(Type type); + +Type getI32SameShape(Type type); + +Type getPointerTypeSameShape(Type type); + +Type getPointerTypeToElement(Type type); + +} // namespace triton + +} // namespace mlir + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Utility.h b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Utility.h new file mode 100644 index 000000000..48a3e66d9 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/IR/Utility.h @@ -0,0 +1,179 @@ +#ifndef TRITON_IR_UTILITY_H_ +#define TRITON_IR_UTILITY_H_ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include +#include + +namespace mlir { + +template SmallVector convertType(ArrayRef in) { + SmallVector out; + for (const auto &i : in) + out.push_back(T(i)); + return out; +} + +template +SmallVector convertType(const VecU &in) { + return convertType(ArrayRef(in)); +} + +template Int product(llvm::ArrayRef arr) { + return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); +} +template auto product(const VecT &vec) { + return product(llvm::ArrayRef(vec)); +} + +// TODO(jlebar): Rename to ceilOfRatio. +template Int ceil(Int m, Int n) { return (m + n - 1) / n; } + +/// Get the highest power of 2 divisor of an integer. +template T highestPowOf2Divisor(T n) { + // When n is 0 or min, return the highest power of 2. The min case is handled + // separately to avoid underflow when T is a signed integer. Technically + // in that case the correct divisor is -n, but this value is outside the + // range of possible values, so we take the next best alternative. + if (n == 0 || n == std::numeric_limits::min()) { + return (static_cast(1) << (sizeof(T) * 8 - 2)); + } + return (n & (~(n - 1))); +} + +/// Get the next power of 2 for an integer (or the integer itself if it is a +/// power of 2). +template T nextPowOf2(T n) { + if (n == 0) { + return 1; + } + n--; + for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) { + n |= n >> i; + } + return n + 1; +} + +namespace triton { + +// Many functions here have two overloads, fn(ArrayRef) and fn(const VecT&). +// This is helpful because C++ won't both convert a vector to ArrayRef *and* +// infer the proper type T in one step. So without the second overload, we +// would have to explicitly convert most arguments to ArrayRef at the callsite. + +template +SmallVector applyPermutation(ArrayRef vec, ArrayRef permutation) { + static_assert(std::is_integral_v); + assert(vec.size() == permutation.size()); + + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (U i = 0; i < static_cast(sortedPerm.size()); i++) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret; + ret.reserve(vec.size()); + for (const U &i : permutation) { + ret.push_back(vec[i]); + } + return ret; +} + +template +auto applyPermutation(const VecT &vec, const PermT &permutation) { + return applyPermutation(ArrayRef(vec), ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector inversePermutation(ArrayRef permutation) { + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (int i = 0; i < sortedPerm.size(); ++i) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { + ret[permutation[i]] = i; + } + return ret; +} + +template +[[nodiscard]] auto inversePermutation(const VecT &permutation) { + return inversePermutation(ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector gather(ArrayRef elems, ArrayRef indices) { + SmallVector ret; + ret.reserve(indices.size()); + for (const U &i : indices) { + ret.push_back(elems[i]); + } + return ret; +} + +template +[[nodiscard]] auto gather(const VecT &elems, const IdxT &indices) { + return gather(ArrayRef(elems), ArrayRef(indices)); +} + +// Is `vec` [0, 1, ..., n]? Returns true on empty list. +template bool isIota(ArrayRef vec) { + static_assert(std::is_integral_v); + for (T i = 0; i < vec.size(); ++i) { + if (vec[i] != i) { + return false; + } + } + return true; +} + +template bool isIota(const VecT &vec) { + return isIota(ArrayRef(vec)); +} + +// Is `vals` some permutation of the numbers 0..(vals.size()-1)? +template bool isPermutationOfIota(ArrayRef vals) { + SmallVector sorted(vals); + llvm::sort(sorted); + return isIota(sorted); +} + +template bool isPermutationOfIota(const VecT &vec) { + return isPermutationOfIota(ArrayRef(vec)); +} + +// Is `vec` [i, i+1, ..., i+n]? Returns true on empty list. +template bool isConsecutive(ArrayRef vec) { + static_assert(std::is_integral_v); + for (int i = 1; i < vec.size(); i++) { + if (vec[i] != vec[i - 1] + 1) { + return false; + } + } + return true; +} + +template bool isConsecutive(const VecT &vec) { + return isConsecutive(ArrayRef(vec)); +} + +template auto seq(T start, T end, T step) { + auto len = ceil(end - start, step); + return llvm::map_range(llvm::seq(0, len), + [=](T i) { return start + i * step; }); +} + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..372a9ec11 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton) +add_public_tablegen_target(TritonTransformsIncGen) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/Passes.h b/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/Passes.h new file mode 100644 index 000000000..29e88fb6d --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/Passes.h @@ -0,0 +1,22 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr createCombineOpsPass(); + +std::unique_ptr createReorderBroadcastPass(); +std::unique_ptr createRewriteTensorPointerPass(); +std::unique_ptr createLoopUnrollPass(); + +} // namespace triton + +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/Passes.td b/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/Passes.td new file mode 100644 index 000000000..0433204b5 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/Triton/Transforms/Passes.td @@ -0,0 +1,63 @@ +#ifndef TRITON_PASSES +#define TRITON_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonCombineOps : Pass { + let summary = "combine ops"; + let description = [{ + This pass aims to optimize the five following patterns: + - `dot(a, b, 0) + c => dot(a, b, c)` + + - `addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))` + + - `select(cond, load(ptrs, broadcast(cond), ???), other) => + load(ptrs, broadcast(cond), other)` + + - `broadcast(constant) => reshaped_constant` + - `torch.sum(x[:,:,None].expand(-1,-1,n) * y[None,:,:].expand(m,-1,-1),1) + => dot(x,y,splat(0))` + }]; + + let constructor = "mlir::triton::createCombineOpsPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect"]; +} + +def TritonReorderBroadcast : Pass { + let summary = "Moves broadcast and splat after elementwise operations"; + let description = [{ + The purpose of this pass is to transform: + - `elementwise(broadcast(a)) => broadcast(elementwise(a))` + - `elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))` + In the event of a match, the broadcast (or splat) operation is delayed + and performed after the ElementWise operation. + }]; + let constructor = "mlir::triton::createReorderBroadcastPass()"; + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonRewriteTensorPointer : Pass { + let summary = "Rewrite load/stores with tensor pointers into legacy load/stores"; + let description = [{ + This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy + semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute + the pointer/mask/other for each load/store. + }]; + + let constructor = "mlir::triton::createRewriteTensorPointerPass()"; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonLoopUnroll : Pass { + let summary = "Loop unroller"; + let description = [{ + The pass unrolls a scf loop with tt.loop_unroll_factor attribute. The attribute specialises how many iterations + the loop should be unrolled. + }]; + let constructor = "mlir::triton::createLoopUnrollPass()"; + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Attributes.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Attributes.h new file mode 100644 index 000000000..1f93b3d93 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -0,0 +1,10 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ + +#include "mlir/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/AttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..a211c7bc8 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,26 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttg) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttg) +add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonGPUAttrDefsIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td) +mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(TritonGPUTypeInterfacesIncGen) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Dialect.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Dialect.h new file mode 100644 index 000000000..1b5dd0546 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -0,0 +1,243 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +#include + +// LinearLayoutCache Utils +using CacheKey = std::tuple, mlir::Attribute>; + +namespace llvm { +template size_t hash_value(const std::vector &vec) { + return hash_combine_range(vec.begin(), vec.end()); +} +} // namespace llvm + +namespace std { +template <> struct hash { + size_t operator()(const CacheKey &key) const noexcept { + using llvm::hash_value; + size_t seed = 0; + std::apply( + [&seed](const auto &...elems) { + ((seed = llvm::hash_combine(seed, hash_value(elems))), ...); + }, + key); + return seed; + } +}; +} // namespace std + +namespace mlir::triton::gpu { + +constexpr static char AttrNumWarpsName[] = "ttg.num-warps"; +constexpr static char AttrNumCTAsName[] = "ttg.num-ctas"; +constexpr static char AttrTargetName[] = "ttg.target"; +constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp"; + +// Find the contextual number of warps on which this operation is executed. +int lookupNumWarps(Operation *op); +// Try to find the contextual number of warps on which this operation is +// executed. Returns nullopt if a warp size cannot be find. This is used for +// verifiers. +std::optional maybeLookupNumWarps(Operation *op); + +class LinearLayoutCache { +public: + std::optional get(const CacheKey &key) { + std::shared_lock lock(mutex); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + return std::nullopt; + } + + void set(CacheKey key, LinearLayout result) { + std::scoped_lock lock(mutex); + cache.emplace(std::move(key), std::move(result)); + } + +private: + std::unordered_map cache; + llvm::sys::SmartRWMutex mutex; +}; +} // namespace mlir::triton::gpu + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonGPU/IR/Ops.h.inc" + +namespace mlir::triton::gpu { +struct SharedMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +// Convert a distributed layout to a linear encoding +LinearEncodingAttr toLinearEncoding(RankedTensorType type); +LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef shape); + +unsigned getTotalElemsPerThread(Type type); + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape); + +SmallVector getElemsPerThread(Type type); + +// Returns the number of threads per warp that may have access to replicated +// elements. If you want non-replicated threads, use +// getThreadsPerWarpWithUniqueData. +SmallVector getThreadsPerWarp(Attribute layout); + +unsigned getWarpSize(Attribute layout); + +// Returns the number of warps per CTA that may have access to replicated +// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData. +SmallVector getWarpsPerCTA(Attribute layout); + +// Returns the number of contiguous elements of the logical tensor that each +// thread has access to, on each dimension of the tensor. For a blocked layout +// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements +// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1, +// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be +// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4]. +SmallVector getContigPerThread(RankedTensorType tensorType); + +// Returns the number of threads per warp that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17 +// have access to the full tensor, whereas the other threads have access to +// replicated elements, so this function returns [2, 2]. +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of warps per CTA that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2], +// returns [1, 1], since the first warp has access to the full tensor, whereas +// the other warps have access to replicated elements. +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); + +// Returns the dimensions of the tensor from minor (fast-varying) to +// major (slow-varying). For distributed layouts, this represents +// the order of the elements within a thread. +// For shared Layout, the order refers to which dimension of the original tensor +// is contiguous in shared memory. +SmallVector getOrder(DistributedEncodingTrait layout, + ArrayRef shape); +SmallVector getOrder(RankedTensorType type); + +SmallVector getOrder(SharedEncodingTrait layout, + ArrayRef shape); +SmallVector getOrder(MemDescType type); +SmallVector getOrder(TensorOrMemDesc type); + +// Order of the elements in the shared memory as defined at layout creation +// If this layout is associated with a MemDesc with a different shape +// it may return a different order than the actual order of the elements +SmallVector getDefaultOrder(SharedEncodingTrait layout); + +// Returns the dimensions along which warpId's are distributed. +// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4] +// tells there are 2 warps along dim0 and 4 warps along dim1. +// warpOrder tells the specific order when distributing warp IDs. +// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows +// [warp0 warp2 warp4 warp6] +// [warp1 warp3 warp5 warp7] +SmallVector getWarpOrder(DistributedEncodingTrait layout, + ArrayRef shape); +SmallVector getWarpOrder(RankedTensorType type); + +// Returns the dimensions along which threadId's are distributed. +// Similar to warpOrder, threadOrder is necessary to tell the specific thread +// distribution in the warp. +SmallVector getThreadOrder(DistributedEncodingTrait layout, + ArrayRef shape); +SmallVector getThreadOrder(RankedTensorType type); + +CTALayoutAttr getCTALayout(Attribute layout); + +SmallVector getCTAsPerCGA(Attribute layout); + +SmallVector getCTASplitNum(Attribute layout); + +SmallVector getCTAOrder(Attribute layout); + +/* The difference between ShapePerCTATile and ShapePerCTA: + * (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp * + * WarpsPerCTA in each dimension and is independent from the tensor shape. + * (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension. + * (3) In the implementation of emitIndices, ShapePerCTATile will + * be replicated or wrapped to fit ShapePerCTA. + */ +// [FIXME LL] Kill this function +SmallVector getShapePerCTATile(RankedTensorType layout); + +// Returns the "logical" shape per CTA +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape); +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape); +SmallVector getShapePerCTA(Type type); + +// Returns the shape per CTA, which is "physically" allocated +// Such shapes may be bigger than the logical one due to, for example, padding +// in shared memory. +SmallVector getAllocationShapePerCTA(Attribute layout, + ArrayRef shape); +SmallVector getAllocationShapePerCTA(Type type); + +unsigned getNumWarpsPerCTA(Attribute layout); + +unsigned getNumCTAs(Attribute layout); + +// Return the order that represents that the batch is in row-major or +// column-major order for a batch of matrices of shape [*, m, n] with +// len(shape) == rank. +SmallVector getMatrixOrder(unsigned rank, bool rowMajor); + +// Return the order that represents that the dot operand is in kContig +// (contiguous in the inner dimension) or it's contiguous on the outer +// dimension. +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kContig); + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding); + +// Return true if a view between the two types cannot be implemented as a no-op. +bool isExpensiveView(Type srcType, Type dstType); + +// Return a blocked encoding where the shape is distributed contiguously amongst +// the threads, warps, CTAs with 1 element per threads. +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs); + +// Dump information about which threads/registers contain each of the tensor +// elements. +void dumpLayout(RankedTensorType tensorType); + +// Dump the layout from HW point of view and prints what tensor element is held +// by each thread and register. +void dumpHWLayout(RankedTensorType tensorType); + +// Return a string representation of the layout of the tensor. +std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView); + +template +llvm::SmallVector expandMatrixShapeWithBatch(llvm::ArrayRef s); + +llvm::SmallVector +expandMatrixOrderWithBatch(llvm::ArrayRef o); +} // namespace mlir::triton::gpu + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h new file mode 100644 index 000000000..ab689723f --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -0,0 +1,288 @@ +// Conversions from TritonGPU layouts (e.g. BlockedEncodingAttr) to +// LinearLayout. + +#ifndef TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H +#define TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H + +#include + +#include "triton/Tools/LinearLayout.h" + +namespace mlir::triton { +enum class ScaleDotElemType : uint32_t; +} // namespace mlir::triton + +namespace mlir::triton::gpu { +class SwizzledSharedEncodingAttr; +class NVMMASharedEncodingAttr; +class AMDMfmaEncodingAttr; + +// - BlockedEncodingAttrs have the following input dimensions. +// +// "register": elements in one thread +// "lane": threads in a warp +// "warp": warps in a block/CTA +// "block": blocks in a cluster +// +// - An n-dimensional SwizzledSharedEncodingAttr has the following input +// dimensions. +// +// "offset": the n'th element in the allocation, within a particular thread +// block (i.e. within a CTA). The offset is measured in elements, not +// bytes. +// "block": blocks in a cluster +// +// All layouts have the following output dimensions. +// +// "dimi" for i in 0..n-1: the location in the n'th logical dimension of the +// output tensor. These also are not reordered according to the layout's +// `order`. +// +// You can flatten the input or output dimensions into a single dimension using +// LinearLayout::flattenIns/Outs(). +// +// elemBitWidth is the bit width of one element in the layout. This is required +// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e. +// shared layouts with nvmma_shared layout) but is otherwise unused. +// +// Returns std::nullopt if the given layout can't be converted to an LL. +LinearLayout toLinearLayout(ArrayRef shape, Attribute layout); + +// Convert the shared encoding of a tensor with `nvmma_shared` layout to a +// LinearLayout that maps from a linear shared memory offset to tensor index. +// +// If `disableSwizzle` is set, then the resulting layout does not include +// swizzling. +LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, + NVMMASharedEncodingAttr shared, + bool disableSwizzle = false); + +// Given a linear layout where the input dimensions contain a "block" dimension, +// this method sets the "block" dimension to 0 and removes the corresponding +// output dimensions. +// +// Note that this behavior differs from calling +// `LinearLayout::sublayout(inDimNames, outDimNames)` when "block" is not in +// `inDimNames`. The latter does not modify the output sizes. +LinearLayout getLayoutWithinBlock(const LinearLayout &layout); + +// In this function, we construct a linear layout representing the +// -> mapping +// for entire `src` and `dst` tensors. We determine the shape of the +// intermediate shared memory buffer needed for a register-to-register +// conversion using the maximum size accessed in each dimension from `src`'s +// layout and `dst`'s layout. See the getRepShapeForCvt function in +// Allocation.cpp for details. Note that the buffer might be smaller than the +// tensor being converted, so we need multiple "iterations" to move a subregion +// of the `src` tensor to the corresponding subregion of the `dst` tensor. The +// pesudo code of layout conversion is as follows: +// +// for iter in 0..numIterations: +// sync threads +// for vecIdx in [0..numRegisters/storeVec]: +// registers <- get registers used in iter +// offsets <- get offsets using the intermediate linear layout +// store registers[vecIdx * storeVec, (vecIdx + 1) * storeVec)] to shared +// memory +// sync threads +// for vecIdx in [0..numRegisters/loadVec]: +// registers <- get registers used in iter +// offsets <- get offsets using the intermediate linear layout +// load registers[vecIdx * loadVec, (vecIdx + 1) * loadVec)] from shared +// memory +LinearLayout chooseShemLayoutForRegToRegConversion( + MLIRContext *ctx, ArrayRef tensorShape, + ArrayRef repShape, ArrayRef order); + +// This function constructs a linear layout that maps +// to . +// The primary goal is to efficiently store 2D tiles of a tensor into shared +// memory using the `stmatrix` instruction, with each thread responsible for +// storing `N` elements. If `stmatrix` cannot be used for the given tensor +// encoding, this function returns `std::nullopt`. +// +// Unlike standard vectorized stores, such as `st.shared.v4 [%offset], +// %vec_reg`, where `%vec_reg` contains four consecutive data elements, the +// `stmatrix` instruction allows `N` registers to point to non-contiguous +// locations within a tensor tile. +// +// For instance, the `stmatrix [%offset], %mat_reg` instruction on NVIDIA GPUs +// enables `%mat_reg` to store `N` elements that do not need to be consecutive. +// However, it is crucial that the address (`%offset`) of each row in a tensor +// tile should be aligned to `N` * `elemBitWidth`. The `%offset` of each thread +// is calculated based on the provided tensor encoding. +// +// Currently, we support only the NVIDIA MMAv3 encoding and the `stmatrix.x4` +// instruction. Each `stmatrix.x4` instruction stores eight 16-bit elements per +// thread, resulting in a total of 8 * 32 = 256 elements per warp, or 16 * 16 +// elements per warp when distributed across four 8x8 tiles. Each thread's +// `%offset` points to an address aligned with 8 * 16 bits, denoting a row in +// the 8x8 tile. The values in `%mat_reg` are non-consecutive elements, +// composed of 4 pairs of consecutive elements. These matrix addresses are +// distributed as follows: +// +// col[0-7] col[8-15] +// row[0-7] lane[0-7] lane[16-23] +// row[8-15] lane[8-15] lane[24-31] +// +// The matrix elements of thread 0 are distributed in the following pattern: +// +// col0 col8 +// row0 reg[0-1] reg[4-5] +// row8 reg[2-3] reg[6-7] +// +// When `swizzleByteSize` is non-zero, the layout is constructed +// differently due to leading dimension offset and swizzling. +// There are two key concepts to understand: +// +// 1. Chunks: The leading dimension (i.e., the column dimension) is divided +// into chunks, where each chunk's size is determined by `swizzleByteSize`. +// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its +// rows to optimize memory access. +// +// - Concept 1: Chunks +// +// In the swizzled layout, the leading dimension is strided by +// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk +// spans a certain number of columns. +// +// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16 +// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16 +// elements * 2 bytes per element = 32 bytes per row). +// +// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be +// calculated as: +// +// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes / +// 32 bytes = 4 tiles +// +// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns +// (since each tile is 16 columns): +// +// col0-15 col16-31 col32-47 col48-63 +// row0-15 tile0 tile1 tile2 tile3 +// +// For a tensor of size 128x128 elements (#rows x #columns), and each element +// being 16 bits, the tensor can be divided into multiple chunks both +// horizontally and vertically. Chunks are stored in memory in a "column-major" +// order based on chunks, meaning chunk1's address follows chunk0's. +// +// Assuming we have 8 warps, and we assign each warp to process a chunk of 16 +// rows (rows per tile) and 128 columns (the width of two chunks). This results +// in each warp handling one horizontal slice of the tensor. +// +// The overall layout can be visualized as: +// +// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->| +// columns 0-63 columns 64-127 +// warp0 | rows 0-15 chunk0 chunk8 +// warp1 | rows 16-31 chunk1 chunk9 +// warp2 | rows 32-47 chunk2 chunk10 +// warp3 | rows 48-63 chunk3 chunk11 +// warp4 | rows 64-79 chunk4 chunk12 +// warp5 | rows 80-95 chunk5 chunk13 +// warp6 | rows 96-111 chunk6 chunk14 +// warp7 | rows 112-127 chunk7 chunk15 +// +// - Concept 2: Swizzling within tiles +// +// Within each 16x16 tile, rows are swizzled to optimize memory access patterns. +// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the +// level of each 16x16 tile rather than the entire tensor. +// +// Key parameters for swizzling: +// +// - `perPhase`: The number of rows over which to apply a XOR operation at +// each phase. +// - `maxPhase`: The total number of phases. +// - `vectorWidth`: The number of elements per vector, which is 8 in this case +// because `stmatrix` stores 8 contiguous elements per thread. +// +// The offset of each element within a tile is calculated using the formula: +// +// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) % +// maxPhase)) * elementSize +// +// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit +// elements). +// +// For example, consider the element at index `(row=1, col=0)` in chunk0: +// +// Without swizzling: +// +// offset = row * swizzleByteSize + col * elementSize +// = 1 * 128 bytes + 0 * 2 bytes +// = 128 bytes +// +// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`): +// +// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) % +// maxPhase)) * elementSize +// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes +// = 128 bytes + (8 * (1 % 8)) * 2 bytes +// = 128 bytes + 8 * 2 bytes +// = 128 bytes + 16 bytes +// = 144 bytes +// +// This swizzling ensures that elements are stored in a way that optimizes for +// memory bandwidth and reduces bank conflicts. +// +// - Verification through Linear Layout +// +// We can verify the offsets with the following outputs of the corresponding +// linear layout, where each element is 16 bits (2 bytes): +// +// - register=1 -> offset=1 +// register=2 -> offset=2 +// register=4 -> offset=4 +// register=8 -> offset=16 +// register=16 -> offset=32 +// register=32 -> offset=8192 +// - lane=1 -> offset=72 +// lane=2 -> offset=144 +// lane=4 -> offset=288 +// lane=8 -> offset=512 +// lane=16 -> offset=8 +// - warp=1 -> offset=1024 +// warp=2 -> offset=2048 +// warp=4 -> offset=4096 +// +// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in +// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result +// matches our earlier calculation. +// +// TODO(Keren): We should replace tensorTy with a LinearLayout and the element +// bit width of the tensor in the future to support more flexible tensor +// encodings +LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + int swizzleByteSize); + +// The primary goal of this function is to efficiently store 2D tiles of a +// tensor into shared memory using the `ldmatrix` instruction. +LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef shape, + bool needTrans, int32_t elemBitWidth); + +// The primary goal of this function is to efficiently load 2D tiles of a +// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs. +LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef shape, + int32_t elemBitWidth); + +// Create LinearLayout for mxfp4 and mxfp8 operand in scaled mfma. +// For mxfp4, we use dot layout directly. Mxfp8 is not covered by dot +// layout, so we need to manually create linear layout for it. +LinearLayout +chooseScaledMfmaOperandLayout(AMDMfmaEncodingAttr mfmaEnc, int kWidth, + int dotOperandIdx, ScaleDotElemType elemType, + llvm::ArrayRef dotOperandShape); + +LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType, + int numWarps); + +// Create LinearLayout for scale in scaled mfma. +LinearLayout chooseScaledMfmaScaleLayout( + MLIRContext *ctx, int dotOperandIdx, + const std::vector> &dotOperandWarpBasis, + ArrayRef dotOperandShape, unsigned mfmaMDim); +} // namespace mlir::triton::gpu + +#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td new file mode 100644 index 000000000..7e3e011e7 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -0,0 +1,1324 @@ +#ifndef TRITONGPU_ATTRDEFS +#define TRITONGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" + +//===----------------------------------------------------------------------===// +// TritonGPU Attribute Definitions +//===----------------------------------------------------------------------===// +def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let methods = [ + ]; +} + +class TritonGPU_Attr traits = [], + Dialect dialect = TritonGPU_Dialect, + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + + let description = [{ +TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines +how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function +\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding +to the indices of the CUDA threads allowed to access some data at index $i$. + +For example, let us consider the layout function: +\mathcal{L}(0, 0) = {0, 4} +\mathcal{L}(0, 1) = {1, 5} +\mathcal{L}(1, 0) = {2, 6} +\mathcal{L}(1, 1) = {3, 7} + +Then, attaching $\mathcal{L} to a tensor $T$ would mean that: +- T[0,0] is owned by both cuda thread 0 and 4 +- T[0,1] is owned by both cuda thread 1 and 5 +- T[1,0] is owned by both cuda thread 2 and 6 +- T[1,1] is owned by both cuda thread 3 and 7 + +Right now, Triton implements two main classes of layouts: shared, and distributed. + }]; + let attrName = "triton.gpu." # attrMnemonic; + + code extraBaseClassDeclaration = [{ + }]; +} + +//===----------------------------------------------------------------------===// +// CTA Layout +//===----------------------------------------------------------------------===// + +def CTALayoutAttr : TritonGPU_Attr<"CTALayout", "cta_layout"> { + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$CTAsPerCGA, + ArrayRefParameter<"unsigned">:$CTASplitNum, + ArrayRefParameter<"unsigned">:$CTAOrder + ); + + let description = [{ +Describes how blocks are distributed among the cooperate thread arrays (aka +CTAs, aka thread blocks) in a cooperate thread group (aka CTG, aka thread group +cluster). CGAs were introduced in Hopper (sm90). + +The tensor is divided up into CTASplitNum pieces, which are distributed among +the CTAsPerCGA thread blocks. Each CTA processes a subtensor of shape +`tensor_shape / CTASplitNum`. + +Example 0: The tensor shape is [64, 128] and, there are two CTAs, each +processing half the tensor [64, 64]. Then CTAsPerCGA = [1, 2] and +CTASplitNum = [1, 2]. + +Example 1: The tensor shape is [64, 128] and, there are two CTAs, both +processing the complete tensor [64, 128]. This happens when multicast is +enabled. In this case, CTAsPerCTA = [1, 2] but CTASplitNum = [1, 1]. + +Example 2: Consider a matmul AxB=C, where A=[M,K], B=[K,N], C=[M,N]. The +CTAsPerCGA for A, B, C are the same, [SplitM, SplitN], but the CTASplitNum are +different. CTASplitNum_A = [SplitM, 1], which means multicast on dim1, +CTASplitNum_B = [1, SplitN], which means multicast on dim0, CTASplitNum_C = +[SplitM, SplitN] which means no multicast. + +Currently programs with multiple CTAs per CGA are an experimental feature in +Triton, not enabled by default. + +You can leave off the CTALayout properties in the textual IR and Triton will +fill in the "default" CTALayout of CTAsPerCGA = CTASplitNum = [1...1]. In +addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to +[n-1,...,0] (it doesn't matter in this case). + }]; + + // CTALayout::get canonicalizes CTAOrder to [n,n-1,...,0] if CTAsPerCGA is + // [1...1]. The CTAOrder doesn't matter in this case. + // + // This is a little weird because if you write textual IR with a one order and + // then print it back out, you might get a different order. But it seems this + // is the best way to canonicalize an attribute in MLIR. + let builders = [ + AttrBuilder<(ins "ArrayRef":$CTAsPerCGA, + "ArrayRef":$CTASplitNum, + "ArrayRef":$CTAOrder), [{ + if (llvm::all_of(CTAsPerCGA, [](unsigned x) { return x == 1; })) { + SmallVector order; + for (int i = CTAsPerCGA.size() - 1; i >= 0; --i) + order.push_back(i); + return $_get(context, CTAsPerCGA, CTASplitNum, order); + } + return $_get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + }]>, + ]; + + let extraClassDeclaration = [{ + static CTALayoutAttr getDefault(MLIRContext *context, int rank) { + SmallVector CTAsPerCGA(rank, 1); + SmallVector CTASplitNum(rank, 1); + SmallVector CTAOrder; + for (int i = rank - 1; i >= 0; --i) + CTAOrder.push_back(i); + return get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + } + }]; + + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + + +def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let description = [{ + Common trait for all TTGIR layouts. + }]; + let methods = [ + InterfaceMethod<"Get the shape of the CTAs per CGA.", + "SmallVector", + "getCTAsPerCGA">, + InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", + "SmallVector", + "getCTAOrder">, + InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", + "SmallVector", + "getCTASplitNum">, + ]; +} + +//===----------------------------------------------------------------------===// +// Shared Layout Encoding +//===----------------------------------------------------------------------===// + +def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ + Common trait describing shared memory. + }]; + let methods = [ + InterfaceMethod<"Return the default alignment for the layout.", + "int32_t", + "getAlignment">, + ]; +} + +def SwizzledSharedEncodingAttr : + TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> { + let mnemonic = "swizzled_shared"; + + let description = [{ +An encoding for tensors whose elements may be simultaneously accessed by +different cuda threads in the programs, via shared memory. In other words, +for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. + +In order to avoid shared memory bank conflicts, elements may be swizzled. +Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1]. + +1. Basic swizzling + + #shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // xor with 0 + [ 5, 4, 7, 6], // xor with 1 + [10, 11, 8, 9], // xor with 2 + [15, 14, 13, 12] // xor with 3 + +Here elements of row r are xor'ed with r (or more properly, in[r][c] -> +out[r][c^r]). + +2. Multiple rows per phase + + #shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14] + +Elements of row r are xor'ed with r/2. In other words, perPhase=2 +means that pairs of 2 rows get the same swizzling. + +3. Max-phase applied + + $shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 5, 4, 7, 6], // phase 1 (xor with 1) + [ 8, 9, 10, 11], // phase 0 + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // ... + [21, 20, 23, 22], + [24, 25, 26, 27], + [29, 28, 31, 30] + +Elements of row r are xor'ed with (r/2) % 2. In other words, maxPhase=m has the +effect of limiting the maximum value of the xor to m-1. + +4. Max-phase and per-phase + + #shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], // phase 0 + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // phase 0 + [20, 21, 22, 23], // phase 0 + [25, 24, 27, 26], // phase 1 + [29, 28, 31, 30]] // phase 1 + +Here the xor value (the "phase", I guess?) changes every perPhase rows, up to a +maximum value of maxPhase-1. In other words, elements of row r are xor'ed with +(r/2) % 2. + +5. Adding vec + + #shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3, 4, 5, 6, 7], + [10, 11, 8, 9, 14, 15, 12, 13], + [20, 21, 22, 23, 16, 17, 18, 19], + [30, 31, 28, 29, 26, 27, 24, 25] + +When vec=2, elements are swizzled in pairs of 2. In other words, the element at +(r,c) has value + + ((c / 2) ^ r) * 2 + (c % 2). + }]; + + // swizzle info: vec, perPhase, maxPhase + // order: the fastest-changing axis first + let parameters = ( + ins + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CTALayoutAttr":$CTALayout + ); + + let builders = [ + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit), [{ + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + // TODO(jlebar): This should not be an overload of + // SwizzledSharedEncodingAttr::get(). It's misleading, because it does a bunch of + // nontrivial work based on the given dotOpEnc. + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ + + // ---- begin MFMA ---- + if (auto mfmaEnc = mlir::dyn_cast(dotOpEnc.getParent())) { + int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + if (needTrans) + kDimNum = 1 - kDimNum; + bool isKDimInner = (order[0] == kDimNum); + + // GFX950 supports LDS transpose load instructions, so we need + // swizzling even when K dimension is not innermost. + const int versionMajor = mfmaEnc.getVersionMajor(); + bool isGFX950 = versionMajor == 4; + bool swizzleNonKDimInner = isGFX950 && (typeWidthInBit == 8 || typeWidthInBit == 16); + if (isKDimInner || swizzleNonKDimInner) { + const int numBanks = isGFX950 ? 64 : 32; + const int bankBitWidth = 32; + const int SIMDWidth = 16; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + // vecSize is set to kWidth of the dotop layout + int vecSize = dotOpEnc.getKWidth(); + int maxPhase = std::max(std::min(SIMDWidth / perPhase, innerDimLength / vecSize), 1); + + // TODO (zhanglx): figure out better parameters for mfma4 + if (mfmaEnc.getMDim() == 4) + maxPhase = 4; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + // ---- begin WMMA ---- + if (mlir::isa(dotOpEnc.getParent())) { + if (dotOpEnc.getOpIdx() == 0) { + const int numBanks = 32; + const int bankBitWidth = 32; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit; + int maxPhase = 16 / perPhase; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + + auto mmaEnc = mlir::dyn_cast(dotOpEnc.getParent()); + + if(!mmaEnc) + return get(context, 1, 1, 1, order, CTALayout); + + int opIdx = dotOpEnc.getOpIdx(); + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + // number of rows per phase + + // index of the inner dimension in `order` + unsigned inner = (opIdx == 0) ? 0 : 1; + + // ---- begin Ampere & Hopper ---- + if (mmaEnc.isAmpere() || mmaEnc.isHopper()) { + int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); + perPhase = std::max(perPhase, 1); + std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; + int vecWidth = 32 / typeWidthInBit; + if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) { + perPhase = std::max(perPhase, 2 * vecWidth); + } + int rank = order.size(); + // --- handle A operand --- + if (opIdx == 0) { // compute swizzling for A operand + int m = (needTrans) ? matShape[2] : matShape[0]; + int k = (needTrans) ? matShape[0] : matShape[2]; + int vec = (order[0] == rank-1) ? k : m; + int mmaStride = (order[0] == rank-1) ? m : k; + int maxPhase = std::max(mmaStride / perPhase, 1); + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + // --- handle B operand --- + if (opIdx == 1) { + // we compute vec and maxPhase m, n and k size of the mma + // instruction. when matmul operands is transposed, we should + // consider that to get m, n and k. + int n = needTrans ? matShape[2] : matShape[1]; + int k = needTrans ? matShape[1] : matShape[2]; + int vec = (order[0] == rank-1) ? n : k; + int mmaStride = (order[0] == rank-1) ? k : n; + int maxPhase = std::max(mmaStride / perPhase, 1); + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + llvm_unreachable("invalid operand index"); + } + + // ---- not implemented ---- + llvm_unreachable("unsupported swizzling for provided MMA version"); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans); + }]>, + ]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + int32_t getAlignment() const; + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + }]; + let hasCustomAssemblyFormat = 1; +} + +def NVMMASharedEncodingAttr : + TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> { + let mnemonic = "nvmma_shared"; + + let description = [{ + Represent blocked shared memory matching MMAv3/MMAv5 shared memory input. + This is meant to represent 2d tiled blocked layout. + The full layout representation is described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout + }]; + + + // fp4Padded: Indicates that this encoding represents a mixed-precision fp4 operand in MMAv5 scaled dot, which needs + // to be in the special padded layout as described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory + let parameters = ( + ins + "unsigned":$swizzlingByteWidth, + "bool":$transposed, + "unsigned":$elementBitWidth, + "bool":$fp4Padded, + "CTALayoutAttr":$CTALayout + ); + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool": $fp4Padded), [{ + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + int32_t swizzlingByteWidth = 0; + unsigned eleBitWidth = eltTy.getIntOrFloatBitWidth(); + int packingFactor = fp4Padded ? 2 : 1; + + // get proper shared memory swizzling mode from the contiguous dimension + // size of the origin blocked layout. + auto contigDimSizeInByte = shapePerCTA[order[0]] * packingFactor * eleBitWidth / 8; + if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) { + swizzlingByteWidth = 128; + } else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) { + swizzlingByteWidth = 64; + } else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) { + swizzlingByteWidth = 32; + } else { + llvm_unreachable("unsupported shared memory layout for MMAv3"); + } + bool transposed = order[0] == 0; + return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, fp4Padded, CTALayout); + }]> + ]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + int32_t getAlignment() const; + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + SmallVector getOrder() const { + return getTransposed() ? SmallVector({0, 1}) : SmallVector({1, 0}); + } + }]; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Distributed Layout Encoding +//===----------------------------------------------------------------------===// +def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ +The Distributed encoding describes the layout L with the 4-level compute hierarchy on GPU. +It is abstracted from the top to the bottom as CTAs Per CGA->Warps Per CTA->Threads Per Warp->Values Per Thread. + +For CTAs Per CGA and Warps Per CTA level, the linear id is distributed contiguously with the shape and order. +For example, for a shape/order pair defines a distribution layout +shape = [4, 4] +order = [0, 1] // The fastest-changing axis first +-> +layout = [0 4 8 12] + [1 5 9 13] + [2 6 10 14] + [3 7 11 15] + +For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + +If the layout does not completely cover the tensor, we tile it until we cover the entire tensor. +We call each individual tile "rep". + }]; + + let methods = [ + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrder">, + InterfaceMethod<"Return total element size per thread.", + "unsigned", + "getTotalElemsPerThread", + (ins "ArrayRef":$shape), + /*defaultImplementation=*/[{ + return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape); + }]>, + InterfaceMethod<"Return element size per thread in each dimension.", + "SmallVector", + "getElemsPerThread", + (ins "ArrayRef":$shape), + /*defaultImplementation=*/[{ + return toLinearEncoding($_self, shape).getElemsPerThread(shape); + }]>, + // Interface for the meta information about the multiple thread hierarchy. + InterfaceMethod<"Get the shape of the warps per CTA.", + "SmallVector", + "getWarpsPerCTA">, + + + InterfaceMethod<"Get the shape of the threads per warp", + "SmallVector", + "getThreadsPerWarp">, + InterfaceMethod<"Convert to LinearLayout.", + "LinearLayout", + "toLinearLayout", + (ins "ArrayRef":$shape)>, + + // Legacy methods: They do not take into account the shape of the tensor + // that is, the fact that we use them to tile the tensor. + InterfaceMethod<"Get the default order of the registers per warp. The fastest-changing axis first", + "SmallVector", + "getDefaultOrder">, + + InterfaceMethod<"Get the default order of the threads per warp. The fastest-changing axis first", + "SmallVector", + "getDefaultThreadOrder">, + + InterfaceMethod<"Get the default order of the warps per CTA. The fastest-changing axis first", + "SmallVector", + "getDefaultWarpOrder"> + + ]; +} + +class DistributedEncoding traits = [], + Dialect dialect = TritonGPU_Dialect> + : TritonGPU_Attr { + + let description = [{ +Distributed encodings have a layout function L that is entirely characterized +by a d-dimensional tensor T. Note that L doesn't need to have the same shape +(or even the same rank) as the tensor it is encoding. + +The layout function \mathcal{L} of this layout is then defined, for an +index `i` \in Z^d, as follows: + +\mathcal{L}(T)[i_d] = L[(i_d + k_d*T.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*T.shape[d] < L.shape[d] + +Intuitively, when the tensor dim size T.shape[d] is larger than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "wrapped around" manner, with +each thread owning multiple values. + +OTOH, when the tensor dim size T.shape[d] is smaller than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "broadcasted" manner, with +each value owned by multiple threads. + +For example, for a tensor/layout pair +T = [x x x x x x x x] + [x x x x x x x x] +L = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] + +Then the data of T would be distributed as follow between the 16 CUDA threads: +L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, + {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ] + }]; + + code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + // Implemented in subclasses + SmallVector getRepOrder() const; + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + SmallVector getWarpsPerCTA() const; + SmallVector getThreadsPerWarp() const; + + LinearLayout toLinearLayout(ArrayRef shape) const; + + // Legacy methods: They do not take into account the shape of the tensor + SmallVector getDefaultWarpOrder() const; + SmallVector getDefaultThreadOrder() const; + SmallVector getDefaultOrder() const; + }]; +} + +//===----------------------------------------------------------------------===// +// Linear Layout Encoding +//===----------------------------------------------------------------------===// + +def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout", + "linear layout"> { + let cppAccessorType = "const LinearLayout &"; +} + +def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { + let mnemonic = "linear"; + + let description = [{ + See the docs in LinearLayout.h for the definition of linear layouts. + }]; + + let parameters = (ins LinearLayoutParam:$linearLayout); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + // Generic distributed encoding methods + unsigned getTotalElemsPerThread(ArrayRef shape) const; + SmallVector getElemsPerThread(ArrayRef shape) const; + + SmallVector getContig(const char *, SmallVector) const; + SmallVector getContigPerThread() const; + SmallVector getContigPerWarp() const; + SmallVector getOrder() const; + SmallVector getWarpOrder() const; + SmallVector getThreadOrder() const; + + // Generalizes get{Warp,Thread,CTA}Order to linear layouts. + // Returns the order of the dimensions `dimName` of the layout. + // If more than dimension is of size one, it uses defaultOrder to determine + // the order of the dimensions of size one. + SmallVector orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const; + + // Generalizes getThreadsPerWarp, getWarpsPerCTA, getCTAsPerCGA to linear layouts. + // Returns the bases of the dimensions `dimName` of the layout. + // If skipBroadcast is false, we count a base zero + SmallVector basesPerDim(StringAttr dimName, + bool skipBroadcast = true) const; + + // [FIXME LL] Supports legacy behaviour. We should remove these functions + SmallVector getShapePerCTATile() const; + SmallVector getSizePerThread() const; + }]; + + let genVerifyDecl = 1; + // Example of assembly format: + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + let hasCustomAssemblyFormat = 1; +} + + +//===----------------------------------------------------------------------===// +// Blocked Layout Encoding +//===----------------------------------------------------------------------===// + +def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding", "blocked_encoding"> { + let mnemonic = "blocked"; + + let description = [{ +An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout +used to promote memory coalescing in LoadInst and StoreInst. +It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which +specify the amount of elements owned by each CUDA thread, warp and CTA respectively. + +Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and +4 CTAs (taking 2x2 for example) as follows: + +CTA [0,0] CTA [0,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +CTA [1,0] CTA [1,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {2, 2} + CTASplitNum = {2, 2} +}> +}]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$sizePerThread, + ArrayRefParameter<"unsigned">:$threadsPerWarp__, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first + + // CTALayout is optional in the textual IR. If omitted, we infer it to be a + // single CTA (so CTAsPerCGA = [1,...,1], CTASplitNum = [1,...,1], + // CTAOrder=[n,n-1,...,0]). + "CTALayoutAttr":$CTALayout + ); + let genVerifyDecl = 1; + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "CTALayoutAttr":$CTALayout), [{ + unsigned rank = sizePerThread.size(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + unsigned remainingLanes = numThreadsPerWarp; + unsigned remainingThreads = numWarps * numThreadsPerWarp; + unsigned remainingWarps = numWarps; + unsigned prevLanes = 1; + unsigned prevWarps = 1; + + // starting from the contiguous dimension + for (unsigned d = 0; d < rank - 1; ++d) { + unsigned i = order[d]; + unsigned threadsPerCTA = std::clamp(remainingThreads, 1, std::max(1, shapePerCTA[i] / sizePerThread[i])); + threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); + warpsPerCTA[i] = std::clamp(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps); + remainingWarps /= warpsPerCTA[i]; + remainingLanes /= threadsPerWarp[i]; + remainingThreads /= threadsPerCTA; + prevLanes *= threadsPerWarp[i]; + prevWarps *= warpsPerCTA[i]; + } + + // Expand the last dimension to fill the remaining lanes and warps + threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes; + warpsPerCTA[order[rank - 1]] = numWarps / prevWarps; + + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "unsigned":$numCTAs), [{ + unsigned rank = sizePerThread.size(); + SmallVector CTAsPerCGA(rank); + SmallVector CTASplitNum(rank); + ArrayRef CTAOrder = order; + + unsigned remainingCTAs = numCTAs; + + // starting from the most strided dimension + for (int d = rank - 1; d >= 0; --d) { + unsigned i = order[d]; + CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, std::max(1, shape[i] / sizePerThread[i])); + CTASplitNum[i] = CTAsPerCGA[i]; + remainingCTAs /= CTAsPerCGA[i]; + } + + CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level + + CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// MMA Layout Encoding +//===----------------------------------------------------------------------===// + +def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let methods = [ + InterfaceMethod<"Return the number of threads per warp for dot operands.", + "SmallVector", + "getThreadsPerWarpForOperand", + (ins "int":$opIdx)>, + + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrderForOperand", + (ins "int":$opIdx)>, + ]; +} + +def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_mfma"; + + let description = [{ +An encoding for tensors that have been produced by MFMA matrix core instructions, +available on AMD Instinct GPUs of CDNA architectures. + +It is characterized by the following parameters: +- `versionMajor` and `versionMinor` indicates the GPU architecture: + - 1.0: gfx908, i.e. MI100 + - 2.0: gfx90a: i.e. MI200, MI210, MI250 + - 3.0: gfx940, gfx941, gfx942: MI300 + - 4.0: gfx950: MI350 +- `warpsPerCTA` indicates the warp layout in the block. +- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction. +- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout +without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel). + +Example 1: +Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. +The data will be distributed between threads as follows: + + warp 0 warp 1 +-----------------/\-------------- -----------------/\-------------- +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] + +Example 2: +Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16. +The data will be distributed between threads as follows: + + warp 0 warp 1 +-----------------/\------------- ------------------/\--------------- +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] + +Example 3: +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): + +M N -> warp 0 warp 2 +| --------------------------/\-------------------------- ------------------------------/\------------------------------ +V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + warp 1 warp 3 + --------------------------/\-------------------------- ------------------------------/\------------------------------ + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] +}]; + + let parameters = ( + ins + "unsigned": $versionMajor, + "unsigned": $versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "unsigned":$MDim, + "unsigned":$NDim, + "bool":$isTransposed, + "CTALayoutAttr":$CTALayout + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + SmallVector getThreadsPerWarpForOperand(int opIdx) const; + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_wmma"; + + let description = [{ +An encoding for tensors that have been produced by WMMA matrix core instructions, +available on AMD Radeon GPUs of RDNA architectures. +- A `version` parameter specifies instruction version to lower in. The data + distribution within one warp is also depends on it. Following architectures are + supported: + - 1: gfx11 + - 2: gfx12 +- A `warpsPerCTA` parameter characterizes data distribution between warps. + An important limitation of WMMA for layout is a shape for tiles processed + by a single warp. It is [16, 16]. + This encoding assumes specific access to matrix elements by threads. + +Example: +Suppose we have a tensor with shape [32, 64], `warpsPerCTA` set to [2, 2]. +Matrix elements represent which lane owns the element. Currently only wave32 mode +is supported. + +// ----------------------------------- version = 1 ----------------------------------- // + +Row | warp 0 warp 1 + |/-------------------^-------------------\ /-------------------^-------------------\ +0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +1 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +2 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +3 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + | ... ... ... ... +14 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + + | warp 2 warp 3 +16 |/-------------------^-------------------\ /-------------------^-------------------\ +17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +18 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +19 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +20 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + | ... ... ... ... +30 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + +// ------------------------ version = 2, isTransposed = false ------------------------ // + +Row | warp 0 warp 1 + |/--------^---------\ /---------^--------\ +0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +1 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +.. | ... ... +6 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +7 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +8 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +9 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +.. | ... ... +14 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] + | + | warp 2 warp 3 + |/--------^---------\ /---------^--------\ +16 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +.. | ... ... +22 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +23 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +24 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +25 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +.. | ... ... +30 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] + +// ------------------------ version = 2, isTransposed = true ------------------------ // + + | warp 0 warp 1 + |/----------------^----------------\ /-------^-------\ +Col>| 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 ... 32 +Row | +0 |[0 0 0 0 0 0 0 0 16 ... 16] [0 0 0 ... 16] +1 |[1 1 1 1 1 1 1 1 17 ... 17] [1 1 1 ... 17] +.. | ... ... +14 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30] +15 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31] + | + | warp 2 warp 3 + |/----------------^----------------\ /-------^-------\ +16 |[0 0 0 0 0 0 0 0 16 ... 16] [0 0 0 ... 16] +17 |[1 1 1 1 1 1 1 1 17 ... 17] [1 1 1 ... 17] +.. | ... ... +30 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30] +31 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31] + }]; + + let parameters = ( + ins + "unsigned": $version, + "bool":$isTransposed, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout + ); + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getElemsPerInstrForOperands() const; + SmallVector getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + SmallVector getThreadsPerWarpForOperand(int opIdx) const; + unsigned getKWidthForOperands() const; + static SmallVector getMNKDimPerInstr(); + }]; +} + +def NvidiaMmaEncodingAttr : DistributedEncoding<"NvidiaMmaEncoding", "nvidia_mma_encoding", [MmaEncodingTrait]> { + let mnemonic = "nvidia_mma"; + + let description = [{ +An encoding for tensors that have been produced by tensor cores. + +It is characterized by two parameters: +- A 'versionMajor' which specifies the generation the tensor cores + whose output is being partitioned: + - 1 for first-gen tensor cores (Volta), and + - 2 for second-gen tensor cores (Turing/Ampere). +- A 'versionMinor' which indicates the specific layout of a tensor core + generation, e.g. for Volta, there might be multiple kinds of layouts + annotated by 0,1,2 and so on. +- A `blockTileSize` to indicate how data should be partitioned between warps. + +// -------------------------------- version = 1 --------------------------- // + +For first-gen tensor cores, the implicit warpTileSize is [16, 16]. +Note: the layout is different from the recommended in PTX ISA +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.884 section, FP32 accumulator). + +For example, when versionMinor=1, the matrix L corresponding to +blockTileSize=[32,16] is: + + warp 0 +--------------------------------/\------------------------------- +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] + + warp 1 = warp0 + 32 +--------------------------------/\------------------------------- +[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ] +[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ] +[ ............................................................... ] + + +// -------------------------------- version = 2 --------------------------- // + +For second-gen tensor cores, the implicit warpTileSize is [16, 8]. +Information about this layout can be found in the official PTX documentation +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.16816 section, FP32 accumulator). + +For example, the matrix L corresponding to blockTileSize=[32,16] is: + warp 0 warp 2 +-----------------/\------------- ----------------/\------------- +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 + + warp 1 warp 3 +----------------/\------------- ----------------/\------------- +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 + +}]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isVolta() const; + bool isTuring() const; + bool isAmpere() const; + bool isHopper() const; + + SmallVector getRepForOperand(ArrayRef shape, + int bitwidth, int kWidth, + int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + SmallVector getThreadsPerWarpForOperand(int opIdx) const; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { + let mnemonic = "slice"; + + let description = [{ + Given a `parent` layout and a `dim`, squeezes the given `dim` in the `parent` + layout and distributes values in a tensor T according to the new layout. + + For example, given + + T = [x x x x x x x x] + L_parent = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] (with 16 CUDA threads) + + With dim = 0, squeezing out dim 0, we have + L = [{0,4,8,12}, {1,5,9,13}, {2,6,10,14}, {3,7,11,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L(T) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ] + + With dim = 1, squeezing out dim 1, we have + L = [ {0,1,2,3}, {4,5,6,7}, {8,9,10,11}, {12,13,14,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L = [ {0,1,2,3}, {4,5,6,7}, ..., {12,13,14,15}, {0,1,2,3}, ..., {12,13,14,15} ] + + This is useful for constructing the inverse layout of an expand_dims operation + during some optimization passes. + }]; + + let parameters = ( + ins + "unsigned":$dim, + "DistributedEncodingTrait":$parent + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + template + SmallVector paddedShape(ArrayRef shape) const; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> { + let mnemonic = "dot_op"; + + let description = [{ +In the TritonGPU dialect, given `d = tt.dot a, b, c` tt.dot's operands a and b +must be of DotOperandEncodingAttr layout, if the dot is MMA v1 or v2 (i.e. +pre-Hopper). For MMA v3, the operands are *almost always* in a regular shared +encoding, but sometimes the LHS is also a dot-operand encoding. + +a's opIdx is 0, b's opIdx is 1. + +The parent field is the layout of d. + +kWidth defines number of consecutive elements stored by one thread along k dimension. +Some layouts do not use this parameter, either because they have a fixed number of +elements along the K dim, or they use all elements of the tensor along the K dim. + +# WGMMA Notes +We require kWidth to be provided for Hopper because the dtype at loading might be +different from the dtype at WGMMA, due to casting. The kWidth is determined by the +dtype at WGMMA. + +The encoded tensor consists of operand A for possibly multiple wgmma instructions. +For each wgmma, each warp in a warp group feeds a single "warp matrix" +Each warp matrix consists of 2x2 "quads". +Each thread holds several elements in each quad. Right before a wgmma, +the sum of bitwidth of +the elements in each quad should add up to 32. + +These values are stored unrolled in `elements`. +The ordering of dimensions is as follows by convention: +batch (only 1 batch for Hopper currently) +matM (m-index of the "warp matrix") +matK (k-index of the "warp matrix") +quadK (k-index of the "quad" in the core matrix) +quadM (m-index of the "quad" in the core matrix) +vecIdx (index of the element in the quad; this is always along the k-dim) + }]; + + let parameters = ( + ins + "unsigned":$opIdx, + "Attribute":$parent, + DefaultValuedParameter<"unsigned", "0">:$kWidth + ); + + let builders = [ + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "Type":$eltTy), [{ + NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); + if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) + return $_get(context, opIdx, parent, 0); + // For MMAV2 and V3 + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + unsigned kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, kWidth); + }]> + ]; + + let assemblyFormat = "`<` `{` struct(params) `}` `>`"; + let genVerifyDecl = 1; + let extraClassDeclaration = extraDistributedDeclaration; +} + +def TTG_SharedMemorySpace : AttrDef { + let mnemonic = "shared_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to shared memory. + }]; +} +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td new file mode 100644 index 000000000..b75b586e6 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -0,0 +1,39 @@ +#ifndef TRITONGPU_DIALECT +#define TRITONGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonGPU_Dialect : Dialect { + let name = "ttg"; + + let cppNamespace = "::mlir::triton::gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "mlir::gpu::GPUDialect", + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + + LinearLayout toLinearLayout(ArrayRef shape, Attribute layout); + + static int getNumCTAs(ModuleOp mod); + static int getThreadsPerWarp(ModuleOp mod); + + private: + LinearLayoutCache llCache; + }]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h new file mode 100644 index 000000000..1e76237da --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -0,0 +1,9 @@ +#ifndef TRITON_GPU_DIALECT_INTERFACES_H +#define TRITON_GPU_DIALECT_INTERFACES_H + +// clang-format off +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc" +// clang-format on + +#endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td new file mode 100644 index 000000000..36cd929e9 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -0,0 +1,450 @@ +#ifndef TRITONGPU_OPS +#define TRITONGPU_OPS + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" // RegionBranchOpInterface +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ViewLikeInterface.td" + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +class TTG_Op traits = []> : + Op { +} + +def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", + [SameOperandsAndResultShape, + SameOperandsAndResultElementType, + Pure]> { + let summary = "convert layout"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { + let summary = "async wait"; + + let arguments = (ins Variadic:$asyncToken, I32Attr:$num); + + let results = (outs TTG_AsyncToken:$retToken); + + let assemblyFormat = "$asyncToken attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { + let summary = "async commit group"; + + let results = (outs TTG_AsyncToken:$asyncToken); + let arguments = (ins Variadic:$inputTokens); + + let assemblyFormat = [{ + $inputTokens attr-dict + }]; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + TypesMatchWith<"infer mask type from src type", + "src", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, + TypesMatchWith<"infer other type from src type", + "src", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 4) || std::equal_to<>()"> +]> { + let summary = "copy data from global memory to local memory asynchronously"; + + let hasVerifier = 1; + let description = [{ + This operation copies data from global memory to local memory asynchronously. + This is analogue to tt.load except the data are copied to local memory pointed + by by the memory descriptor instead of a distributed tensor. The rest of the + operands are the same as tt.load. + }]; + + let arguments = ( + ins TT_PtrTensor:$src, + TTG_MemDescType:$result, + Optional:$mask, + Optional:$other, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let builders = [ + OpBuilder<(ins "Value":$src, "Value":$result, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + ]; + + let results = (outs TTG_AsyncToken:$token); + + let extraClassDeclaration = [{ + static DenseSet getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; + } + }]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between other, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $src `,` $result (`mask` $mask^)? (`other` $other^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($src) `->` type($result) + }]; +} + + +// Allocate shared memory +def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods]> { + let summary = "allocate tensor"; + let description = [{ + This operation allocates buffer in shared memory and return a descriptor + containing the address and a view of the buffer. + + Explicitly deallocating a buffer is optional; see local_dealloc. + }]; + let arguments = ( + ins + Optional:$src, + OptionalAttr:$alignment + ); + + let builders = [ + OpBuilder<(ins "Type":$result), + [{ build($_builder, $_state, result, Value(), IntegerAttr()); }]>, + OpBuilder<(ins "Type":$result, "Value":$src), + [{ build($_builder, $_state, result, src, IntegerAttr()); }]>, + OpBuilder<(ins "Type":$result, "Value":$src, "int32_t":$alignment), + [{ build($_builder, $_state, result, src, $_builder.getI32IntegerAttr(alignment)); }]> + ]; + + let extraClassDeclaration = [{ + bool isSharedMemoryAlloc() { + return getType().getMemorySpace() && + isa(getType().getMemorySpace()); + } + int32_t getAlignmentOrDefault(); + }]; + let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; + + let results = (outs TTG_MemDescType:$result); + let hasFolder = 1; + let hasVerifier = 1; +} + +// Deallocate shared memory +def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree]>]> { + let summary = "dealloc buffer"; + + let description = [{ + This operation deallocates a buffer explicitly. Using the buffer after this + operation is undefined. + + This operation is optional. If you don't explicitly dealloc a buffer, the + compiler assumes it's deallocated at the first point that post-dominates all + uses of the alloc. + + Because we assume a memdesc is dead at the first point that post-dominates + its uses, ops that wait for an async operation on a memdesc to complete + (such as ttng.warp_group_dot_wait) should also take the memdesc as an + operand. + }]; + + let arguments = (ins TTG_MemDescType:$src); + + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; +} + +def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { + let summary = "take a subview of the descriptor."; + + let description = [{ + This operation returns a new descriptor representing a subview of the buffer. + It doesn't affect the underlying memory. The subview can be rank-reduced. + + For example, suppose that + - the input shape is 2x4x16xf16, + - the output shape is 4x4xf16, and + - offsets = [1, 0, 4]. + + Then in Python syntax, the subview covers input[1][0:4][4:8]. + }]; + let arguments = ( + ins TTG_MemDescType:$src, Variadic:$offsets); + + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; + + let results = (outs TTG_MemDescType:$result); + + let hasVerifier = 1; +} + +def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, + TransposeOpInterface, + InferTypeOpWithLayoutEquivalence, + SameOperandsAndResultElementType]> { + let summary = "transpose the descriptor"; + + let description = [{ + This operation returns a new descriptor + representing a transposed view of the buffer. + }]; + + let arguments = (ins TTG_MemDescType:$src, Variadic:$order); + + let arguments = ( + ins TTG_MemDescType:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TTG_MemDescType:$result); + + let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))"; + + let hasFolder = 1; +} + +def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods]> { + let summary = "Load a buffer from local memory into a distributed tensor"; + + let description = [{ + Load a tensor from the local memory descriptor into a distributed tensor. + }]; + let arguments = (ins TTG_MemDescType:$src, Optional :$token); + + let builders = [ + OpBuilder<(ins "Type":$retType, "Value":$src), + [{ + build($_builder, $_state, retType, src, /*token=*/static_cast(nullptr)); + }]>]; + + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; + + let results = (outs TT_Tensor:$result); +} + +def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods]> { + let summary = "Store a distributed tensor into a buffer in local memory"; + + let description = [{ + Store a distributed tensor into a buffer in local memory. + }]; + let arguments = (ins TT_Tensor:$src, TTG_MemDescType:$dst); + + let hasVerifier = 1; + // Use qualified() otherwise "!ttg.memdesc" is printed as "". + let assemblyFormat = [{ + $src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst)) + }]; +} + +def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> { + let summary = "Upcast fp4 (e2m1) to fp"; + + let hasVerifier = 1; + + let description = [{ + Upcast fp4 (e2m1) represented packed as i8s to fp. + + The lower 4 bits of the i8s represent the first fp4 element, and the upper 4 bits + the second fp4 element. + + The `axis` attribute specifies the axis along which the fp4 elements are packed. + }]; + + let builders = [ + OpBuilder<(ins "TypedValue":$src, "Type":$elemType, "int32_t":$axis)> + ]; + + let arguments = (ins RankedTensorOf<[I8]>:$src, I32Attr:$axis); + let results = (outs TT_FloatTensor:$result); + + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + +// Allocate global memory +def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[MemAlloc]>]> { + let summary = "allocate a global memory buffer"; + let description = [{ + This operation allocates a buffer in global memory that is private to the current program. + }]; + let arguments = ( + ins + I32Attr:$nbytes, + I32Attr:$alignment + ); + let results = (outs TT_Ptr:$result); + + let builders = [ + OpBuilder<(ins "Type":$result, "int32_t":$nbytes, "int32_t":$alignment), + [{ build($_builder, $_state, result, + $_builder.getI32IntegerAttr(nbytes), $_builder.getI32IntegerAttr(alignment)); }]> + ]; + + let assemblyFormat = [{attr-dict `:` qualified(type($result))}]; +} + +def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [ + RecursiveMemoryEffects, RecursivelySpeculatable, AsyncRegions, + DeclareOpInterfaceMethods +]> { + let summary = "asynchronously execute code on multiple warpgroups"; + let description = [{ + The `ttg.warp_specialize` op represents executing different code + simultaneously on different warp groups. A warp group is a group of + power-of-2 warps, which can be a different number of warps than in the + enclosing region. + + The "default" region of the op represents the code executed by the currently + executing warp group. This region is allowed to implicitly capture. The op + contains a number of "partition" regions that are isolated from above. They + must be isolated because these regions represent different layout domains, + as the number of warps is different. + + Semantically, execution of each region starts simultaneously for each warp + group, and all warp groups are joined at the end of the op. + + Example: + + ```mlir + %0 = ttg.warp_specialize(%a, %b) + default { + %out = some_operation(%a) // implicit capture of `%a` + ttg.warp_yield %out : i32 + } + partition0(%arg0: i32, %arg1: i32) num_warps(8) { + some_async_dispatch(%arg0, %arg1) + ttg.warp_return + } + partition1(%arg0: i32, %arg1: i32) num_warps(1) { + some_async_dispatch(%arg0, %arg1) + ttg.warp_return + } : (i32, i32) -> i32 + ``` + }]; + + let arguments = (ins + Variadic:$explicitCaptures, + DenseI32ArrayAttr:$partitionNumWarps, + OptionalAttr:$warpGroupStartIds + ); + let results = (outs Variadic:$defaultPassthrough); + + let regions = (region + MinSizedRegion<1>:$defaultRegion, + SizedRegion<1>:$partitionOpHolder + ); + + let extraClassDeclaration = [{ + RegionRange getPartitionRegions(); + + // Get the size and alignment of the capture list. + std::pair getCaptureSizeAlign(); + // Get the total number of extra warps required. + unsigned getTotalPartitionWarps(); + }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def TTG_WarpSpecializePartitionsOp : TTG_Op<"warp_specialize.partitions", [ + IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatable, + Terminator, HasParent<"WarpSpecializeOp"> +]> { + let summary = "container op for `ttg.warp_specialize`"; + let description = [{ + Because MLIR requires entire operations be isolated from above, this op + contains the actual isolated from above regions of `ttg.warp_specialize`. + }]; + + let regions = (region VariadicRegion>:$partitionRegions); +} + +def TTG_WarpYieldOp : TTG_Op<"warp_yield", [ + Pure, Terminator, ReturnLike, HasParent<"WarpSpecializeOp">, + DeclareOpInterfaceMethods +]> { + let summary = "yield from the default region of `ttg.warp_specialize`"; + let description = [{ + The `ttg.warp_yield` operation is the terminator for the "default" region of + a `ttg.warp_specialize` operation. The operands are passed transparently as + the SSA results of the `ttg.warp_specialize` operation. + + Example: + + ```mlir + ttg.warp_yield %a, %b : i32, tensor<32xbf16, #blocked> + ``` + }]; + + let arguments = (ins Variadic:$values); + + let assemblyFormat = "($values^)? attr-dict (`:` type($values)^)?"; + let hasVerifier = 1; +} + +def TTG_WarpReturnOp : TTG_Op<"warp_return", [ + Pure, Terminator, ReturnLike, HasParent<"WarpSpecializePartitionsOp"> +]> { + let summary = "implicit terminator from partition regions"; + let description = [{ + The `ttg.warp_return` operation is the implicit terminator that ends the + partition regions of a `ttg.warp_specialize` op. It has no operands as these + regions cannot return anything. + + TODO: Support returning uniform values from partition regions. + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // TRITONGPU_OPS diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td new file mode 100644 index 000000000..a0415b62c --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td @@ -0,0 +1,23 @@ +#ifndef TRITON_GPU_TYPE_INTERFACES +#define TRITON_GPU_TYPE_INTERFACES + +include "mlir/IR/OpBase.td" + +// Interface dynamically attached to RankedTensorType and MemDescType. +def TTG_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> { + let cppNamespace = "::mlir::triton::gpu"; + let methods = [ + InterfaceMethod<"Returns the encoding of the tensor or memory descriptor", + "mlir::Attribute", "getEncoding", (ins)>, + InterfaceMethod<"Returns element type", + "mlir::Type", "getElementType", (ins)>, + InterfaceMethod<"Returns the type shape", + "llvm::ArrayRef", "getShape", (ins)>, + InterfaceMethod<"Returns the tensor or buffer rank", + "int64_t", "getRank", (ins)>, + InterfaceMethod<"Returns the element type bit width", + "int64_t", "getElementTypeBitWidth", (ins)>, + ]; +} + +#endif // TRITON_GPU_TYPE_INTERFACES diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td new file mode 100644 index 000000000..8061a9879 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td @@ -0,0 +1,101 @@ +#ifndef TRITONGPU_TYPES +#define TRITONGPU_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" + +class TTG_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTG_TokenType : TTG_TypeDef<"Token", "token"> { + let parameters = (ins "int32_t":$type); + + let builders = [ + TypeBuilder<(ins "unsigned":$type), [{ + return $_get($_ctxt, type); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", "async.token", []> { + let summary = "async token type"; + let description = [{ + `ttg.async.token` is a type returned by an asynchronous operation. + It is used to establish an SSA-based link between async operations + and operations that group or synchronize the async operations. + }]; +} + +// Memory descriptor type. +def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { + let summary = "memory descriptor type (`::mlir::triton::gpu::MemDescType`) in Triton IR type system"; + + let description = [{ + Memory descriptor contains a base pointer (scalar) and a descriptor of the memory. + If mutable memory is false that means the memory is constant and can only be allocated and stored once. + A constant memory allocation is different than a tensor as it can have multiple views and the descriptor + can be changed without changing the underlying memory. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory, + ArrayRefParameter<"int64_t">:$allocShape + ); + + let extraClassDeclaration = [{ + MemDescType cloneWith(std::optional> shape, + Type elementType) const { + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory(), getAllocShape()); + } + + bool hasRank() const { return true; } + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false, /*allocShape=*/shape); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, /*allocShape=*/shape); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory, + "llvm::ArrayRef":$allocShape + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, allocShape); + }]> + + ]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Types.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Types.h new file mode 100644 index 000000000..82ab3ae45 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/IR/Types.h @@ -0,0 +1,13 @@ +#ifndef TRITONGPU_IR_TYPES_H_ +#define TRITONGPU_IR_TYPES_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.h.inc" + +#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..6be94d1a8 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU) +add_public_tablegen_target(TritonGPUTransformsIncGen) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h new file mode 100644 index 000000000..c03d2dac6 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h @@ -0,0 +1,8 @@ +#include "mlir/IR/PatternMatch.h" + +namespace mlir::triton::gpu { + +void populateDecomposeScaledBlockedPatterns(mlir::RewritePatternSet &patterns, + int benefit); + +} // namespace mlir::triton::gpu diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Passes.h new file mode 100644 index 000000000..c50d24a08 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -0,0 +1,22 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gpu { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +} // namespace gpu +} // namespace triton +} // namespace mlir +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Passes.td new file mode 100644 index 000000000..0341c8599 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -0,0 +1,399 @@ +#ifndef TRITONGPU_PASSES +#define TRITONGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { + let summary = "pipeline"; + + let description = [{ + Applies software pipelining to loops in the module based on number of stages. + This may convert some load into asynchronous loads, and multi-buffer the data. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages">, + Option<"dumpIntermediateSteps", "dump-intermediate-steps", + "bool", /*default*/"false", + "Dump intermediate steps"> + ]; +} + +// def TritonGPUTC05MMAPipeline : Pass<"tritongpu-tc05mma-pipeline", "mlir::ModuleOp"> { +// let summary = "Test pass calling TC05MMA pipeline"; + +// let description = [{ +// This pass is used to test the TC05MMA pipelining under LIT. Internally it calls +// `getTC05MMASchedule` to get the schedule and then applies the pipelining. +// }]; + +// let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", +// "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", +// "mlir::scf::SCFDialect", +// "mlir::arith::ArithDialect"]; + +// let options = [ +// Option<"disableExpander", "disable-expander", "bool", /*default*/"false", "Run only loop pre-process"> +// ]; +// } + +def TritonGPUTestPipelineAssignLatencies : Pass<"tritongpu-test-pipeline-assign-latencies", "mlir::ModuleOp"> { + let summary = "test assigning latencies to interesting ops ahead of pipelining"; + + let description = [{ + This is a test pass that tests `assignLatencies` method of `TritonGPUPipeline`. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-loop", "mlir::ModuleOp"> { + let summary = "test scheduling a loop for software pipelining"; + + let description = [{ + This is a test pass that tests `scheduleLoop` method of `TritonGPUPipeline`. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + +def TritonGPUTestPipelineLowerLoop : Pass<"tritongpu-test-pipeline-lower-loop", "mlir::ModuleOp"> { + let summary = "test lowering a loop for software pipelining"; + + let description = [{ + This is a test pass that tests `lowerLoop` method of `TritonGPUPipeline`. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + +def TritonGPUFuseNestedLoops : Pass<"tritongpu-fuse-nested-loops", "mlir::ModuleOp"> { + let summary = "fuse nested loops for pipelining"; + + let description = [{ + The `tritongpu-fuse-nested-loops` pass will analyze loop nests in the module + that need to be pipelined and fuse them into a single loop. This composes + with the pipeliner to pipeline loop nests. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::arith::ArithDialect", + "mlir::ub::UBDialect", + ]; +} + +def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { + let summary = "3xTF32 trick"; + + let description = [{ + Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s + to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385 + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; +} + +def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { + let summary = "prefetch"; + + let description = [{ + This pass attempts to prefetch from shared memory the operands (A and B) + of a `tt.dot`, when this operation is located in a loop. + Decompose `DotOp` instructions in loops into several finer-grained `DotOp` + that may have their operands constructed at the end of the previous + iteration. + Transformations are performed in five different places: + 1. The pass emits a prologue to the loop where the data for the first + loop iteration are prefetched. + 2. The loop arguments are extended with the new prefetched values. + 3. The dotOp parameters is updated with the new args. + 4. The prefetch operations for the next iteration are added to the loop. + 5. The yieldOp is updated by adding the prefetched values for the next + iteration. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + +def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> { + let summary = "accelerate matmul"; + + let description = [{ + Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators + (e.g., Nvidia tensor cores) + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> { + let summary = "fuse transpositions"; + + let description = [{ + Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of + hardware-accelerated transpositions. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + Option<"hoistLayoutConversion", "hoist-layout-conversion", + "bool", /*default*/"true", + "whether to move conver to dot operand earlier pass elementwise ops"> + ]; +} + +def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { + let summary = "coalesce"; + + let description = [{ + The pass analyses loads/stores with type `tensor>` or + `tt.ptr>` and replaces the layouts of these operations with + coalesced layouts, i.e. cache friendly access patterns. + Layout conversions are inserted before and after the load/store op + to maintain consistency with the rest of the program. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + + +def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> { + let summary = "remove superfluous layout conversions"; + + let description = [{ + The purpose of this pass is to rewrite the `ConvertLayoutOps` to reduce + the number of operations and to prefer favorable layouts like + `BlockedEncodingAttr` layout for "expensive" loads and stores + (good for coalescing) and `NvidiaMmaEncodingAttr` otherwise + (good for tensor ops). + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + +} + +def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality", "mlir::ModuleOp"> { + let summary = "Reduce the cost of synchronization between threads in an SM"; + + let description = [{ + The aim of this pass is to reduce cross-thread communication for certain + operations, like reductions, reshapes, and gathers. + + For reduction operations, this pass attempts to adjust the reduction size + (or layout) to avoid splitting the reduction operation between multiple + threads. Currently, this pass only optimizes reduction yielded by loop to be + thread-local until after the loop completes. + + For gathers, this pass will attempt to pick an optimized layout for gather + operations in the module. This is determined based on the shapes of the + gather operands as well as their existing layouts. The pass applies + heuristics to determine when it is appropriate to assign specific layouts + and trigger their respective codegen paths. For now, the pass only attempts + to apply layouts that result in warp-synchronous gathers. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> { + let summary = "Reorder instructions"; + + let description = "This pass reorder instructions so as to (1) decrease register pressure (e.g., by moving " + "conversions from shared memory before their first use) and (2) promote LLVM instruction " + "order more friendly to `ptxas`."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReduceDataDuplication: Pass<"tritongpu-reduce-data-duplication", "mlir::ModuleOp"> { + let summary = "Reduce data duplication in register by decomposing convert[distributed -> dotOperand] " + "into convert[distributed -> shared -> dotOperand]"; + + let description = "Decomposing conversions this way makes it possible to use CSE and reuse #shared tensors"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and-if", "mlir::ModuleOp"> { + let summary = "Combine tensor select and if"; + + let description = "For select instruction that uses the same condition as the if instruction in the same block " + "this pass combines the select into the if instruction, making the select operands returned by the " + "then/else yields."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init", "mlir::ModuleOp"> { + let summary = "Replace accumulator zero-initialization with the flag indicating first use of the accumulator"; + + let description = "For the dot operations that support accumulator-use flag this pass replaces the zero-initialization " + "of the accumulator with the flag indicating the first use of the accumulator."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> { + let summary = "Improve coalescing for async global to local copies"; + + let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than " + "the blocked encoding's sizePerThread, this pass improves coalescing by clipping the " + "sizePerThread value"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUWSTaskPartition : Pass<"tritongpu-warp-spec-task-partition", "mlir::ModuleOp"> { + let summary = "Warp specialization task partition"; + + let description = "This pass computes a warp schedule partition by annoating anchor operations with async task ids"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUTaskIdPropagate : Pass<"triton-gpu-taskid-propagate", "mlir::ModuleOp"> { + let summary = "Propagate async_task_id annotations based on dependencies"; + + let description = [{ + This pass propagates the `async_task_id` annotation to the dependencies + of any op that has it set. This has the functional effect of partitioning + the graph into multiple async tasks, based on the initial annotation. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSCodePartition: Pass<"tritongpu-warp-spec-code-partition", "mlir::ModuleOp"> { + let summary = "TritonGPU warp specialization code partition"; + + let description = "This pass generates warp specialized code baed on task id attributes."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numBuffers", "num-buffers", + "int32_t", /*default*/"0", + "number of buffering for producer-consumer">, + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization">, + Option<"regDecProducer", "producer-reg-dec", + "int32_t", /*default*/"40", + "register decrement for producer warp group">, + Option<"regIncConsumer", "consumer-reg-inc", + "int32_t", /*default*/"232", + "register indrement for consumer warp group"> + ]; +} + +def TritonGPUWSDataPartition : Pass<"tritongpu-warp-spec-data-partition", "mlir::ModuleOp"> { + let summary = "Warp specialization data partition"; + + let description = "This pass partitions operations into multiple suboperations which operate on smaller data shapes"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSLowering : Pass<"tritongpu-warp-spec-lowering", "mlir::ModuleOp"> { + let summary = "Warp specialization lowering"; + + let description = "This pass lowers warp specializtion related operations."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSCanonicalization : Pass<"tritongpu-warp-spec-canonicalization", "mlir::ModuleOp"> { + let summary = "Warp specialization canonicalization"; + + let description = "This pass fixes up async task id for each op wrapped in a WS region."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUPingPongSync: Pass<"tritongpu-ping-pong-sync", "mlir::ModuleOp"> { + let summary = "TritonGPU experiemental ping pong schedule"; + + let description = "This pass inserts barriers to enforce critical section for gemms."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"1", + "number of consumer warp groups for warp specialization">, + ]; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h new file mode 100644 index 000000000..0a3d736c6 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h @@ -0,0 +1,101 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ + +// This is a fork of upstream pipeline transformation. This will be merged back +// upstream once we have a stable solution. + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class RewriterBase; +class Operation; +class Value; + +namespace scf { +class ForOp; +} + +namespace triton { + +/// Options to dictate how loops should be pipelined. +struct PipeliningOption { + /// Lambda returning all the operation in the forOp, with their stage, in the + /// order picked for the pipelined loop. + using GetScheduleFnType = std::function> &)>; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true the + /// pipeliner will have to predicate operations in the the prologue/epilogue. + bool supportDynamicLoops = false; + + // Callback to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. It is expected to replace the given operation with + // the predicated equivalent and return it, or return nullptr if the + // predication is impossible. In the latter case, pipelining will fail and + // may leave IR in a partially transformed state. + using PredicateOpFnType = + std::function; + PredicateOpFnType predicateFn = nullptr; + + // TODO: add option to decide if the prologue should be peeled. +}; + +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +/// +/// If `modifiedIR` is provided, it will be set to a value that indicates +/// whether pipelining modified the IR before failing, signaling to the caller +/// whether they can proceed with different transformations. +FailureOr pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp, + const PipeliningOption &options, + bool *modifiedIR = nullptr); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h new file mode 100644 index 000000000..a005e790a --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -0,0 +1,67 @@ +#ifndef TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ +#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include +#include + +namespace mlir { +namespace triton { + +static const char *kNumStagesAttrName = "tt.num_stages"; +static const char *kDisallowAccMultiBufferAttrName = + "tt.disallow_acc_multi_buffer"; +static const char *kLoopStageAttrName = "loop.stage"; +static const char *kLoopClusterAttrName = "loop.cluster"; +static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage"; +static const char *kLatencyAttrName = "tt.latency"; + +bool loopHasDistGreaterThanOne(scf::ForOp forOp); +bool isOuterLoop(scf::ForOp forOp); + +/// Function to mask operations during scheduling. +Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred); + +/// Replace all uses of `oldUse` with `val` and propagate the type if needed. +/// This is useful when we need to change a memory descriptor from immutable to +/// mutable. +void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, + Value val); + +// Return true if the given ForOp has the attribute +// `tt.disallow_acc_multi_buffer` set to true. +bool getDisallowAccMultiBuffer(scf::ForOp forOp); + +/// Visit the operands of `op` and the operands of any nested ops defined +/// outside of `op`. +void visitNestedOperands(Operation *op, function_ref visitor); +/// Get the operands of `op` and the operands of any nested ops defined outside +/// of `op`. +SetVector getNestedOperands(Operation *op); + +// Return maxumum length of the vectorized copy between registers and shared +// memory for the given tensor type and shared encoding. +int getCopyVecBytes(RankedTensorType registerTy, + gpu::SharedEncodingTrait sharedEnc); + +// Serialize the latencies of the operations in the loops into the latency +// attribute. +void serializeLatencies(ModuleOp module, DenseMap &opLatency); + +// Deserialize the latencies of the operations in the loops from the attribute. +DenseMap deserializeLatencies(ModuleOp module); + +// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a +// single buffer slice (leading dimension equal to 1), at the given index. +Value createSingleBufferView(OpBuilder &builder, Value alloc, Value idx); +Value createSingleBufferView(OpBuilder &builder, Value alloc, int idx); + +// Create an allocation and init the mbarriers. +Value createBarrierAlloc(scf::ForOp forOp, int numBarriers); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Schedule.h new file mode 100644 index 000000000..21f877b8e --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -0,0 +1,169 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include + +namespace mlir { +namespace triton { + +namespace gpu { + +/// Discover operations that should become async and assign latencies to them +/// based on the numStages value provided by the user. +void assignLatencies(ModuleOp moduleOp, int numStages); + +/// Schedule the loops based on the latencies assigned to the operations. +void scheduleLoops(ModuleOp moduleOp); + +/// Lower the loops to prepare them for pipeline expansion. +void lowerLoops(ModuleOp moduleOp); + +}; // namespace gpu + +/// This fill out the pipelining options including schedule and annotations +/// for wait ops. This also does pre-processing by converting some of the +/// loads into async loads so that the IR is ready to be pipelined. +bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// Fills out pipelining options for an outer loop pipelining case. This +/// schedules async copies to overlap with the epilogue of a loop. +bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// Pipeline the Tensor Core Gen 05 MMA ops in the module with `numStages` +/// stages. This will pre-process the loops, lowering the ops related to TG Gen5 +/// MMA, and then pipeline the loops using expander. +void pipelineTC05MMALoops(ModuleOp module, int numStages, + bool disableExpander = false); + +/// Pipeline the TMA stores in the loop. +bool pipelineTMAStores(scf::ForOp forOp); + +/// Simple pipelining for the MMA ops which accumulator is modified in the loop. +scf::ForOp pipelineMMAWithScaledAcc(scf::ForOp forOp); + +/// This does post-processing on the pipelined loop to try to pipeline wgmma +/// ops. +// TODO: this should be included as part of the pipeline but currently the wgmma +// wait modeling is problematic. +void asyncLaunchDots(scf::ForOp forOp); + +/// Post process the pipelined loop by updating the wait ops with the right +/// number of groups in flight. +void updateWaits(ModuleOp module); + +class CoarseSchedule { +public: + class ClusterList { + std::list orderClusters; + + public: + using iterator = decltype(orderClusters)::iterator; + using const_iterator = decltype(orderClusters)::const_iterator; + ClusterList() = default; + iterator begin() { return orderClusters.begin(); } + const_iterator begin() const { return orderClusters.begin(); } + iterator end() { return orderClusters.end(); } + const_iterator end() const { return orderClusters.end(); } + size_t size() { return orderClusters.size(); } + iterator newAtBack() { + orderClusters.push_back(orderClusters.size()); + return std::prev(orderClusters.end()); + } + iterator newAtFront() { + orderClusters.push_front(-1); + for (auto &clusterId : orderClusters) { + clusterId++; + } + return orderClusters.begin(); + } + iterator newBefore(iterator cluster) { + auto ret = orderClusters.insert(cluster, *cluster); + for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { + clusterId++; + } + return ret; + } + + bool isBefore(iterator a, iterator b) const { + for (auto it = begin(); it != end(); ++it) { + if (it == a) + return true; + if (it == b) + return false; + } + llvm::report_fatal_error( + "One or both clusters not found in clusters list!"); + } + }; + + CoarseSchedule() = default; + CoarseSchedule(int numStages) : numStages(numStages) {} + ClusterList clusters; + using Cluster = decltype(clusters)::iterator; + + DenseMap> opToStageAndCluster; + + void setNumStages(int numStages) { this->numStages = numStages; } + int getNumStages() { return numStages; } + + void insert(Operation *op, int stage, Cluster cluster) { + assert(stage < numStages && "Invalid stage"); + opToStageAndCluster[op] = {stage, cluster}; + } + + bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { + if (opToStageAndCluster.count(op)) + return false; + insert(op, stage, cluster); + return true; + } + + bool insertMinimum(Operation *op, int stage, Cluster cluster); + + bool insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg, bool insertIfEarlier = false); + + void erase(Operation *op) { opToStageAndCluster.erase(op); } + + int count(Operation *op) { return opToStageAndCluster.count(op); } + + std::pair operator[](Operation *op) { + return opToStageAndCluster[op]; + } + + auto find(Operation *op) const { return opToStageAndCluster.find(op); } + + SmallVector> + getOpsInOrder(scf::ForOp forOp); + std::vector> + createFinalSchedule(scf::ForOp forOp); + + bool empty() const { return opToStageAndCluster.size() == 0; } + auto end() const { return opToStageAndCluster.end(); } + auto begin() const { return opToStageAndCluster.begin(); } + + // Set based on CoarseSchedule. + void serialize(scf::ForOp &forOp); + // Create a CoarseSchedule based on forOp's . + LogicalResult deSerialize(scf::ForOp &forOp); + + LLVM_DUMP_METHOD void dump(); + +private: + int numStages = 0; +}; + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule); + +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h new file mode 100644 index 000000000..fbfa235fc --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonGPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class TritonGPUTypeConverter : public TypeConverter { +public: + TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp, + int numCTAs); + int getNumWarps() const { return numWarps; } + int getThreadsPerWarp() const { return threadsPerWarp; } + int getNumCTAs() const { return numCTAs; } + +private: + MLIRContext *context; + int numWarps; + int threadsPerWarp; + int numCTAs; +}; + +class TritonGPUConversionTarget : public ConversionTarget { + +public: + explicit TritonGPUConversionTarget(MLIRContext &ctx, + TritonGPUTypeConverter &typeConverter); +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Utility.h new file mode 100644 index 000000000..642e29ef7 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -0,0 +1,320 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include + +namespace mlir { + +namespace triton { +class ModuleAxisInfoAnalysis; +class LoadOp; +class StoreOp; +class FuncOp; +namespace gpu { +class SwizzledSharedEncodingAttr; +} +} // namespace triton + +// Return a tuple of two or three entries representing the shape of the +// instruction used to perform a matrix multiplication operation. +// Version = 1: +// Version = 2: <1, m, n> +// Version = 3: +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + Type type, int numWarps); + +// Return true if the Load uses block pointer. +bool isLoadFromTensorPtr(triton::LoadOp op); + +// Return an array of indices enumerating the elements of 'arr' in descending +// order (so that result[i] is the index of the i-th largest element of 'arr') +SmallVector argSort(const SmallVector &arr); + +// Return the operand used to access the memory in the operation +Value getMemAccessPtr(Operation *op); + +// Return bitwidth of tensor element +unsigned getElementBitWidth(RankedTensorType type); + +// Calculate the optimal number of elements per thread for a given operation +// along an axis with greatest continuity. +unsigned +getNumElementsPerThread(Operation *op, SmallVector order, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis); + +// Returns whether the op is a "view op", i.e. doesn't move any data +bool isView(Operation *op); + +/* Dump Triton IR in graphviz dot format. + * + * You can override `onValue` and `onOperation` in a subclass to mark + * specific Values and Operations. The below subclass + * GraphLayoutMarker is an example. + * + * Default NodeInfo for Value nodes: + * {{"shape": "box"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", shapeStr}} + * + * Default NodeInfo for Operation nodes: + * {{"shape": "ellipse"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", operationName}} + * + * If the key "label" is not set by `onValue` or `onOperation`, default labels + * will be generated. For Value node, the default label is the shape string and + * for Operation node, it is the operation name. + * + * Reference: + * https://graphviz.org/doc/info/shapes.html + * https://graphviz.org/doc/info/colors.html + * + * Usage: + * C++: GraphDumper().dumpToFile(func, "func.dot"); + * Shell: dot -Tjpg func.dot -o func.jpg + */ +class GraphDumper { +public: + using NodeInfo = std::map; + + // Override this function to mark specific Values + virtual NodeInfo onValue(Value value) const; + // Override this function to mark specific Operations + virtual NodeInfo onOperation(Operation *op) const; + + std::string dump(triton::FuncOp func) const; + void dumpToFile(triton::FuncOp func, const std::string &filename) const; + +protected: + std::string getShapeStr(const Type &type) const; + + std::string getUniqueId(Value value) const; + std::string getUniqueId(Operation *op) const; + + std::string emitNode(const std::string &id, const NodeInfo style) const; + std::string emitEdge(const std::string &srcId, + const std::string &destId) const; + + std::string emitValueNode(Value value) const; + std::string emitOperationNode(Operation *op) const; +}; + +/* A subclass of GraphDumper that marks different layout kinds in different + * colors.*/ +class GraphLayoutMarker : public GraphDumper { +public: + NodeInfo onValue(Value value) const override; + +protected: + std::string getColor(const Type &type) const; +}; + +// Infers the encoding of the result of op given the source encoding. +Attribute inferDstEncoding(Operation *op, Attribute encoding); + +// Infers the encoding of the source of op given the result encoding. +Attribute inferSrcEncoding(Operation *op, Attribute encoding); + +bool isExpensiveLoadOrStore(Operation *op); + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); + +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separately for the loop to be correct. +scf::ForOp replaceForOpWithNewSignature( + RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements); +scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, + ValueRange newIterOperands); + +// Replace WhileOp with a new WhileOp with extra operands. The YieldOp is not +// updated and needs to be updated separately for the loop to be correct. +scf::WhileOp replaceWhileOpWithNewSignature( + RewriterBase &rewriter, scf::WhileOp loop, ValueRange newIterOperands, + TypeRange newResultTypes, + SmallVectorImpl> &replacements); +scf::WhileOp replaceWhileOpWithNewSignature(RewriterBase &rewriter, + scf::WhileOp loop, + ValueRange newIterOperands, + TypeRange newResultTypes); + +// Replace IfOp with a new IfOp with extra results operands. The YieldOp is not +// updated and needs to be updated separately for the bodies to be correct. +scf::IfOp replaceIfOpWithNewSignature( + RewriterBase &rewriter, scf::IfOp loop, TypeRange newResultTypes, + SmallVectorImpl> &replacements); +scf::IfOp replaceIfOpWithNewSignature(RewriterBase &rewriter, scf::IfOp ifOp, + TypeRange newResultTypes); + +// Append the given |newOperands| to the |forOp|'s yield op. +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands); + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping); + +// Get backward slice of tensor values starting from the root node along with +// encoding propagation. +LogicalResult getConvertBackwardSlice( + OpOperand &root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation = nullptr, + std::function getExistingConversion = + nullptr); + +// Populate pattern to remove dead cycles in ForOp. +// opsCanBeTriviallyDead specifies the operations of which the side effect can +// be ignored. +void populateForOpDeadArgumentElimination( + RewritePatternSet &patterns, DenseSet &opsCanBeTriviallyDead); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(OpBuilder &b, Location loc, unsigned linear, + ArrayRef shape); + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape); +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape); + +// Return true if the op is a pure elementwise_inline_asm op with a single +// operand and single result. +bool isPureUnaryInlineAsm(Operation *op); + +// read the compute capability from the module attributes +int getNVIDIAComputeCapability(Operation *module); + +// Read the amd target from the module attributes +StringRef getAMDArch(Operation *module); + +std::optional +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible); + +// Convert \param op operands and results to layout \param encoding. +void convertOpEncoding(Attribute encoding, Operation *op); + +// Returns the original memory allocation for a memdesc value +triton::gpu::LocalAllocOp findShmemAlloc(Value operand); + +// Returns MMAs inside a for loop that are multi-buffered for pipeline analysis +SmallVector +getMMAsWithMultiBufferredOperands(scf::ForOp forOp, + SmallVector &mmaOps); + +// 0 is reserved for default sync. +// TODO: comprehensive mechanism to globally manage namedbarrier. +static int const nameBarrierIdBegin = 1; +static int nameBarrierIdEnd = 16; + +/// Helper functions for async task +typedef int AsyncTaskId; +SmallVector getAsyncTaskIds(Operation *op); +bool hasAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId); +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds); +SmallVector getNestedAsyncTaskIds(Operation *op); +void addAsyncTaskIds(Operation *op, ArrayRef asyncTasks); +void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId); +void removeAsyncTaskIds(Operation *op); + +class OpBuilderWithAsyncTaskIds : public OpBuilder { +public: + OpBuilderWithAsyncTaskIds(MLIRContext *context) : OpBuilder(context) {} + + explicit OpBuilderWithAsyncTaskIds(Operation *op) : OpBuilder(op) { + setAsyncTaskIdsFromOp(op); + } + + void setAsynTaskIdsFromArray(ArrayRef newAsyncTaskIds) { + asyncTaskIds = SmallVector(newAsyncTaskIds.begin(), + newAsyncTaskIds.end()); + } + + void setAsyncTaskIdsFromOp(Operation *op) { + setAsynTaskIdsFromArray(getAsyncTaskIds(op)); + } + + void setAsyncTaskIdsFromValueUsers(Value value) { + SetVector asyncTaskIdSet; + for (Operation *user : value.getUsers()) + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user)) + asyncTaskIdSet.insert(asyncTaskId); + setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef()); + } + + template + OpTy createWithAsyncTaskIds(Args &&...args) { + OpTy op = create(std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(op, asyncTaskIds); + return op; + } + +private: + SmallVector asyncTaskIds; +}; + +class PatternRewriterWithAsyncTaskIds { +public: + PatternRewriterWithAsyncTaskIds(PatternRewriter &rewriter, Operation *op) + : rewriter(&rewriter) { + setAsyncTaskIdsFromOp(op); + } + + void setAsynTaskIdsFromArray(ArrayRef newAsyncTaskIds) { + asyncTaskIds = SmallVector(newAsyncTaskIds.begin(), + newAsyncTaskIds.end()); + } + + void setAsyncTaskIdsFromOp(Operation *op) { + setAsynTaskIdsFromArray(getAsyncTaskIds(op)); + } + + void setAsyncTaskIdsFromValueUsers(Value value) { + SetVector asyncTaskIdSet; + for (Operation *user : value.getUsers()) + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user)) + asyncTaskIdSet.insert(asyncTaskId); + setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef()); + } + + template + OpTy create(Location location, Args &&...args) { + OpTy op = rewriter->create(location, std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(op, asyncTaskIds); + return op; + } + + template + OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { + auto newOp = + rewriter->replaceOpWithNewOp(op, std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(newOp, asyncTaskIds); + return newOp; + } + +private: + PatternRewriter *rewriter; + SmallVector asyncTaskIds; +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..3f974d2c0 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -0,0 +1,24 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttng) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttng) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttng) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttng) +add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonNvidiaGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td) +mlir_tablegen(TritonNvidiaGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOpInterfaces.td) +mlir_tablegen(TritonNvidiaGPUOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TritonNvidiaGPUOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(TritonNvidiaGPUOpInterfacesIncGen) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h new file mode 100644 index 000000000..a807cca0b --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonNvidiaGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc" +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc" + +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc" + +namespace mlir::triton::nvidia_gpu { + +struct TensorMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +struct TMemAllocation { + TMemAllocation(int numCols, int numRows) + : numRows(numRows), numCols(numCols) {} + int numRows; + int numCols; +}; + +TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType); + +Attribute getTmemCompatibleLayout(unsigned M, unsigned N, + ArrayRef shape, unsigned numWarps, + triton::gpu::CTALayoutAttr ctaLayout); + +bool isDistributedLayoutTMemCompatible(Operation *op, + RankedTensorType tensorType, + gpu::MemDescType memType); + +} // namespace mlir::triton::nvidia_gpu + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td new file mode 100644 index 000000000..0d1b45e7f --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td @@ -0,0 +1,84 @@ +#ifndef TRITONNVIDIAGPU_ATTRDEFS +#define TRITONNVIDIAGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +// Type for synchronization tokens. +def TT_TokenLoadTypeAttr : I32EnumAttr< + "TokenLoadType", "", + [ + I32EnumAttrCase<"None", 0, "none">, + I32EnumAttrCase<"AsyncLoadOp", 1, "asyncLoadOp">, + I32EnumAttrCase<"TMALoadOp", 2, "tmaLoadOp">, + I32EnumAttrCase<"LocalStoreOp", 3, "localStoreOp">, + ]>{ + let cppNamespace = "::mlir::triton::nvidia_gpu"; +} + +def TTG_TensorMemorySpace : AttrDef { + let mnemonic = "tensor_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to tensor memory. + The memory is laid out in blocks of size blockM x blockN. Each block is distributed + across TMEM 128 rows. + + Blocks are distributed along M dimension first and then N dimension. This is an arbitrary + convention that need to be followed operations reading/writing to TMEM. + + a tensor <128x128xf32> with blockM = 64 and blockN = 64 will be distributed as follows: + + \ col 0 1 31 32 64 96 127 + rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) (64, 0) ... (0, 64) ... (64, 64) ... (64, 96) + 1 + ... + 15 (15, 0) (15, 1) ... (15, 31) (79, 0) ... (15, 64) ... (79, 64) ... (79, 96) + 16 ( 0, 32) ( 0, 33) ... ( 0, 63) (64, 32) ... ( 0, 96) ... (64, 96) ... (64, 127) + ... + 31 (15, 32) (15, 33) ... (15, 63) (79, 32) ... (15, 96) ... (79, 96) ... (79, 127) + 32 (16, 0) (16, 1) ... (16, 31) (80, 0) ... (16, 64) ... (80, 64) ... (80, 96) + ... + 127 (63, 32) (63, 33) ... (63, 63) (127, 32) ... (63, 96) ... (127, 96)... (127, 127) + }]; +} + +def TTG_TensorMemoryEncodingAttr : AttrDef { + let mnemonic = "tensor_memory_encoding"; + let attrName = "triton.gpu.tensor_memory_encoding"; + let description = [{ + An encoding to represent the different way the tensor memory is laid out. + `unpacked` attributes indicates whether types smaller than 32bits are unpacked (take full 32bits) + or are packed (N elements are stored within one 32bits row). + }]; + let parameters = ( + ins + "unsigned":$blockM, + "unsigned":$blockN, + "bool":$unpacked, + DefaultValuedParameter<"unsigned", "1">:$CTASplitM, + DefaultValuedParameter<"unsigned", "1">:$CTASplitN + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def TTG_TensorMemoryScalesEncodingAttr : AttrDef { + let mnemonic = "tensor_memory_scales_encoding"; + let attrName = "triton.gpu.tensor_memory_scales_encoding"; + let description = [{ + An encoding to represent the layout of tensor memory scales. + As described in the PTX doc, blocked scales in TMEM must be in a special layout. They are organized + as a multiple copies of "chunk", each of which having the size 32x4x4B. Moreover, such chunks are duplicated + over 4 warps to fill entire 128 rows of TMEM. This encoding indicates that a tensor in TMEM is in such a special + layout. + }]; + let parameters = ( + ins + DefaultValuedParameter<"unsigned", "1">:$CTASplitM, + DefaultValuedParameter<"unsigned", "1">:$CTASplitN + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td new file mode 100644 index 000000000..b7c9cee0a --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -0,0 +1,67 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_DIALECT +#define TRITONNVIDIAGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonNvidiaGPU_Dialect : Dialect { + let name = "ttng"; + + let cppNamespace = "::mlir::triton::nvidia_gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton Nvidia GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "triton::gpu::TritonGPUDialect", + "mlir::gpu::GPUDialect", + ]; + + let extraClassDeclaration = [{ + static std::string getNumWarpsAttrName() { return "ttg.num-warps"; } + static int getNumWarps(ModuleOp mod) { + if(!mod->hasAttr("ttg.num-warps")) + llvm::report_fatal_error( + "TritonGPU module should contain a ttg.num-warps attribute"); + return cast(mod->getAttr("ttg.num-warps")).getInt(); + } + static int getNumCTAs(ModuleOp mod) { + if(!mod->hasAttr("ttg.num-ctas")) + llvm::report_fatal_error( + "TritonGPU module should contain a ttg.num-ctas attribute"); + return cast(mod->getAttr("ttg.num-ctas")).getInt(); + } + void registerTypes(); + }]; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; + let useDefaultTypePrinterParser = 1; +} + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td" + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td new file mode 100644 index 000000000..2943ba59d --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td @@ -0,0 +1,42 @@ +#ifndef TRITON_NVIDIAGPU_OP_INTERFACES +#define TRITON_NVIDIAGPU_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> { + let description = [{ + This interface is implemented by MMAv5 dot and dot scaled ops. + }]; + + let cppNamespace = "::mlir::triton::nvidia_gpu"; + + // We can add more methods as needed. + let methods = [ + InterfaceMethod<"Return the accumulator init flag.", + "::mlir::Value", + "useAccumulator">, + InterfaceMethod<"Set the accumulator init flag.", + "void", + "setUseAccumulator", + (ins "::mlir::Value":$flag)>, + InterfaceMethod<"Associate a new barrier to this MMAv5 op.", + "void", + "setBarrier", + (ins "::mlir::Value":$barrier)>, + InterfaceMethod<"Return the accumulator.", + "::mlir::Value", + "getAccumulator">, + InterfaceMethod<"Set the accumulator.", + "void", + "setAccumulator", + (ins "::mlir::Value":$accum)>, + InterfaceMethod<"Return the predicate of this op.", + "::mlir::Value", + "getPredicate">, + InterfaceMethod<"Set the predicate of this op.", + "void", + "setPredicate", + (ins "::mlir::Value":$pred)>, + ]; +} +#endif // TRITON_NVIDIAGPU_OP_INTERFACES diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td new file mode 100644 index 000000000..1ae05a7f3 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -0,0 +1,607 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_OPS +#define TRITONNVIDIAGPU_OPS + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +class TTNG_Op traits = []> : + Op { +} + +def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments, + MemoryEffects<[MemWrite]>]> { + let summary = "mbarrier arrive"; + + let description = [{ + This operation defining the arriving action for a mbarrier. + txCount: + An optional attribute that set tx-count. This Op will be lowered into + mbarrier.arrive.expect_tx if the optional attribute exist. + trackAsyncOp: + If true, this op will be lowered into cp.async.mbarrier.arrive.noinc. + pred: + Only perform arrive action when pred is true. + remoteCtaId: + if set, perform an remote arrive action. + + Example: + + triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr + + }]; + + let arguments = (ins TTG_MemDescType:$mbarrier, + Optional:$pred, + Optional:$remoteCtaId, + I1Attr: $trackAsyncOp, + DefaultValuedAttr: $txCount + ); + + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { + let arguments = (ins BoolAttr:$bCluster); + + let summary = "fence proxy async"; + + let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 90; + } + }]; +} + +def TTNG_GetCanonicalWarpIdOp : TTNG_Op<"get_canonical_warp_id", [Pure]> { + let description = [{ + Returns the one dimensional warpId when it's used for producing warp uniform values. + }]; + + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> { + let summary = "named barrier arrive"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> { + let summary = "named barrier wait"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> { + let arguments = (ins I1Attr:$relaxed); + let assemblyFormat = "attr-dict"; +} + +def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> { + let assemblyFormat = "attr-dict"; +} + +// +// WarpGroupDot Op +// +def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "warp group dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp + }]; + + let arguments = (ins TTG_TensorOrMemDesc:$a, + TTG_TensorOrMemDesc:$b, + TT_FpIntTensor:$c, + Optional:$useC, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc, + DefaultValuedAttr:$isAsync); + + let results = (outs TT_FpIntTensor:$d); + + let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)"; + + let extraClassDeclaration = [{ + bool needsPartialAccumulator(); + }]; +} + +def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods, + AllTypesMatch<["inputs", "outputs"]>]> { + let summary = "warp group dot wait"; + let arguments = (ins Variadic:$inputs, I32Attr:$pendings); + let results = (outs Variadic:$outputs); + let description = [{ + Waits until there are $pendings or fewer outstanding async dot operations. + + $inputs must be the tensors corresponding to the async dot ops that we're + waiting on. For example, if there are N pending async dot ops and we call + `warp_group_dot_wait 1`, then $inputs must be the result of the first dot op. + }]; + + let assemblyFormat = "$inputs attr-dict `:` type($inputs)"; +} + +def TTNG_InitBarrierOp : TTNG_Op<"init_barrier", [DeclareOpInterfaceMethods]> { + let summary = "Initialize a barrier in the given shared memory allocation."; + + let description = [{ + Initializes a shared memory allocation with mbarrier information. + `alloc` is a descriptor to the shared memory allocation. `count` is the + number of arrives expected by the barrier. + + This lowers to PTX mbarrier.init.shared::cta.b64. + }]; + + let hasVerifier = 1; + let arguments = (ins TTG_MemDescType:$alloc, + I32Attr:$count); + let assemblyFormat = "$alloc `,` $count attr-dict `:` qualified(type($alloc))"; +} + +def TTNG_InvalBarrierOp : TTNG_Op<"inval_barrier", [DeclareOpInterfaceMethods]> { + let summary = "Invalidate a barrier allocation."; + + let description = [{ + Invalidate a barrier allocation so that it can be re-used. According to PTX + spec this has to be done before any reuse of the memory used by mbarrier. + + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval + }]; + + let hasVerifier = 1; + let arguments = (ins TTG_MemDescType:$alloc); + let assemblyFormat = "$alloc attr-dict `:` qualified(type($alloc))"; +} + +def TTNG_BarrierExpectOp : TTNG_Op<"barrier_expect", [DeclareOpInterfaceMethods]> { + let summary = "Signal a barrier of an expected number of bytes to be copied."; + + let description = [{ + This signal the barrier that `size` bytes are expected to be copied. The + associated barrier wait will block until the expected number of bytes are copied. + }]; + + let hasVerifier = 1; + let arguments = ( + ins TTG_MemDescType:$alloc, + I32Attr:$size, + I1:$pred + ); + + let assemblyFormat = [{ + $alloc `,` $size attr-dict `,` $pred `:` qualified(type($alloc)) + }]; +} + +def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [ + DeclareOpInterfaceMethods, + AttrSizedOperandSegments]> { + let summary = "wait until the mbarrier phase completes."; + + let description = [{ + Blocks the program progress until the mbarrier object in `alloc` completes + its current phase. + + This lowers a waitloop using PTX instruction + mbarrier.try_wait.parity.shared.b64. + + Accepts optional list of memory. If present, it is assumed that any of the + dependencies may be accessed until the barrier completes. + + The barrier behavior is described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms + }]; + + let hasVerifier = 1; + let arguments = (ins TTG_MemDescType:$alloc, + I32:$phase, + Optional:$pred, + Variadic:$deps); + let builders = [ + OpBuilder<(ins "Value":$alloc, "Value":$phase), + [{ + build($_builder, $_state, alloc, phase, /*pred=*/static_cast(nullptr), /*deps=*/{}); + }]>, + OpBuilder<(ins "Value":$alloc, "Value":$phase, "Value":$pred), + [{ + build($_builder, $_state, alloc, phase, pred, /*deps=*/{}); + }]>, + OpBuilder<(ins "Value":$alloc, "Value":$phase, "ValueRange":$deps), + [{ + build($_builder, $_state, alloc, phase, /*pred=*/static_cast(nullptr), deps); + }]>, + ]; + let assemblyFormat = "$alloc `,` $phase attr-dict (`,` $pred^)? (`deps` $deps^)? `:` qualified(type($alloc)) (`,` type($deps)^)?"; +} + +def TTNG_TensorDescToTMAPtrOp : TTNG_Op<"tensor_desc_to_tma_ptr", [Pure]> { + let summary = "Convert tensor descriptor to pointer to tma descriptor"; + + let arguments = (ins TT_TensorDescType:$desc); + let results = (outs TT_Ptr:$ptr); + + let assemblyFormat = [{ + $desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr)) + }]; + + let builders = [ + OpBuilder<(ins "Value":$desc), [{ + auto ptrTy = triton::PointerType::get($_builder.getI8Type(), 1); + build($_builder, $_state, ptrTy, desc); + }]> + ]; + + let hasCanonicalizeMethod = 1; +} + + +def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [DeclareOpInterfaceMethods]> { + let summary = "copy data based on descriptor from global memory to local memory asynchronously"; + + let description = [{ + This operation copies data from global memory to local memory + asynchronously. This is analogue to tt.load except the data are copied to + local memory pointed by the memory descriptor instead of a distributed + tensor. The data copied depends on the global memory descriptor pointed to + by `desc_ptr`. + }]; + + let hasVerifier = 1; + let arguments = ( + ins TT_PtrType:$desc_ptr, + Variadic:$coord, + TTG_MemDescType:$barrier, + TTG_MemDescType:$result, + I1:$pred, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let assemblyFormat = [{ + $desc_ptr `[` $coord `]` $result `,` $barrier `,` $pred + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` qualified(type($desc_ptr)) `,` qualified(type($barrier)) `->` qualified(type($result)) + }]; +} + +def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global", [DeclareOpInterfaceMethods]> { + let summary = "copy data based on descriptor from local memory to global memory asynchronously"; + + let description = [{ + This operation copies data from local memory to global memory + asynchronously. This is analogue to tt.store except the data are copied from + local memory pointed by the memory descriptor instead of a distributed + tensor. The data copied depends on the global memory descriptor pointed to + by `desc_ptr`. + }]; + + let arguments = ( + ins TT_PtrType:$desc_ptr, + Variadic:$coord, + TTG_MemDescType:$src); + + let assemblyFormat = [{ + $desc_ptr `[` $coord `]` $src + attr-dict `:` qualified(type($desc_ptr)) `,` qualified(type($src)) + }]; +} + +def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather", [DeclareOpInterfaceMethods]> { + let summary = "gather data based on descriptor from global memory to local memory asynchronously"; + + let description = [{ + This operation gathers multiple rows of data from global memory matrix to + local memory asynchronously. This is similar to + async_tma_copy_global_to_local except that each row is indexed independently. + }]; + + let arguments = (ins + TT_PtrType:$desc_ptr, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset, + TTG_MemDescType:$barrier, + TTG_MemDescType:$result, + I1:$pred + ); + + let assemblyFormat = [{ + $desc_ptr `[` $x_offsets `,` $y_offset `]` $result `,` $barrier `,` $pred + attr-dict `:` type(operands) + }]; + + let hasVerifier = 1; +} + +def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter", [DeclareOpInterfaceMethods]> { + let summary = "scatter data from local memory into global memory based on a descriptor asynchronously"; + + let description = [{ + The `ttng.async_tma_scatter` operation scatters multiple separately-indexed + rows of data from local memory into global memory asynchronously. The + operation scatters a 2D tensor in shared memory, laid out by core tensor + tiles nvmma_shared layout into separately indexed rows in global + memory at a given `y` offset. + }]; + + let arguments = (ins + TT_PtrType:$desc_ptr, + RankedTensorOf<[I32]>:$x_offsets, + I32:$y_offset, + TTG_MemDescType:$src + ); + + let assemblyFormat = [{ + $desc_ptr `[` $x_offsets `,` $y_offset `]` $src + attr-dict `:` type(operands) + }]; +} + +def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> { + let summary = "wait until all the inputs are read."; + let arguments = (ins I32Attr:$pendings); + let description = [{ + Wait until all the read operations are done from the associated store operations. + This is needed before the shared memory can be written to. + }]; + + let assemblyFormat = "attr-dict"; +} + +def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "block level op mapping to tensorcore gen5 mma"; + + let description = [{ + $d += matrix_multiply($a, $b). + If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier. + If there is a barrier the result will be safe to read after a barrier wait. + If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs. + and syncronize both CTAs if the op is synchronous. + }]; + + let arguments = (ins TTG_MemDescType:$a, + TTG_MemDescType:$b, + TTG_MemDescType:$d, + I1:$useD, + I1:$pred, + Optional:$barrier, + OptionalAttr:$two_ctas); + + // TODO: improve printing format. + let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)"; +} + +def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "block level op mapping to tensorcore gen5 mma"; + + let description = [{ + $d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale)) + If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier. + If there is a barrier the result will be safe to read after a barrier wait. + }]; + + let arguments = (ins TTG_MemDescType:$a, + TTG_MemDescType:$b, + TTG_MemDescType:$d, + TTG_MemDescType:$a_scale, + TTG_MemDescType:$b_scale, + TT_ScaleDotElemTypeAttr:$a_type, + TT_ScaleDotElemTypeAttr:$b_type, + I1:$useD, + I1:$pred, + Optional:$barrier); + + // TODO: improve printing format. + let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)"; +} + +def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load", [MemoryEffects<[MemRead]>]> { + let summary = "Load a buffer from tensor memory into a distributed tensor"; + + let description = [{ + This is similar to ttg.local_load except the result layout is restricted to only few possibility. + Therefore we cannot combine this op with any convert layout like local_load. + }]; + let arguments = (ins TTG_MemDescType:$src); + + let assemblyFormat = [{$src attr-dict `:` qualified(type($src)) `->` type($result)}]; + let results = (outs TT_Tensor:$result); + let hasVerifier = 1; +} + +def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store", [MemoryEffects<[MemWrite]>]> { + let summary = "Store a distributed tensor into a buffer in tensor memory"; + + let description = [{ + This is similar to ttg.local_local except the source layout is restricted to only few possibility. + }]; + let arguments = (ins TTG_MemDescType:$dst, TT_Tensor:$src, I1:$pred); + + let assemblyFormat = [{$src `,` $dst `,` $pred attr-dict `:` type($src) `->` qualified(type($dst))}]; + let hasVerifier = 1; +} + +def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [DeclareOpInterfaceMethods]> { + let summary = "allocate tensor memory"; + let description = [{ + This operation allocates buffer in tensor memory and return a descriptor + containing the address and a view of the buffer. + This is similar to ttg.local_alloc except the buffer is allocated in tensor memory. + + Explicitly deallocating a buffer is optional; see local_dealloc. + }]; + let arguments = (ins Optional:$src); + + let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; + + let results = (outs TTG_MemDescType:$result); + let hasVerifier = 1; +} + +def TTNG_TMEMCopyOp : TTNG_Op<"tmem_copy", [DeclareOpInterfaceMethods]> { + let summary = "Initiate an asynchronous copy operation from shared memory to the Tensor Memory."; + + let description = [{ + 2D blocks stored contiguously in SMEM are copied into TMEM as specified by the destination address. + The completion of the copy can be observed by waiting on the optional barrier. If this op is used + together with an MMA op, one barrier can be used to wait for both copy and MMA. We do not need to wait + for the completion of the copy before MMA, since tcgen05.cp followed by tcgen05.mma is guaranteed to + execute in that order. + + This op lowers to the PTX instruction tcgen05.cp. Right now, we only support 1CTA and the warpx4.32x128b + variant of the instruction. Each 32x128b block in SMEM is duplicated over 4 warps and stored into 128 rows + and 4 columns of TMEM. The primary use case of this op is to copy blocked scales from SMEM to TMEM. + + The shape of the input SMEM can be flexibily chosen depending on use cases. In the simplest case (e.g. unit test), + the source SMEM can be of shape (32 x num_blocks, 16), and the destination TMEM should be of shape (128, 16 x num_blocks), + for copying 8 bit values. For scaled GEMM, rep_m x rep_k copies of a 32x128b block need to be stored in SMEM, where + rep_m = BLOCK_M / 128, rep_k = BLOCK_K / scale_vec_size / 4, and scale_vec_size = 32 for MXFP. + Conceptually, the SMEM is organized in a high-dimensional layout, (rep_m, rep_k, 32, 4, 4B). + Some of axes can be flattened into one, to reduce the rank of the load. For example, the following patterns are supported: + * (rep_m, rep_k * 32 x 4 x 4B), 2D scale load with cp.async + * (rep_m, rep_k, 32, 16B), 4D scale load with TMA + * (rep_m, rep_k, 32, 4, 4B), 5D scale load with cp.async + Since rep_m blocks are not contiguous in SMEM, this axis cannot be flattened into inner ones. + + In Triton, the TMEM memdesc for blocked scales must be of the following form: + * Its shape must be (BLOCK_MN, BLOCK_K / scale_vec_size), representing the logical shape of blocked scales. + * It must be attached with `tensor_memory_scales_encoding` to indicate the chunk-based layout and its duplication over 4 warps. + + In contrast, the src SMEM must be in the explicit chunk-based layout as described above. So the IR might look like this: + + %0 = ttng.tmem_alloc : () -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory> + ttng.tmem_copy %1, %0 : (!ttg.memdesc<1x1x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>) -> () + + We interpret the semantics of this copy operation as follows. The chunk-based layout in SMEM implies that + the logical shape (BLOCK_MN, BLOCK_K / scale_vec_size) in TMEM is the result of certain reshape and transpose operations. + In practice, to take an advantage of the native scale layout and the TMEM copy op, users need to do + `scales5D.trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // scale_vec_size)` before feeding scales into dot_scaled. + When we use tmem_copy in the IR, such reshape and transpose operations are removed. But the change in the logical shape they have caused on + registers is now understood to be incorporated into tmem_copy itself. Ideally, we would lift reshape / transpose done on registers onto + the SMEM memdesc, making tmem_copy a straightforward 2D copy operation: (BLOCK_MN, BLOCK_K / scale_vec_size) -> (BLOCK_MN, BLOCK_K / scale_vec_size). + In the absence of such operations on memdesc, we resort to implicitly encoding the reshape/transpose semantics in tmem_copy. + + }]; + let arguments = (ins TTG_MemDescType:$src, TTG_MemDescType:$dst, Optional:$barrier); + + let assemblyFormat = [{$src `,` $dst `,` $barrier attr-dict `:` functional-type(operands, results)}]; + let hasVerifier = 1; +} + +def TTNG_GetAsyncTaskIdOp : TTNG_Op<"get_async_task_id", [Pure]> { + let results = (outs I32:$result); + + let builders = [OpBuilder<(ins)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +// +// Token +// + +def TTNG_CreateTokenOp : TTNG_Op<"create_token"> { + let results = (outs TensorOf<[TTNG_TokenType]>:$result); + + let arguments = (ins I32Attr:$num, TT_TokenLoadTypeAttr:$loadType); + + let builders = [OpBuilder<(ins "uint32_t":$num, "triton::nvidia_gpu::TokenLoadType":$loadType)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1:$phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1: $phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> { + let summary = "register allocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + +def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> { + let summary = "register deallocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td new file mode 100644 index 000000000..d3126f8a0 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td @@ -0,0 +1,37 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_TYPES +#define TRITONNVIDIAGPU_TYPES + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTNG_TypeDef + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTNG_TokenType : TTNG_TypeDef<"Token", "token">; + +def TTNG_MutexType : TTNG_TypeDef<"Mutex", "mutex">; + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h new file mode 100644 index 000000000..63c7a091a --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITONNVIDIAGPU_IR_TYPES_H_ +#define TRITONNVIDIAGPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..d4b5c097f --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonNvidiaGPU) +add_public_tablegen_target(TritonNvidiaGPUTransformsIncGen) diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h new file mode 100644 index 000000000..4ddf2cebb --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +// Used by Triton runtime +struct ClusterInfo { + ClusterInfo() : clusterDimX(1), clusterDimY(1), clusterDimZ(1) {} + int clusterDimX; + int clusterDimY; + int clusterDimZ; +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + +namespace mlir { + +std::unique_ptr createTritonNvidiaGPUPlanCTAPass( + mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr); + +std::unique_ptr +createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90); + +std::unique_ptr createTritonNvidiaGPUTMALoweringPass(); + +std::unique_ptr createTensorMemoryAllocationPass(); + +std::unique_ptr createTritonNvidiaGPUMMALoweringPass(); + +std::unique_ptr createTritonNvidiaGPUKeepAccInTMemPass(); + +std::unique_ptr createTritonNvidiaGPUPromoteLHSToTMemPass(); + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_TRITONNVIDIAGPULEGALIZETMALAYOUTS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +} // namespace mlir +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td new file mode 100644 index 000000000..40278d2b8 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -0,0 +1,137 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_PASSES +#define TRITONNVIDIAGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> { + let summary = "plan CTA"; + + let description = [{ + This pass computes and applies "optimized" CTA tilings to DotOp, ReduceOp + and StoreLikeOps operations. + }]; + + let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> { + let summary = "Insert fences across generic and async proxy"; + + let description = [{ + This pass is to insert memory fences to ensure that memory operations are + properly ordered across generic and async operations. + }]; + + let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"90", + "device compute capability"> + ]; +} + +def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::ModuleOp"> { + let summary = "lower to TMA load/store operations"; + + let description = [{ + Lower Triton experimental descriptor load to TMA load/store operations in TritonNvidiaGPUDialect. + }]; + + let constructor = "mlir::createTritonNvidiaGPUTMALoweringPass()"; + + let dependentDialects = [ + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritionTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> { + let summary = "Assign tensor memory allocation"; + + let description = [{ + Decide on tensor memory allocation and assign attributes to each allocation. + }]; + + let constructor = "mlir::createTensorMemoryAllocationPass()"; + + let dependentDialects = [ + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonNvidiaGPUMMALoweringPass : Pass<"triton-nvidia-mma-lowering", "mlir::ModuleOp"> { + let summary = "lower mma operations if needed"; + + let description = [{ + Lower MMA ops to prepare for conversion to LLVM. + }]; + + let constructor = "mlir::createTritonNvidiaGPUMMALoweringPass()"; + + let dependentDialects = [ + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonNvidiaGPUKeepAccInTMemPass : Pass<"tritongpu-keep-acc-in-tmem", "mlir::ModuleOp"> { + let summary = "Keep accumulator in Tensor Memory"; + + let description = [{ + For Tensor Core Gen 05 Dot operations called in the loop, where the accumulator is reused + in the next iteration, we want to keep the accumulator in Tensor Memory, so that we can + avoid the cost of loading the accumulator from registers to Tensor Memory. + }]; + + let constructor = "mlir::createTritonNvidiaGPUKeepAccInTMemPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonNvidiaGPUPromoteLHSToTMemPass : Pass<"tritongpu-promote-lhs-to-tmem", "mlir::ModuleOp"> { + let summary = "Promote LHS operand of MMAv5 op to Tensor Memory"; + + let description = [{ + Promote LHS operand of MMAv5 op to Tensor Memory. + }]; + + let constructor = "mlir::createTritonNvidiaGPUPromoteLHSToTMemPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h new file mode 100644 index 000000000..8cb76d14d --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -0,0 +1,103 @@ +#pragma once +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::nvidia_gpu { + +constexpr inline int TMA_SIZE_BYTES = 128; +constexpr inline int TMA_ALIGN = 128; + +template +mlir::LogicalResult createTMADesc(mlir::Value tmaPtr, + mlir::triton::MakeTensorDescOp op, + BuilderT &builder) { + using namespace mlir; + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto mkI32Constant = [&](int32_t val) { + return builder.template create( + loc, builder.getI32Type(), builder.getI32IntegerAttr(val)); + }; + + auto elemType = op.getBase().getType().getPointeeType(); + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; + + int32_t contig_dim_size = op.getTensorShape().back(); + int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; + if (contig_dim_size_in_bytes > 128) { + contig_dim_size = 128 / elemSize; + } + llvm::SmallVector boxDim; + boxDim.push_back(mkI32Constant(contig_dim_size)); + for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { + boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); + } + + int32_t swizzle_mode; + if (contig_dim_size_in_bytes >= 128) { + swizzle_mode = 3; + } else if (contig_dim_size_in_bytes == 64) { + swizzle_mode = 2; + } else if (contig_dim_size_in_bytes == 32) { + swizzle_mode = 1; + } else { + op->emitError() + << "contiguous box dimension must be at least 32 bytes but got " + << contig_dim_size_in_bytes; + return failure(); + } + + Value elemSizeVal = builder.template create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize)); + + SmallVector globalDim(llvm::reverse(op.getShape())); + SmallVector globalStride; + for (int k = op.getStrides().size() - 2; k >= 0; --k) { + globalStride.push_back(op.getStrides()[k]); + } + + SmallVector elementStride(globalDim.size(), mkI32Constant(1)); + + for (int i = 0; i < globalStride.size(); ++i) + globalStride[i] = builder.template create( + loc, globalStride[i], elemSizeVal); + + int elemTypeEnum; + switch (elemSize) { + case 1: { + elemTypeEnum = 0; + break; + } + case 2: { + elemTypeEnum = 1; + break; + } + case 4: { + elemTypeEnum = 2; + break; + } + default: { + op->emitError() + << "Tensor descriptor element type must have size 1, 2, or 4 but got " + << elemSize; + return failure(); + } + } + + builder.template create( + loc, + /*desc_ptr=*/tmaPtr, + /*global_address=*/op.getBase(), + /*box_dim=*/boxDim, + /*global_dim=*/globalDim, + /*global_stride=*/globalStride, + /*element_strides=*/elementStride, + /*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum), + /*interleave_layout*/ builder.getI32IntegerAttr(0), + /*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode), + /*fill_mode=*/builder.getI32IntegerAttr(0)); + return success(); +} + +} // namespace mlir::triton::nvidia_gpu diff --git a/third_party/enflame/include/triton/include/triton/Target/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/enflame/include/triton/include/triton/Target/LLVMIR/CMakeLists.txt b/third_party/enflame/include/triton/include/triton/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..1f6c1b351 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVMIR) +add_public_tablegen_target(LLVMIRIncGen) diff --git a/third_party/enflame/include/triton/include/triton/Target/LLVMIR/Passes.h b/third_party/enflame/include/triton/include/triton/Target/LLVMIR/Passes.h new file mode 100644 index 000000000..27ecb5c3d --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Target/LLVMIR/Passes.h @@ -0,0 +1,17 @@ +#ifndef TRITON_TARGET_LLVM_IR_PASSES_H +#define TRITON_TARGET_LLVM_IR_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Create a pass to add DIScope +std::unique_ptr createLLVMDIScopePass(); + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "triton/Target/LLVMIR/Passes.h.inc" + +} // namespace mlir + +#endif // TRITON_TARGET_LLVM_IR_PASSES_H diff --git a/third_party/enflame/include/triton/include/triton/Target/LLVMIR/Passes.td b/third_party/enflame/include/triton/include/triton/Target/LLVMIR/Passes.td new file mode 100644 index 000000000..999b0b889 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Target/LLVMIR/Passes.td @@ -0,0 +1,15 @@ +#ifndef TRITON_TARGET_LLVMIR_PASSES +#define TRITON_TARGET_LLVMIR_PASSES + +include "mlir/Pass/PassBase.td" + +def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> { + let summary = "Materialize LLVM line info"; + let description = [{ + This pass materializes line mapping information for LLVM IR dialect operations. + }]; + + let constructor = "mlir::createLLVMDIScopePass()"; +} + +#endif diff --git a/third_party/enflame/include/triton/include/triton/Tools/LayoutUtils.h b/third_party/enflame/include/triton/include/triton/Tools/LayoutUtils.h new file mode 100644 index 000000000..1a5167edb --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Tools/LayoutUtils.h @@ -0,0 +1,109 @@ +#ifndef TRITON_TOOLS_LAYOUTUTILS_H +#define TRITON_TOOLS_LAYOUTUTILS_H + +#include "triton/Tools/LinearLayout.h" + +namespace mlir::triton { +// Is the sublayout defined from dimNames to dimNames the identity? +// In particular, is the input and output size in these dimensions +// the same, and are the bases the identity? +bool squareSublayoutIsIdentity(const LinearLayout &ll, + ArrayRef dimNames); + +// Is the sublayout defined from dimNames to dimNames a subpermutation matrix? +// I.e. the layout matrix is formed by selecting unique columns from the +// identity matrix and adding zero columns. A zero column in the layout means +// that changing a bit in the inputs does not change the bits of the outputs +// (broadcasting). +bool squareSublayoutIsPermutation(const LinearLayout &ll, + ArrayRef dimNames); + +// For each output dimension d, ensure that the layout's output size (i.e., its +// codomain) does not exceed shape[d]. Do this without changing the size of the +// layout's inputs (i.e., leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +// +// We achieve this by setting the largest value in each output dimension d to 0 +// because bases that map to a location larger than shape[d] +// effectively duplicate along that dimension. For example, consider a layout +// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to +// shrink the output dimension size to 8: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16 +// +// In the first step, we shrink the output dimension size to 16 by setting +// L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// This means that lane=2 has the same data as lane=0. +// +// Now the output dimension of this layout has a size of 16, which is still +// larger than 8. We find the current largest value in the output dimension, +// which is L(register=1) = 8, and we set L(register=1) to 0: +// +// L(register=1) = 0 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// Now the output dimension of this layout has a size of 8, which is the desired +// size. Note that this method works only because the bases are powers of two, +// which is the case for DistributedLayouts If broadcastRegisters is false, we +// remove any register that's larger than the desired shape. In the example +// above we would have +// L(register=1) = 4 +// L(register=2) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters = true); + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape); + +// Return a vector of the standard out dimension names for tensor layouts. These +// are "dim0", "dim1", etc. +SmallVector standardOutDimNames(MLIRContext *ctx, int rank); + +// Return a vector of the standard out dimension name/value pairs, i.e. +// ("dim0", dstShape[0]), ("dim1", dstShape[1]), etc. +SmallVector> +standardOutDimPairs(MLIRContext *ctx, ArrayRef dstShape); + +// Return an identity mapping from `inDimName` to the standard out dimensions, +// with the dimensions sized according to the shape. The bases are sorted +// according to `order`, with the most minor dimension first. +LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, + ArrayRef order); + +// Compute the supremum of two lists. +// Error out if the supremum does not exist (e.g. [a, b] and [b, a]). +// If the supremum is not unique, we return the first list first +// (e.g. [a, b], [a, c] -> [a, b, c]). +SmallVector supremum(const SmallVector &x, + const SmallVector &y); +} // namespace mlir::triton + +#endif // TRITON_TOOLS_LAYOUTUTILS_H diff --git a/third_party/enflame/include/triton/include/triton/Tools/LinearLayout.h b/third_party/enflame/include/triton/include/triton/Tools/LinearLayout.h new file mode 100644 index 000000000..d4bb11ed9 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Tools/LinearLayout.h @@ -0,0 +1,761 @@ +#ifndef TRITON_TOOLS_LINEARLAYOUT_H +#define TRITON_TOOLS_LINEARLAYOUT_H + +#include +#include +#include +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::triton { + +// # High-level overview of linear layouts +// +// The idea for linear layouts is due to Adam P. Goucher. +// +// In Triton, a linear layout (LL) is a function that maps from a "hardware +// location" to a "logical tensor index". +// +// For example, suppose we have a 2D tensor T stored in GPU registers. T's +// layout (i.e., L) is the function that, given a "hardware location" tuple of +// (thread-id, warp-id), returns an index (x,y) into T. In other words, if +// L(t,w) = (x,y) is our linear layout func, then a register in thread t in warp +// w contains the value T[x,y]. +// +// The key fact about LLs is, the mapping from (t,w) to (x,y) is not arbitrary. +// We only need to specify the value of L(t,w) at certain special points +// (namely, the values L(t,0) and L(0,w) where t and w are powers of 2), and +// from those we can compute all the other values of L. +// +// Here's an example LL where we have 4 warps and 4 threads per warp, and the +// tensor T has shape 4x4. We define the function L by choosing the values of +// L(0,1), L(0,2), L(1,0), and L(2,0). Our choices are shown below. +// +// t/w 0 1 2 3 +// 0 ? (0,1) (0,2) ? +// L(t,w) = 1 (1,1) ? ? ? +// 2 (2,2) ? ? ? +// 3 ? ? ? ? +// +// You only need to specify these four values to define the whole linear layout. +// These special values are called the "basis vectors" or "bases" of the layout. +// We complete the table by xor'ing together the bases, according to the +// following rule. (I write "⊕" for xor.) +// +// L(t1 ⊕ t2, w1 ⊕ w2) = L(t1, w1) ⊕ L(t2, w2) (linearity rule). +// +// The linearity rule plus our four choices allows us to fill in the whole +// table. Here's how we might compute some of the values. +// +// L(0,0) = L(1 ⊕ 1, 0 ⊕ 0) = L(1,0) ⊕ L(1,0) = (1,1) ⊕ (1,1) = (0,0) +// L(0,3) = L(0 ⊕ 0, 2 ⊕ 1) = L(0,2) ⊕ L(0,1) = (0,2) ⊕ (0,1) = (0,3) +// L(3,0) = L(2 ⊕ 1, 0 ⊕ 0) = L(2,0) ⊕ L(1,0) = (2,2) ⊕ (1,1) = (3,3) +// L(3,3) = L(3 ⊕ 0, 0 ⊕ 3) = L(3,0) ⊕ L(0,3) = (3,3) ⊕ (0,3) = (3,0). +// +// (Notice it's a consequence of the linearity rule that L(0,0) = (0,0), no +// matter what values we chose for the table.) +// +// The whole table looks like this. +// +// t/w 0 1 2 3 +// 0 (0,0) (0,1) (0,2) (0,3) +// L(t,w) = 1 (1,1) (1,0) (1,3) (1,2) +// 2 (2,2) (2,3) (2,0) (2,1) +// 3 (3,3) (3,2) (3,1) (3,0). +// +// Careful readers will recognize this as a classic "swizzled" layout where +// (t, w) -> (t, w ⊕ t). To go from this formula to an LL, you only need to +// compute the results at input points (0,1), (0,2), (1,0), and (2,0). + +// Indeed the whole point of LLs is that they allow us to specify transposed and +// swizzled layouts as a "general case". Instead of a layout class for +// registers in a thread, and another layout for registers in a thread but in +// MMAv2 order, and so on, all of these can be represented by different LLs. +// This gets rid of special cases and lets us write more general code. +// +// In this example, L was a 2D -> 2D function, but LLs are general MD -> ND +// functions. In practice, a GPU register layout usually has input dims (reg, +// thread-id, warp-id, block-id), where reg represents the fact that one thread +// may store values for the tensor in multiple registers. +// +// To summarize, a linear layout is a function from tuples of integers to tuples +// of integers. We specify some key values of the function, and then we can +// compute all the other values using the linearity rule. +// +// Here are the key things you can do with linear layout objects. +// +// 1. Given an LL, construct a new LL by modifying it or combining it with +// another LL. +// +// 2. "Apply" an LL, i.e. use it to map an input index to an output index. +// A function for this that uses LLVM-dialect MLIR as its input and output +// lives in TritonGPUToLLVM.h. +// +// 3. Convert an existing Triton layout (e.g. BlockedLayoutAttr) to an LL. +// These functions live in TritonGPU/LinearLayoutConversions.h. During +// TTGIR -> LLVM codegen, we convert Triton layouts to linear layouts and +// then apply them. In the future, we intend to remove the Triton layouts +// entirely. +// +// # Examples of linear layouts +// +// 1. The 1D identity layout. This maps L(x) = x. +// +// Recall that our bases are the values of L(x) where x is a power of two. +// So for e.g. an 8-element layout, we have L(1) = 1, L(2) = 2, L(4) = 4, and +// therefore our bases are [1, 2, 4]. +// +// 2. The 1D zeros layout. This maps L(x) = 0. +// +// For an 8-element layout, we have L(1) = L(2) = L(4) = 0, so our bases are +// [0, 0, 0]. +// +// 3. A 2D -> 2D identity layout. Our basis vectors are the values of L(x,0) +// and L(0,y) where x and y are powers of two. The bases are +// +// - L(0,1) = (0,1) +// - L(0,2) = (0,2) +// - L(1,0) = (1,0) +// - L(2,0) = (2,0). +// +// 4. A 2D -> 2D transpose layout. For a 4x4 layout, we have: +// +// - L(0,1) = (1,0) +// - L(0,2) = (2,0) +// - L(1,0) = (0,1) +// - L(2,0) = (0,2). +// +// 5. A 1D -> 1D "transpose" layout. Consider the 16-element layout that maps +// +// x = 0 1 2 3 4 5 6 7 8 9 A B C D E F +// L(x) = 0 4 8 C 1 5 9 D 2 6 A E 3 7 B F. +// +// The bases are [L(1), L(2), L(4), L(8)] = [4, 8, 1, 2]. You can also think +// of this as a rearrangement of the 1D identity layout [1, 2, 4, 8]. +// +// 6. A 2D -> 1D broadcasted layout. L(x,y) = x. For a 4x4 -> 4 layout, our +// bases are +// +// - L(0,1) = 0 +// - L(0,2) = 0 +// - L(1,0) = 1 +// - L(2,0) = 2. +// +// # Implementation notes +// +// ## Dimension order +// +// An LL's input and output dimensions have an order. This order only affects +// the reshapeIns/Outs and similar operations, where the layout is logically +// flattened according to the dimension order and then chopped up again. +// +// ## Surjectivity and injectivity +// +// Most LLs are surjective, i.e. all output values are covered by some input +// value. But occasionally you might create a non-surjective layout, usually +// via invertAndCompose. We aggressively assert that LLs are surjective unless +// you explicitly create one that's not. +// +// LLs are not, in general, injective. There might exist multiple input values +// that map to the same output value. This represents the idea that the same +// logical tensor elements can be stored in multiple places in the hardware. +// +// ## Why map hardware loc -> tensor index and not the other way around? +// +// In Triton, a linear layout usually tells us which logical tensor value is +// stored at a particular place in the hardware. For example, an LL might map +// the tuple (thread-id, warp-id, block-id) to a 2D index into a tensor, (x,y), +// meaning that the register at (t,w,b) has value tensor[x,y]. Or it might map +// from a shared memory (offset, block) to a tensor index. +// +// It might seem more natural to go the other way around, from tensor index to +// place in the hardware. But a particular tensor[x,y] value might be stored in +// more than one place in the hardware, so if we went in this direction, the +// layout would no longer be a proper function. This would complicate +// everything else. +// +// # Optional mathematical background: Linear functions over GF(2) +// +// (You shouldn't need to understand this math to use linear layouts, but it +// helps with the implementation.) +// +// One way to define a linear function is to say it's any function F that can be +// written as +// +// L(a) = a1 * B1 + a2 * B2 + ... + aM * BM, +// +// where +// +// - a is a vector [a1...aM], and ai is a scalar in some field 𝔽 (for +// example, ai might be a real number), and +// - each Bj is a vector [b1j, b1j, ..., bNj] of N scalars in 𝔽. +// +// We can also write this as a matrix-vector product Ba, where +// +// - a is the column vector [a1, ..., aM] and +// +// - B is the matrix formed by concatenating the column vectors B1, ..., BM: +// +// | ↑ ↑ ↑ | +// B = | B1, B2, ..., BM| +// | ↓ ↓ ↓ | +// +// |b11, b12, ..., b1M| +// |b21, b22, ..., b2M| +// = | ↓ ↓ ↓ | +// |bN1, bN2, ..., bNM|. +// +// Usually when we do linear algebra, the field 𝔽 from which `ai` and `bij` are +// drawn is the real or complex numbers. But in linear layouts, we let 𝔽 be a +// different field: GF(2). +// +// GF(2) is the two-element field of bits. To define a field, I need to give +// you the set of elements and also addition and multiplication operations. For +// GF(2) the elements are simply {0,1}. We define addition as xor, and +// multiplication as binary `and`. +// +// Here's an example of a 4x4 matrix-vector multiply where the elements are in +// GF(2). I'm using ⊕ to represent GF(2)'s addition operation (i.e xor) and × +// to represent multiplication (i.e. binary `and`). +// +// | 1 0 0 0 | | 0 | | 1 | | 0 | | 0 | | 0 | +// | 0 1 1 0 | | 1 | = | 0 | × 0 ⊕ | 1 | × 1 ⊕ | 1 | × 1 ⊕ | 0 | × 0 +// | 0 0 1 1 | | 1 | | 0 | | 0 | | 1 | | 1 | +// | 0 0 1 1 | | 0 | | 0 | | 0 | | 1 | | 1 | +// +// | 0 | | 0 | +// = | 1 | ⊕ | 1 | +// | 0 | | 1 | +// | 0 | | 1 | +// +// | 0 | +// = | 0 |. +// | 1 | +// | 1 | +// +// This works, but it's cumbersome. It's more compact to think of the vector +// `a` as an M-bit integer, and each column Bi of the matrix B as an N-bit +// integer. Here's the same matrix-vector product written this way. +// +// = | 1 2 14 12 | × 6 +// = | 1 2 14 12 | × 0b0110 +// = (1 × 0) ⊕ (2 × 1) ⊕ (14 × 1) ⊕ (12 × 0) +// = 2 ⊕ 14 +// = 12. +// +// And we confirm that our answer of 12 is equal to the binary value 0b1100 we +// got before. +// +// Notice that the function F(a) is fully specified by the matrix B, and that +// the four columns of B tell us the values of F at power-of-two values for `a`, +// namely F(1), F(2), F(4), and F(8). In other words, we specify four results +// of F(x) (we call these the function's "basis vectors" or its "bases") and we +// can then compute any other value by xor'ing together subsets of the bases. +// +// In the case of a 1D -> 1D layout, the implementation of an LL is +// straightforward from the mathematical description. If the LL is +// higher-dimensional, we can "stack" the bit vectors to create 1D vectors. +// For example, if we have a 2D LL and we're given input tuple (0b0011, 0b1100), +// we can treat this like a 1D input 0b0011'1100 and then do the regular 1D LL +// computation. Similarly we can "unstack" the output from 1D to ND. +// +// The linearity rule presented earlier is perhaps misleading at this point. In +// the 1D view of things, we really only need +// +// L(x ⊕ y) = L(x) ⊕ L(y) (1D linearity rule), +// +// which is part of the definition of L being a linear function. The new 1D +// linearity rule plus stacking/unstacking is equivalent to the earlier +// N-dimensional linearity rule. +// +// That's all we need in order to define linear layouts mathematically! +// +// # Comparison to Nvidia CuTe +// +// (Note, I'm not an expert on CuTe; this is my best understanding.) +// +// CuTe is a programmatic layout system that's part of Nvidia CUTLASS; see +// https://github.com/NVIDIA/cutlass/blob/629f465/media/docs/cute/00_quickstart.md +// +// LLs and CuTe solve similar problems. Before CuTe, CUTLASS v2 had many +// handcrafted layouts, "RowMajor", "VoltaTensorOpMultiplicandCongruous", etc, +// see https://www.youtube.com/watch?v=QLdUML5MCfE&t=574s. Each of these was a +// special case. CUTLASS v3 introduced CuTe layouts, which are programmable and +// subsume all of these special cases. The CUTLASS folks say this simplified +// CUTLASS, in the same way that we hope LLs will simplify Triton. +// +// Like CuTe layouts, LLs are also programmable and composable. But there are +// also some differences. +// +// - Dimensions in LLs are named; CuTe dimensions are numbered. +// - CuTe layouts can be nested; LLs cannot be. (Nesting doesn't give CuTe +// layouts additional power; any nested layout can be flattened.) +// - CuTe layouts support non-power-of-two shapes; LLs do not. In particular +// this means that LLs cannot represent padded layouts. +// - In CuTe, swizzling is a separate step applied after specifying a layout. +// In LLs, swizzling is part of the layout itself. +// - The structure of LLs allows us to programmatically search for layouts that +// satisfy certain requirements, for example a shared layout that doesn't +// have bank conflicts when read into a particular register layout. CuTe +// expects a human to choose the layout using their brain. +// - CuTe emits code that is in the critical path of your CPU and GPU programs, +// therefore it needs to be fast. It uses C++ template magic to specialize +// on known-sized dimensions, and so on. LLs themselves do not need to be +// fast; only the emitted `apply` code is on the critical path. +// - CuTe requires a CUDA compiler such as nvcc; LLs do not. +// +class LinearLayout { +private: + // bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0). All other values of L are + // computed by xor'ing bases together, using the linearity rule. In addition: + // + // - Each inDim has the same set of outDims, in the same order. + // - The order of dims is minor-to-major, although this only affects reshape. + llvm::MapVector /*size=getNumOutDims()*/> + /*size=getInDimSizeLog2(inDim)*/> + bases; + + llvm::MapVector outDims; + bool surjective; + +public: + using BasesT = decltype(bases); + + // The 0-dimensional layout that maps everything to 0. This is useful as a + // starting point when doing something like + // + // LinearLayout ret = LinearLayout::empty(); + // for (...) ret *= ...; + // return ret; + static LinearLayout empty() { return LinearLayout(BasesT{}, {}); } + + // Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x + // for x in [0, size). + static LinearLayout identity1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0 + // for x in [0, size). By default this creates a surjective layout where + // `outDim` has size 1 (the only element is 0). If `outDimSize` is specified + // to be greater than 1, then this creates a non-surjective layout with a + // specific size for `outDim`. + static LinearLayout zeros1D(int32_t size, StringAttr inDim, StringAttr outDim, + int32_t outDimSize = 1); + + // Creates a LinearLayout from a list of bases. These are interpreted + // according to the rules written for the member variable `bases`. + // + // Calculates the out-dim sizes according to the bases. Consider the + // following example. + // + // L(in1=1) = (out1=1, out2=0) + // L(in1=2) = (out1=5, out2=1) + // L(in1=4) = (out1=2, out2=2) + // + // To calculate the out-dim sizes, we first find the largest values for out1 + // and out2, namely 5 and 2, then round these up to the next power of 2, + // namely 8 and 4. These are the out-dim sizes. + // + // Assert-fails if the layout is not surjective given these out-dim sizes. + // That is, every possible out-dim in range [0, size) must be produced by + // xor'ing some combination of bases. + explicit LinearLayout(BasesT bases, ArrayRef outDimNames); + + // Creates a LinearLayout given a list of bases and the explicit out-dimension + // sizes. Allows the layout to be non-surjective. + // + // To see why we need to explicitly pass out-dim sizes when creating a + // non-surjective layout, consider the following example. + // + // L(in1=1) = 1 + // L(in1=2) = 4 + // + // If we naively infer the out-dim sizes from these bases, we'd infer a size + // of nextPow2(4) = 8. But given that the layout is non-surjective, who is to + // say that the codomain is not (say) [0,32)? We can't tell, thus we need to + // be explicit about the sizes. + explicit LinearLayout(BasesT bases, + ArrayRef> outDims, + bool requireSurjective); + + // Construct a LinearLayout from an explicit list of bases. (This constructor + // is needed because llvm::MapVector does not have a constructor that accepts + // an initializer_list.) + // + // For example, given these bases + // + // L(in1=1, in2=0) = (out1=0, out2=1) + // L(in1=2, in2=0) = (out1=0, out2=2) + // L(in1=0, in2=1) = (out1=0, out2=4) + // L(in1=0, in2=2) = (out1=0, out2=8) + // L(in1=0, in2=4) = (out1=1, out2=1) + // + // we can use this constructor to build an equivalent LL: + // + // LinearLayout({ + // {"in1", {/*L(in1=1)=*/{0,1}, /*L(in1=2)=*/{0,2}}}, + // {"in2", {/*L(in2=1)=*/{0,4}, /*L(in2=2)=*/{0,8}, /*L(in2=4)=*/{1,1}}}, + // }, + // {"out1", "out2"}) + // + // The overload that infers out-dim sizes assert-fails if the layout is not + // surjective. + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames); + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef> outDims, bool requireSurjective); + + bool isSurjective() const { return surjective; } + + bool isInvertible() const { + return surjective && getTotalInDimSize() == getTotalOutDimSize(); + } + + const BasesT &getBases() const { return bases; } + + // Get the pos'th basis vector for the inDim -> outDim mapping. + // getBasis(inDim, pos) = L(0, ..., inDim = 2^pos, ..., 0). + ArrayRef getBasis(StringAttr inDim, int32_t pos) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + assert(pos < it->second.size()); + return it->second[pos]; + } + + int32_t getBasis(StringAttr inDim, int32_t pos, StringAttr outDim) const { + return getBasis(inDim, pos)[getOutDimIndex(outDim)]; + } + + // These are in minor-to-major order, although if you don't flatten the dims + // (e.g. by reshaping) then the order doesn't really affect anything. + auto getInDimNames() const { return llvm::make_first_range(bases); } + auto getOutDimNames() const { return llvm::make_first_range(outDims); } + auto getOutDimSizes() const { return llvm::make_second_range(outDims); } + + // Gets the position that this outDim occupies in getOutDimNames(). Asserts + // if the dim is not present. + int32_t getOutDimIndex(StringAttr outDim) const; + + bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); } + bool hasOutDim(StringAttr outDim) const { return outDims.contains(outDim); } + + int32_t getNumInDims() const { return bases.size(); } + int32_t getNumOutDims() const { return outDims.size(); } + + // Asserts if the dimension is not present. + int32_t getInDimSizeLog2(StringAttr inDim) const; + int32_t getInDimSize(StringAttr inDim) const { + return 1 << getInDimSizeLog2(inDim); + } + + int32_t getTotalInDimSizeLog2() const; + int32_t getTotalInDimSize() const { return 1 << getTotalInDimSizeLog2(); } + + // getOutDimSize(dim) == s means that there exists an input value that will + // produce each output value in [0,s) (if the layout is surjective). + // + // For example, if our bases are + // + // L(in0=1) = 1 + // L(in0=2) = 4 + // L(in1=1) = 2 + // L(in1=2) = 8 + // + // then the largest value we can produce is L(3,3) = 1 ⊕ 4 ⊕ 2 ⊕ 8 = 15 (and + // indeed we can produce all values in [0,16) by xor'ing subsets of the bases + // 1,2,4,8), so getOutDimSize(out_dim0) == 16. + // + // Asserts if the dimension is not present. + int32_t getOutDimSizeLog2(StringAttr outDim) const; + int32_t getOutDimSize(StringAttr outDim) const { + return 1 << getOutDimSizeLog2(outDim); + } + + int32_t getTotalOutDimSizeLog2() const; + int32_t getTotalOutDimSize() const { return 1 << getTotalOutDimSizeLog2(); } + + // Finds the number of consecutive input elements in the first input dimension + // that map to consecutive output elements in the first output dimension. + // + // Mathematically, finds the maximum value V such that for any a, b, c, and + // for all v in [0,V), + // + // L(a*V + v, b, c, ...) = L(a*V, b, c, ...) + (v, 0, ..., 0) + // + // Note that's +, not ⊕, in the RHS. (Equivalently, we could use binary-or + // instead of +. In other words, we require that L(a*V, b, c, ...) have no + // bits that overlap with v.) + // + // For example, if L maps (register, lane) to (dim1, dim0), then this tells + // you how many consecutive registers map to consecutive elements of dim1. + // + // This only works across the first (i.e. the most-minor) dimension of in/out. + // If you want it to work across more dimensions, flatten the layout. + // + // TODO(jlebar): Replace with divideLeft. + int32_t getNumConsecutiveInOut() const; + + // Reorders the in/out dimensions of the layout. This is mostly cosmetic + // (affecting e.g. the order of getIn/OutDimNames), but it also affects the + // behavior of reshape. + [[nodiscard]] LinearLayout + transposeIns(ArrayRef newInDimOrder) const; + [[nodiscard]] LinearLayout + transposeOuts(ArrayRef newOutDimOrder) const; + + [[nodiscard]] LinearLayout reshapeIns( + ArrayRef> newInDims) + const; + + // Reshapes to a single input dim (named whatever our first in-dim is named). + [[nodiscard]] LinearLayout flattenIns() const { + if (getNumInDims() == 0) { + return reshapeIns({}); + } + return reshapeIns({{*getInDimNames().begin(), getTotalInDimSize()}}); + } + + [[nodiscard]] LinearLayout + reshapeOuts(ArrayRef> + newOutDims) const; + + // Reshapes to a single out dim (named whatever our first out-dim is named). + [[nodiscard]] LinearLayout flattenOuts() const { + if (getNumOutDims() == 0) { + return reshapeOuts({}); + } + return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}}); + } + + // Concatenates two layouts by their input dimensions. The layouts must have + // the same output dimensions and sizes and different input dimensions. The + // input dimensions of this layout are placed before those of 'other'. This + // can be thought of as the opposite of `sublayout`, which slices a layout + // from a larger one. + [[nodiscard]] LinearLayout concatIns(const LinearLayout &other) const; + // Concatenates two layouts by their output dimensions. The layouts must have + // the same input dimensions and sizes and different output dimensions. The + // output dimensions of this layout are placed before those of 'other'. This + // can be thought of as the opposite of `sublayout`, which slices a layout + // from a larger one. + [[nodiscard]] LinearLayout concatOuts(const LinearLayout &other) const; + + // Computes the direct sum of two layouts. + // https://en.wikipedia.org/wiki/Direct_sum#Direct_sum_of_matrices + // + // Roughly speaking, the first layout acts on the first part of the input + // dimensions, and the second layout acts on the second part. + // In other words, it's the generalisation of concatenation of the inputs + // to linear maps. + // + // Examples: + // + // - empty() is the multiplicative identity: + // + // L * empty() == empty() * L == L. + // + // - Multiplying two identity1D layouts with disjoint in/out dimensions gives + // a 2D identity layout: + // + // identity1D(4, "i1", "o1") * identity1D(8, "i2", "o2") => + // L(i1,i2) = (i1,i2), + // + // with in-dims ("i1", "i2") and out-dims ("o1", "o2"), in that order. + // + // - If out-dims overlap, they are combined, as in the following examples. + // + // - identity1D(4, "i", "o") * identity1D(2, "i", "o") == + // identity1D(8, "i", "o") + // The output matrix is [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + // + // - identity1D(4, "i", "o") * zeros1D(2, "i", "o") => L(x) = x % 4 + // for x in [0,8). + // The output matrix is [[1, 0, 0], [0, 1, 0], [0, 0, 0]] + // + // - zeros1D(2, "i", "o") * identity1D(4, "i", "o") => L(x) = x / 2 + // for x in [0,8). + // The output matrix is [[0, 0, 0], [0, 1, 0], [0, 0, 1]] + + // - identity1D(4, "i", "o1") * identity1D(8, "i", "o2") => + // L(x) = (x % 4, x / 4) for x in [0,32). + // The output dims are ("o1", "o2") in that order. + // + // If the input (or output) dims of the layouts are not the same, we take + // the supremum of the two ordered lists with the inclusion, respecting the + // order. If multiple suprema exist, we bias towards the first list. + // e.g. sup([a, b], [a, c]) = [a, b, c], sup([a, b], [b, c]) = [a, b, c] + // sup([a, b], [b, a]) = error! Supremum does not exist. + // + // Notice that this operation is not commutative, but it is associative. + // + // Requires: Any in/out dimensions which are in both outer and inner appear in + // the same relative order. + // + // Postcondition: If both inner and outer are surjective, the result is + // surjective. + friend LinearLayout operator*(LinearLayout inner, LinearLayout outer); + LinearLayout &operator*=(LinearLayout outer) { + *this = *this * outer; + return *this; + } + + // Returns true if this layout acts trivially (as the identity) on the given + // dimensions. This means that it's the identity on those dimensions, and it + // does not map other dimensions onto those or these onto other dimensions. + bool isTrivialOver(ArrayRef dimNames) const; + + // For an endomorphism on dimNames (linear map that maps dimNames to dimNames) + // checks whether it is the identity map on these dimensions (i.e + // LinearLayouts::isTrivialOver) and if so, returns the sublayout of the + // remaining dimensions. + // nb. The isTrivialOver condition is more restrictive than the usual + // "leaves the subspace invariant" condition in maths. + // We can always relax it if we know how to take advantage of a conversion + // layout being block-diagonal in the future. + std::optional quotient(ArrayRef dimNames) const; + + // Gets a layout with only these in/out dimensions. + // + // In other words, gets a layout where the in-dims not mentioned in inDimNames + // are set to 0, and the out-dims not mentioned in outDimNames are omitted. + // + // The output-dim sizes are unchanged. The order of the in/out dims in the + // returned layout matches the order of the original layout, not the order of + // the arguments. + LinearLayout sublayout(ArrayRef inDimNames, + ArrayRef outDimNames) const; + + // Is the sublayout restricted to inDimNames + outDimNames all zeros? + bool sublayoutIsZero(ArrayRef inDimNames, + ArrayRef outDimNames) const; + + // Computes and returns L(x, y, z). + // + // If you want to apply the layout to mlir Values instead of integers, that + // function lives in TritonGPUToLLVM/Utility.h. + SmallVector> + apply(ArrayRef> ins) const; + + // Creates a new layout which is equivalent to running this layout, then + // running `outer`. That is, + // + // - let this layout be L(x), and + // - let `outer` be O(x). + // - Then compose(outer) returns the layout (O∘L)(x), aka O(L(x)). + // + // Requires: + // - The output dimensions of this layout equal the input dimensions of + // outer (order doesn't matter). + // - For each output dim d of this layout, this->getOutDimSize(d) <= + // outer.getInDimSize(d). + // + // Postcondition: The result is surjective iff `this` and `outer` are + // surjective and this->getOutDimSize(d) == outer.getInDimSize(d) for each of + // this->getOutDimNames(). + // + [[nodiscard]] LinearLayout compose(const LinearLayout &outer) const; + + // Inverts or pseudo-inverts `outer` and composes it with `this`. + // + // Formally, if C = A.invertAndCompose(B), then for all x, C(x) = y implies + // A(x) = B(y), or in other words A(x) = B(C(x)). If B is invertible, then + // C(x) = B^-1(A(x)), which is how this function gets its name. + // + // For example, suppose you have the following two LLs. + // + // - R is an LL representing registers, mapping (lane, warp) to a 2D index. + // - S is an LL representing shared memory, mapping offset to a 2D index. + // + // Suppose you want to store tensor values from registers into shared memory. + // That is, given a (lane, warp), you want to know the corresponding shared + // memory offset to store into. + // + // This is equivalent to converting a (lane, warp) into a 2D index (i.e. + // applying R), then converting a 2D index into a shmem offset (i.e. applying + // the inverse of S). R.invertAndCompose(S) computes this transformation. + // + // Notice the following requirements in order for this to work. + // + // - R and S must have the same output dimension names (different order is + // allowed). + // - S must be surjective, i.e. there must be some offset for each output + // dimension of S. This way when we compose S^-1 with R, every possible + // 2D index that we might get from R has some shmem offset. + // - The codomain of S must be at least as large as the codomain of R. + // Otherwise, R could map some tensor index that is not stored in S. + // + // One requirement we *don't* have is that S is injective; we allow two shmem + // offsets to hold the same 2D index. If S is not injective, + // the algorithm chooses the smallest offset for a given (lane, warp). + [[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const; + + // Get the layout that is the inverse of this layout. + [[nodiscard]] LinearLayout invert() const; + // Compute and return a psueodinverse of this layout. This is a layout such + // that `B = A.psuedoinvert()` implies that `A(B(x)) = I`. If `A` is + // invertible, then this returns `A^-1`. + [[nodiscard]] LinearLayout pseudoinvert() const; + + // For each in-dim, returns a bitmask of the "free variables" in the layout + // function. + // + // These are the bits in the input that can be changed without changing the + // output. If all of the free variables are 0, then the layout is injective + // (i.e. every input bit affects the output). + llvm::MapVector getFreeVariableMasks() const; + + // Take the current linear layout and remove all zero bases for the provided + // dimension and return the resulting layout. This is useful for deriving a + // layout that returns just the unique output values when varying a given + // input dimension that has broadcasting. + [[nodiscard]] LinearLayout removeZeroBasesAlongDim(StringAttr stripDim) const; + + std::string toString() const; + + friend bool operator==(LinearLayout lhs, LinearLayout rhs); + friend bool operator!=(LinearLayout lhs, LinearLayout rhs) { + return !(lhs == rhs); + } + bool equalIgnoringOutDimSizes(const LinearLayout &other) const; + friend size_t hash_value(const LinearLayout &layout); + +private: + // Factory function that gracefully fails rather than asserts if the layout is + // not well-formed. + static std::optional + tryCreate(BasesT bases, ArrayRef> outDims, + bool requireSurjective); + + // Constructor that does not check invariants. Used by tryCreate. + struct NoCheckInvariants {}; + LinearLayout(BasesT bases, ArrayRef> outDims, + NoCheckInvariants); + + [[nodiscard]] std::optional + checkInvariants(bool requireSurjective); +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +} // namespace mlir::triton + +#endif // TRITON_TOOLS_LINEARLAYOUT_H diff --git a/third_party/enflame/include/triton/include/triton/Tools/StrUtil.h b/third_party/enflame/include/triton/include/triton/Tools/StrUtil.h new file mode 100644 index 000000000..8b59f7d2b --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Tools/StrUtil.h @@ -0,0 +1,54 @@ +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::triton { + +// Better version of llvm::join. This one works when T is an integer or any +// other type which defines operator<<(raw_ostream). +template +std::string join(C &&container, llvm::StringRef sep = ", ") { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + s << elem; + } + return ret; +} + +// Joins a container of elements into a string, using `sep` as a separator. +// +// fn is called to transform each element of the container before it's added to +// the string. fn must have one of the following two signatures. +// +// - void fn(llvm::raw_ostream&, E), where E is the element type of the +// container, or +// - T fn(E), where T is a type which can be passed to +// raw_ostream::operator<<. +// +template +std::string join(C &&container, llvm::StringRef sep, Fn &&fn) { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + + if constexpr (std::is_invocable_v) { + static_assert( + std::is_void_v< + std::invoke_result_t>); + fn(s, elem); + } else { + s << fn(elem); + } + } + return ret; +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/include/triton/Tools/Sys/GetEnv.hpp b/third_party/enflame/include/triton/include/triton/Tools/Sys/GetEnv.hpp new file mode 100644 index 000000000..3ff8e5ef3 --- /dev/null +++ b/third_party/enflame/include/triton/include/triton/Tools/Sys/GetEnv.hpp @@ -0,0 +1,99 @@ +#ifndef TRITON_TOOLS_SYS_GETENV_HPP +#define TRITON_TOOLS_SYS_GETENV_HPP + +#include +#include +#include +#include +#include +#include + +namespace mlir::triton { + +inline const std::set CACHE_INVALIDATING_ENV_VARS = { + // clang-format off + "AMDGCN_ENABLE_DUMP", + "AMDGCN_USE_BUFFER_OPS", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", + "DISABLE_MMA_V3", + "DISABLE_MMA_V5", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", + "LLVM_ENABLE_TIMING", + "LLVM_PASS_PLUGIN_PATH", + "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "MLIR_DUMP_PATH", + "MLIR_ENABLE_TIMING", + "TRITON_DEFAULT_FP_FUSION", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", + "TRITON_ENABLE_LLVM_DEBUG", + "TRITON_HIP_GLOBAL_PREFETCH", + "TRITON_HIP_LOCAL_PREFETCH", + "TRITON_HIP_USE_BLOCK_PINGPONG", + "TRITON_LLVM_DEBUG_ONLY", + "TRITON_ENABLE_ASAN", + "TRITON_OVERRIDE_ARCH", + "USE_IR_LOC", + "NVPTX_ENABLE_DUMP", + "STORE_TMEM_TO_GLOBAL_BYPASS_SMEM", + "ALLOW_LHS_TMEM_LAYOUT_CONVERSION", + "ENABLE_LHS_TO_TMEM", + "TRITON_F32_DEFAULT", + "ENABLE_PINGPONG", + // clang-format on +}; + +inline const std::set CACHE_NEUTRAL_ENV_VARS = { + // clang-format off + "TRITON_REPRODUCER_PATH", + "TRITON_ENABLE_PYTHON_STACKTRACE" + // clang-format on +}; + +namespace tools { + +inline void assertIsRecognized(const std::string &env) { + bool is_invalidating = CACHE_INVALIDATING_ENV_VARS.find(env.c_str()) != + CACHE_INVALIDATING_ENV_VARS.end(); + bool is_neutral = + CACHE_NEUTRAL_ENV_VARS.find(env.c_str()) != CACHE_NEUTRAL_ENV_VARS.end(); + std::string errmsg = env + "is not recognized. " + "Please add it to triton/tools/sys/getenv.hpp"; + assert((is_invalidating || is_neutral) && errmsg.c_str()); +} + +inline std::string getStrEnv(const std::string &env) { + assertIsRecognized(env); + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return ""; + std::string result(cstr); + return result; +} + +// return value of a cache-invalidating boolean environment variable +inline bool getBoolEnv(const std::string &env) { + assertIsRecognized(env); + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str == "on" || str == "true" || str == "1"; +} + +inline std::optional isEnvValueBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (str == "on" || str == "true" || str == "1") + return true; + if (str == "off" || str == "false" || str == "0") + return false; + return std::nullopt; +} +} // namespace tools +} // namespace mlir::triton + +#endif diff --git a/third_party/enflame/include/triton/lib/Analysis/Alias.cpp b/third_party/enflame/include/triton/lib/Analysis/Alias.cpp new file mode 100644 index 000000000..020f513ba --- /dev/null +++ b/third_party/enflame/include/triton/lib/Analysis/Alias.cpp @@ -0,0 +1,72 @@ +#include "triton/Analysis/Alias.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { + +AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { + if (lhs == rhs) + return lhs; + AliasInfo ret; + for (auto value : lhs.allocs) { + ret.insert(value); + } + for (auto value : rhs.allocs) { + ret.insert(value); + } + return ret; +} + +LogicalResult SharedMemoryAliasAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + AliasInfo aliasInfo; + bool pessimistic = true; + auto result = op->getResult(0); + // skip ops that return memdesc in a different memory space. + if (auto memdescTy = dyn_cast(result.getType())) { + if (!isa_and_nonnull( + memdescTy.getMemorySpace())) + return success(); + } + + // Only LocalAllocOp creates a new buffer. + if (isa(op)) { + aliasInfo.insert(result); + pessimistic = false; + } else if (isa( + op)) { + aliasInfo = AliasInfo(operands[0]->getValue()); + pessimistic = false; + } else { + assert(!isa(result.getType()) && + "unknown operation creating memory descriptor"); + } + + if (pessimistic) { + setAllToEntryStates(results); + return success(); + } + // Join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(aliasInfo)); + + return success(); +} + +AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { + // TODO: implement + return AliasResult::MayAlias; +} + +ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op, + Value location) { + // TODO: implement + return ModRefResult::getModAndRef(); +} + +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Analysis/Allocation.cpp b/third_party/enflame/include/triton/lib/Analysis/Allocation.cpp new file mode 100644 index 000000000..074d2a390 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Analysis/Allocation.cpp @@ -0,0 +1,852 @@ +#include "triton/Analysis/Allocation.h" + +#include +#include +#include + +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "allocation-shared-memory" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Shared Memory Allocation Analysis +//===----------------------------------------------------------------------===// +namespace triton { + +// Bitwidth of pointers +constexpr int kPtrBitWidth = 64; +// Max shmem LDS/STS instruction in bits +constexpr int kMaxShmemVecBitLength = 128; + +static SmallVector getRepShapeForCvt(RankedTensorType srcTy, + RankedTensorType dstTy) { + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (!cvtNeedsSharedMemory(srcTy, dstTy)) { + return {}; + } + + if (shouldUseDistSmem(srcLayout, dstLayout)) { + // TODO: padding to avoid bank conflicts + return convertType(gpu::getShapePerCTA(srcTy)); + } + + assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()"); + + auto srcShapePerCTA = gpu::getShapePerCTA(srcTy); + auto dstShapePerCTA = gpu::getShapePerCTA(dstTy); + auto srcShapePerCTATile = gpu::getShapePerCTATile(srcTy); + auto dstShapePerCTATile = gpu::getShapePerCTATile(dstTy); + + assert(srcTy.getRank() == dstTy.getRank() && + "src and dst must have the same rank"); + + unsigned rank = dstTy.getRank(); + SmallVector repShape(rank); + for (unsigned d = 0; d < rank; ++d) { + repShape[d] = + std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), + std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); + } + return repShape; +} + +// Both `atomic_cas` and `atomic_rmw need a single scratch element if returning +// a scalar value because Triton's block-based programming model ensures that +// all threads in each block see the same return value, even those threads that +// do not participate in the atomic operation +static SmallVector getRepShapeForAtomic(Value result) { + SmallVector smemShape; + if (atomicNeedsSharedMemory(result)) { + smemShape.push_back(1); + } + return smemShape; +} + +std::pair +getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) { + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + auto srcLinAttr = gpu::toLinearEncoding(srcLayout, srcTy.getShape()); + auto dstLinAttr = gpu::toLinearEncoding(dstLayout, dstTy.getShape()); + auto inOrd = srcLinAttr.getOrder(); + auto outOrd = dstLinAttr.getOrder(); + + unsigned rank = srcTy.getRank(); + + unsigned srcContigPerThread = srcLinAttr.getContigPerThread()[inOrd[0]]; + unsigned dstContigPerThread = dstLinAttr.getContigPerThread()[outOrd[0]]; + // TODO: Fix the legacy issue that outOrd[0] == 0 always means + // that we cannot do vectorization. + unsigned innerDim = rank - 1; + unsigned inVec = outOrd[0] != innerDim ? 1 + : inOrd[0] != innerDim ? 1 + : srcContigPerThread; + unsigned outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; + + if (isa(srcLayout) && + isa(dstLayout)) { + // when storing from mma layout and loading in blocked layout vectorizing + // the load back gives better performance even if there is a + // transposition. + outVec = dstContigPerThread; + } + return {inVec, outVec}; +} + +ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, + RankedTensorType dstTy) { + // Initialize vector sizes and stride + auto repShape = getRepShapeForCvt(srcTy, dstTy); + if (repShape.empty()) + return ScratchConfig({}, {}); + ScratchConfig scratchConfig(repShape, repShape); + auto rank = repShape.size(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + assert(cvtNeedsSharedMemory(srcTy, dstTy)); + auto outOrd = gpu::toLinearEncoding(dstLayout, dstTy.getShape()).getOrder(); + scratchConfig.order = outOrd; + + std::tie(scratchConfig.inVec, scratchConfig.outVec) = + getScratchCvtInOutVecLengths(srcTy, dstTy); + // We can't write a longer vector than the shape of shared memory. + // This shape might be smaller than the tensor shape in case we decided to + // do the conversion in multiple iterations. + unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]]; + scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim); + scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim); + // Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this + // is the max vectorisation + auto inBitWidth = isa(srcTy.getElementType()) + ? kPtrBitWidth + : srcTy.getElementTypeBitWidth(); + auto outBitWidth = isa(dstTy.getElementType()) + ? kPtrBitWidth + : dstTy.getElementTypeBitWidth(); + scratchConfig.inVec = + std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth); + scratchConfig.outVec = + std::min(scratchConfig.outVec, kMaxShmemVecBitLength / outBitWidth); + + // No padding is required if the tensor is 1-D, or if all dimensions except + // the first accessed dimension have a size of 1. + if (rank <= 1 || product(repShape) == repShape[outOrd[0]]) + return scratchConfig; + + auto paddedSize = std::max(scratchConfig.inVec, scratchConfig.outVec); + scratchConfig.paddedRepShape[outOrd[0]] += paddedSize; + return scratchConfig; +} + +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { + if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + return helper.getScratchSizeInBytes(); + } + if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + return helper.getScratchSizeInBytes(); + } + if (auto gatherOp = dyn_cast(op)) { + GatherLoweringHelper helper(gatherOp); + return helper.getScratchSizeInBytes(); + } + if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + return std::max(dstTy.getNumElements(), threadsPerWarp) * + std::max(8, dstTy.getElementTypeBitWidth()) / 8; + } + if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { + // Conversions from/to shared memory do not need scratch memory. + return 0; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + auto elems = getNumScratchElements(scratchConfig.paddedRepShape); + return isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + } + if (isa(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (dyn_cast(value.getType())) { + return 0; + } + auto smemShape = getRepShapeForAtomic(op->getResult(0)); + auto elems = getNumScratchElements(smemShape); + auto elemTy = cast(value.getType()).getPointeeType(); + assert(!isa(elemTy) && "unexpected pointer type"); + return elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + } + if (isa(op)) { + constexpr int32_t kTMASize = 128; + return kTMASize; + } + return 0; +} + +class AllocationAnalysis { +public: + AllocationAnalysis(Operation *operation, + Allocation::FuncAllocMapT *funcAllocMap, + Allocation *allocation, + AllocationAnalysisScratchSizeFn scratchSizeGetter) + : operation(operation), funcAllocMap(funcAllocMap), + allocation(allocation), scratchSizeGetter(scratchSizeGetter) { + run(); + } + +private: + using BufferT = Allocation::BufferT; + + /// Value -> Liveness Range + /// Use MapVector to ensure determinism. + using BufferRangeMapT = llvm::MapVector>; + /// Nodes -> Nodes + using GraphT = DenseMap>; + + void run() { + getValuesAndSizes(); + resolveLiveness(); + computeOffsets(); + } + + /// Initializes explicitly defined shared memory values for a given operation. + void getExplicitValueSize(Operation *op) { + auto alloc = dyn_cast(op); + if (!alloc || !alloc.isSharedMemoryAlloc()) + return; + // Bytes could be a different value once we support padding or other + // allocation policies. + auto allocType = alloc.getType(); + auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType); + auto bytes = + product(shapePerCTA) * allocType.getElementTypeBitWidth() / 8; + + auto alignment = alloc.getAlignmentOrDefault(); + LLVM_DEBUG({ + llvm::dbgs() << "check localAlloc in getExplicitValueSize: "; + alloc.dump(); + }); + int sharingGroup = -1; + if (alloc->hasAttr("allocation.shareGroup")) { + sharingGroup = + mlir::cast(alloc->getAttr("allocation.shareGroup")) + .getInt(); + LDBG("with shareGroup of " << sharingGroup); + } + allocation->addBuffer( + alloc, bytes, alignment, 0, sharingGroup); + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes, + unsigned alignment) { + if (bytes > 0) + allocation->addBuffer(op, bytes, alignment); + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes) { + if (bytes > 0) + allocation->addBuffer(op, bytes); + } + + /// Initializes temporary shared memory for a given operation. + void getScratchValueSize(Operation *op) { + constexpr size_t scratchAlignment = 128; + if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto funcOp = dyn_cast(callable); + auto *funcAlloc = &(*funcAllocMap)[funcOp]; + auto bytes = funcAlloc->getSharedMemorySize(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + return; + } + if (auto ws = dyn_cast(op)) { + // `ttg.warp_specialize` needs memory to pass its explicit captures. Pack + // the captures like a struct. + auto [captureSize, captureAlign] = ws.getCaptureSizeAlign(); + maybeAddScratchBuffer(op, captureSize, + captureAlign); + return; + } + if (auto func = dyn_cast(op)) { + unsigned numWarpIndices = 0; + // Warp specialization communicates states over shared memory to each + // warp. Add space for an i8 for each warpgroup warp. + func.walk([&](gpu::WarpSpecializeOp op) { + numWarpIndices = std::max(numWarpIndices, op.getTotalPartitionWarps()); + }); + maybeAddScratchBuffer(op, numWarpIndices); + return; + } + unsigned bytes = scratchSizeGetter(op); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + + void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { + dataflow::Lattice *latticeElement = + analysis.getLatticeElement(value); + if (latticeElement) { + AliasInfo &info = latticeElement->getValue(); + if (!info.getAllocs().empty()) { + for (auto alloc : info.getAllocs()) { + allocation->addAlias(value, alloc); + } + } + } + } + + /// Extract all shared memory values and their sizes + void getValuesAndSizes() { + // Get the alloc values + operation->walk([&](Operation *op) { + getExplicitValueSize(op); + getScratchValueSize(op); + }); + LDBG("getValuesAndSizes --"); + for (auto valueBufferIter : allocation->valueBuffer) { + auto *buffer = valueBufferIter.second; + LLVM_DEBUG(llvm::dbgs() + << "-- buffer " << buffer->id << " " << buffer->size << " " + << buffer->offset << " " << buffer->sharingGroup << "\n"); + } + // Get the alias values + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *aliasAnalysis = + solver->load(); + // Run the analysis rooted at every isolated from above operation, including + // the top-level function but also any nested regions. + operation->walk([&](Operation *op) { + if (op->hasTrait() && + failed(solver->initializeAndRun(op))) { + // TODO: return error instead of bailing out.. + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } + }); + operation->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + getValueAlias(operand, *aliasAnalysis); + } + for (auto value : op->getResults()) { + getValueAlias(value, *aliasAnalysis); + } + }); + } + + /// Computes the liveness range of the allocated value. + /// Each buffer is allocated only once. + void resolveExplicitBufferLiveness( + function_ref(Value value, BufferT *buffer)> + getLiveness) { + for (auto valueBufferIter : allocation->valueBuffer) { + auto value = valueBufferIter.first; + auto *buffer = valueBufferIter.second; + bufferRange[buffer] = getLiveness(value, buffer); + LLVM_DEBUG({ + llvm::dbgs() << "-- buffer " << buffer->id << "; value: "; + value.dump(); + }); + } + } + + /// Extends the liveness range by unionizing the liveness range of the aliased + /// values because each allocated buffer could be an alias of others, if block + /// arguments are involved. + void resolveAliasBufferLiveness( + function_ref(Value value, BufferT *buffer)> + getLiveness) { + for (const auto &[value, buffers] : allocation->aliasBuffer) { + auto range = getLiveness(value, buffers.front()); + for (auto *buffer : buffers) { + auto minId = range.start(); + auto maxId = range.end(); + if (bufferRange.count(buffer)) { + // Extend the allocated buffer's range + minId = std::min(minId, bufferRange[buffer].start()); + maxId = std::max(maxId, bufferRange[buffer].end()); + } + bufferRange[buffer] = Interval(minId, maxId); + } + } + } + + /// Computes the liveness range of scratched buffers. + /// Some operations may have a temporary buffer that is not explicitly + /// allocated, but is used to store intermediate results. + void resolveScratchBufferLiveness( + const DenseMap &operationId) { + // Analyze liveness of scratch buffers and virtual buffers. + auto processScratchMemory = [&](const auto &container) { + for (auto [op, buffer] : container) { + // Buffers owned by the function are assumed live for the whole + // function. This memory is used for warp specialization codegen. + // FIXME: Spooky-action-at-a-distance. Find a better way to model this. + if (op == operation) { + bufferRange.insert( + {buffer, Interval(size_t(), std::numeric_limits::max())}); + continue; + } + + // Any scratch memory's live range is the current operation's live + // range. + // Extend live range when asyncTaskId is not empty (i.e when we have + // warp spec). + if (getAsyncTaskIds(op).empty()) { + bufferRange.insert( + {buffer, Interval(operationId.at(op), operationId.at(op) + 1)}); + } else { + for (auto tId : getAsyncTaskIds(op)) + buffer->regionIds.insert(tId); + // For warp-specialized code, we can assume each region has its own + // copy of a scratch buffer, i.e each region is for a single taskId. + // In that case, we don't need to extend the liveness of scratch + // buffers. + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } + LLVM_DEBUG({ + llvm::dbgs() << "-- buffer " << buffer->id << "; value: "; + op->dump(); + }); + } + }; + processScratchMemory(allocation->opScratch); + processScratchMemory(allocation->opVirtual); + } + + /// Resolves liveness of all values involved under the root operation. + void resolveLiveness() { + // Assign an ID to each operation using post-order traversal. + // To achieve the correct liveness range, the parent operation's ID + // should be greater than each of its child operation's ID . + // Example: + // ... + // %5 = triton.convert_layout %4 + // %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) { + // %2 = triton.convert_layout %5 + // ... + // scf.yield %arg0 + // } + // For example, %5 is defined in the parent region and used in + // the child region, and is not passed as a block argument. + // %6 should should have an ID greater than its child operations, + // otherwise %5 liveness range ends before the child operation's liveness + // range ends. + DenseMap operationId; + operation->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + + // Analyze liveness of explicit buffers + Liveness liveness(operation); + auto getValueLivenessRange = [&](Value value, BufferT *buffer) { + auto liveOperations = liveness.resolveLiveness(value); + // Update regions for buffer. + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + for (auto rId : getAsyncTaskIds(liveOp)) { + buffer->regionIds.insert(rId); + } + }); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each( + liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { + if (buffer->regionIds.size() > 1 || buffer->sharingGroup >= 0) { + // For a buffer that is associated with warp + // specialization, due to producer-consumer channel, it + // should have at least two regions, and it will be live + // throughout. For a buffer that is local to a consumer: + // we need to make sure not to overlap with local + // buffers from another consumer. This will be handled + // when building the interference graph. + minId = 0; + maxId = operationId.size(); + return; + } + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + return Interval(minId, maxId); + }; + + resolveExplicitBufferLiveness(getValueLivenessRange); + resolveAliasBufferLiveness(getValueLivenessRange); + resolveScratchBufferLiveness(operationId); + } + + void dumpBuffers() { + LDBG("Dump bufferRange: id size offset sharingGroup ---------"); + for (auto bufferIter : bufferRange) { + LLVM_DEBUG({ + llvm::dbgs() << "-- " << bufferIter.first->id << " " + << bufferIter.first->size << " " + << bufferIter.first->offset << " " + << bufferIter.first->sharingGroup << " regions ["; + for (auto tId : bufferIter.first->regionIds) { + llvm::dbgs() << tId << " "; + } + llvm::dbgs() << "] interval " << bufferIter.second.start() << " " + << bufferIter.second.end() << "\n"; + }); + } + } + + void dumpAllocationSize() const { + LDBG("Dump shared memory allocation size -----------"); + auto liveBuffers = allocation->getLiveBuffers(); + auto analyzedSize = 0; + for (auto [op, bufferIds] : liveBuffers) { + auto size = 0; + for (auto bufferId : bufferIds) { + auto bufferSize = allocation->getAllocatedSize(bufferId); + size += bufferSize; + } + analyzedSize = std::max(analyzedSize, size); + } + llvm::dbgs() << "Allocated: " << allocation->sharedMemorySize + << ", analyzed: " << analyzedSize << "\n"; + } + + void dumpInterferenceGraph(const GraphT &interference) const { + LDBG("\n"); + LDBG("Dump interference graph: \n"); + for (auto edges : interference) { + llvm::dbgs() << "-- from " << edges.first->id << " to "; + for (auto node : edges.second) { + llvm::dbgs() << node->id << "; "; + } + llvm::dbgs() << "\n"; + } + } + + /// Computes the shared memory offsets for all related values. + /// Paper: Algorithms for Compile-Time Memory Optimization + /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) + void computeOffsets() { + SmallVector buffers; + // Handle sharingGroup here. For allocations with the same sharingGroup + // get the union of the live range, and union of the regionIds. Put + // the + // largest buffer in buffers. + DenseMap> toGroup; + for (auto bufferIter : bufferRange) { + if (bufferIter.first->sharingGroup >= 0) + toGroup[bufferIter.first->sharingGroup].push_back(bufferIter.first); + } + DenseMap sharingIdToRep; + for (auto &kv : toGroup) { + size_t bigSize = 0; + BufferT *rep = nullptr; + for (auto *buf : kv.second) { + if (buf->size > bigSize) { + rep = buf; + bigSize = buf->size; + } + } + // FIXME: update live range and regionIds. + sharingIdToRep[kv.first] = rep; + } + for (auto bufferIter : bufferRange) { + if (sharingIdToRep.find(bufferIter.first->sharingGroup) != + sharingIdToRep.end()) { + if (bufferIter.first != + sharingIdToRep[bufferIter.first->sharingGroup]) { + LDBG("-- ignore shared buffer " << bufferIter.first->size << " " + << bufferIter.first->offset << " " + << bufferIter.first->sharingGroup); + continue; + } + } + buffers.emplace_back(bufferIter.first); + } + + // Sort buffers by size in descending order to reduce the fragmentation + // on big buffers caused by smaller buffers. Big buffers have a higher + // chance to overlap with multiple other buffers, and allocating them first + // (by calculateStarts) ensures a higher chance that they will occupy a + // standalone smem slot. + llvm::stable_sort( + buffers, [&](BufferT *A, BufferT *B) { return A->size > B->size; }); + + calculateStarts(buffers); + dumpBuffers(); + + // NOTE: The original paper doesn't consider interference between + // the bumped ranges. Buffers that previously do not interfere with + // could interfere after offset bumping if their liveness ranges overlap. + // Therefore, we rerun the interference graph algorithm after bumping so + // that we regroup the buffers and color them again. Since we always + // increase the buffer offset and keep reducing conflicts, we will + // eventually reach a fixed point. + GraphT interference; + buildInterferenceGraph(buffers, interference); + do { + allocate(buffers, interference); + buildInterferenceGraph(buffers, interference); + } while (!interference.empty()); + + LLVM_DEBUG(dumpAllocationSize()); + // Update allocation for sharingGroup. + for (auto &kv : toGroup) { + auto *rep = sharingIdToRep[kv.first]; + for (auto *buf : kv.second) { + if (buf != rep) { + buf->setOffsetAligned(rep->offset); + LDBG("-- set sharing buffer's offset " + << buf->size << " " << buf->offset << " " << buf->sharingGroup); + } + } + } + dumpBuffers(); + } + + /// Computes the initial shared memory offsets. + void calculateStarts(const SmallVector &buffers) { + // v = values in shared memory + // t = triplet of (size, start, end) + // shared memory space + // - + // | *******t4 + // | /|\ v2 inserts t4, t5, and t6 + // | | + // | ******t5 ************t6 + // | ^^^^^v2^^^^^^ + // | | *********************t2 + // | \|/ v2 erases t1 + // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 + // |---------------------------------------------| liveness range + // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... + // If the available triple's range is less than a given buffer range, + // we won't know if there has been an overlap without using graph coloring. + // Start -> Liveness Range + using TripleMapT = std::multimap>; + TripleMapT tripleMap; + tripleMap.insert(std::make_pair(0, Interval())); + SmallVector xBuffers = buffers; + while (!xBuffers.empty()) { + auto tripleIt = tripleMap.begin(); + auto offset = tripleIt->first; + auto range = tripleIt->second; + tripleMap.erase(tripleIt); + auto bufferIt = + std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { + auto xRange = bufferRange[buffer]; + bool res = xRange.intersects(range); + for (const auto &val : tripleMap) + res = res && + !val.second.intersects(xRange); // only one buffer intersect + return res; + }); + if (bufferIt != xBuffers.end()) { + auto buffer = *bufferIt; + auto xSize = buffer->size; + auto xRange = bufferRange.lookup(buffer); + // TODO(Keren): A buffer's size shouldn't be determined here, have to + // clean it up + size_t alignOffset = buffer->setOffsetAligned(offset); + tripleMap.insert({alignOffset + xSize, + Interval{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); + // We could either insert (range.start, xRange.start) or (range.start, + // xRange.end), both are correct and determine the potential buffer + // offset, and the graph coloring algorithm will solve the interference, + // if any + if (range.start() < xRange.start()) + tripleMap.insert({offset, Interval{range.start(), xRange.end()}}); + if (xRange.end() < range.end()) + tripleMap.insert({offset, Interval{xRange.start(), range.end()}}); + xBuffers.erase(bufferIt); + } + } + LLVM_DEBUG(dumpBuffers()); + } + + /// Builds a graph of all shared memory values. Edges are created between + /// shared memory values that are overlapping. + void buildInterferenceGraph(const SmallVector &buffers, + GraphT &interference) { + // Reset interference graph + auto inDifferentRegion = [&](BufferT *A, BufferT *B) { + auto tA = A->regionIds; + auto tB = B->regionIds; + if (tA.empty() && tB.empty()) + return false; + if (tA.empty() || tB.empty()) + return true; + for (auto t1 : tA) { + for (auto t2 : tB) { + if (t1 != t2) + return true; + } + } + return false; + }; + interference.clear(); + for (auto x : buffers) { + for (auto y : buffers) { + if (x == y) + continue; + auto xStart = x->offset; + auto yStart = y->offset; + auto xSize = x->size; + auto ySize = y->size; + Interval xSizeRange = {xStart, xStart + xSize}; + Interval ySizeRange = {yStart, yStart + ySize}; + auto xOpRange = bufferRange.lookup(x); + auto yOpRange = bufferRange.lookup(y); + + // Buffers interfere if their allocation offsets overlap and they are + // live at the same time. + if (xOpRange.intersects(yOpRange) && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + + // Buffers also interfere if their allocation offsets overlap and they + // exist within regions that may execute simultaneously with respect to + // each other. + auto wsx = x->owner->getParentWithTrait(); + auto wsy = y->owner->getParentWithTrait(); + if (wsx && wsy && wsx == wsy && + x->owner->getParentRegion() != y->owner->getParentRegion() && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + // if x and y belong to different regions (ignore producer region). + if (inDifferentRegion(x, y) && xSizeRange.intersects(ySizeRange)) + interference[x].insert(y); + } + } + + LLVM_DEBUG(dumpInterferenceGraph(interference)); + } + + /// Finalizes shared memory offsets considering interference. + void allocate(const SmallVector &buffers, + const GraphT &interference) { + // Reset shared memory size + allocation->sharedMemorySize = 0; + // First-fit graph coloring + // Neighbors are nodes that interfere with each other. + // We color a node by finding the index of the first available + // non-neighboring node or the first neighboring node without any color. + // Nodes with the same color do not interfere with each other. + DenseMap colors; + for (auto value : buffers) { + colors[value] = (value == buffers[0]) ? 0 : -1; + } + SmallVector available(buffers.size()); + for (auto x : buffers) { + std::fill(available.begin(), available.end(), true); + for (auto y : interference.lookup(x)) { + int color = colors[y]; + if (color >= 0) { + available[color] = false; + } + } + auto it = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), it); + LLVM_DEBUG({ + llvm::dbgs() << "-- color " << x->id << " " << colors[x] << "\n"; + }); + } + // Finalize allocation + // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) + // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) + // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) + // TODO(Keren): We are wasting memory here. + // Nodes with color2 can actually start with 24. + for (auto x : buffers) { + size_t newOffset = 0; + for (auto y : interference.lookup(x)) { + newOffset = std::max(newOffset, y->offset + y->size); + } + if (colors.lookup(x) != 0) + x->setOffsetAligned(newOffset); + allocation->sharedMemorySize = + std::max(allocation->sharedMemorySize, x->offset + x->size); + } + LLVM_DEBUG(dumpBuffers()); + } + +private: + Operation *operation; + Allocation::FuncAllocMapT *funcAllocMap; + Allocation *allocation; + BufferRangeMapT bufferRange; + AllocationAnalysisScratchSizeFn scratchSizeGetter; +}; + +} // namespace triton + +void Allocation::run( + FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this, + scratchSizeGetter); +} + +std::map> +Allocation::getLiveBuffers() { + std::map> liveBuffers; + + Operation *rootOperation = getOperation(); + Liveness liveness(rootOperation); + auto analyzeOperation = [&](Operation *op) -> void { + auto scratchBuffer = getBufferId(op); + if (scratchBuffer != InvalidBufferId) + liveBuffers[op].push_back(scratchBuffer); + for (auto result : op->getOpResults()) { + auto bufferId = getBufferId(result); + if (bufferId == Allocation::InvalidBufferId) + continue; + auto liveOperations = liveness.resolveLiveness(result); + for (auto depOp : liveOperations) + liveBuffers[depOp].push_back(bufferId); + } + }; + rootOperation->walk(analyzeOperation); + return liveBuffers; +} + +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Analysis/AxisInfo.cpp b/third_party/enflame/include/triton/lib/Analysis/AxisInfo.cpp new file mode 100644 index 000000000..aad43186e --- /dev/null +++ b/third_party/enflame/include/triton/lib/Analysis/AxisInfo.cpp @@ -0,0 +1,1377 @@ +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "axis-info" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton { +namespace { + +int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) { + // Base Case + if (a == 0) { + *x = 0; + *y = 1; + return b; + } + int64_t x1, y1; // To store results of recursive call + int64_t gcd = gcdImpl(b % a, a, &x1, &y1); + // Update x and y using results of + // recursive call + *x = y1 - (b / a) * x1; + *y = x1; + return gcd; +} + +int64_t gcd(int64_t a, int64_t b) { + if (a == 0) + return b; + if (b == 0) + return a; + int64_t x, y; + return gcdImpl(a, b, &x, &y); +} + +constexpr int log2Int(int64_t num) { + return (num > 1) ? 1 + log2Int(num / 2) : 0; +} + +// If lhs * rhs overflows, return max value possible value for the type +int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { + int64_t maxDivisor = highestPowOf2Divisor(0); + if (lhs > maxDivisor / rhs) + return maxDivisor; + return lhs * rhs; +} + +class AxisInfoVisitor { +public: + AxisInfoVisitor() = default; + virtual ~AxisInfoVisitor() = default; + + static bool isContiguousDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getContiguity(dim) == shape[dim]; + } + + static bool isConstantDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getConstancy(dim) == shape[dim]; + } + + virtual AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) = 0; + + virtual bool match(Operation *op) = 0; +}; + +// Base class for all operations +template class AxisInfoVisitorImpl : public AxisInfoVisitor { +public: + using AxisInfoVisitor::AxisInfoVisitor; + + AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) final { + return getAxisInfo(cast(op), operands); + } + + bool match(Operation *op) final { return isa(op); } + + virtual AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) = 0; +}; + +// Binary operations +template +class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + assert(operands.size() == 2 && "Expected two operands"); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + for (auto d = 0; d < rank; ++d) { + if (constantValue.has_value()) { + contiguity.push_back(1); + constancy.push_back( + std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + highestPowOf2Divisor(constantValue.value())); + } else { + contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); + constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); + } + } + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +protected: + virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) { + return {}; + } +}; + +class AxisInfoVisitorList { +public: + template > + void append() { + (visitors.emplace_back(std::make_unique()), ...); + } + + AxisInfo apply(Operation *op, + ArrayRef *> operands) { + for (auto &visitor : visitors) + if (visitor->match(op)) + return visitor->getAxisInfo(op, operands); + return AxisInfo(); + } + +private: + std::vector> visitors; +}; + +class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + AxisInfoVisitorList visitors; + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join( + AxisInfo::getPessimisticValueState(lattice->getAnchor()))); + } + + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef *> argLattices, + unsigned firstIndex) override { + if (auto forOp = dyn_cast(op)) { + visitForOpInductionVar(forOp, argLattices); + } else { + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front( + firstIndex + successor.getSuccessorInputs().size())); + } + } + +public: + AxisInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncAxisInfoMapT = DenseMap; + + LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + void + visitForOpInductionVar(scf::ForOp op, + ArrayRef *> argLattices); +}; + +template +class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + return operands[0]->getValue(); + } +}; + +class MakeRangeOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::MakeRangeOp op, + ArrayRef *> operands) override { + auto start = op.getStart(); + auto end = op.getEnd(); + return AxisInfo(/*contiguity=*/{end - start}, + /*divisibility=*/{highestPowOf2Divisor(start)}, + /*constancy=*/{1}); + } +}; + +class ConstantOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(arith::ConstantOp op, + ArrayRef *> operands) override { + auto intAttr = dyn_cast(op.getValue()); + auto boolAttr = dyn_cast(op.getValue()); + if (intAttr || boolAttr) { + int64_t value{}; + if (intAttr) + value = intAttr.getValue().getZExtValue(); + else + value = boolAttr.getValue() ? 1 : 0; + return AxisInfo(/*contiguity=*/{1}, + /*divisibility=*/{highestPowOf2Divisor(value)}, + /*constancy=*/{1}, + /*knownConstantValue=*/{value}); + } + // TODO: generalize to dense attr + auto splatAttr = dyn_cast(op.getValue()); + if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { + int64_t value = splatAttr.template getSplatValue().getZExtValue(); + TensorType ty = cast(splatAttr.getType()); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1), + /*divisibility=*/ + AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), + /*constancy=*/ + AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()), + /*knownConstantValue=*/{value}); + } + return AxisInfo(); + } +}; + +class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(ub::PoisonOp op, + ArrayRef *> operands) override { + constexpr int64_t largePowerOf2 = int64_t(1) << 32; + // Poison values are never accessed, thus assume optimistic values. + if (auto shape = dyn_cast(op.getType())) { + unsigned rank = shape.getRank(); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2), + /*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2), + /*constancy=*/AxisInfo::DimVectorT(shape.getShape())); + } + + return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2}, + /*constancy=*/{1}); + } +}; + +template +class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // Contiguity assumes an increasing sequence. So for SubIOp contiguous + // RHS doesn't produce a contiguous result. + if (isa(op)) + return gcd(lhs.getContiguity(dim), rhs.getConstancy(dim)); + + return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)), + gcd(lhs.getContiguity(dim), rhs.getConstancy(dim))); + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs) + // rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs) + // lhs + rhs = k * d_lhs + p * d_rhs = (k * k' + p * p') * gcd(d_lhs, d_rhs) + auto rhsDivisibility = rhs.getDivisibility(dim); + if constexpr (std::is_same_v) { + // %ptr = addptr %lhs, %rhs + // is equivalent to + // %0 = mul %rhs, %elemSize + // %ptr = add %lhs, %0 + // The result will still be contiguous in terms of elements but not bytes + // For example: + // addptr [16] : !ptr, [0, 1, 2, 3] : i32 -> !ptr + // returns: + // [16, 20, 24, 28] : !ptr + // with element locations: + // [4, 5, 6, 7] + // It is "strided contiguous" with a divisilibity of 16 bytes + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + rhsDivisibility = multiplyDivisor(rhs.getDivisibility(dim), elemSize); + } + return gcd(lhs.getDivisibility(dim), rhsDivisibility); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() + + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() - + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + auto rhsValue = rhs.getConstantValue().value() * elemSize; + return {lhs.getConstantValue().value() + rhsValue}; + } + } + return {}; + } +}; + +class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + // lhs * 1 = lhs + auto lhsContiguity = + rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1 + ? lhs.getContiguity(dim) + : 1; + // 1 * rhs = rhs + auto rhsContiguity = + lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1 + ? rhs.getContiguity(dim) + : 1; + return std::max(lhsContiguity, rhsContiguity); + } + + int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && + !(rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto rhsDivisibility = rhs.getDivisibility(dim); + if (rhs.getContiguity(dim) > 1 && + !(lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + rhsDivisibility = 1; + } + return multiplyDivisor(lhsDivisibility, rhsDivisibility); + } + + std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs / 1 = lhs + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? lhs.getContiguity(dim) + : 1; + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // Case 1: both lhs and rhs are constants. + auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + // Case 2: lhs contiguous, rhs constant. + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), + // ..., (d_lhs * k + n) / (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // the minimal constancy is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual constancy. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + constancy = std::max(constancy, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return constancy; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // Case 1: lhs is 0 + if (lhs.getConstantValue().has_value() && + lhs.getConstantValue().value() == 0) + return lhs.getDivisibility(dim); + // Case 2: rhs is 1 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return lhs.getDivisibility(dim); + // otherwise: return 1 + return 1; + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() / rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getContiguity(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + int64_t contiguity = 1; + // lhs contiguous, rhs constant + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p), + // ..., (d_lhs * k + n) % (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // The minimal contiguity is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual contiguity. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + contiguity = std::max(contiguity, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return contiguity; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' + // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' + // lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r + // r must be divisible by gcd(d_lhs, d_rhs) + return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); + }; + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // lhs % 1 = 0 + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? shape[dim] + : gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() % rhs.getConstantValue().value()}; + else if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return {0}; + return {}; + } +}; + +class SplatOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::SplatOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + TensorType retTy = cast(_retTy); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(opInfo.getDivisibility(0)); + constancy.push_back(retTy.getShape()[d]); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::LoadOp op, + ArrayRef *> operands) override { + // If pointers and mask both have constancy properties, those properties + // will also extend to output. + AxisInfo ptrInfo = operands[0]->getValue(); + std::optional maskInfo; + if (operands.size() > 1) { + maskInfo = operands[1]->getValue(); + } + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + + for (int d = 0; d < ptrInfo.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(1); + constancy.push_back( + gcd(ptrInfo.getConstancy(d), + maskInfo.has_value() ? maskInfo->getConstancy(d) : 0)); + } + + return AxisInfo(contiguity, divisibility, constancy); + } +}; + +class ExpandDimsOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::ExpandDimsOp op, + ArrayRef *> operands) override { + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); + AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); + AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + int64_t newDivisibility = 1; + if (opInfo.getConstantValue().has_value()) { + // The tensor is constant, same as ConstantOpAxisInfoVisitor + newDivisibility = highestPowOf2Divisor(opInfo.getConstantValue().value()); + } else if (opInfo.getRank()) { + // Otherwise, calculate the GCD as the new divisibility + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + newDivisibility = + opInfo.getContiguity(0) > 1 ? 1 : opInfo.getDivisibility(0); + for (int d = 1; d < opInfo.getRank(); ++d) { + newDivisibility = + gcd(newDivisibility, + opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d)); + } + } + contiguity.insert(contiguity.begin() + op.getAxis(), 1); + divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); + constancy.insert(constancy.begin() + op.getAxis(), 1); + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class BroadcastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::BroadcastOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = cast(_retTy); + TensorType opTy = cast(_opTy); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); + divisibility.push_back(opInfo.getDivisibility(d)); + constancy.push_back(opShape[d] == 1 ? retShape[d] + : opInfo.getConstancy(d)); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +template +class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return AxisInfo(); + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + for (short d = 0; d < rank; ++d) { + int64_t constHint = 1; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constHint = lhsInfo.getConstancy(d); + constantValue = + compare(getPredicate(op), lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value()) + ? 1 + : 0; + } else { + // Case 1: lhs and rhs are both partial constants + constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)); + if ((gtPredicate(getPredicate(op)) || lePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(lhsInfo, shape, d)) { + // Case 2: lhs all constant, rhs all contiguous + // NOTE: + // lhs: 4 4 4 4 + // rhs: 4 5 6 7 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs lt rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 1, 1, 1 + // lhs ge rhs: 1, 0, 0, 0 + // lhs gt rhs: 0, 0, 0, 0 + constHint = std::max(constHint, gcd(rhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } else if ((ltPredicate(getPredicate(op)) || + gePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)) { + // Case 3: lhs all contiguous, rhs all constant + // NOTE + // lhs: 4 5 6 7 + // rhs: 4 4 4 4 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 0, 0, 0 + // lhs lt rhs: 0, 0, 0, 0 + // lhs gt rhs: 0, 1, 1, 1 + // lhs ge rhs: 1, 1, 1, 1 + constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } + } + + constancy.push_back(constHint); + divisibility.push_back(1); + contiguity.push_back(1); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +private: + static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { + return op.getPredicate(); + } + + static bool gtPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sgt || + predicate == arith::CmpIPredicate::ugt; + } + + static bool gePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sge || + predicate == arith::CmpIPredicate::uge; + } + + static bool ltPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::slt || + predicate == arith::CmpIPredicate::ult; + } + + static bool lePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sle || + predicate == arith::CmpIPredicate::ule; + } + + static bool compare(arith::CmpIPredicate predicate, int64_t lhs, + int64_t rhs) { + switch (predicate) { + case arith::CmpIPredicate::eq: + return lhs == rhs; + case arith::CmpIPredicate::ne: + return lhs != rhs; + case arith::CmpIPredicate::slt: + return lhs < rhs; + case arith::CmpIPredicate::sle: + return lhs <= rhs; + case arith::CmpIPredicate::sgt: + return lhs > rhs; + case arith::CmpIPredicate::sge: + return lhs >= rhs; + case arith::CmpIPredicate::ult: + return (uint64_t)lhs < (uint64_t)rhs; + case arith::CmpIPredicate::ule: + return (uint64_t)lhs <= (uint64_t)rhs; + case arith::CmpIPredicate::ugt: + return (uint64_t)lhs > (uint64_t)rhs; + case arith::CmpIPredicate::uge: + return (uint64_t)lhs >= (uint64_t)rhs; + default: + break; + } + llvm_unreachable("unknown comparison predicate"); + } +}; + +template +class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto condConstancy = operands[0]->getValue().getConstancy(); + auto lhsInfo = operands[1]->getValue(); + auto rhsInfo = operands[2]->getValue(); + auto rank = lhsInfo.getRank(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + if (operands[0]->getValue().getConstantValue().has_value()) { + if (operands[0]->getValue().getConstantValue() == 0) { + contiguity = rhsInfo.getContiguity(); + divisibility = rhsInfo.getDivisibility(); + constancy = rhsInfo.getConstancy(); + constantValue = rhsInfo.getConstantValue(); + } else { + contiguity = lhsInfo.getContiguity(); + divisibility = lhsInfo.getDivisibility(); + constancy = lhsInfo.getConstancy(); + constantValue = lhsInfo.getConstantValue(); + } + } else { + // The condition can be either a tensor or i1. + // If i1 is used as the condition, the entire tensor of either + // lhs or rhs is selected. + bool i1Cond = isa(op.getOperand(0).getType()); + for (auto d = 0; d < rank; ++d) { + if (i1Cond) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } else { + constancy.push_back( + std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]), + gcd(rhsInfo.getConstancy(d), condConstancy[d]))); + contiguity.push_back( + std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]), + gcd(rhsInfo.getContiguity(d), condConstancy[d]))); + if (contiguity.back() == lhsInfo.getContiguity(d) && + contiguity.back() == rhsInfo.getContiguity(d)) { + // Contiguity not changed + divisibility.push_back( + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + } else { + // Contiguity changed, we cannot use only divisibility. + // For example, the following example should have contiguity 2 and + // divisibility 2 + // [[0, 1], [4, 5]] + // [[16, 17, 18, 19]] + divisibility.push_back( + std::min(gcd(lhsInfo.getDivisibility(d), contiguity.back()), + gcd(rhsInfo.getDivisibility(d), contiguity.back()))); + } + } + } + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value() && + lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) + constantValue = lhsInfo.getConstantValue(); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } +}; + +template +class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() & + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() | + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() ^ + rhs.getConstantValue().value()}; + } + } + return {}; + } +}; + +class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto shift = rhs.getConstantValue().value_or(0); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto numBits = log2Int(lhsDivisibility); + return multiplyDivisor(lhsDivisibility, 1ll << shift); + } + + int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() << rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (!rhs.getConstantValue().has_value()) + return 1; + auto shift = rhs.getConstantValue().value(); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return std::max(1, lhsDivisibility / (int64_t(1) << shift)); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + std::optional constantValue; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::max(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } else if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::min(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } + return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), + /*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1), + /*knownConstancy=*/AxisInfo::DimVectorT(rank, 1), + /*constantValue=*/constantValue); + } else { + AxisInfo::DimVectorT contiguity, divisibility, constancy; + for (auto d = 0; d < rank; ++d) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } + return AxisInfo(contiguity, divisibility, constancy, std::nullopt); + } + } +}; + +//===----------------------------------------------------------------------===// +// AxisInfoAnalysis +//===----------------------------------------------------------------------===// + +AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis>( + solver) { + // UnrealizedConversionCast: + // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast + // may exist + visitors.append, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor>(); + visitors.append(); + visitors.append, + DivOpAxisInfoVisitor>(); + visitors.append, + RemOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append>(); + visitors.append, + LogicalOpAxisInfoVisitor, + LogicalOpAxisInfoVisitor>(); + visitors.append>(); + visitors.append, + ShROpAxisInfoVisitor>(); + visitors.append, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor>(); + visitors.append(); +} + +LogicalResult AxisInfoAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + AxisInfo curr = visitors.apply(op, operands); + if (curr.getRank() == 0) { + setAllToEntryStates(results); + return success(); + } + // override with hint + auto newContiguity = curr.getContiguity(); + auto newDivisibility = curr.getDivisibility(); + auto newConstancy = curr.getConstancy(); + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + curr = AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(curr)); + return success(); +} + +void AxisInfoAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { + ProgramPoint *programPoint = getProgramPointAfter(op); + const auto &lb = + getLatticeElementFor(programPoint, op.getLowerBound())->getValue(); + const auto &step = + getLatticeElementFor(programPoint, op.getStep())->getValue(); + + AxisInfo::DimVectorT knownContiguity(1, 1); + AxisInfo::DimVectorT knownDivisibility(1, 1); + AxisInfo::DimVectorT knownConstancy(1, 1); + knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0)); + auto inductionVar = + AxisInfo(knownContiguity, knownDivisibility, knownConstancy); + (void)argLattices[0]->join(inductionVar); +} + +} // anonymous namespace + +template +void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy) { + // liast of attributes that we care about + SmallVector> retVecs; + retVecs.push_back({contiguity, "tt.contiguity"}); + retVecs.push_back({divisibility, "tt.divisibility"}); + retVecs.push_back({constancy, "tt.constancy"}); + // initialize attributes one by one + for (auto [vec, attrName] : retVecs) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto int_attr = dyn_cast_or_null(attr)) + *vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue()); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + *vec = DimVectorT(vals.begin(), vals.end()); + } + } +} + +/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { + auto rank = 1; + if (TensorType ty = dyn_cast(value.getType())) + rank = ty.getRank(); + if (triton::PointerType ty = dyn_cast(value.getType())) + if (TensorType elemTy = dyn_cast(ty.getPointeeType())) + rank = elemTy.getRank(); + + DimVectorT knownContiguity(rank, 1); + DimVectorT knownDivisibility(rank, 1); + DimVectorT knownConstancy(rank, 1); + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) { + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + } else if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); + knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); + knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); + } + } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); + knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); + knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); + } + // Other operations are conservatively initialized with the lowest possible + // divisibility, contiguity, and constancy unless they have specified. + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + knownDivisibility = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + knownContiguity = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + knownConstancy = DimVectorT(vals.begin(), vals.end()); + } + } + + return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); +} + +/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + for (auto d = 0; d < lhs.getRank(); ++d) { + contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); + divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); + } + std::optional constantValue; + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value() && + lhs.getConstantValue() == rhs.getConstantValue()) + constantValue = lhs.getConstantValue(); + return AxisInfo(contiguity, divisibility, constancy, constantValue); +} + +unsigned ModuleAxisInfoAnalysis::getContiguity(Value value) { + auto tensorTy = dyn_cast(value.getType()); + if (!tensorTy) + return 1; + auto elemTy = tensorTy.getElementType(); + // Get the pointee type if we have a tensor of ptrs to compute contiguity for + if (auto ptrTy = dyn_cast(elemTy)) { + elemTy = ptrTy.getPointeeType(); + } + return getContiguity(value, elemTy.getIntOrFloatBitWidth()); +} + +unsigned ModuleAxisInfoAnalysis::getContiguity(Value offsetsValue, + unsigned elementBitWidth) { + // FIXME: This is not as good as it could be, as we don't need to restrict + // the analysis to one dimension. We should determine contiguity on the + // flattenOuts() layout + auto tensorTy = cast(offsetsValue.getType()); + auto linAttr = + gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape()); + auto order = linAttr.getOrder(); + unsigned align = getAlignment(offsetsValue, elementBitWidth); + + auto uniqueContigPerThread = linAttr.getContigPerThread(); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + LDBG("getContiguity uniqueContigPerThread = " << contiguity); + contiguity = std::min(align, contiguity); + + return contiguity; +} + +unsigned ModuleAxisInfoAnalysis::getAlignment(Value value) { + auto tensorTy = dyn_cast(value.getType()); + if (!tensorTy) + return 1; + + auto elemTy = tensorTy.getElementType(); + // Get the pointee type if we have a tensor of ptrs to compute contiguity for + if (auto ptrTy = dyn_cast(elemTy)) { + elemTy = ptrTy.getPointeeType(); + } + return getAlignment(value, elemTy.getIntOrFloatBitWidth()); +} + +unsigned ModuleAxisInfoAnalysis::getAlignment(Value offsetsValue, + unsigned elementBitWidth) { + auto tensorTy = cast(offsetsValue.getType()); + auto *axisInfo = getAxisInfo(offsetsValue); + if (!axisInfo) + return 1; + auto linAttr = + gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape()); + auto order = linAttr.getOrder(); + auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto maxContig = axisInfo->getContiguity(order[0]); + + auto elemNumBytes = std::max(elementBitWidth / 8, 1); + auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); + unsigned alignment = std::min(maxMultiple, maxContig); + LDBG("getAlignment order[0] " + << order[0] << " maxMultipleBytes = " << maxMultipleBytes + << " maxContig = " << maxContig << " elemNumBits = " << elementBitWidth + << " maxMultiple = " << maxMultiple << " alignment " << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = dyn_cast(mask.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(mask); + if (!axisInfo) + return 1; + auto linAttr = + gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape()); + auto maskOrder = linAttr.getOrder(); + auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); + LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " + << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *analysis = solver->load(); + WalkResult result = funcOp.walk([&](Operation *op) { + if (op->hasTrait() && + failed(solver->initializeAndRun(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return; + + auto *axisInfoMap = getFuncData(funcOp); + auto updateAxisInfoMap = [&](Value value) { + auto axisInfo = analysis->getLatticeElement(value)->getValue(); + AxisInfo curAxisInfo; + if (axisInfoMap->count(value)) { + curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value)); + } else { + curAxisInfo = axisInfo; + } + (*axisInfoMap)[value] = curAxisInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateAxisInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateAxisInfoMap(value); + } + }); +} + +void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *axisInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, int64_t prevValue) { + auto curValue = highestPowOf2Divisor(0); + if (callee.getArgAttrOfType(index, attrName)) { + curValue = + callee.getArgAttrOfType(index, attrName).getInt(); + } + auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64), + gcd(prevValue, curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto axisInfo = axisInfoMap->lookup(value); + assert(axisInfo.getRank() == 1 && "only scalar arguments are supported"); + setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); + setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); + setAttrFn("tt.constancy", axisInfo.getConstancy(0)); + } +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/lib/Analysis/CMakeLists.txt b/third_party/enflame/include/triton/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..693d222f2 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Analysis/CMakeLists.txt @@ -0,0 +1,20 @@ +add_triton_library(TritonAnalysis + AxisInfo.cpp + Allocation.cpp + Membar.cpp + Alias.cpp + Utility.cpp + + DEPENDS + TritonTableGen + TritonGPUTableGen + TritonGPUAttrDefsIncGen + TritonGPUTypeInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRLLVMDialect + TritonIR + TritonGPUIR + TritonNvidiaGPUIR +) diff --git a/third_party/enflame/include/triton/lib/Analysis/Membar.cpp b/third_party/enflame/include/triton/lib/Analysis/Membar.cpp new file mode 100644 index 000000000..ca9cd507c --- /dev/null +++ b/third_party/enflame/include/triton/lib/Analysis/Membar.cpp @@ -0,0 +1,252 @@ +#include "triton/Analysis/Membar.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include + +namespace mlir { + +void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) { + FunctionOpInterface funcOp = + dyn_cast(allocation->getOperation()); + OpBuilder builder(funcOp.getContext()); + resolve(funcOp, &funcBlockInfoMap, &builder); +} + +void MembarAnalysis::resolve(FunctionOpInterface funcOp, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + // Initialize the blockList. Operations are organized into "virtual blocks", + // which represent segments of straight-line code analyzed by each iteration + // of the dataflow analysis. Virtual blocks abstract over both control flow + // represented by basic blocks and block successors (i.e. `BranchOpInterface`) + // and control flow represented by regions (i.e. `RegionBranchOpInterface`). + // + // A virtual block consists of a parent block and a starting iterator, where + // the virtual block starts on the operation *after* the starting iterator. A + // null iterator is used to represent the beginning of the block. The virtual + // block ends at any region branch operation or the basic block terminator. + // Thus, basic blocks are broken up into multiple virtual blocks at each + // region operation. + // + // Entry virtual blocks are represented by a null iterator. Populate the + // blockList with the entry virtual blocks in the function. Then, each + // iteration scans until a terminator or region branch operation is found. + DenseMap inputBlockInfoMap; + DenseMap outputBlockInfoMap; + std::deque blockList; + funcOp.walk([&](Block *block) { + // Start the analysis from the entry blocks of any nested isolated from + // above regions. + if (block->isEntryBlock() && + !isa(block->getParentOp())) + blockList.emplace_back(block, Block::iterator()); + }); + + // A fixed point algorithm + while (!blockList.empty()) { + VirtualBlock block = blockList.front(); + blockList.pop_front(); + // Make a copy of the inputblockInfo but not update + auto inputBlockInfo = inputBlockInfoMap[block]; + SmallVector successors; + Block::iterator startIt = + block.second.isValid() ? std::next(block.second) : block.first->begin(); + for (Operation &op : llvm::make_range(startIt, block.first->end())) { + if (op.hasTrait() || + isa(op)) { + visitTerminator(&op, successors); + break; + } + update(&op, &inputBlockInfo, funcBlockInfoMap, builder); + } + // Get the reference because we want to update if it changed + if (outputBlockInfoMap.count(block) && + inputBlockInfo == outputBlockInfoMap[block]) { + // If we have seen the block before and the inputBlockInfo is the same as + // the outputBlockInfo, we skip the successors + continue; + } + // Update the current block. The block transfer function is not monotonic, + // so overwrite the output state entirely. + outputBlockInfoMap[block] = inputBlockInfo; + // Update the successors + for (VirtualBlock successor : successors) { + inputBlockInfoMap[successor].join(outputBlockInfoMap[block]); + blockList.emplace_back(successor); + } + } + + // Update the final dangling buffers that haven't been synced + BlockInfo &funcBlockInfo = (*funcBlockInfoMap)[funcOp]; + funcOp.walk([&](triton::ReturnOp returnOp) { + // A basic block can be broken into several virtual blocks. Find all virtual + // blocks that belong to the basic block containing the return. + SmallVector> virtualBlocks; + for (auto &[block, blockInfo] : outputBlockInfoMap) { + if (block.first == returnOp->getBlock()) + virtualBlocks.emplace_back(block, blockInfo); + } + // The return is a terminator, so the virtual block that contains this + // return starts after all other ones. Find it by comparing the start + // iterators of the virtual blocks. + auto maxIt = llvm::max_element(virtualBlocks, [&](auto &lhs, auto &rhs) { + assert(lhs.first.first == rhs.first.first); + Block::iterator lhsIt = lhs.first.second, rhsIt = rhs.first.second; + return !lhsIt.isValid() || + (rhsIt.isValid() && lhsIt->isBeforeInBlock(&*rhsIt)); + }); + + funcBlockInfo.join(maxIt->second); + }); +} + +void MembarAnalysis::visitTerminator(Operation *op, + SmallVector &successors) { + if (isa(op)) { + // Collect the block successors of the branch. + for (Block *successor : op->getSuccessors()) + successors.emplace_back(successor, Block::iterator()); + return; + } + + if (auto br = dyn_cast(op)) { + // The successors of an operation with regions can be queried via an + // interface. The operation branches to the entry blocks of its region + // successors. It can also branch to after itself. + SmallVector regions; + br.getSuccessorRegions(RegionBranchPoint::parent(), regions); + for (RegionSuccessor ®ion : regions) { + if (region.isParent()) { + successors.emplace_back(br->getBlock(), br->getIterator()); + } else { + Block &block = region.getSuccessor()->front(); + successors.emplace_back(&block, Block::iterator()); + } + } + return; + } + + // FIXME: `ReturnLike` adds `RegionBranchTerminatorOpInterface` for some + // reason. Check that the parent is actually a `RegionBranchOpInterface`. + auto br = dyn_cast(op); + if (br && isa(br->getParentOp())) { + // Check the successors of a region branch terminator. It can branch to + // another region of its parent operation or to after the parent op. + SmallVector operands(br->getNumOperands()); + SmallVector regions; + br.getSuccessorRegions(operands, regions); + for (RegionSuccessor ®ion : regions) { + if (region.isParent()) { + Operation *parent = br->getParentOp(); + successors.emplace_back(parent->getBlock(), parent->getIterator()); + } else { + Block &block = region.getSuccessor()->front(); + successors.emplace_back(&block, Block::iterator()); + } + } + return; + } + + // Otherwise, it could be a return op + if (op->hasTrait()) + return; + llvm_unreachable("Unknown terminator encountered in membar analysis"); +} + +void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { + OpBuilder::InsertionGuard g(*builder); + ::insertBarrier(*builder, op); +} + +void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + if (isa(op)) { + // If the current op is a barrier, we sync previous reads and writes + blockInfo->sync(); + return; + } + + if (isa(op) && + !isa(op->getNextNode())) { + // If the current op is an async wait and the next op is not a barrier we + // insert a barrier op and sync + builder->setInsertionPointAfter(op); + insertBarrier(op, builder); + blockInfo->sync(); + return; + } + + BlockInfo curBlockInfo; + auto scratchBufferId = Allocation::InvalidBufferId; + if (isa(op)) { + // Inter-function dependencies + auto callOpInterface = dyn_cast(op); + if (auto callee = + dyn_cast(callOpInterface.resolveCallable())) + curBlockInfo = funcBlockInfoMap->lookup(callee); + } else { + // Intra-function dependencies + if (auto memoryEffectOpInterface = dyn_cast(op)) { + // Explicit buffer + SmallVector> + effectInstances; + memoryEffectOpInterface.getEffects(effectInstances); + for (auto effectInstance : effectInstances) { + if (auto value = effectInstance.getValue()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) { + if (isa(effectInstance.getEffect())) + curBlockInfo + .syncWriteIntervals[allocation->getAllocatedInterval( + bufferId)] + .insert(op); + else if (isa(effectInstance.getEffect())) + curBlockInfo + .syncReadIntervals[allocation->getAllocatedInterval( + bufferId)] + .insert(op); + } + } + } + } + } + scratchBufferId = allocation->getBufferId(op); + } + + // Scratch buffer operations consist of a series of shared memory operations + // starting from a shared memory write, followed by a series of shared memory + // read/write operations, and ending with a shared memory read, i.e., shared + // memory write -> ... -> shared memory read. + if (scratchBufferId != Allocation::InvalidBufferId) { + if (!curBlockInfo.syncReadIntervals.empty() || + !curBlockInfo.syncWriteIntervals.empty()) { + llvm::report_fatal_error( + "scratch buffer operations should not have any shared memory " + "dependencies"); + } + auto interval = allocation->getAllocatedInterval(scratchBufferId); + curBlockInfo.syncWriteIntervals[interval].insert(op); + if (blockInfo->isIntersected(curBlockInfo, filter)) { + builder->setInsertionPoint(op); + insertBarrier(op, builder); + } + // Ops with a scratch buffer internally syncs read/write on shared memory + blockInfo->sync(); + curBlockInfo.syncReadIntervals[interval].insert(op); + } else if (blockInfo->isIntersected(curBlockInfo, filter)) { + builder->setInsertionPoint(op); + insertBarrier(op, builder); + blockInfo->sync(); + } + // Update the region info, even if barrier is inserted, we have to maintain + // the current op's read/write buffers. + blockInfo->join(curBlockInfo); +} +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Analysis/Utility.cpp b/third_party/enflame/include/triton/lib/Analysis/Utility.cpp new file mode 100644 index 000000000..8d3650c8a --- /dev/null +++ b/third_party/enflame/include/triton/lib/Analysis/Utility.cpp @@ -0,0 +1,1096 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace mlir { + +using namespace triton; +using namespace triton::gpu; + +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { + auto order = toLinearEncoding(srcEncoding, srcShape).getOrder(); + auto it = std::find(order.begin(), order.end(), axis); + // delete the axis from order + order.erase(it); + // insert axis at the beginning of order + order.insert(order.begin(), axis); + return order; +} + +// Thread offset is the thread index offset of two adjacent threads on the +// reduction axis within the warp. +unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { + auto *ctx = srcEncoding.getContext(); + auto linearLayout = toLinearLayout(srcShape, srcEncoding); + auto kLane = mlir::StringAttr::get(ctx, "lane"); + const auto &bases = linearLayout.getBases(); + const auto &lanes = bases.find(kLane)->second; + auto offset = 1; + for (const auto &lane : lanes) { + if (lane[axis] != 0) + break; + offset *= 2; + } + return offset; +} + +// Cases where distributed shared memory is not required in ConvertLayout: +// (1) numCTAs == 1 +// (2) numCTAs > 1 but srcCTALayout == dstCTALayout +// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented +// in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { + unsigned numCTAs = getNumCTAs(srcLayout); + assert(numCTAs == getNumCTAs(dstLayout) && + "Invalid layout conversion: the numbers of CTAs of src and dst " + "layouts are different"); + + // Case (1): Never use dsmem when numCTAs == 1 + if (numCTAs == 1) + return false; + + // Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not + // implemented yet + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + llvm::report_fatal_error("Layout conversion to be implemented"); + } + + // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported + if (auto sliceLayout = mlir::dyn_cast(dstLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + return true; + } + + // The above two branches make sure that it is legal to call getCTALayout of + // srcLayout and dstLayout + + // Case (2): Do not use dsmem when srcCTALayout == dstCTALayout + auto srcCTALayout = getCTALayout(srcLayout); + auto dstCTALayout = getCTALayout(dstLayout); + if (srcCTALayout == dstCTALayout) + return false; + + // Dsmem access is required when srcCTALayout != dstCTALayout + return true; +} + +unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { + return getWarpsPerCTAWithUniqueData(srcEncoding, srcShape)[axis]; +} + +unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { + return getThreadsPerWarpWithUniqueData(srcEncoding, srcShape)[axis]; +} + +bool ReduceOpHelper::isWarpSynchronous() { + return getWarpsPerCTAWithUniqueData(srcEncoding, srcShape)[axis] == 1; +} + +SmallVector ReduceOpHelper::getScratchRepShape() { + SmallVector smemShape; + // This case doesn't need inter-warp communication + if (isWarpSynchronous()) + return {0, 0}; + + smemShape = convertType(srcShape); + smemShape[axis] = getInterWarpSizeWithUniqueData(); + + return smemShape; +} + +unsigned ReduceOpHelper::getScratchSizeInBytes() { + auto smemShape = getScratchRepShape(); + auto elems = product(smemShape); + + unsigned bytesPerElem = 0; + for (const auto &ty : srcElementTypes) { + bytesPerElem += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return bytesPerElem * elems; +} + +bool ReduceOpHelper::isReduceWithinCTA() { + // TODO: Support reduce across CTAS + // Layout optimization passes such as PlanCTAPass and + // RemoveLayoutConversionPass should avoid cross-CTA reduction + return getCTASplitNum(srcEncoding)[axis] == 1; +} + +unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { + return getEncoding().getContigPerThread()[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { + auto contigPerThread = getEncoding().getContigPerThread(); + contigPerThread[getAxis()] = 1; + return product(contigPerThread); +} + +Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); } + +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() { + return getEncoding().getThreadsPerWarp()[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { + auto nThreads = product(getEncoding().getThreadsPerWarp()); + return nThreads / getAxisNumThreadsPerWarpWithUniqueData(); +} + +// Return the flat numbers of threads computing independent scan results. +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { + auto nWarps = product(getEncoding().getWarpsPerCTA()); + return (nWarps / getAxisNumWarpsWithUniqueData()) * + getNonAxisNumThreadsPerWarp(); +} + +unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { + return getEncoding().getWarpsPerCTA()[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumBlocks() { + auto contigPerThread = getEncoding().getContigPerThread(); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + unsigned axis = getAxis(); + return ceil( + getShape()[axis], + (contigPerThread[axis] * threadsPerWarp[axis] * warpsPerCTA[axis])); +} + +unsigned ScanLoweringHelper::getNonAxisNumBlocks() { + auto contigPerThread = getEncoding().getContigPerThread(); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + auto rank = contigPerThread.size(); + unsigned axis = getAxis(); + unsigned numBlocks = 1; + for (unsigned i = 0; i < rank; i++) { + if (i == axis) + continue; + numBlocks *= + ceil(getShape()[i], (contigPerThread[i] * threadsPerWarp[i] * + warpsPerCTA[i])); + } + return numBlocks; +} + +bool ScanLoweringHelper::isSupported() { + // TODO: Support the following cases: + // 1. Scan on non-blocking encodings + if (!isa(legacyEncoding)) + return false; + return true; +} + +unsigned ScanLoweringHelper::getScratchSizeInElems() { + unsigned numWarps = product(getEncoding().getWarpsPerCTA()); + unsigned numNonAxisElementsPerWarp = + getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread(); + unsigned numElements = numWarps * numNonAxisElementsPerWarp * + getAxisNumBlocks() * getNonAxisNumBlocks(); + return numElements; +} + +unsigned ScanLoweringHelper::getScratchSizeInBytes() { + // Lowering will fail later if the layout is not supported. + if (!isSupported()) + return 0; + + unsigned axisNumWarps = getAxisNumWarpsWithUniqueData(); + if (axisNumWarps == 1) + return 0; + unsigned elementSizeInBytes = 0; + for (const auto &ty : srcElementTypes) { + elementSizeInBytes += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return elementSizeInBytes * getScratchSizeInElems(); +} + +std::optional +getWarpLayoutConvertDecomposition(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto conversion = minimalCvtLayout(srcTy, dstTy); + + MLIRContext *ctx = srcTy.getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + + // We have already checked that data movement is only required within a warp, + // thus we can discard the block and warp dimensions. + LinearLayout C = conversion.sublayout({kLane, kRegister}, {kLane, kRegister}); + + // `C` is map from `(dst_lane, dst_reg) -> (src_lane, src_reg)`. From the + // perspetive of the destination lane, it tells us which register from which + // lane to get the value. Since the source and destination layouts are + // subpermutation matrices, the overall transformation amounts to permuting + // data around (plus broadcasting, if necessary). + // + // Warp shuffles allow indexing into another lane, but does not allowing + // selecting the register. Suppose we decompose `C` into `C = P1 ∘ W ∘ P2`, + // where `W` is a warp shuffle and `P1` and `P2` are (lane-dependent) register + // permutations within a lane. Start from `C` and work backwards. + // + // Given any `C`, is it possible that for a given destination register, two + // destination lanes map to different source registers in the same source + // lane. This is impossible to represent using a shuffle. This happens when, + // with respect to the identity layout, a register base is swapped with a lane + // base (when the destination lane changes, the source register changes but + // the lane does not). + // + // Example: + // + // src = {register = [[1,0], [2,0]], lane = [[0,1], [0,2]]} + // dst = {register = [[0,1], [2,0]], lane = [[1,0], [0,2]]} + // cvt = dst, since src is the identity layout + // + // The map from destination -> source looks like: + // + // dst_lane + // dst_reg 0 1 2 3 + // 0 T0:0 T0:1 T2:0 T2:1 + // 1 T1:0 T1:1 T3:0 T3:1 + // 2 T0:2 T0:3 T2:2 T2:3 + // 3 T1:2 T1:3 T3:2 T3:3 + // + // Note for each destination register, two lanes want two different registers + // in the same source lane (T0:0 -> T0:0, T1:0 -> T0:1). This is impossible to + // represent with a warp shuffle, because the source lane (e.g. T0) can only + // supply one of its registers as the shuffle value. + // + // The goal of `P2` is to permute registers within a thread so that this does + // not happen. Specifically, pick `P2` such that bases in + // `(P2^-1 ∘ C).sublayout(kLane, {kLane, kRegister})` has non-zero lane + // components when the register components are non-zero. + // + // P2 can only change the register mapping within a thread. Constrain P2 as: + // + // P2 = [ I 0 ] + // [ P I ] + // + // Then `P2^-1 ∘ C` is: + // + // [ I 0 ] [ C(r,r) C(r,l) ] = [ C(r,r) C(r,l) ] + // [ P' I ] [ C(l,r) C(l,l) ] [ P'*C(r,r)+C(l,r) P'*C(r,l)+C(l,l) ] + // + // Where addition in GF(2) is xor. + // + // We can see that P' selects rows (i.e. bases) from the upper half (register) + // and combines them with the lower half (lane). Because the goal is for P' to + // select register bases `i` where C(r,l)[i] != 0, we know P'*C(r,r) = 0, + // since the corresponding C(r,r)[i] element in the same row will be zero. + // + // Note that solutions for P' do not always exist (no register permutation + // will decompose C to make the warp shuffle possible), and this happens when + // there aren't enough non-zero bases in C(r,l). + // + // Find the indices of the missing lane bases: rows in the lower half where + // the register component is non-zero but the lane component is zero. + SmallVector missingLaneRows; + for (int i : llvm::seq(C.getInDimSizeLog2(kLane))) { + ArrayRef /*C(l,(r,l))[i]*/ lowerHalfRow = C.getBasis(kLane, i); + assert(lowerHalfRow.size() == 2); + if (/*C(l,r)[i]*/ lowerHalfRow[0] != 0) { + assert(/*C(l,l)[i]*/ lowerHalfRow[1] == 0); + missingLaneRows.push_back(i); + } else if (lowerHalfRow[1] == 0) { + // If there is broadcasting along the lane, then C'(l,l) below won't be + // invertible. Intuitively, the dst tensor contains a subset of the src + // tensor's data, so recovering the src tensor through permutation alone + // is impossible. We would need an affine component (bfly shuffle). + return {}; + } + } + + // Find rows in the upper-half of C (i.e. the (reg) -> (reg, lane) submatrix) + // that can be selected by P' to make the lane components in the lower half + // (i.e. the (lane) -> (lane) submatrix) non-zero. + std::vector> PPrimeLaneBases(C.getInDimSizeLog2(kLane), + {0}); + for (int i : llvm::seq(C.getInDimSizeLog2(kRegister))) { + ArrayRef /*C(r,(r,l))[i]*/ upperHalfRow = C.getBasis(kRegister, i); + assert(upperHalfRow.size() == 2); + if (/*C(r,l)[i]*/ upperHalfRow[1] == 0) + continue; + + assert(upperHalfRow[0] == 0); + int32_t laneBase = upperHalfRow[1]; + assert(/*C(r,r)[i]*/ upperHalfRow[0] == 0); + if (!missingLaneRows.empty()) { + // Select row i into row j from the missing rows. The order in which the + // missing rows are selected doesn't really matter. + PPrimeLaneBases[missingLaneRows.pop_back_val()][0] |= (1 << i); + } + } + if (!missingLaneRows.empty()) { + // The decomposition failed. No solution for P' is possible. + return {}; + } + + // P' outputs the destination register. + LinearLayout PPrime({{kLane, std::move(PPrimeLaneBases)}}, + {{kRegister, C.getInDimSize(kRegister)}}, + /*requiresSurjective=*/false); + + // Form P2^-1 from P'. + unsigned dstRegSize = C.getInDimSize(kRegister); + unsigned numLanes = C.getInDimSize(kLane); + LinearLayout P2invTop = + LinearLayout::identity1D(dstRegSize, kRegister, kRegister) + .concatOuts( + LinearLayout::zeros1D(dstRegSize, kRegister, kLane, numLanes)); + LinearLayout P2invBot = + PPrime.concatOuts(LinearLayout::identity1D(numLanes, kLane, kLane)); + LinearLayout P2inv = P2invTop.concatIns(P2invBot); + + // Check that P2^-1 was formed correctly. + assert(P2inv.sublayoutIsZero(kRegister, kLane)); + assert(squareSublayoutIsPermutation(P2inv, kLane)); + + LinearLayout Cp = P2inv.compose(C); + + // Now we have C' = P2^-1 ∘ C = W ∘ P1. W is considerably easier to compute. + // A warp shuffle is a function from `(register, lane) -> (lane)`, i.e. + // + // W = [ I R' ] + // [ 0 L ] + // + // `W^-1 ∘ C'` will be + // + // [ I R ] [ C'(r,r) C'(r,l) ] = [ ... C'(r,l) + R*C'(l,l) ] + // [ 0 L ] [ C'(l,r) C'(l,l) ] = [ ... L*C'(l,l) ] + // + // Since P1 cannot change lanes, we know that + // + // W^-1 ∘ C' = [ ... 0 ] + // [ ... I ] + // + // Thus L = C'(l,l)^-1, and R = -C'(r,l) * C'(l,l)^-1. (0 - LL) = LL in GF(2). + // We know that C'(l,l) has a suitable pseudo-inverse. + LinearLayout L = Cp.sublayout(kLane, kLane).pseudoinvert(); + LinearLayout R = Cp.sublayout(kRegister, kLane).compose(L); + + // Now form W^-1. + LinearLayout WinvLeft = + LinearLayout::identity1D(dstRegSize, kRegister, kRegister) + .concatIns( + LinearLayout::zeros1D(numLanes, kLane, kRegister, dstRegSize)); + LinearLayout Winv = WinvLeft.concatOuts(R.concatIns(L)); + + // Check that Winv was formed correctly. P1 is just what's left over. + LinearLayout P1 = Winv.compose(Cp); + assert(P1.sublayoutIsZero(kRegister, kLane)); + assert(squareSublayoutIsIdentity(P1, kLane)); + + // Grab just the interesting parts of the decomposed layouts. + P1 = P1.sublayout({kLane, kRegister}, kRegister); + P2inv = P2inv.sublayout({kLane, kRegister}, kRegister); + Cp = Cp.sublayout({kLane, kRegister}, kLane); + + // To minimize the number of selects emitted on the source side, determine the + // minimum set of registers that could be selected from each thread. + // InstCombine *might* be able to crush this, but if the sizePerThread is + // large, it's truly a huge number of selects that get emitted. + // If reducedP1 is trivial, then we will emit + // shflSrc = select(i == i, src[i], undef) and this will get trivially folded, + // so don't worry about this case. + LinearLayout reducedP1 = P1.removeZeroBasesAlongDim(kLane); + LinearLayout reducedP2 = P2inv.removeZeroBasesAlongDim(kLane); + + // The number of emitted selects can still be quite large if the layout is not + // cooperative. This happens when the source register is more correlated + // with the desination lane than the destination register (i.e. the number of + // non-zero bases). The number of selects impacts performance and grows + // exponentially with the number of non-zero bases. Experiments show that more + // than 1 select causes performance to be slower than shared memory. + if (reducedP1.getInDimSize(kLane) > 2 || reducedP2.getInDimSize(kLane) > 2) + return {}; + + // HACK: Workaround AMD codegen path generating transient invalid layouts. + auto isInvalidDotEnc = [](RankedTensorType type) { + auto dotEnc = dyn_cast(type.getEncoding()); + return dotEnc && dotEnc.getKWidth() == 0; + }; + if (isInvalidDotEnc(srcTy) || isInvalidDotEnc(dstTy)) + return {}; + + // When the element type is smaller than 32 bits, values are upcasted to i32 + // for shuffles. When the shared memory conversion can use vector stores of + // sufficiently large length, the shared memory conversion is faster. + // TODO: Implementing shuffling packed 16 and 8 bit values. + auto [inVec, outVec] = getScratchCvtInOutVecLengths(srcTy, dstTy); + if (!isa(srcTy.getElementType()) && + srcTy.getElementTypeBitWidth() < 32 && inVec > 4 && outVec > 4) + return {}; + + // Return just the interesting parts of the decomposed layouts. + return {{std::move(P1), std::move(Cp), std::move(P2inv), std::move(reducedP1), + std::move(reducedP2)}}; +} + +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, + ArrayRef dstShape) { + SmallVector, SmallVector>> ret; + + if (srcShape.empty()) { + assert(dstShape.empty()); + return ret; + } + ret.push_back({}); + + int srcIdx = 0; + int dstIdx = 0; + int srcNElems = 1; + int dstNElems = 1; + while (srcIdx < srcShape.size() || dstIdx < dstShape.size()) { + if (srcNElems < dstNElems || // + (srcIdx < srcShape.size() && srcNElems == 1) || + (srcIdx < srcShape.size() && srcShape[srcIdx] == 1)) { + assert(srcIdx < srcShape.size()); + srcNElems *= srcShape[srcIdx]; + ret.back().first.push_back(srcIdx); + srcIdx++; + } else if (dstNElems < srcNElems || + (dstIdx < dstShape.size() && dstShape[dstIdx] == 1)) { + assert(dstIdx < dstShape.size()); + dstNElems *= dstShape[dstIdx]; + ret.back().second.push_back(dstIdx); + dstIdx++; + } else { + ret.push_back({}); + srcNElems = 1; + dstNElems = 1; + } + } + return ret; +} + +unsigned ScanLoweringHelper::getAxisElementStride() { + auto order = getOrder(); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= getEncoding().getContigPerThread()[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisThreadStride() { + auto encoding = getEncoding(); + auto kThread = StringAttr::get(encoding.getContext(), "lane"); + // OOOGHHH This is nasty. We should implement this lowering via LLs natively + // to avoid this + auto threadsPerWarp = encoding.basesPerDim(kThread, /*skipBroadcast=*/false); + auto order = getOrder(); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= threadsPerWarp[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisBlockStride() { + auto order = getOrder(); + unsigned stride = 1; + auto contigPerThread = getEncoding().getContigPerThread(); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= ceil(getShape()[dim], contigPerThread[dim] * + threadsPerWarp[dim] * + warpsPerCTA[dim]); + } + llvm_unreachable("Axis not found in order"); +} + +GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp) + : gatherOp(gatherOp) {} + +unsigned GatherLoweringHelper::getScratchSizeInBytes() { + // If the gather is warp-local, no scratch space is needed. + if (isWarpLocal()) + return 0; + + // Otherwise, performing the gather will require scratch space to communicate + // the source tensor across threads. For now, assume the whole source tensor + // is written back to shared memory. + RankedTensorType srcType = gatherOp.getSrc().getType(); + return product(srcType.getShape()) * + ceil(srcType.getElementTypeBitWidth(), 8); +} + +bool GatherLoweringHelper::isWarpLocal() { + // The gather is warp-local if for each column along the gather axis in the + // source and index tensors, all the elements are owned by the same warp. + RankedTensorType srcType = gatherOp.getSrc().getType(); + RankedTensorType idxType = gatherOp.getIndices().getType(); + LinearLayout srcLayout = + toLinearLayout(srcType.getShape(), srcType.getEncoding()); + LinearLayout idxLayout = + toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + Builder b(gatherOp.getContext()); + StringAttr kBlock = b.getStringAttr("block"); + StringAttr kWarp = b.getStringAttr("warp"); + StringAttr kLane = b.getStringAttr("lane"); + StringAttr kGatherDim = + b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); + + // The tensor layouts must be distributed layouts, where the basis matrix is a + // subpermutation matrix (permutation matrix plus zeros for broadcasting). + // FIXME(jeff): Check this invariant somehow. + // + // We want to know if all elements of a column along the gather axis are + // mapped to the same set of warps, which means the gather can be performed + // entirely within the warp. We need to query + // + // srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) + // + // But due to broadcasting, the matrix might not be invertible. But since the + // matrix is a permutation matrix (checked below), we can instead query + // + // srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim}) + // + // Which implies that changing the warp will not change the gather dimension. + // And since there is no swizzling, this applies to all warps. + if (!srcLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim) || + !idxLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim)) + return false; + + SmallVector otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + if (dim != gatherOp.getAxis()) { + otherDims.push_back(b.getStringAttr("dim" + Twine(dim))); + } + } + + // If the gather axis `dimN` is invariant to the warp, but the `(block, warp)` + // mapping to all other dimensions must be the same for both layouts. If so, + // then the warp that owns a particular index element also owns all the source + // elements it could index into. + if (srcLayout.sublayout({kBlock, kWarp}, otherDims) != + idxLayout.sublayout({kBlock, kWarp}, otherDims)) + return false; + + // The two constraints above ensure that data-movement to perform the gather + // operation are contained within a warp. The subsequent constraints simplify + // codegen. + + // Require that for any given gather column, the threads mapped to the column + // in the index and source tensors are the same. This means we don't need to + // xor shuffle across threads before emitting index shuffles; we push warp + // shuffling to layout conversions. + return srcLayout.sublayout(kLane, otherDims) == + idxLayout.sublayout(kLane, otherDims); +} + +unsigned getNumScratchElements(ArrayRef shape) { + if (shape.empty()) + return 0; + return product(shape); +} + +bool supportMMA(triton::DotOp op, int version) { + // Refer to mma section for the data type supported by Volta and Hopper + // Tensor Core in + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + auto aElemTy = op.getA().getType().getElementType(); + auto bElemTy = op.getB().getType().getElementType(); + if (version == 5) { + if (triton::tools::getBoolEnv("DISABLE_MMA_V5")) + return false; + auto retType = op.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + int numWarps = lookupNumWarps(op); + if (aElemTy.isInteger() || bElemTy.isInteger() || + retType.getElementType().isInteger()) + return false; + if (op.getType().getRank() != 2) + return false; + if (numWarps != 4 && numWarps != 8) { + // Currently only support numWarps 4 or 8 for TMEM load and store. + return false; + } + if (!(retShapePerCTA[rank - 2] % 64 == 0 && + retShapePerCTA[rank - 1] % 8 == 0)) + return false; + return true; + } + if (version == 3) { + if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) + return false; + auto retType = op.getType(); + RankedTensorType typeA = op.getA().getType(); + int k = typeA.getShape().back(); + // If k size is smaller than the native mma size, we cannot use MMA. + if (k < 256 / aElemTy.getIntOrFloatBitWidth()) + return false; + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + int numWarps = lookupNumWarps(op); + // TODO(Keren): for now, fallback to MMAv2 if handling batch matmul. + if (rank == 3) + return false; + if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && + retShapePerCTA[rank - 1] % 8 == 0 && + (llvm::isa(aElemTy) || + aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || + aElemTy.isF32()))) { + return false; + } + // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. + if (op.getMaxNumImpreciseAcc() < 32 && + (llvm::isa(aElemTy)) && + cast(op.getType()).getElementType().isF32()) { + return false; + } + } + if (aElemTy.isF32() && bElemTy.isF32()) { + return op.getInputPrecision() == InputPrecision::TF32 && version >= 2; + } + return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); +} + +bool supportMMA(Value value, int version) { + // Tell whether a DotOp support MMA by the operand type(either $a or $b). + // We cannot get both the operand types(in TypeConverter), here we assume the + // types of both the operands are identical here. + assert((version == 1 || version == 2 || version == 3) && + "Unexpected MMA layout version found"); + auto elemTy = + cast(value.getType()).getElementType(); + // FP8 is not natively supported on all mma versions but it can always be + // promoted to fp16 therefore we can always support it. + bool isFP8 = llvm::isa(elemTy); + return isFP8 || elemTy.isF16() || elemTy.isBF16() || + (elemTy.isF32() && version >= 2) || + (elemTy.isInteger(8) && version >= 2); +} + +// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto mmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (!mmaLayout || !dotOperandLayout) { + return false; + } + int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); + auto parentTy = RankedTensorType::get( + srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent()); + auto ans = mmaLayout.getVersionMajor() == 3 && + dotOperandLayout.getOpIdx() == 0 && + mmaLayout.getWarpsPerCTA()[1] == 1 && + !cvtNeedsSharedMemory(parentTy, srcTy) && elementTypeSize == 8 && + dotOperandLayout.getKWidth() == 32 / elementTypeSize; + return ans; +} + +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (!mfmaLayout || !dotOperandLayout) + return false; + + // Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case + return dotOperandLayout.getParent() == mfmaLayout && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == 8 && + ((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) || + (mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) && + triton::type::isFloat8(srcTy.getElementType()) && + triton::type::isFloat8(dstTy.getElementType()) && + mfmaLayout.getWarpsPerCTA()[1] == 1; +} + +// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity +// under kBlock, kWarp or kLane (in that order). The idea here is that if we +// have a transformation that's the identity on kBlock, we don't need to use +// distributed shared memory. If it's also the identity on kWarp, we can +// transfer via warp-shuffles, and if it's the identity on kLane just have to +// reorder the registers +LinearLayout minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy) { + MLIRContext *ctx = srcTy.getContext(); + LinearLayout srcLayout = + toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + + auto comp = dstLayout.invertAndCompose(srcLayout); + // We try to quotient by the largest subspace first + auto dims = SmallVector{"block", "warp", "lane", "register"}; + for (auto dim : dims) { + auto quotient = comp.quotient(StringAttr::get(ctx, dim)); + if (!quotient.has_value()) { + break; + } + comp = *quotient; + } + return comp; +} + +bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto outDims = to_vector(layout.getOutDimNames()); + return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister}); +} + +bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + return to_vector(layout.getOutDimNames()) == + SmallVector{kRegister, kLane}; +} + +bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { + // TODO(jlebar): Remove these special cases `isMfmaToDotShortcut` once + // they're fully subsumed by the linear-layout checks. + return !cvtReordersRegisters(srcTy, dstTy) && + !(cvtNeedsWarpShuffle(srcTy, dstTy) && + getWarpLayoutConvertDecomposition(srcTy, dstTy)) && + !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && + // to be removed when generalized warp shuffle conversions + // are ready: + !matchMFMAAndDotOperandShuffleCase(srcTy, dstTy); +} + +bool atomicNeedsSharedMemory(Value value) { + auto type = value.getType(); + if (isa(type) || value.use_empty()) + return false; + return true; +} + +namespace { + +/// A data structure similar to SetVector but maintains +/// a deque instead of a vector to allow for efficient +/// push_back and pop_front operations. +/// Using SetVector doesn't suffice our needs because +/// it only pushes and pops from the back. +/// For example, if we have a queue like this: +/// 0->4 1->2->3 +/// ^-------- +/// where 3 depends on 4, once we pop 3, we found +/// 4 is not ready, so we check 2 and push 3 back +/// to the queue. +struct DFSSubgraphState { + DFSSubgraphState() : set(), deque() {} + DenseSet set; + std::deque deque; + + bool push_back(Operation *op) { + if (set.insert(op).second) { + deque.push_back(op); + return true; + } + return false; + } + + Operation *pop_front() { + Operation *op = deque.front(); + deque.pop_front(); + set.erase(op); + return op; + } + + bool empty() { return deque.empty(); } +}; + +/// DFS post-order implementation that maintains a global count to work across +/// multiple invocations, to help implement topological sort on multi-root DAGs. +/// We traverse all operations but only record the ones that appear in +/// `toSort` for the final result. +struct DFSState { + DFSState(const SetVector &set) : toSort(set), seen() {} + const SetVector &toSort; + SmallVector topologicalCounts; + DenseSet seen; + + /// We mark each op as ready if all its operands and parents ops are seen. If + /// an op is ready, we add it to the queue. Otherwise, we keep adding its + /// operands to the ancestors set. + /// We always want an op to be scheduled after all its parents to handle + /// correctly cases with scf operations. + void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph, + SmallVector &readyQueue) { + bool ready = true; + for (Value operand : op->getOperands()) { + auto def = operand.getDefiningOp(); + if (def && !seen.count(def)) { + subGraph.push_back(def); + ready = false; + } + } + Operation *parent = op->getParentOp(); + while (parent) { + if (!seen.count(parent)) { + subGraph.push_back(parent); + ready = false; + } + parent = parent->getParentOp(); + } + if (ready) + readyQueue.push_back(op); + } +}; + +void dfsPostorder(Operation *root, DFSState *state) { + DFSSubgraphState subGraph; + subGraph.push_back(root); + SmallVector ops; + while (!subGraph.empty()) { + // Nodes in the ready queue are ready to be processed. + // Meaning that either their operands are all seen or they have null + // operands. + SmallVector readyQueue; + auto *current = subGraph.pop_front(); + state->addToReadyQueue(current, subGraph, readyQueue); + while (!readyQueue.empty()) { + Operation *current = readyQueue.pop_back_val(); + if (!state->seen.insert(current).second) + continue; + ops.push_back(current); + for (Value result : current->getResults()) { + for (Operation *op : result.getUsers()) + state->addToReadyQueue(op, subGraph, readyQueue); + } + for (Region ®ion : current->getRegions()) { + for (Operation &op : region.getOps()) + state->addToReadyQueue(&op, subGraph, readyQueue); + } + } + } + + for (Operation *op : llvm::reverse(ops)) { + if (state->toSort.count(op) > 0) + state->topologicalCounts.push_back(op); + } +} + +} // namespace + +SetVector +multiRootTopologicalSort(const SetVector &toSort) { + if (toSort.empty()) { + return toSort; + } + + // Run from each root with global count and `seen` set. + DFSState state(toSort); + for (auto *s : toSort) { + assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); + dfsPostorder(s, &state); + } + + // Reorder and return. + SetVector res; + for (auto it = state.topologicalCounts.rbegin(), + eit = state.topologicalCounts.rend(); + it != eit; ++it) { + res.insert(*it); + } + return res; +} + +SetVector multiRootGetSlice(Operation *op, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector slice; + slice.insert(op); + + unsigned currentIndex = 0; + SetVector backwardSlice; + SetVector forwardSlice; + while (currentIndex != slice.size()) { + auto *currentOp = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentOp. + backwardSlice.clear(); + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = backwardFilter; + getBackwardSlice(currentOp, &backwardSlice, opt); + slice.insert(backwardSlice.begin(), backwardSlice.end()); + + // Compute and insert the forwardSlice starting from currentOp. + forwardSlice.clear(); + getForwardSlice(currentOp, &forwardSlice, forwardFilter); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + ++currentIndex; + } + return multiRootTopologicalSort(slice); +} + +namespace { +// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis +// interacts with constant propagation, but SparseConstantPropagation +// doesn't seem to be sufficient. +class ConstantAnalysis : public DataFlowAnalysis { +public: + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { + ProgramPoint programPoint(op); + if (failed(visit(&programPoint))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + + LogicalResult visit(ProgramPoint *point) override { + Operation *op = point->getOperation(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>( + op->getResult(0)); + propagateIfChanged(constant, constant->join(dataflow::ConstantValue( + value, op->getDialect()))); + return success(); + } + // Dead code analysis requires every operands has initialized ConstantValue + // state before it is visited. + // https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322 + // That's why we need to set all operands to unknown constants. + setAllToUnknownConstants(op->getResults()); + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) + setAllToUnknownConstants(block.getArguments()); + } + return success(); + } + +private: + /// Set all given values as not constants. + void setAllToUnknownConstants(ValueRange values) { + dataflow::ConstantValue unknownConstant(nullptr, nullptr); + for (Value value : values) { + auto *constant = + getOrCreate>(value); + propagateIfChanged(constant, constant->join(unknownConstant)); + } + } +}; +} // namespace + +std::unique_ptr createDataFlowSolver() { + auto solver = std::make_unique(); + solver->load(); + solver->load(); + return solver; +} + +static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { + + if (auto makeTensorPtrOp = dyn_cast(op)) { + return makeTensorPtrOp; + } + + if (auto advanceOp = dyn_cast(op)) { + return getMakeTensorPtrOp(advanceOp.getPtr()); + } + + if (auto branch = dyn_cast(op)) { + auto idx = cast(v).getResultNumber(); + llvm::SmallVector yieldOps; + op->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) + yieldOps.push_back(yieldOp); + }); + + // benzh@ if multi yields, all yields operand should come from same arg. + Value newValue = yieldOps[0].getOperands()[idx]; + return getMakeTensorPtrOp(newValue); + } + + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +MakeTensorPtrOp getMakeTensorPtrOp(Value v) { + using BranchOps = llvm::SetVector>; + llvm::DenseMap blockToCFOps; + auto moduleOp = + v.getParentBlock()->getParentOp()->getParentOfType(); + + moduleOp.walk([&](Operation *op) { + if (auto br = dyn_cast(op)) { + Block *block = br.getDest(); + blockToCFOps[block].insert({op, -1}); + } + if (auto condBr = dyn_cast(op)) { + Block *blockT = condBr.getTrueDest(); + Block *blockF = condBr.getFalseDest(); + blockToCFOps[blockT].insert({condBr, 1}); + blockToCFOps[blockF].insert({condBr, 0}); + } + }); + + if (Operation *definingOp = v.getDefiningOp()) + return getMakeTensorPtrOpImpl(definingOp, v); + + // If there is no defining op, v must be a BlockArgument. + BlockArgument arg = cast(v); + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + + if (auto forOp = dyn_cast(argOwner)) + return getMakeTensorPtrOp( + forOp.getOperand(argNum + forOp.getNumControlOperands() - 1)); + if (auto funcOp = dyn_cast(argOwner)) { + Block *block = arg.getOwner(); + Operation *op; + int tOrF; + std::tie(op, tOrF) = blockToCFOps[block][0]; + if (auto br = dyn_cast(op)) + return getMakeTensorPtrOp(br.getDestOperands()[argNum]); + if (auto condBr = dyn_cast(op)) + return getMakeTensorPtrOp(tOrF ? condBr.getTrueDestOperands()[argNum] + : condBr.getFalseDestOperands()[argNum]); + return getMakeTensorPtrOp(argOwner->getOperand(argNum)); + } + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/CMakeLists.txt b/third_party/enflame/include/triton/lib/CMakeLists.txt new file mode 100644 index 000000000..e8ae340f2 --- /dev/null +++ b/third_party/enflame/include/triton/lib/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) +add_subdirectory(Tools) +add_subdirectory(Instrumentation) diff --git a/third_party/enflame/include/triton/lib/Conversion/CMakeLists.txt b/third_party/enflame/include/triton/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..143a4375a --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonToTritonGPU) +add_subdirectory(TritonGPUToLLVM) diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp new file mode 100644 index 000000000..f69c32a72 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -0,0 +1,51 @@ +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_ALLOCATESHAREDMEMORY +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +namespace { +struct AllocateSharedMemory + : public mlir::triton::gpu::impl::AllocateSharedMemoryBase< + AllocateSharedMemory> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + ModuleAllocation allocation(mod); + + mod.walk([&](FunctionOpInterface funcOp) { + auto *funcAllocation = allocation.getFuncData(funcOp); + funcOp.walk([&](Operation *op) { + auto oBufferId = funcAllocation->getBufferId(op); + int offset = -1; + if (oBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(oBufferId); + else if (op->getNumResults() == 1) { + Value value = op->getResult(0); + auto vBufferId = funcAllocation->getBufferId(value); + if (vBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(vBufferId); + } + if (offset == -1) + return; + if (op->hasAttr("allocation.offset")) + return; + op->setAttr("allocation.offset", + IntegerAttr::get(IntegerType::get(ctx, 32), offset)); + }); + return WalkResult::skip(); + }); + mod->setAttr("ttg.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + allocation.getSharedMemorySize())); + } +}; +} // namespace diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp new file mode 100644 index 000000000..88e9f1eda --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp @@ -0,0 +1,58 @@ +#include "mlir/IR/BuiltinOps.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUALLOCATEWARPGROUPS +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { +struct AllocateWarpGroups + : public mlir::triton::gpu::impl::TritonGPUAllocateWarpGroupsBase< + AllocateWarpGroups> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // Compute the total number of warps required at any given time. + int baseNumWarps = lookupNumWarps(mod); + int maxExtraWarps = 0; + mod.walk([&](WarpSpecializeOp op) { + ArrayRef arr = op.getPartitionNumWarps(); + int req = op.getTotalPartitionWarps(); + maxExtraWarps = std::max(maxExtraWarps, req); + + // Allocate the start IDs such that the largest warpgroups have lower + // starting warp IDs. + // FIXME: Handle aligning warp group IDs to 4 for TMEM. + SmallVector> idxAndSize; + for (auto [i, size] : llvm::enumerate(arr)) + idxAndSize.emplace_back(i, size); + llvm::sort(idxAndSize, + [&](auto lhs, auto rhs) { return lhs.second > rhs.second; }); + + SmallVector startIds(arr.size()); + int startId = baseNumWarps; + for (auto [i, size] : idxAndSize) { + startIds[i] = startId; + startId += size; + } + op.setWarpGroupStartIds(startIds); + }); + + if (auto totalNumWarps = + mod->getAttrOfType("ttg.total-num-warps")) { + if (maxExtraWarps == 0) + // There is no WarpSpecializeOp and ttg.total-num-warps is already set. + return; + } + Builder b(&getContext()); + mod->setAttr("ttg.total-num-warps", + b.getI32IntegerAttr(baseNumWarps + maxExtraWarps)); + } +}; +} // namespace diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp new file mode 100644 index 000000000..1a5e0809b --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -0,0 +1,103 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; + +struct AssertOpConversion : public ConvertOpToLLVMPattern { + explicit AssertOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter); + auto elemTy = elems[0].getType(); + Value condition = b.int_val(elemTy.getIntOrFloatBitWidth(), 0); + for (auto elem : elems) { + if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) { + condition = b.or_( + condition, + b.icmp_eq(elem, rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)))); + } else { + assert(false && "Unsupported type for assert"); + return failure(); + } + } + llAssert(op, condition, adaptor.getMessage(), rewriter); + if (isa(op.getCondition().getType())) { + // Add a barrier to avoid a race condition in case an assert is followed + // by an op that may trap if the assert condition is true. Since the + // tensor in those two operations may have different layout we need to + // make sure all the threads are done executing the assert before going to + // the next op. + b.barrier(); + } + rewriter.eraseOp(op); + return success(); + } + // op: the op at which the assert is inserted. Unlike printf, we need to + // know about the op to split the block. + void llAssert(Operation *op, Value condition, StringRef message, + ConversionPatternRewriter &rewriter) const { + + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + + StringRef file = "unknown"; + StringRef func = "unknown"; + int line = 0; + int col = 0; + + while (auto callLoc = dyn_cast(loc)) + loc = callLoc.getCallee(); + + if (auto fileLineColLoc = dyn_cast(loc)) { + file = fileLineColLoc.getFilename(); + line = fileLineColLoc.getLine(); + col = fileLineColLoc.getColumn(); + } + + // #block1 + // if (condition) { + // #block2 + // __assertfail(message); + // } + // #block3 + Block *prevBlock = op->getBlock(); + + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + targetInfo.assertFail(rewriter, loc, message, file, func, line); + + // Split a block after the call. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(ifBlock); + rewriter.create(loc, thenBlock); + rewriter.setInsertionPointToEnd(prevBlock); + rewriter.create(loc, condition, ifBlock, thenBlock); + rewriter.setInsertionPointToStart(thenBlock); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateAssertOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..e134423d6 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,40 @@ +add_triton_library(TritonGPUToLLVM + DotOpToLLVM/FMA.cpp + DotOpToLLVM/FMADotUtility.cpp + AllocateSharedMemory.cpp + AllocateWarpGroups.cpp + AssertOpToLLVM.cpp + ControlFlowOpToLLVM.cpp + ConvertLayoutOpToLLVM.cpp + DecomposeUnsupportedConversions.cpp + ElementwiseOpToLLVM.cpp + FuncOpToLLVM.cpp + GatherOpToLLVM.cpp + GlobalScratchMemoryAllocation.cpp + HistogramOpToLLVM.cpp + MakeRangeOpToLLVM.cpp + MemoryOpToLLVM.cpp + PrintOpToLLVM.cpp + ReduceOpToLLVM.cpp + ScanOpToLLVM.cpp + SPMDOpToLLVM.cpp + TypeConverter.cpp + Utility.cpp + ViewOpToLLVM.cpp + + DEPENDS + TritonGPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRGPUDialect + MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms + MLIRGPUTransforms + TritonAnalysis + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUTransforms +) diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp new file mode 100644 index 000000000..97ac574a5 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -0,0 +1,161 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (funcOp->hasAttr("nvvm.kernel")) { + // A GPU kernel + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + // A device function + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = b.insert_val(packedResultsTy, packedResults, + it.value(), it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the last argument of the caller, which is the current stack pointer + // of shared memory and append it to the operands of the callOp. + auto loc = callOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + if (!caller->hasAttr("allocation.offset")) { + auto base = LLVM::getStackPointer(rewriter, caller); + promotedOperands.push_back(base); + } else { + auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp); + promotedOperands.push_back(base); + } + + auto opOffsetAttr = callOp->getAttrOfType( + "ttg.global_scratch_memory_offset"); + Value opOffsetVal; + if (opOffsetAttr) { + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + opOffsetVal = b.i32_val(opOffset); + } + + promotedOperands.push_back( + LLVM::getGlobalScratchPtr(loc, rewriter, caller, opOffsetVal)); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + newCallOp.getProperties().setOpBundleSizes( + rewriter.getDenseI32ArrayAttr({})); + newCallOp.getProperties().setOperandSegmentSizes( + {static_cast(promotedOperands.size()), 0}); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateControlFlowOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..e3c945391 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,704 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton::gpu; + +#if 0 +// XXX(Keren): A temporary knob to control the use of legacy MMA conversion +// because LinearLayout seems to have some performance issues. +constexpr bool useLegacyMMAConversion = false; + +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { +public: + ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isSupported(srcLayout, dstLayout)) { + return lowerDistributedToDistributed(op, adaptor, rewriter, targetInfo); + } + return failure(); + } + +private: + bool isSupported(Attribute srcLayout, Attribute dstLayout) const { + return isa( + srcLayout) && + isa( + dstLayout); + } + + // shared memory rd/st for blocked or mma layout with data padding + void processReplica(Location loc, ConversionPatternRewriter &rewriter, + bool stNotRd, RankedTensorType type, + ArrayRef numCTAsEachRep, + ArrayRef multiDimRepId, unsigned vec, + ArrayRef paddedRepShape, + ArrayRef origRepShape, + ArrayRef outOrd, SmallVector &vals, + Value smemBase) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto accumNumCTAsEachRep = product(numCTAsEachRep); + auto layout = type.getEncoding(); + auto rank = type.getRank(); + auto sizePerThread = getSizePerThread(layout); + auto accumSizePerThread = product(sizePerThread); + SmallVector numCTATiles(rank); + auto shapePerCTATile = getShapePerCTATile(layout); + auto shapePerCTA = getShapePerCTA(layout, type.getShape()); + auto order = getOrder(layout); + for (unsigned d = 0; d < rank; ++d) { + numCTATiles[d] = ceil(shapePerCTA[d], shapePerCTATile[d]); + } + auto elemTy = type.getElementType(); + bool isInt1 = elemTy.isInteger(1); + bool isPtr = isa(elemTy); + auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); + if (isInt1) + elemTy = IntegerType::get(elemTy.getContext(), 8); + else if (isPtr) + elemTy = IntegerType::get(elemTy.getContext(), 64); + + auto llvmElemTy = getTypeConverter()->convertType(elemTy); + + for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { + auto multiDimCTAInRepId = + getMultiDimIndex(ctaId, numCTAsEachRep, order); + SmallVector multiDimCTAId(rank); + for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) { + auto d = it.index(); + multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); + } + + auto linearCTAId = + getLinearIndex(multiDimCTAId, numCTATiles, order); + // TODO: This is actually redundant index calculation, we should + // consider of caching the index calculation result in case + // of performance issue observed. + for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { + SmallVector multiDimOffset = + LLVM::getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, + type, multiDimCTAInRepId, shapePerCTATile); + SmallVector multiDimOffsetWrapped = + LLVM::getWrappedMultiDimOffset(rewriter, loc, multiDimOffset, + origRepShape, shapePerCTATile, + shapePerCTA); + Value offset = LLVM::linearize(rewriter, loc, multiDimOffsetWrapped, + paddedRepShape, outOrd); + auto elemPtrTy = smemBase.getType(); + Value ptr = b.gep(elemPtrTy, llvmElemTy, smemBase, offset); + auto vecTy = vec_ty(llvmElemTy, vec); + if (stNotRd) { + Value valVec = b.undef(vecTy); + for (unsigned v = 0; v < vec; ++v) { + auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v]; + if (isInt1) + currVal = b.zext(llvmElemTy, currVal); + else if (isPtr) + currVal = b.ptrtoint(llvmElemTy, currVal); + valVec = b.insert_element(vecTy, valVec, currVal, b.i32_val(v)); + } + b.store(valVec, ptr); + } else { + Value valVec = b.load(vecTy, ptr); + for (unsigned v = 0; v < vec; ++v) { + Value currVal = b.extract_element(llvmElemTy, valVec, b.i32_val(v)); + if (isInt1) + currVal = b.icmp_ne( + currVal, rewriter.create( + loc, i8_ty, rewriter.getI8IntegerAttr(0))); + else if (isPtr) + currVal = b.inttoptr(llvmElemTyOrig, currVal); + vals[elemId + linearCTAId * accumSizePerThread + v] = currVal; + } + } + } + } + } + // blocked/mma -> blocked/mma. + // Data padding in shared memory to avoid bank conflict. + LogicalResult + lowerDistributedToDistributed(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto typeConverter = getTypeConverter(); + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (product(srcTy.getShape()) == 1) { + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(getTotalElemsPerThread(dstTy), inVals[0]); + Value result = + packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + return success(); + } + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto shape = dstTy.getShape(); + unsigned rank = dstTy.getRank(); + SmallVector numReplicates(rank); + SmallVector inNumCTAsEachRep(rank); + SmallVector outNumCTAsEachRep(rank); + SmallVector inNumCTAs(rank); + SmallVector outNumCTAs(rank); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout); + auto shapePerCTA = getShapePerCTA(srcLayout, shape); + + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = + std::min(shapePerCTA[d], srcShapePerCTATile[d]); + unsigned outPerCTA = + std::min(shapePerCTA[d], dstShapePerCTATile[d]); + unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); + numReplicates[d] = ceil(shapePerCTA[d], maxPerCTA); + inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; + outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; + assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); + inNumCTAs[d] = ceil(shapePerCTA[d], inPerCTA); + outNumCTAs[d] = ceil(shapePerCTA[d], outPerCTA); + } + + // Potentially we need to store for multiple CTAs in this replication + auto accumNumReplicates = product(numReplicates); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + unsigned inVec = scratchConfig.inVec; + unsigned outVec = scratchConfig.outVec; + const auto &paddedRepShape = scratchConfig.paddedRepShape; + const auto &origRepShape = scratchConfig.repShape; + + unsigned outElems = getTotalElemsPerThread(dstTy); + auto outOrd = getOrder(dstLayout); + SmallVector outVals(outElems); + + for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { + auto multiDimRepId = + getMultiDimIndex(repId, numReplicates, outOrd); + if (repId != 0) { + insertBarrier(rewriter, op); + } + processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, + multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, + vals, smemBase); + insertBarrier(rewriter, op); + processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, + multiDimRepId, outVec, paddedRepShape, origRepShape, + outOrd, outVals, smemBase); + } + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct ConvertLayoutOpBlockedToDotOpShortcutConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase &targetInfo; + explicit ConvertLayoutOpBlockedToDotOpShortcutConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + + const auto &shape = op.getType().getShape(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto dstDotEncoding = dyn_cast(dstTy.getEncoding()); + if (!dstDotEncoding) + return failure(); + if (!isa(srcTy.getEncoding()) || + !isa(dstDotEncoding.getParent())) + return failure(); + if (cvtNeedsSharedMemory(srcTy, dstTy)) + return failure(); + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } +}; +#endif + +struct ConvertLayoutOpUsingLinearLayoutsConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase &targetInfo; + + // Set benefit to 2 so that this pattern applies before other convert-layout + // conversions. TODO(jlebar): Eventually we want this to be the only pattern. + explicit ConvertLayoutOpUsingLinearLayoutsConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + + const auto &shape = op.getType().getShape(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + LinearLayout srcLayout = + toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + + assert(to_vector(conversion.getInDimNames()) == + to_vector(conversion.getOutDimNames())); + auto dims = conversion.getInDimNames(); + if (llvm::is_contained(dims, kBlock)) { + // Case 1: Transfer between values in different CTAs. + // This requires moving values through distributed shared memory. + return rewriter.notifyMatchFailure( + op, "NYI: Transfer between different CTAs"); + } else if (llvm::is_contained(dims, kWarp)) { + // Case 2: Transfer between values in the same CTA, in which case we move + // values through shared memory. + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, kLane)) { + // Case 3. Transfer between values in the same warp, in which case we try + // to move values using warp shuffles, though if the pattern is + // complicated enough we may fall back to using shared memory + if (auto decomposedCvt = + getWarpLayoutConvertDecomposition(srcTy, dstTy)) { + transferWithinWarp(op, *decomposedCvt, adaptor, rewriter); + return success(); + } + // TODO: Since data is only transferred within a warp over shared memory, + // we should use `bar.warp.sync` instead of `barrier`, which will improve + // latency when warps issue barriers on different cycles. + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, kRegister)) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). + return transferWithinThread(op, conversion, adaptor, rewriter); + } else { + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + } + + LogicalResult + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + StringAttr kRegister = str_attr("register"); + assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(conversion.getInDimSize(kRegister)); + for (int i = 0; i < outVals.size(); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + outVals[i] = inVals[srcIdx]; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + LogicalResult transferWithinBlock(ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + assert(cvtNeedsSharedMemory(srcTy, dstTy)); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(!inVals.empty()); + + // We munge the input values by converting i (n<8) elements to i8 and + // pointers to i64. This is necessary because TargetInfo::loadDShared and + // storeDShared can't handle vectors of pointers or sub-byte elements. + auto elemTy = srcTy.getElementType(); + auto isSubByteInt = + elemTy.isInteger() && elemTy.getIntOrFloatBitWidth() < 8; + auto isPtr = isa(elemTy); + auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); + if (isSubByteInt) + elemTy = IntegerType::get(elemTy.getContext(), 8); + else if (isPtr) + elemTy = IntegerType::get(elemTy.getContext(), 64); + auto llvmElemTy = getTypeConverter()->convertType(elemTy); + + // Munge input values + for (const auto &it : llvm::enumerate(inVals)) { + if (isSubByteInt) { + inVals[it.index()] = b.zext(llvmElemTy, it.value()); + } else if (isPtr) { + inVals[it.index()] = b.ptrtoint(llvmElemTy, it.value()); + } + } + + // Pretty sure this is the identity function ATM + // It'd be better to simply call `quotient({kBlock})` and + // remove kBlock from transferWithinBlockImpl + auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout); + auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout); + SmallVector outVals = + transferWithinBlockImpl(inVals, op, srcLayoutWithinBlock, + dstLayoutWithinBlock, adaptor, rewriter); + + // Unmunge output values + for (const auto &it : llvm::enumerate(outVals)) { + if (isSubByteInt) { + outVals[it.index()] = b.trunc(llvmElemTyOrig, it.value()); + } else if (isPtr) { + outVals[it.index()] = b.inttoptr(llvmElemTyOrig, it.value()); + } + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + // Use warp shuffles to implement a layout conversion where data only needs to + // be moved within warps. + void transferWithinWarp(ConvertLayoutOp op, + DecomposedWarpConversion decomposed, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + SmallVector + transferWithinBlockImpl(ArrayRef inVals, ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + StringAttr kOffset = str_attr("offset"); + StringAttr kIteration = str_attr("iteration"); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + + auto scratchConfig = + getScratchConfigForCvt(op.getSrc().getType(), op.getType()); + auto tensorShapePerCTA = convertType(getShapePerCTA( + op.getSrc().getType().getEncoding(), op.getType().getShape())); + // Input dims: [offset, iteration, block] + // Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape + LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion( + ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order); + + // Layout for the store from registers to shared memory. + // + // Note: If two threads in the same warp write to the same shmem offset, the + // hardware resolves that without a stall or a bank conflict. Therefore we + // don't need to avoid duplicate writes. + // Input dims: [reg, lane, warp] + // Output dims: [offset, iteration] + bool isStMatrix = targetInfo.canUseStMatrix( + op.getSrc().getType(), scratchConfig.repShape, + scratchConfig.paddedRepShape, scratchConfig.order, + /*swizzleByteSize=*/0); + LinearLayout shmemStoreLayout = + isStMatrix ? chooseStMatrixLayout(ctx, op.getSrc().getType(), + /*swizzleByteSize=*/0) + : srcLayout.invertAndCompose(sharedLayout); + + const int shmemAllocatedNumElems = + getNumScratchElements(scratchConfig.paddedRepShape); + assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems); + + // Layout for the load from shmem to registers. + LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout); + + // Check that the `register` fully determines the `iteration`. That is, + // each thread does exactly the same reads and writes to shmem on each + // iteration, just with different input/output registers. + assert( + shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); + assert( + shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); + + // iteration -> registers + SmallVector> inRegsForIter = + collectRegsForIter(ctx, shmemStoreLayout); + SmallVector> outRegsForIter = + collectRegsForIter(ctx, shmemLoadLayout); + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto sharedPtrTy = smemBase.getType(); + Type elemTy = inVals[0].getType(); + auto outSize = shmemLoadLayout.getInDimSize(kRegister); + auto iterations = sharedLayout.getInDimSize(kIteration); + assert(scratchConfig.inVec * iterations <= inVals.size()); + assert(scratchConfig.outVec * iterations <= outSize); + + // Check only one dimension has been padded. + // This means the difference between the padded shape and the original shape + // should only be in one dimension, specifically in + // `scratchConfig.order[0]`. + auto rank = scratchConfig.repShape.size(); + for (auto i = 0; i < rank; i++) { + if (i == scratchConfig.order[0]) { + continue; + } + assert(scratchConfig.repShape[i] == scratchConfig.paddedRepShape[i]); + } + auto paddedStride = scratchConfig.repShape[scratchConfig.order[0]]; + auto paddedSize = + scratchConfig.paddedRepShape[scratchConfig.order[0]] - paddedStride; + + // Linear layout function is split in two parts below: + // + // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0) + // offset = regBase xor regIdx + // + // It is the same hack as what we've done in the emitIndices function to get + // around performance issues on AMD GPUs + auto getVecAddr = [&](LinearLayout &layout, Value ®Base, + int regSlice) -> Value { + auto regIdx = layout + .apply({{kRegister, regSlice}, + {kLane, 0}, + {kWarp, 0}, + {kBlock, 0}})[0] + .second; + Value offset = b.xor_(regBase, b.i32_val(regIdx)); + if (paddedSize > 0) { + assert(llvm::isPowerOf2_32(paddedStride)); + assert(llvm::isPowerOf2_32(paddedSize)); + auto rshiftVal = llvm::Log2_32(paddedStride); + auto lshiftVal = llvm::Log2_32(paddedSize); + offset = b.add( + b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)), + offset); + } + auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset); + vecAddr.setInbounds(true); + return vecAddr; + }; + + auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout, + {{kRegister, b.i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}})[0] + .second; + auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout, + {{kRegister, b.i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}})[0] + .second; + // register idx -> Value + llvm::MapVector outVals; + for (int i = 0; i < iterations; i++) { + if (i != 0) + insertBarrier(rewriter, op); + + auto &inRegs = inRegsForIter[i]; + auto &outRegs = outRegsForIter[i]; + + // When using `stmatrix`, we can store `inVec` elements even if they are + // not contiguous + auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut() + : scratchConfig.inVec; + for (int j = 0; j < inVals.size() / iterations; j += inVec) { + auto inRegSlice = inRegs[j]; + Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice); + SmallVector inValsVec; + for (int k = 0; k < inVec; k++) + inValsVec.push_back(inVals[inRegSlice + k]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + if (isStMatrix) { + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } else { + targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec, + /*pred=*/b.true_val()); + } + } + + insertBarrier(rewriter, op); + + for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) { + auto outRegSlice = outRegs[j]; + auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice); + Value valsVec = + targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt, + vec_ty(elemTy, scratchConfig.outVec), + /*pred=*/b.true_val()); + for (Value v : unpackLLVector(loc, valsVec, rewriter)) + outVals[outRegSlice++] = v; + } + } + + SmallVector outValsVec; + for (size_t i = 0; i < outVals.size(); i++) + outValsVec.push_back(outVals[i]); + return outValsVec; + } + + // Determine which registers are read/written in which iteration of the shmem + // transfer specified by `layout`. + SmallVector /*registers*/> + collectRegsForIter(MLIRContext *ctx, const LinearLayout &layout) const { + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + StringAttr kIteration = str_attr("iteration"); + + // The choice of iteration should be determined only by the register. That + // is, it should be correct to split the register dimension into iterations. + assert(layout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); + + LinearLayout sublayout = layout.sublayout({kRegister}, {kIteration}); + SmallVector> ret(sublayout.getOutDimSize(kIteration)); + for (int reg = 0; reg < sublayout.getInDimSize(kRegister); reg++) { + auto idx = sublayout.apply({{kRegister, reg}}); + ret[idx.begin()->second].push_back(reg); + } + return ret; + } +}; + +} // namespace + +void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp( + ConvertLayoutOp op, DecomposedWarpConversion decomposed, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + auto [P1, Cp, P2inv, reducedP1, reducedP2] = std::move(decomposed); + + // Grab the source elements and prepare the outputs of just the shuffles. + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector shflOuts(Cp.getInDimSize(kRegister)); + + Value laneId = getLaneId(rewriter, loc); + + // Emit one shuffle per destination register. + for (int i : llvm::seq(shflOuts.size())) { + // 'Cp' maps a (dst_lane, dst_reg) -> (src_lane, src_reg), and we know that + // for a register, it does not map to different registers in the same lane. + // At the same time, for each register, P1 returns the source value index + // to provide as the shuffle value. + auto out = applyLinearLayout(loc, rewriter, P1, + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); + assert(out.size() == 1); + Value srcRegIdx = out.front().second; + // The size of the input lane dimension is the number of selects to emit. + // TODO(jeff): For dtypes smaller than i32, we can use byte permutes and + // shuffle multiple values at a time. + Value shflSrc = b.undef(srcValues.front().getType()); + for (int j : llvm::seq(reducedP1.getInDimSize(kLane))) { + int32_t check = + reducedP1.apply({{kLane, j}, {kRegister, i}}).front().second; + shflSrc = b.select(b.icmp_eq(srcRegIdx, b.i32_val(check)), + srcValues[check], shflSrc); + } + + out = applyLinearLayout(loc, rewriter, Cp, + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); + assert(out.size() == 1); + Value shflIdx = out.front().second; + shflOuts[i] = targetInfo.shuffleIdx(rewriter, loc, shflSrc, shflIdx); + } + + // Finally, we just need to apply P2 to the shflOuts to permute the registers + // into their final form. Use the same trick to reduce the number of emitted + // selects. + SmallVector results(shflOuts.size()); + for (int i : llvm::seq(results.size())) { + Value result = b.undef(srcValues.front().getType()); + + auto out = applyLinearLayout(loc, rewriter, P2inv, + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); + Value resultIdx = out.front().second; + for (int j : llvm::seq(reducedP2.getInDimSize(kLane))) { + int32_t check = + reducedP2.apply({{kLane, j}, {kRegister, i}}).front().second; + result = b.select(b.icmp_eq(resultIdx, b.i32_val(check)), shflOuts[check], + result); + } + results[i] = result; + } + + Value result = + packLLElements(loc, getTypeConverter(), results, rewriter, op.getType()); + rewriter.replaceOp(op, result); +} + +void mlir::triton::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add( + typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp new file mode 100644 index 000000000..efc391de6 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -0,0 +1,96 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +static void addAttrs(Operation *op, ArrayRef attrs) { + for (const NamedAttribute attr : attrs) + op->setAttr(attr.getName(), attr.getValue()); +} + +} // namespace + +namespace mlir::triton::gpu { + +void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, + ShortcutFn shortcutFn) { + MLIRContext *ctx = module.getContext(); + int numCTAs = TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module); + + module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcMma = dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (!srcMma || !dstDotOp || shortcutFn(srcType, dstType)) + return; + + int numWarps = lookupNumWarps(cvtOp); + auto enc = BlockedEncodingAttr::get( + ctx, srcType.getShape(), getContigPerThread(srcType), getOrder(srcType), + numWarps, threadsPerWarp, numCTAs); + auto tmpType = RankedTensorType::get(dstType.getShape(), + dstType.getElementType(), enc); + + auto tmp = builder.create(cvtOp.getLoc(), tmpType, + cvtOp.getSrc()); + addAttrs(tmp, cvtOp->getAttrs()); + auto newConvert = + builder.create(cvtOp.getLoc(), dstType, tmp); + addAttrs(newConvert, cvtOp->getAttrs()); + + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + }); +} + +void decomposeBlockedToDotLayoutConversion(ModuleOp module) { + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + + module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; + auto srcBlocked = + dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (srcBlocked && dstDotOp) { + auto dotParent = dyn_cast(dstDotOp.getParent()); + if (dotParent) { + return; + } + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); + auto tmpType = MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SwizzledSharedEncodingAttr::get( + module.getContext(), dstDotOp, srcType.getShape(), + srcBlocked.getOrder(), srcBlocked.getCTALayout(), + srcType.getElementType()), + sharedMemorySpace); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + addAttrs(tmp, cvtOp->getAttrs()); + auto newConvert = builder.create(cvtOp.getLoc(), + dstType, tmp); + addAttrs(newConvert, cvtOp->getAttrs()); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 000000000..0d6a0cad3 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,38 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace ::mlir::triton::gpu; + +namespace { +class GenericFMAVectorMultiplier : public FMAVectorMultiplier { + OpBuilder &builder; + Location loc; + +public: + GenericFMAVectorMultiplier(OpBuilder &builder, Location loc) + : builder(builder), loc(loc) {} + + Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) override { + auto K = a.size(); + assert(b.size() == K); + Value accum = c; + for (auto [aElem, bElem] : llvm::zip(a, b)) + accum = builder.create(loc, aElem, bElem, accum); + return accum; + } +}; + +} // namespace + +LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + GenericFMAVectorMultiplier multiplier(rewriter, loc); + return parametricConvertFMADot(op, adaptor, typeConverter, rewriter, + multiplier); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp new file mode 100644 index 000000000..fa2c81472 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -0,0 +1,170 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; + +namespace { + +/// OperandValueKey structure represents compile time part +/// of spatial coordinates of a value in a tensor. +/// +/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be +/// defined as: +/// +/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord) +/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord) +/// k = kIdx +/// +/// Where: +/// CTABSize, CTANKSize: constants; +/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components; +/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components. +struct OperandValueKey { + unsigned bRepIdx, nonKRepIdx; + unsigned bIdx, nonKIdx, kIdx; + + bool operator==(const OperandValueKey &other) const { + return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx && + bIdx == other.bIdx && nonKIdx == other.nonKIdx && + kIdx == other.kIdx); + } +}; + +} // namespace + +template <> struct std::hash { + std::size_t operator()(const OperandValueKey &k) const { + return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx, + k.kIdx); + } +}; + +namespace { + +using ValueTableFMA = std::unordered_map; + +ValueTableFMA getValueTableFromStructFMA( + Value val, ArrayRef perRepShape, ArrayRef repetitions, + unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, + Location loc, ArrayRef inRepOrder, ArrayRef repOrder) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + assert(perRepShape.size() == 3); + auto numElemsRep = product(perRepShape); + assert(elems.size() == numElemsRep * product(repetitions)); + assert(kDim == 1 || kDim == 2); + assert(nonKDim == 1 || nonKDim == 2); + const unsigned bDim = 0; + + for (unsigned idx = 0; idx < elems.size(); ++idx) { + auto inRepLinearIdx = idx % numElemsRep; + auto repLinearIdx = idx / numElemsRep; + auto inRepSpatialIdx = + mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); + auto repSpatialIdx = + mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder); + OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], + inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], + inRepSpatialIdx[kDim]}; + res[key] = elems[idx]; + } + return res; +} + +} // namespace + +namespace mlir::triton::gpu { + +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto dTensorTy = cast(D.getType()); + + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + // TODO process A and B operand separately + auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = getContigPerThread(dTensorTy); + auto numElemsPerThread = product(sizePerThread); + SmallVector shapePerCTATile; + for (auto [reg, thread, warp] : + llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(), + dLayout.getWarpsPerCTA())) { + shapePerCTATile.push_back(reg * thread * warp); + } + shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile)); + sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread)); + + unsigned K = aShapePerCTA[2]; + + unsigned threadTileShape[3]; + unsigned repetitions[3]; + for (int i = 0; i < 3; ++i) { + repetitions[i] = + ceil(dShapePerCTA[i], static_cast(shapePerCTATile[i])); + } + + auto has = getValueTableFromStructFMA( + llA, {sizePerThread[0], sizePerThread[1], K}, + {repetitions[0], repetitions[1], 1}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder); + auto hbs = getValueTableFromStructFMA( + llB, {sizePerThread[0], K, sizePerThread[2]}, + {repetitions[0], 1, repetitions[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder); + + SmallVector acc = cc; + + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearInRepIdx = + LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + LLVM::linearize(multiDimRepIdx, repetitions, repOrder); + unsigned linearAccumIdx = + linearInRepIdx + linearRepIdx * numElemsPerThread; + + SmallVector aOpVector; + SmallVector bOpVector; + + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + acc[linearAccumIdx] = multiplier.multiplyVectors( + aOpVector, bOpVector, acc[linearAccumIdx]); + } + + auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..4155bccf9 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,665 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir::triton::gpu; + +namespace mlir::triton::gpu { + +Type getElementType(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) + return tensorType.getElementType(); + return type; +} + +int getNumElementsPerThreads(Type type, + const LLVMTypeConverter *typeConverter) { + int numElemsPerThread = 1; + if (auto tensorTy = dyn_cast(type)) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) + numElemsPerThread = structType.getBody().size(); + } + return numElemsPerThread; +} + +} // namespace mlir::triton::gpu + +namespace { +struct AddPtrOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto resultTy = op.getType(); + auto typeConverter = getTypeConverter(); + auto resultTensorTy = dyn_cast(resultTy); + if (resultTensorTy) { + unsigned elems = getTotalElemsPerThread(resultTy); + Type elemTy = typeConverter->convertType( + cast(resultTensorTy.getElementType()).getPointeeType()); + Type ptrTy = typeConverter->convertType(resultTensorTy.getElementType()); + auto ptrs = unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto offsets = unpackLLElements(loc, adaptor.getOffset(), rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = b.gep(ptrTy, elemTy, ptrs[i], offsets[i]); + } + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, view); + } else { + assert(isa(resultTy)); + auto resultPtrTy = typeConverter->convertType(resultTy); + auto resultElemTy = typeConverter->convertType( + cast(resultTy).getPointeeType()); + Value result = b.gep(resultPtrTy, resultElemTy, adaptor.getPtr(), + adaptor.getOffset()); + rewriter.replaceOp(op, result); + } + return success(); + } +}; + +struct CmpIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create( + loc, elemTy, ArithCmpIPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::ICmpPredicate + ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__) \ + case arith::CmpIPredicate::item__: \ + return LLVM::ICmpPredicate::item__ + + __PRED_ENUM(eq); + __PRED_ENUM(ne); + __PRED_ENUM(sgt); + __PRED_ENUM(sge); + __PRED_ENUM(slt); + __PRED_ENUM(sle); + __PRED_ENUM(ugt); + __PRED_ENUM(uge); + __PRED_ENUM(ult); + __PRED_ENUM(ule); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpIPredicate"); + } +}; + +struct CmpFOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + static SmallVector + createDestOps(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + MultipleOperandsRange operands, Location loc) { + return {rewriter.create( + loc, elemTy, ArithCmpFPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::FCmpPredicate + ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__, item1__) \ + case arith::CmpFPredicate::item__: \ + return LLVM::FCmpPredicate::item1__ + + __PRED_ENUM(OEQ, oeq); + __PRED_ENUM(ONE, one); + __PRED_ENUM(OGT, ogt); + __PRED_ENUM(OGE, oge); + __PRED_ENUM(OLT, olt); + __PRED_ENUM(OLE, ole); + __PRED_ENUM(ORD, ord); + __PRED_ENUM(UEQ, ueq); + __PRED_ENUM(UGT, ugt); + __PRED_ENUM(UGE, uge); + __PRED_ENUM(ULT, ult); + __PRED_ENUM(ULE, ule); + __PRED_ENUM(UNE, une); + __PRED_ENUM(UNO, uno); + __PRED_ENUM(AlwaysTrue, _true); + __PRED_ENUM(AlwaysFalse, _false); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpFPredicate"); + } +}; + +struct MulhiUIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + explicit MulhiUIOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(MulhiUIOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + + Type resultElementTy = getElementTypeOrSelf(op.getResult().getType()); + assert(resultElementTy.isInteger(32) || resultElementTy.isInteger(64)); + + auto funcName = targetInfo.getMulhiFuncName(resultElementTy); + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +struct ExternElementwiseOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + SmallVector createDestOps(ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + StringRef funcName = op.getSymbol(); + if (funcName.empty()) + llvm::errs() << "ExternElementwiseOpConversion"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath()); + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } +}; + +struct ElementwiseInlineAsmOpConversion + : public ConvertOpToLLVMPattern { + using Base = ConvertOpToLLVMPattern; + + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + // If operand size is smaller than 32 bits, pack in groups of 32 bits. + SmallVector packOperands(ElementwiseInlineAsmOp op, + MultipleOperandsRange operands, + ConversionPatternRewriter &rewriter, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector packedOperands; + unsigned numPackedElements = op.getPackedElement(); + for (int i = 0, e = op.getNumOperands(); i < e; i++) { + Type elemTy = getElementType(op.getOperand(i)); + unsigned bitWidth = + elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 64; + unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1; + numElementPerReg = std::min(numElementPerReg, numPackedElements); + for (int j = 0; j < numPackedElements; j += numElementPerReg) { + if (numElementPerReg == 1) { + packedOperands.push_back(operands[j][i]); + continue; + } + Type t = + vec_ty(getTypeConverter()->convertType(elemTy), numElementPerReg); + Value packed = b.undef(t); + for (int k = 0; k < numElementPerReg; k++) { + packed = b.insert_element(packed, operands[j + k][i], b.i32_val(k)); + } + packedOperands.push_back(packed); + } + } + return packedOperands; + } + + SmallVector> + createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands, Location loc) const { + auto ctx = op->getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + if (operands.size() % op.getPackedElement() != 0) + llvm::report_fatal_error("Inline asm op has more packed elements than " + "number of elements per thread."); + + // Pack elems smaller than 32 bits into 32-bit registers. + SmallVector packedOperands = + packOperands(op, operands, rewriter, loc); + + // Types returned by the LLVM asm op. If there's more than one, they'll be + // wrapped in a struct. + SmallVector asmRetTypes; + for (auto result : op.getResult()) { + auto ty = getTypeConverter()->convertType(getElementType(result)); + + // Pack return elements into 32-bits. + unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64; + unsigned numElemsPerReg = + std::min(bitWidth < 32 ? 32 / bitWidth : 1, op.getPackedElement()); + assert(op.getPackedElement() % numElemsPerReg == 0); + if (numElemsPerReg > 1) { + ty = vec_ty(ty, numElemsPerReg); + } + for (unsigned i = 0; i < op.getPackedElement() / numElemsPerReg; i++) { + asmRetTypes.push_back(ty); + } + } + Type asmRetType = + asmRetTypes.size() > 1 ? struct_ty(asmRetTypes) : asmRetTypes[0]; + + Value asmResults = + rewriter + .create( + loc, asmRetType, + /*operands=*/packedOperands, + /*asm_string=*/op.getAsmString(), + /*constraints=*/op.getConstraints(), + /*has_side_effects=*/!op.getPure(), + /*is_align_stack=*/false, + /*asm_dialect=*/ + LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT), + /*operand_attrs=*/ArrayAttr()) + ->getResult(0); + + // asmResults is a flat struct; pack its values into + // [return_value][op.getPackedElement()]. + SmallVector> ret(op->getNumResults()); + int structIdx = 0; + for (int i = 0; i < op->getNumResults(); i++) { + for (int j = 0; j < op.getPackedElement(); j++) { + Value val; + if (asmRetTypes.size() > 1) { + val = b.extract_val(asmResults, structIdx++); + } else { + val = asmResults; + } + if (auto vectorTy = dyn_cast(val.getType())) { + for (int k = 0; k < vectorTy.getNumElements(); k++) { + ret[i].push_back(b.extract_element(val, b.i32_val(k))); + } + j += vectorTy.getNumElements() - 1; + } else { + ret[i].push_back(val); + } + } + } + return ret; + } + + LogicalResult + matchAndRewrite(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // Layout is unpackedOperands[operand][elem]. + SmallVector> unpackedOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + unpackedOperands.push_back(subOperands); + } + + int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), + getTypeConverter()); + + // These are checked by the verifier, so we don't need to raise a nice + // error. + assert(all_of(unpackedOperands, [&](auto &operands) { + return operands.size() == numElemsPerThread; + })); + if (numElemsPerThread % op.getPackedElement() != 0) { + // Pad with the undef for each operand to have a multiple of + // op.getPackedElement() elements. + int numPaddedValue = + op.getPackedElement() - numElemsPerThread % op.getPackedElement(); + for (auto &operands : unpackedOperands) { + for (int i = 0; i < numPaddedValue; i++) { + operands.push_back(b.undef(operands[0].getType())); + } + } + } + + // Run the inline asm op on each block of elements. + // + // Layout is unpackedResults[result_idx][elem]. + // + // This loop always runs at least once, even when the asm has no input + // elements. + SmallVector> unpackedResults(op->getNumResults()); + for (unsigned i = 0; i < numElemsPerThread; i += op.getPackedElement()) { + // Block of elements to process with one call to the inline asm. This is + // ordered opposite `unpackedResults`: The outer dim is + // op.getPackedElement(), and the inner dim is the operand. + SmallVector> block(op.getPackedElement()); + for (auto &os : unpackedOperands) { + for (int j = 0; j < op.getPackedElement(); j++) { + block[j].push_back(os[i + j]); + } + } + auto cur = createDestOps(op, adaptor, rewriter, block, loc); + assert(cur.size() == unpackedResults.size()); + for (unsigned j = 0; j < cur.size(); j++) { + unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), + cur[j].end()); + } + } + for (auto &results : unpackedResults) { + results.resize(numElemsPerThread); + } + // Reorder and pack the results. + SmallVector outs; + for (int i = 0; i < unpackedResults.size(); i++) { + outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], + rewriter, op->getResult(i).getType())); + } + + rewriter.replaceOp(op, outs); + return success(); + } +}; + +struct AbsIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0][0], + /*is_int_min_poison=*/false)}; + } +}; + +struct AbsFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (llvm::isa(elemTy)) { + // Mask out the sign bit + auto num_bits = + getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); + assert(num_bits <= 16); + auto mask = (1u << (num_bits - 1u)) - 1u; + auto maskAttr = rewriter.getIntegerAttr(elemTy, mask); + auto maskConst = rewriter.create(loc, maskAttr); + return {b.and_(operands[0][0], maskConst)}; + } + + return {rewriter.create(loc, elemTy, operands[0][0])}; + } +}; + +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {rewriter.create( + loc, llvmOperands[1].getType(), llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; +template +struct MinMaxFOpConversion + : ElementwiseOpConversionBase> { + using Base = ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static_assert(std::is_same::value || + std::is_same::value, + "OpTy must be arith::MinimumFOp or arith::MaximumFOp"); + + // Choose the destination op based on the OpTy. + using DestOpNanProp = + typename std::conditional::value, + LLVM::MinimumOp, LLVM::MaximumOp>::type; + using DestOpNoNanProp = + typename std::conditional::value, + LLVM::MinNumOp, LLVM::MaxNumOp>::type; + + explicit MinMaxFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + bool hwNanPropagationSupported, + PatternBenefit benefit = 1) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + hwNanPropagationSupported(hwNanPropagationSupported) {} + + SmallVector createDestOps(OpTy op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (hwNanPropagationSupported) { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + // Handle workaround for NaN propagation, i.e. software emulation of NaN + // propagation. If any of the operands is NaN, return NaN. + auto lhs = operands[0][0]; + auto rhs = operands[0][1]; + auto lhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, lhs, lhs); + auto rhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, rhs, rhs); + auto isNan = rewriter.create(loc, lhsIsNan, rhsIsNan); + auto nonNanRes = rewriter.create(loc, elemTy, lhs, rhs); + + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + +private: + bool hwNanPropagationSupported; +}; + +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // Clip pattern not found, use min/max. + if (op.getPropagateNan() == PropagateNan::ALL) { + if (targetInfo.supportMaximumMinimum()) { + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + // On pre-80 compute capability, we need to handle NaN propagation + // manually. We need to check only the first operand for clamp. + auto lhs = operands[0][0]; + auto isNan = rewriter.create(loc, LLVM::FCmpPredicate::une, + lhs, lhs); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + auto nonNanRes = rewriter.create(loc, v, operands[0][2]); + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + + // No NaN propagation. + assert(op.getPropagateNan() == PropagateNan::NONE); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMinMaxFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool hwNanPropagationSupported, + PatternBenefit benefit) { + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); +} + +void mlir::triton::populateClampFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); +} + +void mlir::triton::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) + POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) + POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) + POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) + POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) + POPULATE_UNARY_OP(math::FloorOp, math::FloorOp) + POPULATE_UNARY_OP(math::CeilOp, math::CeilOp) + POPULATE_UNARY_OP(math::LogOp, math::LogOp) + POPULATE_UNARY_OP(math::Log2Op, math::Log2Op) + POPULATE_UNARY_OP(math::CosOp, math::CosOp) + POPULATE_UNARY_OP(math::SinOp, math::SinOp) + POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) + POPULATE_UNARY_OP(math::RsqrtOp, math::RsqrtOp) + POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) + POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op) + POPULATE_UNARY_OP(math::ErfOp, math::ErfOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) +#undef POPULATE_UNARY_OP + +#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - + POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * + POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) + POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) + POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % + POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) + POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + // fmin (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp) + // fmax (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp) + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax +#undef POPULATE_BINARY_OP + + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); + patterns.add(typeConverter, axisInfoAnalysis, + benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000..4ead774b7 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,198 @@ +#include "mlir/IR/BuiltinAttributes.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +// NOTE: [Additional Function Arguments] +// To support use of shared memory and global scratch memory inside of a +// function, the caller allocates a single large block of the relevant memory +// and calls the function with these extra arguments at the end. +// Specifically, the last argument is the global scratch memory allocation and +// the second to last is the shared memory allocation. +// +// For the kernel function itself, the shared memory base is a global symbol +// so no additional function argument is required but global scratch memory +// allocation is still passed in as the last argument. Though here the scratch +// memory is shared between all programs, so a linear offset based on the +// program id is required to get the local scratch base. + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + // Push back two new arguments that indicate the current pointer to shared + // memory and global scratch memory. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto sharedPtrTy = + LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace()); + auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1); + + // 1. Modify the function type to add the new arguments. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + bool isKernel = LLVM::isKernel(funcOp); + if (!isKernel) { + amendedInputTy.push_back(sharedPtrTy); + } + amendedInputTy.push_back(globalPtrTy); + auto amendedFuncTy = + FunctionType::get(ctx, amendedInputTy, funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + if (auto argAttrs = funcOp.getAllArgAttrs()) { + llvm::SmallVector amendedArgAttrs(argAttrs.begin(), + argAttrs.end()); + while (amendedArgAttrs.size() < amendedInputTy.size()) { + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + } + amendedAttrs.push_back( + rewriter.getNamedAttr(funcOp.getArgAttrsAttrName(), + rewriter.getArrayAttr(amendedArgAttrs))); + } + + // 3. Add the new arguments to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + if (!isKernel) { + region.addArgument(sharedPtrTy, loc); + } + region.addArgument(globalPtrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + // Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM + // attributes. + static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) { + const bool isKernel = LLVM::isKernel(llvmFuncOp); + for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) { + const auto attrs = llvmFuncOp.getArgAttrDict(i); + if (!attrs) { + continue; + } + + for (const auto &attr : attrs) { + if (attr.getName() == "tt.nv_tma_desc") { + const auto i32_type = + mlir::IntegerType::get(llvmFuncOp.getContext(), 32); + assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1)); + assert(isKernel && + "tt.nv_tma_desc is not supported for device functions"); + + // See + // https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc + mlir::BlockArgument arg = llvmFuncOp.getArgument(i); + const auto byteType = + mlir::IntegerType::get(llvmFuncOp.getContext(), 8); + const auto arrayType = mlir::LLVM::LLVMArrayType::get( + llvmFuncOp.getContext(), byteType, 128); + llvmFuncOp.setArgAttr(i, "llvm.byval", + mlir::TypeAttr::get(arrayType)); + llvmFuncOp.setArgAttr(i, "nvvm.grid_constant", + mlir::UnitAttr::get(llvmFuncOp.getContext())); + llvmFuncOp.setArgAttr(i, "llvm.align", + mlir::IntegerAttr::get(i32_type, 64)); + } + } + } + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo); + + // FailureOr maybeNewFuncOp = + // mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter, + // *getTypeConverter()); + // if (failed(maybeNewFuncOp)) { + // return failure(); + // } + + // LLVM::LLVMFuncOp newFuncOp = *maybeNewFuncOp; + + // auto ctx = funcOp->getContext(); + + // if (LLVM::isKernel(funcOp)) { + // // Set an attribute to indicate this function is a kernel entry. + // newFuncOp->setAttr("nvvm.kernel", + // rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + // newFuncOp.setLinkage(LLVM::Linkage::External); + // } else { + // // The noinline attribute will be used by the LLVM codegen to prevent + // // inlining. + // // + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + // newFuncOp.setPassthroughAttr( + // ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + // newFuncOp.setLinkage(LLVM::Linkage::Internal); + // } + // // Set an attribute for reqntidx, it could be used in latter LLVM codegen + // // for `nvvm.annotation` metadata. + // int numWarps = triton::gpu::lookupNumWarps(funcOp); + // if (auto totalNumWarps = + // funcOp.getParentOp()->getAttrOfType( + // "ttg.total-num-warps")) + // numWarps = totalNumWarps.getInt(); + // newFuncOp->setAttr("nvvm.reqntid", + // rewriter.getDenseI32ArrayAttr(32 * numWarps)); + + // rewriter.eraseOp(funcOp); + // rewriter.eraseOp(amendedFuncOp); + + // // Add attributes for by-value TMA descriptor args (nvidia) + // handleByvalTmaDescArgs(newFuncOp); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp new file mode 100644 index 000000000..109a58389 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -0,0 +1,350 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { +class GatherOpConversion : public ConvertOpToLLVMPattern { +public: + GatherOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + +private: + // Codegen the gather by storing the source tensor into shared memory and then + // gathering directly from shared memory. + void emitGatherInShared(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + // Codegen a warp-local gather by shuffling elements across the warp and + // selecting from them. + void emitWarpLocalGather(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + const TargetInfoBase &targetInfo; +}; + +LogicalResult +GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + GatherLoweringHelper helper(op); + // Specialize the lowering based on the source layout. Given that the cost of + // a warp shuffle is approximately half the cost of a roundtrip to shared + // memory with zero bank conflicts, we will need a more precise heuristic to + // choose between the two codegen paths and rely on the middle end to pick the + // right layout. + if (helper.isWarpLocal()) { + emitWarpLocalGather(op, adaptor, rewriter); + } else { + emitGatherInShared(op, adaptor, rewriter); + } + return success(); +} + +static Value convertIndexToI32(Location loc, Value index, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned idxWidth = index.getType().getIntOrFloatBitWidth(); + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + index = b.trunc(i32_ty, index); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + index = b.zext(i32_ty, index); + } + return index; +} + +void GatherOpConversion::emitGatherInShared( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType srcType = op.getSrc().getType(); + + // Compute the src subtensor shape owned by this CTA. + SmallVector srcShapePerCTA = + convertType(triton::gpu::getShapePerCTA(srcType)); + + // Grab the src values in this thread. + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // Emit the indices of the src values owned by this thread. + SmallVector> srcIndices = + emitIndices(loc, rewriter, targetInfo, srcType.getEncoding(), + op.getSrc().getType(), /*withCTAOffset=*/true); + + // Store the src values owned by the thread into their respective location in + // the scratch memory. + assert(srcValues.size() == srcIndices.size()); + + // Get the base pointer to the scratch memory. + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + // For each src element owned by the thread, index into the scratch memory and + // then store it. + Type elemType = getTypeConverter()->convertType(srcType.getElementType()); + for (auto [value, indices] : llvm::zip(srcValues, srcIndices)) { + // Convert the index at each dim into a single offset given the shape of the + // tensor. + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + // Emit the offset into the shared memory and then store the value. + Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset); + b.store(value, ptr); + } + + // Synchronize the whole CTA. + b.barrier(); + + // Grab the index values owned by this thread. + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + // Apply the layout of the destination tensor to obtain the indices of the + // column to gather along, then for each column, replace the index along the + // gather axis with the appropriate index value. + // + // I = LL(pid) + // idx = indices[I] + // I_gather = [I[d] if d != axis else idx for d in range(len(I))] + // out[I] = src[I_gather] + RankedTensorType dstType = op.getType(); + SmallVector> dstIndices = + emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, + /*withCTAOffset=*/true); + + unsigned axis = op.getAxis(); + SmallVector results(dstIndices.size()); + for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { + indices[axis] = convertIndexToI32(loc, idx, rewriter); + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset); + results[i] = b.load(elemType, ptr); + } + + Value packed = + packLLElements(loc, getTypeConverter(), results, rewriter, dstType); + rewriter.replaceOp(op, packed); +} + +// High-level description of the algorithm: +// +// `isWarpLocal` checks that it is possible to compute each output element +// without data movement across warps. +// +// If the gather dim is `dimN`, then this means +// +// ll^-1(dimN)[(block, warp)] == 0 +// +// for both source and index tensors: moving along the gather axis does not +// change the warp. Broadcasted layouts are not supported, so we know the +// layouts are permutation matrices. +// +// We can check this with `ll((block, warp))[dimN] == 0`. +// +// Let `gatherCol` be a tuple of all dimensions except the gather dimension. +// We also check that the gather columns line up the same way with respect to +// the warp between the source and index tensors with +// +// ll_src((block, warp))[gatherCol] == ll_idx((block, warp))[gatherCol] +// +// This means that for all index columns, the corresponding column in the source +// tensor is owned by the same warp. +// +// We also check +// +// ll_src(lane)[gatherCol] == ll_idx(lane)[gatherCol] +// +// This boils down to the fact that the algorithm essentially emits a series of +// index shuffles for each index value owned by each thread, and then a pile of +// selects to pick the right value. We need to figure out given an index value +// in a particular column, what are the source register values it could read +// from and who owns them. +// +// If this relationship did not hold, then the possible source registers for +// each index value varies with the thread, meaning the value operand provided +// to each shuffle index instruction would depend on the thread ID. This isn't a +// big deal. It just means would have to emit a pile of selects before each +// shuffle as well, to pick the right source register value. But we choose not +// to handle this. +// +// The codegen algorithm emits code: +// - Given the thread ID and a particular index tensor register, figure out +// which gather column it belongs to using a layout. +// - Using the index value itself as the value for `dimN`, use another layout to +// figure out which lane in the warp owns the desired value and which register +// in that lane it is. +// - For the gather column, figure out the source registers in that column, and +// for each of them, emit an index shuffle with the same computed lane ID. +// - Use the register component to select the right value from the shuffle +// results. +void GatherOpConversion::emitWarpLocalGather( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + // Layout dimension names. + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis())); + SmallVector allDims, otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + allDims.push_back(str_attr("dim" + Twine(dim))); + if (dim != op.getAxis()) { + otherDims.push_back(allDims.back()); + } + } + + // Compute the src and idx layouts. + LinearLayout srcLayout = + toLinearLayout(srcType.getShape(), srcType.getEncoding()); + LinearLayout idxLayout = + toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + // Let `ll_src` be the source layout and `ll_idx` be the index layout. + // Let `src_col` be a tuple of dimensions except the gather dimension, + // representing a specific column in the source tensor. Likewise for + // `idx_col`. Let `src_idx` be the index into gather dimension in the source + // tensor. + // + // `(src_lane, src_reg) = ll_src^-1(src_col, src_idx)`, where `src_lane` is + // the thread that contains the required element and `src_reg` is the register + // within that thread. + // + // Because `ll_src(block=0, warp=0, lane=0)[otherDims] == + // ll_idx(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the + // index tensor) the thread will need to read from the same column in the + // source tensor. + // + // Thus, we can obtain + // + // (src_lane, src_reg) = (ll_src^-1)( + // ll_idx(black, warp, lane, idx_reg)[otherDims], + // idxValues[idx_reg] + // )[{"lane", "register"}] + // + // And the mapping will be the correct for each thread. + // + // Given `src_reg \in [0, K*N)`, we just need to emit N index shuffles for + // each `idx_reg` (the number of index shuffles is quadratic!) and + // `llvm.select` using `src_reg` to get the right one. `K` is the number of + // elements per column owned by a thread. + + // Invert the source layout. It doesn't matter whether it is fully invertible + // with respect to anything except the register input dimension, since we know + // those don't vary in ways that matter for codegen. + LinearLayout invSrcLayout = srcLayout.pseudoinvert(); + + // Sanity check: the warp must be invariant to the index because otherwise the + // gather would need to read across warps! + assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kBlock, kWarp}) && + "expected a warp-local gather"); + invSrcLayout = invSrcLayout.sublayout(allDims, {kLane, kRegister}); + + LinearLayout idxColLayout = + idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = targetInfo.getClusterCTAId(rewriter, loc); + + unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister); + assert(srcRegsPerThread == srcValues.size()); + + // Given a index value, we need to know which sources register values it could + // index into. This is invariant to anything other than the register, which we + // checked already. Compute the full reverse map from + // + // idx_reg -> gather_column -> (src_reg0, src_reg1, ...) + // + LinearLayout invertSrcRegMap = invSrcLayout.sublayout(allDims, {kRegister}); + // Remove zero bases in the gather dimension to make the function injective + // (for a given column) over the same codomain. + invertSrcRegMap = invertSrcRegMap.removeZeroBasesAlongDim(kGatherDim); + // We are left with only non-zero bases in the gather dimension, which means + // the number of registers per column is the size of the "gather dimension". + unsigned numRegsPerColumn = invertSrcRegMap.getInDimSize(kGatherDim); + // Get a map from idx_reg to the column it indexes into. + LinearLayout idxRegToCol = idxLayout.sublayout({kRegister}, otherDims); + // Now given `idx_reg`, we can compute the column it belongs to in both src + // and index tensors, then partially apply `invertSrcRegMap` with this to + // obtain a function that outputs the corresponding registers in the src + // tensor in the same column. + + // L(column, i) = L(column, 0) xor L(0, i) + LinearLayout invertSrcRegMapColPart = + invertSrcRegMap.sublayout(otherDims, {kRegister}); + LinearLayout invertSrcRegMapRest = + invertSrcRegMap.sublayout({kGatherDim}, {kRegister}); + + SmallVector results; + for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { + SmallVector> column = + applyLinearLayout(loc, rewriter, idxColLayout, + {{kBlock, blockId}, + {kWarp, warpId}, + {kLane, laneId}, + {kRegister, b.i32_val(idxReg)}}); + assert(column.size() == otherDims.size()); + + // Combine the computed column with the data-dependent gather index. + column.emplace_back(kGatherDim, convertIndexToI32(loc, idxVal, rewriter)); + SmallVector> srcLaneAndReg = + applyLinearLayout(loc, rewriter, invSrcLayout, column); + + auto [srcLaneName, srcLane] = srcLaneAndReg.back(); + auto [srcRegName, srcReg] = srcLaneAndReg.front(); + assert(srcLaneName == kLane && srcRegName == kRegister); + + assert(!srcValues.empty() && "can't gather from an empty tensor"); + + // Figure out which src registers we need to index shuffle from. This is + // invariant to anything else. + SmallVector> normalizedColumn = + idxRegToCol.apply({{kRegister, idxReg}}); + int32_t srcBase = + invertSrcRegMapColPart.apply(normalizedColumn).front().second; + + Value result = b.undef(srcValues.front().getType()); + for (unsigned i = 0; i != numRegsPerColumn; ++i) { + int32_t rest = + invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second; + int32_t srcRegIdx = srcBase ^ rest; + + Value value = + targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane); + result = b.select(b.icmp_eq(b.i32_val(srcRegIdx), srcReg), value, result); + } + + results.push_back(result); + } + + rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results, + rewriter, op.getType())); +} + +} // namespace + +void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.insert(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp new file mode 100644 index 000000000..07299ea1c --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp @@ -0,0 +1,103 @@ +#include "mlir/Analysis/Liveness.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUGLOBALSCRATCHALLOCATIONPASS +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +static int32_t roundUp(int32_t val, int32_t step) { + auto t = val + step - 1; + return t - (t % step); +} + +static void allocateGMem(Operation *parentOp, + llvm::SetVector &callStack) { + // Recursively visit any dependency functions + parentOp->walk([&](triton::CallOp call) { + auto callable = call.resolveCallable(); + if (!callable->hasAttr("ttg.global_scratch_memory_size")) { + auto inserted = callStack.insert(parentOp); + assert(inserted && "call cycle detected"); + allocateGMem(callable, callStack); + callStack.remove(parentOp); + } + }); + + MLIRContext *ctx = parentOp->getContext(); + OpBuilder builder(ctx); + int32_t offset = 0; + uint32_t largestAlignment = 1; + + // Dumb allocation that ignores liveness and makes no attempt to minimize + // padding + // TODO: Use a real algorithm + parentOp->walk([&](Operation *op) { + uint32_t nbytes = 0; + uint32_t align = 0; + if (auto alloc = dyn_cast(op)) { + nbytes = alloc.getNbytes(); + align = alloc.getAlignment(); + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto nbytes_attr = callable->getAttrOfType( + "ttg.global_scratch_memory_size"); + auto align_attr = callable->getAttrOfType( + "ttg.global_scratch_memory_alignment"); + assert(nbytes_attr); + assert(align_attr); + + nbytes = nbytes_attr.getValue().getZExtValue(); + align = align_attr.getValue().getZExtValue(); + } + if (nbytes > 0) { + offset = roundUp(offset, align); + op->setAttr("ttg.global_scratch_memory_offset", + builder.getI32IntegerAttr(offset)); + offset += nbytes; + largestAlignment = std::max(largestAlignment, align); + } + }); + int32_t totalMemorySize = roundUp(offset, largestAlignment); + parentOp->setAttr("ttg.global_scratch_memory_size", + builder.getI32IntegerAttr(totalMemorySize)); + parentOp->setAttr("ttg.global_scratch_memory_alignment", + builder.getI32IntegerAttr(largestAlignment)); +} + +namespace { +class TritonGPUGlobalScratchAllocationPass + : public mlir::triton::gpu::impl::TritonGPUGlobalScratchAllocationPassBase< + TritonGPUGlobalScratchAllocationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + + bool seenKernel = false; + + SetVector callStack; + mod->walk([&](triton::FuncOp func) { + allocateGMem(func, callStack); + + if (func.getVisibility() == SymbolTable::Visibility::Public) { + assert(!seenKernel); + seenKernel = true; + auto size = + func->getAttrOfType("ttg.global_scratch_memory_size"); + auto align = func->getAttrOfType( + "ttg.global_scratch_memory_alignment"); + assert(size); + assert(align); + mod->setAttr("ttg.global_scratch_memory_size", size); + mod->setAttr("ttg.global_scratch_memory_alignment", align); + } + }); + assert(seenKernel); + } +}; +} // namespace diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp new file mode 100644 index 000000000..e2327bccb --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -0,0 +1,214 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +// Compute a histogram within a warp. This uses an algorithm by @apgoucher +// that does the following: +// Create a ballot for each bit of the bin index (there +// are only log2(num_bins) of these) and then apply bitwise operations to get +// the indicator functions for the bins owned by this particular thread, and +// only popcount those. +static SmallVector computeWarpLevelHistogram( + Location loc, RankedTensorType srcType, SmallVector &srcValues, + int numBins, int numThreadPerWarp, Value threadId, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(numBins % numThreadPerWarp == 0 && + "numBins must be divisible by numThreadPerWarp"); + Value zero = b.i32_val(0); + int numBits = llvm::Log2_64(numBins); + int numBitsLaneId = llvm::Log2_64(numThreadPerWarp); + unsigned numElementsPerThreads = triton::gpu::getTotalElemsPerThread(srcType); + unsigned numThreadWithUniqueData = + triton::gpu::getThreadsPerWarpWithUniqueData(srcType.getEncoding(), + srcType.getShape())[0]; + // The histogram is distributed across threads, each thread owns `numBins / + // numThreadPerWarp` bins. + SmallVector warpLevelHistogram(numBins / numThreadPerWarp, zero); + for (int i = 0; i < numElementsPerThreads; ++i) { + Value value = srcValues[i]; + SmallVector ballotBits; + for (int j = 0; j < numBits; ++j) { + Value bitSet = b.and_(value, b.i32_val(1 << j)); + Value cmp = b.icmp_ne(bitSet, zero); + Value bit = + targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), cmp); + ballotBits.push_back(bit); + } + uint64_t fullMaskValue = + numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF; + Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue); + Value mask = fullMask; + // If not all threads have unique data, mask out the redundant ones. + if (numThreadWithUniqueData < numThreadPerWarp) { + mask = b.int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1); + } + for (int i = 0; i < numBitsLaneId; i++) { + Value updateMask = + b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero), + b.int_val(numThreadPerWarp, 0), fullMask); + mask = b.and_( + mask, b.xor_(ballotBits[i + numBits - numBitsLaneId], updateMask)); + } + // at this point, 'mask' tells you which elements are in a bin owned by this + // thread. + for (int k = 0; k < warpLevelHistogram.size(); k++) { + Value binMask = mask; + for (int j = 0; j < numBits - numBitsLaneId; j++) { + Value updateMask = + b.int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue)); + binMask = b.and_(binMask, b.xor_(ballotBits[j], updateMask)); + } + // at this point, 'bin_mask' tells you which elements are in the kth bin + // owned by this thread. + Value bitCount = rewriter.create( + loc, int_ty(numThreadPerWarp), binMask); + if (numThreadPerWarp > 32) + bitCount = b.trunc(i32_ty, bitCount); + warpLevelHistogram[k] = b.add(warpLevelHistogram[k], bitCount); + } + } + return warpLevelHistogram; +} + +static void atomicAdd(Value ptr, Value val, Location loc, + ConversionPatternRewriter &rewriter) { + rewriter.create(loc, LLVM::AtomicBinOp::add, ptr, val, + LLVM::AtomicOrdering::monotonic); +} + +static SmallVector computeCrossWarpHistogram( + Location loc, ConversionPatternRewriter &rewriter, RankedTensorType srcType, + Value baseSharedMemPtr, const SmallVector &warpLevelHistogram, + int numBins, int numThreadPerWarp, const SmallVector &indices, + Value threadId, int numWarps) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector histogramValues; + unsigned numWarpsWithUniqueData = + mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(), + srcType.getShape())[0]; + Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1)); + // Initialize the shared memory with zeros. + int64_t numElementPerThread = + ceil(numBins, numThreadPerWarp * numWarps); + for (int i = 0; i < numElementPerThread; ++i) { + Value offset = + b.add(threadId, b.i32_val((i * numWarps * numThreadPerWarp))); + offset = b.urem(offset, b.i32_val(numBins)); + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + b.store(b.i32_val(0), sharedMemPtr); + } + b.barrier(); + Block *afterAtomics = nullptr; + // If some warps have replicated data we need to skip those warps when + // accumulating. + if (numWarpsWithUniqueData < numWarps) { + Block *currentBlock = rewriter.getInsertionBlock(); + afterAtomics = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *atomicBlock = rewriter.createBlock(afterAtomics); + rewriter.setInsertionPointToEnd(currentBlock); + Value cond = b.icmp_ult( + threadId, b.i32_val(numWarpsWithUniqueData * numThreadPerWarp)); + rewriter.create(loc, cond, atomicBlock, afterAtomics); + rewriter.setInsertionPointToStart(atomicBlock); + } + // Apply atomic add to update the histogram in shared memory. + for (int i = 0; i < warpLevelHistogram.size(); ++i) { + Value warpLevelHistogramValue = warpLevelHistogram[i]; + Value offset = b.add(b.mul(laneId, b.i32_val(warpLevelHistogram.size())), + b.i32_val(i)); + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + atomicAdd(sharedMemPtr, warpLevelHistogramValue, loc, rewriter); + } + if (afterAtomics) { + rewriter.create(loc, afterAtomics); + rewriter.setInsertionPointToStart(afterAtomics); + } + b.barrier(); + // load the histogram to register with the right layout. + for (Value index : indices) { + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index); + Value val = b.load(i32_ty, sharedMemPtr); + histogramValues.push_back(val); + } + return histogramValues; +} + +namespace { +struct HistogramOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + explicit HistogramOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + SmallVector srcValues = unpackLLElements(loc, input, rewriter); + int numBins = op.getType().getDimSize(0); + auto mod = op->getParentOfType(); + int numThreadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + assert(numThreadsPerWarp == 32 || + numThreadsPerWarp == 64 && + "Only supports 32 or 64 threads per warp"); + int numWarps = triton::gpu::lookupNumWarps(op); + // Pad out the bins so that we have at least one bin per thread within a + // warp. + numBins = std::max(numBins, numThreadsPerWarp); + Value threadId = getThreadId(rewriter, loc); + auto srcType = op.getSrc().getType(); + // First compute a warp local histogram based on values owned by each warps. + SmallVector warpLevelHistogram = computeWarpLevelHistogram( + loc, srcType, srcValues, numBins, numThreadsPerWarp, threadId, rewriter, + targetInfo); + + // Then use atomic to update the histogram in shared memory. + // TODO: we could skip this for cases with num_warps=1 as long as we can + // generate the right layout. Currently the warp level histogram generates + // data in the default blocked layout. + Value baseSharedMemPtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto dstType = op.getType(); + Attribute dstEncoding = dstType.getEncoding(); + auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding, + dstType, true); + SmallVector innerDimIndices; + for (int i = 0; i < indices.size(); ++i) + innerDimIndices.push_back(indices[i][0]); + SmallVector histogramValue = computeCrossWarpHistogram( + loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins, + numThreadsPerWarp, innerDimIndices, threadId, numWarps); + + Value results = packLLElements(loc, typeConverter, histogramValue, rewriter, + op.getType()); + rewriter.replaceOp(op, results); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateHistogramOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp new file mode 100644 index 000000000..8060b4431 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp @@ -0,0 +1,54 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +struct MakeRangeOpConversion + : public ConvertOpToLLVMPattern { + MakeRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32)); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart()); + auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, true); + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = b.add(multiDim.value()[0], start); + } + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMakeRangeOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000..ea14908b9 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,198 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// blocked -> shared. +// Swizzling in shared memory to avoid bank conflict. Normally used for +// A/B operands of dots. +void lowerDistributedToShared( + Location loc, Value src, Value dst, Value adaptorSrc, + const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { + auto srcTy = cast(src.getType()); + auto dstTy = cast(dst.getType()); + auto elemTy = typeConverter->convertType(srcTy.getElementType()); + + auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); + storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter, + targetInfo, llvmOpCount); +} + +struct GlobalScratchAllocOpConversion + : public ConvertOpToLLVMPattern { + GlobalScratchAllocOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto opOffsetAttr = op->getAttrOfType( + "ttg.global_scratch_memory_offset"); + assert(opOffsetAttr); + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + + auto funcOp = op->getParentOfType(); + if (!funcOp) { + return failure(); + } + Value ptr = + LLVM::getGlobalScratchPtr(loc, rewriter, funcOp, b.i32_val(opOffset)); + + rewriter.replaceOp(op, ptr); + return success(); + } +}; + +struct LocalAllocOpConversion + : public ConvertOpToLLVMPattern { + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.isSharedMemoryAlloc()) + return failure(); + Location loc = op->getLoc(); + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + + auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(), + loc, rewriter); + // If there is an initial tensor, store it into the shared memory. + if (op.getSrc()) { + lowerDistributedToShared(loc, op.getSrc(), op.getResult(), + adaptor.getSrc(), smemObj, typeConverter, + rewriter, targetInfo); + } + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::LocalDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { +public: + LocalLoadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); + } + +private: + LogicalResult + lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), + typeConverter->convertType(srcTy.getElementType()), rewriter); + auto elemLlvmTy = typeConverter->convertType(dstTy.getElementType()); + + SmallVector outVals = loadSharedToDistributed( + dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value memDescVal = op.getDst(); + auto llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + + std::pair llvmOpCount; + lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), + adaptor.getSrc(), smemObj, getTypeConverter(), + rewriter, targetInfo, &llvmOpCount); + + targetInfo.storeOpAnnotation(op, llvmOpCount.first, llvmOpCount.second); + + rewriter.eraseOp(op); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMemoryOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp new file mode 100644 index 000000000..5cb27bb48 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp @@ -0,0 +1,241 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace { + +// The input print op contains: +// - a "prefix" (string) specified by the user, and +// - one or more "operands" (tensors). +// +// For each operand, we print all of the values contained in this GPU thread, +// one per line, along with the index of the value in its tensor. +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + auto getPid = [&](int axis) { + return targetInfo.programId(rewriter, loc, + op->getParentOfType(), axis); + }; + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + // Simple printf of a string without any tensors. + if (op.getNumOperands() == 0) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" + << op.getPrefix(); + llPrintf(formatStr, {pid[0], pid[1], pid[2]}, rewriter); + rewriter.eraseOp(op); + return success(); + } + + assert(op.getNumOperands() == op.getIsSigned().size()); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + bool isSigned = op.getIsSigned()[i] > 0; + // Elements of the tensor that are resident in this GPU thread. + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + + // Get the indices of `elems` within the tensor. Note that if `elems` + // has an "interesting" layout, then these will not be in any + // particularly nice order. + + // Extract the shape of the tensor being printed and use it to figure + // out how many digits we need for each of the dimensions. + SmallVector dimWidths; + SmallVector> indices; + if (auto rankedTy = + dyn_cast(op.getOperand(i).getType())) { + indices = emitIndices(loc, rewriter, targetInfo, rankedTy.getEncoding(), + rankedTy, true); + for (int64_t dim : rankedTy.getShape()) { + if (dim > 0) { + dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); + } else { + dimWidths.push_back(0); + } + } + } else { + // We're printing a scalar. + assert(elems.size() == 1); + indices.push_back({}); + } + + if (!elems.empty()) { + printTensor(op.getPrefix(), /*operand=*/i, + /*numOperands=*/op.getNumOperands(), elems, pid, indices, + dimWidths, op.getHex(), rewriter, isSigned); + } + } + rewriter.eraseOp(op); + return success(); + } + + void printTensor(StringRef prefixStr, size_t operand, size_t numOperands, + ArrayRef elems, std::array pid, + ArrayRef> indices, + ArrayRef dimWidths, bool hex, + ConversionPatternRewriter &rewriter, bool isSigned) const { + assert(!elems.empty()); + assert(elems.size() == indices.size()); + assert(dimWidths.size() == indices.front().size()); + + size_t rank = dimWidths.size(); + + // Format is: + // pid (, , ) idx (, , ...) (operand ) + // where we leave off "(operand )" if there's only one operand. + // + // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts + // with " " and ends with ": "). + + Value formatStrValue; + int formatStrByteCount = 0; + for (int i = 0; i < elems.size(); i++) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + + // nvptx printf can only accept 32 args; if we pass more than that, it + // will print garbage for the trailing args. + constexpr int kMaxPrintfOperands = 32; + SmallVector printfOperands; + + // TODO(jlebar): We really should pad the pid, but because the max pid is + // not known at compile-time, this would require nontrivial device-side + // work. + os << "pid ("; + for (int j = 0; j < pid.size(); j++) { + if (j != 0) { + os << ", "; + } + os << getFormatSubstr(pid[j]); + printfOperands.push_back(pid[j]); + } + os << ") "; + + // If `rank` is large enough, we could end up exceeding + // kMaxPrintfOperands. In that case, just truncate the index. + // (Subtract 2 because we're going to add two operands after the index.) + int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; + + os << "idx ("; + const auto &index = indices[i]; + for (size_t dim = 0; dim < index.size(); dim++) { + if (dim != 0) { + os << ", "; + } + if (dim == maxAllowedRank) { + os << "... (truncated)"; + break; + } + os << getFormatSubstr(index[dim], /*hex=*/false, + /*width=*/dimWidths[dim]); + printfOperands.push_back(index[dim]); + } + os << ")" << prefixStr; + + if (numOperands > 1) { + os << "(operand " << operand << ") "; + } + + auto elem = elems[i]; + + os << getFormatSubstr(elem, hex, /*width=*/std::nullopt, isSigned); + printfOperands.push_back(elem); + + // It's the same format string each iteration, but it's a lot easier if we + // construct the format string at the same time as we populate + // printfOperands. But we don't want to create BLOCK_SIZE duplicate + // strings, so we cache the Value. + if (i == 0) { + formatStrValue = + llPrintf(formatStr, printfOperands, rewriter, &formatStrByteCount); + } else { + targetInfo.printf(rewriter, formatStrValue, formatStrByteCount, + printfOperands); + } + } + } + + std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt, + bool isSigned = false) const { + Type type = value.getType(); + // If the `value` is a pointer, just return %p. + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + (isSigned ? "lli" : "llu"); + else + return prefix + (isSigned ? "i" : "u"); + } + assert(false && "not supported type"); + return ""; + } + + // Returns a Value for the format string, which you can reuse. Writes the byte + // count for the string to |formatStrByteCount| if not null. + Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populatePrintOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp new file mode 100644 index 000000000..910c4bc01 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -0,0 +1,392 @@ +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DistributedEncodingTrait; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getThreadOrder; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +namespace { +struct ReduceOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + ReduceOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ReduceOpHelper helper(op); + assert(helper.isReduceWithinCTA() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShape = helper.getScratchRepShape(); + + SmallVector smemBases = + getSmemBases(op, product(smemShape), rewriter, targetInfo); + + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemShape, smemBases, rewriter); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; + + void accumulate(Location loc, ConversionPatternRewriter &rewriter, + Region &combineOp, SmallVector &acc, ValueRange cur, + Value pred = {}) const { + auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); + if (acc.size() < results.size()) { + acc.resize(results.size()); + } + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + } + + SmallVector> + unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; + } + + void sync(ConversionPatternRewriter &rewriter, Location loc, + triton::ReduceOp op) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + b.barrier(); + } + + // Reduce along op axis for elements that are in the same thread. The + // accumulated value is stored in accs. + void reduceWithinThreads( + ReduceOpHelper &helper, SmallVector> &srcValues, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + RankedTensorType operandType = op.getInputTypes()[0]; + // Assumes offsets don't actually depend on type + SmallVector> offsets = + emitOffsetForLayout(helper.getSrcLayout(), operandType); + + // Thread X might hold the same input value in two registers. Get the + // indices in `offsets` that hold unique values, and only accumulate over + // those. + llvm::MapVector, int> uniqueOffsets; + for (int i = 0; i < offsets.size(); ++i) { + uniqueOffsets.insert({offsets[i], i}); + } + + unsigned srcElems = getTotalElemsPerThread(operandType); + auto *combineOp = &op.getCombineOp(); + auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo, + helper.getSrcLayout(), operandType, true); + // reduce within threads + for (const auto &[_, i] : uniqueOffsets) { + SmallVector key = offsets[i]; + key[op.getAxis()] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]); + if (isFirst) + indices[key] = srcIndices[i]; + } + } + + // Apply warp reduction across the given number of contiguous lanes using op + // region and the accumulator values as source. + void warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, unsigned interleave, + Value pred = {}) const { + auto success = targetInfo.warpReduce(rewriter, loc, acc, op, + numLaneToReduce, interleave); + if (success) + return; + + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { + SmallVector shfl(acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); + } + accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred); + } + } + + // Reduce across threads within each warp. + void + reduceWithinWarps(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned threadOffsetOnReductionAxis = + helper.getThreadOffsetOnReductionAxis(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = accs[key]; + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps, + threadOffsetOnReductionAxis); + } + } + + // Pack the accumulator values and replace the reduce op with the result. + void packResults(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(accs[key][i]); + } + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else + results[i] = accs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + } + + void storeWarpReduceToSharedMemory( + ReduceOpHelper &helper, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcLayout = + mlir::cast(helper.getSrcLayout()); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + unsigned axis = op.getAxis(); + auto smemShape = helper.getScratchRepShape(); + + // Lezcano: We should move all the shared memory logic to use LLs natively + auto srcShape = helper.getSrcShape(); + auto kLane = rewriter.getStringAttr("lane"); + auto [multiDimLaneId, isRepresentativeLane] = + delinearize(rewriter, loc, srcLayout, srcShape, kLane, laneId); + auto kWarp = rewriter.getStringAttr("warp"); + auto [multiDimWarpId, isRepresentativeWarp] = + delinearize(rewriter, loc, srcLayout, srcShape, kWarp, warpId); + + Value laneIdAxis = multiDimLaneId[axis]; + Value laneZero = b.icmp_eq(laneIdAxis, b.i32_val(0)); + Value write = + b.and_(b.and_(isRepresentativeLane, isRepresentativeWarp), laneZero); + + Value warpIdAxis = multiDimWarpId[axis]; + + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = it.second; + + SmallVector writeIdx = indices[key]; + writeIdx[axis] = warpIdAxis; + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShape, smemOrder); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value writePtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); + targetInfo.storeShared(rewriter, loc, writePtr, acc[i], write); + } + } + } + + // Load the reduction of each warp and accumulate them to a final value and + // store back to shared memory. + void accumulatePartialReductions(ReduceOpHelper &helper, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + auto smemShape = helper.getScratchRepShape(); + unsigned elems = product(smemShape); + unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto mod = op->getParentOfType(); + int numLanes = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numWarps = triton::gpu::lookupNumWarps(op); + int numThreads = numLanes * numWarps; + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = b.i32_val(numLanes); + Value laneId = b.urem(threadId, warpSize); + Value zero = b.i32_val(0); + + unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value threadIsNeeded = b.icmp_slt(threadId, b.i32_val(elems)); + Value readOffset = threadId; + for (unsigned round = 0; round < elemsPerThread; ++round) { + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value readPtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, + threadIsNeeded); + } + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */, + threadIsNeeded); + // only the first thread in each sizeInterWarps is writing + Value writeOffset = readOffset; + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + writePtrs[i] = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); + } + + Value laneIdModSizeInterWarps = b.urem(laneId, b.i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarpsIsZero = + b.icmp_eq(laneIdModSizeInterWarps, zero); + Value pred = b.and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred); + } + + if (round != elemsPerThread - 1) { + readOffset = b.add(readOffset, b.i32_val(numThreads)); + } + } + } + + // Load the final reduction from shared memory and replace the reduce result + // with it. + void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + // nd-tensor where n >= 1 + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, targetInfo, + resultLayout, resultTy, true); + auto resultShape = resultTy.getShape(); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + op.getAxis(), b.i32_val(0)); + for (size_t resultIdx = 0, resultDim = resultShape.size(); + resultIdx < resultDim; ++resultIdx) { + auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1; + if (resultShape[resultIdx] > smemShape[smemIdx]) { + // When srcShape smaller then src sizePerThread, only srcShape + // elements is accumulated in smem. Modulo smemShape effectively + // replicates srcShape elements to src sizePerThread. + readIdx[smemIdx] = + b.urem(readIdx[smemIdx], b.i32_val(smemShape[smemIdx])); + } + } + Value readOffset = + linearize(rewriter, loc, readIdx, smemShape, smemOrder); + Value readPtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + resultVals[j] = b.load(elemTy, readPtr); + } + + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = b.load(elemTy, smemBases[i]); + } + } + rewriter.replaceOp(op, results); + } +}; +} // namespace + +void mlir::triton::populateReduceOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h new file mode 100644 index 000000000..e3012d29d --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -0,0 +1,163 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H + +// TODO: refactor so that it doesn't fail if Allocation.h +// is included after utility.h (due to conflict in `store` macro +// and +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" + +// +#include "mlir/IR/TypeUtilities.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include +#include + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton { +class ReduceOp; +class ScanOp; + +inline SmallVector +inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock, + Block *insertionBlock, Block::iterator insertionPoint, + ValueRange combineArgs) { + auto returnOp = combineBlock.getTerminator(); + rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint, + combineArgs); + + auto results = SmallVector(returnOp->getOperands()); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +inline SmallVector applyCombineOp(Location loc, + ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur, Value pred = {}) { + // Allows for passing an uninitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + + // Create a new copy of the combine block, and try to speculatively inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + + rewriter.cloneRegionBefore(combineOp, parent, + std::next(currentBlock->getIterator())); + Block &newCombine = *currentBlock->getNextNode(); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + auto isRegionSpeculatable = + std::all_of(newCombine.begin(), newCombine.end(), + [](auto &op) { return isSpeculatable(&op); }); + + if (!pred || isRegionSpeculatable) { + // Fast path, region has no side effects so we can unconditionally execute + return inlineCombineBlock(rewriter, newCombine, currentBlock, + rewriter.getInsertionPoint(), combineArgs); + } + + // Slow case, create an if to only execute region when pred is true + // #currentBlock + // if (pred) { + // #newCombine + // results = combineOp(cur, acc) + // yield results + // } else { + // yield undef + // } + // #thenBlock + Block *thenBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + auto returnOp = newCombine.getTerminator(); + auto results = SmallVector(returnOp->getOperands()); + + rewriter.setInsertionPointToEnd(currentBlock); + SmallVector thenBlockArgs; + thenBlockArgs.reserve(results.size()); + for (auto result : results) { + auto ty = result.getType(); + auto undef = rewriter.create(loc, ty); + thenBlockArgs.push_back(undef); + thenBlock->addArgument(ty, loc); + } + rewriter.create(loc, pred, &newCombine, combineArgs, + thenBlock, thenBlockArgs); + + // Split a block after the call. + rewriter.setInsertionPointToEnd(&newCombine); + rewriter.replaceOpWithNewOp(returnOp, thenBlock, results); + rewriter.setInsertionPointToStart(thenBlock); + return SmallVector(thenBlock->getArguments()); +} + +} // namespace mlir::triton + +template +class ConvertTritonGPUReduceScanToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + // Make sure the class is only instantiated with Reduce and Scan + static_assert(std::is_same_v || + std::is_same_v); + + using ConvertOpToLLVMPattern::getTypeConverter; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Return the pointee type of the shared memory pointer for operand i. + Type getElementType(SourceOp op, int i) const { + auto ty = op.getInputTypes()[i].getElementType(); + return getTypeConverter()->convertType(ty); + } + + // Helper to compute the smem bases in both reductions and scans + SmallVector getSmemBases(SourceOp op, unsigned elems, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + // Assign base index to each operand in their order in indices + std::map indexToBase; + auto basePtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + indexToBase[indices[0]] = basePtr; + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + indexToBase[indices[i]] = + b.gep(basePtr.getType(), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], b.i32_val(elems)); + } + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemBases[i] = indexToBase[i]; + } + return smemBases; + } +}; + +#endif diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..972fc5592 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,38 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId(rewriter, op->getLoc(), + op->getParentOfType(), + op.getAxisAsInt()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp new file mode 100644 index 000000000..ac9198103 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -0,0 +1,573 @@ +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::toLinearEncoding; + +// apply combine region to acc and cur and accumulate it into acc +static SmallVector accumulate(ScanLoweringHelper &helper, + ConversionPatternRewriter &rewriter, + ValueRange acc, ValueRange cur, + Value pred = {}) { + auto loc = helper.getLoc(); + auto &combineOp = helper.getCombineOp(); + return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); +} + +// Scan a contiguous elements within a thread and update `srcValues` in place. +static void +scanThreadContiguousElements(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper) { + // Depending on layout contiguous elements along axis dim may not be + // contiguous in srcValues. Keep track of what elements belong to the same + // chunk of contiguous elements. + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned numChunks = srcValues.size() / scanElementsPerThreads; + unsigned stride = helper.getAxisElementStride(); + SmallVector> accs(numChunks); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + // Change this into emitOffsetForLayout? + unsigned accIndex = (srcIndex % stride) + + ((srcIndex / stride) / scanElementsPerThreads) * stride; + + accs[accIndex] = + accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]); + srcValues[srcIndex] = accs[accIndex]; + } +} + +// Apply a scan across threads of the warp for the last element of each +// contiguous group of elements. +static void warpScan(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value laneIdAxis) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Reduce within warps. + SmallVector acc = srcValues[srcIndex]; + for (unsigned i = 1; i <= scanDim / 2; i <<= 1) { + SmallVector shfl(acc.size()); + for (unsigned j = 0; j < acc.size(); ++j) { + shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); + } + Value mask = b.icmp_sge(laneIdAxis, b.i32_val(i)); + SmallVector tempAcc = + accumulate(helper, rewriter, shfl, acc, mask); + for (unsigned j = 0; j < acc.size(); ++j) { + acc[j] = b.select(mask, tempAcc[j], acc[j]); + } + } + srcValues[srcIndex] = std::move(acc); + } +} + +// For each set of contiguous elements within a thread we store the partial +// reduction into shared memory. Each parallel scan and each warp will store its +// own partial reductions. The shared memory is organized as follow: +// ----------------------------------------------------------------- +// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +static void storeWarpAccumulator(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId, SmallVector smemBases, + SmallVector smemTypes, + Value parallelLaneId, Value isRepresentative, + const TargetInfoBase &targetInfo) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned chunkId = 0; + unsigned elementStride = helper.getAxisElementStride(); + + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + auto lastElement = srcValues[srcIndex]; + Value mask = b.icmp_eq(laneId, b.i32_val(scanDim - 1)); + mask = b.and_(mask, isRepresentative); + Value index = + b.add(parallelLaneId, b.mul(warpId, b.i32_val(numParallelLane))); + index = b.add(index, b.i32_val(chunkId * numParallelLane * axisNumWarps)); + for (unsigned i = 0; i < lastElement.size(); ++i) { + Value writePtr = + b.gep(smemBases[i].getType(), smemTypes[i], smemBases[i], index); + targetInfo.storeShared(rewriter, loc, writePtr, lastElement[i], mask); + } + chunkId++; + } +} + +// Read the partial reductions from shared memory from each chunk of contiguous +// elements for each warp and parallel scan. Then combine the partial reduction +// with the right elements. Within a given contiguous element chunk we update +// all the elements by accumulating the value from the last element of the +// reduced value from the previous lane. +static void AddPartialReduce(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, + ArrayRef smemBases, + ArrayRef smemTypes, Value warpId, + Value laneIdAxis, Value parallelLaneId) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + Value maskNotFirstWarp = b.icmp_ne(warpId, b.i32_val(0)); + Value maskNotFirstLane = b.icmp_ne(laneIdAxis, b.i32_val(0)); + Value maskNotFirstThread = b.or_(maskNotFirstWarp, maskNotFirstLane); + struct Accumulator { + SmallVector acc; + SmallVector maskedAcc; + }; + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Accumulate the partial reduction from shared memory. Decide which + // accumulator to combine based on whether the elements belong to the same + // dimension along axis. + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + Accumulator &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + for (unsigned i = 0; i < axisNumWarps; ++i) { + Value index = + b.add(parallelLaneId, + b.i32_val(numParallelLane * (i + chunkId * axisNumWarps))); + SmallVector partialReduce(helper.getNumOperands()); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + auto elemTy = smemTypes[j]; + Value ptr = b.gep(smemBases[j].getType(), elemTy, smemBases[j], index); + partialReduce[j] = b.load(elemTy, ptr); + } + + if (accumulator.acc.size() == 0) { + accumulator.acc = partialReduce; + accumulator.maskedAcc = partialReduce; + continue; + } + Value mask = b.icmp_sge(warpId, b.i32_val(i + 1)); + accumulator.acc = + accumulate(helper, rewriter, accumulator.acc, partialReduce); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + accumulator.maskedAcc[j] = + b.select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); + } + } + + Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{}; + auto temp = accumulate(helper, rewriter, accumulator.maskedAcc, + srcValues[srcIndex], pred); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + auto val = srcValues[srcIndex]; + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + temp[i] = b.select(maskNotFirstWarp, temp[i], val[i]); + } + } + srcValues[srcIndex] = temp; + // Update the rest of the contiguous elements. + SmallVector lastElement(helper.getNumOperands()); + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); + lastElement[i] = + b.select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + pred = axisBlockId == 0 ? maskNotFirstThread : Value{}; + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + laneValue[j] = b.select(maskNotFirstThread, laneValue[j], + srcValues[srcIndex - i * elementStride][j]); + } + } + srcValues[srcIndex - i * elementStride] = std::move(laneValue); + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + accumulator.maskedAcc = accumulator.acc; + chunkId++; + } +} + +static void AddPartialReduceOneWarp(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value warpId, + Value laneIdAxis, Value laneIdLast) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + Value maskFirstWarp = b.icmp_eq(warpId, b.i32_val(0)); + Value maskFirstLane = b.icmp_eq(laneIdAxis, b.i32_val(0)); + Value maskFirstThread = b.and_(maskFirstWarp, maskFirstLane); + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector> accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + auto &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + if (axisBlockId == 0) // First chunk and first block + accumulator = srcValues[srcIndex]; + else + srcValues[srcIndex] = + accumulate(helper, rewriter, accumulator, srcValues[srcIndex]); + // Update the rest of the contiguous elements. + auto lastElement = srcValues[srcIndex]; + if (scanDim > 1) { + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + lastElement[i] = targetInfo.shuffleUp( + rewriter, loc, srcValues[srcIndex][i], threadStride); + lastElement[i] = + b.select(maskFirstLane, accumulator[i], lastElement[i]); + if (numScanBlocks > 1) + // Update accumulator with the value from the last lane. + accumulator[i] = targetInfo.shuffleIdx( + rewriter, loc, srcValues[srcIndex][i], laneIdLast); + } + } else if (numScanBlocks > 1) { + accumulator = srcValues[srcIndex]; + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = accumulate(helper, rewriter, lastElement, laneValue); + if (axisBlockId == 0) { + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + // For the first warp and first chunk we don't have anything to + // accumulate. + laneValue[j] = b.select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], + laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = std::move(laneValue); + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + chunkId++; + } +} + +namespace { +struct ScanOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + using ConvertTritonGPUReduceScanToLLVMPattern< + triton::ScanOp>::ConvertTritonGPUReduceScanToLLVMPattern; + explicit ScanOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(emitFastScan(op, adaptor, rewriter, targetInfo))) + return success(); + return failure(); + } + +private: + const TargetInfoBase &targetInfo; + std::tuple, Value> + getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId) const; + std::tuple, Value> + getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value warpId) const; + std::tuple + getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const; + LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const; +}; + +std::tuple, Value> +ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const { + auto loc = helper.getLoc(); + auto srcEncoding = helper.getEncoding(); + auto kWarp = rewriter.getStringAttr("lane"); + return delinearize(rewriter, loc, srcEncoding, helper.getShape(), kWarp, + laneId); +} + +std::tuple, Value> +ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const { + auto loc = helper.getLoc(); + auto srcEncoding = helper.getEncoding(); + auto kWarp = rewriter.getStringAttr("warp"); + return delinearize(rewriter, loc, srcEncoding, helper.getShape(), kWarp, + warpId); +} + +// Break up the threadId into lane and warp id along the scan dimension and +// compute a flat id for the parallel dimensions. +std::tuple +ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const { + auto loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto [multiDimLaneId, isRepresentativeLane] = + getMultiDimLaneId(rewriter, helper, laneId); + auto [multiDimWarpId, isRepresentativeWarp] = + getMultiDimWarpId(rewriter, helper, warpId); + + Value laneIdAxis = multiDimLaneId[axis]; + Value warpIdAxis = multiDimWarpId[axis]; + + multiDimLaneId[axis] = b.i32_val(0); + threadsPerWarp[axis] = 1; + Value laneIdParallel = linearize(rewriter, loc, multiDimLaneId, + threadsPerWarp, helper.getOrder()); + multiDimWarpId[axis] = b.i32_val(0); + warpsPerCTA[axis] = 1; + Value warpIdParallel = + linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, helper.getOrder()); + Value flatIdParallel = b.add( + laneIdParallel, + b.mul(warpIdParallel, b.i32_val(helper.getNonAxisNumThreadsPerWarp()))); + auto isRepresentative = b.and_(isRepresentativeLane, isRepresentativeWarp); + return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel, + isRepresentative); +} + +SmallVector> +unpackInputs(Location loc, triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter) { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; +} + +// Flip the srcValues. Both reverses the chunks and reverses the lanes. +// Lane reversal is done with a butterfly shuffle flip (divide and flip). +SmallVector> +flipSrcValues(Location loc, triton::ScanOp op, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + SmallVector> srcValues, int iWarpSize) { + SmallVector> values(srcValues.size()); + for (int i = 0; i < srcValues.size(); ++i) { + int revIndex = srcValues.size() - i - 1; + for (unsigned j = 0; j < op.getNumOperands(); ++j) { + for (unsigned k = iWarpSize / 2; k >= 1; k = k / 2) { + srcValues[revIndex][j] = + targetInfo.shuffleXor(rewriter, loc, srcValues[revIndex][j], k); + } + values[i].push_back(srcValues[revIndex][j]); + } + } + return values; +} + +// Lowering using warp shuffle operations to do warp level scan. +LogicalResult +ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + ScanLoweringHelper helper(op); + auto loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (!helper.isSupported()) + return op.emitError("TODO: unsupported scan layout"); + + Value threadId = getThreadId(rewriter, loc); + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = b.i32_val(iWarpSize); + Value warpId = b.udiv(threadId, warpSize); + Value laneId = b.urem(threadId, warpSize); + + auto [laneIdAxis, warpIdAxis, flatIdParallel, isRepresentative] = + getDelinearizedIds(rewriter, helper, laneId, warpId); + auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + auto srcValues = + unpackInputs(loc, op, adaptor, rewriter, *getTypeConverter()); + + // For the reverse option we apply flip(scan(flip()) in + // order to avoid having a separate code path in the reverse direction. + // We do this by 1) reversing chunks, 2) reversing lanes, 3) reversing + // warp ids and then undoing this below. + // (Note: Tried pretty hard to get shflDownSync to work but I ended up + // having to add a lot of the complex cross warp code (if rev switch + // first/last etc). Reverse first seems more maintainable.) + if (op.getReverse()) { + warpIdAxis = b.sub(b.i32_val(axisNumWarps - 1), warpIdAxis); + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + // Scan contiguous elements in a thread and update `srcValues`. + scanThreadContiguousElements(srcValues, rewriter, helper); + // Apply warp level scan to the last element of each chunk of contiguous + // elements. + warpScan(srcValues, rewriter, targetInfo, helper, laneIdAxis); + + if (axisNumWarps > 1) { + // Slow path for the case where there are multiple warps with unique data on + // the axis. + auto elems = helper.getScratchSizeInElems(); + SmallVector smemBases = + getSmemBases(op, elems, rewriter, targetInfo); + SmallVector smemTypes(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemTypes[i] = getElementType(op, i); + } + + // Store the partial reducing for each warp into shared memory. + storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, + smemBases, smemTypes, flatIdParallel, isRepresentative, + targetInfo); + b.barrier(); + // Read back the partial reduction of each warp and accumulate them based on + // warpId. Then update each chunk of contiguous elements by adding the + // accumulated value from the previous lane. + AddPartialReduce(srcValues, rewriter, targetInfo, helper, smemBases, + smemTypes, warpIdAxis, laneIdAxis, flatIdParallel); + } else if (srcValues.size() > 1) { + // Fast path for the case where there is only one warp with unique data on + // the axis. + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + auto multiDimLaneId = + std::get<0>(getMultiDimLaneId(rewriter, helper, laneId)); + multiDimLaneId[helper.getAxis()] = b.i32_val(scanDim - 1); + auto linearEncoding = helper.getEncoding(); + auto threadsPerWarp = linearEncoding.getThreadsPerWarp(); + auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, + helper.getOrder()); + AddPartialReduceOneWarp(srcValues, rewriter, targetInfo, helper, warpIdAxis, + laneIdAxis, laneIdLast); + } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. + + auto transpose = [](const SmallVector> &v) { + assert(v.size() > 0 && v[0].size() > 0); + auto ret = SmallVector>(v[0].size(), + SmallVector(v.size())); + for (int i = 0; i < v.size(); ++i) { + for (int j = 0; j < v[0].size(); ++j) { + ret[j][i] = v[i][j]; + } + } + return ret; + }; + + SmallVector results(op.getNumOperands()); + if (op.getReverse()) { + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + auto valuesTransposed = transpose(srcValues); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto resultTy = dyn_cast(op.getResult()[i].getType()); + results[i] = packLLElements(loc, getTypeConverter(), valuesTransposed[i], + rewriter, resultTy); + } + rewriter.replaceOp(op, results); + return success(); +} +} // namespace + +void mlir::triton::populateScanOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000..d40e146ad --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -0,0 +1,79 @@ +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::MemDescType; + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, const TargetInfoBase &targetInfo, + const DataLayoutAnalysis *analysis) + : TritonGPUToLLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), targetInfo, + analysis) {} + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, const LowerToLLVMOptions &options, + const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, options, analysis) { + addConversion([ctx](triton::PointerType type) -> std::optional { + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); + }); + addConversion([ctx](TensorDescType type) -> std::optional { + return LLVM::LLVMPointerType::get(ctx, 1); + }); + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type, targetInfo); + }); + addConversion([&](MemDescType type) -> std::optional { + return convertMemDescType(type, targetInfo); + }); + addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { + return convertAsyncTokenType(type); + }); + + convertFP8Type(); +} + +Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type, const TargetInfoBase &targetInfo) { + auto ctx = type.getContext(); + Type eltType = convertType(type.getElementType()); + unsigned numElementsPerThread = getTotalElemsPerThread(type); + SmallVector types(numElementsPerThread, eltType); + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertMemDescType( + MemDescType type, const TargetInfoBase &targetInfo) { + auto ctx = type.getContext(); + // base ptr + auto ptrType = LLVM::LLVMPointerType::get( + ctx, targetInfo.getAddressSpace(type.getMemorySpace())); + + if (isa( + type.getEncoding())) { + return ptrType; + } + + SmallVector types; + types.push_back(ptrType); + auto rank = type.getRank(); + // offsets + for (auto i = 0; i < rank; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertAsyncTokenType( + triton::gpu::AsyncTokenType type) { + return IntegerType::get(type.getContext(), 32); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp new file mode 100644 index 000000000..d4fd686e4 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -0,0 +1,929 @@ +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Attributes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" + +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_clz(unsigned x) { + unsigned long r; + _BitScanReverse(&r, x); + return static_cast(r ^ 31); +} + +static int __builtin_ctz(unsigned x) { + unsigned long r; + _BitScanForward(&r, x); + return static_cast(r); +} + +#endif + +// This reverts #5645, because it introduced increased register pressure in AMD +// backend. +// TODO: remove when new implementation performance reaches target level +namespace { + +LinearLayout getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, + LinearLayout regLayout, + triton::gpu::SharedEncodingTrait dstEnc, + int elemBitWidth) { + StringAttr kBlock = StringAttr::get(ctx, ("block")); + int rank = shape.size(); + + LinearLayout sharedLayout = triton::gpu::toLinearLayout(shape, dstEnc); + auto sharedOrder = triton::gpu::getOrder(dstEnc, shape); + + // sharedLayout's in-dims are currently (offset, block). Reshape to + // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional + // shmem strides. (The offsetX's appear in minor-to-major order.) + auto sharedLegacy = cast(dstEnc); + SmallVector> multiDimSharedSize; + for (int i = 0; i < rank; i++) { + int dim = sharedOrder[i]; + int64_t size = std::max( + int64_t{1}, + shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]); + multiDimSharedSize.push_back( + {StringAttr::get(ctx, ("offset" + std::to_string(dim))), size}); + } + multiDimSharedSize.push_back({kBlock, sharedLayout.getInDimSize(kBlock)}); + sharedLayout = sharedLayout.reshapeIns(multiDimSharedSize); + + // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, + // ..., offsetXN, block), where the offsetX's are in minor-to-major order. + return regLayout.invertAndCompose(sharedLayout); +} + +} // namespace + +namespace mlir { + +namespace triton::gpu { +Type getFunctionType(Type resultType, ValueRange operands) { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); +} + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, + StringRef libname /*= ""*/, + StringRef libpath /*= ""*/) { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return cast(*funcOp); + + Operation *parent = op; + if (!isa(op)) + parent = op->getParentOfType(); + OpBuilder b(parent); + auto ret = b.create(op->getLoc(), funcName, funcType); + ret.getOperation()->setAttr("libname", + StringAttr::get(op->getContext(), libname)); + ret.getOperation()->setAttr("libpath", + StringAttr::get(op->getContext(), libpath)); + return ret; +} +} // namespace triton::gpu + +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(layout.getNumInDims() == indices.size()); + for (auto [inDimName, idx] : indices) { + assert(layout.hasInDim(inDimName) && "Invalid inDimName"); + } + + // This function can emit a lot of MLIR code, which ultimately makes + // compilation slow. (We think this shouldn't be the case -- it's not *that* + // much code -- but we're not clear on how to fix the slowness, which happens + // in the bowels of MLIR.) + // + // As a result we go through some contortions to avoid emitting code where + // possible. + + // Manually constant-fold the layout where possible. + SmallVector> constantIns; + for (auto [inDimName, idx] : indices) { + if (auto constant = idx.getDefiningOp()) { + constantIns.push_back( + {inDimName, cast(constant.getValue()).getInt()}); + } else { + constantIns.push_back({inDimName, 0}); + } + } + SmallVector constantComponent = + llvm::to_vector(llvm::make_second_range(layout.apply(constantIns))); + + Value zero = b.i32_val(0); + SmallVector> outIndices; + for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) { + if (constantComponent[i] == 0) + outIndices.push_back({outDimName, zero}); + else + outIndices.push_back({outDimName, b.i32_val(constantComponent[i])}); + } + + for (auto [inDimName, idx] : indices) { + if (idx.getDefiningOp()) { + continue; + } + + int nBits = layout.getInDimSizeLog2(inDimName); + for (int i = 0; i < nBits; i++) { + Value bit = b.and_(idx, b.i32_val(1 << i)); + Value bit_is_zero = b.icmp_eq(bit, zero); + for (auto &[outDimName, outIdx] : outIndices) { + int32_t basis = layout.getBasis(inDimName, i, outDimName); + if (basis == 0) + continue; + outIdx = b.xor_(outIdx, b.select(bit_is_zero, zero, b.i32_val(basis))); + } + } + } + + return outIndices; +} + +std::optional getWarpGroupStartThreadId(Block *block) { + using namespace triton::gpu; + + // Look for an enclosing `ttg.warp_specialize` op. + while (block && block->getParentOp() && + !isa(block->getParentOp())) + block = block->getParentOp()->getBlock(); + if (!block || !block->getParentOp()) + return {}; + + auto partitions = cast(block->getParentOp()); + unsigned idx = block->getParent()->getRegionNumber(); + WarpSpecializeOp ws = partitions.getParentOp(); + std::optional> startIds = ws.getWarpGroupStartIds(); + assert(startIds && "cannot get warp group ID before warp group allocation"); + int32_t warpStartId = (*startIds)[idx]; + int threadsPerWarp = + TritonGPUDialect::getThreadsPerWarp(ws->getParentOfType()); + return warpStartId * threadsPerWarp; +} + +Value getThreadId(OpBuilder &rewriter, Location loc) { + Value tid = + rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); + tid = rewriter.create(loc, i32_ty, tid); + + // If this is being created inside a warp specialize op, compute the relative + // thread ID within the warp group. + if (std::optional startId = + getWarpGroupStartThreadId(rewriter.getInsertionBlock())) { + TritonLLVMOpBuilder b(loc, rewriter); + tid = rewriter.create(loc, tid, b.i32_val(*startId)); + } + + return tid; +} + +static int lookupThreadsPerWarp(OpBuilder &rewriter) { + assert(rewriter.getInsertionBlock() && "expected an insertion point"); + Operation *op = rewriter.getInsertionBlock()->getParentOp(); + while (op && !isa(op)) + op = op->getParentOp(); + assert(op && "cannot create thread ID outside of module"); + return triton::gpu::TritonGPUDialect::getThreadsPerWarp(cast(op)); +} + +Value getLaneId(OpBuilder &rewriter, Location loc) { + TritonLLVMOpBuilder b(loc, rewriter); + Value tid = getThreadId(rewriter, loc); + int threadsPerWarp = lookupThreadsPerWarp(rewriter); + return b.urem(tid, b.i32_val(threadsPerWarp)); +} + +std::pair getLaneAndWarpId(OpBuilder &rewriter, Location loc) { + TritonLLVMOpBuilder b(loc, rewriter); + Value tid = getThreadId(rewriter, loc); + int threadsPerWarp = lookupThreadsPerWarp(rewriter); + Value warpSizeVal = b.i32_val(threadsPerWarp); + + Value laneId = b.urem(tid, warpSizeVal); + Value warpId = b.udiv(tid, warpSizeVal); + return {laneId, warpId}; +} + +SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + + LinearLayout ll = triton::gpu::toLinearLayout(shape, layout); + + // TODO(jlebar): We could add strong typing if we wanted; for now this is + // "stringly typed". + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); + unsigned rank = shape.size(); + SmallVector> ret; + // Linear layout function is split in two parts below: + // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0) + // idxs = idxsBase xor idxsReg + // + // L(0, t, w, b) part is the same for all registers, + // so we hoist it out of the main register loop in the below. + // + // This approach produces code with lower register pressure and + // less computations, compared to fused L(r,t,w,b) method. + auto idxsBase = applyLinearLayout(loc, rewriter, ll, + {{kRegister, b.i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}); + for (unsigned reg = 0; reg < ll.getInDimSize(str_attr("register")); reg++) { + auto idxsReg = + ll.apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + SmallVector> idxs; + for (auto [idxBase, idxReg] : llvm::zip(idxsBase, idxsReg)) { + auto dimName = idxBase.first; + assert(dimName == idxReg.first && + "dim names of block+warp+thread and register idx should be equal"); + auto idx = b.xor_(idxBase.second, b.i32_val(idxReg.second)); + idxs.emplace_back(dimName, idx); + } + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + ret.push_back(llvm::to_vector(llvm::make_second_range(idxs))); + } + + return ret; +} + +namespace { + +Value getSmemVecAddr(const LinearLayout ®Layout, + const LinearLayout ®ToSharedLayout, + const LinearLayout &invertAllocSharedLayout, + const SharedMemoryObject &smemObj, + triton::gpu::MemDescType sharedTy, Type elemLlvmTy, + Value regId, Value laneId, Value warpId, Value blockId, + Location loc, RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + auto shape = sharedTy.getShape(); + auto allocShape = sharedTy.getAllocShape(); + auto rank = shape.size(); + auto sharedEnc = + cast(sharedTy.getEncoding()); + + auto smemBase = smemObj.getBase(); + auto smemOffsets = smemObj.getOffsets(); + auto smemStrides = smemObj.getStrides(sharedTy, loc, rewriter); + Value smemOffset; + // When loading or storing to shared memory, we consider two cases for + // performance reasons: + // + // 1. Non-swizzled shared memory. + // 2. Swizzled shared memory. + // + // Consider lowering `ttg.local_load %a`. In the first case, we can + // directly construct a linear layout using `%a`'s shape and shared memory + // encoding, irrespective of `%a`'s rank or whether it represents a slice of a + // larger tensor. + // + // The method does not apply for swizzled shared memory in some scenarios. + // Key properties of swizzling in Triton are: + // + // - Swizzling applies only to tensors with rank ≥ 2. + // - It is restricted to the last two dimensions of the tensor. + // - These last two dimensions are always treated as the most "minor." + // + // An important edge case arises when `%a` results from `%a = ttg.subview %b`, + // where `%b` is swizzled (and so is `%a`). In this case, constructing a + // layout and determining shared memory offsets using `%a`'s shape is + // incorrect. This is because swizzling depends on the original shape of `%b`, + // which differs from `%a`'s shape. As a result, some locations may fall + // outside `%a`'s contiguous view of memory. Specifically, an element `[i + // (row_idx), j (col_idx)]` in `%a` might map to `[i, j']` after swizzling, + // where `j'` lies outside `%a`'s shape but still within `%b`'s shape. + // + // We propose case 2 (see comments below), which provides a more general + // solution for all swizzled shared memory scenarios, including the edge case + // mentioned above. + if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 + smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}})[0] + .second; + // This reverts #5645, because it introduced increased register pressure in + // AMD backend. + // TODO: remove when new implementation performance reaches target level + if (auto swizzledSharedEnc = + mlir::dyn_cast( + sharedEnc)) { + auto regToSharedLayout = + getRegToSharedLayout(ctx, shape, regLayout, swizzledSharedEnc, + elemLlvmTy.getIntOrFloatBitWidth()); + auto smemOrder = swizzledSharedEnc.getOrder(); + smemOffsets = llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}})))); + // Reorder strides according to `order`. This way they match the + // multi-dimensional offsets in regToSharedLayout. + smemOffset = dot(rewriter, loc, smemOffsets, + applyPermutation(smemStrides, smemOrder)); + } + } else { // Case 2 -> rank-reduced swizzling + assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2"); + assert(isa(sharedEnc) && + "NVMMA layout not supported for sliced tensors"); + // We define both tensor offsets and shared memory offsets: + // + // - Tensor offsets: Relative offsets within a given tensor. + // - Shared memory offsets: Absolute offsets within the shared memory. + // + // In Triton, the shared memory layout provides an invertible, one-to-one + // mapping between tensor offsets and shared memory offsets. The `base` + // field of any shared memory object represents both the shared memory + // offset and the tensor offset relative to the original tensor at + // allocation, prior to any subview operations. + // + // To determine the shared memory offsets for a specific register when + // dealing with swizzled and sliced tensors, the process involves: + // + // 1. Retrieving the original tensor's `invertAllocSharedLayout`, which + // maps the allocated tensor's offsets back to shared memory offsets. + // 2. Reconstructing the register's offsets in the allocated tensor by + // summing: + // - The shared memory offsets of the current view's base, and + // - The relative tensor offsets of the register. + // + // This approach ensures that "absolute" tensor offsets can be + // mapped to the correct shared memory addresses using + // `invertAllocSharedLayout`. + auto multiDimTensorOffsets = + llvm::to_vector(applyLinearLayout(loc, rewriter, regLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}})); + for (auto i = 0; i < rank; i++) { + multiDimTensorOffsets[i].second = + b.add(multiDimTensorOffsets[i].second, smemOffsets[i]); + } + smemOffset = applyLinearLayout(loc, rewriter, invertAllocSharedLayout, + multiDimTensorOffsets)[0] + .second; + Value baseToAllocBaseDist = dot(rewriter, loc, smemOffsets, smemStrides); + smemOffset = b.sub(smemOffset, baseToAllocBaseDist); + } + auto ptrTy = smemBase.getType(); + auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset); + vecAddr.setInbounds(true); + return vecAddr; +} + +} // namespace + +bool emitTransferBetweenRegistersAndShared( + LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, + std::optional maxVecElems, const SharedMemoryObject &smemObj, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + std::function perVectorCallback) { + MLIRContext *ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + + auto shape = sharedTy.getShape(); + LinearLayout sharedLayout = + triton::gpu::toLinearLayout(shape, sharedTy.getEncoding()); + LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + + // TODO(jlebar): We don't currently support loading from shared memory in a + // different CTA. We'd need to emit `mapa.shared::cluster` instructions. + for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock); + inBlock *= 2) { + auto idx = regToSharedLayout.apply( + {{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}}); + // Intra-block offset must be 0 + int32_t offset = idx[0].second; + if (offset != 0) { + return false; + } + // Check if there's any cross CTA load. + int32_t outBlock = idx[1].second; + if (outBlock != inBlock) { + return false; + } + } + + // Determine how many consecutive registers map to consecutive shmem elements + // in out-dimension offsetN. This is our load instruction's vector width. + // + // It's OK if the vector width we choose here is wider than the hardware + // supports; LLVM will legalize it. + const int vecElems = + std::min(regToSharedLayout.getNumConsecutiveInOut(), + maxVecElems.value_or(std::numeric_limits::max())); + + auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1; + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); + + // For kernels with a single CTA, `allocSharedLayout.sublayout(S("block"), + // outDims) == 0`. We need to take out the "block" dimension in order to use + // `invert`. + // For kernels with multiple CTAs per CGA, + // `allocSharedLayout.sublayout(S("block"), outDims) != 0`. We do not need to + // take out the "block" dimension. + // Thus we use `pseudoinvert` instead of `invert` here for simplicity. + auto allocShape = sharedTy.getAllocShape(); + LinearLayout invertAllocSharedLayout = + triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()), + sharedTy.getEncoding()) + .pseudoinvert(); + + int numElems = regToSharedLayout.getInDimSize(kRegister); + auto vecTy = vec_ty(elemLlvmTy, vecElems); + SmallVector ret; + for (int i = 0; i < numElems / vecElems; i++) { + auto regId = b.i32_val(i * vecElems); + auto vecAddr = getSmemVecAddr( + regLayout, regToSharedLayout, invertAllocSharedLayout, smemObj, + sharedTy, elemLlvmTy, regId, laneId, warpId, blockId, loc, rewriter); + perVectorCallback(vecTy, vecAddr); + } + return true; +} + +bool emitTransferBetweenRegistersAndShared( + RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, + Type elemLlvmTy, std::optional maxVecElems, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, + std::function perVectorCallback) { + auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(), + registerTy.getEncoding()); + return emitTransferBetweenRegistersAndShared( + regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, + target, perVectorCallback); +} + +SmallVector loadSharedToDistributed(RankedTensorType dstTy, + triton::gpu::MemDescType srcTy, + Type elemLlvmTy, + const SharedMemoryObject &smemObj, + Location loc, RewriterBase &rewriter, + const TargetInfoBase &target) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector ret; + bool success = emitTransferBetweenRegistersAndShared( + dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { + auto vecVal = b.load(vecTy, vecAddr); + vecVal.setAlignment(vecTy.getNumElements() * + elemLlvmTy.getIntOrFloatBitWidth() / 8); + + for (int v = 0; v < vecTy.getNumElements(); v++) { + ret.push_back(b.extract_element(elemLlvmTy, vecVal, b.i32_val(v))); + } + }); + if (!success) + llvm::report_fatal_error("Failed to emit transfer from shared to register"); + + return ret; +} + +void storeDistributedToShared(triton::gpu::MemDescType dstTy, + RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, + const SharedMemoryObject &smemObj, Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + std::pair *const llvmOpCount) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + bool success = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { + ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); + srcVals = srcVals.drop_front(vecTy.getNumElements()); + + Value vec = b.undef(vecTy); + for (int i = 0; i < vals.size(); i++) { + vec = b.insert_element(vec, vals[i], b.i32_val(i)); + } + b.store(vec, vecAddr) + .setAlignment(vecTy.getNumElements() * + elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } + }); + + if (!success) + llvm::report_fatal_error("Failed to emit transfer from register to shared"); +} + +SmallVector> emitOffsetForLayout(Attribute layout, + RankedTensorType type) { + MLIRContext *ctx = layout.getContext(); + auto shape = type.getShape(); + unsigned rank = shape.size(); + + auto ll = triton::gpu::toLinearLayout(shape, layout); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + SmallVector> offsets; + for (int i = 0; i < ll.getInDimSize(str_attr("register")); i++) { + auto idxs = ll.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + offsets.push_back( + llvm::to_vector_of(llvm::make_second_range(idxs))); + } + return offsets; +} + +namespace LLVM { +using namespace mlir::triton; +using mlir::triton::gpu::getOrder; + +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v) { + auto i1ty = rewriter.getIntegerType(1); + return rewriter.create(loc, i1ty, + IntegerAttr::get(i1ty, v)); +} + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { + auto i32ty = rewriter.getIntegerType(32); + return rewriter.create(loc, i32ty, + IntegerAttr::get(i32ty, v)); +} + +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v) { + auto i64ty = rewriter.getIntegerType(64); + return rewriter.create(loc, i64ty, + IntegerAttr::get(i64ty, v)); +} + +Value createConstantF16(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f16Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF16FloatAttr(v)); +} + +Value createConstantBF16(Location loc, OpBuilder &rewriter, float v) { + APFloat apf(v); + bool ignored; + apf.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &ignored); + auto type = type::bf16Ty(rewriter.getContext()); + auto attr = FloatAttr::get(type, apf); + return rewriter.create(loc, type, attr); +} + +Value createConstantF32(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f32Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF32FloatAttr(v)); +} + +Value createConstantF64(Location loc, OpBuilder &rewriter, double v) { + auto type = type::f64Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF64FloatAttr(v)); +} + +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type) { + if (!isa(type)) { + llvm::report_fatal_error("Creating NaN constant for non-float type!"); + } + return rewriter.create( + loc, type, APFloat::getNaN(cast(type).getFloatSemantics())); +} + +// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value) { + Type ty = converter->convertType(builder.getIndexType()); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { + Type ty = builder.getIntegerType(width); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc, + LLVMFuncOp funcOp, ValueRange args) { + auto op = builder.create(loc, funcOp, args); + op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({})); + op.getProperties().setOperandSegmentSizes({static_cast(args.size()), 0}); + return op; +} + +LLVM::CallIntrinsicOp +createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic, + TypeRange types, ValueRange args) { + auto op = builder.create(loc, types, args); + op.getProperties().setIntrin(builder.getStringAttr(intrinsic)); + op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({})); + op.getProperties().setOperandSegmentSizes({static_cast(args.size()), 0}); + return op; +} + +bool isConstantZero(Value v) { + if (auto constantOp = v.getDefiningOp()) { + if (auto attr = dyn_cast(constantOp.getValue())) { + return attr.getValue().isZero(); + } + if (auto attr = dyn_cast(constantOp.getValue())) { + return attr.getValue().isZero(); + } + } + return false; +} + +Value getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto elems = smemObj.getElems(); + auto types = smemObj.getTypes(); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + // pack into struct + Value llvmStruct = rewriter.create(loc, structTy); + for (const auto &v : llvm::enumerate(elems)) { + assert(v.value() && "can not insert null values"); + llvmStruct = b.insert_val(structTy, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector elems(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + elems[i] = b.extract_val(type, llvmStruct, i); + } + return {/*base=*/elems[0], + /*baseElemType=*/elemTy, + /*offsets=*/{elems.begin() + 1, elems.end()}}; +} + +// Extract the bits of `a` that are set in `mask` +Value pext_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(a.getType() == i32_ty && "a must be i32"); + // Handle width = 32 to avoid doing 1 << 32 + if (mask == 0xFFFFFFFF) + return a; + + // Implements the blocked algorithm from + // https://forums.developer.nvidia.com/t/pdep-and-pext-functionality-for-cuda/270973 + uint32_t mskConst = mask; + uint32_t extcnt = 0; + Value result = b.i32_val(0); + while (mskConst) { + uint32_t oldmsk = mskConst; + uint32_t bitgrplsb = mskConst & (-mskConst); + mskConst &= bitgrplsb + mskConst; + uint32_t bitgrp = mskConst ^ oldmsk; + uint32_t lsbpos = 31 - __builtin_clz(bitgrplsb); + // like popcount for a number 0..01..1..0 but portable + uint32_t grplen = __builtin_ctz(~(bitgrp >> lsbpos)); + uint32_t shift = lsbpos - extcnt; + extcnt += grplen; + result = + b.or_(result, b.lshr(b.and_(b.i32_val(bitgrp), a), b.i32_val(shift))); + } + return result; +} + +std::tuple, Value> +delinearize(RewriterBase &rewriter, Location loc, + triton::gpu::DistributedEncodingTrait layout, + ArrayRef shape, StringAttr dimName, Value linear) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ll = triton::gpu::toLinearLayout(shape, layout); + auto linearLayout = + triton::gpu::LinearEncodingAttr::get(rewriter.getContext(), ll); + assert(ll.hasInDim(dimName)); + int32_t freeVarMask = ll.getFreeVariableMasks()[dimName]; + auto isRepresentative = b.true_val(); + if (freeVarMask != 0) { + isRepresentative = + b.icmp_eq(b.and_(b.i32_val(freeVarMask), linear), b.i32_val(0)); + // We remove the bits of linear that are set to one in freeVarMask + int32_t nonFreeVarMask = ~freeVarMask & (ll.getInDimSize(dimName) - 1); + linear = pext_i32(rewriter, loc, linear, nonFreeVarMask); + } + + auto orderDim = linearLayout.orderPerDim(dimName, linearLayout.getOrder()); + auto shapeDim = linearLayout.basesPerDim(dimName); + auto multiDim = delinearize(rewriter, loc, linear, shapeDim, orderDim); + + return std::make_tuple(std::move(multiDim), isRepresentative); +} + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + SmallVector reorderedMultiDim(rank); + if (auto constantOp = linear.getDefiningOp()) { + unsigned intVal = mlir::cast(constantOp.getValue()) + .getValue() + .getSExtValue(); + reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered); + } else { + reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); + } + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + unsigned remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + unsigned dimSize = en.value(); + multiDim[en.index()] = b.i32_val(remained % dimSize); + remained = remained / dimSize; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + Value remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + Value dimSize = b.i32_val(en.value()); + multiDim[en.index()] = b.urem(remained, dimSize); + remained = b.udiv(remained, dimSize); + } + return multiDim; +} + +SmallVector delinearize(unsigned linear, ArrayRef shape, + ArrayRef order) { + auto rank = shape.size(); + assert(order.size() == rank); + SmallVector multiDim(rank); + for (auto dim : order) { + multiDim[dim] = linear % shape[dim]; + linear /= shape[dim]; + } + assert(linear == 0); + return multiDim; +} + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(rewriter, loc, applyPermutation(multiDim, order), + applyPermutation(shape, order)); +} + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto rank = multiDim.size(); + Value linear = b.i32_val(0); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.i32_val(dimShape); + linear = b.add(b.mul(linear, dimSize), dim); + } + } + return linear; +} + +size_t linearize(ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + size_t linear = 0; + for (unsigned dim : llvm::reverse(order)) + linear = linear * shape[dim] + multiDim[dim]; + return linear; +} + +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto ctx = moduleOp.getContext(); + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (key + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> contentStr(content); + size_t contentSize = contentStr.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize); + + LLVM::GlobalOp global; + { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(ctx), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(contentStr)); + } + + Value zero = b.i32_val(0); + Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); + Value globalPtr = rewriter.create( + UnknownLoc::get(ctx), globalPtrType, global.getSymName()); + Value stringStart = + b.gep(ptr_ty(ctx), i8_ty, globalPtr, SmallVector({zero})); + return stringStart; +} + +} // namespace LLVM + +SharedMemoryObject +getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, + SharedMemoryObject smemObj, + ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(shape.size() == 2 || shape.size() == 3); + auto offsets = smemObj.getOffsets(); + auto rank = offsets.size(); + assert(rank == shape.size()); + if (rank == 3) + return smemObj; + offsets.insert(offsets.begin(), b.i32_val(0)); + auto expandedSmemObj = + SharedMemoryObject(smemObj.getBase(), smemObj.getBaseElemType(), offsets); + return expandedSmemObj; +} + +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp new file mode 100644 index 000000000..f0ebdf6a0 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -0,0 +1,422 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +namespace { +struct SplatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a + // LLVM::StructType value. + // + // @elemType: the element type in operand. + // @resType: the return type of the Splat-like op. + // @constVal: a LLVM::ConstantOp or other scalar value. + static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto tensorTy = cast(resType); + // Check the converted type for the tensor as depending on the encoding the + // converter may pick different element types. + auto srcType = typeConverter->convertType(tensorTy); + if (auto structTy = dyn_cast(srcType)) + srcType = structTy.getBody()[0]; + // If the type sizes don't match we need to pack constants. + if (srcType.isIntOrFloat() && constVal.getType().getIntOrFloatBitWidth() != + srcType.getIntOrFloatBitWidth()) { + unsigned cstBitWidth = constVal.getType().getIntOrFloatBitWidth(); + unsigned srcBitWidth = srcType.getIntOrFloatBitWidth(); + assert(cstBitWidth <= srcBitWidth && srcBitWidth % cstBitWidth == 0); + unsigned ratio = srcBitWidth / cstBitWidth; + Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); + VectorType vecType = VectorType::get(ratio, intTy); + Value intCst = b.bitcast(constVal, intTy); + Value vec = b.undef(vecType); + for (unsigned i = 0; i < ratio; ++i) + vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i)); + constVal = vec; + } + auto llSrc = b.bitcast(constVal, srcType); + size_t elemsPerThread = getTotalElemsPerThread(tensorTy); + llvm::SmallVector elems(elemsPerThread, llSrc); + return packLLElements(loc, typeConverter, elems, rewriter, resType); + } + LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto src = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, + typeConverter, rewriter, loc); + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; +// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), +// the logic is the same as triton::SplatOp, so the underlying implementation +// is reused. +struct ArithConstantSplatOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto value = op.getValue(); + if (!mlir::dyn_cast(value)) + return failure(); + auto loc = op->getLoc(); + LLVM::ConstantOp arithConstantOp; + auto values = mlir::dyn_cast(op.getValue()); + auto elemType = values.getElementType(); + Attribute val; + if (type::isFloat(elemType)) { + val = values.getValues()[0]; + } else if (type::isInt(elemType)) { + val = values.getValues()[0]; + } else { + llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: " + << value.getType() << "\n"; + return failure(); + } + // Lower FP8 constant to int8 constant since FP8 types are not supported on + // LLVM IR. + if (type::isFloat8(elemType)) + elemType = rewriter.getIntegerType(8); + auto constOp = rewriter.create(loc, elemType, val); + auto typeConverter = getTypeConverter(); + auto llStruct = SplatOpConversion::convertSplatLikeOp( + elemType, op.getType(), constOp, typeConverter, rewriter, loc); + rewriter.replaceOp(op, llStruct); + return success(); + } +}; +struct CatOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename CatOp::Adaptor; + explicit CatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + unsigned elems = getTotalElemsPerThread(resultTy); + auto typeConverter = getTypeConverter(); + Type elemTy = typeConverter->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + // unpack input values + auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter); + auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter); + // concatenate (and potentially reorder) values + SmallVector retVals; + for (Value v : lhsVals) + retVals.push_back(v); + for (Value v : rhsVals) + retVals.push_back(v); + // pack and replace + Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct JoinOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename JoinOp::Adaptor; + explicit JoinOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The op has a blocked encoding. + // - The last dimension (the one we're joining) is also the most minor + // dimension. + // - The input and output encodings are the same, except the output has + // 2 elements per thread in the last dim. + // + // With these invariants, join is trivial: We just return the i'th element + // from lhs, followed by the i'th elem from rhs. + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + SmallVector lhsVals = + unpackLLElements(loc, adaptor.getLhs(), rewriter); + SmallVector rhsVals = + unpackLLElements(loc, adaptor.getRhs(), rewriter); + assert(lhsVals.size() == rhsVals.size()); + SmallVector joinedVals; + for (int i = 0; i < lhsVals.size(); i++) { + joinedVals.push_back(lhsVals[i]); + joinedVals.push_back(rhsVals[i]); + } + Value ret = + packLLElements(loc, typeConverter, joinedVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct SplitOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename SplitOp::Adaptor; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The layout distribute the last dimension along registers + // - The last dimension (the one we're splitting) has sizePerThread=2, + // threadPerWarp=1 and warpPerBlock=1. + // + // With these invariants, split is trivial: We can count how many contiguous + // registers belong to the same chunk then we separate the registers between + // two different chunks. + auto srcTy = cast(op.getSrc().getType()); + auto ll = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + int splitDim = srcTy.getRank() - 1; + auto kReg = mlir::StringAttr::get(srcTy.getContext(), "register"); + const auto &bases = ll.getBases(); + const auto ®s = bases.find(kReg)->second; + int numContiguousValues = 1; + bool found = false; + for (const auto ® : regs) { + if (reg[splitDim] != 0) { + found = true; + break; + } + numContiguousValues *= 2; + } + assert(found && "Split dimension is not distributed along registers."); + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + SmallVector srcVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(srcVals.size() % 2 == 0); + SmallVector outLhsVals; + SmallVector outRhsVals; + for (int i = 0; i < srcVals.size(); i += 2 * numContiguousValues) { + for (int j = 0; j < numContiguousValues; j++) { + outLhsVals.push_back(srcVals[i + j]); + outRhsVals.push_back(srcVals[i + numContiguousValues + j]); + } + } + auto resultTy = cast(op.getResult(0).getType()); + Value retLhs = + packLLElements(loc, typeConverter, outLhsVals, rewriter, resultTy); + Value retRhs = + packLLElements(loc, typeConverter, outRhsVals, rewriter, resultTy); + rewriter.replaceOp(op, {retLhs, retRhs}); + return success(); + } +}; +struct ReshapeOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ReshapeOp::Adaptor; + explicit ReshapeOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) { + return emitOptionalError(loc, + "expensive view not supported on reshape op"); + } + auto resultTy = cast(op.getType()); + auto srcTy = cast(op.getSrc().getType()); + auto typeConverter = getTypeConverter(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, typeConverter, vals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ExpandDimsOp::Adaptor; + explicit ExpandDimsOpConversion( + LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(op.getType()); + auto srcLayout = dyn_cast(srcTy.getEncoding()); + if (!srcLayout) { + return emitOptionalError( + loc, "ExpandDimsOp only supports SliceEncodingAttr as its input"); + } + auto resultLayout = resultTy.getEncoding(); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + offset.erase(offset.begin() + srcLayout.getDim()); + resultVals.push_back(srcValues.at(offset)); + } + Value ret = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct MemDescTransOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.getBase(), srcSmemObj.getBaseElemType(), + /*offsets=*/applyPermutation(srcSmemObj.getOffsets(), op.getOrder())); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct TransOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By construction, TransOp::inferReturnTypes ensures that the src encoding + // is the same as the dst encoding so that this op is a no-op. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } +}; + +struct BroadcastOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Following the order of indices in the legacy code, a broadcast of: + // [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)] + // => + // [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)] + // + // logically maps to a broadcast within a thread's scope: + // [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1), + // 1,spt(k+1)..spt(n-1)] + // => + // [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)] + // + // regardless of the order of the layout + // + Location loc = op->getLoc(); + Value src = adaptor.getSrc(); + Value result = op.getResult(); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(result.getType()); + auto srcLayout = srcTy.getEncoding(); + auto resultLayout = resultTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultShape = resultTy.getShape(); + unsigned rank = srcTy.getRank(); + auto typeConverter = getTypeConverter(); + assert(rank == resultTy.getRank()); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + SmallVector srcVals = unpackLLElements(loc, src, rewriter); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) + offset[j] = 0; + resultVals.push_back(srcValues.at(offset)); + } + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct MemDescSubviewOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto layoutOrder = getOrder(srcTy); + + // newBase = base + offset + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter); + SmallVector opOffsetVals = op.getOffsets(); + SmallVector opSmemStrides(smemStrides.end() - opOffsetVals.size(), + smemStrides.end()); + SmallVector offsetVals; + auto destRank = op.getResult().getType().getRank(); + auto rankReduced = srcTy.getRank() - destRank; + for (int i = rankReduced; i < opOffsetVals.size(); i++) { + offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i])); + } + // Compute the offset based on the original strides of the shared memory + // object + auto offset = dot(rewriter, loc, opOffsetVals, opSmemStrides); + auto elemPtrTy = smemObj.getBase().getType(); + smemObj = SharedMemoryObject( + b.gep(elemPtrTy, llvmElemTy, smemObj.getBase(), offset), llvmElemTy, + offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; +} // namespace + +void mlir::triton::populateViewOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..7f4f5be20 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(TritonToTritonGPU + TritonGPUConversion.cpp + TritonToTritonGPUPass.cpp + + DEPENDS + TritonConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + TritonIR +# ProtonIR + TritonGPUIR + TritonGPUTransforms +) diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp new file mode 100644 index 000000000..773c01e4a --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -0,0 +1,124 @@ +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +#include +#include + +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +// +// TypeConverter +// +TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + int numWarps, int threadsPerWarp, + int numCTAs) + : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp), + numCTAs(numCTAs) { + addConversion([](Type type) { return type; }); + + // Add encoding for tensor + addConversion([this](RankedTensorType tensorType) -> RankedTensorType { + // types with encoding are already in the right format + // TODO: check for layout encodings more specifically + if (tensorType.getEncoding()) + return tensorType; + ArrayRef shape = tensorType.getShape(); + triton::gpu::BlockedEncodingAttr encoding = + getDefaultBlockedEncoding(this->context, shape, this->numWarps, + this->threadsPerWarp, this->numCTAs); + return RankedTensorType::get(shape, tensorType.getElementType(), encoding); + }); + + // Add encoding for tensor pointer + addConversion([this](triton::PointerType ptrType) -> triton::PointerType { + // Check whether tensor pointer `tt.ptr>` + auto pointeeTensorType = + dyn_cast(ptrType.getPointeeType()); + if (pointeeTensorType == nullptr) + return ptrType; + + // Add layout into the tensor + auto convertedTensorType = convertType(pointeeTensorType); + return triton::PointerType::get(convertedTensorType, + ptrType.getAddressSpace()); + }); + + // + // Materializations + // + // This will be called when (newArgType != origArgType) + // This will create newArg, and map(origArg, newArg) + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> Value { + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); + return {}; + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) -> Value { + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); + return {}; + }); + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. + addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + auto cast = + builder.create(loc, tensorType, inputs); + return cast.getResult(); + }); +} + +// +// TritonGPUConversion +// +TritonGPUConversionTarget::TritonGPUConversionTarget( + MLIRContext &context, TritonGPUTypeConverter &typeConverter) + : ConversionTarget(context) { + // TODO: we should also verify ops of TritonGPUDialect + addLegalDialect(); + + // Some ops from SCF are illegal + addIllegalOp(); + + addDynamicallyLegalDialect( + [&](Operation *op) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; + }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { + Attribute aEncoding = + cast(dotOp.getA().getType()).getEncoding(); + Attribute bEncoding = + cast(dotOp.getB().getType()).getEncoding(); + if (aEncoding && isa(aEncoding) && + bEncoding && isa(bEncoding)) + return true; + return false; + }); +} diff --git a/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp new file mode 100644 index 000000000..39ed637dd --- /dev/null +++ b/third_party/enflame/include/triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -0,0 +1,903 @@ +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "llvm/ADT/APSInt.h" +#include + +#define GEN_PASS_CLASSES +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +// #include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// pass named attrs (e.g., tt.contiguity) from Triton to Triton +static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { + for (const NamedAttribute attr : dictAttrs.getValue()) + if (!op->hasAttr(attr.getName())) + op->setAttr(attr.getName(), attr.getValue()); +} + +template struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +class ArithConstantPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = getTypeConverter()->convertType(op.getType()); + auto retShapedType = cast(retType); + auto value = dyn_cast(adaptor.getValue()); + if (isa(retShapedType)) { + assert(value && "expected a dense elements attribute"); + // This is a hack. We just want to add encoding. + value = value.reshape(retShapedType); + } + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retShapedType, value), + adaptor.getAttributes()); + return success(); + } +}; + +void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + // -------------- + // Add legality and rewrite pattern rules for operations + // from the Arith dialect. The basic premise is that + // Arith operations require both inputs to have the same + // non-null encoding + // -------------- + MLIRContext *context = patterns.getContext(); + // TODO: there's probably a better way to avoid adding all ops one-by-one + patterns.add< + ArithConstantPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // NegFOp + // Floating point + GenericOpPattern, GenericOpPattern, + // MaxMin + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + // Floating point + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + // Cmp + GenericOpPattern, GenericOpPattern, + // Select + GenericOpPattern, + // Cast Ops + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); +} + +void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern>( + typeConverter, context); +} + +// +// Triton patterns +// +struct TritonExpandDimsPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Type retType = op.getType()); + RankedTensorType argType = + cast(adaptor.getSrc().getType()); + Attribute _argEncoding = argType.getEncoding(); + if (!_argEncoding) + return failure(); + auto argEncoding = cast(_argEncoding); + // return shape + auto retShape = argType.getShape().vec(); + retShape.insert(retShape.begin() + op.getAxis(), 1); + // return encoding + auto retSizePerThread = llvm::to_vector(argEncoding.getSizePerThread()); + retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1); + auto retThreadsPerWarp = argEncoding.getThreadsPerWarp(); + retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1); + auto retWarpsPerCTA = argEncoding.getWarpsPerCTA(); + retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1); + SmallVector retOrder(retShape.size()); + std::iota(retOrder.begin(), retOrder.end(), 0); + + auto argCTALayout = argEncoding.getCTALayout(); + auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis()); + auto retCTASplitNum = + insertOne(argCTALayout.getCTASplitNum(), op.getAxis()); + auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis()); + auto retCTALayout = triton::gpu::CTALayoutAttr::get( + getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder); + + triton::gpu::BlockedEncodingAttr retEncoding = + triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, + retThreadsPerWarp, retWarpsPerCTA, + retOrder, retCTALayout); + // convert operand to slice of return type + Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( + getContext(), op.getAxis(), retEncoding); + RankedTensorType newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), newArgEncoding); + // construct new op + auto newSrc = rewriter.create( + op.getLoc(), newArgType, adaptor.getSrc()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newSrc, adaptor.getAxis()), + adaptor.getAttributes()); + return success(); + } + +private: + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } +}; + +struct TritonDotPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType origType = op.getType(); + auto origShape = origType.getShape(); + auto typeConverter = getTypeConverter(); + int numWarps = typeConverter->getNumWarps(); + int threadsPerWarp = typeConverter->getThreadsPerWarp(); + int numCTAs = typeConverter->getNumCTAs(); + auto rank = origShape.size(); + SmallVector retSizePerThread(rank, 1); + auto numElements = product(origShape); + if (numElements / (numWarps * threadsPerWarp) >= 4) { + retSizePerThread[rank - 1] = 2; + retSizePerThread[rank - 2] = 2; + } + if (numElements / (numWarps * threadsPerWarp) >= 16) { + retSizePerThread[rank - 1] = 4; + retSizePerThread[rank - 2] = 4; + } + retSizePerThread[rank - 1] = std::min( + retSizePerThread[rank - 1], static_cast(origShape[rank - 1])); + retSizePerThread[rank - 2] = std::min( + retSizePerThread[rank - 2], static_cast(origShape[rank - 2])); + + SmallVector retOrder(rank); + for (unsigned i = 0; i < rank; ++i) + retOrder[i] = rank - 1 - i; + Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( + getContext(), origShape, retSizePerThread, retOrder, numWarps, + threadsPerWarp, numCTAs); + RankedTensorType retType = + RankedTensorType::get(origShape, origType.getElementType(), dEncoding); + // a & b must be of smem layout + auto aType = cast(adaptor.getA().getType()); + auto bType = cast(adaptor.getB().getType()); + Type aEltType = aType.getElementType(); + Type bEltType = bType.getElementType(); + Attribute aEncoding = aType.getEncoding(); + Attribute bEncoding = bType.getEncoding(); + if (!aEncoding || !bEncoding) + return failure(); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + Value c = adaptor.getC(); + if (!mlir::isa(aEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 0, dEncoding, aEltType); + auto dstType = + RankedTensorType::get(aType.getShape(), aEltType, encoding); + a = rewriter.create(a.getLoc(), dstType, a); + } + if (!mlir::isa(bEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 1, dEncoding, bEltType); + auto dstType = + RankedTensorType::get(bType.getShape(), bEltType, encoding); + b = rewriter.create(b.getLoc(), dstType, b); + } + c = rewriter.create(c.getLoc(), retType, c); + + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, a, b, c, adaptor.getInputPrecision(), + adaptor.getMaxNumImpreciseAcc()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The cat op satisfy two conditions: + // 1. output.numel = lhs.numel + rhs.numel + // 2. output.total_elems_per_thread = + // next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread) + // For now, this behaves like generic, but this + // will evolve when we add support for `can_reorder=False`. + auto retType = cast( + this->getTypeConverter()->convertType(op.getType())); + auto retEncoding = + cast(retType.getEncoding()); + auto lhsType = adaptor.getLhs().getType(); + auto rhsType = adaptor.getRhs().getType(); + auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType); + auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType); + auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType); + auto retShape = retType.getShape(); + auto retOrder = retEncoding.getOrder(); + auto retThreadsPerWarp = retEncoding.getThreadsPerWarp(); + auto retWarpsPerCTA = retEncoding.getWarpsPerCTA(); + // Get new retSizePerThread if ret elems per thread is not enough. + // We have to round it up to the next power of 2 due to triton's tensor size + // constraint. + auto newRetTotalElemsPerThread = + nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread); + auto newRetSizePerThread = llvm::to_vector(retEncoding.getSizePerThread()); + newRetSizePerThread[retOrder[0]] *= + newRetTotalElemsPerThread / retTotalElemsPerThread; + triton::gpu::BlockedEncodingAttr newRetEncoding = + triton::gpu::BlockedEncodingAttr::get( + getContext(), newRetSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retEncoding.getCTALayout()); + auto newRetType = RankedTensorType::get(retShape, retType.getElementType(), + newRetEncoding); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newRetType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonJoinOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Simply rely on type inference for this op. (Notably, GenericOpPattern + // does not do this, instead it assigns the default layout to the ins and + // outs.) + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getLhs(), adaptor.getRhs()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonSplitOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = dyn_cast(srcTy.getEncoding()); + int rank = srcEnc.getOrder().size(); + auto typeConverter = getTypeConverter(); + + // The operand to split must have: + // - a blocked layout, with + // - sizePerThread = 2 in the last dimension, + // - threadsPerWarp, warpsPerCTA, and CTAsPerCGA = 1 in the last dim, and + // - the last dimension minor. + // If that's not the case, add a convert before the split. + if (!srcEnc || srcEnc.getSizePerThread().back() != 2 || + srcEnc.getOrder().front() != rank - 1) { + // If we take the default encoding for the op's result (i.e. post-split) + // and add 1 to the end of each dim, that gives us what we want. Other + // than making a legal src encoding, our choice of layout doesn't matter; + // it'll get fixed by RemoveLayoutConversions. + auto defaultEnc = getDefaultBlockedEncoding( + getContext(), + cast(op.getResult(0).getType()).getShape(), + typeConverter->getNumWarps(), typeConverter->getThreadsPerWarp(), + typeConverter->getNumCTAs()); + + auto append = [&](ArrayRef vals, unsigned val) { + SmallVector res(vals); + res.push_back(val); + return res; + }; + auto prepend = [&](ArrayRef vals, unsigned val) { + SmallVector res; + res.push_back(val); + res.append(vals.begin(), vals.end()); + return res; + }; + + srcEnc = BlockedEncodingAttr::get( + getContext(), append(defaultEnc.getSizePerThread(), 2), + append(defaultEnc.getThreadsPerWarp(), 1), + append(defaultEnc.getWarpsPerCTA(), 1), + prepend(defaultEnc.getOrder(), rank - 1), + CTALayoutAttr::get(getContext(), + append(defaultEnc.getCTAsPerCGA(), 1), + append(defaultEnc.getCTASplitNum(), 1), + prepend(defaultEnc.getCTAOrder(), rank - 1))); + srcTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), + srcEnc); + src = rewriter.create(op.getLoc(), srcTy, src); + } + + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src), + adaptor.getAttributes()); + return success(); + } +}; + +// This function returns the layout to use for gather/scatter indices. The +// `gather4` and `scatter4` TMA instructions require 4 consecutive indices. +// Thus, threads issuing these instructions must have all 4 index elements +// available. +static RankedTensorType getNewIndicesType(RankedTensorType type) { + assert(type.getRank() == 1); + auto enc = cast(type.getEncoding()); + + // Technically any layout where we have a pack of 4 neighbouring elements plus + // broadcasted over the warp dimension is okay but for now we just pick a + // layout. + unsigned numThreadsPerWarp = product(enc.getThreadsPerWarp()); + unsigned numWarps = product(enc.getWarpsPerCTA()); + std::array sizePerThread{1, 4}; + std::array threadsPerWarp = {numThreadsPerWarp, 1}; + std::array order = {1, 0}; + std::array warpsPerCta = {1, static_cast(numWarps)}; + + MLIRContext *ctx = type.getContext(); + auto ctaLayout = CTALayoutAttr::getDefault(ctx, /*rank=*/2); + auto parentEncoding = BlockedEncodingAttr::get( + ctx, sizePerThread, threadsPerWarp, warpsPerCta, order, ctaLayout); + auto newEncoding = SliceEncodingAttr::get(ctx, /*dim=*/0, parentEncoding); + if (enc == newEncoding) + return {}; + + return RankedTensorType::get(type.getShape(), type.getElementType(), + newEncoding); +} + +struct TritonDescriptorGatherPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExperimentalDescriptorGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType newType = getNewIndicesType( + cast(adaptor.getXOffsets().getType())); + if (!newType) + return failure(); + + Value newInd = rewriter.create(op.getLoc(), newType, + adaptor.getXOffsets()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.getDesc(), + newInd, adaptor.getYOffset()); + return success(); + } +}; + +struct TritonDescriptorScatterPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExperimentalDescriptorScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType newType = getNewIndicesType( + cast(adaptor.getXOffsets().getType())); + if (!newType) + return failure(); + + Value newInd = rewriter.create(op.getLoc(), newType, + adaptor.getXOffsets()); + rewriter.replaceOpWithNewOp( + op, adaptor.getDesc(), newInd, adaptor.getYOffset(), adaptor.getSrc()); + return success(); + } +}; + +struct TritonTransPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + if (!srcEnc) + return failure(); + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src, op.getOrder()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonBroadcastPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This creates a tensor with the new shape but the argument's layout + LogicalResult + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(adaptor.getSrc().getType()); + auto srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + Type retType = RankedTensorType::get( + op.getType().getShape(), op.getType().getElementType(), srcEncoding); + // Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReduce = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonScanPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newScan = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis(), op.getReverse()); + addNamedAttrs(newScan, adaptor.getAttributes()); + + auto &newCombineOp = newScan.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newScan.getResult()); + return success(); + } +}; + +class TritonFuncOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getName(), op.getFunctionType()); + addNamedAttrs(newOp, adaptor.getAttributes()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter))) + return failure(); + + return success(); + } +}; + +class TritonCallOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getCallee(), op.getResultTypes(), adaptor.getOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + return success(); + } +}; + +class TritonReturnOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, unsigned numCTAs) { + MLIRContext *context = patterns.getContext(); + patterns.insert< // TODO: view should have custom pattern that views the + // layout + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + TritonBroadcastPattern, GenericOpPattern, + TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonReducePattern, + GenericOpPattern, TritonScanPattern, + GenericOpPattern, + GenericOpPattern, TritonExpandDimsPattern, + TritonTransPattern, TritonDotPattern, TritonDescriptorGatherPattern, + TritonDescriptorScatterPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + // this assumes the right layout will be set later for dot scaled. + GenericOpPattern, GenericOpPattern, + TritonFuncOpPattern>(typeConverter, context); +} +// Proton patterns +// NOTE: Because Proton's inputs are scalars and not tensors this conversion +// isn't strictly nessessary however you could envision a case where we pass in +// tensors in for Triton object specific tracing operations in which case we +// would need to fill in the OpConversionPattern +void populateProtonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + // patterns.add>(typeConverter, + // context); +} +// +// SCF patterns +// +// This is borrowed from ConvertForOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +struct SCFForPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + // Ref: ConvertForOpTypes + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Now, update all the types. + + // Convert the types of block arguments within the given region. This + // replaces each block with a new block containing the updated signature. + // The entry block may have a special conversion if `entryConversion` is + // provided. On success, the new entry block to the region is returned for + // convenience. Otherwise, failure is returned. + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a IRMapping, but this seems a bit more direct. + newOp->setOperands(adaptor.getOperands()); + // Update the result types to the new converted types. + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + + rewriter.replaceOp(op, newOp.getResults()); + + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFWhilePattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +class SCFConditionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); +} + +// CF + +class CFBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getSuccessor(), adaptor.getOperands()); + if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +class CFCondBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + + if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(), + *converter))) + return failure(); + if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +void populateCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(typeConverter, context); +} +// + +class ConvertTritonToTritonGPU + : public ConvertTritonToTritonGPUBase { +public: + ConvertTritonToTritonGPU() = default; + // constructor with some parameters set explicitly. + ConvertTritonToTritonGPU(const std::string &target, int numWarps, + int threadsPerWarp, int numCTAs) { + this->numWarps = numWarps; + this->threadsPerWarp = threadsPerWarp; + this->numCTAs = numCTAs; + this->target = target; + } + + void runOnOperation() override { + if (target.getValue().empty()) { + mlir::emitError( + getOperation().getLoc(), + "'convert-triton-to-tritongpu' requires 'target' option to be set"); + return signalPassFailure(); + } + + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + // type converter + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs); + TritonGPUConversionTarget target(*context, typeConverter); + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithPatternsAndLegality(typeConverter, patterns, target); + populateMathPatternsAndLegality(typeConverter, patterns, target); + populateTritonPatterns(typeConverter, patterns, numCTAs); + populateProtonPatterns(typeConverter, patterns); + // TODO: can we use + // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? + populateSCFPatterns(typeConverter, patterns); + populateCFPatterns(typeConverter, patterns); + patterns.insert>(typeConverter, context); + + auto inti = llvm::APSInt(32, false); + + Builder b(&getContext()); + mod->setAttr(AttrNumWarpsName, b.getI32IntegerAttr(numWarps)); + mod->setAttr(AttrNumThreadsPerWarp, b.getI32IntegerAttr(threadsPerWarp)); + mod->setAttr(AttrNumCTAsName, b.getI32IntegerAttr(numCTAs)); + mod->setAttr(AttrTargetName, b.getStringAttr(this->target.getValue())); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + + // update layouts + // broadcast src => multicast, dst => broadcasted + // if (failed(target.refineLayouts(mod, numWarps))) + // return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass(const std::string &target, + int numWarps, + int threadsPerWarp, + int numCTAs) { + return std::make_unique<::ConvertTritonToTritonGPU>(target, numWarps, + threadsPerWarp, numCTAs); +} + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass() { + return std::make_unique<::ConvertTritonToTritonGPU>(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..6ef40db00 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) +add_subdirectory(TritonNvidiaGPU) diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..788774a7c --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LLVM_TARGET_DEFINITIONS Canonicalize.td) +mlir_tablegen(TritonCanonicalize.inc -gen-rewriters) +add_public_tablegen_target(TritonCanonicalizeIncGen) + +add_triton_library(TritonIR + Dialect.cpp + Ops.cpp + Traits.cpp + Types.cpp + OpInterfaces.cpp + + DEPENDS + TritonTableGen + TritonCanonicalizeIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Canonicalize.td b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Canonicalize.td new file mode 100644 index 000000000..dc3771033 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Canonicalize.td @@ -0,0 +1,17 @@ +#ifndef TT_PATTERNS +#define TT_PATTERNS + +include "mlir/IR/PatternBase.td" +include "triton/Dialect/Triton/IR/TritonOps.td" + +// broadcast(splat(x)) -> splat(x) +def BroadcastSplatPattern : + Pat<(TT_BroadcastOp (TT_SplatOp $x)), + (TT_SplatOp $x)>; + +// broadcast(broadcast(x)) -> broadcast(x) +def BroadcastBroadcastPattern : + Pat<(TT_BroadcastOp (TT_BroadcastOp $x)), + (TT_BroadcastOp $x)>; + +#endif diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Dialect.cpp new file mode 100644 index 000000000..2874a3f56 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Dialect.cpp @@ -0,0 +1,98 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/DialectImplementation.h" + +#include "mlir/Transforms/InliningUtils.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc" +#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TritonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + IRMapping &) const final { + return true; + } + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } +}; + +} // namespace + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + >(); + + // We can also add interface here. + addInterfaces(); +} + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/IR/OpInterfaces.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/OpInterfaces.cpp new file mode 100644 index 000000000..cc7792a08 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/OpInterfaces.cpp @@ -0,0 +1,78 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { +namespace triton { +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op) { + TransposeOpInterface transposeOp = cast(op); + auto rank = cast(transposeOp.getSrc().getType()).getRank(); + auto order = transposeOp.getOrder(); + if (rank != order.size()) { + return op->emitError( + "order must have the same size as the rank of the operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return op->emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +// A DotOpInterface operation should have at least three operands. +// The first two operands should share a common dimension, and the result +// should have the dimensions of the two operands that are not shared. +// A DotOpInterface operation can be either 2d or 3d. +// In the 3d case, the first dimension of operands is the batch dimension. +LogicalResult verifyDotOpInterface(Operation *op) { + DotOpInterface dotOp = cast(op); + + if (dotOp->getNumOperands() < 3) + return dotOp->emitOpError("expected at least 3 operands"); + auto aTy = cast(dotOp->getOperand(0).getType()); + auto bTy = cast(dotOp->getOperand(1).getType()); + auto cTy = cast(dotOp->getOperand(2).getType()); + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + auto cShape = cTy.getShape(); + // Check if all 3d or all 2d + if (aShape.size() != 2 && aShape.size() != 3) + return dotOp->emitOpError("expected operands to be 2d or 3d"); + if (aShape.size() != bShape.size() || aShape.size() != cShape.size()) + return dotOp->emitOpError("expected all operands to have the same rank"); + + // Check for valid A, B input shapes for dot + if (!dotOp.verifyDims()) + return dotOp->emitOpError( + "expected the last dimension of the first operand " + "to be equal to the second-to-last dimension of " + "the second operand"); + + // Check the batch dimension + if (aShape.size() == 3 && (aShape[0] != cShape[0] || bShape[0] != cShape[0])) + return dotOp->emitOpError("expected the first dimension of the first " + "operand to be equal to the first dimension of " + "the result"); + // Check the output shape + if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] || + cShape[cShape.size() - 1] != bShape[aShape.size() - 1]) + return dotOp->emitOpError( + "expected the output shape to be the concatenation of the last " + "dimension of the first operand and the last dimension of the " + "second "); + return success(); +} + +} // namespace impl +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Ops.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 000000000..70e8811f3 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,1294 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/Support/ErrorHandling.h" + +namespace mlir { +namespace triton { + +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), + triton::GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" + +#include "TritonCanonicalize.inc" + +namespace mlir { +namespace triton { + +//-- LoadOp -- +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile); +} + +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { + CanonicalizeMaskedLoadPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return success(); + } +}; + +void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- StoreOp -- +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + /*boundaryCheck=*/{}, cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, Value mask, CacheModifier cache, + EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, + cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, ArrayRef boundaryCheck, + CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + builder.getDenseI32ArrayAttr(boundaryCheck), cache, + evict); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern : public OpRewritePattern { + CanonicalizeMaskedStorePattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return success(); + } +}; + +void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- TransOp -- +OpFoldResult TransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + // If the source and result types are the same, we can return the source + // If their layout is different (even if structurally equivalent), we need + // to insert a convert_layout in between as otherwise ::fold complains + // We do this in CanonicalizeConvertFromTranspose + if (getSrc().getType() == getType()) { + return getSrc(); + } + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + // Eliminate splat constant transpose ops. + if (auto attr = + llvm::dyn_cast_if_present(adaptor.getSrc())) + return attr.reshape(getType()); + + return {}; +} + +LogicalResult TransOp::verify() { + auto order = getOrder(); + auto srcTy = cast(getSrc().getType()); + if (order.size() != srcTy.getShape().size()) { + return emitError("order must have the same size as the source tensor"); + } + if (!isPermutationOfIota(order)) { + return emitError("order must be a permutation of 0..n-1"); + } + SmallVector retShape = applyPermutation(srcTy.getShape(), order); + if (retShape != getType().getShape()) { + return emitError( + "result shape must match the permutation of the source shape"); + } + return success(); +} + +LogicalResult TransOp::inferReturnTypes( + MLIRContext *context, std::optional location, + TransOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + + // type is the same as the input + auto argTy = cast(adaptor.getSrc().getType()); + auto shape = argTy.getShape(); + auto order = adaptor.getOrder(); + SmallVector retShape = applyPermutation(shape, order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, shape, order, retEncoding) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + return success(); +} + +//-- DotOp -- +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc && retEnc); + Dialect &dialect = retEnc.getDialect(); + auto interface = cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult DotOp::verify() { + auto aTy = getA().getType(); + auto bTy = getB().getType(); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + auto accTy = getC().getType(); + auto retEnc = accTy.getEncoding(); + if (!retEnc) + return emitError("miss encoding of C operand"); + Dialect &dialect = retEnc.getDialect(); + auto interface = cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +bool DotOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; +} + +//-- DotScaledOp -- +bool DotScaledOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + auto aKdim = aShape[aShape.size() - 1]; + auto bKdim = bShape[aShape.size() - 2]; + if (this->getAElemType() == ScaleDotElemType::E2M1) + aKdim *= 2; + if (this->getBElemType() == ScaleDotElemType::E2M1) + bKdim *= 2; + + return aKdim == bKdim; +} + +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = cast(getType()); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start > end) { + return this->emitOpError() << "start must be less than or equal to end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start != ty.getShape()[0]) { + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must match size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis, + SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +void ReduceOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = cast(operands[i].getType()); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + } + + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult ReduceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + for (auto arg : operands) { + auto argTy = cast(arg.getType()); + auto retEltTy = argTy.getElementType(); + if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if (op.getNumOperands() != op.getNumResults()) { + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + + auto getElementType = [](Type ty) { + if (auto tensorType = dyn_cast(ty)) { + return tensorType.getElementType(); + } + return ty; + }; + + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + if (opElemTy != getElementType(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &ty : operands.getTypes()) { + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +static llvm::SmallVector +getElementTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &op : operands) { + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +::mlir::Operation *ReduceOp::getSingleCombiner() { + if (getNumOperands() != 1 || getNumResults() != 1) + return nullptr; + Block *block = &(*getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return nullptr; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return nullptr; + + return reduceOp; +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +//-- ScanOp -- +void ScanOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool reverse) { + SmallVector inferredReturnTypes; + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + ScanOp::build(builder, state, inferredReturnTypes, operands, axis, reverse); +} + +LogicalResult +ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + return success(); +} + +LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ScanOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ScanOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ScanOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } + +//-- SplatOp -- +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getSrc(); + if (!value) + return {}; + if (!isa(value)) + return {}; + auto shapedType = cast(getType()); + auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); + return ret; +} + +//-- ExpandDimsOp -- +LogicalResult ExpandDimsOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = cast(arg.getType()); + auto retShape = argTy.getShape().vec(); + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (inferLayoutInterface + ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) + .failed()) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return success(); +} + +LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + // expand_dims(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) + // + // On its own this doesn't do much, but consider + // broadcast(expand_dims(broadcast)) + // -> broadcast(broadcast(expand_dims)) + // -> broadcast(expand_dims) + if (auto broadcast = dyn_cast(definingOp)) { + auto src = broadcast.getSrc(); + auto srcTy = src.getType(); + SmallVector newExpandShape(srcTy.getShape()); + newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); + + // Infer the encoding of the new expand op, if encodings are present. + Attribute newExpandEnc; + if (auto srcEnc = srcTy.getEncoding()) { + if (cast(&srcEnc.getDialect()) + ->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc, + op.getLoc()) + .failed()) { + return emitOptionalError(op.getLoc(), + "failed to infer layout for ExpandDimsOp"); + } + } + + auto newExpandTy = RankedTensorType::get( + newExpandShape, srcTy.getElementType(), newExpandEnc); + auto newExpand = rewriter.create(op.getLoc(), newExpandTy, + src, op.getAxis()); + auto newBroadcast = rewriter.create( + broadcast.getLoc(), op.getType(), newExpand.getResult()); + rewriter.replaceOp(op, {newBroadcast.getResult()}); + return success(); + } + + return failure(); +} + +template +static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { + if (!value) + return {}; + + auto shapedType = cast(op.getType()); + if (auto denseElemsAttr = dyn_cast(value)) { + if (denseElemsAttr.isSplat()) { + return denseElemsAttr.resizeSplat(shapedType); + } else { + return denseElemsAttr.reshape(shapedType); + } + } + return {}; +} + +OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +//-- ReshapeOp -- + +void ReshapeOp::build(OpBuilder &builder, OperationState &state, + ArrayRef shape, + TypedValue src) { + auto srcTy = src.getType(); + auto srcEnc = srcTy.getEncoding(); + Attribute dstEnc; + if (srcEnc) { + auto result = cast(&srcEnc.getDialect()) + ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, + dstEnc, state.location); + assert(succeeded(result)); + } + auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc); + build(builder, state, dstTy, src); +} + +LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { + if (op.getEfficientLayout()) + return failure(); + + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // reshape(reshape) -> reshape + if (auto parentReshape = dyn_cast(definingOp)) { + // Allow reorder if either reshape allowed it + const bool allowReorder = + (op.getAllowReorder() || parentReshape.getAllowReorder()); + rewriter.replaceOpWithNewOp(op, op.getType(), + parentReshape.getSrc(), allowReorder, + op.getEfficientLayout()); + return success(); + } + + // reshape(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType() && !getAllowReorder()) { + // no-op + return getSrc(); + } + + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +LogicalResult ReshapeOp::verify() { + auto dstTy = getType(); + auto srcTy = getSrc().getType(); + if (getType().getNumElements() != srcTy.getNumElements()) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + + Attribute srcEnc = srcTy.getEncoding(); + Attribute dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("Op requires that either (a) src and dst both have " + "encodings, or (b) neither does."); + } + + if (!srcEnc || getAllowReorder()) { + return success(); + } + + // Check that we can infer the dst encoding from the src encoding + // and that the inferred dst encoding is the same as the given dst encoding + Attribute inferredDstEnc; + auto result = + cast(&srcEnc.getDialect()) + ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(), + inferredDstEnc, getLoc()); + assert(succeeded(result)); + return cast(&srcEnc.getDialect()) + ->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc, + getLoc()); +} + +//-- FpToFpOp -- + +// Fold FpToFpOp when the input operand is a constant zero. +OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) { + auto srcVal = getSrc(); + auto dstTy = getType(); + // Fold trivial cast + if (srcVal.getType() == dstTy) { + return srcVal; + } + + auto resElemType = cast(getElementTypeOrSelf(getType())); + const llvm::fltSemantics &semantic = resElemType.getFloatSemantics(); + + if (matchPattern(srcVal, m_PosZeroFloat())) { + llvm::APFloat posZero = + llvm::APFloat::getZero(semantic, /*negative=*/false); + if (auto tensorTy = dyn_cast(dstTy)) + return DenseElementsAttr::get(tensorTy, posZero); + return Builder(getContext()).getFloatAttr(resElemType, posZero); + } + + if (matchPattern(srcVal, m_NegZeroFloat())) { + llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true); + if (auto tensorTy = dyn_cast(dstTy)) + return DenseElementsAttr::get(tensorTy, negZero); + return Builder(getContext()).getFloatAttr(resElemType, negZero); + } + + return {}; +} + +LogicalResult FpToFpOp::verify() { + auto dstType = getType(); + auto srcType = getSrc().getType(); + if (auto dstTensorType = dyn_cast(dstType)) + dstType = dstTensorType.getElementType(); + if (auto srcTensorType = dyn_cast(srcType)) + srcType = srcTensorType.getElementType(); + if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && + (!getRounding().has_value())) { + return emitError("Rounding mode is required for FP downcast"); + } + return success(); +} + +//-- BroadcastOp -- +void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +LogicalResult BroadcastOp::verify() { + auto src = getSrc(); + auto srcTensorType = cast(src.getType()); + auto srcShape = srcTensorType.getShape(); + auto result = getResult(); + auto resultTensorType = cast(result.getType()); + auto resultShape = resultTensorType.getShape(); + if (srcShape.size() != resultShape.size()) { + return emitError("rank of source must be same as rank of result"); + } + for (int i = 0; i < srcShape.size(); i++) { + if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) { + return emitError("Different dimensions at index ") + << i << " between source and result. " + << "Broadcast requires the source dimension to be 1."; + } + } + return success(); +} + +//-- MakeTensorPtrOp -- +void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ValueRange offsets, ArrayRef tensorShape, + ArrayRef order) { + // Get pointer type from `base` + auto pointerType = cast(base.getType()); + assert(pointerType != nullptr); + + // Build type `tt.ptr>` + auto tensorType = RankedTensorType::get( + SmallVector(tensorShape.begin(), tensorShape.end()), + pointerType.getPointeeType()); + auto result = PointerType::get(tensorType, pointerType.getAddressSpace()); + + return build(builder, state, result, base, shape, strides, offsets, + builder.getDenseI32ArrayAttr(order)); +} + +//-- AddPtrOp -- +OpFoldResult AddPtrOp::fold(FoldAdaptor adaptor) { + // addptr(ptr, 0) -> ptr + if (matchPattern(adaptor.getOffset(), m_Zero())) { + return getPtr(); + } + return {}; +} + +//-- AdvanceOp -- +OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { + // advance(ptr, 0, 0) -> ptr + SmallVector rawOffsets = getOffsets(); + auto offsets = getConstantIntValues(rawOffsets); + if (!offsets.has_value()) + return {}; + for (int64_t offset : offsets.value()) + if (offset != 0) + return {}; + return getPtr(); +} + +//-- MakeTensorDescOp -- +void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ArrayRef blockShape) { + auto ptrTy = dyn_cast(base.getType()); + if (!ptrTy) { + llvm::report_fatal_error("Expected pointer type"); + } + auto elemTy = ptrTy.getPointeeType(); + + SmallVector blockShape64(blockShape); + auto blockTy = RankedTensorType::get(blockShape64, elemTy); + auto descTy = TensorDescType::get(builder.getContext(), blockTy); + return build(builder, state, descTy, base, shape, strides); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + call_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +// -- CallOp -- +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this).getProperties().callee; + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +// -- ReturnOp -- +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); + + return success(); +} + +// -- JoinOp -- +LogicalResult +JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, + JoinOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + auto srcTy = cast(adaptor.getLhs().getType()); + + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (cast(&srcEnc.getDialect()) + ->inferJoinOpEncoding(srcEnc, retEnc, srcTy.getShape(), location) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); + return success(); +} + +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( + MLIRContext *context, std::optional location, + SplitOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + auto srcTy = cast(adaptor.getSrc().getType()); + auto srcShape = srcTy.getShape(); + + if (srcShape.empty() || srcShape.back() != 2) { + return emitOptionalError(location, + "last dimension of input tensor must be 2"); + } + ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (cast(&srcEnc.getDialect()) + ->inferSplitOpEncoding(srcEnc, retEnc, srcTy.getShape(), location) + .failed()) { + return failure(); + } + } + auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); + inferredReturnTypes.push_back(retTy); + inferredReturnTypes.push_back(retTy); + return success(); +} + +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +LogicalResult ElementwiseInlineAsmOp::verify() { + if (getNumOperands() >= 1) { + auto tensorType = dyn_cast(getOperand(0).getType()); + size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; + if (numInputElems % this->getPackedElement() != 0) { + return emitError("number of input elements ") + << numInputElems + << " must be a multiple of the op's packed_element attribute, " + << getPackedElement(); + } + } + return success(); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +// -- GatherOp -- +LogicalResult GatherOp::verify() { + RankedTensorType indicesTy = getIndices().getType(); + RankedTensorType srcTy = getSrc().getType(); + RankedTensorType resTy = getResult().getType(); + + if (indicesTy.getShape() != resTy.getShape()) { + return emitOpError("indices and output shapes must match"); + } + if (indicesTy.getEncoding() != resTy.getEncoding()) { + return emitOpError("indices and output encodings must match"); + } + if (srcTy.getElementType() != resTy.getElementType()) { + return emitOpError("input and output element types must match"); + } + if (srcTy.getRank() != indicesTy.getRank()) { + return emitOpError("input and indices ranks must match"); + } + if (getAxis() >= srcTy.getRank()) { + return emitOpError("gather dimension must be less than the input rank"); + } + for (int dim = 0; dim < indicesTy.getRank(); ++dim) { + if (dim == getAxis()) + continue; + if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { + return emitOpError("indices dimension ") + << dim << " must match the corresponding input dimension"; + } + } + + return success(); +} + +LogicalResult GatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + GatherOpAdaptor adaptor(operands, attributes, properties, regions); + auto indicesType = cast(adaptor.getIndices().getType()); + auto srcType = cast(adaptor.getSrc().getType()); + + // Shape and encoding of the indices with the element type of the src. + inferredReturnTypes.push_back( + RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), + indicesType.getEncoding())); + return success(); +} + +// -- ExperimentalDescriptorGatherOp +LogicalResult +ExperimentalDescriptorGatherOp::verifyResultType(Operation *op, + mlir::ShapedType type) { + if (type.getRank() != 2) + return op->emitOpError("result must be a 2D tensor, but got ") << type; + + // The swizzling of TMA accesses matches that of the MMAv3 shared memory + // layouts. However, these have minimum size requirements. + // TODO: We can support smaller gather sizes by padding the `local_alloc` this + // lowers to to the nearest minimum tile size. + if (unsigned rows = type.getShape()[0]; rows < 8) { + return op->emitOpError("gather must have at least 8 rows, but got ") + << rows; + } + + Type dtype = type.getElementType(); + if (dtype.getIntOrFloatBitWidth() > 32) + return op->emitOpError("TMA dtype cannot be greater than 32 bits"); + + unsigned minCols = 32 / dtype.getIntOrFloatBitWidth() * 8; + if (unsigned cols = type.getShape()[1]; cols < minCols) { + return op->emitOpError("gather of ") + << dtype << " must have at least " << minCols << " columns, but got " + << cols; + } + + return success(); +} + +LogicalResult ExperimentalDescriptorGatherOp::verify() { + RankedTensorType blockType = getDesc().getType().getBlockType(); + // Gather from `!tt.tensordesc>`. + if (blockType.getRank() != 2) + return emitOpError("block must be a 2D tensor, but got ") << blockType; + if (blockType.getShape()[0] != 1) + return emitOpError("block must have exactly 1 row, but got ") << blockType; + + // With x offsets `tensor`. + RankedTensorType indicesType = getXOffsets().getType(); + if (indicesType.getRank() != 1) + return emitOpError("x offsets must be a 1D tensor, but got ") + << indicesType; + + // Into `tensor`. + RankedTensorType resultType = getType(); + if (failed(verifyResultType(*this, resultType))) + return failure(); + + if (resultType.getShape()[0] != indicesType.getShape()[0]) { + return emitOpError("result tensor must have as many rows as indices (") + << indicesType.getShape()[0] << "), but got " << resultType; + } + if (resultType.getShape()[1] != blockType.getShape()[1]) { + return emitOpError("result tensor number of columns must match block (") + << blockType.getShape()[1] << "), but got " << resultType; + } + if (resultType.getElementType() != blockType.getElementType()) { + return emitOpError("result tensor element type must match block (") + << blockType.getElementType() << "), but got " << resultType; + } + + return success(); +} + +// -- ExperimentalDesciptorLoadOp -- +static LogicalResult verifyDesciptorLoadStoreType(Operation *op, + TensorDescType desc, + RankedTensorType tensor) { + RankedTensorType block = desc.getBlockType(); + ArrayRef blockShape = block.getShape(); + ArrayRef tensorShape = tensor.getShape(); + if (blockShape.size() > tensorShape.size()) { + // Allow ranked reduced load if the leading dimensions are all 1s. + for (int i = 0; i < blockShape.size() - tensorShape.size(); ++i) { + if (blockShape[i] != 1) + return op->emitOpError( + "ranked reduce load only allowed for unit dimension leading dim."); + } + blockShape = blockShape.take_back(tensorShape.size()); + } + + if (blockShape == tensorShape && + block.getElementType() == tensor.getElementType()) + return success(); + return op->emitOpError("tensor desciptor block and tensor types must match"); +} + +LogicalResult ExperimentalDescriptorLoadOp::verify() { + return verifyDesciptorLoadStoreType(*this, getDesc().getType(), getType()); +} + +// -- ExperimentalDesciptorStoreOp -- +LogicalResult ExperimentalDescriptorStoreOp::verify() { + return verifyDesciptorLoadStoreType(*this, getDesc().getType(), + getSrc().getType()); +} + +// -- ExperimentalTensormapCreateOp -- +LogicalResult ExperimentalTensormapCreateOp::verify() { + auto rank = getBoxDim().size(); + if (getGlobalDim().size() != rank) { + return emitError("Rank mismatch for global dim. Got ") + << getGlobalDim().size() << " but expected " << rank; + } + if (getGlobalStride().size() + 1 != rank) { + return emitError("Rank mismatch for global stride. Got ") + << getGlobalStride().size() << " but expected " << rank - 1; + } + if (getElementStride().size() != rank) { + return emitError("Rank mismatch for element stride. Got ") + << getElementStride().size() << " but expected " << rank; + } + return success(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Traits.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..6c45e5a8d --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,217 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; + +LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) { + auto tensorTypeA = dyn_cast(typeA); + auto tensorTypeB = dyn_cast(typeB); + if (!(bool(tensorTypeA) && bool(tensorTypeB))) + return typeA == typeB ? success() : failure(); + auto encodingA = tensorTypeA.getEncoding(); + auto encodingB = tensorTypeB.getEncoding(); + auto shapeA = tensorTypeA.getShape(); + auto shapeB = tensorTypeB.getShape(); + if (shapeA != shapeB) + return failure(); + if (tensorTypeA.getElementType() != tensorTypeB.getElementType()) + return failure(); + // If there's no encoding or the encodings are the same + if (encodingA == encodingB) + return success(); + + return cast(&encodingA.getDialect()) + ->verifyLayoutsAreEqual(shapeA, encodingA, encodingB, {}); +} + +static LogicalResult verifySameEncoding(Type typeA, Type typeB, + bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. + auto getEncoding = [=](Type type) -> Attribute { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { + assert(!triton::isTensorPointerType(type)); + } + return ret; + }; + auto encodingA = getEncoding(typeA); + auto encodingB = getEncoding(typeB); + if (!encodingA || !encodingB) + return success(); + return encodingA == encodingB ? success() : failure(); +} + +LogicalResult +OpTrait::impl::verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( + Operation *op, bool allowTensorPointerType) { + if (op->getNumOperands() == 0) + return success(); + + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) + return op->emitOpError() + << "requires the same encoding for all operands and results"; + + return verifySameOperandsEncoding(op, allowTensorPointerType); +} + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + return success(); +} + +// Check that the Triton layouts on op's operands and return types are valid. +// For example, we check that the number of warps per block in a Triton GPU +// blocked layout matches that of its module. +// +// It's a little weird to check these properties of a layout only when the +// layout is used in an op, since most of the properties don't actually depend +// on the op. They do depend on the *module*, though, and a layout is attached +// to a module only by virtue of being used in one of the module's ops. +LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { + auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { + // Only ranked tensors can have layouts. + auto rankedTy = dyn_cast(val.getType()); + if (!rankedTy) + return success(); + + mlir::Attribute layout = rankedTy.getEncoding(); + if (!layout) + return success(); + + Dialect &dialect = layout.getDialect(); + auto verifyLayoutInterface = + dyn_cast(&dialect); + if (verifyLayoutInterface) { + return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op, + makeErr); + } + + return success(); + }; + + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto err = checkLayout(operand, [&]() { + // Stringify the operand using `printAsOperand`. This prints e.g. "%42" + // rather than the full definition. + std::string operandStr; + llvm::raw_string_ostream os(operandStr); + // If we don't assume verified, dump() will recursively call this + // function! + operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); + + return op->emitError("Operand ") + << i << " (" << operand << ") has an invalid layout: "; + }); + if (!err.succeeded()) + return err; + } + + for (size_t i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + auto err = checkLayout(result, [&]() { + if (op->getNumResults() == 1) { + return op->emitError("Result has an invalid layout: "); + } else { + return op->emitError("Result ") << i << " has an invalid layout: "; + } + }); + if (!err.succeeded()) + return err; + } + + return success(); +} + +static ArrayRef getTypeShape(Type type) { + auto rankedType = dyn_cast(type); + if (auto ptrType = dyn_cast(type)) + rankedType = dyn_cast(ptrType.getPointeeType()); + return rankedType ? rankedType.getShape() : ArrayRef(); +} + +LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() << "requires the same shape for all operands"; + + return success(); +} + +LogicalResult +OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : op->getResultTypes()) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + + return verifySameLoadStoreOperandsShape(op); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Types.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Types.cpp new file mode 100644 index 000000000..de8925cbf --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/IR/Types.cpp @@ -0,0 +1,142 @@ +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, addressSpace); +} + +void PointerType::print(AsmPrinter &printer) const { + if (getAddressSpace() == 1) { + printer << "<" << getPointeeType() << ">"; + } else { + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; + } +} + +namespace mlir { + +namespace triton { + +unsigned getPointeeBitWidth(Type type) { + auto pointeeType = getPointeeType(type); + if (auto tensorTy = dyn_cast(pointeeType)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + return pointeeType.getIntOrFloatBitWidth(); +} + +Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i1Type, + tensorTy.getEncoding()); + return i1Type; +} + +Type getPointeeType(Type type) { + if (auto tensorTy = dyn_cast(type)) { + // Tensor of pointers + auto shape = tensorTy.getShape(); + auto ptrType = dyn_cast(tensorTy.getElementType()); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return type; +} + +Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i32Type, + tensorTy.getEncoding()); + return i32Type; +} + +Type getPointerTypeSameShape(Type type) { + if (auto tensorTy = dyn_cast(type)) { + Type elementType = tensorTy.getElementType(); + auto shape = tensorTy.getShape(); + PointerType ptrType = PointerType::get(elementType, 1); + return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); + } else { + return PointerType::get(type, 1); + } +} + +Type getPointerTypeToElement(Type type) { + Type elementType = getElementTypeOrSelf(type); + PointerType ptrType = PointerType::get(elementType, 1); + return ptrType; +} + +// upstream Triton only uses address space 1 for Pointer Type +Type getPointerType(Type type, int addressSpace) { + return PointerType::get(type, addressSpace); +} + +int getAddressSpace(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getAddressSpace(); + return 1; +} + +bool isTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + return isa(ptrType.getPointeeType()); + return false; +} + +bool isTensorOrTensorPointerType(Type type) { + return isa(type) || isTensorPointerType(type); +} + +Type getElementTypeOfTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) + return tensorTy.getElementType(); + return {}; +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..cda076d4e --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,19 @@ +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonCombineIncGen) + +add_triton_library(TritonTransforms + Combine.cpp + LoopUnroll.cpp + ReorderBroadcast.cpp + RewriteTensorPointer.cpp + + DEPENDS + TritonTransformsIncGen + TritonCombineIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + TritonIR +) diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/Combine.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/Combine.cpp new file mode 100644 index 000000000..7382875f6 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/Combine.cpp @@ -0,0 +1,274 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace mlir::triton { +namespace { + +bool isZero(Value val) { + return (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())); +} + +bool isAddPtrOffsetCombinable(Value first, Value second) { + auto GetConstantIntValue = [](Value val) -> std::optional { + DenseElementsAttr constAttr; + auto defOp = val.getDefiningOp(); + if (defOp) { + if (auto splatOp = llvm::dyn_cast(defOp)) + val = splatOp.getSrc(); + else if (matchPattern(defOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto attr = constAttr.getSplatValue(); + // Check IntegerAttr + if (auto intAttr = dyn_cast_or_null(attr)) + return intAttr.getValue(); + } + } + + // Check constant value. + llvm::APInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal; + + return std::nullopt; + }; + + if (first.getType() == second.getType()) { + // Whether bitwidth of element type is equal to pointer + if (getElementTypeOrSelf(first.getType()).getIntOrFloatBitWidth() == 64) + return true; + + // first + second does not overflow + auto firstVal = GetConstantIntValue(first); + auto secondVal = GetConstantIntValue(second); + if (firstVal && secondVal) { + bool overflow = false; + auto resVal = firstVal->sadd_ov(*secondVal, overflow); + return !overflow; + } + } + return false; +} + +// TODO(csigg): remove after next LLVM integrate. +using FastMathFlags = arith::FastMathFlags; + +#include "TritonCombine.inc" + +// select(cond, load(ptrs, splat(cond), ???), other) +// => load(ptrs, splat(cond), other) +class CombineSelectMaskedLoadPattern : public RewritePattern { +public: + CombineSelectMaskedLoadPattern(MLIRContext *context) + : RewritePattern(arith::SelectOp::getOperationName(), 3, context, + {LoadOp::getOperationName()}) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto selectOp = llvm::dyn_cast(op); + if (!selectOp) + return failure(); + + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + Value condSelect = selectOp.getCondition(); + + auto loadOp = trueValue.getDefiningOp(); + if (!loadOp) + return failure(); + + Value mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto splatOp = mask.getDefiningOp(); + if (!splatOp) + return failure(); + + auto splatCond = splatOp.getSrc(); + if (splatCond != condSelect) + return failure(); + + rewriter.replaceOpWithNewOp( + op, loadOp.getPtr(), loadOp.getMask(), /*other=*/falseValue, + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + return success(); + } +}; + +// sum(x[:, :, None] * y[None, :, :], 1) +// -> dot(x, y) +class CombineBroadcastMulReducePattern : public RewritePattern { +private: + static bool isAddF32(const Operation *op) { + if (auto addf = dyn_cast_or_null(op)) + return addf.getType().getIntOrFloatBitWidth() <= 32; + return false; + } + + static SmallVector getEqualIndices(ArrayRef x, + ArrayRef y) { + SmallVector res; + for (int i = 0; i < x.size(); ++i) + if (x[i] == y[i]) + res.push_back(i); + return res; + } + +public: + CombineBroadcastMulReducePattern(MLIRContext *context) + : RewritePattern(ReduceOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto reduceOp = llvm::dyn_cast(op); + if (!reduceOp) + return failure(); + // only support reduce with simple addition + Region &combineOp = reduceOp.getCombineOp(); + bool isReduceAdd = combineOp.hasOneBlock() && + combineOp.front().getOperations().size() == 2 && + isAddF32(&*combineOp.front().getOperations().begin()); + if (!isReduceAdd) + return failure(); + // operand of reduce has to be mul + auto mulOp = reduceOp.getOperand(0).getDefiningOp(); + if (!mulOp) + return failure(); + // mul operand has to be broadcast + auto broadcastLhsOp = mulOp.getOperand(0).getDefiningOp(); + if (!broadcastLhsOp) + return failure(); + auto broadcastRhsOp = mulOp.getOperand(1).getDefiningOp(); + if (!broadcastRhsOp) + return failure(); + // broadcast operand is expand dims + auto expandLhsOp = broadcastLhsOp.getSrc().getDefiningOp(); + if (!expandLhsOp) + return failure(); + auto expandRhsOp = broadcastRhsOp.getSrc().getDefiningOp(); + if (!expandRhsOp) + return failure(); + // get not-broadcast dimensions + int expandLhsAxis = expandLhsOp.getAxis(); + int expandRhsAxis = expandRhsOp.getAxis(); + if (expandLhsAxis != 2 || expandRhsAxis != 0) + return failure(); + auto broadcastLhsShape = + cast(broadcastLhsOp.getType()).getShape(); + auto broadcastRhsShape = + cast(broadcastLhsOp.getType()).getShape(); + if (broadcastLhsShape[2] < 16 || broadcastRhsShape[0] < 16) + return failure(); + Type newAccType = RankedTensorType::get( + {broadcastLhsShape[0], broadcastRhsShape[2]}, + cast(broadcastLhsOp.getSrc().getType()).getElementType()); + rewriter.setInsertionPoint(op); + auto newAcc = rewriter.create( + op->getLoc(), newAccType, + rewriter.create(op->getLoc(), + rewriter.getF32FloatAttr(0))); + rewriter.replaceOpWithNewOp(op, expandLhsOp.getSrc(), + expandRhsOp.getSrc(), newAcc, + InputPrecision::TF32, 0); + return success(); + } +}; + +// When reducing a 1D tensor the order of elements of the tensor doesn't matter. +// Therefore we can relax the reshape to allow it to re-order elements. +class CombineReshapeReducePatterns : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp reshapeOp, + mlir::PatternRewriter &rewriter) const override { + if (reshapeOp.getAllowReorder()) + return failure(); + if (reshapeOp.getType().getRank() != 1) + return failure(); + for (Operation *user : reshapeOp->getUsers()) { + if (!isa(user)) + return failure(); + } + rewriter.modifyOpInPlace(reshapeOp, + [&]() { reshapeOp.setAllowReorder(true); }); + return success(); + } +}; + +class RankedReduceDescriptorLoads : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp reshapeOp, + mlir::PatternRewriter &rewriter) const override { + auto loadDef = reshapeOp.getSrc() + .getDefiningOp(); + if (!loadDef || !loadDef->hasOneUse()) + return failure(); + int loadRank = loadDef.getType().getRank(); + int reshapeRank = reshapeOp.getType().getRank(); + if (!(reshapeRank < loadRank)) + return failure(); + ArrayRef loadShape = loadDef.getType().getShape(); + ArrayRef reshapeShape = reshapeOp.getType().getShape(); + for (int i = 0; i < loadRank - reshapeRank; ++i) { + // Only rank reduce unit dims. + if (loadShape[i] != 1) + return failure(); + } + if (loadShape.take_back(reshapeRank) != reshapeShape) + return failure(); + rewriter.modifyOpInPlace( + loadDef, [&]() { loadDef.getResult().setType(reshapeOp.getType()); }); + rewriter.replaceOp(reshapeOp, loadDef.getResult()); + return success(); + } +}; + +class CombineOpsPass : public TritonCombineOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + // Dot Add %{ + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + // %} + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // anonymous namespace + +std::unique_ptr createCombineOpsPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/Combine.td b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/Combine.td new file mode 100644 index 000000000..e3588f587 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/Combine.td @@ -0,0 +1,47 @@ +#ifndef TRITON_PATTERNS +#define TRITON_PATTERNS + +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" + + +// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) + +// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +def CombineDotAddIPattern : Pat< + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFPattern : Pat< + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +def CombineDotAddIRevPattern : Pat< + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFRevPattern : Pat< + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) +// Note: leave (sub %c0, %c0) canceling to ArithDialect +// (ref: ArithCanonicalization.td) +defvar DefOverflow = ConstantEnumCase; +def CombineAddPtrPattern : Pat< + (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), + (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), + [(Constraint> $idx0, $idx1)]>; + +#endif diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/LoopUnroll.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/LoopUnroll.cpp new file mode 100644 index 000000000..cb25d41a2 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/LoopUnroll.cpp @@ -0,0 +1,75 @@ +#include + +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "triton-loop-unroll" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton { + +namespace { + +class LoopUnrollPass : public TritonLoopUnrollBase { + + int getUnrollFactorOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise set the + // factor to 1 to suppress the unrolling. + if (auto factor = + forOp->getAttrOfType(loopUnrollFactorAttrName)) + return factor.getInt(); + return 1; + } + + const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; + const char *pipelineStagesAttrName = "tt.num_stages"; + +public: + LoopUnrollPass() = default; + LoopUnrollPass(const LoopUnrollPass &) {} + void runOnOperation() override { + LDBG("Loop unroll pass"); + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with unroll factor <= 1. + if (getUnrollFactorOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + auto ctx = getOperation()->getContext(); + for (auto loop : loops) { + auto unrollFactor = getUnrollFactorOrDefault(loop); + loop->removeAttr(loopUnrollFactorAttrName); + LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop); + auto resultLoops = loopUnrollByFactor(loop, unrollFactor); + // Do not pipeline the epilog loop. + if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) { + (*resultLoops->epilogueLoopOp) + ->setAttr(pipelineStagesAttrName, + mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1)); + } + } + } +}; + +} // anonymous namespace + +std::unique_ptr createLoopUnrollPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp new file mode 100644 index 000000000..db8166085 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -0,0 +1,235 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +// TODO(jlebar): Move this and all other generatede code into namespace +// mlir::triton. +#define GEN_PASS_DEF_TRITONREORDERBROADCAST +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace mlir::triton { +namespace { + +Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter, + Operation *op, ValueRange newOperands, + TypeRange newTypes) { + OperationState newElementwiseState(op->getLoc(), op->getName()); + newElementwiseState.addOperands(newOperands); + newElementwiseState.addTypes(newTypes); + newElementwiseState.addAttributes(op->getAttrs()); + return rewriter.create(newElementwiseState); +} + +bool isSplat(Operation *op) { + if (auto splatOp = llvm::dyn_cast(op)) { + return true; + } + DenseElementsAttr constAttr; + return (matchPattern(op, m_Constant(&constAttr)) && constAttr.isSplat()); +} + +// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) +struct MoveSplatAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveSplatAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult match(Operation *op) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + for (auto operand : op->getOperands()) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + + if (!isSplat(definingOp)) { + return failure(); + } + } + return success(op->getNumOperands() > 0); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto operands = op->getOperands(); + + llvm::SmallVector scalarOperands(operands.size()); + for (unsigned iOp = 0; iOp < operands.size(); ++iOp) { + auto definingOp = operands[iOp].getDefiningOp(); + + DenseElementsAttr constAttr; + if (auto splatOp = llvm::dyn_cast(definingOp)) { + scalarOperands[iOp] = splatOp.getSrc(); + } else if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto value = constAttr.getSplatValue(); + scalarOperands[iOp] = arith::ConstantOp::materialize( + rewriter, value, constAttr.getElementType(), loc); + } else { + llvm_unreachable("Expected a splat"); + } + } + + auto resultTypes = op->getResultTypes(); + llvm::SmallVector scalarResultTys; + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + scalarResultTys.push_back(elemTy); + } + + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, scalarOperands, + scalarResultTys); + + for (unsigned iRes = 0; iRes < resultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + } +}; + +// elementwise(broadcast(a)) => broadcast(elementwise(a)) +// This also generalizes to multiple arguments when the rest are splat-like +// Not handled: multiple broadcasted arguments +struct MoveBroadcastAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveBroadcastAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult match(Operation *op) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + auto operands = op->getOperands(); + bool seenBroadcast = false; + ArrayRef srcShape; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) { + return failure(); + } + auto getSrcShape = [](BroadcastOp b) { + return b.getSrc().getType().getShape(); + }; + if (auto broadcastOp = llvm::dyn_cast(definingOp)) { + if (!seenBroadcast) { + seenBroadcast = true; + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { + // If the broadcast have different types we cannot re-order. + return failure(); + } + } else if (!isSplat(definingOp)) { + // Not splat or broadcast + return failure(); + } + } + return success(seenBroadcast); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Find broadcast op + auto operands = op->getOperands(); + BroadcastOp broadcastOp; + for (auto operand : operands) { + broadcastOp = operand.getDefiningOp(); + if (broadcastOp) { + break; + } + } + + auto srcTy = broadcastOp.getSrc().getType(); + auto srcShape = srcTy.getShape(); + auto srcEncoding = srcTy.getEncoding(); + + // Reshape operands to match srcShape + llvm::SmallVector newOperands; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (auto broadcastSrcOp = llvm::dyn_cast(definingOp)) { + newOperands.push_back(broadcastSrcOp.getSrc()); + continue; + } + auto elemTy = + dyn_cast(operand.getType()).getElementType(); + auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding); + if (auto splatOp = llvm::dyn_cast(definingOp)) { + auto newSplat = rewriter.create(loc, newTy, splatOp.getSrc()); + newOperands.push_back(newSplat); + continue; + } + DenseElementsAttr constAttr; + if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto scalarValue = constAttr.getSplatValue(); + auto splatValue = SplatElementsAttr::get(newTy, scalarValue); + auto newConstant = + rewriter.create(loc, newTy, splatValue); + newOperands.push_back(newConstant); + continue; + } + llvm_unreachable("Expected broadcast or splat"); + } + + // Reshape results to match srcShape + llvm::SmallVector newResultTypes; + auto resultTypes = op->getResultTypes(); + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + newResultTypes.push_back( + RankedTensorType::get(srcShape, elemTy, srcEncoding)); + } + + // Create new op and broadcast results + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, newOperands, + newResultTypes); + for (unsigned iRes = 0; iRes < newResultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + } +}; + +class ReorderBroadcastPass + : public ::impl::TritonReorderBroadcastBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + BroadcastOp::getCanonicalizationPatterns(patterns, context); + ExpandDimsOp::getCanonicalizationPatterns(patterns, context); + // elementwise(broadcast(a)) => broadcast(elementwise(a)) + patterns.add(context); + // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + patterns.add(context); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr createReorderBroadcastPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp new file mode 100644 index 000000000..b2e58cf24 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -0,0 +1,569 @@ +#include +#include + +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +/// An additional struct to record the meta information of operations +/// with tensor pointers +struct RewritedInfo { +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; + +public: + RewritedInfo() = default; + + RewritedInfo(const RewritedInfo &other) = default; + + RewritedInfo(Value base, const SmallVector &shape, + const SmallVector &strides, + const SmallVector &offsets, + const ArrayRef &tensorShape) + : base(base), shape(shape), strides(strides), offsets(offsets), + tensorShape(tensorShape) { + assert(shape.size() == strides.size() && shape.size() == offsets.size() && + shape.size() == tensorShape.size()); + } + + unsigned int length() const { return shape.size(); } + + Value getOffset(unsigned i) { return offsets[i]; } + + SmallVector getOffsets() { return offsets; } + + void setOffset(unsigned i, Value newOffset) { + offsets[i] = newOffset; + cachedOffsetWithRange.clear(); + } + + void setOffsets(const SmallVector &newOffsets) { + offsets = newOffsets; + cachedOffsetWithRange.clear(); + } + + Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + unsigned i) { + if (cachedOffsetWithRange.count(i)) + return cachedOffsetWithRange[i]; + + // Add range + auto indexI32RowType = + RankedTensorType::get({tensorShape[i]}, builder.getI32Type()); + auto indexRowType = + RankedTensorType::get({tensorShape[i]}, builder.getI64Type()); + Value splatOffset = + builder.create(loc, indexRowType, offsets[i]); + Value range = builder.create(loc, indexI32RowType, 0, + tensorShape[i]); + Value i64Range = builder.create(loc, indexRowType, range); + + // Expand dimensions + Value expandedResult = + builder.create(loc, splatOffset, i64Range); + for (int j = 0; j < tensorShape.size(); ++j) { + if (j == i) + continue; + expandedResult = + builder.create(loc, expandedResult, j); + } + + return cachedOffsetWithRange[i] = expandedResult; + } + + Value generatePtr(OpBuilder &builder, const Location &loc) { + assert(tensorShape.size() == offsets.size() && + tensorShape.size() == strides.size()); + auto indexTensorType = + RankedTensorType::get(tensorShape, builder.getI64Type()); + auto ptrType = cast(base.getType()); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType); + + // Generate offsets per dimension + Value ptr = builder.create(loc, ptrTensorType, base); + for (unsigned i = 0; i < tensorShape.size(); ++i) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = builder.create( + loc, offsetWithRange.getType(), strides[i]); + Value offsetWithStride = + builder.create(loc, offsetWithRange, splatStride); + Value broadcasted = builder.create( + loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = builder.create(loc, ptrTensorType, ptr, + broadcasted); + } + + return ptr; + } + + Value generateMask(OpBuilder &builder, const Location &loc, + const std::optional> &boundaryCheck) { + if (!boundaryCheck.has_value()) + return {}; + + // Generate mask per dimension + auto maskTensorType = + RankedTensorType::get(tensorShape, builder.getI1Type()); + Value mask; + for (auto i : boundaryCheck.value()) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // Compare with lower bound + Value lowerBound = builder.create( + loc, 0, builder.getI64Type()); + Value splatLowerBound = builder.create( + loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = builder.create( + loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = builder.create( + loc, offsetWithRange.getType(), shape[i]); + Value cmpUpper = builder.create( + loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = builder.create(loc, cmpLower, cmpUpper); + Value broadcasted = + builder.create(loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = builder.create(loc, mask, broadcasted); + } + } + + return mask; + } + + Value generateOther(OpBuilder &builder, const Location &loc, + const std::optional &padding) { + if (!padding.has_value()) + return Value(); + + // Create element attribute + auto elementType = + cast(base.getType()).getPointeeType(); + auto otherTensorType = RankedTensorType::get(tensorShape, elementType); + + // Set zero padding value + TypedAttr attr = builder.getZeroAttr(elementType); + + // Float NaN padding case + if (padding.value() == triton::PaddingOption::PAD_NAN) { + assert(!elementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(attr).getValue().getSemantics()); + attr = builder.getFloatAttr(elementType, apNaN); + } + + // Create tensor + Value constant = builder.create(loc, attr); + return builder.create(loc, otherTensorType, constant); + } +}; + +} // namespace + +// TODO: this pass relies on assumptions of how block pointers are created and +// on pattern matches that walks the SSA links to find the base/strides. This is +// very fragile and to solve we should expose convert Ptr of tensor to a +// structure containins all values and not only offsets. +class RewriteTensorPointerPass + : public TritonRewriteTensorPointerBase { +private: + DenseMap rewritedInfo; + +public: + static bool needRewrite(Operation *op) { + return std::any_of(op->getOperands().begin(), op->getOperands().end(), + [](Value operand) { + return triton::isTensorPointerType(operand.getType()); + }); + } + + static void generateNewOperands(SmallVector &oldOperands, + unsigned index, ArrayRef newValues) { + size_t size = oldOperands.size(); + assert(index < size); + SmallVector operands = oldOperands; + oldOperands.reserve(size - 1 + newValues.size()); + oldOperands.clear(); + if (index != 0) { + oldOperands.append(operands.begin(), operands.begin() + index); + } + oldOperands.append(newValues.begin(), newValues.end()); + if (index != size - 1) { + oldOperands.append(operands.begin() + index + 1, operands.end()); + } + } + + Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, + triton::MakeTensorPtrOp op, + std::stack &eraser) { + // Save info for later use + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + + // Cast I32 offsets into I64 + SmallVector i64Offsets; + for (auto offset : op.getOffsets()) { + auto i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), offset); + i64Offsets.push_back(i64Offset); + } + + // Save information + rewritedInfo[op.getResult()] = + RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, + tensorType.getShape()); + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op, + std::stack &eraser) { + // Get info from previous results + assert(rewritedInfo.count(op.getPtr())); + auto info = rewritedInfo[op.getPtr()]; + + // Calculate new offsets + assert(info.length() == op.getOffsets().size()); + SmallVector newOffsets; + for (int i = 0; i < info.length(); ++i) { + Value i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); + Value newOffset = builder.create( + op.getLoc(), info.getOffset(i), i64Offset); + newOffsets.push_back(newOffset); + } + + // Save info for later use + info.setOffsets(newOffsets); + rewritedInfo[op.getResult()] = info; + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser) { + assert(isa(op) || isa(op)); + + // We only have to rewrite load/stores with tensor pointers + auto ptr = op->getOperand(0); + if (!triton::isTensorPointerType(ptr.getType())) + return nullptr; + + // Get info from previous results + assert(rewritedInfo.count(ptr)); + auto info = rewritedInfo[ptr]; + + // Load/store with tensor pointers implicitly will check the bound while + // accessing memory, so we should set `mask` and `other` (according to the + // padding). Also note that load with tensor pointers do not have `mask` and + // `other` while building IR from Python AST + std::optional> boundaryCheck; + if (auto loadOp = dyn_cast(op)) { + assert(!loadOp.getMask() && !loadOp.getOther()); + boundaryCheck = loadOp.getBoundaryCheck(); + } else if (auto storeOp = dyn_cast(op)) { + assert(!storeOp.getMask()); + boundaryCheck = storeOp.getBoundaryCheck(); + } + + // Generate new `ptr`, `mask` and `other` + auto newPtr = info.generatePtr(builder, op->getLoc()); + auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); + Value newOther; + if (auto loadOp = dyn_cast(op)) + newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding()); + + // Create a new operation + if (auto loadOp = dyn_cast(op)) { + auto newResult = builder.create( + loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + op->getResult(0).replaceAllUsesWith(newResult); + if (op->getAttr("async_task_id")) + newResult->setAttr("async_task_id", op->getAttr("async_task_id")); + } else if (auto storeOp = dyn_cast(op)) { + auto newOp = builder.create( + storeOp.getLoc(), newPtr, storeOp.getValue(), newMask, + storeOp.getCache(), storeOp.getEvict()); + if (op->getAttr("async_task_id")) + newOp->setAttr("async_task_id", op->getAttr("async_task_id")); + } + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser) { + auto thenYieldOp = op.thenYield(); + assert(op.getNumResults() == thenYieldOp.getNumOperands()); + SmallVector results = thenYieldOp.getOperands(); + + // get new result types + SmallVector newRetTypes; + bool needRewrite = false; + for (unsigned i = 0; i < results.size(); ++i) { + if (!triton::isTensorPointerType(results[i].getType())) { + newRetTypes.push_back(results[i].getType()); + continue; + } + needRewrite = true; + auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + const auto &info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + newRetTypes.push_back(builder.getI64Type()); + } + } + if (!needRewrite) + return op; + // create and clone new IfOp + bool hasElse = !op.getElseRegion().empty(); + scf::IfOp newOp = builder.create(op.getLoc(), newRetTypes, + op.getCondition(), hasElse); + IRMapping mapping; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + mapping.map(op->getOperand(i), newOp->getOperand(i)); + } + auto rematerialize = [&](Block *block) { + for (Operation &opInIf : block->getOperations()) { + builder.clone(opInIf, mapping); + } + }; + builder.setInsertionPointToStart(newOp.thenBlock()); + rematerialize(op.thenBlock()); + if (hasElse) { + builder.setInsertionPointToStart(newOp.elseBlock()); + rematerialize(op.elseBlock()); + } + + // update rewritedInfo + auto opResults = op.getResults(); + unsigned oldResIdx = 0, newResIdx = 0; + while (oldResIdx < results.size()) { + if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + opResults[oldResIdx].replaceAllUsesWith(newOp.getResult(newResIdx)); + oldResIdx++; + newResIdx++; + } else { + auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + info.setOffset(j, newOp->getResult(newResIdx++)); + } + rewritedInfo[op.getResult(oldResIdx)] = info; + oldResIdx++; + } + } + + eraser.push(op); + return newOp; + } + + Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser) { + // Generate new iteration operands and set rewritten information + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; + ++i, ++oldI) { + if (!triton::isTensorPointerType(newIterOperands[i].getType())) + continue; + + // Expand the tensor pointer into offsets + assert(rewritedInfo.count(newIterOperands[i])); + auto info = rewritedInfo[newIterOperands[i]]; + generateNewOperands(newIterOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + + // Rebuild the loop type + auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep(), + newIterOperands); + newForOp->setAttrs(op->getAttrs()); + + // Create value mapping. Note that for tensor pointers, we use identity + // mapping. It may refer to a value in the old loop, but we will rewrite it + // later + IRMapping mapping; + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; + ++i, ++oldI) { + auto oldRegionIterArg = op.getRegionIterArg(oldI); + if (triton::isTensorPointerType(oldRegionIterArg.getType())) { + // Pass rewritten info inside + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + mapping.map(oldRegionIterArg, oldRegionIterArg); + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getRegionIterArg(i + j)); + rewritedInfo[oldRegionIterArg] = info; + i += info.length() - 1; + } else { + mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); + } + } + mapping.map(op.getInductionVar(), newForOp.getInductionVar()); + + // Clone body + builder.setInsertionPointToStart(newForOp.getBody()); + for (auto &opInFor : *op.getBody()) { + builder.clone(opInFor, mapping); + } + + // Replace later usages + assert(op.getNumResults() == op.getInitArgs().size()); + for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { + auto oldResult = op.getResult(oldI); + if (triton::isTensorPointerType(oldResult.getType())) { + // Pack new offsets into rewritten info + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getResult(i + j)); + i += info.length() - 1; + rewritedInfo[oldResult] = info; + } else { + oldResult.replaceAllUsesWith(newForOp.getResult(i)); + } + } + + // Erase later + eraser.push(op); + return newForOp; + } + + Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser) { + // Replace tensor pointers with offsets + SmallVector newOperands = op->getOperands(); + for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { + if (!triton::isTensorPointerType(newOperands[i].getType())) + continue; + + assert(rewritedInfo.count(newOperands[i])); + auto info = rewritedInfo[newOperands[i]]; + generateNewOperands(newOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + op->setOperands(newOperands); + + // No need to erase + return nullptr; + } + + Operation *rewriteOp(Operation *op, std::stack &eraser) { + OpBuilder builder(op); + + // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers + // Rewriting functions return the next operation to visit, if there is no + // next one, simply return `nullptr` + if (auto makeTensorPtrOp = dyn_cast(op)) { + return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); + } else if (auto advanceOp = dyn_cast(op)) { + return rewriteAdvanceOp(builder, advanceOp, eraser); + } else if (isa(op) || isa(op)) { + return rewriteLoadStoreOp(builder, op, eraser); + } else if (isa(op->getDialect())) { + if (auto ifOp = dyn_cast(op)) { + return rewriteIfOp(builder, ifOp, eraser); + } + if (!needRewrite(op)) + return op; + + if (auto forOp = dyn_cast(op)) { + return rewriteForOp(builder, forOp, eraser); + } else if (auto yieldOp = dyn_cast(op)) { + return rewriteYieldOp(builder, yieldOp, eraser); + } else { + llvm_unreachable("Currently we only support tensor pointer usages " + "inside a `scf::ForOp` or `scf::IfOp`, others such as " + "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " + "are not supported yet"); + } + } + + // Otherwise return the original one + return op; + } + + void visitOperation(Operation *op, std::stack &eraser) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : llvm::make_early_inc_range(block)) { + if (auto newOp = rewriteOp(&nestedOp, eraser)) { + visitOperation(newOp, eraser); + } + } + } + } + } + + void runOnOperation() override { + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because + // MLIR does not support one-multiple value mapping. For example, if we use + // `ConversionPatternRewriter`, we can not make a type converter, which + // converts `ptr` into multiple types `ptr<>, int64, int64, ...` + // (containing the base/offsets/strides...). What we can do is to convert + // `ptr` into a single type `Tuple, int64, int64, ...>`. But + // in this way, we also have to define `PackTuple` and `UnpackTuple` + // operations and make a canonicalization pass to optimize, which is much + // So here we recursively build the IR, to be specific, we have to rewrite + // `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`, + // `scf.for` (tensor pointer usages may be in a loop fashion) + std::stack eraser; + visitOperation(getOperation(), eraser); + + // The operation could not be erased during visit, because they may have + // later usages, so we erase after visit + rewritedInfo.clear(); + while (!eraser.empty()) { + auto op = eraser.top(); + eraser.pop(); + op->erase(); + } + } +}; + +std::unique_ptr triton::createRewriteTensorPointerPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..7486d72f3 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(TritonGPUIR + Dialect.cpp + LinearLayoutConversions.cpp + Ops.cpp + Types.cpp + + DEPENDS + TritonGPUTableGen + TritonGPUAttrDefsIncGen + TritonGPUTypeInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRGPUDialect + TritonIR + TritonTools +) diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Dialect.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Dialect.cpp new file mode 100644 index 000000000..a16f6bc65 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -0,0 +1,3412 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/TypeSwitch.h" + +// Include TableGen'erated code +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Utility +namespace mlir { +namespace triton { +namespace gpu { + +LinearEncodingAttr toLinearEncoding(RankedTensorType type) { + return toLinearEncoding(type.getEncoding(), type.getShape()); +} + +LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef shape) { + auto linearLayout = toLinearLayout(shape, layout); + return LinearEncodingAttr::get(layout.getContext(), std::move(linearLayout)); +} + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape) { + return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape); +} + +SmallVector getElemsPerThread(Attribute layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getElemsPerThread(shape); +} + +SmallVector getElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return SmallVector(1, 1); + auto tensorType = cast(type); + return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape()); +} + +unsigned getTotalElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return 1; + auto tensorType = cast(type); + return getTotalElemsPerThread(tensorType.getEncoding(), + tensorType.getShape()); +} + +SmallVector getThreadsPerWarp(Attribute layout) { + if (auto distributedLayout = dyn_cast(layout)) { + return distributedLayout.getThreadsPerWarp(); + } else { + llvm::report_fatal_error("getThreadsPerWarp not implemented"); + return SmallVector(); + } +} + +unsigned getWarpSize(Attribute layout) { + unsigned size = 1; + auto threadsPerWarp = getThreadsPerWarp(layout); + for (auto e : threadsPerWarp) { + size *= e; + } + return size; +} + +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape) { + return toLinearEncoding(layout, tensorShape).getThreadsPerWarp(); +} + +SmallVector getWarpsPerCTA(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getWarpsPerCTA(); + } + + llvm::report_fatal_error("getWarpsPerCTA not implemented"); + return SmallVector(); +} + +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape) { + auto linearLayout = toLinearLayout(tensorShape, layout); + auto llAttr = LinearEncodingAttr::get(layout.getContext(), linearLayout); + return llAttr.getWarpsPerCTA(); +} + +SmallVector getContigPerThread(RankedTensorType tensorType) { + auto layout = tensorType.getEncoding(); + auto shape = tensorType.getShape(); + auto linearLayout = toLinearLayout(shape, layout); + auto llAttr = LinearEncodingAttr::get(tensorType.getContext(), linearLayout); + return llAttr.getContigPerThread(); +} + +SmallVector getShapePerCTATile(RankedTensorType type) { + return toLinearEncoding(type).getShapePerCTATile(); +} + +bool isExpensiveView(Type srcType, Type dstType) { + auto tensorSrcType = cast(srcType); + auto tensorDstType = cast(dstType); + auto llSrc = + toLinearLayout(tensorSrcType.getShape(), tensorSrcType.getEncoding()); + auto llDst = + toLinearLayout(tensorDstType.getShape(), tensorDstType.getEncoding()); + // In case there are replicated value we need to make sure the new and old + // layout have matching masks. + for (auto [srcMask, dstMask] : + llvm::zip(llSrc.getFreeVariableMasks(), llDst.getFreeVariableMasks())) { + assert(srcMask.first == dstMask.first); + if (srcMask.second != dstMask.second) + return true; + } + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + +/* Utility function used by get.*Order methods of SliceEncodingAttr. + * Erase dim and decrease all values larger than dim by 1. + * Example: order = [0, 2, 4, 3, 1], dim = 2 + * resOrder = [0, 3, 2, 1] + */ +static SmallVector eraseOrder(ArrayRef order, + unsigned dim) { + unsigned rank = order.size(); + assert(dim < rank && "Invalid dim to erase"); + SmallVector resOrder; + for (unsigned i : order) + if (i < dim) + resOrder.push_back(i); + else if (i > dim) + resOrder.push_back(i - 1); + return resOrder; +} + +SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { + // Return the order that represents that the batch is in row-major or + // column-major order for a batch of matrices of shape [*, m, n] with + // len(shape) == rank. + assert(rank >= 2); + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + if (!rowMajor) { + std::swap(order[0], order[1]); + } + return order; +} + +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kContig) { + // kContig: if true, the matrix is fastest-running on k, + // otherwise it is on m (resp. n) + // opIdx=0: [batch, m, k] if rank == 3 else [m, k] + // opIdx=1: [batch, k, n] if rank == 3 else [k, n] + // batch (if rank == 3) is always the slowest running dimension + assert(rank == 2 || rank == 3); + assert(opIdx == 0 || opIdx == 1); + auto rowMajor = bool(opIdx) != kContig; + return getMatrixOrder(rank, rowMajor); +} + +SmallVector getRepOrder(RankedTensorType type) { + auto layout = type.getEncoding(); + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getRepOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getRepOrder"); + return {}; +} + +// Legacy impl for now +// This one's not terribly bad as we don't broadcast ShareEncodings +SmallVector getOrder(SharedEncodingTrait layout, + ArrayRef shape) { + if (auto swizzledLayout = + mlir::dyn_cast(layout)) { + return llvm::to_vector(swizzledLayout.getOrder()); + } + if (auto sharedLayout = mlir::dyn_cast(layout)) { + return sharedLayout.getOrder(); + } + llvm::report_fatal_error("Unimplemented usage of getOrder for MemDescType"); + return {}; +} + +// Convenience functions +SmallVector getOrder(TensorOrMemDesc type) { + if (auto memDesc = dyn_cast(type)) { + return getOrder(memDesc); + } else { + auto tensorTy = cast(type); + return getOrder(tensorTy); + } +} + +SmallVector getOrder(MemDescType type) { + return getOrder(cast(type.getEncoding()), + type.getShape()); +} + +// Legacy impl for now +SmallVector getOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return layout.getDefaultOrder(); +} + +// Convenience function +SmallVector getOrder(RankedTensorType type) { + return getOrder(cast(type.getEncoding()), + type.getShape()); +} + +SmallVector getDefaultMmaOrder(MmaEncodingTrait layout) { + auto distributedLayout = cast(layout); + auto rank = distributedLayout.getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +// Legacy impl for now +SmallVector getThreadOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return layout.getDefaultThreadOrder(); +} + +// Convenience function +SmallVector getThreadOrder(RankedTensorType type) { + return getThreadOrder(cast(type.getEncoding()), + type.getShape()); +} + +// Legacy impl for now +SmallVector getWarpOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return layout.getDefaultWarpOrder(); +} + +// Convenience function +SmallVector getWarpOrder(RankedTensorType type) { + return getWarpOrder(cast(type.getEncoding()), + type.getShape()); +} + +CTALayoutAttr getCTALayout(Attribute layout) { + if (auto ttgLayout = mlir::dyn_cast(layout)) { + return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(ttgLayout), + getCTASplitNum(ttgLayout), + getCTAOrder(ttgLayout)); + } + llvm::report_fatal_error("Unimplemented usage of getCTALayout"); + return {}; +} + +SmallVector getCTAsPerCGA(Attribute layout) { + ArrayRef ref; + if (auto ttgLayout = mlir::dyn_cast(layout)) + return ttgLayout.getCTAsPerCGA(); + else + llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); + return SmallVector(ref.begin(), ref.end()); +} + +SmallVector getCTASplitNum(Attribute layout) { + SmallVector res; + if (auto ttgLayout = mlir::dyn_cast(layout)) { + return ttgLayout.getCTASplitNum(); + } else if (auto tmemLayout = + mlir::dyn_cast( + layout)) { + res.resize(2); + res[0] = tmemLayout.getCTASplitM(); + res[1] = tmemLayout.getCTASplitN(); + } else if (auto tmemScaleLayout = mlir::dyn_cast< + triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(layout)) { + res.resize(2); + res[0] = tmemScaleLayout.getCTASplitM(); + res[1] = tmemScaleLayout.getCTASplitN(); + } else { + assert(false && "Unimplemented usage of getCTASplitNum"); + } + return res; +} + +SmallVector getCTAOrder(Attribute layout) { + SmallVector res; + if (auto ttgLayout = mlir::dyn_cast(layout)) { + res = ttgLayout.getCTAOrder(); + } else { + llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); + } + return res; +} + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector shapePerCTA(rank); + for (unsigned i = 0; i < rank; ++i) { + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + shapePerCTA[i] = shape[i] / splitNum; + } + return shapePerCTA; +} + +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { + if (mlir::isa(layout)) { + // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. + // The first dim of shape is numStages. This is a work around, otherwise + // too many places would have to be modified in pipeline pass. Maybe we + // need to refactor this logic in the future. + auto CTASplitNum = cast(layout).getCTASplitNum(); + if (shape.size() == CTASplitNum.size() + 1) { + auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); + res.insert(res.begin(), shape.front()); + return res; + } + } + SmallVector splitNum = getCTASplitNum(layout); + if (auto tmem = dyn_cast(layout)) { + if (shape.size() > splitNum.size()) { + splitNum.insert(splitNum.begin(), shape.size() - splitNum.size(), 1); + } + } + return getShapePerCTA(splitNum, shape); +} + +SmallVector getAllocationShapePerCTA(Attribute layout, + ArrayRef shapeLogical) { + SmallVector shape(shapeLogical); + if (auto sharedMMALayout = mlir::dyn_cast(layout)) { + if (sharedMMALayout.getFp4Padded()) { + auto packedAxis = getOrder(sharedMMALayout, shapeLogical)[0]; + if (shape.size() == 3) { + // Take into account multi buffering + shape[1 + packedAxis] *= 2; + } else { + shape[packedAxis] *= 2; + } + } + } + return getShapePerCTA(layout, shape); +} + +SmallVector getShapePerCTA(Type type) { + auto tensorType = cast(type); + return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape()); +} + +SmallVector getAllocationShapePerCTA(Type type) { + auto tensorType = cast(type); + return getAllocationShapePerCTA(tensorType.getEncoding(), + tensorType.getShape()); +} + +unsigned getNumWarpsPerCTA(Attribute layout) { + SmallVector warpsPerCTA; + if (auto blockedLayout = dyn_cast(layout)) + warpsPerCTA = blockedLayout.getWarpsPerCTA(); + else if (auto sliceLayout = dyn_cast(layout)) + return getNumWarpsPerCTA(sliceLayout.getParent()); + else if (auto mmaLayout = dyn_cast(layout)) { + // Use the distributed layout interface to get the number of warps per + // CTA. + auto distributedLayout = cast(layout); + warpsPerCTA = distributedLayout.getWarpsPerCTA(); + } else if (auto mfmaLayout = dyn_cast(layout)) + warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + else if (auto wmmaLayout = dyn_cast(layout)) + warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + else if (auto dotLayout = dyn_cast(layout)) + warpsPerCTA = dotLayout.getWarpsPerCTA(); + else + llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA"); + return product(warpsPerCTA); +} + +unsigned getNumCTAs(Attribute layout) { + return product(getCTAsPerCGA(layout)); +} + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { + // If the new elements per thread is less than the old one, we will need to + // do convert encoding that goes through shared memory anyway. So we + // consider it as expensive. + RankedTensorType tensorTy = cat.getType(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape); + return newTotalElemsPerThread < totalElemsPerThread; +} + +LogicalResult CTALayoutAttr::verify( + function_ref emitError, ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, ArrayRef CTAOrder) { + if (CTAsPerCGA.size() != CTASplitNum.size() || + CTASplitNum.size() != CTAOrder.size()) { + return emitError() << "CTAsPerCGA, CTASplitNum, and CTAOrder must all have " + "the same rank."; + } + + if (!isPermutationOfIota(CTAOrder)) { + return emitError() + << "CTAOrder must be a permutation of 0..(rank-1), but was [" + << CTAOrder << "]"; + } + + if (llvm::any_of(CTAsPerCGA, [](unsigned x) { return x == 0; })) { + return emitError() << "Every element in CTAsPerCGA must be greater than 0."; + } + + if (llvm::any_of(CTASplitNum, [](unsigned x) { return x == 0; })) { + return emitError() + << "Every element in CTASplitNum must be greater than 0."; + } + + return success(); +} + +LogicalResult +BlockedEncodingAttr::verify(function_ref emitError, + ArrayRef sizePerThread, + ArrayRef threadsPerWarp, + ArrayRef warpsPerCTA, + ArrayRef order, CTALayoutAttr CTALayout) { + if (sizePerThread.size() != threadsPerWarp.size() || + threadsPerWarp.size() != warpsPerCTA.size() || + warpsPerCTA.size() != order.size()) { + return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and " + "order must all have the same rank."; + } + + // Empty CTALayout is allowed, but if it's present its rank must match the + // BlockedEncodingAttr's rank. + if (CTALayout.getCTASplitNum().size() != 0 && + sizePerThread.size() != CTALayout.getCTASplitNum().size()) { + return emitError() << "BlockedEncodingAttr and CTALayout's fields must " + "have the same rank."; + } + if (!isPermutationOfIota(order)) { + return emitError() + << "order must be a permutation of 0..(rank-1), but was [" << order + << "]"; + } + return success(); +} + +// 1 element per thread +// order = reverse(arange(rank)) +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs) { + int rank = shape.size(); + llvm::SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + std::reverse(order.begin(), order.end()); + llvm::SmallVector sizePerThread(rank, 1); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(context, shape, sizePerThread, + order, numWarps, threadsPerWarp, + numCTAs); + return encoding; +} + +LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl, + LinearLayout &outLl, bool fwdInference, int axis, + std::optional loc) { + auto kRegister = StringAttr::get(ctx, "register"); + auto outDims = llvm::to_vector(inLl.getOutDimNames()); + if (fwdInference) { + auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]); + outLl = split * inLl; + } else { + // TODO This requires a division algorithm! + // Implement manually ll.divideLeft(split) + auto contiguousElems = + LinearEncodingAttr::get(ctx, inLl).getContigPerThread(); + if (contiguousElems[axis] > 1) { + LinearLayout::BasesT newBases; + for (const auto &basesDim : inLl.getBases()) { + std::vector> newBasesDim; + for (auto base : basesDim.second) { + if (base[axis] == 1) { + continue; + } + base[axis] /= 2; + newBasesDim.push_back(std::move(base)); + } + newBases.insert({basesDim.first, std::move(newBasesDim)}); + } + outLl = LinearLayout(std::move(newBases), std::move(outDims)); + } else { + return emitOptionalError(loc, + "Fp4ToFpOp/SplitOp requires at least 2 elements " + "per thread in the axis/last dimension"); + } + } + return success(); +} + +} // namespace gpu +} // namespace triton +} // namespace mlir + +static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, + unsigned &value, StringRef desc) { + auto intAttr = mlir::dyn_cast(attr); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer type in ") + << desc; + return failure(); + } + if (intAttr.getType().isSignedInteger()) { + int64_t attrVal = intAttr.getSInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else if (intAttr.getType().isSignlessInteger()) { + int64_t attrVal = intAttr.getInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else { + value = intAttr.getUInt(); + } + return success(); +} + +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + +// parse an array of integers +static LogicalResult parseIntArrayAttr(AsmParser &parser, + const NamedAttribute &attr, + SmallVector &res, + StringRef desc) { + auto arrayAttr = mlir::dyn_cast(attr.getValue()); + if (!arrayAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; + return failure(); + } + for (Attribute i : arrayAttr) { + unsigned value; + if (parseIntAttrValue(parser, i, value, desc).failed()) + return failure(); + res.push_back(value); + } + return success(); +}; + +static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, + unsigned &value, StringRef desc) { + return parseIntAttrValue(parser, attr.getValue(), value, desc); +}; + +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + +// Print the CTALayout if it's not equal to the default. +static void maybePrintCTALayout(mlir::MLIRContext *context, + mlir::AsmPrinter &printer, CTALayoutAttr layout, + unsigned rank) { + if (layout != CTALayoutAttr::getDefault(context, rank)) { + printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]" + << ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]" + << ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]"; + } +} + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" + +// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. +// But we need to have a consistent interface with e.g. SliceEncodingAttr, which +// computes some of these fields. +SmallVector BlockedEncodingAttr::getRepOrder() const { + return SmallVector(getDefaultOrder()); +} +SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector BlockedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector BlockedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector BlockedEncodingAttr::getDefaultOrder() const { + return SmallVector(getOrder()); +} +SmallVector BlockedEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector BlockedEncodingAttr::getDefaultWarpOrder() const { + return SmallVector(getDefaultOrder()); +} +SmallVector BlockedEncodingAttr::getThreadsPerWarp() const { + return SmallVector(getThreadsPerWarp__()); +} +SmallVector BlockedEncodingAttr::getDefaultThreadOrder() const { + return SmallVector(getDefaultOrder()); +} + +template +SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { + size_t rank = shape.size(); + unsigned dim = getDim(); + SmallVector retShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + retShape[d] = shape[d]; + else if (d == dim) + retShape[d] = 1; + else + retShape[d] = shape[d - 1]; + } + return retShape; +} +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +SmallVector SliceEncodingAttr::getRepOrder() const { + auto parentRepOrder = getParent().getRepOrder(); + return eraseOrder(parentRepOrder, getDim()); +} +SmallVector SliceEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + res.erase(res.begin() + getDim()); + return res; +} +SmallVector SliceEncodingAttr::getCTAOrder() const { + auto parentCTAOrder = ::getCTAOrder(getParent()); + return eraseOrder(parentCTAOrder, getDim()); +} +SmallVector SliceEncodingAttr::getCTAsPerCGA() const { + auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); + if (parentCTAsPerCGA[getDim()] == 1) { + parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); + return parentCTAsPerCGA; + } + /* For getCTAsPerCGA of a slice layout, we have two choices: + * (1) Return CTAsPerCGA of its parent. This is not a perfect solution + * because the rank of the returned CTAsPerCGA does not match the rank of + * tensorShape. + * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a + * perfect solution because the product of the returned CTAsPerCGA might not + * match numCTAs. + * To avoid introducing inconsistencies to the shape and + * layout system, the usage of directly getting CTAsPerCGA of a slice layout + * in which the sliced dim is not 1 is banned. You should always consider + * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) + * in the branch where layout is an instance of SliceEncodingAttr. This is + * inconvenient but safe. + */ + llvm::report_fatal_error( + "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); +} +SmallVector SliceEncodingAttr::getWarpsPerCTA() const { + auto parent = getParent(); + auto parentWarpsPerCTA = ::getWarpsPerCTA(parent); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + getDim()); + int32_t nextDim = getDim() < warpsPerCTA.size() ? getDim() : getDim() - 1; + warpsPerCTA[nextDim] *= parentWarpsPerCTA[getDim()]; + return warpsPerCTA; +} +SmallVector SliceEncodingAttr::getDefaultWarpOrder() const { + auto parentWarpOrder = getParent().getDefaultWarpOrder(); + return eraseOrder(parentWarpOrder, getDim()); +} +SmallVector SliceEncodingAttr::getThreadsPerWarp() const { + auto parent = getParent(); + auto parentThreadsPerWarp = ::getThreadsPerWarp(parent); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + getDim()); + int32_t nextDim = getDim() < threadsPerWarp.size() ? getDim() : getDim() - 1; + threadsPerWarp[nextDim] *= parentThreadsPerWarp[getDim()]; + return threadsPerWarp; +} +SmallVector SliceEncodingAttr::getDefaultThreadOrder() const { + auto parentThreadOrder = getParent().getDefaultThreadOrder(); + return eraseOrder(parentThreadOrder, getDim()); +} +SmallVector SliceEncodingAttr::getDefaultOrder() const { + SmallVector parentOrder = getParent().getDefaultOrder(); + unsigned dim = getDim(); + SmallVector order; + for (unsigned d : parentOrder) { + if (d != dim) + order.push_back(d > dim ? d - 1 : d); + } + return order; +} + +// + +// Wmma encoding + +int32_t SwizzledSharedEncodingAttr::getAlignment() const { return 16; } + +SmallVector SwizzledSharedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector SwizzledSharedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector SwizzledSharedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +int32_t NVMMASharedEncodingAttr::getAlignment() const { return 1024; } + +SmallVector NVMMASharedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector NVMMASharedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector NVMMASharedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { + return ::getCTAsPerCGA(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTAOrder() const { + return ::getCTAOrder(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + auto rank = res.size(); + assert(rank == 2 || rank == 3 && "Invalid dotLayout"); + + // Do not split CTA in K dimension + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + res[kDim] = 1; + return res; +} +SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { + auto distributedLayout = mlir::cast(getParent()); + auto warps = distributedLayout.getWarpsPerCTA(); + auto rank = warps.size(); + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + warps[kDim] = 1; + return warps; +} +SmallVector DotOperandEncodingAttr::getDefaultOrder() const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(getOpIdx(), rank, /*kContig*/ true); +} +SmallVector DotOperandEncodingAttr::getDefaultWarpOrder() const { + // FIXME(Lezcano): Preexisting. Do we want to have this path at all? + if (mlir::isa(getParent())) { + return mlir::cast(getParent()) + .getDefaultWarpOrder(); + } + llvm::report_fatal_error( + "DotOperandEncoding::getDefaultWarpOrder not implemented"); + return {}; +} +SmallVector DotOperandEncodingAttr::getDefaultThreadOrder() const { + return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), + /*kContig*/ true); +} + +LogicalResult DotOperandEncodingAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + unsigned opIdx, Attribute parent, unsigned kWidth) { + if (opIdx != 0 && opIdx != 1) { + return emitError() << "ttg.dot_op opIdx parameter can be 0 or 1, got: " + << opIdx; + } + if (!parent) { + return emitError() << "ttg.dot_op parent parameter cannot be null"; + } + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter can only be " + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "ttg.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 16 && parentAttr.getVersion() == 1 || + kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2) + return emitError() << "ttg.dot_op kWidth parameter must be 16 for " + "gfx11 and 8/16 for gfx12"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth == 0) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "MFMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; + return success(); + } + + return emitError() << "ttg.dot_op unexpected parent layout: " << parent; +} + +//===----------------------------------------------------------------------===// +// Blocked Encoding +//===----------------------------------------------------------------------===// + +static std::optional getCTALayoutOrError( + AsmParser &parser, std::optional> CTAsPerCGA, + std::optional> CTASplitNum, + std::optional> CTAOrder, unsigned rank) { + if (CTAsPerCGA && CTASplitNum && CTAOrder) { + return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum, + *CTAOrder); + } + if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) { + return CTALayoutAttr::getDefault(parser.getContext(), rank); + } + parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder " + "must all be present or all be absent"); + return std::nullopt; +} + +Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "sizePerThread") { + if (parseIntArrayAttr(parser, attr, sizePerThread, + "number of elements per thread") + .failed()) + return {}; + } else if (attr.getName() == "threadsPerWarp") { + if (parseIntArrayAttr(parser, attr, threadsPerWarp, + "number of threads per warp") + .failed()) + return {}; + } else if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, + "number of warps per CTA") + .failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/sizePerThread.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), + sizePerThread, threadsPerWarp, + warpsPerCTA, order, *CTALayout); +} + +void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "sizePerThread = [" << ArrayRef(getSizePerThread()) << "]" + << ", threadsPerWarp = [" << ArrayRef(getThreadsPerWarp()) << "]" + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" + << ", order = [" << getOrder() << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getSizePerThread().size()); + + printer << "}>"; +} + +// FIXME Can we take the LinearLayout by const&? +LogicalResult +LinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout) { + // Example of LinearEncodingAttr + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + // The input dims must be {register, lane, warp, block} + // The output dims of the linear layout should be dim0..dim[rank-1] + + static const auto expectedInDims = + SmallVector({"register", "lane", "warp", "block"}); + for (const auto &[i, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expectedDimStr] = dims; + if (dim.str() != expectedDimStr) { + return emitError() << "Expected input dimension " << i << " to be '" + << expectedDimStr << "'. Got " << dim; + } + } + + // outDims are ['dim0', 'dim1', ...] + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + const auto &bases = linearLayout.getBases(); + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; + })) { + return emitError() + << "In a distributed layout, each base must move in at most one " + "dimension."; + } + } + + return success(); +} + +void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const { + // We don't use the default implementation as it's a bit too verbose + // This prints in the following format that is shape agnostic, in the sense + // that we don't print explicitly the outShape of the LL + // We always assume LLs to be surjective + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + auto ll = getLinearLayout(); + printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) { + return base.first.str() + " = " + "[" + + join(base.second, ", ", + [](const std::vector &vec) { + return "[" + join(vec, ", ") + "]"; + }) + + "]"; + }) << "}>"; +} + +Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + LinearLayout::BasesT bases; + + // Parse the basis names in order (the order is relevant) + std::vector inDimNames = {"register", "lane", "warp", "block"}; + + for (const auto &inDimNameStr : inDimNames) { + auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); + Attribute value = dict.get(inDimName); + + // Expecting an array of arrays + auto arrayOfArraysAttr = mlir::dyn_cast(value); + if (!arrayOfArraysAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of arrays for basis of '") + << inDimName.getValue() << "'"; + return {}; + } + + std::vector> inDimBases; + for (Attribute arrayAttr : arrayOfArraysAttr) { + auto intArrayAttr = mlir::dyn_cast(arrayAttr); + if (!intArrayAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of integers in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + std::vector basis; + for (Attribute intAttr : intArrayAttr) { + auto intValueAttr = mlir::dyn_cast(intAttr); + if (!intValueAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected integer in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + basis.push_back(intValueAttr.getInt()); + } + inDimBases.push_back(std::move(basis)); + } + bases[inDimName] = std::move(inDimBases); + } + size_t rank = 0; + for (const auto &basesDim : llvm::make_second_range(bases)) { + if (!basesDim.empty()) { + rank = basesDim[0].size(); + break; + } + } + + // To implement this we'd need to serialise the rank as well. + // We can do this if we ever need it + if (rank == 0) { + parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); + return {}; + } + + // Generate standared outDimNames (dim0, dim1, ...) + SmallVector outDimNames; + for (int i = 0; i < rank; ++i) { + outDimNames.push_back( + StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); + } + + // Create LinearLayout + LinearLayout linearLayout(std::move(bases), std::move(outDimNames)); + + // Create and return the LinearEncodingAttr + return parser.getChecked(parser.getContext(), + std::move(linearLayout)); +} + +SmallVector basesPerDimImpl(const LinearLayout::BasesT &namedBases, + StringAttr dimName, size_t rank, + bool skipBroadcast = true) { + const auto &bases = namedBases.find(dimName)->second; + + if (bases.empty()) { + return SmallVector(rank, 1); + } + + SmallVector ret(rank, 1); + auto nonZero = [](auto val) { return val != 0; }; + int nonZeroIdx = 0; + for (const auto &basis : bases) { + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + if (it != basis.end()) { + nonZeroIdx = it - basis.begin(); + ret[nonZeroIdx] *= 2; + } else if (!skipBroadcast) { + // If we've seen a non-zero basis, we double the size of the previous dim + // This is just needed to count the CTAsPerCGA + ret[nonZeroIdx] *= 2; + } + } + return ret; +} + +SmallVector +LinearEncodingAttr::basesPerDim(StringAttr dimName, bool skipBroadcast) const { + auto ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast); +} + +SmallVector +LinearEncodingAttr::orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const { + auto ll = getLinearLayout(); + const auto &bases = ll.getBases().find(dimName)->second; + llvm::SetVector order; + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &basis : bases) { + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + if (it != basis.end()) { + auto i = it - basis.begin(); + order.insert(i); + } + } + // If any dim is missing, we add them in the defaultOrder + for (auto i : defaultOrder) { + order.insert(i); + } + return SmallVector(order.begin(), order.end()); +} + +// [Note. Divergence of methods wrt. legacy layouts] +// For smaller shapes where the CTATile is larger than the output +// tensor, some methods return different values than the legacy layouts. I think +// this is benign tho. An example: what is the the vector of `warpsPerCTA` if +// all the warps hold the same data? I think it should be [1, 1], even if we +// have 4 warps. But perhaps for this we have to add some masking in some +// places... We'll see +SmallVector LinearEncodingAttr::getRepOrder() const { + // This is not correct, but: + // - It happens to agree in most places with the legacy layout + // - getRepOrder does not make sense for LinearEncodingAttr as it already has + // the same shape as the tensor that uses it + return getOrder(); +} +SmallVector LinearEncodingAttr::getCTAsPerCGA() const { + // CTAs are split into an identity part (SplitNum) and a broadcast part + return basesPerDim(StringAttr::get(getContext(), "block"), + /*skipBroadcast=*/false); +} +SmallVector LinearEncodingAttr::getCTAOrder() const { + return orderPerDim(StringAttr::get(getContext(), "block"), getOrder()); +} +SmallVector LinearEncodingAttr::getCTASplitNum() const { + return basesPerDim(StringAttr::get(getContext(), "block")); +} +SmallVector LinearEncodingAttr::getWarpsPerCTA() const { + return basesPerDim(StringAttr::get(getContext(), "warp")); +} +SmallVector LinearEncodingAttr::getDefaultWarpOrder() const { + return getWarpOrder(); +} +SmallVector LinearEncodingAttr::getWarpOrder() const { + return orderPerDim(StringAttr::get(getContext(), "warp"), getOrder()); +} +SmallVector LinearEncodingAttr::getThreadsPerWarp() const { + return basesPerDim(StringAttr::get(getContext(), "lane")); +} +SmallVector LinearEncodingAttr::getDefaultThreadOrder() const { + return getThreadOrder(); +} +SmallVector LinearEncodingAttr::getThreadOrder() const { + return orderPerDim(StringAttr::get(getContext(), "lane"), getDefaultOrder()); +} +SmallVector LinearEncodingAttr::getSizePerThread() const { + auto rank = getOrder().size(); + auto ll = getLinearLayout(); + auto ctx = getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + + // We canonicalize on the spot, as if we use CGAs the regs are not in + // canonical form The order is [reg, lane, warp, rep, block], so we first + // remove the blocks + llvm::SmallVector ctaShape; + for (auto [shape, cgaNum] : + llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) { + ctaShape.push_back(shape / cgaNum); + } + LinearLayout::BasesT bases = ll.getBases(); + + llvm::SetVector reverseRepOrder; + auto nonZero = [](auto val) { return val != 0; }; + auto ®isters = bases[kRegister]; + while (!registers.empty()) { + auto &basis = registers.back(); + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // If there's broadcasting (base == zeros) there are no more reps + if (it == basis.end()) { + break; + } + auto dim = it - basis.begin(); + reverseRepOrder.insert(dim); + // As soon as we stop finding reps, we stop + if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) { + break; + } + ctaShape[dim] /= 2; + registers.pop_back(); + } + return basesPerDimImpl(bases, kRegister, rank); +} + +SmallVector LinearEncodingAttr::getShapePerCTATile() const { + auto sizePerThread = getSizePerThread(); + auto threadsPerWarp = getThreadsPerWarp(); + auto warpsPerCTA = getWarpsPerCTA(); + SmallVector shape; + for (auto [size, thread, warp] : + llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) { + shape.push_back(size * thread * warp); + } + return shape; +} + +SmallVector LinearEncodingAttr::getDefaultOrder() const { + return getOrder(); +} + +SmallVector LinearEncodingAttr::getOrder() const { + auto rank = getLinearLayout().getNumOutDims(); + SmallVector order(rank); + // Choose [rank-1, rank-2, ... 0] as the default order in case + // there are dims that do not move in the register + // This order is as good as any really + std::iota(order.rbegin(), order.rend(), 0); + + return orderPerDim(StringAttr::get(getContext(), "register"), order); +} + +LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ll = getLinearLayout(); + auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); + llvm::SmallDenseMap namedShape; + llvm::SmallVector permutedDims; + for (auto dim : getRepOrder()) { + permutedDims.push_back(canonicalDims[dim]); + namedShape[canonicalDims[dim]] = shape[dim]; + } + ll = ll.transposeOuts(permutedDims); + ll = ensureLayoutNotSmallerThan(ll, namedShape); + ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false); + ll = ll.transposeOuts(canonicalDims); + return ll; +} + +SmallVector +LinearEncodingAttr::getElemsPerThread(ArrayRef shape) const { + // When broadcasting the layout the shape changes, otherwise the shape is + // the same as the shape of the tensor + // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep + // the invariant that the shape of the LL is that of the tensor + // We choose the former for BC + auto scaledLayout = get(getContext(), toLinearLayout(shape)); + auto kRegister = StringAttr::get(getContext(), "register"); + return scaledLayout.basesPerDim(kRegister, /*skipBroadcast=*/false); +} + +SmallVector +LinearEncodingAttr::getContig(const char *inDim, + SmallVector lowerContig) const { + auto ll = getLinearLayout(); + const auto &bases = + ll.getBases().find(StringAttr::get(getContext(), inDim))->second; + auto order = getOrder(); + auto rank = order.size(); + + SmallVector contig(lowerContig); + auto basisIt = bases.begin(); + for (unsigned dim : order) { + std::vector basis(rank, 0); + basis[dim] = contig[dim]; + + while (basisIt != bases.end() && *basisIt == basis) { + contig[dim] *= 2; + basis[dim] *= 2; + ++basisIt; + } + } + return contig; +} + +SmallVector LinearEncodingAttr::getContigPerThread() const { + SmallVector contig(getOrder().size(), 1); + return getContig("register", contig); +} + +SmallVector LinearEncodingAttr::getContigPerWarp() const { + return getContig("lane", getContigPerThread()); +} + +unsigned +LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape) const { + return product(getElemsPerThread(shape)); +} + +//===----------------------------------------------------------------------===// +// MMA encoding +//===----------------------------------------------------------------------===// + +Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + SmallVector instrShape; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CTALayout, + instrShape); +} + +void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +//===----------------------------------------------------------------------===// +// MFMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + bool isTransposed; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) + return {}; + } + + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, + instrShape[0], instrShape[1], isTransposed, *CTALayout); +} + +void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() // + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" // + << ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" // + << ", isTransposed = " << getIsTransposed(); + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +LogicalResult +AMDMfmaEncodingAttr::verify(function_ref emitError, + unsigned versionMajor, unsigned versionMinor, + llvm::ArrayRef warpsPerCTA, + unsigned mDim, unsigned nDim, bool isTransposed, + mlir::triton::gpu::CTALayoutAttr) { + if (!(versionMajor >= 0 && versionMajor <= 4)) { + return emitError() << "major version must be in the [0, 4] range"; + } + if (versionMinor != 0) { + return emitError() << "minor version must be 0"; + } + if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) { + return emitError() + << "(M, N) cases other than (32, 32) or (16, 16) unimplemented"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// WMMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned version = 0; + bool isTransposed = false; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "version") { + if (parseUInt(parser, attr, version, "version").failed()) + return {}; + } + if (attr.getName() == "isTranspose") { + if (parseBool(parser, attr, isTransposed, "isTranspose").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), version, isTransposed, warpsPerCTA, *CTALayout); +} + +void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "version = " << getVersion() + << ", isTranspose = " << getIsTransposed() << ", warpsPerCTA = [" + << ArrayRef(getWarpsPerCTA()) << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +LogicalResult +AMDWmmaEncodingAttr::verify(function_ref emitError, + unsigned version, bool isTransposed, + llvm::ArrayRef warpsPerCTA, + mlir::triton::gpu::CTALayoutAttr) { + if (version != 1 && version != 2) { + return emitError() << "WMMA version must be in the [1, 2] range"; + } + // Transposed layout is needed for bypassing LDS between multiple dots. + // Version 1 tt.dot results and tt.dot operand layouts are different, + // therefore we test and support transposed only for version 2. + if (version != 2 && isTransposed) { + return emitError() << "Transposed WMMA is supported only for version 2"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Sliced Encoding +//===----------------------------------------------------------------------===// + +Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + unsigned dim = mlir::cast(attrs.get("dim")).getInt(); + auto parent = mlir::dyn_cast(attrs.get("parent")); + if (!parent) { + parser.emitError(parser.getNameLoc(), + "expected a distributed encoding trait"); + return {}; + } + return parser.getChecked(parser.getContext(), dim, parent); +} + +void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "dim = " << getDim() << ", " + << "parent = " << getParent() << "}>"; +} + +//===----------------------------------------------------------------------===// +// SwizzledShared encoding +//===----------------------------------------------------------------------===// + +Attribute SwizzledSharedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned vec = 0; + unsigned perPhase = 0; + unsigned maxPhase = 0; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "vec") { + if (parseUInt(parser, attr, vec, "vec").failed()) + return {}; + } else if (attr.getName() == "perPhase") { + if (parseUInt(parser, attr, perPhase, "perPhase").failed()) + return {}; + } else if (attr.getName() == "maxPhase") { + if (parseUInt(parser, attr, maxPhase, "maxPhase").failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/order.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), vec, perPhase, maxPhase, order, *CTALayout); +} + +void SwizzledSharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "vec = " << getVec() // + << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() // + << ", order = [" << getOrder() << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getOrder().size()); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// NVMMAShared encoding +//===----------------------------------------------------------------------===// + +Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned swizzlingByteWidth; + bool transposed = false; + bool fp4Padded = false; + unsigned elementBitWidth; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "swizzlingByteWidth") { + if (parseUInt(parser, attr, swizzlingByteWidth, "swizzlingByteWidth") + .failed()) + return {}; + } else if (attr.getName() == "transposed") { + if (parseBool(parser, attr, transposed, "transposed").failed()) + return {}; + } else if (attr.getName() == "elementBitWidth") { + if (parseUInt(parser, attr, elementBitWidth, "elementBitWidth").failed()) + return {}; + } else if (attr.getName() == "fp4Padded") { + if (parseBool(parser, attr, fp4Padded, "fp4Padded").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/2); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), swizzlingByteWidth, transposed, elementBitWidth, + fp4Padded, *CTALayout); +} + +void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "swizzlingByteWidth = " << getSwizzlingByteWidth() // + << ", transposed = " << getTransposed() // + << ", elementBitWidth = " << getElementBitWidth(); + if (getFp4Padded()) { + // Print only in this case to reduce the noise for the more common case. + printer << ", fp4Padded = true"; + } + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/2); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// Mfma encoding +//===----------------------------------------------------------------------===// +// TODO: there is a lot of common code with MmaEncoding here + +SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDMfmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDMfmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector AMDMfmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector AMDMfmaEncodingAttr::getDefaultOrder() const { + return getDefaultMmaOrder(*this); +} +SmallVector AMDMfmaEncodingAttr::getDefaultWarpOrder() const { + return getDefaultOrder(); +} +SmallVector AMDMfmaEncodingAttr::getDefaultThreadOrder() const { + auto order = getDefaultOrder(); + if (getIsTransposed()) + std::swap(order[0], order[1]); + return order; +} +SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { + unsigned rows, cols; + auto rank = getDefaultOrder().size(); + SmallVector res(rank, 1); + if (getMDim() == 32) { + cols = 2; + rows = 32; + } else { + assert(getMDim() == 16); + cols = 4; + rows = 16; + } + if (getIsTransposed()) { + res[rank - 1] = cols; + res[rank - 2] = rows; + } else { + res[rank - 1] = rows; + res[rank - 2] = cols; + } + return res; +} + +SmallVector +AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { + unsigned mDim = getMDim(); + unsigned nDim = getNDim(); + assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + constexpr int warpSize = 64; // MFMA is always based on the 64-wide warps. + int kGroups = -1; + if (mDim == nDim) + kGroups = warpSize / mDim; + if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) + kGroups = 1; + int64_t kDim = kWidth * kGroups; + if (opIdx == 0) + return {mDim, kDim}; + else + assert(opIdx == 1); + return {kDim, nDim}; +} + +SmallVector AMDMfmaEncodingAttr::getRepOrder() const { + auto rank = getDefaultOrder().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getDefaultOrder().size(); + return getOrderForDotOperand(opIdx, rank, /*kContig*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const { + auto rank = getDefaultOrder().size(); + SmallVector threads(rank, 1); + unsigned kThreads; + unsigned nonKThreads; + switch (getMDim()) { + case 32: + assert(getNDim() == 32); + kThreads = 2; + nonKThreads = 32; + break; + case 16: + assert(getNDim() == 16); + kThreads = 4; + nonKThreads = 16; + break; + default: + llvm::report_fatal_error( + "unexpected mfma shape encountered in getThreadsPerWarpForOperand"); + } + int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; + int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; + threads[kDimIdx] = kThreads; + threads[nonKDimIdx] = nonKThreads; + return threads; +} + +SmallVector +AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getInstrShapeForOperand(kWidth, opIdx); + auto rank = operandShape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +//===----------------------------------------------------------------------===// +// Wmma encoding +//===----------------------------------------------------------------------===// + +SmallVector AMDWmmaEncodingAttr::getRepOrder() const { + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +SmallVector +AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getDefaultOrder().size(); + return getOrderForDotOperand(opIdx, rank, /*kContig*/ true); +} + +SmallVector +AMDWmmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const { + auto rank = getDefaultOrder().size(); + SmallVector threads(rank, 1); + unsigned kThreads; + unsigned nonKThreads; + switch (getVersion()) { + case 1: + // kThreads * onKThreads != 32, + // because values in lanes (n, n + 16) duplicates + kThreads = 1; + nonKThreads = 16; + break; + case 2: + kThreads = 2; + nonKThreads = 16; + break; + default: + llvm::report_fatal_error( + "unsupported WMMA version in getThreadsPerWarpForOperand"); + } + int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; + int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; + threads[kDimIdx] = kThreads; + threads[nonKDimIdx] = nonKThreads; + return threads; +} + +SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDWmmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDWmmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector AMDWmmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector AMDWmmaEncodingAttr::getDefaultOrder() const { + return getDefaultMmaOrder(*this); +} +SmallVector AMDWmmaEncodingAttr::getDefaultWarpOrder() const { + return getDefaultOrder(); +} +SmallVector AMDWmmaEncodingAttr::getDefaultThreadOrder() const { + auto order = getDefaultOrder(); + if (getIsTransposed()) + std::swap(order[0], order[1]); + return order; +} +SmallVector AMDWmmaEncodingAttr::getThreadsPerWarp() const { + auto rank = getWarpsPerCTA().size(); + SmallVector threads(rank, 1); + auto mnkInstr = getMNKDimPerInstr(); + mnkInstr[getIsTransposed() ? 1 : 0] /= 8; + threads[rank - 2] = mnkInstr[0]; + threads[rank - 1] = mnkInstr[1]; + return threads; +} + +SmallVector AMDWmmaEncodingAttr::getElemsPerInstrForOperands() const { + return {16, 16}; +} + +SmallVector +AMDWmmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, + int opIdx) const { + auto operandTileShape = getElemsPerInstrForOperands(); + assert(operandTileShape.size() == 2); + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = operandShape.size(); + assert(rank == 2 || rank == 3); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +SmallVector AMDWmmaEncodingAttr::getMNKDimPerInstr() { + // TODO: move magic numbers out of the code + return {16, 16, 16}; +} + +unsigned AMDWmmaEncodingAttr::getKWidthForOperands() const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + auto numReplicated = getVersion() == 1 ? 2 : 1; + auto elemsPerInstr = numReplicated * product(getElemsPerInstrForOperands()) / + product(getThreadsPerWarp()); + return elemsPerInstr; +} + +//===----------------------------------------------------------------------===// +// Mma encoding +//===----------------------------------------------------------------------===// + +bool NvidiaMmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; } + +bool NvidiaMmaEncodingAttr::isTuring() const { + return getVersionMajor() == 2 && getVersionMinor() == 1; +} + +bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } + +bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } + +SmallVector NvidiaMmaEncodingAttr::getRepOrder() const { + auto rank = getDefaultOrder().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} +SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector NvidiaMmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector NvidiaMmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector NvidiaMmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector NvidiaMmaEncodingAttr::getDefaultOrder() const { + return getDefaultMmaOrder(*this); +} +SmallVector NvidiaMmaEncodingAttr::getDefaultWarpOrder() const { + auto rank = getDefaultOrder().size(); + // Hopper (wgmma) uses column-major as this is embedded in the instruction + // For Ampere we can choose either row-major or column-major. + // We choose row-major as the legacy path did so + return getMatrixOrder(rank, /*rowMajor*/ !isHopper()); +} +SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { + auto rank = getDefaultOrder().size(); + SmallVector res(rank, 1); + if (isAmpere()) { + res[rank - 2] = 8; + res[rank - 1] = 4; + return res; + } + if (isHopper()) { + res[rank - 2] = 8; + res[rank - 1] = 4; + return res; + } + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for unknown Mma version "); +} +SmallVector NvidiaMmaEncodingAttr::getDefaultThreadOrder() const { + auto rank = getDefaultOrder().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +SmallVector +NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getDefaultOrder().size(); + return getOrderForDotOperand(opIdx, rank, /*kContig*/ true); +} + +SmallVector +NvidiaMmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const { + auto threadsPerWarp = getThreadsPerWarp(); + auto rank = threadsPerWarp.size(); + if (opIdx == 1) + std::swap(threadsPerWarp[rank - 2], threadsPerWarp[rank - 1]); + return threadsPerWarp; +} + +SmallVector +NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, + int kWidth, int opIdx) const { + assert( + kWidth >= 32 / bitwidth && + "kWidth must be >= 32 / bitwidth for this function to be well-defined"); + auto rank = shape.size(); + // Broadcast long K + auto warpsPerCTA = getWarpsPerCTA(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; + warpsPerCTA[kDim] = 1; + + SmallVector tileSize; + if (rank == 3) { + tileSize.push_back(1); + } + if (opIdx == 0) { + // m x k + tileSize.push_back(16); + tileSize.push_back(4 * 64 / bitwidth); + } else { + // k x n + // Hopper path never uses the n value, since this method is only invoked + // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF + // so it's fine if the n is incorrect here + tileSize.push_back(4 * 64 / bitwidth); + tileSize.push_back(8); + } + + SmallVector numRep; + // Lezcano: This is odd. Why do we always return a vector of size 3? + if (rank != 3) { + numRep.push_back(1); + } + for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) { + numRep.push_back(std::max(1, s / (size * warp))); + } + return numRep; +} + +//===----------------------------------------------------------------------===// +// DotOperand Encoding +//===----------------------------------------------------------------------===// +SmallVector DotOperandEncodingAttr::getRepOrder() const { + if (auto mma = mlir::dyn_cast(getParent())) { + return mma.getRepOrderForOperand(getOpIdx()); + } + llvm::report_fatal_error( + "getRepOrder not implemented for DotOperandEncodingAttr"); + return {}; +} + +SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { + if (auto mma = mlir::dyn_cast(getParent())) { + return mma.getThreadsPerWarpForOperand(getOpIdx()); + } + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for DotOperandEncodingAttr"); + return {}; +} + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// + +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + // Encoding attributes + if (auto mmaAttr = mlir::dyn_cast(attr)) { + os << "mma"; + return AliasResult::FinalAlias; + } else if (auto sharedAttr = mlir::dyn_cast(attr)) { + os << "shared"; + return AliasResult::FinalAlias; + } else if (auto blockedAttr = mlir::dyn_cast(attr)) { + os << "blocked"; + return AliasResult::FinalAlias; + } else if (auto linearAttr = mlir::dyn_cast(attr)) { + os << "linear"; + return AliasResult::FinalAlias; + } /* else if (auto sliceAttr = dyn_cast(attr)) { + os << "slice"; + return AliasResult::FinalAlias; + } */ + // Memory space attributes + if (auto smem = mlir::dyn_cast(attr)) { + os << "smem"; + return AliasResult::FinalAlias; + } + return OpAsmDialectInterface::getAlias(attr, os); + } +}; + +struct TritonGPUInferLayoutInterface + : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const override { + resultEncoding = + SliceEncodingAttr::get(getDialect()->getContext(), axis, + cast(operandEncoding)); + return success(); + } + + // Infer the encoding of a tt.trans(x) given the encoding of x. + // + // Our goal is to choose an encoding so that the trans is a "nop". For + // example, in a blocked encoding, the same GPU threads hold the same + // elements, they're just "renamed" -- what was element [i,j] of the tensor is + // now element [j,i], but that element is held by the same GPU thread. + // + // For most properties of the encoding, we let + // outputEnc.prop = inputEnc.prop * trans.order, + // where `x * y` means we apply permutation y to x. + // + // This works because prop[i] tells you something about the i'th dimension of + // the tensor. (For example, sizePerThread[2] == 4 means that one GPU thread + // contains 4 elements along dim 2 of the tensor.) The transpose reorders the + // dimensions according to the perm trans.order, so we achieve our goal of + // having a "nop" transpose by reordering the values in the prop the same way. + // + // The big exception to this is the encoding's `order`. + // + // An encoding's order is a list of dimensions, from fastest moving (most + // minor) to slowest moving. Thus enc.order[i] does not tell you something + // about the i'th dimension of the tensor, and it would be disasterously + // incorrect to do enc.order * trans.order. + // + // But! If we invert enc.order, it *does* meet this criterion. For example, + // if enc.order = [2,0,1], inverse(enc.order) = [1,2,0]. If you stare at it, + // you'll see that inverse(enc.order)[i] == j means that dimension i is the + // j'th most minor. Therefore we can safely permute *this* by trans.order. + // + // Thus we have + // + // outputEnc.order = inverse(inverse(inputEnc.order) * trans.order) + // = inverse(trans.order) * inputEnc.order. + // + LogicalResult inferTransOpEncoding(Attribute operandEncoding, + ArrayRef shape, + ArrayRef order, // trans order + Attribute &resultEncoding) const override { + // Note: inferFooOpEncoding should not crash if given invalid inputs, which + // happens when someone creates invalid IR. If we return failure() on + // error, then MLIR will generate a helpful error message. + + auto *ctx = getDialect()->getContext(); + auto invOrder = inversePermutation(order); + SmallVector invOrderUnsigned(invOrder.begin(), invOrder.end()); + + auto permuteCTALayout = + [&](const CTALayoutAttr &layout) -> FailureOr { + auto n = order.size(); + if (layout.getCTAsPerCGA().size() != n || + layout.getCTASplitNum().size() != n || + layout.getCTAOrder().size() != n) { + return failure(); + } + + return CTALayoutAttr::get( + ctx, applyPermutation(layout.getCTAsPerCGA(), order), + applyPermutation(layout.getCTASplitNum(), order), + applyPermutation(invOrderUnsigned, layout.getCTAOrder())); + }; + + if (auto enc = + mlir::dyn_cast(operandEncoding)) { + if (enc.getOrder().size() != order.size()) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = SwizzledSharedEncodingAttr::get( + ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(), + applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout); + return success(); + } + + if (auto enc = mlir::dyn_cast(operandEncoding)) { + if (order != ArrayRef({1, 0})) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = NVMMASharedEncodingAttr::get( + ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(), + enc.getElementBitWidth(), enc.getFp4Padded(), *ctaLayout); + return success(); + } + + if (auto enc = mlir::dyn_cast(operandEncoding)) { + auto n = order.size(); + if (enc.getSizePerThread().size() != n || + enc.getThreadsPerWarp().size() != n || + enc.getWarpsPerCTA().size() != n || enc.getOrder().size() != n) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = BlockedEncodingAttr::get( + ctx, applyPermutation(enc.getSizePerThread(), order), + applyPermutation(enc.getThreadsPerWarp(), order), + applyPermutation(enc.getWarpsPerCTA(), order), + applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout); + return success(); + } + auto ll = toLinearLayout(shape, operandEncoding); + auto namedBases = ll.getBases(); + for (auto &bases : llvm::make_second_range(namedBases)) { + for (auto &b : bases) { + std::vector newB; + for (auto i : order) { + newB.push_back(b[i]); + } + b = std::move(newB); + } + } + auto retLl = LinearLayout(std::move(namedBases), + llvm::to_vector(ll.getOutDimNames())); + resultEncoding = LinearEncodingAttr::get(ctx, std::move(retLl)); + return success(); + } + + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const override { + auto sliceEncoding = mlir::dyn_cast(operandEncoding); + if (!sliceEncoding) + return emitOptionalError( + location, "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + if (sliceEncoding.getDim() != axis) + return emitOptionalError( + location, "Incompatible slice dimension for ExpandDimsOp operand"); + resultEncoding = sliceEncoding.getParent(); + return success(); + } + + LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const override { + auto mmaRetEncoding = mlir::dyn_cast(retEncoding); + if (mmaRetEncoding && mmaRetEncoding.isHopper()) { + auto dotOpEnc = mlir::dyn_cast(operandEncoding); + if (!mlir::isa(operandEncoding) && + !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && + mlir::isa(dotOpEnc.getParent()))) { + return emitOptionalError( + location, "unexpected operand layout for NvidiaMmaEncodingAttr v3"); + } + } else if (auto dotOpEnc = + mlir::dyn_cast(operandEncoding)) { + if (opIdx != dotOpEnc.getOpIdx()) + return emitOptionalError(location, "Wrong opIdx"); + if (retEncoding != dotOpEnc.getParent()) + return emitOptionalError(location, "Incompatible parent encoding"); + } else + return emitOptionalError( + location, "Dot's a/b's encoding should be of DotOperandEncodingAttr"); + return success(); + } + + LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const override { + auto aEncoding = + mlir::dyn_cast(operandEncodingA); + auto bEncoding = + mlir::dyn_cast(operandEncodingB); + if (!aEncoding && !bEncoding) + return mlir::success(); + auto mmaAEncoding = + mlir::dyn_cast_or_null(aEncoding.getParent()); + if (mmaAEncoding && mmaAEncoding.isHopper()) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return op->emitError("mismatching encoding between A and B operands"); + if (aEncoding.getKWidth() != bEncoding.getKWidth()) + return op->emitError("mismatching kWidth between A and B operands"); + return success(); + } + + // Given a src shape + encoding and a dst shape, our goal is to compute a dst + // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] + // contains elements [a,b,c,d] before the reshape, it contains those same + // elements after the reshape, they're just "renamed". + // + // Using legacy layouts, a dst encoding that satisfies this property may not + // exist. Here are some positive and negative examples. + // + // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so + // dim 1 is the fastest-changing in the dst, but the src has the opposite + // order. + // - OK: 2x2x32 order=[1,0,2] -> 4x32. We choose dst order [0,1]. + // What's important is that the 2x2 dimensions appear in major-to-minor + // order. + // - NOT OK: 32x32 sizePerThread=[2,2] -> 1024. Thread 0 in the src + // contains elements [(0,0), (0,1), (1,0), and (1,1)]. We cannot express + // this with an encoding based on the dst shape. + // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will + // contain the same elements as before. + // + // With linear layouts, we can always find a dst encoding that satisfies + // this property. See inferReshapeOpEncoding. + // + // Users of this function require that it is symmetrical: if + // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => + // srcEnc. + LogicalResult inferReshapeOpLegacyEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + Attribute &dstEnc) const { + auto src = mlir::dyn_cast(srcEnc); + if (!src) { + return failure(); + } + + // Nop reshape; we can always infer an encoding. + if (srcShape == dstShape) { + dstEnc = srcEnc; + return success(); + } + + // default -> default encoding is always a nop. + auto context = srcEnc.getContext(); + int32_t numWarps = product(src.getWarpsPerCTA()); + int32_t threadsPerWarp = product(src.getThreadsPerWarp()); + int32_t numCTAs = product(src.getCTALayout().getCTAsPerCGA()); + if (srcEnc == getDefaultBlockedEncoding(context, srcShape, numWarps, + threadsPerWarp, numCTAs)) { + dstEnc = getDefaultBlockedEncoding(context, dstShape, numWarps, + threadsPerWarp, numCTAs); + return success(); + } + + // Feature flag to disable this routine while it's relatively new. + // TODO(jlebar): Remove this once we're confident in the code. + if (triton::tools::getBoolEnv( + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE")) { + return failure(); + } + + // Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA + // should be like the other fields in blocked encoding, but I'm not sure how + // to handle CTASplitNum. + if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || + !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { + return failure(); + } + + // Cowardly refuse to handle encodings where shape[dim] is not divisible by + // sizePerThread[dim], threadsPerWarp[dim], and warpsPerCTA[dim]. (We make + // an exception if the block is larger than the shape.) + auto checkDivisibility = [&](StringRef name, ArrayRef subblock) { + for (int dim = 0; dim < srcShape.size(); dim++) { + if (srcShape[dim] >= subblock[dim] && + srcShape[dim] % subblock[dim] != 0) { + return failure(); + } + } + return success(); + }; + if (!succeeded( + checkDivisibility("sizePerThread", src.getSizePerThread())) || + !succeeded( + checkDivisibility("threadsPerWarp", src.getThreadsPerWarp())) || + !succeeded(checkDivisibility("warpsPerCTA", src.getWarpsPerCTA()))) { + return failure(); + } + + SmallVector, SmallVector>> decomp = + getReshapeDecomposition(srcShape, dstShape); + + // enc.order[i] == j means that dimension j is the enc.order[i]'th most + // minor. But what we usually want is the inverse: inverse(enc.order)[i] = j + // means that dimension i is the j'th most minor (larger means more major). + auto srcInvOrder = inversePermutation(src.getOrder()); + + // If src dims [a,b,c] are to be merged, then they must be consecutive in + // physical order, with `a` being the most major. + for (const auto &[srcDims, dstDims] : decomp) { + if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { + return failure(); + } + } + + // If src dims [a,b,c] are to be merged, then `c` must fill up sizePerThread + // / threadsPerWarp / blocksPerCTA before `b` can have any non-1 values. + // Examples: + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,2,2]. + // The total sizePerThread for dim 2 is 2, which is less than dim 2's + // size of 4. Therefore dim 1 cannot have non-1 sizePerThread. + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4]. + // Dim 2's sizePerThread covers its whole size, so dim 1 is allowed to + // have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[2,1,4]. + // Dim 1's sizePerThread does not cover its whole size, so dim 0 is not + // allowed to have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,1,2], + // threadsPerWarp=[1,2,1]. + // Dim 2 has 2 elems per thread and 1 thread per warp. 2*1 is less than + // dim 2's size. Therefore dim 1 must have threadsPerWarp=1. + // + // In addition, the encoding's block can be larger than the shape, but only + // in the most-major dimension of each decomposed chunk, and only after + // we've "used up" the more minor dims. Examples: + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4], threadsPerWarp=[16,2,1], + // warpsPerCTA=[4,1,1]. + // The whole size of dims 0 and 1 are covered by sizePerThread * + // threadsPerWarp. Therefore dim 2 is allowed to have threadsPerWarp and + // warpsPerCTA larger than its size. + for (const auto &[srcDims, dstDims] : decomp) { + auto shapeRemaining = gather(srcShape, srcDims); + auto checkSubblock = [&, srcDims = srcDims](ArrayRef subblock) { + // Iterate minor-to-major (i==0 is most major). + for (int i = srcDims.size() - 1; i >= 0; i--) { + int dim = srcDims[i]; + if (subblock[dim] == 1) { + continue; + } + + // Check that more-minor dims all have 1 in shapeRemaining. + for (int j = i + 1; j < srcDims.size(); j++) { + if (shapeRemaining[j] != 1) { + return failure(); + } + } + + if (shapeRemaining[i] >= subblock[dim]) { + assert(shapeRemaining[i] % subblock[dim] == 0); // checked earlier + shapeRemaining[i] /= subblock[dim]; + } else { + shapeRemaining[i] = 0; + } + + // Is the block larger than the shape in this dimension? This is OK + // only if we're the most-major dimension of the chunk and in all + // future chunks, only this most-major dim has a non-1 size. + if (shapeRemaining[i] == 0 && i != 0) { + return failure(); + } + } + return success(); + }; + if (!succeeded(checkSubblock(src.getSizePerThread())) || + !succeeded(checkSubblock(src.getThreadsPerWarp())) || + !succeeded(checkSubblock(src.getWarpsPerCTA()))) { + return failure(); + } + } + + // Given e.g. src.getSizePerThread(), computeSubblockSize computes e.g. + // dst.getSizePerThread(). This should be called for each of sizePerThread, + // threadsPerWarp, and warpsPerCTA, in that order. + SmallVector dstShapeRemaining(dstShape); + auto computeSubblockSize = [&](ArrayRef srcSubblock, + SmallVector &dstSubblock, + StringRef fieldName) -> LogicalResult { + // The dst subblock is "filled up" greedily starting with the most minor + // dim. When we're done, we are left with a smaller shape, of size + // dstShape / dstSubblock, which we store in dstShapeRemaining and use for + // the next call to computeSubblockSize. + dstSubblock.resize(dstShape.size()); + for (const auto &[srcDims, dstDims] : decomp) { + int64_t subblockRemaining = product(gather(srcSubblock, srcDims)); + for (int i = dstDims.size() - 1; i >= 0; i--) { + auto &val = dstSubblock[dstDims[i]]; + auto &shapeRemaining = dstShapeRemaining[dstDims[i]]; + val = std::min(subblockRemaining, shapeRemaining); + + assert(shapeRemaining % val == 0); // Checked earlier. + subblockRemaining /= val; + shapeRemaining /= val; + } + + // If there are any elems remaining in the subblock, it must be because + // the block is larger than the shape. This excess goes into the + // most-major dim of the subblock. + dstSubblock[dstDims[0]] *= subblockRemaining; + } + return success(); + }; + + SmallVector dstSizePerThread; + SmallVector dstThreadsPerWarp; + SmallVector dstWarpsPerCTA; + if (!succeeded(computeSubblockSize(src.getSizePerThread(), dstSizePerThread, + "sizePerThread")) || + !succeeded(computeSubblockSize(src.getThreadsPerWarp(), + dstThreadsPerWarp, "threadsPerWarp")) || + !succeeded(computeSubblockSize(src.getWarpsPerCTA(), dstWarpsPerCTA, + "warpsPerCTA"))) { + return failure(); + } + + // Since we know that each set of srcDims is consecutive, we can + // meaningfully sort decomp by the physical order of the src dimensions, + // major-to-minor. This will also be the order of the dst dimensions. + llvm::sort(decomp, [&](const auto &a, const auto &b) { + const auto &[srcDimsA, dstDimsA] = a; + const auto &[srcDimsB, dstDimsB] = b; + return srcInvOrder[srcDimsA.front()] < srcInvOrder[srcDimsB.front()]; + }); + + // Compute the dst order. Make the dimensions appear in the same order as + // their corresponding src dimensions. + SmallVector dstInvOrder(dstShape.size()); + int i = 0; + for (const auto &[srcDims, dstDims] : decomp) { + for (auto dim : reverse(dstDims)) { + dstInvOrder[dim] = i++; + } + } + auto dstOrder = inversePermutation(dstInvOrder); + + // CTALayout can be all 1's because we bailed on multi-CTA layouts above. + auto CTALayout = CTALayoutAttr::get( + src.getContext(), + /*CTAsPerCGA=*/SmallVector(dstShape.size(), 1), + /*CTASplitNum=*/SmallVector(dstShape.size(), 1), + /*CTAOrder=*/llvm::to_vector(llvm::seq(dstShape.size()))); + + dstEnc = BlockedEncodingAttr::get(src.getContext(), dstSizePerThread, + dstThreadsPerWarp, dstWarpsPerCTA, + dstOrder, CTALayout); + + return success(); + } + + LogicalResult + verifyLayoutsAreEqual(ArrayRef shape, Attribute expected, + Attribute got, + std::optional loc) const override { + if (expected == got) { + return success(); + } + if (!expected || !got) + return failure(); + // Check whether the encodings are structurally the same. + auto expectedLL = triton::gpu::toLinearLayout(shape, expected); + auto gotLL = triton::gpu::toLinearLayout(shape, got); + if (expectedLL != gotLL) { + return emitOptionalError(loc, "Expected result encoding ", expected, + " but was ", got); + } + return success(); + } + + LogicalResult + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + auto result = + inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc); + if (succeeded(result)) { + return result; + } + + // If the legacy encoding failed use LinearLayouts. + // Once LinearLayouts are more widely used, we can remove + // inferReshapeOpLegacyEncoding and simply use LLs. + auto *ctx = getContext(); + auto src = toLinearLayout(srcShape, srcEnc); + + if (product(srcShape) != product(dstShape)) { + return emitOptionalError(loc, "numel of dst shape does not match " + "numel of src shape"); + } + + auto newRank = dstShape.size(); + + auto newOutDims = standardOutDimPairs(ctx, dstShape); + + // reshapeOp assumes minor-to-major, so we need to transpose the out dims + // before the reshape + auto srcOutDims = to_vector(src.getOutDimNames()); + std::reverse(srcOutDims.begin(), srcOutDims.end()); + std::reverse(newOutDims.begin(), newOutDims.end()); + auto dst = src.transposeOuts(srcOutDims) + .reshapeOuts(newOutDims) + .transposeOuts(standardOutDimNames(ctx, newRank)); + dstEnc = LinearEncodingAttr::get(ctx, dst); + return success(); + } + + LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const override { + if (auto enc = mlir::dyn_cast(srcEnc)) { + // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape + // AxBxCx2. The encoding is the same as the input, but with 2 elems per + // thread in the new dimension. The new dimension is most-minor. + auto append = [](ArrayRef vals, int val) { + SmallVector ret(vals); + ret.push_back(val); + return ret; + }; + auto appendMinorDim = [](ArrayRef order) { + SmallVector ret(order); + ret.insert(ret.begin(), ret.size()); + return ret; + }; + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + append(enc.getSizePerThread(), 2), // + append(enc.getThreadsPerWarp(), 1), // + append(enc.getWarpsPerCTA(), 1), // + appendMinorDim(enc.getOrder()), // + CTALayoutAttr::get(enc.getContext(), // + append(enc.getCTAsPerCGA(), 1), + append(enc.getCTASplitNum(), 1), + appendMinorDim(enc.getCTAOrder()))); + return success(); + } + + auto ctx = getContext(); + + // Append dim to shape + auto ll = toLinearLayout(shape, srcEnc); + SmallVector dstShape(shape.begin(), shape.end()); + dstShape.push_back(1); + ll = ll.reshapeOuts(standardOutDimPairs(ctx, dstShape)); + + // Try join on last dim + auto axis = dstShape.size() - 1; + auto newLl = LinearLayout::empty(); + auto result = + tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/true, axis, loc); + + assert(result.succeeded()); + dstEnc = LinearEncodingAttr::get(ctx, newLl); + return success(); + } + + LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const override { + auto enc = mlir::dyn_cast(srcEnc); + if (enc) { + // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of + // shape AxBxC. The input must have 2 elements per thread in the last + // dimension, which must be most-minor. The result encoding is the same + // as the input, but with the last dimension removed. + if (enc.getSizePerThread().back() != 2) { + return emitOptionalError( + loc, "SplitOp requires 2 elements per thread in the " + "last dimension of the input"); + } + if (enc.getThreadsPerWarp().back() != 1 || + enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, "SplitOp requires threadsPerWarp, warpsPerCTA, " + "and CTAsPerCGA = 1 for the last dimension of the input"); + } + if (enc.getCTALayout().getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, + "SplitOp requires the last dimension to be most-minor in CTAOrder"); + } + SmallVector newOrder(enc.getOrder()); + int splitDim = newOrder.size() - 1; + // Remove splitDim from order. + newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim), + newOrder.end()); + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + ArrayRef(enc.getSizePerThread()).drop_back(1), + ArrayRef(enc.getThreadsPerWarp()).drop_back(1), + ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder), + CTALayoutAttr::get(enc.getContext(), // + ArrayRef(enc.getCTAsPerCGA()).drop_back(1), + ArrayRef(enc.getCTASplitNum()).drop_back(1), + ArrayRef(enc.getCTAOrder()).drop_front(1))); + return success(); + } + + auto axis = shape.size() - 1; + assert(shape[axis] == 2 && + "SplitOp input shape should have 2 in the last dim"); + + auto ctx = getContext(); + + // Split on last dim + auto ll = toLinearLayout(shape, srcEnc); + auto newLl = LinearLayout::empty(); + auto result = + tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/false, axis, loc); + if (!result.succeeded()) { + return failure(); + } + + // Remove last dim from newLl (which should be 1) + SmallVector dstShape(shape.begin(), shape.end()); + dstShape.pop_back(); + newLl = newLl.reshapeOuts(standardOutDimPairs(ctx, dstShape)); + dstEnc = LinearEncodingAttr::get(ctx, newLl); + return success(); + } + + LogicalResult + inferFp4ToFpOpEncoding(ArrayRef shape, int axis, Attribute inEnc, + Attribute &outEnc, bool fwdInference, + std::optional loc) const override { + // We implement two legacy layout propagations + // Once we fully migrate to LinearLayouts, we can remove these. + auto *ctx = getContext(); + auto rank = shape.size(); + // The output encoding will only be a legacy encoding if the axis is the + // fastest running dimension. + // FIXME: We should make sure that there are enough elements along the axis + // axis whenever fwdInference is false + if (cast(inEnc).getDefaultOrder()[axis] == 0) { + // Dot operand: double kWidth if kDim == axis. + if (auto dotEnc = mlir::dyn_cast(inEnc)) { + auto kWidth = dotEnc.getKWidth(); + if (fwdInference) { + kWidth *= 2; + } else { + if (kWidth > 1) { + // bwd inference + kWidth /= 2; + } else { + return emitOptionalError(loc, + "Fp4ToFpOp requires at least 2 elements " + "per thread in the axis dimension"); + } + } + outEnc = DotOperandEncodingAttr::get(ctx, dotEnc.getOpIdx(), + dotEnc.getParent(), kWidth); + return success(); + } + + // Blocked layout: double elemsPerThread[axis]. + if (auto blockedEnc = mlir::dyn_cast(inEnc)) { + auto sizePerThread = llvm::to_vector(blockedEnc.getSizePerThread()); + if (fwdInference) { + sizePerThread[axis] *= 2; + } else { + if (sizePerThread[axis] > 1) { + sizePerThread[axis] /= 2; + } else { + return emitOptionalError( + loc, "Fp4ToFpOp requires at least 2 elements per " + "thread in the axis dimension"); + } + } + outEnc = BlockedEncodingAttr::get( + ctx, sizePerThread, blockedEnc.getThreadsPerWarp(), + blockedEnc.getWarpsPerCTA(), blockedEnc.getOrder(), + blockedEnc.getCTALayout()); + return success(); + } + } + + auto ll = toLinearLayout(shape, inEnc); + auto newLl = LinearLayout::empty(); + auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc); + if (!result.succeeded()) + return result; + outEnc = LinearEncodingAttr::get(ctx, newLl); + return success(); + } +}; + +struct TritonGPUVerifyTensorLayoutInterface + : public triton::DialectVerifyTensorLayoutInterface { + using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface; + + LogicalResult verifyTensorLayout( + Attribute layout, RankedTensorType rankedTy, Operation *op, + function_ref makeErr) const override { + if (isa(layout)) + return makeErr() << "Shared layout is not allowed on tensor type."; + // TODO(jlebar): Currently this only checks blocked layouts, but other + // layouts also have invariants! + + // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. + if (auto blocked = dyn_cast(layout)) { + ModuleOp module = op->getParentOfType(); + + // A different verifier should have checked that the layout itself is + // valid, including that threads-per-warp has the same rank as + // warps-per-block etc. + auto layoutRank = blocked.getThreadsPerWarp().size(); + if (layoutRank != rankedTy.getRank()) { + return makeErr() << layout << ".\nLayout has rank " << layoutRank + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + } + + int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module); + int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutThreadsPerWarp + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + std::optional moduleWarpsPerCTA = maybeLookupNumWarps(op); + if (!moduleWarpsPerCTA) { + return makeErr() + << "Could not determine the number of warps per CTA. Operation " + "is not in a context with `ttg.num-warps`."; + } + int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); + if (layoutWarpsPerCTA != *moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutWarpsPerCTA + << " warps per CTA, but the context requires " + << *moduleWarpsPerCTA << " warps per CTA."; + } + + if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { + int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module); + int64_t layoutCTAsPerCGA = + product(blocked.getCTALayout().getCTAsPerCGA()); + if (layoutCTAsPerCGA != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutCTAsPerCGA + << " CTAs per CGA, but the module specifies " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + } + } + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Layout debug printing +//===----------------------------------------------------------------------===// + +// Return N-D delinearized indices from a linear index. +static SmallVector delinearizeIndex(int64_t idx, + ArrayRef shape) { + SmallVector ret(shape.size()); + for (int i = shape.size() - 1; i >= 0; i--) { + ret[i] = idx % shape[i]; + idx /= shape[i]; + } + return ret; +} + +// Returns how many padding characters are needed for the string representation +// of value to be the same as max. +static int numCharacterPadding(int value, int max) { + return std::to_string(max).size() - std::to_string(value).size(); +} + +// return the string padded to have the same length as max. +static std::string paddedString(int value, int max) { + int nbChar = numCharacterPadding(value, max); + std::string str; + for (int i = 0; i < nbChar; i++) + str += " "; + str += std::to_string(value); + return str; +} + +std::string getSharedLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + if (!layout) + return ""; + + LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout); + + StringAttr kOffset = StringAttr::get(tensorType.getContext(), "offset"); + StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); + int64_t tensorSize = product(tensorType.getShape()); + unsigned numBlocks = getNumCTAs(layout); + int32_t blockSize = tensorSize / numBlocks; + + // elementMapping is for the non-hw layout, offsetMapping for hw-layout + std::vector elementMapping(tensorSize); + std::vector offsetMapping; + + // Shared layouts are a mapping of (block, offset) --> (...) + + // We can just use a single int to index into elementMapping because + // the 'swizzle' operation rearranges the indicies---and we want to keep it + // that way + int32_t idx = 0; + // Enumerate all the offsets for each block + for (int32_t block = 0; block < numBlocks; block++) { + for (int32_t offset = 0; offset < blockSize; offset++) { + SmallVector> inputs = { + {kBlock, block}, + {kOffset, offset}, + }; + + SmallVector> outputs = ll.apply(inputs); + + std::string sharedInfo = "("; + std::string &value = elementMapping[idx]; + + if (!value.empty()) + value += "|"; + + value += "("; + // We can build up both strings (for hw/non-hw layouts) concurrently + for (int i = 0; i < outputs.size(); i++) { + // Based on the formatting from LinearLayout::toString, the format for + // the hw layout is slightly different. HW layouts use "," vs ":". + if (i > 0) { + sharedInfo += ","; + value += ":"; + } + auto index = paddedString(outputs[i].second, tensorType.getDimSize(i)); + sharedInfo += index; + value += index; + } + value += ")"; + sharedInfo += ")"; + + offsetMapping.push_back(sharedInfo); + + idx++; + } + } + + std::string layoutStr; + + if (!useHWPointOfView) { + int rank = tensorType.getRank(); + bool newLine = true; + for (int i = 0; i < tensorSize; i++) { + auto indices = delinearizeIndex(i, tensorType.getShape()); + int numOpenBracket = 0; + for (int j = rank - 1; j >= 0; j--) { + if (indices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "["; + numOpenBracket++; + } + if (newLine) { + for (int j = 0; j < rank - numOpenBracket; j++) + layoutStr += " "; + newLine = false; + } + + layoutStr += elementMapping[i]; + auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape()); + for (int j = rank - 1; j >= 0; j--) { + if (nextIndices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "]"; + } + if (nextIndices.back() % tensorType.getShape().back() == 0) { + layoutStr += "\n"; + newLine = true; + } else { + layoutStr += ","; + } + } + } else { + // For the HW view here, print the (block, offset) --> (r,c) mapping + uint32_t idx = 0; + for (int32_t block = 0; block < numBlocks; block++) { + layoutStr += "Block: " + std::to_string(block) + ":\n"; + for (int32_t offset = 0; offset < (tensorSize / numBlocks); offset++) { + layoutStr += "Offset: " + std::to_string(offset) + " -> "; + layoutStr += offsetMapping[idx]; + layoutStr += "\n"; + idx++; + } + } + } + + return layoutStr; +} + +std::string getDistributedLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + if (!layout) + return ""; + + StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register"); + StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane"); + StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp"); + StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); + + LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout); + int64_t tensorSize = product(tensorType.getShape()); + std::vector elementMapping(tensorSize); + std::vector threadMapping; + unsigned threadsPerWarp = ll.getInDimSize(kLane); + unsigned numWarpsPerCTA = ll.getInDimSize(kWarp); + unsigned numBlocks = ll.getInDimSize(kBlock); + int numElementsPerThreads = ll.getInDimSize(kRegister); + for (int blockId = 0; blockId < numBlocks; ++blockId) { + for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { + for (int tid = 0; tid < threadsPerWarp; ++tid) { + for (int idx = 0; idx < numElementsPerThreads; ++idx) { + SmallVector> inputs = { + {kBlock, blockId}, + {kWarp, warpId}, + {kLane, tid}, + {kRegister, idx}}; + SmallVector> outputs = + ll.apply(inputs); + int32_t linearizedIdx = 0; + int stride = 1; + for (int i = outputs.size() - 1; i >= 0; i--) { + linearizedIdx += outputs[i].second * stride; + stride *= tensorType.getDimSize(i); + } + std::string &value = elementMapping[linearizedIdx]; + if (!value.empty()) + value += "|"; + int padding = numCharacterPadding(blockId, numBlocks) + + numCharacterPadding(tid + warpId * threadsPerWarp, + numWarpsPerCTA * threadsPerWarp) + + numCharacterPadding(idx, numElementsPerThreads); + for (int i = 0; i < padding; i++) + value += " "; + if (numBlocks > 1) + value += "B" + std::to_string(blockId) + ":"; + value += "T" + std::to_string(tid + warpId * threadsPerWarp) + ":" + + std::to_string(idx); + // Now also compute the thread mapping. + std::string threadInfo = "("; + for (int i = 0; i < outputs.size(); i++) { + if (i > 0) + threadInfo += ","; + threadInfo += + paddedString(outputs[i].second, tensorType.getDimSize(i)); + } + threadInfo += ")"; + threadMapping.push_back(threadInfo); + } + } + } + } + std::string layoutStr; + if (!useHWPointOfView) { + // Printing the threads containing each elements of the tensor. + int rank = tensorType.getRank(); + bool newLine = true; + for (int i = 0; i < tensorSize; i++) { + auto indices = delinearizeIndex(i, tensorType.getShape()); + int numOpenBracket = 0; + for (int j = rank - 1; j >= 0; j--) { + if (indices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "["; + numOpenBracket++; + } + if (newLine) { + for (int j = 0; j < rank - numOpenBracket; j++) + layoutStr += " "; + newLine = false; + } + + layoutStr += elementMapping[i]; + auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape()); + for (int j = rank - 1; j >= 0; j--) { + if (nextIndices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "]"; + } + if (nextIndices.back() % tensorType.getShape().back() == 0) { + layoutStr += "\n"; + newLine = true; + } else { + layoutStr += ", "; + } + } + } else { + // Printing the elements in each physical reg/warps/threads. + for (int blockId = 0; blockId < numBlocks; blockId++) { + if (numBlocks > 1) + layoutStr += "Block" + std::to_string(blockId) + ":\n"; + for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { + layoutStr += "Warp" + std::to_string(warpId) + ":\n"; + for (int idx = 0; idx < numElementsPerThreads; ++idx) { + for (int tid = 0; tid < threadsPerWarp; ++tid) { + int linearizedIdx = + blockId * numWarpsPerCTA * threadsPerWarp * + numElementsPerThreads + + warpId * threadsPerWarp * numElementsPerThreads + + tid * numElementsPerThreads + idx; + layoutStr += threadMapping[linearizedIdx]; + if (tid < threadsPerWarp - 1) + layoutStr += ", "; + } + layoutStr += "\n"; + } + } + } + } + return layoutStr; +} + +template +llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch(llvm::ArrayRef s) { + auto rank = s.size(); + assert(rank == 2 || rank == 3); + if (rank == 3) + return llvm::SmallVector(s); + return {1, s[0], s[1]}; +} + +template llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch( + llvm::ArrayRef s); + +template llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch( + llvm::ArrayRef s); + +llvm::SmallVector +mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef o) { + int rank = o.size(); + assert(rank == 2 || rank == 3); + if (rank == 3) + return llvm::SmallVector(o); + llvm::SmallVector expanded(3, 0); + for (int i = 0; i < rank; ++i) + expanded[i] += o[i] + 1; + return expanded; +} + +std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + + // tensorType is needed later on (e.g., getDimSize(j)), so we still have to + // pass it as a param + if (mlir::isa(layout)) { + return getSharedLayoutStr(tensorType, useHWPointOfView); + } else if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return getDistributedLayoutStr(tensorType, useHWPointOfView); + } + + // else unimplemented, return error + llvm::report_fatal_error("Unimplemented usage of getLayoutStr"); + return ""; +} + +void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) { + llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/false); +} + +void mlir::triton::gpu::dumpHWLayout(RankedTensorType tensorType) { + llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/true); +} + +namespace { +struct TensorModel + : public triton::gpu::TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public triton::gpu::TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; +} // namespace + +void TritonGPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc" + >(); + addInterfaces(); + addInterfaces(); + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); +} + +LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // Verify that dialect attributes are attached to the right ops. + if (llvm::is_contained( + {AttrNumCTAsName, AttrTargetName, AttrNumThreadsPerWarp}, + attr.getName()) && + !isa(op)) { + return op->emitOpError("has unexpected attribute ") + << attr.getName() << " which is expected only on `module` ops"; + } + if (attr.getName() == AttrNumWarpsName && !isa(op)) { + return op->emitOpError("has unexpected attribute ") + << attr.getName() + << " which is expected only on `module` or `tt.func` ops"; + } + + return success(); +} + +int TritonGPUDialect::getNumCTAs(ModuleOp module) { + if (auto attr = module->getAttrOfType(AttrNumCTAsName)) + return attr.getInt(); + return 1; +} + +int TritonGPUDialect::getThreadsPerWarp(ModuleOp module) { + if (auto attr = module->getAttrOfType(AttrNumThreadsPerWarp)) + return attr.getInt(); + return 32; +} + +std::optional triton::gpu::maybeLookupNumWarps(Operation *op) { + if (isa(op)) { + if (auto attr = op->getAttrOfType(AttrNumWarpsName)) + return attr.getInt(); + } else if (auto partitions = + dyn_cast(op->getParentOp())) { + unsigned idx = op->getParentRegion()->getRegionNumber(); + return partitions.getParentOp().getPartitionNumWarps()[idx]; + } + if (Operation *parent = op->getParentOp()) + return maybeLookupNumWarps(parent); + return {}; +} + +int triton::gpu::lookupNumWarps(Operation *op) { + std::optional numWarps = maybeLookupNumWarps(op); + if (!numWarps) { + op->emitOpError( + "is not contained within a context that specifies the number of warps"); + llvm::report_fatal_error("failed to lookup the number of warps, the " + "surrounding module should contain a " + + Twine(AttrNumWarpsName) + " attribute"); + } + return *numWarps; +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp new file mode 100644 index 000000000..5f7b8fe6f --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -0,0 +1,1559 @@ +#include + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" + +using mlir::triton::ScaleDotElemType; + +namespace mlir::triton::gpu { +namespace { + +// We use the following nomenclature in this file. +// +// - ctaLayout: A layout for one block, i.e. input dims [register, lane, warp] +// for register layouts, and input dims [offset] for shared layouts. +// - cgaLayout: Arrangement of multiple blocks, i.e. input dims [block]. +// +// Note that this is inconsistent with the type name CTALayoutAttr. That type +// is equivalent to our cgaLayout. +// +// IMO the name CTALayoutAttr is wrong. If we tried to be consistent anyway, +// then we'd have to rename ctaLayout to "warpLayout". I think that's more +// confusing than being inconsistent about "cgaLayout", especially when we have +// to consider the size of the warpLayout (surely that's not the "warpSize"). + +#define S(v) StringAttr::get(ctx, (v)) + +// TODO Have order be a mandatory argument of standardOutDimNames. +SmallVector permuteDimNames(const SmallVector &names, + const SmallVector &order) { + assert(names.size() == order.size()); + SmallVector ret; + for (unsigned i : order) { + ret.push_back(names[i]); + } + return ret; +} + +// Make a LinearLayout that maps a block-id to an N-dimensional index. +// +// The tensor is split up into CTAsPerCGA pieces, which are distributed among +// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). +// +// See the nomenclature note at the top of the file for an explanation of why +// this is called makeCgaLayout when it accepts a CTALayoutAttr. +LinearLayout makeCgaLayout(CTALayoutAttr layout) { + MLIRContext *ctx = layout.getContext(); + StringAttr kBlock = S("block"); + + int rank = layout.getCTAOrder().size(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < rank; i++) { + // Start with the most minor dimension, which is order[0]. + int dim = layout.getCTAOrder()[i]; + int split = layout.getCTASplitNum()[dim]; + int ctas = layout.getCTAsPerCGA()[dim]; + assert(ctas % split == 0); + ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * + LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]); + } + + // Transpose to standard order (dim0, dim1, ...). + return ret.transposeOuts(outDimNames); +} + +// Combines the layout of a CTA (input dims [register, lane, warp]) with the +// layout of a CGA (i.e. a block), and ensures that the resulting layout has the +// given shape. +// +// See the nomenclature note at the top of the file for why the variable with +// type CTALayoutAttr is called cgaLayoutAttr. +LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, + CTALayoutAttr cgaLayoutAttr, + ArrayRef shape) { + int rank = shape.size(); + assert(ctaLayout.getNumOutDims() == rank); + assert(cgaLayoutAttr.getCTAOrder().size() == rank); + MLIRContext *ctx = cgaLayoutAttr.getContext(); + + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + llvm::SmallDenseMap labeledShape; + for (auto [dim, size] : llvm::zip(outDimNames, shape)) { + labeledShape[dim] = size; + } + + LinearLayout cgaLayout = + ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + // Calculate the shape of the ctaLayout, which is `shape` divided by the + // cgaLayout's size. + llvm::SmallDenseMap ctaShape; + assert(llvm::to_vector(ctaLayout.getOutDimNames()) == + llvm::to_vector(cgaLayout.getOutDimNames())); + for (auto dim : ctaLayout.getOutDimNames()) { + ctaShape[dim] = + std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim)); + } + + ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape); + ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape); + + LinearLayout ret = (ctaLayout * cgaLayout).transposeOuts(outDimNames); + for (auto dim : ret.getOutDimNames()) { + assert(ret.getOutDimSize(dim) == labeledShape[dim]); + } + return ret; +} + +LinearLayout +sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, + SwizzledSharedEncodingAttr shared) { + MLIRContext *ctx = shared.getContext(); + int rank = shape.size(); + if (rank == 1) { + return combineCtaCgaWithShape( + LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), + shared.getCTALayout(), shape); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for the 2 most minor dimensions of the layout. These are + // the dims that get swizzled. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + int numCols = shape[colDim]; + int numRows = shape[rowDim]; + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int logCol = 0; logCol < llvm::Log2_32(numCols); logCol++) { + bases2D.push_back({0, 1 << logCol}); + } + for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) { + int row = 1 << logRow; + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + bases2D.push_back({row, (vec * ((row / perPhase) % maxPhase)) % numCols}); + } + LinearLayout ctaLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + ctaLayout *= + LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); + } + + return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape); +} + +} // namespace + +LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, + NVMMASharedEncodingAttr shared, + bool disableSwizzle) { + MLIRContext *ctx = shared.getContext(); + int rank = shape.size(); + if (rank == 1) { + // TODO: Not sure if this is correct. + return combineCtaCgaWithShape( + LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), + shared.getCTALayout(), shape); + } + int elemBitWidth = shared.getElementBitWidth(); + int tileWidthBytes = shared.getSwizzlingByteWidth(); + int vec = 128 / elemBitWidth; + int perPhase = 0; + int maxPhase = 0; + if (tileWidthBytes == 32) { + perPhase = 4; + maxPhase = 2; + } else if (tileWidthBytes == 64) { + perPhase = 2; + maxPhase = 4; + } else if (tileWidthBytes == 128) { + perPhase = 1; + maxPhase = 8; + } + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for a the layout's 2-dimensional tile. + assert(rank >= 2); + int batchDims = rank - 2; + int colDim = batchDims + (shared.getTransposed() ? 0 : 1); + int rowDim = batchDims + (shared.getTransposed() ? 1 : 0); + + int tileRows = 8; + int tileCols = 8 * tileWidthBytes / elemBitWidth; + bool isFp4Padded = false; + if (auto sharedMMALayout = + dyn_cast(shared)) { + if (sharedMMALayout.getFp4Padded()) { + isFp4Padded = true; + } + } + int packingFactor = isFp4Padded ? 2 : 1; + + if (shape[colDim] * packingFactor < tileCols || shape[rowDim] < tileRows) { + llvm::errs() << "Illegal shared layout; expected shape to be at least [" + << tileRows << ", " << tileCols << "], shape: [" + << shape[rowDim] << ", " << shape[colDim] << "]\n"; + llvm::report_fatal_error("Illegal shared layout"); + } + + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int logCol = 0; logCol < llvm::Log2_32(tileCols); logCol++) { + if (isFp4Padded) { + int colPadded = 1 << logCol; + // Each group of 16 offsets consists of 8 "real" and 8 "padded" offsets. + // We represent the padded layout by mapping 8 padded offsets to the same + // coordinates as the real ones. When computing the inverse of this LL, + // the offsets correspoding to the real ones are picked in the image by + // invertAndCompose. + int colPacked = colPadded / 16 * 8 + colPadded % 8; + bases2D.push_back({0, colPacked}); + } else { + bases2D.push_back({0, 1 << logCol}); + } + } + for (int logRow = 0; logRow < llvm::Log2_32(tileRows); logRow++) { + int row = 1 << logRow; + if (disableSwizzle) { + bases2D.push_back({row, 0}); + continue; + } + if (isFp4Padded) { + int colPadded = vec * ((row / perPhase) % maxPhase); + int colPacked = colPadded / 16 * 8 + colPadded % 8; + bases2D.push_back({row, colPacked}); + } else { + bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)}); + } + } + LinearLayout tileLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int dim = batchDims - 1; dim >= 0; --dim) { + tileLayout *= + LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); + } + + return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); +} + +/// Function to generate lane and warp layout for dot operands. +static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, + ArrayRef shape, + ArrayRef order, + unsigned kDim, + StringAttr inDimName) { + // Let warpsPerCTAMma = {2, 2}, then + // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB + // assume warpOrder = {1, 0} + // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that + // the C is owned as per the following layout: + // C: 0 | 1 + // - | - + // 2 | 3 + // In order to be able to compute C, we need the following warp tiling of + // A and B: + // A: 0 1 | 0 1 B: 0 2 | 1 3 + // - - | - - - - | - - + // 2 3 | 2 3 0 2 | 1 3 + // In other words, we need to broadcast along K + auto rank = shape.size(); + auto dimNames = standardOutDimNames(ctx, rank); + LinearLayout layout = LinearLayout::empty(); + + // We have to broadcast along the inner dimension + // For A, when moving along M we go from 0 to 2. + // For B, when moving along N we go from 0 to 1. + // As such, choosing the order of A {1, 0}, gives us the correct broadcasting + // Same happens if the warpOrder is {0, 1}, like in Hopper + for (auto d : order) { + if (d == kDim) { + layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]); + } else { + layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]); + } + } + return layout; +} + +LinearLayout +AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + int rank = shape.size(); + assert(rank == getWarpsPerCTA().size()); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + assert(((getMDim() == 32 && getNDim() == 32) || + (getMDim() == 16 && getNDim() == 16)) && + "Unsupported mfma type"); + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. + SmallVector order = getDefaultOrder(); + auto tileLayout = LinearLayout::empty(); + + if (getMDim() == 32) { + // For mfma with 32x32 output, each of the 64 threads holds 16 elements. + // + // For the register (i.e., element) dimension, these 16 elements are along + // the matrix C's M dimension, with 4 consecutive elements spanning 4 rows + // and then the next 4 rows being a gap. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 32 consecutive threads covering a whole + // row and the next 32 threads start after a gap spanning 4 rows. + tileLayout = LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + } else { + assert(getMDim() == 16); + // For mfma with 16x16 output, each of the 64 threads holds 4 elements. + // + // For the register (i.e., element) dimension, these 4 elements are along + // the matrix C's M dimension, with 4 consecutive elements spanning 4 rows. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 16 consecutive threads covering a whole + // row and the next 16 threads start after a gap spanning 4 rows. + tileLayout = LinearLayout( + {{kRegister, {{0, 1}, {0, 2}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + } + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + // And each warp takes the same register and lane sub-layout. So multiply with + // an identity layout for the warp. + LinearLayout warpLayout = + identityStandardND(S("warp"), getWarpsPerCTA(), order); + LinearLayout ctaLayout = tileLayout * warpLayout; + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape, + int32_t elemBitWidth) { + auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); + assert(mfmaLayout.getMDim() == 16 || mfmaLayout.getNDim() == 32); + assert(elemBitWidth == 16 || elemBitWidth == 8); + + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + int32_t kWidthDot = dotMfmaLayout.getKWidth(); + // Number of bits loaded by an LDS read. ds_read_tr primarily supports 64-bit + // loads for most element sizes (16b, 8b, 4b). + const int32_t ldsReadWidth = 64; + int32_t kWidthTransRead = ldsReadWidth / elemBitWidth; + const int elemByteWidth = elemBitWidth / 8; + auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + + int32_t kSize = shape[kDim]; + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + + MLIRContext *ctx = dotMfmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // Regular dot mfma order for both cases is [k, nonk]/[k, nonk, batch] + // For LDS transpose layout swap order to [nonk, k]/[nonk, k, batch] + SmallVector order = dotMfmaLayout.getDefaultOrder(); + std::swap(order[0], order[1]); + + // For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes) + // of data. The smallest unit for transposition is a + // [non-K, K] = {16, kWidthTransRead} sub-tile of elements, + // where each thread reads kWidthTransRead elements along the non-K dimension. + // Due to the transposition mechanism, each thread ends up with + // kWidthTransRead elements along the K dimension. + // + // The MFMA selection logic prioritizes double-rate MFMA instructions whenever + // possible: + // + // - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k + // is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice. + // + // - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is + // selected; otherwise (blockK ≤ k), mfma32x32xk is used. + // + // NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA + // instructions are used. + // + // In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead + // elements along the K dimension: + // - The first kWidthTransRead elements belong to the first sub-tile. + // - The next kWidthTransRead elements belong to the second sub-tile. + // + // These elements are then grouped into larger tiles, each consisting of + // 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data + // for one MFMA instruction. The shape of these tiles depends on the MFMA + // instruction used. + // + // For single-rate MFMA instructions, each thread holds kWidthTransRead + // elements along the K dimension. This means that the larger tile + // (corresponding to one MFMA instruction) consists of 4 {16, kWidthTransRead} + // sub-tiles. + std::vector> registerBase; + std::vector> laneBase; + + // Populate register base for first subtile + for (int i = 1; i < kWidthTransRead; i *= 2) { + registerBase.push_back({i, 0}); + } + + const int threadsPerSubtileNonK = 16 / kWidthTransRead; + const int threadsPerSubtileK = kWidthTransRead; + + // Populate lane base for first subtile + for (int i = 1; i < threadsPerSubtileNonK; i *= 2) { + laneBase.push_back({i * kWidthTransRead, 0}); + } + for (int i = 1; i < threadsPerSubtileK; i *= 2) { + laneBase.push_back({0, i}); + } + + // Function to extend register base for multiple tiles K dim. + auto extendRegisterBaseForKDim = [&](int kTileSize) { + const int regsPerTile = kWidthTransRead * 2; // Two subtiles per tile + int totalRegs = (kSize / kTileSize) * regsPerTile; + + for (int reg = regsPerTile; reg < totalRegs; reg *= 2) { + registerBase.push_back({0, (reg / regsPerTile) * kTileSize}); + } + }; + + const bool isMfma32 = (mfmaLayout.getMDim() == 32); + const bool isMfma16 = (mfmaLayout.getMDim() == 16); + const int kTileSize = isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth; + const bool largeKSize = kSize >= kTileSize; + + // Extend register base for large K sizes. + if (largeKSize) { + registerBase.push_back({0, threadsPerSubtileK}); // Second subtile + extendRegisterBaseForKDim(kTileSize); + } + + // Extend lane base based on MFMA size. + const int numSubtilesPerTile = largeKSize ? 2 : 1; + std::vector> laneBaseExt; + + if (isMfma32) { + laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}}; + } else { + laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK}, + {0, 2 * numSubtilesPerTile * threadsPerSubtileK}}; + } + + laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end()); + + // Base vectors above are defined in a fixed order [non-k-dim, k-dim]. + // To assign them to actual matrix dimensions `order` array is used. + // For operand A: non-k-dim -> dim0, k-dim -> dim1 + // For operand B: non-k-dim -> dim1, k-dim -> dim0 + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + SmallVector warpOrder = dotMfmaLayout.getDefaultWarpOrder(); + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); + + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); + return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); +} + +LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { + + // Current linear layout conversion for dot operand is only necessary to + // enable LDS bypass for operand B in the MFMA dot path. To achieve + // performance gains from bypassing LDS, the following conditions must be met: + // + // 1) opIdx == 1: Currently, only the B tensor (e.g. weights in moe-like + // kernels) bypasses LDS. This constraint is not strict and support for + // bypassing operand A (e.g. Q tensor in flash attention) will be added in + // the future. + // + // 2) B tensor must be column major: This is required to support vectorized + // global load instructions, as MFMA instructions expect threads to hold B + // operand elements along the K dimension. + // + // 3) kWidth == 8: Ensures maximum global load vectorization for fp16 + // operations. + // TODO: Generalize conversion to handle maximum kWidth for other types + // (i.e. fp8). + // + // 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is + // held by exactly one thread, maintaining the same number of global loads + // as in a blocked layout. + // + // Other use of Linear layout is a support of rare corner cases, + // for example one instruction tile is larger than tensor + auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); + + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + + int32_t kWidth = dotMfmaLayout.getKWidth(); + auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int32_t kSize = shape[kDim]; + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + + MLIRContext *ctx = dotMfmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] + SmallVector order = dotMfmaLayout.getDefaultOrder(); + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + SmallVector warpOrder = dotMfmaLayout.getDefaultWarpOrder(); + + // Lane holds kWidth consecutive elements along k dimension, so + // base register vectors for one tile are initialized in following way: + // {1, 0}, {2, 0} ... {kWidth/2, 0} + std::vector> registerBase; + for (int32_t elem = 1; elem < kWidth; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + std::vector> laneBase; + int32_t kTileSize = -1; + + if (mfmaLayout.getMDim() == 32) { + // Canonical MFMA linear layout handles 4 consecutive elements along + // the register dimension. Dot operand handles variable kWidth consecutive + // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2, + // 32}, this means that mapping of first 5 base (up to thread 16) vectors + // will be an identity along N dim. Thread 32 will be mapped to element + // kWidth in K dimension. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}}; + kTileSize = kWidth * 2; + } else { + assert(mfmaLayout.getMDim() == 16); + // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this + // means that mapping of first 4 base (up to thread 16) vectors will be an + // identity along N dim. Thread 16 will be mapped to element kWisth in K + // dimension. Thread 32 is mapped to element 2*kWidth in K dim. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}}; + kTileSize = kWidth * 4; + } + assert(kTileSize != -1); + // Add repeats of registers along K dimension to register base vectors + for (int32_t elem = kTileSize; elem < kSize; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + // Base vectors above are defined in a fixed order [non-k-dim, k-dim]. + // To assign them to actual matrix dimensions `order` array is used. + // For operand A: non-k-dim -> dim0, k-dim -> dim1 + // For operand B: non-k-dim -> dim1, k-dim -> dim0 + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); + + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); + + return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); +} + +LinearLayout +AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + int rank = shape.size(); + assert(rank == getWarpsPerCTA().size()); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + SmallVector mnkDim = getMNKDimPerInstr(); + unsigned mDim = mnkDim[0], nDim = mnkDim[1]; + (void)mDim, (void)nDim; + + assert(((shape[mIndex] == 1 || shape[mIndex] >= mDim) && + (shape[nIndex] == 1 || shape[nIndex] >= nDim)) && + "Unsupported tensor shape for given wmma layout"); + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. + SmallVector threadOrder = getDefaultThreadOrder(); + assert(threadOrder[0] == mIndex || threadOrder[0] == nIndex); + assert(threadOrder[1] == mIndex || threadOrder[1] == nIndex); + + // For wmma with 16x16 output, each of the 32 threads holds 8 elements. + // + // The first version of WMMA layout has following specific: + // for the register (i.e., element) dimension, these 8 elements are + // along the matrix C's M dimension, with 1 consecutive elements + // spanning 1 row and then the next 1 row being a gap. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 16 consecutive threads covering a whole + // row and the next 16 threads start at the next row. + // + // The second version of wmma layout is less tricky: + // for the register dimension 8 elements are along the matrix C's M + // dimension. First 16 lanes take 0-8 elems along M, second 16 take 8-15. + // We have 16 pair of threads in each warp, one pair covers the whole + // column. + // + // Please also check explaining comments in TritonGPUAttrDefs.td at the + // AMDWmmaEncodingAttr section. + unsigned ver = getVersion(); + assert(ver == 1 || ver == 2); + LinearLayout tileLayout = + ver == 1 + ? LinearLayout( + {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, + {outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]}) + : LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 4}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}}, + {outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]}); + + if (hasBatchDim) { + int batchIndex = 0; + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= + LinearLayout::identity1D(1, kRegister, outDimNames[batchIndex]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[batchIndex]); + } + + // And each warp takes the same register and lane sub-layout. So multiply with + // an identity layout for the warp. + auto warpOrder = getDefaultWarpOrder(); + LinearLayout warpLayout = + identityStandardND(S("warp"), getWarpsPerCTA(), warpOrder); + // reorder dim names in rep order, so combineCtaCgaWithShape generate proper + // extension of layout + auto repOrder = getRepOrder(); + SmallVector repDimNames; + for (auto dim : repOrder) + repDimNames.push_back(outDimNames[dim]); + LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) * + warpLayout.transposeOuts(repDimNames); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout, + ArrayRef shape) { + auto wmmaLayout = llvm::cast(dotWmmaLayout.getParent()); + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + auto kDim = dotWmmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int32_t kSize = shape[kDim]; + MLIRContext *ctx = dotWmmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + // lane order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] + SmallVector laneOrder = dotWmmaLayout.getDefaultOrder(); + // generate continuous part of register bases(i.e. kWidth) + std::vector> registerBase; + const int32_t kWidth = dotWmmaLayout.getKWidth(); + for (int i = 1; i < kWidth; i *= 2) + registerBase.push_back(std::vector{i, 0}); + std::vector> laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}}; + switch (wmmaLayout.getVersion()) { + case 1: + // WMMA version 1 duplicates values in lanes 0-15 and 16-31 + laneBase.push_back({0, 0}); + break; + case 2: + // WMMA version 2 offset values in lanes 0-15 and 16-31 across k dimensions + laneBase.push_back({kWidth, 0}); + break; + default: + assert(false && "unexpected version"); + } + // Generate layout for one wmma instruction + LinearLayout tileLayout( + {{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[laneOrder[0]], outDimNames[laneOrder[1]]}); + if (hasBatchDim) { + assert(laneOrder[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= + LinearLayout::identity1D(1, kRegister, outDimNames[laneOrder[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[laneOrder[2]]); + } + + // Generate warp layout + auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + auto warpOrder = dotWmmaLayout.getDefaultWarpOrder(); + LinearLayout warpLayout = + broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp")); + + // reorder dim names in rep order, so combineCtaCgaWithShape generate proper + // extension of layout + auto repOrder = wmmaLayout.getRepOrderForOperand(dotWmmaLayout.getOpIdx()); + SmallVector repDimNames; + for (auto dim : repOrder) + repDimNames.push_back(outDimNames[dim]); + + // join instruction layout and warps using repetition order of dimensions + LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) * + warpLayout.transposeOuts(repDimNames); + + return combineCtaCgaWithShape(ctaLayout, wmmaLayout.getCTALayout(), shape); +} + +LinearLayout +BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { + assert(shape.size() == getDefaultOrder().size()); + MLIRContext *ctx = getContext(); + + const auto &order = getDefaultOrder(); + LinearLayout ctaLayout = + identityStandardND(S("register"), getSizePerThread(), order) * + identityStandardND(S("lane"), getThreadsPerWarp(), order) * + identityStandardND(S("warp"), getWarpsPerCTA(), order); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, + ArrayRef shape) { + int rank = shape.size(); + auto blocked = cast(operandLayout.getParent()); + MLIRContext *ctx = operandLayout.getContext(); + + // TODO: introduce registerOrder or use getDefaultOrder(operandLayout) + // Currently this order is used in legacy converter, because we do not + // have access to full dot operand layout, only parent part. + auto regOrder = blocked.getDefaultOrder(); + // TODO: use operandLayout.getDefaultThreadOrder() + auto threadOrder = blocked.getDefaultThreadOrder(); + auto warpOrder = blocked.getDefaultWarpOrder(); + auto repOrder = blocked.getRepOrder(); + + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + auto threadSize = llvm::to_vector(blocked.getSizePerThread()); + auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + threadSize[kDimIdx] = shape[kDimIdx]; + auto threadShape = blocked.getThreadsPerWarp(); + auto warpShape = blocked.getWarpsPerCTA(); + + SmallVector repDimNames = + permuteDimNames(standardOutDimNames(ctx, rank), repOrder); + + auto registersLayout = identityStandardND(kReg, threadSize, regOrder); + auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder, + kDimIdx, kLane); + auto warpsLayout = + broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp); + + LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) * + lanesLayout.transposeOuts(repDimNames) * + warpsLayout.transposeOuts(repDimNames); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape); +} + +LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, + unsigned kWidth, ArrayRef order, + ArrayRef repOrder) { + // Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder + // Like LinearLayout::empty() but with a rank and an order + int rank = repOrder.size(); + auto dimNames = standardOutDimNames(ctx, rank); + auto trivialShape = SmallVector(rank, 1); + LinearLayout ctaLayout = + identityStandardND(S("register"), trivialShape, repOrder); + + assert(rank >= 2); + auto inner = order[0]; + auto outer = order[1]; + + assert(tileShape.size() == rank); + int m = tileShape[outer]; + int n = tileShape[inner]; + + // The relative order of registers and lanes is given by: + // - Inner dim: kWidth registers + // - Inner dim: 4 lanes + // - Outer dim: 8 lanes + // - Outer dim: repeat m / 8 times + // - Inner dim: repeat n / (kWidth * 4) times + assert(m % 8 == 0); + assert(n % (kWidth * 4) == 0); + // There is at least one subtile on the inner-most dimension + // FIXME. We should implement operator* in terms of operator*= + // and chain *= instead of using * + auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames()); + ctaLayout = ctaLayout * + LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) * + LinearLayout::identity1D(4, S("lane"), dimNames[inner]) * + LinearLayout::identity1D(8, S("lane"), dimNames[outer]) * + LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) * + LinearLayout::identity1D(n / (kWidth * 4), S("register"), + dimNames[inner]); + return ctaLayout; +} + +LinearLayout +NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ctx = getContext(); + int rank = shape.size(); + + SmallVector tileShape; + if (isAmpere()) { + // Ampere.getInstrShape() returns the tile shape + tileShape = SmallVector(getInstrShape()); + } else { + assert(isHopper()); + auto instrShapeMNK = getInstrShape(); + tileShape = SmallVector({instrShapeMNK[0], instrShapeMNK[1]}); + } + // nvidiamma layout always assumes kWidth = 2 + constexpr auto kWidth = 2; + auto ctaLayout = + nvidiaMmaTile(ctx, tileShape, kWidth, getDefaultOrder(), getRepOrder()); + + // The triton orders are defined on [dim0, dim1, ...], so we need to pass + // those dims Then, for some reason, operator* requires the orders to match + // so we need to reorder the outs to match + ctaLayout *= + identityStandardND(S("warp"), getWarpsPerCTA(), getDefaultWarpOrder()) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot) { + int rank = shape.size(); + auto mma = cast(dot.getParent()); + int kWidth = dot.getKWidth(); + bool isA = dot.getOpIdx() == 0; + MLIRContext *ctx = mma.getContext(); + + SmallVector tileShape(rank, 1); + if (isA) { + tileShape[rank - 2] = 16; + tileShape[rank - 1] = kWidth * 8; + } else { + // Hopper takes the rhs via shared memory + assert(mma.isAmpere()); + tileShape[rank - 2] = kWidth * 8; + tileShape[rank - 1] = 8; + } + auto ctaLayout = nvidiaMmaTile(ctx, tileShape, kWidth, dot.getDefaultOrder(), + dot.getRepOrder()); + auto kDim = isA ? rank - 1 : rank - 2; + ctaLayout *= + broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(), + mma.getDefaultWarpOrder(), kDim, S("warp")) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); +} + +LinearLayout +DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto parent = getParent(); + if (auto blockedLayout = mlir::dyn_cast(parent)) { + return fmaDotToLinearLayout(*this, shape); + } else if (auto mfmaLayout = mlir::dyn_cast(parent)) { + return mfmaDotToLinearLayout(*this, shape); + } else if (auto wmmaLayout = mlir::dyn_cast(parent)) { + return wmmaDotOperandToLinearLayout(*this, shape); + } else { + auto mma = mlir::cast(parent); + return nvidiaDotToLinearLayout(shape, *this); + } +} + +LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { + MLIRContext *ctx = getContext(); + + // First compute the linear layout for this layout's parent. + SmallVector parentShape(shape); + parentShape.insert(parentShape.begin() + getDim(), 1); + LinearLayout parentLL = triton::gpu::toLinearLayout(parentShape, getParent()); + + // Remove dimension getDim() from the parent layout. + // + // 1. Construct a layout `transform` from parent-out-dims to slice-out-dims + // that removes the relevant out-dim. + // 2. Compute linearSlice = parent.compose(transform). Now linearSlice maps + // from parent in-dims to slice out-dims. + // 3. Fix up duplicate registers introduced by slicing. + auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); + LinearLayout transform = LinearLayout::empty(); + for (auto [idx, outDim] : llvm::enumerate(parentLL.getOutDimNames())) { + if (idx == getDim()) { + // Because we're multiplying by all zeros, we could replace outDimNames[0] + // with any other valid out-dim; the layout will be the same. + transform *= LinearLayout::zeros1D(parentLL.getOutDimSize(outDim), outDim, + outDimNames[0]); + } else { + transform *= + LinearLayout::identity1D(parentLL.getOutDimSize(outDim), outDim, + outDimNames[idx - (idx < getDim() ? 0 : 1)]); + } + } + LinearLayout sliceLL = parentLL.compose(transform); + + // Step 3: Along the "register" dim, remove any all-zero bases. + auto bases = sliceLL.getBases(); + std::vector> newRegBases; + for (const auto &basis : bases[S("register")]) { + if (llvm::any_of(basis, [](int b) { return b != 0; })) { + newRegBases.push_back(basis); + } + } + bases[S("register")] = newRegBases; + + return LinearLayout(std::move(bases), + llvm::to_vector(sliceLL.getOutDimNames())); +} + +LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef shape, + Attribute layout) { + CacheKey key{std::vector(shape.begin(), shape.end()), layout}; + if (auto result = llCache.get(key)) { + return *result; + } + + // Layouts are distributed or shared in triton core + // To add a new layout add an else-if clause + LinearLayout result = LinearLayout::empty(); + if (auto distributed = dyn_cast(layout)) { + result = distributed.toLinearLayout(shape); + } else { + if (auto shared = dyn_cast(layout)) { + result = sharedToLinearLayoutNoLeadingOffset(shape, shared); + } else if (auto shared = dyn_cast(layout)) { + result = sharedToLinearLayoutLeadingOffset(shape, shared); + } else { + assert(0 && "unknown layout"); + } + } + + llCache.set(std::move(key), result); + return result; +} + +LinearLayout toLinearLayout(ArrayRef shape, Attribute layout) { + auto *ctx = layout.getContext(); + return ctx->getLoadedDialect()->toLinearLayout(shape, + layout); +} + +LinearLayout getLayoutWithinBlock(const LinearLayout &layout) { + assert(!layout.getInDimNames().empty()); + MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); + + StringAttr kBlock = S("block"); + assert(layout.hasInDim(kBlock)); + auto bases = layout.getBases(); + bases[kBlock] = {}; + return LinearLayout(bases, llvm::to_vector<4>(layout.getOutDimNames())); +} + +LinearLayout chooseShemLayoutForRegToRegConversion( + MLIRContext *ctx, ArrayRef tensorShape, + ArrayRef repShape, ArrayRef order) { + auto outDimNames = standardOutDimNames(ctx, tensorShape.size()); + LinearLayout layout = LinearLayout::empty(); + SmallVector kRepDims; + SmallVector kOffsetDims; + auto totalIters = 1; + auto totalOffsets = 1; + for (int i = 0; i < tensorShape.size(); i++) { + int dim = order[i]; + StringAttr kIteration = S("iteration" + std::to_string(dim)); + StringAttr kOffset = S("offset" + std::to_string(dim)); + kRepDims.push_back(kIteration); + kOffsetDims.push_back(kOffset); + assert(llvm::isPowerOf2_32(repShape[dim])); + assert(llvm::isPowerOf2_32(tensorShape[dim])); + auto numIters = tensorShape[dim] / repShape[dim]; + layout *= + LinearLayout::identity1D(repShape[dim], kOffset, outDimNames[dim]); + layout *= LinearLayout::identity1D(numIters, kIteration, outDimNames[dim]); + totalIters *= numIters; + totalOffsets *= repShape[dim]; + } + StringAttr kOffset = S("offset"); + StringAttr kIteration = S("iteration"); + StringAttr kBlock = S("block"); + SmallVector newDims; + newDims.append(kOffsetDims.begin(), kOffsetDims.end()); + newDims.append(kRepDims.begin(), kRepDims.end()); + // Transpose layout from [offset0, rep0, offset1, rep1, ...] to + // [offset0, offset1, ..., rep0, rep1, ...] + auto ret = layout.transposeIns(newDims); + // Reshape layout from [offset0, offset1, ..., rep0, rep1, ...] to + // [offset, rep, block] + return ret.reshapeIns( + {{kOffset, totalOffsets}, {kIteration, totalIters}, {kBlock, 1}}); +} + +namespace { +LinearLayout chooseStMatrixLayoutLeadingOffset(MLIRContext *ctx, + RankedTensorType tensorTy, + int swizzleByteSize) { + int perPhase; + int maxPhase; + if (swizzleByteSize == 32) { + perPhase = 4; + maxPhase = 2; + } else if (swizzleByteSize == 64) { + perPhase = 2; + maxPhase = 4; + } else if (swizzleByteSize == 128) { + perPhase = 1; + maxPhase = 8; + } else { + llvm::errs() << "Illegal swizzleByteSize: " << swizzleByteSize << "\n"; + llvm::report_fatal_error("Illegal swizzleByteSize"); + } + + // stmatrix only supports 16-bit elements, and each vector has 8 elements + int elemBitWidth = 16; + int vecSize = 8; + int numRowsPerTile = 16; + int numColsPerChunk = 8 * swizzleByteSize / elemBitWidth; + + // Construct a single stmatrix.x4 (16x16) tile + std::vector> basesReg = {{1, 0}, {2, 0}, {4, 0}}; + std::vector> basesLane; + for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile); logRow++) { + int row = 1 << logRow; + basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row}); + } + basesLane.push_back({8, 0}); + + auto mma = cast(tensorTy.getEncoding()); + assert(mma.getVersionMajor() >= 3 && "Only MMAv3 is supported"); + int instrM = mma.getInstrShape()[0]; + int instrN = mma.getInstrShape()[1]; + + // TODO(Keren): The following logic can be simplified by using the + // `divideLeft` function in `LinearLayout` once it's available. + // Construct the bases for a single chunk + // In theory the following situation is valid but it will be + // suboptimal. Swizzling should happen within a warp. + assert(instrN >= numColsPerChunk && + "Each chunk is filled in with a single warp"); + for (int logCol = 0; logCol < llvm::Log2_32(numColsPerChunk / 16); logCol++) { + int col = 1 << logCol; + basesReg.push_back({16 * col, 0}); + } + + // Construct the bases for warpsPerCTA[0] + std::vector> basesWarp; + auto warpsPerCTA = mma.getWarpsPerCTA(); + auto shape = tensorTy.getShape(); + for (int logWarp = 0; logWarp < llvm::Log2_32(warpsPerCTA[0]); logWarp++) { + int warp = 1 << logWarp; + basesWarp.push_back({0, warp * instrM}); + } + + // Expand the `register` dimension so the size of columns matches `shape[1] / + // warpsPerCTA[1]` + auto numColsPerWarp = std::max(instrN, shape[1] / warpsPerCTA[1]); + assert(warpsPerCTA[1] * instrN >= shape[1] && + "There must be enough columns to use MMAv3"); + auto logNumCols = llvm::Log2_32(numColsPerWarp / numColsPerChunk); + for (int logCol = 0; logCol < logNumCols; logCol++) { + int chunk = 1 << logCol; + int basis = chunk * shape[0]; + basesReg.push_back({0, basis}); + } + + // Expand the `register` dimension so that the size of rows matches `shape[0]` + assert(warpsPerCTA[0] * instrM <= shape[0] && + "There must be enough rows to use MMAv3"); + auto logNumRows = llvm::Log2_32(shape[0] / (warpsPerCTA[0] * instrM)); + for (int logRow = 0; logRow < logNumRows; logRow++) { + int chunk = 1 << logRow; + int basis = chunk * warpsPerCTA[0] * instrM; + basesReg.push_back({0, basis}); + } + + // Expand the `warp` dimension so that the size of cols matches `shape[1]` + for (int logWarp = 0; logWarp < llvm::Log2_32(warpsPerCTA[1]); logWarp++) { + int warp = 1 << logWarp; + if (warp * numColsPerWarp >= shape[1]) { + basesWarp.push_back({0, 0}); + } else { + int basis = (warp * numColsPerWarp) / numColsPerChunk * shape[0]; + basesWarp.push_back({0, basis}); + } + } + + auto layout = LinearLayout({{S("register"), basesReg}, + {S("lane"), basesLane}, + {S("warp"), basesWarp}, + {S("block"), {}}}, + {S("offset1"), S("offset0")}); + return layout.reshapeOuts( + {{S("offset"), layout.getTotalOutDimSize()}, {S("iteration"), 1}}); +} + +LinearLayout chooseStMatrixLayoutNoLeadingOffset(MLIRContext *ctx, + Attribute encoding, + ArrayRef shape) { + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kCol = S("dim1"); + StringAttr kRow = S("dim0"); + StringAttr kBlock = S("block"); + + std::vector> basesReg = {{1, 0}, {2, 0}, {4, 0}}; + std::vector> basesLane = { + {0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}; + LinearLayout layout = + LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow}); + + // Expand the `register` dimension so the size of columns matches `n`. + auto mma = cast(encoding); + int n = mma.getInstrShape()[1]; + layout *= + LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol); + + // Expand the `warp` dimension according to warpsPerCTA. + layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + auto ret = combineCtaCgaWithShape(layout, mma.getCTALayout(), shape); + auto tensorShapePerCTA = getShapePerCTA(mma, shape); + llvm::SmallDenseMap namedTensorShape; + namedTensorShape[kRow] = tensorShapePerCTA[0]; + namedTensorShape[kCol] = tensorShapePerCTA[1]; + ret = ensureLayoutNotSmallerThan(ret, namedTensorShape); + ret = ensureLayoutNotLargerThan(ret, namedTensorShape); + return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames())) + .reshapeOuts( + {{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}}); +} + +LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot, + ArrayRef shape, bool needTrans, + int32_t elemBitWidth) { + auto ctx = dot.getContext(); + auto mma = cast(dot.getParent()); + auto rank = shape.size(); + auto opIdx = dot.getOpIdx(); + int kDim = (opIdx == 0) ? rank - 1 : rank - 2; + int nonKDim = (opIdx == 0) ? rank - 2 : rank - 1; + + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kBlock = S("block"); + StringAttr kInner = opIdx == 0 ? (needTrans ? S("dim0") : S("dim1")) + : (needTrans ? S("dim1") : S("dim0")); + StringAttr kOuter = opIdx == 0 ? (needTrans ? S("dim1") : S("dim0")) + : (needTrans ? S("dim0") : S("dim1")); + + std::vector> basesReg; + for (int logReg = 0; logReg < llvm::Log2_32(8 * 16 / elemBitWidth); + logReg++) { + auto reg = 1 << logReg; + basesReg.push_back({0, reg}); + } + std::vector> basesLane = { + {1, 0}, {2, 0}, {4, 0}, {0, 0}, {0, 0}}; + bool kX2 = shape[kDim] > 8 * 16 / elemBitWidth; + bool kX4 = shape[kDim] > 16 * 16 / elemBitWidth; + bool nonKX2 = shape[nonKDim] > 8; + // Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix + // efficiently. opIdx=0 and opIdx=1 are handled differently. + if (opIdx == 0) { + // The matrix elements of thread 0 are distributed in the following pattern + // (fp16): + // + // col0 col8 + // row0 reg[0-1] reg[4-5] + // row8 reg[2-3] reg[6-7] + if (needTrans) { + assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are " + "supported in the transposed mode"); + if (nonKX2) + basesLane[3] = {0, 8}; + if (kX2) + basesLane[4] = {8 * 16 / elemBitWidth, 0}; + } else { + if (nonKX2) + basesLane[3] = {8, 0}; + if (kX2) + basesLane[4] = {0, 8 * 16 / elemBitWidth}; + } + } else { + // The matrix elements of thread 0 are distributed in the following pattern + // (fp16): + // + // col0 col8 col16 col24 + // row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7] + if (needTrans) { + assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are " + "supported in the transposed mode"); + if (kX2) + basesLane[3] = {8, 0}; + if (kX4) + basesLane[4] = {16, 0}; + } else { + if (kX2) + basesLane[3] = {0, 8 * 16 / elemBitWidth}; + if (kX4) + basesLane[4] = {0, 16 * 16 / elemBitWidth}; + } + } + int numTileCols = + (8 * 16 / elemBitWidth) + << (static_cast(kX2) + static_cast(kX4 && opIdx == 1)); + // Expand the `register` dimension so the size of columns matches `K`. + auto layout = + LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}}, + {kOuter, kInner}) * + LinearLayout::identity1D(shape[kDim] / numTileCols, kReg, + S("dim" + std::to_string(kDim))); + // Expand the `warp` dimension according to warpsPerCTA. + auto warpsPerCTA = mma.getWarpsPerCTA(); + layout *= broadcastedDotOperandLayout(ctx, warpsPerCTA, + mma.getDefaultWarpOrder(), kDim, kWarp) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + return combineCtaCgaWithShape(layout, getCTALayout(dot), shape); +} + +} // anonymous namespace + +LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + int swizzleByteSize) { + if (swizzleByteSize == 0) + return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy.getEncoding(), + tensorTy.getShape()); + else + return chooseStMatrixLayoutLeadingOffset(ctx, tensorTy, swizzleByteSize); +} + +LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef shape, + bool needTrans, int32_t elemBitWidth) { + auto dot = cast(enc); + return chooseDotLdMatrixLayout(dot, shape, needTrans, elemBitWidth); +} + +LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef shape, + int32_t elemBitWidth) { + auto dot = cast(enc); + return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth); +} + +LinearLayout +chooseScaledMfmaOperandLayout(AMDMfmaEncodingAttr mfmaEnc, int kWidth, + int dotOperandIdx, ScaleDotElemType elemType, + llvm::ArrayRef dotOperandShape) { + MLIRContext *ctx = mfmaEnc.getContext(); + unsigned mDim = mfmaEnc.getMDim(); + if (elemType == ScaleDotElemType::E2M1) { + auto newEncoding = + DotOperandEncodingAttr::get(ctx, dotOperandIdx, mfmaEnc, kWidth / 2); + return newEncoding.toLinearLayout(dotOperandShape); + } + + // For mxfp8, each lane contains 32 elements, consisting of two blocks + // of 16 consecutive elements. There's a gap between these two blocks, + // which is not supported by normal dot layout. + assert(elemType == ScaleDotElemType::E4M3 || + elemType == ScaleDotElemType::E5M2); + using basisT = std::vector>; + unsigned rank = dotOperandShape.size(); + auto standardOutDims = standardOutDimNames(ctx, rank); + auto warpOrder = mfmaEnc.getDefaultWarpOrder(); + + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + + basisT regBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}}; + basisT laneBase = {{1, 0}, {2, 0}, {4, 0}, {8, 0}}; + int32_t kTileSize; + if (mDim == 16) { + regBase.emplace_back(std::vector{0, 64}); + laneBase.emplace_back(std::vector{0, 16}); + laneBase.emplace_back(std::vector{0, 32}); + kTileSize = kWidth * 4; + } else { + assert(mDim == 32); + regBase.emplace_back(std::vector{0, 32}); + laneBase.emplace_back(std::vector{16, 0}); + laneBase.emplace_back(std::vector{0, 16}); + kTileSize = kWidth * 2; + } + // Add repeats of registers along K dimension to register base vectors + int64_t kSize = dotOperandIdx == 0 ? dotOperandShape[1] : dotOperandShape[0]; + for (int32_t elem = kTileSize; elem < kSize; elem *= 2) { + regBase.emplace_back(std::vector{0, elem}); + } + + // Order of dimensionality changes on A/B operand, so here we need to reverse + // if it's operand B. + std::vector repOrder = {0, 1}; + if (dotOperandIdx == 1) { + std::reverse(repOrder.begin(), repOrder.end()); + } + + auto regLanes = LinearLayout( + {{kRegister, regBase}, {kLane, laneBase}}, + {standardOutDims[repOrder[0]], standardOutDims[repOrder[1]]}); + + auto warps = identityStandardND(kWarp, mfmaEnc.getWarpsPerCTA(), warpOrder); + + return combineCtaCgaWithShape(regLanes.transposeOuts(standardOutDims) * + warps.transposeOuts(standardOutDims), + mfmaEnc.getCTALayout(), dotOperandShape); +} + +LinearLayout chooseScaledMfmaScaleLayout( + MLIRContext *ctx, int dotOperandIdx, + const std::vector> &dotOperandWarpBasis, + ArrayRef dotOperandShape, unsigned mfmaMDim) { + using basisT = std::vector>; + unsigned rank = dotOperandShape.size(); + auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true); + auto standardOutDims = standardOutDimNames(ctx, rank); + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + // Init register layout. Will be adjusted later + auto regs = mlir::triton::identityStandardND(kRegister, {1, 1}, order); + LinearLayout lanes = LinearLayout::empty(); + // In scaled dot, the shapes of operands(without batch dimension) are, + // respectively: + // - A: [M, K] + // - B: [K, N] + // - aScale: [M, K / 32] + // - bScale: [N, K / 32] + // + // To correctly feed A/B and its scale into instruction, we need to + // distribute aScale/bScale among warps in the same way as A/B. But bScale + // is not transposed like B. So we need to transpose the warp layout of + // bScale. + // + // The tricky part is, our desired outputs are [dim0, dim1], but + // at this position, the layouts are transposed to [dim1, dim0]. So + // instead of reverse bScale's layout, we need to reverse aScale's. There + // will be a transpose in the end to correct everything. + basisT warps = dotOperandWarpBasis; + if (dotOperandIdx == 0) { + for (auto &basis : warps) { + std::reverse(basis.begin(), basis.end()); + } + } + // In general, for both 32x32 and 16x16 scaled mfma, and no matter what + // data type the A/B operand is, each lane takes 32 elements from A/B + // alone K dim, and 1 or 2 elements from scale accordingly. The number of + // scale's elements in a lane varies because the 32 elements from A/B may + // not be consecutive. + // + // For mxfp4, these 32 elements are consecutive, so only 1 scale element + // is required. But for mxfp6/mxfp8, there are 2 16-consecutive elements + // blocks, so 2 scale elements are required. + if (mfmaMDim == 32) { + // For ROCDL::mfma_scale_f32_32x32x64_f8f6f4 with fp4 input, each lane + // takes 32 consecutive elements from A alone K dimension. The first + // 32 lanes collectively handle A[0:32][0:32], and the other 32 lanes + // collectively handle A[0:32][32:64]. Each lane take 1 scale element + // accordingly. Similar to B and bScale. + lanes = LinearLayout( + {{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}}}, + {kWarp, warps}, + {kBlock, {}}}, + {standardOutDims[order[0]], standardOutDims[order[1]]}); + } else { + assert(mfmaMDim == 16); + // For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane + // takes 32 consecutive elements from A alone K dimension. The first + // 16 lanes collectively handle A[0:16][0:32], and another 16 lanes + // collectively handle A[0:16][32:64] and so on. Each lane take 1 scale + // element accordingly. Similar to B and bScale. + lanes = + LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, + {kWarp, warps}, + {kBlock, {}}}, + {standardOutDims[order[0]], standardOutDims[order[1]]}); + } + LinearLayout newLL = regs * lanes; + + // Adjust register-level layout to fill the shape, at this level, both + // aScale and bScale should align with A operand. + SmallVector repOrder = {1, 0}; + for (auto d : repOrder) { + auto outDim = standardOutDims[d]; + auto dimSize = newLL.getOutDimSize(outDim); + newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister, + outDim); + } + newLL = newLL.transposeOuts(standardOutDims); + return newLL; +} + +LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType, + int numWarps) { + assert(numWarps == 4 || numWarps == 8); + MLIRContext *ctx = scaleType.getContext(); + + using basisT = std::vector>; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + + int64_t M = scaleType.getDimSize(0); + int64_t N = scaleType.getDimSize(1); + auto CTALayout = getCTALayout(scaleType.getEncoding()); + basisT regBase; + + // Pick a layout that will be trivial to store into the following TMEM layout: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x + // Pack 4 scales together, if there are less than 4 we replicate the data. + for (int i = 1; i < 4; i = i << 1) { + if (i >= N) + regBase.push_back({0, 0}); + else + regBase.push_back({0, i}); + } + // Distribute 32 elements of M along a warp. + basisT laneBase = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}; + // The data are replicated across all the warps of each warpgroups. + basisT warpBase = {{0, 0}, {0, 0}}; + for (int i = 32; i < M; i = i << 1) { + regBase.push_back({i, 0}); + } + for (int i = 4; i < N; i = i << 1) { + regBase.push_back({0, i}); + } + // If we have 8 warps distribute the last dimension on the second warp group. + if (numWarps == 8) { + warpBase.push_back(regBase.back()); + regBase.pop_back(); + } + + SmallVector outDimNames = standardOutDimNames(ctx, 2); + auto regLanes = + LinearLayout({{kRegister, regBase}, {kLane, laneBase}, {kWarp, warpBase}}, + {outDimNames[0], outDimNames[1]}); + + return combineCtaCgaWithShape(regLanes, CTALayout, scaleType.getShape()); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Ops.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Ops.cpp new file mode 100644 index 000000000..0c87871d0 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -0,0 +1,809 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/DebugStringHelper.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + +namespace mlir::triton::gpu { + +namespace { + +template bool hasEncoding(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) { + auto encoding = tensorType.getEncoding(); + return encoding && isa(encoding); + } + return false; +} + +bool hasDotOperandEncoding(Value value) { + return hasEncoding(value); +} + +bool isConvertTrivial(ConvertLayoutOp op) { + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); + auto srcEncoding = srcType.getEncoding(); + auto dstEncoding = dstType.getEncoding(); + return cast(&srcEncoding.getDialect()) + ->verifyLayoutsAreEqual(srcType.getShape(), srcEncoding, dstEncoding, {}) + .succeeded(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Canonicalizer +//===----------------------------------------------------------------------===// + +// reshape(cvt) -> reshape +struct CanonicalizeConvertFromReshape + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + // If the layouts are structurally the same, the convert is trivial + if (isConvertTrivial(convert)) { + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder(), + op.getEfficientLayout()); + return success(); + } + + if (isExpensiveView(convert.getSrc().getType(), op.getType())) + return failure(); + if (!op.getAllowReorder()) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder(), + op.getEfficientLayout()); + return mlir::success(); + } +}; + +// TODO We should do this generically for op(cvt) -> op +// We have similar patterns for reshape and split... +// See https://github.com/triton-lang/triton/pull/5403#discussion_r1920091671 + +// trans(cvt) -> trans +struct CanonicalizeConvertFromTranspose + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::TransOp op, + PatternRewriter &rewriter) const override { + // transpose(x, order=[0, 1, ...]) -> x + // We turn it into a (trivial) convert_layout that may be folded away + if (isIota(op.getOrder())) { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSrc()); + return success(); + } + + // If the layouts are structurally the same, the convert is trivial + auto convert = op.getSrc().getDefiningOp(); + if (!convert || !isConvertTrivial(convert)) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getOrder()); + return success(); + } +}; + +// histogram(cvt) -> histogram +struct CanonicalizeConvertFromHistogram + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::HistogramOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// If the gather does not have an optimized layout attached, then the source +// layout does not matter since the gather will be codegen'd by storing the +// source tensor into shared memory. Thus, we can fold conversions into the +// source operand. +// +// gather(cvt(src), idx) -> gather(src, idx) +struct CanonicalizeConvertFromGatherSource : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override { + // Don't do this if the compiler picked an optimized layout. + if (op.getEfficientLayout()) + return failure(); + + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + + rewriter.replaceOpWithNewOp(op, convert.getSrc(), op.getIndices(), + op.getAxis()); + return success(); + } +}; + +// alloc(cvt) -> alloc +struct CanonicalizeConvertFromAlloc + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, + PatternRewriter &baseRewriter) const override { + if (!op.getSrc()) + return failure(); + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// local_store(cvt) -> local_store +struct CanonicalizeConvertFromLocalStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, + PatternRewriter &baseRewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + rewriter.replaceOpWithNewOp(op, convert.getSrc(), + op.getDst()); + return mlir::success(); + } +}; + +struct CanonicalizeConvertFromSplit + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::SplitOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + auto srcEncoding = convert.getSrc().getType().getEncoding(); + // Multiple source layout can give the same output layout, if the source + // layout of the convert gives the same destination layout we can skip the + // convert. + auto dstEncoding = inferDstEncoding(op, srcEncoding); + if (dstEncoding != op.getOutLHS().getType().getEncoding()) + return failure(); + rewriter.replaceOpWithNewOp(op, convert.getSrc()); + return mlir::success(); + } +}; + +struct CanonicalizeConvertFromConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ConvertLayoutOp op, + PatternRewriter &rewriter) const override { + // Convert to the same layout is redundant. + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return success(); + } + + // We don't handle conversions to DotOperandEncodingAttr. This is a + // heuristic to accommodate fused attention. + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); + if (mlir::isa(dstType.getEncoding()) && + mlir::isa(srcType.getEncoding())) + return failure(); + + Operation *arg = op.getSrc().getDefiningOp(); + if (!arg) + return failure(); + + // cvt(reshape) -> reshape + if (auto reshape = dyn_cast(arg)) { + if (!reshape.getAllowReorder() || reshape.getEfficientLayout() || + isExpensiveView(reshape.getSrc().getType(), op.getType())) + return failure(); + + // In TritonGPUToLLVM phase, ViewOp is converted to unpacking and packing + // operations, which requires the element type to match between unpacking + // and packing. However, part of values with dot operand encoding will be + // packed/unpacked as i32 elements instead of the underlying element type. + // To avoid errors, skip this folding when either the operand or result + // of view has a dot operand encoding. + if (hasDotOperandEncoding(op->getOperand(0)) || + hasDotOperandEncoding(op->getResult(0))) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + reshape.getResult(), + reshape.getAllowReorder()); + return success(); + } + + // cvt(histogram) -> histogram + if (auto histogram = dyn_cast(arg)) { + // For histogram ops the input and output layouts are independent, so we + // can always fold convert into the histogram op. + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + histogram.getSrc()); + return success(); + } + + // cvt(local_load) -> local_load. + if (auto sharedLoad = dyn_cast(arg)) { + // Shared_load can load to any layout so we can always fold convert into + // it. + // We insert at the point of the original op as there could be ops with + // memory side-effects between the LocalLoad op and the ConvertLayout op + rewriter.setInsertionPoint(arg); + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + sharedLoad.getSrc()); + + return success(); + } + + // cvt(cat) -> cat + if (auto cat = dyn_cast(arg)) { + if (isExpensiveCat(cat, op.getType().getEncoding())) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + cat.getOperands()); + return success(); + } + + // cvt(cvt(x, type1), type2) -> cvt(x, type2) + if (auto cvt = dyn_cast(arg)) { + PatternRewriterWithAsyncTaskIds rewriterTask(rewriter, cvt); + rewriterTask.replaceOpWithNewOp( + op, op->getResultTypes().front(), cvt.getSrc()); + return success(); + } + + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.getSrc()); + return success(); + } + + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.getStart(), range.getEnd()); + return success(); + } + + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = dyn_cast(cst.getValue())) { + auto ty = cast(op->getResultTypes().front()); + auto newRet = + SplatElementsAttr::get(ty, ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return success(); + } + return failure(); + } +}; + +void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); +} + +LogicalResult Fp4ToFpOp::verify() { + auto srcTy = cast(getSrc().getType()); + auto resTy = cast(getResult().getType()); + auto rank = srcTy.getRank(); + + if (rank != resTy.getRank()) + return emitError() << "source rank " << rank << " != result rank " + << resTy.getRank(); + + auto srcShape = srcTy.getShape(); + auto resShape = resTy.getShape(); + auto axis = getAxis(); + + if (!(0 <= axis && axis < rank)) + return emitError() << "axis " << axis << " out of range for rank " << rank; + + auto elemType = resTy.getElementType(); + if (!(elemType.isBF16() || elemType.isF16())) + return emitError() << "only bf16 or f16 is supported for now, got " + << elemType; + + for (int i = 0; i < rank; ++i) { + if (i == axis) { + if (resShape[i] != srcShape[i] * 2) + return emitError() << "axis " << axis + << " dimension must be 2x source dimension (src=" + << srcShape[i] << ", dst=" << resShape[i] << ")"; + } else { + if (resShape[i] != srcShape[i]) + return emitError() << "dimension " << i + << " mismatch (src=" << srcShape[i] + << ", dst=" << resShape[i] << ", axis=" << axis + << ")"; + } + } + return success(); +} + +void Fp4ToFpOp::build(OpBuilder &builder, OperationState &state, + TypedValue src, Type elemType, + int32_t axis) { + auto srcTy = src.getType(); + auto shape = llvm::to_vector(srcTy.getShape()); + auto rank = srcTy.getRank(); + assert(0 <= axis && axis < rank); + shape[axis] *= 2; + + Attribute inEnc = srcTy.getEncoding(); + Attribute outEnc; + auto result = + inEnc.getDialect() + .getRegisteredInterface() + ->inferFp4ToFpOpEncoding(shape, axis, inEnc, outEnc, + /*fwdInference=*/true, state.location); + assert(succeeded(result)); + + auto resultTy = RankedTensorType::get(shape, elemType, outEnc); + build(builder, state, resultTy, src, axis); +} + +OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult +MemDescTransOp::inferReturnTypes(MLIRContext *context, + std::optional location, + MemDescTransOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + + // type is the same as the input + auto argTy = cast(adaptor.getSrc().getType()); + auto shape = argTy.getShape(); + auto order = adaptor.getOrder(); + SmallVector retShape = applyPermutation(shape, order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, shape, order, retEncoding) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(), + argTy.getMutableMemory())); + return success(); +} + +// LocalAllocOp +void LocalAllocOp::getEffects( + SmallVectorImpl> + &effects) { + Operation *op = getOperation(); + // If allocation is immutable, mark it as no side effect allow things like + // CSE, DCE to work in early compiler passes. + // After the memory offset is computed, we attach the true side effect to the + // op. + if (!getType().getMutableMemory() && !op->hasAttr("allocation.offset")) + return; + effects.emplace_back(MemoryEffects::Allocate::get(), + mlir::triton::gpu::SharedMemory::get()); + if (getSrc()) + effects.emplace_back(MemoryEffects::Write::get(), + getOperation()->getOpResult(0), + mlir::triton::gpu::SharedMemory::get()); +} + +OpFoldResult LocalAllocOp::fold(FoldAdaptor adaptor) { + if (getType().getMutableMemory()) + return {}; + auto src = getSrc(); + if (!src) + return {}; + auto localLoadOp = src.getDefiningOp(); + if (!localLoadOp) + return {}; + auto loadSrc = localLoadOp.getSrc(); + if (loadSrc.getType() != getType()) + return {}; + return loadSrc; +} + +LogicalResult LocalAllocOp::verify() { + if (!getSrc()) { + if (!getType().getMutableMemory()) + return emitError("uninitialized alloc must have a mutable memdesc type"); + return success(); + } + auto srcTy = getSrc().getType(); + auto dstTy = getType(); + + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + return success(); +} + +// LocalLoadOp +void LocalLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// LocalStoreOp +LogicalResult LocalStoreOp::verify() { + if (!getDst().getType().getMutableMemory()) + return emitOpError("Cannot store into immutable memory"); + return success(); +} + +void LocalStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// AsyncCopyGlobalToLocalOp +LogicalResult AsyncCopyGlobalToLocalOp::verify() { + if (!getResult().getType().getMutableMemory()) + return emitOpError("Cannot store into immutable memory"); + return success(); +} + +void AsyncCopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +LogicalResult MemDescSubviewOp::verify() { + auto srcTy = getSrc().getType(); + auto dstTy = getType(); + + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + if (getOffsets().size() != srcTy.getRank()) { + return emitError("offsets must have the same rank as input"); + } + if (srcTy.getRank() < dstTy.getRank()) { + return emitError("result rank must be less than or equal to input rank"); + } + auto rankDiff = srcTy.getRank() - dstTy.getRank(); + for (int i = 0; i < dstTy.getRank(); i++) { + if (dstTy.getDimSize(i) > srcTy.getDimSize(i + rankDiff)) { + return emitError( + "result shape cannot be larger than input shape at dimension ") + << i; + } + } + + auto srcEnc = srcTy.getEncoding(); + auto dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("src and result must both have or not have an encoding"); + } + + if (!isa(srcEnc) && + !isa(srcEnc)) { + return emitError("src encoding must be SharedEncodingTrait"); + } + if (!isa(dstEnc) && + !isa(srcEnc)) { + return emitError("result encoding must be SharedEncodingTrait"); + } + + if (isa(srcEnc)) { + // We support only 3D -> 2D subviews with only first offset being non-zero. + if (srcTy.getRank() != 3 || dstTy.getRank() != 2) { + return emitError("only 3D -> 2D subviews are supported for " + "TensorMemoryEncodingAttr"); + } + for (int i = 1; i < srcTy.getRank(); i++) { + if (auto constOp = getOffsets()[i].getDefiningOp()) { + if (!isa(constOp.getValue()) || + cast(constOp.getValue()).getInt() != 0) { + return emitError("only first offset can be non-zero for the subview" + "of TensorMemoryEncodingAttr"); + } + } else { + return emitError( + "offsets other than the first one must be constant zeros"); + } + } + } + + // TODO(jlebar): Currently we generate illegal encodings, so we can't add a + // verifier for them. In particular, we use the same encoding for the src and + // dst of a subview op, when the subview removes a dimension. That generates + // an illegal shared encoding (because the size of `order` doesn't match the + // rank of the tensor), but it's not checked anywhere, and we believe the + // resulting code ultimately works. + + return success(); +} + +// -- LocalAllocOp -- + +int32_t LocalAllocOp::getAlignmentOrDefault() { + auto align = getAlignment(); + if (align) { + return *align; + } + + auto ty = getType(); + auto enc = dyn_cast(ty.getEncoding()); + return enc ? enc.getAlignment() : 16; +} + +// -- WarpSpecializeOp -- + +static Type removeEncodingIfTensor(Type type) { + if (auto tensorType = dyn_cast(type)) { + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType()); + } + return type; +} + +RegionRange WarpSpecializeOp::getPartitionRegions() { + return cast( + getPartitionOpHolder().front().front()) + .getPartitionRegions(); +} + +void WarpSpecializeOp::getSuccessorRegions( + RegionBranchPoint src, SmallVectorImpl &successors) { + // The parent branches transparently into the default region. + if (src.isParent()) { + successors.emplace_back(&getDefaultRegion()); + return; + } + // And the default region branches transparently back to the parent. + assert(src.getRegionOrNull() == &getDefaultRegion()); + successors.push_back(RegionSuccessor(getResults())); +} + +LogicalResult WarpSpecializeOp::verify() { + // The default region is not isolated from above but the partition regions + // have to be. MLIR does not support this, so we hide an op inside another + // region that contains the isolated regions. Check that it is there. + if (!isa( + getPartitionOpHolder().front().front())) { + return emitOpError( + "expected to find only a `ttg.warp_specialize.partitions` op inside " + "its second region"); + } + + // Verify the partitions. + if (getPartitionRegions().size() != getPartitionNumWarps().size()) { + return emitOpError("has ") << getPartitionRegions().size() + << " partitions but `partitionNumWarps` has " + << getPartitionNumWarps().size() << " elements"; + } + for (auto [i, numWarps] : llvm::enumerate(getPartitionNumWarps())) { + if (llvm::isPowerOf2_32(numWarps)) + continue; + return emitOpError("partition #") + << i << " number of warps (" << numWarps << ") must be a power of 2"; + } + if (std::optional> startIds = getWarpGroupStartIds()) { + if (startIds->size() != getPartitionNumWarps().size()) { + return emitOpError("has ") + << startIds->size() << " warp group start IDs but expected " + << getPartitionNumWarps().size(); + } + } + + for (auto [i, region] : llvm::enumerate(getPartitionRegions())) { + if (region->getNumArguments() != getNumOperands()) { + return emitOpError("partition region #") + << i << " has " << region->getNumArguments() + << " arguments but expected " << getNumOperands(); + } + for (auto [argIdx, argType, capType] : llvm::enumerate( + region->getArgumentTypes(), getExplicitCaptures().getTypes())) { + if (argType == capType) + continue; + return emitOpError("partition region #") + << i << " argument #" << argIdx << " has type " << argType + << " but corresponding capture has type " << capType; + } + } + + // This op cannot be nested inside itself. + if ((*this)->getParentOfType()) { + return emitOpError( + "cannot be nested inside another `ttg.warp_specialize` op"); + } + + return success(); +} + +ParseResult WarpSpecializeOp::parse(OpAsmParser &p, OperationState &result) { + SmallVector operands; + SMLoc operandLoc = p.getCurrentLocation(); + if (p.parseOperandList(operands, AsmParser::Delimiter::Paren) || + p.parseOptionalAttrDictWithKeyword(result.attributes) || + p.parseKeyword("default") || p.parseRegion(*result.addRegion())) + return failure(); + + OperationState partitionOpState( + p.getEncodedSourceLoc(p.getCurrentLocation()), + WarpSpecializePartitionsOp::getOperationName()); + + SmallVector partitionNumWarps; + SmallVector partitionArgs; + while (succeeded(p.parseOptionalKeyword( + ("partition" + Twine(partitionNumWarps.size()).str())))) { + partitionArgs.clear(); + SMLoc regionLoc = p.getCurrentLocation(); + if (p.parseArgumentList(partitionArgs, AsmParser::Delimiter::Paren, + /*allowType=*/true) || + p.parseKeyword("num_warps") || p.parseLParen() || + p.parseInteger(partitionNumWarps.emplace_back()) || p.parseRParen() || + p.parseRegion(*partitionOpState.addRegion(), partitionArgs)) + return failure(); + } + + FunctionType types; + if (p.parseColon() || p.parseType(types) || + p.resolveOperands(operands, types.getInputs(), operandLoc, + result.operands)) + return failure(); + + result.addTypes(types.getResults()); + result.addAttribute(getPartitionNumWarpsAttrName(result.name), + p.getBuilder().getDenseI32ArrayAttr(partitionNumWarps)); + + Block &holder = result.addRegion()->emplaceBlock(); + OpBuilder b(p.getContext()); + b.setInsertionPointToStart(&holder); + b.create(partitionOpState); + return success(); +} + +void WarpSpecializeOp::print(OpAsmPrinter &p) { + p << '('; + p.printOperands(getOperands()); + p << ')'; + p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(), + {getPartitionNumWarpsAttrName()}); + + p.printNewline(); + p << "default "; + p.printRegion(getDefaultRegion(), /*printEntryBlockArgs=*/false); + + for (auto [i, region, numWarps] : + llvm::enumerate(getPartitionRegions(), getPartitionNumWarps())) { + p.printNewline(); + p << "partition" << i << '('; + llvm::interleaveComma(region->getArguments(), p, [&](BlockArgument arg) { + p.printRegionArgument(arg); + }); + p << ") num_warps(" << numWarps << ") "; + p.printRegion(*region, /*printEntryBlockArgs=*/false); + } + p << " : "; + p.printFunctionalType(*this); +} + +LogicalResult WarpYieldOp::verify() { + if (getNumOperands() != getParentOp().getNumResults()) { + return emitOpError("has ") + << getNumOperands() << " operands but parent op expected " + << getParentOp().getNumResults(); + } + for (auto [i, result, type] : + llvm::enumerate(getParentOp().getResultTypes(), getOperandTypes())) { + if (result != type) { + return emitOpError("operand #") << i << " has type " << type + << " but parent op expected " << result; + } + } + return success(); +} + +// Get the size of a scalar type when stored in shared memory. +// TODO: Generalize this as needed. +static size_t getSharedMemorySize(Type type) { + if (isa(type)) + return llvm::divideCeil(type.getIntOrFloatBitWidth(), 8); + if (isa(type)) + return 8; + if (auto desc = dyn_cast(type)) { + if (!isa(desc.getMemorySpace())) + return 8; + return 8 + desc.getRank() * 4; + } + llvm::report_fatal_error( + Twine("shared memory size for scalar type is unspecified: ") + + mlir::debugString(type)); +} + +std::pair WarpSpecializeOp::getCaptureSizeAlign() { + uint64_t captureSize = 0; + // Tightly pack the captures in memory. + for (Type type : getOperandTypes()) { + captureSize += getSharedMemorySize(type); + } + // Align the captures to 8 bytes. + return {captureSize, 8}; +} + +unsigned WarpSpecializeOp::getTotalPartitionWarps() { + ArrayRef numWarps = getPartitionNumWarps(); + return std::accumulate(numWarps.begin(), numWarps.end(), 0); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Types.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Types.cpp new file mode 100644 index 000000000..ef9c6c4a3 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/IR/Types.cpp @@ -0,0 +1,117 @@ +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + +Type TokenType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + int type = 1; + if (parser.parseInteger(type)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return TokenType::get(parser.getContext(), type); +} + +void TokenType::print(AsmPrinter &printer) const { + printer << "<" << getType() << ">"; +} + +static constexpr llvm::StringRef kMutableMemory = "mutable"; + +Type MemDescType::parse(AsmParser &parser) { + if (failed(parser.parseLess())) + return Type(); + + SmallVector dimensions; // required + if (failed(parser.parseDimensionList(dimensions, /*allowDynamic=*/false))) + return Type(); + + Type elementType; // required + if (failed(parser.parseType(elementType))) + return Type(); + + Attribute encoding; // required + if (failed(parser.parseComma()) || failed(parser.parseAttribute(encoding))) + return Type(); + + Attribute memorySpace; // required + if (failed(parser.parseComma()) || failed(parser.parseAttribute(memorySpace))) + return Type(); + + bool mutableMemory = false; // optional + SmallVector allocShape; // optional + if (succeeded(parser.parseOptionalComma())) { + if (succeeded(parser.parseOptionalKeyword(kMutableMemory))) { + mutableMemory = true; + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseDimensionList(allocShape, /*allowDynamic=*/false, + /*withTrailingX=*/false))) { + return Type(); + } + } + } else if (failed(parser.parseDimensionList(allocShape, + /*allowDynamic=*/false, + /*withTrailingX=*/false))) { + return Type(); + } + } + + if (parser.parseGreater()) + return Type(); + + return MemDescType::get(parser.getContext(), dimensions, elementType, + encoding, memorySpace, mutableMemory, dimensions); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + auto shape = getShape(); + for (auto dim : shape) + printer << dim << "x"; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + if (getMemorySpace()) + printer << ", " << getMemorySpace(); + if (getMutableMemory()) + printer << ", " << kMutableMemory; + auto allocShape = getAllocShape(); + if (allocShape != shape) { + printer << ", " << allocShape[0]; + for (auto dim : allocShape.drop_front(1)) { + printer << "x" << dim; + } + } + printer << ">"; +} + +LogicalResult MemDescType::verify(function_ref emitError, + ArrayRef shape, Type elementType, + Attribute encoding, Attribute memorySpace, + bool mutableMemory, + ArrayRef allocShape) { + if (allocShape.size() < shape.size()) + emitError() << "alloc shape must have at least as many dimensions as shape"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::gpu::TritonGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + >(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp new file mode 100644 index 000000000..cd3fd737f --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -0,0 +1,862 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Get the highest version supported for the hardware and the dot. +static int getMMAVersionSafe(int computeCapability, DotOp op) { + // List supported mma version in order of preference. + SmallVector versionsSupported; + if (computeCapability < 75) { + versionsSupported = {1}; + } else if (computeCapability < 90) { + versionsSupported = {2}; + } else if (computeCapability < 100) { + versionsSupported = {3, 2}; + } else if (computeCapability < 110) { + versionsSupported = {5, 2}; + } else if (computeCapability < 130) { + versionsSupported = {2}; + } else { + assert(false && "computeCapability not supported"); + } + for (int baseVersion : versionsSupported) { + if (supportMMA(op, baseVersion)) + return baseVersion; + if (baseVersion == 3) + op.emitRemark() << "Warning: can't use MMA V3 for the dot op"; + } + return 0; +} + +SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, + int numWarps) { + auto rank = shape.size(); + // Early exit for batched matmul + if (rank == 3) + return {(unsigned)numWarps, 1, 1}; + + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion() && + !isa(op); + }; + auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); + bool hasChainedDot = false; + for (Operation *op : slices) { + if (isa(op) && (op != dotOp)) { + auto chainedDot = cast(op); + auto resTy = chainedDot.getResult().getType(); + if (resTy.getRank() != rank) { + continue; + } + if (auto mmaEncoding = + dyn_cast(resTy.getEncoding())) { + return getWarpsPerCTA(mmaEncoding); + } + hasChainedDot = true; + } + } + if (hasChainedDot) { + if (shape[0] >= shape[1]) { + return {(unsigned)numWarps, 1}; + } else { + return {1, (unsigned)numWarps}; + } + } + + assert(rank == 2); + SmallVector shapePerWarp = {16, 8}; + SmallVector warps = {1, 1}; + // Compute repM and repN + SmallVector reps = {ceil(shape[0], shapePerWarp[0]), + ceil(shape[1], shapePerWarp[1])}; + // The formula for the number of registers given the reps is + // repM * 4 * repK + repN * 2 * repK + regsC + // where regsC = repM * repN * 4, which does not depend on the warp shape + // + // As such, to minimize the register pressure, we need to balance + // repM and repN. We then untie towards M, as the lhs tile has 4 elements, + // and the rhs tile has just 2. + while (product(warps) < numWarps) { + if (reps[0] >= reps[1]) { + warps[0] *= 2; + // Too many warps for this mma (repM == repN == 1). + // We allocate the remaining warps to the left (arbitrary choice) + if (reps[0] != 1) { + reps[0] /= 2; + } + } else { + warps[1] *= 2; + reps[1] /= 2; + } + } + return {(unsigned)warps[0], (unsigned)warps[1]}; +} + +SmallVector +warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, + const SmallVector &instrShape) { + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); + // Contains a chained dot. We prefer to assign warps to one axis + // to facilitate use cases like flash attention, allowing reductions within + // the same warp. + if (llvm::find_if(slices, [](Operation *op) { + return isa(op); + }) != slices.end()) + return {(unsigned)numWarps, 1}; + + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). + SmallVector ret = {4, 1}; + SmallVector shapePerWarp = {16, instrShape[1]}; + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] > shapePerWarp[0] * ret[0]) { + ret[0] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + +// Returns a shared memory allocation that can be used by a dotMMA op for the +// given value. +static Value +getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx, + bool allowTranspose, bool isMMAv5Fp4Padded = false, + Operation *op = nullptr /*only for diagnostic*/) { + OpBuilder::InsertionGuard g(rewriter); + Value arg = v; + if (auto cvtOp = v.getDefiningOp()) + arg = cvtOp.getSrc(); + auto argType = cast(arg.getType()); + assert(argType.getEncoding() && "unexpected tensor type"); + auto newOrder = getOrder(argType); + + // If the MMA op doesn't support transpose pick the layout expected by the MMA + // op. + if (!allowTranspose) { + if (opIdx == 1) { + newOrder = {0, 1}; + } else { + newOrder = {1, 0}; + } + } + + if (newOrder != getOrder(argType) && op) { + op->emitWarning("Warning: Forcing a different order [") + << newOrder[0] << ", " << newOrder[1] + << "] on SMEM than the register order for the opreand " << opIdx + << ". Registers will be transposed before SMEM store and the pipelined " + "load for this operand will be disabled, so poor performance is " + "expected."; + } + + Attribute SharedMemorySpace = + SharedMemorySpaceAttr::get(argType.getContext()); + auto CTALayout = getCTALayout(argType.getEncoding()); + auto newLayout = NVMMASharedEncodingAttr::get( + argType.getContext(), argType.getShape(), newOrder, CTALayout, + argType.getElementType(), isMMAv5Fp4Padded); + auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), + newLayout, SharedMemorySpace); + rewriter.setInsertionPointAfterValue(arg); + return rewriter.create(arg.getLoc(), newType, arg); +} + +static LocalAllocOp +getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) { + OpBuilder::InsertionGuard g(rewriter); + auto argType = cast(arg.getType()); + assert(argType.getEncoding() && "unexpected tensor type"); + auto newOrder = getOrder(argType); + + Attribute SharedMemorySpace = + SharedMemorySpaceAttr::get(argType.getContext()); + auto CTALayout = getCTALayout(argType.getEncoding()); + // No swizzling for scale for now + auto newLayout = SwizzledSharedEncodingAttr::get(argType.getContext(), 1, 1, + 1, newOrder, CTALayout); + auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), + newLayout, SharedMemorySpace); + rewriter.setInsertionPointAfterValue(arg); + return rewriter.create(loc, newType, arg); +} + +SmallVector +getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { + switch (version) { + case 2: + return warpsPerTileV2(dotOp, shape, numWarps); + case 3: + return warpsPerTileV3(dotOp, shape, numWarps, instrShape); + default: + assert(false && "not supported version"); + return {0, 0}; + } +} + +static bool bwdFilter(Operation *op) { + return (op->hasTrait() && isMemoryEffectFree(op)) || + isView(op) || + isa(op); +} + +// Finds the bitwidth with which the value x is loaded +static int computeOrigBitWidth(Value x) { + SetVector slice; + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = bwdFilter; + getBackwardSlice(x, &slice, opt); + + // TODO: This heuristic may be a bit too coarse and may need improving + // If the chain contains a fp4 to fp16/bf16 conversion, then the original + // bitwidth is 4. + if (llvm::any_of(slice, [](Operation *op) { return isa(op); })) + return 4; + + int origBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); + for (auto op : slice) { + if (isa(op)) { + if (auto tensorTy = + dyn_cast(op->getResultTypes().front())) { + origBitWidth = + std::min(origBitWidth, tensorTy.getElementTypeBitWidth()); + } + } + } + + // If JoinOp occurred at least once, in backward layout propagation, + // the kWidth will be split in half as we pass through the JoinOp. + // Hence we divide origBitWidth by 2 here to compensate for that and + // improve our load width. + // This won't be optimal if there is a tree of multiple JoinOps, which + // would require counting the max number of JoinOp's along any path. + // + // In the future we might want to do something like trying a large kWidth, + // run layout backpropagation and see what's the contiguity that you + // get at the loads that feed into it. + if (llvm::any_of(slice, [](Operation *op) { return isa(op); })) + origBitWidth /= 2; + + return origBitWidth; +} + +class BlockedToMMA : public mlir::OpRewritePattern { + int computeCapability; + mutable llvm::DenseMap dotOpInstNs; + +public: + BlockedToMMA(mlir::MLIRContext *context, int computeCapability, int benefit) + : OpRewritePattern(context, benefit), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotOp dotOp, + mlir::PatternRewriter &rewriter) const override { + if (computeCapability < 70) + return failure(); + if (computeCapability < 80) { + dotOp.emitRemark() + << "Dot op using MMA for compute capability " << computeCapability + << " has been deprecated. It falls back to the FMA path."; + return failure(); + } + // TODO: Check data-types and SM compatibility + if (!dotOp.getType().getEncoding() || + mlir::isa(dotOp.getType().getEncoding())) + return failure(); + + int numWarps = lookupNumWarps(dotOp); + int versionMajor = getMMAVersionSafe(computeCapability, dotOp); + if (!(versionMajor >= 1 && versionMajor <= 3)) + return failure(); + + // If both of the operands are not loads, we fallback to MMAv2 + // otherwise the reg-smem roundtrip will tank the MMAv3 performance + auto comesFromLoadOrBlockArg = [](Value v) -> bool { + // Peel out the original cvt dot_op<..., #blocked> + // and any other potential cvt/trans ops + while (true) { + if (auto cvtOp = v.getDefiningOp()) { + v = cvtOp.getSrc(); + continue; + } + if (auto transOp = v.getDefiningOp()) { + v = transOp.getSrc(); + continue; + } + break; + } + // We also accept block arguments as they appear in many MLIR tests + // If this is problematic we can totally drop them + return isa(v) || + (v.getDefiningOp() && + isa(v.getDefiningOp())); + }; + + bool aFromLoad = comesFromLoadOrBlockArg(dotOp.getA()); + bool bFromLoad = comesFromLoadOrBlockArg(dotOp.getB()); + auto origDotOp = dotOp; + + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = cast(a.getType()); + auto oldBType = cast(b.getType()); + auto oldRetType = cast(dotOp.getType()); + + // get MMA encoding for the given number of warps + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + auto retShapePerCTA = getShapePerCTA(oldRetType); + auto instrShape = mmaVersionToInstrShape( + versionMajor, retShapePerCTA, oldAType.getElementType(), numWarps); + + assert(versionMajor == 2 || versionMajor == 3); + int versionMinor = computeCapability == 75 ? 1 : 0; + auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + auto mmaEnc = NvidiaMmaEncodingAttr::get( + oldRetType.getContext(), versionMajor, versionMinor, warpsPerTile, + CTALayout, instrShape); + PatternRewriterWithAsyncTaskIds taskIdRewriter(rewriter, dotOp); + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = + rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); + + auto getDotOperand = [&](Value v, int opIdx, int bitwidth) { + auto minType = + bitwidth > 0 ? rewriter.getIntegerType(bitwidth) : v.getType(); + auto vType = cast(v.getType()); + auto newVEncoding = DotOperandEncodingAttr::get( + v.getContext(), opIdx, newRetType.getEncoding(), minType); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); + return rewriter.create(v.getLoc(), newVType, v); + }; + + Operation *newDot = nullptr; + if (versionMajor == 3) { + auto eltType = dotOp.getA().getType().getElementType(); + // In MMAV3 transpose is only supported for f16 and bf16. + bool allowTranspose = eltType.isF16() || eltType.isBF16(); + if (!aFromLoad) { + int bitwidth = getElementTypeOrSelf(a).getIntOrFloatBitWidth(); + a = getDotOperand(a, 0, bitwidth); + } else { + a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); + } + b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); + newDot = taskIdRewriter.create( + dotOp.getLoc(), newRetType, a, b, newAcc, nullptr, + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false); + } else { + // convert operands + int minBitwidth = + std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + + a = getDotOperand(a, 0, minBitwidth); + b = getDotOperand(b, 1, minBitwidth); + newDot = taskIdRewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); + } + // convert dot instruction + rewriter.replaceOpWithNewOp(origDotOp, origDotOp.getType(), + newDot->getResult(0)); + return success(); + } +}; + +// Pick the layout to match MXFP scales layout in register so that it can be +// copied directly using tmem st. +static Attribute getTmemScales(RankedTensorType type, unsigned numWarps) { + return triton::gpu::LinearEncodingAttr::get( + type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps)); +} + +static bool canUseTwoCTAs(triton::DotOp dotOp) { + RankedTensorType retType = dotOp.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + // TODO: we could support 2 CTAs matmul with numCTAs > 2. + SmallVector splitNum = getCTASplitNum(retType.getEncoding()); + if (splitNum.size() != 2 || splitNum[0] != 2 || splitNum[1] != 1) + return false; + int m = retShapePerCTA[0]; + int n = retShapePerCTA[1]; + // minimum size supported by 2CTAs mmav5. + if (m < 64 || n < 32) + return false; + Value b = dotOp.getB(); + // Skip convert layouts. + while (auto cvtOp = b.getDefiningOp()) + b = cvtOp.getSrc(); + if (!b.getDefiningOp()) + return false; + return true; +} + +static DistributedEncodingTrait +replaceCTALayout(DistributedEncodingTrait layout, + const triton::gpu::CTALayoutAttr &newCTALayout) { + if (auto blockedLayout = mlir::dyn_cast(layout)) { + return BlockedEncodingAttr::get( + layout.getContext(), blockedLayout.getSizePerThread(), + blockedLayout.getThreadsPerWarp(), blockedLayout.getWarpsPerCTA(), + blockedLayout.getDefaultOrder(), newCTALayout); + } else if (auto sliceLayout = mlir::dyn_cast(layout)) { + return SliceEncodingAttr::get( + layout.getContext(), sliceLayout.getDim(), + replaceCTALayout(sliceLayout.getParent(), newCTALayout)); + } else { + llvm::report_fatal_error("not implemented"); + return layout; + } +} + +static Value splitBOperand(Value b, mlir::PatternRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + MLIRContext *ctx = b.getContext(); + while (auto cvtOp = b.getDefiningOp()) + b = cvtOp.getSrc(); + auto loadOp = b.getDefiningOp(); + assert(loadOp && "expected LoadOp"); + RankedTensorType bType = cast(b.getType()); + auto currentLayout = cast(bType.getEncoding()); + auto newCTALayout = + CTALayoutAttr::get(ctx, {1, 2}, {1, 2}, getCTAOrder(currentLayout)); + Attribute newLayout = replaceCTALayout(currentLayout, newCTALayout); + rewriter.setInsertionPoint(loadOp); + for (OpOperand &operand : loadOp->getOpOperands()) { + auto tensorType = dyn_cast(operand.get().getType()); + if (!tensorType) + continue; + Value newOperand = rewriter.create( + operand.get().getLoc(), + RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), newLayout), + operand.get()); + loadOp.setOperand(operand.getOperandNumber(), newOperand); + } + loadOp.getResult().setType(RankedTensorType::get( + bType.getShape(), bType.getElementType(), newLayout)); + Value newB = loadOp.getResult(); + rewriter.setInsertionPointAfter(loadOp); + auto cvt = + rewriter.create(b.getLoc(), bType, loadOp.getResult()); + rewriter.replaceAllUsesExcept(loadOp.getResult(), cvt.getResult(), cvt); + return newB; +} + +class BlockedToMMAv5 : public mlir::OpRewritePattern { + int computeCapability; + +public: + BlockedToMMAv5(mlir::MLIRContext *context, int computeCapability, int benefit) + : OpRewritePattern(context, benefit), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotOp dotOp, + mlir::PatternRewriter &rewriter) const override { + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + + // get MMA encoding for the given number of warps + auto retShapePerCTA = getShapePerCTA(oldRetType); + int numWarps = lookupNumWarps(dotOp); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + + int versionMajor = getMMAVersionSafe(computeCapability, dotOp); + if (versionMajor != 5) + return failure(); + Location loc = dotOp.getLoc(); + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + if (std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)) >= 32 && + dotOp.getInputPrecision() != InputPrecision::TF32) + return failure(); + auto oldAType = dotOp.getA().getType(); + auto oldBType = dotOp.getB().getType(); + bool useTwoCTAs = canUseTwoCTAs(dotOp); + if (useTwoCTAs) { + b = splitBOperand(b, rewriter); + } + // TF32 transpose is only supported with 128 swizzle mode with 32B + // atomicity. As we currently don't support this layout we disallow + // transpose for TF32 inputs. + bool allowTranspose = !dotOp.getA().getType().getElementType().isF32(); + a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); + b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); + MLIRContext *context = dotOp->getContext(); + auto instrShape = mmaVersionToInstrShape( + versionMajor, retShapePerCTA, oldAType.getElementType(), numWarps); + ArrayRef CTASplitNum = CTALayout.getCTASplitNum(); + Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get( + context, instrShape[0], instrShape[1], /*unpacked=*/true, + CTASplitNum[0], CTASplitNum[1]); + Attribute tensorMemorySpace = + triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); + Type accMemDescType = triton::gpu::MemDescType::get( + oldRetType.getShape(), oldRetType.getElementType(), accEncoding, + tensorMemorySpace, + /*mutableMemory=*/true); + Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout( + instrShape[0], instrShape[1], retShapePerCTA, numWarps, CTALayout); + auto newAccType = RankedTensorType::get(oldRetType.getShape(), + oldRetType.getElementType(), + newDistributedEncoding); + Value cvtAcc = + rewriter.create(loc, newAccType, dotOp.getOperand(2)); + auto acc = rewriter.create( + loc, accMemDescType, cvtAcc); + PatternRewriterWithAsyncTaskIds taskIdRewriter(rewriter, dotOp); + auto vTrue = rewriter.create(dotOp.getLoc(), 1, 1); + auto mma = taskIdRewriter.create( + loc, a, b, acc, vTrue, vTrue, Value(), UnitAttr()); + mma.setTwoCtas(useTwoCTAs); + + auto ld = + rewriter.create(loc, newAccType, acc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, ld); + return success(); + } +}; + +Value addSmemStageToScaleLoad(Value scale, mlir::PatternRewriter &rewriter) { + /* + Rewrite load(scale) -> local_load(local_alloc(load(scale))). + This function does not add anything to the final IR when num_stages > 1, + but it makes it easy to apply TMEM copy rewriting later. + + Since scales are stored in TMEM for MMAv5 scaled dot, loading of scales do + not needs to be put into SMEM. But in practice, the software pipeliner puts + loading of scales into multi-buffered SMEM. At that point, the SMEM + allocation created here is eliminated. + */ + OpBuilder::InsertionGuard g(rewriter); + auto op = scale.getDefiningOp(); + Operation *loadConsumer = nullptr; + + if (!op) + return scale; + + while (!isa(op)) { + if (auto reshape = dyn_cast(op)) { + op = reshape.getSrc().getDefiningOp(); + loadConsumer = reshape; + } else if (auto trans = dyn_cast(op)) { + op = trans.getSrc().getDefiningOp(); + loadConsumer = trans; + } else if (auto cvt = dyn_cast(op)) { + op = cvt.getSrc().getDefiningOp(); + loadConsumer = cvt; + } else { + // Unrecognized pattern, bail out. In practice, this implies that MMA + // pipelining will not apply to the scaled dot op, since scales will not + // be in passed through SMEM to tc_gen5_mma_scaled. + return scale; + } + } + + auto scaleAfterLoad = op->getResult(0); + auto scaleSmemAlloc = + getSharedMemoryScale(scaleAfterLoad, rewriter, op->getLoc()); + + rewriter.setInsertionPointAfterValue(scaleSmemAlloc); + auto localLoad = rewriter.create( + op->getLoc(), scaleAfterLoad.getType(), scaleSmemAlloc); + + rewriter.replaceAllUsesExcept(scaleAfterLoad, localLoad.getResult(), + scaleSmemAlloc); + + if (loadConsumer) { + return scale; + } else { + return localLoad; + } +} + +class ScaledBlockedToMMAv5 + : public mlir::OpRewritePattern { + int computeCapability; + +public: + ScaledBlockedToMMAv5(mlir::MLIRContext *context, int computeCapability, + int benefit) + : mlir::OpRewritePattern(context, benefit), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotScaledOp dotOp, + mlir::PatternRewriter &rewriter) const override { + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + + if (dotOp.getAScale() == nullptr || dotOp.getBScale() == nullptr) { + return failure(); + } + + // get MMA encoding for the given number of warps + auto retShapePerCTA = getShapePerCTA(oldRetType); + int numWarps = lookupNumWarps(dotOp); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + if (computeCapability < 100) + return failure(); + if (retShapePerCTA[0] < 128 || retShapePerCTA[1] < 8) + return failure(); + Location loc = dotOp.getLoc(); + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = a.getType(); + auto oldBType = b.getType(); + + bool IsAMixedPrecFp4 = false; + bool IsBMixedPrecFp4 = false; + + if (dotOp.getAElemType() != dotOp.getBElemType()) { + if (dotOp.getAElemType() == ScaleDotElemType::E2M1) + IsAMixedPrecFp4 = true; + else if (dotOp.getBElemType() == ScaleDotElemType::E2M1) + IsBMixedPrecFp4 = true; + } + + // For mixed-precision fp4 operands, set allowTranspose = false, to force + // the packed axis, K, to be contiguous in SMEM + a = getSharedMemoryMMAOperand(a, rewriter, 0, + /*allowTranspose=*/!IsAMixedPrecFp4, + IsAMixedPrecFp4, dotOp); + b = getSharedMemoryMMAOperand(b, rewriter, 1, + /*allowTranspose=*/!IsBMixedPrecFp4, + IsBMixedPrecFp4, dotOp); + + MLIRContext *context = dotOp->getContext(); + unsigned m = 128; + unsigned n = retShapePerCTA[1] >= 256 ? 256 : retShapePerCTA[1]; + unsigned k = 32; + // If both operands are E2M1, target the FP4 tensor core implicitly. + // This may result in a downstream compile-time error if the scaled TC + // descriptor requires options that are unavailable to the .kind=mxf4 mma. + // This is likely preferable over a silent runtime performance degradation + // from running f4xf4 via .kind=mxf8f6f4 + if (dotOp.getAElemType() == ScaleDotElemType::E2M1 && + dotOp.getBElemType() == ScaleDotElemType::E2M1) { + k = 64; + } + SmallVector instrShape = {m, n, k}; + ArrayRef CTASplitNum = CTALayout.getCTASplitNum(); + Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get( + context, instrShape[0], instrShape[1], /*unpacked=*/true, + CTASplitNum[0], CTASplitNum[1]); + Attribute tensorMemorySpace = + triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); + Type accMemDescType = triton::gpu::MemDescType::get( + oldRetType.getShape(), oldRetType.getElementType(), accEncoding, + tensorMemorySpace, + /*mutableMemory=*/true); + Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout( + instrShape[0], instrShape[1], retShapePerCTA, numWarps, CTALayout); + auto newAccType = RankedTensorType::get(oldRetType.getShape(), + oldRetType.getElementType(), + newDistributedEncoding); + Value cvtAcc = + rewriter.create(loc, newAccType, dotOp.getOperand(2)); + auto acc = rewriter.create( + loc, accMemDescType, cvtAcc); + + RankedTensorType oldScaleAType = dotOp.getAScale().getType(); + RankedTensorType oldScaleBType = dotOp.getBScale().getType(); + + Attribute scaleEncoding = + triton::nvidia_gpu::TensorMemoryScalesEncodingAttr::get( + context, CTASplitNum[0], CTASplitNum[1]); + Type scaleAType = triton::gpu::MemDescType::get( + oldScaleAType.getShape(), oldScaleAType.getElementType(), scaleEncoding, + tensorMemorySpace, + /*mutableMemory=*/false); + Type scaleBType = triton::gpu::MemDescType::get( + oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleEncoding, + tensorMemorySpace, + /*mutableMemory=*/false); + Attribute scaleALayout = getTmemScales(oldScaleAType, numWarps); + Attribute scaleBLayout = getTmemScales(oldScaleBType, numWarps); + RankedTensorType newScaleAType = RankedTensorType::get( + oldScaleAType.getShape(), oldScaleAType.getElementType(), scaleALayout); + RankedTensorType newScaleBType = RankedTensorType::get( + oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleBLayout); + + auto lhsScale = addSmemStageToScaleLoad(dotOp.getAScale(), rewriter); + auto rhsScale = addSmemStageToScaleLoad(dotOp.getBScale(), rewriter); + + Value newScaleA = + rewriter.create(loc, newScaleAType, lhsScale); + Value newScaleB = + rewriter.create(loc, newScaleBType, rhsScale); + Value scaleA = rewriter.create( + loc, scaleAType, newScaleA); + Value scaleB = rewriter.create( + loc, scaleBType, newScaleB); + auto vTrue = rewriter.create(dotOp.getLoc(), 1, 1); + PatternRewriterWithAsyncTaskIds taskIdRewriter(rewriter, dotOp); + taskIdRewriter.create( + loc, a, b, acc, scaleA, scaleB, dotOp.getAElemType(), + dotOp.getBElemType(), vTrue, vTrue, Value()); + + auto ld = + rewriter.create(loc, newAccType, acc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, ld); + return success(); + } +}; +} // namespace + +static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, + Type promotedType) { + Type tensorPromotedType = cast(operand.getType()) + .cloneWith(std::nullopt, promotedType); + return builder.create(loc, tensorPromotedType, operand); +} + +// promote operands of dot op if the existing combination is not natively +// supported. +static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { + mod.walk([=](DotOp dotOp) -> void { + auto D = dotOp.getD(); + OpBuilder builder(dotOp); + Type AElType = dotOp.getA().getType().getElementType(); + Type promoteType; + NvidiaMmaEncodingAttr mmaLayout = + dyn_cast(D.getType().getEncoding()); + if (mmaLayout) { + bool isNativeFP8 = llvm::isa(AElType); + // promote operands for sm < 89 since fp8 mma is not natively supported + // promote operands for sm >= 90 when mma is not v3 + if (!isNativeFP8 || + (isNativeFP8 && (computeCapability == 89 || mmaLayout.isHopper()))) + return; + promoteType = builder.getF16Type(); + } else { + // FMA case. + Type AElType = dotOp.getA().getType().getElementType(); + Type DElType = D.getType().getElementType(); + if (AElType == DElType) + return; + promoteType = DElType; + } + Location loc = dotOp.getLoc(); + Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType); + Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType); + dotOp.setOperand(0, promotedA); + dotOp.setOperand(1, promotedB); + }); +} + +// Transpose scaled_dot ops that have a scale on lhs. +static void transposeDotOp(DotScaledOp dotOp) { + OpBuilder builder(dotOp); + Value lhs = dotOp.getA(); + std::array transOrder = {1, 0}; + Value lhsTransposed = builder.create(lhs.getLoc(), lhs, transOrder); + Value rhs = dotOp.getB(); + Value rhsTransposed = builder.create(rhs.getLoc(), rhs, transOrder); + Value c = dotOp.getC(); + Value cTransposed = builder.create(c.getLoc(), c, transOrder); + Value result = builder.create( + dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed, + cTransposed, dotOp.getBScale(), dotOp.getAScale(), dotOp.getBElemType(), + dotOp.getAElemType(), dotOp.getFastMath()); + Operation *transposedResult = + builder.create(result.getLoc(), result, transOrder); + dotOp.replaceAllUsesWith(transposedResult); + dotOp.erase(); +} + +static void transposeDots(ModuleOp m) { + // TODO: extend to regular dot when it is profitable. For instance when we may + // want to use rhs from register for mmav3. + SmallVector toTranspose; + m.walk([&](DotScaledOp dotOp) -> void { + if (dotOp.getAScale() == nullptr && dotOp.getBScale() != nullptr) + toTranspose.push_back(dotOp); + }); + for (DotScaledOp dotOp : toTranspose) { + transposeDotOp(dotOp); + } +} + +#define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUAccelerateMatmulPass + : public impl::TritonGPUAccelerateMatmulBase< + TritonGPUAccelerateMatmulPass> { +public: + using impl::TritonGPUAccelerateMatmulBase< + TritonGPUAccelerateMatmulPass>::TritonGPUAccelerateMatmulBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + auto computeCapability = getNVIDIAComputeCapability(m); + // We could do this generically if we manage to improve the heuristics + // reverted in these two PRs https://github.com/triton-lang/triton/pull/5834 + // https://github.com/triton-lang/triton/pull/5837 + transposeDots(m); + + mlir::RewritePatternSet patterns(context); + constexpr int benefitDefault = 1; + constexpr int benefitMMAv5 = 10; + patterns.add(context, computeCapability, benefitDefault); + populateDecomposeScaledBlockedPatterns(patterns, benefitDefault); + patterns.add( + context, computeCapability, benefitMMAv5); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + // Now that we have picked the mma type, decompose dot that are not natively + // supported. + decomposeMixedModeDotOp(m, computeCapability); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..2b4969982 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,50 @@ +add_triton_library(TritonGPUTransforms + AccelerateMatmul.cpp + Coalesce.cpp + F32DotTC.cpp + FuseNestedLoops.cpp + CombineTensorSelectAndIf.cpp + DecomposeScaledBlocked.cpp + ReduceDataDuplication.cpp + OptimizeAccumulatorInit.cpp + OptimizeDotOperands.cpp + OptimizeThreadLocality.cpp + Pipeliner/AssignLatencies.cpp + Pipeliner/LowerLoops.cpp + Pipeliner/ScheduleLoops.cpp + Pipeliner/WGMMAPipeline.cpp + Pipeliner/PipelineExpander.cpp + Pipeliner/TestPipelineAssignLatencies.cpp + Pipeliner/TestPipelineScheduleLoop.cpp + Pipeliner/TestPipelineLowerLoop.cpp + Pipeliner/SoftwarePipeliner.cpp + Pipeliner/TC05MMAPipeline.cpp + Pipeliner/TMAStoresPipeline.cpp + Pipeliner/ModifiedAccMMAPipeline.cpp + Pipeliner/PipeliningUtility.cpp + Pipeliner/Schedule.cpp + Prefetch.cpp + RemoveLayoutConversions.cpp + ReorderInstructions.cpp + CoalesceAsyncCopy.cpp + Utility.cpp + TaskIdPropagate.cpp + WSTaskPartition.cpp + WSDataPartition.cpp + WSCodePartition.cpp + WSLowering.cpp + PingPong.cpp + WSCanonicalization.cpp + + DEPENDS + TritonGPUTransformsIncGen + + LINK_LIBS PUBLIC + MLIRTransforms + MLIRTransformUtils + TritonAnalysis + TritonIR + TritonGPUIR + TritonNvidiaGPUIR + MLIRTransformUtils +) diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp new file mode 100644 index 000000000..c9545f043 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -0,0 +1,195 @@ +#include +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct CoalescePass : public impl::TritonGPUCoalesceBase { + void + setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); + auto refTensorType = cast(ptr.getType()); + + LDBG("Considering op: " << *op); + LLVM_DEBUG({ + DBGS() << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); + SmallVector order = argSort(contiguity); + LDBG("order=[" << triton::join(order, ", ") << "]"); + + auto matchesShape = [&refTensorType](const Value &val) { + auto rttType = dyn_cast(val.getType()); + return rttType && rttType.getShape() == refTensorType.getShape(); + }; + + // The desired divisibility is the maximum divisibility among all dependent + // pointers which have the same shape and order as `ptr`. + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); + if (ptr.getDefiningOp()) { + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + memAccessesSameOrder.insert(use); + } + } + } + + auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType); + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); + + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + + unsigned perThread = getNumElementsPerThread(op, order, axisInfoAnalysis); + LDBG("perThread for op: " << perThread); + + for (Operation *opSameOrder : memAccessesSameOrder) { + if (opSameOrder == op) + continue; + unsigned currPerThread = + getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); + LDBG("perThread for opSameOrder: " << currPerThread); + perThread = std::max(perThread, currPerThread); + } + + perThread = std::min(perThread, std::max(numElems / numThreads, 1)); + LDBG("perThread: " << perThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + perThread = std::min( + perThread, getNumElementsPerThread(op, order, axisInfoAnalysis)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; + + auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, CTALayout); + } + + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + } + + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + + void runOnOperation() override { + // Run axis info analysis + ModuleOp moduleOp = getOperation(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing + llvm::MapVector layoutMap; + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(moduleOp); + moduleOp.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + // We only convert `tensor>` load/store + bool isPtrTensor = false; + if (auto tensorType = dyn_cast(ptr.getType())) + isPtrTensor = isa(tensorType.getElementType()); + if (!isPtrTensor) + return; + int numWarps = lookupNumWarps(curr); + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); + }); + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + for (auto &kv : layoutMap) { + coalesceOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp new file mode 100644 index 000000000..9ce7e1714 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp @@ -0,0 +1,128 @@ +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This pass currently only applies if the following are all true... +// 1) Operand A for WGMMA is to be loaded in registers +// 2) We upcast operand A in registers before the WGMMA +// (downcasting is not yet supported) +// 3) Pipelining is enabled for loading A +// +// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding +// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if +// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread +// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two +// 8-byte-cp.async's for each contiguous 16B global data owned by each +// thread. This breaks coalescing (i.e. results 2x the minimum required +// transactions). +// +// This issue occurs for cp.async because it combines load and store into one +// instruction. The fix is to clip each dim of sizePerThread by shared vec, so +// that the vectorization of load and store are equal along the contiguous +// dimension. In the above example, each thread will then only own 8B contiguous +// global data. +struct ClipAsyncCopySizePerThread + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp, + PatternRewriter &rewriter) const override { + Value src = copyOp.getSrc(); + Value mask = copyOp.getMask(); + Value other = copyOp.getOther(); + auto srcTy = cast(src.getType()); + auto dstTy = cast(copyOp.getResult().getType()); + auto blockedEnc = dyn_cast(srcTy.getEncoding()); + if (!blockedEnc) + return rewriter.notifyMatchFailure(copyOp, + "src must be of blocked encoding"); + auto sharedEnc = dyn_cast(dstTy.getEncoding()); + if (!sharedEnc) + return failure(); + auto sharedVec = sharedEnc.getVec(); + + // obtain max contiguous copy size + // Note this can be further optimized, as copyContigSize can be even + // smaller when lowering, depending on contiguity and mask alignment + // (see AsyncCopyGlobalToLocalOpConversion) + LinearLayout regLayout = + triton::gpu::toLinearLayout(srcTy.getShape(), blockedEnc); + LinearLayout sharedLayout = + triton::gpu::toLinearLayout(srcTy.getShape(), sharedEnc); + auto copyContigSize = + regLayout.invertAndCompose(sharedLayout).getNumConsecutiveInOut(); + + // obtain block sizePerThread along contig dim + auto contigPerThread = getContigPerThread(srcTy); + auto blockContigSize = contigPerThread[blockedEnc.getOrder()[0]]; + + if (blockContigSize <= copyContigSize) + return rewriter.notifyMatchFailure( + copyOp, + "blocked sizePerThread along contiguous dim must be greater than the " + "max contiguous copy size "); + + contigPerThread[blockedEnc.getOrder()[0]] = copyContigSize; + + // obtain new blockedEnc based on clipped sizePerThread + auto mod = copyOp->getParentOfType(); + int numWarps = lookupNumWarps(copyOp); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + auto newBlockEnc = BlockedEncodingAttr::get( + copyOp.getContext(), srcTy.getShape(), contigPerThread, + blockedEnc.getOrder(), numWarps, threadsPerWarp, + blockedEnc.getCTALayout()); + + // insert cvt's after src, mask, and other + auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) { + auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); + auto cvt = rewriter.create(copyOp->getLoc(), newTy, src); + return cvt.getResult(); + }; + src = convertBlockLayout(src, newBlockEnc); + if (mask) + mask = convertBlockLayout(mask, newBlockEnc); + if (other) + other = convertBlockLayout(other, newBlockEnc); + + rewriter.modifyOpInPlace(copyOp, [&]() { + copyOp.getSrcMutable().assign(src); + if (mask) + copyOp.getMaskMutable().assign(mask); + if (other) + copyOp.getOtherMutable().assign(other); + }); + + return success(); + } +}; + +class CoalesceAsyncCopyPass + : public impl::TritonGPUCoalesceAsyncCopyBase { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + MLIRContext *context = &getContext(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(applyPatternsGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp new file mode 100644 index 000000000..9963db357 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -0,0 +1,174 @@ +#include "mlir/IR/Dominance.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +/// The user of select maybe inside either the ThenRegion or ElseRegion of +/// the scf.if. So, canonicalize user of select in scf.if first. +static void canonicalizeSelectUsersInSCFIf(ModuleOp input) { + llvm::MapVector, SmallVector> + usersNeedreplaced; + input.walk([&](arith::SelectOp selectOp) { + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + Value trueVal = selectOp.getOperand(1); + Value falseVal = selectOp.getOperand(2); + Value resVal = selectOp.getResult(); + for (auto *condUser : condition.getUsers()) { + if (!llvm::isa(condUser)) + continue; + scf::IfOp ifOp = llvm::cast(condUser); + for (auto *resUser : resVal.getUsers()) { + if (ifOp->isProperAncestor(resUser)) { + if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) != + nullptr) { + // The user is inside the ThenRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back( + resUser); + } else { + // The user is inside the ElseRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back( + resUser); + } + } + } + } + }); + + // Replace the operand of user. + for (auto [replacedSrcAndDst, users] : + llvm::make_early_inc_range(usersNeedreplaced)) { + Value srcVal = replacedSrcAndDst.first; + Value dstVal = replacedSrcAndDst.second; + for (Operation *user : llvm::make_early_inc_range(users)) { + srcVal.replaceUsesWithIf( + dstVal, [&](OpOperand &use) { return use.getOwner() == user; }); + } + } +} + +/// Return true if the select could be merged into the If without breaking SSA +/// rules. +static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp, + DominanceInfo &dom) { + // If needs to be dominated by the select. + if (!dom.dominates(selectOp.getOperation(), ifOp.getOperation())) { + return false; + } + // If needs to dominate all the select's users. + for (auto user : selectOp.getResult().getUsers()) { + if (!dom.dominates(ifOp, user)) { + return false; + } + } + return true; +} + +class CombineTensorSelectAndIfPass + : public impl::TritonGPUCombineTensorSelectAndIfBase< + CombineTensorSelectAndIfPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + canonicalizeSelectUsersInSCFIf(m); + + // Go over the arith.select ops, look if there is an if + // with the same condition. + DominanceInfo dom(m); + llvm::MapVector> selectToIf; + m.walk([&](arith::SelectOp selectOp) { + // Apply only to selects with a tensor result. Scalars are cheap enough to + // predicate. + if (!isa(selectOp.getResult().getType())) + return; + // Look if there is an if in the same block, with the same condition. + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + SetVector conditionUsers(condition.getUsers().begin(), + condition.getUsers().end()); + // sort the users in topological order. + conditionUsers = multiRootTopologicalSort(conditionUsers); + // Get condition's users + for (Operation *user : conditionUsers) { + auto ifOp = dyn_cast(user); + if (!ifOp || ifOp->getBlock() != parentBlock) + continue; + if (canMergeIntoIf(selectOp, ifOp, dom)) { + selectToIf[ifOp].push_back(selectOp); + break; + } + } + }); + + for (auto [ifOp, selectOps] : selectToIf) { + // Add new return value to the if (and create else block if necessary), + // then yield the select value in the then block and the else block. + OpBuilder builder(ifOp); + auto loc = ifOp.getLoc(); + // Create an scf::IfOp with extra return value. + SmallVector newResultTypes = {ifOp.getResultTypes().begin(), + ifOp.getResultTypes().end()}; + for (arith::SelectOp selectOp : selectOps) { + newResultTypes.push_back(selectOp.getResult().getType()); + } + auto newIfOp = builder.create( + loc, newResultTypes, ifOp.getCondition(), /*hasElse*/ true); + // Move the existing blocks to the new if. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + + if (ifOp.elseBlock()) { + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + } else { + // Create an empty yield + auto yieldOp = newIfOp.getElseBodyBuilder().create(loc); + } + + SmallVector ifYieldOperands = newIfOp.thenYield().getOperands(); + SmallVector elseYieldOperands = newIfOp.elseYield().getOperands(); + for (arith::SelectOp selectOp : selectOps) { + Value thenValue = selectOp.getTrueValue(); + Value elseValue = selectOp.getFalseValue(); + ifYieldOperands.push_back(thenValue); + elseYieldOperands.push_back(elseValue); + } + // Update yields + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + builder.setInsertionPoint(yield); + builder.create(loc, operands); + yield.erase(); + }; + updateYield(newIfOp.thenYield(), ifYieldOperands); + updateYield(newIfOp.elseYield(), elseYieldOperands); + + int resultIdx = 0; + // Replace old if with the new one. + for (auto result : ifOp.getResults()) { + result.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + } + // Replace the select with the new return value. + for (arith::SelectOp selectOp : selectOps) { + selectOp.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + selectOp.erase(); + } + + ifOp.erase(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp new file mode 100644 index 000000000..13b56bd60 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp @@ -0,0 +1,251 @@ +#include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h" + +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +SmallVector getTransposeOrder(int rank) { + assert(rank >= 2); + auto transOrder = llvm::to_vector<2>(llvm::seq(rank - 2)); + transOrder.push_back(rank - 1); + transOrder.push_back(rank - 2); + return transOrder; +} + +class DecomposeScaledBlocked : public OpRewritePattern { + +public: + DecomposeScaledBlocked(MLIRContext *context, int benefit) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(DotScaledOp scaledDotOp, + PatternRewriter &rewriter) const override { + // Types + auto computeType = getComputeType(scaledDotOp.getAElemType(), + scaledDotOp.getBElemType(), rewriter); + auto loc = scaledDotOp.getLoc(); + + auto cvtDotOperand = [&](TypedValue v, + int opIdx) -> TypedValue { + auto *ctx = rewriter.getContext(); + auto retEnc = scaledDotOp.getType().getEncoding(); + auto vType = v.getType(); + auto encoding = DotOperandEncodingAttr::get(ctx, opIdx, retEnc, + vType.getElementType()); + auto retTy = RankedTensorType::get(vType.getShape(), + vType.getElementType(), encoding); + return rewriter.create(loc, retTy, v); + }; + + auto scaledA = scaleArg(rewriter, scaledDotOp, 0, computeType); + scaledA = cvtDotOperand(scaledA, 0); + auto scaledB = scaleArg(rewriter, scaledDotOp, 1, computeType); + scaledB = cvtDotOperand(scaledB, 1); + auto newDot = rewriter.create(scaledDotOp.getLoc(), scaledA, scaledB, + scaledDotOp.getC()); + + rewriter.replaceOpWithNewOp(scaledDotOp, + scaledDotOp.getType(), newDot); + return success(); + } + +private: + FloatType getComputeType(ScaleDotElemType aType, ScaleDotElemType bType, + PatternRewriter &rewriter) const { + if (aType == ScaleDotElemType::FP16 || bType == ScaleDotElemType::FP16) + return rewriter.getF16Type(); + return rewriter.getBF16Type(); + } + + TypedValue scaleTo16(PatternRewriter &rewriter, + TypedValue scale, + FloatType computeType) const { + auto loc = scale.getLoc(); + auto scaleTy = scale.getType(); + assert(computeType == rewriter.getBF16Type() || + computeType == rewriter.getF16Type()); + + // Choose an fp type that can fit the scale value. + FloatType largeFpType = computeType == rewriter.getF16Type() + ? rewriter.getF32Type() + : computeType; + int intWidth = largeFpType.getIntOrFloatBitWidth(); + auto intType = rewriter.getIntegerType(intWidth); + + auto zexted = + rewriter.create(loc, scaleTy.clone(intType), scale); + // getFpMantissaWidth() returns the number of bits in the mantissa plus the + // sign bit! + int shiftValue = largeFpType.getFPMantissaWidth() - 1; + auto shiftConst = + rewriter.create(loc, shiftValue, intWidth); + auto shift = + rewriter.create(loc, scaleTy.clone(intType), shiftConst); + auto shlRes = rewriter.create(loc, zexted, shift); + Value scaleFP = + rewriter.create(loc, scaleTy.clone(largeFpType), shlRes); + if (largeFpType != computeType) { + scaleFP = rewriter.create( + loc, scaleTy.clone(computeType), scaleFP); + } + return cast>(scaleFP); + } + + TypedValue + broadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp, + ModuleOp mod, TypedValue scale, + int dim) const { + auto *ctx = rewriter.getContext(); + auto loc = scale.getLoc(); + auto scaleTy = scale.getType(); + auto rank = scaleTy.getRank(); + // 2.1) Expand dims along the last dimension + { + // 2.1.1) Find default encoding for ExpandDims + auto shape = to_vector(scaleTy.getShape()); + shape.insert(shape.end(), 1); + auto nWarps = lookupNumWarps(scaledDotOp); + auto threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + auto numCTAs = TritonGPUDialect::getNumCTAs(mod); + auto blockedEnc = getDefaultBlockedEncoding(ctx, shape, nWarps, + threadsPerWarp, numCTAs); + // 2.1.2) Cast scale16 to SliceEncoding + auto sliceEnc = SliceEncodingAttr::get(ctx, rank, blockedEnc); + auto sliceType = RankedTensorType::get( + scaleTy.getShape(), scaleTy.getElementType(), sliceEnc); + scale = rewriter.create(loc, sliceType, scale); + } + auto expandScale = rewriter.create(loc, scale, rank); + // 2.2) Broadcast the dimension to size 32 + auto scaleShape = to_vector(scaleTy.getShape()); + scaleShape.push_back(32); + auto broadcastScale = rewriter.create( + loc, expandScale.getType().clone(scaleShape), expandScale); + // 2.3) Transpose the dimension to the scaled dimension + auto transposeOrder = llvm::to_vector(llvm::seq(rank)); + transposeOrder.insert(transposeOrder.begin() + dim + 1, rank); + auto transposedScale = + rewriter.create(loc, broadcastScale, transposeOrder); + // 2.4) Reshape to the shape of v + scaleShape.pop_back(); + scaleShape[dim] *= 32; + auto reshapeScale = + rewriter.create(loc, scaleShape, transposedScale); + return reshapeScale; + } + + TypedValue maskNan(PatternRewriter &rewriter, + DotScaledOp scaledDotOp, ModuleOp mod, + TypedValue mxfp, + TypedValue scale, + int dim) const { + // Implement tl.where(scale == 0xFF, float("nan"), mxfp) + auto loc = scale.getLoc(); + + // Scale is NaN + auto scaleTy = scale.getType(); + auto constFF = rewriter.create( + loc, scaleTy, + DenseElementsAttr::get(scaleTy, + APInt(scaleTy.getElementTypeBitWidth(), 0xff))); + auto scaleIsNan = cast>( + rewriter + .create(loc, arith::CmpIPredicate::eq, scale, + constFF) + .getResult()); + auto cond = broadcastScale(rewriter, scaledDotOp, mod, scaleIsNan, dim); + // Make scale is NaN compatible with mxfp + auto condTy = cond.getType(); + condTy = RankedTensorType::get(condTy.getShape(), condTy.getElementType(), + mxfp.getType().getEncoding()); + cond = rewriter.create(loc, condTy, cond); + + // Create NaN + auto mxfpTy = mxfp.getType(); + auto nan = APFloat::getNaN( + cast(mxfpTy.getElementType()).getFloatSemantics()); + auto constNan = rewriter.create( + loc, mxfpTy, DenseElementsAttr::get(mxfpTy, nan)); + + auto result = rewriter.create(loc, cond, constNan, mxfp); + return cast>(result.getResult()); + } + + TypedValue scaleArg(PatternRewriter &rewriter, + DotScaledOp scaledDotOp, int opIdx, + FloatType computeType) const { + auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB(); + auto scale = opIdx == 0 ? scaledDotOp.getAScale() : scaledDotOp.getBScale(); + auto isFp4 = + ScaleDotElemType::E2M1 == + (opIdx == 0 ? scaledDotOp.getAElemType() : scaledDotOp.getBElemType()); + auto fastMath = scaledDotOp.getFastMath(); + + auto *ctx = rewriter.getContext(); + auto loc = v.getLoc(); + auto mod = scaledDotOp->getParentOfType(); + auto rank = v.getType().getRank(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; + + // 0) Upcast value to computeType (fp16/bf16) + if (isFp4) { + // We always pack along the fastest moving dimension, kDim + v = rewriter.create(loc, v, computeType, kDim); + } else { + auto vType16 = v.getType().clone(computeType); + v = cast>( + rewriter.create(loc, vType16, v).getResult()); + } + if (!scale) + return v; + + // For some weird reason, we take the scale with shape as if it were coming + // from the lhs even when it's the rhs. In a normal world, we should accept + // this parametre transposed, as we do with the mxfp. + if (opIdx == 1) { + auto order = getTransposeOrder(rank); + scale = rewriter.create(loc, scale, order); + } + + // 1) Cast scale to compute type (fp16/bf16) + auto scale16 = scaleTo16(rewriter, scale, computeType); + + // 2) Broadcast scale to the same shape and layout as v + auto reshapeScale = + broadcastScale(rewriter, scaledDotOp, mod, scale16, kDim); + reshapeScale = + rewriter.create(loc, v.getType(), reshapeScale); + + // 3) Multiply + auto mxfp = cast>( + rewriter.create(loc, v, reshapeScale).getResult()); + + // Skip NaN checks if fastMath + if (fastMath) + return mxfp; + + // 4) If the scale is NaN, return NaN, else return the scaled value. + return maskNan(rewriter, scaledDotOp, mod, mxfp, scale, kDim); + } +}; + +} // namespace + +namespace mlir::triton::gpu { + +void populateDecomposeScaledBlockedPatterns(RewritePatternSet &patterns, + int benefit) { + patterns.add(patterns.getContext(), benefit); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp new file mode 100644 index 000000000..6fe35aebd --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -0,0 +1,120 @@ +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUF32DOTTC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +// nb. We call the trick TF32x3 as C++ disallows variables starting with numbers +// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385 +// For a, b f32 +// dot(a, b, inputPrecision="tf32x3") -> +// let aBig = f32ToTF32(a), aSmall = a - aBig; +// let bBig = f32ToTF32(b), bSmall = b - bBig; +// let small = dot(aSmall, bBig, inputPrecision="tf32") + +// dot(aBig, bSmall, inputPrecision="tf32") +// let masked_nans = replaceNansWithZeros(small) +// let big = dot(aBig, bBig, inputPrecision="tf32") +// return big + masked_nans; +class TF32x3 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + + auto isF32 = [](Value operand) { + return cast(operand.getType()).getElementType().isF32(); + }; + + if (!(dotOp.getInputPrecision() == InputPrecision::TF32x3 && + isF32(dotOp.getA()) && isF32(dotOp.getB()))) { + return failure(); + } + + // Aux functions + auto f32ToTF32 = [&](Value value) -> Value { + return rewriter + .create(dotOp.getLoc(), value.getType(), + "cvt.rna.tf32.f32 $0, $1;", "=r,r", + /*isPure=*/true, /*pack=*/1, + ArrayRef{value}) + .getResult()[0]; + }; + auto zeroLike = [&](Value c) -> Value { + return rewriter.create( + dotOp->getLoc(), c.getType(), + rewriter.create(dotOp->getLoc(), + rewriter.getF32FloatAttr(0))); + }; + auto add = [&](Value a, Value b) -> Value { + return rewriter.create(dotOp.getLoc(), a, b); + }; + auto sub = [&](Value a, Value b) -> Value { + return rewriter.create(dotOp.getLoc(), a, b); + }; + auto dot = [&](Value a, Value b, Value c) -> Value { + return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, + InputPrecision::TF32, + dotOp.getMaxNumImpreciseAcc()); + }; + auto replaceNansWithZeros = [&](Value value) -> Value { + auto nans = rewriter.create( + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); + auto zero = zeroLike(value); + return rewriter.create(dotOp->getLoc(), nans, zero, + value); + }; + + auto aBig = f32ToTF32(dotOp.getA()); + auto aSmall = sub(dotOp.getA(), aBig); + + auto bBig = f32ToTF32(dotOp.getB()); + auto bSmall = sub(dotOp.getB(), bBig); + + auto zero = zeroLike(dotOp.getC()); + + auto dot1 = dot(aSmall, bBig, zero); + auto dot2 = dot(aBig, bSmall, dot1); + + // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. + // If rhs is +infinity, we will have: + // +infinity * 1.0 = +infinity + // +infinity * 0.0 = NaN + // We would get the wrong result if we sum these partial products. Instead, + // we must override any accumulated result if the last partial product is + // non-finite. + auto dot2withZeroedNans = replaceNansWithZeros(dot2); + auto dot3 = dot(aBig, bBig, dot2withZeroedNans); + + auto sum = add(dot3, dotOp.getC()); + + rewriter.replaceOp(dotOp, sum); + return success(); + } +}; + +} // anonymous namespace + +struct F32DotTCPass : public impl::TritonGPUF32DotTCBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { + signalPassFailure(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp new file mode 100644 index 000000000..7ab936cdc --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -0,0 +1,1081 @@ +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "llvm/Support/Debug.h" +#include + +namespace mlir { +namespace triton { +namespace gpu { + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_TRITONGPUFUSENESTEDLOOPS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This attribute is set by the front-end to control whether fusion is on. +static constexpr llvm::StringLiteral kFlattenAttr = "tt.flatten"; +// This attribute indicates the inner loop length has been speculated. +static constexpr llvm::StringLiteral kMustExecuteAttrName = "ttg.must-execute"; +// This attribute is just used for testing the pass. +static constexpr llvm::StringLiteral kAlwaysFuseAttrName = "ttg.always-fuse"; + +namespace { +struct FuseNestedLoopsPass + : public impl::TritonGPUFuseNestedLoopsBase { + using TritonGPUFuseNestedLoopsBase::TritonGPUFuseNestedLoopsBase; + + void runOnOperation() override; +}; + +//===----------------------------------------------------------------------===// +// LoopNest +//===----------------------------------------------------------------------===// + +// A node in the loop nest represents a single for loop with a list of +// immediately nested loops. +struct LoopNestNode { + LoopNestNode(scf::ForOp loop) : loop(loop) {} + + // The for loop. + scf::ForOp loop; + // Loops nested immediately below this loop. + SmallVector children; +}; + +// A loop nest is a tree of loops. +struct LoopNest { + LoopNest(scf::ForOp outermost); + + // Print the loop nest. + void print(raw_ostream &os) const; + // Dump the loop nest for debugging. + LLVM_DUMP_METHOD void dump() const; + + // Owner of the memory of the nodes. + SmallVector> nodes; + + // The outermost loop in the nest, which has no preconditions. Even if the + // outermost loop is contained within an if, its preconditions relative to the + // loop nest are empty. + LoopNestNode *root; +}; +} // namespace + +LoopNest::LoopNest(scf::ForOp outermost) + : root( + nodes.emplace_back(std::make_unique(outermost)).get()) { +} + +void LoopNest::print(raw_ostream &os) const { + // Print just the first line of the loop's textual IR. + std::string buffer; + auto printLoopFirstLine = [&](scf::ForOp loop) { + buffer.clear(); + llvm::raw_string_ostream str(buffer); + loop.print(str); + os << buffer.substr(0, buffer.find('\n')); + }; + + os << "LoopNest:\n"; + SmallVector> stack; + stack.emplace_back(root, 0); + while (!stack.empty()) { + auto [node, indent] = stack.pop_back_val(); + + // Print the current loop. + os << std::string(indent * 2, ' '); + printLoopFirstLine(node->loop); + os << "\n"; + + // Push the children of the current loop. + for (LoopNestNode *child : node->children) + stack.emplace_back(child, indent + 1); + } + os << "\n"; +} + +void LoopNest::dump() const { print(llvm::dbgs()); } + +//===----------------------------------------------------------------------===// +// findLoopNests +//===----------------------------------------------------------------------===// + +// Forward declaration. +static void findLoopNests(Operation *container, + SmallVectorImpl &nests); + +// Recursively construct a loop nest. +static void constructLoopNest(LoopNestNode *parent, LoopNest &nest, + SmallVectorImpl &nests) { + parent->loop->walk([&](Operation *op) { + if (op == parent->loop) + return WalkResult::advance(); + + if (auto forOp = dyn_cast(op)) { + auto &child = + nest.nodes.emplace_back(std::make_unique(forOp)); + parent->children.push_back(child.get()); + // Recurse with the current loop nest. + constructLoopNest(child.get(), nest, nests); + return WalkResult::skip(); + } + + // If the traversal encounters any other operation with regions, restart the + // traversal and construct new loop nests. This means ops like `scf.while` + // divide the analysis domain, but it also means loop fusion won't "see" + // across `scf.if`, for example. + // TODO: Handle loop nests with preconditions. The traversal can keep a + // stack of `scf.if` preconditions while constructing the loop nest. + if (op->getNumRegions()) { + findLoopNests(op, nests); + return WalkResult::skip(); + } + + return WalkResult::advance(); + }); +} + +// Find all the loop nests in the operation. The only region operation that +// allows CFG regions is `tt.func`. That means we can just walk starting from +// the function body and can build loop nests directly off the region trees +// contained in the function -- we don't have to worry about CFGs inside the +// nested region trees. +static void findLoopNests(Operation *container, + SmallVectorImpl &nests) { + container->walk([&](scf::ForOp loop) { + LoopNest nest(loop); + constructLoopNest(nest.root, nest, nests); + nests.push_back(std::move(nest)); + return WalkResult::skip(); + }); +} + +//===----------------------------------------------------------------------===// +// Logue +//===----------------------------------------------------------------------===// + +namespace { +// A prologue or epilogue. +struct Logue { + // Move the ops in the logue before the iterator. + void moveBefore(Block *block, Block::iterator it) { + for (Operation *op : ops) + op->moveBefore(block, it); + } + + // Replace all uses of the logue results with the given values, where `logue` + // comprises all the ops in `containingRegion`. + void replaceAllUsesWith(ValueRange values, Region &containingRegion) { + for (auto [newOut, output] : llvm::zip(values, outputs)) { + // Replace uses of the prologue outputs that are not in the prologue, i.e. + // inside the `then` region where it got spliced. + output.replaceUsesWithIf(newOut, [&](OpOperand &use) { + return !containingRegion.isAncestor(use.getOwner()->getParentRegion()); + }); + } + } + + // Get the number of outputs. + unsigned getNumOutputs() const { return outputs.size(); } + // Get the outputs as a `ValueRange`. + ValueRange getOutputs() const { return outputs; } + // Get the types of the outputs. + TypeRange getOutputTypes() const { return getOutputs().getTypes(); } + + // A contiguous range of ops representing the prologue or epilogue. + SmallVector ops; + // The outputs of the logue. These are the SSA value results of `ops` that are + // used by ops outside of `ops`. + SmallVector outputs; +}; +} // namespace + +// Given a range of ops, form it into a logue by finding the outputs. +static Logue createLogueFrom(llvm::iterator_range ops, + mlir::DominanceInfo &domInfo) { + Logue logue; + for (Operation &op : ops) + logue.ops.push_back(&op); + + if (ops.empty()) + return logue; + + // An op result is an output of the logue if the last operation in the logue + // dominates any of its users. + Operation &lastOp = *std::prev(ops.end()); + auto isOutput = [&](OpResult result) { + for (Operation *user : result.getUsers()) { + if (domInfo.properlyDominates(&lastOp, user)) + return true; + } + return false; + }; + + // Find the outputs. + for (Operation &op : ops) { + for (OpResult result : op.getOpResults()) { + if (isOutput(result)) + logue.outputs.push_back(result); + } + } + + return logue; +} + +//===----------------------------------------------------------------------===// +// fuseOneLevel +//===----------------------------------------------------------------------===// + +// Only hoist operations that are side-effect free and "cheap" (i.e. only scalar +// operands). Importantly, we need to be able to hoist code generated by fusing +// children loops into their parents so the algorithm can be applied +// recursively. +static bool canHoistLoopBoundComputation(Operation *op) { + auto isScalar = [](Type type) { return type.isIntOrIndexOrFloat(); }; + return isMemoryEffectFree(op) && + llvm::all_of(op->getOperandTypes(), isScalar) && + llvm::all_of(op->getResultTypes(), isScalar); +} + +// Determine if all of `values` are or can be made invariant to the outer loop +// by hoisting operations. `toHoist` is shared across all child loop bounds. +static bool isOuterLoopInvariant(mlir::DominanceInfo &domInfo, scf::ForOp outer, + ArrayRef values, + llvm::SetVector &toHoist) { + // The set of operations within `outer` that are being checked if they can be + // hoisted. This set prevents checking operations twice but also if the + // computation can be hoisted, this becomes the set of operations to hoist. + llvm::SetVector visited; + + // Climb the use-def chain breadth-first so that operations can be hoisted in + // the reverse visitation order. + std::queue queue; + for (Value value : values) + queue.push(value); + + while (!queue.empty()) { + Value value = queue.front(); + queue.pop(); + + // If the value properly dominates the outer loop, then it must be invariant + // to it. + if (domInfo.properlyDominates(value, outer)) + continue; + // If the value is a block argument, it cannot be hoisted. + if (auto arg = dyn_cast(value)) + return false; + + Operation *op = value.getDefiningOp(); + // Check if the op was already visited. + if (visited.contains(op)) + continue; + // If the defining op cannot be hoisted, then the value cannot be made loop + // invariant. + if (!canHoistLoopBoundComputation(op)) + return false; + visited.insert(op); + // Recurse on the operands of the op. + for (Value operand : op->getOperands()) + queue.push(operand); + } + + // The operations in `visited` must be hoisted. Note that operations are not + // added to `toHoist` unless all of `values` can be hoisted. This is to avoid + // hoisting operations for loops that don't end up getting fused if one of + // their bounds operands cannot be hoisted. + toHoist.insert(visited.begin(), visited.end()); + + return true; +} + +// Pessimistically assume the internal storage bitwidth for index types. +static unsigned getIntTypeWidth(Type type) { + if (isa(type)) + return IndexType::kInternalStorageBitWidth; + return cast(type).getWidth(); +} + +// Generate IR to compute the number of iterations of a loop. +static Value computeNumIters(ImplicitLocOpBuilder &b, scf::ForOp loop) { + // len(range(lb, ub, step)) = ceildiv(ub - lb, step) + // This works even if step is negative. + Value diff = + b.create(loop.getUpperBound(), loop.getLowerBound()); + // Let someone else prove it can be unsigned. + return b.create(diff, loop.getStep()); +} + +// Cast an integer or index value to an integer or index `type`, if necessary. +static Value castIntIfNecessary(ImplicitLocOpBuilder &b, Value value, + Type type) { + if (value.getType() == type) + return value; + if (isa(value.getType()) || isa(type)) + return b.create(type, value); + if (cast(value.getType()).getWidth() > + cast(type).getWidth()) + return b.create(type, value); + return b.create(type, value); +} + +// To model an "undef" value, i.e. a value that is known to never be read on +// live code paths, create a zero-valued constant where possible, otherwise use +// a poison value. PTXAS appears to generate better code with zeros compared to +// poison values. +static Value createPoisonOrZero(ImplicitLocOpBuilder &b, Type type) { + Type elTy = getElementTypeOrSelf(type); + if (!elTy.isIntOrIndexOrFloat() || + (!isa(type) && type != elTy)) + return b.create(type); + + TypedAttr attr = isa(elTy) ? TypedAttr(b.getFloatAttr(elTy, 0)) + : b.getIntegerAttr(elTy, 0); + if (auto tensor = dyn_cast(type)) + attr = SplatElementsAttr::get(tensor, attr); + return b.create(attr); +} + +static scf::YieldOp getYield(Region &body) { + return cast(body.front().back()); +} + +static scf::IfOp eraseIfResults(ImplicitLocOpBuilder &b, scf::IfOp ifOp, + llvm::BitVector indices, + SmallVector replaceWith) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(ifOp); + while (indices.size() < ifOp.getNumResults()) + indices.push_back(false); + + getYield(ifOp.getThenRegion())->eraseOperands(indices); + getYield(ifOp.getElseRegion())->eraseOperands(indices); + + TypeRange newTypes = getYield(ifOp.getThenRegion()).getOperandTypes(); + auto newIf = b.create(newTypes, ifOp.getCondition()); + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + + SmallVector replacements; + auto replIt = replaceWith.begin(); + auto resIt = newIf->result_begin(); + for (unsigned i : llvm::seq(ifOp.getNumResults())) + replacements.push_back(indices[i] ? *replIt++ : *resIt++); + assert(ValueRange(replacements).getTypes() == ifOp.getResultTypes()); + ifOp.replaceAllUsesWith(replacements); + ifOp.erase(); + return newIf; +} + +// Given a one level loop nest in the form +// +// for i in range(lbi, ubi, stepi): +// prologue0(i) +// for j0 in range(lbj0, ubj0, stepj0): +// body0(i, j0) +// epilogue1(i) +// for j1 in range(lbj1, ubj1, stepj1): +// body1(i, j1) +// epilogue2(i) +// ... +// for jN in range(lbjN, ubjN, stepjN): +// bodyN(i, jN) +// epilogue(i) +// +// Rewrite this into a single loop in the form: +// +// len_i = len(range(lbi, ubi, stepi)) +// len_j0 = len(range(lbj0, ubj0, stepj0)) +// len_j1 = len(range(lbj1, ubj1, stepj1)) +// ... +// len_jN = len(range(lbjN, ubjN, stepjN)) +// inner_len = max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - N +// total_iters = len_i * inner_len +// +// T = -1 +// i = lbi - stepi +// for _ in range(total_iters): +// T = 0 if T == (inner_len - 1) else T + 1 +// +// if T == 0: +// i += stepi +// prologue0(i) +// j0 = lbj0 +// if T >= 0 and T < len_j0: +// body0(i, j0) +// j0 += stepj0 +// +// if T == max(1, len_j0) - 1: +// prologue1(i) +// j1 = lbj1 +// if T >= max(1, len_j0) - 1 +// and T < max(1, len_j0) - 1 + len_j1: +// body1(i, j1) +// j1 += stepj1 +// +// if T == max(1, len_j0) + max(1, len_j1) - 2: +// prologue2(i) +// j2 = lbj2 +// if T >= max(1, len_j0) + max(1, len_j1) - 2 +// and T < max(1, len_j0) + max(1, len_j1) - 2 + len_j2: +// body2(i, j2) +// j2 += stepj2 +// +// ... +// +// if T == max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN-1) - N: +// prologueN(i) +// jN = lbjN +// if T >= max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN-1) - N +// and T < max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN-1) - N + +// len_jN: +// bodyN(i, jN) +// jN += stepjN +// +// if T == max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - (N + 1): +// epilogue(i) +// +// This routine can be applied recursively on a loop nest tree, leaf-to-root, to +// flatten the loop nest into a single loop. However, this routine only fuses +// child loops whose loop bounds are invariant to the parent loop. For child +// loops where this is not the case, the function will ignore them. +// +// We could fuse loops with parent-loop-variant or even data-dependent bounds, +// but this will require generating `scf.while` in a form that is not friendly +// to the pipeliner. In order to effectively fuse and pipeline these kinds of +// loop nests, loop nest fusion and the pipeliner need to share a higher-level +// representation (or perhaps be the same pass). +// +// Note that there are many potential forms of the fused loop. This routine will +// attempt to minimize the number of fused loop iterations by overlapping the +// iteration spaces of the child loops and the epilogues. E.g. the last +// iteration of bodyjK will execute on the same fused loop iteration as +// epilogueK and the first iteration of bodyj(K+1). Hence the `- N` term in the +// total number of iterations. +// +// What the above Python-psuedo-code glosses over is SSA dependency management. +// To interpret the pseudocode as SSA IR, just imagine everything is put back +// into allocas and SSA formation re-runs after fusion, which one should note +// will introduce undefs. +// +// Handling dependencies will require turning implicit captures into +// loop-carried dependencies. Consider: +// +// scf.for %i = %lbi to %ubi step %stepi { +// %a = tt.call @func(%i) +// scf.for %j = %lbj to %ubj step %stepj { +// %b = tt.call @use(%a, %j) +// } +// } +// +// This needs to be rewritten into: +// +// %poison = ub.poison +// %Tlast, %ilast, %jlast, %alast = scf.for %unused = ... +// iter_args(%Tprev = %c-1_i32, +// %iprev = %lbi - %stepi, +// %jprev = %poison, +// %aprev = %poison) -> (i32, i32, i32, i32) { +// %T = (%Tprev + 1) mod (...) +// %a, %i, %j = scf.if %T == 0 { +// %inext = %iprev + 1 +// %jnext = %lbj - %stepj +// +// %anext = tt.call @func(%i) +// yield %inext, %jnext, %anext +// } else { +// yield %iprev, %jprev, %aprev +// } +// +// scf.if %T >= 0 and %T < ... { +// tt.call @use(%a, %j) +// } +// +// Note: the induction variables will be initialized to their lower bound to +// avoid underflow in lbjk - stepjk, with the exception of the outer loop +// induction variable, which needs to be incremented inside the prologue to +// avoid a dependency on the epilogue. This helps the scheduler behave. +// +// Any inputs and outputs of the loop bodies would also need to be handled +// similarly: initialized as undef if appropriate and carried through the fused +// loop. This is why fusion will increase liveranges. To minimize the number of +// additional loop-carried values, the routine will analyze the subblock of IR +// inside each `prologueK` and determine its "outputs" as intermediate SSA +// values that are used later in the loop nest. +static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { + scf::ForOp outer = parent->loop; + + SmallVector innerLoops; + llvm::SetVector toHoist; + for (LoopNestNode *child : parent->children) { + scf::ForOp inner = child->loop; + assert(child->children.empty() && "fuseOneLevel runs leaf-to-root"); + + // Check if the inner loop bounds are or can be made invariant to the outer + // loop. Check them all at once to avoid adding ops to `toHoist` if not + // necessary. + if (!isOuterLoopInvariant( + domInfo, outer, + {inner.getLowerBound(), inner.getUpperBound(), inner.getStep()}, + toHoist)) + continue; + + // Add this child to the list of loops to fuse. + innerLoops.push_back(child->loop); + } + + // From the perspective of the overall analysis, we can delete all the + // children of the current loop node. Child loops that cannot be fused are now + // treated opaquely by the rest of the analysis. This allows partial fusing of + // the constructed loop nest. + parent->children.clear(); + + // If there are no child loops to fuse, then there is nothing to do. + if (innerLoops.empty()) + return; + + // The transformation will definitely succeed on `childrenToFuse`. `toHoist` + // only contains the operations that must be hoisted for `childrenToFuse` to + // be fusible. + toHoist = topologicalSort(toHoist); + for (Operation *op : toHoist) + op->moveBefore(outer); + + // Determine the integer type to use for the length computations. Use an + // integer bitwidth twice the size of the largest integer, up to 64 bits, to + // avoid overflow. + unsigned intTyWidth = getIntTypeWidth(outer.getInductionVar().getType()); + + // Generate the computations of the fused loop bounds. + Location loc = outer.getLoc(); + ImplicitLocOpBuilder b(loc, outer); + Value lenOuter = computeNumIters(b, outer); + SmallVector lenInners; + for (scf::ForOp loop : innerLoops) { + // len_jk = len(range(lbjk, ubjk, stepjk)) + Value lenInner = computeNumIters(b, loop); + intTyWidth = std::max(intTyWidth, getIntTypeWidth(lenInner.getType())); + lenInners.push_back(lenInner); + } + auto intTy = b.getIntegerType(intTyWidth); + + auto intTyCst = [&](int64_t v) { + return b.create(IntegerAttr::get(intTy, v)); + }; + + // inner_len = max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jN) - N + unsigned N = innerLoops.size() - 1; + Value innerLen = intTyCst(0); + // Keep all the partial sums because we need them later. + SmallVector partialInnerSums; + partialInnerSums.push_back(innerLen); + for (Value lenInner : lenInners) { + lenInner = castIntIfNecessary(b, lenInner, intTy); + lenInner = b.create(intTyCst(1), lenInner); + innerLen = b.create(innerLen, lenInner); + partialInnerSums.push_back(innerLen); + } + innerLen = b.create(innerLen, intTyCst(N)); + + // total_iters = len_i * inner_len + Value totalIters = + b.create(castIntIfNecessary(b, lenOuter, intTy), innerLen); + + // The outputs of the prologue, each epilogue, and all inner loop bodies need + // to carried through the fused loop. + SmallVector logues; + auto addLogue = [&](Block::iterator begin, Block::iterator end) { + logues.push_back(createLogueFrom({begin, end}, domInfo)); + }; + // prologue0 + addLogue(outer.getBody()->begin(), innerLoops.front()->getIterator()); + // prologuek where 0 < k <= N + for (auto i : llvm::seq(0, innerLoops.size() - 1)) { + addLogue(std::next(innerLoops[i]->getIterator()), + innerLoops[i + 1]->getIterator()); + } + // epilogue + addLogue(std::next(innerLoops.back()->getIterator()), + // Don't include the outer loop yield. + std::prev(outer.getBody()->end())); + + // We need iter args for: + // - The fused loop induction var + // - The outer loop induction var + // - The outer loop iter args + // - The induction vars for each inner loop + // - The outputs of each child loop + // - The outputs of each logue + SmallVector fusedInits; + + // T = -1 + fusedInits.push_back(intTyCst(-1)); + // i = lbi - stepi + fusedInits.push_back( + b.create(outer.getLowerBound(), outer.getStep())); + + unsigned outerArgsStartIdx = fusedInits.size(); + llvm::append_range(fusedInits, outer.getInits()); + + // Everything else is initialized to undef. + unsigned ivarStartIdx = fusedInits.size(); + for (scf::ForOp loop : innerLoops) { + fusedInits.push_back( + createPoisonOrZero(b, loop.getInductionVar().getType())); + } + unsigned innerOutsStartIdx = fusedInits.size(); + for (scf::ForOp loop : innerLoops) { + for (Type resultType : loop.getResultTypes()) + fusedInits.push_back(createPoisonOrZero(b, resultType)); + } + unsigned logueOutsStartIdx = fusedInits.size(); + for (Logue &logue : llvm::drop_end(logues)) { + for (Type outputType : logue.getOutputTypes()) + fusedInits.push_back(createPoisonOrZero(b, outputType)); + } + + // for _ in range(total_iters): + auto fused = + b.create(intTyCst(0), totalIters, intTyCst(1), fusedInits); + // Replace the outer loop args with the args in the fused loop args. + for (auto [arg, fusedArg] : + llvm::zip(outer.getRegionIterArgs(), + fused.getRegionIterArgs().slice(outerArgsStartIdx))) { + arg.replaceAllUsesWith(fusedArg); + } + b.setInsertionPointToStart(fused.getBody()); + + // T = 0 if T == (inner_len - 1) else T + 1 + Value T = fused.getRegionIterArg(0); + Value nextT = b.create(T, intTyCst(1)); + Value rollover = + b.create(arith::CmpIPredicate::eq, T, + b.create(innerLen, intTyCst(1))); + T = b.create(rollover, intTyCst(0), nextT); + + // `i` is computed inside the first prologue. + Value curI = fused.getRegionIterArg(1); + Value i; + + assert(partialInnerSums.size() == N + 2); + ArrayRef ivars = fused.getRegionIterArgs().slice(ivarStartIdx); + auto bodyOutsIt = + ValueRange(fused.getRegionIterArgs()).begin() + innerOutsStartIdx; + auto logueOutsIt = + ValueRange(fused.getRegionIterArgs()).begin() + logueOutsStartIdx; + SmallVector prologueIfs, bodyIfs; + for (unsigned k = 0; k <= N; ++k) { + // if T == max(1, len_j0) + ... max(1, len_jk-1) - k + // [[if k == 0]] i += stepi + // prologuek(i) + // jk = lbjk + Value innerStartT = + b.create(partialInnerSums[k], intTyCst(k)); + Value prologueCond = + b.create(arith::CmpIPredicate::eq, T, innerStartT); + + // The `scf.if` outputs will be `jk` and the outputs of prologuek. We also + // have to initialize the inner loop iter args. + scf::ForOp inner = innerLoops[k]; + Logue &prologue = logues[k]; + + SmallVector prologueOutTypes{inner.getInductionVar().getType()}; + llvm::append_range(prologueOutTypes, prologue.getOutputTypes()); + llvm::append_range(prologueOutTypes, inner.getInits().getTypes()); + if (k == 0) + prologueOutTypes.push_back(curI.getType()); + auto prologueIf = b.create(prologueOutTypes, prologueCond); + prologueIfs.push_back(prologueIf); + + // Splice prologuek into the `then` region. + Block *thenBlock = b.createBlock(&prologueIf.getThenRegion()); + prologue.moveBefore(thenBlock, thenBlock->end()); + + if (k == 0) { + // Increment `i` and replace its uses inside the prologue. + b.setInsertionPointToStart(thenBlock); + i = b.create(curI, outer.getStep()); + mlir::replaceAllUsesInRegionWith(outer.getInductionVar(), i, + prologueIf.getThenRegion()); + } + + // Yield the initialized jk, the prologue outputs, and the initial values of + // the inner loop. + b.setInsertionPointToEnd(thenBlock); + SmallVector thenOuts{inner.getLowerBound()}; + llvm::append_range(thenOuts, prologue.getOutputs()); + llvm::append_range(thenOuts, inner.getInits()); + if (k == 0) + thenOuts.push_back(i); + b.create(thenOuts); + + // In the `else` region, just yield the last values of jk, the outputs, and + // the iter args. + b.createBlock(&prologueIf.getElseRegion()); + Value lastJk = ivars[k]; + unsigned numOuts = prologue.getNumOutputs(); + SmallVector elseOuts{lastJk}; + elseOuts.append(logueOutsIt, logueOutsIt + numOuts); + elseOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults()); + if (k == 0) + elseOuts.push_back(curI); + logueOutsIt += numOuts; + b.create(elseOuts); + + // The results of the `scf.if` become the values of jk and the prologue + // outputs for the rest of the fused loop. + Value jk = prologueIf.getResult(0); + ValueRange prologueOuts = prologueIf.getResults().slice(1, numOuts); + ValueRange prologueInits = + prologueIf.getResults().slice(1 + numOuts, inner.getNumResults()); + inner.getInductionVar().replaceAllUsesWith(jk); + prologue.replaceAllUsesWith(prologueOuts, prologueIf.getThenRegion()); + for (auto [init, iterArg] : + llvm::zip(prologueInits, inner.getRegionIterArgs())) + iterArg.replaceAllUsesWith(init); + // Replace uses of `i` elsewhere with the prologue result. + if (k == 0) { + i = prologueIf.getResults().back(); + outer.getInductionVar().replaceAllUsesWith(i); + } + + // if T >= max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jk-1) - k + // and T < max(1, len_j0) + max(1, len_j1) + ... + max(1, len_jk-1) - k + + // len_jk + // bodyk(i, jk) + // jk += stepjk + b.setInsertionPointAfter(prologueIf); + Value innerEndT = b.create( + innerStartT, castIntIfNecessary(b, lenInners[k], intTy)); + Value ge = + b.create(arith::CmpIPredicate::sge, T, innerStartT); + Value lt = b.create(arith::CmpIPredicate::slt, T, innerEndT); + Value bodyCond = b.create(ge, lt); + + // The outputs will be the outputs of the inner loop body and the next jk. + SmallVector bodyOutTypes{jk.getType()}; + llvm::append_range(bodyOutTypes, inner->getResultTypes()); + auto bodyIf = b.create(bodyOutTypes, bodyCond); + bodyIfs.push_back(bodyIf); + + // Splice bodyk into the `then` region. + inner.getBody()->eraseArguments([](Value arg) { return true; }); + bodyIf.getThenRegion().takeBody(inner.getBodyRegion()); + auto yield = getYield(bodyIf.getThenRegion()); + b.setInsertionPoint(yield); + Value nextJk = b.create(jk, inner.getStep()); + yield->insertOperands(0, nextJk); + + // The `else` region just forwards the values. + b.createBlock(&bodyIf.getElseRegion()); + SmallVector bodyForwardedOuts{jk}; + bodyForwardedOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults()); + bodyOutsIt += inner->getNumResults(); + b.create(bodyForwardedOuts); + + // Now we can replace the results of the inner loop with the outputs of the + // body if. + inner.replaceAllUsesWith( + bodyIf.getResults().slice(1, inner.getNumResults())); + + // If the inner loop must execute, then its body does not have to be wrapped + // in a conditional. + if (inner->hasAttr(kMustExecuteAttrName)) { + b.setInsertionPoint(bodyIf); + bodyIf.getConditionMutable().assign( + b.create(b.getBoolAttr(true))); + } + + // Move the insertion point for the next iteration. + b.setInsertionPointAfter(bodyIf); + } + + // if T == len_j0 + len_j1 + ... + len_jN - N - 1: + // epilogue(i) + Logue &epilogue = logues.back(); + + // The only possible use of an epilogue output is the yield. + auto outerYield = cast(outer.getBody()->getTerminator()); + SmallVector usedIterArgs; + for (Value output : epilogue.getOutputs()) { + for (OpOperand &use : output.getUses()) { + if (use.getOwner() == outerYield) { + usedIterArgs.push_back(fused.getRegionIterArgs().drop_front( + outerArgsStartIdx)[use.getOperandNumber()]); + } + } + } + + auto epilogueCond = + b.create(arith::CmpIPredicate::eq, T, + b.create(innerLen, intTyCst(1))); + auto epilogueIf = + b.create(epilogue.getOutputTypes(), epilogueCond); + + Block *thenBlock = b.createBlock(&epilogueIf.getThenRegion()); + epilogue.moveBefore(thenBlock, thenBlock->end()); + + b.setInsertionPointToEnd(thenBlock); + b.create(epilogue.getOutputs()); + b.createBlock(&epilogueIf.getElseRegion()); + b.create(usedIterArgs); + epilogue.replaceAllUsesWith(epilogueIf.getResults(), + epilogueIf.getThenRegion()); + + // Finally, create the yield of the fused loop. + SmallVector outerOuts{T, i}; + llvm::append_range(outerOuts, outerYield.getOperands()); + for (scf::IfOp bodyIf : bodyIfs) + outerOuts.push_back(/*jk=*/bodyIf.getResult(0)); + for (auto [bodyIf, loop] : llvm::zip(bodyIfs, innerLoops)) { + llvm::append_range(outerOuts, + bodyIf.getResults().slice(1, loop.getNumResults())); + } + for (auto [logueIf, logue] : llvm::zip(prologueIfs, llvm::drop_end(logues))) { + llvm::append_range(outerOuts, + logueIf.getResults().slice(1, logue.getNumOutputs())); + } + + b.setInsertionPointToEnd(fused.getBody()); + b.create(outerOuts); + outer.replaceAllUsesWith( + fused.getResults().slice(outerArgsStartIdx, outer.getNumResults())); + + // Reduce dependencies across inner loops by hoisting the initialization of + // inner loop iter args to the outer loop when possible, and then placing the + // reset of these values in the epilogue. + auto fusedInitsIt = fused.getInitsMutable().begin() + innerOutsStartIdx; + auto fusedArgsIt = fused.getRegionIterArgs().begin() + innerOutsStartIdx; + auto fusedYieldIt = getYield(fused.getBodyRegion())->getOpOperands().begin() + + innerOutsStartIdx; + SmallVector yieldsToUpdate; + SmallVector reset, forwarded; + for (auto [loop, ifOp, bodyIf, prologue] : + llvm::zip(innerLoops, prologueIfs, bodyIfs, logues)) { + unsigned numResults = loop.getNumResults(); + unsigned prologueSkip = 1 + prologue.getNumOutputs(); + + llvm::BitVector removeIndices(prologueSkip + numResults); + SmallVector replaceWith; + for (auto [i, init] : llvm::enumerate(loop.getInits())) { + if (init.getParentRegion() == &fused.getBodyRegion()) + continue; + // Initialize this in the outer loop. + fusedInitsIt[i].assign(init); + replaceWith.push_back(fusedArgsIt[i]); + removeIndices.set(prologueSkip + i); + yieldsToUpdate.push_back(&fusedYieldIt[i]); + forwarded.push_back(bodyIf.getResult(1 + i)); + reset.push_back(init); + } + // Remove the initializers in the corresponding prologue. + eraseIfResults(b, ifOp, removeIndices, replaceWith); + + fusedInitsIt += numResults; + fusedArgsIt += numResults; + fusedYieldIt += numResults; + } + if (!yieldsToUpdate.empty()) { + MutableOperandRange(getYield(epilogueIf.getThenRegion())).append(reset); + MutableOperandRange(getYield(epilogueIf.getElseRegion())).append(forwarded); + b.setInsertionPoint(epilogueIf); + TypeRange newTypes = getYield(epilogueIf.getThenRegion()).getOperandTypes(); + auto newIf = b.create(newTypes, epilogueIf.getCondition()); + newIf.getThenRegion().takeBody(epilogueIf.getThenRegion()); + newIf.getElseRegion().takeBody(epilogueIf.getElseRegion()); + epilogueIf.replaceAllUsesWith( + newIf.getResults().take_front(epilogueIf.getNumResults())); + ResultRange newResults = + newIf.getResults().drop_front(epilogueIf.getNumResults()); + for (auto [i, yieldOperand] : llvm::enumerate(yieldsToUpdate)) + yieldOperand->set(newResults[i]); + epilogueIf.erase(); + } + + // Update the parent's loop to the fused loop. Set the new stage count to the + // max stage count of the inner loops. + int numStages = 1; + for (scf::ForOp loop : innerLoops) { + if (auto stageAttr = loop->getAttrOfType(kNumStagesAttrName)) + numStages = std::max(numStages, stageAttr.getInt()); + loop.erase(); + } + outer.erase(); + parent->loop = fused; + if (numStages > 1) + fused->setAttr(kNumStagesAttrName, b.getI32IntegerAttr(numStages)); +} + +//===----------------------------------------------------------------------===// +// flattenLoopNest +//===----------------------------------------------------------------------===// + +// Completely flatten a loop nest by recursively fusing loops in a post-order +// traversal with `fuseOneLevel`. +static void flattenLoopNest(LoopNestNode *node, mlir::DominanceInfo &domInfo) { + for (LoopNestNode *child : node->children) + flattenLoopNest(child, domInfo); + fuseOneLevel(node, domInfo); +} + +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + +// Fuse simple loop nests with a single outer and inner loop, and where the +// inner loop has a `tt.dot` operation. +static bool shouldFuse(const LoopNest &nest) { + if (nest.root->loop->hasAttr(kAlwaysFuseAttrName)) + return true; + + // Only fuse simple loop nests. + return nest.nodes.size() == 2 && nest.root->children.size() == 1 && + nest.root->loop->hasAttr(kFlattenAttr); +} + +// This function identifies a subgraph of cheap ops that can be sunk between two +// regions in the loop nest and moves them, reducing their liveranges. +static void sinkOps(Region &limit, Block *sinkBlock, Block::iterator sinkBefore, + llvm::iterator_range prologue, + function_ref inSinkRegion) { + llvm::SetVector sunkOps; + auto canBeSunk = [&](Operation &op) -> std::pair { + if (!isPure(&op) || isa(op)) + return {false, false}; + // An op can be sunk if all its users are inside the inner loop or are + // marked for sinking. + bool isRoot = true; + for (Operation *user : op.getUsers()) { + if (inSinkRegion(user)) + continue; + isRoot = false; + if (sunkOps.contains(user)) + continue; + return {false, false}; + } + return {true, isRoot}; + }; + + // Find the subgraph of operations that can be sunk. + SmallVector roots; + for (Operation &op : llvm::reverse(prologue)) { + auto [canSink, isRoot] = canBeSunk(op); + if (canSink) + sunkOps.insert(&op); + if (isRoot) + roots.push_back(&op); + } + if (sunkOps.empty()) + return; + + sunkOps = topologicalSort(sunkOps); + for (Operation *op : sunkOps) + op->moveBefore(sinkBlock, sinkBefore); +} + +// Sink ops from the prologue into the epilogue when possible. +static void optimizeEpilogueDependencies(scf::ForOp outerLoop, + scf::ForOp innerLoop, + mlir::DominanceInfo &domInfo) { + auto inEpilogue = [&](Operation *op) { + return domInfo.properlyDominates(innerLoop, op, /*enclosingOpOk=*/false); + }; + Region &limit = outerLoop.getBodyRegion(); + sinkOps(limit, outerLoop.getBody(), std::next(innerLoop->getIterator()), + {outerLoop.getBody()->begin(), innerLoop->getIterator()}, inEpilogue); +} + +// Speculate the length of the inner loop such that the loop is known to execute +// at least once. This way, the inner loop body does not have to be placed +// inside a conditional in the fused loop, which interacts better with the +// pipeliner. +static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop, + scf::ForOp innerLoop, + mlir::DominanceInfo &domInfo) { + // The inner loop bounds must be outer-loop invariant to speculate from + // outside the loop nest. + Location loc = innerLoop.getLoc(); + llvm::SetVector toHoist; + if (!isOuterLoopInvariant(domInfo, outerLoop, + {innerLoop.getLowerBound(), + innerLoop.getUpperBound(), innerLoop.getStep()}, + toHoist)) + return failure(); + + // Hoist the inner loop bounds computations if necessary. + toHoist = topologicalSort(toHoist); + for (Operation *op : toHoist) + op->moveBefore(outerLoop); + + // Mark the inner loop. + ImplicitLocOpBuilder b(loc, outerLoop); + innerLoop->setAttr(kMustExecuteAttrName, b.getUnitAttr()); + + // Speculate on whether the length of the inner loop is zero. + Value lenInner = computeNumIters(b, innerLoop); + auto zeroAttr = IntegerAttr::get(lenInner.getType(), 0); + Value innerLoopEmpty = + b.create(arith::CmpIPredicate::eq, lenInner, + b.create(zeroAttr)); + auto ifOp = b.create(outerLoop.getResultTypes(), innerLoopEmpty); + + // In the `then` branch, the inner loop does not execute. Clone the loop nest + // into it and remove the inner loop. + mlir::IRMapping map; + b.createBlock(&ifOp.getThenRegion()); + auto newLoop = cast(b.clone(*outerLoop, map)); + b.create(newLoop.getResults()); + auto newInnerLoop = cast(map.lookup(innerLoop)); + newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits()); + newInnerLoop.erase(); + + // Move the loop nest into the `else` branch. + outerLoop.replaceAllUsesWith(ifOp.getResults()); + Block *block = b.createBlock(&ifOp.getElseRegion()); + outerLoop->remove(); + b.insert(outerLoop); + b.create(outerLoop.getResults()); + + return success(); +} + +static LogicalResult preprocessLoopNest(const LoopNest &nest, + mlir::DominanceInfo &domInfo) { + assert(nest.nodes.size() == 2 && nest.root->children.size() == 1); + + scf::ForOp &outerLoop = nest.root->loop; + scf::ForOp &innerLoop = nest.root->children.front()->loop; + + moveLoopInvariantCode(outerLoop); + optimizeEpilogueDependencies(outerLoop, innerLoop, domInfo); + return speculateInnerLoopLength(outerLoop, innerLoop, domInfo); +} + +void FuseNestedLoopsPass::runOnOperation() { + auto &domInfo = getAnalysis(); + + for (auto func : getOperation().getOps()) { + SmallVector nests; + findLoopNests(func, nests); + for (LoopNest &nest : nests) { + if (!shouldFuse(nest)) + continue; + if (!nest.root->loop->hasAttr(kAlwaysFuseAttrName) && + failed(preprocessLoopNest(nest, domInfo))) + continue; + flattenLoopNest(nest.root, domInfo); + } + } +} + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp new file mode 100644 index 000000000..17e27255c --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -0,0 +1,281 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEACCUMULATORINIT +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +class TMEMAllocWithUnusedInit + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::TMEMAllocOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + if (op.getSrc() == nullptr) + return failure(); + SmallVector users(op.getResult().getUsers().begin(), + op.getResult().getUsers().end()); + if (users.size() > 2) + return failure(); + triton::nvidia_gpu::MMAv5OpInterface mmaOp = nullptr; + triton::nvidia_gpu::TMEMLoadOp tmemLoad = nullptr; + for (auto user : users) { + if (auto load = dyn_cast(user)) { + tmemLoad = load; + } else if (auto mma = + dyn_cast(user)) { + mmaOp = mma; + } + } + if (!mmaOp) + return failure(); + if (tmemLoad && !mmaOp->isBeforeInBlock(tmemLoad)) + return failure(); + Value useAccFlag = mmaOp.useAccumulator(); + if (!useAccFlag) + return failure(); + auto flagConstOp = useAccFlag.getDefiningOp(); + if (!flagConstOp) + return failure(); + if (cast(flagConstOp.getValue()).getInt() != 0) + return failure(); + op.getSrcMutable().clear(); + return success(); + } +}; + +bool dotSupportsAccInitFlag(Operation *op) { + assert(isa(op) && + "Expected an op which implements a DotOpInterface"); + + if (auto wgDotOp = dyn_cast(op)) { + // Partial accumulation would require a select op to handle the + // initialization that would degrade the performance. + return !wgDotOp.needsPartialAccumulator(); + } + if (isa(op)) { + return true; + } + return false; +} + +std::pair getAccumulatorUseAndDef(Operation *op) { + assert(isa(op) && + "Expected an op which implements a DotOpInterface"); + + if (auto wgDotOp = dyn_cast(op)) { + return std::make_pair(wgDotOp.getC(), wgDotOp); + } + if (auto tc05MmaOp = dyn_cast(op)) { + auto accVal = tc05MmaOp.getAccumulator(); + auto tmemAlloc = accVal.getDefiningOp(); + if (!tmemAlloc || + tmemAlloc->getParentRegion() != tc05MmaOp->getParentRegion()) + return std::make_pair(nullptr, nullptr); + triton::nvidia_gpu::TMEMLoadOp tmemLoad = nullptr; + for (auto user : tmemAlloc.getResult().getUsers()) { + if (auto load = dyn_cast(user)) { + tmemLoad = load; + break; + } + } + if (!tmemLoad || + tmemLoad->getParentRegion() != tc05MmaOp->getParentRegion()) + return std::make_pair(nullptr, nullptr); + return std::make_pair(tmemAlloc.getSrc(), tmemLoad); + } + assert(false && "Unexpected op which implements a DotOpInterface"); + return std::make_pair(nullptr, nullptr); +} + +void setUseAccFlag(Operation *op, Value useAcc) { + assert(isa(op) && + "Expected an op which implements a DotOpInterface"); + + if (auto wgDotOp = dyn_cast(op)) { + wgDotOp.getUseCMutable().assign(useAcc); + } else if (auto tc05MmaOp = + dyn_cast(op)) { + tc05MmaOp.setUseAccumulator(useAcc); + } else { + assert(false && "Unexpected op which implements a DotOpInterface"); + } +} + +bool isConstantZeroTensor(Value v) { + return (matchPattern(v, m_Zero()) || matchPattern(v, m_AnyZeroFloat())); +} + +std::optional> +findZeroInitOp(Value accUse, scf::ForOp forOp, bool &loopArgIsZero) { + Value v = accUse; + if (auto arg = dyn_cast(v)) { + assert(arg.getOwner() == forOp.getBody()); + if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) { + loopArgIsZero = true; + } + v = forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + + auto defOp = v.getDefiningOp(); + if (!defOp) { + return std::nullopt; + } + if (auto selOp = dyn_cast(defOp)) { + if (!selOp.getCondition().getType().isInteger(1)) + return std::nullopt; + if (isConstantZeroTensor(selOp.getTrueValue()) || + isConstantZeroTensor(selOp.getFalseValue())) { + return std::make_pair(selOp, 0); + } + } + if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIndex = cast(v).getResultNumber(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + if (isConstantZeroTensor(thenVal) || isConstantZeroTensor(elseVal)) { + // Make sure that the other value is not defined in the if itself, but + // passed from outside + if (thenVal.getParentBlock()->getParentOp() == ifOp || + elseVal.getParentBlock()->getParentOp() == ifOp) { + return std::nullopt; + } + return std::make_pair(ifOp, resultIndex); + } + } + return std::nullopt; +} + +} // namespace + +class OptimizeAccumulatorInitPass + : public impl::TritonGPUOptimizeAccumulatorInitBase< + OptimizeAccumulatorInitPass> { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + SmallVector mmaOps; + m.walk([&](Operation *op) { + if (isa(op) && dotSupportsAccInitFlag(op)) + mmaOps.push_back(op); + }); + + // for each mma op, find where the accumulator is initialized with zero + // It can be: + // 1. A constant zero + // 2. Initialized with zero as the loop argument + // 3. Initialized with zero in the if op or with a select op in current + // or any of the previous loop iterations + for (Operation *mmaOp : mmaOps) { + Location loc = mmaOp->getLoc(); + + scf::ForOp forOp = dyn_cast(mmaOp->getParentOp()); + if (!forOp) { + continue; + } + + IRRewriter rewriter(forOp); + rewriter.setInsertionPoint(forOp); + + Value vTrue = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value vFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + + // Find the accumulator + auto [accUse, accDef] = getAccumulatorUseAndDef(mmaOp); + if (!accUse || !accDef) { + continue; + } + if (isConstantZeroTensor(accUse)) { + setUseAccFlag(mmaOp, vFalse); + continue; + } + + bool loopArgIsZero = false; + std::optional> zeroInitOp = + findZeroInitOp(accUse, forOp, loopArgIsZero); + if (!zeroInitOp) { + continue; + } + + Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue; + scf::ForOp newForOp = + replaceForOpWithNewSignature(rewriter, forOp, {loopArgFlagValue}); + forOp.erase(); + forOp = newForOp; + loopArgFlagValue = + forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1); + + Value condition = nullptr; + Value oldValue = nullptr; + Value zeroValue = nullptr; + bool thenInitsToZero = false; + if (auto selOp = dyn_cast(zeroInitOp->first)) { + condition = selOp.getCondition(); + oldValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getFalseValue() + : selOp.getTrueValue(); + zeroValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getTrueValue() + : selOp.getFalseValue(); + thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue()); + } else { + assert(isa(*zeroInitOp->first) && "Expected an if op"); + auto ifOp = cast(zeroInitOp->first); + unsigned resultIndex = zeroInitOp->second; + condition = ifOp.getCondition(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal; + zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal; + thenInitsToZero = isConstantZeroTensor(thenVal); + } + + // Create a select op that updates the flag + rewriter.setInsertionPoint(zeroInitOp->first); + bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp); + Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue; + auto selectFlagOp = rewriter.create( + loc, condition, thenInitsToZero ? vFalse : prevFlagValue, + thenInitsToZero ? prevFlagValue : vFalse); + setUseAccFlag(mmaOp, zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue); + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->insertOperands(forYield->getNumOperands(), + {zeroingBeforeMMA ? vTrue : selectFlagOp}); + + // Stop clearing out the accumulator with zero + if (auto selOp = dyn_cast(zeroInitOp->first)) { + rewriter.setInsertionPoint(selOp); + rewriter.replaceOp(selOp, oldValue); + } else { + auto ifOp = cast(zeroInitOp->first); + int resultIndex = zeroInitOp->second; + auto zeroingYield = + thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield(); + zeroingYield.setOperand(resultIndex, oldValue); + } + } + + // Cleanup unused init values in tmem allocs + mlir::RewritePatternSet patterns(m.getContext()); + patterns.add(m.getContext()); + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp new file mode 100644 index 000000000..b03080252 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -0,0 +1,302 @@ +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include + +namespace mlir::triton::gpu { + +namespace { +// Given +// dot(convert(trans(src)) #dot_operand) -> +// dot(convert(local_load(trans(alloc(src))))) +// change the encoding of the inner convert to a special, swizzled shared +// encoding. +class SwizzleShmemConvert : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp, + PatternRewriter &rewriter) const override { + if (!cvtOp->hasOneUse() || + !isa(cvtOp->use_begin()->getOwner())) + return failure(); + // Match outerCvt(trans(innerCvt(x))). + auto trans = cvtOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef{1, 0}) + return failure(); + + RankedTensorType srcTy = trans.getSrc().getType(); + + if (auto srcCvt = trans.getSrc().getDefiningOp()) { + srcTy = srcCvt.getSrc().getType(); + } + RankedTensorType sharedLoadTy = cvtOp.getType(); + auto cvtEncoding = + dyn_cast(sharedLoadTy.getEncoding()); + if (!cvtEncoding) + return failure(); + + // TODO(Qingyi): need to check whether the CTALayout of innerCvtEnc should + // be used here. For tests where numCTAs = 1, this is not a problem since + // all CTALayouts are the same. + // + // Set needTrans to true here. newInnerCvtEnc is computed based on + // argEncoding which is before the transpose. Without needTrans we will + // compute vec and maxPhase based on incorrect m, n and k size of mma. The + // type inference of MemDescTransOp simply swap the order but doesn't fix + // the vec and maxPhase for the YType, hence it would causing incorrect + // swizzling code. + auto newInnerCvtEnc = SwizzledSharedEncodingAttr::get( + getContext(), cvtEncoding, srcTy.getShape(), + /*order=*/getOrder(srcTy), + triton::gpu::getCTALayout(srcTy.getEncoding()), srcTy.getElementType(), + /*needTrans=*/true); + if (newInnerCvtEnc == cvtEncoding) + return failure(); + rewriter.setInsertionPoint(trans); + auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext()); + auto alloc = rewriter.create( + trans.getLoc(), + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), + newInnerCvtEnc, sharedMemorySpace), + trans.getSrc()); + auto newTrans = rewriter.create(trans.getLoc(), alloc, + ArrayRef({1, 0})); + rewriter.replaceOpWithNewOp(trans, sharedLoadTy, newTrans); + return success(); + } +}; + +// Rewrite +// +// dot(alloc(trans() #shared1) -> +// dot(trans(alloc() #shared2)) +// +// if dot is an MMAv3/v5 (because MMAv3/v5 allows us to fold transposes). +class FuseTransMMAV3Plus : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LocalAllocOp allocOp, + PatternRewriter &rewriter) const override { + if (!allocOp.getSrc() || !allocOp->hasOneUse() || + !isa( + *allocOp->getUsers().begin())) + return failure(); + + auto dot = *allocOp->getUsers().begin(); + // Match outerCvt(trans(innerCvt(x))). + auto trans = allocOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef({1, 0})) + return failure(); + + MemDescType allocType = allocOp.getType(); + auto allocEncoding = cast(allocType.getEncoding()); + RankedTensorType srcTy = trans.getSrc().getType(); + + // MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3 + // without transpose for other data types.) + auto newInnerCvtOrder = getOrder(srcTy); + if (auto cvt = trans.getSrc().getDefiningOp()) { + newInnerCvtOrder = getOrder(cvt.getSrc().getType()); + } + auto srcElemTy = allocType.getElementType(); + if (!srcElemTy.isF16() && !srcElemTy.isBF16()) { + if (allocOp.getResult() == dot->getOperand(0)) { + newInnerCvtOrder = {0, 1}; + } else if (allocOp.getResult() == dot->getOperand(1)) { + newInnerCvtOrder = {1, 0}; + } + } + + // TODO(Qingyi): need to check whether the CTALayout of innerCvtEnc should + // be used here. For tests where numCTAs = 1, this is not a problem since + // all CTALayouts are the same. + auto newInnerEnc = NVMMASharedEncodingAttr::get( + getContext(), srcTy.getShape(), newInnerCvtOrder, + allocEncoding.getCTALayout(), srcTy.getElementType(), + allocEncoding.getFp4Padded()); + + MemDescType innerTy = + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc, + allocType.getMemorySpace()); + auto newAlloc = rewriter.create(allocOp.getLoc(), innerTy, + trans.getSrc()); + rewriter.replaceOpWithNewOp(allocOp, newAlloc, + ArrayRef({1, 0})); + return success(); + } +}; + +// Inject TMEM copy instructions into IR to efficiently load blocked scales for +// scaled dot +class UseShmemForScales + : public OpRewritePattern { +public: + using OpRewritePattern< + triton::nvidia_gpu::TCGen5MMAScaledOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::TCGen5MMAScaledOp mmaOp, + PatternRewriter &rewriter) const override { + auto aScale = mmaOp.getAScale(); + auto bScale = mmaOp.getBScale(); + LogicalResult ret = failure(); + if (aScale && isa( + aScale.getType().getEncoding())) { + if (rewriteOperand(mmaOp.getAScaleMutable(), rewriter).succeeded()) + ret = success(); + } + if (bScale && isa( + bScale.getType().getEncoding())) { + if (rewriteOperand(mmaOp.getBScaleMutable(), rewriter).succeeded()) + ret = success(); + } + return ret; + } + +private: + LogicalResult rewriteOperand(OpOperand &opOperand, + PatternRewriter &rewriter) const { + auto src = cast>(opOperand.get()); + auto tmemAlloc = src.getDefiningOp(); + if (!tmemAlloc) { + return failure(); + } + auto dstType = tmemAlloc.getResult().getType(); + + if (!tmemAlloc.getSrc()) { + return failure(); + } + + // Look for a sequence + // local_load + // -> reshape(..., (BLOCK_MN / 128, BLOCK_K / scale_vec_size / 4, 32, 4, + // 4) + // -> transpose(..., (0, 3, 2, 1, 4)) + // -> reshape(..., (BLOCK_MN, BLOCK_K / scale_vec_size) + // -> tmem_alloc + // -> tc_gen_mma_scaled + // and replace it with local_alloc -> tc_gen_mma_scaled + auto scale2DShape = dstType.getShape(); + auto blockMN = scale2DShape[0]; + auto numScales = scale2DShape[1]; + const SmallVector transposeOrder{0, 3, 2, 1, 4}; + const SmallVector reshape5DShape{blockMN / 128, numScales / 4, 32, + 4, 4}; + + auto reshapeOp2D = getNextOp(tmemAlloc.getSrc()); + if (!reshapeOp2D || + reshapeOp2D.getResult().getType().getShape() != scale2DShape) { + return failure(); + } + + auto transOp = getNextOp(reshapeOp2D.getSrc()); + if (!transOp || transOp.getOrder() != ArrayRef(transposeOrder)) { + return failure(); + } + + auto reshapeOp5D = getNextOp(transOp.getSrc()); + if (!reshapeOp5D || reshapeOp5D.getResult().getType().getShape() != + ArrayRef(reshape5DShape)) { + return failure(); + } + + auto localLoad = getNextOp(reshapeOp5D.getSrc()); + if (!localLoad || !isTmemCopyCompatible(localLoad.getSrc().getType())) { + return failure(); + } + opOperand.assign(localLoad.getSrc()); + return success(); + } + + template Op getNextOp(Value op) const { + while (auto cvtOp = op.getDefiningOp()) { + op = cvtOp.getSrc(); + } + return op.getDefiningOp(); + } + + bool isDescendingOrder(triton::gpu::MemDescType scale) const { + auto order = triton::gpu::getOrder(scale); + auto rank = scale.getRank(); + for (int i = 0; i < rank; ++i) { + if (order[i] != rank - 1 - i) + return false; + } + return true; + } + + bool isTmemCopyCompatible(triton::gpu::MemDescType scaleType) const { + // TMEM copy expects that blocked scale "chunks" in SMEM are stored in + // innermost axes contiguously. + if (!isDescendingOrder(scaleType)) + return false; + + auto sharedEnc = + cast(scaleType.getEncoding()); + if (sharedEnc.getMaxPhase() != 1 || sharedEnc.getPerPhase() != 1 || + sharedEnc.getVec() != 1) { + // For now, we do not expect swizzling to be applied to the scale SMEM. + // This is currently true for non-matmul operand SMEM allocated during + // pipelining. + return false; + } + + if (scaleType.getRank() != 2) { + // TODO: Add support for higher rank when 5D coalesced load is fixed + // or 4D TMA is supported. + return false; + } + + auto elemBits = scaleType.getElementType().getIntOrFloatBitWidth(); + + // We assume that 32x128b chunks are flattened into the inner most axis. + auto innerMostBits = + scaleType.getDimSize(scaleType.getRank() - 1) * elemBits; + return innerMostBits % (32 * 128) == 0; + } +}; + +} // namespace + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUOptimizeDotOperandsPass + : public impl::TritonGPUOptimizeDotOperandsBase< + TritonGPUOptimizeDotOperandsPass> { +public: + using impl::TritonGPUOptimizeDotOperandsBase< + TritonGPUOptimizeDotOperandsPass>::TritonGPUOptimizeDotOperandsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + OpPassManager pm; + pm.addPass(mlir::createCanonicalizerPass()); + if (failed(runPipeline(pm, m))) + return signalPassFailure(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp new file mode 100644 index 000000000..0ded1b366 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -0,0 +1,591 @@ +#include +#include + +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZETHREADLOCALITY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +// Change the destination layout of reshape ops allowing reorder when used by a +// reduction in order to minimize the amount of cross thread communication for +// the reduction. +struct OptimizeReshapeLayoutPattern : public OpRewritePattern { + OptimizeReshapeLayoutPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(ReshapeOp viewOp, + PatternRewriter &rewriter) const override { + if (!viewOp.getAllowReorder()) + return failure(); + std::optional reductionAxis; + for (Operation *user : viewOp.getResult().getUsers()) { + if (auto reduceOp = dyn_cast(user)) { + if (reductionAxis) { + if (reductionAxis != reduceOp.getAxis()) + return failure(); + } else { + reductionAxis = reduceOp.getAxis(); + } + } + } + if (!reductionAxis) + return failure(); + RankedTensorType tensorType = viewOp.getType(); + if (auto blocked = + mlir::dyn_cast(tensorType.getEncoding())) { + // If the layout already has all the elements along the reduction + // dimension in the same thread we can skip. + if (blocked.getThreadsPerWarp()[*reductionAxis] == 1 && + blocked.getWarpsPerCTA()[*reductionAxis] == 1 && + blocked.getCTAsPerCGA()[*reductionAxis] == 1) + return failure(); + } + ArrayRef shape = tensorType.getShape(); + SmallVector order; + for (int i : triton::gpu::getOrder(tensorType)) { + if (i != *reductionAxis) + order.push_back(i); + } + // Make the reduction axis last so that elements won't be distributed + // amongst threads along this dimension. + order.push_back(*reductionAxis); + SmallVector sizePerThread(shape.size(), 1); + auto mod = viewOp->getParentOfType(); + int numWarps = lookupNumWarps(viewOp); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = TritonGPUDialect::getNumCTAs(mod); + auto encoding = + BlockedEncodingAttr::get(viewOp.getContext(), shape, sizePerThread, + order, numWarps, threadsPerWarp, numCTAs); + if (encoding == tensorType.getEncoding()) + return failure(); + RankedTensorType newType = + RankedTensorType::get(shape, tensorType.getElementType(), encoding); + if (triton::gpu::isExpensiveView(viewOp.getSrc().getType(), newType)) + return failure(); + rewriter.setInsertionPointAfter(viewOp); + rewriter.modifyOpInPlace(viewOp, [&]() { + viewOp.getResult().setType(newType); + viewOp.setEfficientLayout(true); + }); + auto cvt = rewriter.create(viewOp.getLoc(), tensorType, + viewOp.getResult()); + rewriter.replaceAllUsesExcept(viewOp.getResult(), cvt.getResult(), cvt); + return success(); + } +}; +} // namespace + +static RankedTensorType replaceEncoding(RankedTensorType oldType, + Attribute newEncoding) { + return RankedTensorType::get(oldType.getShape(), oldType.getElementType(), + newEncoding); +} + +// This function considers a gather op in isolation and attempts to determine +// whether an optimized layout can be applied to the source and index tensors. +static void setOptimizedGatherLayout(GatherOp op, RewriterBase &b) { + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + // Determine a warp-local gather layout that minimizes the number of emitted + // warp shuffles. + unsigned numThreadsPerWarp = + product(triton::gpu::getThreadsPerWarp(srcType.getEncoding())); + unsigned numWarps = + product(triton::gpu::getWarpsPerCTA(srcType.getEncoding())); + + // If in a gather column, each thread owns `srcSizePerThread[axis]` elements + // in the source tensor and `idxSizePerThread[axis]` elements in the index + // tensor (including broadcasting), then the number of index shuffles per + // column is `srcSizePerThread[axis] * idxSizePerThread[axis]`. This is then + // replicated over the number of columns in which a thread owns (an equal + // number of) elements, which is `product(srcSizePerThread[i] for i != axis)`. + // + // Thus, the total number of index shuffles is `product(srcSizePerThread) * + // idxSizePerThread[axis]`. Since we cannot alter the number of threads per + // warp or the number of warps, `product(srcSizePerThread)` is just a function + // of the shape. + // + // So we want to minimize `idxSizePerThread[axis]`. Note that broadcasting is + // forbidden in the source tensor but allowed in the index tensor. Choose the + // smallest value while still ensuring that a warp spans whole columns. + // + // In order to prevent broadcasting in the source tensor layout, ensure + // + // sizePerThread(i) * threadsPerWarp(i) * warpsPerCTA(i) = shape(i) + // + // For all i != axis in the source tensor. The same relationship must hold for + // the index tensor. This means we can't just set `idxSizePerThread[axis]` to + // 1 and compute the rest from that. Find the smallest value where this + // relationship is still respected. + + // We know that the layouts will be the same between the two tensors except + // for `sizePerThread[axis]`. + unsigned axis = op.getAxis(); + unsigned rank = srcType.getRank(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector order; + order.push_back(axis); + + // Minimize `sizePerThread[axis]` by putting as many theads along the axis as + // possible, limited to the actual size of the dimension. + unsigned maxThreadsInAxis = + std::min(srcType.getDimSize(axis), numThreadsPerWarp); + threadsPerWarp[axis] = maxThreadsInAxis; + + // Now spread them along the other dimensions. Do this according to order + // (arbitrary). + unsigned threadsToAlloc = numThreadsPerWarp / maxThreadsInAxis; + for (unsigned dim : getThreadOrder(srcType)) { + if (dim == axis) + continue; + // The gather axis is now the fastest-changing dimension. + order.push_back(dim); + unsigned nextThreadAlloc = + std::min(srcType.getDimSize(dim), threadsToAlloc); + threadsPerWarp[dim] = nextThreadAlloc; + threadsToAlloc /= nextThreadAlloc; + } + assert(llvm::none_of(threadsPerWarp, [](unsigned c) { return c == 0; })); + + // There must be one warp along the gather axis. + warpsPerCTA[axis] = 1; + // Allocate the remaining warps in the same manner. + unsigned warpsToAlloc = numWarps; + for (unsigned dim : getWarpOrder(srcType)) { + if (dim == axis) + continue; + unsigned warpsCanFit = srcType.getDimSize(dim) / threadsPerWarp[dim]; + assert(warpsCanFit != 0); + unsigned nextWarpAlloc = std::min(warpsCanFit, warpsToAlloc); + warpsPerCTA[dim] = nextWarpAlloc; + warpsToAlloc /= nextWarpAlloc; + } + assert(llvm::none_of(warpsPerCTA, [](unsigned c) { return c == 0; })); + + // Just set `sizePerThread` to 1 along other dimensions and let broadcasting + // handling it. This also means we can use the same layout between the source + // and index tensors for simplicity. + SmallVector sizePerThread(rank, 1); + sizePerThread[axis] = srcType.getDimSize(axis) / threadsPerWarp[axis]; + + // Overflow by broadcasting along the gather axis since this is the most + // predictable. + threadsPerWarp[axis] *= threadsToAlloc; + warpsPerCTA[axis] *= warpsToAlloc; + + assert(product(threadsPerWarp) == numThreadsPerWarp); + assert(product(warpsPerCTA) == numWarps); + + // Construct the new layout. + MLIRContext *ctx = srcType.getContext(); + auto baseLayout = cast(srcType.getEncoding()); + auto ctaLayout = + CTALayoutAttr::get(ctx, baseLayout.getCTAsPerCGA(), + baseLayout.getCTASplitNum(), baseLayout.getCTAOrder()); + auto newLayout = BlockedEncodingAttr::get(ctx, sizePerThread, threadsPerWarp, + warpsPerCTA, order, ctaLayout); + + // Update the layout on the gather op and insert conversions. + auto cvtSrc = b.create( + op.getLoc(), replaceEncoding(srcType, newLayout), op.getSrc()); + auto cvtIdx = b.create( + op.getLoc(), replaceEncoding(idxType, newLayout), op.getIndices()); + + b.setInsertionPointAfter(op); + auto cvtOut = + b.create(op.getLoc(), op.getType(), op.getResult()); + b.replaceAllUsesExcept(op.getResult(), cvtOut, cvtOut); + + b.modifyOpInPlace(op, [&] { + op.getSrcMutable().set(cvtSrc); + op.getIndicesMutable().set(cvtIdx); + op.getResult().setType(replaceEncoding(op.getType(), newLayout)); + + // Mark the layout as optimized on the op to prevent it from being changed. + op.setEfficientLayout(true); + }); + + // Make sure we did this right. + assert(GatherLoweringHelper(op).isWarpLocal()); +} + +namespace { +struct OptimizeGatherLayoutPattern : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp op, + PatternRewriter &rewriter) const override { + if (op.getEfficientLayout()) + return failure(); + setOptimizedGatherLayout(op, rewriter); + return success(); + } +}; +} // namespace + +namespace { +class TritonGPUOptimizeThreadLocalityPass + : public impl::TritonGPUOptimizeThreadLocalityBase< + TritonGPUOptimizeThreadLocalityPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // First try to optimize the layout of views and gathers. + mlir::RewritePatternSet layoutPatterns(&getContext()); + layoutPatterns.add(&getContext()); + layoutPatterns.add(&getContext()); + if (mlir::applyPatternsGreedily(mod, std::move(layoutPatterns)).failed()) { + signalPassFailure(); + } + + DenseSet reduceOps; + mod.walk([&](triton::ReduceOp reduce) -> void { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto reductionOp = getReductionOp(reduce); + if (!reductionOp || + !isa( + reductionOp.value())) + return; + // TODO: relax this restriction + if (!(isa(srcEncoding) && rank > 1)) + return; + // The code currently assumes that the reduction is happening on the most + // inner dim. + if (reduce.getAxis() != rank - 1) + return; + for (auto operand : reduce->getOperands()) { + if (!operand.getDefiningOp()) + return; + } + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + // Not worth applying this optimization if there is only one element per + // thread on the reduction axis + if (elemsPerThread == 1) + return; + if (!reduce->hasOneUse()) + return; + Operation *user = *(reduce->getUsers().begin()); + if (!user->hasOneUse()) + return; + OpOperand &yieldOpOperand = *(user->getUses().begin()); + auto yieldOp = dyn_cast(yieldOpOperand.getOwner()); + if (!yieldOp) + return; + auto operandNumber = yieldOpOperand.getOperandNumber(); + Block *block = reduce->getBlock(); + Operation *parentOp = block->getParentOp(); + auto forOp = dyn_cast(parentOp); + if (!forOp) + return; + auto argNum = yieldOpOperand.getOperandNumber(); + auto oldAccum = forOp.getInitArgs()[argNum]; + auto cstOp = oldAccum.getDefiningOp(); + if (!cstOp) + return; + reduceOps.insert(reduce); + }); + + IRRewriter builder(&getContext()); + for (auto reduce : reduceOps) { + builder.setInsertionPoint(reduce); + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + assert(isa(srcEncoding) && + "Thread locality optimization only supports blocked encoding"); + auto blocked = dyn_cast(srcEncoding); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto rank = srcShape.size(); + // create new layouts + auto blocked3d = getThreadLocalityOptimizedEncoding(reduce); + auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce); + auto viewOpTensorType = RankedTensorType::get( + viewOpTensorShape, srcType.getElementType(), blocked3d); + auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank, + blocked3d); + // Get forOp + assert(reduce->hasOneUse()); + OpOperand &use = *(reduce->getUses().begin()); + auto operandNumber = use.getOperandNumber(); + auto oldUpdate = use.getOwner(); + assert(oldUpdate->getNumOperands() == 2); + auto accumOperandNumber = (operandNumber == 0) ? 1 : 0; + auto accumOperand = oldUpdate->getOperand(accumOperandNumber); + assert(isa(accumOperand)); + auto blockArg = dyn_cast(accumOperand); + auto blockArgNum = blockArg.getArgNumber(); + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + // get oldAccum + auto oldAccum = + forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()]; + // get old loop user + Value loopResult = + forOp.getResult(blockArgNum - forOp.getNumInductionVars()); + assert(loopResult.hasOneUse()); + OpOperand &loopUse = *(loopResult.getUses().begin()); + Operation *loopUser = loopUse.getOwner(); + // get old loop yield + auto oldYield = cast(forOp.getBody()->getTerminator()); + // create newAccum initialization + auto newAccum = + createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d); + // create new loop by copying the old for op signature and appending + // newAccum to the block arguments + auto newLoop = replaceForOpWithNewSignature( + builder, forOp, ValueRange{newAccum->getResult(0)}); + // create thread local reduction (also adds viewOps) + auto newReduce = createReduce(builder, reduce, viewOpTensorType); + + // create new accum update + auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate); + // create new yield + auto newYield = createYield(builder, newLoop, oldYield, + newUpdate->getResult(0), blockArgNum); + // create post loop reduction on the original reduce axis + auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce); + // add convert_layout to get back to original layout, the result layout + // should now match the layout of the old accumulator (%cst) + Type destType = loopResult.getType(); + auto cvtLayout = createConvertLayout(builder, destType, newReduce2); + // incorporate the original accumulator value into the final result + auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate, + cvtLayout, oldAccum); + // Replace the old loop user with the final result + loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0)); + + // cleanup + oldYield.erase(); + forOp.erase(); + } + }; + +private: + std::optional getReductionOp(triton::ReduceOp reduce) const { + auto numRegions = reduce->getNumRegions(); + if (numRegions != 1) + return std::nullopt; + Region ®ion = reduce->getRegion(0); + auto numBlocks = region.getBlocks().size(); + if (numBlocks != 1) + return std::nullopt; + Block &block = region.front(); + auto blockWithoutTerminator = block.without_terminator(); + auto blockSizeWithoutTerminator = std::distance( + blockWithoutTerminator.begin(), blockWithoutTerminator.end()); + if (blockSizeWithoutTerminator != 1) + return std::nullopt; + Operation *op = &block.front(); + return std::optional(op); + } + Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder, + Operation *oldUpdate, + Operation *cvtLayout, + Value oldAccum) const { + builder.setInsertionPointAfter(cvtLayout); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), oldAccum); + mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0)); + auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping); + return finalOp; + } + Operation *createConvertLayout(OpBuilder &builder, Type destType, + Operation *newReduce) const { + builder.setInsertionPointAfter(newReduce); + auto newCvt = builder.create( + newReduce->getLoc(), destType, newReduce->getResult(0)); + return newCvt; + } + + Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop, + triton::ReduceOp &reduce) const { + auto resultIndex = + loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars(); + auto newLoopResult = loop.getResult(resultIndex); + builder.setInsertionPointAfter(loop); + IRMapping mapping; + mapping.map(*(reduce.getOperands().begin()), newLoopResult); + auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping); + return newReduce2; + } + + Operation *createYield(OpBuilder &builder, scf::ForOp &loop, + scf::YieldOp &oldYield, Value newUpdate, + int oldAccumBlockArgNum) const { + builder.setInsertionPoint(oldYield); + SmallVector yieldValues = llvm::to_vector(oldYield.getOperands()); + yieldValues[oldAccumBlockArgNum - 1] = + loop.getBody()->getArgument(oldAccumBlockArgNum); + yieldValues.push_back(newUpdate); + auto newYield = + builder.create(oldYield.getLoc(), yieldValues); + return newYield; + } + + Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop, + Operation *newReduce, Operation *oldUpdate) const { + auto blockArgNum = loop.getBody()->getNumArguments() - 1; + auto newArg = loop.getBody()->getArgument(blockArgNum); + builder.setInsertionPointAfter(newReduce); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), newArg); + mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0)); + auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping); + return newUpdate; + } + + Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce, + Type viewOpTensorType) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + builder.setInsertionPointAfter(reduce); + IRMapping mapping; + for (auto operand : reduce.getOperands()) { + auto viewOp = builder.create( + reduce.getLoc(), viewOpTensorType, operand, + /*allowReorder=*/true, /*efficientLayout=*/true); + mapping.map(operand, viewOp); + } + + auto newReduce = cloneWithInferType(builder, &(*reduce), mapping); + newReduce->setAttr("axis", builder.getI32IntegerAttr(rank)); + auto typeInfer = dyn_cast(newReduce); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newReduce->getContext(), newReduce->getLoc(), + newReduce->getOperands(), newReduce->getAttrDictionary(), + newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newReduce->getResult(i).setType(newTypes[i]); + } + } + return newReduce; + } + + // Work around the lack of support for MaxNumFOp and MinNumFOp in + // arith::getNeutralElement. + std::optional getNeutralElement(Operation *op) const { + if (isa(op)) { + OpBuilder builder(op->getContext()); + + Type resultType = op->getResult(0).getType(); + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/true)); + } + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/false)); + } + } else { + return mlir::arith::getNeutralElement(op); + } + llvm_unreachable("Unhandled reduction op"); + return std::nullopt; + } + + Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce, + Value &oldAccum, SmallVector &shape, + Attribute &slice2d) const { + // Drop the last dimension (thread locality dimension) + SmallVector accumShape(shape.begin(), shape.end() - 1); + auto elemType = cast(oldAccum.getType()).getElementType(); + // Create tensor type for the new accumulator + auto accumType = RankedTensorType::get(accumShape, elemType, slice2d); + // Create new accumulator + builder.setInsertionPointAfter(oldAccum.getDefiningOp()); + auto reductionOp = getReductionOp(reduce); + assert(reductionOp && "Processing a reduce that is not supported!"); + auto neutralVal = getNeutralElement(reductionOp.value()); + assert(neutralVal && "Could not find neutral value for reduction op!"); + auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value()); + auto newAccum = builder.create(oldAccum.getLoc(), + accumType, denseAttr); + return newAccum; + } + + SmallVector + getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto rank = srcShape.size(); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto viewOpTensorShape = insertValue(srcShape, rank, 1); + viewOpTensorShape[reduce.getAxis()] /= elemsPerThread; + viewOpTensorShape[rank] = elemsPerThread; + return viewOpTensorShape; + } + + BlockedEncodingAttr + getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto blocked = dyn_cast(srcEncoding); + auto sizePerThread3d = + insertValue(blocked.getSizePerThread(), rank, + blocked.getSizePerThread()[reduce.getAxis()]); + sizePerThread3d[reduce.getAxis()] = 1; + auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1); + auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1); + auto order3d = insertValue(blocked.getOrder(), 0, rank); + auto ctasPerCGA3d = + insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1); + auto ctasSplitNum3d = + insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1); + auto ctaOrder3d = + insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank); + auto ctaLayout3d = triton::gpu::CTALayoutAttr::get( + reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d); + auto blocked3d = triton::gpu::BlockedEncodingAttr::get( + reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d, + order3d, ctaLayout3d); + return blocked3d; + } + + template + SmallVector insertValue(ArrayRef vec, unsigned index, int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } + template + SmallVector insertValue(const SmallVector &vec, unsigned index, + int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } +}; +} // namespace + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/PingPong.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/PingPong.cpp new file mode 100644 index 000000000..419461ef6 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/PingPong.cpp @@ -0,0 +1,320 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include + +#define DEBUG_TYPE "triton-ping-pong-sync" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +// Returns the taskId if op has a single taskId, otherwise, returns -1. +static int getSingleTaskId(Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (asyncTasks.size() != 1) + return -1; + return asyncTasks[0]; +} + +// Treat exp2, mulf, addf, reduce as expensive computation when data type is +// a tensor type of 1D or higher. +static bool isExpensiveComp(Operation *op) { + if (!isa(op) && !isa(op) && + !isa(op) && !isa(op)) + return false; + auto tensorTy = dyn_cast(op->getOperand(0).getType()); + return tensorTy && tensorTy.getRank() >= 1; +} + +static Value createGetAsyncTaskId(OpBuilder &builder, Operation *op) { + auto loc = op->getLoc(); + return builder.create(loc); +} + +static bool isInnermostLoop(scf::ForOp forOp) { + for (Operation &nestedOp : forOp.getBody()->getOperations()) { + if (isa(nestedOp)) { + return false; + } + } + return true; +} + +#define GEN_PASS_DEF_TRITONGPUPINGPONGSYNC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUPingPongSyncPass + : public impl::TritonGPUPingPongSyncBase { +public: + using impl::TritonGPUPingPongSyncBase< + TritonGPUPingPongSyncPass>::TritonGPUPingPongSyncBase; + + enum class ResourceType { + Gemm, + OtherComp, + }; + const int PING_BARRIER = 9; + const int PONG_BARRIER = 10; + + unsigned getLoopDepth(Operation *op) { + unsigned depth = 0; + auto pOp = op->getParentOfType(); + while (pOp) { + ++depth; + pOp = pOp->getParentOfType(); + } + return depth; + } + + void + getNestedFor(scf::IfOp ifOp, + DenseMap> &loopDepthMap) { + ifOp->walk([&](Operation *subOp) { + if (dyn_cast(subOp)) { + unsigned tDepth = getLoopDepth(subOp); + loopDepthMap[tDepth].push_back(subOp); + } + }); + } + + Operation *moveBackward(Operation *endofGemm, scf::ForOp forOp) { + SmallVector opList; + for (auto &op : forOp.getBody()->without_terminator()) { + opList.push_back(&op); + } + bool found = false; + Operation *newEnd = endofGemm; + for (auto it = opList.rbegin(); it != opList.rend(); ++it) { + Operation *op = *it; + if (op == endofGemm) { + found = true; + continue; + } + if (found && isa(op)) { + break; + } + if (found) + newEnd = op; + } + return newEnd; + } + + bool categorizeIf(scf::IfOp ifOp, bool &hasDot, bool &hasExpCudaOp) { + hasDot = false; + hasExpCudaOp = false; + bool hasFor = false; + ifOp->walk([&](Operation *subOp) { + LLVM_DEBUG({ + LDBG("walk if"); + subOp->dump(); + }); + if (isa(subOp)) { + hasFor = true; + } else if (isa(subOp)) { + hasDot = true; + } else if (isExpensiveComp(subOp)) { + hasExpCudaOp = true; + } + LDBG("---- " << hasDot << " " << hasExpCudaOp << " " << hasFor); + }); + LDBG("after walk if " << hasDot << " " << hasExpCudaOp << " " << hasFor); + return hasFor; + } + void runOnFuncOp(triton::FuncOp funcOp) { + // Insert sync points in ForOp for consumer warp groups. Enable this pass + // when number of consumer warp groups == 2. + if (numConsumerGroups != 2) + return; + if (!mlir::triton::tools::getBoolEnv("ENABLE_PINGPONG")) + return; + + SmallVector loops; + // Identify ForOps for consumer warp groups. Here we assume taskId 0 is for + // producer. This pass handles the case of a single forOp for two consumer + // warp groups. + // Find top-most IfOps, and find the top level ForOp, assuming only one top + // level ForOp. A few use cases: 1> Persistent with ForOp containing both + // cuda and tensor 2> Persistent with ForOp containing tensor, then epilogue + // with cuda + DenseMap> loopDepthMap1; + DenseMap> loopDepthMap2; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &bodyOp : block.getOperations()) { + Operation *op = &bodyOp; + if (auto ifOp = dyn_cast(op)) { + int wgId = getSingleTaskId(op); + // Assume taskId 0 is for producer. Assume we will visit taskId 1 + // first. + if (wgId == 1) { + getNestedFor(ifOp, loopDepthMap1); + } else if (wgId == 2) { + getNestedFor(ifOp, loopDepthMap2); + } + } + } + } + // Verify loopDepthMap1 and loopDepthMap2: a single ForOp at depth of 0 + // and a single ForOp at depth of 1. + LDBG("Found loops: " << loopDepthMap1.size()); + if (loopDepthMap1.empty()) + return; + if (loopDepthMap1[0].size() != 1) + return; + bool hasPersistent = loopDepthMap1.find(1) != loopDepthMap1.end(); + + // Assume two loops have the same ops. Check innermost loop. + SmallVector starts, ends; + for (unsigned iter = 0; iter < 2; ++iter) { + Operation *op = iter == 0 ? loopDepthMap1[hasPersistent ? 1 : 0][0] + : loopDepthMap2[hasPersistent ? 1 : 0][0]; + auto forOp = dyn_cast(op); + Operation *startOfGemm = nullptr; + Operation *endOfGemm = nullptr; + // A simple heuristic for now: + // Mark the start of a gemm section when hitting a DotLike op. + // Mark the end of a gemm section once hitting an expensive non-dot + // computation op. + for (auto &op : forOp.getBody()->without_terminator()) { + if (startOfGemm && endOfGemm) + break; + bool hasDot, isCudaCore; + bool hasError = false; + if (auto ifOp = dyn_cast(op)) { + // if containing expensive cuda core op + hasError = categorizeIf(ifOp, hasDot, isCudaCore); + } else { + LLVM_DEBUG({ + LDBG("walk for"); + op.dump(); + }); + hasDot = isa(op); + // hasDot = isa(&op); + isCudaCore = isExpensiveComp(&op); + LDBG("walk for " << hasDot << " " << isCudaCore); + } + if (hasError || (hasDot && isCudaCore)) + break; + if (hasDot && !isCudaCore && startOfGemm == nullptr) { + startOfGemm = &op; + continue; + } + if (!hasDot && isCudaCore && startOfGemm) { + endOfGemm = &op; + break; + } + } + if (startOfGemm) { + LLVM_DEBUG({ + LDBG("found start of tensor core ops"); + startOfGemm->dump(); + }); + } + if (endOfGemm) { + LLVM_DEBUG({ + LDBG("found end of tensor core ops"); + endOfGemm->dump(); + }); + } + + if (!startOfGemm || !endOfGemm) + return; + starts.push_back(startOfGemm); + ends.push_back(endOfGemm); + } + // TODO: epilogue overlapping. + { + // "bar.arrive 9, 256" only when task Id is 2. + Operation *outerLoopTask2 = loopDepthMap2[0][0]; + OpBuilder builder(outerLoopTask2); + builder.setInsertionPoint(outerLoopTask2); + auto forLoc = outerLoopTask2->getLoc(); + Value pingBarrier = + builder.create(forLoc, PING_BARRIER, 32); + Value numThreads = builder.create(forLoc, 256, 32); + builder.create(forLoc, pingBarrier, + numThreads); + } + for (unsigned idx = 0; idx < 2; ++idx) { + Operation *op = idx == 0 ? loopDepthMap1[hasPersistent ? 1 : 0][0] + : loopDepthMap2[hasPersistent ? 1 : 0][0]; + auto forOp = dyn_cast(op); + OpBuilder builder(forOp); + Operation *startOfGemm = starts[idx]; + Operation *endOfGemm = ends[idx]; + + // FIXME: hard-code using named barrier 9 and 10 in this pass. + // Prior to the forOp, add "bar.arrive 9, 256" only when task Id is 2. + // At startOfGemm, insert "bar.sync 8+taskId, 256" + // At endOfGemm, insert "bar.arrive 11-taskId, 256" + builder.setInsertionPoint(forOp); + auto forLoc = forOp->getLoc(); + + // FIXME: hard-code total number of threads to be 256 when + // numConsumerGroups is 2. + Value numThreads = builder.create(forLoc, 256, 32); + // for taskId of 1, generate: bar.sync pingBarrier; bar.arrive pongBarrier + // for taskId of 2, outside of the loop, generate bar.arrive pingBarrier + // inside the loop, generate bar.sync pongBarrier; bar.arrive + // pingBarrier + Value pingBarrier = + builder.create(forLoc, PING_BARRIER, 32); + + int wgId = getSingleTaskId(forOp); + // At startOfGemm, insert "bar.sync 9 or 10, 256" + builder.setInsertionPoint(startOfGemm); + auto loc = startOfGemm->getLoc(); + Value syncBarrier = builder.create( + loc, wgId == 1 ? PING_BARRIER : PONG_BARRIER, 32); + builder.create(loc, syncBarrier, numThreads); + + // At endOfGemm, insert "bar.arrive 10 or 9, 256" + Operation *insertBefore = endOfGemm; + insertBefore = moveBackward(endOfGemm, forOp); + builder.setInsertionPoint(insertBefore); + auto loc2 = endOfGemm->getLoc(); + Value arriveBarrier = builder.create( + loc2, wgId == 1 ? PONG_BARRIER : PING_BARRIER, 32); + builder.create(loc2, arriveBarrier, + numThreads); + } + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + LLVM_DEBUG({ + LDBG("post pass"); + getOperation()->dump(); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp new file mode 100644 index 000000000..424d76501 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -0,0 +1,286 @@ +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Return true if the preconditions for pipelining the loop are met. +bool preCondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (loopHasDistGreaterThanOne(forOp)) + return false; + // Don't pipeline outer loops. + if (isOuterLoop(forOp)) + return false; + return true; +} + +bool canHaveSharedEncoding(tt::LoadOp op) { + // If used by an user with DotOp encoding, all the uses must be compatible. + bool incompatible = false; + getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible); + return !incompatible; +} + +bool isSmallLoad(tt::LoadOp loadOp, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return true; + auto ty = cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + LDBG("Load " << *loadOp << " has width " << width); + return width < 32; +} + +bool isPipeliningBeneficial(Operation *op, Operation *finalUser, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + if (auto loadOp = dyn_cast(op)) { + if (isSmallLoad(loadOp, axisInfoAnalysis)) { + LDBG("Load " << *loadOp << " is too small for pipelining"); + return false; + } + } + if (isa( + op)) + return true; + if (!canHaveSharedEncoding(cast(op))) { + LDBG("Load " << *op << " cannot have shared encoding"); + return false; + } + + ttg::SharedEncodingTrait localAllocEnc; + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + for (auto user : op->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) { + // If the load is used by a LocalAllocOp, all the users need to have the + // same encoding. + return false; + } + } + } + + if (localAllocEnc) { + auto registerTy = cast(op->getResultTypes()[0]); + auto vecBytes = getCopyVecBytes(registerTy, localAllocEnc); + if (vecBytes < 4) { + // At least 4 bytes need to be consecutive for cp.async + return false; + } + } + + return true; +} + +// Create a map from load ops to their indirection level and the +// final use of the load op (another load op, or a dot op). +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +llvm::MapVector +loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + llvm::MapVector loadOpToIndLevel; + DenseSet seen; + DenseSet excluded; + + std::function dfs = + [&](Operation *op, Operation *finalUser, int distance) { + if (!seen.insert(op).second || excluded.count(op)) + return; + if (isa(op)) { + if (!isPipeliningBeneficial(op, finalUser, axisInfoAnalysis)) + return; + if (loadOpToIndLevel.count(op)) { + int level = loadOpToIndLevel[op]; + if (level != distance) { + // If we have multiple uses at different distances, we don't know + // which one to pick. + LDBG("Load " << *op + << " has multiple uses at different distances:" + << level << " and " << distance); + loadOpToIndLevel.erase(op); + excluded.insert(op); + return; + } + } else { + LDBG("Load " << *op << " considered for pipelining with distance " + << distance); + loadOpToIndLevel[op] = distance; + } + finalUser = op; + distance++; + } + for (Value operand : getNestedOperands(op)) { + if (isa(op)) { + // Heuristic: only pipeline A and B operands of the dot op. + if (operand == op->getOperand(2)) + continue; + } + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, finalUser, distance); + } + } + if (auto tmemAlloc = dyn_cast(op)) { + if (!tmemAlloc.getSrc()) { + for (auto user : tmemAlloc.getResult().getUsers()) { + if (auto tmemCopy = dyn_cast(user)) { + dfs(tmemCopy.getSrc().getDefiningOp(), finalUser, distance); + break; + } + } + } + } + }; + + bool seenDot = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + seenDot = true; + seen.clear(); + dfs(&op, &op, 0); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (pipelineWithoutDot && !seenDot) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, &op, 0); + } + } + + return loadOpToIndLevel; +} + +bool hasLatenciesAssigned(scf::ForOp forOp) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr("tt_latency")) + return true; + } + return false; +} + +void assignUserProvidedLatencies(scf::ForOp forOp, + DenseMap &opLatency) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto latencyAttr = op.getAttr("tt_latency")) { + opLatency[&op] = mlir::cast(latencyAttr).getInt(); + } + } +} + +} // namespace + +// Look for load ops that directly or indirectly feed into dot ops. Based +// on the requested number of stages assign the latencies in a way that +// cover all the stages with the sum of latencies in the chain from the first +// load to the final dot op. +void assignLatencies(ModuleOp moduleOp, int defaultNumStages) { + auto getNumStagesOrDefault = [defaultNumStages](scf::ForOp forOp) -> int { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return defaultNumStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + }; + + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (preCondition(forOp) && getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + if (loops.empty()) + return; + + DenseMap opLatency; + for (auto forOp : loops) { + if (hasLatenciesAssigned(forOp)) { + assignUserProvidedLatencies(forOp, opLatency); + continue; + } + int numStages = getNumStagesOrDefault(forOp); + bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName); + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + llvm::MapVector loadOpToIndLevel = + loadOpsToIndirectionLevel(forOp, pipelineWithoutDot, axisInfoAnalysis); + if (loadOpToIndLevel.empty()) + continue; + + // We assume loads with different dist are assigned to different stages. + // If numStages is 2, we will have no stage available for indirect loads + // with dist >= 1. In general, when dist is equal to numStages - 1, we + // should not pipeline it. + for (auto iter = loadOpToIndLevel.begin(); + iter != loadOpToIndLevel.end();) { + if (iter->second >= numStages - 1) + iter = loadOpToIndLevel.erase(iter); + else + ++iter; + } + + // Calculate the stage distance between applicable loads. + auto vals = llvm::make_second_range(loadOpToIndLevel); + int maxIndirectionLevel = vals.empty() ? 0 : *llvm::max_element(vals); + unsigned loadLatency = (numStages - 1) / (maxIndirectionLevel + 1); + + for (auto [loadOp, dist] : loadOpToIndLevel) { + opLatency[loadOp] = loadLatency; + } + } + serializeLatencies(moduleOp, opLatency); +} +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp new file mode 100644 index 000000000..f1dcbc905 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp @@ -0,0 +1,919 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +///////////////////////////// +// UTILS +///////////////////////////// + +class OpBuilderForStage : public OpBuilder { + std::optional _stage; + std::optional _cluster; + CoarseSchedule &_schedule; + +public: + explicit OpBuilderForStage(Operation *op, CoarseSchedule &schedule, int stage, + CoarseSchedule::Cluster cluster) + : OpBuilder(op, nullptr), _schedule(schedule), _stage(stage), + _cluster(cluster) {} + explicit OpBuilderForStage(Operation *op, CoarseSchedule &schedule) + : OpBuilder(op, nullptr), _schedule(schedule) { + if (_schedule.count(op)) { + auto sc = _schedule[op]; + _stage = sc.first; + _cluster = sc.second; + } + } + void setStageCluster(std::pair stageCluster) { + _stage = stageCluster.first; + _cluster = stageCluster.second; + } + + template OpTy create(Args &&...args) { + OpTy op = OpBuilder::create(std::forward(args)...); + if (_stage && _cluster) { + _schedule.insert(op, *_stage, *_cluster); + } + return op; + } +}; + +bool isTMALoad(Operation *op) { + return isa(op); +} + +DenseSet getTopLevelUsersInLoop(Operation *op, scf::ForOp forOp) { + DenseSet topLevelUsers; + SmallVector q; + for (auto &use : op->getUses()) + q.push_back(&use); + while (!q.empty()) { + auto use = q.pop_back_val(); + auto yieldOp = dyn_cast(use->getOwner()); + if (yieldOp && yieldOp->getParentOp() == forOp) { + for (auto &use : + forOp.getRegionIterArgs()[use->getOperandNumber()].getUses()) + q.push_back(&use); + continue; + } + Operation *topLevelUser = + forOp.getBody()->findAncestorOpInBlock(*use->getOwner()); + topLevelUsers.insert(topLevelUser); + } + return topLevelUsers; +} + +Operation *getFirstUseOfPipelinedOp(SmallVector ops, + scf::ForOp forOp, + CoarseSchedule &schedule) { + Operation *firstUser = nullptr; + DenseSet topLevelUsers; + for (Operation *op : ops) { + auto users = getTopLevelUsersInLoop(op, forOp); + topLevelUsers.insert(users.begin(), users.end()); + } + for (Operation *topLevelUser : topLevelUsers) { + assert(schedule.count(topLevelUser) && "op user not found in the schedule"); + auto [_useStage, _useCluster] = schedule[topLevelUser]; + if (!firstUser) { + firstUser = topLevelUser; + } else { + auto [_firstUserStage, _firstUserCluster] = schedule[firstUser]; + if (_useStage < _firstUserStage || + (_useStage == _firstUserStage && + schedule.clusters.isBefore(_useCluster, _firstUserCluster))) { + firstUser = topLevelUser; + } + } + } + return firstUser; +} + +int getDefUseStageDiff(Operation *op, scf::ForOp forOp, + CoarseSchedule &schedule) { + assert(schedule.count(op) && "LoadOp not found in the schedule"); + auto [defStage, _] = schedule[op]; + std::optional useStage; + DenseSet topLevelUsers = getTopLevelUsersInLoop(op, forOp); + for (Operation *topLevelUser : topLevelUsers) { + auto [_useStage, _] = schedule[topLevelUser]; + useStage = std::min(_useStage, useStage.value_or(_useStage)); + } + if (!useStage) + return 0; + assert(useStage >= defStage && "LoadOp used before defined"); + return useStage.value() - defStage; +} + +template +Value createIncrementModulo(BuilderT &builder, Location loc, Value counter, + Value modulus, Value zero, Value one, + Value *outCond = nullptr) { + Value addOne = builder.template create(loc, counter, one); + Value inRangeCond = builder.template create( + loc, arith::CmpIPredicate::slt, addOne, modulus); + if (outCond) + *outCond = inRangeCond; + return builder.template create(loc, inRangeCond, addOne, + zero); +} + +///////////////////////////// +// LOWER LOADS +///////////////////////////// + +ttg::SharedEncodingTrait getSharedEncoding(Operation *op) { + // Try to use local alloc encoding if possible. + ttg::SharedEncodingTrait localAllocEnc; + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + for (auto user : op->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) { + // Some users have different encoding than others. + // Use one of the encodings, and warn about the performance issue. + op->emitRemark() + << "Pipelining load with different use encodings. This will lead " + "to layout conversions and performance degradation."; + continue; + } + } + } + + auto ty = cast(op->getResultTypes()[0]); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + auto order = ttg::getOrder(ty); + if (isTMALoad(op)) { + // For TMA, the encoding compatible with it takes precedence over local + // alloc created for the MMA operand. + if (localAllocEnc) { + if (auto sharedMMALayout = + dyn_cast(localAllocEnc)) { + assert(!sharedMMALayout.getFp4Padded() && + "TMA load for mixed precision MMAv5 is not supported yet."); + } + } + return ttg::NVMMASharedEncodingAttr::get( + ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType(), + /*fp4Padded*/ false); + } + + if (localAllocEnc) + return localAllocEnc; + + // Try to use dot encoding if possible. + bool incompatible = false; + localAllocEnc = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) + .value_or(nullptr); + + if (localAllocEnc) + return localAllocEnc; + + // Use generic layout. This won't be optimal for 2D tensors. + return ttg::SwizzledSharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, + ctaLayout); +} + +// Create an allocation that can hold distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, + ttg::SharedEncodingTrait sharedEnc, + unsigned distance) { + OpBuilder builder(forOp); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + Value alloc = + builder.create(loadOp->getLoc(), memdescType); + + builder.setInsertionPointAfter(forOp); + builder.create(forOp.getLoc(), alloc); + return alloc; +} + +template +Operation *createWithStage(BuilderT &builder, Location loc, int stage, + CoarseSchedule::Cluster cluster, Args &&...args) { + Operation *op = builder.template create( + loc, std::forward(args)...); + + return op; +} + +// Check if the load can be pipelined entirely in shared memory, with user +// consuming directly the shared memory, without going through registers. +bool canBeShmemPipelined(Operation *op) { + if (auto loadOp = dyn_cast(op)) { + // AsyncCopyGlobalToLocalOp does not support the non-zero "other" value. + // With consumer consuming directly the shared memory, there would be no way + // to replace masked values with the "other" value. + if (loadOp.getOther() && !isZeroConst(loadOp.getOther())) + return false; + } + + if (!op->hasOneUse()) + return false; + Operation *user = *op->getUsers().begin(); + if (auto alloc = dyn_cast(user)) { + return isa(alloc.getType().getEncoding()); + } + return false; +} + +void createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + CoarseSchedule &schedule) { + OpBuilderForStage builder(forOp, schedule); + Value zero = builder.create(forOp.getLoc(), 0, 32); + + Operation *firstUse = getFirstUseOfPipelinedOp({loadOp}, forOp, schedule); + assert(firstUse && "LoadOp has no users"); + // Replace the load with async copy, wait and loal_load. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(loadOp); + builder.setStageCluster(schedule[loadOp]); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + ttg::MemDescType allocTy = cast(alloc.getType()); + + // Create async copy + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + ttg::MemDescType subviewTy = ttg::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + Operation *copy = builder.create( + loc, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile()); + Operation *commit = + builder.create(loc, copy->getResult(0)); + + // Create wait and local load + builder.setStageCluster(schedule[firstUse]); + Operation *wait = + builder.create(loc, commit->getResult(0), 0); + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + + if (!loadOp.getOther() || isZeroConst(loadOp.getOther())) { + // Remove redundant local_load -> local_alloc, but only if + // we are not using the other value. AsyncCopyGlobalToLocalOp does not + // support the masking. + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto userAlloc = dyn_cast(user)) { + if (allocTy.getEncoding() == userAlloc.getType().getEncoding()) { + tt::replaceUsesAndPropagateType(builder, userAlloc, + viewLoad.getResult()); + allocsToErase.push_back(userAlloc); + } + } + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } + } + + // If there are some uses that were not local_allocs, we need to create a + // local_load for them. + if (loadOp->use_begin() != loadOp->use_end()) { + auto sharedLoad = builder.create( + loc, loadOp.getType(), viewLoad, wait->getResult(0)); + auto result = sharedLoad->getResults(); + + // Create a select for non-zero other values as they are not handled by + // AsyncCopyGlobalToLocalOp for now. + if (other && !isZeroConst(other)) { + auto select = builder.create( + loc, loadOp.getType(), + // Use the mask operand from the original load, not the one with a + // potentially transformed layout. + loadOp.getMask(), sharedLoad.getResult(), other); + result = select->getResults(); + } + loadOp->replaceAllUsesWith(result); + } + schedule.erase(loadOp); + loadOp->erase(); +} + +void createTMAAsyncCopy( + scf::ForOp forOp, Operation *loadOp, Value desc, Value alloc, + Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, + CoarseSchedule &schedule, + function_ref + createCopy) { + OpBuilderForStage builder(forOp, schedule); + Value zero = builder.create(forOp.getLoc(), 0, 32); + + Operation *firstUse = getFirstUseOfPipelinedOp({loadOp}, forOp, schedule); + assert(firstUse && "LoadOp has no users"); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + Location loc = loadOp->getLoc(); + + builder.setInsertionPoint(loadOp); + builder.setStageCluster(schedule[loadOp]); + ttg::MemDescType allocTy = cast(alloc.getType()); + + // Create async copy + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + ttg::MemDescType subviewTy = ttg::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + + Value pred = builder.create(loc, 1, 1); + Value tmaPtr = + builder.create(loc, desc); + createCopy(builder, tmaPtr, barrier, view, pred); + + // Create local load after the wait + builder.setInsertionPointAfter(waitOp); + builder.setStageCluster(schedule[firstUse]); + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + // Remove redundant local_load -> local_alloc + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto userAlloc = dyn_cast(user)) { + if (allocTy.getEncoding() == userAlloc.getType().getEncoding()) { + tt::replaceUsesAndPropagateType(builder, userAlloc, + viewLoad.getResult()); + allocsToErase.push_back(userAlloc); + } + } + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } + + // If there are some uses that were not local_allocs, we need to create a + // local_load for them. + if (loadOp->use_begin() != loadOp->use_end()) { + auto sharedLoad = builder.create( + loc, loadOp->getResultTypes().front(), viewLoad); + auto result = sharedLoad->getResults(); + loadOp->replaceAllUsesWith(result); + } + schedule.erase(loadOp); + loadOp->erase(); +} + +void createTMAAsyncLoad(scf::ForOp forOp, + tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value barrier, + Operation *waitOp, CoarseSchedule &schedule) { + return createTMAAsyncCopy(forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, + extractIdx, barrier, waitOp, schedule, + [&](OpBuilderForStage &builder, Value tmaPtr, + Value barrier, Value view, Value pred) { + builder.create( + loadOp.getLoc(), tmaPtr, loadOp.getIndices(), + barrier, view, pred); + }); +} + +void createTMAAsyncGather(scf::ForOp forOp, + tt::ExperimentalDescriptorGatherOp gatherOp, + Value alloc, Value insertIdx, Value extractIdx, + Value barrier, Operation *waitOp, + CoarseSchedule &schedule) { + return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc, + insertIdx, extractIdx, barrier, waitOp, schedule, + [&](OpBuilderForStage &builder, Value tmaPtr, + Value barrier, Value view, Value pred) { + builder.create( + gatherOp.getLoc(), tmaPtr, + gatherOp.getXOffsets(), gatherOp.getYOffset(), + barrier, view, pred); + }); +} + +struct AsyncLoad { + int stageDiff; + Value alloc; + Value barrier; + Operation *waitOp; + SharedEncodingTrait sharedEncoding; +}; +struct LoadGroupInfo { + Value insertIdx; + Value extractIdx; + Value phase; + bool hasTMALoad = false; +}; + +void createTMABarrierAndWait( + scf::ForOp forOp, llvm::MapVector &asyncLoads, + const llvm::MapVector &loadGroups, + CoarseSchedule &schedule) { + SmallVector> commonWaitGroups; + llvm::SmallDenseSet visited; + // Find groups of loads that can share the same barrier. We look consecutive + // loads and check that there are uses in between. + for (auto &[loadOp, asyncLoad] : asyncLoads) { + if (!isTMALoad(loadOp) || visited.count(loadOp)) + continue; + llvm::SmallDenseSet users; + SmallVector group; + Block *loadBlock = loadOp->getBlock(); + auto addToGroup = [&](Operation *loadOp) { + group.push_back(loadOp); + visited.insert(loadOp); + for (Operation *user : loadOp->getUsers()) { + // Special case for MMAv3 loads, we can ignore the alloc and only + // consider uses of the alloc op since it will be removed. + if (canBeShmemPipelined(loadOp)) { + auto alloc = cast(*loadOp->getUsers().begin()); + if (alloc->getBlock() == loadBlock) { + users.insert(alloc->getUsers().begin(), alloc->getUsers().end()); + continue; + } + } + Operation *userInBlock = loadBlock->findAncestorOpInBlock(*user); + if (userInBlock) + users.insert(userInBlock); + } + }; + addToGroup(loadOp); + Operation *nextOp = loadOp->getNextNode(); + int numBuffers = asyncLoad.stageDiff; + while (nextOp) { + if (users.count(nextOp) || visited.count(nextOp)) + break; + if (isTMALoad(nextOp) && asyncLoads.count(nextOp)) { + if (asyncLoads[nextOp].stageDiff != numBuffers) + break; + if (group.size() > 0 && schedule[group[0]] == schedule[nextOp]) { + addToGroup(nextOp); + } + } + nextOp = nextOp->getNextNode(); + } + commonWaitGroups.push_back(group); + } + + // For each group calculate the size and insert the barrier after the last + // load. + for (SmallVector &group : commonWaitGroups) { + int sizeInBytes = 0; + int numBuffers = asyncLoads[group[0]].stageDiff; + const LoadGroupInfo loadGroup = loadGroups.find(numBuffers)->second; + for (Operation *op : group) { + auto tensorTy = cast(op->getResultTypes()[0]); + int loadSize = product(tensorTy.getShape()); + sizeInBytes += + loadSize * tensorTy.getElementType().getIntOrFloatBitWidth() / 8; + } + + Value barrierAlloc = triton::createBarrierAlloc(forOp, numBuffers); + Location loc = forOp.getLoc(); + OpBuilderForStage builder(group[0], schedule); + Value barrier = triton::createSingleBufferView(builder, barrierAlloc, + loadGroup.insertIdx); + Value pred = builder.create(loc, 1, 1); + builder.create(loc, barrier, sizeInBytes, pred); + + builder.setInsertionPointAfter(group.back()); + Operation *firstUse = getFirstUseOfPipelinedOp(group, forOp, schedule); + builder.setStageCluster(schedule[firstUse]); + Value barrierViewWait = triton::createSingleBufferView( + builder, barrierAlloc, loadGroup.extractIdx); + auto wait = builder.create(loc, barrierViewWait, + loadGroup.phase); + + // Update the async loads info. + for (Operation *op : group) { + asyncLoads[op].barrier = barrier; + asyncLoads[op].waitOp = wait; + } + + // Invalidate and deallocate barrier + builder.setInsertionPointAfter(forOp); + for (int i = 0; i < numBuffers; i++) { + Value barrierView = + triton::createSingleBufferView(builder, barrierAlloc, i); + builder.create(loc, barrierView); + } + builder.create(loc, barrierAlloc); + } +} + +// Check if load requires additional buffer for a mma pipelining +bool loadRequiresAdditionalBuffer(Operation *loadOp) { + // TODO: Limit the cases to only the wgmma pipelining once mmav5 + // pipelining is integrated with the new pipeliner + if (canBeShmemPipelined(loadOp)) { + return true; + } + // Pattern match the op sequence used for mmav5 scales + if (loadOp->hasOneUse()) { + ttg::LocalAllocOp alloc = + dyn_cast(*loadOp->getUsers().begin()); + if (alloc && alloc->hasOneUse()) { + if (isa(*alloc->getUsers().begin())) { + return true; + } + } + } + return false; +} + +scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) { + llvm::MapVector asyncLoads; + llvm::MapVector loadGroups; + // Only visit the top level ops, we do not support pipelining conditional + // loads for now + for (auto &op : forOp.getBody()->without_terminator()) { + if (isa(op)) { + int stageDiff = getDefUseStageDiff(&op, forOp, schedule); + if (stageDiff == 0 || !isa(op.getResultTypes()[0])) { + // Don't care about non-pipelined loads. Don't use async loads for + // scalar values. + continue; + } + SharedEncodingTrait sharedEncoding = getSharedEncoding(&op); + // Do not create async loads for small loads (cp.async requires at least 4 + // bytes) + int copyVecBytes = getCopyVecBytes( + cast(op.getResultTypes()[0]), sharedEncoding); + if (copyVecBytes >= 4 || isTMALoad(&op)) { + if (loadRequiresAdditionalBuffer(&op)) { + // Allocate additional buffer required by the wgmma pipelining. + stageDiff += 1; + } + // asyncLoads[&op] = {.stageDiff = stageDiff, + // .sharedEncoding = sharedEncoding}; + } else if (stageDiff > 1) { + // Distance-1 loads can in most cases be pipelined in registers without + // any performance degradation, as the schedule will usually reorder the + // user and the producer so there is no liverange overlap, and no copy + // needed. + op.emitRemark() << "Pipelining load that cannot use vectorized " + "copy. This will likely " + "lead to pipelining in registers and severe " + "performance degradation."; + } + } + } + + if (asyncLoads.empty()) + return forOp; + + for (auto &[loadOp, asyncLoad] : asyncLoads) { + Value alloc = createAlloc(forOp, loadOp, asyncLoad.sharedEncoding, + asyncLoad.stageDiff); + asyncLoad.alloc = alloc; + loadGroups.insert({asyncLoad.stageDiff, {}}); + if (isTMALoad(loadOp)) { + loadGroups[asyncLoad.stageDiff].hasTMALoad = true; + } + } + + IRRewriter builder(forOp); + builder.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); + // Create a counter to index into the allocations per loop iteration. + // NOTE: We create two duplicates values, insertIdx and extractIdx so that the + // pipeliner will re-materialize the value in later stages of the pipeline + // instead of carrying it as a dependency across multiple iterations. + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + SmallVector newOperands; + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + for (auto [_, loadGroup] : loadGroups) { + newOperands.push_back(minusOne); // insertIdx + newOperands.push_back(minusOne); // extractIdx + if (loadGroup.hasTMALoad) { + // A single barrier arrival sequence is a "phase" and two phases can + // overlap, provided the phases are differentiated with an alternating + // boolean value. + newOperands.push_back(zero); // phase + } + } + + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + + // Update yield op with temporary yield values + auto forYield = cast(newForOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) { + forYield.getResultsMutable().append(newOperands[i]); + } + + builder.setInsertionPoint(forOp); + loc = forOp.getLoc(); + int argIdx = newOperandIndex; + for (auto &[numBuffers, loadGroup] : loadGroups) { + Value insertIdx = newForOp.getBody()->getArgument(argIdx); + argIdx++; + Value extractIdx = newForOp.getBody()->getArgument(argIdx); + argIdx++; + Value phase = nullptr; + if (loadGroup.hasTMALoad) { + phase = newForOp.getBody()->getArgument(argIdx); + argIdx++; + } + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin()); + + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + loadGroup.insertIdx = createIncrementModulo(builder, loc, insertIdx, + numBuffersVal, zero, one); + Value cndExt = nullptr; + loadGroup.extractIdx = createIncrementModulo( + builder, loc, extractIdx, numBuffersVal, zero, one, &cndExt); + if (phase) { + Value nextPhase = builder.create(loc, phase, one); + phase = builder.create(loc, cndExt, phase, nextPhase); + loadGroup.phase = phase; + } + } + + createTMABarrierAndWait(forOp, asyncLoads, loadGroups, schedule); + + for (auto [op, asyncLoad] : asyncLoads) { + auto [insertIdx, extractIdx, phase, _] = loadGroups[asyncLoad.stageDiff]; + if (auto loadOp = dyn_cast(op)) { + createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, + schedule); + } else if (auto loadOp = dyn_cast(op)) { + createTMAAsyncLoad(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, + asyncLoad.barrier, asyncLoad.waitOp, schedule); + } else if (auto loadOp = dyn_cast(op)) { + createTMAAsyncGather(forOp, loadOp, asyncLoad.alloc, insertIdx, + extractIdx, asyncLoad.barrier, asyncLoad.waitOp, + schedule); + } + } + // Patch the yield with the updated counters. Subtract to account for the loop + // counter. + argIdx = newOperandIndex - 1; + for (auto &[numBuffers, loadGroup] : loadGroups) { + forYield.setOperand(argIdx++, loadGroup.insertIdx); + forYield.setOperand(argIdx++, loadGroup.extractIdx); + if (loadGroup.phase) + forYield.setOperand(argIdx++, loadGroup.phase); + } + + // Automatically discover dependencies and schedule new insert/extract ops to + // correct stages. + scheduleDependencies(forOp, schedule); + + // Insert sync point for any possibly outstanding loads after the loop. This + // can happen as we speculatively execute loads in the loop. + builder.setInsertionPointAfter(forOp); + builder.create(loc, ValueRange({}), 0); + + // Make sure all ops have attributes. + for (Operation &op : forOp.getBody()->without_terminator()) { + assert(schedule.count(&op) && "op not found in the schedule"); + } + return forOp; +} + +///////////////////////////// +// LOWER TMA DESCRIPTORS +///////////////////////////// + +LogicalResult +allocTMABuffers(scf::ForOp forOp, + llvm::MapVector &tmaBufferMapping, + int numStages) { + IRRewriter rewriter(forOp); + + // Create a multi-buffered allocation for each MakeTensorDescOp call in the + // loop + forOp.walk([&](tt::MakeTensorDescOp op) { + // TODO peter: walk to loop yield to find the init value if this is a + // loop-carried value. That would save us from allocating another buffer + // just for the init value + auto loc = op.getLoc(); + Value alloc = rewriter.create( + loc, triton::getPointerType(rewriter.getI8Type()), + numStages * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN); + tmaBufferMapping[op.getOperation()] = alloc; + }); + return success(); +} + +template +Value subviewTMADescriptor(BuilderT &builder, Location loc, Value alloc, + Value counter) { + Value tmaSizeVal = builder.template create( + loc, ttng::TMA_SIZE_BYTES, 32); + Value offset = + builder.template create(loc, tmaSizeVal, counter); + return builder.template create(loc, alloc.getType(), alloc, + offset); +} + +LogicalResult rewriteTMABufferUpdates( + scf::ForOp forOp, + const llvm::MapVector &tmaBufferMapping, + ArrayRef tmaCounters, int numStages, Value one, Value zero, + CoarseSchedule &schedule) { + assert(tmaBufferMapping.size() == tmaCounters.size()); + + Value numStagesVal = mlir::OpBuilder(forOp).create( + forOp.getLoc(), numStages, 32); + + for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) { + auto &[op, alloc] = pair; + + // Rewriter MakeTensorDescOp as writing a TMA descriptor + auto makeDescOp = cast(op); + + OpBuilderForStage stageBuilder(makeDescOp, schedule); + auto loc = makeDescOp.getLoc(); + + BlockArgument counter = tmaCounters[iOp]; + Value nextBuf = subviewTMADescriptor(stageBuilder, loc, alloc, counter); + if (failed(ttng::createTMADesc(nextBuf, makeDescOp, stageBuilder))) { + return failure(); + } + stageBuilder.create( + loc, nextBuf); + Value nextDesc = stageBuilder.create( + loc, makeDescOp.getType(), nextBuf); + + makeDescOp.getResult().replaceAllUsesWith(nextDesc); + + // Increment the buffer index counter + Value nextCounter = createIncrementModulo(stageBuilder, loc, counter, + numStagesVal, zero, one); + + // If we are in a (potentially nested) if region, propagate the counter + // up to the main for op body scope + Operation *curOp = op; + Operation *parent = op->getParentOp(); + while (parent != forOp.getOperation()) { + auto ifOp = dyn_cast(parent); + if (!ifOp) { + std::string msg; + llvm::raw_string_ostream ss(msg); + ss << "Cannot pipeline MakeTensorDescOp inside:\n"; + parent->print(ss); + ss << "\nOnly scf.if regions are supported"; + return makeDescOp->emitOpError(std::move(msg)); + } + + IRRewriter rewriter(parent); + auto newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, {nextCounter.getType()}); + + auto yieldNewBlock = newIfOp.thenBlock(); + auto yieldOldBlock = newIfOp.elseBlock(); + + if (yieldNewBlock != curOp->getBlock()) { + std::swap(yieldNewBlock, yieldOldBlock); + } + cast(yieldNewBlock->getTerminator()) + .getResultsMutable() + .append(nextCounter); + cast(yieldOldBlock->getTerminator()) + .getResultsMutable() + .append(counter); + + ifOp.erase(); + nextCounter = newIfOp.getResults().back(); + curOp = newIfOp; + parent = newIfOp->getParentOp(); + } + + // Finally, rewrite the loop level yield + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield.setOperand(counter.getArgNumber() - 1, nextCounter); + } + return success(); +} + +scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule) { + llvm::MapVector tmaBufferMapping; + if (failed( + allocTMABuffers(forOp, tmaBufferMapping, schedule.getNumStages()))) { + llvm_unreachable("TMA pipelining failed"); + } + + IRRewriter builder(forOp); + Location loc = forOp.getLoc(); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + SmallVector newOperands; + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Create one counter per TMA buffer. This allows the descriptors to be + // updated independently without needing to write duplicate of existing tma + // descriptors. + unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size(); + for (int i = 0; i < tmaBufferMapping.size(); ++i) { + newOperands.push_back(zero); + } + + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + + auto tmaCounters = ArrayRef(newForOp.getBody()->getArguments()) + .slice(tmaCounterArgsStartIdx); + + // Update yield op with temporary yield values + auto forYield = cast(newForOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) { + forYield.getResultsMutable().append(newOperands[i]); + } + + if (failed(rewriteTMABufferUpdates(newForOp, tmaBufferMapping, tmaCounters, + schedule.getNumStages(), one, zero, + schedule))) { + llvm_unreachable("Failed to rewrite TMA ops"); + } + return newForOp; +} + +///////////////////////////// +// LOWER LOOP +///////////////////////////// + +void lowerLoop(scf::ForOp forOp) { + CoarseSchedule schedule; + if (failed(schedule.deSerialize(forOp))) { + return; + } + scf::ForOp newForOp = lowerLoads(forOp, schedule); + newForOp = lowerTMADescriptors(newForOp, schedule); + schedule.serialize(newForOp); +} + +} // namespace + +void lowerLoops(ModuleOp moduleOp) { + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + if (loops.empty()) + return; + for (auto forOp : loops) { + lowerLoop(forOp); + } +} + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/ModifiedAccMMAPipeline.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/ModifiedAccMMAPipeline.cpp new file mode 100644 index 000000000..48fa2dbe0 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/ModifiedAccMMAPipeline.cpp @@ -0,0 +1,170 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +ttng::TMEMAllocOp createTMemAlloc(IRRewriter &rewriter, + ttng::TMEMAllocOp oldTMemAllocOp, + Value initValue) { + Location loc = oldTMemAllocOp.getLoc(); + auto oldRetType = oldTMemAllocOp.getType(); + SmallVector shape = {oldRetType.getShape().begin(), + oldRetType.getShape().end()}; + Type accMemDescType = triton::gpu::MemDescType::get( + shape, oldRetType.getElementType(), oldRetType.getEncoding(), + oldRetType.getMemorySpace(), /*mutableMemory=*/true); + return rewriter.create(oldTMemAllocOp.getLoc(), + accMemDescType, initValue); +} + +Value createBarrierAlloc(IRRewriter &rewriter, scf::ForOp forOp) { + MLIRContext *ctx = forOp.getContext(); + Location loc = forOp.getLoc(); + unsigned numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs( + forOp->getParentOfType()); + Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(ctx); + auto barrierCTALayout = ttg::CTALayoutAttr::get( + /*context=*/ctx, /*CTAsPerCGA=*/{numCTAs}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout); + ttg::MemDescType barrierMemDescType = + ttg::MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = + rewriter.create(loc, barrierMemDescType, Value()); + rewriter.create(forOp->getLoc(), barrierAlloc, 1); + return barrierAlloc; +} + +scf::ForOp pipelineDot(scf::ForOp forOp, ttng::TCGen5MMAOp dotOp, + ttng::TMEMLoadOp loadOp, ttng::TMEMAllocOp allocOp, + Operation *accModOp, int yieldArgNo) { + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + Location loc = forOp.getLoc(); + Value vTrue = rewriter.create(loc, 1, 1); + Value accInitValue = forOp.getInitArgs()[yieldArgNo]; + auto newAlloc = createTMemAlloc(rewriter, allocOp, accInitValue); + Value barrier = createBarrierAlloc(rewriter, forOp); + Value phase = rewriter.create(loc, 0, 32); + Value notZerothIter = rewriter.create(loc, 0, 1); + Type loadTy = loadOp.getType(); + scf::ForOp newForOp = + replaceForOpWithNewSignature(rewriter, forOp, {phase, notZerothIter}); + forOp.erase(); + forOp = newForOp; + phase = forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 2); + notZerothIter = forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1); + Value oldAccValue = forOp.getRegionIterArg(yieldArgNo); + + rewriter.setInsertionPoint(accModOp ? accModOp : dotOp); + loc = dotOp.getLoc(); + rewriter.create(loc, barrier, phase, notZerothIter); + auto flag = rewriter.create(loc, rewriter.getI32Type(), + notZerothIter); + phase = rewriter.create(loc, phase, flag); + if (accModOp) { + // Update the source of the modifying op + loc = accModOp->getLoc(); + auto accValue = rewriter.create(loc, loadTy, newAlloc); + accModOp->replaceUsesOfWith(oldAccValue, accValue); + rewriter.setInsertionPointAfter(accModOp); + rewriter.create(loc, newAlloc, accModOp->getResult(0), + vTrue); + } + + // Update the dot op + dotOp.getDMutable().assign(newAlloc); + dotOp.getBarrierMutable().assign(barrier); + + // Update the yield + appendToForOpYield(forOp, {phase, vTrue}); + + // Short-circuit the loop carry value that was holding the accumulator value, + // removing the last reference to the loaded accumulator. + forOp.getBody()->getTerminator()->setOperand(yieldArgNo, oldAccValue); + + // Remove the old alloc and load + loadOp.erase(); + allocOp.erase(); + + // Update the uses outside the loop + rewriter.setInsertionPointAfter(forOp); + phase = forOp.getResult(forOp.getNumResults() - 2); + notZerothIter = forOp.getResult(forOp.getNumResults() - 1); + rewriter.create(dotOp.getLoc(), barrier, phase, + notZerothIter); + auto afterLoopLoad = + rewriter.create(forOp.getLoc(), loadTy, newAlloc); + forOp->getResult(yieldArgNo).replaceAllUsesWith(afterLoopLoad.getResult()); + + rewriter.create(dotOp.getLoc(), barrier); + rewriter.create(dotOp.getLoc(), barrier); + + return forOp; +} + +scf::ForOp mlir::triton::pipelineMMAWithScaledAcc(scf::ForOp forOp) { + // Look for chained mmas for which the tmem access is not pipelined yet, with + // an operation modifying the acc value before the mma. + SmallVector dotOps; + forOp.walk([&](ttng::TCGen5MMAOp mmaOp) { + if (mmaOp->getBlock() != forOp.getBody()) { + return; + } + dotOps.push_back(mmaOp); + }); + + dotOps = getMMAsWithMultiBufferredOperands(forOp, dotOps); + + for (auto op : dotOps) { + auto dotOp = llvm::cast(op); + auto tmemAlloc = dotOp.getD().getDefiningOp(); + if (!tmemAlloc || tmemAlloc->getBlock() != dotOp->getBlock()) { + continue; + } + if (tmemAlloc.getSrc() == nullptr) { + continue; + } + ttng::TMEMLoadOp tmemLoad = nullptr; + for (auto user : tmemAlloc.getResult().getUsers()) { + if (auto load = dyn_cast(user)) { + tmemLoad = load; + break; + } + } + if (!tmemLoad || tmemLoad->getBlock() != dotOp->getBlock() || + !tmemLoad.getResult().hasOneUse()) { + continue; + } + OpOperand &tmemLoadUse = *tmemLoad.getResult().getUses().begin(); + auto yieldOp = dyn_cast(tmemLoadUse.getOwner()); + if (!yieldOp || yieldOp->getParentOfType() != forOp) { + continue; + } + int yieldArgNo = tmemLoadUse.getOperandNumber(); + if (!forOp.getRegionIterArg(yieldArgNo).hasOneUse()) { + continue; + } + Operation *accModOp = + *forOp.getRegionIterArg(yieldArgNo).getUsers().begin(); + if (accModOp == tmemAlloc) { + accModOp = nullptr; // not really an acc modification + } else { + if (!accModOp->hasOneUse() && + *accModOp->getUsers().begin() != tmemAlloc) { + continue; + } + } + + forOp = + pipelineDot(forOp, dotOp, tmemLoad, tmemAlloc, accModOp, yieldArgNo); + } + return forOp; +}; diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp new file mode 100644 index 000000000..c453ea979 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -0,0 +1,864 @@ +//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop software pipelining +// +//===----------------------------------------------------------------------===// + +// Fork of upstream pipeliner. This will be merged upstream once things are +// stable. Modifications so far are: +// -Bug fix for def with a distance of 1 scheduled in stage 0. +// -Support dynamic loops and predicate operations in the prologue. +// -Support for non-index type for induction variable. +// -Support source with distance of 1 used multiple stages later. +// -Fix bug when a value yield is used outside the loop and the value def is not +// in the last stage. If we are not peeling the epilgue we need to remap the +// output correctly. +// -Allow for distance of 2 or more between producer and consumer for the cases +// where the producer is in the same stage as the consumer. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" + +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" + +// FIXME: PipelineExpander should not depend on Triton-specific headers! +#include "triton/Dialect/TritonGPU/IR/Types.h" + +#define DEBUG_TYPE "triton-loop-pipelining" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::triton; + +namespace { + +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + +protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value ub; + Value lb; + Value step; + bool dynamicLoop; + triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + /// Return the defining op of the given value, if the Value is an argument of + /// the loop return the associated defining op in the loop and its distance to + /// the Value. + std::pair getDefiningOpAndDistance(Value value); + + /// Return true if the schedule is possible and return false otherwise. A + /// schedule is correct if all definitions are scheduled before uses. + bool verifySchedule(); + +public: + /// Initialize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + LogicalResult emitPrologue(RewriterBase &rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + LogicalResult emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues); +}; + +/// Find operands of all the nested operations within `op`. +static SetVector getNestedOperands(Operation *op) { + SetVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + operands.insert(operand); + } + }); + return operands; +} + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const triton::PipeliningOption &options) { + LDBG("Start initializeLoopInfo"); + forOp = op; + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + dynamicLoop = true; + auto upperBoundCst = ub.getDefiningOp(); + auto lowerBoundCst = lb.getDefiningOp(); + auto stepCst = step.getDefiningOp(); + if (!upperBoundCst || !lowerBoundCst || !stepCst) { + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm); + if (numIteration > maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } + } + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { + LDBG("--no epilogue or predicate set -> BAIL"); + return false; + } + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + // All operations need to have a stage. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!stages.contains(&op)) { + op.emitOpError("not assigned a pipeline stage"); + LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + return false; + } + } + + if (!verifySchedule()) { + LDBG("--invalid schedule: " << op << " -> BAIL"); + return false; + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + LDBG("--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"); + return false; + } + } + + // Support only loop-carried dependencies with a distance of one iteration or + // those defined outside of the loop. This means that any dependency within a + // loop should either be on the immediately preceding iteration, the current + // iteration, or on variables whose values are set before entering the loop. + for (auto &op : forOp.getBody()->without_terminator()) { + for (auto operand : getNestedOperands(&op)) { + auto [def, distance] = getDefiningOpAndDistance(operand); + if (!def) + continue; + if (distance > 1 && (stages[def] != stages[&op])) { + // Allow the case of loop carried dependency between the ops in the same + // stage. + LDBG("--only support loop carried dependency with a distance of 1 or " + "defined outside of the loop -> BAIL"); + return false; + } + } + } + annotateFn = options.annotateFn; + return true; +} + +/// Compute unrolled cycles of each op (consumer) and verify that each op is +/// scheduled after its operands (producers) while adjusting for the distance +/// between producer and consumer. +bool LoopPipelinerInternal::verifySchedule() { + int64_t numCylesPerIter = opOrder.size(); + // Pre-compute the unrolled cycle of each op. + DenseMap unrolledCyles; + for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { + Operation *def = opOrder[cycle]; + auto it = stages.find(def); + assert(it != stages.end()); + int64_t stage = it->second; + unrolledCyles[def] = cycle + stage * numCylesPerIter; + } + for (Operation *consumer : opOrder) { + int64_t consumerCycle = unrolledCyles[consumer]; + for (Value operand : getNestedOperands(consumer)) { + auto [producer, distance] = getDefiningOpAndDistance(operand); + if (!producer) + continue; + auto it = unrolledCyles.find(producer); + // Skip producer coming from outside the loop. + if (it == unrolledCyles.end()) + continue; + int64_t producerCycle = it->second; + if (consumerCycle < producerCycle - numCylesPerIter * distance) { + InFlightDiagnostic diag = + consumer->emitWarning("operation scheduled before its operands. " + "Pipelining will be disabled."); + diag.attachNote(producer->getLoc()) + .append("operand defined here: ") + .appendOp(*producer, OpPrintingFlags().printGenericOpForm()); + return false; + } + } + } + return true; +} + +/// Clone `op` and call `callback` on the cloned op's operands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + clone->walk([&](Operation *nested) { + // 'clone' itself will be visited first. + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || isa(operand.get())) + callback(&operand); + } + }); + return clone; +} + +LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { + setValueMapping(arg, operand.get(), 0); + } + + // If the incoming value to an iter arg from the loop yield is defined outside + // the loop, then that means the iter arg takes that value for all stages + // after the first stage. + auto yield = cast(forOp.getBody()->getTerminator()); + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), yield->getOpOperands())) { + if (forOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) + continue; + for (int64_t i = 1; i < maxStage; ++i) + setValueMapping(arg, operand.get(), i); + } + + Location loc = forOp.getLoc(); + SmallVector predicates(maxStage); + for (int64_t i = 0; i < maxStage; i++) { + // special handling for induction variable as the increment is implicit. + // iv = lb + i * step + Type t = lb.getType(); + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, + rewriter.getIntegerAttr(t, i)))); + setValueMapping(forOp.getInductionVar(), iv, i); + + if (dynamicLoop) { + // pred = ub > lb + (i * step) + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, ub); + } + + for (Operation *op : opOrder) { + if (stages[op] > i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + int predicateIdx = i - stages[op]; + if (predicates[predicateIdx]) { + OpBuilder::InsertionGuard insertGuard(rewriter); + newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); + if (newOp == nullptr) + return failure(); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + Value source = newOp->getResult(destId); + // If the value is a loop carried dependency update the loop argument + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + if (predicates[predicateIdx] && + !forOp.getResult(operand.getOperandNumber()).use_empty()) { + // If the value is used outside the loop, we need to make sure we + // return the correct version of it. + Value prevValue = valueMapping + [forOp.getRegionIterArgs()[operand.getOperandNumber()]] + [i - stages[op]]; + source = rewriter.create( + loc, predicates[predicateIdx], source, prevValue); + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + source, i - stages[op] + 1); + } + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + } + } + } + return success(); +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + for (Operation *op : opOrder) { + unsigned stage = stages[op]; + + auto analyzeOperand = [&](OpOperand &operand) { + auto [def, distance] = getDefiningOpAndDistance(operand.get()); + if (!def) + return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) + return; + assert(stage > defStage->second); + LiverangeInfo &info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); + } + return crossStageValues; +} + +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + int64_t distance = 0; + while (auto arg = dyn_cast(value)) { + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + distance++; + value = + forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + Operation *def = value.getDefiningOp(); + if (!def) + return {nullptr, 0}; + return {def, distance}; +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector + &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + auto defStage = stages.find(def); + if (defStage != stages.end()) { + Value valueVersion = + valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage->second]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } else + newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); + } + for (auto escape : crossStageValues) { + LiverangeInfo &info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + Value maxStageValue = rewriter.create( + loc, rewriter.getIntegerAttr(t, maxStage)); + Value maxStageByStep = + rewriter.create(loc, step, maxStageValue); + newUb = rewriter.create(loc, ub, maxStageByStep); + } + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + newForOp->setAttrs(forOp->getAttrs()); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector + &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + Location loc = newForOp.getLoc(); + Type t = ub.getType(); + for (unsigned i = 0; i < maxStage; i++) { + // c = ub - (maxStage - i) * step + Value c = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); + + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } + for (Operation *op : opOrder) { + int64_t useStage = stages[op]; + auto *newOp = rewriter.clone(*op, mapping); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = rewriter.create( + forOp.getLoc(), step, + rewriter.create( + forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); + Value iv = rewriter.create( + forOp.getLoc(), newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + Value source = operand->get(); + auto arg = dyn_cast(source); + if (arg && arg.getOwner() == forOp.getBody()) { + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = source.getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } + + if (predicates[useStage]) { + OpBuilder::InsertionGuard insertGuard(rewriter); + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) + return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (OpOperand &yieldOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yieldOperand.get()); + // When we don't peel the epilogue and the yield value is used outside the + // loop we need to make sure we return the version from numStages - + // defStage. + if (!peelEpilogue && + !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { + Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end() && defStage->second < maxStage) { + Value pred = predicates[defStage->second]; + source = rewriter.create( + pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yieldOperand.getOperandNumber() + 1]); + } + } + } + yieldOperands.push_back(source); + } + + for (auto &it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + auto defStage = stages.find(def); + if (defStage == stages.end()) { + for (unsigned int stage = 1; stage <= maxStage; stage++) + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + retVal.value(), stage); + } else if (defStage->second > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage->second + 1); + } + } + rewriter.create(forOp.getLoc(), yieldOperands); + return success(); +} + +LogicalResult +LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + + auto createConst = [&](int v) { + return rewriter.create(loc, + rewriter.getIntegerAttr(t, v)); + }; + + // total_iterations = cdiv(range_diff, step); + // - range_diff = ub - lb + // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step + Value zero = createConst(0); + Value one = createConst(1); + Value stepLessZero = rewriter.create( + loc, arith::CmpIPredicate::slt, step, zero); + Value stepDecr = + rewriter.create(loc, stepLessZero, one, createConst(-1)); + + Value rangeDiff = rewriter.create(loc, ub, lb); + Value rangeIncrStep = rewriter.create(loc, rangeDiff, step); + Value rangeDecr = + rewriter.create(loc, rangeIncrStep, stepDecr); + Value totalIterations = rewriter.create(loc, rangeDecr, step); + + // If total_iters < max_stage, start the epilogue at zero to match the + // ramp-up in the prologue. + // start_iter = max(0, total_iters - max_stage) + Value iterI = rewriter.create(loc, totalIterations, + createConst(maxStage)); + iterI = rewriter.create(loc, zero, iterI); + + // Capture predicates for dynamic loops. + SmallVector predicates(maxStage + 1); + + for (int64_t i = 1; i <= maxStage; i++) { + // newLastIter = lb + step * iterI + Value newlastIter = rewriter.create( + loc, lb, rewriter.create(loc, step, iterI)); + + setValueMapping(forOp.getInductionVar(), newlastIter, i); + + // increment to next iterI + iterI = rewriter.create(loc, iterI, one); + + if (dynamicLoop) { + // Disable stages when `i` is greater than total_iters. + // pred = total_iters >= i + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::sge, totalIterations, createConst(i)); + } + } + + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + SmallVector> returnMap(returnValues.size()); + for (Operation *op : opOrder) { + if (stages[op] < i) + continue; + unsigned currentVersion = maxStage - stages[op] + i; + unsigned nextVersion = currentVersion + 1; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[currentVersion]; + newOperand->set(replacement); + } + }); + if (dynamicLoop) { + OpBuilder::InsertionGuard insertGuard(rewriter); + newOp = predicateFn(rewriter, newOp, predicates[currentVersion]); + if (!newOp) + return failure(); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, + i - 1); + for (auto [opRes, newRes] : + llvm::zip(op->getResults(), newOp->getResults())) { + setValueMapping(opRes, newRes, currentVersion); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != opRes) + continue; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + unsigned ri = operand.getOperandNumber(); + returnValues[ri] = newRes; + Value mapVal = forOp.getRegionIterArgs()[ri]; + returnMap[ri] = std::make_pair(mapVal, currentVersion); + if (nextVersion <= maxStage) + setValueMapping(mapVal, newRes, nextVersion); + } + } + } + if (dynamicLoop) { + // Select return values from this stage (live outs) based on predication. + // If the stage is valid select the peeled value, else use previous stage + // value. + for (auto pair : llvm::enumerate(returnValues)) { + unsigned ri = pair.index(); + auto [mapVal, currentVersion] = returnMap[ri]; + if (mapVal) { + unsigned nextVersion = currentVersion + 1; + Value pred = predicates[currentVersion]; + Value prevValue = valueMapping[mapVal][currentVersion]; + auto selOp = rewriter.create(loc, pred, pair.value(), + prevValue); + returnValues[ri] = selOp; + if (nextVersion <= maxStage) + setValueMapping(mapVal, selOp, nextVersion); + } + } + } + } + return success(); +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +FailureOr +mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const triton::PipeliningOption &options, + bool *modifiedIR) { + if (modifiedIR) + *modifiedIR = false; + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); + + if (modifiedIR) + *modifiedIR = true; + + // 1. Emit prologue. + if (failed(pipeliner.emitPrologue(rewriter))) + return failure(); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return failure(); + + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + if (failed(pipeliner.emitEpilogue(rewriter, returnValues))) + return failure(); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + + return newForOp; +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp new file mode 100644 index 000000000..9200ccdfa --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -0,0 +1,308 @@ +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Casting.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +bool mlir::triton::loopHasDistGreaterThanOne(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + }); +} + +bool mlir::triton::isOuterLoop(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getOperations(), [](Operation &op) { + return isa(op); + }); +} + +// Combine the current mask with the given predicate. +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (isa(maskType)) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +// Function to mask operations during scheduling. +Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto ifOp = dyn_cast(op)) { + rewriter.setInsertionPoint(op); + Value cnd = getPredMask(rewriter, ifOp.getCondition().getType(), + ifOp.getCondition(), pred); + ifOp.getConditionMutable().assign(cnd); + return op; + } + if (auto asyncCopyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(asyncCopyOp); + Value mask = getPredMask(rewriter, asyncCopyOp.getSrc().getType(), + asyncCopyOp.getMask(), pred); + asyncCopyOp.getMaskMutable().assign(mask); + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } + if (auto copyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(copyOp); + Value mask = getPredMask(rewriter, copyOp.getPred().getType(), + copyOp.getPred(), pred); + copyOp.getPredMutable().assign(mask); + return op; + } + if (auto gatherOp = dyn_cast(op)) { + rewriter.setInsertionPoint(gatherOp); + Value mask = getPredMask(rewriter, gatherOp.getPred().getType(), + gatherOp.getPred(), pred); + gatherOp.getPredMutable().assign(mask); + return op; + } + if (auto expectOp = dyn_cast(op)) { + rewriter.setInsertionPoint(expectOp); + Value mask = getPredMask(rewriter, expectOp.getPred().getType(), + expectOp.getPred(), pred); + expectOp.getPredMutable().assign(mask); + return op; + } + if (auto mmav5Op = dyn_cast(op)) { + rewriter.setInsertionPoint(mmav5Op); + auto currPred = mmav5Op.getPredicate(); + Value mask = getPredMask(rewriter, currPred.getType(), currPred, pred); + mmav5Op.setPredicate(mask); + return op; + } + if (auto tmemStoreOp = dyn_cast(op)) { + rewriter.setInsertionPoint(tmemStoreOp); + Value mask = getPredMask(rewriter, tmemStoreOp.getPred().getType(), + tmemStoreOp.getPred(), pred); + tmemStoreOp.getPredMutable().assign(mask); + return op; + } + if (auto waitBarrier = dyn_cast(op)) { + rewriter.setInsertionPoint(waitBarrier); + Value mask = pred; + Value currentPred = waitBarrier.getPred(); + if (currentPred) { + mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred); + } + waitBarrier.getPredMutable().assign(mask); + return op; + } + if (auto storeOp = dyn_cast(op)) { + rewriter.setInsertionPoint(storeOp); + Value mask = getPredMask(rewriter, storeOp.getPtr().getType(), + storeOp.getMask(), pred); + storeOp.getMaskMutable().assign(mask); + return op; + } + if (auto atomicRMWOp = dyn_cast(op)) { + rewriter.setInsertionPoint(atomicRMWOp); + Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(), + atomicRMWOp.getMask(), pred); + atomicRMWOp.getMaskMutable().assign(mask); + return op; + } + + op->emitError("pipeliner doesn't know how to predicate this op."); + llvm::report_fatal_error("Fatal pipeliner error"); + return op; +} + +void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, + Operation *oldUse, Value val) { + SmallVector opsToDelete; + SmallVector operandsToReplace; + + // Save the operand to replace / delete later (avoid iterator invalidation). + // TODO: can we use an early_inc iterator? + for (OpOperand &use : oldUse->getUses()) { + // Non-subview/trans ops will be replaced by `val`. + if (!isa( + use.getOwner())) { + operandsToReplace.push_back(&use); + continue; + } + Operation *user = use.getOwner(); + // `subview(old_op)` is replaced by a new `subview(val)`. + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(user); + Value newVal; + if (auto subview = dyn_cast(user)) { + triton::gpu::MemDescType oldType = subview.getType(); + bool isMutable = + cast(val.getType()).getMutableMemory(); + Type newDstType = triton::gpu::MemDescType::get( + oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), + oldType.getMemorySpace(), isMutable); + newVal = builder.create( + subview.getLoc(), newDstType, val, subview.getOffsets()); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); + } else if (auto trans = dyn_cast(user)) { + newVal = builder.create(trans.getLoc(), val, + trans.getOrder()); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); + } + assert(newVal); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); + replaceUsesAndPropagateType(builder, user, newVal); + opsToDelete.push_back(use.getOwner()); + } + + // Perform late replacement. + for (OpOperand *operand : operandsToReplace) { + Operation *op = operand->getOwner(); + operand->set(val); + } + + // Perform late op erasure. + for (Operation *op : opsToDelete) + op->erase(); +} + +// Return true if the given ForOp has the attribute +// `tt.disallow_acc_multi_buffer` set to true. +bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) { + return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName); +} + +void mlir::triton::visitNestedOperands(Operation *op, + function_ref visitor) { + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isProperAncestor(op)) + visitor(operand); + } + }); +} + +SetVector mlir::triton::getNestedOperands(Operation *op) { + SetVector result; + visitNestedOperands(op, [&](Value operand) { result.insert(operand); }); + return result; +} + +int mlir::triton::getCopyVecBytes(RankedTensorType registerTy, + ttg::SharedEncodingTrait sharedEnc) { + auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(), + registerTy.getEncoding()); + auto sharedLayout = + triton::gpu::toLinearLayout(registerTy.getShape(), sharedEnc); + auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + const int vecElems = regToSharedLayout.getNumConsecutiveInOut(); + return vecElems * registerTy.getElementTypeBitWidth() / 8; +} + +void mlir::triton::serializeLatencies(ModuleOp module, + DenseMap &opLatency) { + for (auto &[op, latency] : opLatency) { + op->setAttr( + kLatencyAttrName, + IntegerAttr::get(IntegerType::get(module.getContext(), 32), latency)); + } +} + +DenseMap mlir::triton::deserializeLatencies(ModuleOp module) { + DenseMap opLatency; + module.walk([&](Operation *op) { + if (op->hasAttr(kLatencyAttrName)) { + opLatency[op] = op->getAttrOfType(kLatencyAttrName).getInt(); + op->removeAttr(kLatencyAttrName); + } + }); + return opLatency; +} + +// Create an allocation and init the mbarriers. +Value mlir::triton::createBarrierAlloc(scf::ForOp forOp, int numBarriers) { + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + MLIRContext *ctx = forOp.getContext(); + Location loc = forOp.getLoc(); + unsigned numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs( + forOp->getParentOfType()); + Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(ctx); + auto barrierCTALayout = ttg::CTALayoutAttr::get( + /*context=*/ctx, /*CTAsPerCGA=*/{numCTAs}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout); + ttg::MemDescType barrierMemDescType = ttg::MemDescType::get( + {numBarriers}, rewriter.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Value barrierAlloc = + rewriter.create(loc, barrierMemDescType, Value()); + for (unsigned i = 0; i < numBarriers; i++) { + Value barrierView = createSingleBufferView(rewriter, barrierAlloc, i); + rewriter.create(forOp->getLoc(), barrierView, 1); + } + return barrierAlloc; +} + +Value mlir::triton::createSingleBufferView(OpBuilder &builder, Value alloc, + Value idx) { + assert(isa(alloc.getType()) && + "Expected MemDescType"); + auto allocDescType = cast(alloc.getType()); + SmallVector shape; + if (allocDescType.getShape().size() > 1) { + shape.insert(shape.end(), allocDescType.getShape().begin() + 1, + allocDescType.getShape().end()); + } else { + shape.push_back(1); + } + auto viewDescType = triton::gpu::MemDescType::get( + shape, allocDescType.getElementType(), allocDescType.getEncoding(), + allocDescType.getMemorySpace(), allocDescType.getMutableMemory(), + /*allocShape=*/allocDescType.getAllocShape()); + SmallVector idxs = {idx}; + if (allocDescType.getShape().size() > 1) { + Value zero = + builder.template create(alloc.getLoc(), 0, 32); + for (unsigned i = 1; i < allocDescType.getShape().size(); i++) { + idxs.push_back(zero); + } + } + return builder.template create( + alloc.getLoc(), viewDescType, alloc, idxs); +} + +Value mlir::triton::createSingleBufferView(OpBuilder &builder, Value alloc, + int idx) { + return mlir::triton::createSingleBufferView( + builder, alloc, + builder.create(alloc.getLoc(), idx, 32)); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp new file mode 100644 index 000000000..301d6382f --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -0,0 +1,231 @@ +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +bool tt::CoarseSchedule::insertMinimum(Operation *op, int stage, + Cluster cluster) { + auto res = opToStageAndCluster.insert({op, {stage, cluster}}); + if (res.second) { + return true; + } + + auto &[existingStage, existingCluster] = res.first->second; + + // Always insert if the stage is earlier. + if (stage < existingStage) { + existingStage = stage; + existingCluster = cluster; + return true; + } + + // If the stage is later, no change. + if (stage > existingStage) { + return false; + } + + // If existingCluster is reachable from cluster, + // then cluster is earlier in the list + auto it = cluster; + for (auto it = cluster; it != clusters.end(); ++it) { + if (it == existingCluster) { + existingCluster = cluster; + return true; + } + } + + // Didn't change the cluster. + return false; +} + +bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, + tt::CoarseSchedule::Cluster cluster, + bool includeArg, bool insertIfEarlier) { + auto tryInsert = [&](Operation *op, int stage, + tt::CoarseSchedule::Cluster cluster) { + if (!insertIfEarlier) + return insertIfAbsent(op, stage, cluster); + return insertMinimum(op, stage, cluster); + }; + + bool inserted = false; + for (Value operand : getNestedOperands(op)) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (tryInsert(defOp, stage, cluster)) { + inserted = true; + insertDepsOfOp(defOp, stage, cluster, includeArg); + } + } + } + return inserted; +} + +SmallVector> +tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) { + SmallVector>, 8> + orderClusters(clusters.size()); + for (auto &op : forOp.getBody()->without_terminator()) { + if (opToStageAndCluster.count(&op) == 0) { + continue; + } + assert(opToStageAndCluster[&op].first < numStages && + "Op with invalid stage!"); + int clusterId = *opToStageAndCluster[&op].second; + assert(clusterId == std::distance(clusters.begin(), + opToStageAndCluster[&op].second) && + "Cluster ID mismatch!"); + orderClusters[clusterId].push_back(make_tuple( + &op, opToStageAndCluster[&op].first, opToStageAndCluster[&op].second)); + } + SmallVector> opsInOrder; + for (int i = 0; i < orderClusters.size(); i++) { + for (auto [op, stage, cluster] : orderClusters[i]) { + opsInOrder.push_back({op, stage, cluster}); + } + } + + return opsInOrder; +} + +std::vector> +tt::CoarseSchedule::createFinalSchedule(scf::ForOp forOp) { + SmallVector> + opsInOrder = getOpsInOrder(forOp); + std::vector> schedule; + for (auto [op, stage, cluster] : opsInOrder) + schedule.push_back({op, stage}); + return schedule; +} + +void tt::CoarseSchedule::dump() { + assert(numStages > 0 && "Invalid number of stages"); + for (int i = 0; i < numStages; i++) { + llvm::dbgs() << "\n---- Ops in stage " << i << "\n"; + for (auto &[op, stageAndCluster] : opToStageAndCluster) { + if (i == stageAndCluster.first) { + llvm::dbgs() << " cluster: " << *stageAndCluster.second + << ":\n\t" << *op << "\n"; + } + } + } +} + +static void setStageCluster(Operation *op, int stage, int cluster) { + auto ctx = op->getContext(); + op->setAttr(mlir::triton::kLoopStageAttrName, + IntegerAttr::get(IntegerType::get(ctx, 32), stage)); + op->setAttr(mlir::triton::kLoopClusterAttrName, + IntegerAttr::get(IntegerType::get(ctx, 32), cluster)); +} + +static std::pair getStageCluster(Operation *op) { + auto stage = op->getAttrOfType(tt::kLoopStageAttrName); + auto clusterId = op->getAttrOfType(tt::kLoopClusterAttrName); + assert(stage && clusterId && + "Operation is missing stage & cluster attribute"); + return {stage.getValue().getSExtValue(), clusterId.getValue().getSExtValue()}; +} + +static std::pair getMinMaxCluster(scf::ForOp &forOp) { + int minClusterId = -1, maxClusterId = -1; + for (auto &op : forOp.getBody()->without_terminator()) { + if (!op.hasAttr(mlir::triton::kLoopStageAttrName) || + !op.hasAttr(mlir::triton::kLoopClusterAttrName)) + continue; + auto [_, cluster] = getStageCluster(&op); + if (maxClusterId < 0) { + minClusterId = cluster; + maxClusterId = cluster; + continue; + } + maxClusterId = cluster > maxClusterId ? cluster : maxClusterId; + minClusterId = cluster < minClusterId ? cluster : minClusterId; + } + return std::make_pair(minClusterId, maxClusterId); +} + +static std::optional tryGetMaxStage(scf::ForOp &forOp) { + std::optional maxStage = std::nullopt; + if (forOp->hasAttr(mlir::triton::kScheduledMaxStageAttrName)) { + return forOp + ->getAttrOfType(mlir::triton::kScheduledMaxStageAttrName) + .getValue() + .getSExtValue(); + } + return maxStage; +} + +// Set based on CoarseSchedule. +void tt::CoarseSchedule::serialize(scf::ForOp &forOp) { + for (auto [op, stage, cluster] : getOpsInOrder(forOp)) { + setStageCluster(op, stage, *cluster); + } + forOp->setAttr(mlir::triton::kScheduledMaxStageAttrName, + IntegerAttr::get(IntegerType::get(forOp.getContext(), 32), + numStages - 1)); +} + +// Create a CoarseSchedule based on forOp's . +LogicalResult tt::CoarseSchedule::deSerialize(scf::ForOp &forOp) { + auto [minClusterId, maxClusterId] = getMinMaxCluster(forOp); + std::optional maxStage = tryGetMaxStage(forOp); + if (!maxStage) { + return failure(); + } + numStages = *maxStage + 1; + + DenseMap clustersMap; + for (int i = minClusterId; i < maxClusterId + 1; i++) { + clustersMap.insert({i, clusters.newAtBack()}); + } + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasAttr(mlir::triton::kLoopStageAttrName)) + continue; + auto [stage, clusterId] = getStageCluster(&op); + insert(&op, stage, clustersMap[clusterId]); + } + return success(); +} + +// TODO: Should this be moved somewhere else? +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void tt::scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule) { + int numStages = schedule.getNumStages(); + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp new file mode 100644 index 000000000..1ff1aeb3e --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp @@ -0,0 +1,337 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +bool hasGpuBarriers(scf::ForOp forOp) { + WalkResult result = forOp.walk( + [&](mlir::gpu::BarrierOp barrier) { return WalkResult::interrupt(); }); + return result.wasInterrupted(); +} + +// Return true if the preconditions for pipelining the loop are met. +bool isSafeToPipeline(scf::ForOp forOp) { + // Skip loop with distance > 1. + if (loopHasDistGreaterThanOne(forOp)) + return false; + // Don't pipeline outer loops. + if (isOuterLoop(forOp)) + return false; + // Skip loops with barriers. + if (hasGpuBarriers(forOp)) + return false; + return true; +} + +bool hasLatenciesAssigned(scf::ForOp forOp, + const DenseMap &opLatency) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (opLatency.count(&op)) + return true; + } + return false; +} + +CoarseSchedule scheduleKeyOps(scf::ForOp forOp, + const DenseMap &opLatency) { + llvm::MapVector opToStage; + // Find terminator for later reference + auto terminator = cast(forOp.getBody()->getTerminator()); + // Determine all operations that have a non-zero latency + SmallVector latOps; + for (auto &op : forOp.getBody()->without_terminator()) { + if (opLatency.count(&op)) + latOps.push_back(&op); + } + // If no latency ops, nothing to schedule + if (latOps.empty()) + return CoarseSchedule(0); + + // Compute the longest path to the yield for each operation reachable + // from any latency operation. + DenseMap distance; + std::function computeDistance = [&](Operation *op) -> int { + auto it = distance.find(op); + if (it != distance.end()) + return it->second; + // Compute max distance among all users that are inside the loop body + int maxDist = -1; + for (Operation *user : op->getUsers()) { + // Only consider users inside the same block and not the terminator + Operation *inBlockUser = forOp.getBody()->findAncestorOpInBlock(*user); + if (!inBlockUser || inBlockUser == terminator) + continue; + int distUser = computeDistance(inBlockUser); + if (distUser > maxDist) + maxDist = distUser; + } + int lat = 0; + if (opLatency.count(op)) + lat = opLatency.lookup(op); + // If an op has no users (maxDist == -1) but has latency, we include its + // latency otherwise it contributes 0 to the distance. + int d = lat + (maxDist < 0 ? 0 : maxDist); + distance[op] = d; + return d; + }; + + // Compute distances for all latency-starting ops + int maxDistance = 0; + for (Operation *latOp : latOps) { + int d = computeDistance(latOp); + if (d > maxDistance) + maxDistance = d; + } + + // Assign stage to each op reachable from a latency op + for (auto [op, dist] : distance) { + // We only schedule ops that are downstream of a latency op + // (had a non-negative distance due to a latency op). + if (dist >= 0) + opToStage[op] = maxDistance - dist; + } + + auto stages = llvm::make_second_range(opToStage); + int maxStage = *llvm::max_element(stages); + CoarseSchedule schedule(maxStage + 1); + SmallVector clusters(maxStage + 1); + for (int i = 0; i <= maxStage; i++) { + clusters[i] = schedule.clusters.newAtBack(); + } + // Assign ops to the clusters in reverse-stage order; + // ops with higher stage numbers are assigned first. This way we will + // end up with roughly reverse program order in the clusters. + for (auto [op, stage] : opToStage) + schedule.insert(op, stage, clusters[maxStage - stage]); + + // Move `scf.if` ops in the current schedule (forward slice of the latency + // ops) into a new epilogue cluster at the end of the schedule, pushing them + // as close to the end of the loop body as possible. + CoarseSchedule::Cluster epilogue = schedule.clusters.newAtBack(); + for (auto [op, stage] : opToStage) { + auto ifOp = dyn_cast(op); + if (!ifOp) + continue; + // If the `scf.if` op itself is a latency op, skip it. + if (opLatency.contains(ifOp)) + continue; + // Ensure this does not create scheduling conflicts by ensuring the forward + // slice of the `scf.if` does not contain ops that are already scheduled, as + // this will cause the `scf.if` to be scheduled after its dependents. + SetVector slice; + getForwardSlice(ifOp, &slice); + if (llvm::any_of(slice, [&](Operation *op) { return opToStage.count(op); })) + continue; + schedule.insert(ifOp, stage, epilogue); + } + + return schedule; +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +void scheduleDistanceOneDependencies(scf::ForOp forOp, + CoarseSchedule &schedule) { + int numStages = schedule.getNumStages(); + + // Mapping from the cluster to the cluster before it. + DenseMap dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + if (auto arg = dyn_cast(operand)) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) { + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && schedule.count(defOp) == 0) { + if (isa(defOp)) { + // Exception: Schedule loads with a distance of 1 together + // with the current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, + /*includeArg=*/true, + /*insertIfEarlier=*/true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], + /*includeArg=*/true, + /*includeIfEarlier=*/true); + } + } + } + } + } + } +} + +// Schedule the prologue and epilogue `if` ops in the loop, pushing them as +// close to the loop boundaries as possible. Return the cluster after the +// prologue (or the beginning of the loop if there is no prologue). +CoarseSchedule::Cluster schedulePrologueAndEpilogue(scf::ForOp forOp, + CoarseSchedule &schedule) { + int numStages = schedule.getNumStages(); + CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.omitUsesFromAbove = false; + getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + if (!ifsToStage.empty()) { + CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insertIfAbsent(ifOp, stage, prologueCluster); + } + } + + // Other IfOps should be pushed to the end. + CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto ifOp = dyn_cast(op)) { + if (ifsToStage.count(ifOp) == 0) { + schedule.insertIfAbsent(ifOp, numStages - 1, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule, + CoarseSchedule::Cluster afterPrologue) { + int numStages = schedule.getNumStages(); + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + CoarseSchedule::Cluster userCluster = opToCluster[user]; + CoarseSchedule::Cluster opCluster; + if (schedule.count(op)) + opCluster = schedule[op].second; + else + opCluster = opToCluster[op]; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +void scheduleLoop(scf::ForOp forOp, + const DenseMap &opLatency) { + if (!hasLatenciesAssigned(forOp, opLatency) || !isSafeToPipeline(forOp)) + return; + // Based on the latencies, schedule the key ops to the stages. + CoarseSchedule schedule = scheduleKeyOps(forOp, opLatency); + if (schedule.empty()) + return; + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Initial coarse schedule:\n" << forOp << "\n"; + }); + // Schedule the dependencies + CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, schedule); + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Coarse schedule with prologue and epilogue:\n" << forOp << "\n"; + }); + scheduleDependencies(forOp, schedule); + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Coarse schedule with dependencies:\n" << forOp << "\n"; + }); + scheduleDistanceOneDependencies(forOp, schedule); + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Coarse schedule with dist 1:\n" << forOp << "\n"; + }); + scheduleRemainingToLastStage(forOp, schedule, afterPrologue); + LLVM_DEBUG({ + schedule.serialize(forOp); + DBGS() << "Final coarse schedule:\n" << forOp << "\n"; + }); + + // Write the schedule to the IR + schedule.serialize(forOp); +} + +} // namespace + +void scheduleLoops(ModuleOp moduleOp) { + DenseMap opLatency = deserializeLatencies(moduleOp); + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + if (loops.empty()) + return; + for (auto forOp : loops) { + scheduleLoop(forOp, opLatency); + } +} + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp new file mode 100644 index 000000000..dbf8e01fe --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -0,0 +1,176 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPIPELINE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static void pipelineWgmma(ModuleOp moduleOp) { + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + + for (scf::ForOp forOp : loops) { + mlir::triton::asyncLaunchDots(forOp); + } +} + +static void expandLoops(ModuleOp moduleOp) { + SmallVector loops; + moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + for (scf::ForOp forOp : loops) { + CoarseSchedule schedule; + if (failed(schedule.deSerialize(forOp))) { + continue; + } + + std::vector> finalSchedule = + schedule.createFinalSchedule(forOp); + triton::PipeliningOption options; + options.supportDynamicLoops = true; + options.peelEpilogue = false; + options.predicateFn = triton::predicateOp; + options.getScheduleFn = + [&](scf::ForOp forOp, + std::vector> &schedule) { + schedule = finalSchedule; + }; + IRRewriter rewriter(forOp); + FailureOr newForOp = + triton::pipelineForLoop(rewriter, forOp, options); + } +} + +static void removeAttributes(ModuleOp moduleOp) { + moduleOp->walk([&](Operation *op) { + op->removeAttr(mlir::triton::kLoopStageAttrName); + op->removeAttr(mlir::triton::kLoopClusterAttrName); + op->removeAttr(mlir::triton::kScheduledMaxStageAttrName); + }); +} + +struct PipelinePass : public impl::TritonGPUPipelineBase { + + using impl::TritonGPUPipelineBase::TritonGPUPipelineBase; + + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return numStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + // Go over the interesting ops and assign latencies (based on the + // numStages) to the them, trying to populate the allowed stages. This + // step will be at some point extracted to separate pass that will be run + // only for loops missing the latency information. + assignLatencies(moduleOp, numStages); + if (dumpIntermediateSteps) { + llvm::dbgs() << "// -----// SoftwarePipeliner internal IR Dump After: " + "AssignLatencies\n" + << moduleOp << "\n\n\n"; + } + // numStages should not be used below this point. We should know + // everything based on the assigned stages + + // Schedule the loops + scheduleLoops(moduleOp); + if (dumpIntermediateSteps) { + llvm::dbgs() << "// -----// SoftwarePipeliner internal IR Dump After: " + "ScheduleLoops\n" + << moduleOp << "\n\n\n"; + } + + // Transform the loop by introducing async operations to prepare it for + // pipeline expansion. + lowerLoops(moduleOp); + if (dumpIntermediateSteps) { + llvm::dbgs() + << "// -----// SoftwarePipeliner internal IR Dump After: LowerLoops\n" + << moduleOp << "\n\n\n"; + } + + // Apply the pipeline expansion. + expandLoops(moduleOp); + if (dumpIntermediateSteps) { + llvm::dbgs() << "// -----// SoftwarePipeliner internal IR Dump After: " + "ExpandLoops\n" + << moduleOp << "\n\n\n"; + } + + // Cleanup the IR from the pipeline attributes. + removeAttributes(moduleOp); + + pipelineWgmma(moduleOp); + + // There is a hard dependency between load pipelining and the TC05MMA + // pipelining. We can pipeline the TC05MMA only after the loads are + // pipelined and buffers are allocated. + mlir::triton::pipelineTC05MMALoops(moduleOp, 2); + + // schedule the waits + mlir::triton::updateWaits(getOperation()); + + // Clean up arithmetic before applying the next level of pipelining to + // simplify the IR. + auto arithDialect = + getOperation().getContext()->getLoadedDialect(); + RewritePatternSet patterns(getOperation().getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (applyPatternsGreedily(getOperation(), std::move(patterns)).failed()) + return signalPassFailure(); + + { + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + for (scf::ForOp forOp : loops) { + mlir::triton::pipelineTMAStores(forOp); + } + + for (scf::ForOp forOp : loops) { + mlir::triton::pipelineMMAWithScaledAcc(forOp); + } + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp new file mode 100644 index 000000000..8d8c45998 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp @@ -0,0 +1,939 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace { + +const char *kPipelineStageAttrName = "triton.pipeline_stage"; +const char *kPipelineAttrName = "triton.pipeline"; + +// Utils: +void replaceAllUsesDominatedBy(Operation *domOp, Value newValue, + Value oldValue) { + DominanceInfo domOpInfo(domOp->getParentOp()); + oldValue.replaceUsesWithIf(newValue, [&](OpOperand &use) { + return domOpInfo.properlyDominates(domOp, use.getOwner()); + }); +} + +void annotateWithPipelineStage(IRRewriter &builder, Operation *op, int stage) { + op->setAttr(kPipelineStageAttrName, + IntegerAttr::get(builder.getI32Type(), stage)); +} + +int getPipelineStage(Operation *op) { + return op->getAttrOfType(kPipelineStageAttrName).getInt(); +} + +struct MMAInfo { + struct AccOverridePoint { + Operation *op; + Value condition = nullptr; + Value initValue = nullptr; + int distance = 0; + bool isFlag = false; + }; + + ttng::TMEMAllocOp accAlloc; // Directly precedes the dot, allocating tmem + // for the accumulator + ttng::TMEMLoadOp + accLoad; // Directly follows the dot, loading accumulator from tmem + std::optional accDef; + std::optional yieldArgNo; + bool accIsMultiBuffered; + + Value phase = nullptr; + Value barrierIdx = nullptr; + Value accInsertIdx = nullptr; + Value accExtractIdx = nullptr; + Value barrierAlloc = nullptr; +}; + +// Returns the TMEMAllocOp and TMEMLoadOp that are used to allocate and load the +// accumulator for the given MMA operation. The TMEMAllocOp and TMEMLoadOp must +// be in the same region as the MMA operation. +std::optional> +getTMemAllocAndLoad(ttng::MMAv5OpInterface mmaOp) { + auto acc = mmaOp->getOperand(2).getDefiningOp(); + if (!acc || acc->getParentRegion() != mmaOp->getParentRegion()) { + return std::nullopt; + } + for (auto user : acc->getUsers()) { + if (auto load = dyn_cast(user)) { + if (load->getParentRegion() == mmaOp->getParentRegion()) { + return std::make_pair(acc, load); + } + } + } + return std::nullopt; +} + +// Check if the accumulator is being used by the same MMA in the next iteration. +// If so, return the yield argument number that the accumulator is being used +// as. Also, check if accumulator has runtime divergent uses - uses that may not +// be known at the compile time. +std::optional trackAccChain(scf::ForOp forOp, ttng::TMEMLoadOp accDef, + ttng::TMEMAllocOp accAlloc, + bool &hasDivergentUses) { + hasDivergentUses = false; + struct UseInfo { + Value value = nullptr; + std::optional yieldArgNo = std::nullopt; + bool divergentUse = false; + }; + SmallVector queue; + std::optional yieldArgNo = std::nullopt; + queue.push_back({accDef.getResult(), std::nullopt, false}); + while (!queue.empty()) { + UseInfo info = queue.pop_back_val(); + for (auto &use : info.value.getUses()) { + if (auto yieldOp = dyn_cast(use.getOwner())) { + if (yieldOp->getParentOp() == forOp) { + queue.push_back({forOp.getRegionIterArg(use.getOperandNumber()), + use.getOperandNumber(), true}); // divergent use + continue; + } + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + queue.push_back({ifOp.getResult(use.getOperandNumber()), + info.yieldArgNo, true}); // divergent use + continue; + } + assert(0 && "Unexpected use of accumulator"); + } else if (auto selectOp = dyn_cast(use.getOwner())) { + queue.push_back({selectOp.getResult(), info.yieldArgNo, true}); + } else if (use.getOwner() == accAlloc) { + yieldArgNo = info.yieldArgNo; + } else { + // Op other than yield or accAlloc. Mark as divergent use if + // we had to go through selectOp or ifOp. + hasDivergentUses = info.divergentUse; + } + } + } + return yieldArgNo; +} + +SmallVector getDirectAccUses(ttng::TMEMLoadOp accDef) { + SmallVector accUses; + for (auto user : accDef.getResult().getUsers()) { + if (!isa(user) && !isa(user)) { + accUses.push_back(user); + } + } + return accUses; +} + +std::optional +getAccOverridePointInLoop(scf::ForOp forOp, ttng::TMEMAllocOp accUse, + ttng::TMEMLoadOp accDef) { + MMAInfo::AccOverridePoint accOverridePoint; + accOverridePoint.isFlag = false; + DenseSet seen; + Value v = accUse.getSrc(); + if (v == nullptr) { + // Uninitialized accumulator means unused accumulator + accOverridePoint.op = accUse; + return accOverridePoint; + } + int dist = 0; + while (auto blockArg = dyn_cast(v)) { + if (!seen.insert(v).second) { + return std::nullopt; + } + assert(blockArg.getOwner() == forOp.getBody()); + auto yieldOp = cast(blockArg.getOwner()->getTerminator()); + v = yieldOp.getOperand(blockArg.getArgNumber() - 1); + dist++; + } + if (!v.getDefiningOp()) { + return std::nullopt; + } + accOverridePoint.distance = dist; + bool thenOverrides = false; + if (auto selectOp = dyn_cast(v.getDefiningOp())) { + accOverridePoint.op = selectOp; + bool trueIsConst = + (selectOp.getTrueValue().getDefiningOp() != nullptr); + bool falseIsConst = + (selectOp.getFalseValue().getDefiningOp() != + nullptr); + if (trueIsConst && falseIsConst) { + // Both values are constant, so the select overrides unconditionally + accOverridePoint.initValue = v; + return accOverridePoint; + } else if (trueIsConst) { + accOverridePoint.initValue = selectOp.getTrueValue(); + thenOverrides = true; + } else if (falseIsConst) { + accOverridePoint.initValue = selectOp.getFalseValue(); + thenOverrides = false; + } else { + return std::nullopt; + } + accOverridePoint.condition = selectOp.getCondition(); + if (!thenOverrides) { + IRRewriter builder(selectOp); + Value vTrue = builder.create( + selectOp.getLoc(), builder.getBoolAttr(true)); + accOverridePoint.condition = builder.create( + selectOp.getLoc(), accOverridePoint.condition, vTrue); + } + } else if (v.getDefiningOp() != accDef) { + assert(!isa(v.getDefiningOp()) && + "Expected unconditional override op"); + accOverridePoint.op = v.getDefiningOp(); + accOverridePoint.initValue = v; + } else { + return std::nullopt; + } + + return accOverridePoint; +} + +std::optional +getAccUseFlagFalseInLoop(scf::ForOp forOp, Value useAccFlagUse) { + DenseSet seen; + Value v = useAccFlagUse; + int dist = 0; + while (auto blockArg = dyn_cast(v)) { + if (!seen.insert(v).second) { + return {}; + } + assert(blockArg.getOwner() == forOp.getBody()); + auto yieldOp = cast(blockArg.getOwner()->getTerminator()); + v = yieldOp.getOperand(blockArg.getArgNumber() - 1); + dist++; + } + if (!v.getDefiningOp() || !forOp->isAncestor(v.getDefiningOp())) { + return std::nullopt; + } + assert(v.getType().isInteger(1)); + + IRRewriter builder(v.getDefiningOp()->getNextNode()); + MMAInfo::AccOverridePoint accOverridePoint; + accOverridePoint.isFlag = true; + accOverridePoint.distance = dist; + Location loc = v.getDefiningOp()->getLoc(); + auto vTrue = + builder.create(loc, builder.getBoolAttr(true)); + accOverridePoint.op = v.getDefiningOp(); + accOverridePoint.condition = builder.create(loc, v, vTrue); + + return accOverridePoint; +} + +std::optional +getAccOverrideOrFlagFalseInLoop(scf::ForOp forOp, + ttng::MMAv5OpInterface mmaOp) { + auto tmemAllocAndLoad = getTMemAllocAndLoad(mmaOp); + assert(tmemAllocAndLoad.has_value() && "Expected tmem alloc and load"); + auto [accAlloc, accLoad] = tmemAllocAndLoad.value(); + auto accOverridePoint = getAccOverridePointInLoop(forOp, accAlloc, accLoad); + + if (!accOverridePoint.has_value()) { + auto useAccFlag = mmaOp.useAccumulator(); + accOverridePoint = getAccUseFlagFalseInLoop(forOp, useAccFlag); + } + + return accOverridePoint; +} + +ttng::TMEMAllocOp createTMemAlloc(IRRewriter &builder, + ttng::TMEMAllocOp oldTMemAllocOp, + bool multiBufferred, int numStages) { + Location loc = oldTMemAllocOp.getLoc(); + auto oldRetType = oldTMemAllocOp.getType(); + SmallVector shape = {oldRetType.getShape().begin(), + oldRetType.getShape().end()}; + if (multiBufferred) { + shape.insert(shape.begin(), numStages); + } + Type accMemDescType = triton::gpu::MemDescType::get( + shape, oldRetType.getElementType(), oldRetType.getEncoding(), + oldRetType.getMemorySpace(), /*mutableMemory=*/true); + return builder.create(oldTMemAllocOp.getLoc(), + accMemDescType, nullptr); +} + +void createInitStore(IRRewriter &builder, ttng::TMEMAllocOp allocOp, + Value initVal, bool multiBufferred) { + Value bufferSlice = allocOp; + if (multiBufferred) { + bufferSlice = triton::createSingleBufferView(builder, allocOp, 0); + } + Value vTrue = builder.create(allocOp.getLoc(), 1, 1); + builder.create(allocOp.getLoc(), bufferSlice, initVal, + vTrue); +} + +Operation *findNearestCommonDominator(ArrayRef ops, + DominanceInfo &domInfo) { + if (ops.size() == 0) { + return nullptr; + } + if (ops.size() == 1) { + return ops[0]; + } + llvm::SmallPtrSet blocks; + for (auto op : ops) { + blocks.insert(op->getBlock()); + } + Block *domBlock = domInfo.findNearestCommonDominator(blocks); + if (domBlock == nullptr) { + return nullptr; + } + SmallVector ancestorOps; + for (auto op : ops) { + ancestorOps.push_back(domBlock->findAncestorOpInBlock(*op)); + } + Operation *dom = ancestorOps[0]; + for (unsigned i = 1; i < ops.size(); i++) { + if (ancestorOps[i]->isBeforeInBlock(dom)) { + dom = ancestorOps[i]; + } + } + return dom; +} + +void updateAccUsesInLoop(IRRewriter &builder, scf::ForOp forOp, MMAInfo &info, + ttng::TMEMAllocOp newAlloc, int numStages) { + DominanceInfo domInfo(forOp); + SmallVector directUses = getDirectAccUses(info.accLoad); + if (!directUses.empty()) { + Operation *domOp = findNearestCommonDominator(directUses, domInfo); + assert(domOp != nullptr && "Could not find a common dominator"); + builder.setInsertionPoint(domOp); + Value extractSlice = newAlloc; + if (info.accIsMultiBuffered) { + extractSlice = + triton::createSingleBufferView(builder, newAlloc, info.accExtractIdx); + } + auto load = builder.create( + domOp->getLoc(), info.accLoad.getType(), extractSlice); + // If accumulator is multi-buffered, it is implicit that we put the load + // in the last stage. + int pipelineStage = info.accIsMultiBuffered ? numStages - 1 : 0; + annotateWithPipelineStage( + builder, forOp.getBody()->findAncestorOpInBlock(*load.getOperation()), + pipelineStage); + for (auto user : directUses) { + user->replaceUsesOfWith(info.accLoad, load); + } + } +} + +void updateAccUsesOutsideLoop(IRRewriter &builder, scf::ForOp forOp, + const MMAInfo &info, ttng::TMEMAllocOp newAlloc, + int extractIdxArgNo) { + builder.setInsertionPointAfter(forOp); + if (!info.yieldArgNo.has_value()) { + return; + } + if (forOp.getResult(info.yieldArgNo.value()).getUsers().empty()) { + return; + } + Value bufferSlice = newAlloc; + if (info.accIsMultiBuffered) { + Value extractIdxVal = forOp.getResult(extractIdxArgNo); + bufferSlice = + triton::createSingleBufferView(builder, newAlloc, extractIdxVal); + } + auto load = builder.create( + forOp.getLoc(), forOp.getResult(info.yieldArgNo.value()).getType(), + bufferSlice); + forOp.getResult(info.yieldArgNo.value()).replaceAllUsesWith(load); +} + +void updateAccDefsInLoop(IRRewriter &builder, scf::ForOp forOp, MMAInfo &info, + ttng::TMEMAllocOp newAlloc, int numStages) { + assert(info.accDef.has_value()); + Operation *def = info.accDef->op; + Value condition = info.accDef->condition; + Location loc = def->getLoc(); + + builder.setInsertionPointAfter(def); + if (condition && condition.getDefiningOp()) { + builder.setInsertionPointAfter(condition.getDefiningOp()); + } + // if insertion point is outside the loop body, move it inside + if (builder.getBlock() != forOp.getBody()) { + builder.setInsertionPointAfter(&forOp.getBody()->front()); + } + Value numStagesVal = builder.create(loc, numStages, 32); + + Value newInsertIdx = builder.create( + loc, info.accInsertIdx, builder.create(loc, 1, 32)); + Value insWrap = builder.create(loc, arith::CmpIPredicate::eq, + newInsertIdx, numStagesVal); + newInsertIdx = builder.create( + loc, newInsertIdx.getType(), insWrap, + builder.create(loc, 0, 32), newInsertIdx); + if (condition) { + newInsertIdx = + builder.create(loc, newInsertIdx.getType(), condition, + newInsertIdx, info.accInsertIdx); + } + annotateWithPipelineStage(builder, newInsertIdx.getDefiningOp(), 0); + + Value newExtractIdx = builder.create( + loc, info.accExtractIdx, + builder.create(loc, 1, 32)); + auto extWrap = builder.create(loc, arith::CmpIPredicate::eq, + newExtractIdx, numStagesVal); + newExtractIdx = builder.create( + loc, newExtractIdx.getType(), extWrap, + builder.create(loc, 0, 32), newExtractIdx); + if (info.accDef->condition) { + newExtractIdx = builder.create( + loc, newExtractIdx.getType(), info.accDef->condition, newExtractIdx, + info.accExtractIdx); + } + annotateWithPipelineStage(builder, newExtractIdx.getDefiningOp(), 1); + + if (info.accDef->initValue) { + Value bufferSlice = + triton::createSingleBufferView(builder, newAlloc, newInsertIdx); + Value vTrue = builder.create(loc, 1, 1); + auto tmemStore = builder.create( + loc, bufferSlice, info.accDef->initValue, + condition ? condition : vTrue); + annotateWithPipelineStage(builder, tmemStore, 0); + } + + // Always update the for yield with the new insert and extract indices + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->replaceUsesOfWith(info.accInsertIdx, newInsertIdx); + forYield->replaceUsesOfWith(info.accExtractIdx, newExtractIdx); + + // Only update rest of the uses if the override is dist 0 (the same + // loop iteration) + if (info.accDef->distance == 0) { + replaceAllUsesDominatedBy(newInsertIdx.getDefiningOp(), newInsertIdx, + info.accInsertIdx); + replaceAllUsesDominatedBy(newExtractIdx.getDefiningOp(), newExtractIdx, + info.accExtractIdx); + } + + if (info.accDef->initValue && condition) { + assert(isa(info.accDef->op)); + info.accDef->op->erase(); + } + + info.accInsertIdx = newInsertIdx; + info.accExtractIdx = newExtractIdx; +} + +// Hoist tmem_allocs outside of the loop and update the mma ops to use the +// hoisted tmem allocs. Also, update the acc loads and stores to use the new +// tmem allocs. +void hoistAndUseTMemAlloc(IRRewriter &builder, scf::ForOp forOp, + ttng::MMAv5OpInterface mmaOp, MMAInfo &info, + int numStages) { + builder.setInsertionPoint(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + Value one = builder.create(forOp.getLoc(), 1, 32); + Value numStagesVal = + builder.create(forOp.getLoc(), numStages, 32); + Value vTrue = builder.create(forOp.getLoc(), 1, 1); + + builder.setInsertionPoint(forOp); + ttng::TMEMAllocOp newAlloc = createTMemAlloc( + builder, info.accAlloc, info.accIsMultiBuffered, numStages); + bool chainedAcc = info.yieldArgNo.has_value(); + if (chainedAcc) { + Value accInitValue = forOp.getInitArgs()[info.yieldArgNo.value()]; + createInitStore(builder, newAlloc, accInitValue, info.accIsMultiBuffered); + } + + // Update mma ops to use the hoisted tmem allocs + Value insertSlice = newAlloc; + if (info.accIsMultiBuffered) { + builder.setInsertionPoint(mmaOp); + insertSlice = + triton::createSingleBufferView(builder, insertSlice, info.accInsertIdx); + } + + mmaOp.setAccumulator(insertSlice); + + updateAccUsesInLoop(builder, forOp, info, newAlloc, numStages); + assert(isa(info.accExtractIdx)); + int extractIdxArgNo = + cast(info.accExtractIdx).getArgNumber() - 1; + updateAccUsesOutsideLoop(builder, forOp, info, newAlloc, extractIdxArgNo); + + // Short circuit loop carry value that was holding the accumulator value, + // removing the last reference to the loaded accumulator. + if (info.yieldArgNo.has_value()) { + forOp.getBody()->getTerminator()->setOperand( + info.yieldArgNo.value(), forOp.getInitArgs()[info.yieldArgNo.value()]); + } + + if (info.accIsMultiBuffered) { + updateAccDefsInLoop(builder, forOp, info, newAlloc, numStages); + } + + info.accLoad.erase(); + info.accAlloc.erase(); + info.accAlloc = newAlloc; +} + +// Create multi-buffered barrier allocs and lower the MMA to MMA + wait barrier +void createBarrierAndWaitOps(IRRewriter &builder, scf::ForOp forOp, + ttng::MMAv5OpInterface mmaOp, MMAInfo &info, + int numStages) { + builder.setInsertionPoint(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + Value one = builder.create(forOp.getLoc(), 1, 32); + Value numStagesVal = + builder.create(forOp.getLoc(), numStages, 32); + + info.barrierAlloc = triton::createBarrierAlloc(forOp, numStages); + + Location loc = mmaOp->getLoc(); + builder.setInsertionPoint(mmaOp); + + Value barrierSlice = triton::createSingleBufferView( + builder, info.barrierAlloc, info.barrierIdx); + mmaOp.setBarrier(barrierSlice); + + builder.setInsertionPointAfter(mmaOp); + auto waitOp = + builder.create(loc, barrierSlice, info.phase); + annotateWithPipelineStage(builder, waitOp, numStages - 1); + + Value newBarrierIdx = + builder.create(loc, info.barrierIdx, one); + auto barWrap = builder.create(loc, arith::CmpIPredicate::eq, + newBarrierIdx, numStagesVal); + + // New barrierIdx and phase are in the first stage, so they can be used by + // the ops that are ahead of them in either order or stages. + newBarrierIdx = builder.create(loc, newBarrierIdx.getType(), + barWrap, zero, newBarrierIdx); + replaceAllUsesDominatedBy(newBarrierIdx.getDefiningOp(), newBarrierIdx, + info.barrierIdx); + info.barrierIdx = newBarrierIdx; + annotateWithPipelineStage(builder, info.barrierIdx.getDefiningOp(), 0); + + Value originalPhase = info.phase; + Value newPhase = builder.create( + loc, info.phase.getType(), barWrap, + builder.create(loc, info.phase, one), info.phase); + replaceAllUsesDominatedBy(newPhase.getDefiningOp(), newPhase, info.phase); + info.phase = newPhase; + annotateWithPipelineStage(builder, info.phase.getDefiningOp(), 0); + + // We need to add a barrier before load from the accumulator, if it is in the + // same stage as the dot. + ttng::TMEMLoadOp tmemLoad = nullptr; + SmallVector users = {info.accAlloc->getUsers().begin(), + info.accAlloc->getUsers().end()}; + while (!users.empty()) { + auto user = users.pop_back_val(); + if (isa(user)) { + users.append(user->getUsers().begin(), user->getUsers().end()); + } + if (isa(user) && forOp->isAncestor(user)) { + if (tmemLoad) { + assert(tmemLoad == cast(user) && + "Should have only one tmem load from the accumulator"); + } + tmemLoad = cast(user); + } + } + if (tmemLoad) { + int loadStage = + getPipelineStage(forOp.getBody()->findAncestorOpInBlock(*tmemLoad)); + int mmaOpStage = getPipelineStage(mmaOp); + if (loadStage == mmaOpStage) { + builder.setInsertionPoint(tmemLoad); + auto barrier = + builder.create(loc, barrierSlice, originalPhase); + annotateWithPipelineStage( + builder, forOp.getBody()->findAncestorOpInBlock(*barrier), + mmaOpStage); + } + } +} + +bool isSafeToPipeline(ttng::TCGen5MMAScaledOp scaledDot, scf::ForOp forOp) { + // MMAv5 scaled dot (tcgen05.mma mxf8f6f4) is safe to be pipelined only + // when its scales in TMEM are stored by the TMEMCopy op (tcgen05.cp). + // That condition is equivalent to scale arguments of + // ttng::TCGen5MMAScaledOp being in SMEM during SWP in our convention. + auto isCopiedByTMEMCopy = [&](Value scale) { + auto scaleAlloc = findShmemAlloc(scale); + if (!scaleAlloc || !forOp.isDefinedOutsideOfLoop(scaleAlloc)) + return false; + return true; + }; + + return isCopiedByTMEMCopy(scaledDot.getAScale()) && + isCopiedByTMEMCopy(scaledDot.getBScale()); +} + +// Find MMAs eligible for pipelining and lower them by: +// 1. Hoisting the accumulator allocation outside of the loop. +// 2. Creating a barrier alloc and lowering the MMA to MMA + wait barrier. +// 3. Updating the uses of the accumulator in the loop to use the new tmem +// alloc. +FailureOr preProcessLoopForTC05MMAPipelining(scf::ForOp forOp, + int numStages) { + SmallVector mmaOps; + forOp.walk([&](Operation *op) { + // Skip MMA nested in another forOp + if (op->getParentOfType() == forOp) { + if (isa(op)) { + mmaOps.push_back(op); + } else if (auto scaledDot = dyn_cast(op)) { + if (isSafeToPipeline(scaledDot, forOp)) { + mmaOps.push_back(op); + } else { + op->emitWarning("Skipping pipelining of an MMAv5 scaled op because " + "TMEM copy is not used."); + } + } + } + }); + + // Temporarily disable mma pipelining if there are more than one mmaOp in the + // loop. This is a workaround for difficult to solve scheduling issues with + // loads feeding into non-0 stage ops. + if (mmaOps.empty() || mmaOps.size() > 1) { + return failure(); + } + + mmaOps = getMMAsWithMultiBufferredOperands(forOp, mmaOps); + + if (mmaOps.empty()) { + return failure(); + } + + IRRewriter builder(forOp->getContext()); + for (auto op : mmaOps) { + // Avoid pipelining if in the backward slice of the mmaOp there is an + // operation that is already assigned a stage, as it would make the pipeline + // deeper than we are prepared for. + auto mmaOp = cast(op); + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice(mmaOp, &backwardSlice, opt); + if (llvm::any_of(backwardSlice, [&](Operation *op) { + return op->hasAttr(kPipelineStageAttrName); + })) { + continue; + } + + auto allocAndLoadOpt = getTMemAllocAndLoad(mmaOp); + if (!allocAndLoadOpt) { + continue; + } + auto [accAlloc, accLoad] = allocAndLoadOpt.value(); + bool hasDivergentUses = false; + std::optional yieldArgNo = + trackAccChain(forOp, accLoad, accAlloc, hasDivergentUses); + if (hasDivergentUses) { + // If we can't tell for sure that the value is coming from the mma + // accumulator, skip. + continue; + } + if (yieldArgNo.has_value()) { + int accInitArgNo = + cast(accAlloc.getSrc()).getArgNumber() - 1; + assert(yieldArgNo.value() == accInitArgNo); + } + + std::optional accOverridePoint = + getAccOverrideOrFlagFalseInLoop(forOp, mmaOp); + + if (accOverridePoint.has_value() && accOverridePoint->distance > 1) { + // We only support an override up to 1 iteration back. + continue; + } + + SmallVector accUses = getDirectAccUses(accLoad); + DominanceInfo domOpInfo(forOp); + Operation *newAccLoadInsertPoint = + findNearestCommonDominator(accUses, domOpInfo); + // Check pipelining and multi-buffering constraints + // 1. Really needs multibuffering - if the acc is used unconditionally in + // the loop, or under different conditions. If we cannot multibuffer in this + // case, we may as well not pipeline at all, as we will have to wait after + // the dot in every loop iteration. + scf::IfOp topLevelIf = + newAccLoadInsertPoint + ? dyn_cast(forOp.getBody()->findAncestorOpInBlock( + *newAccLoadInsertPoint)) + : nullptr; + bool requiresMultiBuffer = accUses.size() > 0 && !topLevelIf; + // If we override the acc in the loop, it is generally hard to handle it + // without multibuffering. We make an exception if it not a physical + // override of a value, but just setting a flag that acc is not used. In + // this case we don't need different buffer to store init value. + requiresMultiBuffer |= + accOverridePoint.has_value() && !accOverridePoint->isFlag; + + // 2. If the acc is not owerwritten in the loop (by op other than the dot), + // it cannot be multi-buffered. This is because the overwrite is the only + // way to initialize next buffer without incurring a copy. + bool canMultiBuffer = accOverridePoint.has_value() && + !mlir::triton::getDisallowAccMultiBuffer(forOp); + if (requiresMultiBuffer && !canMultiBuffer) { + continue; + } + + MMAInfo mmaInfo = {.accAlloc = accAlloc, + .accLoad = accLoad, + .accDef = accOverridePoint, + .yieldArgNo = yieldArgNo, + .accIsMultiBuffered = canMultiBuffer}; + + builder.setInsertionPoint(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + + // Update for loop with new arguments + SmallVector newOperands; + const int argsPerMMA = 4; + newOperands.push_back(zero); // phase + newOperands.push_back(zero); // barrierIdx + newOperands.push_back(zero); // accInsertIdx + newOperands.push_back(zero); // accExtractIdx + assert(newOperands.size() == argsPerMMA); + + int firstNewOperandIndex = forOp.getInitArgs().size(); + auto newForOp = replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + + mmaInfo.phase = forOp.getRegionIterArg(firstNewOperandIndex + 0); + mmaInfo.barrierIdx = forOp.getRegionIterArg(firstNewOperandIndex + 1); + mmaInfo.accInsertIdx = forOp.getRegionIterArg(firstNewOperandIndex + 2); + mmaInfo.accExtractIdx = forOp.getRegionIterArg(firstNewOperandIndex + 3); + + SmallVector newYieldOperands; + newYieldOperands.push_back(mmaInfo.phase); + newYieldOperands.push_back(mmaInfo.barrierIdx); + newYieldOperands.push_back(mmaInfo.accInsertIdx); + newYieldOperands.push_back(mmaInfo.accExtractIdx); + + appendToForOpYield(forOp, newYieldOperands); + + annotateWithPipelineStage(builder, mmaOp, 0); + hoistAndUseTMemAlloc(builder, forOp, mmaOp, mmaInfo, numStages); + createBarrierAndWaitOps(builder, forOp, mmaOp, mmaInfo, numStages); + + // Invalidate and dealloc barrier + builder.setInsertionPointAfter(forOp); + Location loc = mmaOp->getLoc(); + for (int i = 0; i < numStages; i++) { + Value barrierView = + triton::createSingleBufferView(builder, mmaInfo.barrierAlloc, i); + builder.create(loc, barrierView); + } + builder.create(loc, mmaInfo.barrierAlloc); + } + + return forOp; +} + +bool insertUsersOfOp(tt::CoarseSchedule &coarseSchedule, Operation *op, + int stage, tt::CoarseSchedule::Cluster cluster) { + bool changed = false; + for (auto user : op->getUsers()) { + // Let wait barriers be scheduled based on the stage of async op it waits + // for. + if (!isa(user) && coarseSchedule.count(user) == 0) { + changed = true; + coarseSchedule.insert(user, stage, cluster); + insertUsersOfOp(coarseSchedule, user, stage, cluster); + } + } + return changed; +} + +bool getTC05MMASchedule(scf::ForOp &forOp, int numStages, + tt::PipeliningOption &options) { + tt::CoarseSchedule coarseSchedule(numStages); + tt::CoarseSchedule::Cluster cluster = coarseSchedule.clusters.newAtFront(); + for (auto &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr(kPipelineStageAttrName)) { + int stage = + op.getAttrOfType(kPipelineStageAttrName).getInt(); + coarseSchedule.insert(&op, stage, cluster); + } + } + + auto scheduleDependencies = [&]() { + bool fixedPoint = false; + while (!fixedPoint) { + fixedPoint = true; + // Schedule upstream dependencies + for (int stage = 0; stage < numStages; stage++) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (coarseSchedule.count(&op) && coarseSchedule[&op].first == stage) { + bool changed = coarseSchedule.insertDepsOfOp(&op, stage, cluster, + /*includeArg=*/false); + fixedPoint &= !changed; + } + } + } + // Schedule downstream dependencies + for (int stage = numStages - 1; stage >= 0; stage--) { + for (auto &op : forOp.getBody()->without_terminator()) { + if (coarseSchedule.count(&op) && coarseSchedule[&op].first == stage) { + bool changed = insertUsersOfOp(coarseSchedule, &op, stage, cluster); + fixedPoint &= !changed; + } + } + } + } + }; + + scheduleDependencies(); + + // Make sure that async loads are scheduled in the same stage they are used. + DenseMap allocToStage; + DenseMap allocToBarrierWait; + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto barrierWait = dyn_cast(op)) { + auto localAlloc = findShmemAlloc(barrierWait.getAlloc()); + assert(localAlloc); + assert(allocToBarrierWait.count(localAlloc) == 0); + allocToBarrierWait[localAlloc] = barrierWait; + continue; + } + if (!coarseSchedule.count(&op)) + continue; + + auto [stage, cluster] = coarseSchedule[&op]; + for (auto arg : op.getOperands()) { + auto memDescTy = dyn_cast(arg.getType()); + if (!memDescTy) + continue; + + auto localAlloc = findShmemAlloc(arg); + if (!localAlloc) + continue; + + allocToStage[localAlloc] = stage; + } + } + + for (auto &op : forOp.getBody()->without_terminator()) { + Value memDesc; + Value barrier; + if (auto copyOp = dyn_cast(op)) { + memDesc = copyOp.getResult(); + } else if (auto copyOp = dyn_cast(op)) { + memDesc = copyOp.getResult(); + barrier = copyOp.getBarrier(); + } else if (auto gatherOp = dyn_cast(op)) { + memDesc = gatherOp.getResult(); + barrier = gatherOp.getBarrier(); + } else if (auto storeOp = dyn_cast(op)) { + memDesc = storeOp.getSrc(); + } else if (auto scatterOp = dyn_cast(op)) { + memDesc = scatterOp.getSrc(); + } else { + continue; + } + auto localAlloc = findShmemAlloc(memDesc); + assert(localAlloc); + int stage = allocToStage[localAlloc]; + coarseSchedule.insert(&op, stage, cluster); + + // Schedule any barrier wait in the same stage as well, otherwise we will + // change the loop distance to the wait. + if (!barrier) + continue; + auto barrierAlloc = findShmemAlloc(barrier); + assert(barrierAlloc); + auto waitOp = allocToBarrierWait[barrierAlloc]; + // NOTE: barriers can be grouped onto multiple loads, so schedule into the + // eariest stage where the result is used. This means we reduce the distance + // between the tma issue and wait, but it is at least correct. + coarseSchedule.insertMinimum(waitOp, stage, cluster); + } + + scheduleDependencies(); + + // Schedule everything else to stage 0 + for (auto &op : forOp.getBody()->without_terminator()) { + op.removeAttr(kPipelineStageAttrName); + if (coarseSchedule.count(&op) == 0) { + coarseSchedule.insert(&op, 0, cluster); + } + } + + std::vector> schedule = + coarseSchedule.createFinalSchedule(forOp); + + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = tt::predicateOp; + options.supportDynamicLoops = true; + + return true; +} + +} // namespace + +void mlir::triton::pipelineTC05MMALoops(ModuleOp module, int numStages, + bool disableExpander) { + SmallVector forOps; + module->walk([&](scf::ForOp forOp) { forOps.push_back(forOp); }); + + for (auto forOp : forOps) { + FailureOr newForOp = + preProcessLoopForTC05MMAPipelining(forOp, numStages); + if (succeeded(newForOp)) { + (*newForOp)->setAttr(kPipelineAttrName, + UnitAttr::get(module.getContext())); + } + } + // Run canonicalization to clean up the short-circuited loop carried values. + mlir::RewritePatternSet patterns(module.getContext()); + scf::ForOp::getCanonicalizationPatterns(patterns, module.getContext()); + if (applyPatternsGreedily(module, std::move(patterns)).failed()) { + llvm::errs() << "Failed to canonicalize the module\n"; + return; + } + + if (!disableExpander) { + SmallVector loops; + module->walk([&](scf::ForOp forOp) { + if (forOp->getAttr(kPipelineAttrName)) + loops.push_back(forOp); + }); + + for (auto forOp : loops) { + mlir::triton::PipeliningOption options; + bool foundSchedule = getTC05MMASchedule(forOp, /*numStages=*/2, options); + assert(foundSchedule && "Failed to find a schedule for TC05MMA"); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); + } + } +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp new file mode 100644 index 000000000..1b2c0701b --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -0,0 +1,120 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +struct TMAStore { + Operation *op; + mlir::TypedValue desc; + mlir::TypedValue src; +}; + +static SmallVector getTMAStores(scf::ForOp forOp) { + SmallVector tmaStores; + + forOp.getBody()->walk([&](Operation *op) { + if (auto storeOp = dyn_cast(op)) { + tmaStores.push_back({storeOp, storeOp.getDesc(), storeOp.getSrc()}); + } else if (auto scatterOp = + dyn_cast(op)) { + tmaStores.push_back({scatterOp, scatterOp.getDesc(), scatterOp.getSrc()}); + + // Don't walk into nested loops. + } else if (isa(op)) { + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + + return tmaStores; +} + +static Value createAlloc(scf::ForOp &forOp, const TMAStore &store) { + OpBuilderWithAsyncTaskIds builder(forOp); + RankedTensorType ty = store.src.getType(); + auto order = ttg::getOrder(ty); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + Attribute encoding = ttg::SwizzledSharedEncodingAttr::get( + ty.getContext(), 1, 1, 1, order, ctaLayout); + if (ty.getRank() > 1) { + encoding = ttg::NVMMASharedEncodingAttr::get( + ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType(), + /*fp4Padded*/ false); + } + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()); + Type memdescType = + ttg::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, + sharedMemorySpace, /*mutableMemory*/ true); + Value alloc = builder.createWithAsyncTaskIds( + store.op->getLoc(), memdescType); + return alloc; +} + +static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store, + Value alloc) { + OpBuilderWithAsyncTaskIds builder(store.op); + Location loc = store.op->getLoc(); + RankedTensorType ty = store.src.getType(); + auto order = ttg::getOrder(ty); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + + // Put wait before the local_store make the store truly async. We know + // that we are the only user of the CopyLocalToGlobal. + builder.createWithAsyncTaskIds(loc, 0); + builder.createWithAsyncTaskIds(loc, store.src, alloc); + builder.createWithAsyncTaskIds(loc, false); + Value tmaPtr = + builder.createWithAsyncTaskIds( + loc, store.desc); + if (auto storeOp = dyn_cast(store.op)) { + builder.createWithAsyncTaskIds( + loc, tmaPtr, storeOp.getIndices(), alloc); + } else { + auto scatterOp = cast(store.op); + builder.createWithAsyncTaskIds( + loc, tmaPtr, scatterOp.getXOffsets(), scatterOp.getYOffset(), alloc); + } + + store.op->erase(); +} + +bool mlir::triton::pipelineTMAStores(scf::ForOp forOp) { + SmallVector tmaStores = getTMAStores(forOp); + if (tmaStores.empty()) + return false; + + DenseMap storeToAlloc; + DenseMap, Type>, Value> allocs; + for (const TMAStore &store : tmaStores) { + // Reuse allocations for stores of the same shape and types. This allows + // saving shared memory usage. It is valid since we have a wait 0 before + // every local_store. We could pipeline more aggressively if we didn't + // reuse but there is a tradeoff with shared memory usage. + RankedTensorType srcTy = store.src.getType(); + auto key = std::make_pair(srcTy.getShape(), srcTy.getElementType()); + Value &alloc = allocs[key]; + if (!alloc) { + alloc = createAlloc(forOp, store); + } + storeToAlloc[store.op] = alloc; + } + + for (const TMAStore &store : tmaStores) { + createTMAAsyncCopy(forOp, store, storeToAlloc[store.op]); + } + + // Deallocate shared memory buffers. + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + builder.create(forOp->getLoc(), 0); + for (auto it : storeToAlloc) { + builder.create(forOp->getLoc(), it.second); + } + return true; +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp new file mode 100644 index 000000000..a75099952 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineAssignLatencies.cpp @@ -0,0 +1,29 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTESTPIPELINEASSIGNLATENCIES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct TestPipelineAssignLatencies + : public impl::TritonGPUTestPipelineAssignLatenciesBase< + TestPipelineAssignLatencies> { + using impl::TritonGPUTestPipelineAssignLatenciesBase< + TestPipelineAssignLatencies>::TritonGPUTestPipelineAssignLatenciesBase; + + void runOnOperation() override { assignLatencies(getOperation(), numStages); } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineLowerLoop.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineLowerLoop.cpp new file mode 100644 index 000000000..7602bb476 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineLowerLoop.cpp @@ -0,0 +1,32 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTESTPIPELINELOWERLOOP +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct TestPipelineLowerLoop + : public impl::TritonGPUTestPipelineLowerLoopBase { + using impl::TritonGPUTestPipelineLowerLoopBase< + TestPipelineLowerLoop>::TritonGPUTestPipelineLowerLoopBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + lowerLoops(m); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp new file mode 100644 index 000000000..a95688aa0 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp @@ -0,0 +1,31 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTESTPIPELINESCHEDULELOOP +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static const char *kLatencyAttrName = "tt.latency"; + +struct TestPipelineScheduleLoop + : public impl::TritonGPUTestPipelineScheduleLoopBase< + TestPipelineScheduleLoop> { + using impl::TritonGPUTestPipelineScheduleLoopBase< + TestPipelineScheduleLoop>::TritonGPUTestPipelineScheduleLoopBase; + + void runOnOperation() override { scheduleLoops(getOperation()); } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp new file mode 100644 index 000000000..502791d2e --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp @@ -0,0 +1,558 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-wgmma-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +/// Find the minimum number of async_commit_group ops between the wait +/// and the associated async_commit_group. This can be safely used as the wait +/// number. +static int minNumInterleavedCommitOps(Operation *waitOp) { + auto countCommitsBetween = [](Operation *op1, Operation *op2) { + int count = 0; + for (auto op = op1; op != op2; op = op->getNextNode()) { + if (isa(op)) + count++; + // Intentionally skip block ops' children. This will give us + // convervatively low number of insert ops. + } + return count; + }; + + int minCommitNumber = INT_MAX; + + // DFS the def chain of the extract op to find the insert op. On each path + // we calculate the number of async_commit. Then we select the minimum number + // of async_commit ops among all the paths. + std::function minOverHistories = + [&](Value val, Operation *sinkOp, int thisHistorySum) -> int { + if (Operation *defOp = val.getDefiningOp()) { + thisHistorySum += countCommitsBetween(defOp->getNextNode(), sinkOp); + minCommitNumber = std::min(minCommitNumber, thisHistorySum); + return minCommitNumber; + } + if (auto arg = mlir::dyn_cast(val)) { + Block *block = arg.getOwner(); + auto forOp = dyn_cast(block->getParentOp()); + + // Failed to track, return 0 conservatively. + if (!forOp) + return 0; + + Operation *firstForInst = &*forOp.getBody()->begin(); + int insertsBetween = countCommitsBetween(firstForInst, sinkOp); + thisHistorySum += insertsBetween; + if (thisHistorySum >= minCommitNumber) + return minCommitNumber; + + // get the value assigned to the argument coming from outside the loop + Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1]; + int min1 = minOverHistories(incomingVal, forOp, thisHistorySum); + + // get the value assigned to the argument coming from the previous + // iteration + Operation *yieldOp = block->getTerminator(); + Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1); + int min2 = minOverHistories(prevVal, yieldOp, thisHistorySum); + return std::min(std::min(min1, min2), minCommitNumber); + } + // Failed to track, return 0 conservatively. + return 0; + }; + + if (waitOp->getNumOperands() != 1) + return 0; + Value val = waitOp->getOperand(0); + // If the value resides in a region other than the region of the wait op, then + // the wait op must be in some nested region. Measure the number of commits + // between the definition value and the parent op. + // TODO: We could measure commits in nested regions along the path if + // necessary. + while (waitOp->getParentRegion() != val.getParentRegion()) + waitOp = waitOp->getParentOp(); + int minCommits = minOverHistories(val, waitOp, 0); + return minCommits; +} + +// Look for consecutive wait ops and combine them into a single wait op. +static void +combineRedundantWaitOps(llvm::SmallSetVector &waitOps) { + llvm::MapVector toDelete; + for (auto waitOp : waitOps) { + if (toDelete.count(waitOp)) + continue; + SmallVector waitGroup = {waitOp}; + SmallVector depTokens; + unsigned minWaitNumber = waitOp.getNum(); + Operation *next = waitOp->getNextNode(); + while (next && !isa(next)) { + if (auto nextWait = dyn_cast(next)) { + waitGroup.push_back(nextWait); + minWaitNumber = std::min(minWaitNumber, nextWait.getNum()); + depTokens.append(nextWait.getOperands().begin(), + nextWait.getOperands().end()); + } + next = next->getNextNode(); + } + if (waitGroup.size() == 1) + continue; + OpBuilder builder(waitGroup.front()); + auto newWaitOp = builder.create(waitOp.getLoc(), + depTokens, minWaitNumber); + for (auto waitOp : waitGroup) { + toDelete[waitOp] = newWaitOp; + } + } + for (auto waitOp : toDelete) { + waitOp.first->replaceAllUsesWith(waitOp.second); + waitOp.first->erase(); + } +} + +/// Update wait op number by analyzing the number of async_commit_group ops +/// along all paths. +void mlir::triton::updateWaits(ModuleOp module) { + llvm::SmallSetVector waitOps; + module.walk([&](ttg::AsyncWaitOp waitOp) { + int minNumCommits = minNumInterleavedCommitOps(waitOp); + waitOp.setNum(minNumCommits); + waitOps.insert(waitOp); + }); + combineRedundantWaitOps(waitOps); +} + +// Add the given values as operands of the given wait, and replace all uses of +// the values with the wait. Also adds related MemDesc's to the wait. +// +// Threading %a through the wait transforms +// +// %a = <...> +// (%x', %y') = ttng.async_wait %x, %y +// %b = fn(%a) +// +// into +// +// %a = <...> +// (%x', %y', %a') = ttng.async_wait %x, %y, %a +// %b = fn(%a') +// +// The wait must dominate all uses of the elements of `values`. +// +// In addition to adding each value from `values` to the wait, this function +// also adds some MemDesc's to the wait. The idea is that if you have +// +// %alloc = ttg.local_alloc ... +// %a = ttng.warp_group_dot %alloc +// %a1 = ttng.warp_group_dot_wait %a +// +// then we want the wait to depend on %alloc as well as %a. This extends the +// live range of %alloc, so that it won't be destroyed until after the dot is +// waited on. +// +// Specifically, this function finds all warp_group_dot ops that elements of +// `values` depend on. Then it adds the MemDesc operands of those dots to the +// wait. +static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, + MutableArrayRef values) { + IRRewriter builder(wait.getContext()); + builder.setInsertionPoint(wait); + + // Operands are only added to the wait through this function, so we can have + // the invariant that the wait has no duplicates. This makes things a bit + // easier below. + size_t origNumOperands = wait.getNumOperands(); + SetVector newOperands(wait.getOperands().begin(), + wait.getOperands().end()); + assert(newOperands.size() == origNumOperands && + "Wait op has duplicate operands."); + + newOperands.insert(values.begin(), values.end()); + + // Find memdefs depended on by `values` through async dot ops. + SmallVector asyncDots; + for (Value v : values) { + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.filter = [&](Operation *op) { + if (auto dot = dyn_cast(op)) { + asyncDots.push_back(dot); + return false; + } + return op->getBlock() == wait->getBlock(); + }; + SetVector slice; + getBackwardSlice(v, &slice, options); + } + + for (ttng::WarpGroupDotOp dot : asyncDots) { + for (Value operand : dot.getOperands()) { + if (isa(operand.getType())) { + newOperands.insert(operand); + } + } + } + + // We can't use replaceWithNewOp because we're changing the number of return + // values in the operation. + auto newWait = builder.create( + wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings()); + + auto dominatedByNewWait = [&](OpOperand &operand) { + auto opInThisBlock = + newWait->getBlock()->findAncestorOpInBlock(*operand.getOwner()); + return opInThisBlock && newWait->isBeforeInBlock(opInThisBlock); + }; + for (int i = 0; i < origNumOperands; i++) { + Value operand = wait.getResult(i); + if (!isa(operand.getType())) + operand.replaceAllUsesWith(newWait.getResult(i)); + } + for (int i = origNumOperands; i < newOperands.size(); i++) { + Value operand = newWait.getOperand(i); + if (!isa(operand.getType())) + operand.replaceUsesWithIf(newWait.getResult(i), dominatedByNewWait); + } + wait->erase(); +} + +// Determines whether a given MMAv3 dot op, represented as ttng.warp_group_dot, +// needs a wait immediately after it. +// +// In PTX, MMAv3 exists only as an asynchronous op. In Triton, we can represent +// MMAv3 ops as either ttng.warp_group_dot {isAsync=True} or ttng.warp_group_dot +// {isAsync=False}. But even if we use ttng.warp_group_dot {isAsync=True}, the +// conservative thing is to make a dot "effectively synchronous" by inserting a +// `ttng.warp_group_dot_wait {pendings=0}` right after it. +// +// We can omit the wait and create a "properly async" dot if all of the +// following are true. +// +// 1. All operands that touch shared memory are multi-buffered, i.e. can't read +// an incomplete value while it's being written asynchronously by a load. +// 1a. If operand A is in registers, these registers cannot be updated +// inside +// the loop. +// **Exception** if the operand is produced by a preceding WGMMA, +// then this op can be properly async. Either the f16 shortcut is +// possible and the WGMMA's can run back-to-back (see rule 3 below), or +// elementwise truncate is needed, in which case the preceding WGMMA is +// not async and a WarpGroupDotWait is inserted right after, which +// guarantees exclusive access to the operand registers. +// +// 2. If the dot is used by any op in the loop, it must be used under an `if`, +// and will be synced with a `wait 0` at the beginning of the `if` block. +// +// 3. During iteration i, between the start of the loop up until the first +// `ttng.warp_group_dot_wait {pendings=0}` op, the result of the dot from +// iteration i-1 is consumed only by other MMAv3 dots as the `c` operand. +// +// This is safe because the following pseudo-PTX is valid: +// +// %accum = warp_group_dot %a1, %b1, %c1 +// %accum = warp_group_dot %a2, %b2, %accum +// +// That is, the second async dot can use the result of the first one without +// an intervening wait. However, the only operation that can legally read +// %accum before the wait is another warp_group_dot, and this only works for +// the `c` operand, not `a` or `b`. See +// https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence +// (ttng::WarpGroupDotOp corresponds to wgmma.fence followed by one or more +// wgmma.async ops, so our understanding is that the two +// ttng::WarpGroupDotOps don't have to correspond to wgmma.async ops with +// the same shapes as specified in the docs, because there's an intervening +// fence.) +// +// If the op can be properly async, this function returns the index of the dot +// in the loop's iter_args. (Rule (2) above ensures this is well-defined.) +// +static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, + scf::ForOp forOp) { + LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp); + + // Rule 1: All shmem operands are multi-buffered. + auto checkOperand = [&](Value operand) { + if (!isa( + cast(operand.getType()).getEncoding())) { + // Rule 1a: Register operands must not be modified within the loop. + // First, check for chained WGMMA as an exception. + if (auto cvt = dyn_cast(operand.getDefiningOp())) { + return isa( + cvt.getSrc().getType().getEncoding()); + } + // And then, do a stricter-than-necessary check for now, that the operand + // is defined outside the loop. + return forOp.isDefinedOutsideOfLoop(operand); + } + + // If it's a shmem operand, it must either be defined outside the loop, or + // come from an MemDescSubview op. Only ConvertLayout and Trans ops are + // allowed in between. + Value transitiveOperand = operand; + while (isa_and_nonnull( + transitiveOperand.getDefiningOp()) || + isa(transitiveOperand)) { + auto blockArg = dyn_cast(transitiveOperand); + if (blockArg && blockArg.getOwner() == forOp.getBody()) { + transitiveOperand = + cast(blockArg.getOwner()->getTerminator()) + .getOperand(blockArg.getArgNumber() - 1); + } else if (Operation *def = transitiveOperand.getDefiningOp()) { + transitiveOperand = def->getOperand(0); + } + } + return forOp.isDefinedOutsideOfLoop(transitiveOperand) || + transitiveOperand.getDefiningOp(); + }; + + // We don't have to call checkOperand on getC() because it's always in + // registers, never in shmem. + assert(isa(dotOp.getC().getType().getEncoding())); + if (!checkOperand(dotOp.getA()) || !checkOperand(dotOp.getB())) { + LDBG("Can't make dot async because shmem operands aren't multi-buffered"); + return std::nullopt; + } + + // Rule 2: The dot cannot be unconditionally used by any op in the loop. + // Uses under `if` are allowed, as can be explicitly synced with a `wait 0`. + int iterArgIdx = -1; + Value iterArg = nullptr; + SmallVector> queue; + for (auto &use : dotOp->getUses()) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + while (!queue.empty()) { + auto [user, argIdx] = queue.pop_back_val(); + if (user->getParentOp() == forOp) { + if (isa(user)) { + if (iterArg) { + // The dot is used by the loop's yield, but we can't have any other + // uses. + LDBG("Can't make dot async because dot is used by multiple ops in " + "the loop."); + return std::nullopt; + } + iterArgIdx = argIdx; + iterArg = forOp.getRegionIterArg(argIdx); + continue; + } + LDBG("Can't make dot async because dot is unconditionally used in the " + "loop."); + return std::nullopt; + } + if (auto ifOp = dyn_cast(user->getParentOp())) { + if (isa(user)) { + // The result is returned by the if, follow it further. + auto uses = ifOp.getResult(argIdx).getUses(); + for (auto &use : uses) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + } + } else { + return std::nullopt; + } + } + + // Rule 3a: Are the only users of the dot's result from iteration i-1 other + // MMAv3 dots? If so, we're done, this dot can be properly async. + if (llvm::all_of(iterArg.getUses(), [&](OpOperand &use) { + return isa(use.getOwner()) && + use.getOperandNumber() == 2; + })) { + return iterArgIdx; + } + + // Rule 3b: Are all users of the dot's result from iteration i-1 after the + // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be + // properly async, but we have to thread its result from iteration i-1 through + // the wait. + auto waitOps = forOp.getBody()->getOps(); + auto firstWaitOpIter = llvm::find_if( + waitOps, [&](auto waitOp) { return waitOp.getPendings() == 0; }); + if (firstWaitOpIter != waitOps.end() && + llvm::all_of(iterArg.getUsers(), [&](Operation *user) { + assert(forOp->isAncestor(user)); + while (user->getParentOp() != forOp) { + user = user->getParentOp(); + } + return (*firstWaitOpIter)->isBeforeInBlock(user); + })) { + LDBG("MMAv3 dot can be properly async because it follows a " + "warp_group_dot_wait " + "{pendings=0}.\n" + << " wait: " << *firstWaitOpIter << "\n" + << " dot: " << dotOp); + threadValuesThroughWait(*firstWaitOpIter, {iterArg}); + return iterArgIdx; + } + + LDBG("Can't make dot async because its result from i-1 is used by " + "something other than another MMAv3 dot as the `c` operand."); + return std::nullopt; +} + +// If necessary, insert a dot-wait inside the loop, waiting for the results of +// the properly-async dots from iteration i-1 to complete. (We pipeline to +// depth 2, so there are at most 2 copies of each warp_group_dot in flight at a +// time.) +// +// We can skip inserting the wait if we have a `warp_group_dot_wait +// {pendings=0}` somewhere in the loop. To see why, consider: +// +// warp_group_dot +// warp_group_dot; wait 0 // synchronous dot +// warp_group_dot +// warp_group_dot +// +// In this example, there are three properly-async dots, so we'd normally put +// `wait 3` at the end of the loop, meaning "wait until there are 3 or fewer +// pending async dots". But note that when this iteration of the loop +// completes, there are only *two* pending async dots from this iteration, so +// this wait would do nothing. This is true in general, no matter where the +// `wait 0` appears. +static void insertAsyncWarpGroupDotWaitInLoop( + scf::ForOp forOp, + const llvm::MapVector &properlyAsyncDots) { + if (properlyAsyncDots.empty()) + return; + + if (llvm::any_of(forOp.getBody()->getOps(), + [](auto wait) { return wait.getPendings() == 0; })) { + return; + } + + // Insert waits before the users of the properly async dots other than loop + // yield. + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + SmallVector uses; + for (auto &use : asyncDot->getUses()) { + if (auto yieldOp = dyn_cast(use.getOwner())) { + continue; + } + uses.push_back(&use); + } + + DenseMap> blockToUsers; + for (auto use : uses) { + auto block = use->getOwner()->getBlock(); + blockToUsers[block].push_back(use->get()); + } + + for (auto [block, users] : blockToUsers) { + OpBuilder builder(block, block->begin()); + auto newWait = builder.create( + asyncDot->getLoc(), ArrayRef{}, 0); + + threadValuesThroughWait(newWait, users); + } + } + + // Add the wait right after the last properly-async dot. This only needs to + // wait for all properly-async dots from the i-1'th iteration to complete, IOW + // we wait until there are most `asyncDots.size()` dots in flight. + // + // (You might want to put the wait at the end of the loop instead of right + // after the last dot, but there could be a load into shmem between the last + // async dot and the end of the loop, and that could clobber memory being used + // by a dot.) + IRRewriter builder(forOp.getContext()); + auto lastAsyncDot = properlyAsyncDots.back().first; + builder.setInsertionPointAfter(lastAsyncDot); + auto wait = builder.create( + lastAsyncDot->getLoc(), + /*inputs=*/ArrayRef{}, properlyAsyncDots.size()); + + // Thread the results of the async dots through the wait. + SmallVector addlWaitOperands; + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + addlWaitOperands.push_back(asyncDot->getResult(0)); + } + threadValuesThroughWait(wait, addlWaitOperands); +} + +// Convert MMAv3 ttng::WarpGroupDotOps {isAsync = False} (i.e. Hopper wgmma) +// into ttng::WarpGroupDotOps {isAsync = True} and insert +// ttng::WarpGroupDotWaitOps as necessary. +// +// We assume we have space for each dot to be pipelined to depth 2, i.e. each +// dot op in the loop can have at most 2 warp_group_dot ops in flight at once. +// (Each warp_group_dot op usually corresponds to a series of wgmma.async ops.) +void triton::asyncLaunchDots(scf::ForOp forOp) { + LDBG("Original loop:\n" << *forOp); + + // First, change every MMAv3 ttng.warp_group_dot {isAsync=false} + // into ttng.warp_group_dot {isAsync=true}. + // The rest of this function is concerned with inserting + // ttng.warp_group_dot_wait ops in the appropriate places. + // + // We call those dots that don't need to be followed immediately by a `wait 0` + // "properly async", or sometimes just "async". + // + // For each dot, determine whether it can be properly async, or if it needs a + // sync immediately after. If it can be properly async, we know its only use + // is in the loop's `yield` statement; asyncDots maps the op to its index in + // the yield op. + IRRewriter builder(forOp.getContext()); + llvm::MapVector properlyAsyncDots; + for (auto WarpGroupDotOp : forOp.getBody()->getOps()) { + WarpGroupDotOp.setIsAsync(true); + if (auto iterArgIdx = dotCanBeProperlyAsync(WarpGroupDotOp, forOp)) { + properlyAsyncDots[WarpGroupDotOp] = *iterArgIdx; + } else { + builder.setInsertionPointAfter(WarpGroupDotOp); + auto wait = builder.create( + WarpGroupDotOp.getLoc(), ArrayRef{}, + /*pendings=*/0); + SmallVector waitOperands = {WarpGroupDotOp.getResult()}; + threadValuesThroughWait(wait, waitOperands); + } + } + + if (properlyAsyncDots.empty()) { + LDBG("No properly async dots."); + return; + } + + // Next, insert a wait inside the loop. We pipeline to depth 2, so the third + // iteration's set of asynchronous dots (and their corresponding async copies + // from global to shmem) can't start until the first iteration's set has + // completed. + insertAsyncWarpGroupDotWaitInLoop(forOp, properlyAsyncDots); + + // Finally, insert a wait after the loop, waiting for dots from the final + // iteration of the loop. + SmallVector waitOperands; + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + waitOperands.push_back(forOp.getResult(iterArgIdx)); + } + // Wait until there are 0 outstanding async dot ops. + builder.setInsertionPointAfter(forOp); + auto WarpGroupDotWaitAfterLoop = builder.create( + forOp.getLoc(), ArrayRef{}, 0); + threadValuesThroughWait(WarpGroupDotWaitAfterLoop, waitOperands); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp new file mode 100644 index 000000000..248871ebd --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -0,0 +1,461 @@ +//===----------------------------------------------------------------------===// +// +// This pass tries to prefetch operands (a and b) of tt.dot. +// Those ConvertLayoutOps will be lowered to shared memory loads. +// +// For example: +// %a: tensor<128x32xf16, #enc> +// scf.for %iv = ... iter_args(%a_arg = %a, ...) { +// %d = tt.dot %a_arg, %b, %c +// ... +// scf.yield %a_next, ... +// } +// +// will be translated to +// +// %a: tensor<128x32xf16, #enc> +// %a_tmp = tensor.subview %a[0, 0] [128, 16] +// %a_prefetch = ttg.local_load %a_tmp +// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch) +// { +// %x = tt.dot %a_prefetch_arg, %b, %c +// %a_tmp_rem = tensor.subview %a_buf[0, 16] [128, 16] +// %a_prefetch_next = ttg.local_load %a_tmp_rem +// ... +// scf.yield %next_a, ..., %a_prefetch_next +// } +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-prefetch" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPREFETCH +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +class Prefetcher { + /// cache the ForOp we are working on + scf::ForOp forOp; + /// cache the YieldOp of this ForOp + scf::YieldOp yieldOp; + /// + // TODO: add a hook to infer prefetchWidth + unsigned prefetchWidth = 32; + + /// dots to be prefetched + SetVector dots; + /// dot => dot operand + DenseMap dot2aLoopArg; + DenseMap dot2aHeaderDef; + DenseMap dot2bLoopArg; + DenseMap dot2bHeaderDef; + DenseMap dot2aYield; + DenseMap dot2bYield; + DenseMap> dot2aVals; + DenseMap> dot2bVals; + /// operand => defining + DenseMap operand2headPrefetch; + + LogicalResult isForOpOperand(Value v); + + Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + std::optional offsetK = std::nullopt, + std::optional shapeK = std::nullopt); + + void cloneElementwiseOps(Value &bRem, const SmallVector &vals, + OpBuilder &builder); + +public: + Prefetcher() = delete; + + Prefetcher(scf::ForOp forOp) : forOp(forOp) { + yieldOp = cast(forOp.getBody()->getTerminator()); + } + + LogicalResult initialize(); + + void emitPrologue(); + + scf::ForOp createNewForOp(); +}; + +void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector &vals, + OpBuilder &builder) { + IRMapping mapping; + mapping.map(vals[1], ret); + for (int i = 2; i < vals.size(); i++) { + Value v = vals[i]; + Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0); + if (isa(curr.getType())) { + auto retType = RankedTensorType::get( + cast(ret.getType()).getShape(), + cast(curr.getType()).getElementType(), + cast(curr.getDefiningOp()->getOperand(0).getType()) + .getEncoding()); + curr.setType(retType); + } + mapping.map(v, curr); + } + if (vals.size() > 1) + ret = mapping.lookup(vals.back()); +} + +Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + std::optional offsetK, + std::optional shapeK) { + // opIdx: 0 => a, 1 => b + auto type = cast(v.getType()); + SmallVector shape{type.getShape().begin(), type.getShape().end()}; + auto rank = shape.size(); + SmallVector offset(rank, 0); + Type elementType = type.getElementType(); + + // k => (prefetchWidth, k - prefetchWidth) + int64_t kIdx = opIdx == 0 ? rank - 1 : rank - 2; + + offset[kIdx] = isPrologue ? 0 : prefetchWidth; + shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth); + + if (shapeK) + shape[kIdx] = *shapeK; + if (offsetK) + offset[kIdx] = *offsetK; + + SmallVector offsetsVal; + for (int64_t off : offset) + offsetsVal.push_back( + builder.create(v.getLoc(), off, 32)); + Value newSmem = builder.create( + v.getLoc(), + triton::gpu::MemDescType::get( + shape, elementType, type.getEncoding(), type.getMemorySpace(), + type.getMutableMemory(), type.getAllocShape()), + v, offsetsVal); + + auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( + builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + Value prefetchSlice = builder.create( + v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), + newSmem); + + return prefetchSlice; +} + +LogicalResult Prefetcher::initialize() { + Block *loop = forOp.getBody(); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + SmallVector dotsInFor; + for (Operation &op : *loop) + if (auto dotOp = dyn_cast(op)) { + // Only accepts dotOps encoded as Nvidia MMA v2 or AMD MFMA + auto dstMmaEnc = + dyn_cast(getEncoding(dotOp.getResult())); + auto dstMfmaEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstMfmaEnc && (!dstMmaEnc || dstMmaEnc.getVersionMajor() != 2)) + // Don't rewrite if any other type is found. + return failure(); + dotsInFor.push_back(dotOp); + } + + if (dotsInFor.empty()) + return failure(); + + // TODO: segfault (original for still has uses) + // when used in flash attention that has 2 dots in the loop + if (dotsInFor.size() > 1) + return failure(); + + // returns source of cvt + auto getPrefetchSrc = [](Value v) -> SmallVector { + // walk back to conversion + Operation *op = v.getDefiningOp(); + bool foundConvertFromShared = false; + SmallVector rets; + rets.push_back(op->getResult(0)); + LDBG("Prefetch src: " << *op); + while (op) { + if (op->getNumOperands() != 1) + break; + if (!op->getResult(0).hasOneUse()) + break; + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast(op)) { + // NYI for other encodings, for example if we have transpose + // in the chain + if (isa(cvt.getType().getEncoding())) + foundConvertFromShared = true; + break; + } + op = op->getOperand(0).getDefiningOp(); + if (op) + LDBG("op: " << *op); + } + std::reverse(rets.begin(), rets.end()); + + if (foundConvertFromShared) + return rets; + return {}; + }; + + auto getIncomingOp = [this](Value v) -> Value { + if (auto arg = mlir::dyn_cast(v)) + if (arg.getOwner()->getParentOp() == forOp.getOperation()) + return forOp.getTiedLoopInit(arg)->get(); + return Value(); + }; + + auto getYieldOperand = [this](Value v) -> Value { + auto arg = mlir::cast(v); + unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars(); + return yieldOp.getOperand(yieldIdx); + }; + + for (triton::DotOp dot : dotsInFor) { + auto aType = dot.getA().getType(); + auto bType = dot.getB().getType(); + auto aEnc = + mlir::cast(aType.getEncoding()); + auto bEnc = + mlir::cast(bType.getEncoding()); + int aKWidth = aEnc.getKWidth(); + int bKWidth = bEnc.getKWidth(); + assert(aKWidth == bKWidth); + + auto kSize = aType.getShape().back(); + + // works better with nvidia tensor cores + unsigned elementWidth = aType.getElementTypeBitWidth(); + if (aKWidth == 0) + prefetchWidth = 256 / elementWidth; + else + prefetchWidth = 8 * aKWidth; + + // Skip prefetching if kSize is less than prefetchWidth + if (kSize < prefetchWidth) + continue; + auto aVals = getPrefetchSrc(dot.getA()); + auto bVals = getPrefetchSrc(dot.getB()); + + if (aVals.size() && bVals.size()) { + Value aSmem = aVals.front(); + Value bSmem = bVals.front(); + Value aHeaderDef = getIncomingOp(aSmem); + Value bHeaderDef = getIncomingOp(bSmem); + // Only prefetch loop arg + if (aHeaderDef && bHeaderDef) { + dots.insert(dot); + dot2aVals[dot] = aVals; + dot2bVals[dot] = bVals; + dot2aHeaderDef[dot] = aHeaderDef; + dot2bHeaderDef[dot] = bHeaderDef; + dot2aLoopArg[dot] = aSmem; + dot2bLoopArg[dot] = bSmem; + dot2aYield[dot] = getYieldOperand(aSmem); + dot2bYield[dot] = getYieldOperand(bSmem); + } + } + } + + return success(); +} + +void Prefetcher::emitPrologue() { + OpBuilder builder(forOp); + + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Value aPrefetched = + generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder); + cloneElementwiseOps(aPrefetched, dot2aVals[dot], builder); + Value bPrefetched = + generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder); + cloneElementwiseOps(bPrefetched, dot2bVals[dot], builder); + + operand2headPrefetch[dot.getA()] = aPrefetched; + operand2headPrefetch[dot.getB()] = bPrefetched; + } +} + +scf::ForOp Prefetcher::createNewForOp() { + OpBuilder builder(forOp); + + SmallVector loopArgs; + for (auto v : forOp.getInitArgs()) + loopArgs.push_back(v); + for (triton::DotOp dot : dots) { + loopArgs.push_back(operand2headPrefetch[dot.getA()]); + loopArgs.push_back(operand2headPrefetch[dot.getB()]); + } + + auto newForOp = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), loopArgs); + + builder.setInsertionPointToStart(newForOp.getBody()); + IRMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // The insertion point should be placed before the yield op + auto setInsertionPointBeforeYield = [](OpBuilder &builder, + scf::ForOp newForOp) { + if (newForOp.getBody()->mightHaveTerminator()) { + builder.setInsertionPoint(newForOp.getBody()->getTerminator()); + } else { + builder.setInsertionPointToEnd(newForOp.getBody()); + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + // If we're currently trying to sink a prefetched dot, we need to stop + // sinking it (by resetting the insertion point to the end) if we find + // control flow, or anything that depends on the dot op. + if (op.getNumRegions() > 0) { + setInsertionPointBeforeYield(builder, newForOp); + } + for (auto operand : op.getOperands()) { + if (auto def = operand.getDefiningOp()) { + auto dot = dyn_cast(def); + if (dot && dots.contains(dot)) { + setInsertionPointBeforeYield(builder, newForOp); + } + } + } + Operation *newOp = builder.clone(op, mapping); + auto dot = dyn_cast(&op); + if (dot && dots.contains(dot)) { + Attribute dotEncoding = dot.getType().getEncoding(); + // prefetched dot + Operation *firstDot = builder.clone(*dot, mapping); + if (Value a = operand2headPrefetch.lookup(dot.getA())) + firstDot->setOperand( + 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); + if (Value b = operand2headPrefetch.lookup(dot.getB())) + firstDot->setOperand( + 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); + + // remaining part + int64_t kOff = prefetchWidth; + int64_t kRem = dot.getA().getType().getShape().back() - prefetchWidth; + Operation *prevDot = firstDot; + if (kRem == 0) { + // There is only one dot while prefetchWidth == kSize so delay issuing + // it. Meanwhile, newOp should be set to firstDot to make sure the dot + // result is updated to yield. + builder.setInsertionPoint(prevDot); + newOp = firstDot; + } + + while (kRem != 0) { + // int64_t kShape = largestPow2(kRem); + int64_t kShape = prefetchWidth; + auto insertionPoint = builder.saveInsertionPoint(); + builder.setInsertionPoint(prevDot); + Value aRem = + generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false, + dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(aRem, dot2aVals[dot], builder); + Value bRem = + generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false, + dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(bRem, dot2bVals[dot], builder); + builder.restoreInsertionPoint(insertionPoint); + newOp = builder.clone(*dot, mapping); + newOp->setOperand(0, aRem); + newOp->setOperand(1, bRem); + newOp->setOperand(2, prevDot->getResult(0)); + prevDot = newOp; + kOff += kShape; + kRem -= kShape; + if (kRem == 0) { + // We want to delay issuing the last dot as long as possible, ideally + // until after the prefetch. To accomplish this, set the insertion + // point above the dot. If we find anything dependent on the dot (at + // the top of this loop), we resume inserting after it. + builder.setInsertionPoint(prevDot); + } + } + } + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) + mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); + } + + // prefetch next iteration + SmallVector yieldValues; + for (Value v : forOp.getBody()->getTerminator()->getOperands()) + yieldValues.push_back(mapping.lookupOrDefault(v)); + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Value aToYield = generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true, + dotEncoding, builder); + cloneElementwiseOps(aToYield, dot2aVals[dot], builder); + yieldValues.push_back(aToYield); + // bToYield + Value bToYield = generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true, + dotEncoding, builder); + cloneElementwiseOps(bToYield, dot2bVals[dot], builder); + yieldValues.push_back(bToYield); + } + // Update ops of yield + builder.setInsertionPointToEnd(newForOp.getBody()); + if (!yieldValues.empty()) + builder.create(yieldOp.getLoc(), yieldValues); + return newForOp; +} + +} // anonymous namespace + +struct PrefetchPass : public impl::TritonGPUPrefetchBase { + void runOnOperation() override { + + // Canonicalize convert ops to make the pattern matching easier. + RewritePatternSet cleanUpPatterns(&getContext()); + triton::gpu::ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, + &getContext()); + if (mlir::applyPatternsGreedily(getOperation(), std::move(cleanUpPatterns)) + .failed()) { + signalPassFailure(); + } + getOperation()->walk([&](scf::ForOp forOp) { + Prefetcher prefetcher(forOp); + + if (prefetcher.initialize().failed()) + return; + + prefetcher.emitPrologue(); + + scf::ForOp newForOp = prefetcher.createNewForOp(); + + // replace the original loop + for (unsigned i = 0; i < forOp->getNumResults(); ++i) + forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); + forOp->erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp new file mode 100644 index 000000000..8b252d068 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -0,0 +1,79 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREDUCEDATADUPLICATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUReduceDataDuplicationPass + : public impl::TritonGPUReduceDataDuplicationBase< + TritonGPUReduceDataDuplicationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcEncoding = srcType.getEncoding(); + if (isa(srcEncoding)) + return; + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (!dstDotOp) + return; + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; + auto srcOrder = triton::gpu::getOrder(srcType); + auto rank = srcOrder.size(); + SmallVector sharedOrder; + if (rank == 3) { + // add all elements except the element that is zero + for (unsigned i = 0; i < rank; ++i) + if (srcOrder[i] != 0) + sharedOrder.emplace_back(srcOrder[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = srcOrder; + } + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); + auto tmpType = triton::gpu::MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SwizzledSharedEncodingAttr::get( + mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder, + triton::gpu::getCTALayout(srcEncoding), srcType.getElementType()), + sharedMemorySpace); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + auto newConvert = builder.create(cvtOp.getLoc(), + dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp new file mode 100644 index 000000000..1c779eb40 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -0,0 +1,1534 @@ +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include +#include + +namespace mlir::triton::gpu { + +#define GEN_PASS_DEF_TRITONGPUREMOVELAYOUTCONVERSIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-remove-layout-conversions" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// The current algorithm works by analyzing the IR and doing a one-shot rewrite +// based on the analysis. The algorithm is as follows. +// +// 1. Find all the anchor ops. These are ops that have a layout we want to +// preserve. +// +// 2. For each anchor, propagate its layout to all its descendants. +// An op can have multiple ancestors that are anchors, so at this stage an op +// may have multiple layouts associated with it. +// +// 3. Resolve conflicts by deciding which of the multiple layouts the op should +// keep, inserting convert-layout ops to resolve conflicts. After this +// stage, each value has only one layout associated with it. +// +// 4. Rewrite the IR by walking the function in dominance order. Since we +// assume the IR is structured we just need to process the regions in the +// correct order. For each op, rewrite it using the layout decided by the +// analysis phase. +class LayoutPropagation { +public: + // Structure to keep track of the layout associated to a value. + struct LayoutInfo { + LayoutInfo(Attribute encoding) { encodings.insert(encoding); } + LayoutInfo() {} + llvm::SmallSetVector encodings; + }; + LayoutPropagation(FuncOp F) : funcOp(F) {} + // Find the anchor ops and set their layout in the data structure. + void initAnchorLayout(); + // Recursively Propagate the layout to all the users of the anchor ops until + // we reach a fix point. + void propagateLayout(); + // Add layouts given in `Info` to the uses of `value`. + SmallVector propagateToUsers(Value value, LayoutInfo &info); + // Set the encoding to all the values and fill out the values with new layout + // in `changed`. + void setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, Operation *op); + // Resolve cases where a value has multiple layouts associated to it. + void resolveConflicts(); + // Rewrite the IR for the full module. + void rewrite(); + // Rewrite the IR for a region. + void rewriteRegion(Region &R); + // Rewrite an op based on the layout picked by the analysis. + Operation *rewriteOp(Operation *op); + // Rewrite a for op based on the layout picked by the analysis. + Operation *rewriteForOp(scf::ForOp forOp); + Operation *rewriteWhileOp(scf::WhileOp whileOp); + Operation *rewriteIfOp(scf::IfOp ifOp); + void rewriteYieldOp(scf::YieldOp yieldOp); + void rewriteConditionOp(scf::ConditionOp conditionOp); + void rewriteReduceToScalar(Operation *reduceOp); + void rewriteAssertOp(AssertOp assertOp); + Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, + Attribute encoding); + // Map the original value to the rewritten one. + void map(Value old, Value newV); + // Return the mapped value in the given encoding. This will insert a convert + // if the encoding is different than the encoding decided at resolve time. + Value getValueAs(Value value, Attribute encoding); + // Dump the current stage of layout information. + void dump(); + +private: + // map from value to layout information. + llvm::MapVector layouts; + // map of the values rewrite based on their encoding. + DenseMap, Value> rewriteMapping; + SetVector opToDelete; + FuncOp funcOp; +}; + +class LayoutRematerialization { +public: + LayoutRematerialization(FuncOp F) : funcOp(F) {} + + // Map the original value to the remat'ed one. + void addRematValue(Value old, Attribute encoding, Value newV); + // Get the remat'ed value in the given encoding, if one already exists and + // is different then the layout conversion root. + Value getRematValue(Value value, Attribute encoding) const { + return rematMapping.lookup({value, encoding}); + } + + void cleanup(); + void backwardRematerialization(); + void backwardRematerialization(ConvertLayoutOp convertOp); + // TODO: Merge the three hoistConvert*(); functions as they are duplicate code + void hoistConvertDotOperand(); + void hoistConvertDotOperand(ConvertLayoutOp convertOp); + void hoistConvertOnTopOfExtOrBroadcast(); + void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); + void hoistConvertIntoConditionals(); + void hoistConvertIntoConditionals(ConvertLayoutOp convertOp); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp, IRMapping &mapping); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp); + + LogicalResult + getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding, + SetVector &slice, + DenseMap &layout, + std::function stopPropagation); + + LogicalResult getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr); + +private: + void updateRematMapping(SmallVector> &values); + // Existing tuples of (value, layout) that needs to be updated when recreating + // scf ops. This prevents keeping track of Values that have been delete when + // rewriting slices. + DenseMap mappedValues; + // map of the values remat based on encoding. + DenseMap, Value> rematMapping; + // DenseMap, Operation*> + SetVector opToDelete; + FuncOp funcOp; + DominanceInfo domInfo; + PostDominanceInfo postDomInfo; +}; + +void LayoutRematerialization::addRematValue(Value old, Attribute encoding, + Value newV) { + LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); + rematMapping[{old, encoding}] = newV; + mappedValues[old] = encoding; +} + +// Remove unneeded values now that we are done with the rematMapping. +void LayoutRematerialization::cleanup() { + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +// Return true if the op is an op with a layout we don't want to change. We will +// propagate the layout starting from anchor ops. +bool isLayoutAnchor(Operation *op) { + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return true; + if (auto gatherOp = dyn_cast(op)) + return gatherOp.getEfficientLayout(); + + // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be + // anything, so it stops forward-propagation of layouts. We rely on the + // backwards pass to fix it up if necessary. (If we didn't do this, then + // anything following the reshape won't be covered by the forward pass at + // all.) + if (auto reshape = dyn_cast(op)) + return reshape.getAllowReorder(); + + return false; +} + +void LayoutPropagation::initAnchorLayout() { + auto addAnchor = [&](Value v) { + if (auto tensorType = dyn_cast(v.getType())) { + layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); + } + }; + + // Consider function args as anchors. This makes it easier to write tests -- + // you can pass a tensor with an encoding as an arg, instead of explicitly + // calling tt.load. + for (auto arg : funcOp.getArguments()) { + addAnchor(arg); + } + + funcOp.walk([&](Operation *op) { + if (isLayoutAnchor(op)) { + for (auto result : op->getResults()) { + addAnchor(result); + } + } + }); +} + +void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, + Operation *op) { + for (Value value : values) { + if (!isa(value.getType())) + continue; + bool hasChanged = false; + for (auto encoding : info.encodings) { + Attribute dstEncoding; + if (isa(op)) { + // Try to remove the convert by making the dst encoding match the source + // encoding. + dstEncoding = encoding; + } else { + dstEncoding = inferDstEncoding(op, encoding); + } + if (dstEncoding) + hasChanged |= layouts[value].encodings.insert(dstEncoding); + } + if (hasChanged) + changed.push_back(value); + } +} + +SmallVector LayoutPropagation::propagateToUsers(Value value, + LayoutInfo &info) { + SmallVector changed; + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getTiedLoopRegionIterArg(&use); + Value result = forOp.getTiedLoopResult(&use); + setEncoding({arg, result}, info, changed, user); + continue; + } + if (auto whileOp = dyn_cast(user)) { + Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()]; + setEncoding({arg}, info, changed, user); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto parent = yieldOp->getParentOp(); + SmallVector valuesToPropagate; + if (isa(parent)) + valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); + if (auto forOp = dyn_cast(parent)) + valuesToPropagate.push_back( + forOp.getRegionIterArg(use.getOperandNumber())); + if (auto whileOp = dyn_cast(parent)) + valuesToPropagate.push_back( + whileOp.getBeforeArguments()[use.getOperandNumber()]); + if (isa(parent)) + setEncoding(valuesToPropagate, info, changed, user); + continue; + } + if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(conditionOp->getParentOp()); + // Skip arg 0 as it is the condition. + unsigned argIndex = use.getOperandNumber() - 1; + Value afterArg = whileOp.getAfterArguments()[argIndex]; + Value result = whileOp->getResult(argIndex); + setEncoding({afterArg, result}, info, changed, user); + continue; + } + if (auto dotWaitOp = dyn_cast(user)) { + unsigned opIndex = use.getOperandNumber(); + Value result = dotWaitOp->getResult(opIndex); + setEncoding(result, info, changed, user); + continue; + } + if (auto gatherOp = dyn_cast(user)) { + // Propagate the layout through the indices only, and if the layout does + // not have an efficient layout set. + if (!gatherOp.getEfficientLayout() && + &use == &gatherOp.getIndicesMutable()) { + setEncoding(gatherOp.getResult(), info, changed, user); + continue; + } + } + if (user->hasTrait() || + user->hasTrait() || + isa(user)) { + setEncoding(user->getResults(), info, changed, user); + continue; + } + } + return changed; +} + +void LayoutPropagation::propagateLayout() { + SmallVector queue; + for (auto it : layouts) { + queue.push_back(it.first); + } + while (!queue.empty()) { + Value currentValue = queue.back(); + LayoutInfo info = layouts[currentValue]; + queue.pop_back(); + SmallVector changed = propagateToUsers(currentValue, info); + + LLVM_DEBUG({ + DBGS() << "propagateLayout considering " << currentValue << ", which has " + << info.encodings.size() << " candidate encoding(s):\n"; + for (Attribute encoding : info.encodings) + DBGS() << " " << encoding << "\n"; + }); + + queue.insert(queue.end(), changed.begin(), changed.end()); + } +} + +void LayoutPropagation::resolveConflicts() { + for (auto &it : layouts) { + Operation *op = it.first.getDefiningOp(); + LayoutInfo &info = it.second; + if (info.encodings.size() <= 1) + continue; + // Hacky resolve, prefer block encoding. + // TODO: add a proper heuristic. + Attribute encoding = *info.encodings.begin(); + bool isLoadOrStore = + op && isa(op); + for (Attribute e : info.encodings) { + if ((isLoadOrStore && isa(e)) || + (!isLoadOrStore && isa(e))) { + encoding = e; + break; + } + } + info.encodings.clear(); + info.encodings.insert(encoding); + } +} + +void LayoutPropagation::dump() { + for (auto it : layouts) { + llvm::errs() << "Value: "; + OpPrintingFlags flags; + flags.skipRegions(); + it.first.print(llvm::errs(), flags); + llvm::errs() << " \n encoding:\n"; + for (auto encoding : it.second.encodings) { + encoding.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "--\n"; + } +} + +void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } + +bool reduceToScalar(Operation *op) { + // For reductions returning a scalar we can change the src encoding without + // affecting the output. + return isa(op) && !isa(op->getResultTypes()[0]); +} + +void LayoutPropagation::rewriteRegion(Region ®ion) { + std::deque queue = {®ion}; + while (!queue.empty()) { + Region *currentRegion = queue.front(); + queue.pop_front(); + for (Operation &op : currentRegion->getOps()) { + bool needRewrite = false; + SmallVector results = op.getResults(); + for (Value result : results) { + auto it = layouts.find(result); + // If we haven't mapped this value skip. + if (it == layouts.end()) + continue; + LayoutInfo &info = it->second; + assert(info.encodings.size() == 1 && + "we should have resolved to a single encoding"); + auto encoding = cast(result.getType()).getEncoding(); + // If the encoding is already what we want skip. + if (encoding == *info.encodings.begin()) + continue; + needRewrite = true; + } + if (needRewrite) { + Operation *newOp = rewriteOp(&op); + for (Region &R : newOp->getRegions()) + queue.push_back(&R); + } else if (auto yieldOp = dyn_cast(&op)) { + rewriteYieldOp(yieldOp); + } else if (auto conditionOp = dyn_cast(&op)) { + rewriteConditionOp(conditionOp); + } else if (reduceToScalar(&op)) { + rewriteReduceToScalar(&op); + } else if (auto assertOp = dyn_cast(&op)) { + rewriteAssertOp(assertOp); + } else { + // If we don't need to rewrite the op we still need to remap the + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto it = layouts.find(operand.get()); + if (it == layouts.end()) + continue; + Attribute encoding = + cast(operand.get().getType()).getEncoding(); + Value newOperand = getValueAs(operand.get(), encoding); + op.setOperand(operand.getOperandNumber(), newOperand); + } + for (Region &R : op.getRegions()) + queue.push_back(&R); + } + } + } + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +void LayoutPropagation::map(Value old, Value newV) { + rewriteMapping[{old, cast(newV.getType()).getEncoding()}] = + newV; +} + +Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { + if (auto tensorType = dyn_cast(value.getType())) { + Value rewrittenValue; + auto layoutIt = layouts.find(value); + if (layoutIt == layouts.end()) { + rewrittenValue = value; + } else { + assert(layoutIt->second.encodings.size() == 1 && + "we should have resolved to a single encoding"); + Attribute encodingPicked = *(layoutIt->second.encodings.begin()); + if (encodingPicked == tensorType.getEncoding()) + rewrittenValue = value; + else + rewrittenValue = rewriteMapping[{value, encodingPicked}]; + } + assert(rewrittenValue); + if (cast(rewrittenValue.getType()).getEncoding() == + encoding) + return rewrittenValue; + OpBuilder rewriter(value.getContext()); + rewriter.setInsertionPointAfterValue(rewrittenValue); + auto tmpType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value converted = rewriter.create(value.getLoc(), tmpType, + rewrittenValue); + // TODO: we could cache the conversion. + return converted; + } + return value; +} + +Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, + Operation *op, + Attribute encoding) { + Operation *newOp = rewriter.clone(*op); + + Attribute operandEnc; + if (op->getNumOperands() > 0) { + operandEnc = inferSrcEncoding(op, encoding); + assert(operandEnc); + } + + for (OpOperand &operand : op->getOpOperands()) { + newOp->setOperand(operand.getOperandNumber(), + getValueAs(operand.get(), operandEnc)); + } + + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { + auto origType = dyn_cast(op->getResult(i).getType()); + if (!origType) + continue; + auto newType = RankedTensorType::get(origType.getShape(), + origType.getElementType(), encoding); + newOp->getResult(i).setType(newType); + } + return newOp; +} + +Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { + SmallVector operands; + OpBuilder rewriter(forOp); + for (auto [operand, result] : + llvm::zip(forOp.getInitArgs(), forOp.getResults())) { + Value convertedOperand = operand; + if (layouts.count(result)) + convertedOperand = + getValueAs(operand, *layouts[result].encodings.begin()); + operands.push_back(convertedOperand); + } + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), operands); + newForOp->setAttrs(forOp->getAttrs()); + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + forOp.getBody()->getOperations()); + + for (auto [oldResult, newResult] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + + for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + if (oldArg.getType() == newArg.getType()) { + oldArg.replaceAllUsesWith(newArg); + continue; + } + map(oldArg, newArg); + } + return newForOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { + SmallVector operands; + SmallVector returnTypes; + OpBuilder rewriter(whileOp); + for (auto [operand, arg] : + llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) { + Value convertedOperand = operand; + if (layouts.count(arg)) + convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin()); + operands.push_back(convertedOperand); + } + for (Value ret : whileOp.getResults()) { + auto it = layouts.find(ret); + if (it == layouts.end()) { + returnTypes.push_back(ret.getType()); + continue; + } + auto origType = dyn_cast(ret.getType()); + auto newType = + RankedTensorType::get(origType.getShape(), origType.getElementType(), + it->second.encodings[0]); + returnTypes.push_back(newType); + } + + auto newWhileOp = + rewriter.create(whileOp.getLoc(), returnTypes, operands); + SmallVector argsTypesBefore; + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + SmallVector bbArgLocsBefore(argsTypesBefore.size(), + whileOp.getLoc()); + SmallVector bbArgLocsAfter(returnTypes.size(), whileOp.getLoc()); + rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter); + + for (int i = 0; i < whileOp.getNumRegions(); ++i) { + newWhileOp->getRegion(i).front().getOperations().splice( + newWhileOp->getRegion(i).front().getOperations().begin(), + whileOp->getRegion(i).front().getOperations()); + } + + auto remapArg = [&](Value oldVal, Value newVal) { + if (oldVal.getType() == newVal.getType()) + oldVal.replaceAllUsesWith(newVal); + else + map(oldVal, newVal); + }; + for (auto [oldResult, newResult] : + llvm::zip(whileOp.getResults(), newWhileOp.getResults())) + remapArg(oldResult, newResult); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments())) + remapArg(oldArg, newArg); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments())) + remapArg(oldArg, newArg); + return newWhileOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { + SmallVector operands; + OpBuilder rewriter(ifOp); + SmallVector newResultTypes(ifOp->getResultTypes()); + for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { + auto it = layouts.find(ifOp->getResult(i)); + if (it == layouts.end()) + continue; + auto origType = cast(ifOp->getResult(i).getType()); + Attribute encoding = *(it->second.encodings.begin()); + newResultTypes[i] = RankedTensorType::get( + origType.getShape(), origType.getElementType(), encoding); + } + auto newIfOp = rewriter.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), true, true); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + for (auto [oldResult, newResult] : + llvm::zip(ifOp.getResults(), newIfOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newIfOp.getOperation(); +} + +void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { + Operation *parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + Type yieldType = operand.get().getType(); + if (isa(parentOp)) + yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); + if (auto whileOp = dyn_cast(parentOp)) + yieldType = + whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); + auto tensorType = dyn_cast(yieldType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + yieldOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { + scf::WhileOp whileOp = cast(conditionOp->getParentOp()); + for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { + OpOperand &operand = conditionOp->getOpOperand(i); + Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); + auto tensorType = dyn_cast(argType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + conditionOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) { + OpBuilder rewriter(reduceOp); + Attribute srcEncoding; + // Since all the operands need to have the same encoding pick the first one + // and use it for all the operands. + for (Value operand : reduceOp->getOperands()) { + auto it = layouts.find(operand); + if (it != layouts.end()) { + srcEncoding = it->second.encodings[0]; + break; + } + } + if (!srcEncoding) + return; + for (OpOperand &operand : reduceOp->getOpOperands()) { + Value newOperand = getValueAs(operand.get(), srcEncoding); + reduceOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { + Attribute srcEncoding; + // Only need to deal with the first operand which is the condition tensor. + Value operand = assertOp->getOperand(0); + auto it = layouts.find(operand); + if (it == layouts.end()) + return; + srcEncoding = it->second.encodings[0]; + Value newOperand = getValueAs(operand, srcEncoding); + assertOp->setOperand(0, newOperand); +} + +Operation *LayoutPropagation::rewriteOp(Operation *op) { + opToDelete.insert(op); + if (auto forOp = dyn_cast(op)) + return rewriteForOp(forOp); + if (auto whileOp = dyn_cast(op)) + return rewriteWhileOp(whileOp); + if (auto ifOp = dyn_cast(op)) + return rewriteIfOp(ifOp); + OpBuilder rewriter(op); + Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); + if (auto convertOp = dyn_cast(op)) { + Attribute srcEncoding = convertOp.getSrc().getType().getEncoding(); + auto it = layouts.find(convertOp.getSrc()); + if (it != layouts.end()) + srcEncoding = *(it->second.encodings.begin()); + Value src = getValueAs(convertOp.getSrc(), srcEncoding); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, src); + cvt->setAttrs(op->getAttrs()); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (canFoldIntoConversion(op, encoding)) { + Operation *newOp = rewriter.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, + newOp->getResult(0)); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (op->hasTrait() || + op->hasTrait() || + isa(op)) { + Operation *newOp = cloneElementwise(rewriter, op, encoding); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newOp; + } + llvm::report_fatal_error("unexpected op in rewrite"); + return nullptr; +} + +bool canBeRemat(Operation *op) { + if (isa(op)) + return !isExpensiveLoadOrStore(op); + if (isa(op)) + return false; + if (auto gather = dyn_cast(op)) + return !gather.getEfficientLayout(); + + if (isa(op)) + return false; + + return true; +} + +void LayoutRematerialization::updateRematMapping( + SmallVector> &values) { + for (auto [old, newV] : values) { + auto it = mappedValues.find(old); + if (it != mappedValues.end()) { + Attribute encoding = it->second; + auto rematIt = rematMapping.find({old, it->second}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + mappedValues.erase(it); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } + } + rematMapping[{newV, encoding}] = replacedValue; + mappedValues[newV] = encoding; + } + } +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp, + IRMapping &mapping) { + SetVector opsToRewrite; + // Keep track of yield operands that need to be duplicated. + DenseMap> yieldOperandsMap; + // Keep these around to remove them from the slice after our collection pass + // This ensures we don't duplicate them during an for rewrite or causing the + // for/yield to fall out of sync + SetVector valuesWithExistingRemat; + for (Value v : slice) { + auto layoutIt = layout.find(v); + assert(layoutIt != layout.end()); + // If we already have a remat value for this value, use it. + if (Value remat = getRematValue(v, layoutIt->second)) { + mapping.map(v, remat); + valuesWithExistingRemat.insert(v); + continue; + } + if (v.getDefiningOp()) { + opsToRewrite.insert(v.getDefiningOp()); + if (auto ifOp = v.getDefiningOp()) { + unsigned operandIdx = cast(v).getResultNumber(); + opsToRewrite.insert(ifOp.thenYield().getOperation()); + yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx); + opsToRewrite.insert(ifOp.elseYield().getOperation()); + yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); + } + } else { + BlockArgument blockArg = cast(v); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (auto loopOp = cast(parentOp)) { + opsToRewrite.insert(loopOp.getOperation()); + OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + auto yieldOp = blockArg.getOwner()->getTerminator(); + yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + opsToRewrite.insert(yieldOp); + } + } + } + slice.set_subtract(valuesWithExistingRemat); + opsToRewrite = multiRootTopologicalSort(opsToRewrite); + + // replaceAllUsesWith calls delayed until after initial rewrite. + // This is required for slice.count(value) to work mid rewrite. + SmallVector> replacements; + + SmallVector deadOps; + IRRewriter builder(slice.begin()->getContext()); + for (Operation *op : opsToRewrite) { + if (auto forOp = dyn_cast(op)) { + // Keep a mapping of the operands index to the new operands index. + SmallVector> argMapping; + SmallVector newOperands; + for (auto arg : forOp.getRegionIterArgs()) { + if (slice.count(arg)) { + OpOperand &initVal = *forOp.getTiedLoopInit(arg); + argMapping.push_back(std::make_pair( + forOp.getTiedLoopResult(&initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); + newOperands.push_back(mapping.lookup(initVal.get())); + } + } + // Create a new for loop with the new operands. + scf::ForOp newForOp = replaceForOpWithNewSignature( + builder, forOp, newOperands, replacements); + deadOps.push_back(forOp.getOperation()); + Block &loopBody = *newForOp.getBody(); + for (auto m : argMapping) { + mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second)); + int numIndVars = newForOp.getNumInductionVars(); + mapping.map(loopBody.getArgument(m.first + numIndVars), + loopBody.getArgument(m.second + numIndVars)); + LLVM_DEBUG({ + DBGS() << "mapping forOp " + << loopBody.getArgument(m.first + numIndVars) << " to " + << loopBody.getArgument(m.second + numIndVars) << '\n'; + }); + // The result is not in the layout/slice, the argument is. + Value oldArg = loopBody.getArgument(m.first + numIndVars); + addRematValue(newForOp.getResult(m.first), layout[oldArg], + newForOp.getResult(m.second)); + addRematValue(oldArg, layout[oldArg], + loopBody.getArgument(m.second + numIndVars)); + } + continue; + } + if (auto ifOp = dyn_cast(op)) { + SmallVector newTypes; + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + auto it = layout.find(res); + assert(it != layout.end()); + + auto oldType = cast(res.getType()); + auto newType = RankedTensorType::get( + oldType.getShape(), oldType.getElementType(), it->second); + newTypes.push_back(newType); + } + } + scf::IfOp newIfOp = + replaceIfOpWithNewSignature(builder, ifOp, newTypes, replacements); + unsigned oldIdx = 0; + unsigned newIdx = ifOp.getNumResults(); + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + // Why can't we use res instead of ifOp.getResult(oldIdx)? + mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx)); + addRematValue(ifOp.getResult(oldIdx), layout[res], + newIfOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; + } + deadOps.push_back(ifOp.getOperation()); + continue; + } + builder.setInsertionPoint(op); + if (auto yieldOp = dyn_cast(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + SmallVector operandsToRewrite = yieldOperandsMap[op]; + // Sort so that operands are added in the same order as the new scf + // results/arguments. + std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); + for (int operandIdx : operandsToRewrite) { + yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); + } + builder.create(op->getLoc(), yieldOperands); + op->erase(); + continue; + } + if (isa(op)) { + Operation *newOp = builder.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), + layout[op->getResult(0)]); + auto cvt = builder.create(op->getLoc(), newType, + newOp->getResult(0)); + mapping.map(op->getResult(0), cvt.getResult()); + addRematValue(op->getResult(0), layout[op->getResult(0)], + cvt.getResult()); + continue; + } + Operation *newOp = builder.clone(*op, mapping); + for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + auto it = layout.find(old); + if (it == layout.end()) + continue; + auto newType = RankedTensorType::get( + cast(old.getType()).getShape(), + cast(old.getType()).getElementType(), it->second); + newV.setType(newType); + addRematValue(old, it->second, newV); + } + } + // Check mapping and see if there are existing convertOps on the old Argument + convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); + opToDelete.insert(convertOp); + + updateRematMapping(replacements); + for (auto &kv : replacements) { + builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + + for (Operation *op : deadOps) + opToDelete.insert(op); +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp) { + IRMapping mapping; + rewriteSlice(slice, layout, convertOp, mapping); +} + +LogicalResult LayoutRematerialization::getConvertBackwardSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation) { + // Allow re-using existing conversions for a value. Check dominance of any + // reusable materializations against the root value. This is sufficient + // because the conversions are processed in post-order. + auto getExistingConversion = [&](OpOperand &value, Attribute encoding) { + Value remat = getRematValue(value.get(), encoding); + if (!remat) + return Value(); + // `value` can be replaced with an existing rematerialization if it + // dominates the current use of value. + Operation *user = value.getOwner(); + if (domInfo.properlyDominates(remat, user)) { + return remat; + } + // FIXME: If the current user is a conversion, then we know it will become + // a no-op when its operand is replaced with `remat`, but we need to check + // that its users are all dominated by `remat` so the IR is valid. + // if (isa(user) && remat.getDefiningOp() && + // domInfo.properlyDominates(user, remat.getDefiningOp())) { + // for (Operation *op : user->getUsers()) { + // if (!domInfo.dominates(remat, op)) + // return Value(); + // } + // return remat; + // } + return Value(); + }; + + return mlir::getConvertBackwardSlice(root, slice, rootEncoding, layout, + stopPropagation, getExistingConversion); +} + +LogicalResult LayoutRematerialization::getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation) { + LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice, + layout, stopPropagation); + if (result.failed() || slice.empty()) + return failure(); + + // Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return failure(); + } + } + return success(); +} + +void LayoutRematerialization::backwardRematerialization() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + backwardRematerialization(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertOnTopOfExtOrBroadcast(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertIntoConditionals() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertIntoConditionals(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::backwardRematerialization( + ConvertLayoutOp convertOp) { + // DotOperand is hoisted by hoistDotOperand + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + Value oldV = convertOp.getSrc(); + LDBG("check backward remat with source " << oldV << " encoding " + << targetType.getEncoding()); + // Check to see if there are existing remat'ed values for the pair of oldValue + // and encoding. Make sure it dominates the current conversion. + Value newV = getRematValue(oldV, targetType.getEncoding()); + if (newV && domInfo.properlyDominates(newV, convertOp)) { + // Replace it with the remat'ed value. + convertOp.replaceAllUsesWith(newV); + opToDelete.insert(convertOp); + LDBG("found remat'ed value" << newV); + return; + } + + // 1. Take a backward slice of all the tensor dependencies that can be + // rematerialized. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout); + if (result.failed()) { + LDBG(" getRematerializableSlice failed"); + return; + } + + LLVM_DEBUG({ + DBGS() << " remat convert op " << convertOp << '\n'; + for (Value v : slice) + DBGS() << " " << v << '\n'; + }); + // 2. Rewrite the slice. + rewriteSlice(slice, layout, convertOp); +} + +void LayoutRematerialization::hoistConvertDotOperand() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertDotOperand(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertDotOperand( + ConvertLayoutOp convertOp) { + auto targetType = convertOp.getType(); + // The pass is targeted to Nvidia mma/wgmma dot operands + + auto canBePipelined = [&](ConvertLayoutOp convertOp) { + // FIXME: Check that the parent is a for loop + auto parent = convertOp->getParentOp(); + if (!parent) + return false; + + // Find all the dot-like ops in the for loop that have a nvidia dot operand + // encoding on the lhs and check if any of them post-dominates the load + + // cvt + SmallVector dotLikeOps; + parent->walk([&](Operation *op) { + if (!isa(op)) + return; + auto opType = dyn_cast(op->getOperand(0).getType()); + if (!opType) + return; + auto dotEnc = dyn_cast(opType.getEncoding()); + if (!dotEnc) + return; + if (isa(dotEnc.getParent())) + dotLikeOps.push_back(op); + }); + if (dotLikeOps.empty()) + return false; + return llvm::any_of(dotLikeOps, [&](Operation *dot) { + return postDomInfo.postDominates(dot, convertOp); + }); + }; + + // We move convert #dot_operand next to their loads. This is done + // so that it's then easy to pipeline these loads + if (!canBePipelined(convertOp)) + return; + + // We hoist over any operation that can be done without data movement between + // threads We do views and elementwise pure ops for now + auto noDataMovement = [](Operation *op) { + return (op->hasTrait() && isMemoryEffectFree(op)) || + isa(op) || isView(op); + }; + // Stop the slice as soon as we find an operation that cannot be done without + // data movement between threads + auto stop = std::not_fn(noDataMovement); + + SetVector slice; + DenseMap layout; + // Set-up the conversion "cache" + LogicalResult result = getConvertBackwardSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop); + if (result.failed()) + return; + + IRMapping mapping; + OpBuilder builder(convertOp.getContext()); + SetVector innerSlice; + for (Value v : slice) { + if (!v.getDefiningOp()) { + LLVM_DEBUG( + { DBGS() << " Block arguments not supported. Got " << v << "\n"; }); + return; + } + auto loadOp = dyn_cast(v.getDefiningOp()); + // We expect the leaves of the slice to be Load or arith::Constant + // This could be generalised if necessary + if (!loadOp) { + auto op = v.getDefiningOp(); + if (isa(op) || noDataMovement(op)) { + innerSlice.insert(v); + continue; + } else { + LLVM_DEBUG({ + DBGS() << " Leaves must be Load or Constant. Got " << v << "\n"; + }); + return; + } + } + builder.setInsertionPointAfter(loadOp); + auto type = dyn_cast(loadOp.getType()); + if (!type) + continue; + auto newType = RankedTensorType::get(type.getShape(), type.getElementType(), + layout[loadOp]); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, loadOp.getResult()); + mapping.map(loadOp.getResult(), newConvertOp.getResult()); + } + + if (innerSlice.empty()) { + return; + } + + LLVM_DEBUG({ + DBGS() << " Hoisting " << convertOp << '\n'; + for (Value v : innerSlice) + DBGS() << " " << v << '\n'; + }); + + rewriteSlice(innerSlice, layout, convertOp, mapping); +} + +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( + ConvertLayoutOp convertOp) { + // DotOperand is hoisted by hoistDotOperand + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + + auto isExtOrBroadcastOp = [](Operation *op) { + if (isa(op)) { + return true; + } + if (auto fpToFpOp = dyn_cast(op)) { + auto srcType = cast(fpToFpOp.getOperand().getType()); + return getElementBitWidth(srcType) < + getElementBitWidth(cast(fpToFpOp.getType())); + } + return false; + }; + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, + isExtOrBroadcastOp); + if (result.failed()) + return; + + Operation *extOrBroadcatOp = nullptr; + unsigned sliceSize = slice.size(); + for (unsigned i = 0; i < sliceSize; i++) { + Value v = slice[i]; + Operation *op = v.getDefiningOp(); + if (!op) + continue; + if (isExtOrBroadcastOp(op)) { + SetVector tempSlice; + DenseMap tempLayout; + Attribute srcEncoding = inferSrcEncoding(op, layout[v]); + if (!srcEncoding) + return; + LogicalResult result = getRematerializableSlice( + op->getOpOperand(0), srcEncoding, tempSlice, tempLayout); + // If we can rematerialize the rest of the ext slice we can ignore this + // ext as it won't need a convert. + if (result.succeeded()) { + slice.insert(tempSlice.begin(), tempSlice.end()); + layout.insert(tempLayout.begin(), tempLayout.end()); + continue; + } + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOrBroadcatOp != nullptr) + return; + extOrBroadcatOp = op; + } + } + + if (extOrBroadcatOp == nullptr) + return; + Attribute dstEncoding = layout[extOrBroadcatOp->getResult(0)]; + Attribute srcEncoding = inferSrcEncoding(extOrBroadcatOp, dstEncoding); + if (!srcEncoding) + return; + // Move the convert before the ext op and rewrite the slice. + OpBuilder builder(extOrBroadcatOp); + auto tensorType = + cast(extOrBroadcatOp->getOperand(0).getType()); + auto newType = RankedTensorType::get( + tensorType.getShape(), tensorType.getElementType(), srcEncoding); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); + newConvertOp->setAttrs(convertOp->getAttrs()); + Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp); + newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); + auto oldExtOrBroadcastType = + cast(extOrBroadcatOp->getResult(0).getType()); + Type newExtOrBroadcasrType = RankedTensorType::get( + oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(), + dstEncoding); + newExtOrBroadcast->getResult(0).setType(newExtOrBroadcasrType); + IRMapping mapping; + mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0)); + slice.remove(extOrBroadcatOp->getResult(0)); + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp, mapping); +} + +void LayoutRematerialization::hoistConvertIntoConditionals( + ConvertLayoutOp convertOp) { + // Take the backward slice of tensor dependencies rooted at the conversion, + // stopping at conditionals. This subslice is used to initialize the analysis. + SetVector slice; + DenseMap layout; + auto isIfOp = [](Operation *op) { return isa(op); }; + if (failed(getRematerializableSlice(convertOp.getSrcMutable(), + convertOp.getType().getEncoding(), slice, + layout, isIfOp))) + return; + + // These are the conditional edges above which conversions should be hoisted. + // The value represents the `scf.if` op result and the operand represents the + // edge into one of the branches. + SmallVector> hoistAbove; + + // The list of `scf.if` op results in the slice that are not rematerializable. + // Hoisting is terminated at these values. + SmallVector terminals; + + // This loop recurses through the subslices of the backwards dependencies, so + // re-query the size of `slice`. + for (unsigned i = 0; i != slice.size(); ++i) { + Value v = slice[i]; + auto ifOp = v.getDefiningOp(); + if (!ifOp) + continue; + + Attribute rootLayout = layout.at(v); + unsigned resIdx = cast(v).getResultNumber(); + + // Take the backward slice along each branch. + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); + + OpOperand &thenRes = thenYield.getResultsMutable()[resIdx]; + OpOperand &elseRes = elseYield.getResultsMutable()[resIdx]; + + SetVector thenSlice, elseSlice; + DenseMap thenLayout, elseLayout; + + LogicalResult thenResult = getRematerializableSlice( + thenRes, rootLayout, thenSlice, thenLayout, isIfOp); + LogicalResult elseResult = getRematerializableSlice( + elseRes, rootLayout, elseSlice, elseLayout, isIfOp); + + // If propagation across both edges of this conditional succeeded, then we + // don't need to hoist across it. Merge into the current slice. + if (succeeded(thenResult) && succeeded(elseResult)) { + slice.insert(thenSlice.begin(), thenSlice.end()); + slice.insert(elseSlice.begin(), elseSlice.end()); + layout.insert(thenLayout.begin(), thenLayout.end()); + layout.insert(elseLayout.begin(), elseLayout.end()); + continue; + } + + // If propagation across both edges failed, then this conditional + // terminates backwards rematerialization. + if (failed(thenResult) && failed(elseResult)) { + terminals.push_back(cast(v)); + continue; + } + + // Only hoist into conditionals inside loops. The assumption is that an if + // inside a loop executes fewer than the total number of loop iterations, + // making this hoist profitable. + if (!isa(ifOp->getParentOp())) { + terminals.push_back(cast(v)); + continue; + } + + // The layout conversion can be rematerialized along one edge but not the + // other. We can hoist the conversion into the other branch. Push this + // into the subslice list for analysis. + if (succeeded(thenResult)) { + hoistAbove.emplace_back(v, &elseRes); + slice.insert(thenSlice.begin(), thenSlice.end()); + layout.insert(thenLayout.begin(), thenLayout.end()); + } else { + hoistAbove.emplace_back(v, &thenRes); + slice.insert(elseSlice.begin(), elseSlice.end()); + layout.insert(elseLayout.begin(), elseLayout.end()); + } + } + + // Exit early if there is nothing to do. + if (hoistAbove.empty()) + return; + + // Rematerialize failed hoists right before the condtional, and hoist those + // that succeeded into the branch and then rewrite the slice. + IRMapping mapping; + auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) { + auto tensorType = cast(v.getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value newCvt = b.create(convertOp.getLoc(), newType, v); + + mapping.map(v, newCvt); + slice.remove(v); + }; + for (Value v : terminals) { + OpBuilder b(v.getContext()); + b.setInsertionPointAfter(v.getDefiningOp()); + hoistRemat(b, v, layout.at(v)); + } + for (auto [result, edge] : hoistAbove) { + OpBuilder b(edge->getOwner()); + hoistRemat(b, edge->get(), layout.at(result)); + } + rewriteSlice(slice, layout, convertOp, mapping); +} + +void backwardRematerialization(ModuleOp module) { + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.backwardRematerialization(); + layoutRemat.cleanup(); + }); +} + +void hoistConvert(ModuleOp module) { + SmallVector convertOps; + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.hoistConvertOnTopOfExtOrBroadcast(); + layoutRemat.cleanup(); + + layoutRemat = LayoutRematerialization(funcOp); + layoutRemat.hoistConvertIntoConditionals(); + layoutRemat.cleanup(); + + layoutRemat = LayoutRematerialization(funcOp); + layoutRemat.hoistConvertDotOperand(); + layoutRemat.cleanup(); + }); +} +} // namespace + +class TritonGPURemoveLayoutConversionsPass + : public impl::TritonGPURemoveLayoutConversionsBase< + TritonGPURemoveLayoutConversionsPass> { +public: + // Cleanup convert ops. + void cleanupConvertOps() { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + RewritePatternSet cleanUpPatterns(context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); + if (applyPatternsGreedily(m, std::move(cleanUpPatterns)).failed()) { + signalPassFailure(); + } + + LLVM_DEBUG({ + DBGS() << "Module after canonicalizing:\n"; + m.dump(); + }); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + // 1. Propagate layout forward starting from "anchor" ops. + m.walk([](FuncOp funcOp) { + LayoutPropagation layoutPropagation(funcOp); + layoutPropagation.initAnchorLayout(); + layoutPropagation.propagateLayout(); + layoutPropagation.resolveConflicts(); + layoutPropagation.rewrite(); + }); + + LLVM_DEBUG({ + DBGS() << "Module after propagating layouts forward:\n"; + m.dump(); + }); + + cleanupConvertOps(); + + // 2. For remaining convert ops, try to rematerialize the slice of producer + // operation to avoid having to convert. + backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // Cleanup dummy converts created during backward remat. + cleanupConvertOps(); + + // 3. For remaining converts, try to hoist them above cast generating larger + // size types in order to reduce the cost of the convert op. + hoistConvert(m); + LLVM_DEBUG({ + DBGS() << "Module after hoisting converts:\n"; + m.dump(); + }); + + // 4. Apply clean up patterns to remove remove dead convert and dead code + // generated by the previous transformations. + RewritePatternSet cleanUpPatterns2(context); + DenseSet opsCanBeTriviallyDead; + populateForOpDeadArgumentElimination(cleanUpPatterns2, + opsCanBeTriviallyDead); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + if (applyPatternsGreedily(m, std::move(cleanUpPatterns2)).failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after final cleanups:\n"; + m.dump(); + }); + } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp new file mode 100644 index 000000000..da6afb9a6 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -0,0 +1,141 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREORDERINSTRUCTIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static bool willIncreaseRegisterPressure(Operation *op) { + if (isa(op)) + return true; + auto cvt = dyn_cast(op); + if (!cvt) + return false; + if (mlir::isa( + cvt.getType().getEncoding())) + return true; + return false; +} + +class TritonGPUReorderInstructionsPass + : public impl::TritonGPUReorderInstructionsBase< + TritonGPUReorderInstructionsPass> { +public: + TritonGPUReorderInstructionsPass() = default; + + Operation *getFirstUse(Operation *op) { + std::vector users; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + users.push_back(ancestor); + } + auto minOpIt = + llvm::min_element(users, [](mlir::Operation *a, mlir::Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != users.end() ? *minOpIt : nullptr; + } + + void runOnOperation() override { + ModuleOp m = getOperation(); + mlir::DominanceInfo dom(m); + // sink conversion after the last dealloc + // before the first use ancestor in its block + m.walk([&](triton::gpu::ConvertLayoutOp op) { + auto curr = mlir::Block::iterator(op); + auto end = op->getBlock()->end(); + for (; curr != end && &*curr != getFirstUse(op); curr++) + if (isa(&*curr)) + op->moveAfter(&*curr); + }); + // Sink conversions into loops when they will increase + // register pressure + DenseMap opToMove; + auto moveAfter = [](Operation *lhs, Operation *rhs) { + lhs->moveAfter(rhs); + }; + m.walk([&](Operation *op) { + if (!willIncreaseRegisterPressure(op)) + return; + auto user_begin = op->user_begin(); + auto user_end = op->user_end(); + if (std::distance(user_begin, user_end) != 1) + return; + if (user_begin->getParentOfType() == + op->getParentOfType()) + return; + opToMove.insert({op, *user_begin}); + }); + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); + // Move alloc(load) immediately after dependent load + m.walk([&](triton::gpu::LocalAllocOp op) { + if (!op.getSrc()) + return; + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move transpositions just after their definition + opToMove.clear(); + m.walk([&](triton::TransposeOpInterface op) { + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move `dot` operand so that conversions to opIdx=1 happens after + // conversions to opIdx=0 + m.walk([&](triton::gpu::LocalLoadOp op) { + auto dstEncoding = mlir::dyn_cast( + op.getType().getEncoding()); + if (!dstEncoding) + return; + int opIdx = dstEncoding.getOpIdx(); + if (opIdx != 1) + return; + if (!op->hasOneUse()) + return; + auto dotUser = dyn_cast(*op->user_begin()); + if (!dotUser) + return; + auto AOp = + dotUser.getOperand(0).getDefiningOp(); + if (!AOp) + return; + // Check that the conversion to OpIdx=1 happens before and can be moved + // after the conversion to OpIdx=0. + if (!dom.dominates(op.getOperation(), AOp.getOperation())) + return; + moveAfter(op, AOp); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp new file mode 100644 index 000000000..db74231d3 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp @@ -0,0 +1,487 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "triton-gpu-taskid-propagate" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = ::mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTASKIDPROPAGATE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return all Ops that are marked with target task +void getAsyncTaskOps(triton::FuncOp funcOp, DenseSet &asyncTaskOps, + int asyncTaskId) { + funcOp.walk([&](Operation *op) -> void { + if (auto attr = + op->getAttrOfType("async_task_id")) { + for (auto val : attr.getValues()) { + if (val == asyncTaskId) { + asyncTaskOps.insert(op); + break; + } + } + } + }); +} + +void getAllParentOps(DenseSet &parentOps, Operation *targetOp) { + auto op = targetOp; + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + parentOps.insert(parent); + op = parent; + } else { + break; + } + } +} + +void getAllParentOps(triton::FuncOp funcOp, DenseSet &parentOps, + int asyncTaskId) { + DenseSet targetOps; + getAsyncTaskOps(funcOp, targetOps, asyncTaskId); + for (auto op : targetOps) { + getAllParentOps(parentOps, op); + } +} + +void labelByUsers(Operation *op, ArrayRef allAsyncTasks) { + for (Value result : op->getResults()) { + for (Operation *userOp : result.getUsers()) { + if (!userOp->hasAttr("async_task_id")) { + labelByUsers(userOp, allAsyncTasks); + } + addAsyncTaskIds(op, getAsyncTaskIds(userOp)); + } + } + if (!op->hasAttr("async_task_id")) { + addAsyncTaskIds(op, allAsyncTasks); + } +} + +/// Because we set some special filter rules in populateAsyncTaskRegion, +/// there may be unlabeled Ops, e.g. YieldOps, some definingOps of ForOps. +/// or Ops without relations to asyncTaskOps +void populateUnlabledOpsAtLast(triton::FuncOp funcOp, + ArrayRef allAsyncTasks) { + // Label asyncTasks' parentOps + for (int i : allAsyncTasks) { + DenseSet asyncTaskParentOps; + getAllParentOps(funcOp, asyncTaskParentOps, i); + for (auto op : asyncTaskParentOps) { + addAsyncTaskIds(op, {i}); + } + } + + // Get unlabeled Ops + DenseSet unlabeledOps; + funcOp.walk([&](Operation *op) -> void { + if (isa(op) || isa(op) || + isa(op)) { + return; + } + if (!op->hasAttr("async_task_id")) { + unlabeledOps.insert(op); + } + }); + + // Label Ops using its parentOp + for (auto op : unlabeledOps) { + if (auto parent = op->getParentOp()) { + if (!isa(parent)) { + if (!parent->hasAttr("async_task_id")) { + LLVM_DEBUG({ + LDBG("op and parent: "); + op->dump(); + parent->dump(); + }); + continue; + } + assert(parent->hasAttr("async_task_id")); + auto asyncTasks = getAsyncTaskIds(parent); + setAsyncTaskIds(op, asyncTasks); + unlabeledOps.erase(op); + } + } + } + + // Label Ops using dependency + for (auto op : unlabeledOps) { + labelByUsers(op, allAsyncTasks); + unlabeledOps.erase(op); + } + assert(unlabeledOps.size() == 0); +} + +#ifndef NDEBUG +static bool oneVecCoversTheOther(SmallVector &one, + SmallVector &other) { + // Every element of other appears in one. + for (AsyncTaskId t : other) { + // If t doesn't appear in one, return false. + bool found = false; + for (AsyncTaskId t2 : one) { + if (t2 == t) { + found = true; + break; + } + } + if (!found) + return false; + } + return true; +} + +struct AsyncTaskIdsCompare { + static SmallVector getEmptyKey() { + SmallVector V; + V.push_back(reinterpret_cast(-1)); + return V; + } + + static SmallVector getTombstoneKey() { + SmallVector V; + V.push_back(reinterpret_cast(-2)); + return V; + } + + static unsigned getHashValue(const SmallVector &V) { + return static_cast(llvm::hash_combine_range(V.begin(), V.end())); + } + + static bool isEqual(const SmallVector &LHS, + const SmallVector &RHS) { + return LHS == RHS; + } +}; + +// Make sure the def chain contains the right taskId. +bool verifyTaskId(triton::FuncOp &funcOp, + const llvm::DenseSet &anchorOps) { + bool retCode = true; + DenseSet, AsyncTaskIdsCompare> anchorAsyncTasks; + for (auto anchorOp : anchorOps) { + anchorAsyncTasks.insert(getAsyncTaskIds(anchorOp)); + } + + funcOp.walk([&](Operation *op) { + // Skip control ops + if (llvm::isa(op)) + return; + + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.empty()) { + LLVM_DEBUG({ + LDBG("Op does not have task id"); + op->dump(); + }); + llvm_unreachable("Op does not have task id"); + } + + auto partitionShouldBeUsedSpecified = [](Operation *op) { + if (isa(op)) + return true; + if (isa(op)) + return true; + if (isa(op)) + return true; + return false; + }; + + if (!anchorAsyncTasks.contains(asyncTaskIds)) { + if (partitionShouldBeUsedSpecified(op)) { + LLVM_DEBUG({ + LDBG("async tasks not specified by user"); + op->dump(); + }); + llvm_unreachable("async tasks not specified by user"); + } + } + + assert(!asyncTaskIds.empty() && "Op does not have task id"); + + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + if (llvm::isa(defOp)) + continue; + auto defTaskIds = getAsyncTaskIds(defOp); + // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if + // necessary. + LLVM_DEBUG({ + if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { + // print defOp and op + LDBG("Def op does not cover op"); + LDBG("Def op"); + defOp->dump(); + LDBG("op"); + op->dump(); + } + }); + assert(oneVecCoversTheOther(defTaskIds, asyncTaskIds) && + "defTaskIds should cover asyncTaskIds"); + } + }); + return retCode; +} +#endif + +void backwardPropagateTaskIds(Operation *op, + const llvm::DenseSet &anchors) { + SmallVector queue; + auto asyncTasks = getAsyncTaskIds(op); + for (Value operand : op->getOperands()) { + queue.push_back(operand); + } + + DenseSet seen; + for (auto anchor : anchors) { + if (anchor != op) + for (auto result : anchor->getResults()) + seen.insert(result); + } + + while (!queue.empty()) { + auto value = queue.pop_back_val(); + if (!seen.insert(value).second) { + continue; + } + + // Handle BlockArguments of for loops (i.e. loop carried dependences). + if (auto blockArg = dyn_cast(value)) { + auto parent = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(parent)) { + // Propagate to the control operands. + auto control = + forOp.getOperands().take_front(forOp.getNumControlOperands()); + queue.insert(queue.end(), control.begin(), control.end()); + // Propagate to the initializer. + if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { + queue.push_back(forOp.getTiedLoopInit(blockArg)->get()); + // Propagate to the yield. + auto idx = blockArg.getArgNumber() - forOp.getNumInductionVars(); + queue.push_back(forOp.getBody()->getTerminator()->getOperand(idx)); + addAsyncTaskIds(forOp, asyncTasks); + } + } + continue; + } + + auto op = value.getDefiningOp(); + if (!anchors.count(op)) + addAsyncTaskIds(op, asyncTasks); + + // Handle for loops. + if (auto forOp = dyn_cast(op)) { + // Propagate to control operands. + auto control = + forOp.getOperands().take_front(forOp.getNumControlOperands()); + queue.insert(queue.end(), control.begin(), control.end()); + // Propagate to arguments. + unsigned idx = cast(value).getResultNumber(); + queue.push_back(forOp.getOperand(idx + forOp.getNumControlOperands())); + // Propagate to yield. + queue.push_back(forOp.getBody()->getTerminator()->getOperand(idx)); + continue; + } + + // Handle conditionals. + if (auto ifOp = dyn_cast(op)) { + queue.push_back(ifOp.getCondition()); + unsigned idx = cast(value).getResultNumber(); + if (ifOp.elseBlock()) { + queue.push_back(ifOp.elseYield()->getOperand(idx)); + } + queue.push_back(ifOp.thenYield()->getOperand(idx)); + continue; + } + + // Handle normal ops. + for (Value operand : op->getOperands()) { + queue.push_back(operand); + } + } +} + +void backwardPropagateTaskIds(llvm::DenseSet &rootOps, + llvm::DenseSet &anchorOps) { + for (Operation *op : rootOps) { + backwardPropagateTaskIds(op, anchorOps); + } +} + +void forwardPropagateTaskIds(Operation *root, + const llvm::DenseSet &anchors) { + auto asyncTasks = getAsyncTaskIds(root); + SmallVector queue; + for (Value result : root->getResults()) + queue.push_back(result); + + DenseSet seen; + for (auto anchor : anchors) { + if (anchor != root) + for (auto result : anchor->getResults()) + seen.insert(result); + } + + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!seen.insert(v).second) + continue; + + for (Operation *depOp : v.getUsers()) { + auto depAsyncTasks = getAsyncTaskIds(depOp); + // Skip depOp that already has task ids. Those could be either anchorOps + // or propagated backward from anchor ops. + if (!depAsyncTasks.empty() && depAsyncTasks != asyncTasks) + continue; + setAsyncTaskIds(depOp, asyncTasks); + // Go through yieldOp to propagate task ids to the result of parentOp. + if (auto yieldOp = dyn_cast(depOp)) { + auto parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (operand.get() == v) { + queue.push_back(parentOp->getResult(operand.getOperandNumber())); + break; + } + } + } else { + for (Value result : depOp->getResults()) + queue.push_back(result); + } + } + } +} + +void forwardPropagateTaskIds(llvm::DenseSet &anchorOps) { + for (Operation *op : anchorOps) { + forwardPropagateTaskIds(op, anchorOps); + } +} + +void populateTaskIdsForControlDependencies( + llvm::DenseSet &anchorOps) { + for (auto op : anchorOps) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (!asyncTaskIds.empty()) { + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + setAsyncTaskIds(parent, asyncTaskIds); + backwardPropagateTaskIds(parent, anchorOps); + op = parent; + } else { + break; + } + } + } + } +} + +class TritonGPUTaskIdPropagatePass + : public impl::TritonGPUTaskIdPropagateBase { +public: + using impl::TritonGPUTaskIdPropagateBase< + TritonGPUTaskIdPropagatePass>::TritonGPUTaskIdPropagateBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + llvm::DenseSet anchorOps; + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (asyncTasks.empty()) + return; + std::sort(asyncTasks.begin(), asyncTasks.end()); + setAsyncTaskIds(op, asyncTasks); + if (!isa(op)) + anchorOps.insert(op); + }); + + // If there is no anchorOp, task id propagation is not needed. + if (anchorOps.empty()) + return; + populateTaskIdsForControlDependencies(anchorOps); + + LLVM_DEBUG({ + LDBG("after populateTaskIdsForControlDependencies "); + funcOp->dump(); + }); + + backwardPropagateTaskIds(anchorOps, anchorOps); + + LLVM_DEBUG({ + LDBG("after backwardPropagateTaskIds "); + funcOp->dump(); + }); + + forwardPropagateTaskIds(anchorOps); + + LLVM_DEBUG({ + LDBG("after forwardPropagateTaskIds "); + funcOp->dump(); + }); + + llvm::DenseSet rootOps; + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (!asyncTasks.empty() && + !isa(op)) + rootOps.insert(op); + }); + backwardPropagateTaskIds(rootOps, anchorOps); + LLVM_DEBUG({ + LDBG("after final backwardPropagateTaskIds "); + funcOp->dump(); + }); + + DenseSet allAsyncTasks; + funcOp->walk([&](Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + allAsyncTasks.insert(asyncTasks.begin(), asyncTasks.end()); + }); + SmallVector allAsyncTasksVec(allAsyncTasks.begin(), + allAsyncTasks.end()); + populateUnlabledOpsAtLast(funcOp, allAsyncTasksVec); + + LLVM_DEBUG({ + LDBG("after populateUnlabledOpsAtLast "); + funcOp->dump(); + }); + +#ifndef NDEBUG + verifyTaskId(funcOp, anchorOps); +#endif + } + + void runOnOperation() override { + if (numConsumerGroups == 0) { + getOperation()->walk([&](triton::FuncOp funcOp) { + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (!asyncTasks.empty()) + op->removeAttr("async_task_id"); + }); + }); + return; + } + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..0adf3d949 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -0,0 +1,1280 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "ttg-utility" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace mlir { + +using namespace triton; + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + Type eltType, int numWarps) { + if (version == 1) + return {16, 16}; + else if (version == 2) { + auto rank = shape.size(); + SmallVector ret(rank, 1); + ret[rank - 1] = 8; + ret[rank - 2] = 16; + return ret; + } else if (version == 3) { + unsigned k = 256 / eltType.getIntOrFloatBitWidth(); + if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { + assert(false && "type not supported"); + return {0, 0, 0}; + } + SmallVector validN; + + // MMAv3 with larger instruction shape is preferred. + if (llvm::isa( + eltType) || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, + 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); + } + + if (eltType.isInteger(8)) { + validN.assign({224, 208, 192, 176, 160, 144, 128, 112, 96, 80, 64, 48, 32, + 24, 16, 8}); + } + + unsigned m = 16; + unsigned mWarps = std::max(shape[0] / m, 1); + unsigned nWarps = std::max(numWarps / mWarps, 1); + unsigned maxN = std::max(shape[1] / nWarps, 8); + for (auto n : validN) { + if (shape[1] % n == 0 && n <= maxN) { + return {m, n, k}; + } + } + + assert(false && "type not supported"); + return {0, 0, 0}; + } else if (version == 5) { + unsigned m = shape[0] >= 128 ? 128 : 64; + // Right now default to distributing along N. TODO: For cases where we have + // dot followed by reduction we need to be able to distribute along M. + // if (numWarps > 4) + // m = 64; + unsigned n = shape[1] >= 256 ? 256 : shape[1]; + unsigned k = 256 / eltType.getIntOrFloatBitWidth(); + return {m, n, k}; + } else { + assert(false && "version not supported"); + return {0, 0}; + } +} + +bool isLoadFromTensorPtr(triton::LoadOp op) { + return mlir::triton::isTensorPointerType(op.getPtr().getType()); +} + +SmallVector argSort(const SmallVector &arr) { + SmallVector ret(arr.size()); + std::iota(ret.begin(), ret.end(), 0); + std::stable_sort(ret.begin(), ret.end(), + [&](unsigned x, unsigned y) { return arr[x] > arr[y]; }); + return ret; +} + +Value getMemAccessPtr(Operation *op) { + if (auto ld = dyn_cast(op)) + return ld.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto copy = dyn_cast(op)) + return copy.getSrc(); + if (auto store = dyn_cast(op)) + return store.getPtr(); + return nullptr; +} + +unsigned getElementBitWidth(RankedTensorType type) { + auto typeForMem = + isa(type.getElementType()) + ? cast(type.getElementType()).getPointeeType() + : type.getElementType(); + return typeForMem.getIntOrFloatBitWidth(); +} + +unsigned getNumElementsPerThread(Operation *op, SmallVector order, + ModuleAxisInfoAnalysis &axisInfoAnalysis) { + Value val = getMemAccessPtr(op); + auto ty = cast(val.getType()); + auto shapePerCTA = triton::gpu::getShapePerCTA(ty); + AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + unsigned elemNumBits = getElementBitWidth(ty); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); + unsigned maxContig = + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); + unsigned alignment = std::min(maxMultiple, maxContig); + unsigned currPerThread = std::min(alignment, 128 / elemNumBits); + LDBG("elemNumBytes: " << elemNumBytes + << ", divisibility: " << maxMultipleBytes + << ", contig: " << valInfo.getContiguity(order[0]) + << ", alignment: " << alignment); + return currPerThread; +} + +bool isView(Operation *op) { + return isa(op); +} + +//===----------------------------------------------------------------------===// +// GraphDumper +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphDumper::onValue(Value value) const { + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +GraphDumper::NodeInfo GraphDumper::onOperation(Operation *op) const { + return {{"shape", "ellipse"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +std::string GraphDumper::dump(triton::FuncOp func) const { + llvm::SetVector values; + llvm::SetVector operations; + + func.walk([&](Operation *op) { + operations.insert(op); + for (Value operand : op->getOperands()) + values.insert(operand); + for (Value result : op->getResults()) + values.insert(result); + }); + + std::ostringstream oss; + oss << "// Generated by Triton GraphDumper\n" + << "\n" + << "digraph {\n"; + + oss << " // Value Nodes\n"; + for (Value value : values) + oss << " " << emitValueNode(value) << "\n"; + oss << "\n"; + + oss << " // Operation Nodes\n"; + for (Operation *op : operations) + oss << " " << emitOperationNode(op) << "\n"; + oss << "\n"; + + oss << " // Edges\n"; + for (Operation *op : operations) { + for (Value operand : op->getOperands()) + oss << " " << emitEdge(getUniqueId(operand), getUniqueId(op)) << "\n"; + for (Value result : op->getResults()) + oss << " " << emitEdge(getUniqueId(op), getUniqueId(result)) << "\n"; + } + + oss << "}\n"; + return oss.str(); +} + +void GraphDumper::dumpToFile(triton::FuncOp func, + const std::string &filename) const { + std::ofstream ofs(filename); + ofs << dump(func); +} + +std::string GraphDumper::getShapeStr(const Type &type) const { + std::ostringstream oss; + oss << "["; + if (auto tensorTy = dyn_cast(type)) { + auto shape = tensorTy.getShape(); + for (unsigned i = 0; i < shape.size(); ++i) { + if (i > 0) + oss << ", "; + oss << shape[i]; + } + } + oss << "]"; + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Value value) const { + std::ostringstream oss; + oss << value.getImpl(); + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Operation *op) const { + std::ostringstream oss; + oss << op; + return oss.str(); +} + +std::string GraphDumper::emitNode(const std::string &id, + const GraphDumper::NodeInfo info) const { + std::ostringstream oss; + oss << "\"" << id << "\" ["; + for (auto it = info.begin(); it != info.end(); ++it) { + if (it != info.begin()) + oss << ", "; + oss << it->first << " = \"" << it->second << "\""; + } + oss << "];"; + return oss.str(); +} + +std::string GraphDumper::emitEdge(const std::string &srcId, + const std::string &destId) const { + std::ostringstream oss; + oss << "\"" << srcId << "\" -> \"" << destId << "\";"; + return oss.str(); +} + +std::string GraphDumper::emitValueNode(Value value) const { + NodeInfo info = onValue(value); + if (info.find("label") == info.end()) { + std::string shapeStr = getShapeStr(value.getType()); + if (auto arg = mlir::dyn_cast(value)) + info["label"] = + "BlockArg" + std::to_string(arg.getArgNumber()) + " " + shapeStr; + else + info["label"] = shapeStr; + } + return emitNode(getUniqueId(value), info); +} + +std::string GraphDumper::emitOperationNode(Operation *op) const { + NodeInfo info = onOperation(op); + if (info.find("label") == info.end()) + info["label"] = op->getName().getStringRef().str(); + return emitNode(getUniqueId(op), info); +} + +//===----------------------------------------------------------------------===// +// GraphLayoutMarker +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphLayoutMarker::onValue(Value value) const { + std::string color = getColor(value.getType()); + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", color}}; +} + +std::string GraphLayoutMarker::getColor(const Type &type) const { + if (auto tensorTy = dyn_cast(type)) { + auto layout = tensorTy.getEncoding(); + if (isa(layout)) + return "green"; + else if (isa(layout)) + return "yellow"; + else if (isa(layout)) + return "lightslateblue"; + else if (isa(layout)) + return "orange"; + else if (isa(layout)) + return "orangered"; + else { + llvm::report_fatal_error("Unrecognized layout"); + return "unknown"; + } + } else { + return "white"; + } +} +// -------------------------------------------------------------------------- // + +static Attribute inferDstEncoding(triton::ReduceOp op, Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get( + op->getContext(), op.getAxis(), + cast(encoding)); +} + +static Attribute inferDstEncoding(triton::ExpandDimsOp op, Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return {}; + if (op.getAxis() != sliceEncoding.getDim()) + return {}; + return sliceEncoding.getParent(); +} + +static Attribute inferDstEncoding(JoinOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getLhs().getType().getShape(); + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferJoinOpEncoding(srcEnc, dstEnc, shape, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return {}; +} + +static Attribute inferDstEncoding(SplitOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getSrc().getType().getShape(); + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(srcEnc, dstEnc, shape, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(triton::ReduceOp op, Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return {}; + if (op.getAxis() != sliceEncoding.getDim()) + return {}; + return sliceEncoding.getParent(); +} + +static Attribute inferSrcEncoding(triton::ExpandDimsOp op, Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get( + op->getContext(), op.getAxis(), + cast(encoding)); +} + +static Attribute inferSrcEncoding(JoinOp op, Attribute dstEnc) { + // Split is the inverse of join. + auto shape = op.getResult().getType().getShape(); + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(dstEnc, srcEnc, shape, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(SplitOp op, Attribute dstEnc) { + // Join is the inverse of split. + Attribute srcEnc; + auto shape = op.getOutLHS().getType().getShape(); + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferJoinOpEncoding(dstEnc, srcEnc, shape, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(GatherOp op, Attribute dstEnc) { + // The index encoding is the same as the output encoding. + return dstEnc; +} + +static Attribute inferTransOpDstEncoding(Attribute srcEnc, + ArrayRef shape, + ArrayRef order) { + // Simply forward to the existing inferTransOpEncoding function. + Attribute retEncoding; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferTransOpEncoding(srcEnc, shape, order, retEncoding))) { + return retEncoding; + } + return {}; +} + +static Attribute inferDstEncoding(triton::gpu::Fp4ToFpOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getSrc().getType().getShape(); + auto result = + srcEnc.getDialect() + .getRegisteredInterface() + ->inferFp4ToFpOpEncoding(shape, op.getAxis(), srcEnc, dstEnc, + /*fwdInference*/ true, std::nullopt); + assert(succeeded(result)); + return dstEnc; +} + +static Attribute inferSrcEncoding(triton::gpu::Fp4ToFpOp op, Attribute dstEnc) { + Attribute srcEnc; + auto shape = op.getSrc().getType().getShape(); + if (succeeded( + dstEnc.getDialect() + .getRegisteredInterface() + ->inferFp4ToFpOpEncoding(shape, op.getAxis(), dstEnc, srcEnc, + /*fwdInference*/ false, std::nullopt))) { + return srcEnc; + } + return {}; +} + +static Attribute inferDstEncoding(triton::TransposeOpInterface op, + Attribute encoding) { + return inferTransOpDstEncoding( + encoding, cast(op.getSrc().getType()).getShape(), + op.getOrder()); +} + +static Attribute inferSrcEncoding(triton::TransposeOpInterface op, + Attribute encoding) { + // We want to solve for srcEnc in + // transpose(srcEnc, order) -> dstEnc. + // Given the identity + // transpose(transpose(x, order), inverse(order)) == x, + // we can see this is equivalent to + // transpose(dstEnc, inverse(order)) -> srcEnc. + auto shape = cast(op->getResult(0).getType()).getShape(); + return inferTransOpDstEncoding(encoding, shape, + triton::inversePermutation(op.getOrder())); +} + +static Attribute inferReshapeOpDstEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + bool allowReorder) { + // We don't do anything smart to allow-reorder reshapes here. They are + // handled in OptimizeThreadLocality. + if (allowReorder) + return {}; + + Attribute dstEnc; + auto result = + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc, + /*loc=*/std::nullopt); + assert(succeeded(result)); + return dstEnc; +} + +static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) { + return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding, + op.getType().getShape(), + op.getAllowReorder()); +} + +static Attribute inferDstEncoding(GatherOp op, Attribute encoding) { + // The output encoding is the same as the index encoding. + // FIXME: This assumes `encoding` is the index encoding, which can be + // different than the source encoding. + return encoding; +} + +static Attribute inferSrcEncoding(triton::ReshapeOp op, Attribute encoding) { + // The encoding of x given the encoding of y in `reshape(x) -> y` is the same + // as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an + // invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this + // way. + return inferReshapeOpDstEncoding(op.getType().getShape(), encoding, + op.getSrc().getType().getShape(), + op.getAllowReorder()); +} + +static bool isSingleValue(Value value) { + // Don't consider load as expensive if it is loading a scalar. + if (auto tensorTy = dyn_cast(value.getType())) + return tensorTy.getNumElements() == 1; + // TODO: Handle other cases. + // For example, when ptr is a tensor of single value. + // It means that ptr is a resultant of broadcast or generated through + // a chain of broadcast and other operations. + // Rematerialize it without considering contiguous memory access pattern is + // fine. + return true; +} + +Attribute inferSrcEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + // Scan only supports blocked encoding at the moment. + if (!isa(encoding)) + return {}; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) { + return encoding; + } + + if (auto reduceOp = dyn_cast(op)) + return inferSrcEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferSrcEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferSrcEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferSrcEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferSrcEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferSrcEncoding(reshape, encoding); + if (auto gather = dyn_cast(op)) + return inferSrcEncoding(gather, encoding); + if (auto fp4ToFp = dyn_cast(op)) + return inferSrcEncoding(fp4ToFp, encoding); + + return {}; +} + +Attribute inferDstEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + if (!isa(encoding)) + return {}; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) + return encoding; + if (auto reduceOp = dyn_cast(op)) + return inferDstEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferDstEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferDstEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferDstEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferDstEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferDstEncoding(reshape, encoding); + if (auto gather = dyn_cast(op)) + return inferDstEncoding(gather, encoding); + if (auto fp4ToFp = dyn_cast(op)) + return inferDstEncoding(fp4ToFp, encoding); + + return {}; +} + +bool isExpensiveLoadOrStore(Operation *op) { + // Case 1: Pointer of tensor is always expensive + auto operandType = op->getOperand(0).getType(); + if (triton::isTensorPointerType(operandType)) + return true; + // Case 2a: A size 1 tensor is not expensive since all threads will load the + // same + if (isSingleValue(op->getOperand(0))) + return false; + // Case 2b: Tensor of pointers has more threads than elements + // we can presume a high hit-rate that makes it cheap to load + auto ptrType = cast(op->getOperand(0).getType()); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::lookupNumWarps(op); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (ptrType.getNumElements() < numWarps * threadsPerWarp) + return false; + return true; +} + +bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { + if (!op) + return true; + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); + if (isa(op)) + return true; + if (isa( + op)) + return true; + return false; +} + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { + if (isa(op)) + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); + if (auto convert = dyn_cast(op)) { + if (mlir::isa(targetEncoding)) { + auto srcEncoding = convert.getSrc().getType().getEncoding(); + if (targetEncoding != srcEncoding) + return false; + } + return true; + } + + if (auto reshape = dyn_cast(op)) { + auto reshapeDstType = reshape.getType(); + RankedTensorType newDstType = + RankedTensorType::get(reshapeDstType.getShape(), + reshapeDstType.getElementType(), targetEncoding); + return reshape.getAllowReorder() && !reshape.getEfficientLayout() && + !triton::gpu::isExpensiveView(reshape.getSrc().getType(), + newDstType); + } + return isa(op); +} + +scf::ForOp replaceForOpWithNewSignature( + RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInitArgs()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); + newLoop->setAttrs(loop->getAttrs()); + newLoop.getBody()->erase(); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + return newLoop; +} + +scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + SmallVector> replacements; + auto newForOp = replaceForOpWithNewSignature(rewriter, loop, newIterOperands, + replacements); + for (auto &kv : replacements) { + rewriter.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + return newForOp; +} + +scf::WhileOp replaceWhileOpWithNewSignature( + RewriterBase &rewriter, scf::WhileOp loop, ValueRange newIterOperands, + TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInits()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + + // Result and operand types + SmallVector resultTypes; + SmallVector argsTypesBefore; + for (auto res : loop.getResults()) + resultTypes.push_back(res.getType()); + for (auto type : newResultTypes) + resultTypes.push_back(type); + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + scf::WhileOp newLoop = + rewriter.create(loop.getLoc(), resultTypes, operands); + newLoop->setAttrs(loop->getAttrs()); + + SmallVector bbArgLocsBefore(argsTypesBefore.size(), loop.getLoc()); + SmallVector bbArgLocsAfter(resultTypes.size(), loop.getLoc()); + rewriter.createBlock(&newLoop.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newLoop.getAfter(), {}, resultTypes, bbArgLocsAfter); + + // Copy regions + for (int i = 0; i < loop.getNumRegions(); ++i) + newLoop->getRegion(i).front().getOperations().splice( + newLoop->getRegion(i).front().getOperations().begin(), + loop->getRegion(i).front().getOperations()); + + // Remap arguments + for (auto [oldArg, newArg] : llvm::zip( + loop.getBeforeArguments(), newLoop.getBeforeArguments().take_front( + loop.getBeforeArguments().size()))) + rewriter.replaceAllUsesWith(oldArg, newArg); + for (auto [oldArg, newArg] : llvm::zip(loop.getAfterArguments(), + newLoop.getAfterArguments().take_front( + loop.getAfterArguments().size()))) + rewriter.replaceAllUsesWith(oldArg, newArg); + + // Stack the new results + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + + return newLoop; +} + +scf::WhileOp replaceWhileOpWithNewSignature(RewriterBase &rewriter, + scf::WhileOp loop, + ValueRange newIterOperands, + TypeRange newResultTypes) { + SmallVector> replacements; + auto newWhileOp = replaceWhileOpWithNewSignature( + rewriter, loop, newIterOperands, newResultTypes, replacements); + for (auto &kv : replacements) { + rewriter.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + return newWhileOp; +} + +scf::IfOp replaceIfOpWithNewSignature( + RewriterBase &rewriter, scf::IfOp ifOp, TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(ifOp); + + // Create a new loop before the existing one, with the extra operands. + auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); + resultTypes.append(newResultTypes.begin(), newResultTypes.end()); + scf::IfOp newIf = rewriter.create(ifOp.getLoc(), resultTypes, + ifOp.getCondition()); + newIf->setAttrs(ifOp->getAttrs()); + + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + scf::IfOp::ensureTerminator(newIf.getThenRegion(), rewriter, ifOp.getLoc()); + scf::IfOp::ensureTerminator(newIf.getElseRegion(), rewriter, ifOp.getLoc()); + + for (auto it : llvm::zip(ifOp.getResults(), + newIf.getResults().take_front(ifOp.getNumResults()))) + replacements.push_back(it); + return newIf; +} + +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands) { + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +scf::IfOp replaceIfOpWithNewSignature(RewriterBase &rewriter, scf::IfOp ifOp, + TypeRange newResultTypes) { + SmallVector> replacements; + auto newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, newResultTypes, replacements); + for (auto &kv : replacements) + rewriter.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + return newIfOp; +} + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping) { + Operation *newOp = rewriter.clone(*op, mapping); + // if input types haven't changed, we're done + bool preserveTypes = + std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) { + return !mapping.contains(v) || + v.getType() == mapping.lookup(v).getType(); + }); + if (preserveTypes) + return newOp; + + if (newOp->getNumResults() == 0) + return newOp; + auto origType = dyn_cast(op->getResult(0).getType()); + auto argType = dyn_cast(newOp->getOperand(0).getType()); + if (!origType || !argType) + return newOp; + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), argType.getEncoding()); + newOp->getResult(0).setType(newType); + auto typeInfer = dyn_cast(newOp); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newOp->getResult(i).setType(newTypes[i]); + } + } + return newOp; +} + +// Check if the convert will be performed by reordering registers. +static bool isFreeConvert(Operation *op) { + auto convertOp = dyn_cast(op); + if (!convertOp) + return false; + return cvtReordersRegisters(convertOp.getSrc().getType(), + convertOp.getType()); +} + +LogicalResult getConvertBackwardSlice( + OpOperand &root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation, + std::function getExistingConversion) { + DenseSet> seen; + SmallVector> queue; + + auto enqueue = [&](OpOperand &operand, Attribute encoding) { + auto x = std::make_pair(&operand, encoding); + if (!seen.insert(x).second) { + return; // Already enqueued, skip + } + queue.push_back(x); + }; + enqueue(root, rootEncoding); + + auto updateLayout = [&](Value value, Attribute encoding) { + assert((isa(value.getType()))); + slice.insert(value); + Attribute &existing = layout[value]; + if (existing && existing != encoding) + return failure(); + existing = encoding; + return success(); + }; + + while (!queue.empty()) { + auto [currentValueUse, encoding] = queue.back(); + Value currentValue = currentValueUse->get(); + queue.pop_back(); + if (!isa(currentValue.getType())) + continue; + // Skip propagating through for op results for now. + // TODO: enable this based on needs. + if (currentValue.getDefiningOp()) + return failure(); + if (failed(updateLayout(currentValue, encoding))) + return failure(); + + Value existing; + if (getExistingConversion && + (existing = getExistingConversion(*currentValueUse, encoding))) { + if (failed(updateLayout(existing, encoding))) + return failure(); + currentValue = existing; + } + + if (auto ifOp = currentValue.getDefiningOp()) { + if (stopPropagation && stopPropagation(ifOp)) + continue; + unsigned argIdx = mlir::cast(currentValue).getResultNumber(); + + OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx); + OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx); + + enqueue(thenValue, encoding); + enqueue(elseValue, encoding); + + continue; + } + if (auto *definingOp = currentValue.getDefiningOp()) { + // If the op has multiple results we need to update all results layout. + for (Value result : definingOp->getResults()) { + if (result == currentValue || !isa(result.getType())) + continue; + if (failed(updateLayout(result, encoding))) + return failure(); + } + if (isFreeConvert(definingOp)) { + enqueue(definingOp->getOpOperand(0), encoding); + continue; + } + if (canFoldIntoConversion(definingOp, encoding)) + continue; + if (stopPropagation && stopPropagation(definingOp)) + continue; + if (isa(definingOp)) + return failure(); + if (auto gather = dyn_cast(definingOp)) { + // Specially handle gather since its transfer function only applies + // between its index operand and result. + auto srcEncoding = inferSrcEncoding(gather, encoding); + if (!srcEncoding) + return failure(); + enqueue(gather.getIndicesMutable(), srcEncoding); + continue; + } + for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) { + auto srcEncoding = inferSrcEncoding(definingOp, encoding); + if (!srcEncoding) + return failure(); + enqueue(operand, srcEncoding); + } + continue; + } + auto blockArg = cast(currentValue); + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); + OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand( + blockArg.getArgNumber() - forOp.getNumInductionVars()); + enqueue(*initOperand, encoding); + enqueue(yieldOperand, encoding); + continue; + } + // TODO: add support for WhileOp and other region types. + return failure(); + } + return success(); +} + +// TODO(thomas): this is duplicated with what is in GPUToLLVM +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = triton::applyPermutation(shape, order); + auto reorderedMultiDim = delinearize(b, loc, linear, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + if (rank == 1) { + multiDim[0] = linear; + } else { + Value remained = linear; + for (auto &&en : llvm::enumerate(shape.drop_back())) { + auto dimSize = b.create(loc, en.value(), 32); + multiDim[en.index()] = b.create(loc, remained, dimSize); + remained = b.create(loc, remained, dimSize); + } + multiDim[rank - 1] = remained; + } + return multiDim; +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(b, loc, triton::applyPermutation(multiDim, order), + triton::applyPermutation(shape, order)); +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = b.create(loc, 0, 32); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.create(loc, dimShape, 32); + linear = b.create( + loc, b.create(loc, linear, dimSize), dim); + } + } + return linear; +} + +bool isPureUnaryInlineAsm(Operation *op) { + auto inlineAsmOp = dyn_cast(op); + if (!inlineAsmOp) + return false; + return op->getNumOperands() == 1 && op->getNumResults() == 1 && + inlineAsmOp.getPure(); +} + +int getNVIDIAComputeCapability(Operation *module) { + StringAttr targetAttr = + module->getAttrOfType(triton::gpu::AttrTargetName); + assert(targetAttr && "Expected a target attribute on the module operation"); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("cuda:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:" + int computeCapability; + bool parseError = capabilityStr.getAsInteger(10, computeCapability); + assert(!parseError && + "invalid compute capability string in target attribute"); + + return computeCapability; +} + +StringRef getAMDArch(Operation *module) { + StringAttr targetAttr = + module->getAttrOfType(triton::gpu::AttrTargetName); + assert(targetAttr && "Expected a target attribute on the module operation"); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("hip:") && + "expected target attribute to be prefixed with \"hip:\""); + + return ref.drop_front(4); // drop the "hip:" +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return the shared encoding that needs to be +// used to be compatible with users' layouts. If there are incompatible shared +// encodings, set incompatible to true. +std::optional +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { + ttg::SwizzledSharedEncodingAttr attr; + incompatible = false; + for (Operation *user : val.getUsers()) { + ttg::SwizzledSharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = + dyn_cast(memDesc.getEncoding()); + if (!tempAttr) + return std::nullopt; + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()) + .getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, + bitWidth, /*needTrans=*/false); + } + // Check that the shared encodings needed by the users are compatible. + if (attr != nullptr && attr != tempAttr) { + incompatible = true; + return std::nullopt; + } + attr = tempAttr; + } + return attr; +} + +namespace { + +/// Detect dead arguments in scf.for op by assuming all the values are dead and +/// propagate liveness property. +class ForOpDeadArgElimination : public OpRewritePattern { + DenseSet opsCanBeTriviallyDead; + +public: + using OpRewritePattern::OpRewritePattern; + + explicit ForOpDeadArgElimination( + MLIRContext *context, const DenseSet &opsCanBeTriviallyDead) + : OpRewritePattern(context), + opsCanBeTriviallyDead(opsCanBeTriviallyDead) {} + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + Block &block = *forOp.getBody(); + auto yieldOp = cast(block.getTerminator()); + // Assume that nothing is live at the beginning and mark values as live + // based on uses. + DenseSet aliveValues; + SmallVector queue; + // Helper to mark values as live and add them to the queue of value to + // propagate if it is the first time we detect the value as live. + auto markLive = [&](Value val) { + if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) + return; + if (aliveValues.insert(val).second) + queue.push_back(val); + }; + // Mark all yield operands as live if the associated forOp result has any + // use. + for (auto result : llvm::enumerate(forOp.getResults())) { + if (!result.value().use_empty()) + markLive(yieldOp.getOperand(result.index())); + } + if (aliveValues.size() == forOp.getNumResults()) + return failure(); + // Operations with side-effects are always live. Mark all theirs operands as + // live. + block.walk([&](Operation *op) { + if (!isa(op) && + !(wouldOpBeTriviallyDead(op) || opsCanBeTriviallyDead.contains(op))) { + for (Value operand : op->getOperands()) + markLive(operand); + } + }); + // Propagate live property until reaching a fixed point. + while (!queue.empty()) { + Value value = queue.pop_back_val(); + if (auto nestedFor = value.getDefiningOp()) { + auto result = mlir::cast(value); + OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); + markLive(forOperand.get()); + auto nestedYieldOp = + cast(nestedFor.getBody()->getTerminator()); + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + continue; + } + if (auto nestedIf = value.getDefiningOp()) { + auto result = mlir::cast(value); + // mark condition as live. + markLive(nestedIf.getCondition()); + for (scf::YieldOp nestedYieldOp : + {nestedIf.thenYield(), nestedIf.elseYield()}) { + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + } + continue; + } + if (Operation *def = value.getDefiningOp()) { + // TODO: support while ops. + if (isa(def)) + return failure(); + for (Value operand : def->getOperands()) + markLive(operand); + continue; + } + // If an argument block is live then the associated yield operand and + // forOp operand are live. + auto arg = mlir::cast(value); + if (auto forOwner = dyn_cast(arg.getOwner()->getParentOp())) { + if (arg.getArgNumber() < forOwner.getNumInductionVars()) + continue; + unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); + Value yieldOperand = + forOwner.getBody()->getTerminator()->getOperand(iterIdx); + markLive(yieldOperand); + markLive(forOwner.getInitArgs()[iterIdx]); + } + } + SmallVector deadArg; + for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { + if (aliveValues.contains(yieldOperand.value())) + continue; + if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) + continue; + + // The yield operand might live outside the loop, e.g. + // %init = ... + // %x = ... + // %y = for iter_args(%unused = %init) { + // yield %x + // } + // + // In this case, the loop returns %x if it runs 1 or more times, and + // otherwise it returns %init. We cowardly refuse to remove this operand + // from the yield. (We could, but we'd need to prove that the loop runs 0 + // or >=1 times.) + // + // As a special case, if it doesn't matter whether the loop runs 0 or >=1 + // times (because the loop returns the same value in both cases) then we + // can still mark the operand as dead. This occurs in the above example + // when %init is the same as %x. + if (!forOp->isAncestor( + yieldOperand.value().getParentRegion()->getParentOp()) && + yieldOperand.value() != forOp.getInitArgs()[yieldOperand.index()]) + continue; + + deadArg.push_back(yieldOperand.index()); + } + if (deadArg.empty()) + return failure(); + rewriter.modifyOpInPlace(forOp, [&]() { + // For simplicity we just change the dead yield operand to use the + // associated argument and leave the operations and argument removal to + // dead code elimination. + for (unsigned deadArgIdx : deadArg) { + BlockArgument arg = block.getArgument(deadArgIdx + 1); + yieldOp.setOperand(deadArgIdx, arg); + } + }); + return success(); + } +}; + +} // namespace + +void populateForOpDeadArgumentElimination( + RewritePatternSet &patterns, DenseSet &opsCanBeTriviallyDead) { + patterns.add(patterns.getContext(), + opsCanBeTriviallyDead); +} + +ttg::LocalAllocOp findShmemAlloc(Value operand) { + // If it's a shmem operand, it must either be defined outside the loop, or + // come from an MemDescSubview op. Only ConvertLayout and Trans ops are + // allowed in between. + Value transitiveOperand = operand; + while ( + isa_and_nonnull( + transitiveOperand.getDefiningOp()) || + isa(transitiveOperand)) { + if (auto blockArg = dyn_cast(transitiveOperand)) { + assert(isa(blockArg.getOwner()->getParentOp()) && + "Block argument must come from a for loop"); + transitiveOperand = + cast(blockArg.getOwner()->getTerminator()) + .getOperand(blockArg.getArgNumber() - 1); + } else { + transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); + } + } + if (auto subView = + dyn_cast(transitiveOperand.getDefiningOp())) { + // Multi-buffered operand + return dyn_cast(subView.getSrc().getDefiningOp()); + } else { + // Single bufferred operand that does not require a subview (not loaded in + // the loop) + return dyn_cast(transitiveOperand.getDefiningOp()); + } + return nullptr; +} + +SmallVector +getMMAsWithMultiBufferredOperands(scf::ForOp forOp, + SmallVector &mmaOps) { + // The A and B operands of the mmaOp should be multi-buffered + SmallVector eligible; + for (auto mmaOp : mmaOps) { + auto a = findShmemAlloc(mmaOp->getOperand(0)); + auto b = findShmemAlloc(mmaOp->getOperand(1)); + if (a && forOp.isDefinedOutsideOfLoop(a) && b && + forOp.isDefinedOutsideOfLoop(b)) { + eligible.push_back(mmaOp); + } + } + + return eligible; +} + +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSCanonicalization.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSCanonicalization.cpp new file mode 100644 index 000000000..7886e357a --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSCanonicalization.cpp @@ -0,0 +1,116 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include + +#include "mlir/IR/OperationSupport.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-canonicalization" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define GEN_PASS_DEF_TRITONGPUWSCANONICALIZATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUWSCanonicalization + : public impl::TritonGPUWSCanonicalizationBase< + TritonGPUWSCanonicalization> { +public: + using impl::TritonGPUWSCanonicalizationBase< + TritonGPUWSCanonicalization>::TritonGPUWSCanonicalizationBase; + + void runOnOperation() override { + if (numConsumerGroups == 0) + return; + + // Find top-level ifOp that specializes warps. Such ifOp has the following + // form: + // %51 = ttng.get_canonical_warp_id + // %52 = arith.divui %51, %c4_i32 + // %53 = arith.cmpi eq, %52, %c0_i32 + // scf.if %53 { + // ... + // } + DenseMap ifOpToTaskId; + getOperation()->walk([&](scf::IfOp ifOp) { + // Skip ifOp that has more than one region + if (ifOp.elseBlock()) + return; + + // Get the condition of scf.if + Value cond = ifOp.getCondition(); + auto cmpOp = cond.getDefiningOp(); + if (!cmpOp || cmpOp.getPredicate() != arith::CmpIPredicate::eq) + return; + + // Ensure the LHS of comparison is from an arith.divui + Value divResult = cmpOp.getLhs(); + auto divOp = divResult.getDefiningOp(); + if (!divOp) + return; + + // Ensure the RHS of comparison is a constant + auto warpGroupId = cmpOp.getRhs().getDefiningOp(); + if (!warpGroupId) + return; + + // Ensure the divisor is 4 + auto divisorCst = divOp.getRhs().getDefiningOp(); + if (!divisorCst || divisorCst.value() != 4) + return; + + // Ensure the dividend is from ttng.get_canonical_warp_id + Value warpId = divOp.getLhs(); + auto warpOp = warpId.getDefiningOp(); + if (!warpOp) + return; + + // Al conditions matc + LLVM_DEBUG({ + LDBG("Warp specialization region:"); + ifOp.dump(); + }); + + auto asyncTaskIds = getAsyncTaskIds(ifOp); + assert(asyncTaskIds.size() == 1 && "Expecting one async task id"); + auto taskId = asyncTaskIds[0]; + assert(taskId == warpGroupId.value() && + "Expecting task id to match warp group id"); + ifOpToTaskId[ifOp] = taskId; + }); + + // Fix up the async task ids for each op in the specialized region + for (const auto &item : ifOpToTaskId) { + auto ifOp = item.first; + auto taskId = item.second; + SmallVector regionTaskIds = {taskId}; + ifOp->walk([&](Operation *op) { + // Fix up the async task ids + if (getAsyncTaskIds(op) != regionTaskIds) { + LLVM_DEBUG({ + LDBG("Fixing up async task ids to " << taskId << " for "); + op->dump(); + }); + setAsyncTaskIds(op, regionTaskIds); + } + }); + } + } +}; +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp new file mode 100644 index 000000000..6d7d818a4 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -0,0 +1,2444 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include +#include + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUWSCODEPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-warp-spec-code-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +std::pair scanRegUsage(Block *block, AsyncTaskId asyncTaskId, + int regDecProducer, int regIncConsumer) { + // TODO: scan ops to estimate register usage + if (asyncTaskId == 0) { + // deallocate registers + return {regDecProducer == 0 ? 40 : regDecProducer, false}; + } else { + // allocate registers + return {regIncConsumer == 0 ? 232 : regIncConsumer, true}; + } +} + +unsigned getNumBuffersOrDefault(scf::ForOp forOp, unsigned numBuffers) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return numBuffers; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); +} + +// Collect argument indices that are used by the specific taskId. +static SmallVector collectBlockArgsForTask(scf::ForOp forOp, + int asyncTaskId) { + + // Collect argument indices that can be reached along the definition chain. + SetVector argIndices; + std::function dfs = [&](Value arg, unsigned argIdx) { + for (auto user : arg.getUsers()) { + // Skip ops that are not in the same async task + if (!hasAsyncTaskId(user, asyncTaskId)) + continue; + + if (isa(user)) { + if (auto ifOp = dyn_cast(user->getParentOp())) { + // For block arguments, we need to check the initial value as well. + if (auto blockArg = dyn_cast(arg)) { + auto initArg = forOp.getInitArgs()[blockArg.getArgNumber() - 1]; + if (Operation *def = initArg.getDefiningOp()) { + if (hasAsyncTaskId(def, asyncTaskId)) { + argIndices.insert(argIdx); + } + } else { + llvm_unreachable("Initial value should have a defining op"); + } + } + } + + // Skip control flow ops that are shared by all async tasks + continue; + } + + // Found a real user, the arg is needed + if (user->getNumRegions() == 0) { + argIndices.insert(argIdx); + return; + } + + // Iterate through all regions of the user operation + for (auto ®ion : user->getRegions()) { + for (auto regionArg : region.getArguments()) { + if (arg == regionArg) + dfs(regionArg, argIdx); + } + } + } + }; + + // check dependency with DFS traversal for loop args and results. + mlir::Block &block = forOp.getRegion().front(); + for (unsigned i = forOp.getNumInductionVars(); i < block.getNumArguments(); + ++i) { + auto arg = block.getArgument(i); + dfs(arg, i - forOp.getNumInductionVars()); + } + for (unsigned i = 0; i < forOp.getNumResults(); ++i) { + auto result = forOp->getResult(i); + dfs(result, i); + } + + SmallVector args(argIndices.begin(), argIndices.end()); + llvm::sort(args); + return args; +} + +Operation *SpecializeOp(Operation *op, IRMapping &mapping, + OpBuilderWithAsyncTaskIds &builder, + AsyncTaskId asyncTaskId); + +// Return the argument that tracks accumLoopCount if there is an outer +// ForOp. +Value getAccumLoopCountArg(scf::ForOp parentForOp) { + assert(parentForOp); + auto tSize = parentForOp.getBody()->getArguments().size(); + assert(tSize >= 3); // accum, bufferIdx, phase + Value tmpAccumLoopCount = parentForOp.getBody()->getArgument(tSize - 3); + return tmpAccumLoopCount; +} + +// Check to see if op is enclosed under ifOp. +static bool enclosing(scf::IfOp ifOp, Operation *op) { + auto pOp = op->getParentOfType(); + while (pOp) { + if (pOp == ifOp) + return true; + pOp = pOp->getParentOfType(); + } + return false; +} + +// Check to see if there is no outer loop that is enclosed under ifOp. +static bool immediateEnclosing(scf::IfOp ifOp, Operation *subOp) { + auto pOp = subOp->getParentOfType(); + if (!pOp) + return true; + return !enclosing(ifOp, pOp.getOperation()); +} + +// Return true if the IfOp contains a ForOp that is in opsWithBufferReuse. +// We want to support reuse between channels in a loop and channels in a IfOp. +static bool +needAccumulatedLoopCnt(scf::IfOp ifOp, + SmallVector &opsWithBufferReuse) { + bool needAccum = false; + ifOp.walk([&](Operation *subOp) { + for (auto tOp : opsWithBufferReuse) { + if (auto forOp = dyn_cast(subOp)) { + // For the case of ifOp contains forOp, which contains subOp, no need to + // generate accumLoopCount for ifOp. + if (subOp == tOp && immediateEnclosing(ifOp, tOp)) { + needAccum = true; + break; + } + } else { + if (subOp == tOp) { + needAccum = true; + break; + } + } + } + }); + return needAccum; +} + +Value updateAccumLoopCount(SmallVector &opList, + unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &opsWithBufferReuse, + Value prevAccum); + +scf::ForOp createNewLoopWrapper(scf::ForOp origForOp, unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &opsWithBufferReuse, + Value prevAccum); + +// For certain cases, we need to add an additional output for +// IfOp to track the accumulatedLoopCount, we may need to add +// a corresponding elseBlock with yieldOp. +scf::IfOp rewriteIfOp(scf::IfOp ifOp, unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &opsWithBufferReuse, + Value prevAccum) { + LLVM_DEBUG({ + LDBG("rewrite ifOp for smem sharing "); + ifOp.dump(); + }); + + OpBuilderWithAsyncTaskIds ifBuilder(ifOp.getContext()); + ifBuilder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(ifOp)); + ifBuilder.setInsertionPoint(ifOp); + + SmallVector newResultTypes(ifOp->getResultTypes()); + // Add an output for the IfOp for accumulated loop count. + newResultTypes.push_back(ifBuilder.getI64Type()); + // Create else block if we need to generate accumulated loop count. + auto newIfOp = ifBuilder.createWithAsyncTaskIds( + ifOp.getLoc(), newResultTypes, ifOp.getCondition(), true, true); + + // Move the existing blocks to the new if. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + + ifBuilder.setInsertionPointToEnd(newIfOp.thenBlock()); + SmallVector opList; + for (Operation &op : newIfOp.thenBlock()->getOperations()) { + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + } + + // Update yields + auto loc = ifOp.getLoc(); + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + ifBuilder.setInsertionPoint(yield); + ifBuilder.createWithAsyncTaskIds(loc, operands); + yield.erase(); + }; + + // Add one more operand to then Yield. + Value endAccum = + updateAccumLoopCount(opList, numBuffers, taskTopOps, commonOuterLoop, + opsWithBufferReuse, prevAccum); + + SmallVector ifYieldOperands = newIfOp.thenYield().getOperands(); + ifYieldOperands.push_back(endAccum); + updateYield(newIfOp.thenYield(), ifYieldOperands); + + // Handle elseRegion of the IfOp. + if (ifOp.elseBlock()) { + ifBuilder.setInsertionPointToEnd(newIfOp.elseBlock()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + opList.clear(); + for (Operation &op : newIfOp.elseBlock()->getOperations()) { + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + } + endAccum = + updateAccumLoopCount(opList, numBuffers, taskTopOps, commonOuterLoop, + opsWithBufferReuse, prevAccum); + } else { + // Create an empty yield + auto yieldOp = + newIfOp.getElseBodyBuilder().create(ifOp.getLoc()); + endAccum = prevAccum; + } + // Add one more operand to else Yield. + SmallVector elseYieldOperands = newIfOp.elseYield().getOperands(); + elseYieldOperands.push_back(endAccum); + updateYield(newIfOp.elseYield(), elseYieldOperands); + int resultIdx = 0; + // Replace old if with the new one. + for (auto result : ifOp.getResults()) { + result.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + } + + // If ifOp is in opsWithBufferReuse, replace. + auto tmpIter = std::find(opsWithBufferReuse.begin(), opsWithBufferReuse.end(), + ifOp.getOperation()); + if (tmpIter != opsWithBufferReuse.end()) { + *tmpIter = newIfOp.getOperation(); + } + + ifOp.erase(); + return newIfOp; +} + +Operation *SpecializeIfOp(scf::IfOp ifOp, IRMapping &mapping, + OpBuilderWithAsyncTaskIds &builder, + AsyncTaskId asyncTaskId) { + LLVM_DEBUG({ + LDBG("specialize ifOp "); + ifOp.dump(); + }); + + // It is possible that we need to reduce the results. One example + // is that the defining op for the yield operation is not for this + // taskId and the defining op is not specialized, thus we should + // remove the result. + // We need to update the result types correctly here. + unsigned resultIdx = 0; + SmallVector keptResultVec; + if (!ifOp->getResultTypes().empty()) { + for (Value yieldV : ifOp.thenYield().getOperands()) { + // Check the defining op for the corresponding result. + if (Operation *def = yieldV.getDefiningOp()) { + bool hasTaskId = hasAsyncTaskId(def, asyncTaskId); + if (hasTaskId) { + keptResultVec.push_back(resultIdx); + } + } else { + assert(isa(yieldV) && "Unexpected yield value"); + auto bbArg = cast(yieldV); + // Find transitive defining op for the block arg + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(bbAargOwner)) { + // track initial value + auto initArg = forOp.getInitArgs()[bbArg.getArgNumber() - 1]; + if (Operation *def = initArg.getDefiningOp()) { + if (hasAsyncTaskId(def, asyncTaskId)) + keptResultVec.push_back(resultIdx); + } else { + llvm_unreachable("Initial value should have a defining op"); + } + } else { + llvm_unreachable("Unexpected block argument owner"); + } + } + ++resultIdx; + } + } + + SmallVector newResultTypes; + for (auto idx : keptResultVec) { + newResultTypes.push_back(ifOp->getResultTypes()[idx]); + } + auto newIfOp = builder.createWithAsyncTaskIds( + ifOp.getLoc(), newResultTypes, mapping.lookup(ifOp.getCondition()), true, + ifOp.elseBlock()); + + OpBuilderWithAsyncTaskIds ifBuilder(ifOp.getContext()); + ifBuilder.setAsynTaskIdsFromArray({asyncTaskId}); + + // Handle thenRegion of this IfOp. + ifBuilder.setInsertionPointToEnd(newIfOp.thenBlock()); + for (Operation &thenOp : ifOp.thenBlock()->getOperations()) { + SpecializeOp(&thenOp, mapping, ifBuilder, asyncTaskId); + } + + // Update yields + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + ifBuilder.setInsertionPoint(yield); + ifBuilder.createWithAsyncTaskIds(yield.getLoc(), operands); + yield.erase(); + }; + if (keptResultVec.size() < ifOp->getResultTypes().size()) { + SmallVector ifYieldOperands; + for (auto idx : keptResultVec) { + ifYieldOperands.push_back(newIfOp.thenYield().getOperand(idx)); + } + updateYield(newIfOp.thenYield(), ifYieldOperands); + } + + // Handle elseRegion of the IfOp. + if (ifOp.elseBlock()) { + ifBuilder.setInsertionPointToEnd(newIfOp.elseBlock()); + for (Operation &elseOp : ifOp.elseBlock()->getOperations()) { + SpecializeOp(&elseOp, mapping, ifBuilder, asyncTaskId); + } + if (keptResultVec.size() < ifOp->getResultTypes().size()) { + SmallVector elseYieldOperands; + for (auto idx : keptResultVec) { + elseYieldOperands.push_back(newIfOp.elseYield().getOperand(idx)); + } + updateYield(newIfOp.elseYield(), elseYieldOperands); + } + } + + unsigned newResIdx = 0; + for (auto idx : keptResultVec) { + mapping.map(ifOp.getResult(idx), newIfOp.getResult(newResIdx)); + ++newResIdx; + } + return newIfOp; +} + +Operation *SpecializeForOp(scf::ForOp forOp, IRMapping &mapping, + OpBuilderWithAsyncTaskIds &builder, + AsyncTaskId asyncTaskId) { + // Create newForOp for each task Id. + auto usedArgs = collectBlockArgsForTask(forOp, asyncTaskId); + + // Prepare newLoopArgs. + SmallVector newLoopArgs; + for (unsigned argNumber : usedArgs) { + auto arg = forOp.getInitArgs()[argNumber]; + auto newArg = mapping.lookupOrDefault(arg); + assert(newArg && "Unexpected missing mapping"); + newLoopArgs.push_back(newArg); + } + + // Prepare loop bounds. + auto newLowerBound = mapping.lookupOrDefault(forOp.getLowerBound()); + auto newUpperBound = mapping.lookupOrDefault(forOp.getUpperBound()); + auto newStep = mapping.lookupOrDefault(forOp.getStep()); + + // Create newForOp. + auto newForOp = builder.createWithAsyncTaskIds( + forOp.getLoc(), newLowerBound, newUpperBound, newStep, newLoopArgs); + if (forOp->getAttr("tt.loop_schedule")) + newForOp->setAttr("tt.loop_schedule", forOp->getAttr("tt.loop_schedule")); + + // Initialize Value mapping from forOp to newForOp + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldArg = forOp.getRegionIterArgs()[usedArgs[i]]; + auto newArg = newForOp.getRegionIterArgs()[i]; + mapping.map(oldArg, newArg); + } + + // Recursively clone all operations with this asyncTaskId to newForOp. + OpBuilderWithAsyncTaskIds forBuilder(forOp.getContext()); + forBuilder.setAsynTaskIdsFromArray({asyncTaskId}); + forBuilder.setInsertionPointToStart(newForOp.getBody()); + for (Operation &op : forOp.getBody()->without_terminator()) { + SpecializeOp(&op, mapping, forBuilder, asyncTaskId); + } + + // Create YieldOp for newForOp. + auto yieldOp = llvm::cast(forOp.getBody()->getTerminator()); + SmallVector newYieldOperands; + for (unsigned i : usedArgs) + newYieldOperands.push_back(mapping.lookup(yieldOp.getOperand(i))); + + bool createNewYield = true; + if (newForOp.getBody()->mightHaveTerminator()) { + auto initialYield = + llvm::cast(newForOp.getBody()->getTerminator()); + if (newYieldOperands.size() == 0) { + setAsyncTaskIds(initialYield, {asyncTaskId}); + createNewYield = false; + } + } + if (createNewYield) { + auto newYieldOp = + forBuilder.create(yieldOp.getLoc(), newYieldOperands); + setAsyncTaskIds(newYieldOp, {asyncTaskId}); + } + + // Replace results of forOp with results of newForOp. + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldResult = forOp.getResult(usedArgs[i]); + auto newResult = newForOp.getResult(i); + mapping.map(oldResult, newResult); + } + + return newForOp; +} + +Operation *SpecializeOp(Operation *op, IRMapping &mapping, + OpBuilderWithAsyncTaskIds &builder, + AsyncTaskId asyncTaskId) { + auto taskIds = getAsyncTaskIds(op); + // yieldOp are sometimes implict, meaning they do not necessarily have a task + // id, but they should be shared by all async tasks. + if (!hasAsyncTaskId(op, asyncTaskId) && !isa(op)) + return nullptr; + + if (op->getNumRegions() == 0) { + Operation *newOp = builder.clone(*op, mapping); + setAsyncTaskIds(newOp, asyncTaskId); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + return newOp; + } else { + if (auto ifOp = dyn_cast(op)) { + return SpecializeIfOp(ifOp, mapping, builder, asyncTaskId); + } else if (auto forOp = dyn_cast(op)) { + return SpecializeForOp(forOp, mapping, builder, asyncTaskId); + } else if (auto reduceOp = dyn_cast(op)) { + Operation *newOp = builder.clone(*op, mapping); + // recursively set async task ids for child ops + newOp->walk( + [&](Operation *childOp) { setAsyncTaskIds(childOp, asyncTaskId); }); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + return newOp; + } else { + llvm_unreachable("Unexpected Op with regions"); + } + } + + return nullptr; +} + +// Create IfOp for each ayncTaskId. +DenseMap SpecializeRegion(triton::FuncOp funcOp, + int regDecProducer, + int regIncConsumer) { + + LLVM_DEBUG({ + LDBG("\n\n"); + LDBG("Start specializing region"); + }); + + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(context); + auto loc = funcOp.getLoc(); + + // Collect original operations + SmallVector opList; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &op : block.getOperations()) { + auto taskIds = getAsyncTaskIds(&op); + if (!taskIds.empty()) + opList.push_back(&op); + } + } + + LLVM_DEBUG({ + LDBG("ops to be specialized: "); + for (Operation *op : opList) { + op->dump(); + } + }); + + // Create GetAsyncTaskIdOp. + Block *lastBlock = &funcOp.getBody().back(); + auto returnOp = llvm::cast(lastBlock->getTerminator()); + builder.setInsertionPoint(returnOp); + Value curAsyncTaskId = builder.create(loc); + + DenseMap tasksToIfOp; + + // Clone all operations into the corresponding if blocks. If the operation + // has multiple taskIds, it will be cloned for multiple if blocks. + // If the original code has an IfOp, we should only clone its + // body with the right asyncTaskId, instead of cloning the IfOp. + for (AsyncTaskId asyncTaskId : getNestedAsyncTaskIds(funcOp)) { + // Create IfOp for each asyncTaskId. + Value cond = builder.create( + loc, arith::CmpIPredicate::eq, curAsyncTaskId, + builder.create(loc, asyncTaskId, 32)); + + auto ifOp = builder.create(loc, cond); + tasksToIfOp[asyncTaskId] = ifOp; + setAsyncTaskIds(ifOp, {asyncTaskId}); + + OpBuilderWithAsyncTaskIds taskBuilder(context); + taskBuilder.setAsynTaskIdsFromArray({asyncTaskId}); + + // Set insertion point before yieldOp. + auto yieldOp = ifOp.thenYield(); + setAsyncTaskIds(yieldOp, {asyncTaskId}); + taskBuilder.setInsertionPoint(yieldOp); + + IRMapping mapping; + for (Operation *op : opList) { + SpecializeOp(op, mapping, taskBuilder, asyncTaskId); + } + } + + // Decide if this taskId is a producer or a consumer, and create either + // RegAllocOp or RegDeallocOp accordingly. + for (auto ifOps : tasksToIfOp) { + AsyncTaskId asyncTaskId = ifOps.first; + auto ifOp = ifOps.second; + OpBuilderWithAsyncTaskIds taskBuilder(ifOp.getContext()); + taskBuilder.setAsynTaskIdsFromArray({asyncTaskId}); + auto regAlloc = scanRegUsage(ifOp.thenBlock(), asyncTaskId, regDecProducer, + regIncConsumer); + taskBuilder.setInsertionPointToStart(&(ifOp.getThenRegion().front())); + if (regAlloc.second) + taskBuilder.create( + loc, taskBuilder.getI32IntegerAttr(regAlloc.first)); + else + taskBuilder.create( + loc, taskBuilder.getI32IntegerAttr(regAlloc.first)); + } + + LLVM_DEBUG({ + LDBG("\n\nWith task Id checks"); + funcOp.dump(); + }); + + // Remove original operations that have been cloned in reverse order. + for (auto it = opList.rbegin(); it != opList.rend(); ++it) { + Operation *op = *it; + LLVM_DEBUG({ + LDBG("erasing op "); + op->dump(); + }); + // For debugging purposes, check to see if the original op is still in use. + bool hasUse = false; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + for (Operation *user : op->getResult(i).getUsers()) { + hasUse = true; + LLVM_DEBUG({ + LDBG("op has use "); + user->dump(); + }); + } + } + op->erase(); + } + return tasksToIfOp; +} + +struct Channel { +public: + using Relation = std::pair>; + + Channel(int producer, SmallVector &consumers, Operation *op, + unsigned operandIdx, unsigned numBuffers) + : relation(producer, consumers), op(op), operandIdx(operandIdx), + numBuffers(numBuffers) {} + + bool operator==(const Channel &c) { + return relation == c.relation && operandIdx == c.operandIdx && op == c.op; + } + + Operation *getDstOp() { return op; } + unsigned getDstOperandIdx() { return operandIdx; } + Value getSrcOperand() { return op->getOperand(operandIdx); } + Operation *getSrcOp() { return getSrcOperand().getDefiningOp(); } + + Relation relation; // producer task Id, a list of consumer task Ids + Operation *op; + unsigned operandIdx; + unsigned numBuffers; +}; + +// Find transitive users of the root op. Track through control flow ops (such as +// yield) to get to the real users. +void getTransitiveUsers(Value root, + SetVector> &users) { + for (Operation *userOp : root.getUsers()) { + if (auto yieldOp = dyn_cast(userOp)) { + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (operand.get() == root) { + auto result = + yieldOp->getParentOp()->getResult(operand.getOperandNumber()); + getTransitiveUsers(result, users); + } + } + } else { + // find operand index of root + unsigned operandIndex = 0; + for (OpOperand &operand : userOp->getOpOperands()) { + if (operand.get() == root) { + break; + } + operandIndex++; + } + assert(operandIndex < userOp->getNumOperands() && + "root is not an operand of userOp"); + users.insert({userOp, operandIndex}); + } + } +} + +// Loads will be in producer warp groups. For now, we only allow a single +// warp group/task for a producer. For each LoadOp, create a channel from it +// to any direct user which belongs to a different taskId. +void collectAsyncChannels(SmallVector> &channels, + triton::FuncOp &funcOp, unsigned numBuffers) { + funcOp.walk([&](Operation *op) { + if (isa(op) || + isa(op)) { + auto producerTaskIds = getAsyncTaskIds(op); + if (producerTaskIds.empty() || producerTaskIds.size() > 1) { + LLVM_DEBUG({ + LDBG(" ignoring load ops without async task id or with multiple task " + "ids: "); + op->dump(); + }); + return; + } + auto producerTaskId = producerTaskIds.front(); + unsigned producerNumBuffers = numBuffers; + if (auto forOp = op->getParentOfType()) { + producerNumBuffers = getNumBuffersOrDefault(forOp, numBuffers); + } + + for (auto result : op->getResults()) { + if (result.use_empty()) { + continue; + } + + SetVector> users; + getTransitiveUsers(result, users); + for (auto user : users) { + auto userOp = user.first; + auto consumerTaskIds = getAsyncTaskIds(userOp); + if (consumerTaskIds.empty()) + continue; + // Remove producer task id from consumerTaskIds. + auto iter = std::remove(consumerTaskIds.begin(), + consumerTaskIds.end(), producerTaskId); + consumerTaskIds.erase(iter, consumerTaskIds.end()); + // Add a channel from the single producer task to consumerTaskIds. + if (consumerTaskIds.size() > 0) { + channels.push_back(std::make_unique( + producerTaskId, consumerTaskIds, userOp, user.second, + producerNumBuffers)); + } + } + } + } + }); + + LLVM_DEBUG({ + LDBG("Async channels:"); + for (auto &channel : channels) { + LDBG("producer op: " << channel->relation.first); + channel->getSrcOp()->dump(); + for (auto &asyncTaskId : channel->relation.second) + LDBG("consumer: " << asyncTaskId); + channel->getDstOp()->dump(); + LDBG("numBuffers: " << channel->numBuffers); + } + }); +} + +// Group channels in two ways: +// - by producer ops. One producer corresponds to multiple channels. This +// grouping will be used to create buffers per shared producer. +// - by consumer ops. One consumer corresponds to multiple channels. This +// grouping will be used to create barriers per shared consumer. +// Also compute orderedChannels, which will be keyed by getDstOp() of channels, +// to enforce deterministic order for map. +void groupChannels( + SmallVector &channels, + DenseMap> &channelsGroupedByProducers, + DenseMap> &channelsGroupedByConsumers, + SmallVector &orderedChannels) { + + // Group channels by producer op. + DenseMap> producerChannels; + for (auto channel : channels) { + producerChannels[channel->getSrcOp()].push_back(channel); + } + +#ifndef NDEBUG + // Some sanity checks. + for (auto &item : producerChannels) { + auto &channels = item.second; + unsigned numBuffers = channels.front()->numBuffers; + for (auto c : channels) { + assert(c->numBuffers == numBuffers && "Unmatched number of buffers"); + } + } +#endif + + // Group channels by consumer op. + DenseMap> consumerChannels; + + // Two channels can be combined if + // src1 and src2 are in the same block and + // (dst1 == dst2 or + // (dst1 and dst2 are in the same block, both have a single user, and + // dst1User == dst2User and dst1User is in the same block as dst1)) + auto channelCanBeMerged = [](Channel *c1, Channel *c2) -> bool { + if (c1->getSrcOp()->getBlock() != c2->getSrcOp()->getBlock()) + return false; + Operation *dst1 = c1->getDstOp(), *dst2 = c2->getDstOp(); + if (dst1 == dst2) + return true; + if (dst1->getBlock() != dst2->getBlock() || !dst1->hasOneUse() || + !dst2->hasOneUse()) + return false; + // Check taskIds on dstOps. + if (getAsyncTaskIds(dst1) != getAsyncTaskIds(dst2)) + return false; + Operation *dst1User = *(dst1->getUsers().begin()); + Operation *dst2User = *(dst2->getUsers().begin()); + return dst1User == dst2User && dst1User->getBlock() == dst1->getBlock(); + }; + assert(channels.size() > 0 && "channel size is zero"); + // Compare with existing channels in the consumerChannels to see if + // it can be combined. + for (auto *c0 : channels) { + bool merged = false; + for (auto &kv : consumerChannels) { + if (kv.second.size() > 0 && channelCanBeMerged(c0, kv.second.front())) { + kv.second.push_back(c0); + merged = true; + break; + } + } + if (!merged) { // Create a new entry. + auto *keyOp = c0->getDstOp(); + if (!consumerChannels.count(keyOp)) + orderedChannels.push_back(c0); + consumerChannels[keyOp].push_back(c0); + } + } + + // Reorder channels associated with one entry based on program order of the + // producers. + for (auto &kv : consumerChannels) { + if (kv.second.size() > 1) { + auto &allOps = kv.second.front()->getSrcOp()->getBlock()->getOperations(); + std::sort( + kv.second.begin(), kv.second.end(), [&](Channel *a, Channel *b) { + auto itrA = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == a->getSrcOp(); + }); + auto itrB = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == b->getSrcOp(); + }); + assert(itrA != allOps.end() && itrB != allOps.end()); + return std::distance(itrA, itrB) < 0; + }); + } + } + + // Switch to using channel as the key instead of ops as ops can be volatile. + for (auto &kv : producerChannels) { + channelsGroupedByProducers[kv.second.front()] = kv.second; + } + for (auto &kv : consumerChannels) { + channelsGroupedByConsumers[kv.second.front()] = kv.second; + } + + LLVM_DEBUG({ + DBGS() << "\n\n"; + LDBG("Grouped channels by producer:"); + unsigned i = 0; + for (auto &kv : channelsGroupedByProducers) { + DBGS() << "Channel " << ++i << ":\n"; + DBGS() << "producer: "; + kv.getFirst()->getSrcOp()->dump(); + for (auto &channel : kv.second) { + DBGS() << "consumer: "; + channel->getDstOp()->dump(); + DBGS() << "] "; + LDBG("numBuffers: " << channel->numBuffers); + DBGS() << "\n"; + } + } + + DBGS() << "\n\n"; + LDBG("Grouped channels by consumer:"); + i = 0; + for (auto &kv : channelsGroupedByConsumers) { + DBGS() << "Channel " << ++i << ":\n"; + DBGS() << "consumer: "; + kv.getFirst()->getDstOp()->dump(); + for (auto &channel : kv.second) { + DBGS() << "producer: "; + channel->getSrcOp()->dump(); + for (auto &asyncTaskId : channel->relation.second) + DBGS() << asyncTaskId << ", "; + DBGS() << "] "; + LDBG("numBuffers: " << channel->numBuffers); + DBGS() << "\n"; + } + DBGS() << "\n"; + } + }); +} + +// Reorder producer ops to unblock consumers interleavingly. +void reorderProducerOps(SmallVector &channels) { + if (channels.size() <= 1) + return; + + // Bail out if channels are not in the same block + auto block = channels.front()->getSrcOp()->getBlock(); + for (auto &channel : channels) { + if (channel->getSrcOp()->getBlock() != block) { + return; + } + } + + // Group channels by the first consumer taskId of each channel. Smaller taskId + // has higher priority. + // TODO: consider consumer priority + std::map> groupedProducerOps; + for (auto &channel : channels) { + auto asyncTaskId = channel->relation.second.front(); + groupedProducerOps[asyncTaskId].push_back(channel); + } + + // No need to reorder if all channels are in the same group. + if (groupedProducerOps.size() <= 1) + return; + + // Sort each group by number of consumers. + for (auto &group : groupedProducerOps) { + std::sort(group.second.begin(), group.second.end(), + [&](Channel *a, Channel *b) { + return a->relation.second.size() < b->relation.second.size(); + }); + } + + // Start from the first producer in channels. Iterate through the groups + // which are ordered by the first consumer taskId. Within each group, channels + // are ordered by number of consumers. + Operation *currOp = channels.front()->getSrcOp(); + for (auto &group : groupedProducerOps) { + for (auto &channel : group.second) { + channel->getSrcOp()->moveAfter(currOp); + currOp = channel->getSrcOp(); + } + } + + // Move backward dependency slice close to producer ops. + // Start from the last producer op backwards and move backward slice to + // before each op. This guarantees that the backward slice of each op is + // scheduled as late as possible. + for (auto &group : reverse(groupedProducerOps)) { + for (auto &channel : reverse(group.second)) { + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + SetVector backwardSlice; + getBackwardSlice(channel->getSrcOp(), &backwardSlice, opt); + for (auto &op : backwardSlice) { + if (op->getBlock() == block) + op->moveBefore(channel->getSrcOp()); + } + } + } + + LLVM_DEBUG({ + LDBG("\n"); + LDBG("after reordering producer ops"); + currOp->getParentOfType().dump(); + LDBG("\n"); + }); +} + +unsigned getLoopDepth(Operation *op) { + unsigned depth = 0; + auto pOp = op->getParentOfType(); + while (pOp) { + ++depth; + pOp = pOp->getParentOfType(); + } + return depth; +} + +#if 0 +bool isInnermostLoop(scf::ForOp forOp) { + bool isInner = true; + forOp.walk([&](Operation *subOp) { + if (subOp != forOp.getOperation()) + if (auto forOp = dyn_cast(subOp)) + isInner = false; + }); + return isInner; +} +#endif + +// Generate code +// numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep +Value getNumSteps(scf::ForOp forOp, OpBuilderWithAsyncTaskIds &builder) { + auto loc = forOp.getLoc(); + // numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep + Value numSteps = builder.createWithAsyncTaskIds( + loc, forOp.getUpperBound(), forOp.getLowerBound()); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, + forOp.getStep()); + if (forOp.getStep().getType() != builder.getI64Type()) + numSteps = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), numSteps); + + Value one = builder.createWithAsyncTaskIds(loc, 1, 64); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, one); + Value innerForStep = forOp.getStep(); + if (forOp.getStep().getType() != builder.getI64Type()) + innerForStep = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), forOp.getStep()); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, + innerForStep); + return numSteps; +} + +// Add phase and bufferIndex to be used when lowering the producer. +// When hasParallelReuse is true (i.e this is the innermost loop), we pass in +// accumulatedLoopCount, which is used to initialize initBufferIdx. +// When isOuterOfReuse is true, we add an additional arg for accumLoopCount. +scf::ForOp createNewLoop(scf::ForOp forOp, int numBuffers, + scf::ForOp &parentForOp, Value accumulatedLoopCount, + bool hasParallelReuse, bool isOuterOfReuse) { + auto loc = forOp.getLoc(); + Block *body = forOp.getBody(); + + OpBuilderWithAsyncTaskIds builder(forOp.getContext()); + builder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(forOp)); + builder.setInsertionPoint(forOp); + if (hasParallelReuse) { + LLVM_DEBUG({ + LDBG("createNewLoop hasParallelReuse: "); + accumulatedLoopCount.dump(); + }); + } + + Value numBuffersVal = + builder.createWithAsyncTaskIds(loc, numBuffers, 32); + + // Step 1: Append bufferIdx and phase as forOp arguments. + Value tmpAccumLoopCount; + if (isOuterOfReuse) { + tmpAccumLoopCount = body->insertArgument(body->getNumArguments(), + builder.getI64Type(), loc); + } + Value phase = + body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc); + Value bufferIdx = + body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc); + + // Step 2: Generate bufferIdx and phase for next iteration: + // nextBufferIdx = bufferIdx + 1 + // nextPhase = ((nextBufferIdx < numBuffers && curPhase) || + // (nextBufferIdx >= numBuffers && curPhase^1)) + // nextBufferIdx = nextBufferIdx >= numBuffers ? 0 : nextBufferIdx + auto yieldOp = llvm::cast(body->getTerminator()); + builder.setInsertionPoint(yieldOp); + Value one = builder.createWithAsyncTaskIds(loc, 1, 32); + Value _1_1b = builder.createWithAsyncTaskIds(loc, 1, 1); + // nextBufferIdx = bufferIdx + 1 + Value nextBufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + Value bufferGECond = builder.createWithAsyncTaskIds( + loc, arith::CmpIPredicate::uge, nextBufferIdx, numBuffersVal); + Value bufferLTCond = builder.createWithAsyncTaskIds( + loc, arith::CmpIPredicate::ult, nextBufferIdx, numBuffersVal); + // nextBufferIdx >= numBuffers ? nextBufferIdx - numBuffers : nextBufferIdx + Value moduloBufferIdx = builder.createWithAsyncTaskIds( + loc, nextBufferIdx, numBuffersVal); + nextBufferIdx = builder.createWithAsyncTaskIds( + loc, bufferGECond, moduloBufferIdx, nextBufferIdx); + + // nextPhase = ((nextBufferIdx < numBuffers && curPhase) || + // (nextBufferIdx >= numBuffers && curPhase^1)) + Value flipPhase = + builder.createWithAsyncTaskIds(loc, phase, _1_1b); + Value cond0 = builder.createWithAsyncTaskIds( + loc, bufferGECond, flipPhase); + Value cond1 = builder.createWithAsyncTaskIds( + loc, bufferLTCond, phase); + Value nextPhase = + builder.createWithAsyncTaskIds(loc, cond0, cond1); + + // Step 3: Add nextBufferIdx and nextPhase to yieldOp. + if (isOuterOfReuse) { + // We have not iterated through the body yet, so do not have the right value + // for nextTmpIdx. This will be fixed in the caller. + Value nextTmpIdx = tmpAccumLoopCount; + yieldOp->insertOperands(yieldOp.getNumOperands(), + {nextTmpIdx, nextPhase, nextBufferIdx}); + } else + yieldOp->insertOperands(yieldOp.getNumOperands(), + {nextPhase, nextBufferIdx}); + + // Step 4: Create loop arguments for the new ForOp. + SmallVector newLoopArgs; + for (auto operand : forOp.getInitArgs()) + newLoopArgs.push_back(operand); + + builder.setInsertionPoint(forOp); + Value initBufferIdx, initPhase; + // Set initial values for bufferIdx and phase. + if (parentForOp) { + if (hasParallelReuse) { + // Handling ForOp with an outer loop, use the passed-in value as initial + // value. + initBufferIdx = accumulatedLoopCount; + } else { + // It is possible that parent loop induction variable has different type. + // Here we promote to 64 bit. + // numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep + Value numSteps = getNumSteps(forOp, builder); + + // TODO: use a global flattened iteration space index for multi-dim loops. + // initBufferIdx = (parentInductionVar - parentLowBound) / parentStep * + // numSteps + Value parentIterIdx = builder.createWithAsyncTaskIds( + loc, parentForOp.getInductionVar(), parentForOp.getLowerBound()); + parentIterIdx = builder.createWithAsyncTaskIds( + loc, parentIterIdx, parentForOp.getStep()); + if (parentForOp.getStep().getType() != builder.getI64Type()) + parentIterIdx = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), parentIterIdx); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, parentIterIdx, numSteps); + } + + numBuffersVal = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), numBuffersVal); + // Calculate tmpIdx / numBuffers + // initBufferIdx = tmpIdx - tmpIdx / numBuffers * numBuffers + // initPhase = (tmpIdx / numBuffers) & 1 + Value bufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, numBuffersVal); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, + builder.createWithAsyncTaskIds(loc, bufferIdx, + numBuffersVal)); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, builder.getI32Type(), initBufferIdx); + + Value one = + builder.createWithAsyncTaskIds(loc, 1, 64); + bufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + initPhase = builder.createWithAsyncTaskIds( + loc, builder.getI1Type(), bufferIdx); + } else { + if (hasParallelReuse) { + // Handling ForOp without outer loop. + // tmpIdx = accumulatedLoopCount + initBufferIdx = accumulatedLoopCount; + numBuffersVal = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), numBuffersVal); + // bufferIdx = tmpIdx / numBuffers + Value bufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, numBuffersVal); + // initBufferIdx = tmpIdx - tmpIdx/numBuffers * numBuffers (modulo) + initBufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, + builder.createWithAsyncTaskIds(loc, bufferIdx, + numBuffersVal)); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, builder.getI32Type(), initBufferIdx); + + Value one = + builder.createWithAsyncTaskIds(loc, 1, 64); + // initPhase = (tmpIdx / numBuffers) & 1 + bufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + initPhase = builder.createWithAsyncTaskIds( + loc, builder.getI1Type(), bufferIdx); + } else { + // Set initial phase to false, and initial bufferIdx to 0. + initBufferIdx = + builder.createWithAsyncTaskIds(loc, 0, 32); + initPhase = + builder.createWithAsyncTaskIds(loc, 0, 1); + } + } + if (isOuterOfReuse) { + assert(!hasParallelReuse); + Value initTmpIdx = + builder.createWithAsyncTaskIds(loc, 0, 64); + newLoopArgs.append({initTmpIdx, initPhase, initBufferIdx}); + } else + newLoopArgs.append({initPhase, initBufferIdx}); + + // Step 5: Create newForOp and take the region of the original forOp. + auto newForOp = builder.createWithAsyncTaskIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + if (forOp->getAttr("tt.loop_schedule")) + newForOp->setAttr("tt.loop_schedule", forOp->getAttr("tt.loop_schedule")); + newForOp.getRegion().takeBody(forOp.getRegion()); + + // Step 6: Replace forOp with newForOp. + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + forOp.erase(); + + return newForOp; +} + +// Find top-level ops which contain at least one channel. If a channel's +// getSrcOp() and getDstOp() belong to the inner loop, the outer loop will be +// part of asyncTaskOps. +SmallVector +getTaskTopRegion(triton::FuncOp funcOp, + const SmallVector &channels) { + SmallVector asyncTaskOps; + auto isAsyncTaskTopOp = [&](Operation *taskTopOp) -> bool { + for (auto c : channels) { + Operation *producer = c->getSrcOp(), *consumer = c->getDstOp(); + while (producer && !isa(producer->getParentOp())) { + producer = producer->getParentOp(); + } + while (consumer && !isa(consumer->getParentOp())) { + consumer = consumer->getParentOp(); + } + if (producer == taskTopOp && consumer == taskTopOp) + return true; + } + return false; + }; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &bodyOp : block.getOperations()) { + Operation *op = &bodyOp; + if (op->getNumRegions() <= 0) + continue; + // If this op does not contain both a producer taskId and a consumer + // taskId, continue. + if (getAsyncTaskIds(op).size() == 1) + continue; + if (isAsyncTaskTopOp(op)) + asyncTaskOps.push_back(op); + } + } + + LLVM_DEBUG({ + LDBG("\nTop Task Bodies"); + for (auto op : asyncTaskOps) { + LDBG("\nTask Body:"); + op->dump(); + } + }); + return asyncTaskOps; +} + +static unsigned getNumChannelsInOp(Operation *op, + const SmallVector &channels, + SmallVector &channelsInOp) { + unsigned num = 0; + for (auto *ch : channels) { + // Get the immediate parent. + auto srcParent = ch->getSrcOp()->getParentOp(); + auto dstParent = ch->getDstOp()->getParentOp(); + if (srcParent == op && dstParent == op) + channelsInOp.push_back(ch); + } + return channelsInOp.size(); +} + +void reuseBuffers(SmallVector &taskTopOps, + const SmallVector &channels, + DenseMap &mapToRepresenting, + SmallVector &opsWithBufferReuse) { + // For the case of multiple parallel ForOps with same number of channels, + // we can try reusing the buffers across the parallel ForOps or across ForOps + // and IfOps. Case 1: + // ForOp_A + // ForOp_B + // --> opsWithBufferReuse: ForOp_A ForOp_B + // Case 2: + // ForOp (persistent) + // ForOp_A + // ForOp_B + // --> opsWithBufferReuse: ForOp_A ForOp_B + // Case 3: + // ForOp (persistent) + // ForOp_A + // --> --> opsWithBufferReuse: ForOp_A + // Case 4: + // ForOp + // IfOp + // --> opsWithBufferReuse: ForOp IfOp + // We use accumLoopCount to update bufferIdx for the sharing groups. If there + // is an outer loop, we will need to add an argument to it. Assume we handle + // outer ForOp first, then inner ForOp in program order. + unsigned maxDepth = 0; + DenseMap> loopDepthMap; + for (auto &op : taskTopOps) { + op->walk([&](Operation *subOp) { + if (dyn_cast(subOp) || dyn_cast(subOp)) { + unsigned tDepth = getLoopDepth(subOp); + loopDepthMap[tDepth].push_back(subOp); + if (tDepth > maxDepth) + maxDepth = tDepth; + } + }); + } + // A list of IfOps/ForOps at the innermost level: loopDepthMap[maxDepth] + auto &opsAtMaxDepth = loopDepthMap[maxDepth]; + LDBG("reuseBuffers number of inner ops: " << opsAtMaxDepth.size()); + if (opsAtMaxDepth.empty()) + return; + if (opsAtMaxDepth.size() == 1 && dyn_cast(opsAtMaxDepth[0]) && + maxDepth > 0) { + // Persistent with a single inner loop. There is no sharing group, but + // we can use the logic to generate accumLoopCount for persistent case. + opsWithBufferReuse = opsAtMaxDepth; + LDBG("-- opsWithBufferReuse with size 1"); + return; + } + // Find ops that contain immediate channels. And the ops do not overlap + // live range. For example + // If + // For + // --> If and For can overlap. But + // For + // If + // --> can't overlap + SmallVector innerOps; + SmallVector innerLoops; + for (auto *innerOp : opsAtMaxDepth) { + SmallVector channelsInOp; + getNumChannelsInOp(innerOp, channels, channelsInOp); + if (channelsInOp.empty()) + continue; + innerOps.push_back(innerOp); + if (dyn_cast(innerOp)) + innerLoops.push_back(innerOp); + } + // Make sure opsWithBufferReuse are under the same ForOp or at the top level. + // Make sure opsWithBufferReuse contain the same number of channels, and the + // same numBuffers for the channels. Channels in the first op will be the + // representing channels. All sharing groups will span the same set of regions + // in opsWithBufferReuse. + bool firstOp = true; + Operation *outerLoop = nullptr; + unsigned numChannels = 0, numBuffers = 0; + SmallVector channelsInOpOne; + for (auto *innerOp : innerOps) { + // Ignore IfOps that overlap with innerLoops. + if (dyn_cast(innerOp)) { + bool ignore = false; + for (auto *innerLoop : innerLoops) { + if (innerOp == innerLoop->getParentOp()) { + ignore = true; + break; + } + } + if (ignore) + continue; + } + scf::ForOp parentForOp = innerOp->getParentOfType(); + SmallVector channelsInOp; + getNumChannelsInOp(innerOp, channels, channelsInOp); + if (firstOp) { + outerLoop = parentForOp.getOperation(); + numChannels = channelsInOp.size(); + channelsInOpOne = channelsInOp; + numBuffers = channelsInOp[0]->numBuffers; + opsWithBufferReuse.push_back(innerOp); + } else { + if (outerLoop != parentForOp.getOperation() || + numChannels != channelsInOp.size()) + // Not under the same outer loop. + return; + if (numBuffers != channelsInOp[0]->numBuffers) + return; + unsigned idx = 0; + for (auto *ch : channelsInOp) { + // TODO: sort the channels in the loop according to buffer size. + mapToRepresenting[ch] = channelsInOpOne[idx++]; + } + opsWithBufferReuse.push_back(innerOp); + } + firstOp = false; + } + if (opsWithBufferReuse.size() == 1 && maxDepth == 0) + // A single op in buffer reuse and there is no outer loop. + opsWithBufferReuse.clear(); + LLVM_DEBUG({ + LDBG("reuseBuffers: " << numChannels << " channels opsWithBufferReuse " + << opsWithBufferReuse.size()); + for (auto &kv : mapToRepresenting) { + llvm::dbgs() << "---- from "; + kv.first->getDstOp()->dump(); + llvm::dbgs() << "---- to "; + kv.second->getDstOp()->dump(); + } + }); + // opsWithBufferReuse = innerOps; +} + +// Go through a list of operations under one scope. +// prevAccum can be null if there is an outer loop for the reuse loops. +Value updateAccumLoopCount(SmallVector &opList, + unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &opsWithBufferReuse, + Value prevAccum) { + for (Operation *op : opList) { + if (auto forOp = dyn_cast(op)) { + auto newForOp = + createNewLoopWrapper(forOp, numBuffers, taskTopOps, commonOuterLoop, + opsWithBufferReuse, prevAccum); + // Update prevAccum to be after the loop. + // If the loop is in opsWithBufferReuse, generate prevAccum + numSteps. + bool hasReuse = false; + for (auto tLoop : opsWithBufferReuse) + if (newForOp.getOperation() == tLoop) { + hasReuse = true; + break; + } + if (hasReuse) { + // Update accumLoopCount = prevAccum + numSteps. + OpBuilderWithAsyncTaskIds builder(newForOp.getContext()); + builder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(newForOp)); + builder.setInsertionPointAfter(newForOp); + + Value numSteps = getNumSteps(newForOp, builder); + prevAccum = builder.createWithAsyncTaskIds( + newForOp.getLoc(), prevAccum, numSteps); + } + // If the loop is the outer loop for a reuse loop, we are done. + // At this point, op is no longer valid. + } else if (auto ifOp = dyn_cast(op)) { + if (needAccumulatedLoopCnt(ifOp, opsWithBufferReuse)) { + auto newIfOp = + rewriteIfOp(ifOp, numBuffers, taskTopOps, commonOuterLoop, + opsWithBufferReuse, prevAccum); + // update prevAccum to be result of the new IfOp. + assert(newIfOp.getNumResults() >= 1); + auto numRes = newIfOp.getNumResults(); + LDBG("update prevAccum with result from IfOp"); + prevAccum = newIfOp.getResult(numRes - 1); // last result + } else { + // Still need to process ForOps in pre-order. + SmallVector innerForOps; + ifOp->walk([&](Operation *subOp) { + if (auto forOp = dyn_cast(subOp)) { + innerForOps.push_back(forOp); + } + }); + for (auto innerFor : innerForOps) + createNewLoopWrapper(innerFor, numBuffers, taskTopOps, + commonOuterLoop, opsWithBufferReuse, prevAccum); + } + } + } + return prevAccum; +} + +scf::ForOp createNewLoopWrapper(scf::ForOp origForOp, unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &opsWithBufferReuse, + Value prevAccum) { + LLVM_DEBUG({ + LDBG("call createNewLoop on"); + origForOp.dump(); + }); + + scf::ForOp parentForOp = origForOp->getParentOfType(); + scf::ForOp newForOp; + // for(...) -> for(..., phase, bufferIdx) + unsigned loopNumBuffers = getNumBuffersOrDefault(origForOp, numBuffers); + + bool isOuterOfReuse = + commonOuterLoop && commonOuterLoop == origForOp.getOperation(); + bool hasReuse = false; + for (auto tLoop : opsWithBufferReuse) + if (origForOp.getOperation() == tLoop) { + hasReuse = true; + break; + } + // Set accumulatedLoopCount when this is a loop in opsWithBufferReuse. If + // this loop has an outer loop, an extra arg for accumLoopCount should have + // been added to the outer loop. + Value accumulatedLoopCount = prevAccum; // Value(); + newForOp = createNewLoop(origForOp, loopNumBuffers, parentForOp, + accumulatedLoopCount, hasReuse, isOuterOfReuse); + LLVM_DEBUG({ + LDBG("after createNewLoop "); + newForOp.dump(); + }); + // origForOp is erased in createNewLoop. If origForOp is a top operation + // (i.e in taskTopOps), make sure taskTopOps is updated with the newForOp. + auto asyncTaskLoopForItr = + std::find(taskTopOps.begin(), taskTopOps.end(), origForOp.getOperation()); + if (asyncTaskLoopForItr != taskTopOps.end()) { + // Update taskTopOps. + *asyncTaskLoopForItr = newForOp.getOperation(); + } + + // origForOp is erased in createNewLoop. If origForOp is in + // opsWithBufferReuse, replace. + auto tmpIter = std::find(opsWithBufferReuse.begin(), opsWithBufferReuse.end(), + origForOp.getOperation()); + if (tmpIter != opsWithBufferReuse.end()) { + *tmpIter = newForOp.getOperation(); + } + + // Handle ops in loop body, only IfOps and ForOps. + SmallVector opList; + for (Operation &op : newForOp.getBody()->without_terminator()) { + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + } + Value endAccum = updateAccumLoopCount( + opList, numBuffers, taskTopOps, commonOuterLoop, opsWithBufferReuse, + isOuterOfReuse ? getAccumLoopCountArg(newForOp) : prevAccum); + + // Update yieldOp. + if (isOuterOfReuse) { + Value arg = getAccumLoopCountArg(newForOp); + Operation *yieldOp = newForOp.getBody()->getTerminator(); + yieldOp->replaceUsesOfWith(arg, endAccum); + } + return newForOp; +} + +// This function takes a list of channels, a mapping from a channel +// to its representing channel if the key shares smem space with the +// representing channel, and a list of loops that are sharing smem spaces. Note +// that every loop in opsWithBufferReuse either has the same outer loop or has +// no outer loop. +// For ForOps in taskTopOps, create new ForOp for each by adding phase, +// bufferIdx to the arguments. In the case of sharing smem, we need to traverse +// and update IfOps when necessary. We call updateAccumLoopCount on the list +// of top level Ops that are ForOps or IfOps enclosing a loop with buffer reuse. +// updateAccumLoopCount calls createNewLoopWrapper on ForOps, and rewriteIfOp on +// IfOps. Both will call updateAccumLoopCount on the list of Ops in the ForOp +// body or the thenBlock, elseBlock for IfOp. +Value appendBufferIdxArgs( + SmallVector &taskTopOps, unsigned numBuffers, + const SmallVector &channels, + const DenseMap &mapToRepresenting, + SmallVector &opsWithBufferReuse) { + // In order to handle sharing smem for a list of loops, we have two cases, + // one is the top-level op containing all loops in opsWithBufferReuse is + // a ForOp. + bool genAccumLoopCount = !opsWithBufferReuse.empty(); + Operation *commonOuterLoop = nullptr; + if (genAccumLoopCount) { + auto oneFor = opsWithBufferReuse[0]; + scf::ForOp parentForOp = oneFor->getParentOfType(); + if (parentForOp) + commonOuterLoop = parentForOp.getOperation(); + } + + // When there is no outer loop, we need to create a place holder for + // tmpAccumLoopCount. Every forOp in opsWithBufferReuse either has the same + // outer loop or has no outer loop. + Value tmpAccumLoopCount; + if (opsWithBufferReuse.size() > 1 && !commonOuterLoop) { + auto oneFor = opsWithBufferReuse[0]; + // Initialize tmpAccumLoopCount to be 0. + OpBuilderWithAsyncTaskIds builder(taskTopOps[0]->getContext()); + builder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(oneFor)); + builder.setInsertionPoint(taskTopOps[0]); + tmpAccumLoopCount = builder.createWithAsyncTaskIds( + oneFor->getLoc(), 0, 64); + } + + SmallVector opList; + for (auto &op : taskTopOps) { + if (auto origIfOp = dyn_cast(op)) { + opList.push_back(op); + } + if (auto origForOp = dyn_cast(op)) + opList.push_back(op); + } + updateAccumLoopCount(opList, numBuffers, taskTopOps, commonOuterLoop, + opsWithBufferReuse, tmpAccumLoopCount); + + return tmpAccumLoopCount; +} + +// Create an allocation to hold the mbarriers. +static Value createBarrierAlloc(triton::FuncOp funcOp, unsigned distance) { + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(funcOp.getContext()); + Location loc = funcOp.getLoc(); + auto context = funcOp.getContext(); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = ttg::SwizzledSharedEncodingAttr::get( + context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = ttg::MemDescType::get( + {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + ttg::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = builder.create( + loc, barrierMemDescType, Value()); + for (unsigned i = 0; i < distance; i++) { + Value idx = builder.create(loc, i, 32); + Value barrierView = builder.create( + loc, singleBarrierMemDescType, barrierAlloc, idx); + builder.create(funcOp->getLoc(), barrierView, 1); + } + return barrierAlloc; +} + +// channelsGroupedByConsumers: channels are grouped together. +// Go through each group, check the first channel in the group, create a token +// for each consumer taskId. Return a map that maps each channel + consumer +// taskId to a token. Also update barrierAllocMap that maps each channel + +// consumer taskId to a BarrierAlloc. +DenseMap> createToken( + const DenseMap> + &channelsGroupedByConsumers, + const SmallVector &orderedChannels, triton::FuncOp funcOp, + int numConsumerGroups, + const DenseMap> ©OpMap, + DenseMap> &channelReuse, + DenseMap> &barrierAllocMap) { + DenseMap> ret; + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + for (auto *key : orderedChannels) { + auto it = channelsGroupedByConsumers.find(key); + Channel *channel = it->second.front(); + if (!channelReuse.count(channel)) + continue; + for (auto consumerAsyncTaskId : channel->relation.second) { + ttng::TokenLoadType tokenLoadType; + auto copyOp = copyOpMap.find(channel)->second.first; + if (isa(copyOp)) { + tokenLoadType = ttng::TokenLoadType::AsyncLoadOp; + } else if (isa(copyOp)) { + tokenLoadType = ttng::TokenLoadType::TMALoadOp; + } else if (isa(copyOp)) { + tokenLoadType = ttng::TokenLoadType::LocalStoreOp; + } else { + llvm_unreachable("Unexpected load type"); + } + + Value v; + if (it->second.front()->getSrcOp()->getParentOfType()) { + v = builder.create( + funcOp.getLoc(), channel->numBuffers, tokenLoadType); + } else { + v = builder.create(funcOp.getLoc(), 1, + tokenLoadType); + } + // Channels in the group share the same set of tokens. + for (auto &c : it->second) { + ret[c][consumerAsyncTaskId] = v; + } + for (auto *reuse : channelReuse[channel]) { + ret[reuse][consumerAsyncTaskId] = v; + } + + auto producerOp = it->second.front()->getSrcOp(); + if (isa(producerOp)) { + Value bAlloc = createBarrierAlloc(funcOp, channel->numBuffers); + // Channels in the group share the same set of tokens. + for (auto &c : it->second) { + ret[c][consumerAsyncTaskId] = v; + barrierAllocMap[c][consumerAsyncTaskId] = bAlloc; + } + for (auto *reuse : channelReuse[channel]) { + ret[reuse][consumerAsyncTaskId] = v; + barrierAllocMap[reuse][consumerAsyncTaskId] = bAlloc; + } + } + } + } + return ret; +} + +// Create a buffer array for each producer op, if the producer is in a ForOp, +// the buffer array will contain numBuffers. +DenseMap createBuffer( + DenseMap> &channelsGroupedByProducers, + triton::FuncOp funcOp, int numConsumerGroups, + DenseMap &mapToRepresenting, + DenseMap> &channelReuse) { + + DenseMap bufferMap; + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + DenseSet visited; + for (auto &item : channelsGroupedByProducers) { + auto &channels = item.second; + for (auto c : channels) { + assert(!visited.count(c)); + visited.insert(c); + if (mapToRepresenting.count(c)) { + channelReuse[mapToRepresenting[c]].push_back(c); + LDBG("update channelReuse key " << mapToRepresenting[c] << " " << c); + } else { + channelReuse[c].push_back(c); + LDBG("update channelReuse key " << c << " " << c); + } + } + } + for (auto &item : channelsGroupedByProducers) { + auto &channels = item.second; + auto srcValue = item.first->getSrcOperand(); + auto srcOp = item.first->getSrcOp(); + unsigned numBuffers = channels.front()->numBuffers; + + if (auto tensorType = dyn_cast(srcValue.getType())) { + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::NVMMASharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType, /*fp4Padded*/ false); + auto sliceType = + RankedTensorType::get(sliceShape, elemType, sharedLayout); + + // Get shape, layout and type of the complete buffer + SmallVector bufferShape(sliceShape.begin(), sliceShape.end()); + if (srcOp->getParentOfType()) + bufferShape.insert(bufferShape.begin(), numBuffers); + else + bufferShape.insert(bufferShape.begin(), 1); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto bufferType = + RankedTensorType::get(bufferShape, elemType, sharedLayout); + Type memdescType = + ttg::MemDescType::get(bufferShape, elemType, sharedLayout, + sharedMemorySpace, /*mutableMemory*/ true); + Value buffer = + builder.create(funcOp.getLoc(), memdescType); + + // Channels in the group share the same buffer. + for (auto c : channels) + bufferMap[c] = buffer; + } else { + llvm_unreachable("Unexpected result type"); + } + } + unsigned groupId = 0; + for (auto &kv : channelReuse) { + if (kv.second.size() <= 1) + continue; + bufferMap[kv.first].getDefiningOp()->setAttr( + "allocation.shareGroup", + IntegerAttr::get(IntegerType::get(context, 32), groupId)); + for (auto *c : kv.second) + bufferMap[c].getDefiningOp()->setAttr( + "allocation.shareGroup", + IntegerAttr::get(IntegerType::get(context, 32), groupId)); + ++groupId; + } + return bufferMap; +} + +static std::pair +createAsyncCopy(const DenseMap &bufferMap, Channel *c, + Operation *op, SmallVector &asyncTasksPC, + Value bufferIdx, Value bufferIdxExtract) { + auto loadOp = cast(op); + auto buffer = bufferMap.find(c)->second; + MLIRContext *context = loadOp->getContext(); + OpBuilderWithAsyncTaskIds builder(context); + builder.setInsertionPoint(loadOp->getParentOp()); + builder.setAsynTaskIdsFromArray(asyncTasksPC); + + builder.setInsertionPoint(loadOp); + Value loadResult = loadOp.getResult(); + auto tensorType = dyn_cast(loadResult.getType()); + if (!tensorType) + return {nullptr, nullptr}; + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::NVMMASharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType, /*fp4Padded*/ false); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + ttg::MemDescType subviewTy = + ttg::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true); + Value zero = builder.createWithAsyncTaskIds( + loadOp.getLoc(), 0, 32); + SmallVector copyOffsets(sliceType.getRank() + 1, zero); + copyOffsets[0] = bufferIdx; + builder.setAsyncTaskIdsFromOp(loadOp); + builder.setInsertionPointAfter(loadOp); + auto view = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, copyOffsets); + // Create cp.async + Operation *copy = + builder.createWithAsyncTaskIds( + loadOp.getLoc(), loadOp.getPtr(), view, loadOp.getMask(), + loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile()); + + // Extract part. + builder.setAsyncTaskIdsFromValueUsers(loadResult); + builder.setInsertionPoint(c->getDstOp()); + SmallVector loadOffsets(sliceType.getRank() + 1, zero); + loadOffsets[0] = bufferIdxExtract; + auto viewLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, loadOffsets); + auto sharedLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), loadOp.getType(), viewLoad /*,wait->getResult(0)*/); + // Replace all uses of loadResult + loadResult.replaceAllUsesWith(sharedLoad.getResult()); + loadOp.erase(); + return {copy, sharedLoad}; +} + +// Create a local copy for a channel that is populated by the producer and +// accessed by the consumer. +static std::pair +createLocalCopy(const DenseMap &bufferMap, Channel *channel, + Value srcBufferIdx, Value dstBufferIdx) { + Operation *srcOp = channel->getSrcOp(); + Operation *dstOp = channel->getDstOp(); + MLIRContext *context = srcOp->getContext(); + auto buffer = bufferMap.find(channel)->second; + + Value srcValue = channel->getSrcOperand(); + auto tensorType = dyn_cast(srcValue.getType()); + if (!tensorType) + return {nullptr, nullptr}; + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::NVMMASharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType, /*fp4Padded*/ false); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + ttg::MemDescType subviewTy = + ttg::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true); + + // Consumer part. + OpBuilderWithAsyncTaskIds builder(dstOp); + builder.setAsyncTaskIdsFromOp(dstOp); + builder.setInsertionPoint(dstOp); + Value zero = builder.createWithAsyncTaskIds( + dstOp->getLoc(), 0, 32); + SmallVector loadOffsets(sliceType.getRank() + 1, zero); + loadOffsets[0] = dstBufferIdx; + auto dstView = builder.createWithAsyncTaskIds( + dstOp->getLoc(), subviewTy, buffer, loadOffsets); + auto sharedLoad = builder.createWithAsyncTaskIds( + dstOp->getLoc(), srcValue.getType(), dstView); + srcValue.replaceAllUsesWith(sharedLoad.getResult()); + + // Producer part. Create local_store for new producers. + builder.setAsynTaskIdsFromArray(channel->relation.first); + builder.setInsertionPoint(srcOp->getParentOp()); + zero = builder.createWithAsyncTaskIds(srcOp->getLoc(), + 0, 32); + SmallVector storeOffsets(sliceType.getRank() + 1, zero); + storeOffsets[0] = srcBufferIdx; + builder.setInsertionPointAfter(srcOp); + auto srcView = builder.createWithAsyncTaskIds( + srcOp->getLoc(), subviewTy, buffer, storeOffsets); + // Create local_alloc + Operation *copy = builder.createWithAsyncTaskIds( + srcOp->getLoc(), srcValue, srcView); + return {copy, sharedLoad}; +} + +static int getTMALoadSize(tt::ExperimentalDescriptorLoadOp &tmaLoad) { + auto tensorTy = cast(tmaLoad->getResult(0).getType()); + int loadSize = product(tensorTy.getShape()); + return loadSize * tensorTy.getElementType().getIntOrFloatBitWidth() / 8; +} + +Value getBarrierForPipelineStage(OpBuilderWithAsyncTaskIds &builder, + Value barrierAlloc, Value bufferIdx) { + auto context = barrierAlloc.getContext(); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + ttg::MemDescType barrierTy = ttg::MemDescType::get( + {1}, builder.getI64Type(), + cast(barrierAlloc.getType()).getEncoding(), + sharedMemorySpace, + /*mutableMemory=*/true); + + // Create barrierForTMA from barrierAlloc. + return builder.createWithAsyncTaskIds( + barrierAlloc.getLoc(), barrierTy, barrierAlloc, + ArrayRef({bufferIdx})); +} + +Value getBufferForPipelineStage(OpBuilderWithAsyncTaskIds &builder, + Type loadType, Value buffer, Value bufferIdx, + bool mutableMem) { + auto context = buffer.getContext(); + auto tensorType = dyn_cast(loadType); + assert(tensorType); + + auto order = ttg::getOrder(tensorType); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::NVMMASharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType, /*fp4Padded*/ false); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + ttg::MemDescType subviewTy = + ttg::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemOry=*/mutableMem); + + Value zero = builder.createWithAsyncTaskIds( + buffer.getLoc(), 0, 32); + SmallVector copyOffsets(sliceType.getRank() + 1, zero); + copyOffsets[0] = bufferIdx; + + return builder.createWithAsyncTaskIds( + buffer.getLoc(), subviewTy, buffer, copyOffsets); +} + +Operation * +optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, + SmallVector &tmaLoads, + SmallVector &buffers, Value barrierAlloc, + Value bufferIdx, Value bufferIdxExtract, Value phase, + Operation *headProducer, Operation *headConsumer) { + auto loc = barrierAlloc.getLoc(); + + // Compute the total size of the loads. + int sizeInBytes = 0; + for (auto &tmaLoad : tmaLoads) { + sizeInBytes += getTMALoadSize(tmaLoad); + } + + // For each of the following ops, we will operate on a subview of each value + // according to the pipeline stage. + + // Create a barrier_expect with the appropriate size and insert it before the + // first load. + builder.setInsertionPoint(headProducer); + builder.setAsyncTaskIdsFromOp(headProducer); + auto prodBarrier = + getBarrierForPipelineStage(builder, barrierAlloc, bufferIdx); + auto pred = builder.createWithAsyncTaskIds(loc, 1, 1); + auto expect = builder.createWithAsyncTaskIds( + loc, prodBarrier, sizeInBytes, pred); + + // Convert all the producers to async_tma_copy_global_to_local + Operation *copy = nullptr; + for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { + builder.setInsertionPoint(tmaLoad); + auto pipelineBuffer = getBufferForPipelineStage(builder, tmaLoad.getType(), + buffer, bufferIdx, true); + Value tmaPtr = + builder + .createWithAsyncTaskIds( + loc, tmaLoad.getDesc()); + copy = builder.createWithAsyncTaskIds( + loc, tmaPtr, tmaLoad.getIndices(), prodBarrier, pipelineBuffer, pred); + } + + // Create a wait_barrier before the first consumer. + builder.setInsertionPoint(headConsumer); + builder.setAsyncTaskIdsFromOp(headConsumer); + auto consBarrier = + getBarrierForPipelineStage(builder, barrierAlloc, bufferIdxExtract); + phase = builder.createWithAsyncTaskIds( + loc, builder.getI32Type(), phase); + auto wait = builder.createWithAsyncTaskIds( + loc, consBarrier, phase); + + // Convert all the consumers to local_load + for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { + auto pipelineBuffer = getBufferForPipelineStage( + builder, tmaLoad.getType(), buffer, bufferIdxExtract, false); + auto sharedLoad = builder.createWithAsyncTaskIds( + loc, tmaLoad.getType(), pipelineBuffer); + + Value loadResult = tmaLoad.getResult(); + tmaLoad.getResult().replaceAllUsesWith(sharedLoad.getResult()); + tmaLoad.erase(); + } + return copy; +} + +// Lower producers for channels. Here channels are grouped in +// "channelsGroupedByConsumers". tokenMap tracks the set of tokens for each +// channel. +void insertAsyncComm( + triton::FuncOp funcOp, + const DenseMap> + &channelsGroupedByConsumers, + const DenseMap> &tokenMap, + const DenseMap> &barrierAllocMap, + const DenseMap &bufferMap, + const DenseMap> ©OpMap, + int numConsumerGroups) { + + // Find the operation that is along producer's parent chain, and its parent + // is the same op as producer's parent. Here p is producer, and c is consumer. + auto getSameLevelOp = [](Operation *p, Operation *c) -> Operation * { + while (!isa(c)) { + if (c->getParentOp() == p->getParentOp()) { + return c; + } + c = c->getParentOp(); + } + llvm_unreachable("Failed to find consumer's same level Op with producer"); + }; + + auto consumerReleaseHeuristic = [&](Operation *p, Operation *c, + int consumerAsyncTaskId) -> Operation * { + if (c->getBlock() != p->getBlock()) + return getSameLevelOp(p, c); + + // Find a common place for all users of the consumer, which would be the + // common post dominator. + mlir::PostDominanceInfo dom(funcOp); + std::unordered_set mutuallyNonDominatingUsers; + SmallVector users; + for (auto user : c->getUsers()) { + if (isa(user)) { + // TransOp is not a real consumer. It caculates the shared memory + // address for the real consumer. Continue to find its transitive users + // recursively. + DenseSet visited; + SmallVector transUsers; + transUsers.push_back(user); + while (!transUsers.empty()) { + auto transUser = transUsers.pop_back_val(); + visited.insert(transUser); + if (isa(transUser)) { + for (auto transitiveUser : transUser->getUsers()) { + if (!visited.count(transitiveUser)) + transUsers.push_back(transitiveUser); + } + } else { + users.push_back(transUser); + } + } + } else { + users.push_back(user); + } + } + + for (auto user : users) { + auto it = mutuallyNonDominatingUsers.begin(); + while (it != mutuallyNonDominatingUsers.end()) { + if (dom.properlyPostDominates(user, *it)) { + it = mutuallyNonDominatingUsers.erase(it); + } else if (dom.properlyPostDominates(*it, user)) { + break; + } else { + ++it; + } + } + if (it == mutuallyNonDominatingUsers.end()) + mutuallyNonDominatingUsers.insert(user); + } + + if (mutuallyNonDominatingUsers.size() == 1) { + // Find the common parent of this user and c + auto user = *mutuallyNonDominatingUsers.begin(); + while (user && user->getParentOp() != c->getParentOp()) + user = user->getParentOp(); + assert(user && "Failed to find common parent of this user and c"); + return user; + } + + for (auto &op : reverse(c->getBlock()->getOperations())) { + auto asyncTasks = getAsyncTaskIds(&op); + if (asyncTasks.size() == 1 && asyncTasks[0] == consumerAsyncTaskId) + return &op; + } + + return nullptr; + }; + + // Go through each channel group. + for (auto kv : channelsGroupedByConsumers) { + // Find head and tail ops. + DenseSet producerOps; + DenseSet consumerOps; + for (auto &c : kv.second) { + auto pcOp = copyOpMap.find(c)->second; + producerOps.insert(pcOp.first); + consumerOps.insert(pcOp.second); + consumerOps.insert(c->getDstOp()); + } + + // Find head producer + auto producerBlock = kv.second.front()->getSrcOp()->getBlock(); + Operation *headProducer = nullptr; + for (auto &op : producerBlock->getOperations()) { + if (producerOps.count(&op)) { + headProducer = &op; + break; + } + } + // Find tail producer + Operation *tailProducer = nullptr; + for (auto &op : reverse(producerBlock->getOperations())) { + if (producerOps.count(&op)) { + tailProducer = &op; + break; + } + } + + // Find head consumer and tail consumer + auto consumerBlock = kv.second.front()->getDstOp()->getBlock(); + Operation *headConsumer = nullptr; + for (auto &op : consumerBlock->getOperations()) { + if (consumerOps.count(&op)) { + headConsumer = &op; + break; + } + } + Operation *tailConsumer = nullptr; + for (auto &op : reverse(consumerBlock->getOperations())) { + if (consumerOps.count(&op)) { + tailConsumer = &op; + break; + } + } + + // We have one set of tokens for each channel group. + auto tokens = tokenMap.find(kv.second.front())->second; + auto masterChannel = kv.getFirst(); + + SmallVector asyncTaskP; + asyncTaskP.push_back(masterChannel->relation.first); + SmallVector &asyncTaskC = masterChannel->relation.second; + SmallVector asyncTasksPC = asyncTaskP; + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskC.begin(), + asyncTaskC.end()); + + OpBuilderWithAsyncTaskIds builder(headProducer->getContext()); + if (auto funcOp = dyn_cast(headProducer->getParentOp())) { + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + } else { + builder.setInsertionPoint(headProducer->getParentOp()); + } + builder.setAsynTaskIdsFromArray(asyncTasksPC); + + Value bufferIdx; + Value phase = Value(); + if (auto forOp = headProducer->getParentOfType()) { + // We already added phase, bufferIdx to the ForOp. + auto tSize = forOp.getBody()->getArguments().size(); + assert(tSize >= 2); + bufferIdx = forOp.getBody()->getArguments().back(); + phase = forOp.getBody()->getArgument(tSize - 2); // next to last argument + } else { + // Producer is not in a ForOp, create phase and bufferIdx here. + bufferIdx = builder.createWithAsyncTaskIds( + headProducer->getLoc(), 0, 32); + phase = builder.createWithAsyncTaskIds( + headProducer->getLoc(), 0, 1); + } + + builder.setAsynTaskIdsFromArray(masterChannel->relation.first); + for (auto token : tokens) { + // Insert ProducerAcquireOp before the producer. + builder.setInsertionPoint(headProducer); + builder.createWithAsyncTaskIds( + headProducer->getLoc(), token.second, bufferIdx, phase); + + // Insert ProducerCommitOp if producer is LoadOp. For TMA, TMA lowering + // will handle the ProducerCommit. + if (!isa(headProducer)) { + builder.setInsertionPointAfter(tailProducer); + builder.createWithAsyncTaskIds( + tailProducer->getLoc(), token.second, bufferIdx); + } + } + + for (auto token : tokens) { + builder.setAsynTaskIdsFromArray(token.first); + // Insert ConsumerWaitOp + if (!isa(headProducer)) { + auto consumerWaitPoint = getSameLevelOp(headProducer, headConsumer); + builder.setInsertionPoint(consumerWaitPoint); + builder.createWithAsyncTaskIds( + headConsumer->getLoc(), token.second, bufferIdx, phase); + } + + // Insert ConsumerReleaseOp. + auto consumerReleasePoint = + consumerReleaseHeuristic(tailProducer, tailConsumer, token.first); + builder.setInsertionPointAfter(consumerReleasePoint); + builder.createWithAsyncTaskIds( + consumerReleasePoint->getLoc(), token.second, bufferIdx); + } + + SmallVector tmaLoads; + SmallVector buffers; + DenseMap producerCopyMap; + // Go through all channels in this channel group. + for (auto &c : kv.second) { + if (auto tmaLoad = + dyn_cast(c->getSrcOp())) { + tmaLoads.push_back(tmaLoad); + buffers.push_back(bufferMap.find(c)->second); + } + } + + // Optimize TMA loads. + if (tmaLoads.size() > 0) { + auto barrierAllocs = barrierAllocMap.find(kv.second.front())->second; + // TODO: we created one Alloc for each consumer taskId, but here, we + // only use the first Alloc. + auto barrierAlloc = barrierAllocs.begin()->second; + optimizeTMALoads(builder, tmaLoads, buffers, barrierAlloc, bufferIdx, + bufferIdx, phase, headProducer, headConsumer); + } + } +} + +// Lower producers for channels. Here channels are grouped in +// "channelsGroupedByProducers" +void insertAsyncCopy( + triton::FuncOp funcOp, + const DenseMap> + &channelsGroupedByProducers, + const DenseMap &bufferMap, + DenseMap> ©OpMap) { + // For each producer op, create a async_copy or local_store from the producer + // to the buffer. Create a local_load from the buffer at the dominating + // consumer. + mlir::DominanceInfo dom(funcOp); + + for (auto kv : channelsGroupedByProducers) { + // Finding the dominating channel if possible. + std::unordered_set mutuallyNonDominatingChannels; + for (auto &c : kv.second) { + // check if c is dominating all other previous channels. + auto it = mutuallyNonDominatingChannels.begin(); + while (it != mutuallyNonDominatingChannels.end()) { + auto channel = *it; + if (dom.properlyDominates(c->getDstOp(), channel->getDstOp())) { + it = mutuallyNonDominatingChannels.erase(it); + } else if (dom.properlyDominates(channel->getDstOp(), c->getDstOp())) { + break; + } else { + ++it; + } + } + if (it == mutuallyNonDominatingChannels.end()) + mutuallyNonDominatingChannels.insert(c); + } + + auto srcOp = kv.getFirst()->getSrcOp(); + Value bufferIdx; + Value phase = Value(); + if (auto forOp = srcOp->getParentOfType()) { + // We already added phase, bufferIdx to the ForOp. + auto tSize = forOp.getBody()->getArguments().size(); + assert(tSize >= 2); + bufferIdx = forOp.getBody()->getArguments().back(); + } else { + // Producer is not in a ForOp, create phase and bufferIdx here which will + // be used by both producer and consumers. + OpBuilderWithAsyncTaskIds builder(srcOp); + SmallVector asyncTasksPC = getAsyncTaskIds(srcOp); + for (auto channel : mutuallyNonDominatingChannels) + asyncTasksPC.append(getAsyncTaskIds(channel->getDstOp())); + builder.setAsynTaskIdsFromArray(asyncTasksPC); + bufferIdx = builder.createWithAsyncTaskIds( + srcOp->getLoc(), 0, 32); + } + + assert(mutuallyNonDominatingChannels.size() == 1 && + "conditional consumers not supported"); + + auto domininatingChannel = *mutuallyNonDominatingChannels.begin(); + std::pair producerConsumerOps{nullptr, nullptr}; + + // No need to create async copy for TMA load which will be handled in + // insertAsyncComm. + if (isa(srcOp)) { + producerConsumerOps = {srcOp, domininatingChannel->getDstOp()}; + } else if (isa(srcOp)) { + SmallVector asyncTasksPC = getAsyncTaskIds(srcOp); + asyncTasksPC.append(getAsyncTaskIds(domininatingChannel->getDstOp())); + // After createAsyncCopy, c->getSrcOp()/headProducer are no longer + // valid. + producerConsumerOps = createAsyncCopy(bufferMap, domininatingChannel, + domininatingChannel->getSrcOp(), + asyncTasksPC, bufferIdx, bufferIdx); + } else { + assert(!isa(srcOp) && + "LocalLoadOp buffer should be reused"); + producerConsumerOps = + createLocalCopy(bufferMap, domininatingChannel, bufferIdx, bufferIdx); + } + + for (auto &channel : kv.second) { + copyOpMap[channel] = producerConsumerOps; + } + } +} + +void foldLocalLoads(triton::FuncOp funcOp) { + // If loadResult has a single use which is LocalAlloc, we can get rid of + // sharedLoad and replace all uses of LocalAlloc with viewLoad. + DenseMap opsToReplace; + funcOp.walk([&](ttg::LocalAllocOp localAlloc) { + if (auto src = localAlloc.getSrc()) { + if (auto localLoad = dyn_cast(src.getDefiningOp())) { + // Only fold within the same tasks + if (getAsyncTaskIds(localLoad) == getAsyncTaskIds(localAlloc)) { + opsToReplace[localAlloc] = localLoad.getSrc(); + } + } + } + }); + OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); + for (auto kv : opsToReplace) + replaceUsesAndPropagateType(builder, kv.getFirst(), kv.getSecond()); +} + +class TritonGPUWSCodePartitionPass + : public impl::TritonGPUWSCodePartitionBase { +public: + using impl::TritonGPUWSCodePartitionBase< + TritonGPUWSCodePartitionPass>::TritonGPUWSCodePartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + // Disable code partitioning when numBuffers is 0. + if (numBuffers == 0) + return; + + // Step 1: collect all communications between producers and consumers. + SmallVector> channelsOrigin; + collectAsyncChannels(channelsOrigin, funcOp, numBuffers); + SmallVector channels; + for (const auto &c : channelsOrigin) { + channels.push_back(c.get()); + } + if (channels.empty()) { + return; + } + + // Step 2: group channels + // - each entry of the channelsGroupedByProducers is keyed by the srcOp. + // - each entry of the channelsGroupedByConsumers is keyed by the dstOp. + DenseMap> channelsGroupedByProducers; + DenseMap> channelsGroupedByConsumers; + SmallVector orderedChannels; + groupChannels(channels, channelsGroupedByProducers, + channelsGroupedByConsumers, orderedChannels); + + // Step 3: reorder producer ops and the backward slices of the producer ops. + reorderProducerOps(channels); + + // Step 4: find top-level ops that contain a channel, also create new ForOps + // by adding phase and bufferIdx to the original ForOps, erase the original + // ForOps. + SmallVector asyncTaskTopOps = + getTaskTopRegion(funcOp, channels); + // Update mapToRepresenting that maps a channel to the representing channel + // in the sharing group. + DenseMap mapToRepresenting; + SmallVector opsWithBufferReuse; + reuseBuffers(asyncTaskTopOps, channels, mapToRepresenting, + opsWithBufferReuse); + // Use and update opsWithBufferReuse. + appendBufferIdxArgs(asyncTaskTopOps, numBuffers, channels, + mapToRepresenting, opsWithBufferReuse); + LLVM_DEBUG({ + LDBG("\n\nafter appendBufferIdxArgs"); + funcOp.dump(); + }); + + // Step 5: Create buffers. An array of buffers for each channel. Update + // channelReuse that maps from a representing channel to the group of + // channels that share buffers. + DenseMap> channelReuse; + DenseMap bufferMap = + createBuffer(channelsGroupedByProducers, funcOp, numConsumerGroups, + mapToRepresenting, channelReuse); + LLVM_DEBUG({ + LDBG("\n\nafter createBuffer"); + funcOp.dump(); + }); + + // Step 6: Lower the loads. Also add local copy ops for non-load + // producers. + DenseMap> copyOpMap; + insertAsyncCopy(funcOp, channelsGroupedByProducers, bufferMap, copyOpMap); + LLVM_DEBUG({ + LDBG("\n\nwith async copy"); + funcOp.dump(); + }); + + // Step 7: Create tokens. A set of tokens for each group of channels for + // each channel. + DenseMap> barrierAllocMap; + DenseMap> tokenMap = createToken( + channelsGroupedByConsumers, orderedChannels, funcOp, numConsumerGroups, + copyOpMap, channelReuse, barrierAllocMap); + LLVM_DEBUG({ + LDBG("\n\nafter createToken"); + funcOp.dump(); + }); + + // Step 8: add async communication ops (ProducerAcquire etc). Also lower + // TMA loads. + insertAsyncComm(funcOp, channelsGroupedByConsumers, tokenMap, + barrierAllocMap, bufferMap, copyOpMap, numConsumerGroups); + LLVM_DEBUG({ + LDBG("\n\nwith SyncOps"); + funcOp.dump(); + }); + + // If loadResult has a single use which is LocalAlloc, we can get rid of + // sharedLoad and replace all uses of LocalAlloc with viewLoad. + foldLocalLoads(funcOp); + LLVM_DEBUG({ + LDBG("\n\nsimplify localLoad + localAlloc"); + funcOp.dump(); + }); + + auto ret = SpecializeRegion(funcOp, regDecProducer, regIncConsumer); + LLVM_DEBUG({ + LDBG("\n\nwith SpecializeRegion"); + funcOp.dump(); + }); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + LLVM_DEBUG({ + LDBG("post pass"); + getOperation()->dump(); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp new file mode 100644 index 000000000..3a5b51a8c --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp @@ -0,0 +1,956 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-data-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +static bool oneVecCoversTheOther(SmallVector &one, + SmallVector &other) { + // Every element of other appears in one. + for (AsyncTaskId t : other) { + // If t doesn't appear in one, return false. + bool found = false; + for (AsyncTaskId t2 : one) { + if (t2 == t) { + found = true; + break; + } + } + if (!found) + return false; + } + return true; +} + +// Make sure the def chain contains the right taskId. +void fixTaskId(triton::FuncOp &funcOp) { + funcOp.walk([&](Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + // Do not update loads. + if (isa(defOp)) + continue; + auto defTaskIds = getAsyncTaskIds(defOp); + // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if + // necessary. + if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { + // Skip control flow ops. + if (isa(op)) + continue; + // Const ops with same value but different task ids can be folded. + if (defOp->getDialect()->getNamespace() == "arith") { + LLVM_DEBUG({ + LDBG("backward fixing taskId for"); + defOp->dump(); + }); + addAsyncTaskIds(defOp, asyncTaskIds); + LLVM_DEBUG({ + LDBG("resulting"); + defOp->dump(); + }); + } + } + if (operand.hasOneUse() && + !oneVecCoversTheOther(asyncTaskIds, defTaskIds)) { + // YieldOp may lose task attribute during MLIR canonicalization. + if (isa(op)) { + LLVM_DEBUG({ + LDBG("forward fixing taskId for"); + defOp->dump(); + }); + addAsyncTaskIds(op, defTaskIds); + LLVM_DEBUG({ + LDBG("resulting"); + defOp->dump(); + }); + } + } + } + }); +} + +static SmallVector getShape(Value v) { + auto type = v.getType(); + if (auto type = dyn_cast(v.getType())) { + return {type.getShape().begin(), type.getShape().end()}; + } else if (auto type = dyn_cast(v.getType())) { + return {type.getShape().begin(), type.getShape().end()}; + } else if (auto type = dyn_cast(v.getType())) { + return {type.getBlockType().getShape().begin(), + type.getBlockType().getShape().end()}; + } + return {}; +} + +bool needToSlice(Value v, int dim, int size) { + auto shape = getShape(v); + return shape.size() > dim && shape[dim] > size; +} + +void getBackwardSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &backwardSlice) { + SmallVector queue = {root}; + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!needToSlice(v, dim, sliceSize)) + continue; + if (auto op = v.getDefiningOp()) { + if (backwardSlice.insert(op)) { + if (op->hasTrait() || + isa(op)) { + for (Value operand : op->getOperands()) + queue.push_back(operand); + } else if (auto dotOp = dyn_cast(op)) { + queue.push_back(dim == 0 ? dotOp.getA() : dotOp.getB()); + queue.push_back(dotOp.getC()); + } else if (auto dotOp = dyn_cast(op)) { + queue.push_back(dim == 0 ? dotOp.getA() : dotOp.getB()); + queue.push_back(dotOp.getD()); + } else if (auto tensorDescOp = dyn_cast(op)) { + continue; + } else if (auto ifOp = dyn_cast(op)) { + // track yield value + // find result index of v + unsigned resultIndex = 0; + for (int i = 0; i < op->getNumResults(); ++i) { + if (op->getResult(i) == v) { + resultIndex = i; + break; + } + } + + auto thenYieldArg = ifOp.thenYield().getOperand(resultIndex); + backwardSlice.insert(ifOp.thenYield()); + queue.push_back(thenYieldArg); + auto elseYieldArg = ifOp.elseYield().getOperand(resultIndex); + backwardSlice.insert(ifOp.elseYield()); + queue.push_back(elseYieldArg); + } else { + llvm_unreachable("Unexpected op"); + } + } + } else { + assert(isa(v) && "value is not an operation or block "); + auto bbArg = cast(v); + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(bbAargOwner)) { + // track initial value + auto initArg = forOp.getInitArgs()[bbArg.getArgNumber() - 1]; + queue.push_back(initArg); + // track yield value + auto yieldArg = forOp.getYieldedValues()[bbArg.getArgNumber() - 1]; + queue.push_back(yieldArg); + } + } + } +}; + +void getForwardSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &forwardSlice) { + SmallVector queue = {root}; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!seen.insert(v).second) + continue; + if (!needToSlice(v, dim, sliceSize)) + continue; + getForwardSlice(v, &forwardSlice); + for (Operation *op : forwardSlice) { + if (op->getNumResults() > 0) + seen.insert(op->getResult(0)); + if (auto yieldOp = dyn_cast(op)) { + auto parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (seen.count(operand.get())) { + queue.push_back(parentOp->getResult(operand.getOperandNumber())); + forwardSlice.insert(parentOp); + } + } + } + } + } +}; + +// Compute a closure of all ops originated from or being dependent on by the +// root op. +void getSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &slice) { + getBackwardSliceToPartition(root, dim, sliceSize, slice); + SetVector forwardSlice; + getForwardSliceToPartition(root, dim, sliceSize, forwardSlice); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + for (auto op : forwardSlice) { + if (op->hasTrait() || + isa(op)) { + for (OpOperand &operand : op->getOpOperands()) { + getBackwardSliceToPartition(operand.get(), dim, sliceSize, slice); + } + } else if (auto dotOp = dyn_cast(op)) { + getBackwardSliceToPartition(dim == 0 ? dotOp.getA() : dotOp.getB(), dim, + sliceSize, slice); + getBackwardSliceToPartition(dotOp.getC(), dim, sliceSize, slice); + } else if (auto dotOp = dyn_cast(op)) { + getBackwardSliceToPartition(dim == 0 ? dotOp.getA() : dotOp.getB(), dim, + sliceSize, slice); + getBackwardSliceToPartition(dotOp.getD(), dim, sliceSize, slice); + } + } +} + +struct DataPartitionScheme { + // Which dimension to partition. For dot, dim 0 means along M dimension, 1 + // means along N dimensiont. + unsigned partitionDim = 0; + unsigned numPartitions = 0; + SetVector ops; +}; + +bool computePartitionScheme(triton::FuncOp &funcOp, + DataPartitionScheme &partitionScheme) { + // Do not partition producer tasks + + // Use dot to drive the partition + SetVector dots; + + // check all dot ops that have more than one async task id + funcOp.walk([&](Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() > 1) { + if (isa(op)) { + dots.insert(op); + } + } + }); + + // Checking if all dots can be partitioned in the same way + int numWarps = mlir::triton::gpu::lookupNumWarps(funcOp); + for (auto op : dots) { + // partition along M first, otherwise along N + Value opndA, opndB, accumulator; + + if (auto dotOp = dyn_cast(op)) { + opndA = dotOp.getA(); + opndB = dotOp.getB(); + accumulator = dotOp.getD(); + } else if (auto dotOp = dyn_cast(op)) { + opndA = dotOp.getA(); + opndB = dotOp.getB(); + accumulator = dotOp.getD(); + } + + auto dotType = accumulator.getType(); + LLVM_DEBUG({ + LDBG("Computing partition scheme for"); + op->dump(); + LDBG("\n"); + }); + auto shapePerCTA = getShapePerCTA(dotType); + if (shapePerCTA.size() != 2) { + LDBG("partition not possible: shapePerCTA " << shapePerCTA.size()); + return false; + } + auto asyncTaskIds = getAsyncTaskIds(op); + int sliceSizeM = shapePerCTA[0] / asyncTaskIds.size(); + int sliceSizeN = shapePerCTA[1] / asyncTaskIds.size(); + int partitionDim, partitionSize; + Value partitionOperand; + + if (sliceSizeM >= 64) { + LLVM_DEBUG({ LDBG("partition along M\n"); }); + partitionDim = 0; + partitionSize = sliceSizeM; + partitionOperand = opndA; + } else if (sliceSizeN >= 256) { + LLVM_DEBUG({ LDBG("partition along N\n"); }); + partitionDim = 1; + partitionSize = sliceSizeN; + partitionOperand = opndB; + } else { + LDBG("partition not possible: " << sliceSizeM << " " << sliceSizeN); + return false; + } + + if (partitionScheme.numPartitions == 0) { + partitionScheme.partitionDim = partitionDim; + partitionScheme.numPartitions = asyncTaskIds.size(); + } else { + if (partitionScheme.partitionDim != partitionDim || + partitionScheme.numPartitions != asyncTaskIds.size()) { + LDBG("partition not possible, in conflict with previous partition\n"); + return false; + } + } + + // Partition the slice closure + SetVector &slice = partitionScheme.ops; + getSliceToPartition(accumulator, partitionDim, partitionSize, slice); + + LLVM_DEBUG({ + partitionOperand.dump(); + LDBG("\n"); + LDBG(" slice:"); + for (auto &op : slice) { + op->dump(); + } + LDBG("\n"); + }); + + for (auto op : partitionScheme.ops) { + auto opTaskIds = getAsyncTaskIds(op); + // skip check for control flow ops + if (isa(op)) + continue; +#if 0 + if (opTaskIds.size() > partitionScheme.numPartitions) { + LLVM_DEBUG({ + LDBG("partition not possible: numPartitions" << opTaskIds.size() << " " << partitionScheme.numPartitions); + op->dump(); + }); + return false; + } +#endif + } + } + + return !partitionScheme.ops.empty(); +} + +Operation *sliceOp(Value v, int offset, IRMapping &mappings, + IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme); + +Operation *sliceOp(Operation *op, int offset, IRMapping &mappings, + IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme) { + if (!partitionScheme.ops.contains(op)) + return op; + if (mappings.contains(op)) + return mappings.lookupOrNull(op); + if (reverseMappings.contains(op)) + return op; + + LLVM_DEBUG({ + LDBG("slicing:"); + op->dump(); + LDBG("\n"); + }); + + int dim = partitionScheme.partitionDim; + int numOfPartitions = partitionScheme.numPartitions; + + auto asyncTaskIds = getAsyncTaskIds(op); + SmallVector sliceTaskIds; + if (asyncTaskIds.size() == numOfPartitions) { + // We are slicing the op for consumer only + sliceTaskIds.push_back(asyncTaskIds[offset]); + } else if (asyncTaskIds.size() == 1) { + // We are slicing the op for producer only + sliceTaskIds.push_back(asyncTaskIds.front()); + } else if (asyncTaskIds.size() > numOfPartitions) { + // We are slicing the op for both producer and consumer + sliceTaskIds.push_back(asyncTaskIds.front()); + sliceTaskIds.push_back(asyncTaskIds[offset + 1]); + } else { + llvm_unreachable("Unexpected asyncTaskIds.size()"); + } + + OpBuilderWithAsyncTaskIds builder(op->getContext()); + builder.setAsynTaskIdsFromArray(sliceTaskIds); + auto cloneAndSetResultType = [&](Operation *op) { + builder.setInsertionPoint(op); + auto newOp = builder.clone(*op, mappings); + setAsyncTaskIds(newOp, sliceTaskIds); + mappings.map(op, newOp); + reverseMappings.map(newOp, op); + // set result shape + if (!op->getResults().empty()) { + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + if (auto type = dyn_cast(v.getType())) { + SmallVector shape{type.getShape().begin(), + type.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + // change encoding for ttng.tensor_memory_encoding to match gen5. + if (auto tmem = dyn_cast( + type.getEncoding())) { + Attribute accEncoding = + triton::nvidia_gpu::TensorMemoryEncodingAttr::get( + builder.getContext(), + dim == 0 ? tmem.getBlockM() / 2 : tmem.getBlockM(), + dim == 1 ? tmem.getBlockN() / 2 : tmem.getBlockN(), + tmem.getUnpacked(), tmem.getCTASplitM(), tmem.getCTASplitN()); + auto newType = + MemDescType::get(shape, type.getElementType(), accEncoding, + type.getMemorySpace(), type.getMutableMemory()); + newV.setType(newType); + } else { + auto newType = + MemDescType::get(shape, type.getElementType(), type.getEncoding(), + type.getMemorySpace(), type.getMutableMemory()); + newV.setType(newType); + } + } else if (auto type = dyn_cast(v.getType())) { + SmallVector shape{type.getShape().begin(), + type.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newType = RankedTensorType::get(shape, type.getElementType(), + type.getEncoding()); + newV.setType(newType); + } else if (auto type = dyn_cast(v.getType())) { + auto blockType = type.getBlockType(); + SmallVector shape{blockType.getShape().begin(), + blockType.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newBlockType = + RankedTensorType::get(shape, blockType.getElementType()); + auto newType = TensorDescType::get(builder.getContext(), newBlockType); + newV.setType(newType); + } + + mappings.map(v, newV); + reverseMappings.map(newV, v); + } + return newOp; + }; + + // slice operands first + Operation *newOp; + if (op->hasTrait() || + isa(op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, mappings, reverseMappings, partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto tmemLdOp = dyn_cast(op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, mappings, reverseMappings, partitionScheme); + auto srcTy = mappings.lookupOrNull(tmemLdOp.getSrc()).getType(); + auto type = cast(srcTy); + auto tmem = cast(type.getEncoding()); + + RankedTensorType oldRetType = tmemLdOp.getType(); + auto retShapePerCTA = getShapePerCTA(oldRetType); + int numWarps = mlir::triton::gpu::lookupNumWarps(op); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + builder.setInsertionPoint(op); + // The source op is already sliced at this point, so srcTy, type, tmem is + // sliced. We use getTmemCompatibleLayout to get a block layout that is for + // the sliced tmem here. + Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout( + tmem.getBlockM(), tmem.getBlockN(), retShapePerCTA, numWarps, + CTALayout); + + // oldRetType is the desired output, we slice it and convert from the + // compatible layout to the sliced desired output. + SmallVector shape{oldRetType.getShape().begin(), + oldRetType.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newAccType = RankedTensorType::get(shape, oldRetType.getElementType(), + newDistributedEncoding); + auto ld = builder.createWithAsyncTaskIds( + op->getLoc(), newAccType, mappings.lookupOrNull(tmemLdOp.getSrc())); + + auto newType = RankedTensorType::get(shape, oldRetType.getElementType(), + oldRetType.getEncoding()); + auto cvtOp = builder.createWithAsyncTaskIds(op->getLoc(), + newType, ld); + auto v = tmemLdOp->getResult(0); + auto newV = cvtOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + newOp = cvtOp; + } else if (auto tmemAllocOp = dyn_cast(op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, mappings, reverseMappings, partitionScheme); + // Check for src. + if (tmemAllocOp.getSrc()) { + // src is blocked layout. apply convert layout on src + auto srcTy = cast( + mappings.lookupOrNull(tmemAllocOp.getSrc()).getType()); + + // convert from srcTy to a compatible blocked layout. + auto retShapePerCTA = getShapePerCTA(srcTy); + int numWarps = mlir::triton::gpu::lookupNumWarps(op); + auto CTALayout = getCTALayout(srcTy.getEncoding()); + builder.setInsertionPoint(op); + + // calculate new tmem type. + auto retType = cast(tmemAllocOp.getType()); + SmallVector shape{retType.getShape().begin(), + retType.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto tmem = + cast(retType.getEncoding()); + auto accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get( + builder.getContext(), + dim == 0 ? tmem.getBlockM() / 2 : tmem.getBlockM(), + dim == 1 ? tmem.getBlockN() / 2 : tmem.getBlockN(), + tmem.getUnpacked(), tmem.getCTASplitM(), tmem.getCTASplitN()); + auto newType = MemDescType::get(shape, retType.getElementType(), + accEncoding, retType.getMemorySpace(), + retType.getMutableMemory()); + + Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout( + accEncoding.getBlockM(), accEncoding.getBlockN(), retShapePerCTA, + numWarps, CTALayout); + auto newAccType = RankedTensorType::get( + srcTy.getShape(), srcTy.getElementType(), newDistributedEncoding); + auto cvtOp = builder.createWithAsyncTaskIds( + op->getLoc(), newAccType, + mappings.lookupOrNull(tmemAllocOp.getSrc())); + + // replace tmemAllocOp with alloc, where the src is cvtOp. + auto alloc = + builder.createWithAsyncTaskIds( + op->getLoc(), newType, cvtOp); + + auto v = tmemAllocOp->getResult(0); + auto newV = alloc->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + newOp = alloc; + } else + newOp = cloneAndSetResultType(op); + } else if (auto constOp = dyn_cast(op)) { + builder.setInsertionPoint(op); + auto valAttr = cast(constOp.getValueAttr()); + auto valType = cast(valAttr.getType()); + SmallVector shape{valType.getShape().begin(), + valType.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newValType = valType.clone(shape); + auto newValAttr = valAttr.resizeSplat(newValType); + newOp = builder.createWithAsyncTaskIds(op->getLoc(), + newValAttr); + // Do not drop original task id as constant folding may lose one constant. + setAsyncTaskIds(newOp, getAsyncTaskIds(op)); + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } else if (auto makeRangeOp = dyn_cast(op)) { + builder.setInsertionPoint(op); + int newRangeStart = makeRangeOp.getStart(); + int newRangeEnd = makeRangeOp.getEnd(); + int sliceSize = (newRangeEnd - newRangeStart) / numOfPartitions; + newRangeStart += offset * sliceSize; + newRangeEnd = newRangeStart + sliceSize; + auto v = op->getResult(0); + auto type = cast(v.getType()); + auto newType = RankedTensorType::get({sliceSize}, builder.getI32Type(), + type.getEncoding()); + newOp = builder.createWithAsyncTaskIds( + op->getLoc(), newType, newRangeStart, newRangeEnd); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } else if (isa(op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, mappings, reverseMappings, partitionScheme); + // TODO: slice store base ptr + newOp = cloneAndSetResultType(op); + } else if (isa( + op)) { + SmallVector shape; + Value coordVal; + if (auto loadOp = dyn_cast(op)) { + sliceOp(loadOp.getDesc(), offset, mappings, reverseMappings, + partitionScheme); + coordVal = loadOp.getIndices()[dim]; + shape = getShape(loadOp.getResult()); + } else if (auto storeOp = dyn_cast(op)) { + sliceOp(storeOp.getDesc(), offset, mappings, reverseMappings, + partitionScheme); + coordVal = storeOp.getIndices()[dim]; + shape = getShape(storeOp.getSrc()); + } + auto newCoordVal = coordVal; + if (offset) { + builder.setInsertionPointAfter(coordVal.getDefiningOp()); + Value offsetVal = builder.createWithAsyncTaskIds( + op->getLoc(), offset * shape[dim] / numOfPartitions, 32); + newCoordVal = builder.createWithAsyncTaskIds( + op->getLoc(), coordVal, offsetVal); + mappings.map(coordVal, newCoordVal); + reverseMappings.map(newCoordVal, coordVal); + } + + newOp = cloneAndSetResultType(op); + if (isa(op)) { + // map load result + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } + } else if (auto tensorDescOp = dyn_cast(op)) { + newOp = cloneAndSetResultType(op); + } else if (auto dotOp = dyn_cast(op)) { + // Only hanlde A and accumulator + sliceOp(dim == 0 ? dotOp.getA() : dotOp.getB(), offset, mappings, + reverseMappings, partitionScheme); + sliceOp(dotOp.getC(), offset, mappings, reverseMappings, partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto dotOp = dyn_cast(op)) { + // Only hanlde A and accumulator + sliceOp(dim == 0 ? dotOp.getA() : dotOp.getB(), offset, mappings, + reverseMappings, partitionScheme); + sliceOp(dotOp.getD(), offset, mappings, reverseMappings, partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto forOp = dyn_cast(op)) { + // Add new loop arguments + SmallVector newLoopArgs; + for (auto initArg : forOp.getInitArgs()) + newLoopArgs.push_back(initArg); + DenseMap newArgIdices; + for (unsigned i = 0; i < forOp.getInitArgs().size(); i++) { + auto initArg = forOp.getInitArgs()[i]; + Value newInitArg; + auto newInitArgOp = + sliceOp(initArg, offset, mappings, reverseMappings, partitionScheme); + if (auto bbArg = dyn_cast(initArg)) { + // find the corresponding new block argument + Block *parentBlock = bbArg.getOwner(); + unsigned argIndex = parentBlock->getNumArguments(); + for (unsigned i = 0; i < parentBlock->getNumArguments(); ++i) { + if (parentBlock->getArgument(i) == bbArg) { + argIndex = i; + break; + } + } + assert(argIndex < parentBlock->getNumArguments() && + "new init argment not found"); + Region *parentRegion = parentBlock->getParent(); + Region &newParentRegion = + newInitArgOp->getRegion(parentRegion->getRegionNumber()); + newInitArg = parentRegion->getArgument(argIndex); + } else { + auto initArgOp = initArg.getDefiningOp(); + unsigned resultIndex = cast(initArg).getResultNumber(); + newInitArg = newInitArgOp->getResult(resultIndex); + } + + if (newInitArg != initArg) { + newLoopArgs.append({newInitArg}); + forOp.getBody()->insertArgument(forOp.getBody()->getNumArguments(), + newInitArg.getType(), forOp.getLoc()); + newArgIdices[i] = newLoopArgs.size() - 1; + } + } + + // Create newForOp and take the region of forOp + builder.setInsertionPoint(op); + auto newForOp = builder.createWithAsyncTaskIds( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newLoopArgs); + assert(newForOp.getRegionIterArgs().size() == + newForOp.getInitArgs().size()); + newForOp->setAttrs(forOp->getAttrs()); + partitionScheme.ops.insert(newForOp); + newOp = newForOp; + + // Replace forOp with newForOp + newForOp.getRegion().takeBody(forOp.getRegion()); + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + op->setAttr("to_be_removed", builder.getUnitAttr()); + + // Map new loop arguments + for (auto argIndex : newArgIdices) { + Value v = newForOp.getResult(argIndex.first); + Value newV = newForOp.getResult(argIndex.second); + mappings.map(v, newV); + reverseMappings.map(newV, v); + + auto regionArg = newForOp.getRegionIterArg(argIndex.first); + auto newRegionArg = newForOp.getRegionIterArg(argIndex.second); + mappings.map(regionArg, newRegionArg); + reverseMappings.map(newRegionArg, regionArg); + } + } else if (auto ifOp = dyn_cast(op)) { + // Slice the yield op and update if results + auto thenYieldOp = ifOp.thenYield(); + auto elseYieldOp = ifOp.elseYield(); + auto newThenYieldOp = sliceOp(thenYieldOp, offset, mappings, + reverseMappings, partitionScheme); + sliceOp(elseYieldOp, offset, mappings, reverseMappings, partitionScheme); + assert(newThenYieldOp->getNumOperands() > ifOp->getNumResults() && + "no need to slice if op"); + // Clone ifOp with updated results but re-use the original regions. + builder.setInsertionPoint(op); + SmallVector newResultTypes; + for (auto thenResult : thenYieldOp.getResults()) { + newResultTypes.push_back(thenResult.getType()); + } + auto newIfOp = builder.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition()); + // Move the original regions to the cloned operation. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + newOp = newIfOp; + newIfOp->setAttrs(ifOp->getAttrs()); + partitionScheme.ops.insert(newIfOp); + ifOp->setAttr("to_be_removed", builder.getUnitAttr()); + + // Replace ifOp with newIfOp + for (unsigned i = 0; i < ifOp.getNumResults(); ++i) + ifOp.getResult(i).replaceAllUsesWith(newIfOp.getResult(i)); + + // Map if results based on the mapping for yield + for (auto &v : thenYieldOp->getOpOperands()) { + auto newV = mappings.lookupOrNull(v.get()); + if (newV) { + int operandIndex = v.getOperandNumber(); + // find the corresponding operand index of newV in newYieldOp + int newOperandIndex = -1; + for (int i = 0; i < newThenYieldOp->getNumOperands(); ++i) { + if (newThenYieldOp->getOperand(i) == newV) { + newOperandIndex = i; + break; + } + } + assert(newOperandIndex >= 0 && "newV not found in newYieldOp"); + auto newResult = newIfOp.getResult(operandIndex); + auto newSlicedResult = newIfOp.getResult(newOperandIndex); + mappings.map(newResult, newSlicedResult); + reverseMappings.map(newSlicedResult, newResult); + } + } + } else if (auto yieldOp = dyn_cast(op)) { + int num = yieldOp.getNumOperands(); + for (int i = 0; i < num; i++) { + auto operand = yieldOp.getOperand(i); + sliceOp(operand, offset, mappings, reverseMappings, partitionScheme); + if (auto newV = mappings.lookupOrNull(operand)) + yieldOp->insertOperands(op->getNumOperands(), newV); + } + newOp = op; + } else if (auto reduceOp = dyn_cast(op)) { + assert(reduceOp.getAxis() != partitionScheme.partitionDim && + "reduce should not happen on the partitioned dimension"); + for (Value operand : op->getOperands()) + sliceOp(operand, offset, mappings, reverseMappings, partitionScheme); + newOp = cloneAndSetResultType(op); + // recursively set async task ids for child ops + newOp->walk( + [&](Operation *childOp) { setAsyncTaskIds(childOp, sliceTaskIds); }); + } else { + llvm_unreachable("unsupported op type"); + } + + LLVM_DEBUG({ + LDBG("resulting"); + newOp->dump(); + LDBG("\n"); + }); + mappings.map(op, newOp); + reverseMappings.map(newOp, op); + return newOp; +} + +Operation *sliceOp(Value v, int offset, IRMapping &mappings, + IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme) { + if (auto op = v.getDefiningOp()) { + return sliceOp(op, offset, mappings, reverseMappings, partitionScheme); + } else { + assert(isa(v) && "value is not an operation or block "); + auto bbArg = cast(v); + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + return sliceOp(bbAargOwner, offset, mappings, reverseMappings, + partitionScheme); + } +} + +void doDeepCleanup(triton::FuncOp &funcOp, + DataPartitionScheme &partitionScheme) { + SmallVector opsToDelete; + DenseSet opsCanBeTriviallyDead; + + do { + opsToDelete.clear(); + opsCanBeTriviallyDead.clear(); + + // Identify root ops that are not used so to be deleted. + funcOp.walk([&](Operation *op) { + if (isa(op)) + return; + if (!partitionScheme.ops.contains(op)) + return; + + // Ignore the side effect of ops that are already sliced. The + // resulting ops preserve the side effect. + if (!isMemoryEffectFree(op)) + opsCanBeTriviallyDead.insert(op); + + bool notUsed = true; + for (auto result : op->getResults()) { + if (!result.getUsers().empty()) { + notUsed = false; + break; + } + } + if (notUsed) + opsToDelete.push_back(op); + }); + + LLVM_DEBUG({ + LDBG("opsToDelete:\n"); + for (auto op : opsToDelete) { + LDBG("op: "); + op->dump(); + } + LDBG("\n"); + }); + + if (opsToDelete.empty()) + return; + + // Delete root ops. + for (auto op : opsToDelete) { + partitionScheme.ops.remove(op); + op->erase(); + } + + LLVM_DEBUG({ + LDBG("prior to loop arg deletion:"); + funcOp.dump(); + }); + + // delete block arguments + RewritePatternSet cleanUpPatterns(funcOp.getContext()); + populateForOpDeadArgumentElimination(cleanUpPatterns, + opsCanBeTriviallyDead); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns, + funcOp.getContext()); + scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns, + funcOp.getContext()); + if (applyPatternsGreedily(funcOp, std::move(cleanUpPatterns)).failed()) { + llvm_unreachable("failed to clean up"); + // signalPassFailure(); + } + } while (!opsToDelete.empty()); +} + +void partitionTasks(triton::FuncOp &funcOp, int numConsumerGroups) { + + // op -> (partition dim, num of partitions) + DataPartitionScheme partitionScheme; + if (!computePartitionScheme(funcOp, partitionScheme)) { + if (numConsumerGroups > 1) + llvm::errs() << "computePartitionScheme failed when requested\n"; + return; + } + + for (int i = 0; i < partitionScheme.numPartitions; i++) { + IRMapping mappings, reverseMappings; + + LLVM_DEBUG({ LDBG("partitioning op for task " << i << ":\n"); }); + + // TODO: compute a topological order for partitionScheme.ops and + // slice in that order. + int numOps = partitionScheme.ops.size(); + for (int j = 0; j < numOps; j++) { + auto op = partitionScheme.ops[j]; + sliceOp(op, i, mappings, reverseMappings, partitionScheme); + } + + // clean up + LLVM_DEBUG({ + LDBG("prior to clean up:"); + funcOp.dump(); + }); + SmallVector opsToDelete; + for (auto op : partitionScheme.ops) { + if (op->hasAttr("to_be_removed")) + opsToDelete.push_back(op); + } + for (auto op : opsToDelete) { + partitionScheme.ops.remove(op); + op->erase(); + } + } + + LLVM_DEBUG({ + LDBG("prior to final cleanup:"); + funcOp.dump(); + }); + + // Make sure original ops are not used + doDeepCleanup(funcOp, partitionScheme); + + // Make sure original ops are not used + LLVM_DEBUG({ + LDBG("after partition"); + funcOp.dump(); + LDBG("\n"); + }); + fixTaskId(funcOp); +} + +#define GEN_PASS_DEF_TRITONGPUWSDATAPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUWSDataPartitionPass + : public impl::TritonGPUWSDataPartitionBase { +public: + using impl::TritonGPUWSDataPartitionBase< + TritonGPUWSDataPartitionPass>::TritonGPUWSDataPartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + if (numConsumerGroups == 0) + return; + partitionTasks(funcOp, numConsumerGroups); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp new file mode 100644 index 000000000..8adbdee5d --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp @@ -0,0 +1,274 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include + +#include "mlir/IR/OperationSupport.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-lowering" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +static Value createThreadIdOp(OpBuilder &builder, Location loc) { + Value threadId = builder.create<::mlir::gpu::ThreadIdOp>( + loc, builder.getIndexType(), ::mlir::gpu::Dimension::x); + auto cast = builder.create( + loc, TypeRange{builder.getIntegerType(32)}, ValueRange{threadId}); + return cast.getResult(0); +} + +// Lower to use GetCanonicalWarpIdOp. +// In Hopper, each task is a warpgroup consisting of 4 warps. +static const int WARPS_PER_TASK = 4; +static const int THREADS_PER_TASK = 128; +void lowerGetAsyncTaskIdOp(Operation *parentOp, int numConsumerGroups) { + DenseSet eraseOps; + parentOp->walk([&](ttng::GetAsyncTaskIdOp op) { + auto loc = op.getLoc(); + OpBuilder builder(op); + Value _4 = builder.create(loc, WARPS_PER_TASK, 32); + Value warpId = builder.create(loc); + Value asyncTaskId = builder.create(loc, warpId, _4); + op.getResult().replaceAllUsesWith(asyncTaskId); + + LLVM_DEBUG({ + LDBG("erasing GetAsyncTask"); + op->dump(); + }); + eraseOps.insert(op); + }); + for (Operation *op : eraseOps) + op->erase(); +} + +Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op, + bool emptyBarrier) { + auto loc = op->getLoc(); + assert(isa(op) || isa(op)); + Value curPhase; + if (auto acq = dyn_cast(op)) + curPhase = acq.getPhase(); + else if (auto wait = dyn_cast(op)) + curPhase = wait.getPhase(); + if (emptyBarrier) { + // curPhase = curPhase xor True for emptyBarrier. + Value _1_1b = builder.create(loc, 1, 1); + curPhase = builder.create(loc, curPhase, _1_1b); + } + LLVM_DEBUG(curPhase.dump()); + return curPhase; +} + +void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op, + Value bufferEmpty) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, true); + auto i32Ty = builder.getIntegerType(32); + phase = builder.create(loc, i32Ty, phase); + auto waitOp = builder.create(loc, bufferEmpty, phase); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(waitOp, getAsyncTaskIds(op.getOperation())); +} + +void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op, + Value bufferFull, ttng::TokenLoadType loadType) { + auto loc = op.getLoc(); + int txCnt = 0; + ttng::MBarrierArriveOp arriveOp; + + if (loadType == ttng::TokenLoadType::TMALoadOp) { + // Only thread 0 arrives for TMA load. + Value _0 = builder.create(loc, 0, 32); + Value threadId = createThreadIdOp(builder, loc); + Value pred = builder.create(loc, arith::CmpIPredicate::eq, + threadId, _0); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ false, + txCnt); + } else { + // Each thread arrives. + Value pred = builder.create(loc, 1, 1); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ true, + txCnt); + } + + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); +} + +void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, + Value bufferFull) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, false); + auto i32Ty = builder.getIntegerType(32); + phase = builder.create(loc, i32Ty, phase); + auto waitOp = builder.create(loc, bufferFull, phase); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(waitOp, getAsyncTaskIds(op.getOperation())); +} + +void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op, + Value bufferEmpty, int numCTAs) { + auto loc = op.getLoc(); + auto arriveOp = builder.create( + loc, bufferEmpty, nullptr, nullptr, false, 0); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); +} + +void lowerTokenOperations(Operation *parentOp, int numCTAs, + int numConsumerGroups) { + SmallVector deprecatedOps; + parentOp->walk([&](ttng::CreateTokenOp createTokenOp) { + ttng::TokenLoadType loadType = createTokenOp.getLoadType(); + MLIRContext *context = createTokenOp.getContext(); + OpBuilder builder(createTokenOp); + Location loc = createTokenOp.getLoc(); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = ttg::SwizzledSharedEncodingAttr::get( + context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = + ttg::MemDescType::get({createTokenOp.getNum()}, builder.getI64Type(), + barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + ttg::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value bufferFullArray = builder.create( + loc, barrierMemDescType, Value()); + Value bufferEmptyArray = builder.create( + loc, barrierMemDescType, Value()); + + for (unsigned i = 0; i < createTokenOp.getNum(); i++) { + Value idx = builder.create(loc, i, 32); + Value barrierFullView = builder.create( + loc, singleBarrierMemDescType, bufferFullArray, idx); + unsigned bufferFullCount = + loadType == ttng::TokenLoadType::TMALoadOp ? 1 : THREADS_PER_TASK; + builder.create(loc, barrierFullView, + bufferFullCount); + + Value barrierEmptyView = builder.create( + loc, singleBarrierMemDescType, bufferEmptyArray, idx); + builder.create(loc, barrierEmptyView, + THREADS_PER_TASK); + } + + assert(numCTAs == 1 && "remote CTA is not supported yet"); + builder.create(loc); + + // Helper function for extracting one index from bufferFullArray. + auto extractBufferFull = [&](Location loc, Value idx) -> Value { + return builder.create( + loc, singleBarrierMemDescType, bufferFullArray, idx); + }; + + // Helper function for extracting one index from bufferEmptyArray. + auto extractBufferEmpty = [&](Location loc, Value idx) -> Value { + return builder.create( + loc, singleBarrierMemDescType, bufferEmptyArray, idx); + }; + + // Process token users: ProducerAcquireOp, ProducerCommitOp, ConsumerWaitOp, + // and ConsumerReleaseOp. + for (Operation *user : createTokenOp.getResult().getUsers()) { + auto loc = user->getLoc(); + builder.setInsertionPoint(user); + if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferEmpty.getDefiningOp(), getAsyncTaskIds(user)); + processProducerAcquireOp(builder, op, bufferEmpty); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferFull.getDefiningOp(), getAsyncTaskIds(user)); + processProducerCommitOp(builder, op, bufferFull, loadType); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferFull.getDefiningOp(), getAsyncTaskIds(user)); + processConsumerWaitOp(builder, op, bufferFull); + } else if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferEmpty.getDefiningOp(), getAsyncTaskIds(user)); + processConsumerReleaseOp(builder, op, bufferEmpty, numCTAs); + } else { + llvm_unreachable("Unexpected user of token"); + } + deprecatedOps.push_back(user); + } + + deprecatedOps.push_back(createTokenOp); + }); + for (auto op : deprecatedOps) { + op->erase(); + } + + assert(numCTAs == 1 && "remote CTA is not supported yet"); +} + +#define GEN_PASS_DEF_TRITONGPUWSLOWERING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This pass lowers WS-specific operations. +class TritonGPUWSLowering + : public impl::TritonGPUWSLoweringBase { +public: + using impl::TritonGPUWSLoweringBase< + TritonGPUWSLowering>::TritonGPUWSLoweringBase; + + void runOnOperation() override { + // Disable WarpSpec if numConsumerGroups is zero. + if (numConsumerGroups == 0) + return; + ModuleOp mod = getOperation(); + int numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + lowerGetAsyncTaskIdOp(mod, numConsumerGroups); + lowerTokenOperations(mod, numCTAs, numConsumerGroups); + + // We assume number of warps per warp group is 4. + // With Warp Spec, the effective warps per CTA is + // number of warp groups * 4, but within each warp group, layout will use + // num_warps of 4, since tensors are not distributed between the groups. + // + // Loads usually happen in one producer warp groups. num_warps of 4 makes + // sense because only the 4 warps from the producer warp group are + // participating in the load. + // + // But at some point (at least when we launch the kernel!) we really do need + // to know that the CTA has 8 or 12 warps in it. Attribute + // "num-warp-groups-per-cta" can be used to calculate the total number of + // warps. + auto builder = OpBuilder::atBlockBegin(mod.getBody()); + int numWarps = triton::gpu::lookupNumWarps(mod); + mod->setAttr("ttg.total-num-warps", + builder.getI32IntegerAttr(numWarps * (1 + numConsumerGroups))); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSTaskPartition.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSTaskPartition.cpp new file mode 100644 index 000000000..c1165beb9 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonGPU/Transforms/WSTaskPartition.cpp @@ -0,0 +1,170 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-task-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define GEN_PASS_DEF_TRITONGPUWSTASKPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct TaskSchedule { + unsigned numTasks = 0; + DenseMap opToTaskId; +}; + +// Compute a partition schedule for later passes to actually partition the +// program into async tasks. +void doPartition(triton::FuncOp &funcOp, unsigned numConsumerGroups) { + + // Bail out in the presence of user annotations. + DenseSet allAsyncTasks; + funcOp->walk([&](Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + allAsyncTasks.insert(asyncTasks.begin(), asyncTasks.end()); + }); + + if (!allAsyncTasks.empty()) + return; + + SmallVector loops; + SmallVector loads; + SmallVector dots; + + funcOp.walk([&](Operation *op) { + if (scf::ForOp forOp = dyn_cast(op)) + loops.push_back(forOp); + else if (isa(op)) + dots.push_back(op); + else if (isa(op)) + loads.push_back(op); + }); + + if (loops.empty() || loads.empty() || dots.empty()) + return; + + auto getLoopLevel = [&](Operation *op) { + // Compute loop depth + unsigned depth = 0; + Operation *parent = op->getParentOp(); + while (parent) { + if (isa(parent)) { + ++depth; + } + parent = parent->getParentOp(); + } + return depth; + }; + + // Step 1. Select loads into the first task, which is the producer task by + // default. Place dots into the second task, which is the consumer. + // Only consider loads that are connected to a dot op in a loop. + DenseSet producerOps; + SmallVector consumerOps; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.inclusive = true; + + for (auto op : dots) { + consumerOps.push_back(op); + auto dotOp = dyn_cast(op); + if (!dotOp) + continue; + SetVector backwardSlice; + getBackwardSlice(dotOp.getA(), &backwardSlice, opt); + getBackwardSlice(dotOp.getB(), &backwardSlice, opt); + for (auto depOp : backwardSlice) { + if (isa(depOp)) { + producerOps.insert(depOp); + } else if (isa(depOp) && isExpensiveLoadOrStore(depOp)) { + producerOps.insert(depOp); + } + } + } + + LLVM_DEBUG({ + LDBG("Producer ops:\n"); + for (auto op : producerOps) { + op->dump(); + } + + LDBG("\n"); + LDBG("Consumer ops:\n"); + for (auto op : consumerOps) { + op->dump(); + } + + LDBG("\n"); + }); + + if (consumerOps.empty() || producerOps.empty()) + return; + + // Annoate the program with task ids + SmallVector producerTaskIds{0}; + SmallVector consumerTaskIds; + for (unsigned i = 0; i < numConsumerGroups; ++i) { + consumerTaskIds.push_back(i + producerTaskIds.size()); + } + + for (auto op : producerOps) { + setAsyncTaskIds(op, producerTaskIds); + } + + for (auto op : consumerOps) { + setAsyncTaskIds(op, consumerTaskIds); + } + + LLVM_DEBUG({ + LDBG("After task partition"); + funcOp.dump(); + LDBG("\n"); + }); +} + +class TritonGPUWSTaskPartitionPass + : public impl::TritonGPUWSTaskPartitionBase { +public: + using impl::TritonGPUWSTaskPartitionBase< + TritonGPUWSTaskPartitionPass>::TritonGPUWSTaskPartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + if (numConsumerGroups == 0) + return; + doPartition(funcOp, numConsumerGroups); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..7415bfafb --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(TritonNvidiaGPUIR + Dialect.cpp + Ops.cpp + Types.cpp + + DEPENDS + TritonNvidiaGPUTableGen + TritonNvidiaGPUAttrDefsIncGen + TritonNvidiaGPUOpInterfacesIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR +) diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp new file mode 100644 index 000000000..d73d21b6c --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -0,0 +1,271 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc" + +using namespace mlir; +using namespace mlir::triton::gpu; +using namespace mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +static constexpr int numTmemRows = 128; + +TMemAllocation getTmemAllocSizes(MemDescType memDescType) { + const int rowSizeInBytes = 4; + auto shapePerCTA = triton::gpu::getShapePerCTA(memDescType); + if (isa(memDescType.getEncoding())) { + // For scales the data are packed and replicated 4 times. + assert(memDescType.getElementType().getIntOrFloatBitWidth() == 8); + assert(memDescType.getShape().size() == 2 && + "TODO handle multibuffering of scales."); + int k = shapePerCTA[1]; + int m = shapePerCTA[0]; + int numColumn = ceil(m, 32) * ceil(k, 4); + return TMemAllocation(numColumn, numTmemRows); + } + assert(isa( + memDescType.getEncoding()) && + "Expecting a tensor memory encoding attribute"); + triton::nvidia_gpu::TensorMemoryEncodingAttr attr = + cast( + memDescType.getEncoding()); + bool isUnpacked = attr.getUnpacked(); + int64_t elementSizeInBytes = + isUnpacked ? rowSizeInBytes + : memDescType.getElementType().getIntOrFloatBitWidth() / 8; + int sizeInBytes = product(shapePerCTA) * elementSizeInBytes; + int numRows = numTmemRows; + // BlockM of 64 is and interleaved format, where for single message only the + // first 16 rows are used. For multiple blocks, the rows are interleaved, i.e. + // 0 N/2 N + // --------------------------------------------- + // 0 0,0 0,1... 0,N/2-1 0,N/2 0,N/2+1 ... 0, N-1 \ + //... Block 0 + // 15 15,0 15,1..15,N/2-1 15,N/2 15,N/2+1...15, N-1 / + // 16 0,0 0,1... 0,N/2-1 0,N/2 0,N/2+1 ... 0, N-1 \ + //... Block 1 + // 31 15,0 15,1..15,N/2-1 15,N/2 15,N/2+1...15, N-1 / + // Note that allocations that consists of single block of 64 rows are + // "sparse" and only half of the rows are used. + // Note that even for 3D shapes for which 2D slices are big enough to fit + // entire tensor block, we will use "sparse" allocation. + int blockM = attr.getBlockM(); + int blockN = attr.getBlockN(); + int lastDim = shapePerCTA.size() - 1; + int isSingleBlock = + (shapePerCTA[lastDim - 1] <= blockM) && (shapePerCTA[lastDim] <= blockN); + if (blockM == 64 && isSingleBlock) + numRows = 64; + int numColumn = ceil(sizeInBytes, (numRows * rowSizeInBytes)); + return TMemAllocation(numColumn, numRows); +} + +Attribute getTmemCompatibleLayout(unsigned M, unsigned N, + ArrayRef shape, unsigned numWarps, + triton::gpu::CTALayoutAttr ctaLayout) { + assert(numWarps == 4 || numWarps == 8); + assert(shape.size() == 2); + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; + SmallVector order; + SmallVector blocksPerTile = {(unsigned)shape[0] / M, + (unsigned)shape[1] / N}; + int numBlocks = blocksPerTile[0] * blocksPerTile[1]; + if (M == 64) { + unsigned numWarpGroups = numWarps / 4; + if (numBlocks == 1) { + // Split along the N dimension + sizePerThread = {1, N / (numWarpGroups * 2)}; + threadsPerWarp = {16, 2}; + warpsPerCTA = {4, numWarpGroups}; + } else { + sizePerThread = {1, N / 2}; + threadsPerWarp = {16, 2}; + warpsPerCTA = {0, 0}; + // Distribute at most as many warp groups as there is blocks + // along M dimension. + warpsPerCTA[0] = 4 * std::min(blocksPerTile[0], numWarpGroups); + // Distribute rest of the warp groups along N dimension. + warpsPerCTA[1] = ceil(numWarpGroups, warpsPerCTA[0] / 4); + } + } else { + unsigned numWarpGroups = numWarps / 4; + if (shape[0] > 128) { + // Split along M dimension + sizePerThread = {1, N}; + threadsPerWarp = {32, 1}; + warpsPerCTA = {4 * numWarpGroups, 1}; + } else { + // Split along N dimension + sizePerThread = {1, N / numWarpGroups}; + threadsPerWarp = {32, 1}; + warpsPerCTA = {4, numWarpGroups}; + } + } + order = {0, 1}; + return triton::gpu::BlockedEncodingAttr::get(ctaLayout.getContext(), + sizePerThread, threadsPerWarp, + warpsPerCTA, order, ctaLayout); +} + +// Verify if the distributed layout can be mapped onto tensor memory. +bool isDistributedLayoutTMemCompatible(Operation *op, + RankedTensorType tensorType, + MemDescType memType) { + int numWarps = lookupNumWarps(op); + assert(numWarps % 4 == 0); + int numWarpGroups = numWarps / 4; + + int blockM = 0; + int blockN = 0; + bool scalesEncoding = false; + if (auto attr = dyn_cast( + memType.getEncoding())) { + blockM = attr.getBlockM(); + blockN = attr.getBlockN(); + } else { + assert(isa( + memType.getEncoding()) && + "Expecting a tensor memory encoding attribute"); + return tensorType.getEncoding() == + triton::gpu::LinearEncodingAttr::get( + tensorType.getContext(), + getScaleTMEMStoreLinearLayout(tensorType, numWarps)); + } + auto shapePerCTA = mlir::triton::gpu::getShapePerCTA(tensorType); + int numElements = product(shapePerCTA); + int numBlocks = ceil(numElements, blockM * blockN); + bool useStridedMessage = blockM == 64; + + int numWarpGroupsPerBlock = ceil(numWarpGroups, numBlocks); + + auto tensorEncoding = + cast(tensorType.getEncoding()); + auto sizePerThread = tensorEncoding.getSizePerThread(); + auto threadsPerWarp = tensorEncoding.getThreadsPerWarp(); + auto warpsPerCTA = tensorEncoding.getWarpsPerCTA(); + auto order = tensorEncoding.getOrder(); + + if (order.size() != 2) + return false; + + if (order[0] != 0 || order[1] != 1) + return false; + + if (useStridedMessage) { + // For blockM=64 we need to use 16x32bx2 message, meaning the distributed + // layout needs to be organized into 16x2 threads per warp and one row + // access per thread. + if (threadsPerWarp[0] != 16 || threadsPerWarp[1] != 2 || + sizePerThread[0] != 1) + return false; + + if (numBlocks == 1) { + // with blockM=64 and just single block we cannot split along the M + // dimension. Check that if we split, we split along N. + if (numWarpGroupsPerBlock > 1) { + if (warpsPerCTA[1] == 1) + return false; + } + } + } else { + // For blockM=128, we need to use a 32x32b message, which requires 32 + // threads to be sequentially ordered across the M dimension, ensuring + // that each thread accesses a single and unique TMEM datapath. + if (threadsPerWarp[0] != 32 || sizePerThread[0] != 1) + return false; + } + return true; +} + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc" + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// +namespace { +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (auto sharedAttr = mlir::dyn_cast(attr)) { + os << "tmem"; + return AliasResult::FinalAlias; + } + if (mlir::isa(attr)) { + os << "tmem_scales"; + return AliasResult::FinalAlias; + } + return OpAsmDialectInterface::getAlias(attr, os); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// + +void TritonNvidiaGPUDialect::initialize() { + registerTypes(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc" + >(); + addInterfaces(); +} + +// verify TritonNvidiaGPU ops +LogicalResult +TritonNvidiaGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp new file mode 100644 index 000000000..697ef2a59 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -0,0 +1,480 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/Builders.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.cpp.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +// -- WarpGroupDotOp -- +mlir::LogicalResult WarpGroupDotOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = + cast(operands[0].getType()).getEncoding(); + auto bEnc = + cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return mlir::failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return mlir::failure(); + } + return mlir::success(); +} + +void WarpGroupDotOp::getEffects( + SmallVectorImpl> + &effects) { + auto &a = getAMutable(); + auto &b = getBMutable(); + if (isa(a.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &a, + mlir::triton::gpu::SharedMemory::get()); + if (isa(b.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &b, + mlir::triton::gpu::SharedMemory::get()); +} + +bool WarpGroupDotOp::needsPartialAccumulator() { + const auto &a = getA(); + const auto &d = getD(); + auto aTensorTy = cast(a.getType()); + auto aElTy = cast(a.getType()).getElementType(); + bool isFP8 = llvm::isa(aElTy); + bool accFP32 = + cast(d.getType()).getElementType().isF32(); + uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); + return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1]; +} + +bool WarpGroupDotOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; +} + +// -- WarpGroupDotWaitOp -- +LogicalResult WarpGroupDotWaitOp::inferReturnTypes( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + for (Value operand : operands) + inferredReturnTypes.push_back(operand.getType()); + return mlir::success(); +} + +///--- Async related ops --- +void GetAsyncTaskIdOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state) { + build(builder, state, builder.getI32Type()); +} + +void CreateTokenOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, uint32_t num, + TokenLoadType loadType) { + auto tokenType = TokenType::get(builder.getContext()); + auto resultType = RankedTensorType::get({num}, tokenType); + build(builder, state, resultType, num, loadType); +} + +static LogicalResult +verifyBarrierType(Operation *op, mlir::triton::gpu::MemDescType barrierType) { + if (!barrierType.getElementType().isInteger(64) || + barrierType.getShape() != ArrayRef({1})) + return op->emitOpError( + "barrier allocation must be a descriptor of 1xi64 type"); + return success(); +} + +// -- InitBarrierOp -- +LogicalResult InitBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +void InitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- InvalBarrierOp -- +LogicalResult InvalBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +void InvalBarrierOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- BarrierExpectOp -- +LogicalResult BarrierExpectOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +void BarrierExpectOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- WaitBarrierOp -- +LogicalResult WaitBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +void WaitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { + // The wait will flip the phase therefore it reads and writes the barrier. + effects.emplace_back(MemoryEffects::Read::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- TensorDescToTMAPtrOp -- +LogicalResult TensorDescToTMAPtrOp::canonicalize(TensorDescToTMAPtrOp op, + PatternRewriter &rewriter) { + // tensor_desc_to_tma_ptr(reinterpret_tensor_desc(ptr)) -> ptr + if (auto reinterpret = + op.getDesc().getDefiningOp()) { + rewriter.replaceOp(op, reinterpret.getRawDesc()); + return success(); + } + return failure(); +} + +// -- AsyncTMACopyGlobalToLocalOp -- +LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { + if (failed(verifyBarrierType(*this, getBarrier().getType()))) + return failure(); + if (getCoord().size() < 1 || getCoord().size() > 5) + return emitOpError("TMA copies must have between 1 and 5 coordinates"); + if (!getResult().getType().getMutableMemory()) + return emitOpError("Cannot store into immutable memory"); + return success(); +} + +void AsyncTMACopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getDescPtrMutable(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getBarrierMutable(), + mlir::triton::gpu::SharedMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- AsyncTMACopyLocalToGlobalOp -- +void AsyncTMACopyLocalToGlobalOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDescPtrMutable(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- AsyncTMAGatherOp -- +LogicalResult AsyncTMAGatherOp::verify() { + if (failed(verifyBarrierType(*this, getBarrier().getType()))) + return failure(); + + triton::gpu::MemDescType resultType = getResult().getType(); + if (!resultType.getMutableMemory()) + return emitOpError("cannot store into immutable memory"); + return ExperimentalDescriptorGatherOp::verifyResultType(*this, resultType); +} + +void AsyncTMAGatherOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getDescPtrMutable(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getBarrierMutable(), + mlir::triton::gpu::SharedMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- AsyncTMAScatter -- +void AsyncTMAScatterOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDescPtrMutable(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- TCGen5MMAOp -- +void TCGen5MMAOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDMutable(), + mlir::triton::nvidia_gpu::TensorMemory::get()); + if (isa( + getA().getType().getMemorySpace())) { + effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(), + mlir::triton::gpu::SharedMemory::get()); + + } else { + effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(), + mlir::triton::nvidia_gpu::TensorMemory::get()); + } + effects.emplace_back(MemoryEffects::Read::get(), &getBMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +bool TCGen5MMAOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; +} + +Value TCGen5MMAOp::useAccumulator() { return getUseD(); } + +void TCGen5MMAOp::setUseAccumulator(Value flag) { + getUseDMutable().assign(flag); +} + +void TCGen5MMAOp::setBarrier(Value barrier) { + getBarrierMutable().assign(barrier); +} + +Value TCGen5MMAOp::getAccumulator() { return getD(); } + +void TCGen5MMAOp::setAccumulator(Value accum) { getDMutable().assign(accum); } + +Value TCGen5MMAOp::getPredicate() { return getPred(); } + +void TCGen5MMAOp::setPredicate(Value pred) { getPredMutable().assign(pred); } + +// -- TMEMStoreOp -- +LogicalResult TMEMStoreOp::verify() { + if (!isa( + getDst().getType().getMemorySpace())) + return emitOpError("destination must be a tensor memory buffer."); + if (!isa(getDst().getType().getEncoding())) + return emitOpError("should use tensor memory encoding."); + if (!getDst().getType().getMutableMemory()) { + return emitOpError("Cannot store into an immutable alloc"); + } + return success(); +} + +// -- TCGen5MMAScaledOp -- +void TCGen5MMAScaledOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDMutable(), + mlir::triton::nvidia_gpu::TensorMemory::get()); + if (isa( + getA().getType().getMemorySpace())) { + effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(), + mlir::triton::gpu::SharedMemory::get()); + + } else { + effects.emplace_back(MemoryEffects::Read::get(), &getAMutable(), + mlir::triton::nvidia_gpu::TensorMemory::get()); + } + effects.emplace_back(MemoryEffects::Read::get(), &getBMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +bool TCGen5MMAScaledOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + auto aKdim = aShape[aShape.size() - 1]; + auto bKdim = bShape[aShape.size() - 2]; + if (this->getAType() == ScaleDotElemType::E2M1) + aKdim *= 2; + if (this->getBType() == ScaleDotElemType::E2M1) + bKdim *= 2; + + return aKdim == bKdim; +} + +Value TCGen5MMAScaledOp::useAccumulator() { return getUseD(); } + +void TCGen5MMAScaledOp::setUseAccumulator(Value flag) { + getUseDMutable().assign(flag); +} + +void TCGen5MMAScaledOp::setBarrier(Value barrier) { + getBarrierMutable().assign(barrier); +} + +Value TCGen5MMAScaledOp::getAccumulator() { return getD(); } + +void TCGen5MMAScaledOp::setAccumulator(Value accum) { + getDMutable().assign(accum); +} + +Value TCGen5MMAScaledOp::getPredicate() { return getPred(); } + +void TCGen5MMAScaledOp::setPredicate(Value pred) { + getPredMutable().assign(pred); +} + +// -- TMEMLoadOp -- +LogicalResult TMEMLoadOp::verify() { + if (!isa( + getSrc().getType().getMemorySpace())) + return emitOpError("source must be a tensor memory buffer."); + if (!isa( + getSrc().getType().getEncoding())) + return emitOpError("should use tensor memory encoding."); + return success(); +} + +// -- TMEMAllocOp -- +LogicalResult TMEMAllocOp::verify() { + if (!isa( + getType().getMemorySpace())) + return emitOpError("should create a buffer of tensor memory."); + if (!isa(getType().getEncoding())) + return emitOpError("should use tensor memory encoding."); + if (!getSrc()) { + if (!getType().getMutableMemory()) + return emitError("uninitialized alloc must have a mutable memdesc type"); + } + return success(); +} + +void TMEMAllocOp::getEffects( + SmallVectorImpl> + &effects) { + Operation *op = getOperation(); + // If allocation is immutable, mark it as no side effect allow things like + // CSE, DCE to work in early compiler passes. + // After the memory offset is computed, we attach the true side effect to the + // op. + if (!getType().getMutableMemory() && !op->hasAttr("tensor_memory_col_offset")) + return; + effects.emplace_back(MemoryEffects::Allocate::get(), + mlir::triton::nvidia_gpu::TensorMemory::get()); + if (getSrc()) + effects.emplace_back(MemoryEffects::Write::get(), + getOperation()->getOpResult(0), + mlir::triton::nvidia_gpu::TensorMemory::get()); +} + +bool isDescendingOrder(triton::gpu::MemDescType type) { + auto order = triton::gpu::getOrder(type); + auto rank = type.getRank(); + for (int i = 0; i < rank; ++i) { + if (order[i] != rank - 1 - i) + return false; + } + return true; +} + +// -- TMEMCopyOp -- +LogicalResult TMEMCopyOp::verify() { + if (!isa( + getSrc().getType().getMemorySpace())) + return emitOpError("The source must be a shared memory buffer"); + if (!isa( + getDst().getType().getEncoding())) + return emitOpError("The destination must be a tensor memory buffer."); + + if (getBarrier() && !isa( + getBarrier().getType().getMemorySpace())) { + return emitOpError("The optional barrier should be a shared memory buffer"); + } + if (!getDst().getType().getMutableMemory()) { + return emitOpError("Cannot copy into an immutable alloc"); + } + + auto srcTy = cast(getSrc().getType()); + auto sharedEnc = + cast(srcTy.getEncoding()); + + if (sharedEnc.getMaxPhase() != 1 || sharedEnc.getPerPhase() != 1 || + sharedEnc.getVec() != 1) + return emitOpError("The source should not have swizzling applied for now"); + + if (!isDescendingOrder(srcTy)) { + return emitOpError("The source must be in a row-major order."); + } + + // Given that we want to support flexible input SMEM shapes, kinds of shape + // checking we can do here are limited. For simplicity, shape checking is + // omitted. + return success(); +} + +void TMEMCopyOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), + mlir::triton::nvidia_gpu::TensorMemory::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), + mlir::triton::gpu::SharedMemory::get()); +} + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp new file mode 100644 index 000000000..326f4948a --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::nvidia_gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" + >(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..92e3669d0 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -0,0 +1,20 @@ +add_triton_library(TritonNvidiaGPUTransforms + FenceInsertion.cpp + MMALowering.cpp + KeepAccInTMem.cpp + PlanCTA.cpp + PromoteLHSToTMem.cpp + TensorMemoryAllocation.cpp + TMALowering.cpp + Utility.cpp + + DEPENDS + TritonNvidiaGPUTransformsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUIR + MLIRTransformUtils +) diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp new file mode 100644 index 000000000..fc34ddda7 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -0,0 +1,136 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// +// This pass works after all other passes, inserting fences to ensure that +// memory operations are properly ordered across generic and async proxy. +// +//===----------------------------------------------------------------------===// + +using namespace mlir; +namespace tt = ::mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +struct FenceInsertionPass + : public TritonGPUFenceInsertionBase { + +public: + FenceInsertionPass() = default; + FenceInsertionPass(int computeCapability) { + this->computeCapability = computeCapability; + } + // TODO: support more general patterns to insert fences. eg. any op(generic) + // to shared in use-def chain which refers by async proxy. We have generic( + // convertlayout with sts/stmatix) + fence + async(wgmma) up to now + void runOnOperation() override { + // Only insert fences for compute capability 9.0 + if (computeCapability < 90) + return; + if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) + return; + ModuleOp mod = getOperation(); + mod.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + OpBuilder builder(op); + auto a = op->getOperand(0); + auto b = op->getOperand(1); + auto mmaEncoding = dyn_cast( + cast(op->getResult(0).getType()).getEncoding()); + if (!mmaEncoding || !mmaEncoding.isHopper()) + return WalkResult::advance(); + bool aDependsOnShared = dependOnSharedEncOperand(a); + bool bDependsOnShared = dependOnSharedEncOperand(b); + if (!aDependsOnShared && !bDependsOnShared) + return WalkResult::advance(); + Operation *fence = builder.create( + op->getLoc(), /*bCluster=*/false); + // If there is all the dependencies are outside of the loop try to hoist + // the fence. + while (auto loopOp = fence->getParentOfType()) { + if (aDependsOnShared && + loopOp->isAncestor(a.getParentBlock()->getParentOp())) + break; + if (bDependsOnShared && + loopOp->isAncestor(b.getParentBlock()->getParentOp())) + break; + loopOp.moveOutOfLoop(fence); + } + return WalkResult::advance(); + }); + } + +private: + bool dependOnSharedEncOperand(Value operand) { + static DenseSet> trace; + auto op = operand.getDefiningOp(); + // avoid redundant insertion + if (op && isa(op)) + return false; + // reach convertlayout + if (op && isa(op) && + cast(op).getSrc()) + return true; + // root and not BlockArgument + if (!op && !isa(operand)) + return false; + // op and not BlockArgument + if (op && !isa(operand)) { + for (auto v : op->getOperands()) { + if (dependOnSharedEncOperand(v)) + return true; + } + } + // reach BlockArgument + // TODO: support other scf ops, IfOp, WhileOp, etc. + if (BlockArgument arg = dyn_cast(operand)) { + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + // support ForOp only + if (auto forOp = dyn_cast(argOwner)) { + // prologue + auto iterOperands = forOp.getInitArgs(); + if (argNum == 0) + return false; + if (dependOnSharedEncOperand(iterOperands[argNum - 1])) + return true; + // yield + auto yieldOp = forOp.getBody()->getTerminator(); + Value v = yieldOp->getOperand(argNum - 1); + auto entry = std::make_pair(std::move(yieldOp), + std::move(argNum)); + // avoid cyclic + if (trace.contains(entry)) + return false; + else + trace.insert(entry); + + if (dependOnSharedEncOperand(v)) + return true; + } else if (auto whileOp = dyn_cast(argOwner)) { + assert(false && "FenceInsertionPass does not supported WhileOp"); + } else if (auto ifOp = dyn_cast(argOwner)) { + assert(false && "FenceInsertionPass does not supported IfOp"); + } + } + return false; + } +}; +} // namespace + +std::unique_ptr +mlir::createTritonNvidiaGPUFenceInsertionPass(int computeCapability) { + return std::make_unique(computeCapability); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/KeepAccInTMem.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/KeepAccInTMem.cpp new file mode 100644 index 000000000..f0a97a8d0 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/KeepAccInTMem.cpp @@ -0,0 +1,258 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include + +namespace { + +using namespace mlir; + +namespace ttng = triton::nvidia_gpu; +namespace ttg = triton::gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +static bool bwdFilter(Operation *op) { + return isa(op) || + op->hasTrait() || + op->hasTrait(); +} + +static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = dyn_cast(type); + if (!tensorType) + return type; + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +class TMEMToGlobal : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Value data = op.getValue(); + auto tensorType = dyn_cast(data.getType()); + if (!tensorType) + return failure(); + llvm::SetVector slice; + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = bwdFilter; + getBackwardSlice(data, &slice, opt); + Attribute encoding; + for (auto op : slice) { + if (auto tmemLoad = dyn_cast(op)) { + if (!encoding) + encoding = tmemLoad.getType().getEncoding(); + if (tmemLoad.getType().getEncoding() != encoding) + return failure(); + } + } + if (!encoding || tensorType.getEncoding() == encoding) + return failure(); + // Use tmem load encoding to avoid going through shared memory. + Value newData = rewriter.create( + loc, getNewType(data.getType(), encoding), data); + Value newPointer = rewriter.create( + loc, getNewType(op.getPtr().getType(), encoding), op.getPtr()); + Value newMask; + if (op.getMask()) + newMask = rewriter.create( + loc, getNewType(op.getMask().getType(), encoding), op.getMask()); + rewriter.create(loc, newPointer, newData, newMask, + op.getBoundaryCheck(), op.getCache(), + op.getEvict()); + rewriter.eraseOp(op); + return success(); + } +}; + +static void addTMEMLoad(IRRewriter &rewriter, ttng::TMEMAllocOp localAlloc, + Operation *user, int argNo) { + rewriter.setInsertionPoint(user); + auto load = rewriter.create( + user->getLoc(), user->getOperand(argNo).getType(), + localAlloc->getResult(0)); + user->setOperand(argNo, load); +} + +static bool canKeepAccInTmem(scf::ForOp forOp, Operation *mmaOp, + ttng::TMEMAllocOp &localAlloc, + ttng::TMEMLoadOp &localLoad, + SmallVector> &accUsers, + unsigned &yieldArgNo) { + // The expected sequence of instructions: + // %acc_tm = ttg.local_alloc %acc + // ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm + // %acc_res = ttg.local_load %acc_tm + localAlloc = mmaOp->getOperand(2).getDefiningOp(); + if (!localAlloc) { + return false; + } + for (auto user : localAlloc->getUsers()) { + if (isa(user)) { + localLoad = cast(user); + } else if (user != mmaOp) { + // The accumulator is used by another operation, not something we + // expect. + localLoad = nullptr; + return false; + } + } + + SmallVector queue; + queue.push_back(localLoad->getResult(0)); + bool foundDotCycle = false; + while (!queue.empty()) { + Value value = queue.pop_back_val(); + for (auto &use : value.getUses()) { + if (use.getOwner() == localAlloc) { + foundDotCycle = true; + continue; + } + if (auto yieldOp = dyn_cast(use.getOwner())) { + if (yieldOp->getParentOp() == forOp) { + yieldArgNo = use.getOperandNumber(); + queue.push_back(forOp.getRegionIterArg(yieldArgNo)); + continue; + } + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + // TODO: Accumulator being used in the yield of ifOp means that + // it is being modified in the other branch of the ifOp. This is not + // something we can handle yet. + return false; + } + // Not sure what are we doing here. Back out. + return false; + } + accUsers.emplace_back(use.getOwner(), use.getOperandNumber()); + } + } + return foundDotCycle; +} + +static void hoistReadModifyWrite(Operation *mmaOp, scf::ForOp forOp) { + // For the transformation to make sense, the accumulator must be + // reused by the same MMA operation in subsequent iterations. + SmallVector> accUsers; + ttng::TMEMAllocOp localAlloc = nullptr; + ttng::TMEMLoadOp localLoad = nullptr; + unsigned yieldArgNo; + if (!canKeepAccInTmem(forOp, mmaOp, localAlloc, localLoad, accUsers, + yieldArgNo)) { + return; + } + + assert(localLoad != nullptr); + assert(localAlloc != nullptr); + Type loadType = localLoad->getResult(0).getType(); + IRRewriter rewriter(forOp); + localAlloc->moveBefore(forOp); + localAlloc->setOperand(0, forOp.getInitArgs()[yieldArgNo]); + mmaOp->setOperand(2, localAlloc->getResult(0)); + // Unlink the local_load from the yield. Short circuit the unused yield + // value with the corresponding iter arg. + forOp.getBody()->getTerminator()->setOperand( + yieldArgNo, forOp.getRegionIterArg(yieldArgNo)); + + // Add TMEM loads before all the uses + // TODO: We could be more efficient here, reusing loads instead of + // creating new ones for each use. + for (auto [user, argNo] : accUsers) { + addTMEMLoad(rewriter, localAlloc, user, argNo); + } + + rewriter.setInsertionPointAfter(forOp); + auto afterLoopLoad = rewriter.create( + forOp.getLoc(), loadType, localAlloc->getResult(0)); + forOp->getResult(yieldArgNo).replaceAllUsesWith(afterLoopLoad->getResult(0)); + + localLoad->erase(); +} + +// Hoist invariant tmem_alloc. This could technically be done as general LICM +// but controlling tmem liveranga more precisley is likely to be important. +static void hoistInvariantInputs(Operation *mmaOp, scf::ForOp forOp) { + for (auto operand : mmaOp->getOperands()) { + if (forOp.isDefinedOutsideOfLoop(operand)) + continue; + auto tmemAllocOp = operand.getDefiningOp(); + if (!tmemAllocOp || tmemAllocOp.getType().getMutableMemory()) + continue; + assert(tmemAllocOp.getSrc()); + Value src = tmemAllocOp.getSrc(); + SmallVector opToHoist = {tmemAllocOp.getOperation()}; + // Also hoist simple unary elementwise that may have sinked into the loop. + while (Operation *defOp = src.getDefiningOp()) { + if (forOp.isDefinedOutsideOfLoop(src)) + break; + if (!(isMemoryEffectFree(defOp) && isSpeculatable(defOp) && + defOp->getNumOperands() == 1)) + break; + opToHoist.push_back(defOp); + src = defOp->getOperand(0); + } + if (!forOp.isDefinedOutsideOfLoop(src)) + continue; + for (auto op : llvm::reverse(opToHoist)) { + forOp.moveOutOfLoop(op); + } + } +} +class TritonNvidiaGPUKeepAccInTMemPass + : public TritonNvidiaGPUKeepAccInTMemPassBase< + TritonNvidiaGPUKeepAccInTMemPass> { +public: + using TritonNvidiaGPUKeepAccInTMemPassBase< + TritonNvidiaGPUKeepAccInTMemPass>::TritonNvidiaGPUKeepAccInTMemPassBase; + + void runOnOperation() override { + auto module = getOperation(); + + module.walk([&](scf::ForOp forOp) { runOnForOp(forOp); }); + + if (triton::tools::getBoolEnv("STORE_TMEM_TO_GLOBAL_BYPASS_SMEM")) { + mlir::RewritePatternSet patterns(module.getContext()); + patterns.add(module.getContext()); + if (applyPatternsGreedily(module, std::move(patterns)).failed()) + signalPassFailure(); + } + } + + void runOnForOp(scf::ForOp forOp) { + SmallVector mmaOps; + forOp.walk([&](Operation *mmaOp) { + // Skip MMA nested in another forOp + if (isa(mmaOp) && + mmaOp->getParentOfType() == forOp) { + mmaOps.push_back(mmaOp); + } + }); + if (mmaOps.empty()) { + return; + } + + for (auto mmaOp : mmaOps) { + hoistReadModifyWrite(mmaOp, forOp); + hoistInvariantInputs(mmaOp, forOp); + } + } +}; + +} // namespace + +std::unique_ptr mlir::createTritonNvidiaGPUKeepAccInTMemPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp new file mode 100644 index 000000000..f206c3a36 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp @@ -0,0 +1,133 @@ +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; +using namespace triton::nvidia_gpu; + +template +class SyncMMALowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TCGen5MMAOpTy op, + PatternRewriter &rewriter) const override { + // If the op doesn't have synchronous semantic skip the pattern. + if (op.getBarrier()) + return failure(); + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Attribute sharedMemorySpace = SharedMemorySpaceAttr::get(ctx); + auto barrierCTALayout = CTALayoutAttr::get( + /*context=*/ctx, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout); + MemDescType barrierMemDescType = + MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = + rewriter.create(loc, barrierMemDescType, Value()); + rewriter.create(loc, barrierAlloc, 1); + op.getBarrierMutable().assign(barrierAlloc); + + rewriter.setInsertionPointAfter(op); + Value phase = rewriter.create(loc, 0, 32); + rewriter.create(loc, barrierAlloc, phase); + rewriter.create(loc, barrierAlloc); + return success(); + } +}; + +struct TCGen5MMAScaleSharedToTmemConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Create a tmem_copy of scales from shared memory to tmem. `rows` is the M or + // N of the MMA operation (for LHS or RHS respectively). + bool lowerScaleToTmem(OpOperand &operand, PatternRewriter &rewriter, + int rows) const { + Location loc = operand.getOwner()->getLoc(); + MLIRContext *context = operand.getOwner()->getContext(); + Attribute tensorMemorySpace = TensorMemorySpaceAttr::get(context); + auto oldType = cast(operand.get().getType()); + auto numElems = product(oldType.getShape()); + Type elType = oldType.getElementType(); + SwizzledSharedEncodingAttr oldEncoding = + cast(oldType.getEncoding()); + CTALayoutAttr CTALayout = getCTALayout(oldEncoding); + ArrayRef CTASplitNum = CTALayout.getCTASplitNum(); + // Distribute the scales across the rows of the MMA operation. + SmallVector shape = {rows, numElems / rows}; + Attribute scaleEncoding = TensorMemoryScalesEncodingAttr::get( + context, CTASplitNum[0], CTASplitNum[1]); + Type scaleAType = + MemDescType::get(shape, elType, scaleEncoding, tensorMemorySpace, + /*mutableMemory=*/true); + auto tmemAlloc = rewriter.create(loc, scaleAType, Value()); + rewriter.create(loc, operand.get(), tmemAlloc, + /*barrier*/ Value()); + operand.set(tmemAlloc); + return true; + } + + LogicalResult matchAndRewrite(TCGen5MMAScaledOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op->getContext(); + auto aScaleType = op.getAScale().getType(); + auto bScaleType = op.getBScale().getType(); + int blockM = op.getA() + .getType() + .getShape()[op.getA().getType().getShape().size() - 2]; + int blockN = op.getB() + .getType() + .getShape()[op.getB().getType().getShape().size() - 1]; + int blockK = op.getA() + .getType() + .getShape()[op.getA().getType().getShape().size() - 1]; + bool anyChanged = false; + if (isa(aScaleType.getEncoding())) { + anyChanged = lowerScaleToTmem(op.getAScaleMutable(), rewriter, blockM); + } + if (isa(bScaleType.getEncoding())) { + anyChanged = lowerScaleToTmem(op.getBScaleMutable(), rewriter, blockN); + } + return LogicalResult::success(anyChanged); + } +}; + +class TritonNvidiaGPUMMALoweringPass + : public TritonNvidiaGPUMMALoweringPassBase< + TritonNvidiaGPUMMALoweringPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns + .add, SyncMMALowering, + TCGen5MMAScaleSharedToTmemConversion>(context); + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::createTritonNvidiaGPUMMALoweringPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp new file mode 100644 index 000000000..552746190 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -0,0 +1,1061 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include + +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +using namespace mlir; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +// TODO: use ConvertLayoutOp +using CastOp = ::mlir::UnrealizedConversionCastOp; + +unsigned getNumUsers(Value value) { + return std::distance(value.user_begin(), value.user_end()); +} + +Type replaceLayout(const Type &type, const Attribute &newLayout) { + Type curType = type; + auto ptrTy = dyn_cast(curType); + if (ptrTy) + curType = ptrTy.getPointeeType(); + if (auto tensorTy = dyn_cast(curType)) + curType = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newLayout); + if (ptrTy) + curType = triton::PointerType::get(curType, ptrTy.getAddressSpace()); + return curType; +} + +ttg::DistributedEncodingTrait +replaceCTALayout(ttg::DistributedEncodingTrait layout, + llvm::ArrayRef shape, + const ttg::CTALayoutAttr &newCTALayout) { + if (auto blockedLayout = mlir::dyn_cast(layout)) { + return ttg::BlockedEncodingAttr::get( + layout.getContext(), shape, blockedLayout.getSizePerThread(), + blockedLayout.getOrder(), ttg::getNumWarpsPerCTA(layout), 32, + newCTALayout); + } else if (auto sliceLayout = + mlir::dyn_cast(layout)) { + return ttg::SliceEncodingAttr::get( + layout.getContext(), sliceLayout.getDim(), + replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout)); + } else { + // Other layouts are generated by passes after PlanCTAPass + llvm::report_fatal_error("replaceCTALayout not implemented"); + return layout; + } +} + +class CTAPlanner { +public: + CTAPlanner(ttng::ClusterInfo *clusterInfo_); + ~CTAPlanner(); + + void run(triton::FuncOp &funcOp); + +private: + CastOp markBackward(CastOp cast) const; + CastOp markForward(CastOp cast) const; + bool isBackward(CastOp cast) const; + bool isForward(CastOp cast) const; + + void setTiling(llvm::ArrayRef CTAsPerCGA); + bool processDot(triton::FuncOp &funcOp); + bool processReduce(triton::FuncOp &funcOp); + void processStoreLikeOps(triton::FuncOp &funcOp); + + bool propagate(CastOp cast); + bool propagateBackward(CastOp cast); + bool propagateForward(CastOp cast); + + void eraseCastOp(CastOp cast); + void eraseCastOpFromQueue(CastOp cast); + void eraseCastOpsFromQueue(llvm::ArrayRef casts); + + void insertCasts(Operation *op, llvm::ArrayRef newOperandLayouts, + llvm::ArrayRef newResultLayouts); + void eliminateAdjacentCasts(CastOp cast0, CastOp cast1); + + bool isLoadStoreOp(Operation *op) const; + bool processLoadStore(Operation *op, Attribute layout); + + bool isElementwiseOp(Operation *op) const; + bool processElementwise(Operation *op, Attribute layout); + + bool processConstant(arith::ConstantOp constant, Attribute layout); + bool processSplat(triton::SplatOp splat, Attribute layout); + bool processMakeRange(triton::MakeRangeOp makeRange, Attribute layout); + bool processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr, + Attribute layout); + + bool processBroadcast(triton::BroadcastOp broadcast, Attribute layout); + bool processExpandDimsBackward(triton::ExpandDimsOp expandDims, + ttg::DistributedEncodingTrait newResultLayout); + bool processExpandDimsForward(triton::ExpandDimsOp expandDims, + ttg::DistributedEncodingTrait newSrcLayout); + + bool processConvertLayoutBackward(ttg::ConvertLayoutOp convertLayout, + CastOp cast); + bool processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout, + CastOp cast); + + bool processIfOp(scf::IfOp ifOp, int index, const Type &newType); + bool processForOp(scf::ForOp forOp, int index, const Type &newType); + + bool processIfOpBackward(scf::IfOp ifOp, CastOp cast); + bool processForOpBackward(scf::ForOp forOp, CastOp cast); + bool processBlockArgBackward(BlockArgument arg, CastOp cast); + bool processForOpForward(scf::ForOp forOp, CastOp cast); + bool processYieldOpForward(scf::YieldOp yieldOp, CastOp cast); + + bool processOpFallback(Operation *op); + + bool processMultiUsersBackward(Value input, CastOp cast); + bool processMultiUsersForward(Value output, CastOp cast); + + // This flag indicates whether clusterInfo needs to be deleted in the + // destructor of CTAPlanner. The flag `ownInfo` is set to false when a + // non-null pointer to clusterInfo is passed to the constructor of CTAPlanner. + // Otherwise, a self-managed ClusterInfo will be created and the ownInfo will + // be set to true. + bool ownInfo; + ttng::ClusterInfo *clusterInfo; + bool tiled; + unsigned step; + unsigned stepUnchanged; + std::queue queue; +}; + +CTAPlanner::CTAPlanner(ttng::ClusterInfo *clusterInfo_) + : ownInfo(false), clusterInfo(clusterInfo_), tiled(false), step(0), + stepUnchanged(0) { + if (clusterInfo == nullptr) { + clusterInfo = new ttng::ClusterInfo(); + ownInfo = true; + } +} + +CTAPlanner::~CTAPlanner() { + if (ownInfo) { + delete clusterInfo; + // Actually not necessary but safer + ownInfo = false; + clusterInfo = nullptr; + } +} + +void CTAPlanner::run(triton::FuncOp &funcOp) { + assert(!tiled && "Please create a new CTAPlanner"); + static const unsigned maxSteps = 10000; + + auto nextStep = [&]() { + ++step; + assert(step < maxSteps && "Maximum number of steps exceeded"); + }; + + processDot(funcOp); + nextStep(); + + processReduce(funcOp); + nextStep(); + + if (!tiled) { + processStoreLikeOps(funcOp); + nextStep(); + } + + while (!queue.empty()) { + CastOp cast = queue.front(); + queue.pop(); + bool changed = propagate(cast); + if (changed) { + stepUnchanged = 0; + } else { + queue.push(cast); + ++stepUnchanged; + } + nextStep(); + } +} + +CastOp CTAPlanner::markBackward(CastOp cast) const { + cast->setAttr("direction", StringAttr::get(cast.getContext(), "backward")); + return cast; +} + +CastOp CTAPlanner::markForward(CastOp cast) const { + cast->setAttr("direction", StringAttr::get(cast.getContext(), "forward")); + return cast; +} + +bool CTAPlanner::isBackward(CastOp cast) const { + return cast->getAttrOfType("direction") == "backward"; +} + +bool CTAPlanner::isForward(CastOp cast) const { + return cast->getAttrOfType("direction") == "forward"; +} + +void CTAPlanner::setTiling(llvm::ArrayRef CTAsPerCGA) { + assert(!tiled && "CTA tiling is already determinted"); + assert(clusterInfo && "ClusterInfo pointer is null"); + assert(CTAsPerCGA.size() <= 3 && "setTiling not implemented"); + tiled = true; + unsigned numCTAs = 1; + for (unsigned cta : CTAsPerCGA) + numCTAs *= cta; + if (numCTAs == 2) { + // For 2 CTAs always use 2x1x1. + // TODO: can we always serialize the CTAs on X dimension? + clusterInfo->clusterDimX = 2; + return; + } + + if (CTAsPerCGA.size() > 0) + clusterInfo->clusterDimX = CTAsPerCGA[0]; + if (CTAsPerCGA.size() > 1) + clusterInfo->clusterDimY = CTAsPerCGA[1]; + if (CTAsPerCGA.size() > 2) + clusterInfo->clusterDimZ = CTAsPerCGA[2]; +} + +bool CTAPlanner::processDot(triton::FuncOp &funcOp) { + // TODO: This is a naive implementation and should be refactored + auto getCTATiling = [](int64_t M, int64_t N, int64_t K, + unsigned numCTAs) -> std::pair { + // prefer a larger chunk size, at most 128; first assign splitM. + unsigned chunk_m = 128; + auto isLegal = [](unsigned chunk) { return chunk >= 64; }; + unsigned splitM, splitN; + for (; isLegal(chunk_m); chunk_m /= 2) { + splitM = std::clamp(M / chunk_m, 1, numCTAs); + splitN = numCTAs / splitM; + if (isLegal(N / splitN)) // chunk_n; + break; + } + return {splitM, splitN}; + }; + + funcOp.walk([&](triton::DotOp dot) { + MLIRContext *ctx = dot.getContext(); + + auto aTy = cast(dot.getA().getType()); + auto bTy = cast(dot.getB().getType()); + auto dTy = cast(dot.getD().getType()); + + assert(isa(aTy.getEncoding()) && + isa(bTy.getEncoding()) && + isa(dTy.getEncoding()) && + "PlanCTAPass should follow immediately after CoalescePass"); + + auto aLayout = cast(aTy.getEncoding()); + auto bLayout = cast(bTy.getEncoding()); + auto dLayout = cast(dTy.getEncoding()); + + unsigned M = dTy.getShape()[0]; + unsigned N = dTy.getShape()[1]; + unsigned K = aTy.getShape()[1]; + + unsigned splitM, splitN; + std::tie(splitM, splitN) = getCTATiling(M, N, K, ttg::getNumCTAs(dLayout)); + // FIXME: Should consider IR with more than one DotOps + setTiling({splitM, splitN, 1}); + + auto newCTALayout = ttg::CTALayoutAttr::get(ctx, {splitM, splitN}, + {splitM, splitN}, {1, 0}); + auto newDLayout = ttg::BlockedEncodingAttr::get( + ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(), + ttg::getNumWarpsPerCTA(dLayout), 32, newCTALayout); + auto newALayout = ttg::DotOperandEncodingAttr::get(ctx, aLayout.getOpIdx(), + newDLayout, 0); + auto newBLayout = ttg::DotOperandEncodingAttr::get(ctx, bLayout.getOpIdx(), + newDLayout, 0); + + insertCasts(dot.getOperation(), {newALayout, newBLayout, newDLayout}, + {newDLayout}); + }); + + return true; +} + +bool CTAPlanner::processReduce(triton::FuncOp &funcOp) { + ModuleOp mod = funcOp->getParentOfType(); + unsigned numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + funcOp.walk([&](triton::ReduceOp reduce) { + MLIRContext *context = reduce.getContext(); + Value src = reduce.getOperands()[0]; + unsigned axis = reduce.getAxis(); + + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + auto srcLayout = srcTy.getEncoding(); + + auto rank = srcShape.size(); + auto order = ttg::getOrder(srcTy); + auto sizePerThread = ttg::getContigPerThread(srcTy); + auto CTAOrder = ttg::getCTAOrder(srcLayout); + + llvm::SmallVector CTAsPerCGA(rank, 0); + unsigned remainingCTAs = numCTAs; + for (int i = rank - 1; i >= 0; --i) { + unsigned dim = order[i]; + if (dim == axis) { + CTAsPerCGA[dim] = 1; + } else { + CTAsPerCGA[dim] = std::min(srcShape[dim] / sizePerThread[dim], + remainingCTAs); + remainingCTAs /= CTAsPerCGA[dim]; + } + } + + for (int i = rank - 1; i >= 0; --i) { + unsigned dim = order[i]; + if (dim != axis) { + CTAsPerCGA[dim] *= remainingCTAs; + break; + } + } + + llvm::SmallVector CTASplitNum = CTAsPerCGA; + + // If numCTAs > 1 and the only dimension is the reduced dimension, after the + // above two for-loops, CTAsPerCGA = [0] and remainingCTAs = numCTAs. We set + // CTAsPerCGA[0] = numCTAs and keep CTASplitNum[0] = 1 to ensure that no + // cross-CTA reduction is required, although this will introduce duplicated + // calculation + if (remainingCTAs > 0) + CTAsPerCGA[order[rank - 1]] *= remainingCTAs; + + auto CTALayout = + ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + if (!tiled) + setTiling(CTALayout.getCTAsPerCGA()); + auto newSrcLayout = replaceCTALayout( + cast(srcLayout), srcShape, CTALayout); + auto newResultLayout = + ttg::SliceEncodingAttr::get(context, axis, newSrcLayout); + unsigned numOperands = reduce.getNumOperands(); + SmallVector newSrcLayoutVec(numOperands, newSrcLayout); + SmallVector newResultLayoutVec(numOperands, newResultLayout); + + insertCasts(reduce.getOperation(), newSrcLayoutVec, newResultLayoutVec); + }); + return true; +} + +void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) { + assert(!tiled && "CTA tiling is already determinted"); + + llvm::SmallVector stores; + funcOp.walk([&](Operation *op) { + if (llvm::isa( + op)) + stores.push_back(op); + }); + assert(stores.size() > 0 && "Cannot find store-like ops"); + + ttg::CTALayoutAttr CTALayout; + for (Operation *store : stores) { + if (auto tensorTy = + dyn_cast(store->getOperand(0).getType())) { + if (!tiled) { + // Use CTA tiling of the first store-like op as global CTA tiling + CTALayout = ttg::getCTALayout(tensorTy.getEncoding()); + setTiling(CTALayout.getCTAsPerCGA()); + } + auto newLayout = replaceCTALayout( + cast(tensorTy.getEncoding()), + tensorTy.getShape(), CTALayout); + processElementwise(store, newLayout); + } + } + + // If all store-like ops are processing scalar values and no ReduceOp is + // found, we can conclude that this is an all-scalar computation, since + // ReduceOp is the only op that converts tensor values to scalar values. + if (!tiled) + setTiling({1, 1, 1}); +} + +bool CTAPlanner::propagate(CastOp cast) { + return isBackward(cast) ? propagateBackward(cast) : propagateForward(cast); +} + +bool CTAPlanner::propagateBackward(CastOp cast) { + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + unsigned numUsers = getNumUsers(input); + if (numUsers == 0) { + llvm::report_fatal_error("Unreachable branch"); + return false; + } else if (numUsers == 1) { + Type outTy = output.getType(); + if (auto ptrTy = dyn_cast(outTy)) + outTy = ptrTy.getPointeeType(); + auto layout = mlir::cast( + mlir::cast(outTy).getEncoding()); + Operation *op = input.getDefiningOp(); + if (op == nullptr) { + assert(isa(input) && + "Unexpected Value without defining op"); + processBlockArgBackward(llvm::cast(input), cast); + } else if (auto prevCast = llvm::dyn_cast(op)) { + eliminateAdjacentCasts(prevCast, cast); + } else if (isLoadStoreOp(op)) { + processLoadStore(op, layout); + } else if (isElementwiseOp(op)) { + processElementwise(op, layout); + } else if (auto constant = llvm::dyn_cast(op)) { + processConstant(constant, layout); + } else if (auto splat = llvm::dyn_cast(op)) { + processSplat(splat, layout); + } else if (auto makeRange = llvm::dyn_cast(op)) { + processMakeRange(makeRange, layout); + } else if (auto makeTensorPtr = + llvm::dyn_cast(op)) { + processMakeTensorPtr(makeTensorPtr, layout); + } else if (llvm::isa(op)) { + // ptr operand and result have the same layout, while other operands are + // scalar values + processElementwise(op, layout); + } else if (auto broadcast = llvm::dyn_cast(op)) { + processBroadcast(broadcast, layout); + } else if (auto expandDims = llvm::dyn_cast(op)) { + processExpandDimsBackward(expandDims, layout); + } else if (auto ifOp = llvm::dyn_cast(op)) { + processIfOpBackward(ifOp, cast); + } else if (auto forOp = llvm::dyn_cast(op)) { + processForOpBackward(forOp, cast); + } else if (auto convertLayout = llvm::dyn_cast(op)) { + return processConvertLayoutBackward(convertLayout, cast); + } else { + // Keep original layouts. This may result in a loss of performance. + return processOpFallback(op); + } + return true; + } else { + return processMultiUsersBackward(input, cast); + } +} + +bool CTAPlanner::propagateForward(CastOp cast) { + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + unsigned numUsers = getNumUsers(output); + if (numUsers == 0) { + cast.erase(); + } else if (numUsers == 1) { + Type inTy = input.getType(); + if (auto ptrTy = dyn_cast(inTy)) + inTy = ptrTy.getPointeeType(); + Attribute layout = mlir::cast(inTy).getEncoding(); + Operation *op = *output.user_begin(); + if (auto nextCast = llvm::dyn_cast(op)) { + eliminateAdjacentCasts(cast, nextCast); + } else if (isLoadStoreOp(op)) { + processLoadStore(op, layout); + } else if (isElementwiseOp(op)) { + processElementwise(op, layout); + } else if (llvm::isa(op)) { + // ptr operand and result have the same layout, while other operands are + // scalar values + processElementwise(op, layout); + } else if (auto convertLayout = llvm::dyn_cast(op)) { + return processConvertLayoutForward(convertLayout, cast); + } else if (auto forOp = llvm::dyn_cast(op)) { + processForOpForward(forOp, cast); + } else if (auto yieldOp = llvm::dyn_cast(op)) { + processYieldOpForward(yieldOp, cast); + } else { + // Keep original layouts. This may result in a loss of performance. + return processOpFallback(op); + } + } else { + processMultiUsersForward(output, cast); + } + return true; +} + +void CTAPlanner::eraseCastOp(CastOp cast) { + Value output = cast.getResult(0); + assert(getNumUsers(output) == 0 && + "Cannot erase CastOp because it is still in use"); + cast.erase(); +} + +void CTAPlanner::eraseCastOpFromQueue(CastOp cast) { + eraseCastOpsFromQueue({cast}); +} + +void CTAPlanner::eraseCastOpsFromQueue(llvm::ArrayRef casts) { + llvm::DenseSet erased; + for (CastOp cast : casts) { + eraseCastOp(cast); + erased.insert(cast); + } + + decltype(queue) tempQueue; + std::swap(queue, tempQueue); + + // This is only a naive implementation. Should refactor with linked-list. + while (!tempQueue.empty()) { + auto cast = tempQueue.front(); + tempQueue.pop(); + if (!erased.contains(cast)) + queue.push(cast); + } +} + +void CTAPlanner::insertCasts(Operation *op, + llvm::ArrayRef newOperandLayouts, + llvm::ArrayRef newResultLayouts) { + assert(op->getNumOperands() == newOperandLayouts.size() && + "NumOperands mismatched"); + assert(op->getNumResults() == newResultLayouts.size() && + "NumResults mismatched"); + + Location loc = op->getLoc(); + OpBuilder builder(op->getContext()); + + builder.setInsertionPoint(op); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto operandTy = operand.getType(); + if (triton::isTensorOrTensorPointerType(operandTy)) { + operandTy = replaceLayout(operandTy, newOperandLayouts[i]); + auto cast = markBackward(builder.create(loc, operandTy, operand)); + op->setOperand(i, cast.getResult(0)); + queue.push(cast); + } + } + + builder.setInsertionPointAfter(op); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + auto resultTy = result.getType(); + if (triton::isTensorOrTensorPointerType(resultTy)) { + resultTy = replaceLayout(resultTy, newResultLayouts[i]); + auto cast = + markForward(builder.create(loc, result.getType(), result)); + result.setType(resultTy); + result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation()); + queue.push(cast); + } + } +} + +void CTAPlanner::eliminateAdjacentCasts(CastOp cast0, CastOp cast1) { + assert(cast0.getResult(0) == cast1.getOperand(0) && + "The two casts are not adjacent"); + assert(isForward(cast0) && isBackward(cast1) && + "Expected pattern of adjacent casts: forward + backward"); + + Value input = cast0.getOperand(0); + Value output = cast1.getResult(0); + + if (input.getType() == output.getType()) { + output.replaceAllUsesWith(input); + eraseCastOpsFromQueue({cast1, cast0}); + } else { + OpBuilder builder(cast1.getOperation()); + auto cvt = builder.create(cast1.getLoc(), + output.getType(), input); + output.replaceAllUsesWith(cvt.getResult()); + eraseCastOpsFromQueue({cast1, cast0}); + } +} + +bool CTAPlanner::isLoadStoreOp(Operation *op) const { + return llvm::isa(op); +} + +bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) { + // Special logic for: + // LoadOp -> SliceLayout + // Transform to: + // LoadOp -> originalLayout -> ConvertLayout(DSmem) -> SliceLayout + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = ttg::getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] > 1) { + // Find an input or output value of LoadOp or StoreOp to get its layout + Value val = + op->getNumResults() > 0 ? op->getResult(0) : op->getOperand(0); + Attribute originalLayout = + cast(val.getType()).getEncoding(); + // Insert casts using originalLayout. Adjacent casts will be eliminated + // and generate a ConvertLayoutOp with DSmem access + return processLoadStore(op, originalLayout); + } + } + + auto CTALayout = ttg::getCTALayout(layout); + + llvm::SmallVector newOperandLayouts; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + auto type = op->getOperand(i).getType(); + if (auto ptrTy = dyn_cast(type)) + type = ptrTy.getPointeeType(); + auto tensorTy = cast(type); + auto oldLayout = + cast(tensorTy.getEncoding()); + auto newLayout = + replaceCTALayout(oldLayout, tensorTy.getShape(), CTALayout); + newOperandLayouts.push_back(newLayout); + } + + llvm::SmallVector newResultLayouts; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + auto type = op->getResult(i).getType(); + if (auto ptrTy = dyn_cast(type)) + type = ptrTy.getPointeeType(); + auto tensorTy = cast(type); + auto oldLayout = + cast(tensorTy.getEncoding()); + auto newLayout = + replaceCTALayout(oldLayout, tensorTy.getShape(), CTALayout); + newResultLayouts.push_back(newLayout); + } + + insertCasts(op, newOperandLayouts, newResultLayouts); + return true; +} + +bool CTAPlanner::isElementwiseOp(Operation *op) const { + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (auto externElementwiseOp = dyn_cast(op)) + return externElementwiseOp.getPure(); + if (llvm::isa(op)) + return true; + return false; +} + +bool CTAPlanner::processElementwise(Operation *op, Attribute layout) { + llvm::SmallVector newOperandLayouts(op->getNumOperands(), layout); + llvm::SmallVector newResultLayouts(op->getNumResults(), layout); + insertCasts(op, newOperandLayouts, newResultLayouts); + return true; +} + +bool CTAPlanner::processConstant(arith::ConstantOp constant, Attribute layout) { + if (auto tensorTy = dyn_cast(constant.getType())) { + if (auto attr = dyn_cast(constant.getValue())) { + + auto newTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), layout); + constant.setValueAttr( + SplatElementsAttr::get(newTensorTy, attr.getSplatValue())); + } + } + insertCasts(constant.getOperation(), {}, {layout}); + return true; +} + +bool CTAPlanner::processSplat(triton::SplatOp splat, Attribute layout) { + insertCasts(splat.getOperation(), {{}}, {layout}); + return true; +} + +bool CTAPlanner::processMakeRange(triton::MakeRangeOp makeRange, + Attribute layout) { + insertCasts(makeRange.getOperation(), {}, {layout}); + return true; +} + +bool CTAPlanner::processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr, + Attribute layout) { + // All inputs of `makeTensorPtr` are scalar types + llvm::SmallVector dummyInAttrs(makeTensorPtr.getNumOperands(), {}); + insertCasts(makeTensorPtr.getOperation(), dummyInAttrs, {layout}); + return true; +} + +bool CTAPlanner::processBroadcast(triton::BroadcastOp broadcast, + Attribute layout) { + insertCasts(broadcast.getOperation(), {layout}, {layout}); + return true; +} + +bool CTAPlanner::processExpandDimsBackward( + triton::ExpandDimsOp expandDims, + ttg::DistributedEncodingTrait newResultLayout) { + auto newSrcLayout = ttg::SliceEncodingAttr::get( + newResultLayout.getContext(), expandDims.getAxis(), newResultLayout); + insertCasts(expandDims.getOperation(), {newSrcLayout}, {newResultLayout}); + return true; +} + +bool CTAPlanner::processExpandDimsForward( + triton::ExpandDimsOp expandDims, + ttg::DistributedEncodingTrait newSrcLayout) { + llvm::report_fatal_error("processExpandDimsForward not implemented yet"); + return true; +} + +bool CTAPlanner::processConvertLayoutBackward( + ttg::ConvertLayoutOp convertLayout, CastOp cast) { + Value src = convertLayout.getSrc(); + Value result = convertLayout.getResult(); + assert(getNumUsers(result) == 1 && + "Expect to call processMultiUsersBackward first"); + result.replaceAllUsesWith(src); + convertLayout.erase(); + queue.push(cast); + return true; +} + +bool CTAPlanner::processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout, + CastOp cast) { + Value src = convertLayout.getSrc(); + Value result = convertLayout.getResult(); + assert(getNumUsers(src) == 1 && + "Expect to call processMultiUsersForward first"); + src.setType(result.getType()); + result.replaceAllUsesWith(src); + convertLayout.erase(); + queue.push(cast); + return true; +} + +bool CTAPlanner::processIfOp(scf::IfOp ifOp, int index, const Type &newType) { + // Check index + assert(index < ifOp.getNumResults() && "Invalid result index of IfOp"); + assert(index < ifOp.thenYield().getNumOperands() && + "Invalid operand index of YieldOp"); + assert(index < ifOp.elseYield().getNumOperands() && + "Invalid operand index of YieldOp"); + + Location loc = ifOp.getLoc(); + OpBuilder builder(ifOp.getContext()); + + // Insert forward cast after ifOp + Value result = ifOp.getResult(index); + builder.setInsertionPointAfter(ifOp.getOperation()); + auto newCast = + markForward(builder.create(loc, result.getType(), result)); + result.setType(newType); + result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + // Insert backward casts before yield + for (scf::YieldOp yield : {ifOp.thenYield(), ifOp.elseYield()}) { + Value yieldSrc = yield.getOperand(index); + builder.setInsertionPoint(yield.getOperation()); + newCast = markBackward(builder.create(loc, newType, yieldSrc)); + yield->setOperand(index, newCast.getResult(0)); + queue.push(newCast); + } + + return true; +} + +bool CTAPlanner::processForOp(scf::ForOp forOp, int index, + const Type &newType) { + Block *body = forOp.getBody(); + auto yield = llvm::cast(forOp.getBody()->getTerminator()); + + // Check index + assert(index + forOp.getNumControlOperands() < forOp.getNumOperands() && + "Invalid operand index of ForOp"); + assert(index + forOp.getNumInductionVars() < body->getNumArguments() && + "Invalid block arg index of ForOp"); + assert(index < yield.getNumOperands() && "Invalid operand index of YieldOp"); + assert(index < forOp.getNumResults() && "Invalid result index of IfOp"); + + Location loc = forOp.getLoc(); + OpBuilder builder(forOp.getContext()); + + // Insert backward cast before forOp + OpOperand &operand = + forOp->getOpOperand(index + forOp.getNumControlOperands()); + builder.setInsertionPoint(forOp.getOperation()); + auto newCast = + markBackward(builder.create(loc, newType, operand.get())); + operand.set(newCast.getResult(0)); + queue.push(newCast); + + // Insert forward cast after block arg + Value arg = body->getArgument(index + forOp.getNumInductionVars()); + builder.setInsertionPointToStart(body); + newCast = markForward(builder.create(loc, arg.getType(), arg)); + arg.setType(newType); + arg.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + // Insert backward cast before yield + Value yieldSrc = yield.getOperand(index); + builder.setInsertionPoint(yield.getOperation()); + newCast = markBackward(builder.create(loc, newType, yieldSrc)); + yield->setOperand(index, newCast.getResult(0)); + queue.push(newCast); + + // Insert forward cast after forOp + Value result = forOp.getResult(index); + builder.setInsertionPointAfter(forOp.getOperation()); + newCast = markForward(builder.create(loc, result.getType(), result)); + result.setType(newType); + result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + return true; +} + +int findResultIndex(Operation *op, Value result) { + for (int i = 0; i < op->getNumResults(); ++i) + if (op->getResult(i) == result) + return i; + llvm::report_fatal_error("Invalid index of op result"); + return -1; +} + +bool CTAPlanner::processIfOpBackward(scf::IfOp ifOp, CastOp cast) { + int index = findResultIndex(ifOp.getOperation(), cast.getOperand(0)); + auto newType = cast.getResult(0).getType(); + return processIfOp(ifOp, index, newType); +} + +bool CTAPlanner::processForOpBackward(scf::ForOp forOp, CastOp cast) { + int index = findResultIndex(forOp.getOperation(), cast.getOperand(0)); + auto newType = cast.getResult(0).getType(); + return processForOp(forOp, index, newType); +} + +bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) { + if (auto forOp = llvm::dyn_cast(arg.getOwner()->getParentOp())) { + int index = int(arg.getArgNumber()) - forOp.getNumInductionVars(); + auto newType = cast.getResult(0).getType(); + return processForOp(forOp, index, newType); + } else { + llvm::report_fatal_error("Unexpected parent op of block argument"); + return true; + } +} + +bool CTAPlanner::processForOpForward(scf::ForOp forOp, CastOp cast) { + int index = cast.getResult(0).use_begin()->getOperandNumber() - + forOp.getNumControlOperands(); + auto newType = cast.getOperand(0).getType(); + return processForOp(forOp, index, newType); +} + +bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) { + int index = cast.getResult(0).use_begin()->getOperandNumber(); + auto newType = cast.getOperand(0).getType(); + if (auto ifOp = llvm::dyn_cast(yieldOp->getParentOp())) + return processIfOp(ifOp, index, newType); + else if (auto forOp = llvm::dyn_cast(yieldOp->getParentOp())) + return processForOp(forOp, index, newType); + else + llvm::report_fatal_error("Unexpected parent op of YieldOp"); + return true; +} + +bool CTAPlanner::processOpFallback(Operation *op) { + Location loc = op->getLoc(); + OpBuilder builder(op->getContext()); + + builder.setInsertionPoint(op); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto operandTy = operand.getType(); + if (triton::isTensorOrTensorPointerType(operandTy)) { + auto cast = markBackward(builder.create(loc, operandTy, operand)); + op->setOperand(i, cast.getResult(0)); + queue.push(cast); + } + } + + builder.setInsertionPointAfter(op); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + auto resultTy = result.getType(); + if (triton::isTensorOrTensorPointerType(resultTy)) { + auto cast = markForward(builder.create(loc, resultTy, result)); + result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation()); + queue.push(cast); + } + } + + return true; +} + +bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) { + Location loc = input.getLoc(); + OpBuilder builder(input.getContext()); + + llvm::DenseMap> typeToIndices; + for (OpOperand &operand : input.getUses()) { + auto brotherCast = llvm::dyn_cast(operand.getOwner()); + if (!brotherCast) { + if (stepUnchanged <= queue.size()) + return false; + builder.setInsertionPoint(operand.getOwner()); + brotherCast = markBackward( + builder.create(loc, cast.getResult(0).getType(), input)); + auto newCast = markForward(builder.create( + loc, input.getType(), brotherCast.getResult(0))); + operand.set(newCast.getResult(0)); + queue.push(brotherCast); + queue.push(newCast); + } + auto type = brotherCast.getResult(0).getType(); + typeToIndices[type].push_back(brotherCast); + } + + bool first = true; + for (auto it : typeToIndices) { + Type &type = it.first; + llvm::SmallVector &casts = it.second; + Value newInput = input; + if (!first) { + if (Operation *defOp = input.getDefiningOp()) { + builder.setInsertionPointAfter(defOp); + Operation *clonedOp = builder.clone(*defOp); + newInput = clonedOp->getResult(0); + } else { + llvm::report_fatal_error("Layout conflict for block arg"); // TODO + return false; + } + } + first = false; + if (Operation *defOp = newInput.getDefiningOp()) { + builder.setInsertionPointAfter(defOp); + } else { + assert(isa(newInput) && + "Unexpected Value without defining op"); + builder.setInsertionPointToStart( + llvm::cast(newInput).getOwner()); + } + auto newCast = markBackward(builder.create(loc, type, newInput)); + queue.push(newCast); + auto newResult = newCast.getResult(0); + for (CastOp &brotherCast : casts) { + brotherCast.getResult(0).replaceAllUsesWith(newResult); + eraseCastOpFromQueue(brotherCast); + } + } + return true; +} + +bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) { + Value castSrc = cast.getOperand(0); + + Location loc = cast.getLoc(); + OpBuilder builder(cast.getContext()); + builder.setInsertionPointAfter(cast.getOperation()); + + while (!castResult.use_empty()) { + auto newCast = + markForward(builder.create(loc, castResult.getType(), castSrc)); + castResult.use_begin()->set(newCast.getResult(0)); + queue.push(newCast); + } + + eraseCastOp(cast); + return true; +} + +struct PlanCTAPass : public TritonGPUPlanCTAPassBase { + PlanCTAPass(ttng::ClusterInfo *clusterInfo_ = nullptr) + : clusterInfo(clusterInfo_) {} + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // Skip PlanCTAPass when numCTAs == 1 + if (ttg::TritonGPUDialect::getNumCTAs(mod) == 1) + return; + + mod.walk([&](triton::FuncOp funcOp) { + CTAPlanner planner(clusterInfo); + planner.run(funcOp); + + // FIXME: Clone funcOp so that the IR change can be identified after + // PlanCTAPass. Without this, the change after PlanCTAPass will not be + // displayed when MLIR_ENABLE_DUMP=1. This is not reasonable and should + // be fixed later. + OpBuilder builder(funcOp); + builder.clone(*funcOp.getOperation()); + funcOp.erase(); + }); + } + + ttng::ClusterInfo *clusterInfo; +}; + +} // namespace + +std::unique_ptr +mlir::createTritonNvidiaGPUPlanCTAPass(ttng::ClusterInfo *clusterInfo) { + return std::make_unique(clusterInfo); +} + +/* TODO + * - Use ConvertLayoutOp instead of UnrealizedConversionCastOp. + * - Move PlanCTAPass to the front of CoalescePass. + * - Design better tiling strategy for DotOp and ReduceOp. + * - Consider cases where there are more than one DotOps. + * - Use better data structure for erasing CastOps from queue (linked list?). + * - Process eliminable CastOps in higher priority. + * - Fix the clone func bug in PlanCTAPass::runOnOperation. + * - Add some comments to introduce the overall idea of this pass. + * - Add some lit tests for this pass. + */ diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp new file mode 100644 index 000000000..342698c11 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp @@ -0,0 +1,124 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include + +namespace { + +using namespace mlir; + +namespace ttng = triton::nvidia_gpu; +namespace ttg = triton::gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { +template +Attribute getLHSTMemLayout(MMAOpTy tcGen5MMAOp, + ttg::BlockedEncodingAttr srcLayout) { + auto CTALayout = getCTALayout(srcLayout); + int numWarps = ttg::lookupNumWarps(tcGen5MMAOp); + auto accTmemEncoding = dyn_cast( + tcGen5MMAOp.getD().getType().getEncoding()); + auto lhs = tcGen5MMAOp.getA(); + auto lhsShape = lhs.getType().getShape(); + // M has to follow the MMA size, as it is related to the message we are using. + // N has to follow the number of columns in the LHS. + int M = accTmemEncoding.getBlockM(); + int N = lhsShape[1]; + Attribute resLayout = + ttng::getTmemCompatibleLayout(M, N, lhsShape, numWarps, CTALayout); + return resLayout; +} + +template class LHSToTMem : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MMAOpTy tcGen5MMAOp, + PatternRewriter &rewriter) const override { + MLIRContext *context = tcGen5MMAOp->getContext(); + Location loc = tcGen5MMAOp.getLoc(); + auto lhs = tcGen5MMAOp.getA(); + auto localAllocOp = lhs.template getDefiningOp(); + if (!localAllocOp) + return failure(); + // Limit the liverange of the TMem allocations to single block. + if (localAllocOp->getParentRegion() != tcGen5MMAOp->getParentRegion()) + return failure(); + Value src = localAllocOp.getSrc(); + auto srcType = cast(src.getType()); + auto srcLayout = cast(srcType.getEncoding()); + bool layoutTmemCompatible = ttng::isDistributedLayoutTMemCompatible( + tcGen5MMAOp, srcType, tcGen5MMAOp.getD().getType()); + Attribute newLayout = srcLayout; + if (!layoutTmemCompatible) { + if (triton::tools::getBoolEnv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION")) { + newLayout = getLHSTMemLayout(tcGen5MMAOp, srcLayout); + } else { + return failure(); + } + } + rewriter.setInsertionPointAfter(localAllocOp); + if (newLayout != srcLayout) { + auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), newLayout); + src = rewriter.create(loc, newTy, src); + } + auto accTMemEncoding = dyn_cast( + tcGen5MMAOp.getD().getType().getEncoding()); + ArrayRef CTASplitNum = srcLayout.getCTALayout().getCTASplitNum(); + // TMem encoding for A operand is the same as for D (Acc), but unpacked. + auto aTMemEncoding = ttng::TensorMemoryEncodingAttr::get( + context, accTMemEncoding.getBlockM(), lhs.getType().getShape()[1], + /*unpacked=*/false, CTASplitNum[0], CTASplitNum[1]); + Attribute tensorMemorySpace = + triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); + Type lhsMemDescType = triton::gpu::MemDescType::get( + lhs.getType().getShape(), lhs.getType().getElementType(), aTMemEncoding, + tensorMemorySpace, + /*mutableMemory=*/false); + Value tMemAlloc = + rewriter.create(loc, lhsMemDescType, src); + tcGen5MMAOp.getAMutable().assign(tMemAlloc); + return success(); + } +}; +} // namespace + +class TritonNvidiaGPUPromoteLHSToTMemPass + : public TritonNvidiaGPUPromoteLHSToTMemPassBase< + TritonNvidiaGPUPromoteLHSToTMemPass> { +public: + using TritonNvidiaGPUPromoteLHSToTMemPassBase< + TritonNvidiaGPUPromoteLHSToTMemPass>:: + TritonNvidiaGPUPromoteLHSToTMemPassBase; + + void runOnOperation() override { + if (!triton::tools::getBoolEnv("ENABLE_LHS_TO_TMEM")) + return; + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet patterns(context); + patterns.add>(context); + patterns.add>(context); + if (applyPatternsGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::createTritonNvidiaGPUPromoteLHSToTMemPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp new file mode 100644 index 000000000..2c0bf093d --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -0,0 +1,216 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" + +#include + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; +using namespace triton::nvidia_gpu; + +static void +lowerTMALoad(Operation *op, RankedTensorType tensorType, Value desc, + function_ref createLoad, + PatternRewriterWithAsyncTaskIds &rewriter, + PatternRewriter &baseRewriter) { + MLIRContext *ctx = op->getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); + auto loc = op->getLoc(); + auto order = getOrder(tensorType); + auto ctaLayout = getCTALayout(tensorType.getEncoding()); + Attribute encoding = SwizzledSharedEncodingAttr::get( + tensorType.getContext(), 1, 1, 1, order, ctaLayout); + if (tensorType.getRank() > 1) { + encoding = NVMMASharedEncodingAttr::get( + tensorType.getContext(), tensorType.getShape(), order, ctaLayout, + tensorType.getElementType(), /*fp4Padded*/ false); + } + MemDescType memDescType = + MemDescType::get(tensorType.getShape(), tensorType.getElementType(), + encoding, sharedMemorySpace, /*mutableMemory=*/true); + Value alloc = rewriter.create(loc, memDescType); + auto barrierCTALayout = CTALayoutAttr::get( + /*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = SwizzledSharedEncodingAttr::get( + tensorType.getContext(), 1, 1, 1, {0}, barrierCTALayout); + MemDescType barrierMemDescType = + MemDescType::get({1}, baseRewriter.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = rewriter.create(loc, barrierMemDescType); + rewriter.create(loc, barrierAlloc, 1); + int sizeInBytes = product(tensorType.getShape()) * + tensorType.getElementType().getIntOrFloatBitWidth() / 8; + Value pred = rewriter.create(loc, 1, 1); + rewriter.create(loc, barrierAlloc, + sizeInBytes, pred); + Value tmaPtr = + rewriter.create(loc, desc); + createLoad(tmaPtr, barrierAlloc, alloc, pred); + Value phase = rewriter.create(loc, 0, 32); + rewriter.create(loc, barrierAlloc, phase); + rewriter.create(loc, barrierAlloc); + rewriter.replaceOpWithNewOp(op, tensorType, alloc); +} + +class TMALoadLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op, + PatternRewriter &baseRewriter) const override { + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc, + Value pred) { + rewriter.create( + op.getLoc(), tmaPtr, op.getIndices(), barrierAlloc, alloc, pred); + }; + lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter, + baseRewriter); + return success(); + } +}; + +struct TMAGatherLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExperimentalDescriptorGatherOp op, + PatternRewriter &baseRewriter) const override { + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc, + Value pred) { + rewriter.create( + op.getLoc(), tmaPtr, op.getXOffsets(), op.getYOffset(), barrierAlloc, + alloc, pred); + }; + lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter, + baseRewriter); + return success(); + } +}; + +static void lowerTMAStore(Operation *op, mlir::TypedValue src, + Value desc, + function_ref createStore, + PatternRewriterWithAsyncTaskIds &rewriter, + PatternRewriter &baseRewriter) { + MLIRContext *ctx = op->getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); + auto loc = op->getLoc(); + auto tensorType = src.getType(); + auto order = getOrder(tensorType); + auto ctaLayout = getCTALayout(tensorType.getEncoding()); + Attribute encoding = SwizzledSharedEncodingAttr::get( + tensorType.getContext(), 1, 1, 1, order, ctaLayout); + if (tensorType.getRank() > 1) { + encoding = NVMMASharedEncodingAttr::get( + tensorType.getContext(), tensorType.getShape(), order, ctaLayout, + tensorType.getElementType(), /*fp4Padded*/ false); + } + MemDescType memDescType = + MemDescType::get(tensorType.getShape(), tensorType.getElementType(), + encoding, sharedMemorySpace, /*mutableMemory=*/true); + Value alloc = rewriter.create(loc, memDescType, src); + rewriter.create(loc, false); + Value tmaPtr = + rewriter.create(loc, desc); + createStore(tmaPtr, alloc); + rewriter.create(loc, 0); + baseRewriter.eraseOp(op); +} + +struct TMAStoreLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExperimentalDescriptorStoreOp op, + PatternRewriter &baseRewriter) const override { + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + auto createStore = [&](Value tmaPtr, Value alloc) { + rewriter.create( + op.getLoc(), tmaPtr, op.getIndices(), alloc); + }; + lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter, + baseRewriter); + return success(); + } +}; + +struct TMAScatterLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExperimentalDescriptorScatterOp op, + PatternRewriter &baseRewriter) const override { + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + auto createStore = [&](Value tmaPtr, Value alloc) { + rewriter.create( + op.getLoc(), tmaPtr, op.getXOffsets(), op.getYOffset(), alloc); + }; + lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter, + baseRewriter); + return success(); + } +}; + +class TMACreateDescLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MakeTensorDescOp op, + PatternRewriter &baseRewriter) const override { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + auto alloc = rewriter.create( + loc, getPointerType(baseRewriter.getI8Type()), TMA_SIZE_BYTES, + TMA_ALIGN); + if (failed(createTMADesc(alloc, op, baseRewriter))) { + return failure(); + } + rewriter.create( + loc, alloc.getResult()); + auto newDesc = rewriter.create( + loc, op.getType(), alloc.getResult()); + baseRewriter.replaceOp(op, newDesc); + return success(); + } +}; + +class TritonNvidiaGPUTMALoweringPass + : public TritonNvidiaGPUTMALoweringPassBase< + TritonNvidiaGPUTMALoweringPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::createTritonNvidiaGPUTMALoweringPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp new file mode 100644 index 000000000..4e92293fd --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp @@ -0,0 +1,332 @@ +#include "mlir/Analysis/Liveness.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/MapVector.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; +using namespace triton::nvidia_gpu; + +namespace { + +// Granularity of row allocations. +static constexpr int allocGranularity = 64; +struct TMemChunk { + int startRow; + int startCol; + int numCols; + int numRows; +}; + +// Use a simple bitmap to track memory usage. This is a slow but it allows us to +// handle 2D memory without extra algorithmic complexity. The number of +// allocations is expected to be small so the compile time is unlikely to be a +// problem. +struct MemoryBitMap { + MemoryBitMap() : elements(512 * kNumRows, false) {} + void free(const TMemChunk &chunk) { + for (int i = 0; i < chunk.numCols; i++) { + for (int j = 0; j < chunk.numRows; j++) { + setUsed(chunk.startRow + j, chunk.startCol + i, false); + } + } + } + void alloc(const TMemChunk &chunk) { + // Ensure the underlying data fits the allocation. + while ((chunk.startCol + chunk.numCols) * kNumRows >= elements.size()) + elements.resize(2 * elements.size(), false); + + for (int i = 0; i < chunk.numCols; i++) { + for (int j = 0; j < chunk.numRows; j++) { + setUsed(chunk.startRow + j, chunk.startCol + i, true); + } + } + } + + TMemChunk findFirstFit(TMemAllocation allocSize, + std::optional rowIdConstraint) const { + int numRows = allocSize.numRows / allocGranularity; + assert(kNumRows - numRows >= 0); + assert(allocSize.numRows % allocGranularity == 0); + int startCol = 0; + while (1) { + // Iterate over possible starting rows + for (int startRow = 0; startRow <= kNumRows - numRows; ++startRow) { + if (rowIdConstraint && *rowIdConstraint != startRow) + continue; + bool fits = true; + + // Check if the block starting at (startRow, startCol) is free + for (int i = 0; i < allocSize.numCols && fits; ++i) { + for (int j = 0; j < numRows; ++j) { + if (isUsed(startRow + j, startCol + i)) { + fits = false; + break; + } + } + } + + // If a suitable block is found, return it + if (fits) { + TMemChunk chunk; + chunk.startRow = startRow; + chunk.startCol = startCol; + chunk.numRows = numRows; + chunk.numCols = allocSize.numCols; + return chunk; + } + } + startCol++; + } + return TMemChunk(); + } + +private: + bool isUsed(int row, int col) const { + if (row + col * kNumRows >= elements.size()) + return false; + return elements[row + col * kNumRows]; + } + void setUsed(int row, int col, bool used) { + assert(row + col * kNumRows < elements.size()); + elements[row + col * kNumRows] = used; + } + + static constexpr int kNumRows = 2; + std::vector elements; +}; + +static Interval getLiveIntervals(Value value, Liveness &liveness, + DenseMap &operationId) { + auto liveOperations = liveness.resolveLiveness(value); + // Merge the alloc liverange with the liverange of any subview of the + // allocation. + SmallVector users(value.getUsers()); + while (!users.empty()) { + Operation *user = users.pop_back_val(); + if (!isa(user)) + continue; + auto usersLivness = liveness.resolveLiveness(user->getResult(0)); + liveOperations.insert(liveOperations.end(), usersLivness.begin(), + usersLivness.end()); + users.append(user->getResult(0).getUsers().begin(), + user->getResult(0).getUsers().end()); + } + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + return Interval(minId, maxId); +} + +static void updateMap(MemoryBitMap &memoryMap, Interval liveInterval, + std::map &intervalLiverangeEnd) { + int start = liveInterval.start(); + // Add any dead liverange to the list of free intervals. + for (auto it = intervalLiverangeEnd.begin(); + it != intervalLiverangeEnd.end();) { + if (it->first > start) + break; + memoryMap.free(it->second); + it = intervalLiverangeEnd.erase(it); + } +} + +static TMemChunk allocFirstFit(MemoryBitMap &memoryMap, + TMemAllocation allocSize, + std::optional rowIdConstraint, + ArrayRef coexistingChunks) { + // `coexistingChunks` are all the allocations that might need to be live at + // the same time as the current allocation plus what is known to be currently + // live. Union those allocations with a copy of the current memory map and use + // that to find the actual offsets. + MemoryBitMap mapForAlloc = memoryMap; + for (const TMemChunk &chunk : coexistingChunks) + mapForAlloc.alloc(chunk); + TMemChunk chunk = mapForAlloc.findFirstFit(allocSize, rowIdConstraint); + + // Mark this chunk as allocated in the actual memory map. + memoryMap.alloc(chunk); + return chunk; +} + +static Operation *getAlloc(Value value) { + Operation *op = value.getDefiningOp(); + while (isa(op)) { + op = op->getResult(0).getDefiningOp(); + } + assert(isa(op) && "Expected a TMEMAllocOp"); + return op; +} + +class RowIdConstraints { + llvm::EquivalenceClasses dependentAllocs; + llvm::SmallDenseMap rowIndex; + +public: + void joinOps(Operation *op1, Operation *op2) { + dependentAllocs.unionSets(op1, op2); + } + + std::optional getRowIdConstraint(Operation *op) { + auto it = dependentAllocs.findLeader(op); + if (it == dependentAllocs.member_end()) + return std::nullopt; + auto rowIt = rowIndex.find(*it); + if (rowIt == rowIndex.end()) + return std::nullopt; + return rowIt->second; + } + + void addConstraints(Operation *op, int rowId) { + auto it = dependentAllocs.findLeader(op); + if (it == dependentAllocs.member_end()) + return; + rowIndex[*it] = rowId; + } +}; + +static int +allocateTMem(Operation *parentOp, + DenseMap &offsets) { + SmallVector allocs; + DenseMap operationId; + RowIdConstraints rowIdConstraints; + parentOp->walk([&](Operation *op) { + operationId[op] = operationId.size(); + if (auto alloc = dyn_cast(op)) { + allocs.push_back(alloc); + } + if (auto mmaOp = dyn_cast(op)) { + if (isa( + mmaOp.getA().getType().getEncoding())) { + TMemAllocation allocSize = getTmemAllocSizes(mmaOp.getA().getType()); + if (allocSize.numRows == 64) { + // HW restriction, the A alloc and accumulator needs to be in the same + // rows. + rowIdConstraints.joinOps(getAlloc(mmaOp.getA()), + getAlloc(mmaOp.getD())); + } else { + // TODO: we need to handle cases where the format is blockM and we + // have multiple blocks. + assert((cast( + mmaOp.getA().getType().getEncoding()) + .getBlockM() != 64 && + cast( + mmaOp.getD().getType().getEncoding()) + .getBlockM() != 64) && + "interleaved layout with TMEM operand is not supported yet."); + } + } + } + }); + int totalMemorySize = 0; + MemoryBitMap memoryMap; + Liveness liveness(parentOp); + std::map intervalLiverangeEnd; + DenseMap allocChunks; + // Implement a linear scan first fit algorithm. We expect that fragmentation + // won't be a problem, if it is this should be revisited. + for (auto it = allocs.begin(), e = allocs.end(); it != e; ++it) { + TMEMAllocOp alloc = *it; + + // Find all allocations in code that may execute at the same time. Only look + // at processed allocations. + SmallVector coexistingChunks; + if (auto ws = alloc->getParentOfType()) { + for (auto prevIt = allocs.begin(); prevIt != it; ++prevIt) { + TMEMAllocOp prevAlloc = *prevIt; + auto prevWs = prevAlloc->getParentOfType(); + if (prevWs && prevWs == ws && + alloc->getParentRegion() != prevAlloc->getParentRegion()) + coexistingChunks.push_back(allocChunks.at(prevAlloc)); + } + } + + Interval liveInterval = getLiveIntervals(alloc, liveness, operationId); + auto memDescType = alloc.getType(); + TMemAllocation allocSize = getTmemAllocSizes(memDescType); + updateMap(memoryMap, liveInterval, intervalLiverangeEnd); + + std::optional rowIdConstraint = + rowIdConstraints.getRowIdConstraint(alloc); + TMemChunk chunkAllocated = + allocFirstFit(memoryMap, allocSize, rowIdConstraint, coexistingChunks); + allocChunks.insert({alloc, chunkAllocated}); + // currently naively constraint allocs based on the first one we find. + rowIdConstraints.addConstraints(alloc, chunkAllocated.startRow); + intervalLiverangeEnd[liveInterval.end()] = chunkAllocated; + int colOffset = chunkAllocated.startCol; + int rowOffset = chunkAllocated.startRow * 16; + + alloc->setAttr( + "tensor_memory_col_offset", + IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32), + colOffset)); + alloc->setAttr( + "tensor_memory_row_offset", + IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32), + rowOffset)); + totalMemorySize = std::max(totalMemorySize, colOffset + allocSize.numCols); + } + return totalMemorySize; +} + +class TritionTensorMemoryAllocationPass + : public TritionTensorMemoryAllocationPassBase< + TritionTensorMemoryAllocationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + DenseMap offsets; + // TODO: handle cases with multiple function with TMEMAllocOp. + int totalMemorySize = allocateTMem(mod, offsets); + + std::array possibleAllocations = {0, 32, 64, 128, 256, 512}; + if (totalMemorySize <= 512) { + for (int size : possibleAllocations) { + if (totalMemorySize <= size) { + totalMemorySize = size; + break; + } + } + } + // if totalMemorySize > 512 we exceeded the maximum amount of tensor memory, + // let the compilation finish so that we can raise an exception in python + // for auto-tuner. + if (totalMemorySize > 0) { + assert(mod->getAttr("ttg.shared") != nullptr && + cast(mod->getAttr("ttg.shared")).getInt() != 0 && + "Shared memory is required for allocation of Tensor Core memory."); + } + + mod->setAttr("ttg.tensor_memory_size", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + totalMemorySize)); + } +}; + +} // namespace + +std::unique_ptr mlir::createTensorMemoryAllocationPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..ba46d52eb --- /dev/null +++ b/third_party/enflame/include/triton/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp @@ -0,0 +1,88 @@ + +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#include + +namespace mlir { + +namespace ttg = triton::gpu; + +//===----------------------------------------------------------------------===// +// Helper functions for async task +//===----------------------------------------------------------------------===// + +SmallVector getAsyncTaskIds(Operation *op) { + SmallVector asyncTaskIds; + if (auto attr = op->getAttrOfType("async_task_id")) + for (AsyncTaskId asyncTaskId : attr.getValues()) + asyncTaskIds.push_back(asyncTaskId); + return asyncTaskIds; +} + +bool hasAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId) { + for (AsyncTaskId candidate : getAsyncTaskIds(op)) + if (candidate == asyncTaskId) + return true; + return false; +} + +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds) { + SmallVector sortedAsyncTaskIds(asyncTaskIds.begin(), + asyncTaskIds.end()); + sort(sortedAsyncTaskIds); + auto i32Ty = IntegerType::get(op->getContext(), 32); + auto size = static_cast(sortedAsyncTaskIds.size()); + auto vecTy = VectorType::get(size, i32Ty); + op->setAttr("async_task_id", + DenseIntElementsAttr::get(vecTy, sortedAsyncTaskIds)); +} + +SmallVector getNestedAsyncTaskIds(Operation *op) { + SetVector asyncTaskIds; + op->walk([&](Operation *curOp) { + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(curOp)) + asyncTaskIds.insert(asyncTaskId); + }); + SmallVector res(asyncTaskIds.begin(), asyncTaskIds.end()); + llvm::sort(res); + return res; +} + +void addAsyncTaskIds(Operation *op, ArrayRef asyncTasks) { + auto asyncTasksVec = getAsyncTaskIds(op); + DenseSet asyncTasksSet(asyncTasksVec.begin(), asyncTasksVec.end()); + for (int a : asyncTasks) { + if (!asyncTasksSet.contains(a)) { + asyncTasksVec.push_back(a); + } + } + if (asyncTasksVec.size() > 0) { + setAsyncTaskIds(op, asyncTasksVec); + } +} + +void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId) { + auto origAsyncTaskIds = getAsyncTaskIds(op); + auto end = std::remove(origAsyncTaskIds.begin(), origAsyncTaskIds.end(), + asyncTaskId); + origAsyncTaskIds.erase(end, origAsyncTaskIds.end()); + if (origAsyncTaskIds.empty()) + op->removeAttr("async_task_id"); + else + setAsyncTaskIds(op, origAsyncTaskIds); +} + +void removeAsyncTaskIds(Operation *op) { op->removeAttr("async_task_id"); } +//===----------------------------------------------------------------------===// +// Implementations for general auto WS +//===----------------------------------------------------------------------===// + +} // namespace mlir diff --git a/third_party/enflame/include/triton/lib/Instrumentation/CMakeLists.txt b/third_party/enflame/include/triton/lib/Instrumentation/CMakeLists.txt new file mode 100644 index 000000000..6e6da2351 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Instrumentation/CMakeLists.txt @@ -0,0 +1,42 @@ +set(GPU_INSTRUMENTATION_PASSES + PrintLoadStoreMemSpaces + ) + +set(PrintLoadStoreMemSpaces_SOURCES + PrintLoadStoreMemSpaces.cpp + ) + + +foreach( plugin ${GPU_INSTRUMENTATION_PASSES} ) + add_library( + ${plugin} + SHARED + ${${plugin}_SOURCES} + ) + + target_link_libraries( + ${plugin} + PRIVATE + LLVMCore + LLVMSupport + LLVMTransformUtils + "$<$:-undefined dynamic_lookup>" + ) + # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python + # build. It is empty if building directly from the root + # CMakeLists.txt file. Therefore if not building from Python just + # use the default CMake shared lib path otherwise this causes a hard + # build error + if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + set_target_properties(${plugin} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY + "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation") + endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + + # This is set to -fvisibility=hidden in the top level CMake file + # which causes the llvmGetPassPluginInfo symbol to be hidden and + # an "entry point not found" error. Reset it just for this target + if(NOT MSVC) + target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti) + endif() +endforeach() diff --git a/third_party/enflame/include/triton/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp b/third_party/enflame/include/triton/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp new file mode 100644 index 000000000..c243fc149 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp @@ -0,0 +1,101 @@ +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include + +using namespace llvm; + +namespace { + +struct LoadStoreMemSpace : public PassInfoMixin { + PreservedAnalyses run(llvm::Module &module, ModuleAnalysisManager &) { + bool modifiedCodeGen = runOnModule(module); + + return (modifiedCodeGen ? llvm::PreservedAnalyses::none() + : llvm::PreservedAnalyses::all()); + } + bool runOnModule(llvm::Module &module); + // isRequired being set to true keeps this pass from being skipped + // if it has the optnone LLVM attribute + static bool isRequired() { return true; } +}; + +} // end anonymous namespace + +std::map AddrSpaceMap = { + {0, "FLAT"}, {1, "GLOBAL"}, {3, "SHARED"}, {4, "CONSTANT"}, {5, "SCRATCH"}}; + +std::map LocationCounterSourceMap; + +std::string LoadOrStoreMap(const BasicBlock::iterator &I) { + if (LoadInst *LI = dyn_cast(I)) + return "LOAD"; + else if (StoreInst *SI = dyn_cast(I)) + return "STORE"; + else + throw std::runtime_error("Error: unknown operation type"); +} +template +void InstrumentationFunction(const BasicBlock::iterator &I, const Function &F, + const llvm::Module &M, uint32_t &LocationCounter) { + auto LSI = dyn_cast(I); + if (not LSI) + return; + Value *Op = LSI->getPointerOperand()->stripPointerCasts(); + uint32_t AddrSpace = cast(Op->getType())->getAddressSpace(); + DILocation *DL = dyn_cast(I)->getDebugLoc(); + + std::string SourceAndAddrSpaceInfo = + (F.getName() + " " + DL->getFilename() + ":" + Twine(DL->getLine()) + + ":" + Twine(DL->getColumn())) + .str() + + " " + AddrSpaceMap[AddrSpace] + " " + LoadOrStoreMap(I); + + if (LocationCounterSourceMap.find(SourceAndAddrSpaceInfo) == + LocationCounterSourceMap.end()) { + errs() << LocationCounter << " " << SourceAndAddrSpaceInfo << "\n"; + LocationCounterSourceMap[SourceAndAddrSpaceInfo] = LocationCounter; + LocationCounter++; + } +} + +bool LoadStoreMemSpace::runOnModule(Module &M) { + bool ModifiedCodeGen = false; + uint32_t LocationCounter = 0; + for (auto &F : M) { + if (F.isIntrinsic()) + continue; + StringRef functionName = F.getName(); + if (F.getCallingConv() == CallingConv::AMDGPU_KERNEL || + F.getCallingConv() == CallingConv::PTX_Kernel || + functionName.contains("kernel")) { + for (Function::iterator BB = F.begin(); BB != F.end(); BB++) { + for (BasicBlock::iterator I = BB->begin(); I != BB->end(); I++) { + if (LoadInst *LI = dyn_cast(I)) { + InstrumentationFunction(I, F, M, LocationCounter); + } else if (StoreInst *SI = dyn_cast(I)) { + InstrumentationFunction(I, F, M, LocationCounter); + } + } + } + } + } + return ModifiedCodeGen; +} + +PassPluginLibraryInfo getPassPluginInfo() { + const auto callback = [](PassBuilder &PB) { + PB.registerOptimizerLastEPCallback([&](ModulePassManager &MPM, auto, auto) { + MPM.addPass(LoadStoreMemSpace()); + return true; + }); + }; + + return {LLVM_PLUGIN_API_VERSION, "print-mem-space", LLVM_VERSION_STRING, + callback}; +}; + +extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() { + return getPassPluginInfo(); +} diff --git a/third_party/enflame/include/triton/lib/Target/CMakeLists.txt b/third_party/enflame/include/triton/lib/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/enflame/include/triton/lib/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/enflame/include/triton/lib/Target/LLVMIR/CMakeLists.txt b/third_party/enflame/include/triton/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..f2f9adf8f --- /dev/null +++ b/third_party/enflame/include/triton/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,28 @@ +add_triton_library(TritonLLVMIR + LLVMDIScope.cpp + LLVMIRBreakPhiStruct.cpp + + DEPENDS + LLVMIRIncGen + + LINK_LIBS + ${CMAKE_DL_LIBS} + PUBLIC + MLIRArithToLLVM + MLIRBuiltinToLLVMIRTranslation + MLIRIndexToLLVM + MLIRIR + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRNVVMToLLVMIRTranslation + MLIRROCDLToLLVMIRTranslation + MLIRSCFToControlFlow + MLIRSupport + MLIRTargetLLVMIRExport + TritonGPUToLLVM + ) + +set_source_files_properties( + LLVMIRTranslation.cpp + PROPERTIES + COMPILE_FLAGS "-D__BUILD_DIR__=\\\"${CMAKE_BINARY_DIR}\\\"") diff --git a/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMDIScope.cpp b/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMDIScope.cpp new file mode 100644 index 000000000..4aa9828cd --- /dev/null +++ b/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMDIScope.cpp @@ -0,0 +1,161 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Target/LLVMIR/Passes.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" + +//===----------------------------------------------------------------------===// +// This file implements a pass to add debug info scope to LLVM operations, and +// is inspired by the DIScopeForLLVMFuncOpPass in LLVM/MLIR. Different from the +// DIScopeForLLVMFuncOpPass, this pass also handles inlined functions. +//===----------------------------------------------------------------------===// + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Target/LLVMIR/Passes.h.inc" + +namespace { + +/// Attempt to extract a filename for the given loc. +FileLineColLoc extractFileLoc(Location loc) { + if (auto fileLoc = dyn_cast(loc)) + return fileLoc; + if (auto nameLoc = dyn_cast(loc)) + return extractFileLoc(nameLoc.getChildLoc()); + if (auto opaqueLoc = dyn_cast(loc)) + return extractFileLoc(opaqueLoc.getFallbackLocation()); + if (auto fusedLoc = dyn_cast(loc)) + return extractFileLoc(fusedLoc.getLocations().front()); + if (auto callerLoc = dyn_cast(loc)) + return extractFileLoc(callerLoc.getCaller()); + StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), ""); + return mlir::FileLineColLoc::get(unknownFile, 0, 0); +} + +/// Add a debug info scope to LLVMFuncOp that are missing it. +struct LLVMDIScopePass : public LLVMDIScopeBase { + LLVMDIScopePass() = default; + + void setSubprogramAttr(LLVM::LLVMFuncOp funcOp) { + Location loc = funcOp.getLoc(); + if (loc->findInstanceOf>()) + return; + + MLIRContext *context = &getContext(); + + // To find a DICompileUnitAttr attached to a parent (the module for + // example), otherwise create a default one. + LLVM::DICompileUnitAttr compileUnitAttr; + if (ModuleOp module = funcOp->getParentOfType()) { + auto fusedCompileUnitAttr = + module->getLoc() + ->findInstanceOf>(); + if (fusedCompileUnitAttr) + compileUnitAttr = fusedCompileUnitAttr.getMetadata(); + } + + // Filename, line and colmun to associate to the function. + LLVM::DIFileAttr fileAttr; + int64_t line = 1, col = 1; + FileLineColLoc fileLoc = extractFileLoc(loc); + if (!fileLoc && compileUnitAttr) { + fileAttr = compileUnitAttr.getFile(); + } else if (!fileLoc) { + fileAttr = LLVM::DIFileAttr::get(context, "", ""); + } else { + line = fileLoc.getLine(); + col = fileLoc.getColumn(); + StringRef inputFilePath = fileLoc.getFilename().getValue(); + fileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + } + auto subroutineTypeAttr = + LLVM::DISubroutineTypeAttr::get(context, llvm::dwarf::DW_CC_normal, {}); + + // Figure out debug information (`subprogramFlags` and `compileUnitAttr`) to + // attach to the function definition / declaration. External functions are + // declarations only, and are defined in a different compile unit, so mark + // them appropriately in `subprogramFlags`, and set an empty + // `compileUnitAttr`. + DistinctAttr distinctId; + auto subprogramFlags = LLVM::DISubprogramFlags::Optimized; + if (!funcOp.isExternal()) { + distinctId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + if (!compileUnitAttr) { + compileUnitAttr = LLVM::DICompileUnitAttr::get( + distinctId, llvm::dwarf::DW_LANG_C, fileAttr, + StringAttr::get(context, "triton"), + /*isOptimized=*/true, LLVM::DIEmissionKind::LineTablesOnly); + } + subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition; + } else { + compileUnitAttr = {}; + } + + StringAttr funcNameAttr = funcOp.getNameAttr(); + // Note that scopeline is set differently from LLVM's + // DIScopeForLLVMFuncOpPass. I don't find reasons why scopeline should be + // the column offset + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, + funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line, + subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{}, + /*annotations=*/{}); + funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); + } + + // Get a nested loc for inlined functions + Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, + Location calleeLoc) { + auto calleeFileName = extractFileLoc(calleeLoc).getFilename(); + auto context = op->getContext(); + LLVM::DIFileAttr calleeFileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(calleeFileName), + llvm::sys::path::parent_path(calleeFileName)); + auto lexicalBlockFileAttr = LLVM::DILexicalBlockFileAttr::get( + context, scopeAttr, calleeFileAttr, /*discriminator=*/0); + Location loc = calleeLoc; + if (mlir::isa(calleeLoc)) { + auto nestedLoc = mlir::cast(calleeLoc).getCallee(); + loc = getNestedLoc(op, lexicalBlockFileAttr, nestedLoc); + } + return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); + } + + void setLexicalBlockFileAttr(Operation *op) { + auto opLoc = op->getLoc(); + if (auto callSiteLoc = dyn_cast(opLoc)) { + auto callerLoc = callSiteLoc.getCaller(); + auto calleeLoc = callSiteLoc.getCallee(); + LLVM::DIScopeAttr scopeAttr; + // We assemble the full inline stack so the parent of this loc must be a + // function + auto funcOp = op->getParentOfType(); + auto funcOpLoc = mlir::cast(funcOp.getLoc()); + scopeAttr = mlir::cast(funcOpLoc.getMetadata()); + auto loc = + CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc); + op->setLoc(loc); + } + } + + void runOnOperation() override { + getOperation()->walk([&](Operation *op) -> void { + if (isa(op)) + setSubprogramAttr(cast(op)); + else + setLexicalBlockFileAttr(op); + }); + } +}; + +} // end anonymous namespace + +std::unique_ptr mlir::createLLVMDIScopePass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 000000000..a3c6d6995 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHIIt()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMPasses.h b/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 000000000..1dcdb2992 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/third_party/enflame/include/triton/lib/Tools/CMakeLists.txt b/third_party/enflame/include/triton/lib/Tools/CMakeLists.txt new file mode 100644 index 000000000..2b6e66ee1 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Tools/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(TritonTools + LayoutUtils.cpp + LinearLayout.cpp + + DEPENDS + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + f2reduce +) diff --git a/third_party/enflame/include/triton/lib/Tools/LayoutUtils.cpp b/third_party/enflame/include/triton/lib/Tools/LayoutUtils.cpp new file mode 100644 index 000000000..563241974 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Tools/LayoutUtils.cpp @@ -0,0 +1,251 @@ +#include "triton/Tools/LayoutUtils.h" + +namespace mlir::triton { + +static bool checkSquareSublayout(const LinearLayout &ll, + ArrayRef dimNames, + function_ref checkBasis) { + // The empty layout is the identity + if (dimNames.size() == 0) { + return true; + } + // Check that the input-output sizes are the same + LinearLayout sl = ll.sublayout(dimNames, dimNames); + for (StringAttr dim : dimNames) { + if (ll.getInDimSize(dim) != ll.getOutDimSize(dim)) { + return false; + } + } + // Once the inputs and output dimensions are the same, we can just check + // that the basis for the single remaining dimension is the identity. + sl = sl.flattenIns().flattenOuts(); + const auto &inDimBases = sl.getBases().begin()->second; + for (auto [b, basis] : llvm::enumerate(inDimBases)) { + if (!checkBasis(b, basis[0])) { + return false; + } + } + return true; +} + +bool squareSublayoutIsIdentity(const LinearLayout &ll, + ArrayRef dimNames) { + return checkSquareSublayout( + ll, dimNames, [](int b, int32_t basis) { return basis == (1 << b); }); +} + +bool squareSublayoutIsPermutation(const LinearLayout &ll, + ArrayRef dimNames) { + int32_t mask = 0; + return checkSquareSublayout(ll, dimNames, [&](int b, int32_t basis) { + if (!llvm::isPowerOf2_32(basis)) + return false; + if (mask & basis) + return false; // check if this bit is already set + mask |= basis; + return true; + }); +} + +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + auto bases = layout.getBases(); + + auto kRegister = StringAttr::get(ctx, "register"); + std::set broadcastedDims; + + for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { + auto outDimName = outDim.value(); + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // + std::vector> sortedBases; + for (auto [inDimName, basis] : bases) { + for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { + auto outValue = basis[basisIdx][outDim.index()]; + if (outValue == 0) { + continue; + } + assert(llvm::isPowerOf2_32(outValue)); + sortedBases.emplace_back(inDimName, basisIdx, outValue); + } + } + // From the largest basis to the smallest. + llvm::sort(sortedBases, + [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); + for (auto [inDimName, basisIdx, outValue] : sortedBases) { + if (actualSize <= desiredSize) { + break; + } + if (!broadcastRegisters && inDimName == kRegister) { + broadcastedDims.insert(basisIdx); + } else { + bases[inDimName][basisIdx][outDim.index()] = 0; + } + actualSize >>= 1; + } + } + if (!broadcastRegisters) { + // Remove broadcasted registers + std::vector> newBasesRegister; + for (auto [idx, basis] : llvm::enumerate(bases[kRegister])) { + // Remove if it's broadcasted + if (broadcastedDims.find(idx) == broadcastedDims.end()) { + newBasesRegister.push_back(std::move(basis)); + } + } + bases[kRegister] = std::move(newBasesRegister); + } + + return LinearLayout(std::move(bases), + llvm::to_vector(layout.getOutDimNames())); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + StringAttr kDim = *layout.getInDimNames().begin(); + assert(kDim == "register" || kDim == "offset"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + +// Returns ["dim0", "dim1", ..., "dim"]. +SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { + SmallVector ret; + for (int i = 0; i < rank; i++) { + ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(i))); + } + return ret; +} + +// Returns [("dim0", dstShape[0]), ("dim1", dstShape[1]), ..., +// ("dim", dstShape[rank-1])]. +SmallVector> +standardOutDimPairs(MLIRContext *ctx, ArrayRef dstShape) { + auto newRank = dstShape.size(); + SmallVector> newOutDims; + for (auto [dim, size] : + llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) { + newOutDims.emplace_back(dim, size); + } + return newOutDims; +} + +// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to +// creating a 1D -> 1D mapping of size product(shape) and then reshaping to +// permute(shape, order). +LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, + ArrayRef order) { + assert(shape.size() == order.size()); + MLIRContext *ctx = inDimName.getContext(); + auto rank = shape.size(); + + // The order in triton is written wrt. [dim0, dim1, ...]. + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < shape.size(); i++) { + // Start with the most-minor dimension, which is order[0]. + int dim = order[i]; + ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]); + } + return ret; +} + +// Compute the supremum of two lists. +// If the supremum is not unique, we return the first list first +// Error out if the supremum does not exist +// e.g. sup([a, b], [a, c]) = [a, b, c], sup([a, b], [b, c]) = [a, b, c] +// sup([a, b], [b, a]) = error! Supremum does not exist. +SmallVector supremum(const SmallVector &x, + const SmallVector &y) { + llvm::SetVector result; + DenseMap posX, posY; + for (auto [idx, elem] : llvm::enumerate(x)) + posX[elem] = idx; + for (auto [idx, elem] : llvm::enumerate(y)) + posY[elem] = idx; + int i = 0, j = 0; + const int INF = std::numeric_limits::max(); + while (i < x.size() || j < y.size()) { + while (i < x.size() && result.contains(x[i])) + ++i; + while (j < y.size() && result.contains(y[j])) + ++j; + if (i >= x.size() && j >= y.size()) + break; + if (i < x.size() && j < y.size() && x[i] == y[j]) { + if (posY[x[i]] < j) + llvm_unreachable("Supremum does not exist"); + result.insert(x[i]); + ++i, ++j; + continue; + } + int candX = INF, candY = INF; + if (i < x.size()) { + if (posY.count(x[i]) && posY[x[i]] >= j) + candX = posY[x[i]]; + } + if (j < y.size()) { + if (posX.count(y[j]) && posX[y[j]] >= i) + candY = posX[y[j]]; + } + if (i < x.size() && candX == INF) { + result.insert(x[i]); + ++i; + continue; + } + if (j < y.size() && candY == INF) { + result.insert(y[j]); + ++j; + continue; + } + if (candX <= candY) { + if (posY[x[i]] < j) + llvm_unreachable("Supremum does not exist"); + result.insert(x[i]); + ++i; + } else { + if (posX[y[j]] < i) + llvm_unreachable("Supremum does not exist"); + result.insert(y[j]); + ++j; + } + } + return to_vector(result); +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/lib/Tools/LinearLayout.cpp b/third_party/enflame/include/triton/lib/Tools/LinearLayout.cpp new file mode 100644 index 000000000..ae591d955 --- /dev/null +++ b/third_party/enflame/include/triton/lib/Tools/LinearLayout.cpp @@ -0,0 +1,1111 @@ +#include "triton/Tools/LinearLayout.h" + +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "third_party/f2reduce/f2reduce.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" + +#define DEBUG_TYPE "linear_layout" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_ctz(unsigned x) { + unsigned long r; + _BitScanForward(&r, x); + return static_cast(r); +} + +static int __builtin_ctzll(unsigned long long x) { + unsigned long r; + _BitScanForward64(&r, x); + return static_cast(r); +} + +#endif + +namespace mlir::triton { + +namespace { +using BasesT = LinearLayout::BasesT; +using llvm::SmallDenseSet; +using llvm::Twine; + +BasesT makeBasesMap( + ArrayRef>>> bases) { + BasesT ret; + for (const auto &[inDim, inDimBases] : bases) { + ret[inDim] = inDimBases; + } + return ret; +} + +// Dump the matrix to stderr in a human-readable format for debugging. +void dumpMatrix(uint64_t *m, int numRows, int numCols) { + assert(numCols <= 64); + for (int r = 0; r < numRows; r++) { + llvm::errs() << "0b"; + for (int c = 0; c < numCols; c++) { + llvm::errs() << ((m[r] & (1 << c)) != 0 ? "1" : "0"); + } + llvm::errs() << "\n"; + } +} + +// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing +// the bases of the given layout. This can then be used by f2reduce. +// +// This function is called from the constructor of LinearLayout, so be careful +// not to use any functions that create LLs in here. +std::unique_ptr getMatrix(const LinearLayout &layout) { + int numRows = layout.getTotalOutDimSizeLog2(); + int numCols = layout.getTotalInDimSizeLog2(); + + // Don't handle giant LLs. This makes some things easier; for example, each + // row can be a single uint64_t. + assert(numCols <= 64 && "LinearLayout too large"); + assert(numRows <= 64 && "LinearLayout too large"); + + // Suppose we have a layout specified by the following values. + // + // L(0,1) = (0b01, 0b1) + // L(0,2) = (0b10, 0b0) + // L(1,0) = (0b10, 0b0) + // L(2,0) = (0b11, 0b0) + // + // We will create one column per entry above. The max bit width of the + // codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix + // will be + // + // | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 | + // | ↓ ↓ ↓ ↓ | | 0b0111 | + // | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 | + // | ↓ ↓ ↓ ↓ | + // + // Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not. + std::unique_ptr m(new uint64_t[numRows]()); + int r = 0; + for (StringAttr outDim : layout.getOutDimNames()) { + int c = 0; + for (StringAttr inDim : layout.getInDimNames()) { + for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) { + uint64_t basis = layout.getBasis(inDim, i, outDim); + for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) { + m[r + j] |= ((basis >> j) & 1) << c; + } + c++; + } + } + r += layout.getOutDimSizeLog2(outDim); + } + + return m; +} + +// Compute the rank of the matrix formed by taking the bases for the given +// outDim as columns. In other words, finds the number of linearly-independent +// bases for this output dimension. +int getMatrixRank(std::unique_ptr m, int numRows, int numCols) { + // stride is specified in number of 64-bit words per row, and we pack our + // matrix so that there's only one uint64_t per row. + assert(numCols <= 64); + f2reduce::inplace_rref_strided(m.get(), numRows, numCols, /*stride=*/1); + + // The rank of the reduced matrix is simply the number of nonzero rows. + int rank = 0; + for (int i = 0; i < numRows; i++) { + if (m[i] != 0) + rank++; + } + return rank; +} + +template +void assertDimsEqualIgnoringOrder(T &&a, U &&b) { + SmallDenseSet as(a.begin(), a.end()); + SmallDenseSet bs(b.begin(), b.end()); + if (as != bs) { + llvm::report_fatal_error("Dimensions must match, ignoring order, but they " + "don't. Got dims: [" + + Twine(triton::join(a, ", ")) + "] and [" + + triton::join(b, ", ") + "]"); + } +} + +template +void assertDimsSubsetIgnoringOrder(T &&small, U &&big) { + SmallDenseSet smallSet(small.begin(), small.end()); + SmallDenseSet bigSet(big.begin(), big.end()); + if (!llvm::set_is_subset(smallSet, bigSet)) { + llvm::report_fatal_error("Dimensions must be a subset, ignoring order, but " + "they aren't. Got dims: [" + + Twine(triton::join(small, ", ")) + "] and [" + + triton::join(big, ", ") + "]"); + } +} +} // anonymous namespace + +/*static*/ std::optional +LinearLayout::tryCreate(BasesT bases, + ArrayRef> outDims, + bool requireSurjective) { + LinearLayout ll(std::move(bases), std::move(outDims), NoCheckInvariants{}); + std::optional error = ll.checkInvariants(requireSurjective); + if (error) { + return std::nullopt; + } + return ll; +} + +LinearLayout::LinearLayout(BasesT bases, + ArrayRef> outDims, + NoCheckInvariants) + : bases(std::move(bases)) { + for (auto [outDim, size] : outDims) { + this->outDims[outDim] = size; + } +} + +LinearLayout::LinearLayout(BasesT bases, ArrayRef outDimNames) + : bases(std::move(bases)) { + // Infer out-dim sizes. + for (StringAttr outDim : outDimNames) { + outDims[outDim] = 1; + } + for (const auto &[inDim, inDimBases] : this->bases) { + for (const auto &basis : inDimBases) { + for (int i = 0; i < basis.size(); i++) { + int32_t &size = outDims[outDimNames[i]]; + size = std::max(size, llvm::NextPowerOf2(basis[i])); + } + } + } + + std::optional error = + checkInvariants(/*requireSurjective=*/true); + if (error.has_value()) { + llvm::report_fatal_error(StringRef(*error)); + } +} + +LinearLayout::LinearLayout(BasesT bases, + ArrayRef> outDims, + bool requireSurjective) + : LinearLayout(std::move(bases), std::move(outDims), NoCheckInvariants{}) { + std::optional error = checkInvariants(requireSurjective); + if (error.has_value()) { + llvm::report_fatal_error(StringRef(*error)); + } +} + +std::optional +LinearLayout::checkInvariants(bool requireSurjective) { + LDBG("checkInvariants: " << toString()); + // Check that basis values are non-negative. + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t b) { return b < 0; })) { + return "Invalid bases passed to LinearLayout. Expected all basis " + "values to be non-negative, but found a negative value for " + "in dimension '" + + inDim.str() + "'. Full list of bases:" + toString() + "\n"; + } + } + } + + // Check that the bases all have length equal to outDimNames.size(). + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (basis.size() != outDims.size()) { + return "Invalid bases passed to LinearLayout. Expect all bases to " + "have the same size, equal to outDimNames.size() (" + + std::to_string(outDims.size()) + + "). But this failed for in dimension '" + inDim.str() + + "'. Full list of bases:" + toString() + "\n"; + } + } + } + + // Check that the out-dim sizes are powers of 2. + for (const auto &[outDim, size] : outDims) { + if (!llvm::isPowerOf2_32(size)) { + return "Invalid out-dim size " + std::to_string(size) + " for out-dim '" + + outDim.str() + "'. Out-dim sizes must be powers of 2.\n"; + } + } + + // Check that the bases are smaller than the out-dim sizes. + SmallVector outDimNames = llvm::to_vector(getOutDimNames()); + for (const auto &[inDim, inDimBases] : this->bases) { + for (const auto &basis : inDimBases) { + for (int i = 0; i < basis.size(); i++) { + if (basis[i] >= outDims[outDimNames[i]]) { + return "Invalid basis " + std::to_string(basis[i]) + " for in-dim '" + + inDim.str() + "' and out-dim '" + outDimNames[i].str() + + "'. Basis must be less than the out-dim size.\n"; + } + } + } + } + + // Determine whether the this layout is surjective, i.e. that every `out` + // coordinate can be reached by some `in` coordinate. + // + // It's prohibitively slow to calculate this naively, but thankfully, this + // is equivalent to checking that the number of linearly-independent bases + // is equal to sum(getOutDimSizeLog2). This can be computed by finding + // the rank of the matrix whose columns are those bases. We can compute + // the rank of our matrix using Gaussian elimination, which runs in O(n^3) + // for an n x n matrix. Our matrix size is sum(inDimSizeLog2) x + // sum(outDimSizeLog2), so this should be plenty fast. + this->surjective = + getMatrixRank(getMatrix(*this), /*numRows=*/getTotalOutDimSizeLog2(), + /*numCols=*/getTotalInDimSizeLog2()) == + getTotalOutDimSizeLog2(); + + if (requireSurjective && !surjective) { + return "Layout is expected to be surjective, i.e. every `out` coordinate " + "can be reached by some `in` coordinate, but was not:" + + toString(); + } + + return std::nullopt; +} + +LinearLayout::LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames) + : LinearLayout(makeBasesMap(bases), outDimNames) {} + +LinearLayout::LinearLayout( + ArrayRef>>> bases, + ArrayRef> outDims, bool requireSurjective) + : LinearLayout(makeBasesMap(bases), outDims, requireSurjective) {} + +/*static*/ LinearLayout LinearLayout::identity1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> powersOf2; + for (int32_t i = 1; i < size; i *= 2) { + powersOf2.emplace_back().push_back(i); + } + return LinearLayout({{inDimName, std::move(powersOf2)}}, {outDimName}); +} + +/*static*/ LinearLayout LinearLayout::zeros1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName, + int32_t outDimSize) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> zeros; + for (int i = 0; i < llvm::Log2_32(size); i++) { + zeros.emplace_back().push_back(0); + } + return LinearLayout({{inDimName, zeros}}, {{outDimName, outDimSize}}, + /*requiresSurjective=*/outDimSize == 1); +} + +int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const { + int i = 0; + for (auto [name, _] : outDims) { + if (name == outDim) { + return i; + } + i++; + } + llvm::report_fatal_error("outDim " + Twine(outDim) + " is not in layout" + + toString()); +} + +int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + return it->second.size(); +} + +int32_t LinearLayout::getTotalInDimSizeLog2() const { + return std::accumulate(getInDimNames().begin(), getInDimNames().end(), 0, + [&](int32_t acc, StringAttr inDim) { + return acc + getInDimSizeLog2(inDim); + }); +} + +int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const { + auto it = outDims.find(outDim); + assert(it != outDims.end()); + return llvm::Log2_32(it->second); +} + +int32_t LinearLayout::getTotalOutDimSizeLog2() const { + return std::accumulate(getOutDimNames().begin(), getOutDimNames().end(), 0, + [&](int32_t acc, StringAttr outDim) { + return acc + getOutDimSizeLog2(outDim); + }); +} + +int32_t LinearLayout::getNumConsecutiveInOut() const { + if (bases.empty() || getNumOutDims() == 0) + return 1; + + // Count how many of the initial bases for the first in-dim are + // (2^i, 0, ..., 0). + const auto &firstInDimBases = bases.begin()->second; + int consec = 0; + for (; consec < firstInDimBases.size(); consec++) { + const auto &basis = firstInDimBases[consec]; + if (basis[0] != (1 << consec) || + !std::all_of(basis.begin() + 1, basis.end(), + [](int32_t x) { return x == 0; })) { + break; + } + } + + // `or` together all other bases' first out-dim. + int32_t otherBits = 0; + for (const auto &[inDim, inDimBases] : bases) { + for (int i = 0; i < inDimBases.size(); i++) { + if (inDim != bases.begin()->first || i >= consec) { + otherBits |= inDimBases[i][0]; + } + } + } + int32_t trailingZeros = otherBits != 0 ? __builtin_ctz(otherBits) : 31; + + return 1 << std::min(consec, trailingZeros); +} + +LinearLayout LinearLayout::transposeIns(ArrayRef newInDims) const { + assertDimsEqualIgnoringOrder(newInDims, getInDimNames()); + + BasesT newBases; + for (const auto &inDim : newInDims) { + newBases[inDim] = bases.find(inDim)->second; + } + return LinearLayout(std::move(newBases), llvm::to_vector(outDims), + surjective); +} + +LinearLayout +LinearLayout::transposeOuts(ArrayRef newOutDims) const { + assertDimsEqualIgnoringOrder(newOutDims, getOutDimNames()); + + std::vector permutation; + for (const auto &outDim : newOutDims) { + permutation.push_back(getOutDimIndex(outDim)); + } + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + std::vector newBasis; + for (int32_t i : permutation) { + newBasis.push_back(basis[i]); + } + newInDimBases.push_back(std::move(newBasis)); + } + } + + SmallVector> newOutDimSizes; + for (auto outDim : newOutDims) { + newOutDimSizes.push_back({outDim, getOutDimSize(outDim)}); + } + return LinearLayout(std::move(newBases), newOutDimSizes, surjective); +} + +LinearLayout LinearLayout::reshapeIns( + ArrayRef> newInDims) const { + assert(llvm::all_of(newInDims, [&](auto &inDim) { + return llvm::isPowerOf2_32(inDim.second); + })); + assert(getTotalInDimSize() == std::accumulate(newInDims.begin(), + newInDims.end(), 1, + [&](int32_t acc, auto &inDim) { + return acc * inDim.second; + })); + + // First flatten into a single in-dimension. Then split it up according + // to `newInDims`. + SmallVector> flatBases; + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + flatBases.push_back(basis); + } + } + + BasesT newBases; + int i = 0; + for (const auto &[inDim, inDimSize] : newInDims) { + auto &newInDimBases = newBases[inDim]; + for (int j = 0; j < llvm::Log2_32(inDimSize); j++) { + newInDimBases.push_back(flatBases[i++]); + } + } + return LinearLayout(std::move(newBases), llvm::to_vector(outDims), + surjective); +} + +LinearLayout LinearLayout::reshapeOuts( + ArrayRef> newOutDims) const { + assert(llvm::all_of(newOutDims, [&](auto &outDim) { + return llvm::isPowerOf2_32(outDim.second); + })); + assert(getTotalOutDimSize() == + std::accumulate( + newOutDims.begin(), newOutDims.end(), 1, + [&](int32_t acc, auto &outDim) { return acc * outDim.second; })); + + SmallVector shifts; + shifts.push_back(0); + for (StringAttr outDim : getOutDimNames()) { + shifts.push_back(shifts.back() + getOutDimSizeLog2(outDim)); + } + + // Flatten into a single out-dimension. Then split it up according to + // `newOutDims`. + llvm::MapVector> flatBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &flatInBases = flatBases[inDim]; + for (const auto &basis : inDimBases) { + int b = 0; + for (int i = 0; i < basis.size(); i++) { + b += basis[i] << shifts[i]; + } + flatInBases.push_back(b); + } + } + + BasesT newBases; + for (const auto &[inDim, flatInBases] : flatBases) { + std::vector> &newInDimBases = newBases[inDim]; + for (int32_t b : flatInBases) { + std::vector multiDimBasis; + for (int32_t newSize : llvm::make_second_range(newOutDims)) { + multiDimBasis.push_back(b % newSize); + b /= newSize; + } + newInDimBases.push_back(std::move(multiDimBasis)); + } + } + + return LinearLayout(std::move(newBases), newOutDims, surjective); +} + +LinearLayout LinearLayout::concatIns(const LinearLayout &other) const { + assert(llvm::to_vector(getOutDimNames()) == + llvm::to_vector(other.getOutDimNames()) && + "layouts must have the same output dimensions"); + for (StringAttr outDim : getOutDimNames()) { + assert(getOutDimSize(outDim) == other.getOutDimSize(outDim) && + "layouts must have the same output dimension sizes"); + } + + LinearLayout::BasesT resultBases = getBases(); + for (auto &bases : other.getBases()) + resultBases.insert(bases); + SmallVector> newOutDims; + for (auto &[outDim, outDimSize] : outDims) + newOutDims.emplace_back(outDim, outDimSize); + return LinearLayout(std::move(resultBases), newOutDims, + /*requiresSurjective=*/false); +} + +LinearLayout LinearLayout::concatOuts(const LinearLayout &other) const { + assert(llvm::to_vector(getInDimNames()) == + llvm::to_vector(other.getInDimNames()) && + "layouts must have the same input dimensions"); + for (StringAttr inDim : getInDimNames()) { + assert(getInDimSize(inDim) == other.getInDimSize(inDim) && + "layouts must have the same input dimension sizes"); + } + + LinearLayout::BasesT result; + for (auto [lhsBases, rhsBases] : llvm::zip(getBases(), other.getBases())) { + auto &resultBases = result[lhsBases.first]; + assert(lhsBases.first == rhsBases.first); + for (auto [lhsBasis, rhsBasis] : + llvm::zip(lhsBases.second, rhsBases.second)) { + std::vector resultBasis; + llvm::append_range(resultBasis, lhsBasis); + llvm::append_range(resultBasis, rhsBasis); + resultBases.push_back(std::move(resultBasis)); + } + } + SmallVector> newOutDims; + for (auto &[outDim, outDimSize] : outDims) + newOutDims.emplace_back(outDim, outDimSize); + for (auto &[outDim, outDimSize] : other.outDims) + newOutDims.emplace_back(outDim, outDimSize); + return LinearLayout(std::move(result), newOutDims, + /*requiresSurjective=*/false); +} + +LinearLayout operator*(LinearLayout inner, LinearLayout outer) { + // Check that dims common to outer and inner have the same relative order. + auto inDims = supremum(llvm::to_vector(inner.getInDimNames()), + llvm::to_vector(outer.getInDimNames())); + auto outDims = supremum(llvm::to_vector(inner.getOutDimNames()), + llvm::to_vector(outer.getOutDimNames())); + + // Get the sizeLog2 of all input and output dimensions we're going to + // consider, in order. `inner` is more minor, so its dimensions come + // first. + llvm::MapVector inDimSizesLog2; + llvm::MapVector outDimSizesLog2; + for (const auto &dim : inDims) + inDimSizesLog2.insert({dim, 0}); + for (const auto &dim : outDims) + outDimSizesLog2.insert({dim, 0}); + for (const auto &layout : {inner, outer}) { + for (StringAttr inDim : layout.getInDimNames()) { + inDimSizesLog2[inDim] += layout.getInDimSizeLog2(inDim); + } + for (StringAttr outDim : layout.getOutDimNames()) { + outDimSizesLog2[outDim] += layout.getOutDimSizeLog2(outDim); + } + } + + BasesT allBases; + for (auto [inDimName, inDimSizeLog2] : inDimSizesLog2) { + std::vector> &inDimBases = allBases[inDimName]; + + // Fill with zeros. + inDimBases = std::vector>( + inDimSizeLog2, std::vector(outDimSizesLog2.size(), 0)); + + for (auto [outDimIdx, outDimNameAndSize] : + llvm::enumerate(outDimSizesLog2)) { + auto [outDimName, outDimSize] = outDimNameAndSize; + if (inner.hasInDim(inDimName) && inner.hasOutDim(outDimName)) { + for (int i = 0; i < inner.getInDimSizeLog2(inDimName); i++) { + inDimBases[i][outDimIdx] = inner.getBasis(inDimName, i, outDimName); + } + } + if (outer.hasInDim(inDimName) && outer.hasOutDim(outDimName)) { + int offset = + inner.hasInDim(inDimName) ? inner.getInDimSizeLog2(inDimName) : 0; + int shift = inner.hasOutDim(outDimName) + ? inner.getOutDimSizeLog2(outDimName) + : 0; + for (int i = 0; i < outer.getInDimSizeLog2(inDimName); i++) { + inDimBases[offset + i][outDimIdx] = + outer.getBasis(inDimName, i, outDimName) << shift; + } + } + } + } + + llvm::SmallVector> outDimSizes; + for (auto [outDim, sizeLog2] : outDimSizesLog2) { + outDimSizes.push_back({outDim, 1 << sizeLog2}); + } + return LinearLayout(std::move(allBases), outDimSizes, + inner.isSurjective() && outer.isSurjective()); +} + +bool LinearLayout::isTrivialOver(ArrayRef dimNames) const { + for (StringAttr dim : dimNames) { + if (!llvm::is_contained(getInDimNames(), dim) && + !llvm::is_contained(getOutDimNames(), dim)) { + return false; + } + } + + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); + } + } + return remainingDimNames; + }; + SmallVector remainingInDimNames = + getRemainingDimNames(getInDimNames()); + SmallVector remainingOutDimNames = + getRemainingDimNames(getOutDimNames()); + + // Think of this as a block-matrix multiplying a vector: + // [[A, B], * [v_1, + // [C, D]] v_2] + // where v_2 is the dimNames and v_1 is the remainingInDimNames + // We can quotient out dimNames iff they don't affect the remainingInDimNames + // in the result. In other words, we want to check that B is zero, and C is + // zero, and D is the identity + return squareSublayoutIsIdentity(*this, dimNames) && + sublayoutIsZero(remainingInDimNames, dimNames) && + sublayoutIsZero(dimNames, remainingOutDimNames); +} + +std::optional +LinearLayout::quotient(ArrayRef dimNames) const { + if (!isTrivialOver(dimNames)) { + return std::nullopt; + } + + // This should probably be even less general, where we ask inDimNames == + // outDimNames + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); + } + } + return remainingDimNames; + }; + + SmallVector inDimNames = getRemainingDimNames(getInDimNames()); + SmallVector outDimNames = getRemainingDimNames(getOutDimNames()); + + return sublayout(inDimNames, outDimNames); +} + +LinearLayout LinearLayout::sublayout(ArrayRef inDimNames, + ArrayRef outDimNames) const { + assertDimsSubsetIgnoringOrder(inDimNames, getInDimNames()); + assertDimsSubsetIgnoringOrder(outDimNames, getOutDimNames()); + SmallDenseSet inDimSet(inDimNames.begin(), inDimNames.end()); + SmallDenseSet outDimSet(outDimNames.begin(), outDimNames.end()); + + SmallVector outDimIndicesToKeep; + for (auto [i, outDim] : llvm::enumerate(getOutDimNames())) { + if (outDimSet.contains(outDim)) { + outDimIndicesToKeep.push_back(i); + } + } + BasesT newBases; + for (auto [inDim, inDimBases] : bases) { + if (!inDimSet.contains(inDim)) { + continue; + } + auto &newInDimBases = newBases[inDim]; + for (auto &basis : inDimBases) { + auto &newBasis = newInDimBases.emplace_back(); + for (int i : outDimIndicesToKeep) { + newBasis.push_back(basis[i]); + } + } + } + + SmallVector> newOutDims; + for (auto [outDim, outDimSize] : outDims) { + if (outDimSet.contains(outDim)) { + newOutDims.push_back({outDim, outDimSize}); + } + } + return LinearLayout(std::move(newBases), std::move(newOutDims), + /*requireSurjective=*/false); +} + +bool LinearLayout::sublayoutIsZero(ArrayRef inDimNames, + ArrayRef outDimNames) const { + LinearLayout ss = sublayout(inDimNames, outDimNames); + for (auto [inDim, inDimBases] : ss.bases) { + for (auto basis : inDimBases) { + if (!llvm::all_of(basis, [](int32_t b) { return b == 0; })) { + return false; + } + } + } + return true; +} + +SmallVector> +LinearLayout::apply(ArrayRef> ins) const { + assertDimsEqualIgnoringOrder(llvm::make_first_range(ins), getInDimNames()); + + SmallVector> ret; + for (StringAttr outDim : getOutDimNames()) { + int32_t outVal = 0; + for (auto &[inDim, val] : ins) { + for (int i = 0; i < getInDimSizeLog2(inDim); i++) { + if (val & (1 << i)) + outVal ^= getBasis(inDim, i, outDim); + } + } + ret.push_back({outDim, outVal}); + } + return ret; +} + +LinearLayout LinearLayout::compose(const LinearLayout &outer) const { + assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getInDimNames()); + for (StringAttr outDim : getOutDimNames()) { + assert(getOutDimSize(outDim) <= outer.getInDimSize(outDim)); + } + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + SmallVector> bases; + for (auto [outDim, b] : llvm::zip(getOutDimNames(), basis)) { + bases.push_back({outDim, b}); + } + auto newBases = outer.apply(bases); + auto newBasesRange = llvm::make_second_range(newBases); + newInDimBases.push_back( + std::vector(newBasesRange.begin(), newBasesRange.end())); + } + } + + bool compositionIsSurjective = + isSurjective() && outer.isSurjective() && + llvm::all_of(getOutDimNames(), [&](StringAttr outDim) { + return getOutDimSize(outDim) == outer.getInDimSize(outDim); + }); + return LinearLayout(std::move(newBases), llvm::to_vector(outer.outDims), + compositionIsSurjective); +} + +namespace { +std::unique_ptr concatMatrices(const LinearLayout &A, + const LinearLayout &B) { + // In plain words, "convert_layout does not change the shape of a tensor" + assert(A.getTotalOutDimSizeLog2() == B.getTotalOutDimSizeLog2() && + "Matrices must have the same number of output dimensions"); + int numRows = A.getTotalOutDimSizeLog2(); + int numColsA = A.getTotalInDimSizeLog2(); + + // rref expects the lower bits to be the lower indices of the matrix + auto concat = getMatrix(A); + auto BMat = getMatrix(B); + for (int r = 0; r < numRows; r++) { + concat[r] |= BMat[r] << numColsA; + } + return concat; +} + +LinearLayout lstsq(const LinearLayout &A, const LinearLayout &B) { + // Solve the least square system AX = B for A = outer, B = *this + // and return the least square solution X of minimal norm + // A and B may not be surjective, but we assume that Im(B) \subset Im(A) + // Sketch of the algorithm: + // https://github.com/triton-lang/triton/pull/5309#discussion_r1869084111 + int numRows = A.getTotalOutDimSizeLog2(); + int numColsA = A.getTotalInDimSizeLog2(); + int numColsB = B.getTotalInDimSizeLog2(); + int numCols = numColsA + numColsB; + std::unique_ptr combinedMat = concatMatrices(A, B); + f2reduce::inplace_rref_strided(combinedMat.get(), numRows, numCols, + /*stride=*/1); + + // Compute the pivot columns + // Since A and B have the same image, each row will either have a pivot + // or will be all zeros + SmallVector pivotCols; + for (int r = 0; r < numRows; r++) { + auto row = combinedMat[r]; + if (row == 0) { + continue; + } + int c = __builtin_ctzll(row); + assert(c < numColsA && "Precondition broken. Im(B) not contained in Im(A)"); + assert(pivotCols.empty() || + pivotCols.back() < c && "Pivot columns are not in increasing order"); + pivotCols.push_back(c); + } + + // Extract A^{-1}B and complete the matrix using zeros + std::unique_ptr retMat(new uint64_t[numColsA]()); + int j = 0; + for (int r = 0; r < numColsA; r++) { + auto isPivot = j < pivotCols.size() && pivotCols[j] == r; + retMat[r] = isPivot ? combinedMat[j++] >> numColsA : 0; + } + + // We need names for the in/out dim of the flattened layout we're going to + // read off from `m`. These could be anything, doesn't matter. + StringAttr inDim1D = *A.getInDimNames().begin(); + StringAttr outDim1D = *A.getOutDimNames().begin(); + + // Read off the new bases. These are for a flattened 1D -> 1D + LinearLayout::BasesT retBases; + auto &bs = retBases[inDim1D]; + for (int c = 0; c < numColsB; c++) { + int32_t basis = 0; + for (int r = 0; r < numColsA; r++) { + basis |= (retMat[r] >> c & 1) << r; + } + bs.push_back({basis}); + } + + LinearLayout retFlattened(std::move(retBases), + {{outDim1D, A.getTotalInDimSize()}}, + /*requireSurjective=*/false); + + SmallVector> retInDims; + SmallVector> retOutDims; + for (StringAttr dim : B.getInDimNames()) { + retInDims.push_back({dim, B.getInDimSize(dim)}); + } + for (StringAttr dim : A.getInDimNames()) { + retOutDims.push_back({dim, A.getInDimSize(dim)}); + } + return retFlattened.reshapeIns(retInDims).reshapeOuts(retOutDims); +} + +} // namespace + +LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { + // TODO(Lezcano) Make friend and perhaps rename to `convertFrom` or `lstsq` + // For this, we need to implement our LLVM lowerings by inverting the "outer" + // layout, and then iterating over the elements from the "this" layout and + // fetching the corresponding element from the "outer" layout. This exercises + // the broadcasting that we incentivise via choosing the minimum norm solution + // in lstsq. + + // The order of dims does not matter. We choose to transpose outer + auto outDims = llvm::to_vector(getOutDimNames()); + assertDimsEqualIgnoringOrder(outDims, outer.getOutDimNames()); + const auto &B = *this; + const auto A = outer.transposeOuts(outDims); + for (auto dim : outDims) { + assert(A.getOutDimSize(dim) == B.getOutDimSize(dim) && + "Convert layout does not change the shape of a tensor"); + } + + // We'll write A^{-1} to mean the inverse or the pseudo-inverse of A + // We are computing A^{-1}B so A must be surjective so that + // it has a left inverse. + assert(A.isSurjective()); + + // Broadcasting heuristic + // Imagine we have two layouts with `warps = [[0, 0],  [0, 0]]` + // (broadcasting) on both layouts. We could map any warp to any warp in the + // conversion. Now, we want to map them as the identity map, to mark that + // nothing needs to be done there (`lstsq` would map all the warps to the + // zero warp, minimum norm solution). The heuristic here is as follows: + // - If a dimension is the same for both layouts, we want to map it as the + // identity + // Equivalently, we don't add it to the conversion + // - Otherwise, we just call lstsq (i.e. map all the equivalent elements + // to the same input element) to take advantage of broadcasting in shared + // memory and avoid saving repeated elements in shared memory + SmallVector identityDims; + for (auto dim : A.getInDimNames()) { + if (B.hasInDim(dim) && + A.sublayout(dim, outDims) == B.sublayout(dim, outDims)) { + identityDims.push_back(dim); + } + } + SmallVector ANonIdentityInDims; + SmallVector BNonIdentityInDims; + for (auto dim : A.getInDimNames()) { + if (!llvm::is_contained(identityDims, dim)) { + ANonIdentityInDims.push_back(dim); + } + } + for (auto dim : B.getInDimNames()) { + if (!llvm::is_contained(identityDims, dim)) { + BNonIdentityInDims.push_back(dim); + } + } + + auto AReduced = A.sublayout(ANonIdentityInDims, outDims); + auto BReduced = B.sublayout(BNonIdentityInDims, outDims); + + // If one is empty, the other must be empty as well + assert((AReduced == LinearLayout::empty()) == + (BReduced == LinearLayout::empty())); + bool isEmpty = AReduced == LinearLayout::empty(); + + auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced); + + // TODO(Lezcano): We should return the reduced layout instead of re-adding the + // identity maps. With this, we'll be able to kill `minimalCvtLayout` + + // Add the identity maps for the dimensions that are the same for both layouts + for (auto dim : identityDims) { + ret *= LinearLayout::identity1D(A.getInDimSize(dim), dim, dim); + } + + // Reorder the dimensions in the result to match the order expected by the + // current and outer layouts. + return ret.transposeIns(llvm::to_vector(B.getInDimNames())) + .transposeOuts(llvm::to_vector(A.getInDimNames())); +} + +LinearLayout LinearLayout::invert() const { + assert(isInvertible() && + "A linear layout must be surjective and square to be invertible"); + return pseudoinvert(); +} + +LinearLayout LinearLayout::pseudoinvert() const { + // A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A) + assert(isSurjective() && + "A linear layout must be surjective to compute its pseudoinverse"); + LinearLayout identity = LinearLayout::empty(); + for (auto outDim : getOutDimNames()) { + identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim); + } + return identity.invertAndCompose(*this); +} + +llvm::MapVector +LinearLayout::getFreeVariableMasks() const { + std::unique_ptr mat = getMatrix(*this); + int numRows = getTotalOutDimSizeLog2(); + int numCols = getTotalInDimSizeLog2(); + + // stride is specified in number of 64-bit words per row, and we pack our + // matrix so that there's only one uint64_t per row. + assert(numCols <= 64); + f2reduce::inplace_rref_strided(mat.get(), numRows, numCols, /*stride=*/1); + + // For each row in the RREF matrix, identify the column with the first "1". + // These columns correspond to the basic (i.e. non-free) variables. + std::set basicVars; + for (int r = 0; r < numRows; r++) { + if (mat[r] == 0) { + continue; + } + basicVars.insert(__builtin_ctzll(mat[r])); + } + + llvm::MapVector ret; + int c = 0; + for (StringAttr dim : getInDimNames()) { + int32_t mask = 0; + for (int i = 0; i < getInDimSizeLog2(dim); i++, c++) { + if (basicVars.count(c) == 0) { + mask |= (1 << i); + } + } + ret[dim] = mask; + } + return ret; +} + +LinearLayout LinearLayout::removeZeroBasesAlongDim(StringAttr stripDim) const { + LinearLayout::BasesT result; + for (auto &[inDim, inDimBases] : getBases()) { + auto &newInDimBases = result[inDim]; + if (inDim != stripDim) { + newInDimBases = inDimBases; + continue; + } + for (auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) { + newInDimBases.push_back(basis); + } + } + } + return LinearLayout(std::move(result), llvm::to_vector(getOutDimNames())); +} + +size_t hash_value(const LinearLayout &layout) { + size_t seed = 0; + + // Hash the bases + for (const auto &base : layout.getBases()) { + // Hash the input dimension name + seed = llvm::hash_combine(seed, base.first); + + // Hash the vectors in bases + for (const auto &vec : base.second) { + for (int32_t val : vec) { + seed = llvm::hash_combine(seed, val); + } + } + } + + // Hash the output dimensions and their sizes + for (const auto &outDim : layout.getOutDimNames()) { + seed = llvm::hash_combine(seed, outDim, layout.getOutDimSize(outDim)); + } + // Don't hash the surjective flag as it's a cached property + return seed; +} + +bool operator==(LinearLayout lhs, LinearLayout rhs) { + if (!lhs.equalIgnoringOutDimSizes(rhs)) + return false; + + for (const auto &[lhsOutDimAndSize, rhsOutDimAndSize] : + llvm::zip(lhs.outDims, rhs.outDims)) { + if (lhsOutDimAndSize.second != rhsOutDimAndSize.second) + return false; + } + return true; +} + +bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { + // llvm::MapVector doesn't have an operator== :(. + if (llvm::to_vector(this->getOutDimNames()) != + llvm::to_vector(other.getOutDimNames())) + return false; + if (this->bases.size() != other.bases.size()) + return false; + for (auto it1 = this->bases.begin(), it2 = other.bases.begin(); + it1 != this->bases.end(); ++it1, ++it2) { + if (*it1 != *it2) + return false; + } + return true; +} + +std::string LinearLayout::toString() const { + // Start with a newline because we print out a bulleted list; it doesn't + // make sense for the first line of this list to be on the same line as + // any previous text. + std::string ret = "\n"; + std::string outDimsStr = + "[" + + join(outDims, ", ", + [](auto dimAndSize) { + auto [outDim, size] = dimAndSize; + return outDim.str() + " (size " + std::to_string(size) + ")"; + }) + + "]"; + + if (bases.empty()) { + if (outDims.empty()) { + return "\n(empty layout)"; + } else { + return "\n(empty layout with out-dims " + outDimsStr + ")"; + } + } + + // TODO: Add spaces for alignment. + for (const auto &[inDim, inDimBases] : bases) { + if (inDimBases.empty()) { + ret += " - " + inDim.str() + " is a size 1 dimension\n"; + continue; + } + + ret += " - " + + join(llvm::seq(inDimBases.size()), "\n ", + [&, &inDim = inDim, &inDimBases = inDimBases](int i) { + return inDim.str() + "=" + std::to_string(1 << i) + " -> (" + + join(inDimBases[i], ", ") + ")"; + }) + + "\n"; + } + ret += "where out dims are: " + outDimsStr; + return ret; +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/pyproject.toml b/third_party/enflame/include/triton/pyproject.toml new file mode 100644 index 000000000..525f303ef --- /dev/null +++ b/third_party/enflame/include/triton/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"] + +[tool.yapf] +based_on_style = "pep8" +column_limit = 120 +disable_split_list_with_comment = true +each_dict_entry_on_separate_line=false +split_before_named_assigns = false +split_complex_comprehension = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = ["E501", "E701", "E731", "E741"] diff --git a/third_party/enflame/include/triton/python/MANIFEST.in b/third_party/enflame/include/triton/python/MANIFEST.in new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/include/triton/python/build_helpers.py b/third_party/enflame/include/triton/python/build_helpers.py new file mode 100644 index 000000000..da2efd30c --- /dev/null +++ b/third_party/enflame/include/triton/python/build_helpers.py @@ -0,0 +1,17 @@ +import os +import sysconfig +import sys +from pathlib import Path + + +def get_base_dir(): + return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + + +def get_cmake_dir(): + plat_name = sysconfig.get_platform() + python_version = sysconfig.get_python_version() + dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}" + cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name + cmake_dir.mkdir(parents=True, exist_ok=True) + return cmake_dir diff --git a/third_party/enflame/include/triton/python/pyproject.toml b/third_party/enflame/include/triton/python/pyproject.toml new file mode 100644 index 000000000..d96af50a5 --- /dev/null +++ b/third_party/enflame/include/triton/python/pyproject.toml @@ -0,0 +1,15 @@ + +[build-system] +requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1", "pybind11>=2.13.1"] + +# We're incrementally switching from autopep8 to ruff. +[tool.autopep8] +aggressive = 1 +ignore = "E501,E701,E731,W690,W503" +max_line_length = 88 + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = ["E501", "E701", "E731", "E741"] diff --git a/third_party/enflame/include/triton/python/requirements.txt b/third_party/enflame/include/triton/python/requirements.txt new file mode 100644 index 000000000..1b46cd5e6 --- /dev/null +++ b/third_party/enflame/include/triton/python/requirements.txt @@ -0,0 +1,8 @@ +ninja +cmake +setuptools>=40.8.0 +wheel +cmake>=3.18,<4.0 +ninja>=1.11.1 +pybind11>=2.13.1 +lit diff --git a/third_party/enflame/include/triton/python/setup.py b/third_party/enflame/include/triton/python/setup.py new file mode 100644 index 000000000..9d312c63a --- /dev/null +++ b/third_party/enflame/include/triton/python/setup.py @@ -0,0 +1,809 @@ +import os +import platform +import re +import contextlib +import shlex +import shutil +import subprocess +import sys +import sysconfig +import tarfile +import zipfile +import urllib.request +import json +from io import BytesIO +from distutils.command.clean import clean +from pathlib import Path +from typing import List, Optional + +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext +from setuptools.command.build_py import build_py +from dataclasses import dataclass + +from distutils.command.install import install +from setuptools.command.develop import develop +from setuptools.command.egg_info import egg_info +from wheel.bdist_wheel import bdist_wheel + +import pybind11 + +from build_helpers import get_base_dir, get_cmake_dir + + +@dataclass +class Backend: + name: str + package_data: List[str] + language_package_data: List[str] + tools_package_data: List[str] + src_dir: str + backend_dir: str + language_dir: Optional[str] + tools_dir: Optional[str] + install_dir: str + is_external: bool + + +class BackendInstaller: + + @staticmethod + def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = False): + # Initialize submodule if there is one for in-tree backends. + if not is_external: + root_dir = os.path.join(os.pardir, "third_party") + assert backend_name in os.listdir( + root_dir), f"{backend_name} is requested for install but not present in {root_dir}" + + try: + subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True, + stdout=subprocess.DEVNULL, cwd=root_dir) + except subprocess.CalledProcessError: + pass + except FileNotFoundError: + pass + + backend_src_dir = os.path.join(root_dir, backend_name) + + backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend")) + assert os.path.exists(backend_path), f"{backend_path} does not exist!" + + language_dir = os.path.abspath(os.path.join(backend_src_dir, "language")) + if not os.path.exists(language_dir): + language_dir = None + + tools_dir = os.path.abspath(os.path.join(backend_src_dir, "tools")) + if not os.path.exists(tools_dir): + tools_dir = None + + for file in ["compiler.py", "driver.py"]: + assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}" + + install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name) + package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)] + + language_package_data = [] + if language_dir is not None: + language_package_data = [f"{os.path.relpath(p, language_dir)}/*" for p, _, _, in os.walk(language_dir)] + + tools_package_data = [] + if tools_dir is not None: + tools_package_data = [f"{os.path.relpath(p, tools_dir)}/*" for p, _, _, in os.walk(tools_dir)] + + return Backend(name=backend_name, package_data=package_data, language_package_data=language_package_data, + tools_package_data=tools_package_data, src_dir=backend_src_dir, backend_dir=backend_path, + language_dir=language_dir, tools_dir=tools_dir, install_dir=install_dir, is_external=is_external) + + # Copy all in-tree backends under triton/third_party. + @staticmethod + def copy(active): + return [BackendInstaller.prepare(backend) for backend in active] + + # Copy all external plugins provided by the `TRITON_PLUGIN_DIRS` env var. + # TRITON_PLUGIN_DIRS is a semicolon-separated list of paths to the plugins. + # Expect to find the name of the backend under dir/backend/name.conf + @staticmethod + def copy_externals(): + backend_dirs = os.getenv("TRITON_PLUGIN_DIRS") + if backend_dirs is None: + return [] + backend_dirs = backend_dirs.strip().split(";") + backend_names = [Path(os.path.join(dir, "backend", "name.conf")).read_text().strip() for dir in backend_dirs] + return [ + BackendInstaller.prepare(backend_name, backend_src_dir=backend_src_dir, is_external=True) + for backend_name, backend_src_dir in zip(backend_names, backend_dirs) + ] + + +# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py +def check_env_flag(name: str, default: str = "") -> bool: + return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + +def get_build_type(): + if check_env_flag("DEBUG"): + return "Debug" + elif check_env_flag("REL_WITH_DEB_INFO"): + return "RelWithDebInfo" + elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"): + return "TritonRelBuildWithAsserts" + elif check_env_flag("TRITON_BUILD_WITH_O1"): + return "TritonBuildWithO1" + else: + # TODO: change to release when stable enough + return "TritonRelBuildWithAsserts" + + +def get_env_with_keys(key: list): + for k in key: + if k in os.environ: + return os.environ[k] + return "" + + +def is_offline_build() -> bool: + """ + Downstream projects and distributions which bootstrap their own dependencies from scratch + and run builds in offline sandboxes + may set `TRITON_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading + pinned dependencies from the internet or at using dependencies vendored in-tree. + + Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`). + Missing dependencies lead to an early abortion. + Dependencies' compatibility is not verified. + + Note that this flag isn't tested by the CI and does not provide any guarantees. + """ + return check_env_flag("TRITON_OFFLINE_BUILD", "") + + +# --- third party packages ----- + + +@dataclass +class Package: + package: str + name: str + url: str + include_flag: str + lib_flag: str + syspath_var_name: str + sym_name: Optional[str] = None + + +# json +def get_json_package_info(): + url = "https://github.com/nlohmann/json/releases/download/v3.11.3/include.zip" + return Package("json", "", url, "JSON_INCLUDE_DIR", "", "JSON_SYSPATH") + + +def is_linux_os(id): + if os.path.exists("/etc/os-release"): + with open("/etc/os-release", "r") as f: + os_release_content = f.read() + return f'ID="{id}"' in os_release_content + return False + + +# llvm +def get_llvm_package_info(): + print("PyDbg - enflame/setup.py: get_llvm_package_info") + system = platform.system() + try: + arch = {"x86_64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()] + except KeyError: + arch = platform.machine() + if system == "Darwin": + system_suffix = f"macos-{arch}" + elif system == "Linux": + if arch == 'arm64' and is_linux_os('almalinux'): + system_suffix = 'almalinux-arm64' + elif arch == 'arm64': + system_suffix = 'ubuntu-arm64' + elif arch == 'x64': + vglibc = tuple(map(int, platform.libc_ver()[1].split('.'))) + vglibc = vglibc[0] * 100 + vglibc[1] + if vglibc > 228: + # Ubuntu 24 LTS (v2.39) + # Ubuntu 22 LTS (v2.35) + # Ubuntu 20 LTS (v2.31) + system_suffix = "ubuntu-x64" + elif vglibc > 217: + # Manylinux_2.28 (v2.28) + # AlmaLinux 8 (v2.28) + system_suffix = "almalinux-x64" + else: + # Manylinux_2014 (v2.17) + # CentOS 7 (v2.17) + system_suffix = "centos-x64" + else: + print( + f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build." + ) + return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") + else: + print( + f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build." + ) + return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") + # use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") + # release_suffix = "assert" if use_assert_enabled_llvm else "release" + llvm_hash_path = os.path.join(get_base_dir(), "cmake", "llvm-hash.txt") + with open(llvm_hash_path, "r") as llvm_hash_file: + rev = llvm_hash_file.read(8) + name = f"llvm-{rev}-{system_suffix}" + # Create a stable symlink that doesn't include revision + sym_name = f"llvm-{system_suffix}" + url = f"https://oaitriton.blob.core.windows.net/public/llvm-builds/{name}.tar.gz" + return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH", sym_name=sym_name) + + +def open_url(url): + user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0' + headers = { + 'User-Agent': user_agent, + } + request = urllib.request.Request(url, None, headers) + # Set timeout to 300 seconds to prevent the request from hanging forever. + return urllib.request.urlopen(request, timeout=300) + + +# ---- package data --- + + +def get_triton_cache_path(): + user_home = os.getenv("TRITON_HOME") + if not user_home: + user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None + if not user_home: + raise RuntimeError("Could not find user home directory") + return os.path.join(user_home, ".triton") + + +def update_symlink(link_path, source_path): + source_path = Path(source_path) + link_path = Path(link_path) + + if link_path.is_symlink(): + link_path.unlink() + elif link_path.exists(): + shutil.rmtree(link_path) + + print(f"creating symlink: {link_path} -> {source_path}", file=sys.stderr) + link_path.absolute().parent.mkdir(parents=True, exist_ok=True) # Ensure link's parent directory exists + link_path.symlink_to(source_path, target_is_directory=True) + + +def get_thirdparty_packages(packages: list): + triton_cache_path = get_triton_cache_path() + thirdparty_cmake_args = [] + for p in packages: + package_root_dir = os.path.join(triton_cache_path, p.package) + package_dir = os.path.join(package_root_dir, p.name) + if os.environ.get(p.syspath_var_name): + package_dir = os.environ[p.syspath_var_name] + version_file_path = os.path.join(package_dir, "version.txt") + + input_defined = p.syspath_var_name in os.environ + input_exists = os.path.exists(version_file_path) + input_compatible = input_exists and Path(version_file_path).read_text() == p.url + + if is_offline_build() and not input_defined: + raise RuntimeError(f"Requested an offline build but {p.syspath_var_name} is not set") + if not is_offline_build() and not input_defined and not input_compatible: + with contextlib.suppress(Exception): + shutil.rmtree(package_root_dir) + os.makedirs(package_root_dir, exist_ok=True) + print(f'downloading and extracting {p.url} ...') + with open_url(p.url) as response: + if p.url.endswith(".zip"): + file_bytes = BytesIO(response.read()) + with zipfile.ZipFile(file_bytes, "r") as file: + file.extractall(path=package_root_dir) + else: + with tarfile.open(fileobj=response, mode="r|*") as file: + file.extractall(path=package_root_dir) + # write version url to package_dir + with open(os.path.join(package_dir, "version.txt"), "w") as f: + f.write(p.url) + if p.include_flag: + thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include") + if p.lib_flag: + thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib") + if p.sym_name is not None: + sym_link_path = os.path.join(package_root_dir, p.sym_name) + update_symlink(sym_link_path, package_dir) + + return thirdparty_cmake_args + + +def download_and_copy(name, src_func, dst_path, variable, version, url_func): + if is_offline_build(): + return + triton_cache_path = get_triton_cache_path() + if variable in os.environ: + return + base_dir = os.path.dirname(__file__) + system = platform.system() + arch = platform.machine() + # NOTE: This might be wrong for jetson if both grace chips and jetson chips return aarch64 + arch = {"arm64": "sbsa", "aarch64": "sbsa"}.get(arch, arch) + supported = {"Linux": "linux", "Darwin": "linux"} + url = url_func(supported[system], arch, version) + src_path = src_func(supported[system], arch, version) + tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download + dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path + src_path = os.path.join(tmp_path, src_path) + download = not os.path.exists(src_path) + if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None: + curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip() + curr_version = re.search(r"V([.|\d]+)", curr_version) + assert curr_version is not None, f"No version information for {dst_path}" + download = download or curr_version.group(1) != version + if download: + print(f'downloading and extracting {url} ...') + file = tarfile.open(fileobj=open_url(url), mode="r|*") + file.extractall(path=tmp_path) + os.makedirs(os.path.split(dst_path)[0], exist_ok=True) + print(f'copy {src_path} to {dst_path} ...') + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path, dirs_exist_ok=True) + else: + shutil.copy(src_path, dst_path) + + +# ---- cmake extension ---- + + +class CMakeClean(clean): + + def initialize_options(self): + clean.initialize_options(self) + self.build_temp = get_cmake_dir() + + +class CMakeBuildPy(build_py): + + def run(self) -> None: + self.run_command('build_ext') + return super().run() + + +class CMakeExtension(Extension): + + def __init__(self, name, path, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + self.path = path + + +class CMakeBuild(build_ext): + + user_options = build_ext.user_options + \ + [('base-dir=', None, 'base directory of Triton')] + + def initialize_options(self): + build_ext.initialize_options(self) + self.base_dir = get_base_dir() + + def finalize_options(self): + build_ext.finalize_options(self) + + def run(self): + try: + out = subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError("CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions)) + + match = re.search(r"version\s*(?P\d+)\.(?P\d+)([\d.]+)?", out.decode()) + cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor")) + if (cmake_major, cmake_minor) < (3, 18): + raise RuntimeError("CMake >= 3.18.0 is required") + + for ext in self.extensions: + self.build_extension(ext) + + def get_pybind11_cmake_args(self): + pybind11_sys_path = get_env_with_keys(["PYBIND11_SYSPATH"]) + if pybind11_sys_path: + pybind11_include_dir = os.path.join(pybind11_sys_path, "include") + else: + pybind11_include_dir = pybind11.get_include() + return [f"-Dpybind11_INCLUDE_DIR='{pybind11_include_dir}'", f"-Dpybind11_DIR='{pybind11.get_cmake_dir()}'"] + + def get_proton_cmake_args(self): + cmake_args = get_thirdparty_packages([get_json_package_info()]) + cmake_args += self.get_pybind11_cmake_args() + cupti_include_dir = get_env_with_keys(["TRITON_CUPTI_INCLUDE_PATH"]) + if cupti_include_dir == "": + cupti_include_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "include") + cmake_args += ["-DCUPTI_INCLUDE_DIR=" + cupti_include_dir] + roctracer_include_dir = get_env_with_keys(["TRITON_ROCTRACER_INCLUDE_PATH"]) + if roctracer_include_dir == "": + roctracer_include_dir = os.path.join(get_base_dir(), "third_party", "amd", "backend", "include") + cmake_args += ["-DROCTRACER_INCLUDE_DIR=" + roctracer_include_dir] + return cmake_args + + def build_extension(self, ext): + lit_dir = shutil.which('lit') + ninja_dir = shutil.which('ninja') + # lit is used by the test suite + thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) + thirdparty_cmake_args += self.get_pybind11_cmake_args() + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) + # create build directories + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + # python directories + python_include_dir = sysconfig.get_path("platinclude") + cmake_args = [ + "-G", "Ninja", # Ninja is much faster than make + "-DCMAKE_MAKE_PROGRAM=" + + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path + "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON", + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF", + "-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, + "-DPython3_INCLUDE_DIR=" + python_include_dir, + "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), + "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]) + ] + if lit_dir is not None: + cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) + cmake_args.extend(thirdparty_cmake_args) + + # configuration + cfg = get_build_type() + build_args = ["--config", cfg] + + cmake_args += [f"-DCMAKE_BUILD_TYPE={cfg}"] + if platform.system() == "Windows": + cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] + else: + max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) + build_args += ['-j' + max_jobs] + + if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"): + cmake_args += [ + "-DCMAKE_C_COMPILER=clang", + "-DCMAKE_CXX_COMPILER=clang++", + "-DCMAKE_LINKER=lld", + "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", + "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", + "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", + ] + + # Note that asan doesn't work with binaries that use the GPU, so this is + # only useful for tools like triton-opt that don't run code on the GPU. + # + # I tried and gave up getting msan to work. It seems that libstdc++'s + # std::string does not play nicely with clang's msan (I didn't try + # gcc's). I was unable to configure clang to ignore the error, and I + # also wasn't able to get libc++ to work, but that doesn't mean it's + # impossible. :) + if check_env_flag("TRITON_BUILD_WITH_ASAN"): + cmake_args += [ + "-DCMAKE_C_FLAGS=-fsanitize=address", + "-DCMAKE_CXX_FLAGS=-fsanitize=address", + ] + + # environment variables we will pass through to cmake + passthrough_args = [ + "TRITON_BUILD_PROTON", + "TRITON_BUILD_TUTORIALS", + "TRITON_BUILD_WITH_CCACHE", + "TRITON_PARALLEL_LINK_JOBS", + ] + cmake_args += [f"-D{option}={os.getenv(option)}" for option in passthrough_args if option in os.environ] + + if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON + cmake_args += self.get_proton_cmake_args() + + if is_offline_build(): + # unit test builds fetch googletests from GitHub + cmake_args += ["-DTRITON_BUILD_UT=OFF"] + + cmake_args_append = os.getenv("TRITON_APPEND_CMAKE_ARGS") + if cmake_args_append is not None: + cmake_args += shlex.split(cmake_args_append) + + env = os.environ.copy() + cmake_dir = get_cmake_dir() + subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) + subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) + subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) + + +nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json") +with open(nvidia_version_path, "r") as nvidia_version_file: + # parse this json file to get the version of the nvidia toolchain + NVIDIA_TOOLCHAIN_VERSION = json.load(nvidia_version_file) + +exe_extension = sysconfig.get_config_var("EXE") +download_and_copy( + name="nvcc", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", + dst_path="bin/ptxas", + variable="TRITON_PTXAS_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", +) +# We download a separate ptxas for blackwell, since there are some bugs when using it for hopper +download_and_copy( + name="nvcc", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", + dst_path="bin/ptxas-blackwell", + variable="TRITON_PTXAS_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas-blackwell"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", +) +download_and_copy( + name="cuobjdump", + src_func=lambda system, arch, version: + f"cuda_cuobjdump-{system}-{arch}-{version}-archive/bin/cuobjdump{exe_extension}", + dst_path="bin/cuobjdump", + variable="TRITON_CUOBJDUMP_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cuobjdump/{system}-{arch}/cuda_cuobjdump-{system}-{arch}-{version}-archive.tar.xz", +) +download_and_copy( + name="nvdisasm", + src_func=lambda system, arch, version: + f"cuda_nvdisasm-{system}-{arch}-{version}-archive/bin/nvdisasm{exe_extension}", + dst_path="bin/nvdisasm", + variable="TRITON_NVDISASM_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvdisasm/{system}-{arch}/cuda_nvdisasm-{system}-{arch}-{version}-archive.tar.xz", +) +download_and_copy( + name="nvcc", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/include", + dst_path="include", + variable="TRITON_CUDACRT_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", +) +download_and_copy( + name="cudart", + src_func=lambda system, arch, version: f"cuda_cudart-{system}-{arch}-{version}-archive/include", + dst_path="include", + variable="TRITON_CUDART_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["cudart"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/{system}-{arch}/cuda_cudart-{system}-{arch}-{version}-archive.tar.xz", +) +download_and_copy( + name="cupti", + src_func=lambda system, arch, version: f"cuda_cupti-{system}-{arch}-{version}-archive/include", + dst_path="include", + variable="TRITON_CUPTI_INCLUDE_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["cupti"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cupti/{system}-{arch}/cuda_cupti-{system}-{arch}-{version}-archive.tar.xz", +) +download_and_copy( + name="cupti", + src_func=lambda system, arch, version: f"cuda_cupti-{system}-{arch}-{version}-archive/lib", + dst_path="lib/cupti", + variable="TRITON_CUPTI_LIB_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["cupti"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cupti/{system}-{arch}/cuda_cupti-{system}-{arch}-{version}-archive.tar.xz", +) +backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()] + + +def add_link_to_backends(): + for backend in backends: + update_symlink(backend.install_dir, backend.backend_dir) + + if backend.language_dir: + # Link the contents of each backend's `language` directory into + # `triton.language.extra`. + extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "language", "extra")) + for x in os.listdir(backend.language_dir): + src_dir = os.path.join(backend.language_dir, x) + install_dir = os.path.join(extra_dir, x) + update_symlink(install_dir, src_dir) + + if backend.tools_dir: + # Link the contents of each backend's `tools` directory into + # `triton.tools.extra`. + extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "tools", "extra")) + for x in os.listdir(backend.tools_dir): + src_dir = os.path.join(backend.tools_dir, x) + install_dir = os.path.join(extra_dir, x) + update_symlink(install_dir, src_dir) + + +def add_link_to_proton(): + proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton")) + proton_install_dir = os.path.join(os.path.dirname(__file__), "triton", "profiler") + update_symlink(proton_install_dir, proton_dir) + + +def add_links(): + add_link_to_backends() + if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON + add_link_to_proton() + + +class plugin_install(install): + + def run(self): + add_links() + install.run(self) + + +class plugin_develop(develop): + + def run(self): + add_links() + develop.run(self) + + +class plugin_bdist_wheel(bdist_wheel): + + def run(self): + add_links() + bdist_wheel.run(self) + + +class plugin_egginfo(egg_info): + + def run(self): + add_links() + egg_info.run(self) + + +package_data = { + "triton/tools/extra": sum((b.tools_package_data for b in backends), []), + **{f"triton/backends/{b.name}": b.package_data + for b in backends}, "triton/language/extra": sum((b.language_package_data for b in backends), []) +} + + +def get_extra_packages(extra_name): + packages = [] + extra_file_extensions = {"language": (".py"), "tools": (".c", ".h", ".cpp")} + assert extra_name in extra_file_extensions, f"{extra_name} extra is not valid" + + for backend in backends: + backend_extra_dir = getattr(backend, f"{extra_name}_dir", None) + if backend_extra_dir is None: + continue + + # Walk the specified directory of each backend to enumerate + # any subpackages, which will be added to extra_package. + for dir, dirs, files in os.walk(backend_extra_dir, followlinks=True): + if not any(f for f in files if f.endswith(extra_file_extensions[extra_name])) or dir == backend_extra_dir: + # Ignore directories with no relevant files + # or the root directory + continue + subpackage = os.path.relpath(dir, backend_extra_dir) + package = os.path.join(f"triton/{extra_name}/extra", subpackage) + packages.append(package) + + return list(packages) + + +def get_packages(): + packages = [ + "triton", + "triton/_C", + "triton/compiler", + "triton/language", + "triton/language/extra", + "triton/runtime", + "triton/backends", + "triton/tools", + "triton/tools/extra", + ] + packages += [f'triton/backends/{backend.name}' for backend in backends] + packages += get_extra_packages("language") + packages += get_extra_packages("tools") + if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON + packages += ["triton/profiler"] + + return packages + + +def get_entry_points(): + entry_points = {} + if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON + entry_points["console_scripts"] = [ + "proton-viewer = triton.profiler.viewer:main", + "proton = triton.profiler.proton:main", + ] + return entry_points + + +def get_git_commit_hash(length=8): + try: + cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD'] + return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8')) + except Exception: + return "" + + +def get_git_branch(): + try: + cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD'] + return subprocess.check_output(cmd).strip().decode('utf-8') + except Exception: + return "" + + +def get_git_version_suffix(): + branch = get_git_branch() + if branch.startswith("release"): + return "" + else: + return get_git_commit_hash() + + +setup( + name=os.environ.get("TRITON_WHEEL_NAME", "triton"), + version="3.3.1" + get_git_version_suffix() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""), + author="Philippe Tillet", + author_email="phil@openai.com", + description="A language and compiler for custom Deep Learning operations", + long_description="", + install_requires=["setuptools>=40.8.0"], + packages=get_packages(), + entry_points=get_entry_points(), + package_data=package_data, + include_package_data=True, + ext_modules=[CMakeExtension("triton", "triton/_C/")], + cmdclass={ + "build_ext": CMakeBuild, + "build_py": CMakeBuildPy, + "clean": CMakeClean, + "install": plugin_install, + "develop": plugin_develop, + "bdist_wheel": plugin_bdist_wheel, + "egg_info": plugin_egginfo, + }, + zip_safe=False, + # for PyPI + keywords=["Compiler", "Deep Learning"], + url="https://github.com/triton-lang/triton/", + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Topic :: Software Development :: Build Tools", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + ], + test_suite="tests", + extras_require={ + "build": [ + "cmake>=3.20", + "lit", + ], + "tests": [ + "autopep8", + "isort", + "numpy", + "pytest", + "pytest-forked", + "pytest-xdist", + "scipy>=1.7.1", + "llnl-hatchet", + ], + "tutorials": [ + "matplotlib", + "pandas", + "tabulate", + ], + }, +) diff --git a/third_party/enflame/include/triton/python/src/interpreter.cc b/third_party/enflame/include/triton/python/src/interpreter.cc new file mode 100644 index 000000000..747a0cc17 --- /dev/null +++ b/third_party/enflame/include/triton/python/src/interpreter.cc @@ -0,0 +1,740 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace { + +struct npy_half { + uint16_t value; +}; + +enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; + +std::mutex atomic_op_guard; + +template +constexpr bool is_reinterpret_cast_to_atomic_safe = + std::is_trivially_copyable_v && + std::is_trivially_copyable_v> && + std::is_standard_layout_v && std::is_standard_layout_v> && + sizeof(T) == sizeof(std::atomic) && + alignof(T) == alignof(std::atomic); + +enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; + +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, std::memory_order_acq_rel}, + {MemSemantic::ACQUIRE, std::memory_order_acquire}, + {MemSemantic::RELEASE, std::memory_order_release}, + {MemSemantic::RELAXED, std::memory_order_relaxed}, +}; + +template +T atomic_cmp(T *ptr, T val, std::memory_order order) { + auto cmp = [](T old, T val) { + if constexpr (is_min) { + return old > val; + } else { + return old < val; + } + }; + + T old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_ptr = reinterpret_cast *>(ptr); + old_val = atomic_ptr->load(order); + while (cmp(old_val, val)) { + if (atomic_ptr->compare_exchange_weak(old_val, val, order, order)) { + break; + } + } + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *ptr; + if (cmp(old_val, val)) { + *ptr = val; + } + } + return old_val; +} + +template T atomic_fadd(T *loc, T value, std::memory_order order) { + static_assert(std::is_floating_point::value, + "T must be a floating-point type"); + T old_value; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + T new_value; + std::atomic *atomic_loc = reinterpret_cast *>(loc); + old_value = atomic_loc->load(order); + do { + new_value = old_value + value; + } while ( + !atomic_loc->compare_exchange_weak(old_value, new_value, order, order)); + } else { + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = old_value + value; + } + + return old_value; +} + +/** Create a value of type `To` from the bits of `from`. + * + * similar to `std::bit_cast` but compatible with C++17, + * should perform similar to `*reinterpret_cast(&from)` + * or through punning without expecting any undefined behaviors. + * + * Note: taken from + * https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/utils.hpp#L32 + * with simplification. + */ +template +inline To BitCast(const From &from) noexcept { + static_assert(sizeof(To) == sizeof(From), + "both data types must have the same size"); + + static_assert(std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + "both data types must be trivially copyable"); + + To to; + memcpy(&to, &from, sizeof(from)); + return to; +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L14 +template +inline uint16_t FromFloatBits(uint32_t f) { + uint32_t f_exp, f_sig; + uint16_t h_sgn, h_exp, h_sig; + + h_sgn = (uint16_t)((f & 0x80000000u) >> 16); + f_exp = (f & 0x7f800000u); + + /* Exponent overflow/NaN converts to signed inf/NaN */ + if (f_exp >= 0x47800000u) { + if (f_exp == 0x7f800000u) { + /* Inf or NaN */ + f_sig = (f & 0x007fffffu); + if (f_sig != 0) { + /* NaN - propagate the flag in the significand... */ + uint16_t ret = (uint16_t)(0x7c00u + (f_sig >> 13)); + /* ...but make sure it stays a NaN */ + if (ret == 0x7c00u) { + ret++; + } + return h_sgn + ret; + } else { + /* signed inf */ + return (uint16_t)(h_sgn + 0x7c00u); + } + } else { + if constexpr (gen_overflow) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error("overflow to signed inf"); + } + return (uint16_t)(h_sgn + 0x7c00u); + } + } + + /* Exponent underflow converts to a subnormal half or signed zero */ + if (f_exp <= 0x38000000u) { + /* + * Signed zeros, subnormal floats, and floats with small + * exponents all convert to signed zero half-floats. + */ + if (f_exp < 0x33000000u) { + if constexpr (gen_underflow) { + /* If f != 0, it underflowed to 0 */ + if ((f & 0x7fffffff) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + return h_sgn; + } + /* Make the subnormal significand */ + f_exp >>= 23; + f_sig = (0x00800000u + (f & 0x007fffffu)); + if constexpr (gen_underflow) { + /* If it's not exactly represented, it underflowed */ + if ((f_sig & (((uint32_t)1 << (126 - f_exp)) - 1)) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + /* + * Usually the significand is shifted by 13. For subnormals an + * additional shift needs to occur. This shift is one for the largest + * exponent giving a subnormal `f_exp = 0x38000000 >> 23 = 112`, which + * offsets the new first bit. At most the shift can be 1+10 bits. + */ + f_sig >>= (113 - f_exp); + /* Handle rounding by adding 1 to the bit beyond half precision */ + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. However, the (113 - f_exp) + * shift can lose up to 11 bits, so the || checks them in the original. + * In all other cases, we can just add one. + */ + if (((f_sig & 0x00003fffu) != 0x00001000u) || (f & 0x000007ffu)) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp from zero to one and h_sig will be zero. + * This is the correct result. + */ + return (uint16_t)(h_sgn + h_sig); + } + + /* Regular case with no overflow or underflow */ + h_exp = (uint16_t)((f_exp - 0x38000000u) >> 13); + /* Handle rounding by adding 1 to the bit beyond half precision */ + f_sig = (f & 0x007fffffu); + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. In all other cases, we do. + */ + if ((f_sig & 0x00003fffu) != 0x00001000u) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp by one and h_sig will be zero. This is the + * correct result. h_exp may increment to 15, at greatest, in + * which case the result overflows to a signed inf. + */ + if constexpr (gen_overflow) { + h_sig += h_exp; + if (h_sig == 0x7c00u) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error(""); + } + return h_sgn + h_sig; + } else { + return h_sgn + h_exp + h_sig; + } +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L269 +constexpr uint32_t ToFloatBits(uint16_t h) { + uint16_t h_exp = (h & 0x7c00u); + uint32_t f_sgn = ((uint32_t)h & 0x8000u) << 16; + switch (h_exp) { + case 0x0000u: { // 0 or subnormal + uint16_t h_sig = (h & 0x03ffu); + // Signed zero + if (h_sig == 0) { + return f_sgn; + } + // Subnormal + h_sig <<= 1; + while ((h_sig & 0x0400u) == 0) { + h_sig <<= 1; + h_exp++; + } + uint32_t f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23; + uint32_t f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13; + return f_sgn + f_exp + f_sig; + } + case 0x7c00u: // inf or NaN + // All-ones exponent and a copy of the significand + return f_sgn + 0x7f800000u + (((uint32_t)(h & 0x03ffu)) << 13); + default: // normalized + // Just need to adjust the exponent and shift + return f_sgn + (((uint32_t)(h & 0x7fffu) + 0x1c000u) << 13); + } +} + +npy_half npy_float_to_half(float f) { + return {FromFloatBits(BitCast(f))}; +} + +float npy_half_to_float(npy_half h) { + return BitCast(ToFloatBits(h.value)); +} + +template <> +npy_half atomic_fadd(npy_half *loc, npy_half value, + std::memory_order order) { + npy_half old_value; + + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = npy_float_to_half(npy_half_to_float(old_value) + + npy_half_to_float(value)); + + return old_value; +} + +class AtomicOp { +public: + AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order) + : ptr(ptr), numel(numel), order(order) {} + + void apply() { + for (size_t i = 0; i < numel; ++i) { + applyAt(reinterpret_cast(ptr[i]), i); + } + } + + virtual ~AtomicOp() = default; + +protected: + virtual void applyAt(void *, size_t i) = 0; + + const uint64_t *ptr; + size_t numel; + std::memory_order order; +}; + +template class AtomicRMWOpBase : public AtomicOp { +public: + AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, + const bool *mask, size_t numel, std::memory_order order) + : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} + +protected: + void applyAt(void *loc, size_t i) override final { + if (mask[i]) { + DType *ptr = static_cast(loc); + *(static_cast(ret) + i) = + applyAtMasked(ptr, *(static_cast(val) + i), order); + } + } + + virtual DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) = 0; + + const void *val; + void *ret; + const bool *mask; +}; + +template +class AtomicRMWOp : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_add_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc + value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_fadd(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_and_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc & value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_or_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc | value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_xor_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc ^ value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = atomic_loc->exchange(value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = value; + } + return old_val; + } +}; + +template +void atomic_compare_exchange_strong(void *loc, void *expected, + const void *desired, size_t i, + std::memory_order order) { + T desired_val = *(static_cast(desired) + i); + T *expected_uint = static_cast(expected) + i; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = reinterpret_cast *>(loc); + atomic_loc->compare_exchange_strong(*expected_uint, desired_val, order, + order); + } else { + const std::lock_guard lock(atomic_op_guard); + T *atomic_loc = static_cast(loc); + if (*atomic_loc == *expected_uint) { + *atomic_loc = desired_val; + } else { + *expected_uint = *atomic_loc; + } + } +} + +class AtomicCASOp : public AtomicOp { +public: + AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, + size_t itemsize, size_t numel, std::memory_order order) + : AtomicOp(ptr, numel, order), expected(expected), desired(desired), + itemsize(itemsize) {} + +protected: + void applyAt(void *loc, size_t i) override { + // Atomic operations perform bitwise comparison, so it's safe to + // use number of bytes (itemsize) to determine the type of pointers + if (itemsize == 1) { + atomic_compare_exchange_strong(loc, expected, desired, i, order); + } else if (itemsize == 2) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else if (itemsize == 4) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else if (itemsize == 8) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else { + throw std::invalid_argument("Invalid byte size"); + } + } + +private: + void *expected; + const void *desired; + size_t itemsize; +}; + +// This is a workaround because explicit template parameter list for lambdas is +// a C++20 extension: +// auto try_make_op = [&]() { +// if (dtype.is(pybind11::dtype::of())) { +// atomic_op = std::make_unique>(ptr, val, ret, mask, +// numel, order); +// } +// }; +template struct OpCreator { + pybind11::dtype dtype; + const uint64_t *ptr; + const void *val; + void *ret; + const bool *mask; + size_t numel; + std::memory_order order; + std::unique_ptr &atomic_op; + + template void create() { + if (!atomic_op && dtype.is(pybind11::dtype::of())) { + atomic_op = std::make_unique>(ptr, val, ret, mask, + numel, order); + } + } +}; + +template <> template <> void OpCreator::create() { + if (!atomic_op && dtype.char_() == 'e') { // float16 + // workaround until https://github.com/pybind/pybind11/issues/4061 is + // implemented + atomic_op = std::make_unique>( + ptr, val, ret, mask, numel, order); + } +}; + +template +std::unique_ptr +makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, + void *ret, const bool *mask, size_t numel, + std::memory_order order) { + // Iterate over all supported data types, make one that matches, and return + std::unique_ptr atomic_op; + OpCreator try_make_op{dtype, ptr, val, ret, + mask, numel, order, atomic_op}; + + (try_make_op.template create(), ...); + if (!atomic_op) { + throw std::invalid_argument("Unsupported data type"); + } + // Make it a unique_ptr + return atomic_op; +} + +} // namespace + +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "RMW_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX) + .export_values(); + + m.def("load", + [](py::array_t ptr, py::array_t mask, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptr.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", + [](py::array_t ptr, py::array value, py::array_t mask) { + int numel = ptr.size(); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_value = value.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) { + memcpy(reinterpret_cast(reshaped_ptr.mutable_at(i)), + reshaped_value.data(i), value.dtype().itemsize()); + } + } + }); + + m.def("atomic_rmw", + [](RMWOp rmw_op, py::array_t ptr, py::array val, + py::array_t mask, MemSemantic sem) -> py::array { + std::memory_order order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = val.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto *ptr_data = reshaped_ptr.data(); + auto *mask_data = reshaped_mask.data(); + auto *val_data = static_cast(reshaped_val.data()); + auto *ret_data = static_cast(ret.mutable_data()); + + std::unique_ptr atomic_op; + +#define MAKE_ATOMIC_RMW_OP(OP_NAME, ...) \ + case OP_NAME: \ + atomic_op = makeAtomicRMWOp( \ + ret_dtype, ptr_data, val_data, ret_data, mask_data, numel, order); \ + break; + + switch (rmw_op) { + MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, npy_half, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MAX, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMAX, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MIN, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMIN, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XCHG, int32_t, uint32_t, int64_t, + uint64_t) + default: + throw std::invalid_argument("Unsupported RMW operation"); + } + +#undef MAKE_ATOMIC_RMW_OP + + atomic_op->apply(); + return ret.reshape(shape); + }); + + m.def("atomic_cas", + [](py::array_t ptr, py::array &cmp, py::array &val, + MemSemantic sem) -> py::array { + std::memory_order order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = cmp.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array reshaped_cmp = cmp.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto itemsize = cmp.itemsize(); + memcpy(static_cast(ret.mutable_data()), + static_cast(reshaped_cmp.data()), + itemsize * numel); + AtomicCASOp(reshaped_ptr.data(), ret.mutable_data(), + static_cast(reshaped_val.data()), itemsize, + numel, order) + .apply(); + return ret.reshape(shape); + }); +} diff --git a/third_party/enflame/include/triton/python/src/ir.cc b/third_party/enflame/include/triton/python/src/ir.cc new file mode 100644 index 000000000..680b6ee12 --- /dev/null +++ b/third_party/enflame/include/triton/python/src/ir.cc @@ -0,0 +1,1916 @@ +#include +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/SourceMgr.h" + +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; + +llvm::raw_fd_ostream &mlir_dumps() { + std::error_code EC; + static llvm::raw_fd_ostream S(::triton::tools::getStrEnv("MLIR_DUMP_PATH"), + EC, llvm::sys::fs::CD_CreateAlways); + assert(!EC); + return S; +} + +llvm::raw_ostream &mlir_dumps_or_dbgs() { + if (!::triton::tools::getStrEnv("MLIR_DUMP_PATH").empty()) { + return mlir_dumps(); + } else { + return llvm::dbgs(); + } +} + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + OpBuilder &getBuilder() { return *builder; } + MLIRContext *getContext() { return builder->getContext(); } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(FileLineColLoc::get(context, fileName, line, column)); + } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; + +// Run the pass manager under a source manager diagnostic handler, which +// enables emitted MLIR diagnostics to directly reference Python source +// code. This diagnostic handler supports filtering diagnostic info by +// severity levels. +struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler { + TritonSourceMgrDiagnosticHandler(MLIRContext *ctx, + DiagnosticSeverity minSeverity) + : SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) { + setHandler([this, minSeverity](Diagnostic &diag) { + auto severity = diag.getSeverity(); + switch (severity) { + case DiagnosticSeverity::Error: + break; + case DiagnosticSeverity::Warning: + if (minSeverity == DiagnosticSeverity::Error) + return success(); + break; + case DiagnosticSeverity::Remark: + if (minSeverity == DiagnosticSeverity::Error || + minSeverity == DiagnosticSeverity::Warning) + return success(); + break; + case DiagnosticSeverity::Note: + // notes are handled somewhere else. + return failure(); + default: + llvm_unreachable("Unknown diagnostic severity"); + } + emitDiagnostic(diag); + return success(); + }); + } + + llvm::SourceMgr sourceMgr; +}; + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +// Function to parse a comma-separated string into a vector of C-style strings +llvm::SmallVector +parseCommaSeparatedValues(const std::string &input, + llvm::SmallVector &storage) { + llvm::SmallVector split; + llvm::SmallVector result; + StringRef(input.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + return result; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +// Allow dump a reproducer in the console on crash. +struct ConsoleReproducerStream : public mlir::ReproducerStream { + ~ConsoleReproducerStream() override {} + + StringRef description() override { + return "std::errs, please share the reproducer above with Triton project."; + } + raw_ostream &os() override { return llvm::errs(); } +}; + +static ReproducerStreamFactory makeConsoleReproducer() { + return [](std::string &error) -> std::unique_ptr { + return std::make_unique(); + }; +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .value("CV", CacheModifier::CV) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .export_values(); + + py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) + .value("E4M3", ScaleDotElemType::E4M3) + .value("E5M2", ScaleDotElemType::E5M2) + .value("E2M3", ScaleDotElemType::E2M3) + .value("E3M2", ScaleDotElemType::E3M2) + .value("E2M1", ScaleDotElemType::E2M1) + .value("BF16", ScaleDotElemType::BF16) + .value("FP16", ScaleDotElemType::FP16) + .export_values(); + + py::class_(m, "context", py::module_local()) + .def(py::init<>()) + .def("printOpOnDiagnostic", + [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) + .def("printStackTraceOnDiagnostic", + [](MLIRContext &self, bool v) { + self.printStackTraceOnDiagnostic(v); + }) + .def("disable_multithreading", + [](MLIRContext &self) { self.disableMultithreading(); }); + + py::class_(m, "source_mgr_diag", + py::module_local()) + .def(py::init()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + mlir::LLVM::registerInlinerInterface(registry); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + mlir::LLVM::registerInlinerInterface(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__eq__", + [](Type &self, py::object &other) { + Type *other_ty = py::cast(other); + return (other_ty != nullptr) && (*other_ty == self); + }) + .def("__ne__", + [](Type &self, py::object &other) { + Type *other_ty = py::cast(other); + return (other_ty == nullptr) || (*other_ty != self); + }) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "value", py::module_local()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", &Block::dump) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + py::class_(m, "unit_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { self->dump(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self->print(os, printingFlags); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", [](OpState &self) -> bool { + return succeeded(verify(self.getOperation())); + }); + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_bool_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::bool_(ret.getValue()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", &ModuleOp::dump) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self.print(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("get_entry_func_name", + [](ModuleOp &self) -> std::string { + for (auto &op : self.getOps()) { + if (auto func = dyn_cast(op)) { + if (LLVM::isKernel(func)) + return func.getName().str(); + } + } + return ""; + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + /* + * def ty_to_cpp(ty) is the consumer of this function. + * If the type is a ptr it expects ty[0] == '*', else the type itself. + */ + + .def("get_function_signature", + [](ModuleOp &self, FuncOp &func) -> std::vector { + std::vector strVec; + + auto type = func.getFunctionType(); + unsigned numArgs = type.getNumInputs(); + for (unsigned i = 0; i != numArgs; ++i) { + std::string tempType; + llvm::raw_string_ostream os(tempType); + + auto ty = type.getInput(i); + if (auto attributes = func.getCallableArgAttrs()) { + Attribute attr = attributes[i]; + // Check for tt.nv_tma_desc = 1 + if (auto dAttr = dyn_cast(attr)) { + if (dAttr.contains("tt.nv_tma_desc")) { + strVec.push_back("nvTmaDesc"); + continue; + } + } + } + if (auto ptrType = dyn_cast(ty)) { + auto pType = ptrType.getPointeeType(); + os << "*"; + pType.print(os); + } else { + ty.print(os); + } + strVec.push_back(tempType); + } + return strVec; + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + generateLocationsFromIR(/*raw_ostream=*/llvm::nulls(), + /*fileName=*/fileName, + /*op=*/self, /*flags=*/{}); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def("get_num_args", &FuncOp::getNumArguments) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + if (arg_no >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def("finalize", + [](FuncOp &self) -> void { + // Check if the result of tl.advance is used + self.walk([&](AdvanceOp op) { + if (op->getResult(0).use_empty()) + outputWarning(op->getLoc(), "The result of tl.advance is not " + "being used. Note that tl.advance " + "does not have any side effects. " + "To move the block pointer, you " + "need to assign the result of " + "tl.advance to a variable."); + }); + }) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + // Attr + .def( + "get_unit_attr", + [](TritonOpBuilder &self) { return self.getBuilder().getUnitAttr(); }) + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + v, self.getBuilder().getI1Type())); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + APFloat(type.getFloatSemantics(), std::to_string(v)), type); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = dyn_cast(type)) + return self.create(0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(val, intTy); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, int start, int end) -> Value { + auto retType = RankedTensorType::get( + {end - start}, self.getBuilder().getI32Type()); + return self.create(retType, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_tensor_descriptor_type", + [](TritonOpBuilder &self, Type blockTy) -> Type { + auto ctx = self.getContext(); + return triton::TensorDescType::get( + ctx, cast(blockTy)); + }) + .def("create_reinterpret_tensor_descriptor", + [](TritonOpBuilder &self, Value desc_ptr, Type blockTy) -> Value { + auto ctx = self.getContext(); + auto resultTy = triton::TensorDescType::get( + ctx, cast(blockTy)); + return self.create(resultTy, desc_ptr); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value desc, std::vector &indices, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + auto descTy = cast(desc.getType()); + auto resTy = descTy.getBlockType(); + return self.create( + resTy, desc, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_gather", + [](TritonOpBuilder &self, Value desc, Value x_indices, Value y_index, + Type type) -> Value { + return self.create( + type, desc, x_indices, y_index); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value desc, Value value, + std::vector &indices) -> void { + self.create(desc, value, indices); + }) + .def("create_descriptor_scatter", + [](TritonOpBuilder &self, Value desc, Value value, Value x_indices, + Value y_index) -> void { + self.create(desc, x_indices, + y_index, value); + }) + .def("create_tensormap_create", + [](TritonOpBuilder &self, Value desc_ptr, Value global_address, + std::vector box_dim, std::vector global_dim, + std::vector global_stride, + std::vector element_stride, int32_t elem_type, + int32_t interleave_layout, int32_t swizzle_mode, + int32_t fill_mode) { + self.create( + desc_ptr, global_address, box_dim, global_dim, global_stride, + element_stride, elem_type, interleave_layout, swizzle_mode, + fill_mode); + }) + .def("create_tensormap_fenceproxy_acquire", + [](TritonOpBuilder &self, Value desc_ptr) { + self.create(desc_ptr); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + auto argType = + cast(arg.getType()).getElementType(); + return self.create( + RankedTensorType::get(shape, argType), arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + retShape.insert(retShape.begin() + axis, 1); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create( + RankedTensorType::get(shape, lhsType.getElementType()), lhs, + rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, + std::vector &order) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + auto retShape = applyPermutation(argType.getShape(), order); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, order); + }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold( + RankedTensorType::get(shape, argType.getElementType()), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + auto argType = arg.getType(); + auto ret = self.createOrFold( + RankedTensorType::get(shape, argType), arg); + return ret; + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_dot_scaled", + [](TritonOpBuilder &self, mlir::Value &lhs, + std::optional &lhs_scale, + ScaleDotElemType lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, + ScaleDotElemType rhs_format, bool fast_math, + mlir::Value &c) -> mlir::Value { + return self.create(c.getType(), lhs, rhs, c, + lhs_scale.value_or(Value()), + rhs_scale.value_or(Value()), + lhs_format, rhs_format, fast_math); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, int axis) + -> OpState { return self.create(operands, axis); }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values, + const std::vector &isSigned) -> void { + auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)); + self.create(prefixAttr, hex, values, isSigned); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + self.create(condition, messageAttr); + }) + .def("create_assume", + [](TritonOpBuilder &self, Value &condition) { + self.create(condition); + }) + .def("create_poison", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins) -> Value { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + }) + .def("create_gather", + [](TritonOpBuilder &self, Value src, Value indices, int axis) + -> Value { return self.create(src, indices, axis); }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { self.create(); }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }) + // Make a tensor descriptor + .def("create_make_tensor_descriptor", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, + std::vector &tensorShape) -> Value { + return self.create(base, shape, strides, + tensorShape); + }) + // Proton Ops + .def("create_proton_record", + [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void { + self.create(isStart, regionId); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) -> bool { + auto *context = self.getContext(); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + std::string funcToDump; + if (!haveDump) { + funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP"); + bool isEnvValueBool = + triton::tools::isEnvValueBool(funcToDump).has_value(); + if (!funcToDump.empty() && !isEnvValueBool) + haveDump = true; + } + if (haveDump) { + context->disableMultithreading(); + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + printingFlags.enableDebugInfo(); + auto printAlways = [funcToDump](Pass *, Operation *op) -> bool { + if (funcToDump.empty()) + return true; + if (auto mod = dyn_cast(op)) { + return mod.lookupSymbol(funcToDump); + } + if (auto func = dyn_cast(op)) { + return SymbolTable::getSymbolName(func).getValue() == + funcToDump; + } + + return false; + }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/printAlways, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, mlir_dumps_or_dbgs(), + printingFlags); + } + return haveDump; + }) + .def("get_pipeline_str", + [](PassManager &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.printAsTextualPipeline(os); + return str; + }) + .def("run", [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto *context = mod.getContext(); + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + // Save a reproducer for the current pass manager invocation + // immediately. + makeReproducer(anchorName, passes, op, reproducerPath); + // But if the pass manager crashes, attempt to generate a local + // reproducer instead. + context->disableMultithreading(); + self.enableCrashReproducerGeneration(reproducerPath, + /*genLocalReproducer=*/true); + } else { + self.enableCrashReproducerGeneration(makeConsoleReproducer()); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector storage; + llvm::SmallVector debugTypes = + parseCommaSeparatedValues(debugOnly, storage); + ::llvm::DebugFlag = true; + using namespace llvm; + setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + // setting up diagnostics + bool showOperations = false, showStacktraces = false, + showRemarks = false, showWarnings = false; + + if (auto enableDiagnostics = + triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS"); + !enableDiagnostics.empty()) { + llvm::SmallVector storage; + parseCommaSeparatedValues(enableDiagnostics, storage); + for (auto &str : storage) { + if (str == "warnings") { + showWarnings = true; + } else if (str == "remarks") { + showRemarks = true; + } else if (str == "stacktraces") { + showStacktraces = true; + } else if (str == "operations") { + showOperations = true; + } + // we show errors by default, so no need to set it + } + } + + DiagnosticSeverity minSeverity = showWarnings + ? DiagnosticSeverity::Warning + : DiagnosticSeverity::Error; + minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity; + + TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity); + + context->printOpOnDiagnostic(showOperations); + context->printStackTraceOnDiagnostic(showStacktraces); + if (showStacktraces) { + context->disableMultithreading(); + } + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); +} diff --git a/third_party/enflame/include/triton/python/src/llvm.cc b/third_party/enflame/include/triton/python/src/llvm.cc new file mode 100644 index 000000000..c86bf671a --- /dev/null +++ b/third_party/enflame/include/triton/python/src/llvm.cc @@ -0,0 +1,497 @@ +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Instrumentation/AddressSanitizer.h" +#include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h" +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace llvm { +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; +} // namespace llvm + +using namespace llvm; + +std::unique_ptr +createTargetMachine(llvm::Module *module, std::string proc, + bool enable_fp_fusion, const std::string &features) { + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + llvm::TargetOptions opt; + bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + opt.MCOptions.AsmVerbose = true; + opt.MCOptions.PreserveAsmComments = true; + std::unique_ptr machine{target->createTargetMachine( + module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, + disableLLVMOpt ? llvm::CodeGenOptLevel::None + : llvm::CodeGenOptLevel::Aggressive)}; + return machine; +} + +std::string translateLLVMIRToASM(llvm::Module &module, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject) { + using namespace mlir; + // options + auto options = llvm::cl::getRegisteredOptions(); + for (std::string flag : flags) { + auto *shortPtr = static_cast *>(options[flag]); + assert(shortPtr); + shortPtr->setValue(true); + } + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optIt = options.find("print-after-all"); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + } + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + const bool enabledTiming = triton::tools::getBoolEnv("LLVM_ENABLE_TIMING"); + if (enabledTiming) { + llvm::TimePassesIsEnabled = true; + llvm::TimePassesPerRun = true; + } + + pm.run(module); + + SmallString<0> timePassesStr; + raw_svector_ostream reportStream(timePassesStr); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + // module->print(llvm::outs(), nullptr); + + // create machine + module.setTargetTriple(triple); + auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); + // set data layout + module.setDataLayout(machine->createDataLayout()); + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager pass; + // emit + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile + : llvm::CodeGenFileType::AssemblyFile; + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); + pass.run(module); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + } + return result; +} + +using ret = py::return_value_policy; + +void init_triton_llvm(py::module &&m) { + + py::class_(m, "context", py::module_local()) + .def(py::init<>()); + py::class_(m, "source_mgr", py::module_local()) + .def(py::init<>()); + + py::class_(m, "function_list") + .def( + "__iter__", + [](llvm::Module::FunctionListType &s) { + return py::make_iterator(s.begin(), s.end()); + }, + py::keep_alive<0, 1>()); + + // Module Flag behavior. See + // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 + // for details. + py::class_(m, "module_flag_behavior", + py::module_local()); + m.attr("MODULE_FLAG_BEHAVIOR_ERROR") = llvm::Module::Error; + m.attr("MODULE_FLAG_BEHAVIOR_WARNING") = llvm::Module::Warning; + m.attr("MODULE_FLAG_BEHAVIOR_REQUIRE") = llvm::Module::Require; + m.attr("MODULE_FLAG_BEHAVIOR_OVERRIDE") = llvm::Module::Override; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND") = llvm::Module::Append; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND_UNIQUE") = llvm::Module::AppendUnique; + m.attr("MODULE_FLAG_BEHAVIOR_MAX") = llvm::Module::Max; + m.attr("MODULE_FLAG_BEHAVIOR_MIN") = llvm::Module::Min; + + py::class_(m, "module", py::module_local()) + .def( + "__str__", + [](llvm::Module *self) { + std::string str; + llvm::raw_string_ostream os(str); + os << *self; + return os.str(); + }, + ret::take_ownership) + .def( + "get_functions", + [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + // Note: Backends assume that we are compiling exactly one kernel + // (i.e. one function that's that's called by the CPU) and that it's + // the first function in this list. + return mod->getFunctionList(); + }, + ret::reference_internal) + .def("add_flag", + [](llvm::Module *mod, llvm::Module::ModFlagBehavior behavior, + std::string &key, uint32_t value) { + return mod->addModuleFlag(behavior, key, value); + }); + + py::class_(m, "function", py::module_local()) + .def_property_readonly( + "name", [](llvm::Function *fn) { return fn->getName().str(); }) + .def("set_calling_conv", &llvm::Function::setCallingConv) + .def("add_fn_attr", [](llvm::Function *fn, std::string &name, + std::string &val) { fn->addFnAttr(name, val); }) + .def("add_fn_asan_attr", + [](llvm::Function *fn) { + fn->addFnAttr(llvm::Attribute::SanitizeAddress); + }) + .def("add_fn_target_feature", + [](llvm::Function *fn, std::string &val) { + fn->addFnAttr("target-features", val); + }) + // Sets the nvvm.maxreg property on the given function. + .def("set_nvvm_maxnreg", + [](llvm::Function *fn, int maxnreg) { + auto op = MDNode::get( + fn->getContext(), + { + ValueAsMetadata::get(fn), + MDString::get(fn->getContext(), "maxnreg"), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(fn->getContext()), maxnreg)), + }); + fn->getParent() + ->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(op); + }) + // External functions that are definitions (i.e. not declarations) are + // kernel functions. + .def("is_declaration", &llvm::Function::isDeclaration) + .def("is_external_linkage", [](llvm::Function *fn) { + return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage; + }); + + // optimization levels + py::class_(m, "optimization_level", + py::module_local()); + m.attr("OPTIMIZE_O0") = llvm::OptimizationLevel::O0; + m.attr("OPTIMIZE_O1") = llvm::OptimizationLevel::O1; + m.attr("OPTIMIZE_O2") = llvm::OptimizationLevel::O2; + m.attr("OPTIMIZE_O3") = llvm::OptimizationLevel::O3; + m.attr("OPTIMIZE_Os") = llvm::OptimizationLevel::Os; + m.attr("OPTIMIZE_Oz") = llvm::OptimizationLevel::Oz; + + m.def( + "to_module", + [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { + return mlir::translateModuleToLLVMIR(mod, ctx); + }, + py::keep_alive<0, 2>()); + + m.def("attach_datalayout", [](llvm::Module *mod, const std::string triple, + const std::string proc, + const std::string features) { + std::string error; + auto target = llvm::TargetRegistry::lookupTarget(triple, error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } + llvm::TargetOptions opt; + // Target machine is only used to create the data layout. + std::unique_ptr machine{target->createTargetMachine( + triple, proc, features, opt, llvm::Reloc::PIC_, std::nullopt, + llvm::CodeGenOptLevel::None)}; + // set data layout + mod->setDataLayout(machine->createDataLayout()); + }); + + m.def( + "optimize_module", + [](llvm::Module *mod, const llvm::OptimizationLevel &opt, + std::string arch, std::string features, std::vector flags, + bool enable_fp_fusion) { + if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) + return; + // Check to see if we are passing a list of flags to disable + // optimizations. + auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + auto options = llvm::cl::getRegisteredOptions(); + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + using namespace llvm; + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + PassInstrumentationCallbacks *instrCbPtr = nullptr; + PassInstrumentationCallbacks passInstrCb; + StandardInstrumentations standardInstr(mod->getContext(), + /*DebugLogging*/ true); + if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optMap = llvm::cl::getRegisteredOptions(); + auto optIt = optMap.find("print-after-all"); + if (optIt != optMap.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + standardInstr.registerCallbacks(passInstrCb, &mam); + instrCbPtr = &passInstrCb; + } + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. + // This cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also + // applies some scheduling that helps performance in some cases. We + // should work on using NVPTX target instead and address the performance + // regressions with some scheduling solution. + tuningOptions.SLPVectorization = true; + + std::string pluginFile = + mlir::triton::tools::getStrEnv("LLVM_PASS_PLUGIN_PATH"); + + // We don't pass the targetMachine to the LLVM-IR pass builder, unless + // `arch` is specified. + // + // Don't set target machine in LLVM pass builder when using LLVM IR + // level plugins. LLVM IR level plugin passes typically want to insert + // calls to externally generated code (i.e. precompile a Cuda/Hip kernel + // with Clang and then insert a call to it within an instrumentation + // pass) setting the targetMachine value here can can cause a mismatch + // in the target machine between the MLIR and Clang generated kernels + // and break the lowering of some target specific intrinsics. + std::unique_ptr targetMachine = nullptr; + if (!arch.empty() && pluginFile.empty()) + targetMachine = + createTargetMachine(mod, arch, enable_fp_fusion, features); + PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions, + std::nullopt, instrCbPtr); + + if (!pluginFile.empty()) { + // TODO: Add some logging here that we inserted a pass into the LLVM + // pass pipeline + auto passPlugin = llvm::PassPlugin::Load(pluginFile); + if (!passPlugin) { + llvm::Error Err = passPlugin.takeError(); + std::string ErrMsg = + "Pass Plugin Error: " + llvm::toString(std::move(Err)); + throw std::runtime_error(ErrMsg); + } + passPlugin->registerPassBuilderCallbacks(pb); + } + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make + // sure all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); + bool enableAddressSanitizer = + mlir::triton::tools::getBoolEnv("TRITON_ENABLE_ASAN"); + if (enableAddressSanitizer) { + AddressSanitizerOptions Opts; + mpm.addPass(AddressSanitizerPass(Opts)); + } + mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); + mpm.run(*mod, mam); + }, + // Mandatory parameters + py::arg("mod"), py::arg("opt"), + // If we want to specify the target machine, we require additional + // (optional) parameters + py::arg("arch") = "", py::arg("features") = "", + py::arg("flags") = std::vector{}, + py::arg("enable_fp_fusion") = false); + + m.def( + "translate_to_asm", + [](std::string llvmIR, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + obj = translateLLVMIRToASM(*module, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(obj); + else + return py::str(obj); + }, + ret::take_ownership); + + m.def("init_targets", []() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + }); + }); + + m.def("link_extern_libs", [](llvm::Module *dstMod, + const std::vector &paths) { + if (paths.empty()) + return; + + LLVMContext &ctx = dstMod->getContext(); + llvm::Linker linker(*dstMod); + for (const std::string &path : paths) { + llvm::SMDiagnostic err; + std::unique_ptr libMod = llvm::parseIRFile(path, err, ctx); + if (!libMod) { + std::string message = "Failed to parse library at " + path; + throw std::invalid_argument(message); + } + libMod->setTargetTriple(dstMod->getTargetTriple()); + libMod->setDataLayout(dstMod->getDataLayout()); + + std::unordered_set externalFns; + for (llvm::Function &fn : libMod->functions()) { + if (!fn.isDeclaration()) + externalFns.insert(fn.getName().str()); + } + + if (linker.linkInModule(std::move(libMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + std::string message = "Failed to link library at " + path; + throw std::invalid_argument(message); + } + + // Mark linked-in functions as internal because backends use external + // linkage as a signifier of kernel functions. + for (llvm::Function &fn : dstMod->functions()) { + if (externalFns.count(fn.getName().str())) { + fn.setLinkage(llvm::GlobalValue::InternalLinkage); + } + } + } + }); +} + +void triton_stacktrace_signal_handler(void *) { + llvm::sys::PrintStackTrace(llvm::errs()); + raise(SIGABRT); +} + +void init_triton_stacktrace_hook(pybind11::module &m) { + if (mlir::triton::tools::getBoolEnv("TRITON_ENABLE_PYTHON_STACKTRACE")) { + llvm::sys::AddSignalHandler(triton_stacktrace_signal_handler, nullptr); + } +} diff --git a/third_party/enflame/include/triton/python/src/main.cc b/third_party/enflame/include/triton/python/src/main.cc new file mode 100644 index 000000000..82289edc0 --- /dev/null +++ b/third_party/enflame/include/triton/python/src/main.cc @@ -0,0 +1,55 @@ +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Signals.h" +#include + +namespace py = pybind11; + +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + +void init_triton_env_vars(pybind11::module &m); +void init_triton_ir(pybind11::module &&m); +void init_triton_llvm(pybind11::module &&m); +void init_triton_interpreter(pybind11::module &&m); +void init_triton_passes(pybind11::module &&m); +void init_triton_stacktrace_hook(pybind11::module &m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton_stacktrace_hook(m); + init_triton_env_vars(m); + init_triton_ir(m.def_submodule("ir")); + init_triton_passes(m.def_submodule("passes")); + init_triton_interpreter(m.def_submodule("interpreter")); + init_triton_llvm(m.def_submodule("llvm")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) +} diff --git a/third_party/enflame/include/triton/python/src/passes.cc b/third_party/enflame/include/triton/python/src/passes.cc new file mode 100644 index 000000000..0800ad989 --- /dev/null +++ b/third_party/enflame/include/triton/python/src/passes.cc @@ -0,0 +1,113 @@ +#include "mlir/Transforms/Passes.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" +#include +#include + +namespace py = pybind11; + +void init_triton_analysis(py::module &&m) { + py::class_(m, "allocation", py::module_local()) + .def(py::init()); + py::class_(m, "membar", py::module_local()) + .def(py::init()) + .def("run", &mlir::ModuleMembarAnalysis::run); +} + +void init_triton_passes_common(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass); + ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass); + ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass); + ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); + ADD_PASS_WRAPPER_0("add_cse", createCSEPass); + ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); + ADD_PASS_WRAPPER_0("print_ir", createPrintIRPass); +} + +void init_triton_passes_ttir(py::module &&m) { + using namespace mlir::triton; + ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass); + ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", + createRewriteTensorPointerPass); + ADD_PASS_WRAPPER_0("add_loop_unroll", createLoopUnrollPass); + ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", + createConvertTritonToTritonGPUPass, const std::string &, + int, int, int); +} + +void init_triton_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton::gpu; + ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); + ADD_PASS_WRAPPER_0("add_optimize_thread_locality", + createTritonGPUOptimizeThreadLocality); + ADD_PASS_OPTION_WRAPPER_2("add_pipeline", createTritonGPUPipeline, int, bool); + ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); + ADD_PASS_WRAPPER_0("add_reorder_instructions", + createTritonGPUReorderInstructions); + ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); + ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands", + createTritonGPUOptimizeDotOperands, bool); + ADD_PASS_WRAPPER_0("add_remove_layout_conversions", + createTritonGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_reduce_data_duplication", + createTritonGPUReduceDataDuplication); + ADD_PASS_WRAPPER_0("add_allocate_warp_groups", + createTritonGPUAllocateWarpGroups); + ADD_PASS_WRAPPER_0("add_allocate_shared_memory", createAllocateSharedMemory); + ADD_PASS_WRAPPER_0("add_allocate_global_scratch_memory", + createTritonGPUGlobalScratchAllocationPass); + ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", + createTritonGPUCombineTensorSelectAndIf); + ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", + createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_WRAPPER_0("add_fuse_nested_loops", createTritonGPUFuseNestedLoops); + ADD_PASS_WRAPPER_0("add_coalesce_async_copy", + createTritonGPUCoalesceAsyncCopy); + ADD_PASS_OPTION_WRAPPER_1("add_ws_task_partition", + createTritonGPUWSTaskPartition, int); + ADD_PASS_OPTION_WRAPPER_1("add_ws_data_partition", + createTritonGPUWSDataPartition, int); + ADD_PASS_OPTION_WRAPPER_1("add_ws_lowering", createTritonGPUWSLowering, int); + ADD_PASS_OPTION_WRAPPER_1("add_taskid_propagate", + createTritonGPUTaskIdPropagate, int); + ADD_PASS_OPTION_WRAPPER_4("add_ws_code_partition", + createTritonGPUWSCodePartition, int, int, int, int); + ADD_PASS_OPTION_WRAPPER_1("add_ping_pong_sync", createTritonGPUPingPongSync, + int); + ADD_PASS_OPTION_WRAPPER_1("add_ws_canonicalization", + createTritonGPUWSCanonicalization, int); +} + +void init_triton_passes_convert(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_scf_to_cf", createSCFToControlFlowPass); + ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); + ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); + ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); +} + +void init_triton_passes_llvmir(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_di_scope", createLLVMDIScopePass); +} + +void init_triton_passes(py::module &&m) { + init_triton_analysis(m.def_submodule("analysis")); + init_triton_passes_common(m.def_submodule("common")); + init_triton_passes_convert(m.def_submodule("convert")); + init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); + init_triton_passes_llvmir(m.def_submodule("llvmir")); +} diff --git a/third_party/enflame/include/triton/python/src/passes.h b/third_party/enflame/include/triton/python/src/passes.h new file mode 100644 index 000000000..629fe362d --- /dev/null +++ b/third_party/enflame/include/triton/python/src/passes.h @@ -0,0 +1,38 @@ +#define ADD_PASS_WRAPPER_0(name, builder) \ + m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) + +#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) + +#define ADD_PASS_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder(val0, val1)); \ + }) + +#define ADD_PASS_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder(val0, val1, val2)); \ + }) + +#define ADD_PASS_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) + +#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) + +#define ADD_PASS_OPTION_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder({val0, val1})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder({val0, val1, val2})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder({val0, val1, val2, val3})); }) diff --git a/third_party/enflame/include/triton/python/test-requirements.txt b/third_party/enflame/include/triton/python/test-requirements.txt new file mode 100644 index 000000000..07429ac00 --- /dev/null +++ b/third_party/enflame/include/triton/python/test-requirements.txt @@ -0,0 +1,8 @@ +autopep8 +isort +numpy +pytest +pytest-forked +pytest-xdist +scipy>=1.7.1 +llnl-hatchet diff --git a/third_party/enflame/include/triton/python/test/backend/extension_backend.c b/third_party/enflame/include/triton/python/test/backend/extension_backend.c new file mode 100644 index 000000000..4a1e08bf0 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/backend/extension_backend.c @@ -0,0 +1,42 @@ +#include +#include +#include + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + // create a struct to hold device properties + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", 1024, + "multiprocessor_count", 16, "sm_clock_rate", 2100, + "mem_clock_rate", 2300, "mem_bus_width", 2400); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + // get allocated registers and spilled registers from the function + int n_regs = 0; + int n_spills = 0; + int mod = 0; + int fun = 0; + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load dummy binary for the extension device"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for the extension device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "ext_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_ext_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} diff --git a/third_party/enflame/include/triton/python/test/backend/test_device_backend.py b/third_party/enflame/include/triton/python/test/backend/test_device_backend.py new file mode 100644 index 000000000..fdf53eae1 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/backend/test_device_backend.py @@ -0,0 +1,235 @@ +import functools +import hashlib +import importlib +import os +import shutil +import subprocess +import sysconfig +import tempfile +from pathlib import Path + +import torch + +import triton +import triton.language as tl +from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend) +from triton.compiler.make_launcher import make_so_cache_key +from triton.runtime.cache import get_cache_manager +from triton.runtime.driver import DriverBase + + +def build_for_backend(name, src, srcdir): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + + subprocess.check_call([cc, src, f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so]) + return so + + +class ExtensionUtils: + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(ExtensionUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + dirname = os.path.dirname(os.path.realpath(__file__)) + src = Path(os.path.join(dirname, "extension_backend.c")).read_text() + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + fname = "ext_utils.so" + cache_path = cache.get_file(fname) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = build_for_backend("ext_utils", src_path, tmpdir) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), fname, binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location("ext_utils", cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +class ExtensionDriver(DriverBase): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(ExtensionDriver, cls).__new__(cls) + return cls.instance + + def __init__(self): + self.utils = ExtensionUtils() + + +class ExtensionBackend(BaseBackend): + stub_so_path = "" + + def __init__(self, device_type: str) -> None: + super(ExtensionBackend, self).__init__(device_type) + self.driver = ExtensionDriver() + self.version_key = None + + def add_stages(self, arch, extern_libs, stages): + filter_in_stages = ["ast", "ttir", "ttgir"] + filter_out_stages = [] + for key, _ in stages.items(): + if key not in filter_in_stages: + filter_out_stages.append(key) + for filter_out_key in filter_out_stages: + stages.pop(filter_out_key) + + def add_meta_info(self, ir, cur_module, next_module, metadata, asm): + metadata["name"] = "extension_backend_name" + + def get_driver(self): + return self.driver + + def get_stream(self): + return "" + + @functools.lru_cache(None) + def get_device_properties(self, device): + return self.driver.utils.get_device_properties() + + def get_current_device(self): + return torch.device("cpu") + + def set_current_device(self, device): + pass + + def get_load_binary_fn(self): + return self.driver.utils.load_binary + + def get_kernel_bin(self): + return "ttgir" + + def get_architecture_descriptor(self, **kwargs): + return "" + + def get_version_key(self): + if self.version_key is None: + self.version_key = compute_core_version_key() + return self.version_key + + def make_launcher_stub(self, name, signature, constants): + # name of files that are cached + so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants) + so_cache_manager = get_cache_manager(so_cache_key) + so_name = f"{name}.so" + # retrieve stub from cache if it exists + cache_path = so_cache_manager.get_file(so_name) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src = self._generate_launcher(constants, signature) + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = build_for_backend(name, src_path, tmpdir) + with open(so, "rb") as f: + so_path = so_cache_manager.put(f.read(), so_name, binary=True) + type(self).stub_so_path = so_path + return so_path + else: + type(self).stub_so_path = cache_path + return cache_path + + def _generate_launcher(self, constants, signature): + # generate glue code + src = """ + #define __EXTENSION_BACKEND__ + #include + #include + + static PyObject* launch_counter(PyObject* self, PyObject* args) { + static int64_t launch_counter = 0; + launch_counter += 1; + return PyLong_FromLong(launch_counter); + } + + static PyObject* launch(PyObject* self, PyObject* args) { + if (PyErr_Occurred()) { + return NULL; + } + launch_counter(self, args); + // return None + Py_INCREF(Py_None); + return Py_None; + } + + static PyMethodDef ModuleMethods[] = { + {"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}, + {"launch_counter", launch_counter, METH_VARARGS, "Entry point to get launch counter"}, + {NULL, NULL, 0, NULL} // sentinel + }; + + static struct PyModuleDef ModuleDef = { + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods + }; + + PyMODINIT_FUNC PyInit___triton_launcher(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; + } + """ + + return src + + +def test_dummy_backend(): + register_backend("cpu", ExtensionBackend) + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + inp = torch.randn(10) + out = torch.randn(10) + kernel[(10, )](inp, out, 10, XBLOCK=16) + spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + launch_counter = getattr(mod, "launch_counter") + + for _ in range(100): + kernel[(10, )](inp, out, 10, XBLOCK=16) + + assert launch_counter() > 0 diff --git a/third_party/enflame/include/triton/python/test/kernel_comparison/kernels.yml b/third_party/enflame/include/triton/python/test/kernel_comparison/kernels.yml new file mode 100644 index 000000000..d557e6c66 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/kernel_comparison/kernels.yml @@ -0,0 +1,33 @@ +name_and_extension: + - name: _kernel_0d1d2d3de4de5de6c7de8de9c10de11c + extension: ptx + - name: _kernel_0d1d2d3de4de5de6de7c8de9c10de11c + extension: ptx + - name: _kernel_0d1d2d345de6c789c1011c + extension: ptx + - name: _kernel_0d1d2d3456c789c1011c + extension: ptx + - name: _kernel_0d1d2d3de4de5de6c7de8c9de10de11c + extension: ptx + - name: _kernel_0d1d2d34567c8c91011c + extension: ptx + - name: _kernel_0d1d2d3456c78c91011c + extension: ptx + - name: _kernel_0d1d2d3de4de5de6de7c8c9de10de11c + extension: ptx + - name: _kernel_0d1d2d34567c89c1011c + extension: ptx + - name: _kernel_0d1d2d345de6de7c89c1011c + extension: ptx + - name: _kernel_0d1d2d345de6de7c8c9de1011c + extension: ptx + - name: kernel_0d1d2de + extension: ptx + - name: _kernel_0d1d2d345de6c78c9de1011c + extension: ptx + - name: _bwd_kernel_0d1d2d34d5d6d7d8d9d10d11de12de13de14de15c16de17de18de19c20de21de22de23c2425de26de + extension: ptx + - name: _fwd_kernel_0d1d2d34d5d6de7de8de9c10de11de12de13c14de15de16de17c18de19de20de21c2223de24de + extension: ptx + - name: _bwd_preprocess_0d1d2d + extension: ptx diff --git a/third_party/enflame/include/triton/python/test/regression/conftest.py b/third_party/enflame/include/triton/python/test/regression/conftest.py new file mode 100644 index 000000000..d88687b45 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/regression/conftest.py @@ -0,0 +1,22 @@ +import os +import pytest +import tempfile + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default="cuda") + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") + + +@pytest.fixture +def fresh_triton_cache(): + with tempfile.TemporaryDirectory() as tmpdir: + try: + os.environ["TRITON_CACHE_DIR"] = tmpdir + yield tmpdir + finally: + os.environ.pop("TRITON_CACHE_DIR", None) diff --git a/third_party/enflame/include/triton/python/test/regression/test_cast_matmul.py b/third_party/enflame/include/triton/python/test/regression/test_cast_matmul.py new file mode 100644 index 000000000..c10ca4ce7 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/regression/test_cast_matmul.py @@ -0,0 +1,128 @@ +""" +Mixed precision tests for matmul (tl.dot) with cast (tl.to) + +issue: https://github.com/triton-lang/triton/issues/2523 + +TODO: float8 types +""" + +import pytest +import torch + +import triton +import triton.language as tl +from triton._internal_testing import is_hip_mi300, is_cuda, is_hip + +input_dtypes = ["bfloat16", "float16", "float32", "float64"] +if is_cuda(): + input_dtypes += ["int8", "float8_e5m2"] + cc = torch.cuda.get_device_capability(0) + if cc >= (8, 9): + input_dtypes += ["float8_e4m3fn"] +elif is_hip_mi300(): + input_dtypes += [ + "int8", + "float8_e5m2", + # natively supported on mi300 (see CDNA3 ISA, section 7.2) + "float8_e4m3fnuz", + ] + +out_dtypes = ["float16", "float32"] + + +@triton.jit +def matmul_kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + dot_out_dtype: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # + BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr): + # matrix multiplication + pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + acc += tl.dot(a, b, out_dtype=dot_out_dtype) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc, mask=mask) + + +@pytest.mark.parametrize("M, K, N, BLOCK_K, BLOCK_M, BLOCK_N, w_dtype, x_dtype, out_dtype", + [(M, K, N, BLOCK_K, BLOCK_M, BLOCK_N, w, x, o) # + for BLOCK_K in [16, 32] # + for BLOCK_M in [16, 64] # + for BLOCK_N in [16, 64] # + for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] # + for w in input_dtypes + for x in input_dtypes # + for o in out_dtypes]) +def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, BLOCK_N, w_dtype, x_dtype, out_dtype, device): + if x_dtype == w_dtype: + pytest.skip("skip the same input dtype") + if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]: + pytest.skip("skip due to bug on HIP path") + x_dtype: torch.dtype = getattr(torch, x_dtype) + w_dtype: torch.dtype = getattr(torch, w_dtype) + + def init_tensor(dtype, shape): + if dtype == torch.int8: + return torch.randint(0, 2, shape, device=device, dtype=dtype) + elif dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2): + return torch.randn(shape, device=device, dtype=torch.float16).to(dtype) + else: + return torch.randn(shape, device=device, dtype=dtype) + + torch.manual_seed(42) + a = init_tensor(w_dtype, (M, K)) + b = init_tensor(x_dtype, (K, N)) + + torch_dtype = getattr(torch, out_dtype) + triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype + out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype)) + out_triton = torch.empty((M, N), device=device, dtype=torch_dtype) + + # launch kernel + block_m, block_n, block_k = BLOCK_M, BLOCK_N, BLOCK_K + grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1) + + matmul_kernel[grid]( + a, b, out_triton, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, # + GROUP_M=8, # + BLOCK_M=block_m, # + BLOCK_N=block_n, # + BLOCK_K=block_k) + + torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01) diff --git a/third_party/enflame/include/triton/python/test/regression/test_functional_regressions.py b/third_party/enflame/include/triton/python/test/regression/test_functional_regressions.py new file mode 100644 index 000000000..b6143b178 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/regression/test_functional_regressions.py @@ -0,0 +1,278 @@ +import numpy as np +import pytest +import torch +from numpy.random import RandomState + +import triton +import triton.language as tl + + +def test_chained_matmul(device): + # Regression test for issue #1601 + def chained_matmul_reference(a, b, c): + intermediate = torch.einsum('MK,NK->MN', a, b) + return torch.einsum('MN,NK->MK', intermediate, c) + + @triton.jit + def chained_matmul_kernel(A, # shape: (m, k) + B, # shape: (n, k) + C, # shape: (n, k) + out, # shape: (m, k) + m, n, k: tl.constexpr, # + block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr): + + tl.static_assert(block_k == k, f"expected block_k == k but got {block_k} != {k}") + + block_ix = tl.program_id(0) + a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \ + + tl.arange(0, block_k)[None, :] + + a = tl.load(A + a_tile, mask=a_tile < m * k, other=0.0) + + acc = tl.zeros([block_m, block_k], dtype=tl.float32) + + for loop_block_start in range(0, n, block_n): + bc_tile = (loop_block_start + tl.arange(0, block_n))[:, None] * block_k \ + + tl.arange(0, block_k)[None, :] + b = tl.load(B + bc_tile, mask=bc_tile < n * k, other=0.0) + + intermediate = tl.dot(a, tl.trans(b)) + intermediate_mask = ((loop_block_start + tl.arange(0, block_n)) < n)[None, :] \ + * (tl.arange(0, block_m) < m)[:, None] + + intermediate = tl.where(intermediate_mask, intermediate, 0.0) + + c = tl.load(C + bc_tile, mask=bc_tile < n * k) + + acc += tl.dot(intermediate.to(A.dtype.element_ty), c) + + tl.store(out + a_tile, acc.to(A.dtype.element_ty), mask=a_tile < m * k) + + m, n, k = 32, 64, 128 + block_m, block_n, block_k = 16, 32, k + + grid = (triton.cdiv(m, block_m), ) + a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device=device) + b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device=device) + c = torch.randint_like(b, low=0, high=2) + triton_result = torch.zeros_like(a) + + torch_result = chained_matmul_reference(a, b, c) + chained_matmul_kernel[grid]( + a, b, c, triton_result, m, n, k, # + block_m=block_m, block_n=block_n, block_k=block_k) + + assert (torch_result == triton_result).all() + + +def test_vecmat(device): + + @triton.jit + def batched_vecmat( + # inputs + A, # shape: [dim_m, dim_k] + B, # shape: [dim_m, dim_n, dim_k] + # dimensions + dim_m, dim_n, dim_k, + # outputs + output, + # block information + block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr): + m_index = tl.program_id(0) + n_index = tl.program_id(1) + # Output tile + output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \ + + (n_index * block_n + tl.arange(0, block_n))[None, :] + + vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty) + k_blocks = dim_k // block_k + for k_index in range(k_blocks): + # Load A tile + a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \ + + (k_index * block_k + tl.arange(0, block_k))[None, :] + a = tl.load(A + a_tile) + + # Load B tile, transposed to [n, m, k] in order to broadcast A on a + # leading dimension. + b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \ + + (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \ + + (k_index * block_k + tl.arange(0, block_k))[None, None, :] + b = tl.load(B + b_tile) + + expanded_a, _ = tl.broadcast(a, b) + vecmat += tl.trans(tl.sum(expanded_a * b, axis=2)) + + tl.store(output + output_tile, vecmat) + + M, N, K = 128, 128, 128 + block_m, block_n, block_k = 16, 32, 64 + + rs = RandomState(17) + A_vec = rs.randint(0, 4, (M, K)).astype('float32') + B_vec = rs.randint(0, 4, (M, N, K)).astype('float32') + A = A_vec + B = B_vec + + A_tri = torch.tensor(A, device=device) + B_tri = torch.tensor(B, device=device) + C_tri = torch.zeros((M, N), dtype=torch.float32, device=device) + + grid = (M // block_m, N // block_n) + + batched_vecmat[grid]( + A_tri, B_tri, M, N, K, C_tri, # + block_m=block_m, block_n=block_n, block_k=block_k, # + num_warps=4, num_stages=1) + + A_expanded = A[:, np.newaxis, :] + A_broadcasted = np.broadcast_to(A_expanded, (M, N, K)) + AB = A_broadcasted * B + C_ref = np.sum(AB, axis=2) + + np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@pytest.mark.parametrize("type", + ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"]) +def test_iv_dependent_matmul(type, device): + + @triton.jit + def kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + type: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + a_ptrs = a_ptr + b_ptrs = b_ptr + if type == "post_load_two_iters": + a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak + b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk + elif type == "post_load_three_iters": + a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak + b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk + a_ptrs_next_next = a_ptr + 2 * BLOCK_SIZE_K * stride_ak + b_ptrs_next_next = b_ptr + 2 * BLOCK_SIZE_K * stride_bk + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if type == "pre_load": + a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak + b_ptrs = b_ptr + k * BLOCK_SIZE_K * stride_bk + elif type == "post_pre_mixed": + a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + if type == "post_load": + a_ptrs = a_ptr + (k + 1) * BLOCK_SIZE_K * stride_ak + b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk + elif type == "post_pre_mixed": + b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk + elif type == "post_load_two_iters": + a_ptrs = a_ptrs_next + b_ptrs = b_ptrs_next + a_ptrs_next = a_ptr + (k + 2) * BLOCK_SIZE_K * stride_ak + b_ptrs_next = b_ptr + (k + 2) * BLOCK_SIZE_K * stride_bk + elif type == "post_load_three_iters": + a_ptrs = a_ptrs_next + b_ptrs = b_ptrs_next + a_ptrs_next = a_ptrs_next_next + b_ptrs_next = b_ptrs_next_next + a_ptrs_next_next = a_ptr + (k + 3) * BLOCK_SIZE_K * stride_ak + b_ptrs_next_next = b_ptr + (k + 3) * BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + M = 256 + K = 256 + N = 256 + BLOCK_SIZE_K = 32 + BLOCK_SIZE_N = 32 + BLOCK_SIZE_M = 32 + + a = torch.rand((M, K), device=device) + b = torch.rand((K, N), device=device) + + torch_output = torch.mm(a, b) + triton_output = torch.empty_like(torch_output, device=torch_output.device) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + num_stages = 4 if type == "post_load_three_iters" else 3 + kernel[grid]( + a, b, triton_output, M, N, K, # + a.stride(0), a.stride(1), b.stride(0), b.stride(1), # + triton_output.stride(0), triton_output.stride(1), # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, # + num_stages=num_stages) + torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2) + + +def test_reverse_range(device): + + @triton.jit + def kernel(in_ptr, out_ptr): + x0 = tl.arange(0, 512) + tmp0 = tl.load(in_ptr + (512 - x0)) + tl.store(out_ptr + x0, tmp0) + + data = torch.randn((516, ), dtype=torch.float32, device=device) + res = torch.empty((512, ), dtype=torch.float32, device=device) + kernel[(1, )](data, res) + ref = torch.flip(data[1:513], [0]) + assert (res == ref).all() + + +@triton.jit +def _triton_cummax_helper_fn(arg0_0, arg0_1, arg1_0, arg1_1): + tmp0 = arg0_0 > arg1_0 + tmp1 = arg0_0 == arg1_0 + tmp2 = arg0_1 > arg1_1 + tmp3 = tmp1 & tmp2 + tmp4 = tmp0 | tmp3 + tmp5 = tl.where(tmp4, arg0_0, arg1_0) + tmp6 = tl.where(tmp4, arg0_1, arg1_1) + return tmp5, tmp6 + + +def test_inductor_cummax_bool(device): + + @triton.jit + def triton_(in_ptr0, out_ptr0, out_ptr1, XBLOCK: tl.constexpr): + offset = tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + offset).to(tl.int1) + tmp1 = tmp0.to(tl.int1) + tmp3 = offset.to(tl.int64) + tmp5, tmp6, = tl.associative_scan(( + tmp1, + tmp3, + ), 0, _triton_cummax_helper_fn) + tl.store(out_ptr0 + offset, tmp5) + tl.store(out_ptr1 + offset, tmp6) + + a = torch.randn((64, ), device=device) > 0 + values = torch.empty((64, ), dtype=torch.bool, device=device) + indices = torch.empty((64, ), dtype=torch.int64, device=device) + ref = torch.cummax(a, dim=0) + + triton_[(1, )](a, values, indices, 64) + torch.testing.assert_close(ref.values, values) + torch.testing.assert_close(ref.indices, indices) diff --git a/third_party/enflame/include/triton/python/test/unit/blackwell/test_tmem.py b/third_party/enflame/include/triton/python/test/unit/blackwell/test_tmem.py new file mode 100644 index 000000000..801d36e4e --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/blackwell/test_tmem.py @@ -0,0 +1,91 @@ +import pytest +import torch +import tempfile + +import triton +from triton.backends.compiler import GPUTarget + + +def test_tmem_copy_2d(): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 10: + pytest.skip("Test requires Blackwell target.") + + device = "cuda" + + smem_h = 256 + num_cols = smem_h * 4 // 32 + + copy_ops = """ +%93 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> +ttng.init_barrier %93, 1 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> +%tmem_alloc = ttng.tmem_alloc {{tensor_memory_offset = 0 : i32}}: () -> !ttg.memdesc<128x{num_cols}xi32, #tmem, #ttng.tensor_memory, mutable> +ttng.tmem_copy %17, %tmem_alloc, %93 : (!ttg.memdesc<{smem_h}x4xi32, #shared, #ttg.shared_memory>, !ttg.memdesc<128x{num_cols}xi32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>) -> () + +%c0_i32 = arith.constant 0 : i32 +ttng.wait_barrier %93, %c0_i32 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + """.format(num_cols=num_cols, smem_h=smem_h) + + ir_body = """ + + %cst = arith.constant dense<4> : tensor<{smem_h}x1xi32, #blocked> + %0 = tt.make_range {{end = {smem_h} : i32, start = 0 : i32}} : tensor<{smem_h}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = 4 : i32, start = 0 : i32}} : tensor<4xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{smem_h}x4x!tt.ptr, #blocked> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x{num_cols}x!tt.ptr, #blocked> + + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{smem_h}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{smem_h}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{smem_h}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<4xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x4xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x4xi32, #blocked> -> tensor<{smem_h}x4xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{smem_h}x1xi32, #blocked> -> tensor<{smem_h}x4xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{smem_h}x4xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{smem_h}x4x!tt.ptr, #blocked>, tensor<{smem_h}x4xi32, #blocked> + + %01 = tt.make_range {{end = 128 : i32, start = 0 : i32}} : tensor<128xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %41 = tt.expand_dims %01 {{axis = 1 : i32}} : tensor<128xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<128x1xi32, #blocked> + %cst1 = arith.constant dense<{num_cols}> : tensor<128x1xi32, #blocked> + %51 = arith.muli %41, %cst1 : tensor<128x1xi32, #blocked> + %31 = tt.make_range {{end = {num_cols} : i32, start = 0 : i32}} : tensor<{num_cols}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %21 = tt.expand_dims %31 {{axis = 0 : i32}} : tensor<{num_cols}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{num_cols}xi32, #blocked> + %71 = tt.broadcast %21 : tensor<1x{num_cols}xi32, #blocked> -> tensor<128x{num_cols}xi32, #blocked> + %81 = tt.broadcast %51 : tensor<128x1xi32, #blocked> -> tensor<128x{num_cols}xi32, #blocked> + %91 = arith.addi %81, %71 : tensor<128x{num_cols}xi32, #blocked> + %14 = tt.addptr %3, %91 : tensor<128x{num_cols}x!tt.ptr, #blocked>, tensor<128x{num_cols}xi32, #blocked> + + %11 = tt.load %10 : tensor<{smem_h}x4x!tt.ptr, #blocked> + %17 = ttg.local_alloc %11 : (tensor<{smem_h}x4xi32, #blocked>) -> !ttg.memdesc<{smem_h}x4xi32, #shared, #ttg.shared_memory> + {copy_ops} + %22 = ttng.tmem_load %tmem_alloc : !ttg.memdesc<128x{num_cols}xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x{num_cols}xi32, #blocked> + tt.store %14, %22 : tensor<128x{num_cols}x!tt.ptr, #blocked> + + tt.return + + """.format(copy_ops=copy_ops, num_cols=num_cols, smem_h=smem_h) + + ir = """ + #blocked = #ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], order=[0, 1]}> + #shared = #ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=1, order=[1, 0]}> + #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> + #tmem = #ttng.tensor_memory_encoding + module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + """ + ir_body + """ + } + } + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name, target=GPUTarget("cuda", 100, 32)) + + x = torch.randint(size=(smem_h, 4), low=-100, high=100, dtype=torch.int32).to(device) + z_tri = torch.zeros(size=(128, num_cols), dtype=torch.int32).to(device) + kernel[(1, 1, 1)](x, z_tri) + + num_rep_m = smem_h // 32 + + for m in range(num_rep_m): + col_offset = m * 4 + for i in range(4): + # Copied values are duplicated across warps + assert torch.equal(x[m * 32:(m + 1) * 32], z_tri[32 * i:32 * (i + 1), col_offset:(col_offset + 4)]) diff --git a/third_party/enflame/include/triton/python/test/unit/conftest.py b/third_party/enflame/include/triton/python/test/unit/conftest.py new file mode 100644 index 000000000..913acdc30 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/conftest.py @@ -0,0 +1,26 @@ +import os +import pytest +import tempfile + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default="cuda") + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") + + +@pytest.fixture +def fresh_triton_cache(): + with tempfile.TemporaryDirectory() as tmpdir: + try: + os.environ["TRITON_CACHE_DIR"] = tmpdir + yield tmpdir + finally: + os.environ.pop("TRITON_CACHE_DIR", None) diff --git a/third_party/enflame/include/triton/python/test/unit/cuda/__init__.py b/third_party/enflame/include/triton/python/test/unit/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/include/triton/python/test/unit/cuda/test_experimental_tma.py b/third_party/enflame/include/triton/python/test/unit/cuda/test_experimental_tma.py new file mode 100644 index 000000000..5ef046b1d --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/cuda/test_experimental_tma.py @@ -0,0 +1,1179 @@ +import pytest +import torch + +import triton +import triton.language as tl +from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor) +from triton._internal_testing import dtypes_with_bfloat16, is_interpreter, numpy_random, to_triton, requires_tma, supports_tma, tma_skip_msg + +from typing import Optional + + +def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size): + cpu_desc = torch.empty(128, device="cpu") + if len(dims) == 1: + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, + cpu_desc.data_ptr()) + else: + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], + element_size, cpu_desc.data_ptr()) + return cpu_desc.cuda() + + +def unwrap_tensor(t: torch.Tensor | triton.runtime.jit.TensorWrapper): + if isinstance(t, triton.runtime.jit.TensorWrapper): + return t.base + return t + + +tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"}) + + +@pytest.mark.parametrize("byval_tma", [True, False]) +def test_experimetal_descriptor_load(byval_tma): + if not supports_tma(byval_tma): + pytest.skip(tma_skip_msg(byval_tma)) + + device = "cuda" + SIZE = 128 + + @triton.jit + def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr): + if not BYVAL_TMA: + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc) + off_desc = 0 + off = tl.arange(0, SIZE) + x = tl._experimental_descriptor_load(desc, [off_desc], [SIZE], Z.dtype.element_ty) + tl.store(Z + off, x) + + x = torch.randn(SIZE, dtype=torch.float32, device=device) + if byval_tma: + desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size()) + else: + desc = create_tma_desc_gmem_ptr(x.data_ptr(), [SIZE], [SIZE], x.element_size()) + z_tri = torch.empty_like(x) + compiled_kernel = kernel[(1, )](z_tri, desc, SIZE=SIZE, BYVAL_TMA=byval_tma, num_warps=4) + assert torch.equal(x, z_tri) + if byval_tma: + assert ".param .align 64 .b8" in compiled_kernel.asm["ptx"] + + +@triton.jit +def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # + M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BYVAL_TMA: tl.constexpr, dtype: tl.constexpr): + if not BYVAL_TMA: + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], dtype) + accumulator = tl.dot(a, b, acc=accumulator) + offs_k += BLOCK_SIZE_K + accumulator = accumulator.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) + + +@pytest.mark.parametrize("num_stages", [1, 4]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) +@pytest.mark.parametrize("byval_tma", [True, False]) +def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): + if not supports_tma(byval_tma): + pytest.skip(tma_skip_msg(byval_tma)) + + device = "cuda" + M, N, K = 8192, 8192, 1024 + torch.manual_seed(42) + A = torch.randn((M, K), dtype=torch.float16, device=device) + B = torch.randn((K, N), dtype=torch.float16, device=device) + C = torch.empty((M, N), dtype=torch.float16, device=device) + if byval_tma: + desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size()) + desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size()) + desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size()) + else: + desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size()) + desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size()) + desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size()) + kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, + 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma, + num_warps=8, num_stages=num_stages, dtype=tl.float16) + ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + if BLOCK_M >= 64 and BLOCK_N >= 64 and torch.cuda.get_device_capability()[0] == 9: + # TODO: The use of stmatrix for Blackwell is currently not supported. + # Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4. + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + if byval_tma: + assert ".param .align 64 .b8" in kernel.asm["ptx"] + + +@triton.jit +def device_tensormap_kernel2d(in_ptr, out_ptr, in_desc, out_desc, ready_flag, M, N, M_BLOCK: tl.constexpr, + N_BLOCK: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + if pid_m == 0 and pid_n == 0: + # Write out descriptor + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=in_desc, + global_address=in_ptr, + load_size=[M_BLOCK, N_BLOCK], + global_size=[M, N], + element_ty=in_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=out_desc, + global_address=out_ptr, + load_size=[M_BLOCK, N_BLOCK], + global_size=[M, N], + element_ty=out_ptr.dtype.element_ty, + ) + tl.atomic_xchg(ready_flag, 1, sem="release") + else: + # Spin until descriptor is ready + flag = tl.full([], 0, tl.int32) + while flag == 0: + flag = tl.atomic_add(ready_flag, 0, sem="acquire") + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(in_desc) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(out_desc) + + moffset = pid_m * M_BLOCK + noffset = pid_n * N_BLOCK + + x = tl._experimental_descriptor_load(in_desc, [moffset, noffset], [M_BLOCK, N_BLOCK], in_ptr.dtype.element_ty) + tl._experimental_descriptor_store(out_desc, x, [moffset, noffset]) + + +@requires_tma +@pytest.mark.parametrize("dtype_str", tma_dtypes) +def test_device_tensormap2d(dtype_str): + M_BLOCK, N_BLOCK = 32, 64 + M_GRID, N_GRID = 2, 4 + + shape = (M_BLOCK * M_GRID, M_BLOCK * N_GRID) + device = "cuda" + inp = to_triton(numpy_random(shape, dtype_str=dtype_str), device=device, dst_type=dtype_str) + inp_copy = inp.clone() + out = to_triton(numpy_random(shape, dtype_str=dtype_str), device=device, dst_type=dtype_str) + + in_desc = torch.randint(0, 256, size=(128, ), dtype=torch.uint8, device="cuda") + out_desc = torch.randint(0, 256, size=(128, ), dtype=torch.uint8, device="cuda") + ready_flag = torch.zeros((), dtype=torch.int32, device="cuda") + + device_tensormap_kernel2d[M_GRID, N_GRID](inp, out, in_desc, out_desc, ready_flag, *shape, M_BLOCK=M_BLOCK, + N_BLOCK=N_BLOCK) + + # Check results are correct + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(out)) + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(inp_copy)) + + +@triton.jit +def device_tensormap_kernel1d(in_ptr, out_ptr, in_desc, out_desc, ready_flag, numel, BLOCK: tl.constexpr): + pid = tl.program_id(axis=0) + + if pid == 0: + # Write out descriptor + tl.extra.cuda.experimental_device_tensormap_create1d( + desc_ptr=in_desc, + global_address=in_ptr, + load_size=BLOCK, + global_size=numel, + element_ty=in_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create1d( + desc_ptr=out_desc, + global_address=out_ptr, + load_size=BLOCK, + global_size=numel, + element_ty=out_ptr.dtype.element_ty, + ) + tl.atomic_xchg(ready_flag, 1, sem="release") + else: + # Spin until descriptor is ready + flag = tl.full([], 0, tl.int32) + while flag == 0: + flag = tl.atomic_add(ready_flag, 0, sem="acquire") + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(in_desc) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(out_desc) + + offset = pid * BLOCK + + x = tl._experimental_descriptor_load(in_desc, [offset], [BLOCK], in_ptr.dtype.element_ty) + tl._experimental_descriptor_store(out_desc, x, [offset]) + + +@requires_tma +@pytest.mark.parametrize("dtype_str", tma_dtypes) +def test_device_tensormap1d(dtype_str): + BLOCK = 256 + GRID = 8 + + shape = (BLOCK * GRID, ) + device = "cuda" + inp = to_triton(numpy_random(shape, dtype_str=dtype_str), device=device, dst_type=dtype_str) + inp_copy = inp.clone() + out = to_triton(numpy_random(shape, dtype_str=dtype_str), device=device, dst_type=dtype_str) + + in_desc = torch.randint(0, 256, size=(128, ), dtype=torch.uint8, device="cuda") + out_desc = torch.randint(0, 256, size=(128, ), dtype=torch.uint8, device="cuda") + ready_flag = torch.zeros((), dtype=torch.int32, device="cuda") + + device_tensormap_kernel1d[ + 1, + ](inp, out, in_desc, out_desc, ready_flag, *shape, BLOCK=BLOCK) + + # Check results are correct + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(out)) + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(inp_copy)) + + +@requires_tma +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", tma_dtypes) +def test_tensor_descriptor_load(dtype_str): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + assert desc.shape[0] == M + assert desc.shape[1] == N + assert desc.strides[0] == N + assert desc.strides[1] == 1 + assert desc.block_shape == [M_BLOCK, N_BLOCK] + block = desc.load([M_BLOCK, 2 * N_BLOCK]) + idx = tl.arange(0, M_BLOCK)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :] + tl.store(out_ptr + idx, block) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 128 + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + M, N = 32, 128 + inp = to_triton(numpy_random((M, N), dtype_str), device="cuda", dst_type=dtype_str) + + M_BLOCK = 8 + N_BLOCK = 32 + out = inp.new_empty((M_BLOCK, N_BLOCK)) + + kernel[(1, )](out, inp, M, N, M_BLOCK, N_BLOCK) + + expect = unwrap_tensor(inp)[1 * M_BLOCK:2 * M_BLOCK, 2 * N_BLOCK:3 * N_BLOCK] + torch.testing.assert_close(expect, unwrap_tensor(out)) + + +@requires_tma +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", tma_dtypes) +def test_tensor_descriptor_store(dtype_str): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + midx = moffset + tl.arange(0, M_BLOCK)[:, None] + nidx = noffset + tl.arange(0, N_BLOCK)[None, :] + idx = midx * N + nidx + + val = tl.load(a_ptr + idx) + + desc = tl._experimental_make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + assert desc.shape[0] == M + assert desc.shape[1] == N + assert desc.strides[0] == N + assert desc.strides[1] == 1 + assert desc.block_shape == [M_BLOCK, N_BLOCK] + desc.store([moffset, noffset], val) + + M, N = 32, 128 + inp = to_triton(numpy_random((M, N), dtype_str), device="cuda", dst_type=dtype_str) + + M_BLOCK = 8 + N_BLOCK = 32 + out = inp.new_empty((M, N)) + + grid_m = M // M_BLOCK + grid_n = N // N_BLOCK + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 128 * (grid_m * grid_n) + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + kernel[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) + + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(out)) + + +@requires_tma +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", tma_dtypes) +@pytest.mark.parametrize("K_BLOCK", [16, 32, 64, 128]) +def test_tensor_descriptor_load3d(dtype_str, K_BLOCK): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, K, stride_m, stride_n, stride_k, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, + K_BLOCK: tl.constexpr): + desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N, K], + strides=[stride_m, stride_n, stride_k], + block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], + ) + + pid_m, pid_n, pid_k = tl.program_id(0), tl.program_id(1), tl.program_id(2) + offs = pid_m * M_BLOCK, pid_n * N_BLOCK, pid_k * K_BLOCK + + block = desc.load(offs) + + idx_m = offs[0] + tl.arange(0, M_BLOCK)[:, None, None] + idx_n = offs[1] + tl.arange(0, N_BLOCK)[None, :, None] + idx_k = offs[2] + tl.arange(0, K_BLOCK)[None, None, :] + idx = idx_m * N * K + idx_n * K + idx_k + mask = (idx_m < M) & (idx_n < N) & (idx_k < K) + tl.store(out_ptr + idx, block, mask) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + inp = to_triton(numpy_random((10, 64, 128), dtype_str), device="cuda", dst_type=dtype_str) + inp.data = inp.data[:, :50, :119] + + if K_BLOCK * inp.element_size() < 32: + return pytest.skip("Invalid last dim size") + + M_BLOCK, N_BLOCK = 8, 8 + out = inp.new_empty(inp.shape) + + grid = tuple(triton.cdiv(size, block) for size, block in zip(inp.shape, (M_BLOCK, N_BLOCK, K_BLOCK))) + kernel[grid](out, inp, *inp.shape, *inp.stride(), M_BLOCK, N_BLOCK, K_BLOCK) + + actual = unwrap_tensor(out) + expect = unwrap_tensor(inp) + torch.testing.assert_close(expect, actual) + + +@requires_tma +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", tma_dtypes) +@pytest.mark.parametrize("K_BLOCK", [16, 32, 64, 128]) +def test_tensor_descriptor_store3d(dtype_str, K_BLOCK): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, K, stride_m, stride_n, stride_k, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, + K_BLOCK: tl.constexpr): + desc = tl._experimental_make_tensor_descriptor( + out_ptr, + shape=[M, N, K], + strides=[stride_m, stride_n, stride_k], + block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], + ) + + pid_m, pid_n, pid_k = tl.program_id(0), tl.program_id(1), tl.program_id(2) + offs = pid_m * M_BLOCK, pid_n * N_BLOCK, pid_k * K_BLOCK + + idx_m = offs[0] + tl.arange(0, M_BLOCK)[:, None, None] + idx_n = offs[1] + tl.arange(0, N_BLOCK)[None, :, None] + idx_k = offs[2] + tl.arange(0, K_BLOCK)[None, None, :] + idx = idx_m * N * K + idx_n * K + idx_k + mask = (idx_m < M) & (idx_n < N) & (idx_k < K) + block = tl.load(a_ptr + idx, mask) + + desc.store(offs, block) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + inp = to_triton(numpy_random((10, 50, 119), dtype_str), device="cuda", dst_type=dtype_str) + + if K_BLOCK * inp.element_size() < 32: + return pytest.skip("Invalid last dim size") + + M_BLOCK, N_BLOCK = 8, 8 + out = inp.new_empty((10, 64, 128)) + + grid = tuple(triton.cdiv(size, block) for size, block in zip(inp.shape, (M_BLOCK, N_BLOCK, K_BLOCK))) + kernel[grid](out, inp, *inp.shape, *out.stride(), M_BLOCK, N_BLOCK, K_BLOCK) + + expect = unwrap_tensor(inp) + actual = unwrap_tensor(out)[:, :50, :119] + torch.testing.assert_close(expect, actual) + + +@requires_tma +@pytest.mark.parametrize("dtype_str", tma_dtypes) +@pytest.mark.parametrize("ndim", [2, 3, 4, 5]) +@pytest.mark.parametrize("INNER_BLOCK", [16, 32, 64, 128]) +def test_tensor_descriptor_load_nd(dtype_str, ndim, INNER_BLOCK): + + @triton.jit + def kernel(out_ptr, a_ptr, shape, strides, BLOCK_SHAPE): + desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=shape, + strides=strides, + block_shape=BLOCK_SHAPE, + ) + ndim: tl.constexpr = len(BLOCK_SHAPE) + + offs = (0, ) * ndim + block = desc.load(offs) + + idx = tl.full(BLOCK_SHAPE, 0, tl.int32) + stride = 1 + for k in tl.static_range(ndim - 1, -1, -1): + arange = tl.arange(0, BLOCK_SHAPE[k]) + for _ in tl.static_range(k): + arange = tl.expand_dims(arange, 0) + for _ in tl.static_range(k + 1, ndim): + arange = tl.expand_dims(arange, -1) + + idx += arange * stride + stride *= BLOCK_SHAPE[k] + + tl.store(out_ptr + idx, block) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + alloc_shape = (1, 1, 3, 7, INNER_BLOCK)[-ndim:] + inp = to_triton(numpy_random(alloc_shape, dtype_str), device="cuda", dst_type=dtype_str) + inp.data = inp.data[..., :INNER_BLOCK - 3] + + if INNER_BLOCK * inp.element_size() < 32: + return pytest.skip("Invalid last dim size") + + BLOCK_SHAPE = (2, 2, 4, 8, INNER_BLOCK)[-ndim:] + out = inp.new_empty(BLOCK_SHAPE) + + constexpr_block_shape = tuple(tl.constexpr(v) for v in BLOCK_SHAPE) + kernel[(1, )](out, inp, inp.shape, inp.stride(), constexpr_block_shape) + + # Check in-bounds + actual = unwrap_tensor(out) + expect = unwrap_tensor(inp) + idx = [slice(None, s) for s in inp.shape] + torch.testing.assert_close(expect, actual[idx]) + + # Check out-of-bounds + actual[idx].zero_() + expect = expect.new_zeros(BLOCK_SHAPE) + torch.testing.assert_close(expect, actual) + + +@requires_tma +@pytest.mark.parametrize("dtype_str", tma_dtypes) +@pytest.mark.parametrize("ndim", [2, 3, 4, 5]) +@pytest.mark.parametrize("INNER_BLOCK", [16, 32, 64, 128]) +def test_tensor_descriptor_store_nd(dtype_str, ndim, INNER_BLOCK): + + @triton.jit + def kernel(out_ptr, a_ptr, shape, strides, BLOCK_SHAPE): + desc = tl._experimental_make_tensor_descriptor( + out_ptr, + shape=shape, + strides=strides, + block_shape=BLOCK_SHAPE, + ) + ndim: tl.constexpr = len(BLOCK_SHAPE) + + idx = tl.full(BLOCK_SHAPE, 0, tl.int32) + stride = 1 + for k in tl.static_range(ndim - 1, -1, -1): + arange = tl.arange(0, BLOCK_SHAPE[k]) + for _ in tl.static_range(k): + arange = tl.expand_dims(arange, 0) + for _ in tl.static_range(k + 1, ndim): + arange = tl.expand_dims(arange, -1) + + idx += arange * stride + stride *= BLOCK_SHAPE[k] + + block = tl.load(a_ptr + idx) + + offs = (0, ) * ndim + desc.store(offs, block) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + BLOCK_SHAPE = (2, 2, 4, 8, INNER_BLOCK)[-ndim:] + inp = to_triton(numpy_random(BLOCK_SHAPE, dtype_str), device="cuda", dst_type=dtype_str) + + if INNER_BLOCK * inp.element_size() < 32: + return pytest.skip("Invalid last dim size") + + out = inp.new_empty(BLOCK_SHAPE) + out.data.fill_(-1) + + desc_shape = (1, 1, 3, 7, INNER_BLOCK)[-ndim:] + constexpr_block_shape = tuple(tl.constexpr(v) for v in BLOCK_SHAPE) + kernel[(1, )](out, inp, desc_shape, out.stride(), constexpr_block_shape) + + # Check in-bounds + actual = unwrap_tensor(out) + expect = unwrap_tensor(inp) + idx = [slice(None, s) for s in desc_shape] + torch.testing.assert_close(expect[idx], actual[idx]) + + # Check out-of-bounds + actual[idx].fill_(-1) + expect = expect.new_full(BLOCK_SHAPE, -1) + torch.testing.assert_close(expect, actual) + + +@triton.jit(noinline=True) +def tensor_descriptor_in_function_helper(out_ptr, in_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + in_desc = tl._experimental_make_tensor_descriptor( + in_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + out_desc = tl._experimental_make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + value = in_desc.load([moffset, noffset]) + out_desc.store([moffset, noffset], value.abs()) + + +@requires_tma +@pytest.mark.interpreter +def test_tensor_descriptor_in_function(): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + tensor_descriptor_in_function_helper(out_ptr, a_ptr, M, N, M_BLOCK, N_BLOCK) + + M, N = 32, 128 + inp = torch.randn((M, N), device="cuda") + + M_BLOCK = 8 + N_BLOCK = 32 + out = inp.new_empty((M, N)) + + grid_m = M // M_BLOCK + grid_n = N // N_BLOCK + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 2 * 128 * (grid_m * grid_n) + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + expect = inp.abs() + kernel[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) + torch.testing.assert_close(expect, out) + + +@triton.jit(noinline=True) +def tensor_descriptor_return_helper(ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + return tl._experimental_make_tensor_descriptor( + ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + +@requires_tma +@pytest.mark.interpreter +def test_tensor_descriptor_return_value(): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + in_desc = tensor_descriptor_return_helper(a_ptr, M, N, M_BLOCK, N_BLOCK) + out_desc = tensor_descriptor_return_helper(out_ptr, M, N, M_BLOCK, N_BLOCK) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + value = in_desc.load([moffset, noffset]) + out_desc.store([moffset, noffset], value.abs()) + + M, N = 32, 128 + inp = torch.randn((M, N), device="cuda") + + M_BLOCK = 8 + N_BLOCK = 32 + out = inp.new_zeros((M, N)) + + def alloc_fn(size: int, align: int, stream: Optional[int]) -> torch.Tensor: + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + expect = inp.abs() + kernel[(M // M_BLOCK, N // N_BLOCK)](out, inp, M, N, M_BLOCK, N_BLOCK) + torch.testing.assert_close(expect, out) + + +@triton.jit +def matmul_kernel_make_tensor_desciptor(a_ptr, b_ptr, c_ptr, # + M, N, K, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr, + shape=[K, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_k, offs_bn]) + accumulator = tl.dot(a, b, acc=accumulator) + offs_k += BLOCK_SIZE_K + accumulator = accumulator.to(a_desc.dtype) + c_desc.store([offs_am, offs_bn], accumulator) + + +@requires_tma +@pytest.mark.interpreter +@pytest.mark.parametrize("num_stages", [1, 4]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) +def test_experimental_make_tensor_descriptor_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): + device = "cuda" + if is_interpreter(): + M, N, K = BLOCK_M, BLOCK_N, BLOCK_K + else: + M, N, K = 8192, 8192, 1024 + torch.manual_seed(42) + A = torch.randn((M, K), dtype=torch.float16, device=device) + B = torch.randn((K, N), dtype=torch.float16, device=device) + C = torch.empty((M, N), dtype=torch.float16, device=device) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 3 * 128 * grid[0] + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + kernel = matmul_kernel_make_tensor_desciptor[grid]( + A, + B, + C, + M, + N, + K, + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_warps=8, + num_stages=num_stages, + ) + ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + if is_interpreter(): + return + + assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] + if BLOCK_M >= 64 and BLOCK_N >= 64 and torch.cuda.get_device_capability()[0] == 9: + # TODO: The use of stmatrix for Blackwell is currently not supported. + # Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4. + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + + +@triton.jit +def kernel_make_tensor_desciptor_loop_carried(a_ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + # Test that descriptors work with + pid = tl.program_id(0) + moffset = MBLOCK * pid + + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + + for i in range(0, N, NBLOCK): + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + if i % (3 * NBLOCK) == 0: + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + a = a_desc.load([moffset, i]) + a_desc.store([moffset, i], a + 10) + + n = 0 + while n < N: + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + if n % (3 * NBLOCK) == 0: + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + assert isinstance(a_desc, tl._experimental_tensor_descriptor) + a = a_desc.load([moffset, n]) + a_desc.store([moffset, n], a + 5) + + n += NBLOCK + + +@requires_tma +@pytest.mark.interpreter +def test_experimental_make_tensor_descriptor_loop_carried(): + device = "cuda" + M, N = 64, 512 + torch.manual_seed(42) + A = torch.randn((M, N), dtype=torch.float32, device=device) + MBLOCK, NBLOCK = 8, 128 + grid = (triton.cdiv(M, MBLOCK), ) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert size == 128 * grid[0] + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + ref_out = A + 15 + kernel = kernel_make_tensor_desciptor_loop_carried[grid]( + A, + M, + N, + MBLOCK, + NBLOCK, + ) + torch.testing.assert_close(ref_out, A) + if not is_interpreter(): + assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm[ + "ptx"] + + +@triton.jit +def batched_gemm_2d_tma_kernel(a_ptr, b_ptr, c_ptr, # + B, M, N, K, # + dtype: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SMS: tl.constexpr): + start_pid = tl.program_id(axis=0) + num_tiles_m = tl.cdiv(M, BLOCK_M) + num_tiles_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles_per_batch = num_tiles_m * num_tiles_n + num_tiles = B * num_tiles_per_batch + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + tile_m = 0 + tile_n = 0 + tile_b = 0 + + offs_m = 0 + offs_n = 0 + offs_b = 0 + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N]) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for _ in range(k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + tile_b = tile_id // num_tiles_per_batch + tile_m = (tile_id // num_tiles_n) % num_tiles_m + tile_n = tile_id % num_tiles_n + + offs_b = tile_b + offs_m = tile_m * BLOCK_M + offs_n = tile_n * BLOCK_N + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], + [BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], + [BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], + [BLOCK_M, BLOCK_N]) + + offs_k = ki * BLOCK_K + + a = a_desc.load([offs_m, offs_k]) + b = b_desc.load([offs_n, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + c_desc.store([offs_m, offs_n], c) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@requires_tma +@pytest.mark.interpreter +def test_tensor_descriptor_batched_gemm_2d_tma(): + device = "cuda" + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64 + if is_interpreter(): + B, M, N, K = 2, BLOCK_M, BLOCK_N, BLOCK_K + else: + B, M, N, K = 2, 1024, 1024, 128 + NUM_SMS = 96 + num_stages = 3 + + grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), ) + + a = torch.randn((B, M, K), device=device, dtype=torch.float16) + b = torch.randn((B, N, K), device=device, dtype=torch.float16) + c = torch.empty((B, M, N), device=device, dtype=torch.float16) + + expect = torch.bmm(a, b.mT) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + # TODO: should only need num_stages * 3 descriptors per SM + assert size == 128 * 3 * (num_stages + 1) * grid[0] + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + batched_gemm_2d_tma_kernel[grid]( + a, b, c, # + B, M, N, K, # + tl.float16, # + BLOCK_M, BLOCK_N, BLOCK_K, # + NUM_SMS, # + num_stages=num_stages, num_warps=8) + torch.cuda.synchronize() + + torch.testing.assert_close(c, expect, rtol=1e-3, atol=1e-3) + + +@triton.jit +def batched_gemm_3d_tma_kernel(a_ptr, b_ptr, c_ptr, # + B, M, N, K, # + dtype: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SMS: tl.constexpr): + start_pid = tl.program_id(axis=0) + num_tiles_m = tl.cdiv(M, BLOCK_M) + num_tiles_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles_per_batch = num_tiles_m * num_tiles_n + num_tiles = B * num_tiles_per_batch + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + tile_m = 0 + tile_n = 0 + tile_b = 0 + + offs_m = 0 + offs_n = 0 + offs_b = 0 + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr, [B, M, K], [K * M, K, 1], [1, BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr, [B, N, K], [N * K, K, 1], [1, BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr, [B, M, N], [M * N, N, 1], [1, BLOCK_M, BLOCK_N]) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for _ in range(k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + tile_b = tile_id // num_tiles_per_batch + tile_m = (tile_id // num_tiles_n) % num_tiles_m + tile_n = tile_id % num_tiles_n + + offs_b = tile_b + offs_m = tile_m * BLOCK_M + offs_n = tile_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = a_desc.load([offs_b, offs_m, offs_k]).reshape([BLOCK_M, BLOCK_K]) + b = b_desc.load([offs_b, offs_n, offs_k]).reshape([BLOCK_N, BLOCK_K]) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + c_desc.store([offs_b, offs_m, offs_n], c.reshape((1, BLOCK_M, BLOCK_N))) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@requires_tma +@pytest.mark.interpreter +def test_tensor_descriptor_batched_gemm_3d_tma(): + device = "cuda" + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64 + if is_interpreter(): + B, M, N, K = 2, BLOCK_M, BLOCK_N, BLOCK_K + else: + B, M, N, K = 2, 1024, 1024, 128 + NUM_SMS = 96 + num_stages = 3 + + grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), ) + + a = torch.randn((B, M, K), device=device, dtype=torch.float16) + b = torch.randn((B, N, K), device=device, dtype=torch.float16) + c = torch.empty((B, M, N), device=device, dtype=torch.float16) + + expect = torch.bmm(a, b.mT) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + # TODO: should only need num_stages * 3 descriptors per SM + assert size == 128 * 3 * grid[0] + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + h = batched_gemm_3d_tma_kernel[grid]( + a, b, c, # + B, M, N, K, # + tl.float16, # + BLOCK_M, BLOCK_N, BLOCK_K, # + NUM_SMS, # + num_stages=num_stages, num_warps=8) + torch.cuda.synchronize() + + if not is_interpreter(): + capability = torch.cuda.get_device_capability(0)[0] + dot_op = {9: "warp_group_dot", 10: "tc_gen5_mma"} + assert dot_op[capability] in h.asm["ttgir"] + + torch.testing.assert_close(c, expect, rtol=1e-3, atol=1e-3) + + +@triton.jit +def tma_gather_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.constexpr, BLOCK_X: tl.constexpr, + BLOCK_Y: tl.constexpr): + idx = tl.load(idx_ptr + tl.arange(0, BLOCK_X)) + desc = tl._experimental_make_tensor_descriptor(in_ptr, [X, Y], [Y, 1], [1, BLOCK_Y]) + out = desc.gather(idx, y) + tl.store(out_ptr + tl.arange(0, BLOCK_X)[:, None] * BLOCK_Y + tl.arange(0, BLOCK_Y)[None, :], out) + + +def torch_gather_rows(input, idx, y, block_y): + out = torch.empty(0, device=input.device, dtype=input.dtype) + for i in idx: + x = input[i][y:y + block_y] + out = torch.cat((out, x.reshape(1, x.shape[0])), dim=0) + return out + + +@pytest.mark.interpreter +@pytest.mark.parametrize("X, Y", [(128, 128), (64, 256)]) +@pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8]) +@pytest.mark.parametrize("y", [0, 32, 48]) +@pytest.mark.skipif(not is_interpreter() and torch.cuda.get_device_capability()[0] != 10, + reason="TMA Gather only works on cloud Blackwell Chips") +def test_tma_gather(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device): + if BLOCK_X > X or y + BLOCK_Y > Y: + pytest.skip() + + torch.manual_seed(42) + if dtype != torch.int8: + input = torch.rand((X, Y), dtype=dtype, device=device) + else: + input = torch.arange(X * Y, dtype=dtype, device=device).reshape(X, Y) + output = torch.empty((BLOCK_X, BLOCK_Y), dtype=dtype, device=device) + + idx = torch.randint(BLOCK_X, (BLOCK_X, ), dtype=torch.int32, device=device) + + def alloc_fn(size: int, align: int, steam): + return torch.empty(size, dtype=torch.int8, device=device) + + triton.set_allocator(alloc_fn) + + tma_gather_rows_kernel[(1, )](output, input, idx, y, X, Y, BLOCK_X, BLOCK_Y) + + ref = torch_gather_rows(input, idx, y, BLOCK_Y) + torch.testing.assert_close(ref, output, atol=0, rtol=0) + + +@triton.jit +def tma_gather_dot_pipeline( # + a_ptr, b_ptr, output_ptr, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + K: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # +): + a_desc = tl._experimental_make_tensor_descriptor(a_ptr, [BLOCK_M, K], [K, 1], [1, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr, [K, BLOCK_N], [BLOCK_N, 1], [1, BLOCK_N]) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty) + for k in range(0, K, BLOCK_K): + a = a_desc.gather(tl.arange(0, BLOCK_M), k) + b = b_desc.gather(tl.arange(0, BLOCK_K) + k, 0) + accumulator = tl.dot(a, b, acc=accumulator) + + offs_cm = tl.arange(0, BLOCK_M) + offs_cn = tl.arange(0, BLOCK_N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(output_ptrs, accumulator) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(16, 16, 16)]) +@pytest.mark.parametrize("K", [128]) +@pytest.mark.skipif(not is_interpreter() and torch.cuda.get_device_capability()[0] != 10, + reason="TMA Gather only works on cloud Blackwell Chips") +def test_tma_gather_dot_pipeline(BLOCK_M, BLOCK_N, BLOCK_K, K, device): + + def alloc_fn(size: int, align: int, steam): + return torch.empty(size, dtype=torch.int8, device=device) + + triton.set_allocator(alloc_fn) + + a = torch.arange(BLOCK_M * K, device=device).reshape(BLOCK_M, K).float() + b = torch.arange(K * BLOCK_N, device=device).reshape(K, BLOCK_N).float() + + c = a @ b + + output = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float32, device=device) + if not is_interpreter(): + kernel = tma_gather_dot_pipeline.warmup(a, b, output, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + output.stride(0), output.stride(1), K, BLOCK_M, BLOCK_N, BLOCK_K, + grid=(1, )) + assert kernel.asm["ttgir"].count("ttng.async_tma_gather") == 6 + tma_gather_dot_pipeline[(1, 1, 1)](a, b, output, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + output.stride(0), output.stride(1), K, BLOCK_M, BLOCK_N, BLOCK_K) + + torch.testing.assert_close(c, output) + + +def torch_scatter_rows(input, idx, y, block_y, X, Y): + out = torch.zeros((X, Y), dtype=input.dtype, device=input.device) + for i, j in enumerate(idx): + out[j][y:y + block_y] = input[i] + return out + + +@triton.jit +def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.constexpr, BLOCK_X: tl.constexpr, + BLOCK_Y: tl.constexpr): + idx = tl.load(idx_ptr + tl.arange(0, BLOCK_X)) + data = tl.load(in_ptr + tl.arange(0, BLOCK_X)[:, None] * BLOCK_Y + tl.arange(0, BLOCK_Y)[None, :]) + desc = tl._experimental_make_tensor_descriptor(out_ptr, [X, Y], [Y, 1], [1, BLOCK_Y]) + desc.scatter(data, idx, y) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("X, Y", [(128, 128), (64, 256)]) +@pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8]) +@pytest.mark.parametrize("y", [0, 32, 48]) +@pytest.mark.skipif(not is_interpreter() and torch.cuda.get_device_capability()[0] != 10, + reason="TMA Gather only works on cloud Blackwell Chips") +def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y): + if BLOCK_X > X or y + BLOCK_Y > Y: + pytest.skip() + + torch.manual_seed(42) + input = torch.arange(BLOCK_X * BLOCK_Y, dtype=dtype, device='cuda').reshape(BLOCK_X, BLOCK_Y) + output = torch.zeros((X, Y), dtype=dtype, device='cuda') + + idx = torch.randperm(BLOCK_X, dtype=torch.int32, device='cuda') + + def alloc_fn(size: int, align: int, steam): + return torch.empty(size, dtype=torch.int8, device='cuda') + + triton.set_allocator(alloc_fn) + + tma_scatter_rows_kernel[(1, )](output, input, idx, y, X, Y, BLOCK_X, BLOCK_Y) + + ref = torch_scatter_rows(input, idx, y, BLOCK_Y, X, Y) + torch.testing.assert_close(ref, output, atol=0, rtol=0) diff --git a/third_party/enflame/include/triton/python/test/unit/cuda/test_flashattention.py b/third_party/enflame/include/triton/python/test/unit/cuda/test_flashattention.py new file mode 100644 index 000000000..5053cfc4b --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/cuda/test_flashattention.py @@ -0,0 +1,465 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) +""" + +# import numpy as np +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel(Q, K, V, sm_scale, # + L, M, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, D0, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + + # TODO: may replace with TMA store without range offset + # initialize offsets for store + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + stride_qh_2d = stride_qh // stride_qm // stride_qk + + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + out_tile_ptr = tl.make_block_ptr( + base=Out, + shape=(D0, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # load q: it will stay in SRAM throughout + q = tl.load(q_tile_ptr) + + # loop over k, v and update accumulators + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + # -- compute qk ---- + k = tl.load(k_tile_ptr, boundary_check=(0, 1)) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + # compute new m + m_curr = tl.maximum(tl.max(qk, 1), m_prev) + # correct old l + l_prev *= tl.exp(m_prev - m_curr) + # attention weights + p = tl.exp(qk - m_curr[:, None]) + l_curr = tl.sum(p, 1) + l_prev + # rescale operands of matmuls + l_rcp = 1. / l_curr + p *= l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + # update acc + p = p.to(tl.float16) + v = tl.load(v_tile_ptr, boundary_check=(0, 1)) + acc += tl.dot(p, v) + # update m_i and l_i + l_prev = l_curr + m_prev = m_curr + # update pointers + k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_N, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_N, 0]) + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_prev) + tl.store(m_ptrs, m_prev) + + acc = acc.to(tl.float16) + tl.store(out_tile_ptr, acc, boundary_check=(0, 1)) + + +@triton.jit +def _bwd_preprocess(Out, DO, L, # + NewDO, Delta, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel(Q, K, V, sm_scale, Out, DO, # + DQ, DK, DV, # + L, M, # + D, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + Z, H, N_CTX, D0, # + num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # init tile_ptr + stride_qz_2d = stride_qz // stride_qm // stride_qk + stride_qh_2d = stride_qh // stride_qm // stride_qk + + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + do_tile_ptr = tl.make_block_ptr( + base=DO, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dq_tile_ptr = tl.make_block_ptr( + base=DQ, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dk_tile_ptr = tl.make_block_ptr( + base=DK, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dv_tile_ptr = tl.make_block_ptr( + base=DV, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # offset pointers for batch/head + DQ += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_tile_ptr, boundary_check=(0, 1)) + v = tl.load(v_tile_ptr, boundary_check=(0, 1)) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_tile_ptr, boundary_check=(0, 1)) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, tl.trans(k)) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_tile_ptr, boundary_check=(0, 1)) + dv += tl.dot(tl.trans(p.to(tl.float16)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(tl.float16)), q) + # compute dq + dq = tl.load(dq_tile_ptr) + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_tile_ptr, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_tile_ptr = tl.advance(q_tile_ptr, [BLOCK_M, 0]) + do_tile_ptr = tl.advance(do_tile_ptr, [BLOCK_M, 0]) + dq_tile_ptr = tl.advance(dq_tile_ptr, [BLOCK_M, 0]) + q_tile_ptr = tl.advance(q_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + do_tile_ptr = tl.advance(do_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + dq_tile_ptr = tl.advance(dq_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + # increment tile pointers + k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_M, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_M, 0]) + # write-back + tl.store(dv_tile_ptr, dv.to(tl.float16), boundary_check=(0, 1)) + tl.store(dk_tile_ptr, dk.to(tl.float16), boundary_check=(0, 1)) + dv_tile_ptr = tl.advance(dv_tile_ptr, [BLOCK_M, 0]) + dk_tile_ptr = tl.advance(dk_tile_ptr, [BLOCK_M, 0]) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + D0 = q.shape[0] * q.shape[1] * q.shape[2] + _fwd_kernel[grid]( + q, k, v, sm_scale, # + L, m, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], D0, # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, # + num_warps=num_warps, num_stages=2) + + ctx.save_for_backward(q, k, v, o, L, m) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + BLOCK = 128 + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + D0 = q.shape[0] * q.shape[1] * q.shape[2] + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, # + do_scaled, delta, # + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL) + _bwd_kernel[(ctx.grid[1], )]( + q, k, v, ctx.sm_scale, # + o, do_scaled, # + dq, dk, dv, # + l, m, # + delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], D0, # + ctx.grid[0], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + num_warps=8, num_stages=1) + return dq, dk, dv, None + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 128, 64), + (4, 48, 256, 64), + (4, 48, 512, 64), + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + # (4, 48, 8192, 64), out of memory +]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") +def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() + sm_scale = 0.2 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, sm_scale) + # print(ref_out) + # print(tri_out) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=0) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 14)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + }, + ) for mode in ['fwd', 'bwd'] +] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn) + return ms + if provider == "flash": + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn) + return ms + + +# only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/third_party/enflame/include/triton/python/test/unit/cuda/test_gemm.py b/third_party/enflame/include/triton/python/test/unit/cuda/test_gemm.py new file mode 100644 index 000000000..cf5ae6e55 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/cuda/test_gemm.py @@ -0,0 +1,463 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import itertools +import os +import re + +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@triton.jit +def matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr # + ): + a_block_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + b_block_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), + order=(0, 1), + ) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + + c = tl.dot(a, b) + + if FLOAT16_OUTPUT: + c = c.to(tl.float16) + + if USE_TMA_EPILOGUE: + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + tl.store(c_block_ptr, c) + else: + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.parametrize( + 'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', + itertools.chain(*[[ + # numCTAs = 1, no TMA multicast: + [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], + [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], + [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], + [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], + [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + # static mask, cluster 4x1 + [256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], + [256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], + # dynamic mask, cluster 2x2 + [128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], + [128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], + # small M, N + [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + ] for USE_TMA_EPILOGUE in [True, False]])) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): + if is_hip() and NUM_CTAS > 1: + pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") + + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + if OUTPUT_TYPE == "float16": + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + else: + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + matmul_no_scf_kernel[(1, 1)]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # + num_warps=NUM_WARPS, # + num_ctas=NUM_CTAS, # + FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # + USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + golden = torch.matmul(a_f32, b_f32) + torch.set_printoptions(profile="full") + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) + + +@triton.jit +def matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_wm, stride_wn, # + stride_zm, stride_zn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, # + out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, # + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, # + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, # + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, # + W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, # + Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr # + ): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + block_offset_m = pid_m * BLOCK_M + block_offset_n = pid_n * BLOCK_N + + a_tile_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(A_ORDER_0, A_ORDER_1), + ) + b_tile_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), + block_shape=(BLOCK_K, BLOCK_N), + order=(B_ORDER_0, B_ORDER_1), + ) + # for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix + w_tile_ptr = tl.make_block_ptr( + base=w_ptr, + shape=(N, N), + strides=(stride_wm, stride_wn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_N), + order=(W_ORDER_0, W_ORDER_1), + ) + z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + offs_m = block_offset_m + tl.arange(0, BLOCK_M) + offs_n = block_offset_n + tl.arange(0, BLOCK_N) + z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + mask = (offs_m < M)[:, None] & (offs_n < N)[None, :] + + for k in range(0, K, BLOCK_K): + a = tl.load(a_tile_ptr, boundary_check=(0, 1)) + b = tl.load(b_tile_ptr, boundary_check=(0, 1)) + z += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0]) + + z = z.to(out_dtype) + + if ADD_MATRIX: + z += tl.load(bias_ptrs, mask=mask) + if ADD_ROWS: + ZRs = bias_ptr + offs_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = bias_ptr + offs_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(w_tile_ptr) + z = tl.dot(z.to(w.dtype), w) + z = z.to(out_dtype) + + if USE_TMA_STORE: + z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn), + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), + order=(Z_ORDER_0, Z_ORDER_1)) + tl.store(z_block_ptr, z, boundary_check=(0, 1)) + else: + tl.store(z_ptrs, z, mask=mask) + + +@pytest.mark.parametrize( + 'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', + [ + # corner shapes + (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) + for shape_w_c in [ + [4096, 1, 1024, False, False, True], + [2048, 204, 1000, True, False, True], + [4096, 1, 1024, False, False, False], + [2048, 204, 1000, True, False, False], + ] + for out_dtype in ['float16', 'float32'] # + for use_tma_store in [False, True] # + ] + [ + # softmax epilogue + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ + [64, 64, 16, 4, 1, 64, 64, 64], + [128, 128, 64, 4, 1, None, None, None], + [16, 16, 64, 4, 1, 16, 16, 64], + [64, 64, 32, 8, 1, 64, 64, 64], + [128, 128, 64, 4, 1, 128, 128, 128], + ] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for + trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] + ] + [ + # loop over epilogues besides of softmax + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ + [64, 64, 16, 4, 1, 128, 128, 64], + *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], + # for chain-dot + [128, 128, 64, 4, 1, None, None, None], + [64, 64, 16, 4, 1, None, None, None], + # small BLOCK_M and BLOCK_K + [16, 16, 64, 4, 1, 128, 128, 64], + *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], + # repeat + [64, 64, 32, 8, 1, 128, 256, 64], + [64, 64, 16, 8, 2, 128, 128, 64], + # irregular shape + [128, 128, 64, 4, 1, 500, 200, 128], + [128, 128, 64, 4, 2, 513, 193, 192], + ] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' + ] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in + [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( + epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) + ] + [ + # loop over tile shapes and transpose combinations + (*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ + [64, 64, 32, 4, 1, 128, 256, 64], + [128, 128, 16, 4, 4, 512, 256, 64], + [128, 256, 32, 4, 8, 256, 256, 192], + [512, 256, 32, 4, 8, 1024, 256, 192], + # BLOCK_K >= 128 + [64, 128, 128, 4, 1, 512, 256, 256], + [128, 128, 128, 4, 1, 256, 256, 192], + [128, 128, 128, 4, 2, 256, 256, 192], + # small BLOCK_M and BLOCK_K + [16, 32, 32, 4, 1, 128, 256, 64], + [32, 32, 16, 4, 1, 256, 256, 192], + [16, 32, 64, 4, 4, 512, 256, 64], + ] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in + [False, True] for trans_output in [False, True] for num_stages in [3] + ] + [ + # loop over instr shapes & pipeline stages + (64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) + for n in [16, 32, 64, 128, 256] + for trans_output in [False] + for out_dtype in ['float32'] + for use_tma_store in [False] + for num_stages in [2, 4, 5, 7] + ] + [ + # irregular shapes + (*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ + [128, 128, 64, 4, 1], + [256, 128, 64, 4, 2], + [128, 128, 128, 4, 2], + ] for shape in [ + [512, 360, 1024], + [360, 4096, 512], + ] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in + [3, 4] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, + out_dtype, USE_TMA_STORE, NUM_STAGES): + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ + '16-32-64-4-4-512-256-64-True-False', + '16-32-64-4-4-512-256-64-True-True', + '16-32-64-4-4-512-256-64-False-False', + '16-32-64-4-4-512-256-64-False-True', + ]: + pytest.skip('shapePerCTA[1] < 16 not supported') + + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ + '16-32-64-4-1-256-256-256-False', + '16-32-64-4-2-256-256-256-False', + '16-32-64-4-2-256-256-256-True', + '16-32-64-8-2-256-256-256-False', + '16-32-64-8-2-256-256-256-True', + ]: + pytest.skip('Known legacy issue, ldmatrix can only support x4') + + if is_hip() and NUM_CTAS > 1: + pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") + + if epilogue == 'add-rows' and NUM_CTAS > 1: + pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') + + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K if K is None else K + + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + a_order = [0, 1] + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + a_order = [1, 0] + + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + b_order = [0, 1] + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + b_order = [1, 0] + + if out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + torch_out_dtype = torch.float16 + else: + out_dtype = tl.float32 + torch_out_dtype = torch.float32 + + # avoid out of memory + if epilogue in ['add-matrix', 'add-rows', 'add-cols']: + if (TRANS_OUTPUT): + bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T + else: + bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) + else: + bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) + + # for chain-dot only + w = torch.randn((N, N), device='cuda', dtype=torch.float16).T + w_order = [0, 1] + + if (TRANS_OUTPUT): + z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T + z_order = [0, 1] + else: + z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) + z_order = [1, 0] + + # torch result + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + dot = torch.matmul(a_f32, b_f32) + + def process_epilogue(d, bias, w, epilogue): + if epilogue == 'add-matrix': + ref = d + bias + elif epilogue == 'add-rows': + ref = d + bias[:, 0][:, None] + elif epilogue == 'add-cols': + ref = d + bias[0, :][None, :] + elif epilogue == 'softmax': + num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) + denom = torch.sum(num, dim=-1, keepdims=True) + ref = num / denom + # ref = torch.softmax(d, 1) + elif epilogue == 'chain-dot': + ref = torch.matmul(d, w.to(torch.float32)) + else: + ref = d + return ref + + golden = process_epilogue(dot, bias, w, epilogue) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) + + pgm = matmul_kernel[grid]( + a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_wm=w.stride(0), stride_wn=w.stride(1), # + stride_zm=z.stride(0), stride_zn=z.stride(1), # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # + out_dtype=out_dtype, # + USE_TMA_STORE=USE_TMA_STORE, # + ADD_MATRIX=epilogue == 'add-matrix', # + ADD_ROWS=epilogue == 'add-rows', # + ADD_COLS=epilogue == 'add-cols', # + DO_SOFTMAX=epilogue == 'softmax', # + CHAIN_DOT=epilogue == 'chain-dot', # + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # + W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # + Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) + + torch.set_printoptions(profile="full") + golden = torch.nn.functional.normalize(golden) + z = torch.nn.functional.normalize(z) + assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) + + # check is cuda backend specific + if is_hip(): + return + + disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() + if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: + is_tcgen5 = (torch.cuda.get_device_capability()[0] + == 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 + ptx = pgm.asm['ptx'] + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) + else: + wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) + assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) diff --git a/third_party/enflame/include/triton/python/test/unit/cuda/test_gemm_fusion.py b/third_party/enflame/include/triton/python/test/unit/cuda/test_gemm_fusion.py new file mode 100644 index 000000000..bad5af09d --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/cuda/test_gemm_fusion.py @@ -0,0 +1,176 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def gemm_fusion_kernel(A, B, C, E, # + M, N, K, # + stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): + pid = tl.program_id(0) + + a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + + acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) + a = tl.load(a_tile_ptr) + for i in range(0, N, BLOCK_N): + b = tl.load(b_tile_ptr) + o_ab = tl.dot(a, tl.trans(b)) + c = tl.load(c_tile_ptr) + o_ab = o_ab.to(tl.float16) + acc_e += tl.dot(o_ab, c) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_N, 0]) + c_tile_ptr = tl.advance(c_tile_ptr, [BLOCK_N, 0]) + + acc_e = acc_e.to(tl.float16) + tl.store(e_tile_ptr, acc_e) + + +#TODO: blackwell mma pipeline regressed with https://github.com/openai/triton-private-blackwell/commit/5cc42002bc36fb94385481ed4dab178a143733be +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 9, reason="only works on hopper") +def test_gemm_fusion(): + M, N, K = 4096, 4096, 64 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64 + A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + E = torch.empty((M, K), dtype=torch.float16, device='cuda') + ref_out = torch.matmul(torch.matmul(A, B.T), C) + num_warps = 4 + grid = (triton.cdiv(M, BLOCK_M), 1) + gemm_fusion_kernel[grid]( + A, B, C, E, M, N, K, # + A.stride(0), A.stride(1), # + B.stride(0), B.stride(1), # + C.stride(0), C.stride(1), # + E.stride(0), E.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, # + num_warps=num_warps) + + torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) + + +@triton.jit +def batched_gemm_fusion(Q, K, V, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, NH, N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_qz, stride_qh, stride_qm, stride_qk), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_kz, stride_kh, stride_kn, stride_kk), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_vz, stride_vh, stride_vk, stride_vn), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + o_tile_ptr = tl.make_block_ptr( + base=Out, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_oz, stride_oh, stride_om, stride_on), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + + q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3)) + q = tl.reshape(q, (BLOCK_M, BLOCK_DMODEL), can_reorder=True) + for i in range(0, N_CTX, BLOCK_N): + k = tl.load(k_tile_ptr, boundary_check=(0, 1, 2, 3)) + k = tl.reshape(k, (BLOCK_N, BLOCK_DMODEL), can_reorder=True) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + + p = qk.to(tl.float16) + v = tl.load(v_tile_ptr, boundary_check=(0, 1, 2, 3)) + v = tl.reshape(v, (BLOCK_N, BLOCK_DMODEL), can_reorder=True) + acc += tl.dot(p, v) + + k_tile_ptr = tl.advance(k_tile_ptr, [0, 0, BLOCK_N, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [0, 0, BLOCK_N, 0]) + + acc = tl.reshape(acc, (1, 1, BLOCK_M, BLOCK_DMODEL), can_reorder=True) + acc = acc.to(tl.float16) + tl.store(o_tile_ptr, acc) + + +@pytest.mark.skip(reason="don't support 4d across stack, left for future") +def test_batched_gemm_fusion(): + Z = 4 + NH = 48 + H = 64 + N_CTX = 2048 + BLOCK_M, BLOCK_N, BLOCK_DMODEL = 128, 128, H + torch.manual_seed(20) + A = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + B = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + C = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + E = torch.empty_like(A) + BT = B.transpose(-1, -2) + ref_out = torch.matmul(torch.matmul(A, BT), C) + num_warps = 4 + grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH) + batched_gemm_fusion[grid]( + A, B, C, E, # + A.stride(0), A.stride(1), A.stride(2), A.stride(3), # + B.stride(0), B.stride(1), B.stride(2), B.stride(3), # + C.stride(0), C.stride(1), C.stride(2), C.stride(3), # + E.stride(0), E.stride(1), E.stride(2), E.stride(3), # + Z, NH, N_CTX, # + BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps) + + torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) diff --git a/third_party/enflame/include/triton/python/test/unit/cuda/test_mixed_io.py b/third_party/enflame/include/triton/python/test/unit/cuda/test_mixed_io.py new file mode 100644 index 000000000..68ee474a4 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/cuda/test_mixed_io.py @@ -0,0 +1,81 @@ +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +dtype_mapping = { + 'float16': torch.float16, + 'float32': torch.float32, +} + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero') + + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str', + [(98432, 1024, dtype_str) for dtype_str in ['float16', 'float32']]) +def test_add(SIZE, BLOCK_SIZE, dtype_str): + dtype = dtype_mapping[dtype_str] + output = torch.empty(SIZE, device='cuda', dtype=dtype) + x = torch.randn(SIZE, device='cuda', dtype=dtype) + y = torch.randn(SIZE, device='cuda', dtype=dtype) + + def grid(meta): + return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + + add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE) + + output_torch = x + y + torch.set_printoptions(profile='full') + assert_close(output, output_torch, rtol=1e-2, atol=1e-3, check_dtype=False) + + +@triton.jit +def load_reduce_kernel( + x_ptr, + y_ptr, + stride_xm, + stride_xn, + stride_y, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + x = tl.load(x_ptr) + y = tl.max(x, axis=1) + tl.store(y_ptr + tl.arange(0, BLOCK_M), y) + + +@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', [(128, 64, dtype_str) for dtype_str in ['float16']]) +def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str): + dtype = dtype_mapping[dtype_str] + x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype) + y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype) + + load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N) + + golden = x.max(dim=1)[0] + torch.set_printoptions(profile='full') + assert_close(y, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/third_party/enflame/include/triton/python/test/unit/cuda/test_tma_descriptor.py b/third_party/enflame/include/triton/python/test/unit/cuda/test_tma_descriptor.py new file mode 100644 index 000000000..497248b6b --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/cuda/test_tma_descriptor.py @@ -0,0 +1,49 @@ +import pytest +import torch +from triton.tools.experimental_descriptor import create_1d_tma_descriptor, create_2d_tma_descriptor + + +@pytest.mark.parametrize("M, BLOCK_M, expect_error", [(128, 32, False), (127, 32, False), (128, 31, True)]) +def test_1d_tma_descriptor_exception(M, BLOCK_M, expect_error): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: + pytest.skip("Test requires Hopper target.") + return + + device = "cuda" + x = torch.randn(M, dtype=torch.float32, device=device) + # globalAddress in the tma descriptor must be aligned to 16 bytes for CU_TENSOR_MAP_INTERLEAVE_NONE. + # https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY + assert x.data_ptr() % 16 == 0 + is_error = False + + try: + create_1d_tma_descriptor(x.data_ptr(), M, BLOCK_M, x.element_size()) + except RuntimeError as e: + is_error = True + assert e.args[0] == "Triton Error [CUDA]: invalid argument" + + assert is_error == expect_error + + +@pytest.mark.parametrize("M, BLOCK_M", [(128, 32), (125, 33)]) +@pytest.mark.parametrize("N, BLOCK_N, expect_error", [(128, 32, False), (128, 30, True), (127, 32, True)]) +def test_2d_tma_descriptor_exception(M, N, BLOCK_M, BLOCK_N, expect_error): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: + pytest.skip("Test requires Hopper target.") + return + + device = "cuda" + torch.manual_seed(42) + A = torch.randn((M, N), dtype=torch.float16, device=device) + # globalAddress in the tma descriptor must be aligned to 16 bytes for CU_TENSOR_MAP_INTERLEAVE_NONE. + # https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY + assert A.data_ptr() % 16 == 0 + is_error = False + + try: + create_2d_tma_descriptor(A.data_ptr(), M, N, BLOCK_M, BLOCK_N, A.element_size()) + except RuntimeError as e: + is_error = True + assert e.args[0] == "Triton Error [CUDA]: invalid argument" + + assert is_error == expect_error diff --git a/third_party/enflame/include/triton/python/test/unit/cuda/test_tma_store_gemm.py b/third_party/enflame/include/triton/python/test/unit/cuda/test_tma_store_gemm.py new file mode 100644 index 000000000..b2fc3e874 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/cuda/test_tma_store_gemm.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +@triton.jit +def matmul_tma_load_store( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + OUTPUT_F16: tl.constexpr # +): + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + + c = tl.dot(a, b) + if OUTPUT_F16: + c = c.to(tl.float16) + + tl.store(c_block_ptr, c) + + +@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_F16', [ + [64, 64, 16, 1, 4, False, True, False], + [64, 64, 16, 1, 4, False, True, True], + [128, 64, 32, 1, 4, False, True, False], + [128, 64, 32, 1, 4, False, True, True], + [64, 128, 32, 1, 4, False, True, False], + [64, 128, 32, 1, 4, False, True, True], + [128, 128, 64, 1, 4, False, True, False], + [128, 128, 64, 1, 4, False, True, True], +]) +def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F16): + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + if OUTPUT_F16: + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + matmul_tma_load_store[(1, 1)]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, # + OUTPUT_F16=OUTPUT_F16) + golden = torch.matmul(a, b) + torch.set_printoptions(profile="full") + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/third_party/enflame/include/triton/python/test/unit/instrumentation/test_gpuhello.py b/third_party/enflame/include/triton/python/test/unit/instrumentation/test_gpuhello.py new file mode 100644 index 000000000..bdc6ca907 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/instrumentation/test_gpuhello.py @@ -0,0 +1,48 @@ +import torch + +import pytest +import os + +import triton +import triton.language as tl + +test_stdout = 'Hello From First Instruction of GPU Kernel: kernel1\ttest_gpuhello.py:17:4\n\ +Hello From First Instruction of GPU Kernel: kernel2\ttest_gpuhello.py:23:4\n\ +Hello From First Instruction of GPU Kernel: kernel3\ttest_gpuhello.py:29:4\n' + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel1(BLOCK_SIZE: tl.constexpr): + return + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel2(BLOCK_SIZE: tl.constexpr): + return + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel3(BLOCK_SIZE: tl.constexpr): + return + + +def func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + kernel1[grid](BLOCK_SIZE=1024) + kernel2[grid](BLOCK_SIZE=1024) + kernel3[grid](BLOCK_SIZE=1024) + + +def test_op(capfd, device: str): + size = 98432 + x = torch.rand(size, device=device) + y = torch.rand(size, device=device) + func(x, y) + stdout, stderr = capfd.readouterr() + if 'LLVM_PASS_PLUGIN_PATH' in os.environ: + assert repr(stderr) == repr(test_stdout) diff --git a/third_party/enflame/include/triton/python/test/unit/language/print_helper.py b/third_party/enflame/include/triton/python/test/unit/language/print_helper.py new file mode 100644 index 000000000..dde1409c4 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/print_helper.py @@ -0,0 +1,160 @@ +import sys +import uuid + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def get_current_target_warp_size(): + return triton.runtime.driver.active.get_current_target().warp_size + + +@triton.jit +def kernel_device_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_hex(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x, hex=True) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Triton should add a space after this prefix. + print("x:", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_scalar(SCALAR): + x = tl.load(SCALAR) + # Triton should add a space after this prefix. + print("x:", x) + + +@triton.jit +def kernel_device_print_large( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32) + # Triton should change this prefix to "x: ". + tl.device_print("x ", x) + + +@triton.jit +def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + print("", x, y) + + +@triton.jit +def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + tl.device_print("", x, y) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr): + # This function takes an extra value as a tl.constexpr so this kernel is not + # cached. This way the static print is run every time. + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_print("", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_no_arg_print(): + print("", tl.program_id(0)) + + +@triton.jit +def kernel_print_no_arg(): + print("no arg") + + +@triton.jit +def kernel_print_pointer(X, Y, BLOCK: tl.constexpr): + tl.device_print("ptr ", X + tl.arange(0, BLOCK)) + + +@triton.jit +def kernel_print_2d_tensor(X, Y, BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr): + off_x = tl.arange(0, BLOCK_SIZE_X) + off_y = tl.arange(0, BLOCK_SIZE_Y) + x = tl.load(X + off_x[:, None] * BLOCK_SIZE_Y + off_y[None, :]) + tl.device_print("", x) + + +def test_print(func: str, data_type: str, device: str): + N = 128 # This value should match with test_print in test_subprocess.py. + # TODO(antiagainst): Currently the warp count is chosen to make sure we don't have multiple + # threads printing duplicated messages due to broadcasting. Improve print op lowering logic + # to filter out duplicated data range. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type)) + y = torch.zeros((N, ), dtype=x.dtype, device=device) + if func == "device_print": + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_scalar": + scalar = torch.tensor(42, dtype=x.dtype, device=device) + kernel_device_print_scalar[(1, )](scalar, num_warps=num_warps) + elif func == "device_print_negative": + x = -x + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_uint": + x = torch.arange((1 << 31), (1 << 31) + N, device=device).to(getattr(torch, data_type)) + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "print": + kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_large": + kernel_device_print_large[(1, 2)](BLOCK_M=64, num_warps=num_warps, BLOCK_N=N) + elif func == "print_multiple_args": + kernel_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_multiple_args": + kernel_device_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "static_print": + kernel_static_print[(1, )](x, y, num_warps=num_warps, BLOCK=N, PLACEHOLDER=uuid.uuid4()) + elif func == "no_arg_print": + kernel_no_arg_print[(1, )](num_warps=num_warps) + elif func == "print_no_arg": + kernel_print_no_arg[(1, )](num_warps=num_warps) + elif func == "device_print_hex": + kernel_device_print_hex[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_pointer": + kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_2d_tensor": + BLOCK_SIZE_X = num_warps + BLOCK_SIZE_Y = get_current_target_warp_size() + x_2d_tensor = x.reshape((BLOCK_SIZE_X, BLOCK_SIZE_Y)) + kernel_print_2d_tensor[(1, )](x_2d_tensor, y, num_warps=num_warps, BLOCK_SIZE_X=BLOCK_SIZE_X, + BLOCK_SIZE_Y=BLOCK_SIZE_Y) + else: + assert f"Unknown kernel: {func}" + + if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ + func != "print_multiple_args" and func != "device_print_multiple_args" and \ + func != "device_print_pointer" and func != "device_print_scalar" and func != "device_print_2d_tensor": + assert_close(y, x) + + # Wait until driver complete all the jobs for the device_print, especially test_subprocess + # require this which captures stdout when child exits. + getattr(torch, device).synchronize() + + +if __name__ == "__main__": + fn = globals()[sys.argv[1]] + fn(*sys.argv[2:]) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_annotations.py b/third_party/enflame/include/triton/python/test/unit/language/test_annotations.py new file mode 100644 index 000000000..087a14f50 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_annotations.py @@ -0,0 +1,51 @@ +from __future__ import annotations +import torch +import triton +import triton.language as tl +import pytest + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X + v, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + if not signed and width < 64: + assert "arith.extui %arg1" in h.asm["ttir"] + assert f'%arg1: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] + + +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_block_pointer.py b/third_party/enflame/include/triton/python/test/unit/language/test_block_pointer.py new file mode 100644 index 000000000..aff7a29d8 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_block_pointer.py @@ -0,0 +1,118 @@ +import pytest +import torch + +import triton +import triton.language as tl +from test_core import check_type_supported + + +@triton.jit +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, PADDING_OPTION: tl.constexpr, + TEST_LOWER_BOUND: tl.constexpr, TEST_UPPER_BOUND: tl.constexpr): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + if TEST_LOWER_BOUND: + offset = -N + elif TEST_UPPER_BOUND: + offset = N + # We only copy half of the data to see if the padding works + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(offset, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(offset, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + if PADDING_OPTION is None: + a = tl.load(a_block_ptr, boundary_check=(0, )) + else: + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=PADDING_OPTION) + tl.store(b_block_ptr, a, boundary_check=(0, )) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtypes_str, n, padding_option, boundary_check", [ # + (dtypes_str, n, padding, boundary_check) # + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), + ("float32", "float32"), ("bfloat16", "bfloat16")) + for n in (64, 128, 256, 512, 1024) + for padding in (None, "zero", "nan") # + for boundary_check in (None, "lower", "upper") +]) +def test_block_copy(dtypes_str, n, padding_option, boundary_check, device): + src_dtype_str = dtypes_str[0] + dst_dtype_str = dtypes_str[1] + src_dtype = getattr(torch, src_dtype_str) + dst_dtype = getattr(torch, dst_dtype_str) + check_type_supported(src_dtype, device) + check_type_supported(dst_dtype, device) + if src_dtype_str in ("bool", "int16", "int32"): + if padding_option == "nan": + pytest.skip("Padding with NaN is not supported for integer types") + a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) + else: + a = torch.randn((n, ), device=device, dtype=src_dtype) + b = torch.zeros((n, ), device=device, dtype=dst_dtype) + + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, PADDING_OPTION=padding_option, + TEST_LOWER_BOUND=boundary_check == "lower", TEST_UPPER_BOUND=boundary_check == "upper") + a.to(dst_dtype) + if (boundary_check == "lower") or (boundary_check == "upper"): + assert torch.all(b == 0) + else: + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + elif padding_option == "nan": + assert torch.all(torch.isnan(b[n // 2:n])) + + +@triton.jit +def matmul_no_scf_with_advance_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr # +): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + # Below two lines are just for testing negative offsets for the `advance` API, which could be removed + a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) + a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) + a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero") + b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero") + + c = tl.dot(a, b) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, num_warps", [ # + (shape, num_warps) for shape in [ + [64, 64, 16], + [64, 64, 32], + [64, 64, 64], + ] for num_warps in [4, 8] +]) +def test_block_ptr_matmul_no_scf(shape, num_warps, device): + m, n, k = shape + a = torch.randn((m, k), device=device, dtype=torch.float16) + b = torch.randn((k, n), device=device, dtype=torch.float16) + c = torch.empty((m, n), device=device, dtype=torch.float32) + + grid = lambda META: (1, ) + matmul_no_scf_with_advance_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=m, N=n, K=k, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # + num_warps=num_warps) + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_compile_errors.py b/third_party/enflame/include/triton/python/test/unit/language/test_compile_errors.py new file mode 100644 index 000000000..82b34efa7 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_compile_errors.py @@ -0,0 +1,455 @@ +import contextlib +import pytest +import os + +import torch +import triton +import triton.language as tl +from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure +import traceback +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_hip_mi350 + + +def format_exception(type, value, tb): + list_msg = traceback.format_exception(type, value, tb, chain=False) + return "\n".join(list_msg) + + +def test_err_undefined_variable(): + + @triton.jit + def kernel(): + a += 1 # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "is not defined" in err_msg, "error should mention the undefined variable" + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_operator(): + + @triton.jit + def kernel(): + 0 + "a" + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the 0" + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_static_assert(): + + @triton.jit + def kernel(): + tl.static_assert(isinstance(0, tl.tensor)) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + assert isinstance(e.value, CompileTimeAssertionFailure) + assert e.value.__cause__ is None + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + print(err_msg) + assert "at 2:4:" in err_msg, "error should point to the static_assert call" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_unary_op(): + # Currently Triton can't evaluate `not` of a tuple at compile time. That's + # ok, but the error message needs to point to the correct spot. + @triton.jit + def kernel(): + not (0, 0) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + assert e.value.__cause__ is None + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the `not`" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_op(): + + @triton.jit + def kernel(): + 1.0 << 1 + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the 1.0" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +# This has to be defined as a top-level function; jit'ed functions can't call +# nested functions. +@triton.jit +def nested_call(): + xyz # noqa + + +def test_err_in_nested_call(): + + @triton.jit + def kernel(): + # this is a comment to push nested_call() onto the next line + nested_call() + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + inner_exc = e.value.__cause__ + inner = format_exception(inner_exc.__class__, inner_exc, inner_exc.__traceback__) + assert "at 2:4:" in inner, "error should point to xyz" + assert "" not in inner + assert "code_generator.py" not in inner + + outer = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 3:4" in outer, "error should point to the nested_call" + assert "" not in outer + assert "code_generator.py" not in outer + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_builtin(): + + # The root error here comes from core.py. Make sure the stacktrace reflects + # this. + @triton.jit + def kernel(): + tl.expand_dims(None, -1) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + inner_exc = e.value.__cause__ + inner = format_exception(inner_exc.__class__, inner_exc, inner_exc.__traceback__) + assert f"{os.sep}core.py" in inner, "error should point inside core.py" + assert "code_generator.py" not in inner + + outer = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in outer, "error should point to expand_dims call" + assert "" not in outer + assert "code_generator.py" not in outer + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@triton.jit +def two_returns(): + return tl.arange(0, 4) + return tl.arange(0, 8) + + +def test_two_returns_no_err(): + # This program is valid; `a` has shape (10,). + @triton.jit + def kernel(): + a = two_returns() + a + tl.arange(0, 4) # only works if we took the first return + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +def test_not_const_annotate_no_err(): + + @triton.jit + def kernel(N: int = 1): + pass + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) + + +@triton.jit +def returns_branched_on_constexpr(N: tl.constexpr): + if N == 0: + return tl.arange(0, 4) + # Ideally this would work even without the `else`, but we're not that smart + # yet. + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_constexpr(): + + @triton.jit + def kernel1(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 4) + + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0})) + + @triton.jit + def kernel2(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 8) + + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1})) + + +@triton.jit +def returns_branched_on_non_constexpr(N: int): + if N == 0: + return tl.arange(0, 4) + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_non_constexpr(): + + @triton.jit + def kernel(N: int): + returns_branched_on_non_constexpr(N) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the function call" + assert "at 5:8:" in str(e.value.__cause__), "error should point to the second `return`" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_power_of_two_shapes(): + + @triton.jit + def kernel(): + tl.arange(2, 7) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert str(e.value.__cause__) == "arange's range must be a power of 2" + + +def test_power_of_two_shapes_2(): + + @triton.jit + def kernel(): + tl.full((33, ), 0, dtype=tl.int64) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" + + +def test_captured_var_access(): + + CAPTURED = 42 + + @triton.jit + def kernel(): + a = CAPTURED # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert "CAPTURED is not defined" in str(e.value) + + +GLOBAL = 42 + + +def test_global_var_access(): + + @triton.jit + def kernel(): + a = GLOBAL # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert "global variable" in str(e.value) + + +CONSTEXPR_ANNOTATED_GLOBAL: tl.constexpr = 42 + + +def test_constexpr_annotated_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_ANNOTATED_GLOBAL # noqa + + # No error. + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert False, "Using a constexpr annotated global variable should not be allowed" + except CompilationError as e: + assert "Cannot access global variable" in str(e) + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_constexpr_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +TYPE_ALIAS = tl.pointer_type(tl.int32) + + +def test_global_type_alias_access(): + + @triton.jit + def kernel(): + a = TYPE_ALIAS # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +def test_global_access_in_fn_default_arg(): + + @triton.jit + def kernel(a=GLOBAL): + pass + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={})) + + +def test_defaults_assign_no_err(): + + @triton.jit + def kernel(a=1, B: tl.constexpr = ""): + pass + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""})) + + +def test_where_warning(fresh_triton_cache): + + @triton.jit + def kernel(): + a = tl.full((64, ), 0, tl.uint32) + b = tl.full((64, ), 1, tl.float32) + c = tl.full((64, ), 2, tl.float32) + tl.where(a, b, c) + + with pytest.warns(UserWarning): + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) +def test_fp8_support(fresh_triton_cache, dtype): + warning_dtypes = [] + supported_dtypes = [tl.float8e5] + if is_cuda(): + cc = torch.cuda.get_device_capability(0) + supported_dtypes.append(tl.float8e4b15) + if cc >= (9, 0): + warning_dtypes.append(tl.float8e4b15) + if cc >= (8, 9): + supported_dtypes.append(tl.float8e4nv) + elif is_hip(): + if is_hip_mi300(): + supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16] + if is_hip_mi350(): + supported_dtypes += [tl.float8e4nv] + + @triton.jit + def dtype_kernel(dtype: tl.constexpr): + _ = tl.full((256, ), 0.0, dtype) + + if dtype in warning_dtypes: + ctx = pytest.warns(UserWarning, match=r"fp8e4b15 is deprecated in this architecture") + elif dtype in supported_dtypes: + ctx = contextlib.nullcontext() + else: + ctx = pytest.raises(CompilationError, match="") + + with ctx as e: + triton.compile( + triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + + if dtype not in supported_dtypes: + try: + assert ("not supported in this architecture" in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.int8, tl.float16]) +def test_min_dot_size(dtype): + error_msg = "Input shapes should have " + if is_cuda(): + if dtype.primitive_bitwidth == 8: + error_msg += "M >= 16, N >= 16 and K >= 32" + else: + error_msg = "M >= 16, N >= 16 and K >= 16" + elif is_hip(): + # hip supports arbitrary sizes + error_msg = None + else: + pytest.skip("Test only supported on CUDA and HIP") + + @triton.jit + def dot_kernel(dtype: tl.constexpr): + SIZE: tl.constexpr = 8 + a = tl.full((SIZE, SIZE), 0.0, dtype) + b = tl.full((SIZE, SIZE), 0.0, dtype) + tl.dot(a, b) + + if error_msg is None: + triton.compile( + triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + else: + with pytest.raises(CompilationError) as e: + triton.compile( + triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + try: + assert (error_msg in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_max_num_imprecise_acc_limit(): + + @triton.jit + def dot_kernel(): + SIZE: tl.constexpr = 64 + a = tl.full((SIZE, SIZE), 0.0, tl.float8e5) + b = tl.full((SIZE, SIZE), 0.0, tl.float8e5) + tl.dot(a, b, max_num_imprecise_acc=128) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={})) + try: + assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)") + except AssertionError as assertion_err: + raise assertion_err from e.value diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_compile_only.py b/third_party/enflame/include/triton/python/test/unit/language/test_compile_only.py new file mode 100644 index 000000000..fd9f23985 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_compile_only.py @@ -0,0 +1,158 @@ +import triton +import triton.language as tl +from triton.backends.compiler import GPUTarget +import re + + +def test_compile_only_sm100() -> None: + + @triton.jit + def kernel_add(a, b, c): + idx = tl.arange(0, 32) + tl.store(c + idx, tl.load(a + idx) + tl.load(b + idx)) + + k = triton.compile( + triton.compiler.ASTSource(fn=kernel_add, signature={"a": "*fp32", "b": "*fp32", "c": "*fp32"}, constexprs={}), + target=GPUTarget("cuda", 100, 32)) + ptx = k.asm["ptx"] + assert ".target sm_100a" in ptx + assert ".address_size 64" in ptx + assert k.asm["cubin"] != b"" + + +def test_compile_only_dot() -> None: + + @triton.jit + def simple_dot(a_base, b_base, out): + SIZE: tl.constexpr = 64 + a_ptr = a_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + b_ptr = b_base + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + a = tl.load(a_ptr) + b = tl.load(b_ptr) + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + tl.store(out_ptr, c) + + k = triton.compile( + triton.compiler.ASTSource(fn=simple_dot, signature={"a_base": "*fp16", "b_base": "*fp16", "out": "*fp16"}, + constexprs={}), target=GPUTarget("cuda", 100, 32)) + ttgir = k.asm["ttgir"] + pattern = (r"%(?P\d+) = tt\.load" + r"(.|\n)*?" + r"%(?P\d+) = ttg\.local_alloc %(?P=A)" + r"(.|\n)*?" + r"%(?P\d+) = tt\.load" + r"(.|\n)*?" + r"%(?P\d+) = ttg\.local_alloc %(?P=B)" + r"(.|\n)*?" + r"%(?P\d+) = ttng\.tmem_alloc" + r"(.|\n)*?" + r"ttng\.tc_gen5_mma %(?P=A_SHMEM), %(?P=B_SHMEM), %(?P=TMEM_BASE)" + r"(.|\n)*?" + r"ttng\.tmem_load %(?P=TMEM_BASE)") + + assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern." + + ptx = k.asm["ptx"] + pattern = (r"mov\.u32 %r(?P\d+), global_smem;" + r"(.|\n)*" + r"tcgen05\.alloc\.cta_group::1\.sync\.aligned\.shared::cta\.b32 \[%r(?P=G)], 64" + r"(.|\n)*" + r"tcgen05\.relinquish_alloc_permit\.cta_group::1\.sync\.aligned" + r"(.|\n)*" + r"tcgen05\.st\.sync\.aligned\.16x32bx2.x32.b32" + r"(.|\n)*" + r"tcgen05\.mma\.cta_group::1.kind::f16" + r"(.|\n)*" + r"tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64" + r"(.|\n)*" + r"mbarrier.try_wait.parity.shared.b64" + r"(.|\n)*" + r"tcgen05.ld.sync.aligned.16x32bx2.x32.b32" + r"(.|\n)*" + r"tcgen05.wait::ld.sync.aligned") + assert re.search(pattern, str(ptx)), "The PTX does not match the expected pattern." + assert k.asm["cubin"] != b"" + + +def test_compile_only_k_loop() -> None: + + @triton.jit + def k_loop(a_base, b_base, out, k_tiles): + SIZE: tl.constexpr = 128 + offs_k = tl.arange(0, SIZE) + c = tl.zeros((SIZE, SIZE), dtype=tl.float32) + for k in range(k_tiles): + a_ptr = a_base + tl.arange(0, SIZE)[:, None] * SIZE + offs_k[None, :] + b_ptr = b_base + offs_k[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + offs_k = offs_k + SIZE + a = tl.load(a_ptr) + b = tl.load(b_ptr) + c += tl.dot(a, b) + out_ptr = out + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :] + tl.store(out_ptr, c) + + k = triton.compile( + triton.compiler.ASTSource(fn=k_loop, + signature={"a_base": "*fp16", "b_base": "*fp16", "out": "*fp16", "k_tiles": + "i32"}, constexprs={}), target=GPUTarget("cuda", 100, 32)) + ttgir = k.asm["ttgir"] + + pattern = (r"%(?P\w+) = arith.constant dense<0.000000e\+00>" + r"(.|\n)*?" + r"%(?P\w+) = ttng\.tmem_alloc %(?P=TMEM_BASE)" + r"(.|\n)*?" + r"scf\.for" + r"(.|\n)*?" + r"%(?P\w+) = tt\.load" + r"(.|\n)*?" + r"%(?P\w+) = ttg\.local_alloc %(?P=A)" + r"(.|\n)*?" + r"%(?P\w+) = tt\.load" + r"(.|\n)*?" + r"%(?P\w+) = ttg\.local_alloc %(?P=B)" + r"(.|\n)*?" + r"ttng\.tc_gen5_mma %(?P=A_SHMEM), %(?P=B_SHMEM), %(?P=TMEM)" + r"(.|\n)*?" + r"scf\.yield") + + assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern." + assert k.asm["cubin"] != b"" + + +def test_compile_only_dot_mxfp() -> None: + + @triton.jit + def simple_dot_mxfp(a_base, b_base, a_scale, b_scale, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K + a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * PACKED_BLOCK_K_A + tl.arange(0, PACKED_BLOCK_K_A)[None, :] + b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] + scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] + + a = tl.load(a_ptr) + b = tl.load(b_ptr) + a_scale = tl.load(scale_a_ptr) + b_scale = tl.load(scale_b_ptr) + c = tl.dot_scaled(a, a_scale, "e4m3", b, b_scale, "e4m3") + out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + tl.store(out_ptr, c) + + k = triton.compile( + triton.compiler.ASTSource( + fn=simple_dot_mxfp, signature={ + "a_base": "*u8", "b_base": "*u8", "a_scale": "*u8", "b_scale": "*u8", "out": "*fp32", "BLOCK_M": + "constexpr", "BLOCK_N": "constexpr", "BLOCK_K": "constexpr" + }, constexprs={"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}), target=GPUTarget("cuda", 100, 32)) + ttgir = k.asm["ttgir"] + pattern = (r"ttng.tc_gen5_mma_scaled (.*) lhs = e4m3 rhs = e4m3") + assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern." + + ptx = k.asm["ptx"] + pattern = (r"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X") + assert re.search(pattern, str(ptx)), "The PTX does not match the expected pattern." + assert k.asm["cubin"] != b"" diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_conversions.py b/third_party/enflame/include/triton/python/test/unit/language/test_conversions.py new file mode 100644 index 000000000..7cb4a82bb --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_conversions.py @@ -0,0 +1,365 @@ +# fmt: off + + +import numpy as np +import torch +import pytest +import triton +import triton.language as tl + +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300 + + +def matching_int(dtype): + if dtype.primitive_bitwidth == 8: + return torch.int8 + elif dtype.primitive_bitwidth == 16: + return torch.int16 + elif dtype.primitive_bitwidth == 32: + return torch.int32 + elif dtype.primitive_bitwidth == 64: + return torch.int64 + else: + raise ValueError('unsupported number of bits') + +@triton.jit +def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) + tl.store(dst + idxs, y) + + +def launch_type_convert_triton(src, src_dtype, dst_dtype, device, rounding=None, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE) + return dst + + +@triton.jit +def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + vals = (idxs + offset).to(tl.uint32) + + # pseudorandom permutation: + multiplier = vals << 1 + multiplier += 3511 + vals *= multiplier + + if force_odd: + vals *= 2 + vals += 1 + + if (output_bits == 8): + vals &= 0xff + avals = vals & 0x7f + elif (output_bits == 16): + vals &= 0xffff + avals = vals & 0x7fff + elif (output_bits == 32): + avals = vals & 0x7fffffff + + vals = tl.where(avals <= max_repr, vals, 0) + + if (output_bits == 8): + vals = vals.to(tl.uint8) + elif (output_bits == 16): + vals = vals.to(tl.uint16) + + vals = vals.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, vals) + + +def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, device, BLOCK_SIZE=4096): + + assert(numel % BLOCK_SIZE == 0) + dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device=device) + exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. We don't need to have that + # as input to the conversion kernels. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(x.dtype == tl.float32, "input must be float32") + numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16") + + x = x.to(tl.uint32, bitcast=True) + + mantissa = (x & 0x7fffff) + exponent = ((x >> 23) & 0xff).to(tl.int32) + mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32) + exponent = tl.where(exponent == 0, exponent, exponent - 1) + + sign = (x >> 31) + + exponent = exponent + exponent_bias - 127 + adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits) + mantissa = mantissa.to(tl.float32) * adjustment + + # make exponent nonnegative: + mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe + exponent = tl.where(exponent > -16, exponent, 0) + mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625) + exponent = tl.where(exponent > -8, exponent, exponent + 8) + mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625) + exponent = tl.where(exponent > -4, exponent, exponent + 4) + mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25) + exponent = tl.where(exponent > -2, exponent, exponent + 2) + mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5) + exponent = tl.where(exponent > -1, exponent, exponent + 1) + + if rounding == 'rtne': + # Bring the value to the range [2 ** 23, 2 ** 24] + # where the representable floats map exactly to integers. + # Addition has RTNE semantics. + mantissa += 0x800000 + # Bring the value back to the original range. + mantissa -= 0x800000 + mantissa = mantissa.to(tl.int32) + elif rounding == 'rtz': + mantissa = mantissa.to(tl.int32) + else: + raise ValueError('unrecognized rounding mode') + + # Reassemble output floating-point representation: + exponent = exponent.to(tl.uint32) + y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa + if numbits_dst == 8: + y = y.to(tl.uint8) + elif numbits_dst == 16: + y = y.to(tl.uint16) + return y + + +@triton.jit +def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias) + y = y.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, y) + + +def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + downcast_emulated[(src.shape[0] // BLOCK_SIZE,)]( + triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. downcast_emulated kernel will + # convert -0. in higher precision to 0x80 and thus need to fix the result to 0. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias) + + numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16") + tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + + if numbits_src == 8: + x = x.to(tl.uint8, bitcast=True) + elif numbits_src == 16: + x = x.to(tl.uint16, bitcast=True) + + x = x.to(tl.uint32) + + mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1 + exponent_mask : tl.constexpr = (1 << exponent_bits) - 1 + + mantissa = x & mantissa_mask + exponent = (x >> mantissa_bits) & exponent_mask + sign = (x >> (numbits_src - 1)) + + y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits)) + y = y.to(tl.float32, bitcast=True) + y = y * exponent_compensator + + tl.store(dst + idxs, y) + + +def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=torch.int32, device=device) + upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + return dst + + +def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset, device): + + src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr, device) + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device, rounding=rounding) + src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device) + + dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device) + + dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device) + dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device) + + if not (torch.equal(dst, dst2)): + print('Error!!!') + + dst = dst.cpu().detach().numpy() + dst2 = dst2.cpu().detach().numpy() + src = src.cpu().detach().numpy() + + print(src[dst != dst2][0]) + print(dst[dst != dst2][0]) + print(dst2[dst != dst2][0]) + print(hex(src.view(np.uint32)[dst != dst2][0])) + print(hex(dst.view(np.uint32)[dst != dst2][0])) + print(hex(dst2.view(np.uint32)[dst != dst2][0])) + print('') + raise ValueError('%d elements mismatch' % (dst != dst2).sum()) + + +def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr, device): + + numbits_src = exponent_bits + mantissa_bits + 1 + + src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr, device=device) + + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device) + dst_to_float32 = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) + + src_emulated_to_float32 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) + + assert(torch.equal(src_emulated_to_float32, dst_to_float32)) + + +@pytest.mark.parametrize("src_dtype, dst_dtype", [ + ('float16', 'float32'), + ('bfloat16', 'float32'), + + ('float8e5', 'float16'), + ('float8e5', 'bfloat16'), + ('float8e5', 'float32'), + + ('float8e4b15', 'float16'), + # ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16 + ('float8e4b15', 'float32'), + + ('float8e4nv', 'float16'), + ('float8e4nv', 'bfloat16'), + ('float8e4nv', 'float32'), + + ('float8e4b8', 'float32'), + ('float8e4b8', 'float16'), + + ('float8e5b16', 'float32'), + ('float8e5b16', 'float16'), +]) +def test_typeconvert_upcast(src_dtype, dst_dtype, device): + + # On HIP, fp8e4nv upcasting is only supported to bf16 and fp16, and it's only supported on MI300. + if is_cuda(): + if ((src_dtype == 'float8e4nv' and torch.cuda.get_device_capability(0) < (8, 9)) + or src_dtype in ('float8e4b8', 'float8e5b16')): + # If the dtype should error out in the given device, we assert that and return + with pytest.raises(triton.CompilationError, match="not supported in this architecture"): + launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) + return + elif is_hip(): + if src_dtype == 'float8e4nv' and ( + dst_dtype == 'float32' or ((dst_dtype in ('bfloat16')) and not is_hip_mi300())): + pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture") + if (src_dtype in ('float8e4b15') or + (src_dtype in ('float8e4b8', 'float8e5b16') and not is_hip_mi300())): + # If the dtype should error out in the given device, we assert that and return + with pytest.raises(triton.CompilationError, match="not supported in this architecture"): + launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) + return + + # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) + stuff = { + 'float8e4b15': (4, 3, 15, 0x7e), + 'float8e4nv': (4, 3, 7, 0x7e), + 'float8e5': (5, 2, 15, 0x7b), + 'float8e4b8': (4, 3, 8, 0x7f), + 'float8e5b16': (5, 2, 16, 0x7f), + 'float16': (5, 10, 15, 0x7bff), + 'bfloat16': (8, 7, 127, 0x7f7f), + }[src_dtype] + + upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff, device=device) + +@pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [ + ('float32', 'float16', 'rtne', 0x477fe000), + ('float32', 'float16', 'rtz', 0x477fe000), + ('float32', 'bfloat16', 'rtne', 0x7f7f0000), + ('float32', 'bfloat16', 'rtz', 0x7f7f0000), + ('float32', 'float8e5', 'rtne', 0x47600000), + ('float32', 'float8e5', 'rtz', 0x47600000), + ('float32', 'float8e4nv', 'rtne', 0x43e00000), + ('float32', 'float8e4b8', 'rtne', 0x43700000), + ('float32', 'float8e5b16', 'rtne', 0x47600000), + # ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15 + + ('bfloat16', 'float8e5', 'rtne', 0x4760), + ('bfloat16', 'float8e4nv', 'rtne', 0x43e0), + + ('float16', 'float8e5', 'rtne', 0x7b00), + ('float16', 'float8e4nv', 'rtne', 0x5f00), + + ('bfloat16', 'float8e5b16', 'rtne', 0x4760), + ('bfloat16', 'float8e4b8', 'rtne', 0x4370), + + ('float16', 'float8e5b16', 'rtne', 0x7b00), + ('float16', 'float8e4b8', 'rtne', 0x5b80), +]) +def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + + if is_cuda(): + if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne': + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + + if is_hip(): + if dst_dtype == 'float8e5' and rounding == 'rtne': + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype == 'float8e4nv' and not (src_dtype == 'float16' and rounding == 'rtne' and is_hip_mi300()): + pytest.skip("float8e4nv downcast tests only supported from float16, with RTNE rounding, and on AMDGPU MI300") + + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and not is_hip_mi300(): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias) + stuff = { + 'float16': (5, 10, 15), + 'bfloat16': (8, 7, 127), + 'float8e5': (5, 2, 15), + 'float8e4b15': (4, 3, 15), + 'float8e4nv': (4, 3, 7), + 'float8e4b8': (4, 3, 8), + 'float8e5b16': (5, 2, 16), + }[dst_dtype] + + for i in range(256): + downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_core.py b/third_party/enflame/include/triton/python/test/unit/language/test_core.py new file mode 100644 index 000000000..f70b87f18 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_core.py @@ -0,0 +1,7264 @@ +# ruff: noqa: F821,F841 +import contextlib +import itertools +import re +from typing import Optional +import math +import textwrap +import pathlib + +import numpy as np +import pytest +import torch +import os +import inspect +from numpy.random import RandomState + +import triton +import triton.language as tl +from triton.language.extra import libdevice + +from triton._internal_testing import ( + integral_dtypes, + int_dtypes, + str_to_triton_dtype, + uint_dtypes, + float_dtypes, + float_dtypes_with_bfloat16, + dtypes, + dtypes_with_bfloat16, + is_cuda, + is_interpreter, + is_hopper, + is_hip, + is_hip_cdna, + is_hip_mi200, + is_hip_mi300, + is_hip_mi350, + is_xpu, + get_arch, + torch_float8_dtypes, + torch_dtypes, + numpy_random, + to_triton, + torch_dtype_name, + to_numpy, +) +from triton.runtime.errors import InterpreterError + + +@contextlib.contextmanager +def promotion_numpy_2_0(): + state = np._get_promotion_state() + np._set_promotion_state("weak") + try: + yield + finally: + np._set_promotion_state(state) + + +# No need to emulate NumPy 2.0 if the user has NumPy 2.0 +if np.__version__[0] != "1": + promotion_numpy_2_0 = contextlib.nullcontext + +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + +mma_nonk_sizes = [] + +GPU_DIALECT = "ttg" +if is_interpreter(): + THREADS_PER_WARP = 1 +elif is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size + # for CDNA multiple variants of mma instructions are supported: + # mfma 16x16/mfma 32x32 + # 0 is a special value for automatic heuristic + if is_hip_cdna(): + mma_nonk_sizes = [0, 16, 32] +else: + THREADS_PER_WARP = 32 + + +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def _dtype(dtype: str) -> str: + # ex.: "int64" -> "int" + return re.match(r'([a-zA-Z]+)', dtype).group(0) + + +def patch_kernel(template, to_replace): + if is_interpreter(): + local_namespace = {} + src = textwrap.dedent(inspect.getsource(template.fn)) + for k, v in to_replace.items(): + src = src.replace(k, v) + exec(src, globals(), local_namespace) + return local_namespace[template.fn.__name__] + else: + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel._unsafe_update_src(kernel.src.replace(key, value)) + return kernel + + +def check_cuda_or_hip(device): + # CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel + # GPU do not. + if device not in ['cuda']: + pytest.skip("Only for cuda or HIP") + + +def check_type_supported(dtype, device): + ''' + skip test if dtype is not supported on the current device + ''' + if device in ['cuda']: + cc = torch.cuda.get_device_capability() + if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") + if is_interpreter(): + if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: + pytest.skip("bfloat16 is not supported in the interpreter") + + +class MfmaLayout: + + def __init__(self, version, warps_per_cta, instr_shape, is_transposed): + self.version = version + self.warps_per_cta = warps_per_cta + self.instr_shape = instr_shape + self.is_transposed = is_transposed + + def __str__(self): + return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>" + + +class WmmaLayout: + + def __init__(self, version, warps_per_cta): + self.version = version + self.warps_per_cta = warps_per_cta + + def __str__(self): + return f"#{GPU_DIALECT}.amd_wmma<{{version = {self.version}, warpsPerCTA = {self.warps_per_cta}}}>" + + +class MmaLayout: + + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): + self.version = version + self.warps_per_cta = warps_per_cta + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + self.instr_shape = instr_shape + + def __str__(self): + return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" + + +class DotOperandLayout: + + def __init__(self, parent, op_idx, k_width): + self.parent = parent + self.op_idx = op_idx + self.k_width = k_width + + def __str__(self): + return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" + + +class SliceLayout: + + def __init__(self, dim, parent): + self.dim = dim + self.parent = parent + + def __str__(self): + return f"#{GPU_DIALECT}.slice<{{dim = {self.dim}, parent = {self.parent}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1], + cta_split_num=[1, 1], cta_order=[0, 1]): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class SharedLayout: + + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): + self.vec = vec + self.per_phase = per_phase + self.max_phase = max_phase + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class NVMMASharedLayout: + + def __init__(self, swizzle, transpose, element_bit_width, ctas_per_cga, cta_split_num, cta_order): + self.swizzle = swizzle + self.transpose = transpose + self.element_bit_width = element_bit_width + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + transpose_str = "true" if self.transpose else "false" + return f"#{GPU_DIALECT}.nvmma_shared<{{swizzlingByteWidth={self.swizzle}, transposed={transpose_str}, elementBitWidth={self.element_bit_width}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class LinearLayout: + + def __init__(self, register, lane, warp, block): + self.register = register + self.lane = lane + self.warp = warp + self.block = block + + def __str__(self): + return f"#{GPU_DIALECT}.linear<{{register={self.register}, lane={self.lane}, warp={self.warp}, block={self.block}}}>" + + +# Python impl of LinearEncodingAttr::basesPerDim +def bases_per_dim(layout, dim, rank, skip_broadcast=True): + assert isinstance(layout, LinearLayout) + bases = getattr(layout, dim) + result = [1] * rank + + if not bases: + return result + + non_zero_idx = None + + for basis in bases: + # Find the first non-zero index in the current basis + idx = next((i for i, v in enumerate(basis) if v != 0), None) + if idx is not None: + non_zero_idx = idx + result[idx] *= 2 + elif not skip_broadcast: + # If no non-zero found and we're not skipping broadcasts, use the last found non-zero index + assert non_zero_idx is not None + result[non_zero_idx] *= 2 + + return result + + +def warps_per_cta(layout, shape): + if isinstance(layout, LinearLayout): + return bases_per_dim(layout, 'warp', len(shape)) + elif isinstance(layout, (SliceLayout, DotOperandLayout)): + return warps_per_cta(layout.parent, shape) + else: + return layout.warps_per_cta + + +def is_layout_applicable(layout) -> bool: + if isinstance(layout, (BlockedLayout, SharedLayout, LinearLayout)): + return True + elif isinstance(layout, SliceLayout): + return is_layout_applicable(layout.parent) + elif is_cuda(): + mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout + if not isinstance(mma_layout, MmaLayout): + return False + if mma_layout.version[0] >= 3 and not is_hopper(): + return False + return True + elif is_hip(): + target_arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in target_arch: + # RDNA 3 + return isinstance(layout, WmmaLayout) + elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch): + # CDNA 1, 2, 3 + return isinstance(layout, MfmaLayout) + else: + return False + else: + return True + + +def filter_layouts(layouts): + return [l for l in layouts if is_layout_applicable(l)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) +def test_empty_kernel(dtype_x, device): + SIZE = 128 + + @triton.jit + def kernel(X, SIZE: tl.constexpr): + pass + + check_type_supported(dtype_x, device) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) + kernel[(1, )](x, SIZE=SIZE, num_warps=4) + + +def test_scalar_overflow(device): + + @triton.jit + def kernel(): + huge_int: tl.constexpr = 0xFFFFFFFFFFFFFF + x = tl.full((), 32, dtype=tl.int32) + y = x + huge_int + + with pytest.raises(triton.TritonError, match="out of range"): + kernel[(1, )]() + + +# generic test functions +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = numpy_random(SIZE, dtype_str=dtype_x) + if 'log' in expr: + x = np.abs(x) + 0.01 + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x) + kernel[(1, )](Z=z_tri, X=x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + # compare + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + + +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + x_low=None, x_high=None, y_low=None, y_high=None, filter_y=None, test_broadcast=True, + test_scalar=True): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + check_type_supported(dtype_y, device) + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_lhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + replacements = {'GENERATE_TEST_HERE': expr} + kernel = patch_kernel(kernel, replacements) + kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements) + kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements) + kernel_scalar_rhs = patch_kernel(kernel_scalar_rhs, replacements) + + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs, low=x_low, high=x_high) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) + if filter_y: + y[filter_y(y)] = 1 + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') + + def do_test(x, y, kernel_fn): + x_is_scalar = isinstance(x, (bool, int, float)) + y_is_scalar = isinstance(y, (bool, int, float)) + scalar_test = x_is_scalar or y_is_scalar + + # For scalars, we follow the NumPy 2.0 (and JAX/PyTorch pretty much) casting rules. + if scalar_test: + # We remove any explicit casting + pattern = r'\.astype\(np\.\w+\)' + scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr) + with promotion_numpy_2_0(): + z_ref = eval(scalar_expr) + else: + z_ref = eval(expr if numpy_expr is None else numpy_expr) + + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if not scalar_test and dtype_z is not None: + z_ref = z_ref.astype(dtype_z) + + # triton result + x_tri = x if x_is_scalar else to_triton(x, device=device, dst_type=dtype_x) + y_tri = y if y_is_scalar else to_triton(y, device=device, dst_type=dtype_y) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + err_msg = f"{expr}, {kernel_fn.__name__}" + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=7e-3, rtol=0.01) + + def get_scalar(x, dtype, low, high, filter): + # If dtype is int, don't choose a huge number for the scalar + # as it'll overflow easily when converted to the other dtype + if dtype in integral_dtypes: + # Choose in range [-7, 7] ([0, 7] for uints) + low_x = 0 if dtype in uint_dtypes else -7 + if low is not None: + low_x = max(low_x, low) + high_x = 7 + if high is not None: + high_x = min(high_x, high) + scalar = numpy_random((), dtype_str=dtype, rs=rs, low=low_x, high=high_x).item() + if filter and filter(scalar): + # https://xkcd.com/221/ + scalar = 4 + else: + scalar = x.flat[0].item() + return scalar + + do_test(x, y, kernel) + if mode_y != 'nan' and test_scalar: + if dtype_x in uint_dtypes: + low = 0 if y_low is None else max(y_low, 0) + else: + low = y_low + y_scalar = get_scalar(y, dtype_y, low, y_high, filter_y) + do_test(x, y_scalar, kernel_scalar_rhs) + if test_broadcast: + do_test(x[:1].reshape(()), y, kernel_broadcast_lhs) + do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) + + +def _min_max_integral_mod_value(dtype_x, dtype_y) -> Optional[int]: + """ + Limit min/max values for integral types for mod values. Leads to + overflow/underflow when casting large integral types to floats. + """ + x_bitwidth = _bitwidth(dtype_x) + y_bitwidth = _bitwidth(dtype_y) + + # hard cap max value bit-width to 32 if 64 bit-width types + min_bitwidth = min(x_bitwidth, y_bitwidth, 32) + + # Limit max value bit-width to be one integral type less than the min bit-width + # For example: + # int64, float32 -> int16 + # uint16, float16 -> uint8 + x_dtype = _dtype(dtype_x) + max_bitwidth = max(min_bitwidth >> 1, 8) + dtype_max = x_dtype + str(max_bitwidth) + + max_info = np.iinfo(getattr(np, dtype_max)) + + # Still need to limit values here for uints + if max_bitwidth >= 16 and dtype_max in uint_dtypes: + return max_info.min, max_info.max // 4 + else: + return max_info.min, max_info.max + + +def test_dtype_codegen(): + for dtype in dtypes_with_bfloat16: + full_name = f"triton.language.{dtype}" + assert repr(eval(full_name)) == full_name + + +# --------------- +# test binary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})') + + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + def promote_to_fp32(dtype_x, dtype_y): + return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64') + + if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)): + numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)') + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})') + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})') + elif op == '%': + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = np_expr_gen('x', 'y') + else: + numpy_expr = None + + if (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + else: + # skip when bfloat16, as NumPy's ref performs the computation in float32 + # while Triton performs it in bfloat16 + skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y) + or (op in ('/', '%') and dtype_x in ("float16", "bfloat16"))) + # can't divide by zero + not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes + # can't represent -int(max) + not_minus_one = op in ('*', '/') and dtype_x in int_dtypes and dtype_y in int_dtypes + if not_zero or not_minus_one: + filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1) + else: + filter_y = None + + if op == "%" and dtype_x in integral_dtypes and dtype_y in float_dtypes_with_bfloat16: + x_low, x_high = _min_max_integral_mod_value(dtype_x, dtype_y) + else: + x_low, x_high = None, None + + _test_binary( + dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, + # fails with values where fmod(x, y) is roughly zero, but happens to + # pass with the random values chosen for non-broadcast tests + test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) +def test_addptr(dtype, order, device): + check_type_supported(dtype, device) + + @triton.jit + def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): + offs = tl.arange(0, SIZE) + if ORDER == 0: + tl.store(y + offs, tl.load(x + offs)) + else: + tl.store(offs + y, tl.load(offs + x)) + + SIZE = 1024 + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + x_tri = to_triton(x, dst_type=dtype, device=device) + y_tri = to_triton(y, dst_type=dtype, device=device) + y = x + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) + np.testing.assert_allclose(y, to_numpy(y_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_floordiv(dtype_x, dtype_y, num_ctas, device): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + # can't represent -int(max) + not_minus_one = dtype_x in int_dtypes and dtype_y in int_dtypes + if not_minus_one: + filter_y = lambda y: y == -1 + else: + filter_y = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) + + +def test_unsigned_name_mangling(device): + # Test that uint32 and int32 are mangled differently by the compiler + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(O1, O2, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + out1 = tl.abs(x) # uint32 -> nop + out2 = tl.abs(-y) # int32 -> should have an effect + tl.store(O1 + off, out1) + tl.store(O2 + off, out2) + + dtype_x = 'uint32' + dtype_y = 'int32' + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + # reference result + expect = (np.abs(x), np.abs(-y)) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) + kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) + + # Bitwise op, so expect exact equality + assert (expect[0] == to_numpy(actual[0])).all() + assert (expect[1] == to_numpy(actual[1])).all() + + +# test bitwise ops +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if 'float' in dtype_x + dtype_y: + # The CompilationError must have been caused by a C++ exception with this text. + with pytest.raises(triton.TritonError, match='invalid operands of type'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas) + else: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + if dtype_x.startswith('int'): + dtype_z = f'int{bw}' + else: + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=bw) + + +# --------------- +# test compare ops +# --------------- +ops = ['==', '!=', '>', '<', '>=', '<='] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) + + +# --------------- +# test broadcast +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype, device): + check_type_supported(dtype, device) + + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + + +# ---------- +# test slice +# ---------- + + +@pytest.mark.interpreter +def test_slice(device): + + @triton.jit + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) + + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1, )](XBLOCK=32) + + +# ------------------ +# test invalid slice +# ------------------ + + +@pytest.mark.interpreter +def test_invalid_slice(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + dst[10:] + + with pytest.raises(triton.TritonError, match='unsupported tensor index'): + _kernel[(1, )](dst=dst) + + +# ---------------- +# test expand_dims +# ---------------- +@pytest.mark.interpreter +def test_expand_dims(device): + + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + + # N is a scalar that's not even a tl.tensor -- this should work too. + t = tl.expand_dims(N, -1) + tl.static_assert(t.shape == [1]) + + N = 32 + dummy_tensor = torch.empty((), device=device) + expand_dims_kernel[(1, )](dummy_tensor, N) + + +@pytest.mark.interpreter +def test_expand_dims_error_cases(device): + + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device=device) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range1[(1, )](dummy_tensor, N) + assert "invalid axis -3" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range2[(1, )](dummy_tensor, N) + assert "invalid axis 2" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range3[(1, )](dummy_tensor, N) + assert "invalid axis 1" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim1[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim2[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + +# ---------------------------- +# test invalid program id axis +# ---------------------------- +@pytest.mark.interpreter +def test_invalid_pid_axis(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pid = tl.program_id(20) + + with pytest.raises(triton.TritonError) as exc_info: + _kernel[(1, )](dst) + assert re.search(r"program_id axis must be 0, 1, or 2 but got 20", str(exc_info.value.__cause__)) + + +# --------------- +# test where +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where(dtype, num_ctas, device): + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype, device) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_SCALAR_POINTERS: + ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr) + output = tl.load(ptr + offsets, mask=mask) + else: + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + assert (z == to_numpy(z_tri)).all() + if select_ptrs: + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) + z = np.where(cond[0], x, y) + assert (z == to_numpy(z_tri)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where_broadcast(num_ctas, device): + + @triton.jit + def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + + mask = tl.load(cond_ptr + yoffsets) + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + @triton.jit + def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + mask = False + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + SIZE = 32 + dtype = 'float32' + rs = RandomState(17) + x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) + mask = numpy_random(SIZE, 'bool', rs=rs) + z = np.where(mask, x, 0) + cond_tri = to_triton(mask, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) + assert (z == to_numpy(z_tri)).all() + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) + z = np.where(0, x, 0) + assert (z == to_numpy(z_tri)).all() + + +# --------------- +# test unary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_unary_op(dtype_x, expr, num_ctas, device): + _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + + +# ---------------- +# test math ops +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr, x", + [(dtype_x, expr, x) + for dtype_x in ["float32", "float64"] + for expr in ['exp', 'log', 'cos', 'sin', 'exp2', 'log2', 'sqrt', 'floor', 'ceil'] + for x in ['x', '3.0']]) +def test_math_op(dtype_x, expr, x, device): + _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_erf_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.math.erf(x) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = torch.erf(x) + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_fma_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, Y, W, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + w = tl.load(W + off) + z = tl.math.fma(x, y, w) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + y = torch.randn(SIZE, dtype=torch_dtype, device=device) + w = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = x * y + w + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, y, w, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_divide_op(expr, num_ctas, device): + numpy_expr = "x / y" + dtype = "float32" + _test_binary(dtype, dtype, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +# ------------- +# test precise math +# ------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("expr_prec, expr_ref", + [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), + ('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_precise_math(expr_prec, expr_ref, num_ctas, device): + + @triton.jit + def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + prec = PREC_CALC + ref = REF_CALC + tl.store(OUT + tl.arange(0, BLOCK), prec) + tl.store(OUT_REF + tl.arange(0, BLOCK), ref) + + shape = (128, ) + out = torch.zeros(shape, dtype=torch.float32, device=device) + out_ref = torch.zeros(shape, dtype=torch.float32, device=device) + + x = torch.randn(shape, dtype=torch.float32, device=device) + y = torch.randn(shape, dtype=torch.float32, device=device) + + if (expr_prec.count('sqrt') > 0): + x = torch.abs(x) + + if (expr_prec.count('div') > 0): + y += 1e-6 + + kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref}) + + kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas) + assert torch.all(out == out_ref) # bitwise exact + + +# ---------------- +# test abs +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_abs(dtype_x, device): + _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) +def test_abs_fp8(in_dtype, device): + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') + elif is_cuda(): + cc = torch.cuda.get_device_capability() + if in_dtype == tl.float8e4b15 and cc >= (9, 0): + pytest.skip("float8e4b15 not supported on CUDA >= 9.0") + if in_dtype == tl.float8e4nv and cc < (8, 9): + pytest.skip("float8e4nv not supported on CUDA < 8.9") + + @triton.jit + def abs_kernel(X, Z, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.abs(x) + tl.store(Z + off, z) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device=device) + # f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan + all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width + f8_tensor[all_exp_ones] = 0 + f8 = triton.reinterpret(f8_tensor, in_dtype) + n_elements = f8_tensor.numel() + out_f8 = torch.empty_like(f8_tensor) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + + f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) + expect = f32_tensor.abs() + actual_f8 = convert_float_to_float32(out_f8, in_dtype) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) + + +# ---------------- +# test passing shapes as individual params rather than tuples +# ---------------- + + +@pytest.mark.interpreter +def test_shapes_as_params(device): + + @triton.jit + def kernel(): + a = tl.arange(0, 32).expand_dims(-1).broadcast_to(32, 32) + tl.static_assert(a.shape == [tl.constexpr(32), tl.constexpr(32)]) + + a = tl.arange(0, 32).reshape(4, 8).permute(1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).trans() + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).reshape(32) + tl.static_assert(a.shape == [tl.constexpr(32)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans((2, 1, 0)) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).view(2, 4, 8) + tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) + + kernel[(1, )]() + + +# ---------------- +# test transpose +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_transpose(dtype_x, device): + check_type_supported(dtype_x, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + off2d = off[None, :] + (tl.arange(0, 2) * SIZE)[:, None] + x = tl.load(X + off2d) + z = x.T + tl.store(Z + off2d.T, z) + + x = numpy_random([SIZE, 2], dtype_str=dtype_x) + z_ref = x.T + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE) + np.testing.assert_allclose(z_ref, to_numpy(z_tri)) + + +# ---------------- +# test indexing +# ---------------- + + +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" + + +# TODO: handle `%4 = ttg.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>`` +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_index1d(expr, dtype_str, num_ctas, device): + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] + + # Triton kernel + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) + + # torch result + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) + z_ref = eval(expr) + y + # triton result + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + # compare + assert (z_ref == to_numpy(z_tri)).all() + + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) + except triton.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + + +# --------------- +# test tuples +# --------------- + + +@triton.jit +def tuples_fn(a, b): + return a + b, \ + a - b, \ + a * b + + +@pytest.mark.interpreter +def test_tuples(device): + + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = tuples_fn(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref + + +@triton.jit(noinline=True) +def noinline_simple_fn(x, y, Z): + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_graph_fn1(x): + return x + 1 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn2(y): + return y + 2 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn(x, y, Z): + t0 = noinline_call_graph_fn1(x) + t1 = noinline_call_graph_fn2(y) + z = t0 + t1 + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_shared_fn(x, y, Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + x + y + tl.store(Z + offs, z) + + +@triton.jit(noinline=True) +def noinline_dynamic_fn(x, y, Z): + if x >= 1: + x = noinline_call_graph_fn1(x) + else: + x = noinline_call_graph_fn2(x) + if y >= 2: + y = noinline_call_graph_fn2(y) + else: + y = noinline_call_graph_fn1(y) + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_multi_values_fn(x, y): + return x + 1, y + 2 + + +@triton.jit(noinline=True) +def noinline_multi_values_fn(x, y, Z): + x, y = noinline_call_multi_values_fn(x, y) + z = x + y + tl.store(Z, z) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) +def test_noinline(mode, device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + GENERATE_TEST_HERE(x, y, Z) + + func_name = f'noinline_{mode}_fn' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name}) + x = torch.tensor([1.0], device=device, dtype=torch.float32) + y = torch.tensor([2.0], device=device, dtype=torch.float32) + if mode == "shared": + z = torch.ones((16, 16), device=device, dtype=torch.float32) + else: + z = torch.tensor([0.0], device=device, dtype=torch.float32) + kernel[(1, )](x, y, z, num_warps=1) + if mode == "simple": + assert torch.equal(z, x + y) + elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": + assert torch.equal(z, x + 1 + y + 2) + elif mode == "shared": + ref = torch.full((16, 16), 16, device=device, dtype=torch.float32) + assert torch.equal(z, ref + x + y) + + +# --------------- +# test atomics +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ + ('add', 'float16', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + ('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), + ('max', 'int64', mode, sem), + ('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), + ('min', 'int64', mode, sem), + ('min', 'float64', mode, sem), + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) +def test_atomic_rmw(op, dtype_x_str, mode, sem, device): + check_type_supported(dtype_x_str, device) + if is_interpreter(): + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on GPU") + + n_programs = 5 + + # triton kernel + @triton.jit + def kernel(X, Z): + pid = tl.program_id(0) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) + + sem_arg = sem if sem is None else f'"{sem}"' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] + + # triton result + rs = RandomState(17) + x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str)) + if mode == 'all_neg': + x = -np.abs(x) + if mode == 'all_pos': + x = np.abs(x) + if mode == 'min_neg': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 + if mode == 'max_pos': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) + + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) + h = kernel[(n_programs, )](x_tri, z_tri) + # torch result + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == to_numpy(z_tri).item() + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + sem_str = "acq_rel" if sem is None else sem + if not is_cuda(): + return + + assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_rmw_predicate(num_ctas, device): + + @triton.jit + def kernel(X): + val = tl.program_id(0) + if val < 64: + tl.atomic_max(X, val) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) + assert x.item() == 63 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str, check_return_val", + [(shape, axis, num_ctas, dtype_x_str, check_return_val) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list + for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64'] + for check_return_val in ([True, False] if is_hip() else [True])]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device): + check_type_supported(dtype_x_str, device) + shape0, shape1 = shape + # triton kernel + + @triton.jit + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr, + RETURN_VAL: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + + if DTYPE == tl.float16: + # sum can have bad numerics when accumulating in float16. + # if we're dealing with float16, do the sum in float32. + x = x.to(tl.float32) + + z = tl.sum(x, axis=AXIS) + + if DTYPE == tl.float16: + z = z.to(DTYPE) + + if AXIS == 1: + old = tl.atomic_add(Z + off0, z) + if RETURN_VAL: + tl.store(OLD + off0, old) + else: + old = tl.atomic_add(Z + off1, z) + if RETURN_VAL: + tl.store(OLD + off1, old) + + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) + old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str)) + # reference results + if x.dtype == np.float16: + # do the sum in float32 to reduce numerical variation + z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype) + else: + z_ref = z + np.sum(x, axis=axis, keepdims=False) + old_ref = np.copy(z) + # triton result + x_tri = to_triton(x, device=device) + z_tri = to_triton(z, device=device) + old_tri = to_triton(old, device=device) + + def torch_to_triton_dtype(t): + if t == torch.float16: + return tl.float16 + return None + + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), check_return_val, + num_ctas=num_ctas) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + if check_return_val: + np.testing.assert_equal(old_ref, to_numpy(old_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str) + for size in [2, 4, 8, 32, 64, 128] + for num_ctas in num_ctas_list + for dtype_x_str in ['float16', 'float32']]) +def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device): + + @triton.jit + def kernel(X, val, NUM: tl.constexpr): + off = tl.arange(0, NUM) + offset = off[:, None] * NUM + off[None, :] + val = tl.load(val + offset) + tl.atomic_add(X + offset // 2, val) + + shape = (size // 2, size) + x = torch.zeros(shape, dtype=getattr(torch, dtype_x_str), device=device) + val = torch.randn((size**2), dtype=getattr(torch, dtype_x_str), device=device) + kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas) + ref = val[0::2] + val[1::2] + torch.testing.assert_close(ref, x.reshape(math.prod(shape))) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, idx_order, mask_step, num_ctas, dtype_x_str", + [(shape, idx_order, mask_step, num_ctas, dtype_x_str) + for shape in [(2, 2), (4, 4), (5, 5), (6, 6), (8, 8)] + for idx_order in ['increase', 'decrease', 'random_no_duplication', 'random'] + for mask_step in range(1, 5) + for num_ctas in num_ctas_list + for dtype_x_str in ['float16', 'float32']]) +def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas, dtype_x_str, device): + check_type_supported(dtype_x_str, device) + if is_interpreter(): + pytest.skip("not supported in the interpreter") + + @triton.jit + def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + x_idx = xoffset + tl.arange(0, XBLOCK)[:] + mask = x_idx < shape0 * shape1 + mask = mask and (x_idx % mask_step != 0) + idx_base = shape1 * (x_idx // shape1) + idx_offset = tl.load(idx_ptr + x_idx, mask) + in_elem = tl.load(in_ptr + x_idx, mask) + tl.atomic_add(out_ptr + (idx_offset + idx_base), in_elem, mask, sem='relaxed') + + shape0, shape1 = shape + idx_row = torch.arange(0, shape1, device=device) + if idx_order == 'increase': + idx = torch.stack([idx_row.repeat_interleave(i + 1)[:shape1] for i in range(shape0)]) + if idx_order == 'decrease': + idx = torch.stack([idx_row.flip(0).repeat_interleave(i + 1)[:shape1] for i in range(shape0)]) + if idx_order == 'random_no_duplication': + idx = torch.stack([torch.randperm(shape1, device=device) for _ in idx_row]) + if idx_order == 'random': + idx = torch.randint(0, shape1, size=(shape0, shape1), device=device) + + val = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device) + dst = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device) + + dst_ref = dst.clone() + + cnt = 0 + for i, row in enumerate(idx): + for j, elem in enumerate(row): + if cnt % mask_step != 0: + dst_ref[i][elem] += val[i][j] + cnt += 1 + + kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas) + np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_rmw_block(num_ctas, device): + shape = (8, 8) + + @triton.jit + def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + offs = off0[:, None] * SHAPE1 + off1[None, :] + val = offs.to(tl.float32) + x = X + offs + tl.atomic_min(x, val) + + x = torch.ones((8, 8), device=device, dtype=torch.float32) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) + assert torch.min(x).item() == 0.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_cas(sem, num_ctas, device): + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + change_value[(1, )](Lock) + + assert (Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock, SEM: tl.constexpr): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1, SEM) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # insert barrier to set a fence between tl.store and + # tl.atomic_xchg in a block. + tl.debug_barrier() + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 2000.0) + h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas) + sem_str = "acq_rel" if sem is None else sem + np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) + if not is_cuda(): + return + assert f"atom.global.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_cas(sem, num_ctas, device): + + @triton.jit + def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64) + t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64) + tl.atomic_cas(X + offsets, t1, t2, sem=sem) + + X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64) + Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64) + + change_value[(2, )](X, 4, sem) + assert (torch.equal(X, Y)) + + +@pytest.mark.interpreter +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, + reason="Requires compute capability >= 9 for NV") +def test_load_scope_sem_coop_grid_cta_not_one(device): + + @triton.jit + def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): + numel = 512 + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + mask = index < numel + a = tl.load(ptrs, mask=mask) + tl.store(ptrs, a) + + block_size = 128 + data = torch.zeros((128, ), device=device, dtype=torch.float32) + + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=True) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=False) + + +@pytest.mark.interpreter +def test_load_scope_sem_coop_grid_cta_one(device): + + @triton.jit + def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): + numel = 512 + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + mask = index < numel + a = tl.load(ptrs, mask=mask) + tl.store(ptrs, a) + + block_size = 128 + data = torch.zeros((128, ), device=device, dtype=torch.float32) + + # Should do nothing different for num_ctas=1 (with coop launch grid) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=True) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=False) + + +# --------------- +# test cast +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'bool', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] # + + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" else [])) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): + # CUDA: bfloat16 on cc < 80 will not be tested + # Interpreter: Only bfloat16 <-> float32 is supported + if not is_interpreter() or \ + (is_interpreter() and not ((dtype_z == 'bfloat16' and dtype_x == 'float32') + or (dtype_z == 'float32' and dtype_x == 'bfloat16'))): + check_type_supported(dtype_x, device) + check_type_supported(dtype_z, device) + + if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + + torch.manual_seed(0) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + if dtype_x.startswith('bfloat'): + x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) + else: + x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 + # Triton clamps negative values to zero, while numpy wraps around + # intmax, so avoid negatives for now. + # TODO: figure out which one should actually be happening, and test it + if dtype_z in uint_dtypes: + x = np.absolute(x) + x_tri = to_triton(x, device=device) + if 'float' in dtype_z and 'float' in dtype_x: + # make sure we use values that can be represented in both types + x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x)) + # triton kernel + + @triton.jit + def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + x_ptr = X + tl.arange(0, SIZE) + z_ptr = Z + tl.arange(0, SIZE) + x = tl.load(x_ptr) + + # Depending on the value of ARG_HASH (a "random" number determined by + # the test parameters), spell the cast one of three different ways. + if ARG_HASH % 4 == 0: + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 4 == 1: + z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 4 == 2: + z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, TO_TYPE, bitcast=BITCAST) + + tl.store(z_ptr, z) + + # "Random" number used inside the kernel to determine how we spell the cast. + # This way we don't have to increase the number of tests. + arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) + + dtype_z_np = dtype_z if dtype_z != 'bool' else 'bool_' + # triton result + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) + else: + z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) + + dtype_z_tri = str_to_triton_dtype(dtype_z) + kernel[(1, )](x_tri, z_tri, TO_TYPE=dtype_z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, + num_ctas=num_ctas) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + if dtype_z.startswith('float8') and device not in ['cuda']: + t = z_ref.byte() ^ z_tri.byte() + torch.testing.assert_close(torch.zeros_like(t, dtype=torch.uint8), t) + else: + torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z_np)) + else: + z_ref = x.astype(getattr(np, dtype_z_np)) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +def test_cat(dtype_str, num_warps, device): + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.cat(x, y, can_reorder=True) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) + y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) + z_ref = torch.cat([x, y], dim=0).sum() + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](x, y, z, N=128, num_warps=num_warps) + assert z.sum() == z_ref + # check if there's no duplicate value in z + assert z.unique().size(0) == z.size(0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("constant_field", ["value", "mask"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant(num_ctas, dtype_str, constant_field, device): + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if CONSTANT_FIELD == "value": + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + mask = offsets < n_elements + elif CONSTANT_FIELD == "mask": + output = offsets < n_elements + mask = False + tl.store(output_ptr + offsets, output, mask=mask) + + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field) + + if constant_field == "value": + assert torch.all(output == ref) + else: + assert torch.all(output == 0) + + +def test_load_store_same_ptr(device): + + @triton.jit() + def kernel(in_out_ptr): + pid = tl.program_id(axis=0) + x = tl.load(in_out_ptr + pid) + out = x * 2 + tl.store(in_out_ptr + pid, out) + + for _ in range(1000): + x = torch.ones((65536, ), device=device, dtype=torch.float32) + if is_hip(): + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 + else: + kernel[(65536, )](x, num_warps=32) + assert torch.all(x == 2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['int32']) +def test_umulhi(dtype_str, device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + def umulhi32(a, b): + # Convert to 64-bit unsigned integers to prevent overflow + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + + # Perform the multiplication in 64-bit + product_64 = a_64 * b_64 + + # Shift right by 32 bits to get the high part of the product + result_high_32 = product_64 >> 32 + return result_high_32 + + rs = RandomState(17) + N = 128 + x = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + x_tri = to_triton(x, device=device) + y = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + y_tri = to_triton(y, device=device) + z_tri = torch.zeros_like(x_tri) + kernel[(1, )](x_tri, y_tri, z_tri, N=N) + + z_ref = umulhi32(x, y) + np.testing.assert_equal(z_ref, to_numpy(z_tri)) + + +@pytest.mark.interpreter +def test_join(device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.join(x, y) + tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(-128, 0, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, y, z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + z = tl.join(x, y) + tl.static_assert(z.shape == [2]) + tl.store(Z + tl.arange(0, 2), z) + + x = torch.full([1], 42, device=device).to(torch.int32) + y = torch.full([1], 100, device=device).to(torch.int32) + z = torch.zeros([2], device=device) + kernel[(1, )](x, y, z) + + np.testing.assert_equal([42, 100], to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_with_mma(device): + + @triton.jit + def kernel(X, Z): + x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16) + x2 = tl.join(x, 2 * x) # (32,16,2) + x3 = tl.reshape(x2, (32, 32)) + z = tl.dot(x3, x3) # (32,32) + tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z) + + x = torch.arange(0, 32 * 16, device=device, dtype=torch.float32).reshape((32, 16)) + r = torch.stack([x, 2 * x], dim=-1).reshape((32, 32)) + z_ref = torch.matmul(r, r) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, z) + + torch.testing.assert_close(z, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("debug", [False, True]) +def test_interleave(device, debug): + + @triton.jit(debug=debug) + def kernel(Z, N: tl.constexpr): + z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N)) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(128, 256, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1).reshape(256) + z = torch.zeros_like(z_ref) + kernel[(1, )](z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_interleave_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + z = tl.interleave(X, Y) + tl.static_assert(z.shape == [tl.constexpr(2)]) + tl.store(Z + tl.arange(0, 2), z) + + z = torch.zeros(2, device=device) + kernel[(1, )](10, 20, z) + + np.testing.assert_equal([10, 20], to_numpy(z)) + + +@pytest.mark.interpreter +def test_split(device): + + @triton.jit + def kernel(X, Z1, Z2, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + x1 = tl.reshape(x, (N // 2, 2)) + z1, z2 = tl.split(x1) + tl.store(Z1 + tl.arange(0, N // 2), z1) + tl.store(Z2 + tl.arange(0, N // 2), z2) + + x = torch.arange(0, 256, device=device).to(torch.int32).reshape((128, 2)) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2, N=256) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +@pytest.mark.interpreter +def test_split_to_scalar(device): + + @triton.jit + def kernel(X, Z1, Z2): + offs = tl.arange(0, 2) + x = tl.load(X + offs) + z1, z2 = tl.split(x) + tl.static_assert(isinstance(z1, tl.tensor)) + tl.static_assert(isinstance(z2, tl.tensor)) + tl.static_assert(z1.shape == []) + tl.static_assert(z2.shape == []) + tl.store(Z1, z1) + tl.store(Z2, z2) + + N = 2 + x = torch.arange(0, N, device=device).reshape(N // 2, 2) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +def convert_float_to_float32(fp: torch.tensor, dtype=None): + if not dtype: + dtype = getattr(tl, torch_dtype_name(fp.dtype)) + + fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}")) + exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1 + exp_bias = dtype.exponent_bias + sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int() + exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() + frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() + + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() + + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + # special cases, exp is 0b11..1 + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities + output[fp == 0b01111111] = torch.nan + output[fp == 0b11111111] = torch.nan + else: + output = torch.where(exp == (1 << exp_width) - 1, + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp + | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) + return output + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +def test_convert_float16_to_float32(in_dtype, device): + """Tests that check convert_float_to_float32 function""" + check_type_supported(in_dtype, device) + + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) + f32_output = convert_float_to_float32(f16_input) + + nan = f16_input.isnan() + assert torch.all(f32_output[nan].isnan()) + inf = f16_input.isinf() + assert torch.all(f32_output[inf].isinf()) + other = torch.logical_not(torch.logical_or(nan, inf)) + assert torch.all(f16_input[other] == f32_output[other]) + + +def serialize_fp8(np_data, in_dtype): + return np_data + + +# inverse of `serialize_fp8` + + +def deserialize_fp8(np_data, in_dtype): + return np_data + + +# --------------- +# test reduce +# --------------- + + +@pytest.mark.interpreter +def test_max_returns_zero(device): + # Simple test with a tl.max call that returns 0. The interpreter had a bug + # where it didn't handle this correctly. + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + z = tl.max(x) + tl.store(Z, z) + + BLOCK = 128 + x = torch.zeros((BLOCK, ), device=device) + z = torch.ones((1, ), device=device) + + kernel[(1, )](x, z, BLOCK=BLOCK) + assert z[0] == 0 + + +def get_reduced_dtype(dtype_str, op): + if op in ('argmin', 'argmax'): + return 'int32' + if dtype_str == 'bfloat16': + return 'float32' + return dtype_str + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce1d(op, dtype_str, shape, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + GENERATE_TEST_HERE + tl.store(Z, z) + + if 'with-indices' in op: + patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)' + elif 'arg' in op: + tie_break_left = 'tie-break-left' in op + patch = f'z = tl.{op.split("-")[0]}(x, axis=0, tie_break_left={tie_break_left})' + else: + patch = f'z = tl.{op}(x, axis=0)' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-left': np.argmax, + }[op] + if 'tie-break-left' in op: + x[3:10] = x[numpy_op(x)] + x_tri = to_triton(x, device=device) + # numpy result + z_dtype_str = 'int32' if 'tie-break-left' in op else dtype_str + z_tri_dtype_str = z_dtype_str + if 'tie-break-left' not in op and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if 'tie-break-left' in op: + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# TODO: [Qingyi] Fix argmin / argmax +reduce_configs1 = [(op, dtype, (1, 1024), axis, False) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] + +# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory +# exceeds the limit of 99KB +reduce2d_shapes = [(2, 32), (4, 32), (4, 128)] +# TODO: fix and uncomment +# , (32, 64), (64, 128)] +if is_cuda() and 'V100' in torch.cuda.get_device_name(0): + reduce2d_shapes += [(128, 256) and (32, 1024)] + +reduce_configs2 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None, False) for op in ['min', 'max', 'sum']] + +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] +invalid_config = [('sum', 'float32', (32, 32), axis, False) for axis in [2, 3]] +negative_config = [('sum', 'float32', (32, 32), -1, False)] +keep_dims_2d_configs = [(op, 'float32', (32, 32), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1]] + [(op, 'float32', (32, 32), None, True) for op in ['min', 'max', 'sum']] +keep_dims_3d_configs = [(op, 'float32', (32, 2, 16), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1, 2]] + [(op, 'float32', (32, 2, 16), None, True) + for op in ['min', 'max', 'sum']] +reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + + negative_config + keep_dims_2d_configs + keep_dims_3d_configs + reduce_bool) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) + else: + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + if USE_I1: + x = tl.cast(x, tl.int1) + z = GENERATE_TEST_HERE + z_ptr = Z + if KEEP_DIMS and AXIS is None: + if IS_3D: + z_ptr = z_ptr[None, None, None, :] + else: + z_ptr = z_ptr[None, None, :] + if IS_3D: + if AXIS == 0: + z_ptr = Z + range_n[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 1 or AXIS == -2: + z_ptr = Z + range_m[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 2 or AXIS == -1: + z_ptr = Z + range_m[:, None] * BLOCK_N + range_n[None, :] + else: + if AXIS == 0: + z_ptr = Z + range_n + elif AXIS == 1 or AXIS == -1: + z_ptr = Z + range_m + if KEEP_DIMS and AXIS is not None: + z_ptr = tl.expand_dims(z_ptr, axis=AXIS) + tl.store(z_ptr, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_tri = to_triton(x, device=device) + numpy_op = { + 'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax, 'xor_sum': + np.bitwise_xor.reduce + }[op] + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + if z_dtype_str == 'bool': + z_dtype_str = 'int8' + + # numpy result + # Silence numpy error on axis out of bounds, to give triton a chance to fail + np_axis = axis if axis is not None and axis < len(shape) else None + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + + # triton result + z_shape = z_ref.shape + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) + USE_I1 = dtype_str == 'bool' + if axis is not None and axis >= len(shape): + with pytest.raises(triton.TritonError): + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, USE_I1=USE_I1, num_ctas=num_ctas) + return + else: + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, USE_I1=USE_I1, num_ctas=num_ctas) + + z_tri = to_numpy(z_tri) + + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = z_ref + z_tri_index = z_tri + if not keep_dims: + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] + +scan_configs = [(op, type, shape, axis, reverse, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32', 'bfloat16'] + for axis in [1, 0] + for reverse in [True, False] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element', 'linear_recurrence', 'cummax', 'roll']] +negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)] + + +def test_sum_dtype(device): + + @triton.jit + def kernel_dtype(out_ptr, init, in_dtype: tl.constexpr, out_dtype: tl.constexpr): + x = tl.full((32, 32), init, dtype=in_dtype) + x = tl.sum(x, dtype=out_dtype) + tl.store(out_ptr, x.to(tl.int32)) + + @triton.jit + def kernel_default_int(out_ptr): + x = tl.full((32, 32), 1, dtype=tl.int1) + x = tl.sum(x) + tl.store(out_ptr, x) + + @triton.jit + def kernel_default_float(out_ptr): + x = tl.full((32, 32), 1.0, dtype=tl.bfloat16) + x = tl.sum(x) + tl.store(out_ptr, x) + + out = torch.empty(1, dtype=torch.int32, device=device) + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int1, out_dtype=None) + assert out[0] == 32 * 32 + + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int1, out_dtype=tl.int1) + assert out[0] == 0 + + kernel_dtype[(1, )](out, init=7, in_dtype=tl.int8, out_dtype=tl.int8) + assert out[0] == (7 * 32 * 32) % 256 + + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int32, out_dtype=None) + assert out[0] == 32 * 32 + + kernel_default_int[(1, )](out) + assert out[0] == 32 * 32 + + out = torch.empty(1, dtype=torch.bfloat16, device=device) + kernel_default_float[(1, )](out) + torch.testing.assert_close(out[0], torch.tensor(32 * 32, dtype=torch.bfloat16, device=device)) + + +@triton.jit +# trivial associative but not commutative function +def get_first_element(a, b): + return a + + +# Compute x_i = a_i * x_{i-1} + b_i +@triton.jit +def linear_recurrence(a1, b1, a2, b2): + return a1 * a2, b1 * a2 + b2 + + +@triton.jit +def cummax(v0, i0, v1, i1): + gt = v0 > v1 + return tl.where(gt, v0, v1), tl.where(gt, i0, i1) + + +@triton.jit +def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): + return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) +def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): + check_type_supported(dtype_str, device) + if dtype_str == 'bfloat16': + if op == 'cummax': + pytest.skip("bfloat16 compare not supported before sm90") + if op == 'linear_recurrence': + pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") + numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + + # triton kernel + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + y = tl.load(Y + range_m[:, None] * BLOCK_N + range_n[None, :]) + GENERATE_TEST_HERE + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'z = tl.{op}(x, axis={axis}, reverse={reverse})'}) + elif op == 'get_first_element': + kernel = patch_kernel( + kernel, + {'GENERATE_TEST_HERE': f'z = tl.associative_scan(x, axis={axis}, combine_fn={op}, reverse={reverse})'}) + elif op == 'cummax': + rg = "range_m[:, None]" if axis == 0 else "range_n[None, :]" + rg = f"tl.broadcast_to({rg}.to(tl.int64), [BLOCK_M, BLOCK_N])" + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, {rg}), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + elif op == 'roll': + assert op == 'roll' + kernel = patch_kernel( + kernel, { + 'GENERATE_TEST_HERE': + f'_, z, _ = tl.associative_scan((1 + 0* x, 0 * x, x), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + else: + assert op == 'linear_recurrence' + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, y), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + # input + rs = RandomState(17) + if op == 'linear_recurrence' and dtype_str in int_dtypes: + # If the numbers are too large the op will overflow + # We sample numbers in -1, 0, 1 + x = rs.randint(-1, 2, shape, dtype=dtype_str) + y = rs.randint(-1, 2, shape, dtype=dtype_str) + else: + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + # y is just used in linear_recurrence + y = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_in = x + if reverse: + x_in = np.flip(x, axis) + z = np.empty_like(x) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + y_tri = to_triton(y, device=device, dst_type=dtype_str) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_ref = numpy_op(x_in, axis=axis).astype(getattr(np, numpy_dtype_str)) + if reverse: + z_ref = np.flip(z_ref, axis) + + elif op == 'cummax': + # NumPy does not have cummax + z = z.astype(np.int64) + z_ref = torch.cummax(torch.from_numpy(x_in.copy()), axis=axis).indices.numpy() + if reverse: + z_ref = x_in.shape[axis] - np.flip(z_ref, axis) - 1 + elif op == 'roll': + ROLL = 1 + z_ref = np.roll(x_in.copy(), ROLL, axis=axis) + if axis == 0: + z_ref[:ROLL] = 0 + else: + z_ref[:, :ROLL] = 0 + + if reverse: + z_ref = np.flip(z_ref, axis) + elif op == 'linear_recurrence': + # Simplify to the axis=1 case + x_ref = x.T if axis == 0 else x + y_ref = y.T if axis == 0 else y + if reverse: + x_ref = np.flip(x_ref, 1) + y_ref = np.flip(y_ref, 1) + + result = [] + for x_refi, y_refi in zip(x_ref, y_ref): + li = [] + acc = 0 + for xi, yi in zip(x_refi, y_refi): + acc = xi * acc + yi + li.append(acc) + result.append(li) + z_ref = np.array(result) + if reverse: + z_ref = np.flip(z_ref, 1) + + if axis == 0: + z_ref = z_ref.T + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + if reverse: + z_ref[:-1] = x[-1] + else: + z_ref[1:] = x[0] + else: + if reverse: + z_ref[:, :-1] = x[:, -1:] + else: + z_ref[:, 1:] = x[:, 0:1] + + # triton result + # we don't cast the `fp32 = bf16 op bf16` result to bfloat16 to alleviate accuracy issues + z_tri = to_triton(z, device=device) + kernel[(1, )](x_tri, y_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + + z_tri = to_numpy(z_tri) + # compare + if dtype_str not in int_dtypes: + if op == 'cumprod': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# --------------- +# test histogram +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) +def test_histogram(M, N, device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + offset1) + z = tl.histogram(x, N) + bias = tl.full([M, N], 1, dtype=tl.int32) + # check that histogram produces object compatible with broadcasting + biased = z + bias + tl.store(z_ptr + offset2, z) + + torch.manual_seed(17) + x = torch.randint(0, N, (M, ), device=device, dtype=torch.int32) + z = torch.empty(N, dtype=torch.int32, device=device) + # torch.histc does not work when the input type is not float and the device is CPU + # https://github.com/pytorch/pytorch/issues/74236 + # This is a workload by converting the input to float + z_torch = torch.histc(x.float(), bins=N, min=0, max=N - 1) + histogram_kernel[(1, )](x, z, M=M, N=N) + assert (z_torch == z).all() + + +@pytest.mark.parametrize("M, N", [(1, 64), (2, 32), (4, 16), (8, 8), (16, 4), (32, 2), (64, 1)]) +def test_scan_1d(M, N, device): + + @triton.jit + def scan_kernel(out_ptr, in_ptr, M: tl.constexpr, N: tl.constexpr): + input = tl.load(in_ptr + tl.arange(0, M)) + output = tl.cumsum(input).reshape([1, M]).broadcast_to([N, M]) + tl.store(out_ptr + tl.arange(0, M * N), output.reshape([M * N])) + + x = torch.randint(-100, 100, (M, ), dtype=torch.int32, device=device) + output = torch.empty(M * N, dtype=torch.int32, device=device) + + scan_kernel[(1, )](output, x, M, N) + + ref = torch.cumsum(x, dim=0).reshape([1, M]).broadcast_to([N, M]).reshape([M * N]) + torch.testing.assert_close(ref.to(torch.int32), output, atol=0, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['sum', 'max', 'min']) +@pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) +@pytest.mark.parametrize("N", [512, 1024, 2048]) +@pytest.mark.parametrize("num_pid_n", [2, 4]) +def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device): + + @triton.jit + def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + start_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_n = tl.num_programs(1) + local = INITIALIZE_PATCH + off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n): + off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * N + off_n[None, :] + x = tl.load(Xs) + local = ACCUMULATE_PATCH + tl.store(Y + off_m * num_pid_n + pid_n, local) + + initialize_patch = { + 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', + 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', + 'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)', + }[op] + reduce_patch = { + 'sum': 'local + tl.sum(x, axis=1)', + 'max': 'tl.maximum(local, tl.max(x, axis=1))', + 'min': 'tl.minimum(local, tl.min(x, axis=1))', + }[op] + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + }[op] + kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch}) + torch.manual_seed(0) + BLOCK_M = 32 + x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) + y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N) + if not is_interpreter(): + assert h.asm['ttgir'].count( + '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) + y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) + np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) + + +scan_layouts = [ + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) +@pytest.mark.parametrize("src_layout", scan_layouts) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("add_overflow_check", [False, True]) +def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path): + + overflow_check = """ + %17 = arith.extsi %arg2 : i32 to i64 + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.addi %17, %18 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %20 = arith.cmpi slt, %19, %i32.max : i64 + %21 = arith.cmpi sge, %19, %i32.min : i64 + %22 = arith.andi %20, %21 : i1 + tt.assert %22, "overflow detected" : i1 + """ + + ir = f""" + #blocked = {src_layout} + module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> + %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ + ^bb0(%arg2: i32, %arg3: i32): + %16 = arith.addi %arg2, %arg3 : i32{overflow_check if add_overflow_check else ""} + tt.scan.return %16 : i32 + }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %14 = tt.broadcast %13 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + tt.store %15, %11 : tensor<{M}x{N}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + + temp_file = tmp_path / "test_scan_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = rs.randint(-100, 100, (M, N)).astype('int32') + + z = np.zeros((M, N)).astype('int32') + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + kernel[(1, 1, 1)](x_tri, z_tri) + + z_ref = np.cumsum(x, axis=axis) + + np.testing.assert_equal(z_ref, z_tri.cpu().numpy()) + + +layouts = [ + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True), + WmmaLayout(version=1, warps_per_cta=[4, 1]), + WmmaLayout(version=1, warps_per_cta=[1, 4]), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 4], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([3, 0], [8, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + # FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved + # SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])), + # SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])), + SliceLayout(dim=0, parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8])), + SliceLayout( + dim=1, parent=DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), + op_idx=1, k_width=2)), + LinearLayout(register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], + [0, 8]], warp=[[32, 0], [0, 32]], + block=[]), +] + + +@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32], [16, 16]]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) +@pytest.mark.parametrize("dtype_str,add_overflow_check", [("int32", False), ("int32", True), ("float32", False), + ("float16", False)]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_overflow_check, reduce_op, device, + tmp_path: pathlib.Path): + if isinstance(src_layout, + (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): + pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") + if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)): + pytest.skip("Skipping test because it runs out of shared memory") + if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: + pytest.skip("Skipping sum reduction on float16 due to accuracy issues") + if is_hip() and isinstance(src_layout, LinearLayout): + pytest.skip("FIXME: LinearLayout not supported on HIP") + + if isinstance(src_layout, MmaLayout) and src_layout.version == 3: + src_layout.instr_shape[2] = 16 if dtype_str == "float16" else 8 + + overflow_check = """ + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.extsi %arg4 : i32 to i64 + %20 = arith.addi %18, %19 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %21 = arith.cmpi slt, %20, %i32.max : i64 + %22 = arith.cmpi sge, %20, %i32.min : i64 + %23 = arith.andi %21, %22 : i1 + tt.assert %23, "overflow detected" : i1 + """ + + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] + rdims_1d = f"{N}" if axis == 0 else f"{M}" + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" + store_range = "%7" if axis == 0 else "%1" + warps = warps_per_cta(src_layout, [M, N]) + num_warps = np.prod(warps) + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1]) + one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0], [1], [1], [0]) + + expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1" + other_axis = 1 - axis + epilogue = { + "reduce1d": + f""" + %14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + tt.return + }} + }} + """, "reduce2d": + f""" + %14 = "tt.reduce"(%13) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + tt.reduce.return %17 : {ty} + }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} + tt.store %arg2, %14 : !tt.ptr<{ty}> + tt.return + }} + }} + """, "expand_reduce2d": + f""" + %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> + %15 = "tt.reduce"(%14) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + tt.reduce.return %17 : {ty} + }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) + %16 = ttg.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> + %17 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.store %17, %16 : tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.return + }} + }} + """ + }[epilogue_kind] + + ir = f""" + #blocked = {blocked} + #src = {src_layout} + #one_d_layout = {one_d_layout} + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + tt.reduce.return %17 : {ty} + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> + """ + epilogue + + temp_file = tmp_path / "test_reduce_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + reduce2d = 'reduce2d' in epilogue_kind + z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1) + z = np.zeros(z_shape).astype(dtype_str) + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x) if reduce2d else numpy_op(x, axis=axis, keepdims=True) + + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.parametrize("M", [32, 64, 128, 256]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): + + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + tt.store %8, %4 : tensor<{M}x1x!tt.ptr, #src> + tt.return + }} + }} + """ + + temp_file = tmp_path / "test_store_op.ttgir" + temp_file.write_text(ir) + store_kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, 1)).astype('float32') + y = np.zeros((M, 1), dtype='float32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + + pgm = store_kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +layouts = [ + # TODO (lixun): Add MfmaLayout + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), +] + + +@pytest.mark.parametrize("M", [64, 128, 256]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("src_dim", [0, 1]) +@pytest.mark.parametrize("dst_dim", [0, 1]) +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path: pathlib.Path): + + ir = f""" + #dst = {dst_layout} + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %7 = {GPU_DIALECT}.convert_layout %3 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.store %6, %7 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.return + }} + }} + """ + temp_file = tmp_path / "test_convert1d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, )).astype('int32') + y = np.zeros((M, ), dtype='int32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + pgm = kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@triton.jit +def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = weight_2 / new_weight + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + # [HIP] TO DO: some tests are flaky with the layout, so turn off them for now. + # BlockedLayout([1, 4], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) +] + + +@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]]) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("op", ["sum", "max"]) +@pytest.mark.parametrize("first_axis", [0, 1]) +def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathlib.Path): + + op_str = "" + if op == "sum": + op_str = """ + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32""" + elif op == "max": + op_str = """ + %13 = arith.cmpi "sgt", %arg2, %arg3 : i32 + %14 = arith.select %13, %arg2, %arg3 : i32 + tt.reduce.return %14 : i32""" + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %5 = tt.broadcast %2 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %6 = tt.broadcast %4 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #src> + %11 = "tt.reduce"(%10) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>> + %12 = "tt.reduce"(%11) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32 + tt.store %arg1, %12 : !tt.ptr + tt.return + }} + }} + """ + temp_file = tmp_path / "test_chain_reduce.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, N)).astype('int32') + + z = np.zeros((1, )).astype('int32') + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, z_tri) + if op == "sum": + z_ref = np.sum(x) + elif op == "max": + z_ref = np.max(x) + + np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@pytest.mark.interpreter +def test_generic_reduction(device): + + @triton.jit + def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): + xindex = tl.arange(0, BLOCK) + x = tl.load(X + xindex) + mean = x + m2 = tl.zeros_like(x) + weight = tl.full(x.shape, 1, x.dtype) + (mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine) + tl.store(out_mean, mean) + tl.store(out_var, m2 / weight) + + SIZE = 512 + x = torch.rand(SIZE, device=device) + out_mean = torch.empty((), device=device) + out_var = torch.empty((), device=device) + + var_mean_kernel[(1, )](x, out_mean, out_var, BLOCK=SIZE) + + expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0) + torch.testing.assert_close(out_mean, expect_mean) + torch.testing.assert_close(out_var, expect_var) + + +# --------------- +# test permute +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_permute(dtype_str, shape, perm, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if dtype_str == "float8e4b15" and (is_hip() or (is_cuda() and torch.cuda.get_device_capability() >= (9, 0))): + pytest.skip("float8e4b15 not supported on ROCm or CUDA >= 9.0") + if is_hip(): + if shape == (128, 128) and dtype_str == 'float32': + pytest.skip("TODO Out of LDS for float32 with shape 128x128") + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + + # input + x = numpy_random(shape, dtype_str=dtype_str) + # triton result + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + # numpy result + if dtype_str == 'float8e4b15': + ty = tl.float8e4b15 + z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty) + z_tri = z_tri.base + z_tri_contiguous = z_tri_contiguous.base + else: + z_ref = x.transpose(*perm) + # compare + np.testing.assert_allclose(to_numpy(z_tri), z_ref) + np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref) + + if not is_cuda(): + return + + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1]))) +def test_trans_2d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: tl.constexpr, + ou_shape2: tl.constexpr, trans1: tl.constexpr, trans2: tl.constexpr): + in_offs = tl.arange(0, in_shape1)[:, None] * in_shape2 + tl.arange(0, in_shape2)[None, :] + ou_offs = tl.arange(0, ou_shape1)[:, None] * ou_shape2 + tl.arange(0, ou_shape2)[None, :] + tl.store(Out + ou_offs, tl.permute(tl.load(In + in_offs), (trans1, trans2))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1, 2, 3]))) +def test_trans_4d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, # + in_shape1: tl.constexpr, in_shape2: tl.constexpr, in_shape3: tl.constexpr, in_shape4: tl.constexpr, + ou_shape1: tl.constexpr, ou_shape2: tl.constexpr, ou_shape3: tl.constexpr, ou_shape4: tl.constexpr, + trans1: tl.constexpr, trans2: tl.constexpr, trans3: tl.constexpr, trans4: tl.constexpr): + in_ptr = tl.make_block_ptr( + base=In, + shape=(in_shape1, in_shape2, in_shape3, in_shape4), + strides=(in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(in_shape1, in_shape2, in_shape3, in_shape4), + order=(3, 2, 1, 0), + ) + out_ptr = tl.make_block_ptr( + base=Out, + shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + strides=(ou_shape4 * ou_shape3 * ou_shape2, ou_shape4 * ou_shape3, ou_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + order=(3, 2, 1, 0), + ) + tl.store(out_ptr, tl.load(in_ptr).permute((trans1, trans2, trans3, trans4))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm, num_warps=8) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# --------------- +# test dot +# --------------- + + +def convert_fp8_to_fp32(x, device, dtype_str): + if dtype_str == 'float8e4nv': + return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32) + elif dtype_str == 'float8e5': + return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32) + elif dtype_str == 'float8e4b8': + return torch.tensor(x, device=device).view(torch.float8_e4m3fnuz).to(torch.float32) + elif dtype_str == 'float8e5b16': + return torch.tensor(x, device=device).view(torch.float8_e5m2fnuz).to(torch.float32) + assert "Unsupported float8 dtype" + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +def get_test_dot_base_cases(): + return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for input_precision in ['tf32', 'tf32x3', 'ieee'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +def get_test_dot_mixed_sizes_cases(): + available_kpack = [1, 2 if is_hip() else 1] + available_precision = ["tf32" if is_cuda() else "ieee"] + return [ + (*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack, None) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for input_precision in available_precision + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', + 'float32'), ('float32', 'float32')] + for kpack in available_kpack + ] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #2370 +def get_test_dot_transposed_op_base_cases(): + return [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1, None) + for col_a in [True, False] + for col_b in [True, False]] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# Introduced in #2750 +def get_test_dot_h100_shortcut_cases(): + return [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1, None)] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #3908 +def get_test_dot_mfma_edge_cases(): + if not is_hip_cdna(): + return [] + return [(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1, None), + (32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1, None)] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #3370 +def get_test_dot_fp8_output_cases(): + return [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1, None) + for float8_type in ["float8e5", "float8e4nv"]] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #5406 +def get_test_dot_small_k_mfma_cases(): + if not is_hip_cdna(): + return [] + return [(32, 32, k_size, 4, False, False, 'None', 'ieee', in_dtype, out_dtype, 1, mma_nonk_size) + for k_size in [1, 2, 4, 8] + for in_dtype, out_dtype in [('float16', 'float32'), ('int8', 'int32')] + for mma_nonk_size in mma_nonk_sizes] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #4516 +def get_test_dot_small_mn_fma_cases(): + return [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1, None) + for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]] + + +def get_test_dot_double_rate_cases(): + if not is_hip_cdna(): + return [] + return [(32, 32, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + (32, 32, 16, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None), + (16, 16, 32, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + (16, 16, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size", + get_test_dot_double_rate_cases() + \ + get_test_dot_base_cases() + \ + get_test_dot_mixed_sizes_cases() + \ + get_test_dot_transposed_op_base_cases() + \ + get_test_dot_h100_shortcut_cases() + \ + get_test_dot_mfma_edge_cases() + \ + get_test_dot_fp8_output_cases() + \ + get_test_dot_small_k_mfma_cases() + \ + get_test_dot_small_mn_fma_cases()) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size, + num_ctas, device): + if is_interpreter(): + if in_dtype == 'bfloat16': + pytest.skip("bfloat16 is not supported in the interpreter") + else: + if not is_hip() and (M < 16 or N < 16 or K < 16): + pytest.skip("small dots are supported only on HIP at the moment") + if is_cuda(): + capability = torch.cuda.get_device_capability() + + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if input_precision != "ieee": + pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: + pytest.skip("shared memory out of resource") + if out_dtype == 'float16': + # TODO: support out_dtype=float16 for tl.dot on V100 + pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") + + if is_hip(): + if in_dtype in ("float8e5", "float8e4nv") and not is_hip_mi350(): + pytest.skip(f"{in_dtype} only supported on mi350") + if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_mi300(): + pytest.skip(f"{in_dtype} only supported on mi300") + if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_mi300())): + pytest.skip(f"{input_precision} not supported on HIP") + if kpack == 2 and in_dtype == 'int8' and K < 64: + pytest.skip("kpack too large for K") + if not is_hip() and kpack == 2: + pytest.skip("Skip duplicated tests on nv path") + + torch.backends.cuda.matmul.allow_tf32 = input_precision == "tf32" + + if num_ctas > 1 and in_dtype == 'int8': + # FIXME: mma v2 with num_ctas > 1 does not work + pytest.skip() + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, INPUT_PRECISION: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + z = tl.dot(x, y, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + if ADD_MATRIX: + z += tl.load(Zs) + if ADD_ROWS: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(Ws) + z = tl.dot(z.to(w.dtype), w, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + tl.store(Zs, z) + + # input + rs = RandomState(17) + if col_a: + x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T + else: + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + if col_b: + y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T + else: + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + w = numpy_random((N, N), dtype_str=in_dtype, rs=rs) + if 'int' not in in_dtype and 'float8' not in in_dtype: + x *= .1 + y *= .1 + if in_dtype == 'float32' and input_precision == "tf32": + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') + x_tri = to_triton(x, device=device, dst_type=in_dtype) + y_tri = to_triton(y, device=device, dst_type=in_dtype) + w_tri = to_triton(w, device=device, dst_type=in_dtype) + # triton result + if out_dtype == 'int8': + z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs) + else: + z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1 + + z_tri = to_triton(z, device=device) + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) + + if out_dtype == 'int8': + out_dtype = tl.int8 + elif out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + kern_kwargs = { + 'COL_A': col_a, 'COL_B': col_b, 'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N, 'ADD_MATRIX': + epilogue == 'add-matrix', 'ADD_ROWS': epilogue == 'add-rows', 'ADD_COLS': epilogue == 'add-cols', 'DO_SOFTMAX': + epilogue == 'softmax', 'CHAIN_DOT': epilogue == 'chain-dot', 'INPUT_PRECISION': input_precision, 'num_warps': + num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype + } + + if is_hip(): + kern_kwargs['kpack'] = kpack + if mma_nonk_size is not None: + kern_kwargs['matrix_instr_nonkdim'] = mma_nonk_size + + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) + + # torch result + if in_dtype == 'int8': + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) + elif 'float8' in in_dtype: + x = convert_fp8_to_fp32(x, device, in_dtype) + y = convert_fp8_to_fp32(y, device, in_dtype) + z_ref = to_numpy(torch.matmul(x, y)) + else: + z_ref = np.matmul(x, y) + + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:, 0][:, None] + if epilogue == 'add-cols': + z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + if 'float8' in in_dtype: + # Reduce z_ref's precision to fp8 to match the kernel behavior + if in_dtype == 'float8e4nv': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fn) + elif in_dtype == 'float8e5': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2) + elif in_dtype == 'float8e4b8': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fnuz) + elif in_dtype == 'float8e5b16': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2fnuz) + else: + assert "Unsupported float8 dtype" + z_ref = to_numpy(z_fp8.to(torch.float32)) + w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) + z_ref = np.matmul(z_ref, w) + # compare + if in_dtype == 'float32': + # XXX: Somehow there's a larger difference when we use float32 + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + elif out_dtype == tl.float16 or in_dtype == 'bfloat16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + # added atol, to loose precision for float16xfloat16->float32 case + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + if not is_cuda(): + return + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): + # XXX: skip small sizes because they are not vectorized + assert 'ld.global.v4' in ptx + if 'float8' in in_dtype: + assert 'st.global.v2' in ptx + else: + assert 'st.global.v4' in ptx + + is_tcgen5 = (capability[0] == 10) and (num_warps % 4) == 0 and (M % 64) == 0 and (N % 8) == 0 + + if in_dtype == 'float32' and input_precision != "ieee": + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float32: + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) + elif capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float16: + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) + elif capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) + elif in_dtype == 'int8': + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + elif in_dtype == "float8e5" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx + elif in_dtype == "float8e4nv" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx + + +@pytest.mark.parametrize("M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, num_warps, mma, kpack", + [(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, 4, mma, kpack) + for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) + for col_a, col_b in itertools.product([True, False], repeat=2) + for rhs_scale in [False, True] + for mxfp_type in ["e2m1", "e4m3", "e5m2"] + for normal_type in ["e4m3", "e5m2", "bf16", "fp16"] + for mma in (mma_nonk_sizes if is_hip() else [16]) + for kpack in ([1, 2] if is_hip() else [1])]) +def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, num_warps, mma, kpack, device): + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc < (8, 9): + pytest.skip("float8e4nv not supported on CUDA < 8.9") + if is_hip(): + if not is_hip_cdna(): + pytest.skip("scaled_dot only implemented for HIP CDNA") + if "e4m3" in (mxfp_type, normal_type): + if not (is_hip_mi300() or is_hip_mi350()): + pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for MI300 and MI350") + if mma == 16 and K == 64: + pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") + + @triton.jit + def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, + type_b: tl.constexpr): + DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 + DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B + a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, + PACKED_BLOCK_K_A)[None, :] * stride_a1 + b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, + BLOCK_N)[None, :] * stride_b1 + + a = tl.load(a_ptr) + b = tl.load(b_ptr) + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + if a_scale is not None: + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + a_scale = tl.load(scale_a_ptr) + if b_scale is not None: + scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + b_scale = tl.load(scale_b_ptr) + c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b) + out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + tl.store(out_ptr, c.to(tl.bfloat16)) + + @triton.jit + def mxfp_upcast_kernel( + x_ptr, + scale_ptr, + mxfp_ptr, + N, + e_bits: tl.constexpr, + m_bits: tl.constexpr, + to_type: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + # x.shape == (N, 32) for fp8 or (N, 16) for fp4 + # scale.shape == (N,) + # out.shape == (N, 32) + is_fp8: tl.constexpr = e_bits + m_bits == 7 + # fp8: BLOCK_SIZE -> BLOCK_SIZE // 32, 32 + # fp4: BLOCK_SIZE // 2 -> BLOCK_SIZE // 32 , 16 + PARALLEL_DIM: tl.constexpr = BLOCK_SIZE // 32 + LAST_DIM: tl.constexpr = 32 if is_fp8 else 16 + LOAD_SIZE: tl.constexpr = LAST_DIM * PARALLEL_DIM + + offsets = (tl.program_id(0) * LOAD_SIZE + tl.arange(0, PARALLEL_DIM)[:, None] * LAST_DIM + + tl.arange(0, LAST_DIM)[None, :]) + x = tl.load(x_ptr + offsets, mask=offsets < N * LAST_DIM) + + offsets = tl.program_id(0) * PARALLEL_DIM + tl.arange(0, PARALLEL_DIM)[:, None] + scale = tl.load(scale_ptr + offsets, mask=offsets < N) + tl.static_assert(scale.dtype == tl.uint8) + tl.static_assert(x.dtype == tl.uint8) + + if to_type == tl.bfloat16: + upcasted_scale = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) + else: + tl.static_assert(to_type == tl.float16) + scale_fp32 = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + upcasted_scale = scale_fp32.to(tl.float16) + + to_e_bits: tl.constexpr = 8 if to_type == tl.bfloat16 else 5 + to_m_bits: tl.constexpr = 7 if to_type == tl.bfloat16 else 10 + if is_fp8: + if e_bits == 5 and m_bits == 2: + x_f8 = x.to(tl.float8e5, bitcast=True) + upcasted_x = x_f8.to(to_type) + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits + non_finite_mask_16bit: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits + upcasted_x = tl.where( + x & non_finite_mask == non_finite_mask, + (upcasted_x.to(tl.uint16, bitcast=True) | non_finite_mask_16bit).to(to_type, bitcast=True), + upcasted_x, + ) + else: + tl.static_assert(e_bits == 4 and m_bits == 3) + x_f8 = x.to(tl.float8e4nv, bitcast=True) + upcasted_x = x_f8.to(to_type) + else: + to_bias: tl.constexpr = 127 if to_type == tl.bfloat16 else 15 + to_point5: tl.constexpr = 16128 if to_type == tl.bfloat16 else 0x3800 + # e2m1 + em0 = x & 0x7 + em1 = x & 0x70 + x0 = (em0.to(tl.uint16) << (to_m_bits - 1)) | ((x & 0x8).to(tl.uint16) << 12) + x1 = (em1.to(tl.uint16) << (to_m_bits - 1 - 4)) | ((x & 0x80).to(tl.uint16) << 8) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x6) != 0, x0 + ((to_bias - 1) << to_m_bits), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((to_bias - 1) << to_m_bits), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 + x0 = tl.where(em0 == 0x1, to_point5 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, to_point5 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + upcasted_x = tl.interleave(x0, x1).to(to_type, bitcast=True) + # Multiplication preserves infs and NaNs in upcasted_x + mxfp = upcasted_x * upcasted_scale + # If scale is NaN, we encode it as an inf, so we need to correct for that + mxfp = tl.where(scale == 0xFF, float("nan"), mxfp) + + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) + + def dot_scale_ref(x, scale_x, y, scale_y, type_x, type_y): + + def upcast(v, scale, type, comp_dtype, transposed): + if scale is None: + type = { + "e4m3": torch.float8_e4m3fn, + "e5m2": torch.float8_e5m2, + "bf16": torch.bfloat16, + "fp16": torch.float16, + }[type] + return v.view(type).to(comp_dtype) + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type] + # Packing is always on the K dimension so we transpose before upcasting then transpose back. + if transposed: + v = v.mT.contiguous() + v = v.contiguous() + v_upcast = v.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) + N = v_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + comp_dtype = tl.float16 if comp_dtype == torch.float16 else tl.bfloat16 + mxfp_upcast_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, comp_dtype, BLOCK_SIZE, + num_warps=num_warps) + assert v_upcast.isfinite().all() + if transposed: + v_upcast = v_upcast.mT + return v_upcast + + # Upcast to fp16 if one of the input is fp16 + comp_dtype = torch.float16 if "fp16" in (type_x, type_y) else torch.bfloat16 + + x_upcast = upcast(x, scale_x, type_x, comp_dtype, False) + y_upcast = upcast(y, scale_y, type_y, comp_dtype, True) + + class AccumulateInFp32: + + def __enter__(self): + self.prev_value = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value + + with AccumulateInFp32(): + return torch.matmul(x_upcast, y_upcast) + + comp_dtype = torch.float16 if normal_type == "fp16" else torch.bfloat16 + # The max exponent we use to initialize data in the x/y and associated scale tensor to avoid + # overflow when scaling. + comp_dtype_max_exp = 6 if normal_type == "fp16" else 15 + + torch.manual_seed(0) + + def make_arg(shape, ty, col_major=False): + if col_major: + shape = shape[:-2] + (shape[-1], shape[-2]) + if ty == "bf16" or ty == "fp16": + ret = torch.randn(shape, dtype=comp_dtype, device=device) + # Clamp to avoid relative error issues + ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1) + else: + if is_hip_mi350(): + # On other chips, the A/B operands are upcasted to fp16/bf16 + # before matmul, which has larger range to avoid overflow. + # On MI350, we use the V_MFMA_*_F8F6F4 instructions to + # directly calculate matmul on F8F6F4 data. So we need + # to narrow down the range of input to avoid overflow. + ret = torch.randint(20, 40, shape, dtype=torch.uint8, device=device) + else: + ret = torch.randint(256, shape, dtype=torch.uint8, device=device) + if col_major: + ret = ret.mT + return ret + + type_a = normal_type if rhs_scale else mxfp_type + type_b = mxfp_type if rhs_scale else normal_type + + DIV_FACTOR_A = 2 if type_a == "e2m1" else 1 + DIV_FACTOR_B = 2 if type_b == "e2m1" else 1 + x = make_arg((M, K // DIV_FACTOR_A), type_a, col_major=col_a) + y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b) + + min_scale, max_scale = (0, 142) if comp_dtype == torch.bfloat16 else (124, 131) + scale_x = torch.randint(min_scale, max_scale + 1, (M, K // 32), dtype=torch.uint8, device=device) + scale_y = torch.randint(min_scale, max_scale + 1, (N, K // 32), dtype=torch.uint8, device=device) + if rhs_scale: + scale_x = None + else: + scale_y = None + + def make_finite(x, dtype): + # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and + # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme) + if dtype not in ("e5m2", "e4m3"): + return x + if dtype == "e5m2" and comp_dtype == torch.float16: + x = x & 0xB + mask = 0x7C if dtype == "e5m2" else 0x7F + finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask + x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x) + x.copy_(x_finite) + return x + + x = make_finite(x, type_a) + y = make_finite(y, type_b) + kernel_kwargs = {"num_warps": num_warps} + if is_hip(): + kernel_kwargs["kpack"] = kpack + kernel_kwargs["matrix_instr_nonkdim"] = mma + z = x.new_empty((M, N), dtype=comp_dtype) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, + **kernel_kwargs) + z_ref = dot_scale_ref(x, scale_x, y, scale_y, type_a, type_b) + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values + # to zero. Detailed info is at: + # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + atol = 2e-4 if is_hip_mi200() else 1e-5 + rtol = 2e-2 if is_hip_mi200() else 1e-2 + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) + + # make sure ld/st are vectorized + if is_cuda(): + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert (re.search(r'(mma|wgmma.mma_async).sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.(f|bf)16.(f|bf)16', ptx) + or "tcgen05.mma.cta_group::1.kind::f16" in ptx) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str", + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 4, 8] + for num_warps in [1, 2, 4, 8, 16] + for BLOCK_M, BLOCK_N in [(32, 32)] + for M, N, K in [(64, 64, 64), (32, 32, 32)] + for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), + ('float16', 'float32'), ('float32', 'float32')]] + + # Large block sizes + [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] + + # Small block sizes + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 8] + for num_warps in [1, 2, 4] + for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)] + for M, N, K in [(32, 32, 32)] + for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]]) +def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device): + if is_hip(): + # hip does not support tf32 precision, so use ieee for all tests + input_precision = "ieee" + arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in arch or "gfx12" in arch: + if in_dtype_str == "float32": + pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") + if out_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + else: + input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" + if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16): + pytest.skip("small dots are supported only on HIP at the moment") + + if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": + if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties( + triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 131072: + pytest.skip( + "Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU." + ) + + @triton.jit + def kernel( + q_ptr, + k_ptr, + o_ptr, + stride_qb, + stride_qm, + stride_qk, + stride_kb, + stride_kk, + stride_kn, + stride_ob, + stride_om, + stride_on, + BLOCK_B: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + INPUT_PRECISION: tl.constexpr, + out_dtype: tl.constexpr = tl.float32, + ): + startm = tl.program_id(0) * BLOCK_M + startn = tl.program_id(1) * BLOCK_N + offs_b = tl.arange(0, BLOCK_B) + offs_m = startm + tl.arange(0, BLOCK_M) + offs_n = startn + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + q_ptrs = q_ptr + offs_b[:, None, None] * stride_qb + offs_m[None, :, None] * stride_qm + offs_k[ + None, None, :] * stride_qk + k_ptrs = k_ptr + offs_b[:, None, None] * stride_kb + offs_k[None, :, None] * stride_kk + offs_n[ + None, None, :] * stride_kn + q = tl.load(q_ptrs) + k = tl.load(k_ptrs) + qk = tl.dot(q, k, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + o_ptrs = o_ptr + offs_b[:, None, None] * stride_ob + offs_m[None, :, None] * stride_om + offs_n[ + None, None, :] * stride_on + tl.store(o_ptrs, qk) + + if out_dtype_str == 'int8': + out_dtype = tl.int8 + elif out_dtype_str == 'float16': + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + rs = RandomState(17) + x = numpy_random((B, M, K), dtype_str=in_dtype_str, rs=rs) + y = numpy_random((B, K, N), dtype_str=in_dtype_str, rs=rs) + if in_dtype_str == 'int8': + out = numpy_random((B, M, N), dtype_str='int32', rs=rs) + else: + if is_hip() and (BLOCK_M < 16 or BLOCK_N < 16) and out_dtype_str == 'float16': + # float16 accumulator in FMA dot loose precision too fast + x *= 0.1 + y *= 0.1 + out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) + + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + out_tri = to_triton(out, device=device) + + BLOCK_B = B + BLOCK_K = K + + grid = ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + kernel[grid]( + x_tri, + y_tri, + out_tri, + x_tri.stride(0), + x_tri.stride(1), + x_tri.stride(2), + y_tri.stride(0), + y_tri.stride(1), + y_tri.stride(2), + out_tri.stride(0), + out_tri.stride(1), + out_tri.stride(2), + BLOCK_B=BLOCK_B, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + INPUT_PRECISION=input_precision, + out_dtype=out_dtype, + num_warps=num_warps, + ) + + if in_dtype_str == 'int8': + out_ref = np.matmul(x.astype(np.float32), y.astype(np.float32)).astype(np.int32) + else: + out_ref = np.matmul(x, y) + np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) + + +@pytest.mark.parametrize('in_dtype', ['float32']) +def test_dot_mulbroadcasted(in_dtype, device): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Requires sm >= 80 to run") + + @triton.jit + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): + pidn = tl.program_id(1) + pidm = tl.program_id(0) + offm = tl.arange(0, BM)[:, None] + offn = tl.arange(0, BN)[None, :] + offak = tl.arange(0, BK)[None, :] + offbk = tl.arange(0, BK)[:, None] + acc = tl.full((BM, BN), 0.0, tl.float32) + for ridx5 in range(0, K // BK): + x = tl.load(X + ((pidm * K * BM) + (offm * K) + (ridx5 * BK) + offak)) + y = tl.load(Y + ((pidn * BN) + (offbk * N) + (ridx5 * N * BK) + offn)) + x = tl.expand_dims(x, axis=2) + y = tl.expand_dims(y, axis=0) + t = tl.sum(x * y, axis=1) + acc = t + acc + tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + + M, N, K = 256, 192, 160 + BM, BN, BK = 128, 32, 32 + rs = RandomState(17) + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + x = x * 0.1 + y = y * 0.1 + z = numpy_random((M, N), dtype_str=in_dtype, rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(z, device=device) + grid = M // BM, N // BN + h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK) + z_ref = np.matmul(x, y) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) + + if not is_cuda(): + return + assert "tt.dot" in h.asm['ttir'] + assert re.search(r"ttg.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) +def test_full(dtype_str, shape, device): + if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): + # PyTorch only has unsigned 8, but not 16, 32, or 64 + dtype = getattr(torch, dtype_str[1:]) # uintx -> intx + else: + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel_static(out): + a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + @triton.jit + def kernel_dynamic(out, val, dtype: tl.constexpr): + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) + out_static = torch.zeros((128), dtype=dtype, device=device) + kernel_static_patched[(1, )](out_static) + assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) + assert torch.all(out_dynamic == 2) + + +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) +def test_constexpr(literal, dtype_str, device): + + @triton.jit + def kernel(out_ptr): + val = GENERATE_TEST_HERE + tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) + + kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched[(1, )](out) + assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None + + +@triton.jit +def pass_const(a, b, choose_b): + if choose_b: + return b + else: + return a + + +@pytest.mark.parametrize("choose_const", [True, False]) +@pytest.mark.parametrize("constexpr", [True, False]) +@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) +def test_const(device, choose_const, constexpr, mode): + + @triton.jit(do_not_specialize=["choose_const"]) + def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + @triton.jit + def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32, + BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + if mode == "direct": + if choose_const: + LOSE_TAIL = "final_out = c_out" + else: + LOSE_TAIL = "final_out = out" + elif mode == "call": + LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)" + elif mode == "ternary": + LOSE_TAIL = "final_out = c_out if choose_const else out" + elif mode == "if": + LOSE_TAIL = """ + if choose_const: + final_out = c_out + else: + final_out = out +""" + + SIZE = 128 + input = torch.randn((SIZE, ), dtype=torch.float32, device=device) + output = torch.zeros((SIZE, ), dtype=torch.float32, device=device) + patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''}) + + expect_fail = (not constexpr and mode != "direct") or choose_const + if expect_fail: + with pytest.raises(triton.CompilationError) as exc_info: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + if constexpr: + error = "Cannot store to a constant pointer" + else: + if mode == "call": + error = "Inconsistent return types" + elif mode == "if": + error = "Mismatched type for final_out" + elif mode == "ternary": + error = "Ternary expression with dynamic condition has inconsistent type" + else: + assert mode == "direct" and choose_const + error = "Cannot store to a constant pointer" + error_msg = exc_info.value.error_message or str(exc_info.value.__cause__) + assert error in error_msg, "Wrong error message!" + else: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + assert torch.all(input == output) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['float32', 'float16']) +def test_dot_without_load(dtype_str, device): + + @triton.jit + def _kernel(out): + a = GENERATE_TEST_HERE + b = GENERATE_TEST_HERE + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) + a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + out_ref = torch.matmul(a, b) + out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](out) + assert torch.all(out == out_ref) + + +# --------------- +# test arange +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_arange(start, num_ctas, device): + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) + np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + + +# --------------- +# test load +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4] + for other in [0, 1]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device): + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = GENERATE_TEST_HERE + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None" + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) + + reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device))) + torch.testing.assert_close(output, reference_out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("mask_val", [True, False]) +@pytest.mark.parametrize("other_val", [0, 1]) +def test_masked_load_scalar(num_ctas, mask_val, other_val, device): + input_val = 4.0 + size = 128 + dtype = torch.float32 + input = torch.full((size, ), input_val, dtype=dtype, device=device) + output = torch.zeros((size, ), dtype=dtype, device=device) + + @triton.jit + def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr): + offsets = tl.arange(0, size) + x = tl.load(in_ptr + offsets, mask=mask, other=other) + tl.store(out_ptr + offsets, x) + + kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas) + + if mask_val: + reference_out = torch.full((size, ), input_val, dtype=dtype, device=device) + else: + reference_out = torch.full((size, ), other_val, dtype=dtype, device=device) + + torch.testing.assert_close(output, reference_out) + + +# Testing masked loads with a copy to shared memory. +# FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_masked_load_shared_memory(dtype, device): + + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + M = 32 + N = 32 + K = 16 + + in1 = torch.rand((M, K), dtype=dtype, device=device) + in2 = torch.rand((K, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=dtype, device=device) + + @triton.jit + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + + M_offsets = tl.arange(0, M) + N_offsets = tl.arange(0, N) + K_offsets = tl.arange(0, K) + + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] + + # Load inputs. + x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K) + w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N) + + # Without a dot product the memory doesn't get promoted to shared. + o = tl.dot(x, w, out_dtype=tl.float32) + + # Store output + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] + tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) + + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) + + reference_out = torch.matmul(in1, in2) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) +def test_load_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + + if is_hip(): + target_arch = get_arch() + # TODO: support testing for remaining architectures + if 'gfx94' not in target_arch: + return + amdgcn = pgm.asm['amdgcn'] + cg_cache_modifier_str = 'nt' + cv_cache_modifier_str = 'sc0 sc1' + buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] + global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] + load_line = global_load_line[0] if global_load_line else buffer_load_line[0] + if cache == '' or cache == '.ca': + assert cg_cache_modifier_str not in load_line + if cache == '.cg': + assert cg_cache_modifier_str in load_line + if cache == '.cv': + assert cv_cache_modifier_str in load_line + + if is_cuda(): + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_vectorization(N, num_ctas, device): + block_size = 1024 * num_ctas + src = torch.randn(block_size, device=device) + dst = torch.empty(block_size, device=device) + + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) + + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx + torch.testing.assert_close(dst[:N], src[:N], atol=1e-6, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("has_hints", [False, True]) +def test_vectorization_hints(has_hints, device): + src = torch.empty(1024, device=device) + dst = torch.empty(1024, device=device) + off = torch.zeros(1, device=device, dtype=torch.int32) + + @triton.jit + def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = offsets + tl.load(off) + if HINT: + tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if has_hints: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.v4.b32" not in ptx + + +@pytest.mark.interpreter +def test_assume(device): + + @triton.jit + def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): + current_size = N - tl.program_id(0) * BLOCK_N + tl.assume(current_size >= BLOCK_N) + if current_size >= 128: + tl.store(out_ptr + tl.program_id(0), current_size) + else: + tl.store(out_ptr + tl.program_id(0), current_size + 101024) + + output = torch.zeros(1024 // 128, device=device) + pgm = _kernel[(1024 // 128, )](output, N=1024, BLOCK_N=128) + + if is_interpreter(): + return + + assert 'llvm.assume' in pgm.asm['llir'] + + +# --------------- +# test store +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) +def test_store_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, cache_modifier=CACHE) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + + if is_hip(): + target_arch = get_arch() + # TODO: support testing for remaining architectures + if 'gfx94' not in target_arch: + return + amdgcn = pgm.asm['amdgcn'] + cs_cache_modifier_str = 'nt' + wt_cache_modifier_str = 'sc0 sc1' + buffer_store_line = [line for line in amdgcn.splitlines() if "buffer_store" in line] + global_store_line = [line for line in amdgcn.splitlines() if "global_store" in line] + store_line = global_store_line[0] if global_store_line else buffer_store_line[0] + if cache == '' or cache == '.cg': + assert cs_cache_modifier_str not in store_line + assert wt_cache_modifier_str not in store_line + if cache == '.cs': + assert cs_cache_modifier_str in store_line + assert wt_cache_modifier_str not in store_line + if cache == '.wt': + assert cs_cache_modifier_str not in store_line + assert wt_cache_modifier_str in store_line + + if is_cuda(): + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + assert 'st.global.wt' not in ptx + if cache == '.wt': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("eviction_policy", ["", "evict_last", "evict_first"]) +def test_store_eviction_policy(eviction_policy, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, POLICY: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, eviction_policy=POLICY) + + if not is_cuda(): + return + pgm = _kernel[(1, )](dst, src, POLICY=eviction_policy) + ptx = pgm.asm['ptx'] + if eviction_policy == '': + assert 'evict_last' not in ptx + assert 'evict_first' not in ptx + if eviction_policy == 'evict_last': + assert 'evict_last' in ptx + assert 'evict_first' not in ptx + if eviction_policy == 'evict_first': + assert 'evict_last' not in ptx + assert 'evict_first' in ptx + + +# --------------- +# test default +# --------------- +# TODO: can't be local to test_default + + +@triton.jit +def _impl(value=10): + return value + + +@pytest.mark.interpreter +def test_default(device): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device=device) + ret1 = torch.zeros(1, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(ret0, ret1, value=3): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1, )](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + + _kernel[(1, )](ret0, ret1) + assert ret0.item() == 10 + assert ret1.item() == 3 + + +# --------------- +# test noop +# ---------------- + + +@pytest.mark.interpreter +def test_noop(device): + + @triton.jit + def kernel(x): + pass + + x = to_triton(numpy_random((1, ), dtype_str='int32'), device=device) + kernel[(1, )](x) + + +@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) +def test_pointer_arguments(device): + + @triton.jit + def kernel(x): + pass + + pin_memory = 'pinned' in device + x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) + if device == "cpu": + with pytest.raises(ValueError): + kernel[(1, )](x) + else: + kernel[(1, )](x) + + +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) +def test_value_specialization(value: int, value_type: str, device) -> None: + + def repr(specialization): + ty = specialization.signature["value1"] + cst = '_'.join([k for k, v in specialization.constants.items() if isinstance(k, str) and v == 1]) + return f"kernel_{ty}_{cst}" + + @triton.jit(repr=repr) + def kernel(value1, is_one, X): + pass + + x = torch.tensor([3.14159], device=device) + h = kernel[(1, )](value, 1, x) + assert "is_one" in h.name + assert value_type in h.name + + +# -------------------- +# value specialization +# -------------------- + + +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) +def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + + if overflow: + with pytest.raises(OverflowError): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) + + +# ---------------- +# test constexpr +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + if op in ['<<', '>>', '&', '^', '|']: # int op + x_str = "3" if is_lhs_constexpr else "x" + y_str = "4" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="int32") + + # NOTE: bitshifting beyond bitwidth can lead to undefined behavior + if op in ['<<', '>>']: + y = numpy_random((1, ), dtype_str="int32", low=0, high=_bitwidth("int32")) + else: + y = numpy_random((1, ), dtype_str="int32") + else: + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) + + +@pytest.mark.interpreter +def test_constexpr_shape(device): + + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + + +@pytest.mark.interpreter +def test_constexpr_scalar_shape(device): + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + +reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("formats", reshape_list) +def test_reshape(formats, device): + in_format, out_format = formats + + @triton.jit + def kernel(Z, X, out_tuple: tl.constexpr): + x = tl.load(X_PTR_EXPR) + z = tl.reshape(x, out_tuple) + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + } + return patch_kernel(kernel, to_replace) + + x = numpy_random(in_format, dtype_str="int32") + z = x.reshape(out_format) + x_tri = to_triton(x, device=device) + patched_kernel = generate_kernel(in_format, out_format) + z_tri = to_triton(np.empty(out_format, dtype=np.int32), device=device) + patched_kernel[(1, )](z_tri, x_tri, out_format) + np.testing.assert_equal(z, to_numpy(z_tri)) + + +def test_reshape_err(device): + + @triton.jit + def kernel(): + x = tl.arange(0, 8 * 8) + y = tl.reshape(x, (8 * 4, )) + + with pytest.raises(triton.CompilationError) as exc_info: + kernel[(1, )]() + + assert "reshape" in str(exc_info.value) + + +@pytest.mark.interpreter +def test_tma_load_block_shape_err(device): + + @triton.jit + def kernel(ptr): + desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 32]) + desc.load([0, 0]) + + input = torch.empty((128, 128), dtype=torch.int32, device=device) + errc = triton.CompilationError if not is_interpreter() else InterpreterError + with pytest.raises(errc) as e: + kernel[(1, )](input) + + assert "tensor descriptor block shape must have at least 8 rows" in str(e.value.__cause__) + + +@pytest.mark.interpreter +def test_tma_store_block_shape_err(device): + + @triton.jit + def kernel(ptr): + desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 8]) + desc.store([0, 0], tl.zeros((1, 32), dtype=tl.int16)) + + input = torch.empty((128, 128), dtype=torch.int16, device=device) + errc = triton.CompilationError if not is_interpreter() else InterpreterError + with pytest.raises(errc) as e: + kernel[(1, )](input) + + assert "int16 tensor descriptor block shape must have at least 16 columns" in str(e.value.__cause__) + + +def test_trans_reshape(device): + + @triton.jit + def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr): + + in_block_ptr = tl.make_block_ptr( + base=in_base_ptr, + shape=(IN_SHAPE0, IN_SHAPE1), + strides=(IN_SHAPE1, 1), + offsets=(0, 0), + block_shape=(IN_SHAPE0, IN_SHAPE1), + order=(1, 0), + ) + x = tl.load(in_block_ptr) + x = tl.reshape(x, (32, 4, 4, 2)) + x = tl.permute(x, (1, 2, 3, 0)) + x = tl.reshape(x, (IN_SHAPE0 * IN_SHAPE1, )) + tl.store(out_base_ptr + tl.arange(0, IN_SHAPE0 * IN_SHAPE1), x) + + shape = (32, 32) + input = torch.arange(math.prod(shape), dtype=torch.int32, device=device).reshape(shape) + expected = torch.permute(input, (1, 0)) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) + + k = kernel[(1, )](input, actual, shape[0], shape[1]) + assert k.asm['ttgir'].count( + 'ttg.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# ------------- +# test call +# ------------- + + +@triton.jit +def val_multiplier(val, i): + return val * i + + +@triton.jit(noinline=True) +def val_multiplier_noinline(val, i): + return val * i + + +@triton.jit +def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + if type == "inline": + vec = val_multiplier(vec, i) + else: + vec = val_multiplier_noinline(vec, i) + tl.store(ptr + offsets, vec, mask=mask) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("type", ["inline", "noinline"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_call(type, num_ctas, device): + + @triton.jit + def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): + vecmul_kernel(ptr, n_elements, num1, type) + vecmul_kernel(ptr, n_elements, num2, type) + + size = 1024 + rand_val = numpy_random((size, ), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device=device) + err_msg = "" + try: + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + except Exception as e: + err_msg = str(e) + + if type == "noinline" and not is_interpreter(): + assert err_msg != "" + else: + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) + + +# ------------- +# test if +# ------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) +def test_if(if_type, device): + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): + pid = tl.program_id(0) + cond = tl.load(Cond) + if IfType == "if": + if pid % 2 == 0: # eq + tl.store(Ret, tl.load(XTrue)) + elif 1 == pid % 2: # req + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_constexpr": + val = 3.14 if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_void": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_dynamic": + if BoolVar and (1 != pid % 2 and pid % 2 != 1): # rne and ne + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticVaue != 0 and StaticVaue != 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device=device) + x_true = torch.tensor([3.14], dtype=torch.float32, device=device) + x_false = torch.tensor([1.51], dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) + + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) + assert torch.equal(ret, x_true) + + +def test_num_warps_pow2(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1, )](dst=dst, num_warps=3) + _kernel[(1, )](dst=dst, num_warps=1) + _kernel[(1, )](dst=dst, num_warps=2) + _kernel[(1, )](dst=dst, num_warps=4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) +def test_unary_math(func_str, device): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.FUNC_STR(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + kernel = patch_kernel(kernel, {'FUNC_STR': func_str}) + + shape = (128, ) + x = torch.randn(shape, dtype=torch.float32, device=device) + if func_str in ['sqrt', 'rsqrt']: + x = torch.abs(x) + if func_str in ['log', 'log2']: + x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device)) + y = torch.zeros(shape, dtype=torch.float32, device=device) + + kernel[(1, )](x, y, BLOCK=shape[0]) + torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3) + + +# ----------------------- +# test inline asm +# ----------------------- + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint32', rs=rs) + y = numpy_random(shape, dtype_str='uint32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + n = 17 + z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = (y << n) | (x >> (32 - n)) + # compare + np.testing.assert_equal(y_ref, to_numpy(z_tri)) + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm_packed(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # shift 4x8bits values together. + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +@pytest.mark.parametrize('num_ctas', num_ctas_list) +def test_inline_asm_with_pointers(num_ctas, device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x_ptrs = X + tl.arange(0, BLOCK) + y_ptrs = Y + tl.arange(0, BLOCK) + tl.inline_asm_elementwise( + "ld.global.b8 $0, [$1]; \ + shl.b32 $0, $0, 3; \ + st.global.b8 [$2], $0;", "=r,l,l", [x_ptrs, y_ptrs], dtype=tl.int8, is_pure=False, + pack=1) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +def test_inline_asm_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # C = A - B + # D = B - A + (c, d) = tl.inline_asm_elementwise( + asm=""" + sub.u32 $0, $2, $3; // C = A - B + sub.u32 $1, $3, $2; // D = B - A + """, + constraints=( + # 2 output registers: $0=C and $1=D. + "=r,=r," + # 2 input registers: $2=A and $3=B. + "r,r"), + args=[a, b], + dtype=(tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A - B + D_ref = B - A + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +def test_inline_asm_packed_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint8', rs=rs) + B = numpy_random(shape, dtype_str='float32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='int32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='float32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A.astype(np.int32) + D_ref = np.maximum(A.astype(np.float32), B) + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +# ----------------------- +# test control flow +# ----------------------- + + +@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) +def test_for_iv(lo, hi, iv, device): + + @triton.jit + def kernel(Out, lo, hi, iv: tl.constexpr): + acc = 0 + acc = acc.to(tl.int64) + for i in range(lo, hi, iv): + acc += i + tl.store(Out, acc) + + lo = 2**35 + hi = 2**35 + 20 + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) + assert out[0] == sum(range(lo, hi, iv)) + + +@pytest.mark.interpreter +def test_if_else(device): + + @triton.jit + def kernel(Cond, TrueVal, FalseVal, Out): + if tl.load(Cond): + val = tl.load(TrueVal) + else: + val = tl.load(FalseVal) + tl.store(Out, val) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # True + cond[0] = True + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == true_val[0] + # False + cond[0] = False + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == false_val[0] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode, device): + + @triton.jit + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return + tl.store(Out, 1) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # exit early path taken + exit_early[0] = 1 + kernel[(1, )](exit_early, out, True, mode) + assert to_numpy(out)[0] == 0 + # exit early path not taken + exit_early[0] = 0 + kernel[(1, )](exit_early, out, False, mode) + assert to_numpy(out)[0] == 1 + + +@triton.jit +def add_fn(x): + return x + 1 + + +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) +def test_if_call(call_type, device): + + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if call_type == "attribute": + # call attribute + if pid == 0: + a = o + a = a.to(tl.int32).to(tl.int32) + 1 + o = a + elif call_type == "attribute_jit": + # call attribute and jit function + if pid == 0: + a = o + a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1 + o = a + elif call_type == "jit": + if pid == 0: + # regular function call + a = o + a = add_fn(a) + o = a + elif call_type == "jit_if": + # function without end_if block + if pid == 0: + a = o + a = add_fn_return(a, pid) + o = a + elif call_type == "jit_if_exp": + # ifexp expression + if pid == 0: + a = o + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + o = a + elif call_type == "jit_expr": + # call without return + if pid == 0: + a = o + 1 + add_fn_expr(Out, a) + o = a + elif call_type == "jit_static_cond": + if pid == 0: + a = o + 1 + add_fn_static_cond(o, call_type) + o = a + elif call_type == "jit_noinline": + if pid == 0: + a = o + 1 + add_fn_noinline(a) + o = a + elif call_type == "jit_extern": + if pid == 0: + a = o + 1 + tl.cdiv(a, a) + o = a + + tl.store(Out, o) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) + assert to_numpy(out)[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("_cond1", [True, False]) +@pytest.mark.parametrize("_cond2", [True, False]) +@pytest.mark.parametrize("_cond3", [True, False]) +def test_nested_if_else_return(_cond1, _cond2, _cond3, device): + + @triton.jit + def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): + val = 0 + if tl.load(Cond1): + if tl.load(Cond2): + val = tl.load(Val1) + else: + return + else: + if tl.load(Cond3): + val = tl.load(Val2) + else: + val = tl.load(Val3) + tl.store(Out, val) + + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) + targets = { + (True, True, True): val1[0], + (True, True, False): val1[0], + (True, False, True): out[0], + (True, False, False): out[0], + (False, True, True): val2[0], + (False, True, False): val3[0], + (False, False, True): val2[0], + (False, False, False): val3[0], + } + assert out[0] == targets[(_cond1, _cond2, _cond3)] + + +@pytest.mark.interpreter +def test_while(device): + + @triton.jit + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): + init_i = tl.load(InitI) + curr_i = init_i + j = 0 + # Check that init_i is not updated by the loop + while j < tl.load(Bound): + curr_i = curr_i + (j == tl.load(CutOff)) + j += 1 + tl.store(OutInitI, init_i) + tl.store(OutI, curr_i) + tl.store(OutJ, j) + + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) + assert out_init_i[0] == init_i[0] + assert out_i[0] == init_i[0] + 1 + assert out_j[0] == bound[0] + + +@pytest.mark.interpreter +def test_nested_while(device): + + @triton.jit + def nested_while(data, countPtr): + for i in range(10): + count = tl.load(countPtr) + while count > 0: + tl.store(data, tl.load(data) + 1.0) + count = count - 2 + + counter = torch.tensor([8], dtype=torch.int32, device=device) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) + assert data[0] == 40 + + +def test_constexpr_if_return(device): + # Reproducer for #4883, return statement in an if with a constexpr causes + # errors when combined with non-trivial control flow graphs + + @triton.jit + def kernel(Semaphore, Out, total: tl.constexpr): + if total == 1: + tl.store(Out, tl.program_id(0)) + return + + prev = tl.atomic_add(Semaphore, 1) + if prev + 1 != total: + return + + tl.store(Out, tl.program_id(0) + prev) + + sem = torch.zeros((), device=device, dtype=torch.int32) + out = torch.empty((), device=device, dtype=torch.int32) + kernel[(1, )](sem, out, 1) + assert out.item() == 0 + + sem = torch.zeros((), device=device, dtype=torch.int32) + out = torch.full((), fill_value=-1, device=device, dtype=torch.int32) + kernel[(4, )](sem, out, 4) + assert out.item() >= 0 + + +@triton.jit +def return_poison(x): + a = False + if a: + return x + + +def test_poison_return(device): + + @triton.jit + def kernel(Out): + tl.store(Out, return_poison(0)) + + a = torch.empty((), device=device, dtype=torch.int32) + h = kernel[(1, )](a) + assert "ub.poison" in h.asm["ttir"], h.asm["ttir"] + # hip/xpu uses llvm.store, which in this case is removed by the optimizer + if not (is_hip() or is_xpu()): + assert "poison" in h.asm["llir"], h.asm["llir"] + + +# ----------------------- +# test extra +# ----------------------- + + +def test_num_threads(device): + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + + +def test_globaltimer(device): + if is_hip(): + pytest.skip("test_globaltimer is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out1, Out2): + start = tl.extra.cuda.globaltimer() + off = tl.arange(0, 128) + for i in range(10000): + tl.store(Out1 + off, tl.load(Out1 + off) + 1) + end = tl.extra.cuda.globaltimer() + tl.store(Out2, end - start) + + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + h = kernel[(1, )](out1, out2) + assert out2[0] > 0 + assert h.asm["ptx"].count("%globaltimer") == 2 + + +def test_smid(device): + if is_hip(): + pytest.skip("test_smid is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out): + tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) + + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) + assert out.sort()[0].unique().shape[0] > 0 + assert h.asm["ptx"].count("%smid") == 1 + + +# ----------------------- +# test layout conversions +# ----------------------- +# TODO: backend should be tested separately + +layouts = [ + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1), + MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + SliceLayout( + dim=1, + parent=DotOperandLayout(parent=MmaLayout([3, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [16, 32, 16]), + op_idx=0, k_width=2)), + SliceLayout( + dim=1, parent=DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), + op_idx=1, k_width=2)), +] + +intermediate_layouts = [ + None, + SharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]), + SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +def compute_rep_shape(layout): + if type(layout) is BlockedLayout: + warp_shape = np.multiply(layout.sz_per_thread, layout.threads_per_warp) + rep_shape = np.multiply(warp_shape, layout.warps_per_cta) + return rep_shape + else: + assert False, "TODO: support compute_rep_shape for layout " + str(type(layout)) + + +# This function gives a lower bound approximation of scratch buffer shape for convert_layout operation +def compute_scratch_buffer_shape(src_layout, dst_layout, shape): + src_rep_shape = compute_rep_shape(src_layout) + dst_rep_shape = compute_rep_shape(dst_layout) + full_scratch_shape = np.maximum(src_rep_shape, dst_rep_shape) + return np.minimum(full_scratch_shape, shape) + + +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("interm_layout", intermediate_layouts) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): + if str(src_layout) == str(dst_layout): + pytest.skip() + if (isinstance(src_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)): + pytest.skip("DotOperandLayout <-> SharedLayout conversion is not completely supported") + if is_hip(): + try: + scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) + except AssertionError: + pytest.skip("Can't compute scratch buffer size") + lds_size = 65536 + # consider int32 dtype in scratch buffer size, + # because it is the largest dtype used in convert_layout in this test + int32_size = 4 + # skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding + if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size: + pytest.skip("Scratch buffer is too large") + + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + #smem = #ttg.shared_memory + """ if interm_layout is None else f""" + #src = {src_layout} + #interm = {interm_layout} + #dst = {dst_layout} + #smem = #ttg.shared_memory + """ + + conversion = f""" + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ if interm_layout is None else f""" + %15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #smem> + %16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #smem> -> tensor<{M}x{N}xi32, #src> + %17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #smem> + %18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #smem> -> tensor<{M}x{N}xf16, #src> + + %12 = ttg.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_convert2d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + torch.testing.assert_close(z, x, rtol=0, atol=0) + + +layouts_3d = [ + BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), op_idx=0, + k_width=1), +] + +shared_layouts_3d = [ + SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), +] + + +@pytest.mark.parametrize("M, N, K", [[8, 16, 32]]) +@pytest.mark.parametrize("shared_layout", shared_layouts_3d) +@pytest.mark.parametrize("dist_layout", filter_layouts(layouts_3d)) +def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path): + layouts = f""" + #dist = {dist_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %cst_0 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_1 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_2 = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %0 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %4 = tt.addptr %3, %2 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %8 = arith.muli %7, %cst_2 : tensor<1x{N}x1xi32, #dist> + %9 = tt.broadcast %4 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %10 = tt.broadcast %8 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %11 = tt.addptr %9, %10 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %12 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %15 = arith.muli %14, %cst_1 : tensor<{M}x1x1xi32, #dist> + %16 = tt.broadcast %11 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + %19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> + %21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> -> tensor<{M}x{N}x{K}xi32, #dist> + %22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %25 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %26 = tt.addptr %25, %24 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %27 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %30 = arith.muli %29, %cst : tensor<1x{N}x1xi32, #dist> + %31 = tt.broadcast %26 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %32 = tt.broadcast %30 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %33 = tt.addptr %31, %32 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %34 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %37 = arith.muli %36, %cst_0 : tensor<{M}x1x1xi32, #dist> + %38 = tt.broadcast %33 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %39 = tt.broadcast %37 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %40 = tt.addptr %38, %39 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + tt.store %40, %21 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + tt.return + }} +}} +""" + + x = torch.arange(0, M * N * K, device=device, dtype=torch.int32).reshape(M, N, K) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + +dot_layouts = [ + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=4), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=1, k_width=4), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=0, k_width=1), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=1), +] + +shared_layouts = [ + SharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]), + SharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N", [[16, 32]]) +@pytest.mark.parametrize("dtype", ['float16', 'float8e5', 'float32']) +@pytest.mark.parametrize("shared_layout", shared_layouts) +@pytest.mark.parametrize("dist_layout", filter_layouts(dot_layouts)) +def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, tmp_path: pathlib.Path): + if dtype == "float32": + mlir_dtype = "f32" + elif dtype == "float16": + mlir_dtype = "f16" + elif dtype == "float8e5": + mlir_dtype = "f8E5M2" + + layouts = f""" + #dist = {dist_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> + %2 = tt.splat %arg0 : !tt.ptr<{mlir_dtype}> -> tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %3 = tt.splat %arg1 : !tt.ptr<{mlir_dtype}> -> tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{M}x1xi32, #dist> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #dist> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{N}xi32, #dist> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #dist> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>, tensor<{M}x{N}xi32, #dist> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %12 = ttg.local_alloc %11 : (tensor<{M}x{N}x{mlir_dtype}, #dist>) -> !ttg.memdesc<{M}x{N}x{mlir_dtype}, #shared, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<{M}x{N}x{mlir_dtype}, #shared, #smem> -> tensor<{M}x{N}x{mlir_dtype}, #dist> + %14 = tt.addptr %3, %9 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>, tensor<{M}x{N}xi32, #dist> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store_dot.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + +mma_layouts = [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 4 warps case + MmaLayout((3, 0), [8, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 8 warps case + MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # multiple warps on the row + MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # small instrN + MmaLayout((3, 0), [8, 4], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # large number of warps +] + +shared_layouts = [ + SharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + NVMMASharedLayout(64, False, 16, [1, 1], [1, 1], [0, 1]), + NVMMASharedLayout(128, False, 16, [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N", [[128, 128]]) +@pytest.mark.parametrize("mma_layout", filter_layouts(mma_layouts)) +@pytest.mark.parametrize("shared_layout", shared_layouts) +def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path: pathlib.Path): + num_warps = np.prod(mma_layout.warps_per_cta) + + layouts = f""" + #dist = {mma_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dist> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dist> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{M}x1xi32, #dist> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #dist> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{N}xi32, #dist> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #dist> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #dist>, tensor<{M}x{N}xi32, #dist> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #dist> + %12 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #dist>) -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<{M}x{N}xf16, #shared, #smem> -> tensor<{M}x{N}xf16, #dist> + %14 = tt.addptr %3, %9 : tensor<{M}x{N}x!tt.ptr, #dist>, tensor<{M}x{N}xi32, #dist> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dist> + tt.return + }} +}} +""" + + x = torch.arange(0, M * N, device=device, dtype=torch.float16).reshape(M, N) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store_mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + if isinstance(shared_layout, NVMMASharedLayout) and mma_layout.version[0] >= 3: + assert "stmatrix" in kernel.asm["ptx"] + + +def filter_layout_pairs(layout_pairs): + return [pair for pair in layout_pairs if is_layout_applicable(pair[0]) and is_layout_applicable(pair[1])] + + +mma_pairs = [ + [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + ], + [ + MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + ], + [ + MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), + ], + [ + WmmaLayout(1, [4, 4]), + WmmaLayout(1, [16, 1]), + ], + [ + WmmaLayout(1, [16, 1]), + WmmaLayout(1, [4, 4]), + ], + [ + WmmaLayout(2, [4, 4]), + WmmaLayout(2, [16, 1]), + ], + [ + WmmaLayout(2, [16, 1]), + WmmaLayout(2, [4, 4]), + ], + [ + MfmaLayout([2, 0], [2, 2], [32, 32], False), + MfmaLayout([2, 0], [4, 1], [32, 32], False), + ], + [ + MfmaLayout([2, 0], [4, 1], [32, 32], False), + MfmaLayout([2, 0], [2, 2], [32, 32], False), + ], + [ + MfmaLayout([2, 0], [2, 2], [32, 32], False), + MfmaLayout([2, 0], [4, 1], [32, 32], True), + ], + [ + MfmaLayout([2, 0], [4, 1], [32, 32], False), + MfmaLayout([2, 0], [2, 2], [32, 32], True), + ], + [ + MfmaLayout([2, 0], [4, 4], [16, 16], False), + MfmaLayout([2, 0], [16, 1], [16, 16], False), + ], + [ + MfmaLayout([2, 0], [16, 1], [16, 16], False), + MfmaLayout([2, 0], [4, 4], [16, 16], False), + ], + [ + MfmaLayout([2, 0], [4, 4], [16, 16], False), + MfmaLayout([2, 0], [16, 1], [16, 16], True), + ], + [ + MfmaLayout([2, 0], [16, 1], [16, 16], False), + MfmaLayout([2, 0], [4, 4], [16, 16], True), + ], +] + + +@pytest.mark.parametrize("M, N", [[16, 16], [64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("mma_pair", filter_layout_pairs(mma_pairs)) +def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): + if is_hip(): + if isinstance(mma_pair[1], MfmaLayout) and (mma_pair[1].instr_shape[1] > M or mma_pair[1].instr_shape[1] > N): + pytest.skip("HIP do not fully support skinny tensor store") + + src_layout, _ = mma_pair + num_warps = np.prod(src_layout.warps_per_cta) + warp_size = THREADS_PER_WARP + + def do_test(src_layout, dst_layout): + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ + + ir = layouts + f""" + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {warp_size} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} + }} + """ + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x) + + temp_file = tmp_path / "test_convert_mma2mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + do_test(mma_pair[0], mma_pair[1]) + do_test(mma_pair[1], mma_pair[0]) + + +single_warp_layouts = [ + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 2, 2], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 4, 4], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 8, 8], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 16, 16], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 32, 32], [1, 1], [1, 0]), + BlockedLayout([32, 1], [1, THREADS_PER_WARP], [1, 1], [1, 0]), + BlockedLayout([16, 1], [2, THREADS_PER_WARP // 2], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP, 1], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 2, 2], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 4, 4], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 8, 8], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 16, 16], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 1], [1, 0]), +] + + +@pytest.mark.parametrize("M, N", [[32, 32], [64, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", single_warp_layouts) +@pytest.mark.parametrize("dst_layout", single_warp_layouts) +def test_convert_warp_local(M, N, src_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): + if str(src_layout) == str(dst_layout): + pytest.skip() + if np.prod(src_layout.threads_per_warp) == 0 or np.prod(dst_layout.threads_per_warp) == 0: + pytest.skip() + + # Test layout pairs that are likely to codegen warp shuffles. + a, b = list(np.array(src_layout.threads_per_warp) // np.array(dst_layout.threads_per_warp)) + c = a if a != 0 else b + if c > 2: + pytest.skip() + + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + #smem = #ttg.shared_memory + """ + + conversion = f""" + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"ttg.num-warps" = 1 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_convert_warp_local.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + torch.testing.assert_close(z, x, rtol=0, atol=0) + + +@pytest.mark.interpreter +def test_load_scalar_with_mask(device): + + @triton.jit + def kernel(Input, Index, Out, N: int): + index = tl.load(Index) + scalar = tl.load(Input + index, mask=index < N, other=0) + tl.store(Out, scalar, mask=index < N) + + Index = torch.tensor([0], dtype=torch.int32, device=device) + Input = torch.tensor([0], dtype=torch.int32, device=device) + Out = torch.empty_like(Index, device=device) + kernel[(1, )](Input, Index, Out, Index.numel()) + assert Out.data[0] == 0 + + +# This test is used to test our own PTX codegen for float16 and int16 conversions +# maybe delete it later after ptxas has been fixed +@pytest.mark.parametrize("dtype_str", ['float16', 'int16']) +def test_ptx_cast(dtype_str, device): + + @triton.jit + def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x0 = xindex + _tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype) + tmp1 = 2 + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(dtype) + tmp5 = _tmp4 < tmp3 + _tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4) + tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask) + + torch.manual_seed(123) + if dtype_str == 'int16': + torch_dtype = torch.int16 + triton_dtype = tl.int32 + else: + torch_dtype = torch.float16 + triton_dtype = tl.float32 + + s0 = 4 + buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) + buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + assert buf14.to(torch.float32).mean() == -2.0 + + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr, # + num_stages: tl.constexpr = 3 # +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N, K", [(128, 256, 256)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)]) +@pytest.mark.parametrize( + "in_type_str", + ['float8e5', 'float8e5b16', 'float8e4b8', 'float8e4nv'] if is_hip() else ['float8e5', 'float8e4nv', 'float8e4b15']) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device): + num_stages = 3 + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] >= 9 and in_type_str == "float8e4b15": + pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + elif is_hip(): + num_stages = 2 + if in_type_str in ("float8e5b16", "float8e4b8") and not is_hip_mi300(): + pytest.skip(f"{in_type_str} only supported on mi300") + if in_type_str in ("float8e5", "float8e4nv") and not is_hip_mi350(): + pytest.skip(f"{in_type_str} only supported on mi350") + + check_type_supported(in_type_str, device) + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + C = torch.empty((M, N), dtype=torch.float32, device=device) + num_warps = 8 + a = to_triton(A, device=device, dst_type=in_type_str) + b = to_triton(B, device=device, dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None + h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), + C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, + num_stages=num_stages) + torch_a = torch.from_numpy(A).to(device=device) + th_a = f8_to_f16(torch_a, in_type_str) + torch_b = torch.from_numpy(B).to(device=device) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + else: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + if is_cuda() and low_precision_acc > 0 and torch.cuda.get_device_capability()[0] == 9: + assert h.asm["ptx"].count("add.f32") == (BLOCK_M * BLOCK_N) // (32 * num_warps) * (BLOCK_K // low_precision_acc) + + +# ----------------------- +# test enable_fp_fusion +# ----------------------- + + +@pytest.mark.parametrize("enable_fp_fusion", [False, True]) +@pytest.mark.parametrize("default_override", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, default_override, device): + if is_hip(): + pytest.skip( + 'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton' + ) + + # Sequential multiply add can be fused by backend + @triton.jit + def mul_add(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + if default_override: + os.environ["TRITON_DEFAULT_FP_FUSION"] = "1" if enable_fp_fusion else "0" + h = mul_add[(1, )](data) + else: + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) + + if not is_cuda(): + return + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion + + +# ----------------------- +# test override_arch +# ----------------------- + + +@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90"]) +@pytest.mark.parametrize("env_var_override", [False, True]) +def test_override_arch(arch, env_var_override, device): + if not is_cuda(): + pytest.skip('arch only for CUDA') + + @triton.jit + def simple(data, out): + in_ptrs = data + tl.arange(0, 128) + out_ptrs = out + tl.arange(0, 128) + tl.store(out_ptrs, tl.load(in_ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + out = torch.empty_like(data) + + if env_var_override: + os.environ["TRITON_OVERRIDE_ARCH"] = str(arch) + h = simple[(1, )](data, out) + os.environ.pop("TRITON_OVERRIDE_ARCH") + else: + h = simple[(1, )](data, out, arch=arch) + torch.testing.assert_close(data * 1.5 + 1.0, out) + ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"]) + assert ttgir_cc.group(1) == arch[2:] + + +# ----------------------- +# test propagate_nan +# ----------------------- + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +def test_propagate_nan(dtype, propagate_nan, func, device): + + @triton.jit + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + + for mode in ['A', 'B', 'both']: + if func == 'clamp' and mode == 'B': + # clamp does not guarantee propagation from 'min' and 'max' args + continue + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'A' or mode == 'both': A[0] = torch.nan + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'B' or mode == 'both': B[0] = torch.nan + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) + + if mode == 'both' or propagate_nan == 'ALL': + assert torch.isnan(C[0]) + else: + assert not torch.isnan(C[0]) + + +# ----------------------- +# test clamp +# ----------------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp(dtype, device): + + @triton.jit + def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + min = tl.load(min_ptr + off, mask=mask) + max = tl.load(max_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, min, max), mask=mask) + ref_val = tl.minimum(tl.maximum(x, min), max) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + a = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + b = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + min = torch.min(a, b) + max = torch.max(a, b) + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, min, max, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# Test for symmetric clamp(x, -limit, limit), as it may go through optimized +# codegen in the backends +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['bfloat16', 'float16', 'float32']) +def test_clamp_symmetric(dtype, device): + + @triton.jit + def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + limit = tl.load(limit_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, -limit, limit), mask=mask) + ref_val = tl.minimum(tl.maximum(x, -limit), limit) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + limit = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)).abs() + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, limit, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# ----------------------- +# test iterators +# ----------------------- + + +@pytest.mark.interpreter +def test_static_range(device): + + @triton.jit + def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): + acc = 0 + for i in tl.static_range(0, N, step=step): + acc += i + tl.store(Z, acc) + + N = 100 + step = 7 + Out = torch.empty(1, dtype=torch.int32, device=device) + loop_kernel[(1, )](Out, N, step) + Acc = torch.tensor([0], dtype=torch.int32, device=device) + for i in range(0, N, step): + Acc += i + assert (Out == Acc).all(), (Out, Acc) + + +@pytest.mark.interpreter +def test_tl_range_num_stages(device): + if is_hip(): + pytest.skip("test_tl_range is not supported in HIP") + M, N, K = 64, 64, 512 + BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + a = torch.randn((M, K), device=device, dtype=torch.float16) + b = torch.randn((K, N), device=device, dtype=torch.float16) + c = torch.empty((M, N), dtype=torch.float32, device=device) + pgm = matmul_kernel[ + 1, + ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, 0, num_stages=5) + ref_out = torch.matmul(a, b).to(torch.float32) + if is_interpreter(): + # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. + # Thus we use a higher tolerance + torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) + else: + torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 6' in ptx + + +def test_tl_range_fuse(): + if is_hip(): + pytest.skip("loop fusion is not enabled on AMD") + + @triton.jit + def kernel(ub): + for i in tl.range(0, ub, flatten=True): + for j in tl.range(0, ub): + print("i", i) + + compiled_kernel = kernel.warmup(10, grid=(1, )) + assert "tt.flatten" in compiled_kernel.asm["ttir"] + assert compiled_kernel.asm["ttgir"].count("scf.for") == 1 + + +def test_tl_range_option_none(): + + @triton.jit + def kernel(ub): + for i in tl.range(0, ub, num_stages=None, loop_unroll_factor=None): + print("i", i) + + compiled_kernel = kernel.warmup(10, grid=(1, )) + assert "num_stages" not in compiled_kernel.asm["ttir"] + assert "loop_unroll_factor" not in compiled_kernel.asm["ttir"] + + +@triton.jit(noinline=True) +def maxnreg_noinline1(X): + tl.store(X, 0) + + +@triton.jit(noinline=True) +def maxnreg_noinline2(X): + tl.store(X, 0) + + +@pytest.mark.interpreter +def test_maxnreg(device): + if not is_cuda(): + pytest.skip('maxnreg only works on CUDA') + + # triton kernel + @triton.jit + def kernel(X): + maxnreg_noinline1(X) + tl.store(X, 0) + maxnreg_noinline2(X) + + X = torch.empty(1, dtype=torch.int32, device=device) + k = kernel[(1, )](X, maxnreg=42) + + if not is_interpreter(): + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise + + +@pytest.mark.interpreter +def test_temp_var_in_loop(device): + + @triton.jit + def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): + acc = tl.full((BLOCK, ), 0, dtype=tl.int32) + for i in range(N): + if i == 0: + temp = tl.full((BLOCK, ), 2, dtype=tl.int32) + acc = temp + else: + acc += tl.full((BLOCK, ), 1, dtype=tl.int32) + # reuse the temp variable and make sure to check that it isn't creating incorrect IR. + temp = tl.full((BLOCK, ), 1, dtype=tl.int32) + acc += temp + z = Z + tl.arange(0, BLOCK) + tl.store(z, acc) + + N = 10 + BLOCK = 32 + out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) + temp_in_loop[(1, )](out, N, BLOCK) + acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) + for i in range(N): + if i == 0: + temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) + acc = temp + else: + acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + acc += temp + assert (acc == out).all() + + +@pytest.mark.interpreter +def test_num_programs(device): + # Assuming that the kernel is launched with a grid of (11, 21, 31) + grid = (11, 21, 31) + input = torch.empty((3, ), dtype=torch.int32, device=device) + + @triton.jit + def kernel(input): + num_programs_0 = tl.num_programs(0) + num_programs_1 = tl.num_programs(1) + num_programs_2 = tl.num_programs(2) + tl.store(input, num_programs_0) + tl.store(input + 1, num_programs_1) + tl.store(input + 2, num_programs_2) + + kernel[grid](input) + assert torch.all(input == torch.tensor(grid, device=device)) + + +# ----------------------- +# test extern functions +# ----------------------- + + +@pytest.mark.parametrize("dtype_str", ['float32', 'float64']) +def test_math_extern(dtype_str, device): + if is_interpreter(): + pytest.skip('math_extern does not work in the interpreter mode') + + @triton.jit + def kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = libdevice.tanh(x) + tl.store(y_ptr + offsets, y, mask=mask) + + shape = (128, ) + rs = RandomState(17) + + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + y_ref = np.tanh(x) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str=dtype_str, rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, shape[0], BLOCK_SIZE=shape[0]) + # compare + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + + +# ----------------------- +# test loop unrolling +# ----------------------- + + +def test_unroll_attr(device): + + @triton.jit + def _kernel(dst, unroll_factor: tl.constexpr): + pid = tl.program_id(axis=0) + for i in tl.range(0, 10, loop_unroll_factor=unroll_factor): + tl.atomic_add(dst + pid, i + pid) + + def check_loop_unroll_count(ir, opStr, loop_unroll_factor): + for line in ir.splitlines(): + if opStr in line: + loop_unroll_factor = loop_unroll_factor - 1 + # Sometimes we get a remainder loop + assert loop_unroll_factor <= 0 + + # Try for all different loop unroll factors: + for unroll_factor in [1, 2, 4, 5, 8]: + h = _kernel[(1, )](torch.empty(1, device=device), unroll_factor) + check_loop_unroll_count(h.asm["ttir"], 'tt.atomic_rmw', unroll_factor) + + +@triton.jit +def sanitize_add(a, b): + a64 = a.to(tl.int64) + b64 = b.to(tl.int64) + r64 = a64 + b64 + tl.device_assert((r64 >= -2**31) & (r64 <= 2**31 - 1)) + return a + b + + +def test_side_effectful_reduction(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.reduce(vals, 0, sanitize_add) + tl.store(Z, z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros((), device="cuda", dtype=torch.int32) + sanitize_sum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.sum().to(torch.int32)) + + +@pytest.mark.parametrize("reduce_dim", [0, 1]) +def test_side_effectful_reduction_2d(device, reduce_dim): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr, + NON_REDUCE_DIM: tl.constexpr): + offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :] + vals = tl.load(X + offsets) + z = tl.reduce(vals, reduce_dim, sanitize_add) + tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z) + + BLOCK_0 = 16 + BLOCK_1 = 32 + NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32) + Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32) + sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim, + NON_REDUCE_DIM=NON_REDUCE_DIM) + torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) + + +def test_dtype(device): + + @triton.jit + def kernel(X): + dtype_x: tl.constexpr = X.dtype.element_ty + tl.static_assert(dtype_x == tl.int32) + tl.static_assert(dtype_x == tl.constexpr(tl.int32)) + tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32)) + + X = torch.zeros(1, dtype=torch.int32, device=device) + kernel[(1, )](X) + + +def test_side_effectful_scan(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.associative_scan(vals, 0, sanitize_add) + tl.store(Z + tl.arange(0, BLOCK), z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros_like(X) + sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) + + +# stress test slice layout usages in reductions. +@pytest.mark.parametrize("in_shape, perm, red_dims", [ + ((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]), + ((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]), +]) +def test_chained_reductions(in_shape, perm, red_dims, device): + + @triton.jit + def kernel(In, Out, # + dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr, + perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr, + perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr): + idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4) + idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4) + vals = tl.load(In + idx) + vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4]) + r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2) + st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape) + tl.store(Out + st_idx, r) + + input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32) + temp = torch.permute(input, perm).contiguous() + ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2]) + result = torch.empty_like(ref) + kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4], + perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2]) + + assert torch.all(ref == result) + + +@triton.jit +def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ + ([4, 4], [8, 4], 0), + ([128, 64], [256, 64], 0), + ([128, 64], [128, 128], 1), +]) +def test_gather(src_shape, indices_shape, axis, device): + + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + + gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], + src.shape[1], src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], + indices.stride(0), indices.stride(1), output.shape[0], output.shape[1], + output.stride(0), output.stride(1)) + + return output + + src = torch.randn(src_shape, device=device) + indices = torch.randint(0, src.shape[axis], indices_shape, device=device) + ref = torch.gather(src, axis, indices) + result = triton_gather(src, axis, indices) + torch.testing.assert_close(result, ref, rtol=0, atol=0) + + +# These layouts are specially chosen to trigger the warp shuffle codegen. +@pytest.mark.parametrize("src_shape, indices_shape, axis, src_layout, indices_layout", [ + ([32, 16], [32, 16], 0, + "linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), + ([128, 64], [256, 64], 0, + "linear<{register = [[0, 2], [32, 0], [2, 0], [0, 16], [0, 32], [64, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), +]) +def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path, + device): + if is_hip(): + pytest.skip("warp-local gather has issues on HIP") + + def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + compiled = gather_test_kernel.warmup(src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), + src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), + output.stride(1), grid=(1, )) + return output, compiled + + def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout, idx_layout): + ir = f""" +#src_layout = #ttg.{src_layout} +#idx_layout = #ttg.{idx_layout} +{ir}""" + + dtypes = {torch.int32: "i32", torch.float32: "f32", torch.int64: "i64", torch.float64: "f64"} + + src_spec = f"{src.shape[0]}x{src.shape[1]}x{dtypes[src.dtype]}" + indices_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[indices.dtype]}" + output_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[src.dtype]}" + + pat = r"(%[0-9]+) = tt.gather (%[0-9]+)\[(%[0-9]+)\] {axis = " + pat += str(axis) + pat += r" : i32} : \(tensor\<" + pat += src_spec + pat += r", (#[a-z]+[0-9]+)\>, tensor\<" + pat += indices_spec + pat += r", (#[a-z]+[0-9]+)\>\) -> tensor\<" + pat += output_spec + pat += r", (#[a-z]+[0-9]+)\>" + + repl = r""" + %src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout> + %idx = ttg.convert_layout \3 : tensor<""" + indices_spec + r""", \5> -> tensor<""" + indices_spec + r""", #idx_layout> + %out = tt.gather %src[%idx] {axis = """ + str( + axis + ) + r""" : i32} : (tensor<""" + src_spec + r""", #src_layout>, tensor<""" + indices_spec + r""", #idx_layout>) -> tensor<""" + output_spec + r""", #idx_layout> + \1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>""" + return re.sub(pat, repl, ir) + + src = torch.randn(src_shape, device=device) + indices = torch.randint(0, src.shape[axis], indices_shape, device=device) + ref = torch.gather(src, axis, indices) + + output, compiled = prepare_kernel(src, axis, indices) + ir = compiled.asm["ttgir"] + ir = inject_layout(ir, src, axis, indices, src_layout, indices_layout) + + temp_file = tmp_path / "test_warp_gather.ttgir" + temp_file.write_text(ir) + + kernel = triton.compile(str(temp_file)) + assert ("nvvm.shfl.sync.idx" in kernel.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in kernel.asm["llir"]) + + kernel[(1, 1, 1)](src, indices, output) + + torch.testing.assert_close(output, ref, rtol=0, atol=0) + + +@triton.jit +def mul_jit_function(x, y): + return x * y + + +@triton.jit +def apply_binary_op(x, combine_op): + return combine_op(x, x) + + +def test_jit_function_arg(device): + + @triton.jit + def square_kernel_jit_function(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + in_data = tl.load(in_ptr + offsets) + out_data = apply_binary_op(in_data, mul_jit_function) # pass a JITFunction into another JITFunction + tl.store(out_ptr + offsets, out_data) + + BLOCK_SIZE = 16 + x = torch.full((BLOCK_SIZE, ), 3.0, device=device) + out = torch.empty((BLOCK_SIZE, ), device=device) + expect = torch.full((BLOCK_SIZE, ), 9.0, dtype=x.dtype, device=device) + + square_kernel_jit_function[(1, )](x, out, BLOCK_SIZE) + + torch.testing.assert_close(out, expect) + + +@pytest.mark.interpreter +def test_zero_strided_tensors(device): + + @triton.jit + def _simple_add( + X, + stride_x_a, + stride_x_b, + ): + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + + # doesn't directly index c dim, so relies on 0-strided c dim to affect every element + x_ptr = X + pid_a * stride_x_a + pid_b * stride_x_b + + tl.atomic_add(x_ptr, 1) + + x = torch.zeros((2, 2, 1), device=device) + c_dim = 3 + x = x.expand((2, 2, c_dim)) + + a, b, c = x.shape + grid = (a, b, c) + with torch.cuda.device(x.device.index): + _simple_add[grid](x, x.stride(0), x.stride(1)) + + assert torch.allclose(x, torch.ones_like(x) * c_dim) + + +@pytest.mark.interpreter +def test_aliasing(device): + + @triton.jit + def aliasing_kernel(buffer, buffer2): + triton.language.store(buffer, 1) + + buffer = torch.zeros(1, device=device) + aliasing_kernel[(1, )](buffer, buffer) + assert buffer[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_strided_load(dtype, device): + + @triton.jit + def take_every_second_element(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr): + strided_offsets = tl.arange(0, BLOCK_SIZE) * 2 + linear_offsets = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + strided_offsets) + tl.store(output_ptr + linear_offsets, x) + + STRIDE = 2 + SIZE = 512 + OUT_SIZE = SIZE // STRIDE + + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + out_tri = torch.empty(OUT_SIZE, device=device) + take_every_second_element[(1, 1)](x_tri, out_tri, OUT_SIZE) + + # Test that every second element (starting from [0]) from x is stored in out_tri + np.testing.assert_allclose(x[::2], to_numpy(out_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_strided_store(dtype, device): + + @triton.jit + def store_into_every_second(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr): + strided_offsets = tl.arange(0, BLOCK_SIZE) * 2 + linear_offsets = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + linear_offsets) + tl.store(output_ptr + strided_offsets, x) + + STRIDE = 2 + SIZE = 512 + OUT_SIZE = SIZE * STRIDE + + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + out_tri = torch.zeros(OUT_SIZE, device=device) + store_into_every_second[(1, 1)](x_tri, out_tri, SIZE) + + # Test that every second element (starting from [0]) is the same as in x + np.testing.assert_allclose(x, to_numpy(out_tri)[::2]) + # Test that every second element (starting from [1]) is still zero + np.testing.assert_allclose(np.zeros_like(x), to_numpy(out_tri)[1::2]) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_indirect_load(dtype, device): + + @triton.jit + def indirect_load(offset_ptr, x_ptr, output_ptr, SIZE: tl.constexpr): + linear_offsets = tl.arange(0, SIZE) + offsets = tl.load(offset_ptr + linear_offsets) + x = tl.load(x_ptr + offsets) + tl.store(output_ptr + linear_offsets, x) + + SIZE = 512 + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + # Flip the range to load the tensor in reverse order + ptr = torch.arange(SIZE, device=device, dtype=torch.int32).flip(0) + out_tri = torch.empty(SIZE, device=device) + indirect_load[(1, 1)](ptr, x_tri, out_tri, SIZE) + + np.testing.assert_allclose(np.flip(x), to_numpy(out_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_indirect_store(dtype, device): + + @triton.jit + def indirect_store(offset_ptr, x_ptr, output_ptr, SIZE: tl.constexpr): + linear_offsets = tl.arange(0, SIZE) + offsets = tl.load(offset_ptr + linear_offsets) + x = tl.load(x_ptr + linear_offsets) + tl.store(output_ptr + offsets, x) + + SIZE = 512 + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + # Flip the range to store the tensor in reverse order + ptr = torch.arange(SIZE, device=device, dtype=torch.int32).flip(0) + out_tri = torch.empty(SIZE, device=device) + indirect_store[(1, 1)](ptr, x_tri, out_tri, SIZE) + + np.testing.assert_allclose(np.flip(x), to_numpy(out_tri)) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_decorator.py b/third_party/enflame/include/triton/python/test/unit/language/test_decorator.py new file mode 100644 index 000000000..42207cc1f --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_decorator.py @@ -0,0 +1,50 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +def test_decorator_with_def(device): + + def triton_heuristics_pointwise(**kwargs): + + def decorator(func): + return func + + return decorator + + # "def" might appear in a decorator call, e.g. a hash string argument. + # This test makes sure the compiler can find the right position of function + # definition. + @triton_heuristics_pointwise(inductor_meta={'backend_hash': 'def0aeffabe53b3f8'}, ) + @triton.jit + def kernel(): + pass + + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + except Exception as e: + pytest.fail(f"triton compile failed with error: {e}") + + +def test_triton_heuristic(device): + N = 1023 + src = torch.empty(N, device=device) + dst = torch.zeros(N, device=device) + + do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench) + @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs + @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr, EVEN_N: tl.constexpr, EVEN_src: tl.constexpr): + tl.store(dst, EVEN_N) + tl.store(dst + 1, EVEN_src) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + assert dst[0].item() == 0.0 + assert dst[1].item() == 1.0 + assert _kernel.base_fn.__name__ == "_kernel" diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_libdevice.py b/third_party/enflame/include/triton/python/test/unit/language/test_libdevice.py new file mode 100644 index 000000000..2573aef5c --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_libdevice.py @@ -0,0 +1,57 @@ +import pytest +import torch + +import triton +import triton.language as tl + +from triton.language.extra import libdevice +from triton.language.extra.libdevice import fast_dividef as my_fast_dividef + + +@pytest.mark.parametrize("dtype_str", ["float32", "float64"]) +@pytest.mark.parametrize( + "libdevice_fn, torch_special_fn", + [ + ("j0", "bessel_j0"), + ("j1", "bessel_j1"), + ("y0", "bessel_y0"), + ("y1", "bessel_y1"), + ("cyl_bessel_i0", "i0"), + ("cyl_bessel_i1", "i1"), + ], +) +def test_bessel(dtype_str, libdevice_fn, torch_special_fn, device): + SIZE = 128 + dtype = getattr(torch, dtype_str) + + x = torch.randn((SIZE, ), dtype=dtype, device=device) + y_exp = torch.empty((SIZE, ), dtype=dtype, device=device) + y_ref = getattr(torch.special, torch_special_fn)(x) + + @triton.jit + def kernel(in_p, out_p, fn: tl.constexpr, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(in_p + off) + res = getattr(libdevice, fn)(x) + tl.store(out_p + off, res) + + kernel[(1, )](x, y_exp, fn=libdevice_fn, SIZE=SIZE, num_warps=4, num_ctas=1) + + torch.testing.assert_close(y_ref, y_exp, equal_nan=True) + + +def test_libdevice_rename(device): + # mark the import as used by this test + _ = my_fast_dividef + + @triton.jit + def triton_copy(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + data = tl.load(in_ptr + offsets) + tl.store(out_ptr + offsets, data) + + BLOCK_SIZE = 256 + inp = torch.randn(BLOCK_SIZE, device=device) + out = torch.empty_like(inp) + + triton_copy[(1, )](inp, out, BLOCK_SIZE) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_line_info.py b/third_party/enflame/include/triton/python/test/unit/language/test_line_info.py new file mode 100644 index 000000000..eba96ade9 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_line_info.py @@ -0,0 +1,232 @@ +import subprocess +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def kernel_single(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def device_inline(x): + return x + x + + +@triton.jit +def kernel_call(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = device_inline(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit(noinline=True) +def device_noinline(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = x + x + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_call_noinline(X, Y, BLOCK: tl.constexpr): + device_noinline(X, Y, BLOCK) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK": 128}, num_warps=4), + ], + key=[], +) +@triton.jit +def kernel_autotune(X, Y, SIZE: tl.constexpr, BLOCK: tl.constexpr): + for i in range(0, SIZE, BLOCK): + x = tl.load(X + i + tl.arange(0, BLOCK)) + tl.store(Y + i + tl.arange(0, BLOCK), x) + + +# AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +# Since the + symbol will take effect in the dot op after combination, +# it seems making sense to annotate with the same line as dot. +@triton.jit +def kernel_dot_combine(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + a = (tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :]).to(tl.int8) + d = tl.dot(a, a) + d = d + c + tl.device_print("", d) + + +# Call another jit function (cdiv) not in this file +@triton.jit +def kernel_cdiv(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + d = tl.cdiv(c, 4) + tl.device_print("", d) + + +def get_disassembler_command_and_debug_line_format(): + """Gets backend specific disassembler information. + + Returns a tuple: (object file kind, disassembler tool command, + debug line anchor, debug line file and line number separator). + """ + backend = triton.runtime.driver.active.get_current_target().backend + + if backend == "cuda": + from triton.backends.nvidia.compiler import _path_to_binary + nvdisasm, _ = _path_to_binary("nvdisasm") + return ("cubin", [nvdisasm, "-g"], "## File", ",") + + if backend == "hip": + import shutil + # Try to find llvm-objdump from the current PATH to disassmble hsaco. + tool = shutil.which("llvm-objdump") + if tool is not None: + return ("hsaco", [tool, "-D", "-l", "--arch=amdgcn"], ";", ":") + raise RuntimeError("llvm-objdump not found in PATH") + + raise RuntimeError(f"unknown backend {backend}") + + +def extract_file_lines(command, anchor, separator, asm): + fd, path = tempfile.mkstemp() + with open(fd, 'wb') as cubin: + cubin.write(asm) + asm = subprocess.check_output(command + [path]).decode("utf-8") + file_lines = [] + lines = asm.splitlines() + for line in lines: + # We are looking for an anchor string and a separator between the file name and line number. + if anchor in line and separator in line: + entries = line[line.index(anchor):].split(separator) + if len(entries) == 2 and all(len(e) != 0 for e in entries): + file_lines.append((entries[0].strip(), entries[1].strip())) + return file_lines + + +def check_file_lines(file_lines, file_name, lineno, should_contain=True): + """ + Check if the file name and line number is in the file_lines + + Args: + file_lines: list of (file_name, line_number) + file_name: file name + lineno: line number, -1 means do not check line number + should_contain: whether the file name and line number should be in the file_lines + """ + for file, line in file_lines: + if lineno == -1 and file_name in file: + return True + if file_name in file and str(lineno) in line: + return should_contain + return not should_contain + + +func_types = ["single", "call", "call_noinline", "autotune", "dot_combine", "cdiv"] + + +def is_interpreter(): + import os + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +@pytest.mark.parametrize("func", func_types) +def test_line_info(func: str): + if is_interpreter(): + pytest.skip("interpreter does not support warmup compilation") + + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + shape = (128, ) + kernel_info = {} + if func == "single": + kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, )) + elif func == "call": + kernel_info = kernel_call.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, )) + elif func == "call_noinline": + kernel_info = kernel_call_noinline.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, )) + elif func == "autotune": + kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1, ))[0] + elif func == "dot_combine": + kernel_info = kernel_dot_combine.warmup(20, grid=(1, )) + elif func == "cdiv": + kernel_info = kernel_cdiv.warmup(20, grid=(1, )) + + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + if func == "single": + assert (check_file_lines(file_lines, "test_line_info.py", 13)) + assert (check_file_lines(file_lines, "test_line_info.py", 14)) + elif func == "call": + assert (check_file_lines(file_lines, "test_line_info.py", 24)) + assert (check_file_lines(file_lines, "test_line_info.py", 26)) + elif func == "call_noinline": + assert (check_file_lines(file_lines, "test_line_info.py", 38)) + assert (check_file_lines(file_lines, "test_line_info.py", 31)) + assert (check_file_lines(file_lines, "test_line_info.py", 31)) + elif func == "autotune": + assert (check_file_lines(file_lines, "test_line_info.py", 49)) + assert (check_file_lines(file_lines, "test_line_info.py", 50)) + assert (check_file_lines(file_lines, "test_line_info.py", 51)) + elif func == "dot_combine": + assert (check_file_lines(file_lines, "test_line_info.py", 61)) + assert (check_file_lines(file_lines, "test_line_info.py", 62, should_contain=False)) + elif func == "cdiv": + assert (check_file_lines(file_lines, "test_line_info.py", 71)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func", func_types) +def test_line_info_interpreter(func: str): + if not is_interpreter(): + pytest.skip("interpreter is not enabled") + + kernel = None + expected_def_lineno = 0 + if func == "single": + kernel = kernel_single + expected_def_lineno = 12 + elif func == "call": + kernel = kernel_call + expected_def_lineno = 23 + elif func == "call_noinline": + kernel = kernel_call_noinline + expected_def_lineno = 37 + elif func == "autotune": + kernel = kernel_autotune.fn + expected_def_lineno = 48 + elif func == "dot_combine": + kernel = kernel_dot_combine + expected_def_lineno = 58 + elif func == "cdiv": + kernel = kernel_cdiv + expected_def_lineno = 68 + kernel.rewrite() + assert kernel.rewriter.def_file_lineno == expected_def_lineno + + +@pytest.mark.parametrize("status", ["0", "1"]) +def test_line_info_env(monkeypatch, status: str): + if is_interpreter(): + pytest.skip("interpreter does not support warmup compilation") + + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + shape = (128, ) + monkeypatch.setenv("TRITON_DISABLE_LINE_INFO", status) + kernel_single.device_caches.clear() + kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1, )) + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + assert len(file_lines) == 0 if status == "1" else len(file_lines) > 0 diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_matmul.py b/third_party/enflame/include/triton/python/test/unit/language/test_matmul.py new file mode 100644 index 000000000..fbb3cc70d --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_matmul.py @@ -0,0 +1,957 @@ +import math +import pytest +import torch +import triton +import triton.language as tl +import triton.tools.experimental_descriptor +from test_mxfp import MXFP4Tensor, MXScaleTensor +import re +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_hip_mi350, is_hip_cdna + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( # + a_ptr, b_ptr, output_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr, SCALE_A: tl.constexpr = None, PRECISION: tl.constexpr = "ieee"): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty) + for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): + a = tl.load(a_ptrs) + if SCALE_A is not None: + a = a * SCALE_A + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, out_dtype=output_ptr.dtype.element_ty, input_precision=PRECISION) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(output_ptrs, accumulator, mask=mask_c) + + +def get_src_element_ty_size(dtype_str): + if dtype_str == "float8e5": + return 1 + if dtype_str == "float16": + return 2 + if dtype_str == "float32" or dtype_str == "tensorfloat32": + return 4 + raise ValueError(f"Unknown dtype {dtype_str}") + + +@pytest.mark.parametrize("dtype_src_str", ["float32", "tensorfloat32", "float16", "float8e5"]) +@pytest.mark.parametrize("dtype_dst_str", ["float32", "float16"]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 16, 4), (64, 128, 32, 4), (32, 32, 32, 4), + (256, 128, 32, 4), (64, 512, 32, 2), + (512, 64, 32, 2), (64, 16, 16, 4)]) +@pytest.mark.parametrize("NUM_CTAS", [1, 2]) +@pytest.mark.parametrize("NUM_WARPS", [4, 8]) +def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, NUM_CTAS, + device): + if NUM_CTAS > 1 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 9): + pytest.skip("Clusters requires nvidia compute capability >= 9") + if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str) + > 65536): + pytest.skip("HIP path requires less than 64KB of shared memory") + if is_hip() and (not is_hip_mi300()) and dtype_src_str == "tensorfloat32": + pytest.skip("tensorfloat32 is only supported on HIP MI300") + if dtype_src_str == "float8e5" and BLOCK_K == 16: + pytest.skip("Skipping cases small K for float8") + if dtype_src_str == "float8e5" and device == "cuda" and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("Float8 requires compute capability >= 9") + if "float32" in dtype_src_str and dtype_dst_str == "float16": + pytest.skip("Skipping unsupported case") + if "float32" == dtype_src_str and NUM_CTAS > 1: + pytest.skip("FMA matmul not supported for multiple CTAs") + if (BLOCK_M < 64 or (BLOCK_M == 64 and BLOCK_N == 16)) and NUM_CTAS > 1: + pytest.skip("multi-CTAs is broken for mmav2") + M, N, K = 1024, 512, 256 + torch.manual_seed(42) + precision = "tf32" if dtype_src_str == "tensorfloat32" else "ieee" + dtype_src_str = "float32" if dtype_src_str == "tensorfloat32" else dtype_src_str + if dtype_src_str == "float8e5": + a = torch.randint(20, 40, (M, K), dtype=torch.int8, device=device).view(torch.float8_e5m2) + b = torch.randint(20, 40, (K, N), dtype=torch.int8, device=device).view(torch.float8_e5m2) + A = f8_to_f16(a, dtype_src_str) + B = f8_to_f16(b, dtype_src_str) + else: + dtype_src = getattr(torch, dtype_src_str) + a = torch.randn(M, K, dtype=dtype_src, device=device) + b = torch.randn(K, N, dtype=dtype_src, device=device) + A = a + B = b + dtype_dst = getattr(torch, dtype_dst_str) + output = torch.empty((M, N), dtype=dtype_dst, device=device) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + k = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0), + output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES, PRECISION=precision, + num_warps=NUM_WARPS, num_ctas=NUM_CTAS) + ref_out = torch.matmul(A, B).to(torch.float32) + output = output.to(torch.float32) + if dtype_src_str == "float32": + # TF32 has lower precision than torch.float32 + atol = 0.03 + rtol = 0.03 + elif dtype_dst_str == "float16": + atol = 0.06 + rtol = 0.06 + else: + atol = 0.01 + rtol = 0.01 + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol) + # Make sure the mma is pipelined by checking if in the TTGIR we are waiting for the + # barrier coming from the loop args (previous iteration). + # This applies only if TCv5 MMA is used (M % 64 == 0 and N % 8 == 0) and + # when MMA arguments loads are pipelined (N > 16) + if (device == "cuda" and torch.cuda.get_device_capability()[0] == 10 and NUM_STAGES > 1 and BLOCK_M % 64 == 0 + and BLOCK_N % 8 == 0 and BLOCK_N > 16 and not (precision == "ieee" and dtype_src_str == "float32")): + ttgir = k.asm["ttgir"] + pattern = (r"ttng.wait_barrier %arg") + assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern." + + +# persistent matmul with fused loops +@triton.jit +def simple_persistent_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr, + DISALLOW_ACC_MULTI_BUFFER: tl.constexpr): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + tile_id_c = start_pid - NUM_SMS # remat value to use in the epilogue + ki = -1 + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in tl.range(0, k_tiles * tiles_per_SM, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + tile_id_c += NUM_SMS + group_id = tile_id_c // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id_c % group_size_m) + pid_n = (tile_id_c % num_pid_in_group) // group_size_m + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if (c_ptr.dtype == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 16), (64, 128, 32), (32, 32, 32), (256, 128, 16), + (64, 512, 16), (512, 64, 16), (64, 16, 16)]) +@pytest.mark.parametrize("NUM_WARPS", [4, 8]) +@pytest.mark.parametrize("DISALLOW_ACC_MULTI_BUFFER", [True, False]) +def test_simple_persistent_matmul(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, DISALLOW_ACC_MULTI_BUFFER, device): + M, N, K = 1024, 512, 256 + NUM_STAGES = 3 + a = torch.randn(M, K, dtype=torch.float16, device=device) + b = torch.randn(K, N, dtype=torch.float16, device=device) + output = torch.empty((M, N), dtype=torch.float16, device=device) + + # Fake small number of SMS to test that persistent kernel works reliably + NUM_SMS = 8 + + grid = (min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), ) + k = simple_persistent_kernel[grid]( + a, b, output, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + output.stride(0), output.stride(1), # + BLOCK_SIZE_M=BLOCK_M, BLOCK_SIZE_N=BLOCK_N, BLOCK_SIZE_K=BLOCK_K, # + GROUP_SIZE_M=8, NUM_SMS=NUM_SMS, DISALLOW_ACC_MULTI_BUFFER=DISALLOW_ACC_MULTI_BUFFER, num_stages=NUM_STAGES, + num_warps=NUM_WARPS) + ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float16) + + torch.testing.assert_close(ref_out, output, atol=0.01, rtol=0.01) + + # Make sure the mma is pipelined by checking if in the TTGIR we are waiting for the + # barrier coming from the loop args (previous iteration). + # This applies only if TCv5 MMA is used (M % 64 == 0 and N % 8 == 0) and + # when MMA arguments loads are pipelined (N > 16) + if (device == "cuda" and torch.cuda.get_device_capability()[0] == 10 and BLOCK_M % 64 == 0 and BLOCK_N % 8 == 0 + and BLOCK_N > 16): + ttgir = k.asm["ttgir"] + pattern = "ttng.wait_barrier %arg" + assert ttgir.count(pattern) > 0, "Expect barrier coming from the previous iteration." + + +@triton.jit +def mxfp_matmul( # + a_ptr, b_ptr, output_ptr, # + a_scale, b_scale, # + M, N, K, # + stride_scale: tl.constexpr, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + offs_scale_k = tl.arange(0, BLOCK_K // 32) + a_scale_ptr = a_scale + offs_am[:, None] * stride_scale + offs_scale_k[None, :] + b_scale_ptr = b_scale + offs_bn[:, None] * stride_scale + offs_scale_k[None, :] + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty) + for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): + k_remaining = K - k * BLOCK_K + valid_k = offs_k < k_remaining + a = tl.load(a_ptrs, mask=valid_k[None, :], other=0.) + b = tl.load(b_ptrs, mask=valid_k[:, None], other=0.) + scale_a = tl.load(a_scale_ptr) + scale_b = tl.load(b_scale_ptr) + accumulator = tl.dot_scaled(a, scale_a, "e5m2", b, scale_b, "e5m2", accumulator) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + a_scale_ptr += BLOCK_K // 32 + b_scale_ptr += BLOCK_K // 32 + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(output_ptrs, accumulator, mask=c_mask) + + +def fp8e8m0_to_float32(scale): + scale = scale.view(torch.uint8) + scale = scale.to(torch.int32) + scale = scale << 23 + scale = scale.view(torch.float32) + return scale + + +@pytest.mark.parametrize("M, N, K", [(1024, 512, 256), (128, 256, 256), (128, 128, 128), (2, 4, 32), (2, 4, 64), + (256, 16, 32)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128), + (128, 256, 256), (128, 128, 64), (128, 64, 128)]) +@pytest.mark.parametrize("NUM_STAGES", [1, 3]) +@pytest.mark.parametrize("NUM_WARPS", [4, 8]) +@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if is_hip_cdna() else [0])) +def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device): + if is_cuda() and torch.cuda.get_device_capability()[0] < 10: + pytest.skip("Requires compute capability >= 10") + elif is_hip(): + if not is_hip_mi350(): + pytest.skip("Scaled mxfp8 matmul is only natively supported on MI350") + if (M == 2 and N == 4 and K == 32) or (M == 256 and N == 16 and K == 32): + pytest.skip(f"Input shape {M=}, {N=}, {K=} is not supported yet") + if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): + pytest.skip(f"MI350 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") + + if BLOCK_N == 256 and BLOCK_K == 256: + NUM_STAGES = min(NUM_STAGES, 2) + torch.manual_seed(42) + dtype_src_str = "float8e5" + dtype_dst_str = "float32" + a = torch.randint(20, 40, (M, K), dtype=torch.uint8, device=device).view(torch.float8_e5m2) + a_f16 = f8_to_f16(a, dtype_src_str) + b = torch.randint(20, 40, (K, N), dtype=torch.uint8, device=device).view(torch.float8_e5m2) + b_f16 = f8_to_f16(b, dtype_src_str) + a_scale = torch.randint(130, (M, K // 32), dtype=torch.uint8, device=device) + b_scale = torch.randint(130, (N, K // 32), dtype=torch.uint8, device=device) + + dtype_dst = getattr(torch, dtype_dst_str) + output = torch.empty((M, N), dtype=dtype_dst, device=device) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + kernel_kwargs = {} + if is_hip(): + kernel_kwargs["matrix_instr_nonkdim"] = nonKDim + out = mxfp_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, a_scale.stride(0), a.stride(0), a.stride(1), + b.stride(0), b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, + NUM_STAGES=NUM_STAGES, **kernel_kwargs, num_warps=NUM_WARPS) + a_scale_f32 = fp8e8m0_to_float32(a_scale) + b_scale_f32 = fp8e8m0_to_float32(b_scale) + a_scale_f32 = a_scale_f32.repeat_interleave(32, dim=1) + b_scale_f32 = b_scale_f32.repeat_interleave(32, dim=1) + + # b_scales are always col major + b_scale_f32 = b_scale_f32.T.contiguous() + + a = a_f16 * a_scale_f32 + b = b_f16 * b_scale_f32 + ref_out = torch.matmul(a, b).to(torch.float32) + output = output.to(torch.float32) + atol = 0.0001 + rtol = 0.0001 + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol) + + # Pipelining of dot_scaled requires tmem_copy to be used, which in turn + # requires the scales to be in the blocked layout in global memory. + assert "ttng.wait_barrier" not in out.asm["ttgir"] + + +def _knob_promote_lhs_to_tmem(monkeypatch): + # Promoting the LHS to TMEM should be patched because it will otherwise + # unintentionally be enabled for all consecutive tests if using os.environ + monkeypatch.setenv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION", "1") + + +@triton.jit +def block_scale_mxfp_matmul( # + a_ptr, b_ptr, output_ptr, # + a_scale, b_scale, # + M, N, K, # + stride_sk, stride_sb, stride_sc, stride_sd: tl.constexpr, # Need tl.constexpr to pipeline scale load. Why? + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr, USE_2D_SCALE_LOAD: tl.constexpr): + ## This kernel assumes a_scale and b_scale are coming in with shapes + ## [BLOCK_M(or N) // 128, BLOCK_K // 128, 32, 4, 4] for optimial performance + ## on nvidia sm100+ HW + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + + offs_sm = (pid_m * (BLOCK_M // 128) + tl.arange(0, BLOCK_M // 128)) + offs_sn = (pid_n * (BLOCK_N // 128) + tl.arange(0, BLOCK_N // 128)) + + if USE_2D_SCALE_LOAD: + offs_inner = tl.arange(0, (BLOCK_K // 128) * 32 * 4 * 4) + a_scale_ptr = a_scale + offs_sm[:, None] * stride_sk + offs_inner[None, :] + b_scale_ptr = b_scale + offs_sn[:, None] * stride_sk + offs_inner[None, :] + else: + offs_sk = tl.arange(0, (BLOCK_K // 128)) + offs_sc = tl.arange(0, 32) + offs_sd = tl.arange(0, 4) + a_scale_ptr = a_scale + (offs_sm[:, None, None, None, None] * stride_sk + offs_sk[None, :, None, None, None] * + stride_sb + offs_sc[None, None, :, None, None] * stride_sc + + offs_sd[None, None, None, :, None] * stride_sd + offs_sd[None, None, None, None, :]) + b_scale_ptr = b_scale + (offs_sn[:, None, None, None, None] * stride_sk + offs_sk[None, :, None, None, None] * + stride_sb + offs_sc[None, None, :, None, None] * stride_sc + + offs_sd[None, None, None, :, None] * stride_sd + offs_sd[None, None, None, None, :]) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty) + for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + scale_a = tl.load(a_scale_ptr) + scale_b = tl.load(b_scale_ptr) + + if USE_2D_SCALE_LOAD: + scale_a = scale_a.reshape(BLOCK_M // 128, BLOCK_K // 128, 32, 4, 4) + scale_b = scale_b.reshape(BLOCK_N // 128, BLOCK_K // 128, 32, 4, 4) + + # Scales are comming in for optimial peformance, but we reshape here for + # the canonical inputs to dot_scaled + # These reshapes and transposes will be optimized away during lowering + scale_a = scale_a.trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // 32) + scale_b = scale_b.trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // 32) + accumulator = tl.dot_scaled(a, scale_a, "e5m2", b, scale_b, "e5m2", accumulator) + + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + a_scale_ptr += BLOCK_K // 128 * stride_sb + b_scale_ptr += BLOCK_K // 128 * stride_sb + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(output_ptrs, accumulator, mask=c_mask) + + +@pytest.mark.parametrize("M, N, K", [(1024, 512, 512), (998, 111, 512), (63, 128, 512)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128), + (128, 128, 256), (128, 256, 256)]) +@pytest.mark.parametrize("NUM_STAGES", [1, 2, 4]) +@pytest.mark.parametrize("USE_2D_SCALE_LOAD", [False, True]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") +def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device, monkeypatch): + if BLOCK_N == 256 and BLOCK_K == 256: + NUM_STAGES = min(NUM_STAGES, 2) + elif BLOCK_K == 256: + NUM_STAGES = min(NUM_STAGES, 3) + + torch.manual_seed(42) + dtype_src_str = "float8e5" + dtype_dst_str = "float32" + a = torch.randint(20, 40, (M, K), dtype=torch.uint8, device=device).view(torch.float8_e5m2) + A = f8_to_f16(a, dtype_src_str) + b = torch.randint(20, 40, (K, N), dtype=torch.uint8, device=device).view(torch.float8_e5m2) + B = f8_to_f16(b, dtype_src_str) + ceildiv = lambda a, b: math.ceil(a / b) + a_scale = torch.randint(130, (ceildiv(M, 128), ceildiv(K, 128), 32, 4, 4), dtype=torch.uint8).to(device) + b_scale = torch.randint(130, (ceildiv(N, 128), ceildiv(K, 128), 32, 4, 4), dtype=torch.uint8).to(device) + + dtype_dst = getattr(torch, dtype_dst_str) + output = torch.empty((M, N), dtype=dtype_dst, device=device) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + out = block_scale_mxfp_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, a_scale.stride(0), a_scale.stride(1), + a_scale.stride(2), a_scale.stride(3), a.stride(0), a.stride(1), b.stride(0), + b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, + NUM_STAGES=NUM_STAGES, USE_2D_SCALE_LOAD=USE_2D_SCALE_LOAD) + ttgir = out.asm["ttgir"] + ptx = out.asm["ptx"] + + def flatten_scale(scale): + num_chunk_m, num_chunk_k, _, _, _ = scale.shape + return scale.permute(0, 3, 2, 1, 4).reshape(num_chunk_m * 128, num_chunk_k * 4).contiguous() + + a_scale_f32 = flatten_scale(fp8e8m0_to_float32(a_scale))[:M] + b_scale_f32 = flatten_scale(fp8e8m0_to_float32(b_scale))[:N] + a_scale_f32 = a_scale_f32.repeat_interleave(32, dim=1) + b_scale_f32 = b_scale_f32.repeat_interleave(32, dim=1) + + # b_scales are always col major + b_scale_f32 = b_scale_f32.T.contiguous() + + a = A * a_scale_f32 + b = B * b_scale_f32 + ref_out = torch.matmul(a, b).to(torch.float32) + output = output.to(torch.float32) + atol = 0.0001 + rtol = 0.0001 + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol) + + if USE_2D_SCALE_LOAD: + # Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load. + # The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914 + assert "tcgen05.cp" in ptx + if NUM_STAGES > 1: + if BLOCK_M == BLOCK_K and BLOCK_N == BLOCK_K: + load_pipelined = ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") == 2 + else: + load_pipelined = (ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") and + ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_K}x{BLOCK_N}")) + + if load_pipelined and USE_2D_SCALE_LOAD: + # If load is pipelined and tmem_copy is used, MMA pipelining should also kick in + assert "ttng.wait_barrier" in ttgir + elif not load_pipelined: + # The behavior of load pipelining seems to depend on the size of input tensors. + # In this test, it fails to pipeline the RHS tensor when N is not a multiple of 128. Pipelining of the LHS tensor + # does not seem to be affected by the value of M, though. + print(f"SWP failed for M = {M}, N = {N}") + + +@triton.jit +def lhs_in_tmem_kernel( # + a_ptr, b_ptr, output_ptr, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, A_TRANS: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + + if not A_TRANS: + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + else: + a_ptrs = a_ptr + (offs_k[:, None] * stride_am + offs_am[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=1): + k_remaining = K - k * BLOCK_K + valid_k = offs_k < k_remaining + m_remaining = M - pid_m * BLOCK_M + valid_m = offs_am < m_remaining + a = tl.load(a_ptrs, mask=(valid_k[None, :] & valid_m[:, None]), other=0.0) + if A_TRANS: + a = a.T + n_remaining = N - pid_n * BLOCK_N + valid_n = offs_bn < n_remaining + b = tl.load(b_ptrs, mask=(valid_k[:, None] & valid_n[None, :]), other=0.0) + accumulator = tl.dot(a, b, acc=accumulator) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(output_ptrs, accumulator, mask=mask_c) + + +@pytest.mark.parametrize("M, N, K", [(128, 64, 64), (1024, 512, 256)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128), + (128, 256, 256), (128, 128, 64), (128, 64, 128)]) +@pytest.mark.parametrize("a_trans", [False, True]) +@pytest.mark.parametrize("dtype_src_str", ["float32", "float16", "float8e5"]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") +def test_lhs_in_tmem(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, a_trans, dtype_src_str, device, monkeypatch): + _knob_promote_lhs_to_tmem(monkeypatch) + if M != BLOCK_M or N != BLOCK_N or K != BLOCK_K: + # TODO: Make LHS TMEM promotion work for all problem sizes regardless of block dims + pytest.xfail( + "LHS TMEM promotion produces incorrect results when the workload dimensions are not equal to the block dims" + ) + torch.manual_seed(42) + if dtype_src_str == "float8e5": + a = torch.randint(20, 40, (M, K), dtype=torch.int8, device=device).view(torch.float8_e5m2) + if (a_trans): + a = a.T + b = torch.randint(20, 40, (K, N), dtype=torch.int8, device=device).view(torch.float8_e5m2) + A = f8_to_f16(a, dtype_src_str) + B = f8_to_f16(b, dtype_src_str) + else: + dtype_src = getattr(torch, dtype_src_str) + a = torch.randn(M, K, dtype=dtype_src, device=device) + if (a_trans): + a = a.T + b = torch.randn(K, N, dtype=dtype_src, device=device) + A = a + B = b + output = torch.empty((M, N), dtype=torch.float16, device=device) + grid = (1, 1) + lhs_in_tmem_kernel[grid](a, b, output, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0), + output.stride(1), M, N, K, A_TRANS=a_trans, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K) + ref_out = torch.matmul(A if not a_trans else A.T, B).to(torch.float16) + + atol = 0.03 + rtol = 0.03 + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol) + + +@triton.jit +def lhs_in_tmem_kernel_mxfp( # + a_ptr, b_ptr, output_ptr, # + a_scale, b_scale, # + stride_scale, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + offs_am = tl.arange(0, M) + offs_bn = tl.arange(0, N) + offs_k = tl.arange(0, K) + offs_scale_k = tl.arange(0, K // 32) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + a_scale_ptr = a_scale + offs_am[:, None] * stride_scale + offs_scale_k[None, :] + b_scale_ptr = b_scale + offs_bn[:, None] * stride_scale + offs_scale_k[None, :] + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + scale_a = tl.load(a_scale_ptr) + scale_b = tl.load(b_scale_ptr) + accumulator = tl.dot_scaled(a, scale_a, "e5m2", b, scale_b, "e5m2") + offs_cm = tl.arange(0, M) + offs_cn = tl.arange(0, N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(output_ptrs, accumulator) + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") +def test_lhs_in_tmem_mxfp(device, monkeypatch): + _knob_promote_lhs_to_tmem(monkeypatch) + M, N, K = 128, 64, 32 + torch.manual_seed(42) + a = torch.randint(20, 40, (M, K), dtype=torch.uint8, device=device) + b = torch.randint(20, 40, (K, N), dtype=torch.uint8, device=device) + A = f8_to_f16(a, "float8e5") + B = f8_to_f16(b, "float8e5") + a_scale = torch.randint(124, 130, (M, K // 32), dtype=torch.uint8, device=device) + b_scale = torch.randint(124, 130, (N, K // 32), dtype=torch.uint8, device=device) + output = torch.empty((M, N), dtype=torch.float16, device=device) + grid = (1, 1) + lhs_in_tmem_kernel_mxfp[grid](a, b, output, a_scale, b_scale, a_scale.stride(0), a.stride(0), a.stride(1), + b.stride(0), b.stride(1), output.stride(0), output.stride(1), M, N, K) + a_scale_f32 = fp8e8m0_to_float32(a_scale) + b_scale_f32 = fp8e8m0_to_float32(b_scale) + a_scale_f32 = a_scale_f32.repeat_interleave(32, dim=1) + b_scale_f32 = b_scale_f32.repeat_interleave(32, dim=1) + + # b_scales are always col major + b_scale_f32 = b_scale_f32.T.contiguous() + + a = A * a_scale_f32 + b = B * b_scale_f32 + ref_out = torch.matmul(a, b).to(torch.float16) + atol = 0.003 + rtol = 0.003 + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol) + + +@triton.jit +def block_scale_fp4_matmul( # + a_ptr, b_ptr, output_ptr, # + a_scale, b_scale, # + M, N, K, # + stride_scale, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + VEC_SIZE: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr): # + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + + # Two e2m1 values per K + offs_k = tl.arange(0, BLOCK_K // 2) + offs_scale_k = tl.arange(0, BLOCK_K // VEC_SIZE) + if a_scale is not None: + a_scale_ptr = a_scale + offs_am[:, None] * stride_scale + offs_scale_k[None, :] + if b_scale is not None: + b_scale_ptr = b_scale + offs_bn[:, None] * stride_scale + offs_scale_k[None, :] + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty) + for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): + k_remaining = tl.cdiv(K - k * BLOCK_K, 2) + valid_k = offs_k < k_remaining + a = tl.load(a_ptrs, mask=valid_k[None, :], other=0) + b = tl.load(b_ptrs, mask=valid_k[:, None], other=0) + if a_scale is not None: + scale_a = tl.load(a_scale_ptr) + else: + scale_a = None + if b_scale is not None: + scale_b = tl.load(b_scale_ptr) + else: + scale_b = None + accumulator = tl.dot_scaled(a, scale_a, "e2m1", b, scale_b, "e2m1", accumulator) + a_ptrs += (BLOCK_K // 2) * stride_ak + b_ptrs += (BLOCK_K // 2) * stride_bk + if a_scale is not None: + a_scale_ptr += BLOCK_K // VEC_SIZE + if b_scale is not None: + b_scale_ptr += BLOCK_K // VEC_SIZE + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(output_ptrs, accumulator, mask=c_mask) + + +@pytest.mark.parametrize("M, N, K", [(1024, 512, 256), (2, 4, 64)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128), + (128, 256, 256), (128, 128, 64), (128, 64, 128)]) +@pytest.mark.parametrize("with_a_scale", [True, False]) +@pytest.mark.parametrize("with_b_scale", [True, False]) +@pytest.mark.parametrize(("scale_type", "VEC_SIZE"), [("float8_e8m0fnu", 32), ("float8_e4m3fn", 16)], + ids=["mxfp4", "nvfp4"]) +@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if is_hip_cdna() else [0])) +def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, with_a_scale, with_b_scale, scale_type, nonKDim, + device): + if is_cuda(): + if torch.cuda.get_device_capability()[0] < 10: + pytest.skip("Requires compute capability >= 10") + if not (with_a_scale and with_b_scale): + pytest.skip("None aScale/bScale is only tested on AMD backend for now") + elif is_hip(): + if not is_hip_mi350(): + pytest.skip("Scaled fp4 matmul is only natively supported on MI350") + if scale_type != 'float8_e8m0fnu': + pytest.skip("MI350 only supports E8M0 scale") + if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): + pytest.skip(f"MI350 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") + + NUM_STAGES = 1 + torch.manual_seed(42) + a_mxfp4 = MXFP4Tensor(size=(M, K), device=device).random() + a = a_mxfp4.to_packed_tensor(dim=1) + # Generate b with k-major layout, pack two e2m1 along k, then logical transpose to K, N + b_mxfp4 = MXFP4Tensor(size=(N, K), device=device).random() + b = b_mxfp4.to_packed_tensor(dim=1).T + # No need to pack along K since we convert each e2m1 to f32 directly for the reference matmul + b_ref = b_mxfp4.to(torch.float32).T + + a_size = (M, (K + VEC_SIZE - 1) // VEC_SIZE) + b_size = (N, (K + VEC_SIZE - 1) // VEC_SIZE) + a_scale = torch.rand(a_size, device=device) + b_scale = torch.rand(b_size, device=device) + if scale_type == "float8_e8m0fnu": + a_scale_ref = MXScaleTensor(a_scale) + b_scale_ref = MXScaleTensor(b_scale) + a_scale = a_scale_ref.data + b_scale = b_scale_ref.data + elif scale_type == "float8_e4m3fn": + a_scale = a_scale.to(torch.float8_e4m3fn) + b_scale = b_scale.to(torch.float8_e4m3fn) + a_scale_ref = a_scale + b_scale_ref = b_scale + + a_scale_ref = a_scale_ref.to(torch.float32).repeat_interleave(VEC_SIZE, dim=1)[:M, :K] + b_scale_ref = b_scale_ref.to(torch.float32).repeat_interleave(VEC_SIZE, dim=1).T.contiguous()[:K, :N] + stride_scale = a_scale.stride(0) + if not with_a_scale: + a_scale = None + a_scale_ref = 1.0 + if not with_b_scale: + b_scale = None + b_scale_ref = 1.0 + ref_out = torch.matmul(a_mxfp4.to(torch.float32) * a_scale_ref, b_ref * b_scale_ref) + + output = a.new_empty((M, N), dtype=torch.float32) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + kernel_kwargs = {} + if is_hip(): + kernel_kwargs["matrix_instr_nonkdim"] = nonKDim + block_scale_fp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1), + b.stride(0), b.stride(1), output.stride(0), output.stride(1), VEC_SIZE, BLOCK_M, + BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES, **kernel_kwargs) + + torch.testing.assert_close(ref_out, output, atol=1e-2, rtol=1e-2) + + +@triton.jit +def mxfp8_mxfp4_matmul( # + a_ptr, b_ptr, output_ptr, # + a_scale, b_scale, # + M, N, K, # + stride_scale, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + tensor_scale: tl.constexpr, # + DTYPE_A: tl.constexpr, # + DTYPE_B: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr): # + DIV_FACTOR_A: tl.constexpr = 2 if DTYPE_A == "e2m1" else 1 + DIV_FACTOR_B: tl.constexpr = 2 if DTYPE_B == "e2m1" else 1 + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_ak = tl.arange(0, BLOCK_K // DIV_FACTOR_A) + offs_bk = tl.arange(0, BLOCK_K // DIV_FACTOR_B) + offs_scale_k = tl.arange(0, BLOCK_K // 32) + + if a_scale is not None: + a_scale_ptr = a_scale + offs_am[:, None] * stride_scale + offs_scale_k[None, :] + if b_scale is not None: + b_scale_ptr = b_scale + offs_bn[:, None] * stride_scale + offs_scale_k[None, :] + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty) + + for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + if a_scale is not None: + if tensor_scale: + scale_a = tl.load(a_scale_ptr) + else: + scale_a = tl.full(a_scale_ptr.shape, a_scale.to(tl.int8), dtype=tl.int8) + else: + scale_a = None + if b_scale is not None: + scale_b = tl.load(b_scale_ptr) + else: + scale_b = None + accumulator = tl.dot_scaled(a, scale_a, DTYPE_A, b, scale_b, DTYPE_B, accumulator) + a_ptrs += (BLOCK_K // DIV_FACTOR_A) * stride_ak + b_ptrs += (BLOCK_K // DIV_FACTOR_B) * stride_bk + if a_scale is not None: + a_scale_ptr += BLOCK_K // 32 + if b_scale is not None: + b_scale_ptr += BLOCK_K // 32 + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(output_ptrs, accumulator, mask=c_mask) + + +@pytest.mark.parametrize("M, N, K", [(1024, 512, 512)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128), + (128, 256, 256), (128, 128, 64), (128, 64, 128)]) +@pytest.mark.parametrize("NUM_STAGES", [1, 3]) +@pytest.mark.parametrize("B_TRANS", [True, False]) +@pytest.mark.parametrize("CONST_SCALE", [True, False]) +@pytest.mark.parametrize("A_DATA_TYPE", ["float8e5", "float8e4nv", "float4"]) +@pytest.mark.parametrize("B_DATA_TYPE", ["float8e5", "float8e4nv", "float4"]) +@pytest.mark.parametrize("WITH_A_SCALE", [True, False]) +@pytest.mark.parametrize("WITH_B_SCALE", [True, False]) +@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if is_hip_cdna() else [0])) +def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TRANS, CONST_SCALE, A_DATA_TYPE, + B_DATA_TYPE, WITH_A_SCALE, WITH_B_SCALE, nonKDim, device): + if is_cuda(): + if torch.cuda.get_device_capability()[0] < 10: + pytest.skip("Requires compute capability >= 10") + if not (WITH_A_SCALE and WITH_B_SCALE): + pytest.skip("None scale has not been tested on NV backend") + if not (A_DATA_TYPE == "float8e5" and B_DATA_TYPE == "float4"): + pytest.skip(f"(A: {A_DATA_TYPE}, B: {B_DATA_TYPE}) has not been tested on NV backend") + elif is_hip(): + if not is_hip_mi350(): + pytest.skip("Scaled mxfp4 & mxfp8 matmul is only natively supported on MI350") + if CONST_SCALE: + pytest.skip("Constant scale is not supported in AMD backend for now") + if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): + pytest.skip(f"MI350 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") + if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE): + pytest.skip("Float4 without scale is tested in test_block_scale_fp4") + + if B_DATA_TYPE != 'float4' and B_TRANS: + pytest.skip(f'No need to transpose B for {B_DATA_TYPE}') + + if not is_hip() and BLOCK_N == 256 and BLOCK_K == 256: + NUM_STAGES = 2 + + torch.manual_seed(42) + + def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bool = False): + if dtype == "float8e5": + v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e5m2).to(device) + v_ref = f8_to_f16(v.view(torch.float8_e5m2), dtype).to(torch.float32) + elif dtype == "float8e4nv": + v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device) + v_ref = f8_to_f16(v.view(torch.float8_e4m3fn), dtype).to(torch.float32) + else: + # float4 + if transpose: + v_mxfp4 = MXFP4Tensor(size=(size0, size1), device=device).random() + v = v_mxfp4.to_packed_tensor(dim=k_dim) + v_ref = v_mxfp4.to(torch.float32) + else: + v_mxfp4 = MXFP4Tensor(size=(size1, size0), device=device).random() + v = v_mxfp4.to_packed_tensor(dim=(k_dim + 1) % 2).T + v_ref = v_mxfp4.to(torch.float32).T + return v, v_ref + + dtype_converter = {'float8e5': 'e5m2', 'float8e4nv': 'e4m3', 'float4': 'e2m1'} + + a, a_ref = create_operand(A_DATA_TYPE, M, K, 1) + b, b_ref = create_operand(B_DATA_TYPE, K, N, 0, B_TRANS) + + a_scale_mxfp4 = MXScaleTensor(size=(M, (K + 32 - 1) // 32), device=device).random(high=64.0) + b_scale_mxfp4 = MXScaleTensor(size=(N, (K + 32 - 1) // 32), device=device).random(high=64.0) + a_scale = a_scale_mxfp4.data + b_scale = b_scale_mxfp4.data + + a_scale_ref = a_scale_mxfp4.to(torch.float32).repeat_interleave(32, dim=1)[:M, :K] + if CONST_SCALE: + a_scale_ref = torch.full_like(a_scale_ref, 2.0) + a_scale = 128 # 2.0 in e8m0 + b_scale_ref = b_scale_mxfp4.to(torch.float32).repeat_interleave(32, dim=1).T.contiguous()[:K, :N] + stride_scale = b_scale.stride(0) + if not WITH_A_SCALE: + a_scale = None + a_scale_ref = 1.0 + if not WITH_B_SCALE: + b_scale = None + b_scale_ref = 1.0 + + ref_out = torch.matmul(a_ref * a_scale_ref, b_ref * b_scale_ref) + + output = a.new_empty((M, N), dtype=torch.float32) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + kernel_kwargs = {} + if is_hip(): + kernel_kwargs["matrix_instr_nonkdim"] = nonKDim + out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1), + b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE, + dtype_converter[A_DATA_TYPE], dtype_converter[B_DATA_TYPE], BLOCK_M, BLOCK_N, + BLOCK_K, NUM_STAGES=NUM_STAGES, **kernel_kwargs) + if is_cuda(): + ttgir = out.asm["ttgir"] + assert "fp4Padded = true" in ttgir + + torch.testing.assert_close(ref_out, output, atol=1e-3, rtol=1e-3) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_mxfp.py b/third_party/enflame/include/triton/python/test/unit/language/test_mxfp.py new file mode 100644 index 000000000..3e0d6c050 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_mxfp.py @@ -0,0 +1,127 @@ +import pytest +import torch +from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor + + +class MXBaseTest: + + @pytest.fixture + def device(self): + return "cpu" + + +class TestMXFP4Tensor(MXBaseTest): + + @pytest.mark.parametrize("K, N", [(64, 128), (128, 256)]) + def test_roundtrip(self, K, N, device): + tensor = MXFP4Tensor(size=(K, N), device=device).random() + tensor2 = MXFP4Tensor(tensor.to(torch.float32)) + torch.testing.assert_close(tensor.data, tensor2.data) + + @pytest.mark.parametrize("K, N, dim", [(64, 128, 0), (64, 128, 1)]) + def test_packed_tensor(self, K, N, dim, device): + tensor = MXFP4Tensor(size=(K, N), device=device).random() + packed = tensor.to_packed_tensor(dim=dim) + unpacked = tensor.unpack_packed_tensor(packed, dim=dim, original_shape=(K, N)) + torch.testing.assert_close(tensor.data, unpacked) + + def test_padding(self, device): + tensor_pad = MXFP4Tensor(torch.tensor([4], device=device)) + pad_packed = tensor_pad.to_packed_tensor(dim=0) + torch.testing.assert_close(tensor_pad.data, + tensor_pad.unpack_packed_tensor(pad_packed, dim=0, original_shape=(1, ))) + + def test_zero_values(self, device): + test_values = torch.tensor([0.0, -0.0], device=device) + tensor = MXFP4Tensor(test_values) + expected_encodings = torch.tensor([0b0000, 0b1000], dtype=torch.uint8, device=device) + assert torch.equal(tensor.data, expected_encodings), "Zero values should be encoded as 0" + torch.testing.assert_close(tensor.to(torch.float32), test_values) + + def test_out_of_range_values(self, device): + test_values = torch.tensor([7.0, -7.0, float('inf'), float('-inf')], device=device) + tensor = MXFP4Tensor(test_values) + expected_values = torch.tensor([6.0, -6.0, 6.0, -6.0], device=device) + torch.testing.assert_close(tensor.to(torch.float32), expected_values) + + def test_subnormal_numbers(self, device): + test_values = torch.tensor([0.1, 0.2, 0.3, 0.4], device=device) + tensor = MXFP4Tensor(test_values) + expected_values = torch.tensor([0.0, 0.0, 0.5, 0.5], device=device) + torch.testing.assert_close(tensor.to(torch.float32), expected_values) + + def test_rounding_edge_cases(self, device): + test_values = torch.tensor([0.75, 1.25, 1.75, 2.5, 3.5, 5.0], device=device) + expected_values = torch.tensor([1.0, 1.0, 2.0, 2.0, 4.0, 4.0], device=device) + tensor = MXFP4Tensor(test_values) + torch.testing.assert_close(tensor.to(torch.float32), expected_values) + + def test_negative_values(self, device): + test_values = torch.tensor([-0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], device=device) + tensor = MXFP4Tensor(test_values) + torch.testing.assert_close(tensor.to(torch.float32), test_values) + + def test_negative_out_of_range(self, device): + tensor = MXFP4Tensor(torch.tensor([-7.0, -8.0, -10.0], device=device)) + expected_values = torch.tensor([-6.0, -6.0, -6.0], device=device) + torch.testing.assert_close(tensor.to(torch.float32), expected_values) + + @pytest.mark.parametrize("shape, dim", [ + ((1024, ), 0), + ((128, 256), 0), + ((128, 256), 1), + ((64, 64, 64), 2), + ]) + def test_packing(self, shape, dim, device): + tensor = MXFP4Tensor(size=shape, device=device).random() + packed = tensor.to_packed_tensor(dim=dim) + unpacked = tensor.unpack_packed_tensor(packed, dim=dim, original_shape=shape) + torch.testing.assert_close(tensor.data, unpacked) + + def test_packing_with_padding(self, device): + shape = (7, 5) + dim = 1 + tensor = MXFP4Tensor(size=shape, device=device).random() + packed = tensor.to_packed_tensor(dim=dim) + unpacked = tensor.unpack_packed_tensor(packed, dim=dim, original_shape=shape) + torch.testing.assert_close(tensor.data, unpacked) + + def test_invalid_packing_dimension(self, device): + tensor = MXFP4Tensor(size=(4, 4), device=device).random() + with pytest.raises(AssertionError): + tensor.to_packed_tensor(dim=2) # Invalid dimension + + def test_empty_tensor(self, device): + tensor = MXFP4Tensor(torch.tensor([], device=device)) + assert tensor.to(torch.float32).numel() == 0 + + +class TestMXScaleTensor(MXBaseTest): + + def test_positive_values(self, device): + values = torch.tensor([1.0, 2.0, 4.0, 8.0], device=device) + data = MXScaleTensor(values) + torch.testing.assert_close(data.to(torch.float32), values) + + def test_special_values(self, device): + values = torch.tensor([0.0, -1.0, float('nan'), float('inf'), float('-inf')], device=device) + tensor = MXScaleTensor(values) + expected_data = torch.tensor([255, 255, 255, 255, 255], dtype=torch.uint8, device=device) + assert torch.equal(expected_data, tensor.data), "Special values should be encoded as NaN (255)" + + def test_e8m0_nan_to_float_nan(self, device): + tensor = MXScaleTensor(size=(1, ), device=device) + tensor.data = torch.tensor([255], device=device, dtype=torch.uint8) + assert torch.isnan(tensor.to(torch.float32)), "E8M0 NaN encoding should convert to float32 NaN" + + def test_random_generation(self, device): + data = MXScaleTensor(size=(1000, ), device=device).random() + data = data.data + assert ((data >= 0) & (data <= 254)).all(), "Generated data should be between 0 and 254" + assert (data != 255).all(), "Generated data should not include NaN encoding (255)" + + @pytest.mark.parametrize("K, N", [(64, 128), (128, 256)]) + def test_roundtrip(self, K, N, device): + tensor = MXScaleTensor(size=(K, N), device=device).random() + tensor2 = MXScaleTensor(tensor.to(torch.float32)) + torch.testing.assert_close(tensor.data, tensor2.data) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_pipeliner.py b/third_party/enflame/include/triton/python/test/unit/language/test_pipeliner.py new file mode 100644 index 000000000..0831eeea1 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_pipeliner.py @@ -0,0 +1,548 @@ +# End-to-end tests to check the correctness of the pipeliner + +import pytest +import torch +import triton +import triton.language as tl +import triton.tools.experimental_descriptor + +from triton._internal_testing import is_cuda, is_hopper, is_hip_cdna, is_hip_mi200, is_hip + + +def check_capabilities(): + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] < 8: + pytest.skip("CUDA 8.0+ required") + + +@triton.jit +def matmul_kernel( # + a_ptr, scale_ptr, b_ptr, output_ptr, # + M, N, K_MXFP, # K_MXFP is the number of mxfp vectors in a row of a. Otherwise it's just K + stride_am, stride_ak, # + stride_sm, stride_sk, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr, a_type: tl.constexpr, b_type: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + IS_SCALED: tl.constexpr = a_type is not None and b_type is not None + DIV_FACTOR: tl.constexpr = 2 if IS_SCALED and a_type == "e2m1" else 1 + # We pass K_MXFP to make explicit that KB is multiple of 32 and KA is multiple of 16 or 32 + # for the pipeliner divisibility condition + KA = K_MXFP if not IS_SCALED else K_MXFP * (32 // DIV_FACTOR) + KB = K_MXFP if not IS_SCALED else K_MXFP * 32 + BLOCK_AK: tl.constexpr = BLOCK_K // DIV_FACTOR + offs_k = tl.arange(0, BLOCK_K) + offs_ak = tl.arange(0, BLOCK_AK) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if IS_SCALED: + BLOCK_SK: tl.constexpr = BLOCK_K // 32 + offs_sk = tl.arange(0, BLOCK_SK) + scale_ptrs = scale_ptr + (offs_am[:, None] * stride_sm + offs_sk[None, :] * stride_sk) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(KB, BLOCK_K), num_stages=NUM_STAGES): + mask_a = (offs_am[:, None] < M) & (offs_ak[None, :] + k * BLOCK_AK < KA) + mask_b = ((offs_k[:, None] + k * BLOCK_K) < KB) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=mask_a, other=0) + b = tl.load(b_ptrs, mask=mask_b, other=0) + if IS_SCALED: + # Adapted scale indexing and dot_scaled operation + mask_scale = (offs_am[:, None] < M) & (offs_sk[None, :] + k * BLOCK_SK < K_MXFP) + a_scale = tl.load(scale_ptrs, mask=mask_scale, other=0) + accumulator = tl.dot_scaled(a, a_scale, a_type, b, None, b_type, acc=accumulator) + else: + accumulator = tl.dot(a, b, acc=accumulator) + a_ptrs += BLOCK_AK * stride_ak + b_ptrs += BLOCK_K * stride_bk + if IS_SCALED: + scale_ptrs += BLOCK_SK * stride_sk + OUT_DTYPE = tl.bfloat16 if IS_SCALED else tl.float16 + accumulator = accumulator.to(OUT_DTYPE) + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(output_ptrs, accumulator, mask=mask_c) + + +@triton.jit +def matmul_kernel_tma( # + a_ptr, b_ptr, output_ptr, # + M, N, K, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M) % M + offs_bn = (pid_n * BLOCK_N) % N + offs_am = tl.multiple_of(offs_am, BLOCK_M) + offs_bn = tl.multiple_of(offs_bn, BLOCK_N) + offs_k = 0 + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for _ in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): + a = tl._experimental_descriptor_load(a_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], tl.float16) + b = tl._experimental_descriptor_load(b_ptr, [offs_k, offs_bn], [BLOCK_K, BLOCK_N], tl.float16) + accumulator = tl.dot(a, b, acc=accumulator) + offs_k += BLOCK_K + accumulator = accumulator.to(tl.float16) + tl._experimental_descriptor_store(output_ptr, accumulator, [offs_am, offs_bn]) + + +@triton.jit +def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE * num_blocks + offsets = block_start + tl.arange(0, BLOCK_SIZE) + for _ in tl.range(0, num_blocks, num_stages=NUM_STAGES): + mask = offsets < n_elements + x = tl.load(a_ptr + offsets, mask=mask) + y = tl.load(b_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + offsets += BLOCK_SIZE + + +@triton.jit +def mxfp_to_bf16_kernel( + x_ptr, + scale_ptr, + mxfp_ptr, + N, + e_bits: tl.constexpr, + m_bits: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + # x.shape == (N, 32) for fp8 or (N, 16) for fp4 + # scale.shape == (N,) + # out.shape == (N, 32) + is_fp8: tl.constexpr = e_bits + m_bits == 7 + # fp8: BLOCK_SIZE -> BLOCK_SIZE // 32, 32 + # fp4: BLOCK_SIZE // 2 -> BLOCK_SIZE // 32 , 16 + PARALLEL_DIM: tl.constexpr = BLOCK_SIZE // 32 + LAST_DIM: tl.constexpr = 32 if is_fp8 else 16 + LOAD_SIZE: tl.constexpr = LAST_DIM * PARALLEL_DIM + + offsets = (tl.program_id(0) * LOAD_SIZE + tl.arange(0, PARALLEL_DIM)[:, None] * LAST_DIM + + tl.arange(0, LAST_DIM)[None, :]) + x = tl.load(x_ptr + offsets, mask=offsets < N * LAST_DIM) + + offsets = tl.program_id(0) * PARALLEL_DIM + tl.arange(0, PARALLEL_DIM)[:, None] + scale = tl.load(scale_ptr + offsets, mask=offsets < N) + tl.static_assert(scale.dtype == tl.uint8) + tl.static_assert(x.dtype == tl.uint8) + + scale_bf16 = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) + if is_fp8: + if e_bits == 5 and m_bits == 2: + x_f8 = x.to(tl.float8e5, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits + non_finite_mask_bf16: tl.constexpr = ((1 << 8) - 1) << 7 + x_bf16 = tl.where( + x & non_finite_mask == non_finite_mask, + (x_bf16.to(tl.uint16, bitcast=True) | non_finite_mask_bf16).to(tl.bfloat16, bitcast=True), + x_bf16, + ) + else: + tl.static_assert(e_bits == 4 and m_bits == 3) + x_f8 = x.to(tl.float8e4nv, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + else: + # e2m1 + em0 = x & 0x7 + em1 = x & 0x70 + x0 = (em0.to(tl.uint16) << 2 + 4) | ((x & 0x8).to(tl.uint16) << 8 + 4) + x1 = (em1.to(tl.uint16) << (2)) | ((x & 0x80).to(tl.uint16) << (8)) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x6) != 0, x0 + ((127 - 1) << 7), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((127 - 1) << 7), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 + x0 = tl.where(em0 == 0x1, 16128 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, 16128 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True) + # Multiplication preserves infs and NaNs in x_bf16 + mxfp = x_bf16 * scale_bf16 + # If scale is NaN, we encode it as an bf16 inf, so we need to correct for that + mxfp = tl.where(scale == 0xFF, float("nan"), mxfp) + + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) + + +def dot_scale_ref(x, scale, y, type_x, type_y): + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] + type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type_y] + + out_dtype = torch.bfloat16 + + x = x.contiguous() + x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=out_dtype) + + N = x_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=4) + y_upcast = y if type_y == "bf16" else y.view(type_fp8_y).to(out_dtype) + assert x_upcast.dtype == out_dtype + assert y_upcast.dtype == out_dtype + + class AccumulateInFp32: + + def __enter__(self): + self.prev_value = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value + + with AccumulateInFp32(): + return torch.matmul(x_upcast, y_upcast) + + +@pytest.mark.parametrize("scale", [True, False]) +def test_pipeline_matmul(scale, device): + check_capabilities() + if scale and not (is_cuda() or is_hip_cdna()): + pytest.skip("NYI: scale_dot just implemented in CUDA/HIP") + M, N, K = 512, 512, 128 + BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 + NUM_STAGES = 4 + + if scale: + # Large enough tile to let our heuristics to pipeline small tensor kick in + # for the scales + BLOCK_M = 256 + BLOCK_K = 128 + K = BLOCK_K * NUM_STAGES + a_type = "e2m1" + DIV_FACTOR = 2 if a_type == "e2m1" else 1 + a = torch.randint(256, (M, K // DIV_FACTOR), device=device, dtype=torch.uint8) + # Sample small-ish scales to avoid overflow + scale_a = torch.randint(74, (M, K // 32), device=device, dtype=torch.uint8) + # Use e5m2 for Ampere, as it does not support fp_to_fp conversions for fp8e4m3 + # Use bf16 for Hopper as the rhs must come from shmem + b_type = "bf16" if is_hopper() else "e5m2" + if b_type == "bf16": + b = torch.randn((K, N), device=device, dtype=torch.bfloat16) + else: + b = torch.randint(256, (K, N), device=device, dtype=torch.uint8) + # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and + # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme) + finite = torch.arange(K * N, device=device, dtype=torch.uint8).reshape(K, N) % 0x7C + b = torch.where(b & 0x7C == 0x7C, finite | (0x80 & b), b) + output = torch.empty((M, N), dtype=torch.bfloat16, device=device) + else: + a = torch.randn(M, K, device=device, dtype=torch.float16) + b = torch.randn(K, N, device=device, dtype=torch.float16) + scale_a = None + a_type, b_type = None, None + output = torch.empty((M, N), dtype=torch.float16, device=device) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + use_tma = not scale and is_hopper() + + if use_tma: + a_tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(a.data_ptr(), M, K, BLOCK_M, BLOCK_K, + a.element_size()) + b_tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(b.data_ptr(), K, N, BLOCK_K, BLOCK_N, + b.element_size()) + output_tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(output.data_ptr(), M, N, BLOCK_M, + BLOCK_N, output.element_size()) + handler = matmul_kernel_tma[grid](a_tma, b_tma, output_tma, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, + NUM_STAGES=NUM_STAGES) + else: + # Pass K_MXFP to make explicit that KB is multiple of 32 and KA is multiple of 16 or 32º + if scale: + K = scale_a.shape[-1] + stride_sm, stride_sk = scale_a.stride() if scale else (0, 0) + handler = matmul_kernel[grid](a, scale_a, b, output, M, N, K, a.stride(0), a.stride(1), stride_sm, stride_sk, + b.stride(0), b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, NUM_STAGES=NUM_STAGES, a_type=a_type, b_type=b_type) + if scale: + ref_out = dot_scale_ref(a, scale_a, b, a_type, b_type) + else: + ref_out = torch.matmul(a, b) + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + atol = 1e-2 if is_hip_mi200() or scale else None + rtol = 1e-2 if is_hip_mi200() or scale else None + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol, equal_nan=scale) + if is_cuda(): + ttgir = handler.asm["ttgir"] + if use_tma: + assert ttgir.count("ttng.async_tma_copy_global_to_local") != 0, "async tma copy not found" + assert ttgir.count(f"num = {NUM_STAGES} : i32") == 0, "num_stages not match" + assert ttgir.count("ttng.barrier_expect") != 0, "barrier_expect not found" + assert ttgir.count("ttng.wait_barrier") != 0, "wait_barrier not found" + + if torch.cuda.get_device_capability()[0] == 9: + # a_tma, b_tma, output_tma, barriar_tma + assert ttgir.count("ttg.local_alloc") == 4, "alloc number not match" + assert ttgir.count("ttng.warp_group_dot") != 0, "warp_group_dot not found" + elif torch.cuda.get_device_capability()[0] == 10: + # a_tma, b_tma, output_tma, barriar_tma, barriar_mma + assert ttgir.count("ttg.local_alloc") == 5, "alloc number not match" + assert ttgir.count("ttng.tc_gen5_mma") != 0, "warp_group_dot not found" + else: + # 1. check async + assert ttgir.count("ttg.async_copy_global_to_local") != 0, "async copy not found" + # 2. check sync point + assert ttgir.count("num = 0 : i32") == 1, "only one sync point for the loads after the loop" + # 3. check alloc + if torch.cuda.get_device_capability()[0] == 10: + if scale: + # A, B, scale, decomposed A shmem + # MMA pipelining fails to identify the MMA pattern in this case, so the barrier is not inserted. + count = 4 + else: + # A, B, MMA barrier + count = 3 + assert ttgir.count("ttg.local_alloc") == count, "alloc number not match" + else: + assert ttgir.count("ttg.local_alloc") == (3 if scale else 2), "alloc number not match" + + # 4. check dot + cc = torch.cuda.get_device_capability() + if cc[0] == 9: + assert ttgir.count("ttng.warp_group_dot") != 0, "warp_group_dot not found" + elif cc[0] < 9: + assert ttgir.count("ttg.dot") != 0, "dot not found" + + +def test_pipeline_vecadd(device): + check_capabilities() + SIZE = 4096 + NUM_BLOCKS = 4 + BLOCK_SIZE = 256 + NUM_STAGES = 3 + a = torch.randn(SIZE, dtype=torch.float16, device=device) + b = torch.randn(SIZE, dtype=torch.float16, device=device) + output = torch.empty(SIZE, dtype=torch.float16, device=device) + grid = (triton.cdiv(SIZE, NUM_BLOCKS * BLOCK_SIZE), 1) + handler = vecadd_kernel[grid](a, b, output, SIZE, NUM_BLOCKS, BLOCK_SIZE, NUM_STAGES) + ref_out = a + b + torch.testing.assert_close(ref_out, output) + if is_cuda(): + ttgir = handler.asm["ttgir"] + # 1. check number of stages + assert ttgir.count("ttg.async_copy_global_to_local") / 2 == NUM_STAGES, "num_stages not match" + # 2. check alloc + assert ttgir.count("ttg.local_alloc") == 2, "alloc number not match" + + +@pytest.mark.parametrize("ROW_COUNT", [0, 1, 2, 3]) +@pytest.mark.parametrize("NUM_STAGES", [1, 2, 3, 4, 5]) +def test_pipeline_epilogue(ROW_COUNT, NUM_STAGES, device): + + @triton.jit + def kernel_up(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr): + row_step = tl.num_programs(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + for row_idx in tl.range(0, n_rows, row_step, num_stages=NUM_STAGES): + row_start_ptr = input_ptr + row_idx * input_row_stride + input_ptrs = row_start_ptr + col_offsets + val = tl.load(input_ptrs, mask=mask, other=-float('inf')) + val += 1.0 + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, val, mask=mask) + + width = ROW_COUNT + depth = 78 + x = torch.zeros(width, depth, device=device) + y0 = torch.rand_like(x) + n_rows, n_cols = x.shape + BLOCK_SIZE = triton.next_power_of_2(n_cols) + kernel_up[(1, )](y0, x, x.stride(0), y0.stride(0), n_rows, n_cols, BLOCK_SIZE, NUM_STAGES) + assert (y0 == torch.ones_like(x)).all() + + +def random_bfloat16(shape, device): + """ + Creates a random bfloat16 tensor where every element is a multiple of 1/8. + This should avoid floating-point errors in downstream calculations, allowing + for exact comparisons. + """ + + X = torch.randn(shape, device=device, dtype=torch.bfloat16) + X *= 8.0 + X = torch.round(X) + X *= 0.125 + return X + + +@triton.jit +def indirect_matmul_kernel( + Out, + stride_out1, + A, + stride_a1, + B, + stride_b1, + Indices, + K, + + # output tile size: + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, +): + index_ptrs = Indices + tl.arange(0, BLOCK_K) + + m_offs = tl.arange(0, BLOCK_M) + n_offs = tl.arange(0, BLOCK_N)[None, :] + + A_ptrs = A + n_offs + B_ptrs = B + m_offs + + acc = tl.zeros([BLOCK_M, BLOCK_N], tl.float32) + for k in range(0, K, BLOCK_K): + idx = tl.load(index_ptrs) + + a = tl.load(A_ptrs + idx[:, None] * stride_a1) + b = tl.load(B_ptrs + idx[:, None] * stride_b1) + + acc = tl.dot(b.T, a, acc=acc) + index_ptrs += BLOCK_K + + # now write out the accumulator: + Out_ptrs = Out + m_offs[:, None] + n_offs * stride_out1 + tl.store(Out_ptrs, acc) + + +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (128, 128, 64), (128, 64, 128)]) +@pytest.mark.parametrize("num_stages", [1, 3, 5]) +def test_indirect_matmul(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, device): + if num_stages > 3 and is_hip(): + pytest.skip("Not enough shared memory on HIP.") + M = BLOCK_M + N = BLOCK_N + + K = BLOCK_K * 2 + A = random_bfloat16((K, N), device=device) + B = random_bfloat16((K, M), device=device) + + # Use arange for indices so it's numerically just a matmul + Indices = torch.arange(K, device=device) + Out = torch.empty((N, M), device=device, dtype=torch.float32) + + expect = torch.matmul(A.mT.to(torch.float32), B.to(torch.float32)) + + indirect_matmul_kernel[(1, )]( + Out, + Out.stride(0), + A, + A.stride(0), + B, + B.stride(0), + Indices, + K, + BLOCK_M, + BLOCK_K, + BLOCK_N, + num_warps=4, + num_stages=num_stages, + ) + torch.testing.assert_close(expect, Out) + + +@triton.jit +def matmul_kernel_persistent_scatter(a_ptr, b_ptr, c_ptr, # + M, N, K, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr): # + # Matmul using TMA and device-side descriptor creation + dtype = c_ptr.dtype.element_ty + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[1, BLOCK_SIZE_N], + ) + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + c = accumulator.to(dtype) + c_desc.scatter(c, offs_am + tl.arange(0, BLOCK_SIZE_M), offs_bn) + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, + reason="TMA Scatter only works on cloud Blackwell Chips") +def test_scatter_pipeline(device): + + def alloc_fn(size, alignment, stream): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + M, N, K, = 1024, 1024, 1024 + BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 + GROUP_SIZE_M = 4 + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + grid_x = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)) + + a = torch.randn(M, K, device=device, dtype=torch.float16) + b = torch.randn(N, K, device=device, dtype=torch.float16) + c = torch.empty((M, N), device=device, dtype=torch.float16) + + kernel = matmul_kernel_persistent_scatter[(grid_x, )](a, b, c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_SIZE_M, + NUM_SMS) + + ref = torch.matmul(a, b.T) + torch.testing.assert_close(c, ref) + + assert kernel.asm["ttgir"].count("tma_store_wait") == 2, "expected pipelined TMA scatter" diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_random.py b/third_party/enflame/include/triton/python/test/unit/language/test_random.py new file mode 100644 index 000000000..a34691f4b --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_random.py @@ -0,0 +1,273 @@ +import numpy as np +import pytest +import scipy.stats +import torch + +import triton +import triton.language as tl + +##################################### +# Reference Philox Implementation +##################################### + + +class PhiloxConfig: + + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + bits = np.dtype(self._dtype).itemsize * 8 + while len(res) < pad: + res.append(np.array((n & ((1 << bits) - 1)), dtype=self._dtype)) + n >>= bits + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +# Unit Tests +##################################### + +BLOCK = tl.constexpr(1024) + +# test generation of random uint32 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in ['10', '4,53', '400'] + for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randint(size, seed, device, dtype, const_seed): + size = list(map(int, size.split(','))) + torch_dtype = getattr(torch, dtype) + numpy_dtype = getattr(np, f"u{dtype}") + config = {'int32': PHILOX_32, 'int64': PHILOX_64}[dtype] + + @triton.jit + def kernel(X, N, seed): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch_dtype, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK.value), ) + if const_seed: + const_kernel[grid](x, N, seed=seed) + else: + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(numpy_dtype).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=config) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + + +# test uniform PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_rand(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK.value), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert all((x >= 0) & (x <= 1)) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + + +def test_seed_is_int(device): + + @triton.jit + def kernel(X, seed): + offset = tl.arange(0, 1) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand) + + x = torch.empty(1, dtype=torch.float32, device=device) + with pytest.raises(triton.compiler.errors.CompilationError): + seed0 = torch.zeros(1, dtype=torch.int32, device=device) + kernel[(1, )](x, seed0) + with pytest.raises(triton.compiler.errors.CompilationError): + seed1 = 2.3 + kernel[(1, )](x, seed1) + + +# test normal PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randn(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK.value), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('dtype', ['int32', 'int64']) +def test_rand_limits(dtype, device): + + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint_to_uniform_float(x) + tl.store(output + idx, y) + + torch_dtype = getattr(torch, dtype) + min_max_int = torch.tensor([ + torch.iinfo(torch_dtype).min, + torch.iinfo(torch_dtype).max, + ], dtype=torch_dtype, device=device) + output = torch.empty(2, dtype=torch.float32, device=device) + kernel[(1, )](min_max_int, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_reproducer.py b/third_party/enflame/include/triton/python/test/unit/language/test_reproducer.py new file mode 100644 index 000000000..a045e8f30 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_reproducer.py @@ -0,0 +1,42 @@ +import os +import shutil + +import pytest + +import torch +import triton +import re + + +@triton.jit +def triton_(): + return + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") +def test_reproducer(): + tmpdir = ".tmp" + reproducer = 'triton-reproducer.mlir' + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) + os.environ["TRITON_CACHE_DIR"] = tmpdir + os.environ["TRITON_REPRODUCER_PATH"] = reproducer + triton_[(1, )]() + foundPipeline = "" + with open(reproducer, 'r') as f: + line = f.read() + if 'pipeline:' in line: + foundPipeline = line + if 0 == len(foundPipeline): + raise Exception("Failed to find pipeline info in reproducer file.") + + ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm") + if ttgir_to_llvm_pass.search(foundPipeline): + raise Exception("Failed to find triton passes in pipeline") + # cleanup + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_standard.py b/third_party/enflame/include/triton/python/test/unit/language/test_standard.py new file mode 100644 index 000000000..df5784d92 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_standard.py @@ -0,0 +1,132 @@ +import triton +import pytest +import torch +import triton.language as tl + +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random + +# --------------- +# test maximum/minimum ops +# --------------- + + +# TODO: Tests with unsigned integers failed at compilation stage. +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", int_dtypes + uint_dtypes + float_dtypes + ["bfloat16"]) +@pytest.mark.parametrize("op", ["maximum", "minimum"]) +def test_maximum_minium(dtype, op, device): + expr = f'tl.{op}(x, y)' + numpy_expr = f'np.{op}(x, y)' + _test_binary(dtype, dtype, expr, numpy_expr, device=device) + + +# --------------- +# test sort op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) +def test_sort(M, N, descending, dtype_str, device): + + @triton.jit + def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.sort(x, descending=descending) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.sort(x, descending=descending)[0] + z = torch.empty_like(x) + sort_kernel[(1, )](x, z, N, M, descending, num_warps=8) + assert (y == z).all(), (y, z) + + +# --------------- +# test flip op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) +def test_flip(M, N, dtype_str, device): + + @triton.jit + def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.flip(x) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.flip(x, (1, )) + z = torch.empty_like(x, device=device) + flip_kernel[(1, )](x, z, N, M, num_warps=8) + assert (y == z).all(), (y, z) + + +@pytest.mark.interpreter +def test_flip_inf(device): + # Reproducer for https://github.com/triton-lang/triton/issues/5439 + + @triton.jit + def triton_flip_kernel(out_ptr, x_ptr, N: tl.constexpr): + pid = tl.program_id(0) + x = tl.load(x_ptr + pid * N + tl.arange(0, N)) + shape: tl.constexpr = (N // 2, 2) + y = x.reshape(shape) + y = tl.flip(y, dim=1).reshape(x.shape) + tl.store(out_ptr + pid * N + tl.arange(0, N), y) + + x = torch.arange(0, 16, device=device).unsqueeze(0).float() + x[:, -1] = float('inf') + + expect = x.reshape(-1, 8, 2).flip(-1).reshape(-1, 16) + actual = torch.empty_like(x) + triton_flip_kernel[(x.shape[0], )](actual, x, x.shape[1]) + + torch.testing.assert_close(expect, actual) + + +@pytest.mark.interpreter +def test_ravel(device): + + @triton.jit + def triton_ravel(out_ptr): + a = tl.arange(0, 256) + a = tl.reshape(a, (32, 8)) + a = tl.ravel(a) + tl.store(out_ptr + tl.arange(0, 256), a) + + out = torch.empty((256, ), device=device, dtype=torch.int32) + triton_ravel[(1, )](out) + + assert (out == torch.arange(0, 256, device=device)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]]) +def test_swizzle2d(size_i, size_j, size_g, device): + + @triton.jit + def swizzle2d_kernel(output, size_i, size_j, size_g): + for i in tl.range(0, size_i, 1): + for j in tl.range(0, size_j, 1): + new_i, new_j = tl.swizzle2d(i, j, size_i, size_j, size_g) + tl.store(output + new_i * size_j + new_j, i * size_j + j) + + output = torch.zeros(size_i, size_j).to(device) + swizzle2d_kernel[(1, )](output, size_i, size_j, size_g) + expected_order = torch.tensor([[0, 3, 6, 9, 12, 15, 18], [1, 4, 7, 10, 13, 16, 19], [2, 5, 8, 11, 14, 17, 20], + [21, 23, 25, 27, 29, 31, 33], [22, 24, 26, 28, 30, 32, 34]]).to(device) + assert (output == expected_order).all(), (output, expected_order) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_subprocess.py b/third_party/enflame/include/triton/python/test/unit/language/test_subprocess.py new file mode 100644 index 000000000..f1e415bbb --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_subprocess.py @@ -0,0 +1,127 @@ +import itertools +import os +import subprocess +import sys +from collections import Counter + +import triton + +import pytest + +dir_path = os.path.dirname(os.path.realpath(__file__)) +print_path = os.path.join(dir_path, "print_helper.py") +torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +# TODO: Print with multiple operands + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_type, data_type", [(fn, data_type) + for fn in ["device_print", "device_print_scalar"] + for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("print_no_arg", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), + ("device_print_hex", "int16"), + ("device_print_hex", "int32"), + ("device_print_hex", "int64"), + ("device_print_pointer", "int32"), + ("device_print_negative", "int32"), + ("device_print_uint", "uint32"), + ("device_print_2d_tensor", "int32"), + ]) +def test_print(func_type: str, data_type: str, device: str): + proc = subprocess.run( + [sys.executable, print_path, "test_print", func_type, data_type, device], + capture_output=True, + ) + assert proc.returncode == 0 + + if is_interpreter() and func_type != "static_assert": + # Interpreter uses a different format for device_print + # Only check if there's no error + assert proc.stderr == b'' + return + + outs = [line for line in proc.stdout.decode("UTF-8").splitlines() if line] + # The total number of elements in the 1-D tensor to print. + N = 128 + + # Constant for testing the printing of scalar values + SCALAR_VAL = 42 + + # Format is + # pid (, , ) idx (, , ...) (operand ) + expected_lines = Counter() + if func_type in ("print", "device_print", "device_print_uint"): + for i in range(N): + offset = (1 << 31) if data_type == "uint32" else 0 + line = f"pid (0, 0, 0) idx ({i:3}) x: {i + offset}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = 1 + elif func_type == "device_print_scalar": + line = f"pid (0, 0, 0) idx () x: {SCALAR_VAL}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = N + elif func_type == "device_print_negative": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: {-i}" + expected_lines[line] = 1 + elif func_type == "device_print_hex": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: 0x" + if data_type == "int16": + line += f"{i:04x}" + if data_type == "int32": + line += f"{i:08x}" + if data_type == "int64": + line += f"{i:016x}" + expected_lines[line] = 1 + elif func_type == "static_print": + expected_lines[f" int32[constexpr[{N}]]"] = 1 + elif func_type == "no_arg_print": + expected_lines["pid (0, 0, 0) idx (): 0"] = N + elif func_type == "print_no_arg": + expected_lines["pid (0, 0, 0) no arg"] = N + elif func_type == "device_print_large": + for i, j, k in itertools.product(range(2), range(64), range(N)): + expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1 + elif func_type == "device_print_pointer": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1 + elif func_type == "device_print_2d_tensor": + warp_size = triton.runtime.driver.active.get_current_target().warp_size + x_dim = N // warp_size + y_dim = warp_size + for x in range(x_dim): + for y in range(y_dim): + expected_lines[f"pid (0, 0, 0) idx ({x}, {y:2}): {(x * y_dim + y)}"] = 1 + + actual_lines = Counter() + for line in outs: + # Trim the exact pointer address in the output--they can change per run. + line = (line.split(':')[0] + ": 0x") if func_type == "device_print_pointer" else line + actual_lines[line] += 1 + + diff = Counter(actual_lines) + diff.subtract(expected_lines) + for line, delta in diff.items(): + if delta == 0: + continue + print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') + assert all(delta == 0 for delta in diff.values()) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_tuple.py b/third_party/enflame/include/triton/python/test/unit/language/test_tuple.py new file mode 100644 index 000000000..0c08413c3 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_tuple.py @@ -0,0 +1,164 @@ +import pytest +import triton +import triton.language as tl +from typing import NamedTuple +import torch + + +@triton.jit +def _tuple_increment(values): + for i in tl.static_range(len(values)): + values[i] = values[i] + 1 + return values + + +@triton.jit +def _tuple_index_func(Ptrs, values): + for i in tl.static_range(len(values)): + tl.store(Ptrs[i], values[i]) + + +@triton.jit +def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4): + values = _tuple_increment(values) + _tuple_index_func(Ptrs, values) + + +@pytest.mark.parametrize("size", [0, 1, 2, 3, 4]) +def test_index(size, device): + vals = tuple([i + 1 for i in range(size)]) + rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals]) + _tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0) + assert vals == tuple([x.item() - 1 for x in rets]) + + +# ---- + + +@triton.jit +def _tuple_assign(XPtrs, YPtrs, values): + # assign from tuple + X0, X1 = XPtrs + x0, x1 = values + tl.store(X0, x0) + tl.store(X1, x1) + # assign to tuple + Y0, Y1, Y2 = YPtrs + Y = Y0, Y1, Y2 + y = x0, 10, x1 + tl.store(Y[0], y[0]) + tl.store(Y[1], y[1]) + tl.store(Y[2], y[2]) + + +@pytest.mark.interpreter +def test_assign(device): + vals = (2., 3.) + x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) + y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) + _tuple_assign[(1, )](x, y, vals) + assert x[0] == vals[0] + assert x[1] == vals[1] + assert y[0] == vals[0] + assert y[1] == 10 + assert y[2] == vals[1] + + +# ------- + + +@triton.jit +def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1): + tl.static_assert(tuple1[1] is None) + tl.store(Ptr + 5, cst2) + tl.store(Ptr + 6, tuple1[0]) + tl.store(Ptr + 7, tl.load(tuple1[2][0])) + tl.store(Ptr + 8, tuple1[2][1][0]) + tl.store(Ptr + 9, tl.load(tuple1[2][1][2])) + + +# test serialization/deserialization of tuple arguments in +# the frontend. +@triton.jit +def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2): + tl.static_assert(N1 is None) + tl.static_assert(tuple1[1][1] is None) + tl.static_assert(tuple1[1][3] == 4) + tl.store(Ptr + 0, tl.load(tuple1[0])) + tl.store(Ptr + 1, tuple1[1][0]) + tl.store(Ptr + 2, tl.load(tuple1[1][2])) + tl.store(Ptr + 3, cst1 + val1) + tl.store(Ptr + 4, tl.load(tuple2[0])) + _tuple_fn0(Ptr, 15, (-1, None, tuple1)) + + +@pytest.mark.interpreter +def test_serialize(device): + x0 = torch.tensor([8], dtype=torch.int32, device=device) + x1 = torch.tensor([12], dtype=torch.int32, device=device) + y0 = torch.tensor([10], dtype=torch.int32, device=device) + z = torch.empty((10, ), dtype=torch.int32, device=device) + # we want to check that JIT specialization propagates to tuples: + _tuple_serialize[(1, )](z, None, (x0, (1, None, x1, tl.constexpr(4))), 20, 1, (y0, )) + ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device) + assert torch.equal(z, ref) + + +class Function(NamedTuple): + fn: tl.constexpr + captured: tuple + + +class Tensor(NamedTuple): + ptr: any + shape: tuple + stride: tuple + + +@triton.jit +def _namedtuple_create_func0(shape, ptr, stride): + return Tensor(shape=shape, ptr=ptr, stride=stride) + + +@triton.jit +def _namedtuple_create_func1(shape, ptr, stride): + tensor = Tensor(shape=shape, ptr=ptr, stride=stride) + return tensor + + +@triton.jit +def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + mask = (offs_m[:, None] < Tensor.shape[0]) & (offs_n[None, :] < Tensor.shape[1]) + return mask + + +@triton.jit +def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride) + Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride) + Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1] + Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1] + x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0) + y = closure.fn(x, *closure.captured) + tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N)) + + +@pytest.mark.interpreter +def test_namedtuple(device): + x = torch.randn((32, 32), dtype=torch.float32, device=device) + y = torch.empty((16, 16), dtype=torch.float32, device=device) + a = torch.tensor([5.2], dtype=torch.float32, device=device) + + @triton.jit + def mul(x, a): + return x * tl.load(a) + + function = Function(mul, (a, )) + tx = Tensor(x, x.shape, x.stride()) + ty = Tensor(y, y.shape, y.stride()) + _namedtuple_kernel[(1, )](function, tx, ty, 64, 64) + assert torch.allclose(y, x[:16, :16] * a) diff --git a/third_party/enflame/include/triton/python/test/unit/language/test_warp_specialization.py b/third_party/enflame/include/triton/python/test/unit/language/test_warp_specialization.py new file mode 100644 index 000000000..52257601c --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/language/test_warp_specialization.py @@ -0,0 +1,103 @@ +import torch +import pytest +import pathlib +import triton + +from triton._internal_testing import is_cuda + + +@pytest.mark.skipif(not is_cuda(), reason="warp specialization is only supported on NVIDIA") +def test_warp_specialize_basic_ir(tmp_path: pathlib.Path): + ir = """ + tt.func @kernel(%arg0: !tt.ptr) { + %c42_i32 = arith.constant 42 : i32 + gpu.barrier + ttg.warp_specialize(%arg0) + default { + tt.store %arg0, %c42_i32 : !tt.ptr + gpu.barrier + ttg.warp_yield + } + partition0(%arg1: !tt.ptr) num_warps(1) { + %c5555_i32 = arith.constant 5555 : i32 + %c1_i32 = arith.constant 1 : i32 + gpu.barrier + %ptr = tt.addptr %arg1, %c1_i32 : !tt.ptr, i32 + tt.store %ptr, %c5555_i32 : !tt.ptr + ttg.warp_return + } : (!tt.ptr) -> () + tt.return + } + """ + + temp_file = tmp_path / "test_warp_specialize_basic_ir.ttir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + input = torch.empty(2, dtype=torch.int32, device='cuda') + kernel[(1, 1, 1)](input) + assert input[0] == 42 + assert input[1] == 5555 + + +@pytest.mark.skipif(not is_cuda(), reason="warp specialization is only supported on NVIDIA") +def test_warpgroup_reduction(tmp_path: pathlib.Path): + + def template(i, num_warps, in_ptr, out_ptr): + return f""" + %range = tt.make_range {{end = {(i+1)*256} : i32, start = {i*256} : i32}} : tensor<256xi32, #blocked{num_warps}> + %splatted = tt.splat {in_ptr} : !tt.ptr -> tensor<256x!tt.ptr, #blocked{num_warps}> + %ptrs = tt.addptr %splatted, %range : tensor<256x!tt.ptr, #blocked{num_warps}>, tensor<256xi32, #blocked{num_warps}> + %input = tt.load %ptrs : tensor<256x!tt.ptr, #blocked{num_warps}> + %result = "tt.reduce"(%input) ({{ + ^bb0(%lhs: i32, %rhs: i32): + %result = arith.addi %lhs, %rhs : i32 + tt.reduce.return %result : i32 + }}) {{axis = 0 : i32}} : (tensor<256xi32, #blocked{num_warps}>) -> i32 + %offset = arith.constant {i} : i32 + %output = tt.addptr {out_ptr}, %offset : !tt.ptr, i32 + tt.store %output, %result : !tt.ptr + """ + + ir = """ + #blocked4 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + #blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + + module attributes {"ttg.num-warps" = 4 : i32} { + + tt.func @kernel(%arg0: !tt.ptr, %arg1: !tt.ptr) { + ttg.warp_specialize(%arg0, %arg1) + default { + """ + template(0, 4, "%arg0", "%arg1") + """ + ttg.warp_yield + } + partition0(%arg2: !tt.ptr, %arg3: !tt.ptr) num_warps(4) { + """ + template(1, 4, "%arg2", "%arg3") + """ + ttg.warp_return + } + partition1(%arg4: !tt.ptr, %arg5: !tt.ptr) num_warps(2) { + """ + template(2, 2, "%arg4", "%arg5") + """ + ttg.warp_return + } + partition2(%arg6: !tt.ptr, %arg7: !tt.ptr) num_warps(1) { + """ + template(3, 1, "%arg6", "%arg7") + """ + ttg.warp_return + } : (!tt.ptr, !tt.ptr) -> () + tt.return + } + + } + """ + + temp_file = tmp_path / "test_warpgroup_reduction.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + input = torch.arange(1024, dtype=torch.int32, device='cuda') + output = torch.empty(4, dtype=torch.int32, device='cuda') + kernel[(1, 1, 1)](input, output) + assert output[0] == torch.arange(0, 256).sum() + assert output[1] == torch.arange(256, 512).sum() + assert output[2] == torch.arange(512, 768).sum() + assert output[3] == torch.arange(768, 1024).sum() diff --git a/third_party/enflame/include/triton/python/test/unit/runtime/test_autotuner.py b/third_party/enflame/include/triton/python/test/unit/runtime/test_autotuner.py new file mode 100644 index 000000000..fa835eeeb --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/runtime/test_autotuner.py @@ -0,0 +1,185 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +def do_bench(kernel_call, quantiles): + return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1) + + +@pytest.mark.parametrize('use_cuda_graph', [False, True]) +def test_kwargs(use_cuda_graph: bool, device: str): + if use_cuda_graph and not torch.cuda.is_available(): + pytest.xfail("CUDA is not available") + + M, N = 1024, 16 + src = torch.randn(M * N, device=device) + dst = torch.empty(M * N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] + + @triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph, do_bench=do_bench) + @triton.jit + def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) + offsets_n = tl.arange(0, BLOCK_SIZE_N) + x = tl.load(src + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :]) + tl.store(dst + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :], x) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE_M']), ) + _kernel[grid](dst, src, N, M, N) + # the key word args could be in arbitrary order. + _kernel[grid](dst=dst, src=src, M=M // 2, stride_m=N, BLOCK_SIZE_N=N) + assert len(_kernel.cache) == 2 + + +@pytest.mark.parametrize('pass_kwargs_to_kernel', [False, True]) +def test_restore(pass_kwargs_to_kernel, device): + N = 1024 + src = torch.zeros(N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], do_bench=do_bench) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + if pass_kwargs_to_kernel: + _kernel[grid](src=src, N=N) + else: + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) + + +def test_hooks(device): + # Autotuner's pre- and post- hooks should be called the same number of times + N = 4096 + src = torch.zeros(N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 4096}), triton.Config(kwargs={'BLOCK_SIZE': 32})] + + values = {"counter": 0, "has_exception": False} + + def _pre_hook(*args, **kwargs): + values["counter"] += 1 + + def _post_hook(*args, exception): + values["counter"] -= 1 + if exception is not None: + values["has_exception"] = True + assert values["counter"] == 0 + + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=_pre_hook, post_hook=_post_hook) + @triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4}) + @triton.jit + def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + max_iters = tl.cdiv(N, BLOCK_SIZE) + for _ in tl.range(max_iters, num_stages=N_STAGES): + x = tl.load(src + offsets, mask=offsets < N) + tl.store(src + offsets, x, mask=offsets < N) + offsets += BLOCK_SIZE + + _kernel[(1, )](src, N) + + # On NVIDIA GPUs: + # The tuning knob `num_stages` can be set by users. + # This will cause out of resources when N_STAGES = 100 + # shared memory bytes = N_STAGES * BLOCK_SIZE * sizeof(float) + # On AMD GPUs: + # `num_stages` is a fixed value of 2, so it won't cause out of resources + if triton.runtime.driver.active.get_current_target().backend == "cuda": + assert values["has_exception"] is True + else: + assert values["has_exception"] is False + + +@pytest.mark.parametrize('with_perf_model', [False, True]) +def test_prune_configs(with_perf_model: bool, device: str): + N = 1024 + src = torch.randn(N, device=device) + dst = torch.empty(N, device=device) + records = {} + + def early_config_prune(configs, named_args, **kwargs): + records['run_early_config_prune'] = True + if "N" in kwargs and kwargs["N"] == 1024: + records['capture_kwargs'] = True + if "dst" in named_args and "src" in named_args and len(named_args) == 2: + records['capture_named_args'] = True + return [configs[0]] + + def perf_model(*args, **kwargs): + records['run_perf_model'] = True + return kwargs['BLOCK_SIZE'] + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + if with_perf_model: + prune_configs_by = {'perf_model': perf_model, 'top_k': 1} + else: + prune_configs_by = {'early_config_prune': early_config_prune} + + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, do_bench=do_bench) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + torch.testing.assert_close(src, dst) + if with_perf_model: + assert len(records) == 1 + assert records['run_perf_model'] + else: + assert len(records) == 3 + assert records['run_early_config_prune'] + assert records['capture_kwargs'] + assert records['capture_named_args'] + + +def test_exceed_tmem(device): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 10: + pytest.skip("Test requires tensor memory.") + N = 512 + dst = torch.empty((N, ), device=device, dtype=torch.float32) + configs = [triton.Config(kwargs={'BLOCK_SIZE': 128}), triton.Config(kwargs={'BLOCK_SIZE': 32})] + exception_out_of_resource = None + + def _post_hook(*args, exception): + nonlocal exception_out_of_resource + if exception is not None: + exception_out_of_resource = exception + + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=None, post_hook=_post_hook) + @triton.jit + def dot_kernel(dst, BLOCK_SIZE: tl.constexpr): + a = tl.full((BLOCK_SIZE, BLOCK_SIZE), 0.0, tl.float16) + b = tl.full((BLOCK_SIZE, BLOCK_SIZE), 0.0, tl.float16) + c0 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + c1 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + c2 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + c3 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + c4 = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32) + for i in range(0, 100): + c0 = tl.dot(a, b, c0) + c1 = tl.dot(a, b, c1) + c2 = tl.dot(a, b, c2) + c3 = tl.dot(a, b, c3) + c4 = tl.dot(a, b, c4) + c = c4 + c3 + c2 + c1 + c0 + c = c.reshape([BLOCK_SIZE * BLOCK_SIZE]) + tl.store(dst + tl.arange(0, BLOCK_SIZE * BLOCK_SIZE), c) + + dot_kernel[(1, )](dst) + assert exception_out_of_resource is not None and str( + exception_out_of_resource + ) == "out of resource: tensor memory, Required: 640, Hardware limit: 512. Reducing block sizes or `num_stages` may help." diff --git a/third_party/enflame/include/triton/python/test/unit/runtime/test_bindings.py b/third_party/enflame/include/triton/python/test/unit/runtime/test_bindings.py new file mode 100644 index 000000000..b899af4b7 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/runtime/test_bindings.py @@ -0,0 +1,104 @@ +import triton +import triton.language as tl + +import torch +import math + + +@triton.jit +def add_helper(x, y): + return x + y + + +@triton.jit +def add_kernel( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = add_helper(x, y) + tl.store(out_ptr + offsets, output, mask=mask) + + +def test_module_walk(device): + """ + Test the MLIR bindings exposed for the out-of-tree walk. + """ + + def walk_fn(op): + name = op.get_name() + for i in range(op.get_num_results()): + op.get_result(i).id() + for i in range(op.get_num_operands()): + op.get_operand(i).id() + for i in range(op.get_num_regions()): + op.get_region(i).id() + block = op.get_block() + if block is not None: + block.id() + for i in range(block.get_num_arguments()): + block.get_argument(i) + if name == "tt.func": + op.get_str_attr("sym_name") + if name == "tt.call": + op.get_flat_symbol_ref_attr("callee") + + kernel = add_kernel + args = [ + torch.empty((32, 32), device=device), # in_ptr0 + torch.empty((32, 32), device=device), # in_ptr1 + 1024, # n_elements + torch.empty((32, 32), device=device), # out_ptr + 16, # BLOCK_SIZE + ] + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + src = triton.compiler.compiler.ASTSource( + fn=kernel, + signature={kernel.arg_names[i]: triton.runtime.jit.mangle_type(arg) + for i, arg in enumerate(args)}, + constexprs={kernel.arg_names[i]: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + ) + + context = triton._C.libtriton.ir.context() + options = backend.parse_options(dict()) + codegen_fns = dict() + module_map = backend.get_module_map() + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + ttir_module = src.make_ir(options, codegen_fns, module_map, context) + ttir_module.walk(walk_fn) + + +def test_python_func_in_visit_call(device): + + @triton.jit + def test_py_call_const_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + log2e: tl.constexpr = math.log2(math.e) + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x * log2e + tl.store(out_ptr + offsets, output, mask=mask) + + x = torch.randn(4, device=device) + out = torch.zeros_like(x) + test_py_call_const_kernel[(4, )](x, out, 4, 4) diff --git a/third_party/enflame/include/triton/python/test/unit/runtime/test_cache.py b/third_party/enflame/include/triton/python/test/unit/runtime/test_cache.py new file mode 100644 index 000000000..a4982bf8c --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/runtime/test_cache.py @@ -0,0 +1,629 @@ +import importlib.util +import itertools +import os +import shutil +import pathlib + +import pytest +import torch + +import triton +import triton.language as tl +from triton.runtime.jit import JITFunction +from triton._internal_testing import is_hip + + +@triton.jit +def function_0(i): + return i + 1 + + +@triton.jit +def function_1(i): + i = i + 1 + cond: tl.constexpr = True + if cond: + FN: tl.constexpr = function_2 + else: + FN: tl.constexpr = function_0 + return FN(i) + + +@triton.jit +def function_2(i): + i = i + 1 + return i + + +@triton.jit +def combine_fn(a, b): + return COMBINE_OP # noqa: F821 + + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize_on_alignment=["i"]) +def kernel_nospec_on_alignment(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit +def kernel_with_combine_fn(X, BLOCK: tl.constexpr): + i = tl.arange(0, BLOCK) + i = REDUCE_OR_SCAN(i, 0, combine_fn) # noqa: F821 + tl.store(X, i) + + +def apply_src_change(target, old, new, to_modify): + kernel.hash = None + function_0.hash = None + function_1.hash = None + function_2.hash = None + to_modify._unsafe_update_src(to_modify.src.replace(old, new)) + ret = target.cache_key + to_modify._unsafe_update_src(to_modify.src.replace(new, old)) + return ret + + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1', function_1) + assert baseline == updated + + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_1) + assert baseline != updated + + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_2) + assert baseline != updated + + +def test_nested2_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_0) + assert baseline != updated + + +def test_combine_fn_change(): + # Test that tl.reduce and associative_scan calls include + # the combine_fn in the hash + + orig_combine_fn_src = combine_fn.src + orig_kernel_src = kernel_with_combine_fn.src + seen_keys = set() + + for reduce_or_scan, combine_op in itertools.product( + ["tl.reduce", "tl.associative_scan"], + ["a + b", "a * b"], + ): + combine_fn._unsafe_update_src(orig_combine_fn_src.replace("COMBINE_OP", combine_op)) + kernel_with_combine_fn._unsafe_update_src(orig_kernel_src.replace("REDUCE_OR_SCAN", reduce_or_scan)) + try: + key = kernel_with_combine_fn.cache_key + finally: + combine_fn._unsafe_update_src(orig_combine_fn_src) + kernel_with_combine_fn._unsafe_update_src(orig_kernel_src) + + assert key not in seen_keys + seen_keys.add(key) + + +def write_and_load_module(temp_file: pathlib.Path, code, num_extra_lines): + temp_file.write_text(('# extra line\n' * num_extra_lines) + code) + spec = importlib.util.spec_from_file_location("module.name", str(temp_file)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_changed_line_numbers_invalidate_cache(tmp_path: pathlib.Path): + from textwrap import dedent + code = dedent(""" + import triton + @triton.jit + def test_kernel(i): + i = i + 1 + """) + temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py" + orig_mod = write_and_load_module(temp_file0, code, 0) + orig_cache_key = orig_mod.test_kernel.cache_key + + temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py" + updated_mod = write_and_load_module(temp_file1, code, 1) + updated_cache_key = updated_mod.test_kernel.cache_key + assert orig_cache_key != updated_cache_key + + +def test_reuse(device, fresh_triton_cache): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + x = torch.empty(1, dtype=torch.int32, device=device) + for i in range(10): + kernel[(1, )](x, 1, BLOCK=1024) + assert counter == 1 + + +@pytest.mark.parametrize('mode', ['enable', 'disable', 'disable_on_alignment']) +def test_specialize(mode, device, fresh_triton_cache): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + x = torch.empty(1, dtype=torch.int32, device=device) + function = {'enable': kernel, 'disable': kernel_nospec, 'disable_on_alignment': kernel_nospec_on_alignment}[mode] + target = {'enable': 3, 'disable': 1, 'disable_on_alignment': 2}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1, )](x, i, BLOCK=512) + assert counter == target + + +def test_annotation(device): + + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device=device) + + device = getattr(torch, device).current_device() + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) + assert len(kernel.device_caches[device][0]) == 3 + + +GLOBAL_DEFAULT_ARG = 1 + + +def test_kernel_default_arg(device): + global GLOBAL_DEFAULT_ARG + + @triton.jit + def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](x) + assert x == torch.ones_like(x) + + # Changing the global variable should not change the default argument in + # `kernel`. That value gets set at the time the function is declared. + GLOBAL_DEFAULT_ARG = 2 + kernel[(1, )](x) + assert x == torch.ones_like(x) + + device = getattr(torch, device).current_device() + assert len(kernel.device_caches[device][0]) == 1 + + +GLOBAL_VAR = tl.constexpr(1) + + +def test_kernel_global_var_change(device): + global GLOBAL_VAR + + @triton.jit + def kernel(X): + tl.store(X, GLOBAL_VAR) + + x = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](x) + assert x == torch.ones_like(x) + + GLOBAL_VAR = 2 + with pytest.raises(RuntimeError) as e: + kernel[(1, )](x) + + assert "global variable" in str(e.value).lower() + + +GLOBAL = 42 # noqa + + +def test_local_shadows_global(): + global GLOBAL + + @triton.jit + def kernel(): + _, GLOBAL = 0, 0 # noqa + a = GLOBAL # noqa + + # No error because the `GLOBAL` we're modifying is not the same `GLOBAL` as + # inside the kernel. + GLOBAL = 42 + kernel[(1, )]() + GLOBAL = 43 + kernel[(1, )]() + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_local_does_not_shadow_global(): + global CONSTEXPR_GLOBAL + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + _, CONSTEXPR_GLOBAL = 0, 0 # noqa + + CONSTEXPR_GLOBAL = tl.constexpr(42) + kernel[(1, )]() + CONSTEXPR_GLOBAL = tl.constexpr(43) + + # Error because the `CONSTEXPR_GLOBAL` we're modifying is the same + # `CONSTEXPR_GLOBAL` that's read inside `kernel`. (Alternatively, we could + # make this kernel an error altogether, as it is if it's a pure Python + # function -- the fact that we store to `CONSTEXPR_GLOBAL` inside the kernel + # makes the first read a read of the local variable, which doesn't exist + # yet.) + with pytest.raises(RuntimeError): + kernel[(1, )]() + + +CONFLICTING_GLOBAL = tl.constexpr(0) + + +@triton.jit +def conflicting_global_inner(): + a = CONFLICTING_GLOBAL # noqa + + +def test_conflicting_global_in_inner_function(): + global CONFLICTING_GLOBAL + + @triton.jit + def kernel1(): + a = CONFLICTING_GLOBAL # noqa + conflicting_global_inner() + + @triton.jit + def kernel2(): + a = CONFLICTING_GLOBAL #noqa + conflicting_global_inner() + + kernel1[(1, )]() + + # This should be an error because kernel2 calls conflicting_global_inner, + # which saw a value for 42 for the global when it was first compiled. + CONFLICTING_GLOBAL = 1 + + with pytest.raises(RuntimeError) as e: + kernel2[(1, )]() + + assert "Global variable CONFLICTING_GLOBAL has value" in str(e.value) + + +def test_use_builtin(): + + @triton.jit + def kernel(): + a = float(0) # noqa + + # No error about the value of `float` changing. + kernel[(1, )]() + kernel[(1, )]() + + +def test_no_cache_module_as_global(): + + @triton.jit + def kernel(): + tl.arange(0, 16) + + kernel[(1, )]() + # `tl` should not be entered into used_global_vals + assert not kernel.used_global_vals + + +BUILTIN_AS_GLOBAL = tl.int32 + + +def test_cache_builtin_as_global(): + global BUILTIN_AS_GLOBAL + + @triton.jit + def kernel(): + x = BUILTIN_AS_GLOBAL # noqa + + kernel[(1, )]() + + BUILTIN_AS_GLOBAL = tl.int64 + with pytest.raises(RuntimeError) as e: + kernel[(1, )]() + + assert "global variable" in str(e.value).lower() + + +@triton.jit +def no_cache_callable_inner(): + pass + + +def test_no_cache_callable(): + + @triton.jit + def kernel(): + no_cache_callable_inner() + + kernel[(1, )]() + # `no_cache_callable_inner` should not be entered into used_global_vals. + assert not kernel.used_global_vals + + +def test_jit_warmup_cache(device) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + 32, + ] + device = getattr(torch, device).current_device() + assert len(kernel_add.device_caches[device][0]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + + +def test_jit_debug(device) -> None: + + @triton.jit + def kernel(tmp): + tl.device_assert(tl.load(tmp) == 1, "tmp == 1") + + device = getattr(torch, device).current_device() + tmp = torch.tensor([1], dtype=torch.int32, device=device) + assert len(kernel.device_caches[device][0]) == 0 + kernel[(1, )](tmp, debug=False) + assert len(kernel.device_caches[device][0]) == 1 + kernel[(1, )](tmp, debug=True) + assert len(kernel.device_caches[device][0]) == 2 + bins = list(kernel.device_caches[device][0].values()) + assert bins[0].asm['ttir'] != bins[1].asm['ttir'] + + +@triton.jit +def add_fn(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + +def test_jit_noinline(device) -> None: + + @triton.jit + def kernel_add_device(a, b, o, N: tl.constexpr): + add_fn(a, b, o, N) + + device = getattr(torch, device).current_device() + assert len(kernel_add_device.device_caches[device][0]) == 0 + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.device_caches[device][0]) == 1 + bins = list(kernel_add_device.device_caches[device][0].values()) + inline_ttir = bins[0].asm['ttir'] + add_fn.noinline = True + add_fn.hash = None + kernel_add_device.hash = None + kernel_add_device.device_caches[device][0].clear() + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.device_caches[device][0]) == 1 + bins = list(kernel_add_device.device_caches[device][0].values()) + noinline_ttir = bins[0].asm['ttir'] + assert inline_ttir != noinline_ttir + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + +def test_preload(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) + + device = getattr(torch, device).current_device() + + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + pre_compile = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + hash = pre_compile.hash + assert specialization_data is not None + + # clear the cache + shutil.rmtree(fresh_triton_cache) + kernel_add.device_caches[device][0].clear() + + # preload the kernel + kernel_preload = kernel_add.preload(specialization_data) + assert kernel_preload.hash == hash + assert len(kernel_add.device_caches[device][0]) == 1 + + # we should hit the cache and not compile anything + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + JITFunction.cache_hook = None + assert counter == 0 + assert len(kernel_add.device_caches[device][0]) == 1 + assert final_kernel.hash == hash + + # test that we can't preload a mismatched kernel + with pytest.raises(RuntimeError, match="Specialization data is for"): + kernel_sub.preload(specialization_data) + + +def test_hooks(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + # get the serialized specialization data + specialization_data = None + is_warmup = False + key = 0 + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + nonlocal is_warmup + is_warmup = kwargs["compile"]["is_warmup"] + nonlocal key + key = kwargs["compile"]["key"] + + specialization_data_compiled = None + + def compiled_hook(*args, **kwargs): + nonlocal specialization_data_compiled + specialization_data_compiled = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + JITFunction.compiled_hook = compiled_hook + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + assert specialization_data is not None and specialization_data_compiled == specialization_data + assert is_warmup is True + assert key in kernel_add.device_caches[getattr(torch, device).current_device()][0] + + +@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip()) +def test_within_2gb(device, fresh_triton_cache) -> None: + default_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") + from triton.backends import backends + + amd_backend = backends["amd"] + try: + use_buffer_ops_opts = ["1", "0"] + # The ranges should only be available when buffer ops are enabled + pointer_ranges = [[(0, )], []] + for use_buffer_ops, pointer_range in zip(use_buffer_ops_opts, pointer_ranges): + # Set AMDGCN_USE_BUFFER_OPS + amd_backend.compiler.use_buffer_ops.cache_clear() + os.environ["AMDGCN_USE_BUFFER_OPS"] = use_buffer_ops + + @triton.jit + def kernel_add(a): + tl.load(a) + + # This is the attribute we want to test + pointer_range_32 = None + + def cache_hook(*args, **kwargs): + nonlocal pointer_range_32 + pointer_range_32 = [ + k for k, v in kwargs["compile"]["configs"][0].items() if ["tt.pointer_range", 32] in v + ] + + JITFunction.cache_hook = cache_hook + # In warmup we assume that the pointer range is 32 bits + kernel_add.warmup(torch.float32, grid=(1, )) + assert pointer_range_32 == pointer_range + # Torch tensor > 2GB + kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) + assert len(pointer_range_32) == 0 + # Torch tensor <= 2GB + kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) + assert pointer_range_32 == pointer_range + finally: + amd_backend.compiler.use_buffer_ops.cache_clear() + os.environ["AMDGCN_USE_BUFFER_OPS"] = default_buffer_ops + + +def test_function_arguments(device): + + @triton.jit + def func1(): + return 1 + + @triton.jit + def func2(): + return 2 + + @triton.jit + def func3(x): + return x + + @triton.jit + def func4(x, y): + return x + y + + @triton.jit + def kernel(Y, fn: tl.constexpr, fn_args): + tl.store(Y, fn(*fn_args)) + + JITFunction.cache_hook = None + JITFunction.compiled_hook = None + y = torch.zeros((5, ), dtype=torch.int32, device=device) + kernel[(1, )](y[0], func1, tuple()) + kernel[(1, )](y[1], func2, tuple()) + kernel[(1, )](y[2], func3, (3, )) + kernel[(1, )](y[3], func4, (3, 4)) + kernel[(1, )](y[4], func1, tuple()) + assert len(kernel.device_caches[0][0]) == 4 + assert y.tolist() == [1, 2, 3, 7, 1] diff --git a/third_party/enflame/include/triton/python/test/unit/runtime/test_cublas.py b/third_party/enflame/include/triton/python/test/unit/runtime/test_cublas.py new file mode 100644 index 000000000..a4315fc3c --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/runtime/test_cublas.py @@ -0,0 +1,49 @@ +import pytest +import torch +import triton +import os + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cuda(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cuda" + + +@pytest.mark.parametrize("m, n, k", [(16, 16, 16), (32, 16, 16), (16, 32, 16), (16, 16, 32)]) +@pytest.mark.parametrize("dtype_str", ["float8_e4m3fn", "float16"]) +def test_cublas(m, n, k, dtype_str, device): + dtype = getattr(torch, dtype_str) + if not is_cuda(): + pytest.skip("test_cublas is only supported on CUDA") + if dtype == torch.float8_e4m3fn and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("fp8 is only supported on CUDA with cc >= 90") + + from triton._C.libtriton import nvidia + + torch.manual_seed(123) + workspace_size = 32 * 1024 * 1024 + + def limited_rand(elements, shape): + total_elems = torch.prod(torch.tensor(shape)).item() + indices = torch.randint(0, len(elements), (total_elems, ), device=device) + return elements[indices].view(shape) + + elements = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32, device=device) + a = limited_rand(elements, (m, k)).to(dtype) + b = limited_rand(elements, (k, n)).to(dtype) + c = torch.zeros((m, n), dtype=dtype, device=device) + + b = b.T.contiguous() + + workspace = torch.empty(workspace_size, dtype=torch.int8, device=device) + + cublas = nvidia.cublas.CublasLt(workspace) + cublas.matmul(a, b, c) + + ref = torch.matmul(a.to(torch.float16), b.to(torch.float16).T) + + assert torch.allclose(c.to(torch.float16), ref, atol=2.0) diff --git a/third_party/enflame/include/triton/python/test/unit/runtime/test_driver.py b/third_party/enflame/include/triton/python/test/unit/runtime/test_driver.py new file mode 100644 index 000000000..9bd51cc2b --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/runtime/test_driver.py @@ -0,0 +1,41 @@ +import sys +from concurrent.futures import ThreadPoolExecutor +import torch + +import triton +import triton.language as tl + + +def test_is_lazy(): + from importlib import reload + reload(sys.modules["triton.runtime.driver"]) + reload(sys.modules["triton.runtime"]) + mod = sys.modules[triton.runtime.driver.__module__] + assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy")) + assert triton.runtime.driver.active._obj is None + utils = triton.runtime.driver.active.utils # noqa: F841 + assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase")) + + +def test_kernel_in_thread(device): + # Test calling in a new thread sets a valid device context + buf = torch.zeros((38016 * 1024, ), dtype=torch.float32, device=device) + + @triton.jit + def _kernel(P, BLOCK: tl.constexpr): + pid = tl.program_id(0).to(tl.int64) + offset = pid * BLOCK + tl.arange(0, BLOCK) + + p = tl.load(P + offset) + tl.store(P + offset, p) + + def call_triton(): + N = buf.numel() + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]), ) + _kernel[grid](buf, BLOCK=1024) + getattr(torch, device).synchronize() + + call_triton() + with ThreadPoolExecutor(1) as pool: + future = pool.submit(call_triton) + future.result() diff --git a/third_party/enflame/include/triton/python/test/unit/runtime/test_jit.py b/third_party/enflame/include/triton/python/test/unit/runtime/test_jit.py new file mode 100644 index 000000000..5892494c4 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/runtime/test_jit.py @@ -0,0 +1,42 @@ +import itertools +import pytest +import torch + +import triton +import triton.language as tl + + +def test_pre_call_hooks(device): + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + class MyTensor(torch.Tensor): + pass + + def my_hook(*args, **kwargs): + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, MyTensor): + raise Exception("MyTensor is not allowed") + + add_kernel.add_pre_run_hook(my_hook) + + x = torch.randn(4, device=device) + y = MyTensor(x) + out = torch.zeros_like(x) + with pytest.raises(Exception): + add_kernel[(4, )](x, y, out, 4, 4) diff --git a/third_party/enflame/include/triton/python/test/unit/runtime/test_launch.py b/third_party/enflame/include/triton/python/test/unit/runtime/test_launch.py new file mode 100644 index 000000000..91fc6e19b --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/runtime/test_launch.py @@ -0,0 +1,134 @@ +import gc +# import importlib +# import os +# import sys +# import tempfile +# import textwrap +# import time +import tracemalloc + +import torch + +import triton +import triton.language as tl + +# from typing import Tuple + + +def test_metadata() -> None: + + used_hook = False + + def _launch_metadata(grid, kernel, args): + ret = dict() + ret["grid"] = grid + ret["value"] = args["x"] + return ret + + def hook(launch_metadata): + nonlocal used_hook + metadata = launch_metadata.get() + assert metadata["grid"] == (1, 3, 2) + assert metadata["value"] == 6 + used_hook = True + + @triton.jit(launch_metadata=_launch_metadata) + def kernel(x): + pass + + # launch kernel + triton.compiler.CompiledKernel.launch_enter_hook = hook + kernel[(1, 3, 2)](6) + triton.compiler.CompiledKernel.launch_enter_hook = None + assert used_hook + + +def test_memory_leak(device) -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + tracemalloc.start() + try: + inp = torch.randn(10, device=device) + out = torch.randn(10, device=device) + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + begin, _ = tracemalloc.get_traced_memory() + for _ in range(100): + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + end, _ = tracemalloc.get_traced_memory() + assert end - begin < 30000 + finally: + tracemalloc.stop() + + +# LATENCY_THRESHOLD_US = 46 + +# def test_kernel_launch_latency() -> None: +# def define_kernel(kernel_name: str, num_tensor_args: int) -> str: +# arg_str = ",".join([f"arg{i}: torch.Tensor" for i in range(num_tensor_args)]) +# arg_str += ", n_elements: int, BLOCK_SIZE: tl.constexpr" +# func_str = f""" +# import torch + +# import triton +# import triton.language as tl + +# @triton.jit +# def {kernel_name}({arg_str}): +# pass +# """ +# with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py", delete=False) as temp_file: +# temp_file.write(textwrap.dedent(func_str)) +# temp_file_path = temp_file.name + +# return temp_file_path + +# def import_kernel(file_path, kernel_name): +# directory, filename = os.path.split(file_path) +# module_name, _ = os.path.splitext(filename) +# sys.path.insert(0, directory) + +# module = importlib.import_module(module_name) +# kernel = getattr(module, kernel_name) +# return kernel + +# def empty(*kernel_args: Tuple[torch.Tensor]): +# first_arg = kernel_args[0] +# n_elements = first_arg.numel() +# grid = (triton.cdiv(n_elements, 1024),) +# device = torch.cuda.current_device() +# # Warmup +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# torch.cuda.synchronize() +# # Measure launch overhead at steady state +# num_runs = 1000 +# start_time = time.time() +# for i in range(num_runs): +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# end_time = time.time() +# latency_us = (end_time - start_time) / num_runs * 1e6 + +# assert latency_us < LATENCY_THRESHOLD_US, "Kernel launch time has increased!" + +# num_tensor_args = 40 +# kernel_name = 'empty_kernel' +# file_path = define_kernel(kernel_name, num_tensor_args) +# empty_kernel = import_kernel(file_path, kernel_name) + +# # Initialize random tensors for the empty_kernel +# torch.manual_seed(0) +# size = 1024 +# kernel_args = (torch.rand(size, device='cuda') for i in range(num_tensor_args)) + +# # Run empty, which would run empty_kernel internally +# empty(*kernel_args) diff --git a/third_party/enflame/include/triton/python/test/unit/runtime/test_subproc.py b/third_party/enflame/include/triton/python/test/unit/runtime/test_subproc.py new file mode 100644 index 000000000..928b6e6a8 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/runtime/test_subproc.py @@ -0,0 +1,102 @@ +import multiprocessing +import shutil + +import triton +import triton.language as tl +from triton.compiler import ASTSource + +target = triton.runtime.driver.active.get_current_target() +start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn' + + +def compile_fn(): + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + + src = ASTSource( + fn=kernel_sub, + constexprs={'N': 32}, + signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32", 'N': 'constexpr'}, + ) + triton.compile(src=src, target=target) + + +def test_compile_in_subproc() -> None: + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_fn) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_fn_dot(): + + @triton.jit + def kernel_dot(Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + tl.store(Z + offs, z) + + src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}) + triton.compile(src=src, target=target) + + +def test_compile_in_forked_subproc(fresh_triton_cache) -> None: + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_fn_dot) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_empty_kernel_with_gc(): + + @triton.jit + def empty_kernel(): + pass + + import gc + gc.collect() + src = ASTSource(fn=empty_kernel, signature={}) + triton.compile(src=src, target=target) + + +def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None: + ''' + Tests that compilation artifacts can safely live in forked process. + + Scenario being tested here ("p" stands for parent process, "c" is child process): + 1. p compiles a kernel 1, and produces compilation artifacts. + 2. p forks the process to create c. + 3. c deletes compilation artifacts inherited from p, compiles kernel 2, and terminates. + 3. p wait for c and join it. + + This is a regression test that ensures thread pool in MLIRContext is released + safely after compilation. + ''' + import gc + old_gc_state = gc.isenabled() + # disable GC to manage resources manually in the manner described in comment above + gc.disable() + + # stage 1.p + compile_empty_kernel_with_gc() + + # stage 2.p + shutil.rmtree(fresh_triton_cache) + mp_ctx = multiprocessing.get_context(start_method) + proc = mp_ctx.Process(target=compile_empty_kernel_with_gc) + + # stage 3.c + proc.start() + # stage 3.p + proc.join() + + # restore gc state + if old_gc_state: + gc.enable() + assert proc.exitcode == 0 diff --git a/third_party/enflame/include/triton/python/test/unit/test_debug.py b/third_party/enflame/include/triton/python/test/unit/test_debug.py new file mode 100644 index 000000000..8ea621202 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/test_debug.py @@ -0,0 +1,142 @@ +import pytest +import torch +import triton.language as tl +import triton + + +@pytest.mark.parametrize('cond', [True, False]) +@pytest.mark.parametrize('opt_flag', [True, False, None]) +@pytest.mark.parametrize('env_var', [True, False]) +@pytest.mark.parametrize('jit_flag', [True, False]) +@pytest.mark.forked +def test_device_assert(monkeypatch, cond, opt_flag, env_var, jit_flag, device): + monkeypatch.setenv("TRITON_DEBUG", str(int(env_var))) + torch.zeros([1], dtype=torch.int32, device=device) + + @triton.jit(debug=jit_flag) + def _kernel(COND: tl.constexpr): + tl.device_assert(COND, 'test') + + is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag) + + kwargs = {} + if opt_flag is not None: + kwargs["debug"] = opt_flag + + if not cond and is_debug: + with pytest.raises(RuntimeError): + _kernel[(1, )](cond, **kwargs) + getattr(torch, device).synchronize() + return + + _kernel[(1, )](cond, **kwargs) + getattr(torch, device).synchronize() + + +def test_device_assert_barrier(monkeypatch, device): + monkeypatch.setenv("TRITON_DEBUG", "1") + tensor = torch.zeros([16], dtype=torch.int32, device=device) + + @triton.jit + def _kernel(in_ptr0): + xindex = tl.arange(0, 8) + tmp0 = tl.load(in_ptr0 + xindex) + tl.device_assert(tmp0 < 1) + + _kernel[(1, )](tensor) + getattr(torch, device).synchronize() + + +@pytest.mark.parametrize("cond", [False, True]) +def test_static_assert(cond): + + @triton.jit + def _kernel(COND: tl.constexpr): + tl.static_assert(COND) + + if not cond: + with pytest.raises(triton.compiler.errors.CompileTimeAssertionFailure): + _kernel[(1, )](cond) + return + + _kernel[(1, )](cond) + + +def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func, device): + x = torch.tensor([x], dtype=getattr(torch, x_dtype), device=device) + y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device) + z = torch.empty_like(x) + if should_overflow and debug: + with pytest.raises(RuntimeError) as exc_info: + tri_func[(1, )](x, y, z, debug=debug) + getattr(torch, device).synchronize() + assert "device-side assert" in str(exc_info.value) + else: + tri_func[(1, )](x, y, z, debug=debug) + getattr(torch, device).synchronize() + assert int(z) == int(ref_func(x, y)) + + +# integer overflow sanitization + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (-2**31, -1, 'int32', 'int32', False, False), + (-2**31, -1, 'int32', 'int32', True, True), + (2**31 - 1, 1, 'int32', 'int32', True, True), + (2**31 - 1, 100, 'int32', 'int32', True, True), + (-2**31, 0, 'int32', 'int32', True, False), + (-2**31, 2, 'int32', 'int32', True, False), + (0, -1, 'int32', 'int32', True, False), + (-2**15, -1, 'int16', 'int16', True, True), + (2**15 - 1, 1, 'int16', 'int16', True, True), +]) +@pytest.mark.forked +def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): + + @triton.jit + def _kernel_add(X, Y, Z): + tl.store(Z, tl.load(X) + tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y, device) + + +# mul overflow + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (2**30, 4, 'int32', 'int32', False, False), + (2**30, 4, 'int32', 'int32', True, True), + (2**30, 2, 'int32', 'int32', True, True), + (-2**30, -4, 'int32', 'int32', True, True), + (-2**31, 1, 'int32', 'int32', True, False), + (-2**30, 2, 'int32', 'int32', True, False), +]) +@pytest.mark.forked +def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): + + @triton.jit + def _kernel_mul(X, Y, Z): + tl.store(Z, tl.load(X) * tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y, device) + + +# sub overflow + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (-2**31, 1, 'int32', 'int32', False, False), + (-2**31, 1, 'int32', 'int32', True, True), + (2**31 - 1, -1, 'int32', 'int32', True, True), + (2**31 - 1, 1, 'int32', 'int32', True, False), + (-2**31, -1, 'int32', 'int32', True, False), +]) +@pytest.mark.forked +def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): + + @triton.jit + def _kernel_sub(X, Y, Z): + tl.store(Z, tl.load(X) - tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y, device) diff --git a/third_party/enflame/include/triton/python/test/unit/test_debug_dump.py b/third_party/enflame/include/triton/python/test/unit/test_debug_dump.py new file mode 100644 index 000000000..4f522941e --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/test_debug_dump.py @@ -0,0 +1,49 @@ +import os +from contextlib import contextmanager + +import torch +import triton +import triton.language as tl + + +@contextmanager +def enable_dump_context(pass_name="1"): + try: + os.environ["MLIR_ENABLE_DUMP"] = pass_name + yield + finally: + os.environ["MLIR_ENABLE_DUMP"] = "0" + + +def test_fn_dump(capfd, device, fresh_triton_cache): + N = 1024 + src = torch.zeros(N, device=device) + + grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]), ) + + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + with enable_dump_context(): + BLOCK_SIZE = 16 + _kernel[grid](src, N, BLOCK_SIZE) + captured = capfd.readouterr() + print(captured.err) + assert "IR Dump Before" in captured.err + assert "tt.func public @_kernel" in captured.err + + with enable_dump_context("_kernel"): + BLOCK_SIZE = 32 + _kernel[grid](src, N, BLOCK_SIZE) + captured = capfd.readouterr() + assert "IR Dump Before" in captured.err + assert "tt.func public @_kernel" in captured.err + + with enable_dump_context("_kernel2"): + BLOCK_SIZE = 64 + _kernel[grid](src, N, BLOCK_SIZE) + captured = capfd.readouterr() + assert "IR Dump Before" not in captured.err diff --git a/third_party/enflame/include/triton/python/test/unit/test_perf_warning.py b/third_party/enflame/include/triton/python/test/unit/test_perf_warning.py new file mode 100644 index 000000000..86bebdd71 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/test_perf_warning.py @@ -0,0 +1,188 @@ +import os +from contextlib import contextmanager + +import pytest +import torch +import triton +import triton.language as tl + + +@contextmanager +def enable_diagnostics_context(value): + try: + os.environ["MLIR_ENABLE_DIAGNOSTICS"] = value + yield + finally: + os.environ["MLIR_ENABLE_DIAGNOSTICS"] = "" + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def test_mma_remark(capfd, fresh_triton_cache): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] != 9: + pytest.skip("Requires sm = 90 to run") + + @triton.jit + def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + ): + a_block_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(0, 0), + block_shape=(32, 128), + order=(1, 0), + ) + b_block_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, 0), + block_shape=(128, 32), + order=(0, 1), + ) + c_block_ptr = tl.make_block_ptr( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + offsets=(0, 0), + block_shape=(32, 32), + order=(1, 0), + ) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + c = tl.dot(a, b) + tl.store(c_block_ptr, c) + + signature = { + "a_ptr": "*fp32", + "b_ptr": "*fp32", + "c_ptr": "*fp32", + "M": "i32", + "N": "i32", + "K": "i32", + "stride_am": "i32", + "stride_ak": "i32", + "stride_bk": "i32", + "stride_bn": "i32", + "stride_cm": "i32", + "stride_cn": "i32", + } + with enable_diagnostics_context('remarks'): + triton.compile(triton.compiler.ASTSource( + fn=matmul_kernel, + signature=signature, + constexprs={}, + )) + captured = capfd.readouterr() + + assert ("can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark" + assert "note: see current operation:" not in captured.err + + with enable_diagnostics_context('remarks,operations,stacktraces'): + triton.compile(triton.compiler.ASTSource( + fn=matmul_kernel, + signature=signature, + constexprs={}, + )) + captured = capfd.readouterr() + assert "note: diagnostic emitted with trace:" in captured.err + assert "note: see current operation:" in captured.err + + +def test_remark_vectorization(capfd, fresh_triton_cache): + + @triton.jit + def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + x0 = xindex % 9 + x2 = (xindex // 3456) % 512 + x1 = (xindex // 9) % 384 + x4 = xindex + tmp0 = tl.load(in_ptr0 + (x2 + (512 * x0)), None, eviction_policy="evict_last") + tmp1 = tmp0 + 520 + tmp2 = tmp0 < 0 + tmp3 = tl.where(tmp2, tmp1, tmp0) + tmp9 = (-4) + tmp3 + tmp12 = tl.full([1], 512, tl.int64) + tmp14 = tmp9 < tmp12 + tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy="evict_last", other=0.0) + tmp18 = tmp16.to(tl.float32) + tmp19 = tmp18.to(tl.float32) + tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype) + tmp21 = tl.where(tmp14, tmp19, tmp20) + tmp22 = tmp21.to(tl.float32) + tl.store(out_ptr0 + (x4), tmp22, None) + + XBLOCK = 1024 + + astsource_args = { + "fn": ldst_vec, + "signature": { + "in_ptr0": "*i64", + "in_ptr1": "*i64", + "in_ptr2": "*fp16", + "in_ptr3": "*fp32", + "out_ptr0": "*fp16", + "XBLOCK": "constexpr", + }, + "constexprs": {"XBLOCK": XBLOCK}, + } + + with enable_diagnostics_context('remarks'): + triton.compile( + triton.compiler.ASTSource(**astsource_args), + options={"num_warps": 1}, + ) + + _, err = capfd.readouterr() + assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" + assert "note: see current operation:" not in err + + with enable_diagnostics_context('remarks,operations,stacktraces'): + triton.compile( + triton.compiler.ASTSource(**astsource_args), + options={"num_warps": 1}, + ) + + _, err = capfd.readouterr() + assert "note: see current operation:" in err + assert "note: diagnostic emitted with trace:" in err + + +def test_remark_swp_op_before_operands(capfd, fresh_triton_cache): + + @triton.jit + def kernel_pipe_error(in_ptr, out_ptr): + SIZE: tl.constexpr = 64 + in_ptrs = in_ptr + tl.arange(0, SIZE) + val = tl.zeros((SIZE, ), dtype=tl.float32) + k = 0 + for i in tl.range(0, 64, num_stages=3): + in_ptrs = in_ptr + tl.arange(0, SIZE) + SIZE * k + val = tl.load(in_ptrs) + out_ptrs = out_ptr + (tl.arange(0, SIZE) + i * SIZE) + tl.store(out_ptrs, val) + if tl.max(val) > 0: + k += 1 + + i = torch.empty(64 * 64, dtype=torch.float32).cuda() + o = torch.empty(64 * 64, dtype=torch.float32).cuda() + kernel_pipe_error[(1, )](i, o) diff --git a/third_party/enflame/include/triton/python/test/unit/tools/test_aot.py b/third_party/enflame/include/triton/python/test/unit/tools/test_aot.py new file mode 100644 index 000000000..d80c79cf6 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/tools/test_aot.py @@ -0,0 +1,442 @@ +import glob +import os +import subprocess +import sys +import tempfile + +import numpy as np + +import triton +from triton.backends.compiler import GPUTarget +from triton.backends.nvidia.driver import include_dir, library_dirs + +kernel_utils_src = """ +import triton + +@triton.jit +def mul(x, y): + return x * y +""" + +kernel_src = """ +import triton +import triton.language as tl +import kernel_utils + +@triton.jit +def kernel(C, A, B, M, N, K, + stride_cm, stride_cn, + stride_am, stride_ak, + stride_bk, stride_bn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + c = kernel_utils.mul(accumulator, accumulator) + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, c) +""" + +test_utils_src = """ +#include +#include +#include +#include +#include +#include "kernel.h" + +static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) { + FILE *file = fopen(filename, "w"); + if (file == NULL) { + printf("Could not open file %s\\n", filename); + return; + } + for (int i = 0; i < size; i++) { + fprintf(file, "%d", buffer[i]); + if (i < size - 1) { + fprintf(file, ","); + } + } + fclose(file); +} + +static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + printf("Could not open file %s\\n", filename); + return; + } + int index = 0; + while (fscanf(file, "%hd,", &buffer[index]) != EOF && index < size) { + index++; + } + fclose(file); +}""" + + +def gen_kernel_library(dir, libname): + c_files = glob.glob(os.path.join(dir, "*.c")) + subprocess.run( + ["gcc"] + c_files + ["-I", include_dir[0], "-c", "-fPIC"], + check=True, + cwd=dir, + ) + o_files = glob.glob(os.path.join(dir, "*.o")) + + command = ["gcc", *o_files, "-shared", "-o", libname] + for lib_dir in library_dirs(): + command.extend(["-L", lib_dir]) + subprocess.run(command, check=True, cwd=dir) + + +def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): + test_src = f""" +int main(int argc, char **argv) {{ + int M = {M}, N = {N}, K = {K}; + + // initialize CUDA handles + CUdevice dev; + CUcontext ctx; + CUstream stream; + CUdeviceptr A, B, C; + CUresult err = 0; + cuInit(0); + cuDeviceGet(&dev, 0); + cuCtxCreate(&ctx, 0, dev); + cuMemAlloc(&A, M * K * 2); + cuMemAlloc(&B, K * N * 2); + cuMemAlloc(&C, M * N * 4); + cuStreamCreate(&stream, 0); + load_matmul_fp16(); + + // initialize input data + int16_t hA[M*K]; + int16_t hB[K*N]; + memset(hA, 0, M*K*2); + memset(hB, 0, K*N*2); + read_csv_to_buffer(argv[1], hA, M*K); + read_csv_to_buffer(argv[2], hB, K*N); + cuMemcpyHtoD(A, hA, M*K*2); + cuMemcpyHtoD(B, hB, K*N*2); + + // launch kernel + cuStreamSynchronize(stream); + CUresult ret; + int algo_id = {algo_id}; + if (algo_id == 0) {{ + ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1); + }} else {{ + ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id}); + }} + if (ret != 0) fprintf(stderr, "kernel launch failed\\n"); + assert(ret == 0); + + cuStreamSynchronize(stream); + + // read data + int32_t hC[M*N]; + memset(hC, 0, M*N*4); + cuMemcpyDtoH(hC, C, M*N*4); + write_buffer_to_csv(argv[3], hC, M*N); + + // free cuda handles + unload_matmul_fp16(); + cuMemFree(A); + cuMemFree(B); + cuMemFree(C); + cuCtxDestroy(ctx); +}} +""" + src = test_utils_src + test_src + with open(os.path.join(dir, "test.c"), "w") as file: + file.write(src) + + command = ["gcc", "test.c"] + for inc_dir in include_dir: + command.extend(["-I", inc_dir]) + for lib_dir in library_dirs(): + command.extend(["-L", lib_dir]) + command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe]) + subprocess.run(command, check=True, cwd=dir) + + +def write_triton_kernels(dir, src, util_src): + kernel_path = os.path.join(dir, "kernel.py") + with open(kernel_path, "w") as file: + file.write(src) + + kernel_utils_path = os.path.join(dir, "kernel_utils.py") + with open(kernel_utils_path, "w") as file: + file.write(util_src) + + return kernel_path + + +def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path): + compiler_path = os.path.join(triton.tools.__path__[0], "compile.py") + + subprocess.run( + [ + sys.executable, + compiler_path, + "-n", + kernel_name, + "--signature", + signature, + "--out-name", + out_name, + "-o", + out_path, + "-w", + str(num_warps), + "-g", + grid, + kernel_path, + ], + check=True, + cwd=dir, + ) + + +# Edge case kernel with no specialization +def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK): + # compile all desired configs + sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}" + name = f"matmul_{dtype}" + grid = f"M/{BM}, N/{BN}, 1" + _compile_kernel( + dir=dir, + signature=sig, + kernel_name="kernel", + out_name=name, + out_path=name, + num_warps=1, + grid=grid, + kernel_path=kernel_path, + ) + + +def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints): + # compile all desired configs + for ha in ha_hb_hints: + for hb in ha_hb_hints: + sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}" + name = f"matmul_{dtype}" + grid = f"M/{BM}, N/{BN}, 1" + _compile_kernel( + dir=dir, + signature=sig, + kernel_name="kernel", + out_name=name, + out_path=name, + num_warps=1, + grid=grid, + kernel_path=kernel_path, + ) + + +def link_aot_kernels(dir): + linker_path = os.path.join(triton.tools.__path__[0], "link.py") + + # link all desired configs + h_files = glob.glob(os.path.join(dir, "*.h")) + subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir) + + +def generate_matmul_test_data(dir, M, N, K): + a = np.random.randn(M * K).astype(np.float16).reshape((M, K)) + b = np.random.randn(M * K).astype(np.float16).reshape((K, N)) + a_path = os.path.join(dir, "a.csv") + b_path = os.path.join(dir, "b.csv") + c_path = os.path.join(dir, "c.csv") + for x, path in [(a, a_path), (b, b_path)]: + x.view(np.int16).ravel().tofile(path, sep=",") + return a, b, a_path, b_path, c_path + + +# Test edge case where the provided kernel signature has no specializations +def test_compile_link_matmul_no_specialization(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + BM, BN, BK = 16, 16, 16 + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK) + link_aot_kernels(tmp_dir) + + # compile test case + M, N, K = 16, 16, 16 + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) + + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + + # run test case + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0) + + +def test_compile_link_matmul(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + BM, BN, BK = 16, 16, 16 + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + link_aot_kernels(tmp_dir) + + # compile test case + M, N, K = 16, 16, 16 + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) + + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + + # run test case + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0) + + +def test_launcher_has_no_available_kernel(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + BM, BN, BK = 16, 16, 16 + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[":1"]) + link_aot_kernels(tmp_dir) + + # compile test case + M, N, K = 16, 16, 16 + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) + + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + + # run test case + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + result = subprocess.run( + ["./test", a_path, b_path, c_path], + env=env, + cwd=tmp_dir, + capture_output=True, + text=True, + ) + + # It should fail since the launcher requires all the strides be 1 while they are not. + assert result.returncode == -6 + assert "kernel launch failed" in result.stderr + + +def test_compile_link_autotune_matmul(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + + tile_sizes = [ + [16, 16, 16], + [32, 32, 16], + [32, 32, 32], + [64, 64, 32], + ] + + for ts in tile_sizes: + BM, BN, BK = ts[0], ts[1], ts[2] + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + + link_aot_kernels(tmp_dir) + + gen_kernel_library(tmp_dir, "libkernel.so") + + # compile test case + M, N, K = 64, 64, 64 + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + + for algo_id in range(len(tile_sizes)): + # generate and run test case + test_name = f"test_{algo_id}" + gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id) + + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run( + [f"./{test_name}", a_path, b_path, c_path], + check=True, + cwd=tmp_dir, + env=env, + ) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=1e-4) + + +def test_ttgir_to_ptx(): + src = """ +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { + tt.return + } +} +""" + with tempfile.TemporaryDirectory() as tmp_dir: + kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir") + with open(kernel_path, "w") as fp: + fp.write(src) + k = triton.compile(kernel_path, target=GPUTarget("cuda", 80, 32)) + ptx = k.asm["ptx"] + assert ".target sm_80" in ptx + assert ".address_size 64" in ptx diff --git a/third_party/enflame/include/triton/python/test/unit/tools/test_disasm.py b/third_party/enflame/include/triton/python/test/unit/tools/test_disasm.py new file mode 100644 index 000000000..cc4982706 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/tools/test_disasm.py @@ -0,0 +1,21 @@ +import torch + +import triton +import pytest +import triton.language as tl + + +def test_disam_cubin(): + if not triton.runtime.driver.active.get_current_target().backend == "cuda": + pytest.skip("Test requires CUDA.") + + @triton.jit + def kernel(X, i: tl.constexpr): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + h = kernel[(1, )](x, i=12) + assert x[0] == 12 + sass = h.asm["sass"] + # check that the sass has a store instruction. + assert "STG.E" in sass diff --git a/third_party/enflame/include/triton/python/test/unit/tools/test_irsource.py b/third_party/enflame/include/triton/python/test/unit/tools/test_irsource.py new file mode 100644 index 000000000..0c0f25ce7 --- /dev/null +++ b/third_party/enflame/include/triton/python/test/unit/tools/test_irsource.py @@ -0,0 +1,92 @@ +import pathlib +import triton +from triton.compiler import IRSource, make_backend +from triton._C.libtriton import ir + +target = triton.runtime.driver.active.get_current_target() +backend = make_backend(target) + + +def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: + ''' + Tests that MLIR attributes are parsed correctly from input ttir/ttgir. + + Checks for the following: + 1. Name and type signature are parsed correctly + 2. _get_num_warps_from_ir_str() works + 3. tt.nv_tma_desc attribute is parsed correctly + ''' + + sample_ttgir = r""" +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32}, + %desc: !tt.ptr {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} { + tt.return + } +} +""" + temp_file = tmp_path / "test_mlir_attribute_parsing0.ttgir" + temp_file.write_text(sample_ttgir) + context = ir.context() + src = IRSource(str(temp_file), context, backend) + + # check name and type signature + # should match ty_to_cpp(...) + assert src.signature == \ + {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ + 4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} + assert src.name == "@matmul_kernel" + + # check num warps + assert src.parse_options()['num_warps'] == 8 + + sample_ttgir_vector_add = r""" + #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}) + attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> + %13 = arith.addi %9, %12 : tensor<1024xi32, #blocked> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return + } + } + """ + temp_file = tmp_path / "test_mlir_attribute_parsing1.ttgir" + temp_file.write_text(sample_ttgir_vector_add) + context = ir.context() + src = IRSource(str(temp_file), context, backend) + + # now test compilation + triton.compile(str(temp_file), target=target) diff --git a/third_party/enflame/include/triton/python/triton/_C/include b/third_party/enflame/include/triton/python/triton/_C/include new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/include/triton/python/triton/__init__.py b/third_party/enflame/include/triton/python/triton/__init__.py new file mode 100644 index 000000000..46f198d18 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/__init__.py @@ -0,0 +1,73 @@ +"""isort:skip_file""" +__version__ = '3.3.1' + +# --------------------------------------- +# Note: import order is significant here. + +# submodules +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, + reinterpret, + TensorWrapper, + OutOfResources, + InterpreterError, + MockTensor, +) +from .runtime.jit import jit +from .compiler import compile, CompilationError +from .errors import TritonError +from .runtime._allocation import set_allocator + +from . import language +from . import testing +from . import tools + +__all__ = [ + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "heuristics", + "InterpreterError", + "jit", + "JITFunction", + "KernelInterface", + "language", + "MockTensor", + "next_power_of_2", + "OutOfResources", + "reinterpret", + "runtime", + "set_allocator", + "TensorWrapper", + "TritonError", + "testing", + "tools", +] + +# ------------------------------------- +# misc. utilities that don't fit well +# into any specific module +# ------------------------------------- + + +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/third_party/enflame/include/triton/python/triton/_internal_testing.py b/third_party/enflame/include/triton/python/triton/_internal_testing.py new file mode 100644 index 000000000..dbc2d0179 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/_internal_testing.py @@ -0,0 +1,178 @@ +import os +import re +import numpy as np +import torch +import triton +import triton.language as tl +from triton.backends.nvidia.compiler import _path_to_binary +import pytest + +from numpy.random import RandomState +from typing import Optional, Union +from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict + +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] +integral_dtypes = int_dtypes + uint_dtypes +float_dtypes = ['float16', 'float32', 'float64'] +float_dtypes_with_bfloat16 = float_dtypes + ['bfloat16'] +dtypes = integral_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def get_current_target(): + if is_interpreter(): + return None + return triton.runtime.driver.active.get_current_target() + + +def is_cuda(): + target = get_current_target() + return False if target is None else target.backend == "cuda" + + +def is_hopper(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def is_hip(): + target = get_current_target() + return False if target is None else target.backend == "hip" + + +def is_hip_mi200(): + target = get_current_target() + if target is None or target.backend != 'hip': + return False + return target.arch == 'gfx90a' + + +def is_hip_mi300(): + target = get_current_target() + if target is None or target.backend != 'hip': + return False + return target.arch in ('gfx940', 'gfx941', 'gfx942') + + +def is_hip_mi350(): + target = get_current_target() + if target is None or target.backend != 'hip': + return False + return target.arch in ('gfx950') + + +def is_hip_cdna(): + return is_hip_mi200() or is_hip_mi300() or is_hip_mi350() + + +def is_xpu(): + target = get_current_target() + return False if target is None else target.backend == "xpu" + + +def get_arch(): + target = get_current_target() + return "" if target is None else str(target.arch) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + if dtype_str in int_dtypes + uint_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) + x = rs.randint(low, high, shape, dtype=dtype) + x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. + return x + elif dtype_str and 'float8' in dtype_str: + x = rs.randint(20, 40, shape, dtype=np.int8) + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type because the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + if dst_type and 'float8' in dst_type: + return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type)) + if t == 'float32' and dst_type == 'bfloat16': + return torch.tensor(x, device=device).bfloat16() + return torch.tensor(x, device=device) + + +def str_to_triton_dtype(x: str) -> tl.dtype: + return tl.str_to_ty(type_canonicalisation_dict[x]) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') + + +def to_numpy(x): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + + +def supports_tma(byval_only=False): + if is_interpreter(): + return True + if not is_cuda(): + return False + _, cuda_version = _path_to_binary("ptxas") + min_cuda_version = (12, 0) if byval_only else (12, 3) + cuda_version_tuple = tuple(map(int, cuda_version.split("."))) + assert len(cuda_version_tuple) == 2, cuda_version_tuple + return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version + + +def tma_skip_msg(byval_only=False): + if byval_only: + return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)" + else: + return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)" + + +requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg()) diff --git a/third_party/enflame/include/triton/python/triton/_utils.py b/third_party/enflame/include/triton/python/triton/_utils.py new file mode 100644 index 000000000..e89894a1e --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/_utils.py @@ -0,0 +1,35 @@ +from functools import reduce + + +def get_iterable_path(iterable, path): + return reduce(lambda a, idx: a[idx], path, iterable) + + +def set_iterable_path(iterable, path, val): + prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1]) + prev[path[-1]] = val + + +def find_paths_if(iterable, pred): + from .language import core + is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type)) + ret = dict() + + def _impl(current, path): + path = (path[0], ) if len(path) == 1 else tuple(path) + if is_iterable(current): + for idx, item in enumerate(current): + _impl(item, path + (idx, )) + elif pred(path, current): + if len(path) == 1: + ret[(path[0], )] = None + else: + ret[tuple(path)] = None + + if is_iterable(iterable): + _impl(iterable, []) + elif pred(list(), iterable): + ret = {tuple(): None} + else: + ret = dict() + return list(ret.keys()) diff --git a/third_party/enflame/include/triton/python/triton/backends/__init__.py b/third_party/enflame/include/triton/python/triton/backends/__init__.py new file mode 100644 index 000000000..92ba144ba --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/backends/__init__.py @@ -0,0 +1,50 @@ +import os +import importlib.util +import inspect +from dataclasses import dataclass +from .driver import DriverBase +from .compiler import BaseBackend + + +def _load_module(name, path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _find_concrete_subclasses(module, base_class): + ret = [] + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr): + ret.append(attr) + if len(ret) == 0: + raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}") + if len(ret) > 1: + raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}") + return ret[0] + + +@dataclass(frozen=True) +class Backend: + compiler: BaseBackend = None + driver: DriverBase = None + + +def _discover_backends(): + backends = dict() + root = os.path.dirname(__file__) + for name in os.listdir(root): + if not os.path.isdir(os.path.join(root, name)): + continue + if name.startswith('__'): + continue + compiler = _load_module(name, os.path.join(root, name, 'compiler.py')) + driver = _load_module(name, os.path.join(root, name, 'driver.py')) + backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), + _find_concrete_subclasses(driver, DriverBase)) + return backends + + +backends = _discover_backends() diff --git a/third_party/enflame/include/triton/python/triton/backends/compiler.py b/third_party/enflame/include/triton/python/triton/backends/compiler.py new file mode 100644 index 000000000..3583429de --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/backends/compiler.py @@ -0,0 +1,104 @@ +import os +import re +import subprocess +import sysconfig +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import Dict, Union +from types import ModuleType + + +@dataclass(frozen=True) +class GPUTarget(object): + # Target backend, e.g., cuda, hip + backend: str + # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) + arch: Union[int, str] + warp_size: int + + +class BaseBackend(metaclass=ABCMeta): + + def __init__(self, target: GPUTarget) -> None: + self.target = target + assert self.supports_target(target) + + @staticmethod + def _path_to_binary(binary: str): + binary += sysconfig.get_config_var("EXE") + base_dir = os.path.join(os.path.dirname(__file__), os.pardir) + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(base_dir, "third_party", "cuda", "bin", binary), + ] + for path in paths: + if os.path.exists(path) and os.path.isfile(path): + result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return path, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + @classmethod + @abstractmethod + def supports_target(target: GPUTarget): + raise NotImplementedError + + @abstractmethod + def hash(self) -> str: + """Returns a unique identifier for this backend""" + raise NotImplementedError + + @abstractmethod + def parse_options(self, options: dict) -> object: + """ + Converts an `options` dictionary into an arbitrary object and returns it. + This function may contain target-specific heuristics and check the legality of the provided options + """ + raise NotImplementedError + + @abstractmethod + def add_stages(self, stages: dict, options: object) -> None: + """ + Populates `stages` dictionary with entries of the form: + ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] + The value of each entry may populate a `metadata` dictionary. + Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. + All stages are expected to return a `str` object, except for the last stage which returns + a `bytes` object for execution by the launcher. + """ + raise NotImplementedError + + @abstractmethod + def load_dialects(self, context): + """ + Load additional MLIR dialects into the provided `context` + """ + raise NotImplementedError + + @abstractmethod + def get_module_map(self) -> Dict[str, ModuleType]: + """ + Return a map of interface modules to their device-specific implementations + """ + raise NotImplementedError + + @staticmethod + def parse_attr(desc): + assert isinstance(desc, str) + ret = [] + if "D" in desc: + ret += [["tt.divisibility", 16]] + return ret + + @staticmethod + def get_arg_specialization(arg, ty, **kwargs): + """ + Return a string unique to each possible specialization of the argument + """ + if ty == "int" and arg % 16 == 0 and kwargs.get("align", False): + return "D" + if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False): + return "D" + return "" diff --git a/third_party/enflame/include/triton/python/triton/backends/driver.py b/third_party/enflame/include/triton/python/triton/backends/driver.py new file mode 100644 index 000000000..6606b21ca --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/backends/driver.py @@ -0,0 +1,53 @@ +from abc import ABCMeta, abstractmethod +from typing import Callable, List, Protocol, Sequence + + +class Benchmarker(Protocol): + + def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]: + pass + + +class DriverBase(metaclass=ABCMeta): + + @classmethod + @abstractmethod + def is_active(self): + pass + + @abstractmethod + def get_current_target(self): + pass + + @abstractmethod + def get_active_torch_device(self): + pass + + @abstractmethod + def get_benchmarker(self) -> Benchmarker: + """ + Return the benchmarking function that this backend should use by default. + """ + raise NotImplementedError + + def __init__(self) -> None: + pass + + +class GPUDriver(DriverBase): + + def __init__(self): + # TODO: support other frameworks than torch + import torch + self.get_device_capability = torch.cuda.get_device_capability + try: + from torch._C import _cuda_getCurrentRawStream + self.get_current_stream = _cuda_getCurrentRawStream + except ImportError: + self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + + # TODO: remove once TMA is cleaned up + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args diff --git a/third_party/enflame/include/triton/python/triton/compiler/__init__.py b/third_party/enflame/include/triton/python/triton/compiler/__init__.py new file mode 100644 index 000000000..f055926fa --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/compiler/__init__.py @@ -0,0 +1,4 @@ +from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict +from .errors import CompilationError + +__all__ = ["compile", "make_backend", "ASTSource", "IRSource", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/third_party/enflame/include/triton/python/triton/compiler/code_generator.py b/third_party/enflame/include/triton/python/triton/compiler/code_generator.py new file mode 100644 index 000000000..9b1847957 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/compiler/code_generator.py @@ -0,0 +1,1406 @@ +import ast +import inspect +import re +import warnings +import os +import textwrap +import itertools +from types import ModuleType +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List + +from .. import language +from .._C.libtriton import ir +from ..language import constexpr, semantic, str_to_ty, tensor +from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, base_value, base_type +from ..runtime.jit import get_jit_fn_file_line +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .._utils import find_paths_if, get_iterable_path, set_iterable_path + +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) + + +def check_identifier_legality(name, type): + pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$' + if not re.match(pattern, name): + raise CompilationError(f"invalid {type} identifier: {name}", name) + return name + + +def mangle_ty(ty): + if ty.is_tuple(): + return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T' + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + raise TypeError(f'Unsupported type {ty}') + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_value(o: Any) -> bool: + return isinstance(o, base_value) + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return o is None or isinstance(o, (constexpr, language.core.dtype)) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +def _is_namedtuple(val): + return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields") + + +def _apply_to_tuple_values(value, fn): + if _is_namedtuple(type(value)): + fields = value._fields + elif isinstance(value, language.tuple): + fields = value.type.fields + else: + assert False, f"Unsupported type {type(value)}" + + vals = [fn(v) for v in value] + types = [v.type for v in vals] + return language.tuple(vals, language.tuple_type(types, fields)) + + +def flatten_values_to_ir(values: Iterable[base_value]): + handles = [] + for v in values: + v._flatten_ir(handles) + return handles + + +def unflatten_ir_values(handles: List[ir.value], types: List[base_type]): + cursor = 0 + for ty in types: + value, cursor = ty._unflatten_ir(handles, cursor) + yield value + assert cursor == len(handles) + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + return any(self.visit(s) for s in body) + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) is ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class ASTFunction: + + def __init__(self, ret_types, arg_types, constants, attrs): + self.ret_types = ret_types + self.arg_types = arg_types + self.constants = constants + self.attrs = attrs + + def return_types_ir(self, builder: ir.builder): + ret_types = [] + for ret_ty in self.ret_types: + if ret_ty is None: + continue + ir_ty = ret_ty.to_ir(builder) + if isinstance(ir_ty, list): + ret_types.extend(ir_ty) + else: + ret_types.append(ir_ty) + return ret_types + + def serialize(self, builder: ir.builder): + # fill up IR values in template + # > build function + is_val = lambda path, _: path not in self.constants and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val)) + arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths] + ret_types = self.return_types_ir(builder) + return builder.get_function_ty(arg_types, ret_types) + + def deserialize(self, fn): + # create "template" + def make_template(ty): + if isinstance(ty, (list, tuple, language.tuple_type)): + return language.tuple([make_template(x) for x in ty], ty) + return language.constexpr(None) + + vals = make_template(self.arg_types) + is_val = lambda path, _: path not in self.constants and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val)) + # > set attributes + for attr_path, attr_specs in self.attrs.items(): + for attr_name, attr_val in attr_specs: + if attr_path in val_paths: + fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) + for i, path in enumerate(val_paths): + ty = get_iterable_path(self.arg_types, path) + if isinstance(ty, nv_tma_desc_type): + fn.set_arg_attr(i, "tt.nv_tma_desc", 1) + # > add IR values to the template + for i, path in enumerate(val_paths): + ty = get_iterable_path(self.arg_types, path) + set_iterable_path(vals, path, language.tensor(fn.args(i), ty)) + # > add constexpr values to the template + constants = self.constants + for path, val in constants.items(): + set_iterable_path(vals, path, language.constexpr(val)) + return vals + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map, + module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, + file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.builder.module_map = {} if module_map is None else module_map + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + + self.gscope = {} + for k, v in gscope.items(): + if isinstance(v, ModuleType): + self.gscope[k] = module_map.get(v.__name__, v) + continue + + module_name = getattr(v, "__module__", "") + if module_name in module_map: + self.gscope[k] = getattr(module_map[module_name], v.__name__) + else: + self.gscope[k] = v + + self.lscope = {} + self.jit_fn = jit_fn + # TODO: we currently generate illegal names for non-kernel functions involving constexprs! + if is_kernel: + function_name = check_identifier_legality(function_name, "function") + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if any([ + val is absent, name in self.builtin_namespace, # + type(val) is ModuleType, # + isinstance(val, JITFunction), # + getattr(val, "__triton_builtin__", False), # + getattr(val, "__module__", "").startswith("triton.language"), # + isinstance(val, language.dtype), # + _is_namedtuple(val), + self._is_constexpr_global(name), # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + self.visiting_arg_default_value, # + os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1" + ]): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from + annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[base_value, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = language.tuple([self.visit(elt) for elt in node.elts]) + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + handles = [] + + def decay(value): + if isinstance(value, language.tuple): + return _apply_to_tuple_values(value, decay) + elif isinstance(value, (language.constexpr, int, float)): + return semantic.to_tensor(value, self.builder) + return value + + ret_value = decay(ret_value) + + if ret_value is None: + ret_ty = language.void + else: + assert isinstance(ret_value, language.core.base_value) + ret_value._flatten_ir(handles) + ret_ty = ret_value.type + self.builder.ret(handles) + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + # A return op must always terminate the basic block, so we create a dead + # basic block in case there are any ops after the return. + post_ret_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(post_ret_block) + + def visit_Starred(self, node) -> Any: + args = self.visit(node.value) + assert isinstance(args, language.core.tuple) + return args.values + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults[::-1]): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + fn_ty = self.prototype.serialize(self.builder) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = self.prototype.deserialize(self.fn) + # bind arguments to symbols + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + insert_pt = self.builder.get_insertion_block() + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + + # finalize function + assert not self.builder.get_insertion_block().has_terminator() + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + if isinstance(self.ret_type, language.tuple_type): + self.prototype.ret_types = self.ret_type.types + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.serialize(self.builder)) + self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)]) + self.fn.finalize() + + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def assignTarget(self, target, value): + if isinstance(target, ast.Subscript): + assert target.ctx.__class__.__name__ == "Store" + return self.visit_Subscript_Store(target, value) + if isinstance(target, ast.Tuple): + assert target.ctx.__class__.__name__ == "Store" + for i, name in enumerate(target.elts): + self.set_value(self.visit(name), value.values[i]) + return + assert isinstance(target, ast.Name) + self.set_value(self.visit(target), value) + + def visit_Assign(self, node): + # construct values to assign + def _sanitize_value(value): + if isinstance(value, language.tuple): + return _apply_to_tuple_values(value, _sanitize_value) + native_nontensor_types = (language.dtype, language.tuple) + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): + value = semantic.to_tensor(value, self.builder) + return value + + values = _sanitize_value(self.visit(node.value)) + targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets + assert len(targets) == 1 + self.assignTarget(targets[0], values) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return language.tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721 + assert type_equal and defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name]}, '\ + f'but the {block_name} block redefines it as {defs[name]}' + if name in then_defs or name in else_defs: + names.append(name) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in sorted(then_defs.keys() & else_defs.keys()): + if name in names: + continue + then_val = then_defs[name] + then_ty = then_val.type + else_val = else_defs[name] + else_ty = else_val.type + type_equal = type(then_val) == type(else_val) # noqa: E721 + assert type_equal and then_ty == else_ty, \ + f'Mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + + return then_defs, else_defs, then_block, else_block, names + + def visit_if_top_level(self, cond, node): + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create basic-block after conditional + endif_block = self.builder.create_block() + # then terminator + self.builder.set_insertion_point_to_end(then_block) + assert not then_block.has_terminator(), f"{then_block}" + then_handles = flatten_values_to_ir(then_defs[name] for name in names) + self.builder.create_branch(endif_block, then_handles) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + assert not else_block.has_terminator(), f"{else_block}" + else_handles = flatten_values_to_ir(else_defs[name] for name in names) + self.builder.create_branch(endif_block, else_handles) + assert len(then_handles) == len(else_handles) + for then_h, else_h in zip(then_handles, else_handles): + ty = then_h.get_type() + assert ty == else_h.get_type() + endif_block.add_argument(ty) + + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + res_handles = [endif_block.arg(i) for i in range(len(then_handles))] + types = [then_defs[name].type for name in names] + new_values = unflatten_ir_values(res_handles, types) + for name, new_value in zip(names, new_values): + self.set_value(name, new_value) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + then_handles = flatten_values_to_ir(then_defs[name] for name in names) + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op(then_handles) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + else_handles = flatten_values_to_ir(else_defs[name] for name in names) + self.builder.create_yield_op(else_handles) + # update values + res_handles = [if_op.get_result(i) for i in range(len(then_handles))] + types = [then_defs[name].type for name in names] + new_values = unflatten_ir_values(res_handles, types) + for name, new_value in zip(names, new_values): + self.set_value(name, new_value) + + def visit_If(self, node): + cond = self.visit(node.test) + + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if contains_return: + if self.scf_stack: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + self.visit_if_top_level(cond, node) + else: + self.visit_if_scf(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + + active_block = node.body if cond else node.orelse + self.visit_compound_statement(active_block) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = semantic.to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = semantic.to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) is ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) is ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def _verify_loop_carried_variable(self, name, loop_val, live_val): + assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' + assert type(loop_val) is type(live_val), f'Loop carried variable {name} changed type' + assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ + f'Loop-carried variable {name} has initial type {live_val.type} '\ + f'but is re-assigned to {loop_val.type} in loop! '\ + f'Please make sure that the type stays consistent.' + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + loop_val = loop_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + # these are loop-carried values + names.append(name) + init_args.append(live_val) + + init_handles = flatten_values_to_ir(init_args) + init_tys = [h.get_type() for h in init_handles] + init_fe_tys = [a.type for a in init_args] + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op(init_tys, init_handles) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys) + self.builder.set_insertion_point_to_start(before_block) + block_args = [before_block.arg(i) for i in range(len(init_handles))] + condition_args = unflatten_ir_values(block_args, init_fe_tys) + for name, val in zip(names, condition_args): + self.lscope[name] = val + self.local_defs[name] = val + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, block_args) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + body_handles = [after_block.arg(i) for i in range(len(init_handles))] + body_args = unflatten_ir_values(body_handles, init_fe_tys) + for name, val in zip(names, body_args): + self.lscope[name] = val + self.local_defs[name] = val + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + loop_defs[name]._flatten_ir(yields) + + self.builder.create_yield_op(yields) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + result_handles = [while_op.get_result(i) for i in range(len(init_handles))] + result_vals = unflatten_ir_values(result_handles, init_fe_tys) + for name, new_def in zip(names, result_vals): + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript_Load(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_Subscript_Store(self, node, value): + assert node.ctx.__class__.__name__ == "Store" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + assert isinstance(lhs, language.tuple) + lhs.__setitem__(slices, value) + + def visit_Subscript(self, node): + return self.visit_Subscript_Load(node) + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + loop_unroll_factor = None + disallow_acc_multi_buffer = False + flatten = False + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + loop_unroll_factor = iterator.loop_unroll_factor + disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer + flatten = iterator.flatten + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = semantic.to_tensor(lb, self.builder) + ub = semantic.to_tensor(ub, self.builder) + step = semantic.to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_poison(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + loop_val = self.local_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + names.append(name) + init_args.append(live_val) + yields.append(loop_val) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + init_handles = flatten_values_to_ir(init_args) + init_tys = [v.type for v in init_args] + for_op = self.builder.create_for_op(lb, ub, step, init_handles) + if _unwrap_if_constexpr(num_stages) is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if _unwrap_if_constexpr(loop_unroll_factor) is not None: + for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + if disallow_acc_multi_buffer: + for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr()) + if flatten: + for_op.set_attr("tt.flatten", self.builder.get_unit_attr()) + + self.scf_stack.append(node) + for_op_body = for_op.get_body(0) + self.builder.set_insertion_point_to_start(for_op_body) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))] + block_args = unflatten_ir_values(block_handles, init_tys) + for name, val in zip(names, block_args): + self.set_value(name, val) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + local = self.local_defs[name] + if isinstance(local, constexpr): + local = semantic.to_tensor(local, self.builder) + yields.append(local) + + # create YieldOp + if len(yields) > 0: + yield_handles = flatten_values_to_ir(yields) + self.builder.create_yield_op(yield_handles) + for_op_region = for_op_body.get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op_body) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + result_handles = [for_op.get_result(i) for i in range(len(init_handles))] + result_values = unflatten_ir_values(result_handles, init_tys) + for name, val in zip(names, result_values): + self.set_value(name, val) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return language.slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + for i, arg in enumerate(args): + if isinstance(arg, (language.dtype, float, int, bool, JITFunction)): + args[i] = language.core.constexpr(arg) + args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) + args_cst = {path: get_iterable_path(args, path) for path in args_cst} + args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x)) + args_val = [get_iterable_path(args, path) for path in args_path] + # mangle + fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) + # generate function def if necessary + if not self.module.has_function(fn_name): + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = get_jit_fn_file_line(fn) + arg_types = [ + language.core.constexpr if arg is None or isinstance(arg, + (bool, int, language.core.dtype)) else arg.type + for arg in args + ] + prototype = ASTFunction([], arg_types, args_cst, dict()) + generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn, + function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, + module_map=self.builder.module_map) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + args_val = [arg.handle for arg in args_val] + call_op = self.builder.call(symbol, args_val) + if callee_ret_type == language.void: + return None + handles = [call_op.get_result(i) for i in range(call_op.get_num_results())] + return next(unflatten_ir_values(handles, [callee_ret_type])) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args)) + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = {"_builder": self.builder} + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + ret = fn(*args, **extra_kwargs, **kws) + # builtin functions return plain tuples for readability + if isinstance(ret, tuple): + ret = language.tuple(ret) + return ret + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, None) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + ret = fn(*args, **kws) + return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs) and node.attr == "T": + return semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def ast_to_ttir(fn, src, context, options, codegen_fns, module_map): + arg_types = list(map(str_to_ty, src.signature.values())) + prototype = ASTFunction([], arg_types, src.constants, src.attrs) + file_name, begin_line = get_jit_fn_file_line(fn) + # query function representation + from collections import namedtuple + leaves = filter(lambda v: len(v) == 1, src.constants) + constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves} + signature = src.signature + proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature) + generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(proxy), jit_fn=fn, + is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, + codegen_fns=codegen_fns, module_map=module_map) + generator.visit(fn.parse()) + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/third_party/enflame/include/triton/python/triton/compiler/compiler.py b/third_party/enflame/include/triton/python/triton/compiler/compiler.py new file mode 100644 index 000000000..772b76d1e --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/compiler/compiler.py @@ -0,0 +1,441 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import GPUTarget +from .. import __version__ +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +from ..tools.disasm import get_sass +# TODO: this shouldn't be here +from .code_generator import ast_to_ttir +from pathlib import Path +import re +import functools +import os +import sysconfig + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ptx": ptx_prototype_pattern, +} + +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + tma = re.search(r'tt.nv_tma_desc = 1', x) + if tma is not None: + return 'nvTmaDesc' + x = re.sub(r' {[^}]+}', '', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +class ASTSource: + + def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = dict() + if constexprs is not None: + for k, v in constexprs.items(): + k = (fn.arg_names.index(k), ) if isinstance(k, str) else k + assert isinstance(k, tuple) + self.constants[k] = v + self.attrs = attrs or dict() + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + else: + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x) + constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())]) + key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path, context, backend): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + ir.load_dialects(context) + backend.load_dialects(context) + + # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. + # TODO - replace with a proper parser + if self.ext == "ptx": + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + else: + self.module = ir.parse_mlir_module(self.path, context) + fn_name = self.module.get_entry_func_name() + self.name = "@" + fn_name + funcOp = self.module.get_function(fn_name) + func_ty = self.module.get_function_signature(funcOp) + self.signature = {k: ty for k, ty in enumerate(func_ty)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + self.module.context = context + return self.module + + def parse_options(self): + if self.ext == "ttgir": + num_warps = self.module.get_int_attr("ttg.num-warps") + assert num_warps is not None, "Unable to parse ttg.num-warps attribute" + return {'num_warps': num_warps} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] + with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx" or ext == "amdgcn": + return Path(full_name).read_text() + if ext == "cubin" or ext == "hsaco": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1": + return + + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def compile(src, target=None, options=None): + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + context = ir.context() + src = IRSource(src, context, backend) + + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + store_only_binary = os.environ.get("TRITON_STORE_BINARY_ONLY", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. + # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". + # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate + # the file name to 150 characters to be safe. + file_name = src.name[:150] + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + return CompiledKernel(src, metadata_group, hash) + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + metadata["triton_version"] = __version__ + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + + # For IRSource, we have already grabbed the context + called both + # ir.load_dialects and backend.load_dialects. + if not isinstance(src, IRSource): + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + + codegen_fns = backend.get_codegen_implementation(options) + module_map = backend.get_module_map() + try: + module = src.make_ir(options, codegen_fns, module_map, context) + except Exception as e: + filter_traceback(e) + raise + use_ir_loc = os.environ.get("USE_IR_LOC", None) + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{file_name}.{ext}" + if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): + print(f"\nOverriding kernel with file {full_name}") + next_module = parse(full_name, ext, context) + # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json + if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")): + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + # use an env variable to parse ir from file + if use_ir_loc == ext: + ir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ir_full_name) + print(f"Creating new locations for {ir_full_name}") + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # Compilation completed, disabling multithreading in context. + # This is needed to safely finalize threads pool inside context: if current process forks before + # python GC deletes context object, thread pool in child process will be invalid, which could + # lead to child crash or hang. + # + # However disabling multithreading causes the code to hang if the ASAN pass is enabled + # this is likely due to the llvm-symbolizer forking a process + # TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling + # multithreading in the MLIR context + if not os.environ.get("TRITON_ENABLE_ASAN", "0") == "1": + context.disable_multithreading() + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class AsmDict(dict): + + def __missing__(self, key): + + if key == "sass": + value = get_sass(self["cubin"]) + else: + raise KeyError("Unknown key: '%s'" % key) + + self[key] = value + return value + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = AsmDict({ + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + }) + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None: + # Use blackwell max tmem size for now, this should be moved in device properties + max_tmem_size = 512 # tmem size in number of columns + if self.metadata.tmem_size > max_tmem_size: + raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) + + return runner diff --git a/third_party/enflame/include/triton/python/triton/compiler/errors.py b/third_party/enflame/include/triton/python/triton/compiler/errors.py new file mode 100644 index 000000000..39e6c4dfb --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/compiler/errors.py @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/third_party/enflame/include/triton/python/triton/compiler/make_launcher.py b/third_party/enflame/include/triton/python/triton/compiler/make_launcher.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/include/triton/python/triton/errors.py b/third_party/enflame/include/triton/python/triton/errors.py new file mode 100644 index 000000000..3a0a86355 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/errors.py @@ -0,0 +1,5 @@ +"""Base class for all errors raised by Triton""" + + +class TritonError(Exception): + ... diff --git a/third_party/enflame/include/triton/python/triton/language/__init__.py b/third_party/enflame/include/triton/python/triton/language/__init__.py new file mode 100644 index 000000000..b4cb1df50 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/__init__.py @@ -0,0 +1,311 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + _experimental_descriptor_load, + _experimental_descriptor_store, + _experimental_make_tensor_descriptor, + _experimental_reinterpret_tensor_descriptor, + _experimental_tensor_descriptor, + add, + advance, + arange, + associative_scan, + assume, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + constexpr, + debug_barrier, + device_assert, + device_print, + dot, + dot_scaled, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + gather, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + nv_tma_desc_type, + program_id, + range, + reduce, + reshape, + slice, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + tuple, + tuple_type, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) + +__all__ = [ + "PropagateNan", + "TRITON_MAX_TENSOR_NUMEL", + "_experimental_descriptor_load", + "_experimental_descriptor_store", + "_experimental_make_tensor_descriptor", + "_experimental_reinterpret_tensor_descriptor", + "_experimental_tensor_descriptor", + "abs", + "add", + "advance", + "arange", + "argmax", + "argmin", + "associative_scan", + "assume", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "block_type", + "broadcast", + "broadcast_to", + "cat", + "cast", + "cdiv", + "ceil", + "clamp", + "const", + "constexpr", + "cos", + "cumprod", + "cumsum", + "debug_barrier", + "device_assert", + "device_print", + "div_rn", + "dot", + "dot_scaled", + "dtype", + "erf", + "exp", + "exp2", + "expand_dims", + "extra", + "fdiv", + "flip", + "float16", + "float32", + "float64", + "float8e4b15", + "float8e4nv", + "float8e4b8", + "float8e5", + "float8e5b16", + "floor", + "fma", + "full", + "gather", + "histogram", + "inline_asm_elementwise", + "interleave", + "int1", + "int16", + "int32", + "int64", + "int8", + "join", + "load", + "log", + "log2", + "make_block_ptr", + "math", + "max", + "max_constancy", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "permute", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "nv_tma_desc_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "range", + "ravel", + "reduce", + "reshape", + "rsqrt", + "slice", + "sigmoid", + "sin", + "softmax", + "sort", + "split", + "sqrt", + "sqrt_rn", + "static_assert", + "static_print", + "static_range", + "store", + "sum", + "swizzle2d", + "tensor", + "trans", + "tuple", + "uint16", + "uint32", + "uint64", + "uint8", + "uint_to_uniform_float", + "umulhi", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] + + +def str_to_ty(name): + from builtins import tuple + + if isinstance(name, tuple): + fields = type(name).__dict__.get("_fields", None) + return tuple_type([str_to_ty(x) for x in name], fields) + + if name[0] == "*": + name = name[1:] + const = False + if name[0] == "k": + name = name[1:] + const = True + ty = str_to_ty(name) + return pointer_type(element_ty=ty, const=const) + + if name == "nvTmaDesc": + return nv_tma_desc_type() + + if name == "constexpr": + return constexpr + + tys = { + "fp8e4nv": float8e4nv, + "fp8e4b8": float8e4b8, + "fp8e5": float8e5, + "fp8e5b16": float8e5b16, + "fp8e4b15": float8e4b15, + "fp16": float16, + "bf16": bfloat16, + "fp32": float32, + "fp64": float64, + "i1": int1, + "i8": int8, + "i16": int16, + "i32": int32, + "i64": int64, + "u1": int1, + "u8": uint8, + "u16": uint16, + "u32": uint32, + "u64": uint64, + "B": int1, + } + return tys[name] diff --git a/third_party/enflame/include/triton/python/triton/language/_utils.py b/third_party/enflame/include/triton/python/triton/language/_utils.py new file mode 100644 index 000000000..b9aa69071 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/_utils.py @@ -0,0 +1,21 @@ +from typing import List + +TRITON_MAX_TENSOR_NUMEL = 1048576 + + +def is_power_of_two(x): + return (x & (x - 1)) == 0 + + +def validate_block_shape(shape: List[int]): + numel = 1 + for i, d in enumerate(shape): + if not isinstance(d, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + if not is_power_of_two(d): + raise ValueError(f"Shape element {i} must be a power of 2") + numel *= d + + if numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + return numel diff --git a/third_party/enflame/include/triton/python/triton/language/core.py b/third_party/enflame/include/triton/python/triton/language/core.py new file mode 100644 index 000000000..4c9bba7e6 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/core.py @@ -0,0 +1,3068 @@ +from __future__ import annotations + +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple +import builtins +from ..runtime.jit import jit +import inspect +import os + +from .._C.libtriton import ir +from . import semantic +from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape + +T = TypeVar('T') + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _builder, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _builder=None): + return semantic.to_tensor(x, _builder) + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + self.type = constexpr + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) + + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, constexpr) else o + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +class base_value: + """Base class of values that exist in the triton IR (i.e. not constexprs). + """ + type: base_type + + def _flatten_ir(self, handles: List[ir.value]) -> None: + """Flatten frontend value into a sequence of mlir handles, which are appended + to the output list + """ + raise NotImplementedError + + +class base_type: + + def __eq__(self, other): + raise NotImplementedError("Types must implement __eq__") + + def __ne__(self, other): + return not (self == other) + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + """Build a frontend value with the current dtype, wrapping a list of existing handles. + cursor is the index of the first handle relevant to this value, and the function + should return the updated cursor position after any handles consumed by the created value. + """ + raise NotImplementedError + + +# ----------------------- +# dtype +# ----------------------- + + +class dtype(base_type): + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + class KIND(Enum): + BOOLEAN = 0 + INTEGRAL = 1 + FLOATING = 2 + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 52 + self.primitive_bitwidth = 64 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def kind(self): + # Return int value following the type ordering bool < integer < fp + if self.is_bool(): + return dtype.KIND.BOOLEAN + elif self.is_int(): + return dtype.KIND.INTEGRAL + else: + assert self.is_floating() + return dtype.KIND.FLOATING + + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + @staticmethod + def is_tuple(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + if self.name not in builder.options.supported_fp8_dtypes: + raise ValueError(f'type {self} not supported in this architecture. ' + f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') + if self.name in builder.options.deprecated_fp8_dtypes: + warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release") + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + return tensor(handles[cursor], self), cursor + 1 + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False): + element_ty = _unwrap_if_constexpr(element_ty) + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.') + self.element_ty = element_ty + self.address_space = address_space + self.const = const + self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def is_const(self): + return self.const + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const + + @property + def scalar(self): + return self + + +class nv_tma_desc_type(pointer_type): + + def __init__(self, const=True, address_space=0): + super().__init__(uint8, const=const, address_space=address_space) + self.name = 'nv_tma_desc_type' + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + assert (isinstance(shape, (list, tuple))) + + # shape can be empty ([]) when an input is a 0D tensor. + self.shape = tuple(_unwrap_shape(shape)) + if not self.shape: + raise TypeError('0d block_type is forbidden') + + self.numel = validate_block_shape(self.shape) + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + @property + def scalar(self): + return self.element_ty + + +class tuple_type(base_type): + + def __init__(self, types, fields=None): + self.types = types + self.fields = fields or [''] * len(types) + self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' + + def __str__(self): + return self.name + + def __iter__(self): + return iter(self.types) + + def to_ir(self, builder: ir.builder): + return [ty.to_ir(builder) for ty in self.types] + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def is_tuple(self): + return True + + def __eq__(self, other): + return type(self) is type(other) and self.types == other.types and self.fields == other.fields + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: + values = [] + for ty in self.types: + value, cursor = ty._unflatten_ir(handles, cursor) + values.append(value) + return tuple(values, self), cursor + + +class slice_type(dtype): + + def __init__(self): + self.name = 'slice_type' + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +# ----------------------- +# tensor +# ----------------------- + + +class tensor(base_value): + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + super().__init__() + # IR handle + self.handle = handle + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = tuple([constexpr(s) for s in self.shape]) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _builder=None): + return add(self, other, sanitize_overflow=True, _builder=_builder) + + @builtin + def __radd__(self, other, _builder=None): + return add(other, self, sanitize_overflow=True, _builder=_builder) + + @builtin + def __sub__(self, other, _builder=None): + return sub(self, other, sanitize_overflow=True, _builder=_builder) + + @builtin + def __rsub__(self, other, _builder=None): + return sub(other, self, sanitize_overflow=True, _builder=_builder) + + @builtin + def __mul__(self, other, _builder=None): + return mul(self, other, sanitize_overflow=True, _builder=_builder) + + @builtin + def __rmul__(self, other, _builder=None): + return mul(other, self, sanitize_overflow=True, _builder=_builder) + + @builtin + def __truediv__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.truediv(self, other, _builder) + + @builtin + def __rtruediv__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.truediv(other, self, _builder) + + @builtin + def __floordiv__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.floordiv(self, other, _builder) + + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.floordiv(other, self, _builder) + + @builtin + def __mod__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.mod(self, other, _builder) + + @builtin + def __rmod__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.mod(other, self, _builder) + + # unary operators + @builtin + def __neg__(self, _builder=None): + return semantic.minus(self, _builder) + + @builtin + def __invert__(self, _builder=None): + return semantic.invert(self, _builder) + + # bitwise operators + + @builtin + def __and__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.and_(self, other, _builder) + + @builtin + def __rand__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.and_(other, self, _builder) + + @builtin + def __or__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.or_(self, other, _builder) + + @builtin + def __ror__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.or_(other, self, _builder) + + @builtin + def __xor__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.xor_(self, other, _builder) + + @builtin + def __rxor__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.xor_(other, self, _builder) + + @builtin + def __lshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + return semantic.shl(self, other, _builder) + + @builtin + def __rlshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _unwrap_if_constexpr(other) + return semantic.shl(other, self, _builder) + + @builtin + def __rshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + @builtin + def __rrshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _unwrap_if_constexpr(other) + if self.dtype.is_int_signed(): + return semantic.ashr(other, self, _builder) + else: + return semantic.lshr(other, self, _builder) + + # > + @builtin + def __gt__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) + + @builtin + def __rgt__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) + + # >= + @builtin + def __ge__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + + @builtin + def __rge__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) + + # < + @builtin + def __lt__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) + + @builtin + def __rlt__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.less_than(other, self, _builder) + + # <= + @builtin + def __le__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) + + @builtin + def __rle__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) + + # == + @builtin + def __eq__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.equal(self, other, _builder) + + @builtin + def __req__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.equal(other, self, _builder) + + @builtin + def __ne__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) + + @builtin + def __rne__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.not_equal(other, self, _builder) + + @builtin + def logical_and(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _builder=None): + return semantic.not_(self, _builder) + + @builtin + def __getitem__(self, slices, _builder=None): + import builtins + if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None: + slices = [slices] + if isinstance(slices, tuple): + slices = slices.values + ret = self + for dim, sl in enumerate(slices): + if sl is None or isinstance(sl, constexpr) and sl.value is None: + ret = semantic.expand_dims(ret, dim, _builder) + elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None: + pass + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Alias for :py:func:`tensor.cast`. + """ + return cast(self, dtype, fp_downcast_rounding, bitcast, _builder=_builder) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def gather(self, indices, axis) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False, dtype=None) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +class tuple(base_value): + + def __init__(self, args: list, type: tuple_type = None): + self.values = [i for i in args] + + def get_type(x): + if isinstance(x, dtype): + return dtype + if isinstance(x, int): + return constexpr + return x.type + + self.type = type or tuple_type([get_type(x) for x in self.values]) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + import builtins + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) + + def __getattr__(self, name): + return self.values[self.type.fields.index(name)] + + # TODO: remove + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value + + def __add__(self, other): + if isinstance(other, list): + other = tuple(other) + return tuple(self.values + other.values) + # return tuple(a + b for a, b in zip(self.values, other.values)) + + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) + + def __eq__(self, other): + import builtins + if isinstance(other, (list, builtins.tuple)): + other = tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + import builtins + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + def _flatten_ir(self, handles: List[ir.value]): + for v in self.values: + v._flatten_ir(handles) + + +class slice: + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.type = slice_type() + + +class tensor_descriptor_base_type(base_type): + + def __init__(self, block_type: block_type): + self.block_type = block_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]: + value = _experimental_tensor_descriptor_base(handles[cursor], self.block_type) + return value, cursor + 1 + + def to_ir(self, builder: ir.builder): + return builder.create_tensor_descriptor_type(self.block_type.to_ir(builder)) + + def __str__(self) -> str: + # ex. "tensor_descriptor" + return f"tensor_descriptor<{self.block_type}>" + + def __eq__(self, other) -> bool: + if type(other) is not type(self): + return False + return self.block_type == other.block_type + + def __neq__(self, other) -> bool: + return not (self == other) + + +class _experimental_tensor_descriptor_base(base_value): + """" + A tensor descriptor with unknown shape and strides + """ + + def __init__(self, handle, block_type: block_type): + """Not called by user code.""" + super().__init__() + + self.handle = handle # IR handle + self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor: + """Load a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be filled with zeros. + + :note: Offset must be a multiple of 16-bytes + """ + return semantic.descriptor_load(self, offsets, "", "", _builder) + + @builtin + def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor: + """Store a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be ignored. + + :note: Offset must be a multiple of 16-bytes + """ + return semantic.descriptor_store(self, value, offsets, _builder) + + @builtin + def gather(self, *args, _builder=None) -> tensor: + """Gather multiple descriptors worth of data""" + assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}" + x_offsets = args[0] + y_offset = args[1] + return semantic.descriptor_gather(self, x_offsets, y_offset, "", "", _builder) + + @builtin + def scatter(self, value, *args, _builder=None) -> tensor: + """Scatter multiple descriptors worth of data""" + assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}" + x_offsets = args[0] + y_offset = args[1] + return semantic.descriptor_scatter(self, value, x_offsets, y_offset, _builder) + + +class tensor_descriptor_type(tensor_descriptor_base_type): + + def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type): + self.block_type = block_type + self.shape_type = shape_type + self.strides_type = strides_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + shape = shape.values + strides = strides.values + value = _experimental_tensor_descriptor(handle, shape, strides, self.block_type) + return value, cursor + + def to_ir(self, builder: ir.builder): + return [super().to_ir(builder), *self.shape_type.to_ir(builder), *self.strides_type.to_ir(builder)] + + def __eq__(self, other): + return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type + == other.strides_type) + + +class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base): + """A descriptor representing a tensor in global memory. + """ + + def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type): + """Not called by user code.""" + # IR handle + super().__init__(handle, block_type) + self.type = tensor_descriptor_type( + block_type, + shape_type=tuple_type([s.type for s in shape]), + strides_type=tuple_type([s.type for s in strides]), + ) + # Global shape + self.shape = shape + self.strides = strides + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + handles.extend(s.handle for s in self.shape) + handles.extend(s.handle for s in self.strides) + + +def get_bool_env_var(var_name): + v = os.getenv(var_name, "0") + return v == "1" or v == "true" or v == "on" + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v + + +@builtin +def program_id(axis, _builder=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(1, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) + + +@builtin +def num_programs(axis, _builder=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _builder=None): + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +arange.__doc__ = f""" + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 +""" + + +def _unwrap_shape(shape): + shape = _constexpr_to_value(shape) + return [_constexpr_to_value(s) for s in shape] + + +def _shape_check_impl(shape): + shape = _unwrap_shape(shape) + validate_block_shape(shape) + return shape + + +@builtin +def full(shape, value, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param value: A scalar value to fill the array with + :type value: scalar + :param dtype: Data type of the new array, e.g., :code:`tl.float16` + :type dtype: tl.dtype + """ + shape = _shape_check_impl(shape) + value = _constexpr_to_value(value) + dtype = _constexpr_to_value(dtype) + return semantic.full(shape, value, dtype, _builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _builder=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return semantic.broadcast_impl_value(input, other, _builder) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _builder=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.broadcast_impl_shape(input, shape, _builder) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation, + effectively transposing a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + dims = _unwrap_iterable(dims) + if not dims: + dims = (1, 0) + return semantic.permute(input, dims, _builder) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to do a (1,0) permutation. + """ + dims = _unwrap_iterable(dims) + return semantic.permute(input, dims, _builder) + + +@builtin +def cat(input, other, can_reorder=False, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: Tensor + :param other: The second input tensor. + :type other: Tensor + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops). + Current implementation of `cat` supports only can_reorder=True. + """ + return semantic.cat(input, other, can_reorder, _builder) + + +@builtin +def join(a, b, _builder=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return semantic.join(a, b, _builder) + + +@jit +def _take_first(a, b): + return a + + +@_tensor_member_fn +@builtin +def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = semantic.expand_dims(a, 0, _builder) + + out_lhs, out_rhs = semantic.split(a, _builder) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator)) + out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator)) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _builder=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder=True, builder=_builder) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder, _builder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = semantic.to_tensor(input, _builder) + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :type dtype: tl.dtype + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional + """ + input = semantic.to_tensor(input, _builder) + dtype = _constexpr_to_value(dtype) + fp_downcast_rounding = _constexpr_to_value(fp_downcast_rounding) + bitcast = _constexpr_to_value(bitcast) + if bitcast: + return semantic.bitcast(input, dtype, _builder) + return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _builder=None): + """ + Returns the matrix product of two blocks. + + The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions. + For three-dimensional blocks, `tl.dot` performs the batched matrix product, + where the first dimension of each block represents the batch dimension. + + :param input: The first tensor to be multiplied. + :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None: + supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions + default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) + + +@builtin +def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, out_dtype=float32, + _builder=None): + """ + Returns the matrix product of two blocks in microscaling format. + + lhs and rhs use microscaling formats described here: + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + + Software emulation enables targeting hardware architectures without native microscaling + operation support. Right now for such case, microscaled lhs/rhs are upcasted to + :code:`bf16` element type beforehand for dot computation, with one exception: + for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type, + the other input is also upcasted to :code:`fp16` element type instead. + This behavior is experimental and may be subject to change in the future. + + :param lhs: The first tensor to be multiplied. + :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. + :param lhs_scale: Scale factor for lhs tensor. + :type lhs_scale: e8m0 type represented as an uint8 tensor. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}. + :type lhs_format: str + :param rhs: The second tensor to be multiplied. + :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. + :param rhs_scale: Scale factor for rhs tensor. + :type rhs_scale: e8m0 type represented as an uint8 tensor. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}. + :type rhs_format: str + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + """ + out_dtype = _constexpr_to_value(out_dtype) + assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment" + return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, out_dtype, + _builder) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, _builder=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be `None`, and + - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value. + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for + cache at all levels and "cg" stands for cache at global level (cache in L2 and below, not L1), see + `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _constexpr_to_value(mask) + other = _constexpr_to_value(other) + if mask is not None: + mask = semantic.to_tensor(mask, _builder) + if other is not None: + other = semantic.to_tensor(other, _builder) + padding_option = _constexpr_to_value(padding_option) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile, _builder) + + +@builtin +def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype, + _builder=None) -> _experimental_tensor_descriptor_base: + """ + Reinterpret a generic pointer as a TMA-backed tensor descriptor object. + """ + block_ty = block_type(_constexpr_to_value(dtype), block_shape) + return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder) + + +@builtin +def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): + """ + Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This loads a tensor of data based on the descriptor and offsets. + """ + desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder) + return desc.load(offsets, _builder=_builder) + + +@builtin +def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): + """ + Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This stores a tensor of data based on the descriptor and offsets. + """ + desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder) + return desc.store(offsets, value, _builder=_builder) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for + cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt" + stands for cache write-through, see `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"} + """ + # `value` can be constexpr + value = semantic.to_tensor(value, _builder) + mask = _constexpr_to_value(mask) + if mask is not None: + mask = semantic.to_tensor(mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) + + +@_tensor_member_fn +@builtin +def advance(base, offsets, _builder=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return semantic.advance(base, offsets, _builder) + + +@builtin +def _experimental_make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + _builder=None, +) -> _experimental_tensor_descriptor: + """Make an experimental tensor descriptor object + + :param base: the base pointer of the tensor, must be 16-byte aligned + :param shape: A list of non-negative integers representing the tensor shape + :param strides: A list of tensor strides. Leading dimensions must be multiples + of 16-byte strides and the last dimension must be contiguous. + :param block_shape: The shape of block to be loaded/stored from global memory + + Notes + ***** + On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object + and loads and stores from the descriptor will be backed by the TMA hardware. + + Currently only 2-5 dimensional tensors are supported. + + Example + ******* + .. code-block:: python + + @triton.jit + def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl._experimental_make_tensor_descriptor( + in_out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + value = desc.load([moffset, noffset]) + desc.store([moffset, noffset], tl.abs(value)) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + M, N = 256, 256 + x = torch.randn(M, N, device="cuda") + M_BLOCK, N_BLOCK = 32, 32 + grid = (M / M_BLOCK, N / N_BLOCK) + inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK) + + """ + return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire", + "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, + the function defaults to using "acq_rel" semantics. + :type sem: str, optional + :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation. + Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + :type scope: str, optional + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): + cmp = semantic.to_tensor(cmp, _builder) + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_add(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_max(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_min(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_and(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_or(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _builder=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = semantic.to_tensor(condition, _builder) + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.where(condition, x, y, _builder) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def add(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.add(x, y, sanitize_overflow, _builder) + + +@builtin +def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.sub(x, y, sanitize_overflow, _builder) + + +@builtin +def mul(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.mul(x, y, sanitize_overflow, _builder) + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.minimum(x, y, propagate_nan, _builder) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.maximum(x, y, propagate_nan, _builder) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = semantic.to_tensor(x, _builder) + min = semantic.to_tensor(min, _builder) + max = semantic.to_tensor(max, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + min = _promote_bfloat16_to_float32(min, _builder=_builder) + max = _promote_bfloat16_to_float32(max, _builder=_builder) + + propagate_nan = _constexpr_to_value(propagate_nan) + + return semantic.clamp(x, min, max, propagate_nan, _builder) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None, + dtype_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value + :type {return_indices_arg}: bool""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN + :type {tie_break_arg}: bool""" + if dtype_arg is not None: + docstr += f""" + :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. This is useful for preventing data overflows. If not specified, integer and bool dtypes are upcasted to :code:`tl.int32` and float dtypes are upcasted to at least :code:`tl.float32`. + :type {dtype_arg}: tl.dtype""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int | None + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + param_types = [t.type.scalar for t in input] * 2 + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _builder=_builder) + return t + + axis = _constexpr_to_value(axis) + keep_dims = _constexpr_to_value(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = semantic.reduction(input, axis, make_combine_region, _builder) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _builder=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :type input: Tensor + :param axis: the dimension along which the scan should be done + :type axis: int""" + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :type input: Tensor + :param axis: the dimension along which the reduction should be done + :type axis: int + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable + :param reverse: whether to apply the associative scan in the reverse direction along axis + :type reverse: bool + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(scan_op): + param_types = [t.type.scalar for t in input] * 2 + region = scan_op.get_region(0) + with _insertion_guard(_builder): + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_scan_ret(*handles) + + axis = _constexpr_to_value(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, _builder=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :type input: Tensor + :param num_bins: number of histogram bins + :type num_bins: int + + """ + num_bins = _constexpr_to_value(num_bins) + return semantic.histogram(input, num_bins, _builder) + + +@_tensor_member_fn +@builtin +def gather(src, index, axis, _builder=None): + """Gather from a tensor along a given dimension. + + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + + """ + axis = _constexpr_to_value(axis) + return semantic.gather(src, index, axis, _builder) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return semantic.debug_barrier(_builder) + + +@builtin +def multiple_of(input, values, _builder=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_constancy(input, values) + + +@builtin +def assume(cond, _builder=None): + ''' + Allow compiler to assume the :code:`cond` is True. + ''' + return semantic.assume(semantic.to_tensor(cond, _builder), _builder) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _builder=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(semantic.to_tensor(arg, _builder)) + return semantic.device_print(prefix, new_args, hex, _builder) + + +@builtin +def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _constexpr_to_value(msg) + return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, _builder) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _builder=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + `PTX `_ + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + `LLVM format `_ + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :param _builder: the builder + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _constexpr_to_value(asm) + constraints = _constexpr_to_value(constraints) + pack = _constexpr_to_value(pack) + is_pure = _constexpr_to_value(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [semantic.to_tensor(arg, _builder) for arg in args]: + bin_op_type_checking = partial( + semantic.binary_op_type_checking_impl, + builder=_builder, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype] + handles = [t.handle for t in dispatch_args] + call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr), f"{arg1} used as tl.static_range start value is not a constexpr" + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr), f"{step} used as tl.static_range step value is not a constexpr" + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr), f"{arg2} used as tl.static_range end value is not a constexpr" + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + :param loop_unroll_factor: Tells the Triton IR level loop unroller how many + times to unroll a for loop that this range is used with. Less than 2 for + this value implies no unrolling. + :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot + operation in the loop to be multi-buffered, if applicable. + :param flatten: automatically flatten the loop nest starting at this loop to + create a single flattened loop. The compiler will try to pipeline the + flattened loop which can avoid stage stalling. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, + disallow_acc_multi_buffer=False, flatten=False): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + self.loop_unroll_factor = loop_unroll_factor + self.disallow_acc_multi_buffer = disallow_acc_multi_buffer + self.flatten = flatten + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) + + +@builtin +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = semantic.to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = _builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) + + +def binary_op_type_legalization(lhs, rhs, builder): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs, builder) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) diff --git a/third_party/enflame/include/triton/python/triton/language/extra/__init__.py b/third_party/enflame/include/triton/python/triton/language/extra/__init__.py new file mode 100644 index 000000000..3f8c70a71 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/extra/__init__.py @@ -0,0 +1,26 @@ +import pkgutil +from importlib.util import module_from_spec +from sys import modules + +_backends = [] +for module_finder, module_name, is_pkg in pkgutil.iter_modules( + __path__, + prefix=__name__ + ".", +): + # skip .py files (like libdevice.py) + if not is_pkg: + continue + + # import backends (like cuda and hip) that are included during setup.py + spec = module_finder.find_spec(module_name) + if spec is None or spec.loader is None: + continue + module = module_from_spec(spec) + spec.loader.exec_module(module) + + _backends.append(module_name) + modules[module_name] = module + +__all__ = _backends + +del _backends diff --git a/third_party/enflame/include/triton/python/triton/language/extra/libdevice.py b/third_party/enflame/include/triton/python/triton/language/extra/libdevice.py new file mode 100644 index 000000000..76627035d --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/extra/libdevice.py @@ -0,0 +1,786 @@ +def clz(arg0): + ... + + +def popc(arg0): + ... + + +def byte_perm(arg0, arg1, arg2): + ... + + +def mulhi(arg0, arg1): + ... + + +def mul24(arg0, arg1): + ... + + +def brev(arg0): + ... + + +def sad(arg0, arg1, arg2): + ... + + +def abs(arg0): + ... + + +def floor(arg0): + ... + + +def rcp64h(arg0): + ... + + +def rsqrt(arg0): + ... + + +def ceil(arg0): + ... + + +def trunc(arg0): + ... + + +def exp2(arg0): + ... + + +def saturatef(arg0): + ... + + +def fma_rn(arg0, arg1, arg2): + ... + + +def fma_rz(arg0, arg1, arg2): + ... + + +def fma_rd(arg0, arg1, arg2): + ... + + +def fma_ru(arg0, arg1, arg2): + ... + + +def fast_dividef(arg0, arg1): + ... + + +def div_rn(arg0, arg1): + ... + + +def div_rz(arg0, arg1): + ... + + +def div_rd(arg0, arg1): + ... + + +def div_ru(arg0, arg1): + ... + + +def rcp_rn(arg0): + ... + + +def rcp_rz(arg0): + ... + + +def rcp_rd(arg0): + ... + + +def rcp_ru(arg0): + ... + + +def sqrt_rn(arg0): + ... + + +def sqrt_rz(arg0): + ... + + +def sqrt_rd(arg0): + ... + + +def sqrt_ru(arg0): + ... + + +def sqrt(arg0): + ... + + +def add_rn(arg0, arg1): + ... + + +def add_rz(arg0, arg1): + ... + + +def add_rd(arg0, arg1): + ... + + +def add_ru(arg0, arg1): + ... + + +def mul_rn(arg0, arg1): + ... + + +def mul_rz(arg0, arg1): + ... + + +def mul_rd(arg0, arg1): + ... + + +def mul_ru(arg0, arg1): + ... + + +def double2float_rn(arg0): + ... + + +def double2float_rz(arg0): + ... + + +def double2float_rd(arg0): + ... + + +def double2float_ru(arg0): + ... + + +def double2int_rn(arg0): + ... + + +def double2int_rz(arg0): + ... + + +def double2int_rd(arg0): + ... + + +def double2int_ru(arg0): + ... + + +def double2uint_rn(arg0): + ... + + +def double2uint_rz(arg0): + ... + + +def double2uint_rd(arg0): + ... + + +def double2uint_ru(arg0): + ... + + +def int2double_rn(arg0): + ... + + +def uint2double_rn(arg0): + ... + + +def float2int_rn(arg0): + ... + + +def float2int_rz(arg0): + ... + + +def float2int_rd(arg0): + ... + + +def float2int_ru(arg0): + ... + + +def float2uint_rn(arg0): + ... + + +def float2uint_rz(arg0): + ... + + +def float2uint_rd(arg0): + ... + + +def float2uint_ru(arg0): + ... + + +def int2float_rn(arg0): + ... + + +def int2float_rz(arg0): + ... + + +def int2float_rd(arg0): + ... + + +def int2float_ru(arg0): + ... + + +def uint2float_rn(arg0): + ... + + +def uint2float_rz(arg0): + ... + + +def uint2float_rd(arg0): + ... + + +def uint2float_ru(arg0): + ... + + +def hiloint2double(arg0, arg1): + ... + + +def double2loint(arg0): + ... + + +def double2hiint(arg0): + ... + + +def float2ll_rn(arg0): + ... + + +def float2ll_rz(arg0): + ... + + +def float2ll_rd(arg0): + ... + + +def float2ll_ru(arg0): + ... + + +def float2ull_rn(arg0): + ... + + +def float2ull_rz(arg0): + ... + + +def float2ull_rd(arg0): + ... + + +def float2ull_ru(arg0): + ... + + +def double2ll_rn(arg0): + ... + + +def double2ll_rz(arg0): + ... + + +def double2ll_rd(arg0): + ... + + +def double2ll_ru(arg0): + ... + + +def double2ull_rn(arg0): + ... + + +def double2ull_rz(arg0): + ... + + +def double2ull_rd(arg0): + ... + + +def double2ull_ru(arg0): + ... + + +def ll2float_rn(arg0): + ... + + +def ll2float_rz(arg0): + ... + + +def ll2float_rd(arg0): + ... + + +def ll2float_ru(arg0): + ... + + +def ull2float_rn(arg0): + ... + + +def ull2float_rz(arg0): + ... + + +def ull2float_rd(arg0): + ... + + +def ull2float_ru(arg0): + ... + + +def ll2double_rn(arg0): + ... + + +def ll2double_rz(arg0): + ... + + +def ll2double_rd(arg0): + ... + + +def ll2double_ru(arg0): + ... + + +def ull2double_rn(arg0): + ... + + +def ull2double_rz(arg0): + ... + + +def ull2double_rd(arg0): + ... + + +def ull2double_ru(arg0): + ... + + +def int_as_float(arg0): + ... + + +def float_as_int(arg0): + ... + + +def uint_as_float(arg0): + ... + + +def float_as_uint(arg0): + ... + + +def longlong_as_double(arg0): + ... + + +def double_as_longlong(arg0): + ... + + +def fast_sinf(arg0): + ... + + +def fast_cosf(arg0): + ... + + +def fast_log2f(arg0): + ... + + +def fast_logf(arg0): + ... + + +def fast_expf(arg0): + ... + + +def fast_tanf(arg0): + ... + + +def fast_exp10f(arg0): + ... + + +def fast_log10f(arg0): + ... + + +def fast_powf(arg0, arg1): + ... + + +def hadd(arg0, arg1): + ... + + +def rhadd(arg0, arg1): + ... + + +def sub_rn(arg0, arg1): + ... + + +def sub_rz(arg0, arg1): + ... + + +def sub_rd(arg0, arg1): + ... + + +def sub_ru(arg0, arg1): + ... + + +def rsqrt_rn(arg0): + ... + + +def ffs(arg0): + ... + + +def rint(arg0): + ... + + +def llrint(arg0): + ... + + +def nearbyint(arg0): + ... + + +def isnan(arg0): + ... + + +def signbit(arg0): + ... + + +def copysign(arg0, arg1): + ... + + +def finitef(arg0): + ... + + +def isinf(arg0): + ... + + +def nextafter(arg0, arg1): + ... + + +def sin(arg0): + ... + + +def cos(arg0): + ... + + +def sinpi(arg0): + ... + + +def cospi(arg0): + ... + + +def tan(arg0): + ... + + +def log2(arg0): + ... + + +def exp(arg0): + ... + + +def exp10(arg0): + ... + + +def cosh(arg0): + ... + + +def sinh(arg0): + ... + + +def tanh(arg0): + ... + + +def atan2(arg0, arg1): + ... + + +def atan(arg0): + ... + + +def asin(arg0): + ... + + +def acos(arg0): + ... + + +def log(arg0): + ... + + +def log10(arg0): + ... + + +def log1p(arg0): + ... + + +def acosh(arg0): + ... + + +def asinh(arg0): + ... + + +def atanh(arg0): + ... + + +def expm1(arg0): + ... + + +def hypot(arg0, arg1): + ... + + +def rhypot(arg0, arg1): + ... + + +def norm3d(arg0, arg1, arg2): + ... + + +def rnorm3d(arg0, arg1, arg2): + ... + + +def norm4d(arg0, arg1, arg2, arg3): + ... + + +def rnorm4d(arg0, arg1, arg2, arg3): + ... + + +def cbrt(arg0): + ... + + +def rcbrt(arg0): + ... + + +def j0(arg0): + ... + + +def j1(arg0): + ... + + +def y0(arg0): + ... + + +def y1(arg0): + ... + + +def yn(arg0, arg1): + ... + + +def jn(arg0, arg1): + ... + + +def cyl_bessel_i0(arg0): + ... + + +def cyl_bessel_i1(arg0): + ... + + +def erf(arg0): + ... + + +def erfinv(arg0): + ... + + +def erfc(arg0): + ... + + +def erfcx(arg0): + ... + + +def erfcinv(arg0): + ... + + +def normcdfinv(arg0): + ... + + +def normcdf(arg0): + ... + + +def lgamma(arg0): + ... + + +def ldexp(arg0, arg1): + ... + + +def scalbn(arg0, arg1): + ... + + +def fmod(arg0, arg1): + ... + + +def remainder(arg0, arg1): + ... + + +def fma(arg0, arg1, arg2): + ... + + +def pow(arg0, arg1): + ... + + +def tgamma(arg0): + ... + + +def round(arg0): + ... + + +def llround(arg0): + ... + + +def fdim(arg0, arg1): + ... + + +def ilogb(arg0): + ... + + +def logb(arg0): + ... + + +def isfinited(arg0): + ... diff --git a/third_party/enflame/include/triton/python/triton/language/math.py b/third_party/enflame/include/triton/python/triton/language/math.py new file mode 100644 index 000000000..0ecdcf2ea --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/math.py @@ -0,0 +1,250 @@ +from . import core +from . import semantic +from functools import wraps +from typing import List + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + + +@core._tensor_member_fn +@core.builtin +@_add_math_1arg_docstr("absolute value") +def abs(x, _builder=None): + x = semantic.to_tensor(x, _builder) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) + return core.tensor(_builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _builder=None): + ieee_rounding = core._constexpr_to_value(ieee_rounding) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + return semantic.fdiv(x, y, ieee_rounding, _builder) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") +def div_rn(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + z = semantic.to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) diff --git a/third_party/enflame/include/triton/python/triton/language/random.py b/third_party/enflame/include/triton/python/triton/language/random.py new file mode 100644 index 000000000..58dd20569 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/random.py @@ -0,0 +1,208 @@ +from ..runtime.jit import jit +from . import core as tl +from . import math + +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + if c0.dtype == tl.uint32: + PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 + PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 + PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 + else: + tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl") + PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B + PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93 + PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157 + + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = math.umulhi(B, _c2) ^ c1 ^ k0 + c2 = math.umulhi(A, _c0) ^ c3 ^ k1 + c1 = tl.mul(B, _c2, sanitize_overflow=False) + c3 = tl.mul(A, _c0, sanitize_overflow=False) + # raise key + k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False) + k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False) + return c0, c1, c2, c3 + + +@jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = tl.to_tensor(seed) + tl.static_assert(seed.dtype.is_int()) + seed = seed.to(tl.uint64) + c0 = tl.to_tensor(c0) + c1 = tl.to_tensor(c1) + c2 = tl.to_tensor(c2) + c3 = tl.to_tensor(c3) + if tl.constexpr(c0.dtype.primitive_bitwidth) == 32: + int_dtype = tl.uint32 + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + else: + tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox") + int_dtype = tl.uint64 + seed_hi = tl.full((1, ), 0, dtype=int_dtype) + seed_lo = seed + c0 = c0.to(int_dtype, bitcast=True) + c1 = c1.to(int_dtype, bitcast=True) + c2 = c2.to(int_dtype, bitcast=True) + c3 = c3.to(int_dtype, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offset: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + + +@jit +def uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + # TODO: fix frontend issues and cleanup + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + x = x.to(tl.int32, bitcast=True) + scale = 4.6566127342e-10 + else: + tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + x = x.to(tl.int64, bitcast=True) + scale = 1.0842020432385337e-19 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset, n_rounds) + return uint_to_uniform_float(source) + + +@jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + u3 = uint_to_uniform_float(i3) + u4 = uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +# ------------------- +# randn +# ------------------- + + +@jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = math.sqrt(-2.0 * math.log(u1)) + return r * math.cos(th), r * math.sin(th) + + +@jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/third_party/enflame/include/triton/python/triton/language/semantic.py b/third_party/enflame/include/triton/python/triton/language/semantic.py new file mode 100644 index 000000000..431893560 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/semantic.py @@ -0,0 +1,1950 @@ +from __future__ import annotations # remove after python 3.11 +import warnings + +from typing import List, Optional, Sequence, Tuple, TypeVar +import numbers + +from .._C.libtriton import ir +from . import core as tl + +T = TypeVar('T') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + +def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool, + div_or_mod: bool) -> tl.dtype: + # 0) For scalars we follow semantics similar to PyTorch, namely: + # - If the scalar is of a lower or equal kind (bool < uint < int < fp), + # it doesn't participate in the promotion + if a_is_scalar != b_is_scalar: + scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty) + if scalar_ty.kind().value <= tensor_ty.kind().value: + # Upcast because of 3) and 4) below! + if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)): + return tl.float32 + return tensor_ty + + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() and b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + else: + return tl.bfloat16 + if a_ty.is_bf16() or b_ty.is_bf16(): + return tl.float32 + # 5) return fp16 if operands are different fp8 + if a_ty.is_fp8() and b_ty.is_fp8(): + return a_ty if a_ty == b_ty else tl.float16 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 6 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + + +def to_tensor(x, builder, check_type: bool = True): + if isinstance(x, bool): + return tl.tensor(builder.get_int1(x), tl.int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + dtype = tl.int32 + elif 2**31 <= x < 2**32: + dtype = tl.uint32 + elif -2**63 <= x < 2**63: + dtype = tl.int64 + elif 2**63 <= x < 2**64: + dtype = tl.uint64 + else: + raise ValueError(f'Nonrepresentable integer {x}.') + return full((), x, dtype=dtype, builder=builder) + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + dtype = tl.float32 + else: + dtype = tl.float64 + return full((), x, dtype=dtype, builder=builder) + + elif isinstance(x, tl.constexpr): + return to_tensor(x.value, builder) + elif isinstance(x, tl.tensor): + return x + if check_type: + raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") + return x + + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor | numbers.Number, builder: ir.builder, + allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: + lhs_is_scalar = isinstance(lhs, numbers.Number) + rhs_is_scalar = isinstance(rhs, numbers.Number) + if lhs_is_scalar: + lhs_scalar = lhs + lhs = to_tensor(lhs, builder) + if rhs_is_scalar: + rhs_scalar = rhs + rhs = to_tensor(rhs, builder) + + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod) + if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned() + or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()): + raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. " + "Perform a explicit cast on one of them.") + if ret_sca_ty.is_int(): + if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= ret_sca_ty.get_int_max_value()): + raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}") + if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= ret_sca_ty.get_int_max_value()): + raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}") + lhs = full( + (), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder) + rhs = full( + (), rhs_scalar, dtype=ret_sca_ty, builder=builder) if rhs_is_scalar else cast(rhs, ret_sca_ty, builder) + + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + return lhs, rhs + + +def binary_op_sanitize_overflow_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, binary_op: callable): + if lhs.type.scalar.int_bitwidth >= 64 or not builder.options.sanitize_overflow: + return + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + assert lhs_sca_ty == rhs_sca_ty + assert lhs_sca_ty.is_int() + lhs = cast(lhs, tl.int64, builder) + rhs = cast(rhs, tl.int64, builder) + ret = binary_op(lhs, rhs, False, builder) + max_value = lhs_sca_ty.get_int_max_value() + max_value = tl.tensor(builder.get_int64(max_value), tl.int64) + min_value = lhs_sca_ty.get_int_min_value() + min_value = tl.tensor(builder.get_int64(min_value), tl.int64) + cond = and_(less_equal(ret, max_value, builder), greater_equal(ret, min_value, builder), builder) + msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}" + device_assert(cond, msg, builder) + + +def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + other_handle = other.handle + if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64: + # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive + if other.type.is_block(): + i64_ty = tl.block_type(tl.int64, other.type.get_block_shapes()).to_ir(builder) + else: + i64_ty = tl.int64.to_ir(builder) + other_handle = builder.create_int_cast(other.handle, i64_ty, False) + return tl.tensor(builder.create_addptr(input.handle, other_handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, add) + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, sub) + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def mul(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # int * int + elif scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, mul) + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def truediv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def fdiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, ieee_rounding: bool, + builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +############## +# other arithmetic ops +############## + + +def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + min, max = binary_op_type_checking_impl(min, max, builder) + x, min = binary_op_type_checking_impl(x, min, builder) + x, max = binary_op_type_checking_impl(x, max, builder) + + dtype = x.dtype + if dtype.is_floating(): + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") + + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + +def not_(input: tl.tensor, builder: ir.builder): + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + return invert(input, builder) + + +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, True, builder) + + +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + if (range & (range - 1)) != 0: + raise ValueError("arange's range must be a power of 2") + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + + +def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + value = cast(value, dtype, builder) + else: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) + + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + + +def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) + + +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) + + +def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor: + a, b = broadcast_impl_value(a, b, builder) + + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = expand_dims(a, 0, builder) + b = expand_dims(b, 0, builder) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = reshape(ret, [2], can_reorder=False, builder=builder) + + return ret + + +def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + assert (len(a.shape) > 0) + assert (tl._constexpr_to_value(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = builder.create_split(a.handle) + return ( + tl.tensor(outLHS, ret_type), + tl.tensor(outRHS, ret_type), + ) + + +def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor: + if len(input.shape) != len(dims): + raise ValueError("permute dims must have the same length as input shape") + if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return tl.tensor(builder.create_trans(input.handle, dims), ret_type) + + +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + + +####### +# cast +####### + + +def _str_to_rounding_mode(rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") + + +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def _str_to_load_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cv": + cache = ir.CACHE_MODIFIER.CV + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_store_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + +def _str_to_padding_option(padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + +def _str_to_sem(sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + +def _str_to_scope(scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + +def _canonicalize_boundary_check(boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + +def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) + + +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other is not None: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + is_bool = elt_ty == tl.int1 + if is_bool: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `elt_ty` type + if other is not None: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + ret = tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + if is_bool: + ret = cast(ret, tl.int1, builder) + return ret + + +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, + builder: ir.builder) -> tl.tensor: + # Cache, eviction and padding options + cache = _str_to_load_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + padding = _str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + + +def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, builder: ir.builder): + handle = builder.create_reinterpret_tensor_descriptor(desc_ptr.handle, block_ty.to_ir(builder)) + return tl._experimental_tensor_descriptor_base(handle, block_ty) + + +def validate_descriptor_block(shape, dtype): + if len(shape) != 2: + return + # Due to limitations of the shared memory encoding, the TMA bounding box has + # to be at least as big as the swizzle tile. + assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}" + min_cols = 32 // dtype.primitive_bitwidth * 8 + assert shape[ + 1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}" + + +def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str, + builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl._experimental_tensor_descriptor_base) + validate_descriptor_block(desc.block_shape, desc.dtype) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tl.tensor(x, desc.block_type) + + +def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets, + builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl._experimental_tensor_descriptor_base) + validate_descriptor_block(desc.block_shape, desc.dtype) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + assert value.shape == desc.block_shape + + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void) + + +def descriptor_gather(desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str, + builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl._experimental_tensor_descriptor_base) + assert cache_modifier == "", "cache modifier is not supported yet" + assert eviction_policy == "", "eviction policy is not supported yet" + + # Validate descriptor. + assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" + assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}" + + # Validate offsets. + assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}" + + # Validate minimum block size. + assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}" + dtype = desc.dtype + min_cols = 32 // dtype.primitive_bitwidth * 8 + assert desc.block_shape[ + 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}" + + type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]]) + y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0] + x = builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(builder)) + return tl.tensor(x, type) + + +def descriptor_scatter(desc, value: tl.tensor, x_offsets, y_offset, builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl._experimental_tensor_descriptor_base) + + # Validate descriptor. + assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" + assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}" + + # Validate offsets. + assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}" + + # Validate minimum block size. + assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}" + dtype = desc.dtype + min_cols = 32 // dtype.primitive_bitwidth * 8 + assert desc.block_shape[ + 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}" + + y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0] + builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset) + return tl.tensor(None, tl.void) + + +def tensormap_create( + desc_ptr: tl.tensor, + global_address: tl.tensor, + box_dim: List[tl.tensor], + global_dim: List[tl.tensor], + global_stride: List[tl.tensor], + element_stride: List[tl.tensor], + elem_type: int, + interleave_layout: int, + swizzle_mode: int, + fill_mode: int, + builder: ir.builder, +) -> tl.tensor: + assert not global_stride or global_stride[0].dtype == tl.int64 + return tl.tensor( + builder.create_tensormap_create( + desc_ptr.handle, + global_address.handle, + [x.handle for x in box_dim], + [x.handle for x in global_dim], + [x.handle for x in global_stride], + [x.handle for x in element_stride], + elem_type, + interleave_layout, + swizzle_mode, + fill_mode, + ), + tl.void, + ) + + +def tensormap_fenceproxy_acquire(desc_ptr: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_tensormap_fenceproxy_acquire(desc_ptr.handle), tl.void) + + +def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = broadcast_impl_shape(val, block_shape, builder) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), + tl.void) + + +def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + if mask is None: + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) + + +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: + # Cache and eviction options + cache = _str_to_store_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder) + + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val is not None: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if mask is None: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_val.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_ptr.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def _str_to_dot_input_precision(input_precision, builder): + assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.INPUT_PRECISION, input_precision) + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + # All combinations of supported fp8 x fp8 are permitted + pass + else: + assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported lhs dtype {lhs.dtype}" + assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" + + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + # We upcast because there's no fp8e4b15 type in MLIR + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." + min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) + assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ + and rhs.shape[-1].value >= min_dot_size[1], \ + f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + K = lhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if max_num_imprecise_acc is None: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + else: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: + raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})") + + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + + +def _str_to_fp_type(float_format: str): + ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None) + if ty_enum is None: + raise ValueError(f"Invalid float format: {float_format}.") + return ty_enum + + +def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): + """ + If float_format is subbyte, make sure it's packed as uint8 and return it. + Otherwise, return a tensor (perhaps bitcasting) of the specified float format. + """ + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return bitcast(val, triton_ty, builder) + + +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format: str, acc: tl.tensor | None, fast_math: bool, out_dtype: tl.dtype, + builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + #TODO: validate types. + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value + lhs_format_enum = _str_to_fp_type(lhs_format) + rhs_format_enum = _str_to_fp_type(rhs_format) + allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"} + assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" + assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" + rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) + lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) + rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) + + M = lhs.type.shape[-2] + K, N = rhs.type.shape[-2:] + PACKED_A = 2 if lhs_format == "e2m1" else 1 + PACKED_B = 2 if rhs_format == "e2m1" else 1 + assert K * PACKED_B == PACKED_A * lhs.type.shape[ + -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}" + B = lhs.type.shape[0] if lhs_rank == 3 else None + + ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) + _0 = builder.get_fp32(0) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle + return tl.tensor( + builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle, + rhs_format_enum, fast_math, acc_handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + if condition.dtype != tl.int1: + warnings.warn( + f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}" + ) + condition = cast(condition, tl.int1, builder) + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + # x, y are broadcasted + if condition.type.is_block(): + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + else: + condition, _ = broadcast_impl_value(condition, x, builder) + ret_ty = x.type + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: + if axis is None: + inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool, + builder: ir.builder) -> Tuple[tl.tensor, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + scan_op.verify() + + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Gather +# ===----------------------------------------------------------------------=== + + +def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + assert index.dtype.is_int(), "index must be an integer tensor" + + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" + + gather = builder.create_gather(src.handle, index.handle, axis) + return wrap_tensor(gather, src.type.scalar, index.type.shape) + + +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== + + +def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins])) + + +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(), tl.void) + + +def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args] + return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor: + if not builder.options.debug: + return + return tl.tensor(builder.create_assert(cond.handle, msg), tl.void) + + +def assume(cond, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_assume(cond.handle), tl.void) + + +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + +def _convert_to_ir_values(builder, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like] + return [_convert_elem_to_ir_value(builder, list_like, require_i64)] + + +def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = _convert_to_ir_values(builder, shape) + strides = _convert_to_ir_values(builder, strides) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + +def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + # Convert dynamic offsets to IR values + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return tl.tensor(builder.create_advance(base.handle, offsets), base.type) + + +def make_tensor_descriptor( + base: tl.tensor, + shape: List[tl.tensor], + strides: List[tl.tensor], + block_shape: List[tl.constexpr], + builder: ir.builder, +) -> tl._experimental_tensor_descriptor: + ndim = len(shape) + if not (2 <= ndim <= 5): + raise ValueError(f"Expected 2 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + + strides[-1] = tl._constexpr_to_value(strides[-1]) + if strides[-1] != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}") + + shape = [to_tensor(x, builder) for x in shape] + strides = [to_tensor(x, builder).to(tl.int64, _builder=builder) for x in strides] + + # Check whether `block_shape` is static + block_shape = tl._unwrap_shape(block_shape) + + assert isinstance(base.type, tl.pointer_type) + type = tl.block_type(base.type.element_ty, block_shape) + handle = builder.create_make_tensor_descriptor(base.handle, [s.handle for s in shape], [s.handle for s in strides], + block_shape) + return tl._experimental_tensor_descriptor(handle, shape, strides, type) diff --git a/third_party/enflame/include/triton/python/triton/language/standard.py b/third_party/enflame/include/triton/python/triton/language/standard.py new file mode 100644 index 000000000..66d351689 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/language/standard.py @@ -0,0 +1,469 @@ +from __future__ import annotations + +from ..runtime.jit import jit +from . import core +from . import math + +# constexpr utilities + + +def _log2(i: core.constexpr): + log2 = 0 + n = i.value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + +def _is_power_of_two(i: core.constexpr): + n = i.value + return core.constexpr((n & (n - 1)) == 0 and n != 0) + + +# ----------------------- +# Standard library +# ----------------------- + + +@core._tensor_member_fn +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type x: Block + :param div: the divisor + :type div: Block + """ + return (x + div - 1) // div + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, ieee_rounding=False): + z = x - max(x, 0) + num = math.exp(z) + den = sum(num, 0) + return math.fdiv(num, den, ieee_rounding) + + +@core._tensor_member_fn +@jit +def ravel(x, can_reorder=False): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=can_reorder) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms the indices of a row-major `size_i * size_j` matrix into + the indices of a column-major matrix for each group of `size_g` rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # linear index with respect to the first element in this group + ij = ij % size_gj + # new row and column indices + new_i = off_i + ij % size_g + new_j = ij // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + """ + Returns a tensor of zeros with the same shape and type as a given tensor. + + :param input: input tensor + :type input: Tensor + """ + return zeros(input.shape, input.dtype) + + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_max(a, b): + return core.maximum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True, keep_dims=False): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_min(a, b): + return core.minimum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True, keep_dims=False): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +def _pick_sum_dtype(in_dtype: core.constexpr, dtype: core.constexpr): + dtype = core._unwrap_if_constexpr(dtype) + if dtype is not None: + return dtype + + # For integer bitwidths less than 32, pick int32 with the same sign to + # avoid overflow. + out_dtype = None + if in_dtype.is_int_signed(): + out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None + elif in_dtype.is_int_unsigned(): + out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None + return out_dtype + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum", dtype_arg="dtype") +def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None): + # Pick a default dtype for the reduction if one was not specified. + out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype) + + if out_dtype is not None: + input = input.to(out_dtype) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + +@jit +def _xor_combine(a, b): + return a ^ b + + +# xor sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, keep_dims=False): + core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers") + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum") +def cumsum(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) + + +# sort + + +@jit +def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + ret = ix ^ core.where((left > right) != flip, ileft ^ iright, zeros_like(ix)) + return ret.to(x.dtype, bitcast=True) + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims) + return x + + +@core._tensor_member_fn +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + """ + Sorts a tensor along a specified dimension. + + :param x: The input tensor to be sorted. + :type x: Tensor + :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported. + :type dim: int, optional + :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order. + :type descending: bool, optional + """ + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: core.constexpr = _log2(x.shape[_dim]) + for i in core.static_range(1, n_dims + 1): + x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims) + return x + + +# flip + + +def _get_flip_dim(dim, shape): + dim = core._unwrap_if_constexpr(dim) + shape = core._unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + assert dim == len(shape) - 1, "Currently only support flipping the last dimension" + return core.constexpr(dim) + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)])) + core.static_assert(_is_power_of_two(x.numel)) + # reshape the tensor to have all dimensions be 2. + # TODO: We shouldn't have to change the dimensions not sorted. + steps: core.constexpr = _log2(x.numel) + start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)]) + + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + y = core.reshape(x.to(idtype, bitcast=True), [2] * steps) + y = core.expand_dims(y, start) + flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2)) + for i in core.static_range(start, steps): + flip2 = flip + for j in core.static_range(0, steps + 1): + if j != i and j != i + 1: + flip2 = core.expand_dims(flip2, j) + y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype) + x = core.reshape(y, x.shape).to(x.dtype, bitcast=True) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape. + Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])` + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + c = core.join(a, b) + + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/third_party/enflame/include/triton/python/triton/runtime/__init__.py b/third_party/enflame/include/triton/python/triton/runtime/__init__.py new file mode 100644 index 000000000..0b3979d28 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/__init__.py @@ -0,0 +1,23 @@ +from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) +from .cache import RedisRemoteCacheBackend, RemoteCacheBackend +from .driver import driver +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret +from .errors import OutOfResources, InterpreterError + +__all__ = [ + "autotune", + "Autotuner", + "Config", + "driver", + "Heuristics", + "heuristics", + "InterpreterError", + "JITFunction", + "KernelInterface", + "MockTensor", + "OutOfResources", + "RedisRemoteCacheBackend", + "reinterpret", + "RemoteCacheBackend", + "TensorWrapper", +] diff --git a/third_party/enflame/include/triton/python/triton/runtime/_allocation.py b/third_party/enflame/include/triton/python/triton/runtime/_allocation.py new file mode 100644 index 000000000..aa8a45488 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/_allocation.py @@ -0,0 +1,32 @@ +from typing import Optional, Protocol + + +class Buffer(Protocol): + + def data_ptr(self) -> int: + ... + + +class Allocator(Protocol): + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + ... + + +class NullAllocator: + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " + + "Use triton.set_allocator to specify an allocator.") + + +_allocator: Allocator = NullAllocator() + + +def set_allocator(allocator: Allocator): + """ + The allocator function is called during kernel launch for kernels that + require additional global memory workspace. + """ + global _allocator + _allocator = allocator diff --git a/third_party/enflame/include/triton/python/triton/runtime/autotuner.py b/third_party/enflame/include/triton/python/triton/runtime/autotuner.py new file mode 100644 index 000000000..f5fb73c4f --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/autotuner.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import builtins +import os +import time +import inspect +from typing import Dict, Tuple, List, Optional + +from .jit import KernelInterface +from .errors import OutOfResources, PTXASError +from .driver import driver + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Optional[Dict] = None, + warmup=None, + rep=None, + use_cuda_graph=False, + do_bench=None, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [ + Config({}, num_warps=4, num_stages=3, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0) + ] + else: + self.configs = configs + self.keys = key + self.cache: Dict[Tuple, Config] = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_to_zero = [] + if reset_to_zero is not None: + self.reset_to_zero = list(reset_to_zero) + self.restore_value = [] + if restore_value is not None: + self.restore_value = list(restore_value) + + # Hook to reset or restore for required tensors + self.pre_hook = lambda kwargs, reset_only=False: 0 + self.post_hook = lambda kwargs, exception: 0 + self.user_defined_pre_hook = False + self.user_defined_post_hook = False + if pre_hook: + self.pre_hook = pre_hook + self.user_defined_pre_hook = True + elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): + + def _pre_hook(kwargs, reset_only=False): + for name in self.reset_to_zero: + kwargs[name].zero_() + if not reset_only: + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + self.user_defined_post_hook = True + elif len(self.restore_value) > 0: + + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + + self.num_warmups = warmup + self.num_reps = rep + self.use_cuda_graph = use_cuda_graph + + # If we got explicitly called via the old interface, raise a warning + # and proceed with the old behavior. + if warmup is not None or rep is not None or use_cuda_graph: + import warnings + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, + stacklevel=1) + if use_cuda_graph: + from ..testing import do_bench_cudagraph + self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + kernel_call, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + import triton.testing + self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + kernel_call, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + if do_bench is None: + self.do_bench = driver.active.get_benchmarker() + else: + self.do_bench = do_bench + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + verbose = os.environ.get("TRITON_PRINT_AUTOTUNING", None) == "1" + if verbose: + print(f"Autotuning kernel {self.base_fn.__name__} with config {config}") + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(full_nargs) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(full_nargs, exception=None) + + try: + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: + if verbose: + print(f"Autotuning failed with {e}") + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[key] for key in self.keys if key in _args] + for _, arg in _args.items(): + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs: Dict) -> List[Config]: + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + elif not isinstance(top_k, int): + # Slice index must be an integer + raise TypeError("Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int") + + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.num_buffers_warp_spec = num_buffers_warp_spec + self.num_consumer_groups = num_consumer_groups + self.reg_dec_producer = reg_dec_producer + self.reg_inc_consumer = reg_inc_consumer + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("num_buffers_warp_spec", self.num_buffers_warp_spec), + ("num_consumer_groups", self.num_consumer_groups), + ("reg_dec_producer", self.reg_dec_producer), + ("reg_inc_consumer", self.reg_inc_consumer), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}") + res.append(f"num_consumer_groups: {self.num_consumer_groups}") + res.append(f"reg_dec_producer: {self.reg_dec_producer}") + res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=None, rep=None, use_cuda_graph=False, do_bench=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): + ... + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). + :type warmup: int + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). + :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitively expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + # smallest power-of-two >= x_size + @triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])}) + @triton.jit + def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): + ... + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[dict[str, Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/third_party/enflame/include/triton/python/triton/runtime/build.py b/third_party/enflame/include/triton/python/triton/runtime/build.py new file mode 100644 index 000000000..1b76548d4 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/build.py @@ -0,0 +1,37 @@ +import sysconfig +import os +import shutil +import subprocess + + +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) + include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] + # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + return so diff --git a/third_party/enflame/include/triton/python/triton/runtime/cache.py b/third_party/enflame/include/triton/python/triton/runtime/cache.py new file mode 100644 index 000000000..62895508b --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/cache.py @@ -0,0 +1,295 @@ +import importlib +import json +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional +import base64 +import hashlib + + +def get_home_dir(): + return os.getenv("TRITON_HOME", Path.home()) + + +def default_cache_dir(): + return os.path.join(get_home_dir(), ".triton", "cache") + + +def default_override_dir(): + return os.path.join(get_home_dir(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(get_home_dir(), ".triton", "dump") + + +class CacheManager(ABC): + + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, filename) + + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + os.removedirs(temp_dir) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"] + module_path, clz_nme = remote_cache_manager.split(":") + module = importlib.import_module(module_path) + remote_cache_cls = getattr(module, clz_nme) + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def _base32(key): + # Assume key is a hex string. + return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(_base32(key)) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(_base32(key), override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(_base32(key), dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return _base32(key) diff --git a/third_party/enflame/include/triton/python/triton/runtime/driver.py b/third_party/enflame/include/triton/python/triton/runtime/driver.py new file mode 100644 index 000000000..c3b97a764 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/driver.py @@ -0,0 +1,60 @@ +from ..backends import backends +from ..backends import DriverBase + + +def _create_driver(): + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) != 1: + raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") + return actives[0]() + + +class LazyProxy: + + def __init__(self, init_fn): + self._init_fn = init_fn + self._obj = None + + def _initialize_obj(self): + if self._obj is None: + self._obj = self._init_fn() + + def __getattr__(self, name): + self._initialize_obj() + return getattr(self._obj, name) + + def __setattr__(self, name, value): + if name in ["_init_fn", "_obj"]: + super().__setattr__(name, value) + else: + self._initialize_obj() + setattr(self._obj, name, value) + + def __delattr__(self, name): + self._initialize_obj() + delattr(self._obj, name) + + def __repr__(self): + if self._obj is None: + return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>" + return repr(self._obj) + + def __str__(self): + self._initialize_obj() + return str(self._obj) + + +class DriverConfig: + + def __init__(self): + self.default = LazyProxy(_create_driver) + self.active = self.default + + def set_active(self, driver: DriverBase): + self.active = driver + + def reset_active(self): + self.active = self.default + + +driver = DriverConfig() diff --git a/third_party/enflame/include/triton/python/triton/runtime/errors.py b/third_party/enflame/include/triton/python/triton/runtime/errors.py new file mode 100644 index 000000000..1a8046430 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/errors.py @@ -0,0 +1,36 @@ +from ..errors import TritonError +from typing import Optional + + +class InterpreterError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + return self.error_message or "" + + +class OutOfResources(TritonError): + + def __init__(self, required, limit, name): + self.required = required + self.limit = limit + self.name = name + + def __str__(self) -> str: + return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help." + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) + + +class PTXASError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + error_message = self.error_message or "" + return f"PTXAS error: {error_message}" diff --git a/third_party/enflame/include/triton/python/triton/runtime/interpreter.py b/third_party/enflame/include/triton/python/triton/runtime/interpreter.py new file mode 100644 index 000000000..8ce79df04 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/interpreter.py @@ -0,0 +1,1369 @@ +import ast +import textwrap +import inspect +from typing import Tuple, List + +import math +import numpy as np + +import triton +import triton.language as tl +from dataclasses import dataclass +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter +from .._C.libtriton import ir as _ir + + +class TensorHandle: + + def __init__(self, data, dtype): + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already available in the data field + attr: a dictionary of attributes + ''' + self.data = data + self.dtype = dtype + self.attr = {} + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, block_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.block_shape = block_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + ptrs = np.broadcast_to(self.base.data, self.block_shape) + masks = np.ones(self.block_shape, dtype=bool) + for dim in range(len(self.block_shape)): + bcast_dims = [1] * len(self.block_shape) + bcast_dims[dim] = self.block_shape[dim] + off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = masks & (off < self.shape[dim].data) & (off >= 0) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +class TensorDescHandle: + + def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle], + block_shape: List[int]): + self.base = base + self.ndim = len(shape) + self.shape = shape + self.strides = strides + self.block_shape = block_shape + + def validate(self): + assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned" + assert len(self.strides) == self.ndim + assert len(self.block_shape) == self.ndim + + for stride in self.strides[:-1]: + assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned" + assert self.strides[-1].data.item() == 1, "last dim must be contiguous" + + def materialize_pointers(self, offsets: List[TensorHandle]): + assert len(offsets) == self.ndim + scalar_ty = self.base.dtype.element_ty + itemsize = scalar_ty.primitive_bitwidth // 8 + assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned" + + ptrs = np.broadcast_to(self.base.data, self.block_shape) + masks = np.ones(self.block_shape, dtype=bool) + for dim in range(len(self.block_shape)): + bcast_dims = [1] * len(self.block_shape) + bcast_dims[dim] = self.block_shape[dim] + off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64) + masks = masks & (0 <= off) & (off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: dict = None + debug: bool = False + sanitize_overflow: bool = True + arch: str = None + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") + deprecated_fp8_dtypes: Tuple[str] = () + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + backend_name: str = "interpreter" + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder): + return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1) + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + create_int_to_ptr = create_bitcast + create_ptr_to_int = create_bitcast + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins): + return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32) + + def create_gather(self, src, indices, axis): + return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, arg, shape): + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values, isSigned): + # NOTE: the `isSigned` variable is not really used here; because Signness is already known + # by `values` themselves in python interpreter, thus not really needed here; + # it is only used for triton PrintOpToLLVM to correctly construct the format specifier. + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message}" + + def create_assume(self, condition): + assert condition, "Assume failed" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def create_make_tensor_descriptor( + self, + base: TensorHandle, + shape: List[TensorHandle], + strides: List[TensorHandle], + tensor_shape: List[int], + ): + desc = TensorDescHandle(base, shape, strides, tensor_shape) + desc.validate() + return desc + + def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier, + eviction_policy): + assert isinstance(desc, TensorDescHandle) + ptrs, mask = desc.materialize_pointers(indices) + return self.create_masked_load(ptrs, mask, other=None, cache_modifier=cache_modifier, + eviction_policy=eviction_policy, is_volatile=False) + + def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]): + ptrs, mask = desc.materialize_pointers(indices) + return self.create_masked_store(ptrs, value, mask, None, None) + + def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type): + dtype = desc.base.dtype.element_ty + np_dtype = _get_np_dtype(dtype) + result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype) + cache_modifier = None + eviction_policy = None + for i, x_offset in enumerate(x_offsets.data): + indices = [TensorHandle(x_offset, tl.int32), y_offset] + result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data + return TensorHandle(result, dtype) + + def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle, + y_offset: TensorHandle): + for i, x_offset in enumerate(x_offsets.data): + slice = TensorHandle(value.data[i], value.dtype) + indices = [TensorHandle(x_offset, tl.int32), y_offset] + self.create_descriptor_store(desc, slice, indices) + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +def _patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_builtin(pkg, builder): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder) + + +def _patch_lang_tensor(tensor): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype) + assert self.type.is_block() + block_shape = list(self.type.shape) + block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1] + res_ty = tl.core.block_type(self.dtype, block_shape) + return tl.core.tensor(handle, res_ty) + + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: _get_bool(self) + tensor.__repr__ = lambda self: repr(self.handle.data) + tensor.__str__ = lambda self: str(self.handle.data) + tensor.T = property(_get_transpose) + + +class ReduceScanOpInterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + np_dtype = _get_np_dtype(dtype) + if hasattr(ret, "shape") and ret.shape: + ret = ret.astype(np_dtype) + ret_type = tl.block_type(dtype, list(ret.shape)) + else: + ret = np.array([ret], dtype=np_dtype) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply(self, input): + if not isinstance(input, tuple): + input = (input, ) + self.check_tensor(input) + return self.apply_impl(input) + + def apply_impl(self, input): + raise NotImplementedError("apply_impl not implemented") + + +class ReduceOps(ReduceScanOpInterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret[0] if len(ret) == 1 else tuple(ret) + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpInterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return len(ret) == 1 and ret[0] or tuple(ret) + + +def _patch_reduce_scan(): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + tl.reduce = _new_reduce + tl.associative_scan = _new_scan + tl.core.reduce = _new_reduce + tl.core.associative_scan = _new_scan + + +def _patch_lang_core(lang): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + lang.range = _new_range + lang.static_range = _new_range + lang.static_assert = _new_static_assert + lang.static_print = print + lang.dtype.to_ir = _new_to_ir + lang.multiple_of = partial(_set_attr, name="tt.divisibility") + lang.max_contiguous = partial(_set_attr, name="tt.contiguity") + lang.max_constancy = partial(_set_attr, name="tt.constancy") + + _patch_reduce_scan() + + +def _patch_lang(fn): + langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]] + assert len(langs) >= 1, "triton.language must be visible from within jit'd function" + for lang in langs: + _patch_builtin(lang, interpreter_builder) + _patch_builtin(lang.tensor, interpreter_builder) + if lang == tl: + _patch_builtin(lang.math, interpreter_builder) + _patch_lang_tensor(lang.tensor) + _patch_lang_core(lang) + _patch_builtin(tl.core._experimental_tensor_descriptor_base, interpreter_builder) + + +def _tuple_create(arg, contents): + # NamedTuples and tuples have different construction semantics. NamedTuple + # has a constructor that takes individual arguments, while tuple takes an + # iterable. Both have type "tuple" making it difficult to distinguish + # between them, but only NamedTuple has "_fields" and apparently this is how + # everyone does the check. + return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg)) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg)) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + elif isinstance(arg, tuple): + return _tuple_create(arg, map(_implicit_cvt, arg)) + return arg + + +interpreter_builder = InterpreterBuilder() + + +def _unwrap_tensor(t): + if isinstance(t, triton.runtime.jit.TensorWrapper): + return t.base + return t + + +def _rewrap_tensor(t, original_tensor): + if isinstance(original_tensor, triton.runtime.jit.TensorWrapper): + return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype) + return t + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + storages = {} + + def _to_cpu(arg): + if isinstance(arg, tuple): + return _tuple_create(arg, map(_to_cpu, arg)) + elif not hasattr(arg, "data_ptr"): + return arg + + unwrapped_arg = _unwrap_tensor(arg) + if unwrapped_arg.untyped_storage().data_ptr() not in storages: + storage = unwrapped_arg.untyped_storage() + storages[storage.data_ptr()] = storage.cpu() + + storage = storages[unwrapped_arg.untyped_storage().data_ptr()] + cpu_arg = unwrapped_arg.new_empty(0, device='cpu') + cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride()) + cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg) + return cpu_arg + + args_hst = [_to_cpu(arg) for arg in args_dev] + + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + kwargs_hst[key] = _to_cpu(value) + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + storages = {} + + def _from_cpu(arg_dev, arg_hst): + if hasattr(arg_dev, "data_ptr"): + # No need to rewrap because this just modifies internal + arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst) + storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage()) + elif isinstance(arg_dev, tuple): + for (arg_dev, arg_hst) in zip(arg_dev, arg_hst): + _from_cpu(arg_dev, arg_hst) + + for arg_dev, arg_hst in zip(args_dev, args_hst): + _from_cpu(arg_dev, arg_hst) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + _from_cpu(kwarg_dev, kwarg_hst) + + for (arg_dev, arg_hst) in storages.values(): + arg_dev.copy_(arg_hst) + + def __call__(self, *args_dev, **kwargs): + if kwargs.pop("warmup", False): + return + # Removes not used reserved keywords from kwargs + # Triton doesn't support keyword-only, variable positional or variable keyword arguments + # It's safe to inspect only positional or keyword arguments (i.e., argspec.args) + argspec = inspect.getfullargspec(self.fn) + kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # remaps core language functions to interpreted ones + _patch_lang(self.fn) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + try: + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + raise InterpreterError(repr(e)) from e + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class ASTTransformer(ast.NodeTransformer): + + def visit_Assign(self, node): + names = [] + for target in node.targets: + names += [self.visit(target)] + if len(names) > 1: + raise ValueError("Multiple assignments are not supported") + # Modify the assignment x = value to + # triton.language.semantic.to_tensor(value, interpreter_builder, False) + node.value = ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()), + attr='semantic', ctx=ast.Load()), attr='to_tensor', ctx=ast.Load()), + args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()), + ast.Constant(value=False)], keywords=[]) + return node + + +class FunctionRewriter: + ast_transformer = ASTTransformer() + + def __init__(self, fn, **kwargs): + self.fn = fn + self.kwargs = kwargs + self.filename: str = "" + # Absolute line number in the file + self.def_file_lineno: int = 0 + + def rewrite_ast(self): + # If exception is raise, it means the function does not have source code available, + # e.g., dynamically generated functions, we cannot rewrite it so just return the original function + try: + lines, _ = inspect.getsourcelines(self.fn) + except Exception: + return self.fn + + # truncate lines before def + # @triton.autotune(...) + # ... + # @triton.jit + # ... + # def foo(...): <- this line is the function definition + self.filename, self.def_file_lineno = self._get_jit_fn_file_line() + self.def_lineno = self._find_def(lines) + src = self._prepare_source(lines) + transformed_ast = self._transform_ast(src) + return self._compile_and_exec(transformed_ast) + + def _get_jit_fn_file_line(self): + from .jit import get_jit_fn_file_line, JITFunction + return get_jit_fn_file_line(JITFunction(self.fn)) + + def _find_def(self, lines): + def_lineno = 0 + # Line numbers start from 1 + for i, line in enumerate(lines): + if line.strip().startswith("def "): + def_lineno = i + 1 + return def_lineno + + def _prepare_source(self, lines): + lines = lines[self.def_lineno - 1:] + src = ''.join(lines) + return textwrap.dedent(src) + + def _transform_ast(self, src): + # src is like: + # 1: def foo(...): + # 2: ... + parsed_ast = ast.parse(src) + transformed_ast = self.ast_transformer.visit(parsed_ast) + ast.fix_missing_locations(transformed_ast) + inc_lineno = self.def_file_lineno - 1 + ast.increment_lineno(transformed_ast, inc_lineno) + return transformed_ast + + def _compile_and_exec(self, transformed_ast): + compiled_code = compile(transformed_ast, filename=self.filename, mode='exec') + local_namespace = {**self.kwargs} + fn_globals = self.fn.__globals__ + for key, value in globals().items(): + if key not in fn_globals: + fn_globals[key] = value + exec(compiled_code, fn_globals, local_namespace) + return local_namespace[self.fn.__name__] + + +class InterpretedFunction: + # Cache all rewritten functions + rewritten_fn = {} + + def __init__(self, fn, **kwargs) -> None: + self.fn = fn + self.rewriter = FunctionRewriter(fn, **kwargs) + + def run(*args, **kwargs): + grid = kwargs["grid"] + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + def rewrite(self): + if self.fn not in self.rewritten_fn: + self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast() + return self.rewritten_fn[self.fn] + + @property + def __name__(self): + return self.fn.__name__ + + def __getitem__(self, grid): + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + fn = self.rewrite() + try: + return fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/third_party/enflame/include/triton/python/triton/runtime/jit.py b/third_party/enflame/include/triton/python/triton/runtime/jit.py new file mode 100644 index 000000000..e7567de42 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/runtime/jit.py @@ -0,0 +1,910 @@ +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from ..runtime.driver import driver +from types import ModuleType +from .._utils import find_paths_if, get_iterable_path + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def _is_triton_builtin(self, node, func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + def _update_hash(self, func): + if isinstance(func, JITFunction): + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & func.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = func.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + self.used_global_vals.update(func.used_global_vals) + # update hash + func_key = func.cache_key + func_key += str(getattr(func, "noinline", False)) + self.hasher.update(func_key.encode("utf-8")) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) is not ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + self._update_hash(val) + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + ret = getattr(lhs, node.attr) + self._update_hash(ret) + return ret + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool, + do_not_specialize_on_alignment: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +dtype2str = {} +specialize_impl_cache = [] + + +def create_specialize_impl(): + if specialize_impl_cache: + return specialize_impl_cache[-1] + + from ..language import constexpr + + def specialize_impl(arg, specialize_extra, is_const=False, specialize_value=True, align=True): + + if arg is None: + return ("constexpr", None) + elif isinstance(arg, JITFunction): + return ("constexpr", arg.cache_key) + elif isinstance(arg, constexpr): + return ("constexpr", arg) + elif isinstance(arg, bool): + return ("i1", None) + elif isinstance(arg, int): + key = specialize_extra(arg, "int", align=align) if specialize_value else None + if arg == 1 and specialize_value: + return ("constexpr", 1) + elif -(2**31) <= arg and arg <= 2**31 - 1: + return ("i32", key) + elif 2**63 <= arg and arg <= 2**64 - 1: + return ("u64", key) + else: + return ("i64", key) + elif isinstance(arg, float): + return ("fp32", None) + elif hasattr(arg, "tma_desc_cpu_ptr"): + return ("nvTmaDesc", None) + elif isinstance(arg, tuple): + spec = [specialize_impl(x, specialize_extra) for x in arg] + make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals) + tys = make_tuple([x[0] for x in spec]) + keys = make_tuple([x[1] for x in spec]) + return (tys, keys) + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + key = specialize_extra(arg, "tensor", align=align) if specialize_value else None + return (res, key) + + specialize_impl_cache.append(specialize_impl) + return specialize_impl + + +def mangle_type(arg, specialize=False): + specialize_impl = create_specialize_impl() + return specialize_impl(arg, lambda _, **kwargs: None, specialize_value=specialize)[0] + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constant_keys': [list(x) for x in constants.keys()], 'constant_vals': + list(constants.values()), 'attrs_keys': [list(x) for x in attrs.keys()], 'attrs_vals': list(attrs.values()), + 'options': options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams, backend): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + assert len(sig.parameters) == len(kparams) + # Create the function argument list and the dict entries for the return statement + specialization = [] + # signature + for name, kp in zip(sig.parameters.keys(), kparams): + if kp.is_constexpr: + specialization.append(f'("constexpr", {name})') + else: + is_const = 'True' if kp.is_const else 'False' + specialize = 'False' if kp.do_not_specialize else 'True' + align = 'False' if kp.do_not_specialize_on_alignment else 'True' + ret = f"specialize_impl({name}, specialize_extra, {is_const}, {specialize}, {align})" + if kp.annotation_type: + specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]') + else: + specialization.append(f"{ret}") + + # compute argument string for a given parameter + arg = lambda x: x[0] if x[1].default is inspect.Parameter.empty else f"{x[0]}=default_{x[0]}" + # Join all arguments into a function definition string + func_body = f""" +def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options"])}): + params = {{{', '.join([f"'{name}': {name}" for name in sig.parameters.keys()])}}} + specialization = [{','.join(specialization)}] + return params, specialization, options +""" + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace["JITFunction"] = JITFunction + func_namespace["specialize_impl"] = create_specialize_impl() + func_namespace["specialize_extra"] = backend.get_arg_specialization + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + # Hook to signal that a kernel is done compiling and inspect compiled function. + # cache_hook will always be called before compilation and compiled_hook after. + compiled_hook = None + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + is_warmup, + before, + ): + hook = JITFunction.cache_hook if before else JITFunction.compiled_hook + if hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'launch_cooperative_grid': options.launch_cooperative_grid, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + 'is_warmup': is_warmup, + } + + return hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=is_warmup, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + target = driver.active.get_current_target() + backend = make_backend(target) + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + binder = create_function_from_signature(self.signature, self.params, backend) + return {}, target, backend, binder + + def run(self, *args, grid, warmup, **kwargs): + kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1" + + # parse options + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + kernel_cache, target, backend, binder = self.device_caches[device] + bound_args, specialization, options = binder(*args, **kwargs) + + # compute cache key + key = str(specialization) + str(options) + kernel = kernel_cache.get(key, None) + + # Kernel is not cached; we have to compile. + if kernel is None: + # options + options = backend.parse_options(kwargs) + # signature + sigkeys = [x.name for x in self.params] + sigvals = [x[0] for x in specialization] + signature = {k: v for (k, v) in zip(sigkeys, sigvals)} + # check arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in kwargs: + if k not in options.__dict__ and k not in sigkeys: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + # constexprs + constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr") + constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs} + # attributes + attrvals = [x[1] for x in specialization] + attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str)) + attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs} + if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True): + return None + # compile the kernel + src = self.ASTSource(self, signature, constexprs, attrs) + kernel = self.compile(src, target=target, options=options.__dict__) + kernel_cache[key] = kernel + self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False) + + # Check that used global values have not changed. + not_present = object() + for (name, _), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values()) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, + launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, + *bound_args.values()) + return kernel + + def repr(self, _): + return self._fn_name if self._repr is None else self._repr(_) + + def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None, + noinline=None, repr=None, launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + self.starting_line_number = inspect.getsourcelines(fn)[1] + self._repr = repr + self._fn_name = fn.__name__ + self.launch_metadata = launch_metadata + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = i in do_not_specialize or param.name in do_not_specialize + dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment + self.params.append(KernelParam(i, param, dns, dns_oa)) + + # function source code (without decorators) + src = textwrap.dedent(inspect.getsource(fn)) + src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():] + self._unsafe_update_src(src) + # cache of just-in-time compiled kernels + self.device_caches = defaultdict(self.create_binder) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.debug = debug + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import compile, ASTSource + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constant_keys = map(tuple, deserialized_obj['constant_keys']) + constant_vals = deserialized_obj['constant_vals'] + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in zip(constant_keys, constant_vals) + } + attrs_keys = map(tuple, deserialized_obj['attrs_keys']) + attrs_vals = deserialized_obj['attrs_vals'] + attrs = dict(zip(attrs_keys, attrs_vals)) + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, attrs) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.device_caches[device][0][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + # - when `.src` attribute is set, cache key of all callers need to be re-computed + if name == "src": + raise AttributeError(f"Cannot set attribute '{name}' directly. " + f"Use '_unsafe_update_src()' and manually clear `.hash` of all callers" + f"instead.") + super(JITFunction, self).__setattr__(name, value) + + def _unsafe_update_src(self, new_src): + """ + The only method allowed to modify src. + Bypasses the __setattr__ restriction by calling super().__setattr__ directly. + """ + self.hash = None + super().__setattr__('src', new_src) + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, *args): + return self.base.stride(*args) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def clone(self): + return TensorWrapper(self.base.clone(), self.dtype) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + def new_empty(self, sizes): + return TensorWrapper(self.base.new_empty(sizes), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") + + +def get_jit_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line diff --git a/third_party/enflame/include/triton/python/triton/testing.py b/third_party/enflame/include/triton/python/triton/testing.py new file mode 100644 index 000000000..9a338f11c --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/testing.py @@ -0,0 +1,539 @@ +import functools +import math +import os +import statistics +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl +from . import runtime + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +# pure Python implementation of np.quantile/torch.quantile +# to avoid unnecessary runtime dependency on numpy/torch + + +def _quantile(a, q): + n = len(a) + a = sorted(a) + + def get_quantile(q): + if not (0 <= q <= 1): + raise ValueError("Quantiles must be in the range [0, 1]") + point = q * (n - 1) + lower = math.floor(point) + upper = math.ceil(point) + t = point - lower + return (1 - t) * a[lower] + t * a[upper] + + return [get_quantile(q) for q in q] + + +def _summarize_statistics(times, quantiles, return_mode): + if quantiles is not None: + ret = _quantile(times, quantiles) + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times + elif return_mode == "min": + return min(times) + elif return_mode == "max": + return max(times) + elif return_mode == "mean": + return statistics.mean(times) + elif return_mode == "median": + return statistics.median(times) + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". + :type return_mode: str + """ + import torch + assert return_mode in ["min", "max", "mean", "median", "all"] + + with torch.cuda.stream(torch.cuda.Stream()): + # warmup + fn() + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive, + # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2 + # cache flush). + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for _ in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + return _summarize_statistics(ret, quantiles, return_mode) + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". + :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + + di = runtime.driver.active.get_device_interface() + + fn() + di.synchronize() + + cache = runtime.driver.active.get_empty_cache_for_benchmark() + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + runtime.driver.active.clear_cache(cache) + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + runtime.driver.active.clear_cache(cache) + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] + return _summarize_statistics(times, quantiles, return_mode) + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + """ + Asserts that two inputs are close within a certain tolerance. + + :param x: The first input. + :type x: scala, list, numpy.ndarray, or torch.Tensor + :param y: The second input. + :type y: scala, list, numpy.ndarray, or torch.Tensor + :param atol: The absolute tolerance. Default value is 1e-2. + :type atol: float, optional + :param rtol: The relative tolerance. Default value is 0. + :type rtol: float, optional + :param err_msg: The error message to use if the assertion fails. + :type err_msg: str + """ + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle. + :type styles: list[tuple[str, str]] + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/enflame/include/triton/python/triton/tools/__init__.py b/third_party/enflame/include/triton/python/triton/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/include/triton/python/triton/tools/build_extern.py b/third_party/enflame/include/triton/python/triton/tools/build_extern.py new file mode 100644 index 000000000..8f0168d59 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/tools/build_extern.py @@ -0,0 +1,365 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + + header_str = "" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/third_party/enflame/include/triton/python/triton/tools/compile.py b/third_party/enflame/include/triton/python/triton/tools/compile.py new file mode 100644 index 000000000..7eed34389 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/tools/compile.py @@ -0,0 +1,162 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import triton +import triton.backends +from triton.backends.nvidia.driver import ty_to_cpp + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + +if __name__ == "__main__": + + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) + parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + args = parser.parse_args() + + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + for key, value in hints.items(): + if value == 1: + constants[kernel.arg_names[key[0]]] = value + signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} + for key in constants: + signature[key] = 'constexpr' + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{k}={v}" for k, v in constants.items()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16} + src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, options=opts) + if ccinfo.metadata.global_scratch_size > 0: + raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented") + + arg_names = [] + arg_types = [] + arg_names_not_1 = [] + arg_types_not_1 = [] + for i, arg_name in enumerate(kernel.arg_names): + if arg_name not in constants: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) + arg_names_not_1.append(arg_name) + arg_types_not_1.append(signature[arg_name]) + elif hints.get((i, ), None) == 1: + arg_names.append(arg_name) + arg_types.append("i32") + + # dump C stub code + suffix = '' + for i, ty in enumerate(signature.values()): + suffix += str(i) + if hints.get((i, ), None) == 1: + suffix += 'c' + if hints.get((i, ), None) == 16: + suffix += 'd' + func_name = '_'.join([out_name, sig_hash, suffix]) + asm = ccinfo.asm["cubin"] # store binary data once + hex_ = str(binascii.hexlify(asm))[2:-1] + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(asm), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]), + "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]), + "num_args": len(arg_names_not_1) + 1, + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + for ext in ['h', 'c']: + template_path = Path(__file__).parent / "extra" / "cuda" / f"compile.{ext}" + with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp: + fp.write(Path(template_path).read_text().format(**params)) diff --git a/third_party/enflame/include/triton/python/triton/tools/disasm.py b/third_party/enflame/include/triton/python/triton/tools/disasm.py new file mode 100644 index 000000000..002c4e9b5 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/tools/disasm.py @@ -0,0 +1,144 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import functools +import os +import re +import subprocess +import tempfile + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + +@functools.lru_cache() +def path_to_cuobjdump(): + from triton.backends.nvidia.compiler import _path_to_binary + return _path_to_binary("cuobjdump") + + +def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + if fun is None: + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) + else: + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/third_party/enflame/include/triton/python/triton/tools/experimental_descriptor.py b/third_party/enflame/include/triton/python/triton/tools/experimental_descriptor.py new file mode 100644 index 000000000..6077cab6f --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/tools/experimental_descriptor.py @@ -0,0 +1,32 @@ +import torch + +import triton + + +class TmaDescKernelParam: + TMA_DESC_SIZE = 128 + + def __init__(self, ptr, dims, block_dims, element_size): + self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu") + assert len(dims) == len(block_dims) + assert 1 <= len(dims) <= 2 + assert self.desc.data_ptr() % 64 == 0 + + if len(dims) == 1: + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, + self.desc.data_ptr()) + else: + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], + block_dims[1], element_size, self.desc.data_ptr()) + + # Return a CUtensorMap* pointer in host memory + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + +def create_1d_tma_descriptor(ptr, dim, block_dim, element_size): + return TmaDescKernelParam(ptr, [dim], [block_dim], element_size) + + +def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size): + return TmaDescKernelParam(ptr, [dim1, dim0], [block_dim1, block_dim0], element_size) diff --git a/third_party/enflame/include/triton/python/triton/tools/link.py b/third_party/enflame/include/triton/python/triton/tools/link.py new file mode 100644 index 000000000..75a1157a5 --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/tools/link.py @@ -0,0 +1,322 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence, Union + +from dataclasses import dataclass + + +def _exists(x): + return x is not None + + +class LinkerError(Exception): + pass + + +@dataclass +class KernelLinkerMeta: + orig_kernel_name: str + arg_names: Sequence[str] + arg_ctypes: Sequence[str] + sizes: Sequence[Union[int, None]] + sig_hash: str + triton_suffix: str + suffix: str + num_specs: int + """ number of specialized arguments """ + + +class HeaderParser: + + def __init__(self) -> None: + import re + + # [kernel_name, c signature] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] + self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") + # [(type, name)] + self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") + # [d|c] + self.arg_suffix = re.compile("[c,d]") + + self.kernels = defaultdict(list) + + def extract_linker_meta(self, header: str): + for ln in header.splitlines(): + if ln.startswith("//"): + m = self.linker_directives.match(ln) + if _exists(m): + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) + name, sig_hash, suffix = self._match_name(ker_name) + c_types, arg_names = self._match_c_sig(c_sig) + num_specs, sizes = self._match_suffix(suffix, c_sig) + self._add_kernel( + "_".join([name, algo_info]), + KernelLinkerMeta( + orig_kernel_name=name, + arg_names=arg_names, + arg_ctypes=c_types, + sizes=sizes, + sig_hash=sig_hash, + triton_suffix=suffix, + suffix=suffix, + num_specs=num_specs, + ), + ) + + def _match_name(self, ker_name: str): + m = self.kernel_name.match(ker_name) + if _exists(m): + name, sig_hash, suffix = m.group(1), m.group(2), m.group(3) + return name, sig_hash, suffix + raise LinkerError(f"{ker_name} is not a valid kernel name") + + def _match_c_sig(self, c_sig: str): + m = self.c_sig.findall(c_sig) + if len(m): + tys, args = [], [] + for ty, arg_name in m: + tys.append(ty) + args.append(arg_name) + return tys, args + + raise LinkerError(f"{c_sig} is not a valid argument signature") + + def _match_suffix(self, suffix: str, c_sig: str): + args = c_sig.split(",") + s2i = {"c": 1, "d": 16} + num_specs = 0 + sizes = [] + # scan through suffix, first find the index, + # then see if it is followed by d or c + for i in range(len(args)): + pos = suffix.find(str(i)) + if pos == -1: + raise LinkerError(f"{suffix} is not a valid kernel suffix") + pos += len(str(i)) + if self.arg_suffix.match(suffix, pos): + num_specs += 1 + sizes.extend([None] * (i - len(sizes))) + sizes.append(s2i[suffix[pos]]) + pos += 1 + if i < len(args) - 1: + suffix = suffix[pos:] + else: + sizes.extend([None] * (len(args) - len(sizes))) + return num_specs, sizes + + def _add_kernel(self, name: str, ker: KernelLinkerMeta): + if name in self.kernels: + last: KernelLinkerMeta = self.kernels[name][-1] + + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): + if cur != new_: + raise LinkerError( + f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" + ) + + self.kernels[name].append(ker) + + +def gen_signature_with_full_args(m): + return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)]) + + +def gen_signature(m): + arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1] + arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1] + sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)]) + return sig + + +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + return f""" +CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}); +void load_{name}(); +void unload_{name}(); + """ + + +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}); +CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + src = f"// launcher for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" + src += "\n" + + src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required + arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" + src += "\n" + src += " return CUDA_ERROR_INVALID_VALUE;\n" + src += "}\n" + + for mode in ["load", "unload"]: + src += f"\n// {mode} for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{name}() {{" + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n" + src += "}\n" + return src + + +desc = """ +Triton ahead-of-time linker: + +This program takes in header files generated by compile.py, and generates a +single entry-point responsible for dispatching the user's input to the right +kernel given the specializations that were compiled. + +Example usage: +python link.py /path/to/headers/*.h -o kernel_name +""" + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser(description=desc) + parser.add_argument( + "headers", + nargs="+", + help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", + ) + parser.add_argument("--out", "-o", type=Path, help="Out filename") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) + args = parser.parse_args() + + # metadata + parser = HeaderParser() + includes = [] + for header in args.headers: + h_path = Path(header) + h_str = h_path.read_text() + includes.append(h_path.name) + parser.extract_linker_meta(h_str) + + # generate headers + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) + with args.out.with_suffix(".h").open("w") as fp: + out = "#include \n" + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) + + # generate source + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) + with args.out.with_suffix(".c").open("w") as fp: + out = "" + out += "#include \n" + out += "#include \n" + out += "#include \n" + out += "\n" + out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel + fp.write(out) diff --git a/third_party/enflame/include/triton/python/triton/tools/mxfp.py b/third_party/enflame/include/triton/python/triton/tools/mxfp.py new file mode 100644 index 000000000..1b129c1ae --- /dev/null +++ b/third_party/enflame/include/triton/python/triton/tools/mxfp.py @@ -0,0 +1,301 @@ +""" +Helper classes for working with low precision floating point types that +align with the opencompute (OCP) microscaling (MX) specification. + * MXFP4Tensor: 4-bit E2M1 floating point data + * MXScaleTensor: 8-bit E8M0 floating point data +Reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +""" + +import torch + + +class MXFP4Tensor: + + def __init__(self, data=None, size=None, device=None): + """ + Tensor class for working with four bit E2M1 floating point data as defined by the + opencompute microscaling specification. + + + Parameters: + - data: A torch tensor of float32 numbers to convert to fp4e2m1 microscaling format. + - size: The size of the tensor to create. + - device: The device on which to create the tensor. + """ + self.device = device + if data is not None: + assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor" + self.device = data.device + self.data = self._from_float(data) + elif size is not None: + self.size = size if isinstance(size, tuple) else (size, ) + else: + raise ValueError("Either parameter data or size must be provided") + + def random(self): + S = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device) + E = torch.randint(0, 4, size=self.size, dtype=torch.uint8, device=self.device) + M = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device) + + self.data = ((S << 3) | (E << 1) | M).type(torch.uint8) + return self + + def to(self, dtype): + """ + Convert fp4e2m1 data to float32. + + Returns: + - A torch tensor of type dtype representing the fp4e2m1 data. + """ + assert dtype == torch.float32, "Currently only float32 is supported for fp4e2m1 to float conversion" + + data = self.data + S = ((data >> 3) & 0x1).type(dtype) + E = ((data >> 1) & 0x3).type(dtype) + M = (data & 0x1).type(dtype) + + # The MXF4 E2M1 spec defines 0bS000 as zero + value = torch.zeros_like(S) + is_zero = (E == 0) & (M == 0) + non_zero_mask = ~is_zero + if non_zero_mask.any(): + S_nz = S[non_zero_mask] + E_nz = E[non_zero_mask] + M_nz = M[non_zero_mask] + + sign = torch.pow(-1, S_nz) + # Normal and subnormal handling for the exponent and mantissa + exponent = torch.where(E_nz == 0, E_nz, E_nz - 1) + mantissa = torch.where(E_nz == 0, M_nz * 0.5, 1.0 + M_nz * 0.5) + value_nz = sign * torch.pow(2, exponent) * mantissa + + value[non_zero_mask] = value_nz + + # For zeros, the values must remain zero with the correct sign + value[is_zero & (S == 1)] *= -1 + return value.type(torch.float32) + + def _from_float(self, values): + """ + Convert float32 numbers to mxf4 e2m1 format. + * No encodings are reserved for Inf or NaN in mxf4. + * Conversion from float supports roundTiesToEven rounding mode. + * If a value exceeds the mxf4 representable range after rounding, + clamps to the maximum mxf4 magnitude, preserving the sign. + * If a value has magnitude less than the minimum subnormal magnitude + in mxf4 after rounding, converts to zero. + + Parameters: + - values: A torch tensor of float32 numbers to convert to fp4 format. + """ + S = torch.signbit(values).type(torch.uint8) + abs_values = torch.abs(values) + + is_zero = (abs_values == 0) + is_invalid = torch.isnan(values) | torch.isinf(values) + + # Enumerate all possible E2M1 exponent and mantissa values. We will + # use these to compare the distance between float32 and all possible + # E2M1 floats to find the nearest E2M1 representable value + E_bits = torch.tensor([0, 1, 2, 3], dtype=torch.uint8, device=self.device) + M_bits = torch.tensor([0, 1], dtype=torch.uint8, device=self.device) + + candidate_values = [] + candidate_E = [] + candidate_M = [] + + for E in E_bits: + if E == 0: + # Subnormals + exponent = 0 + for M in M_bits: + significand = M * 0.5 + value = significand * (2**exponent) + candidate_values.append(value) + candidate_E.append(E) + candidate_M.append(M) + else: + # Normals + exponent = E.item() - 1 + for M in M_bits: + significand = 1.0 + M * 0.5 + value = significand * (2**exponent) + candidate_values.append(value) + candidate_E.append(E) + candidate_M.append(M) + + candidates = torch.tensor(candidate_values, dtype=torch.float32, device=self.device) + candidate_E = torch.tensor(candidate_E, dtype=torch.uint8, device=self.device) + candidate_M = torch.tensor(candidate_M, dtype=torch.uint8, device=self.device) + + abs_values_flat = abs_values.view(-1) + N = abs_values_flat.shape[0] + abs_values_expanded = abs_values_flat.unsqueeze(1) + + # Clamp invalid values to the max e2m1 representable value + max_candidate_value = candidates.max().item() + abs_values_flat[is_invalid.view(-1)] = max_candidate_value + + # Compute distance between all abs_values and candidate e2m1 values + errors = torch.abs(abs_values_expanded - candidates.unsqueeze(0)) + + # To implement roundTiesToEven, we need to break ties by preferring + # even mantissas (M == 0). We do so by adding an epsilon bias to shift + # the closest candidate with an even mantissa closer to the float value + min_errors, _ = torch.min(errors, dim=1, keepdim=True) + is_tie = (errors == min_errors) + # More than one candidate has the min error for some float value + if is_tie.sum() > 1: + M_bits_expanded = candidate_M.unsqueeze(0).expand(N, -1) + tie_breaker = (M_bits_expanded == 0).type(torch.int32) + + errors = errors - (tie_breaker * 1e-6) + + best_indices = torch.argmin(errors, dim=1) + + E_selected = candidate_E[best_indices] + M_selected = candidate_M[best_indices] + E = E_selected.view(abs_values.shape) + M = M_selected.view(abs_values.shape) + + E[is_zero] = 0 + M[is_zero] = 0 + + return ((S << 3) | (E << 1) | M).type(torch.uint8) + + def to_packed_tensor(self, dim): + """ + Packs two e2m1 elements into a single uint8 along the specified dimension. + + Parameters: + - dim: The dimension along which to pack the elements. + + Returns: + - A torch tensor of dtype uint8 with two e2m1 elements packed into one uint8. + """ + data = self.data + assert 0 <= dim < data.ndim, \ + "The dimension to pack along is not within the range of tensor dimensions" + + size_along_dim = data.size(dim) + new_size_along_dim = (size_along_dim + 1) // 2 + + # If the size is odd, we pad the data along dim with zeros at the end + if size_along_dim % 2 != 0: + pad_sizes = [0] * (2 * data.ndim) + pad_index = (data.ndim - dim - 1) * 2 + 1 + pad_sizes[pad_index] = 1 + data = torch.nn.functional.pad(data, pad_sizes, mode='constant', value=0) + + new_shape = list(data.shape) + new_shape[dim] = new_size_along_dim + new_shape.insert(dim + 1, 2) # packed dimension of length 2 + data = data.reshape(*new_shape) + + low = data.select(dim + 1, 0) + high = data.select(dim + 1, 1) + packed = (high << 4) | low + + return packed + + def unpack_packed_tensor(self, packed_tensor, dim, original_shape): + """ + Unpacks a tensor where two fp4 elements are packed into a single uint8. + + Parameters: + - packed_tensor: The packed tensor + - dim: The dimension along which the tensor was packed. + - original_shape: The shape of the original tensor before packing. + + Returns: + - A tensor with the original data unpacked into uint8 elements containing one + fp4e2m1 element in the least significant bits. + """ + high = (packed_tensor >> 4) & 0xF + low = packed_tensor & 0xF + + stacked = torch.stack((low, high), dim=dim + 1) + + # Flatten along dim and dim+1 and then merge + shape = list(stacked.shape) + new_shape = shape[:dim] + [shape[dim] * 2] + shape[dim + 2:] + data = stacked.reshape(*new_shape) + + # Remove any padding + if original_shape[dim] % 2 != 0: + indices = [slice(None)] * data.ndim + indices[dim] = slice(0, original_shape[dim]) + data = data[tuple(indices)] + + return data.type(torch.uint8) + + +class MXScaleTensor: + + def __init__(self, data=None, size=None, device=None): + """ + Tensor class for working with microscaling E8M0 block scale factors. + + Parameters: + - data: A torch tensor of float32 numbers to convert to fp8e8m0 microscaling format. + - size: The size of the tensor to create. + - device: The device on which to create the tensor. + """ + self.device = device + if data is not None: + assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor" + self.device = data.device + self.data = self._from_float(data) + elif size is not None: + self.size = size if isinstance(size, tuple) else (size, ) + else: + raise ValueError("Either parameter data or size must be provided") + + def random(self, low=None, high=None): + """ + Generate random E8M0 data within a specified range. + * Excludes the NaN encoding (255). + """ + bias = 127 + + min_exponent = 0 if low is None else max(0, int(torch.log2(torch.tensor(low))) + bias) + max_exponent = 254 if high is None else min(254, max(0, int(torch.log2(torch.tensor(high))) + bias)) + assert min_exponent <= max_exponent, "Low must be less than or equal to high" + + E = torch.randint(min_exponent, max_exponent + 1, size=self.size, dtype=torch.uint8, device=self.device) + self.data = E + return self + + def to(self, dtype): + assert dtype == torch.float32, "Currently only float32 is supported for f8e8m0 to float conversion" + data = self.data.type(dtype) + is_nan = (data == 255) + e_biased = data.clone() + e_biased[is_nan] = 0 + e = e_biased - 127 + value = torch.pow(2.0, e) + value[is_nan] = torch.nan + return value.type(dtype) + + def _from_float(self, values): + """ + Convert float32 numbers to E8M0 format. + * Values <= 0, NaNs, and Infs are converted to the NaN encoding (255). + * Positive values are converted by computing the floor of log2(value) to get the exponent. + + Parameters: + - values: A torch tensor of float32 numbers to convert to E8M0 format. + """ + result = torch.empty_like(values, dtype=torch.uint8, device=self.device) + + is_invalid = torch.isnan(values) | torch.isinf(values) | (values <= 0) + result[is_invalid] = 255 + + valid_values = values[~is_invalid] + e = torch.floor(torch.log2(valid_values)) + e_biased = e + 127 + e_biased_int = e_biased.type(torch.int32) + e_biased_clamped = torch.clamp(e_biased_int, 0, 254) + result[~is_invalid] = e_biased_clamped.type(torch.uint8) + + return result diff --git a/third_party/enflame/include/triton/python/tutorials/01-vector-add.py b/third_party/enflame/include/triton/python/tutorials/01-vector-add.py new file mode 100644 index 000000000..e527e5fc7 --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/01-vector-add.py @@ -0,0 +1,135 @@ +""" +Vector Addition +=============== + +In this tutorial, you will write a simple vector addition using Triton. + +In doing so, you will learn about: + +* The basic programming model of Triton. + +* The `triton.jit` decorator, which is used to define Triton kernels. + +* The best practices for validating and benchmarking your custom ops against native reference implementations. + +""" + +# %% +# Compute Kernel +# -------------- + +import torch + +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# %% +# Let's also declare a helper function to (1) allocate the `z` tensor +# and (2) enqueue the above kernel with appropriate grid/block sizes: + + +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output. + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + # The SPMD launch grid denotes the number of kernel instances that run in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks: + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return output + + +# %% +# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device=DEVICE) +y = torch.rand(size, device=DEVICE) +output_torch = x + y +output_triton = add(x, y) +print(output_torch) +print(output_triton) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- +# +# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch. +# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom ops. +# for different problem sizes. + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['size'], # Argument names to use as an x-axis for the plot. + x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. + x_log=True, # x axis is logarithmic. + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=['triton', 'torch'], # Possible values for `line_arg`. + line_names=['Triton', 'Torch'], # Label name for the lines. + styles=[('blue', '-'), ('green', '-')], # Line styles. + ylabel='GB/s', # Label name for the y-axis. + plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(size, provider): + x = torch.rand(size, device=DEVICE, dtype=torch.float32) + y = torch.rand(size, device=DEVICE, dtype=torch.float32) + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) + gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) diff --git a/third_party/enflame/include/triton/python/tutorials/02-fused-softmax.py b/third_party/enflame/include/triton/python/tutorials/02-fused-softmax.py new file mode 100644 index 000000000..832718a19 --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/02-fused-softmax.py @@ -0,0 +1,235 @@ +""" +Fused Softmax +============= + +In this tutorial, you will write a fused softmax operation that is significantly faster +than PyTorch's native op for a particular class of matrices: those whose rows can fit in +the GPU's SRAM. + +In doing so, you will learn about: + +* The benefits of kernel fusion for bandwidth-bound operations. + +* Reduction operators in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. +# Let us consider instead the case of a simple (numerically stabilized) softmax operation: + +import torch + +import triton +import triton.language as tl +from triton.runtime import driver + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', + 'gfx90a', 'gfx908') + + +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +# %% +# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` +# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. +# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads +# X once and does all the necessary computations on-chip. +# Doing so would require reading and writing back only :math:`MN` bytes, so we could +# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). +# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically +# but, as we will see later, it is still far from ideal. + +# %% +# Compute Kernel +# -------------- +# +# Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, +# normalizes it and writes back the result to the output Y. +# +# Note that one important limitation of Triton is that each block must have a +# power-of-two number of elements, so we need to internally "pad" each row and guard the +# memory operations properly if we want to handle any possible input shapes: + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +# %% +# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. + +properties = driver.active.utils.get_device_properties(DEVICE.index) +NUM_SM = properties["multiprocessor_count"] +NUM_REGS = properties["max_num_regs"] +SIZE_SMEM = properties["max_shared_mem"] +WARP_SIZE = properties["warpSize"] +target = triton.runtime.driver.active.get_current_target() +kernels = {} + + +def softmax(x): + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 8 + + # Number of software pipelining stages. + num_stages = 4 if SIZE_SMEM > 200000 else 2 + + # Allocate output + y = torch.empty_like(x) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, + num_stages=num_stages, num_warps=num_warps, grid=(1, )) + kernel._init_handles() + n_regs = kernel.n_regs + size_smem = kernel.metadata.shared + if is_hip(): + # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. + # However, this is not always the case. In most cases all registers can be used as regular purpose registers. + # ISA SECTION (3.6.4 for CDNA3) + # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used + # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total + # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is + # not required to be equal numbers of both types. + if is_cdna(): + NUM_GPRS = NUM_REGS * 2 + + # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. + # When we divide this number with WARP_SIZE we get maximum number of waves that can + # execute on a CU (multi-processor) in parallel. + MAX_NUM_THREADS = properties["max_threads_per_sm"] + max_num_waves = MAX_NUM_THREADS // WARP_SIZE + occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps + else: + occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + occupancy = min(occupancy, SIZE_SMEM // size_smem) + num_programs = NUM_SM * occupancy + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages) + return y + + +# %% +# Unit Test +# --------- + +# %% +# We make sure that we test our kernel on a matrix with an irregular number of rows and columns. +# This will allow us to verify that our padding mechanism works. + +torch.manual_seed(0) +x = torch.randn(1823, 781, device=DEVICE) +y_triton = softmax(x) +y_torch = torch.softmax(x, axis=1) +assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) + +# %% +# As expected, the results are identical. + +# %% +# Benchmark +# --------- +# +# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. +# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], # argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` + line_arg='provider', # argument name whose value corresponds to a different line in the plot + line_vals=['triton', 'torch'], # possible values for `line_arg`` + line_names=[ + "Triton", + "Torch", + ], # label name for the lines + styles=[('blue', '-'), ('green', '-')], # line styles + ylabel="GB/s", # label name for the y-axis + plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. + args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` + )) +def benchmark(M, N, provider): + x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) + stream = getattr(torch, DEVICE.type).Stream() + getattr(torch, DEVICE.type).set_stream(stream) + if provider == 'torch': + ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) + if provider == 'triton': + ms = triton.testing.do_bench(lambda: softmax(x)) + gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms) + + +benchmark.run(show_plots=True, print_data=True) + +# %% +# In the above plot, we can see that: +# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. +# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. +# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. diff --git a/third_party/enflame/include/triton/python/tutorials/03-matrix-multiplication.py b/third_party/enflame/include/triton/python/tutorials/03-matrix-multiplication.py new file mode 100644 index 000000000..7b838b17a --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/03-matrix-multiplication.py @@ -0,0 +1,443 @@ +""" +Matrix Multiplication +===================== +In this tutorial, you will write a very short high-performance FP16 matrix multiplication kernel that achieves +performance on par with cuBLAS or rocBLAS. + +You will specifically learn about: + +* Block-level matrix multiplications. + +* Multi-dimensional pointer arithmetic. + +* Program re-ordering for improved L2 cache hit rate. + +* Automatic performance tuning. + +""" + +# %% +# Motivations +# ----------- +# +# Matrix multiplications are a key building block of most modern high-performance computing systems. +# They are notoriously hard to optimize, hence their implementation is generally done by +# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +# Unfortunately, these libraries are often proprietary and cannot be easily customized +# to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). +# In this tutorial, you will learn how to implement efficient matrix multiplications by +# yourself with Triton, in a way that is easy to customize and extend. +# +# Roughly speaking, the kernel that we will write will implement the following blocked +# algorithm to multiply a (M, K) by a (K, N) matrix: +# +# .. code-block:: python +# +# # Do in parallel +# for m in range(0, M, BLOCK_SIZE_M): +# # Do in parallel +# for n in range(0, N, BLOCK_SIZE_N): +# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) +# for k in range(0, K, BLOCK_SIZE_K): +# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] +# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] +# acc += dot(a, b) +# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc +# +# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. + +# %% +# Compute Kernel +# -------------- +# +# The above algorithm is, actually, fairly straightforward to implement in Triton. +# The main difficulty comes from the computation of the memory locations at which blocks +# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need +# multi-dimensional pointer arithmetic. +# +# Pointer Arithmetic +# ~~~~~~~~~~~~~~~~~~~ +# +# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given +# by :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`. +# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and +# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: +# +# .. code-block:: python +# +# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); +# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1); +# +# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as the following +# code. Also note that we need an extra modulo to handle the case where :code:`M` is not a multiple of +# :code:`BLOCK_SIZE_M` or :code:`N` is not a multiple of :code:`BLOCK_SIZE_N`, in which case we can pad the data with +# some useless values, which will not contribute to the results. For the :code:`K` dimension, we will handle that later +# using masking load semantics. +# +# .. code-block:: python +# +# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M +# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) +# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) +# +# And then updated in the inner loop as follows: +# +# .. code-block:: python +# +# a_ptrs += BLOCK_SIZE_K * stride_ak; +# b_ptrs += BLOCK_SIZE_K * stride_bk; +# +# +# L2 Cache Optimizations +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]` +# block of :code:`C`. +# It is important to remember that the order in which these blocks are computed does +# matter, since it affects the L2 cache hit rate of our program, and unfortunately, a +# simple row-major ordering +# +# .. code-block:: Python +# +# pid = tl.program_id(axis=0) +# grid_n = tl.cdiv(N, BLOCK_SIZE_N) +# pid_m = pid // grid_n +# pid_n = pid % grid_n +# +# is just not going to cut it. +# +# One possible solution is to launch blocks in an order that promotes data reuse. +# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before +# switching to the next column: +# +# .. code-block:: python +# +# # Program ID +# pid = tl.program_id(axis=0) +# # Number of program ids along the M axis +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# # Number of programs ids along the N axis +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# # Number of programs in group +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# # Id of the group this program is in +# group_id = pid // num_pid_in_group +# # Row-id of the first program in the group +# first_pid_m = group_id * GROUP_SIZE_M +# # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# # *Within groups*, programs are ordered in a column-major order +# # Row-id of the program in the *launch grid* +# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) +# # Col-id of the program in the *launch grid* +# pid_n = (pid % num_pid_in_group) // group_size_m +# +# For example, in the following matmul where each matrix is 9 blocks by 9 blocks, +# we can see that if we compute the output in row-major ordering, we need to load 90 +# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped +# ordering, we only need to load 54 blocks. +# +# .. image:: grouped_vs_row_major_ordering.png +# +# In practice, this can improve the performance of our matrix multiplication kernel by +# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). +# + +# %% +# Final Result +# ------------ + +import torch + +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip_mi200(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + +def get_cuda_autotune_config(): + return [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + # Good config for fp8 inputs. + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4) + ] + + +def get_hip_autotune_config(): + return [ + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=4, num_stages=2), + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=8, num_stages=2), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=2), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, + num_warps=4, num_stages=2), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, + num_warps=4, num_stages=2), + ] + + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). + +torch.manual_seed(0) +a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) +b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) +triton_output = matmul(a, b) +torch_output = torch.matmul(a, b) +print(f"triton_output_with_fp16_inputs={triton_output}") +print(f"torch_output_with_fp16_inputs={torch_output}") +# Bigger tolerance for AMD MI200 devices. +# MI200 devices use reduced precision fp16 and bf16 and flush input and +# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices +rtol = 1e-2 if is_hip_mi200() else 0 +if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ Triton and Torch match") +else: + print("❌ Triton and Torch differ") + +TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2") +if TORCH_HAS_FP8 and is_cuda(): + torch.manual_seed(0) + a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) + b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) + a = a.to(torch.float8_e5m2) + # pre-transpose b for efficiency. + b = b.T + b = b.to(torch.float8_e5m2) + triton_output = matmul(a, b) + torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16)) + print(f"triton_output_with_fp8_inputs={triton_output}") + print(f"torch_output_with_fp8_inputs={torch_output}") + if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0): + print("✅ Triton and Torch match") + else: + print("❌ Triton and Torch differ") + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of cuBLAS or rocBLAS. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + +ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS' + +configs = [] +for fp8_inputs in [False, True]: + if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): + continue + configs.append( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. + line_vals=["triton"] if fp8_inputs else [ref_lib.lower(), "triton"], # Label name for the lines + line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"], # Line styles + styles=[("green", "-"), ("blue", "-")], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance-" + + ("fp16" if not fp8_inputs else "fp8"), # Name for the plot, used also as a file name for saving the plot. + args={"fp8_inputs": fp8_inputs}, + )) + + +@triton.testing.perf_report(configs) +def benchmark(M, N, K, provider, fp8_inputs): + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16) + if TORCH_HAS_FP8 and fp8_inputs: + a = a.to(torch.float8_e5m2) + b = b.T + b = b.to(torch.float8_e5m2) + quantiles = [0.5, 0.2, 0.8] + if provider == ref_lib.lower(): + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +benchmark.run(show_plots=True, print_data=True) diff --git a/third_party/enflame/include/triton/python/tutorials/04-low-memory-dropout.py b/third_party/enflame/include/triton/python/tutorials/04-low-memory-dropout.py new file mode 100644 index 000000000..3dd84da47 --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/04-low-memory-dropout.py @@ -0,0 +1,175 @@ +""" +Low-Memory Dropout +================== + +In this tutorial, you will write a memory-efficient implementation of dropout whose state +will be composed of a single int32 seed. This differs from more traditional implementations of dropout, +whose state is generally composed of a bit mask tensor of the same shape as the input. + +In doing so, you will learn about: + +* The limitations of naive implementations of Dropout with PyTorch. + +* Parallel pseudo-random number generation in Triton. + +""" + +# %% +# Baseline +# -------- +# +# The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance +# of deep neural networks in low-data regime (i.e. regularization). +# +# It takes a vector as input and produces a vector of the same shape as output. Each scalar in the +# output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input. +# This forces the network to perform well even when only :math:`1 - p` scalars from the input are available. +# +# At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would +# increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease +# in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which +# keeps the norm consistent regardless of the dropout probability. +# +# Let's first take a look at the baseline implementation. + +import tabulate +import torch + +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def _dropout( + x_ptr, # pointer to the input + x_keep_ptr, # pointer to a mask of 0s and 1s + output_ptr, # pointer to the output + n_elements, # number of elements in the `x` tensor + p, # probability that an element of `x` is changed to zero + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + x_keep = tl.load(x_keep_ptr + offsets, mask=mask) + # The line below is the crucial part, described in the paragraph above! + output = tl.where(x_keep, x / (1 - p), 0.0) + # Write-back output + tl.store(output_ptr + offsets, output, mask=mask) + + +def dropout(x, x_keep, p): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) + return output + + +# Input tensor +x = torch.randn(size=(10, ), device=DEVICE) +# Dropout mask +p = 0.5 +x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32) +# +output = dropout(x, x_keep=x_keep, p=p) +print(tabulate.tabulate([ + ["input"] + x.tolist(), + ["keep mask"] + x_keep.tolist(), + ["output"] + output.tolist(), +])) + +# %% +# Seeded dropout +# -------------- +# +# The above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly +# we need to store the dropout mask for backpropagation. Secondly, dropout state management can get +# very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in +# https://pytorch.org/docs/stable/checkpoint.html). In this tutorial we'll describe an alternative implementation +# that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management +# of persisting randomness across multiple invocations of the kernel. +# +# Pseudo-random number generation in Triton is simple! In this tutorial we will use the +# :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` +# values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides +# other :ref:`random number generation strategies`. +# +# .. note:: +# Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_). +# +# Let's put it all together. + + +@triton.jit +def _seeded_dropout( + x_ptr, + output_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + + +def seeded_dropout(x, p, seed): + output = torch.empty_like(x) + assert x.is_contiguous() + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) + return output + + +x = torch.randn(size=(10, ), device=DEVICE) +# Compare this to the baseline - dropout mask is never instantiated! +output = seeded_dropout(x, p=0.5, seed=123) +output2 = seeded_dropout(x, p=0.5, seed=123) +output3 = seeded_dropout(x, p=0.5, seed=512) + +print( + tabulate.tabulate([ + ["input"] + x.tolist(), + ["output (seed = 123)"] + output.tolist(), + ["output (seed = 123)"] + output2.tolist(), + ["output (seed = 512)"] + output3.tolist(), + ])) + +# %% +# Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same! +# If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you +# to explore the `python/triton/language/random.py`! + +# %% +# Exercises +# --------- +# +# 1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row. +# 2. Add support for striding. +# 3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix on the fly each time using a seed. + +# %% +# References +# ---------- +# +# .. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011 +# .. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014 diff --git a/third_party/enflame/include/triton/python/tutorials/05-layer-norm.py b/third_party/enflame/include/triton/python/tutorials/05-layer-norm.py new file mode 100644 index 000000000..5be07a9ea --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/05-layer-norm.py @@ -0,0 +1,376 @@ +""" +Layer Normalization +==================== +In this tutorial, you will write a high-performance layer normalization +kernel that runs faster than the PyTorch implementation. + +In doing so, you will learn about: + +* Implementing backward pass in Triton. + +* Implementing parallel reduction in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# The *LayerNorm* operator was first introduced in [BA2016]_ as a way to improve the performance +# of sequential models (e.g., Transformers) or neural networks with small batch size. +# It takes a vector :math:`x` as input and produces a vector :math:`y` of the same shape as output. +# The normalization is performed by subtracting the mean and dividing by the standard deviation of :math:`x`. +# After the normalization, a learnable linear transformation with weights :math:`w` and biases :math:`b` is applied. +# The forward pass can be expressed as follows: +# +# .. math:: +# y = \frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} } * w + b +# +# where :math:`\epsilon` is a small constant added to the denominator for numerical stability. +# Let’s first take a look at the forward pass implementation. + +import torch + +import triton +import triton.language as tl + +try: + # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it + # should not be added to extras_require in setup.py. + import apex + HAS_APEX = True +except ModuleNotFoundError: + HAS_APEX = False + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +# %% +# Backward pass +# ------------- +# +# The backward pass for the layer normalization operator is a bit more involved than the forward pass. +# Let :math:`\hat{x}` be the normalized inputs :math:`\frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} }` before the linear transformation, +# the Vector-Jacobian Products (VJP) :math:`\nabla_{x}` of :math:`x` are given by: +# +# .. math:: +# \nabla_{x} = \frac{1}{\sigma}\Big( \nabla_{y} \odot w - \underbrace{ \big( \frac{1}{N} \hat{x} \cdot (\nabla_{y} \odot w) \big) }_{c_1} \odot \hat{x} - \underbrace{ \frac{1}{N} \nabla_{y} \cdot w }_{c_2} \Big) +# +# where :math:`\odot` denotes the element-wise multiplication, :math:`\cdot` denotes the dot product, and :math:`\sigma` is the standard deviation. +# :math:`c_1` and :math:`c_2` are intermediate constants that improve the readability of the following implementation. +# +# For the weights :math:`w` and biases :math:`b`, the VJPs :math:`\nabla_{w}` and :math:`\nabla_{b}` are more straightforward: +# +# .. math:: +# \nabla_{w} = \nabla_{y} \odot \hat{x} \quad \text{and} \quad \nabla_{b} = \nabla_{y} +# +# Since the same weights :math:`w` and biases :math:`b` are used for all rows in the same batch, their gradients need to sum up. +# To perform this step efficiently, we use a parallel reduction strategy: each kernel instance accumulates +# partial :math:`\nabla_{w}` and :math:`\nabla_{b}` across certain rows into one of :math:`\text{GROUP_SIZE_M}` independent buffers. +# These buffers stay in the L2 cache and then are further reduced by another function to compute the actual :math:`\nabla_{w}` and :math:`\nabla_{b}`. +# +# Let the number of input rows :math:`M = 4` and :math:`\text{GROUP_SIZE_M} = 2`, +# here's a diagram of the parallel reduction strategy for :math:`\nabla_{w}` (:math:`\nabla_{b}` is omitted for brevity): +# +# .. image:: parallel_reduction.png +# +# In Stage 1, the rows of X that have the same color share the same buffer and thus a lock is used to ensure that only one kernel instance writes to the buffer at a time. +# In Stage 2, the buffers are further reduced to compute the final :math:`\nabla_{w}` and :math:`\nabla_{b}`. +# In the following implementation, Stage 1 is implemented by the function :code:`_layer_norm_bwd_dx_fused` and Stage 2 is implemented by the function :code:`_layer_norm_bwd_dwdb`. + + +@triton.jit +def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient + DY, # pointer to the output gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + X, # pointer to the input + W, # pointer to the weights + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + Lock, # pointer to the lock + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + # Map the program id to the elements of X, DX, and DY it should compute. + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N + X += row * stride + DY += row * stride + DX += row * stride + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_SIZE_M + Lock += lock_id + Count = Lock + GROUP_SIZE_M + DW = DW + lock_id * N + cols + DB = DB + lock_id * N + cols + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd + wdy = w * dy + xhat = tl.where(mask, xhat, 0.) + wdy = tl.where(mask, wdy, 0.) + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + # Accumulate partial sums for dw/db + partial_dw = (dy * xhat).to(w.dtype) + partial_db = (dy).to(w.dtype) + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + # First store doesn't accumulate + if count == 0: + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(DW, mask=mask) + partial_db += tl.load(DB, mask=mask) + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + # Release the lock + tl.atomic_xchg(Lock, 0) + + +@triton.jit +def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + FINAL_DW, # pointer to the weights gradient + FINAL_DB, # pointer to the biases gradient + M, # GROUP_SIZE_M + N, # number of columns + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + # Map the program id to the elements of DW and DB it should compute. + pid = tl.program_id(0) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Iterate through the rows of DW and DB to sum the partial sums. + for i in range(0, M, BLOCK_SIZE_M): + rows = i + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.) + db += tl.load(DB + offs, mask=mask, other=0.) + # Write the final sum to the output. + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) + tl.store(FINAL_DB + cols, sum_db, mask=cols < N) + + +# %% +# Benchmark +# --------- +# +# We can now compare the performance of our kernel against that of PyTorch. +# Here we focus on inputs that have Less than 64KB per feature. +# Specifically, one can set :code:`'mode': 'backward'` to benchmark the backward pass. + + +class LayerNorm(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, normalized_shape, weight, bias, eps): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + ctx.save_for_backward(x, weight, bias, mean, rstd) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.eps = eps + return y + + @staticmethod + def backward(ctx, dy): + x, w, b, m, v = ctx.saved_tensors + # heuristics for amount of parallel reduction stream for DW/DB + N = w.shape[0] + GROUP_SIZE_M = 64 + if N <= 8192: GROUP_SIZE_M = 96 + if N <= 4096: GROUP_SIZE_M = 128 + if N <= 1024: GROUP_SIZE_M = 256 + # allocate output + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device) + _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) + _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) + dw = torch.empty((N, ), dtype=w.dtype, device=w.device) + db = torch.empty((N, ), dtype=w.dtype, device=w.device) + dx = torch.empty_like(dy) + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + _layer_norm_bwd_dx_fused[(M, )]( # + dx, dy, _dw, _db, x, w, m, v, locks, # + x_arg.stride(0), N, # + BLOCK_SIZE_N=ctx.BLOCK_SIZE, # + GROUP_SIZE_M=GROUP_SIZE_M, # + num_warps=ctx.num_warps) + grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + # accumulate partial sums in separate kernel + _layer_norm_bwd_dwdb[grid]( + _dw, _db, dw, db, min(GROUP_SIZE_M, M), N, # + BLOCK_SIZE_M=32, # + BLOCK_SIZE_N=128, num_ctas=1) + return dx, None, dw, db, None + + +layer_norm = LayerNorm.apply + + +def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + # backward pass (triton) + y_tri.backward(dy, retain_graph=True) + dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]] + x.grad, weight.grad, bias.grad = None, None, None + # backward pass (torch) + y_ref.backward(dy, retain_graph=True) + dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]] + # compare + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0) + assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0) + assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], + x_vals=[512 * i for i in range(2, 32)], + line_arg='provider', + line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []), + line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), + styles=[('blue', '-'), ('green', '-'), ('orange', '-')], + ylabel='GB/s', + plot_name='layer-norm-backward', + args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, + )) +def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + quantiles = [0.5, 0.2, 0.8] + + def y_fwd(): + + if provider == "triton": + return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + if provider == "torch": + return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + if provider == "apex": + apex_layer_norm = (apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)) + return apex_layer_norm(x) # noqa: F811, E704 + + # forward pass + if mode == 'forward': + gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) + # backward pass + if mode == 'backward': + y = y_fwd() + gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: F811, E704 + ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, + grad_to_none=[x], rep=500) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +test_layer_norm(1151, 8192, torch.float16) +bench_layer_norm.run(save_path='.', print_data=True) + +# %% +# References +# ---------- +# +# .. [BA2016] Jimmy Lei Ba and Jamie Ryan Kiros and Geoffrey E. Hinton, "Layer Normalization", Arxiv 2016 diff --git a/third_party/enflame/include/triton/python/tutorials/06-fused-attention.py b/third_party/enflame/include/triton/python/tutorials/06-fused-attention.py new file mode 100644 index 000000000..e65635c49 --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/06-fused-attention.py @@ -0,0 +1,899 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch +import triton.tools.experimental_descriptor + +import triton +import triton.language as tl + +# ENABLE_LHS_TO_TMEM is an experimental environment variable for Blackwell. +# If it is set to 1 it can improve performance of Blackwell attention. However, +# it defaults to 0 as it is known to cause correctness issues outside of the +# _attn_fwd_tma kernel below. + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def supports_tma(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) + +if HAS_TMA_DESC: + print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", ) +else: + print("TMA benchmarks will be running without grid constant TMA descriptor.", ) + + +# TmaAutoTuneHelper used in htyu's PR #5622 +class TmaAutoTuneHelper: + + # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 + class KernelParamWrapper: + + def __init__(self, desc): + self.desc = desc + + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + TMA_SIZE = 128 + + def __init__(self): + self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor) + self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor) + if HAS_TMA_DESC: + self.descriptors = {} + else: + self.cuda_descriptors = {} + + # Call this method outside of the lambda function for grid size + def init_tma_descriptor(self, name): + if HAS_TMA_DESC: + self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8) + else: + self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8) + + # Call this method inside the lambda function for grid size + def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr()) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr()) + desc_x.copy_(buf_x, non_blocking=True) + + # Call this method inside the lambda function for grid size + def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()) + desc_x.copy_(buf_x, non_blocking=True) + + def get_tma_descriptor_kernel_param(self, name): + if HAS_TMA_DESC: + assert self.descriptors[name] is not None + return self.KernelParamWrapper(self.descriptors[name]) + else: + assert self.cuda_descriptors[name] is not None + return self.cuda_descriptors[name] + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, # + K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, fp8_v: tl.constexpr): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + if fp8_v: + p = p.to(tl.float8e5) + else: + p = p.to(tl.float16) + acc = tl.dot(p, v, acc) + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd_inner_tma(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype: tl.constexpr, start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + offsetkv_y = offset_y + lo + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl._experimental_descriptor_load(desc_k, [offsetkv_y, 0], [BLOCK_N, HEAD_DIM], dtype).T + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl._experimental_descriptor_load(desc_v, [offsetkv_y, 0], [BLOCK_N, HEAD_DIM], dtype) + p = p.to(dtype) + # note that this non transposed v for FP8 is only supported on Blackwell + acc = tl.dot(p, v, acc) + # update m_i and l_i + m_i = m_ij + offsetkv_y += BLOCK_N + return acc, l_i, m_i + + +# We don't run auto-tuning every time to keep the tutorial fast. Keeping +# the code below and commenting out the equivalent parameters is convenient for +# re-tuning. +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ + for BM in [64, 128]\ + for BN in [32, 64]\ + for s in ([1] if is_hip() else [3, 4, 7])\ + for w in [4, 8]\ +] + + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + + +@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) +@triton.jit +def _attn_fwd(Q, K, V, sm_scale, M, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # + ): + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + # block pointers + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(HEAD_DIM, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +# We don't run auto-tuning every time to keep the tutorial fast. Keeping +# the code below and commenting out the equivalent parameters is convenient for +# re-tuning. +configs_tma = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ + for BM in [64, 128]\ + for BN in [32, 64, 128]\ + for s in [2, 3, 4, 6]\ + for w in [4, 8]\ +] + + +def keep_tma(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + if (torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8): + return False + return True + + +@triton.autotune(configs=list(filter(keep_tma, configs_tma)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"]) +@triton.jit +def _attn_fwd_tma(sm_scale, M, # + Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + STAGE: tl.constexpr # + ): + dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + offset_y = off_z + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = tl._experimental_descriptor_load(desc_q, [qo_offset_y, 0], [BLOCK_M, HEAD_DIM], dtype) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner_tma(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype, start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc, l_i, m_i = _attn_fwd_inner_tma(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype, start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, # + ) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + tl._experimental_descriptor_store(desc_o, acc.to(dtype), [qo_offset_y, 0]) + + +@triton.jit +def _attn_bwd_preprocess(O, DO, # + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) + off_n = tl.arange(0, HEAD_DIM) + # load + o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hz * N_CTX + off_m, delta) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + # Filled in by the wrapper. + start_n, start_m, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq(dq, q, K, V, # + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=True # + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=False # + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * BLOCK_N2, num_steps, # + MASK=False # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, USE_TMA=True): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + # Tuning for AMD target + if is_hip(): + waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 + extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} + + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + if USE_TMA and supports_tma() and not (torch.cuda.get_device_capability()[0] == 9 + and q.dtype == torch.float8_e5m2): + # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor + y_dim = q.shape[0] * q.shape[1] * q.shape[2] + + desc_helper = TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("q") + desc_helper.init_tma_descriptor("v") + desc_helper.init_tma_descriptor("k") + desc_helper.init_tma_descriptor("o") + + def grid(META): + nonlocal desc_helper + + desc_helper.fill_2d_tma_descriptor("q", q.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_M"], HEAD_DIM_K, + q.element_size()) + + desc_helper.fill_2d_tma_descriptor("v", v.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_N"], HEAD_DIM_K, + v.element_size()) + + desc_helper.fill_2d_tma_descriptor("k", k.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_N"], HEAD_DIM_K, + k.element_size()) + + desc_helper.fill_2d_tma_descriptor("o", o.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_M"], HEAD_DIM_K, + o.element_size()) + + return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + + desc_q = desc_helper.get_tma_descriptor_kernel_param("q") + desc_v = desc_helper.get_tma_descriptor_kernel_param("v") + desc_k = desc_helper.get_tma_descriptor_kernel_param("k") + desc_o = desc_helper.get_tma_descriptor_kernel_param("o") + + ctx.grid = grid + _attn_fwd_tma[grid]( + sm_scale, M, # + q.shape[0], q.shape[1], # + desc_q, desc_k, desc_v, desc_o, # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + FP8_OUTPUT=q.dtype == torch.float8_e5m2, # + STAGE=stage, # + **extra_kern_args) + else: + grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + ctx.grid = grid + _attn_fwd[grid]( + q, k, v, sm_scale, M, o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + **extra_kern_args) + + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply + + +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) +@pytest.mark.parametrize("causal", [True]) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, causal, sm_scale).half() + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + rtol = 0.0 + # Relative tolerance workaround for known hardware limitation of MI200 GPU. + # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": + rtol = 1e-2 + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) + assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) + assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) + + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 +# vary seq length for fixed head and batch=4 +configs = [] +for mode in ["fwd", "bwd"]: + for causal in [True, False]: + if mode == "bwd" and not causal: + continue + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + + (["flash"] if HAS_FLASH else []), + line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="TFLOPS", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "HEAD_DIM": HEAD_DIM, + "mode": mode, + "causal": causal, + }, + )) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE): + assert mode in ["fwd", "bwd"] + dtype = torch.float16 + if "triton" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + if mode == "fwd" and "fp8" in provider: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() + v = v.permute(0, 1, 3, 2) + v = v.to(torch.float8_e5m2) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, causal, sm_scale) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, causal=causal) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn) + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + return total_flops * 1e-12 / (ms * 1e-3) + + +if __name__ == "__main__": + # only works on post-Ampere GPUs right now + bench_flash_attention.run(save_path=".", print_data=True) diff --git a/third_party/enflame/include/triton/python/tutorials/07-extern-functions.py b/third_party/enflame/include/triton/python/tutorials/07-extern-functions.py new file mode 100644 index 000000000..800563701 --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/07-extern-functions.py @@ -0,0 +1,99 @@ +""" +Libdevice (`tl.extra.libdevice`) function +============================== +Triton can invoke a custom function from an external library. +In this example, we will use the `libdevice` library to apply `asin` on a tensor. + +Please refer to `CUDA libdevice-users-guide `_ and/or `HIP device-lib source code `_ regarding the semantics of all available libdevice functions. + +In `libdevice.py`, we try to aggregate functions with the same computation but different data types together. +For example, both `__nv_asin` and `__nv_asinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. +Triton automatically selects the correct underlying device function to invoke based on input and output types. +""" + +# %% +# asin Kernel +# ------------ + +import torch + +import triton +import triton.language as tl +import inspect +import os +from triton.language.extra import libdevice + +from pathlib import Path + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + + +# %% +# Using the default libdevice library path +# ----------------------------------------- +# We can use the default libdevice library path encoded in `triton/language/math.py` + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device=DEVICE) +output_triton = torch.zeros(size, device=DEVICE) +output_torch = torch.asin(x) +assert x.is_cuda and output_triton.is_cuda +n_elements = output_torch.numel() +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) +print(output_torch) +print(output_triton) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + +# %% +# Customize the libdevice library path +# ------------------------------------- +# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel. +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +current_file = inspect.getfile(inspect.currentframe()) +current_dir = Path(os.path.dirname(os.path.abspath(current_file))) + +if is_cuda(): + libdir = current_dir.parent.parent / 'third_party/nvidia/backend/lib' + extern_libs = {'libdevice': str(libdir / 'libdevice.10.bc')} +elif is_hip(): + libdir = current_dir.parent.parent / 'third_party/amd/backend/lib' + extern_libs = {} + libs = ["ocml", "ockl"] + for lib in libs: + extern_libs[lib] = str(libdir / f'{lib}.bc') +else: + raise RuntimeError('unknown backend') + +output_triton = torch.empty_like(x) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, extern_libs=extern_libs) +print(output_torch) +print(output_triton) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') diff --git a/third_party/enflame/include/triton/python/tutorials/08-grouped-gemm.py b/third_party/enflame/include/triton/python/tutorials/08-grouped-gemm.py new file mode 100644 index 000000000..78e6ff566 --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/08-grouped-gemm.py @@ -0,0 +1,565 @@ +""" +Group GEMM +============================ +This group gemm kernel launches a fixed number of CTA to compute a group +of gemms. The scheduling is static and we do it on device. +""" + +# Copyright (c) 2023 - 2025 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Optional +import torch + +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def supports_tma(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def num_sms(): + if is_cuda(): + return torch.cuda.get_device_properties("cuda").multi_processor_count + return 148 + + +@triton.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64, + 'NUM_SM': num_sms(), + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64, + 'NUM_SM': num_sms(), + }), + ], + key=['group_size'], +) +@triton.jit +def grouped_matmul_kernel( + # device tensor of matrices pointers + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + # device tensor of gemm sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + group_gemm_sizes, + # device tensor of leading dimension sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + g_lds, + # number of gemms + group_size, + # number of virtual SM + NUM_SM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + # get the gemm size of the current problem + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + # iterate through the tiles in the current gemm problem + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): + # pick up a tile from the current gemm problem + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + # figure out tile coordinates + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + # do regular gemm here + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + # hint to Triton compiler to do proper loop pipelining + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + # assume full tile for now + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + # assumes full tile for now + tl.store(c_ptrs, c) + + # go to the next tile by advancing NUM_SM + tile_idx += NUM_SM + + # get ready to go to the next gemm problem + last_problem_end = last_problem_end + num_tiles + + +def group_gemm_fn(group_A, group_B): + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[0] + M, K = A.shape + K, N = B.shape + C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + + # note these are device tensors + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) + # we use a fixed number of CTA, and it's auto-tunable + grid = lambda META: (META['NUM_SM'], ) + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + group_size, + ) + + return group_C + + +tma_configs = [ + triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, 'BLOCK_SIZE_K' : BK}, num_stages=s, num_warps=w) \ + for BM in [128]\ + for BN in [128, 256]\ + for BK in [64, 128]\ + for s in ([3, 4])\ + for w in [4, 8]\ +] + + +@triton.autotune( + tma_configs, + key=['group_a_ptrs', 'group_b_ptrs', 'gropup_c_ptrs', 'group_size'], +) +@triton.jit +def grouped_matmul_tma_kernel( + # device tensor of matrices pointers + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + # device tensor of gemm sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + group_gemm_sizes, + # device tensor of leading dimension sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + g_lds, + # number of gemms + group_size, + # number of virtual SM + NUM_SM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + # is the output FP8 or FP16 + FP8: tl.constexpr, +): + dtype = tl.float8e4nv if FP8 else tl.float16 + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + # get the gemm size of the current problem + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: + # pick up a tile from the current gemm problem + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype)) + + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[gm, gk], + strides=[lda, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr, + shape=[gn, gk], + strides=[ldb, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr, + shape=[gm, gn], + strides=[ldc, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + + # iterate through the tiles in the current gemm problem + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): + k = gk + # figure out tile coordinates + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + # do regular gemm here + offs_am = tile_m_idx * BLOCK_SIZE_M + offs_bn = tile_n_idx * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + a = a_desc.load([offs_am, kk * BLOCK_SIZE_K]) + b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K]) + accumulator += tl.dot(a, b.T) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + offs_cn = tile_n_idx * BLOCK_SIZE_N + + c = accumulator.to(dtype) + c_desc.store([offs_cm, offs_cn], c) + + # go to the next tile by advancing NUM_SM + tile_idx += NUM_SM + + # get ready to go to the next gemm problem + last_problem_end = last_problem_end + num_tiles + + +def group_gemm_tma_fn(group_A, group_B): + + assert supports_tma() + + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[1] + M, K = A.shape + N, K = B.shape + C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + # note these are device tensors + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) + + # we use a fixed number of CTA, and it's auto-tunable + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + grid = lambda META: (META['NUM_SM'], ) + grouped_matmul_tma_kernel[grid](d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, + FP8=torch.float8_e4m3fn == group_A[0].dtype, NUM_SM=num_sms()) + return group_C + + +group_m = [1024, 512, 256, 128] +group_n = [1024, 512, 256, 128] +group_k = [1024, 512, 256, 128] +group_A = [] +group_B = [] +group_B_T = [] +assert len(group_m) == len(group_n) +assert len(group_n) == len(group_k) +group_size = len(group_m) +for i in range(group_size): + M = group_m[i] + N = group_n[i] + K = group_k[i] + A = torch.rand((M, K), device=DEVICE, dtype=torch.float16) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float16) + B_T = B.T.contiguous() + group_A.append(A) + group_B.append(B) + group_B_T.append(B_T) + +tri_out = group_gemm_fn(group_A, group_B) +ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] +for i in range(group_size): + assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0) + +if supports_tma(): + tri_tma_out = group_gemm_tma_fn(group_A, group_B_T) + for i in range(group_size): + assert torch.allclose(ref_out[i], tri_tma_out[i], atol=1e-2, rtol=0) + + +# only launch the kernel, no tensor preparation here to remove all overhead +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): + grid = lambda META: (META['NUM_SM'], ) + grouped_matmul_kernel[grid]( + a_ptrs, + b_ptrs, + c_ptrs, + sizes, + lds, + group_size, + ) + + +def triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, dtype): + grid = lambda META: (META['NUM_SM'], ) + grouped_matmul_tma_kernel[grid](a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, FP8=torch.float8_e4m3fn == dtype, + NUM_SM=num_sms()) + + +def torch_perf_fn(group_A, group_B): + for a, b in zip(group_A, group_B): + torch.matmul(a, b) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['N'], + x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name` + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=['cublas', 'triton'] + (['triton-tma'] if supports_tma() else []), + # label name for the lines + line_names=["cuBLAS", "Triton"] + (['Triton + TMA'] if supports_tma() else []), + # line styles + styles=[('green', '-'), ('blue', '-')] + ([('red', '-')] if supports_tma() else []), + ylabel="runtime(ms)", # label name for the y-axis + plot_name="group-gemm-performance", + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark_square_matrices(N, provider): + group_size = 4 + group_A = [] + group_B = [] + group_B_T = [] + A_addrs = [] + B_addrs = [] + B_T_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = torch.rand((N, N), device=DEVICE, dtype=torch.float16) + B = torch.rand((N, N), device=DEVICE, dtype=torch.float16) + C = torch.empty((N, N), device=DEVICE, dtype=torch.float16) + B_T = B.T.contiguous() + group_A.append(A) + group_B.append(B) + group_B_T.append(B_T) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + B_T_addrs.append(B_T.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [N, N, N] + g_lds += [N, N, N] + + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) + + quantiles = [0.5, 0.2, 0.8] + if provider == 'cublas': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) + if provider == 'triton-tma': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_tma_perf_fn(d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, dtype=torch. + float16), quantiles=quantiles) + return ms, max_ms, min_ms + + +@triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['M'], + x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name` + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=['cublas', 'triton'] + (['triton-tma'] if supports_tma() else []), + # label name for the lines + line_names=["cuBLAS", "Triton"] + (['Triton + TMA'] if supports_tma() else []), + # line styles + styles=[('green', '-'), ('blue', '-')] + ([('red', '-')] if supports_tma() else []), + ylabel="runtime(ms)", # label name for the y-axis + plot_name="group-gemm-performance-m-8192-k-8192", + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark_batches(M, provider): + N = 8192 + K = 8192 + group_size = 4 + group_A = [] + group_B = [] + group_B_T = [] + A_addrs = [] + B_addrs = [] + B_T_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + g_T_lds = [] + group_C = [] + for i in range(group_size): + A = torch.rand((M, K), device=DEVICE, dtype=torch.float16) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float16) + C = torch.empty((M, N), device=DEVICE, dtype=torch.float16) + B_T = B.T.contiguous() + group_A.append(A) + group_B.append(B) + group_B_T.append(B_T) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + B_T_addrs.append(B_T.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + g_T_lds += [A.stride(0), B_T.stride(0), C.stride(0)] + + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) + d_g_t_lds = torch.tensor(g_T_lds, dtype=torch.int32, device=DEVICE) + + quantiles = [0.5, 0.2, 0.8] + if provider == 'cublas': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) + if provider == 'triton-tma': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_tma_perf_fn(d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_t_lds, group_size, dtype=torch. + float16), quantiles=quantiles) + return ms, max_ms, min_ms + + +benchmark_square_matrices.run(show_plots=True, print_data=True) +benchmark_batches.run(show_plots=True, print_data=True) diff --git a/third_party/enflame/include/triton/python/tutorials/09-persistent-matmul.py b/third_party/enflame/include/triton/python/tutorials/09-persistent-matmul.py new file mode 100644 index 000000000..b65b303ae --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/09-persistent-matmul.py @@ -0,0 +1,856 @@ +""" +Persistent Matmul +===================== +This script demonstrates persistent kernel implementations of matrix multiplication using Triton. +Various matmul methods are included, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches. +The kernels support both FP16 and FP8 data types but the FP8 implementation is only available on CUDA devices with compute capability >= 9.0. + +Triton and cuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler. +Users can pass command-line arguments to specify matrix dimensions and iteration steps flexibly. + +.. code-block:: bash + + # FP8 + python 09-persistent-matmul.py --prec fp8 --K_range 128 1024 --K_step 128 + + # FP16 + python 09-persistent-matmul.py --prec fp16 --K_range 128 1024 --K_step 128 + +Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090. +""" + +import argparse + +import torch +import triton +import triton.language as tl +import triton.tools.experimental_descriptor +import triton.profiler as proton +from contextlib import contextmanager + +from typing import Optional + +if torch.cuda.is_available(): + from triton._C.libtriton import nvidia + cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) + cublas = nvidia.cublas.CublasLt(cublas_workspace) +else: + cublas = None + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def supports_tma(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) + return ret + + +HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) + +if HAS_TMA_DESC: + print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", ) +else: + print("TMA benchmarks will be running without grid constant TMA descriptor.", ) + + +# TmaAutoTuneHelper used in htyu's PR #5622 +class TmaAutoTuneHelper: + + # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 + class KernelParamWrapper: + + def __init__(self, desc): + self.desc = desc + + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + TMA_SIZE = 128 + + def __init__(self): + self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor) + self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor) + if HAS_TMA_DESC: + self.descriptors = {} + else: + self.cuda_descriptors = {} + + # Call this method outside of the lambda function for grid size + def init_tma_descriptor(self, name): + if HAS_TMA_DESC: + self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8) + else: + self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8) + + # Call this method inside the lambda function for grid size + def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr()) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr()) + desc_x.copy_(buf_x, non_blocking=True) + + # Call this method inside the lambda function for grid size + def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()) + desc_x.copy_(buf_x, non_blocking=True) + + def get_tma_descriptor_kernel_param(self, name): + if HAS_TMA_DESC: + assert self.descriptors[name] is not None + return self.KernelParamWrapper(self.descriptors[name]) + else: + assert self.cuda_descriptors[name] is not None + return self.cuda_descriptors[name] + + +def matmul_get_configs(): + return [ + triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K" : BK, "GROUP_SIZE_M" : 8}, num_stages=s, num_warps=w) \ + for BM in [128] \ + for BN in [128, 256] \ + for BK in [64,128] \ + for s in ([3,4]) \ + for w in [4,8] \ + ] + + +@triton.autotune( + configs=matmul_get_configs(), + key=["M", "N", "K"], +) +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if (c_ptr.dtype.element_ty == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + M, K = a.shape + K, N = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ) + return c + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.autotune( + configs=matmul_get_configs(), + key=["M", "N", "K"], +) +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + ): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + # NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being + # used in both the prologue and epilogue, so we duplicate the counters as a work-around. + tile_id_c = start_pid - NUM_SMS + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if (c_ptr.dtype.element_ty == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul_persistent(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + matmul_kernel_persistent[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + NUM_SMS=NUM_SMS, # + ) + return c + + +def matmul_tma_persistent_get_configs(): + return [ + triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K" : BK, "GROUP_SIZE_M" : 8, "EPILOGUE_SUBTILE" : SUBTILE}, num_stages=s, num_warps=w) \ + for BM in [128] \ + for BN in [128, 256] \ + for BK in [64, 128] \ + for s in ([3, 4]) \ + for w in [4, 8] \ + for SUBTILE in [True, False] \ + ] + + +@triton.autotune( + configs=matmul_tma_persistent_get_configs(), + key=["M", "N", "K"], +) +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # + M, N, K, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + EPILOGUE_SUBTILE: tl.constexpr, # + NUM_SMS: tl.constexpr): # + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_am_c = pid_m * BLOCK_SIZE_M + offs_bn_c = pid_n * BLOCK_SIZE_N + + # Epilogue subtiling is a technique to break our computation and stores into multiple pieces + # By subtiling we can reduce shared memory consumption by the epilogue and instead use that + # memory to increase our stage count. + # In this case we partition the accumulator into 2 BLOCK_SIZE_M x BLOCK_SIZE_N // 2 tensors + if EPILOGUE_SUBTILE: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, c0, [offs_am_c, offs_bn_c]) + c1 = acc1.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, c1, [offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + else: + accumulator = accumulator.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am_c, offs_bn_c]) + + +def matmul_tma_persistent(a, b): + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + desc_helper = TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("a") + desc_helper.init_tma_descriptor("b") + desc_helper.init_tma_descriptor("c") + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + def grid(META): + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "a", + a.data_ptr(), + M, + K, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + a.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "b", + b.data_ptr(), + N, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + b.element_size(), + ) + + store_block_n = META["BLOCK_SIZE_N"] + + if META["EPILOGUE_SUBTILE"]: + store_block_n = store_block_n // 2 + + desc_helper.fill_2d_tma_descriptor( + "c", + c.data_ptr(), + M, + N, + META["BLOCK_SIZE_M"], + store_block_n, + c.element_size(), + ) + + return (min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), ) + + desc_a = desc_helper.get_tma_descriptor_kernel_param("a") + desc_b = desc_helper.get_tma_descriptor_kernel_param("b") + desc_c = desc_helper.get_tma_descriptor_kernel_param("c") + + matmul_kernel_tma_persistent[grid]( + desc_a, desc_b, desc_c, # + M, N, K, # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # + NUM_SMS=NUM_SMS, # + ) + return c + + +@triton.autotune( + configs=matmul_tma_persistent_get_configs(), + key=["M", "N", "K"], +) +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # + M, N, K, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + EPILOGUE_SUBTILE: tl.constexpr, # + NUM_SMS: tl.constexpr): # + # Matmul using TMA and device-side descriptor creation + dtype = c_ptr.dtype.element_ty + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2], + ) + + # tile_id_c is used in the epilogue to break the dependency between + # the prologue and the epilogue + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_cm = pid_m * BLOCK_SIZE_M + offs_cn = pid_n * BLOCK_SIZE_N + + if EPILOGUE_SUBTILE: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c_desc.store([offs_cm, offs_cn], c0) + c1 = acc1.to(dtype) + c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1) + else: + c = accumulator.to(dtype) + c_desc.store([offs_cm, offs_cn], c) + + +def matmul_descriptor_persistent(a, b): + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + matmul_kernel_descriptor_persistent[grid]( + a, b, c, # + M, N, K, # + NUM_SMS=NUM_SMS, # + ) + return c + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "NUM_CONSUMER_GROUPS": 2, + }, + num_stages=2, + num_warps=4, + num_consumer_groups=2, + num_buffers_warp_spec=3, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=3, + num_warps=4, + num_consumer_groups=0, # disable warp specialization + num_buffers_warp_spec=3, + ), + ], + key=["M", "N", "K"], + use_cuda_graph=True, +) +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_persistent_tma_ws_cooperative_kernel( + a_desc_ptr, + b_desc_ptr, + c_desc_ptr, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + NUM_CONSUMER_GROUPS: tl.constexpr, +): + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl._experimental_descriptor_load( + a_desc_ptr, + [offs_am, offs_k], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + + accumulator = tl.dot(a, b.T, accumulator) + offs_k += BLOCK_SIZE_K + + c = accumulator.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + + +def matmul_persistent_tma_ws_cooperative(a, b): + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + desc_helper = TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("a") + desc_helper.init_tma_descriptor("b") + desc_helper.init_tma_descriptor("c") + + def grid(META): + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "a", + a.data_ptr(), + M, + K, + META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], + META["BLOCK_SIZE_K"], + a.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "b", + b.data_ptr(), + N, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + b.element_size(), + ) + desc_helper.fill_2d_tma_descriptor( + "c", + c.data_ptr(), + M, + N, + META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], + META["BLOCK_SIZE_N"], + c.element_size(), + ) + return (min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), ) + + desc_a = desc_helper.get_tma_descriptor_kernel_param("a") + desc_b = desc_helper.get_tma_descriptor_kernel_param("b") + desc_c = desc_helper.get_tma_descriptor_kernel_param("c") + + matmul_persistent_tma_ws_cooperative_kernel[grid]( + desc_a, desc_b, desc_c, # + M, N, K, # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # + ) + return c + + +def cublas_matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + M, K = a.shape + N, K = b.shape + dtype = a.dtype + c = torch.empty((M, N), device=a.device, dtype=dtype) + bytes_per_elem = a.element_size() + flops_str = f"flops{bytes_per_elem * 8}" + with proton.scope(f"cublas [M={M}, N={N}, K={K}]", + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): + cublas.matmul(a, b, c) + return c + + +def torch_matmul(a, b): + M, K = a.shape + N, K = b.shape + bytes_per_elem = a.element_size() + flops_str = f"flops{bytes_per_elem * 8}" + with proton.scope(f"torch [M={M}, N={N}, K={K}]", + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): + c = torch.matmul(a, b.T) + return c + + +@contextmanager +def proton_context(): + proton.activate(0) + try: + yield + finally: + proton.deactivate(0) + + +def bench_fn(reps, warmup_reps, fn, *args): + for _ in range(warmup_reps): + fn(*args) + with proton_context(): + for _ in range(reps): + fn(*args) + + +def bench(K, dtype, reps=1000, warmup_reps=10000): + M = 8192 + N = 8192 + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) + + b = b.T.contiguous() + + if cublas is not None: + bench_fn(reps, warmup_reps, cublas_matmul, a, b) + if dtype == torch.float16: + bench_fn(reps, warmup_reps, torch_matmul, a, b) + bench_fn(reps, warmup_reps, matmul, a, b.T) + bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) + if supports_tma(): + bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) + bench_fn(reps, warmup_reps, matmul_persistent_tma_ws_cooperative, a, b) + bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b) + + +def validate(M, N, K, dtype): + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) + b = b.T.contiguous() + + torch_result = torch_matmul(a, b) if dtype == torch.float16 else None + cublas_result = cublas_matmul(a, b) if cublas is not None else None + naive_result = matmul(a, b.T) + persistent_result = matmul_persistent(a, b.T) + tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None + descriptor_persistent_result = matmul_descriptor_persistent(a, b) if supports_tma() else None + matmul_persistent_tma_ws_cooperative_result = matmul_persistent_tma_ws_cooperative(a, b) if supports_tma() else None + + if torch_result is not None: + naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), + atol=1.0) else "❌" + if cublas_result is not None: + naive_vs_cublas = "✅" if torch.allclose(naive_result.to(torch.float16), cublas_result.to(torch.float16), + atol=1.0) else "❌" + naive_vs_persistent = "✅" if torch.allclose(naive_result.to(torch.float16), persistent_result.to(torch.float16), + atol=1.0) else "❌" + if tma_persistent_result is not None: + naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16), + tma_persistent_result.to(torch.float16), atol=1.0) else "❌" + if descriptor_persistent_result is not None: + naive_vs_descriptor_persistent = "✅" if torch.allclose(cublas_result.to( + torch.float16), descriptor_persistent_result.to(torch.float16), atol=1.0) else "❌" + if matmul_persistent_tma_ws_cooperative_result is not None: + naive_vs_matmul_persistent_tma_ws_cooperative = "✅" if torch.allclose( + cublas_result.to(torch.float16), matmul_persistent_tma_ws_cooperative_result.to(torch.float16), + atol=1.0) else "❌" + print(f"M={M}, N={N}, K={K} verification naive vs: ", end="") + if torch_result is not None: + print(f"torch: {naive_vs_torch} ", end="") + if cublas_result is not None: + print(f"cublas: {naive_vs_cublas} ", end="") + print(f"persistent: {naive_vs_persistent} ", end="") + if tma_persistent_result is not None: + print(f"TMA persistent: {naive_vs_tma_persistent} ", end="") + if descriptor_persistent_result is not None: + print(f"Tensor descriptor persistent: {naive_vs_descriptor_persistent} ", end="") + if matmul_persistent_tma_ws_cooperative_result is not None: + print(f"TMA persistent with warp specialization: {naive_vs_matmul_persistent_tma_ws_cooperative} ", end="") + print() + + +def show_profile(precision, profile_name): + import triton.profiler.viewer as proton_viewer + metric_names = ["time/ms"] + if precision == 'fp8': + metric_names = ["tflop8/s"] + metric_names + elif precision == 'fp16': + metric_names = ["tflop16/s"] + metric_names + file_name = f"{profile_name}.hatchet" + tree, metrics = proton_viewer.parse(metric_names, file_name) + proton_viewer.print_tree(tree, metrics) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-K", type=int, required=False, default=512) + parser.add_argument("--K_range", type=int, nargs=2) + parser.add_argument("--K_step", type=int, default=512) + parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") + args = parser.parse_args() + + if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()): + print("This example requires CUDA with fp8 support.") + else: + dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16 + + if args.K and args.K_range is None: + args.K_range = [args.K, args.K] + args.K_step = 1 # doesn't matter as long as it's not 0 + + torch.manual_seed(0) + + validate(32, 32, 32, dtype) + validate(8192, 8192, args.K_range[0], dtype) + + proton.start("matmul", hook="triton") + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench(K, dtype) + proton.finalize() + show_profile(args.prec, "matmul") diff --git a/third_party/enflame/include/triton/python/tutorials/10-block-scaled-matmul.py b/third_party/enflame/include/triton/python/tutorials/10-block-scaled-matmul.py new file mode 100644 index 000000000..a97d1118b --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/10-block-scaled-matmul.py @@ -0,0 +1,339 @@ +""" +Block Scaled Matrix Multiplication +================================== +This tutorial demonstrates a Triton implementation of block scaled matrix multiplication +which is generic over FP4 and FP8 formats. The formats supported in the tutorial are the OCP microscaling +formats, including mxfp4 and mxfp8, as well as NVIDIA's nvfp4 format. These matrix multiplications +are accelerated by fifth generation tensor core instructions on CUDA devices with compute capability 10. + +Users can run the tutorial with each of the supported formats by passing the `--format` +argument and can benchmark the performance of each by specifying matrix dimensions +and iteration steps. + +.. code-block:: bash + + # FP4 + python 10-block-scaled-matmul.py --format nvfp4 + python 10-block-scaled-matmul.py --format mxfp4 --K_range 512 8192 --bench + + # FP8 + python 10-block-scaled-matmul.py --format mxfp8 --K_range 8192 16384 --K_step 2048 --bench + +Future updates to this tutorial which support mixed precision block scaled matmul are planned. +""" + +# %% +# Background +# ---------- +# +# CUDA devices that support PTX 8.7 and later can utlize block scaled matrix multiply +# instructions. In order for low latency access to these scale factors in the fast +# inner loop over tensor core MMAs, it is important to ensure that the blocked +# scale factors are stored in a contiguous memory layout according to their access +# pattern. +# +# The block scaled matmul tensor core instructions compute the following product: +# +# C = (A * scale_a) @ (B * scale_b) +# +# where scale_a and scale_b are the blocked scale factors for the A and B matrices. +# Under block scaled matmul, each scale factor is broadcast and multiplied across a +# vector of elements from the A and B matrices, usually along their respective K axes. +# The number of elements of A and B over which each scale factor is broadcast is herein +# refered to as the vector size (VEC_SIZE). +# +# In a linear row-major layout, the scale factors would take the shape +# +# (M, K // VEC_SIZE) and (N, K // VEC_SIZE) [1] +# +# in global memory. However, to avoid non-contiguous memory access, it is beneficial to +# instead store the scale factors in a packed block layout. For the LHS matrix this layout +# is given by +# +# (M // 32 // 4, K // VEC_SIZE // 4, 32, 4, 4) [2]. +# +# In this way, each tensor core MMA in the fast inner loop over K blocks can achieve contiguous +# access of a block of 128 rows of scale factors along the M axis, for each BLOCK_M x BLOCK_K +# subtile of the matrix A. +# +# In order to conform with Triton's language semantics for dot_scaled, the scale factors +# are prepared in the above 5D layout [2], but are then logically transposed and reshaped into +# the 2D layout [1] expected by tl.dot_scaled. +# +# For more detailed information on the scale factor layout, see +# 1. https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x +# 2. https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout +# + +import argparse + +import torch +import triton +import triton.language as tl +import triton.tools.experimental_descriptor +import triton.profiler as proton +from triton.tools.experimental_descriptor import TmaDescKernelParam +from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def supports_block_scaling(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 10 + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + kernel_name = kernel.name + if "ELEM_PER_BYTE" and "VEC_SIZE" in args: + if args["ELEM_PER_BYTE"] == 1: + kernel_name += "_mxfp8" + elif args["ELEM_PER_BYTE"] == 2: + if args["VEC_SIZE"] == 16: + kernel_name += "_nvfp4" + elif args["VEC_SIZE"] == 32: + kernel_name += "_mxfp4" + ret["name"] = f"{kernel_name} [M={M}, N={N}, K={K}]" + ret["flops"] = 2. * M * N * K + return ret + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def block_scaled_matmul_kernel( # + a_desc, a_scale, # + b_desc, b_scale, # + c_desc, # + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, # + stride_sk: tl.constexpr, stride_sb: tl.constexpr, stride_sc: tl.constexpr, stride_sd: tl.constexpr, + output_type: tl.constexpr, # + ELEM_PER_BYTE: tl.constexpr, # + VEC_SIZE: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr, # + USE_2D_SCALE_LOAD: tl.constexpr): # + + if ELEM_PER_BYTE == 1: + dtype = tl.float8e4nv + elif ELEM_PER_BYTE == 2: + dtype = tl.dtype("uint8") + + if output_type == 0: + output_dtype = tl.float32 + elif output_type == 1: + output_dtype = tl.float16 + elif output_type == 2: + output_dtype = tl.float8e4nv + + tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [a_desc], dtype=tl.int32, is_pure=False, + pack=1) + tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [b_desc], dtype=tl.int32, is_pure=False, + pack=1) + tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [c_desc], dtype=tl.int32, is_pure=False, + pack=1) + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + offs_k = 0 + + ## block scale offsets + offs_sm = (pid_m * (BLOCK_M // 128) + tl.arange(0, BLOCK_M // 128)) % M + offs_sn = (pid_n * (BLOCK_N // 128) + tl.arange(0, BLOCK_N // 128)) % N + + # For now it is recommended to use 2D scale loads for better performance. + # In the future we will bring additional optimizations to either allow 5D loads, + # the use of TMAs for scale factors, or both. + if USE_2D_SCALE_LOAD: + offs_inner = tl.arange(0, (BLOCK_K // VEC_SIZE // 4) * 32 * 4 * 4) + a_scale_ptr = a_scale + offs_sm[:, None] * stride_sk + offs_inner[None, :] + b_scale_ptr = b_scale + offs_sn[:, None] * stride_sk + offs_inner[None, :] + else: + offs_sk = tl.arange(0, (BLOCK_K // VEC_SIZE // 4)) + # MN spatial offsets for 32 element blocking + offs_sc = tl.arange(0, 32) + # offsets for both scale factor column ID (along K) + # and spatial block column ID (along MN) + offs_sd = tl.arange(0, 4) + a_scale_ptr = a_scale + (offs_sm[:, None, None, None, None] * stride_sk + offs_sk[None, :, None, None, None] * + stride_sb + offs_sc[None, None, :, None, None] * stride_sc + + offs_sd[None, None, None, :, None] * stride_sd + offs_sd[None, None, None, None, :]) + b_scale_ptr = b_scale + (offs_sn[:, None, None, None, None] * stride_sk + offs_sk[None, :, None, None, None] * + stride_sb + offs_sc[None, None, :, None, None] * stride_sc + + offs_sd[None, None, None, :, None] * stride_sd + offs_sd[None, None, None, None, :]) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): + a = tl._experimental_descriptor_load(a_desc, [offs_am, offs_k], [BLOCK_M, BLOCK_K // ELEM_PER_BYTE], dtype) + b = tl._experimental_descriptor_load(b_desc, [offs_bn, offs_k], [BLOCK_N, BLOCK_K // ELEM_PER_BYTE], dtype) + scale_a = tl.load(a_scale_ptr) + scale_b = tl.load(b_scale_ptr) + if USE_2D_SCALE_LOAD: + scale_a = scale_a.reshape(BLOCK_M // 128, BLOCK_K // VEC_SIZE // 4, 32, 4, 4) + scale_b = scale_b.reshape(BLOCK_N // 128, BLOCK_K // VEC_SIZE // 4, 32, 4, 4) + scale_a = scale_a.trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE) + scale_b = scale_b.trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE) + if ELEM_PER_BYTE == 2: + accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator) + else: + accumulator = tl.dot_scaled(a, scale_a, "e4m3", b.T, scale_b, "e4m3", accumulator) + offs_k += BLOCK_K // ELEM_PER_BYTE + a_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sb + b_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sb + tl._experimental_descriptor_store(c_desc, accumulator.to(output_dtype), [offs_am, offs_bn]) + + +def block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, dtype_dst, M, N, K, configs): + output = torch.empty((M, N), dtype=dtype_dst, device="cuda") + if dtype_dst == torch.float32: + dtype_dst = 0 + elif dtype_dst == torch.float16: + dtype_dst = 1 + elif dtype_dst == torch.float8_e4m3fn: + dtype_dst = 2 + else: + raise ValueError(f"Unsupported dtype: {dtype_dst}") + + c_desc = TmaDescKernelParam(output.data_ptr(), output.shape, [configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_N"]], + output.element_size()) + + grid = (triton.cdiv(M, configs["BLOCK_SIZE_M"]) * triton.cdiv(N, configs["BLOCK_SIZE_N"]), 1) + block_scaled_matmul_kernel[grid](a_desc, a_scale, b_desc, b_scale, c_desc, M, N, K, a_scale.stride(0), + a_scale.stride(1), a_scale.stride(2), a_scale.stride(3), dtype_dst, + configs["ELEM_PER_BYTE"], configs["VEC_SIZE"], configs["BLOCK_SIZE_M"], + configs["BLOCK_SIZE_N"], configs["BLOCK_SIZE_K"], configs["num_stages"], + USE_2D_SCALE_LOAD=True) + return output + + +def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference=False): + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 256 if "fp4" in block_scale_type else 128 + VEC_SIZE = 16 if block_scale_type == "nvfp4" else 32 + assert block_scale_type in ["nvfp4", "mxfp4", "mxfp8"], f"Invalid block scale type: {block_scale_type}" + ELEM_PER_BYTE = 2 if "fp4" in block_scale_type else 1 + + device = "cuda" + a_ref = MXFP4Tensor(size=(M, K), device=device).random() + # Similar to Hopper's wgmma symmetric fp8 instruction, the RHS is expected + # to be in col-major layout for Blackwell's tcgen05.mma when using fp4 operands. + # To conform to the expected semantics of tl.dot_scaled, (M, K) x (K, N), + # the data is generated in col-major layout, packed along K for fp4, and then + # logically transposed. Note that if one operand is of fp8 precision, unlike Hopper, + # Blackwell supports both row-major and col-major layouts for the RHS matrix. + b_ref = MXFP4Tensor(size=(N, K), device=device).random() + if block_scale_type == "mxfp8": + a_ref = a_ref.to(torch.float32) + b_ref = b_ref.to(torch.float32) + a = a_ref.to(torch.float8_e4m3fn) + b = b_ref.to(torch.float8_e4m3fn) + else: + # Pack two fp4 elements per byte along K + a = a_ref.to_packed_tensor(dim=1) + b = b_ref.to_packed_tensor(dim=1) + b_ref = b_ref.to(torch.float32).T + + a_desc = TmaDescKernelParam(a.data_ptr(), a.shape, [BLOCK_M, BLOCK_K // ELEM_PER_BYTE], 1) + b_desc = TmaDescKernelParam(b.data_ptr(), b.shape, [BLOCK_N, BLOCK_K // ELEM_PER_BYTE], 1) + + epsilon = 1e-8 + a_scale = torch.rand((M // 128, K // VEC_SIZE // 4, 32, 4, 4), device=device) + epsilon + b_scale = torch.rand((N // 128, K // VEC_SIZE // 4, 32, 4, 4), device=device) + epsilon + if block_scale_type == "nvfp4": + a_scale = a_scale.to(torch.float8_e4m3fn) + b_scale = b_scale.to(torch.float8_e4m3fn) + a_scale_ref = a_scale + b_scale_ref = b_scale + elif block_scale_type in ["mxfp4", "mxfp8"]: + a_scale_ref = MXScaleTensor(a_scale) + b_scale_ref = MXScaleTensor(b_scale) + a_scale = a_scale_ref.data + b_scale = b_scale_ref.data + + reference = None + if compute_reference: + a_scale_ref = a_scale_ref.to(torch.float32) + b_scale_ref = b_scale_ref.to(torch.float32) + + def unpack_scale(packed): + num_chunk_m, num_chunk_k, _, _, _ = packed.shape + return packed.permute(0, 3, 2, 1, 4).reshape(num_chunk_m * 128, num_chunk_k * 4).contiguous() + + a_scale_ref = unpack_scale(a_scale_ref).repeat_interleave(VEC_SIZE, dim=1)[:M, :K] + b_scale_ref = unpack_scale(b_scale_ref).repeat_interleave(VEC_SIZE, dim=1).T.contiguous()[:K, :N] + reference = torch.matmul(a_ref.to(torch.float32) * a_scale_ref, b_ref * b_scale_ref) + + configs = { + "BLOCK_SIZE_M": BLOCK_M, + "BLOCK_SIZE_N": BLOCK_N, + "BLOCK_SIZE_K": BLOCK_K, + "num_stages": 4, + "ELEM_PER_BYTE": ELEM_PER_BYTE, + "VEC_SIZE": VEC_SIZE, + } + return a_desc, a_scale, b_desc, b_scale, configs, reference + + +def validate_block_scaled(M, N, K, block_scale_type="nvfp4"): + a_desc, a_scale, b_desc, b_scale, configs, reference = initialize_block_scaled(M, N, K, block_scale_type, + compute_reference=True) + output = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs) + torch.testing.assert_close(reference, output.to(torch.float32), atol=1e-3, rtol=1e-3) + print(f"✅ (pass {block_scale_type})") + + +def bench_block_scaled(K, block_scale_type="nvfp4", reps=10): + assert K % 128 == 0 + M = 8192 + N = 8192 + print(f"Problem Shape = {M}x{N}x{K}") + + a_desc, a_scale, b_desc, b_scale, configs, _ = initialize_block_scaled(M, N, K, block_scale_type, + compute_reference=False) + _ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs) + + proton.activate(0) + for _ in range(reps): + _ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs) + proton.deactivate(0) + print("Done benchmarking") + + +def show_profile(profile_name): + import triton.profiler.viewer as proton_viewer + metric_names = ["time/ms"] + metric_names = ["tflop/s"] + metric_names + file_name = f"{profile_name}.hatchet" + tree, metrics = proton_viewer.parse(metric_names, file_name) + proton_viewer.print_tree(tree, metrics) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--K_range", type=int, nargs=2) + parser.add_argument("--K_step", type=int, default=512) + parser.add_argument("--bench", action="store_true") + parser.add_argument("--format", type=str, choices=["mxfp4", "nvfp4", "mxfp8"], default="nvfp4") + args = parser.parse_args() + + if not supports_block_scaling(): + print("⛔ This example requires GPU support for block scaled matmul") + else: + torch.manual_seed(42) + + validate_block_scaled(8192, 8192, 8192, block_scale_type=args.format) + + if args.bench: + proton.start("block_scaled_matmul", hook="triton") + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench_block_scaled(K, reps=10000, block_scale_type=args.format) + proton.finalize() + show_profile("block_scaled_matmul") diff --git a/third_party/enflame/include/triton/python/tutorials/README.rst b/third_party/enflame/include/triton/python/tutorials/README.rst new file mode 100644 index 000000000..1dfa5f4dc --- /dev/null +++ b/third_party/enflame/include/triton/python/tutorials/README.rst @@ -0,0 +1,11 @@ +Tutorials +========= + +Below is a gallery of tutorials for writing various basic operations with Triton. It is recommended that you read through the tutorials in order, starting with the simplest one. + +To install the dependencies for the tutorials: + +.. code-block:: bash + + cd triton + pip install -e './python[tutorials]' diff --git a/third_party/enflame/include/triton/test/Analysis/test-alias.mlir b/third_party/enflame/include/triton/test/Analysis/test-alias.mlir new file mode 100644 index 000000000..660b66c96 --- /dev/null +++ b/third_party/enflame/include/triton/test/Analysis/test-alias.mlir @@ -0,0 +1,193 @@ +// RUN: triton-opt %s -mlir-disable-threading -test-print-alias -verify-diagnostics -o /dev/null + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#B_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// There shouldn't be any aliasing with the dot op encoding. +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return +} + +tt.func @alloc(%A : !tt.ptr) { + // expected-remark @below {{%0 -> %0}} + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +tt.func @alloc_init(%A : !tt.ptr) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // expected-remark @below {{%0 -> %0}} + %cst1 = ttg.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + tt.return +} + +tt.func @trans(%A : !tt.ptr) { + // expected-remark @below {{%0 -> %0}} + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%1 -> %0}} + %b = ttg.memdesc_trans %tensor {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory, mutable> + tt.return +} + +tt.func @subview(%A : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory>) { + %index = arith.constant 0 : i32 + // expected-remark @below {{%0 -> %0}} + %a = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%1 -> %0}} + %cst1 = ttg.memdesc_subview %a[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +tt.func @if_alias(%i1 : i1) { + // expected-remark @below {{%0 -> %0}} + %a = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%1 -> %1}} + %b = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%2 -> %0,%1}} + %cst2 = scf.if %i1 -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> { + scf.yield %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + } else { + scf.yield %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return +} + +tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + // expected-remark @below {{%0 -> %0}} + %a = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%1 -> %1}} + %b = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%2 -> %2}} + %c = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%arg6 -> %0}} + // expected-remark @below {{%arg7 -> %1}} + // expected-remark @below {{%arg8 -> %2}} + // expected-remark @below {{%3#0 -> %0,%1}} + // expected-remark @below {{%3#1 -> %0,%1}} + // expected-remark @below {{%3#2 -> %0,%1,%2}} + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a, %b_shared = %b, %c_shared = %c) -> + (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return +} + +tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // expected-remark @below {{%0 -> %0}} + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%1 -> %1}} + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%2 -> %2}} + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%arg7 -> %0}} + // expected-remark @below {{%arg8 -> %1}} + // expected-remark @below {{%arg9 -> %2}} + // expected-remark @below {{%3#0 -> %0,%1}} + // expected-remark @below {{%3#1 -> %0,%1}} + // expected-remark @below {{%3#2 -> %0,%1,%2}} + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> + (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + scf.if %i1 { + %index = arith.constant 8 : i32 + // expected-remark @below {{%4 -> %0,%1}} + %cst0 = ttg.memdesc_subview %a_shared[%index, %index] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield + } + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return +} + +tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // expected-remark @below {{%0 -> %0}} + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%1 -> %1}} + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%2 -> %2}} + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%arg7 -> %0}} + // expected-remark @below {{%arg8 -> %1}} + // expected-remark @below {{%arg9 -> %2}} + // expected-remark @below {{%3#0 -> %0}} + // expected-remark @below {{%3#1 -> %1}} + // expected-remark @below {{%3#2 -> %2,%6,%6}} + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> + (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + // expected-remark @below {{%arg11 -> %2,%6,%6}} + // expected-remark @below {{%4 -> %2,%6,%6}} + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + // expected-remark @below {{%5 -> %6,%6}} + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> { + // expected-remark @below {{%6 -> %6}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } else { + // expected-remark @below {{%6 -> %6}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return +} + +tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) { + %idx = arith.constant 0 : i32 + // expected-remark @below {{%0 -> %0}} + %cst = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%1 -> %1}} + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{%2 -> %0}} + %0 = ttg.memdesc_subview %cst[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + gpu.barrier + // expected-remark @below {{%3 -> %3}} + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) +^bb1(%1: index, %2: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, %3: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, %4: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>): // 2 preds: ^bb0, ^bb2 + %5 = arith.cmpi slt, %1, %arg1 : index + // expected-remark @below {{%5 -> %0,%1,%3}} + // expected-remark @below {{%6 -> %0,%1,%3}} + // expected-remark @below {{%7 -> %0,%1,%3}} + cf.cond_br %5, ^bb2, ^bb3 +^bb2: // pred: ^bb1 + gpu.barrier + %8 = arith.addi %1, %arg2 : index + cf.br ^bb1(%8, %4, %2, %3 : index, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) +^bb3: // pred: ^bb1 + gpu.barrier + // expected-remark @below {{%10 -> %0}} + %9 = ttg.memdesc_subview %0[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +} // module diff --git a/third_party/enflame/include/triton/test/Analysis/test-alignment.mlir b/third_party/enflame/include/triton/test/Analysis/test-alignment.mlir new file mode 100644 index 000000000..5f8616667 --- /dev/null +++ b/third_party/enflame/include/triton/test/Analysis/test-alignment.mlir @@ -0,0 +1,845 @@ +// RUN: triton-opt %s -test-print-alignment -split-input-file -o /dev/null + +tt.func @cast() { + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1}} + %cst = arith.constant 1 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1}} + %0 = arith.extsi %cst : i32 to i64 + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}} + %cst_tensor = arith.constant dense<1> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}} + %1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64> + tt.return +} + +// ----- + +tt.func @add() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}} + %1 = arith.constant dense<1> : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [1], constancy = [1], constant_value = }} + %2 = arith.addi %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 127}} + %3 = arith.constant dense<127> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128}} + %4 = arith.addi %1, %3 : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @addptr(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) { + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1}} + %cst1 = arith.constant 1 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %0 = tt.addptr %arg0, %cst1 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %1 = tt.addptr %arg1, %cst1 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [2], constancy = [1], constant_value = }} + %2 = tt.addptr %arg2, %cst1 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %3 = tt.addptr %arg3, %cst1 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [1], constant_value = }} + %4 = tt.addptr %arg4, %cst1 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4}} + %cst4 = arith.constant 4 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %5 = tt.addptr %arg0, %cst4 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %6 = tt.addptr %arg1, %cst4 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [1], constant_value = }} + %7 = tt.addptr %arg2, %cst4 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [1], constant_value = }} + %8 = tt.addptr %arg3, %cst4 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [1], constant_value = }} + %9 = tt.addptr %arg4, %cst4 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = }} + %11 = tt.expand_dims %10 {axis = 0: i32} : tensor<128xi32> -> tensor<1x128xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = }} + %12 = tt.broadcast %11 : tensor<1x128xi32> -> tensor<128x128xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = }} + %13 = tt.splat %arg0 : !tt.ptr -> tensor<128x128x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = }} + %14 = tt.splat %arg1 : !tt.ptr -> tensor<128x128x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = }} + %15 = tt.splat %arg2 : !tt.ptr -> tensor<128x128x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = }} + %16 = tt.splat %arg3 : !tt.ptr -> tensor<128x128x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = }} + %17 = tt.splat %arg4 : !tt.ptr -> tensor<128x128x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [1, 16], constancy = [128, 1], constant_value = }} + %18 = tt.addptr %13, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [1, 16], constancy = [128, 1], constant_value = }} + %19 = tt.addptr %14, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [2, 16], constancy = [128, 1], constant_value = }} + %20 = tt.addptr %15, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [4, 16], constancy = [128, 1], constant_value = }} + %21 = tt.addptr %16, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [8, 16], constancy = [128, 1], constant_value = }} + %22 = tt.addptr %17, %12 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + tt.return +} + +// ----- + +tt.func @sub() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}} + %1 = arith.constant dense<1> : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [1], constancy = [1], constant_value = }} + %2 = arith.subi %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %3 = arith.subi %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129}} + %4 = arith.constant dense<129> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128}} + %5 = arith.subi %4, %1 : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @mul(%arg0: i64 {tt.divisibility = 16 : i32}) { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}} + %1 = arith.constant dense<1> : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %2 = arith.muli %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128}} + %3 = arith.constant dense<128> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128}} + %4 = arith.muli %3, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2}} + %5 = arith.constant dense<2> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256}} + %6 = arith.muli %4, %5 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 4611686018427387904}} + %7 = arith.constant 4611686018427387904: i64 + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = }} + %8 = arith.muli %arg0, %7 : i64 + tt.return +} + +// ----- + +tt.func @div() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}} + %1 = arith.constant dense<1> : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %2 = arith.divsi %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %3 = arith.divui %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64}} + %4 = arith.constant dense<64> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [64], constant_value = }} + %5 = arith.divsi %0, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %6 = arith.divsi %4, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64}} + %7 = arith.divsi %4, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66}} + %8 = arith.constant dense<66> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [2], constant_value = }} + %9 = arith.divui %0, %8 : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [8192], constancy = [1], constant_value = }} + %10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [64], constant_value = }} + %11 = arith.divsi %10, %4 : tensor<128xi32> + tt.return +} + + +// ----- + +tt.func @rem() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}} + %1 = arith.constant dense<1> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0}} + %2 = arith.remsi %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %3 = arith.remui %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64}} + %4 = arith.constant dense<64> : tensor<128xi32> + // expeted-remark @below {{contiguity = [64], divisibility = [64], constancy = [1], constant_value = }} + %5 = arith.remsi %0, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [64], constancy = [1], constant_value = }} + %6 = arith.remsi %4, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66}} + %7 = arith.constant dense<66> : tensor<128xi32> + // expeted-remark @below {{contiguity = [2], divisibility = [2], constancy = [1], constant_value = }} + %8 = arith.remui %0, %7 : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @expanddims() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2}} + %1 = arith.constant dense<2> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [2], constancy = [1], constant_value = }} + %2 = arith.muli %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [2, 2], constancy = [1, 1], constant_value = }} + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + tt.return +} + +// ----- + +tt.func @broadcast() { + // expeted-remark @below {{contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64}} + %0 = arith.constant dense<64> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 1], constant_value = 64}} + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 128], constant_value = 64}} + %2 = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> + tt.return +} + +// ----- + +tt.func @splat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = }} + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x128x!tt.ptr> + tt.return +} + +// ----- + +tt.func @cmp_all_contiguous() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0}} + %1 = arith.constant dense<0> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %3 = arith.cmpi ne, %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = }} + %4 = arith.cmpi slt, %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %5 = arith.cmpi sle, %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = }} + %6 = arith.cmpi sge, %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %7 = arith.cmpi sgt, %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %8 = arith.cmpi eq, %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %9 = arith.cmpi ne, %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %10 = arith.cmpi slt, %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = }} + %11 = arith.cmpi sle, %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %12 = arith.cmpi sge, %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = }} + %13 = arith.cmpi sgt, %1, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8}} + %14 = arith.constant dense<8> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %15 = arith.cmpi sgt, %14, %0 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1}} + %16 = arith.cmpi sgt, %14, %1 : tensor<128xi32> + tt.return +} + +tt.func @cmp_partial_contiguous() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8}} + %1 = arith.constant dense<8> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [32], constancy = [128], constant_value = 32}} + %3 = arith.constant dense<32> : tensor<128xi32> + // expeted-remark @below {{contiguity = [32], divisibility = [32], constancy = [1], constant_value = }} + %4 = arith.remsi %0, %3 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %5 = arith.cmpi eq, %4, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %6 = arith.cmpi ne, %4, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %7 = arith.cmpi slt, %4, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %8 = arith.cmpi sle, %4, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %9 = arith.cmpi sge, %4, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %10 = arith.cmpi sgt, %4, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %11 = arith.cmpi eq, %1, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %12 = arith.cmpi ne, %1, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %13 = arith.cmpi slt, %1, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %14 = arith.cmpi sle, %1, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %15 = arith.cmpi sge, %1, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %16 = arith.cmpi sgt, %1, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [128], constant_value = 48}} + %17 = arith.constant dense<48> : tensor<128xi32> + // expeted-remark @below {{contiguity = [16], divisibility = [16], constancy = [1], constant_value = }} + %18 = arith.remsi %0, %17 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %19 = arith.cmpi eq, %18, %3 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %20 = arith.cmpi ne, %18, %3 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [16], constant_value = }} + %21 = arith.cmpi slt, %18, %3 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %22 = arith.cmpi sle, %18, %3 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [16], constant_value = }} + %23 = arith.cmpi sge, %18, %3 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %24 = arith.cmpi sgt, %18, %3 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %25 = arith.cmpi eq, %3, %18 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %26 = arith.cmpi ne, %3, %18 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %27 = arith.cmpi slt, %3, %18 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [16], constant_value = }} + %28 = arith.cmpi sle, %3, %18 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %29 = arith.cmpi sge, %3, %18 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [16], constant_value = + tt.return +} + +// ----- + +tt.func @logic() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64}} + %1 = arith.constant dense<64> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [64], constant_value = }} + %2 = arith.divsi %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8}} + %3 = arith.constant dense<8> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %4 = arith.divsi %0, %3 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %5 = arith.andi %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %6 = arith.ori %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %7 = arith.xori %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %8 = arith.andi %2, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %9 = arith.ori %2, %4 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [8], constant_value = }} + %10 = arith.xori %2, %4 : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @select(%arg0 : i1, %arg1 : tensor<4xi1>) { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0}} + %1 = arith.constant dense<0> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [128], constant_value = }} + %3 = arith.cmpi slt, %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0}} + %4 = arith.constant 0 : i1 + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0}} + %7 = tt.splat %4 : i1 -> tensor<128xi1> + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0}} + %5 = arith.select %4, %3, %7 : tensor<128xi1> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %8 = arith.select %7, %3, %2 : tensor<128xi1>, tensor<128xi1> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %9 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi1> -> tensor<128x1xi1> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 1], constant_value = }} + %10 = tt.expand_dims %3 {axis = 1 : i32} : tensor<128xi1> -> tensor<128x1xi1> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %11 = arith.select %arg0, %9, %10 : tensor<128x1xi1> + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [4], constant_value = 4}} + %cst = arith.constant dense<4> : tensor<4xi32> + // expeted-remark @below {{contiguity = [4], divisibility = [1073741824], constancy = [1], constant_value = }} + %12 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %13 = arith.muli %12, %cst : tensor<4xi32> + // expeted-remark @below {{contiguity = [4], divisibility = [16], constancy = [1], constant_value = }} + %14 = tt.make_range {end = 20 : i32, start = 16 : i32} : tensor<4xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %15 = arith.select %arg1, %12, %13 : tensor<4xi1>, tensor<4xi32> + tt.return +} + +// ----- + +tt.func @shift(%arg0: i32 {tt.divisibility = 4 : i32}) { + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [128], constant_value = }} + %s = tt.splat %arg0 : i32 -> tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8}} + %1 = arith.constant dense<8> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4}} + %2 = arith.constant dense<4> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [256], constancy = [1], constant_value = }} + %3 = arith.shli %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %4 = arith.shrsi %0, %2 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128}} + %5 = arith.shli %1, %2 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = }} + %6 = arith.shli %1, %s : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %7 = arith.shrsi %0, %s : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @max_min() { + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [64], constancy = [1], constant_value = }} + %1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [64], constancy = [1], constant_value = }} + %2 = arith.maxsi %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [64], constancy = [1], constant_value = }} + %3 = arith.minsi %0, %1 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8}} + %4 = arith.constant dense<8> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4}} + %5 = arith.constant dense<4> : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8}} + %6 = arith.maxsi %4, %5 : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @if(%i1 : i1) { + // expeted-remark @below {{contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 32], constant_value = 64}} + %cst_64 = arith.constant dense<64> : tensor<128x32xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1}} + %cst_1 = arith.constant dense<1> : tensor<128x32xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 32], constant_value = 64}} + %a = arith.muli %cst_64, %cst_1 : tensor<128x32xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = }} + %ret = scf.if %i1 -> tensor<128x32xi32> { + scf.yield %a : tensor<128x32xi32> + } else { + scf.yield %cst_1 : tensor<128x32xi32> + } + tt.return +} + +// ----- + +tt.func @for() { + // expeted-remark @below {{contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0}} + %a_init = arith.constant dense<0> : tensor<128x32xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1}} + %b_init = arith.constant dense<1> : tensor<128x32xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4}} + %c_init = arith.constant dense<4> : tensor<128x32xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128}} + %ub = arith.constant 128 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0}} + %lb = arith.constant 0 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16}} + %step = arith.constant 16 : i32 + %a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) : i32 { + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [1], constant_value = }} + %t = arith.addi %iv, %lb : i32 + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = }} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = }} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4}} + scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32> + } + tt.return +} + +// ----- + +tt.func @for_dynamic(%lb: i32 {tt.divisibility = 16 : i32}, %step: i32 {tt.divisibility = 8 : i32}, %ub: i32) { + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0}} + %c0 = arith.constant 0 : i32 + scf.for %iv = %lb to %ub step %step : i32 { + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [1], constant_value = }} + %t = arith.addi %iv, %c0 : i32 + } + tt.return +} + +// ----- + +tt.func @for_if(%i1: i1, %arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0}} + %c0_i32 = arith.constant 0 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1}} + %c1_i32 = arith.constant 1 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [2], constancy = [1], constant_value = 10}} + %c10_i32 = arith.constant 10 : i32 + // expeted-remark @below {{contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 64], constant_value = 64}} + %cst = arith.constant dense<64> : tensor<128x64xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = }} + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + %2 = scf.for %arg9 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg1 = %1) -> (tensor<128x64x!tt.ptr>): i32 { + // expeted-remark @below {{scf.if}} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = }} + %3 = scf.if %i1 -> (tensor<128x64x!tt.ptr>) { + scf.yield %arg1 : tensor<128x64x!tt.ptr> + } else { + scf.yield %arg1 : tensor<128x64x!tt.ptr> + } + // expeted-remark @below {{tt.addptr}} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = }} + %4 = tt.addptr %3, %cst : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + // expeted-remark @below {{scf.for}} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = }} + scf.yield %1 : tensor<128x64x!tt.ptr> + } + tt.return +} + +// ----- + +tt.func @for_if_for(%i1: i1, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}) { + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0}} + %c0_i32 = arith.constant 0 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1}} + %c1_i32 = arith.constant 1 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [2], constancy = [1], constant_value = 10}} + %c10_i32 = arith.constant 10 : i32 + // expeted-remark @below {{contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 64], constant_value = 64}} + %cst = arith.constant dense<64> : tensor<128x64xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = }} + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = }} + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr> + // expeted-remark @below {{scf.for}} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = }} + // expeted-remark @below {{scf.if}} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = }} + // expeted-remark @below {{tt.addptr}} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [8, 8], constancy = [128, 64], constant_value = }} + // expeted-remark @below {{scf.for}} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 64], constant_value = }} + %3 = scf.for %arg9 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg2 = %1) -> (tensor<128x64x!tt.ptr>) : i32 { + %4 = scf.if %i1 -> (tensor<128x64x!tt.ptr>) { + %5 = scf.for %arg10 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg3 = %2) -> (tensor<128x64x!tt.ptr>) : i32 { + scf.yield %arg3 : tensor<128x64x!tt.ptr> + } + scf.yield %5 : tensor<128x64x!tt.ptr> + } else { + scf.yield %arg2 : tensor<128x64x!tt.ptr> + } + %6 = tt.addptr %4, %cst : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + scf.yield %1 : tensor<128x64x!tt.ptr> + } + tt.return +} + +// ----- + +tt.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1}} + %cst = arith.constant dense : tensor<128x128xi1> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = }} + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = }} + %3 = tt.splat %arg1 : i32 -> tensor<128x1xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = }} + %4 = arith.muli %2, %3 : tensor<128x1xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = }} + %5 = tt.splat %arg0 : !tt.ptr -> tensor<128x1x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = }} + %6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = }} + %7 = tt.expand_dims %1 {axis = 0 : i32}: tensor<128xi32> -> tensor<1x128xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = }} + %8 = tt.broadcast %6 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = }} + %9 = tt.broadcast %7 : tensor<1x128xi32> -> tensor<128x128xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [4, 16], constancy = [1, 1], constant_value = }} + %10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // expeted-remark @below {{contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = }} + %11 = tt.expand_dims %0 {axis = 1 : i32}: tensor<128xi32> -> tensor<128x1xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = }} + %12 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> + // expeted-remark @below {{contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = }} + %13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + // expeted-remark @below {{contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = }} + %14 = tt.expand_dims %1 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = }} + %15 = tt.splat %arg3 : i32 -> tensor<1x128xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = }} + %16 = arith.muli %14, %15 : tensor<1x128xi32> + // expeted-remark @below {{contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 128], constant_value = }} + %17 = tt.broadcast %13 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = }} + %18 = tt.broadcast %16 : tensor<1x128xi32> -> tensor<128x128xi32> + // expeted-remark @below {{contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = }} + %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %20 = tt.load %10, %cst, %cst_0 : tensor<128x128x!tt.ptr> + tt.store %19, %20, %cst : tensor<128x128x!tt.ptr> + tt.return +} + +// ----- + +tt.func @load_constancy(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 1 : i32}) { + // expeted-remark @below {{divisibility = [16]}} + %sixteen = arith.constant dense<16> : tensor<1024xi32> + // expeted-remark @below {{divisibility = [8]}} + %eight = arith.constant dense<8> : tensor<1024xi32> + // expeted-remark @below {{contiguity = [1024], divisibility = [1073741824], constancy = [1]}} + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // expeted-remark @below {{constancy = [16]}} + %2 = arith.divsi %1, %sixteen : tensor<1024xi32> + // expeted-remark @below {{constancy = [1024]}} + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + // expeted-remark @below {{constancy = [1024]}} + %4 = tt.splat %arg1 : i32 -> tensor<1024xi32> + // expeted-remark @below {{constancy = [8]}} + %5 = arith.divsi %1, %eight : tensor<1024xi32> + // expeted-remark @below {{constancy = [8]}} + %6 = arith.cmpi slt, %5, %4 : tensor<1024xi32> + // expeted-remark @below {{constancy = [16]}} + %7 = tt.addptr %3, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // expeted-remark @below {{constancy = [16]}} + %8 = tt.load %7 : tensor<1024x!tt.ptr> + // expeted-remark @below {{constancy = [8]}} + %9 = tt.load %7, %6 : tensor<1024x!tt.ptr> + tt.return +} + +// ----- + +// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer. +tt.func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %pid = tt.get_program_id x : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128}} + %c128_i32 = arith.constant 128 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [1], constant_value = }} + %1 = arith.muli %pid, %c128_i32 : i32 + // expeted-remark @below {{contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = }} + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [128], constancy = [128], constant_value = }} + %3 = tt.splat %1 : i32 -> tensor<128xi32> + // expeted-remark @below {{contiguity = [128], divisibility = [128], constancy = [1], constant_value = }} + %4 = arith.addi %3, %2 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [128], constant_value = }} + %5 = tt.splat %addr : !tt.ptr -> tensor<128x!tt.ptr> + // expeted-remark @below {{contiguity = [128], divisibility = [16], constancy = [1], constant_value = }} + %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr>, tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [128], constant_value = }} + %9 = tt.splat %n : i32 -> tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [16], constant_value = }} + %mask = arith.cmpi slt, %4, %9 : tensor<128xi32> + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %cst = arith.constant dense<0.0> : tensor<128xf32> + tt.store %5, %cst, %mask : tensor<128x!tt.ptr> + tt.return +} + +// ----- + +// This IR is dumped from vecadd test. +// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask. +tt.func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %1 : i32 -> tensor<64xi32> + %4 = arith.addi %3, %2 : tensor<64xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %9 = tt.splat %n_elements : i32 -> tensor<64xi32> + // expeted-remark @below {{arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = }} + %mask = arith.cmpi slt, %4, %9 : tensor<64xi32> + %11 = tt.load %6, %mask : tensor<64x!tt.ptr> + %12 = tt.load %8, %mask : tensor<64x!tt.ptr> + %13 = arith.addf %11, %12 : tensor<64xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<64x!tt.ptr> + // expeted-remark @below {{tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = }} + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %15, %13, %mask : tensor<64x!tt.ptr> + tt.return +} + +// ----- + +// This IR is dumped from vecadd test. +// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default. +tt.func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %1 : i32 -> tensor<64xi32> + %4 = arith.addi %3, %2 : tensor<64xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %9 = tt.splat %n_elements : i32 -> tensor<64xi32> + // expeted-remark @below {{arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %10 = arith.cmpi slt, %4, %9 : tensor<64xi32> + %11 = tt.load %6, %10 : tensor<64x!tt.ptr> + %12 = tt.load %8, %10 : tensor<64x!tt.ptr> + %13 = arith.addf %11, %12 : tensor<64xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<64x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %15, %13, %10 : tensor<64x!tt.ptr> + tt.return +} + +// ----- + +module { + +// We don't use function cloning here, so the alignment info is the gcd of all call sites. +tt.func @addptr_hints(%arg0: !tt.ptr) { + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1}} + %cst1 = arith.constant 1 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %1 = tt.addptr %arg0, %cst1 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4}} + %cst4 = arith.constant 4 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %2 = tt.addptr %arg0, %cst4 : !tt.ptr, i32 + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16}} + %cst16 = arith.constant 16 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %3 = tt.addptr %arg0, %cst4 : !tt.ptr, i32 + tt.return +} + +tt.func @kernel_div16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +tt.func @kernel_div8(%arg0: !tt.ptr {tt.divisibility = 8 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +tt.func @kernel_div4(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +} + +// ----- + +module { + +// We don't use function cloning here, so the alignment info is the gcd of all call sites. +tt.func @mul(%arg0: i32) { + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1}} + %cst1 = arith.constant 1 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %1 = arith.muli %arg0, %cst1 : i32 + tt.return +} + +tt.func @bar(%arg0: i32) { + tt.call @mul(%arg0) : (i32) -> () + tt.return +} + +tt.func @foo(%arg0: i32) { + tt.call @mul(%arg0) : (i32) -> () + tt.return +} + +tt.func @call_graph(%arg0: i32) { + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = 12}} + %cst12 = arith.constant 12 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4], constancy = [1], constant_value = }} + %0 = arith.muli %arg0, %cst12 : i32 + tt.call @foo(%0) : (i32) -> () + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [1], constant_value = 8}} + %cst8 = arith.constant 8 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [8], constancy = [1], constant_value = }} + %1 = arith.muli %arg0, %cst8 : i32 + tt.call @bar(%1) : (i32) -> () + tt.return +} + +} + +// ----- + +tt.func @tensor_ptr(%arg0: !tt.ptr, 1>) { + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %0 = tt.load %arg0 : !tt.ptr, 1> + tt.return +} + + +// ----- + +tt.func public @chained_for(%8: tensor<128x64x!tt.ptr> {tt.divisibility = 16 : i32}) { + // expeted-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> + // expeted-remark @below {{contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16}} + %c16_i32 = arith.constant 16 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1}} + %c1_i32 = arith.constant 1 : i32 + // expeted-remark @below {{contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0}} + %c0_i32 = arith.constant 0 : i32 + // expeted-remark @below {{contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 64], constant_value = 64}} + %cst_0 = arith.constant dense<64> : tensor<128x64xi32> + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = }} + %9 = scf.for %arg7 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg8 = %8) -> (tensor<128x64x!tt.ptr>) : i32 { + %11 = tt.addptr %arg8, %cst_0 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + scf.yield %11 : tensor<128x64x!tt.ptr> + } + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = }} + // expeted-remark @below {{contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = }} + %10 = scf.for %arg7 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg8 = %9) -> (tensor<128x64x!tt.ptr>) : i32 { + tt.store %arg8, %cst : tensor<128x64x!tt.ptr> + %11 = tt.addptr %arg8, %cst_0 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + scf.yield %11 : tensor<128x64x!tt.ptr> + } + tt.return +} + +// ----- + +module { + tt.func @int_min_does_not_underflow_in_analysis() -> i64 { + // expeted-remark @below {{divisibility = [4611686018427387904]}} + %int_min = arith.constant -9223372036854775808 : i64 + tt.return %int_min : i64 + } +} diff --git a/third_party/enflame/include/triton/test/Analysis/test-allocation.mlir b/third_party/enflame/include/triton/test/Analysis/test-allocation.mlir new file mode 100644 index 000000000..d70eaba4b --- /dev/null +++ b/third_party/enflame/include/triton/test/Analysis/test-allocation.mlir @@ -0,0 +1,914 @@ +// RUN: triton-opt %s -allow-unregistered-dialect -test-print-allocation -verify-diagnostics -o /dev/null +// RUN: triton-opt %s -allow-unregistered-dialect -test-print-allocation="get-scratch-size-function=ValidConstant" 2>&1 | FileCheck %s --check-prefix=CHECK-128 + +// Check there are no lines with a size different to 128 and we have at least a line with size 128. + +// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} +// CHECK-128: scratch offset = {{.*}}, size = 128 +// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#sliceAd0 = #ttg.slice<{dim = 0, parent = #AL}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#B_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { + +// expected-remark @below {{empty}} +// expected-remark @below {{size = 0}} +tt.func @empty(%A : !tt.ptr) { + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> + tt.return +} + +// expected-remark @below {{matmul_loop}} +// expected-remark @below {{size = 4608}} +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + // expected-remark @below {{scratch offset = 0, size = 4608}} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + // expected-remark @below {{scratch offset = 0, size = 2304}} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return +} + +// Shared memory is available after a tensor's liveness range ends +// expected-remark @below {{reusable}} +// expected-remark @below {{size = 4608}} +tt.func @reusable(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %cst3 = arith.constant dense : tensor<32x128xi1, #AL> + %cst4 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #AL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_ptr = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr = tt.splat %A : !tt.ptr -> tensor<32x128x!tt.ptr, #AL> + %a1_ = tt.load %a_ptr, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> + // expected-remark @below {{scratch offset = 0, size = 4608}} + %a1 = ttg.convert_layout %a1_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %a2_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> + // expected-remark @below {{scratch offset = 0, size = 1088}} + %a2 = ttg.convert_layout %a2_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> + %a3_ = tt.load %a_ptr, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> + // expected-remark @below {{scratch offset = 0, size = 4608}} + %a3 = ttg.convert_layout %a3_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %c = tt.dot %a1, %a2, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %a4_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> + // expected-remark @below {{scratch offset = 0, size = 1088}} + %a4 = ttg.convert_layout %a4_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> + %c1 = tt.dot %a3, %a4, %c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + tt.return +} + +// A tensor's shared memory offset is larger than it needs to accommodate further tensors +// %cst0->%c +// %cst1->%cst4 +// %cst3->%g->%h->%i +// expected-remark @below {{preallocate}} +// expected-remark @below {{size = 12288}} +tt.func @preallocate(%A : !tt.ptr) { + // expected-remark @below {{offset = 2048, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 3072, size = 512}} + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 3584, size = 512}} + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1024, size = 1024}} + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 2048, size = 1024}} + %c = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + + // expected-remark @below {{offset = 3072, size = 1024}} + %cst4 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 4096, size = 2048}} + %e = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %a : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 6144, size = 2048}} + %d = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %b : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 8192, size = 2048}} + %f = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst4 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %c : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 10240, size = 2048}} + %cst5 = ttg.local_alloc : () -> !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 4096}} + %g = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %e : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 4096}} + %h = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %d : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 4096}} + %i = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %f : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst5 : !ttg.memdesc<64x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// Unused tensors are immediately released +// expected-remark @below {{unused}} +// expected-remark @below {{size = 1024}} +tt.func @unused(%A : !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> + // expected-remark @below {{0, size = 1024}} + %cst0 = ttg.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> + // expected-remark @below {{offset = 0, size = 512}} + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 512}} + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// cst0 is alive through the entire function, it cannot be released before the end of the function +// expected-remark @below {{longlive}} +// expected-remark @below {{size = 2560}} +tt.func @longlive(%A : !tt.ptr) { + // expected-remark @below {{offset = 2048, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1024, size = 512}} + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1536, size = 512}} + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + + // expected-remark @below {{offset = 1024, size = 512}} + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1536, size = 512}} + %cst4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 512}} + %cst5 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 512}} + %cst6 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %c = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst4 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %d = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// This example triggers graph coloring with > 1 colors. +// expected-remark @below {{multi_color}} +// expected-remark @below {{size = 1504}} +tt.func @multi_color(%A : !tt.ptr) { + // expected-remark @below {{offset = 1152, size = 64}} + %cst = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1472, size = 32}} + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1216, size = 128}} + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // expected-remark @below {{scratch offset = 0, size = 1152}} + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = ttg.local_load %cst : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> + // expected-remark @below {{offset = 0, size = 128}} + %cst_3 = ttg.local_alloc : () -> !ttg.memdesc<4x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %2 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> + // expected-remark @below {{scratch offset = 0, size = 1152}} + %3 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + // expected-remark @below {{offset = 512, size = 256}} + %cst_4 = ttg.local_alloc : () -> !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 64}} + %cst_5 = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> + %4 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> + %5 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> + // expected-remark @below {{offset = 0, size = 512}} + %cst_6 = ttg.local_alloc : () -> !ttg.memdesc<8x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1344, size = 128}} + %cst_7 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %6 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> + // expected-remark @below {{offset = 0, size = 512}} + %cst_8 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 32}} + %cst_9 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 512}} + %cst_10 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %7 = ttg.local_load %cst_1 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL> + %8 = ttg.local_load %cst_4 : !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x32xf16, #AL> + // expected-remark @below {{scratch offset = 0, size = 1152}} + %9 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL> + %10 = ttg.local_load %cst_7 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL> + %cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL> + %cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL> + tt.return +} + +// This example triggers graph coloring with multiple rounds +// expected-remark @below {{multi_color_multi_rounds}} +// expected-remark @below {{size = 9504}} +tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { + // expected-remark @below {{offset = 9472, size = 32}} + %cst = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 9344, size = 128}} + %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 8192}} + %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // expected-remark @below {{scratch offset = 8192, size = 1152}} + %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> + // expected-remark @below {{offset = 8704, size = 128}} + %cst_3 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %2 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> + // expected-remark @below {{offset = 8192, size = 512}} + %cst_4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %3 = ttg.local_load %cst_0 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL> + %4 = ttg.local_load %cst_1 : !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<1024x4xf16, #AL> + // expected-remark @below {{scratch offset = 0, size = 1152}} + %5 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %6 = ttg.local_load %cst_3 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL> + tt.return +} + + +// expected-remark @below {{alloc}} +// expected-remark @below {{size = 512}} +tt.func @alloc(%A : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // expected-remark @below {{offset = 0, size = 512}} + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + + +// expected-remark @below {{dealloc}} +// expected-remark @below {{size = 2048}} +tt.func @dealloc(%A : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 1024}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1024, size = 1024}} + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// expected-remark @below {{scratch}} +// expected-remark @below {{size = 128}} +tt.func @scratch() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + // expected-remark @below {{scratch offset = 0, size = 128}} + %b = "tt.reduce" (%cst0) ({ + ^bb0(%arg0: f16, %arg1: f16): + %add = arith.addf %arg0, %arg1 : f16 + tt.reduce.return %add : f16 + }) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0> + tt.return +} + +// expected-remark @below {{trans}} +// expected-remark @below {{size = 1024}} +tt.func @trans(%A : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 1024}} + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %b = ttg.memdesc_trans %tensor {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory, mutable> + tt.return +} + + +// expected-remark @below {{extract_slice}} +// expected-remark @below {{size = 512}} +tt.func @extract_slice(%A : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %index = arith.constant 0 : i32 + %cst1 = ttg.memdesc_subview %cst0[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// expected-remark @below {{atomic_scalar}} +// expected-remark @below {{size = 8196}} +tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> + // expected-remark @below {{offset = 0, size = 8192}} + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // expected-remark @below {{scratch offset = 8192, size = 4}} + %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return %4 : i32 +} + +// expected-remark @below {{atomic_scalar_no_use}} +// expected-remark @below {{size = 8192}} +tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> + // expected-remark @below {{offset = 0, size = 8192}} + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return +} + +// B0 -> (B1) -> B0 +// Memory used by B1 can be reused by B0. +// expected-remark @below {{if}} +// expected-remark @below {{size = 2048}} +tt.func @if(%i1 : i1) { + // expected-remark @below {{offset = 1024, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1536, size = 512}} + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.if %i1 { + // expected-remark @below {{offset = 0, size = 1024}} + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + // expected-remark @below {{offset = 1024, size = 512}} + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 1536, size = 512}} + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// B0 -> (B1) -> (B2) -> B0 +// Memory used by B0 cannot be reused by B1 or B2. +// expected-remark @below {{if_else}} +// expected-remark @below {{size = 3072}} +tt.func @if_else(%i1 : i1) { + // expected-remark @below {{offset = 1536, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 2048, size = 512}} + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.if %i1 { + // expected-remark @below {{offset = 0, size = 1024}} + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %b = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + } else { + // expected-remark @below {{offset = 1024, size = 512}} + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 2560, size = 512}} + %cst3 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst2 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst3 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + // expected-remark @below {{offset = 0, size = 1024}} + %a = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst1 : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// Block arguments and yields are memory aliases that do not trigger a new +// allocation. +// expected-remark @below {{for}} +// expected-remark @below {{size = 24576}} +tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 8192}} + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 8192, size = 8192}} + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 16384, size = 8192}} + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return + // CHECK-NEXT: size = 24576 +} + +// expected-remark @below {{for_if_slice}} +// expected-remark @below {{size = 24576}} +tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // expected-remark @below {{offset = 0, size = 8192}} + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 8192, size = 8192}} + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 16384, size = 8192}} + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + scf.if %i1 { + %index = arith.constant 8 : i32 + %cst0 = ttg.memdesc_subview %a_shared[%index, %index] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield + } + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return +} + +// c0 cannot be released in the loop +// expected-remark @below {{for_use_ancestor}} +// expected-remark @below {{size = 32768}} +tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // expected-remark @below {{offset = 0, size = 8192}} + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 8192, size = 8192}} + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 16384, size = 8192}} + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c0 = ttg.memdesc_trans %c_shared_init {order=array} : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #A_SHARED_T, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 24576, size = 8192}} + %c1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %b_shared, %a_shared: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return +} + +// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2. +// So they cannot be reused by cst0 and cst1, but can be reused by cst2. +// expected-remark @below {{for_for_if}} +// expected-remark @below {{size = 40960}} +tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + // expected-remark @below {{offset = 0, size = 8192}} + %a_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 8192, size = 8192}} + %b_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 16384, size = 8192}} + %c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> { + // expected-remark @below {{offset = 24576, size = 8192}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } else { + // expected-remark @below {{offset = 32768, size = 8192}} + %cst1 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + scf.yield %cst1 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + // expected-remark @below {{offset = 0, size = 8192}} + %cst2 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// expected-remark @below {{alloc1}} +// expected-remark @below {{size = 512}} +tt.func @alloc1(%A : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// expected-remark @below {{alloc2}} +// expected-remark @below {{size = 1024}} +tt.func @alloc2(%A : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 1024}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// expected-remark @below {{alloc3}} +// expected-remark @below {{size = 1024}} +tt.func @alloc3(%cond : i1) { + scf.if %cond { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + } else { + // expected-remark @below {{offset = 0, size = 1024}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return +} + +// expected-remark @below {{alloc4}} +// expected-remark @below {{size = 1024}} +tt.func @alloc4(%A : !tt.ptr, %cond : i1) { + scf.if %cond { + // expected-remark @below {{virtual offset = 0, size = 1024}} + tt.call @alloc3(%cond) : (i1) -> () + } else { + // expected-remark @below {{virtual offset = 0, size = 512}} + tt.call @alloc1(%A) : (!tt.ptr) -> () + } + tt.return +} + +// expected-remark @below {{single_call}} +// expected-remark @below {{size = 512}} +tt.func @single_call(%A : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // expected-remark @below {{virtual offset = 0, size = 512}} + tt.call @alloc1(%A) : (!tt.ptr) -> () + tt.return +} + +// expected-remark @below {{multiple_calls}} +// expected-remark @below {{size = 1024}} +tt.func @multiple_calls(%A : !tt.ptr) { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{virtual offset = 0, size = 512}} + tt.call @alloc1(%A) : (!tt.ptr) -> () + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // expected-remark @below {{virtual offset = 0, size = 1024}} + tt.call @alloc2(%A) : (!tt.ptr) -> () + tt.return +} + +// expected-remark @below {{if_else_calls}} +// expected-remark @below {{size = 1024}} +tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + scf.if %cond { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 0, size = 1024}} + %cst1 = ttg.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{virtual offset = 0, size = 512}} + tt.call @alloc1(%A) : (!tt.ptr) -> () + } else { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // expected-remark @below {{virtual offset = 0, size = 1024}} + tt.call @alloc2(%A) : (!tt.ptr) -> () + } + tt.return +} + +// expected-remark @below {{for_calls}} +// expected-remark @below {{size = 512}} +tt.func @for_calls(%A : !tt.ptr, %cond : i1) { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + %lb = arith.constant 0 : index + %ub = arith.constant 10 : index + %step = arith.constant 1 : index + scf.for %iv = %lb to %ub step %step { + // expected-remark @below {{virtual offset = 0, size = 512}} + tt.call @alloc1(%A) : (!tt.ptr) -> () + } + tt.return + // CHECK-NEXT: size = 512 +} + +// expected-remark @below {{call_graph_1}} +// expected-remark @below {{size = 1024}} +tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{virtual offset = 0, size = 1024}} + tt.call @alloc3(%cond) : (i1) -> () + tt.return +} + +// expected-remark @below {{call_graph_2}} +// expected-remark @below {{size = 1024}} +tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { + // expected-remark @below {{offset = 0, size = 512}} + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{virtual offset = 0, size = 1024}} + tt.call @alloc4(%A, %cond) : (!tt.ptr, i1) -> () + tt.return +} + +// expected-remark @below {{scan_alloc}} +// expected-remark @below {{size = 128}} +tt.func @scan_alloc(%x : tensor<8x16xf32, #AL>) { + // expected-remark @below {{offset = 0, size = 128}} + %a = "tt.scan"(%x) <{axis = 0 : i32, reverse = false}>({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.scan.return %add : f32 + }) : (tensor<8x16xf32, #AL>) -> tensor<8x16xf32, #AL> + tt.return +} + +// expected-remark @below {{warp_specialize_default_region}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} +tt.func @warp_specialize_default_region() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_specialize() + default { + // expected-remark @below {{offset = 16, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_yield + } + partition0() num_warps(1) { + ttg.warp_return + } : () -> () + "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + + tt.return +} + +// expected-remark @below {{nonoverlapping_liveness_in_default_region}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} +tt.func @nonoverlapping_liveness_in_default_region() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_specialize() + default { + // expected-remark @below {{offset = 16, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + "use"(%1) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + // expected-remark @below {{offset = 16, size = 16}} + %2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + "use"(%2) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + ttg.warp_yield + } + partition0() num_warps(1) { + ttg.warp_return + } : () -> () + "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + + tt.return +} + +// expected-remark @below {{overlapping_liveness_in_default_region}} +// expected-remark @below {{size = 49}} +// expected-remark @below {{offset = 48, size = 1}} +tt.func @overlapping_liveness_in_default_region() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_specialize() + default { + // expected-remark @below {{offset = 16, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 32, size = 16}} + %2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + "use"(%1) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + "use"(%2) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + ttg.warp_yield + } + partition0() num_warps(1) { + ttg.warp_return + } : () -> () + "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + + tt.return +} + +// expected-remark @below {{alias_through_default_outputs}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} +tt.func @alias_through_default_outputs() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + %1 = ttg.warp_specialize() + default { + ttg.warp_yield %0 : !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + } + partition0() num_warps(1) { + ttg.warp_return + } : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 16, size = 16}} + %2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + "use"(%1) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + tt.return +} + +// expected-remark @below {{implicit_capture_liveness}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} +tt.func @implicit_capture_liveness() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_specialize() + default { + // expected-remark @below {{offset = 16, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + ttg.warp_yield + } + partition0() num_warps(1) { + ttg.warp_return + } : () -> () + tt.return +} + +// expected-remark @below {{implicit_and_explicit_capture_liveness}} +// expected-remark @below {{size = 45}} +// expected-remark @below {{offset = 44, size = 1}} +tt.func @implicit_and_explicit_capture_liveness() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 16, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 32, size = 12}} + ttg.warp_specialize(%1) + default { + "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + ttg.warp_yield + } + partition0(%arg0: !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) num_warps(1) { + ttg.warp_return + } : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + tt.return +} + +// expected-remark @below {{explicit_capture_liveness}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} +tt.func @explicit_capture_liveness() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 16, size = 12}} + ttg.warp_specialize(%0) + default { + // expected-remark @below {{offset = 16, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_yield + } + partition0(%arg0: !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) num_warps(1) { + ttg.warp_return + } : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + tt.return +} + +// expected-remark @below {{implicit_capture_liveness_default}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} +tt.func @implicit_capture_liveness_default() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_specialize() + default { + // FIXME: This is correct, but not optimal. The memory for `%0` should be + // reused for the next allocation. The same problem happens with `scf.if`. + "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + // expected-remark @below {{offset = 16, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_yield + } + partition0() num_warps(1) { + ttg.warp_return + } : () -> () + tt.return +} + +// expected-remark @below {{liveness_in_partition}} +// expected-remark @below {{size = 36}} +// expected-remark @below {{offset = 32, size = 4}} +tt.func @liveness_in_partition() { + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(4) { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 16, size = 16}} + %2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + ttg.warp_return + } : () -> () + tt.return +} + +// expected-remark @below {{aliasing_in_partition}} +// expected-remark @below {{size = 36}} +// expected-remark @below {{offset = 32, size = 4}} +tt.func @aliasing_in_partition() { + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(4) { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %1 = ttg.memdesc_subview %0[%c0_i32] : !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> -> !ttg.memdesc<1xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 16, size = 16}} + %2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + "use"(%1) : (!ttg.memdesc<1xi64, #A_SHARED, #smem, mutable>) -> () + ttg.warp_return + } : () -> () + tt.return +} + +// expected-remark @below {{partition_region_interference}} +// expected-remark @below {{size = 88}} +// expected-remark @below {{offset = 80, size = 8}} +tt.func @partition_region_interference() { + // expected-remark @below {{offset = 0, size = 16}} + %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_specialize() + default { + // expected-remark @below {{offset = 16, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_yield + } + partition0() num_warps(4) { + // expected-remark @below {{offset = 32, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 48, size = 16}} + %2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + "use"(%1) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + ttg.warp_return + } + partition1() num_warps(4) { + // expected-remark @below {{offset = 64, size = 16}} + %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 64, size = 16}} + %2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_return + } : () -> () + "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () + tt.return +} + +// expected-remark @below {{two_different_ws}} +// expected-remark @below {{size = 17}} +// expected-remark @below {{offset = 16, size = 1}} +tt.func @two_different_ws() { + ttg.warp_specialize() + default { + // expected-remark @below {{offset = 0, size = 16}} + ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_yield + } + partition0() num_warps(1) { + ttg.warp_return + } : () -> () + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(1) { + // expected-remark @below {{offset = 0, size = 16}} + ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + ttg.warp_return + } : () -> () + tt.return +} + +// expected-remark @below {{ptr_allocation_datalayout}} +// expected-remark @below {{size = 8}} +tt.func @ptr_allocation_datalayout(%arg0: !tt.ptr) { + // expected-remark @below {{offset = 0, size = 8}} + ttg.warp_specialize(%arg0) + default { + ttg.warp_yield + } : (!tt.ptr) -> () + tt.return +} + +// expected-remark @below {{tightly_packed_captures}} +// expected-remark @below {{size = 9}} +tt.func @tightly_packed_captures(%arg0: i8, %arg1: i64) { + // expected-remark @below {{offset = 0, size = 9}} + ttg.warp_specialize(%arg0, %arg1) + default { + ttg.warp_yield + } : (i8, i64) -> () + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Analysis/test-membar.mlir b/third_party/enflame/include/triton/test/Analysis/test-membar.mlir new file mode 100644 index 000000000..de97a1a4c --- /dev/null +++ b/third_party/enflame/include/triton/test/Analysis/test-membar.mlir @@ -0,0 +1,989 @@ +// RUN: triton-opt %s -split-input-file --convert-scf-to-cf --allocate-shared-memory -test-print-membar | FileCheck %s --check-prefix=CHECK --check-prefix=CF +// RUN: triton-opt %s -split-input-file --allocate-shared-memory -test-print-membar | FileCheck %s --check-prefix=CHECK --check-prefix=SCF +// RUN: triton-opt %s -split-input-file --convert-scf-to-cf --allocate-shared-memory -test-print-membar | FileCheck %s --check-prefix=CHECK --check-prefix=CF +// RUN: triton-opt %s -split-input-file --allocate-shared-memory -test-print-membar | FileCheck %s --check-prefix=CHECK --check-prefix=SCF + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#sliceAd0 = #ttg.slice<{dim = 0, parent = #AL}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED_T = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { + +// CHECK-LABEL: matmul_loop +// There shouldn't be any membar with the dot op encoding. +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return +} + +// CHECK-LABEL: raw_single_block +tt.func @raw_single_block(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return +} + +// CHECK-LABEL: war_single_block +tt.func @war_single_block(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: ttg.local_alloc + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: %4 = ttg.local_alloc + %4 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + tt.return +} + +// CHECK-LABEL: war_single_block_local_store +tt.func @war_single_block_local_store(%A : !tt.ptr) { + %cst1 = arith.constant dense : tensor<128x32xi1, #AL> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_alloc + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_store + ttg.local_store %1, %2 : tensor<128x32xf16, #AL> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + +// CHECK-LABEL: scratch +tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { + %cst0 = ttg.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + // CHECK: gpu.barrier + // CHECK: tt.reduce + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %2 = "tt.reduce" (%1) ({ + ^bb0(%arg1: f16, %arg2: f16): + %add = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %add : f16 + }) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0> + tt.return +} + +// CHECK-LABEL: async_wait +tt.func @async_wait(%arg: tensor<32x16xf16, #AL>) { + %cst0 = ttg.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: ttg.async_wait + ttg.async_wait {num = 4 : i32} + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<32x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: subview +tt.func @subview() { + %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> + %a = ttg.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> + %index = arith.constant 0 : i32 + %0 = ttg.memdesc_subview %a[%index, %index] : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + tt.return +} + +// CHECK-LABEL: trans +tt.func @trans(%a: !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK-NOT: gpu.barrier + %b = ttg.memdesc_trans %a {order=array} : !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory> -> !ttg.memdesc<32x16xf16, #A_SHARED_T, #ttg.shared_memory> + tt.return +} + +// CHECK-LABEL: async_copy_global_to_local +tt.func @async_copy_global_to_local(%A : !tt.ptr, %i1 : i1) { + %index = arith.constant 0 : i32 + %a_ptr = tt.splat %A : !tt.ptr -> tensor<16x16x!tt.ptr, #AL> + %mask = tt.splat %i1 : i1 -> tensor<16x16xi1, #AL> + %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %subview = ttg.memdesc_subview %alloc[%index, %index, %index] : !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %4 = ttg.local_load %subview : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + tt.return +} +// If branch inserted a barrier for %cst0, but else didn't, then the barrier should be inserted in the parent region +// CHECK-LABEL: multi_blocks +tt.func @multi_blocks(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.if %i1 { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + scf.yield + } else { + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield + } + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region +// CHECK-LABEL: multi_blocks_join_barrier +tt.func @multi_blocks_join_barrier(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.if %i1 { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + scf.yield + } else { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + scf.yield + } + // CHECK-NOT: gpu.barrier + // CHECK: tt.return + %a_ = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// Read yielded tensor requires a barrier +// CHECK-LABEL: multi_blocks_yield +tt.func @multi_blocks_yield(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %a = scf.if %i1 -> (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + } else { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %3 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + } + %a_ = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK: ttg.local_load + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttg.local_load + %4 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// Even though the entry block doesn't have a barrier, the successors should have barriers +// CHECK-LABEL: multi_blocks_entry_no_shared +tt.func @multi_blocks_entry_no_shared(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %a = scf.if %i1 -> (!ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttg.local_load + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %0 = ttg.local_load %cst1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + } else { + // CHECK-NOT: gpu.barrier + // CHECK: ttg.local_alloc + %cst1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + } + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// Conservatively add a barrier as if the branch (%i1) is never taken +// CHECK-LABEL: multi_blocks_noelse +tt.func @multi_blocks_noelse(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + scf.if %i1 { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + scf.yield + } + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// Conservatively add a barrier as if the branch (%i2) is never taken +// CHECK-LABEL: multi_blocks_nested_scf +tt.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.if %i1 { + scf.if %i2 { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield + } + scf.yield + } else { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield + } + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return +} + +// CHECK-LABEL: for +tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + tt.return +} + +// Although a_shared and b_shared are synced before entering the loop, +// they are reassociated with aliases (c_shared) and thus require a barrier. +// CHECK-LABEL: for_alias +tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %a1 = ttg.local_load %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return +} + +// Although cst2 is not an argument of scf.yield, its memory is reused by cst1. +// So we need a barrier both before and after cst1 +// CHECK-LABEL: for_reuse +tt.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %a1 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %1 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %a2 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %2 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return +} + +// CHECK-LABEL: for_reuse_nested +tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %a0 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %0 = ttg.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %a1 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %1 = ttg.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %a2 = ttg.local_load %a_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = ttg.local_load %b_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + %2 = ttg.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + scf.yield %c_shared, %a_shared, %b_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %r = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return +} + +// repeatedly write to the same shared memory addresses +// CHECK-LABEL: for_for_if +tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } else { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + scf.yield %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + scf.yield %a_shared, %b_shared, %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + tt.return +} + +// c_block_next can either be converted from c_shared_init or c_shared_next_next +// CHECK-LABEL: for_if_for +tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> + %a_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %b_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %c_shared_init = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + %c_blocked = ttg.local_load %c_shared_init : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %cst0 = ttg.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + scf.yield %cst0 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } else { + %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>) { + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %c_blocked_next = ttg.local_load %c_shared_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + scf.yield %c_shared_ : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + // CHECK-NOT: gpu.barrier + %b_blocked_next = ttg.local_load %b_shared: !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %a_shared, %b_shared, %c_shared_next_next : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + } + tt.return +} + +// CHECK-LABEL: cf_if +tt.func @cf_if(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.cond_br %i1, ^bb1, ^bb2 +^bb1: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + cf.br ^bb2 +^bb2: // 2 preds: ^bb0, ^bb1 + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %1 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: cf_if_else +tt.func @cf_if_else(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.cond_br %i1, ^bb1, ^bb2 +^bb1: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.br ^bb3(%1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) +^bb2: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.br ^bb3(%3 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>) +^bb3(%arg: !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory>): // 2 preds: ^bb1, ^bb2 + cf.br ^bb4 +^bb4: // pred: ^bb3 + // CHECK: ttg.local_load + %4 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %5 = ttg.local_load %arg : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: cf_if_else_return +tt.func @cf_if_else_return(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %a = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %b = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + cf.cond_br %i1, ^bb1, ^bb2 +^bb1: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %0 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %1 = ttg.local_load %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +^bb2: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %2 = ttg.local_load %a : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + %3 = ttg.local_load %b : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: atomic_scalar +tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { + // CHECK-NOT: gpu.barrier + %c0_i32 = arith.constant 0 : i32 + %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return %4 : i32 +} + +// CHECK-LABEL: atomic_scalar_no_use +tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> + %2 = ttg.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> + %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %3 = ttg.local_load %2 : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory> -> tensor<128x32xf16, #AL> + tt.return +} + +} + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { + +// CHECK-LABEL: convert_layout1 +tt.func @convert_layout1(%A : !tt.ptr) { + // CHECK-NOT: gpu.barrier + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: convert_layout2 +tt.func @convert_layout2(%A : !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load + // CHECK-NEXT: gpu.barrier + // CHECK: ttg.local_load + %3 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %4 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: convert_layout3 +tt.func @convert_layout3(%cond : i1) { + scf.if %cond { + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load + // CHECK-NOT: gpu.barrier + %1 = ttg.local_load %0 : !ttg.memdesc<16x64xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x64xf16, #AL> + } else { + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttg.local_alloc + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + } + tt.return +} + +// CHEKC-LABEL: convert_layout4 +tt.func @convert_layout4(%A : !tt.ptr, %cond : i1) { + // CHECK-NOT: gpu.barrier + scf.if %cond { + tt.call @convert_layout3(%cond) : (i1) -> () + } else { + tt.call @convert_layout2(%A) : (!tt.ptr) -> () + } + tt.return +} + +// CHECK-LABEL: convert_layout5 +tt.func @convert_layout5(%A : !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %0 = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + %1 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> + // CHECK: ttg.local_load + // CHECK-NEXT: gpu.barrier + // CHECK: ttg.local_load + %3 = ttg.local_load %0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %4 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: single_call_sync +tt.func @single_call_sync(%A : !tt.ptr) { + %0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK: tt.call + // CHECK-NEXT: gpu.barrier + tt.call @convert_layout1(%A) : (!tt.ptr) -> () + %1 = ttg.convert_layout %0 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + tt.return +} + +// CHECK-LABEL: single_call_no_sync +// %1 can reuse %0 in convert_layout2, which has been synced +tt.func @single_call_no_sync(%A : !tt.ptr) { + // CHECK-NOT: gpu.barrier + %0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + tt.call @convert_layout5(%A) : (!tt.ptr) -> () + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #BL> + tt.return +} + +// CHECK-LABEL: multiple_calls +tt.func @multiple_calls(%A : !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + tt.call @convert_layout1(%A) : (!tt.ptr) -> () + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + tt.call @convert_layout2(%A) : (!tt.ptr) -> () + tt.return +} + +// CHECK-LABEL: if_else_calls +tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { + scf.if %cond { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst_ = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.call + // CHECK-NEXT: gpu.barrier + tt.call @convert_layout1(%A) : (!tt.ptr) -> () + %cst1 = ttg.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !ttg.memdesc<16x32xf16, #A_SHARED, #ttg.shared_memory> + } else { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK: tt.call + // CHECK-NOT: gpu.barrier + tt.call @convert_layout2(%A) : (!tt.ptr) -> () + } + tt.return +} + +// CHECK-LABEL: for_calls +tt.func @for_calls(%A : !tt.ptr, %cond : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + %lb = arith.constant 0 : index + %ub = arith.constant 10 : index + %step = arith.constant 1 : index + scf.for %iv = %lb to %ub step %step { + // CHECK: gpu.barrier + // CHECK-NEXT: tt.call + tt.call @convert_layout1(%A) : (!tt.ptr) -> () + } + tt.return +} + +// CHECK-LABEL: call_graph_1 +tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> // CHECK: gpu.barrier + // CHECK-NEXT: tt.call + tt.call @convert_layout3(%cond) : (i1) -> () + tt.return +} + +// CHECK-LABEL: call_graph_2 +tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + tt.call @convert_layout4(%A, %cond) : (!tt.ptr, i1) -> () + // CHECK: tt.call + // CHECK-NEXT: gpu.barrier + %cst0 = ttg.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory> + tt.return +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { + tt.func public @kernel(%arg3: !tt.ptr, %arg4: !tt.ptr, %arg12: tensor<32x128xf16, #blocked>, %arg13: tensor<32x128xf32, #blocked>, %arg14: tensor<32x32xf16, #blocked1>) { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #blocked> + %37 = ttg.local_alloc %arg14 {allocation.offset = 0 : i32} : (tensor<32x32xf16, #blocked1>) -> !ttg.memdesc<32x32xf16, #shared, #ttg.shared_memory> + %58 = ttg.local_alloc %arg12 : (tensor<32x128xf16, #blocked>) -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> + cf.br ^bb1 + ^bb1: // 2 preds: ^bb0, ^bb1 + %59 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + %60 = arith.cmpi eq, %59, %c0_i32 : i32 + cf.cond_br %60, ^bb1, ^bb2 + ^bb2: // pred: ^bb1 + %72 = ttg.convert_layout %arg13 : tensor<32x128xf32, #blocked> -> tensor<32x128xf32, #mma> + %73 = ttg.local_load %37 : !ttg.memdesc<32x32xf16, #shared, #ttg.shared_memory> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %74 = ttg.local_load %58 : !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.dot %73, %74, %72, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x128xf32, #mma> + %76 = ttg.convert_layout %75 {allocation.offset = 0 : i32} : tensor<32x128xf32, #mma> -> tensor<32x128xf32, #blocked> + %77 = arith.truncf %76 : tensor<32x128xf32, #blocked> to tensor<32x128xf16, #blocked> + %78 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + tt.store %78, %77 : tensor<32x128x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { +// CHECK-LABEL: tma_special_cases +tt.func @tma_special_cases(%arg1: !tt.ptr) -> (tensor<256x64xf16, #blocked>){ + %true = arith.constant 1 : i1 + %cx = arith.constant dense<1> : tensor<32xi32> + %c0 = arith.constant 0 : i32 + %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + // CHECK: ttng.init_barrier + // CHECK-NEXT: ttng.init_barrier + ttng.init_barrier %barrier, 1 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.init_barrier %barrier, 1 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.wait_barrier + ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.ptr, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttng.wait_barrier + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.ptr, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + + // CHECK-NEXT: ttg.local_load + %t = ttg.local_load %alloc : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.wait_barrier + ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.ptr, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: ttng.async_tma_gather + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttng.wait_barrier + ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.async_tma_gather %arg1[%cx, %c0] %alloc, %barrier, %true : !tt.ptr, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, i1 + ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttng.inval_barrier + // CHECK-NEXT: ttng.inval_barrier + ttng.inval_barrier %barrier : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.inval_barrier %barrier : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + + tt.return %t : tensor<256x64xf16, #blocked> +} +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 18944 : i32} { +// CHECK-LABEL: tma_special_cases_cf +tt.func @tma_special_cases_cf(%arg1: !tt.ptr, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){ + %true = arith.constant 1 : i1 + %c0 = arith.constant 0 : i32 + %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + // CF: cf.cond_br + // SCF: scf.if + scf.if %i1 { + // CHECK-NOT: gpu.barrier + // CHECK: ttng.async_tma_copy_global_to_local + // CHECK-NEXT: ttng.barrier_expect + // CHECK-NEXT: ttng.wait_barrier + // CF-NEXT: cf.br + // SCF-NEXT: } else { + ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.ptr, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> + } else { + // CHECK-NOT: gpu.barrier + // CHECK: ttg.local_store + // CF-NEXT: cf.br + // SCF-NEXT: } + ttg.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + } + // CHECK: gpu.barrier + // CHECK-NEXT: ttg.local_load + %t = ttg.local_load %alloc : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + tt.return %t : tensor<256x64xf16, #blocked> +} +} + +// ----- + +#layout = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#smem = #ttg.shared_memory + +// CHECK-LABEL: @warp_specialize_isolated_regions +tt.func @warp_specialize_isolated_regions(%arg0: tensor<1xi64>) { + // CHECK-NEXT: local_alloc + %0 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + // CHECK-NEXT: local_store + ttg.local_store %arg0, %0 : tensor<1xi64> -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + // CHECK-NEXT: barrier + // CHECK-NEXT: local_load + ttg.local_load %0 : !ttg.memdesc<1xi64, #layout, #smem, mutable> -> tensor<1xi64> + + // CHECK-NEXT: warp_specialize + ttg.warp_specialize() + default { + ttg.warp_yield + } + // CHECK: partition0 + partition0() num_warps(4) { + %cst = arith.constant dense<0> : tensor<1xi64> + // CHECK: local_alloc + %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + // CHECK-NEXT: local_store + ttg.local_store %cst, %1 : tensor<1xi64> -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + // CHECK-NEXT: barrier + // CHECK-NEXT: local_load + ttg.local_load %1 : !ttg.memdesc<1xi64, #layout, #smem, mutable> -> tensor<1xi64> + // CHECK-NEXT: warp_return + ttg.warp_return + } : () -> () + + tt.return +} + +// CHECK-LABEL: @warp_specialize_into_default +tt.func @warp_specialize_into_default(%arg0: tensor<1xi64>) { + // CHECK-NEXT: local_alloc + %0 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + // CHECK-NEXT: local_store + ttg.local_store %arg0, %0 : tensor<1xi64> -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + // CHECK-NEXT: warp_specialize + ttg.warp_specialize() + // CHECK-NEXT: default + default { + // CHECK-NEXT: barrier + // CHECK-NEXT: local_load + ttg.local_load %0 : !ttg.memdesc<1xi64, #layout, #smem, mutable> -> tensor<1xi64> + // CHECK-NEXT: barrier + gpu.barrier + // CHECK-NEXT: warp_yield + ttg.warp_yield + // CHECK-NEXT: () -> () + } : () -> () + // CHECK-NEXT: local_store + ttg.local_store %arg0, %0 : tensor<1xi64> -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + tt.return +} + +// CHECK-LABEL: @default_region_cfg +tt.func @default_region_cfg(%arg0: tensor<1xi64>, %arg1: i1) { + // CHECK-NEXT: local_alloc + %0 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + // CHECK-NEXT: local_store + ttg.local_store %arg0, %0 : tensor<1xi64> -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + // CHECK-NEXT: warp_specialize + ttg.warp_specialize() + // CHECK-NEXT: default + default { + // CHECK-NEXT: barrier + // CHECK-NEXT: local_load + ttg.local_load %0 : !ttg.memdesc<1xi64, #layout, #smem, mutable> -> tensor<1xi64> + cf.cond_br %arg1, ^bb1, ^bb2 + // CHECK: ^bb1: + ^bb1: + // CHECK-NEXT: barrier + gpu.barrier + cf.br ^bb3 + ^bb2: + cf.br ^bb3 + // CHECK: ^bb3: + ^bb3: + // CHECK-NEXT: warp_yield + ttg.warp_yield + // CHECK-NEXT: () -> () + } : () -> () + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: local_store + ttg.local_store %arg0, %0 : tensor<1xi64> -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + tt.return +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-warps" = 4 : i32} { + +// CHECK-LABEL: @direct_backedge_within_loop +tt.func @direct_backedge_within_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr, %arg5: i1) { + // CHECK-NEXT: constant + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked> + // CHECK-NEXT: local_alloc + %0 = ttg.local_alloc %cst : (tensor<128x32xf16, #blocked>) -> !ttg.memdesc<128x32xf16, #shared, #smem> + // CHECK-NEXT: barrier + // CHECK-NEXT: local_load + %1 = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #blocked> + // CHECK-NEXT: br + cf.br ^bb1(%arg0, %0 : index, !ttg.memdesc<128x32xf16, #shared, #smem>) +^bb1(%2: index, %3: !ttg.memdesc<128x32xf16, #shared, #smem>): + cf.cond_br %arg5, ^bb2, ^bb3 +// CHECK: ^bb2: +^bb2: + // CHECK-NEXT: barrier + // CHECK-NEXT: local_alloc + %4 = ttg.local_alloc %cst : (tensor<128x32xf16, #blocked>) -> !ttg.memdesc<128x32xf16, #shared, #smem> + // CHECK-NEXT: br + cf.br ^bb1(%arg1, %4 : index, !ttg.memdesc<128x32xf16, #shared, #smem>) +// CHECK: ^bb3 +^bb3: + // CHECK-NEXT: barrier + // CHECK-NEXT: local_load + %5 = ttg.local_load %3 : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #blocked> + // CHECK-NEXT: cond_br + cf.cond_br %arg5, ^bb3, ^bb4 +^bb4: + tt.return +} + +} + +// ----- + +// CHECK-LABEL: tmem_copy_after_alloc +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @tmem_copy_after_alloc(%arg0: tensor<1x2048xf8E4M3FN, #blocked>) { + // CHECK: local_alloc + %0 = ttg.local_alloc %arg0 {allocation.offset = 53248 : i32} : (tensor<1x2048xf8E4M3FN, #blocked>) -> !ttg.memdesc<1x2048xf8E4M3FN, #shared, #smem> + // CHECK: tmem_alloc + %1 = ttng.tmem_alloc {tensor_memory_col_offset = 256 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x16xf8E4M3FN, #tmem_scales, #ttng.tensor_memory, mutable> + // gpu.barrier + // CHECK: tmem_copy + ttng.tmem_copy %0, %1, : (!ttg.memdesc<1x2048xf8E4M3FN, #shared, #smem>, !ttg.memdesc<128x16xf8E4M3FN, #tmem_scales, #ttng.tensor_memory, mutable>) -> () + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/CMakeLists.txt b/third_party/enflame/include/triton/test/CMakeLists.txt new file mode 100644 index 000000000..edef53315 --- /dev/null +++ b/third_party/enflame/include/triton/test/CMakeLists.txt @@ -0,0 +1,31 @@ +add_subdirectory(lib) + +llvm_canonicalize_cmake_booleans( + MLIR_ENABLE_BINDINGS_PYTHON +) + +configure_lit_site_cfg( + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py + MAIN_CONFIG + ${CMAKE_CURRENT_SOURCe_DIR}/lit.cfg.py +) + +set(TRITON_TEST_DEPENDS + triton-opt + triton-tensor-layout + triton-llvm-opt +) + +set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck") +set(LIT_ARGS "-Dfilecheck=${FILECHECK_PATH}") + +add_lit_testsuite(check-triton-lit-tests "Running the triton regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + ARGS ${LIT_ARGS} + DEPENDS ${TRITON_TEST_DEPENDS} + ) + +set_target_properties(check-triton-lit-tests PROPERTIES FOLDER "Tests") + +add_lit_testsuites(TRITON-LIT-TESTS ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TRITON_TEST_DEPENDS}) diff --git a/third_party/enflame/include/triton/test/Conversion/allocate_shared_memory.mlir b/third_party/enflame/include/triton/test/Conversion/allocate_shared_memory.mlir new file mode 100644 index 000000000..8f7e4fbaa --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/allocate_shared_memory.mlir @@ -0,0 +1,17 @@ +// RUN: triton-opt %s --allocate-shared-memory | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> + +// CHECK-LABEL: module +// CHECK-SAME: ttg.shared = 131072 : i32 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK-LABEL: @gather_op +// TODO(jeff): Optimize the lowering to reduce shared memory usage. +tt.func @gather_op(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) { + // CHECK-NEXT: allocation.offset = 0 : i32 + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x256xi32, #blocked>) -> tensor<1024x256xf32, #blocked> + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/allocate_warp_groups.mlir b/third_party/enflame/include/triton/test/Conversion/allocate_warp_groups.mlir new file mode 100644 index 000000000..6d2c4fd9e --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/allocate_warp_groups.mlir @@ -0,0 +1,65 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-allocate-warp-groups | FileCheck %s + +// CHECK: module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 4 : i32} +module attributes {"ttg.num-warps" = 4 : i32} { +} + +// ----- + +// CHECK: module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 17 : i32} +module attributes {"ttg.num-warps" = 4 : i32} { + +tt.func @kernel() { + // CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array} + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(1) { + ttg.warp_return + } + partition1() num_warps(8) { + ttg.warp_return + } + partition2() num_warps(4) { + ttg.warp_return + } : () -> () + tt.return +} + +} + +// ----- + +// CHECK: module attributes {"ttg.num-warps" = 2 : i32, "ttg.total-num-warps" = 11 : i32} +module attributes {"ttg.num-warps" = 2 : i32} { + +tt.func @two_warp_specialize() { + // CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array} + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(2) { + ttg.warp_return + } + partition1() num_warps(1) { + ttg.warp_return + } : () -> () + + // CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array} + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(1) { + ttg.warp_return + } + partition1() num_warps(8) { + ttg.warp_return + } : () -> () + + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/async_ops_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/amd/async_ops_to_llvm.mlir new file mode 100644 index 000000000..bda9cb813 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/async_ops_to_llvm.mlir @@ -0,0 +1,205 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950 +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy + tt.func public @async_copy(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the splat to allow the AxisAnalysis to work during lowering + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + // Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds + // CHECK-COUNT-8: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_vectorized_2xf16 + tt.func public @async_copy_vectorized_2xf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds + // CHECK-COUNT-4: rocdl.global.load.lds + // CHECK-NOT: rocdl.global.load.lds + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // GFX950-LABEL: async_copy_vectorized_8xf16 + tt.func public @async_copy_vectorized_8xf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // We need the index calculation so AxisAnalysis sees that we can vectorize the load + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds + // GFX950: rocdl.global.load.lds + // GFX950-next: llvm.return + + // GFX942 does not support vectorization > 4bytes + // expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}} + %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr, #blocked> -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_wait + tt.func public @async_wait(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on + // CHECK: rocdl.s.waitcnt -49168 + // CHECK: rocdl.barrier + ttg.async_wait {num = 0 : i32} + // CHECK: rocdl.s.waitcnt -49167 + // CHECK: rocdl.barrier + ttg.async_wait {num = 1 : i32} + // CHECK: rocdl.s.waitcnt -2 + // CHECK: rocdl.barrier + ttg.async_wait {num = 62 : i32} + // CHECK: rocdl.s.waitcnt -1 + // CHECK: rocdl.barrier + ttg.async_wait {num = 63 : i32} + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_commit_group + tt.func public @async_commit_group(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.return + ttg.async_commit_group + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_mask_other + tt.func public @async_copy_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>, + %arg3: i32 {tt.divisibility = 16 : i32}) { + // We need the splat to allow the AxisAnalysis to work during lowering + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %29 = arith.addi %arg3, %c31_i32 : i32 + %30 = arith.divsi %29, %c32_i32 : i32 + %31 = arith.cmpi sgt, %30, %c0_i32 : i32 + + %51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %65 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked> + %66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked> + %67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + + %70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked> + %71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked> + + // Each thread needs to load 4 elements and we load 1 (sizePerThread) per global.load.lds + // Note that mask/other alignment is 1 so we need 4 conditionals + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + // CHECK: llvm.cond_br + // CHECK: rocdl.global.load.lds + // CHECK-NEXT: llvm.br + // CHECK: _predicated_store + + %2 = ttg.async_copy_global_to_local %1, %arg2 mask %67 other %cst_0 : tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: async_copy_cache_mods + tt.func public @async_copy_cache_mods(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>) { + // We need the splat to allow the AxisAnalysis to work during lowering + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + // Each thread needs to load 1 element and we load 1 (sizePerThread) per global.load.lds + + // CHECK: llvm.getelementptr + // CHECK: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] + %2 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = ca: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // CHECK: llvm.getelementptr + // CHECK: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]] + %3 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cg: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + // CHECK: llvm.getelementptr + // CHECK: %[[aux_cv:.*]] = llvm.mlir.constant(17 : i32) : i32 + // CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]] + %4 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cv: tensor<32x32x!tt.ptr, #blocked> -> <32x32xf16, #shared, #smem, mutable> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/buffer_load_store.mlir b/third_party/enflame/include/triton/test/Conversion/amd/buffer_load_store.mlir new file mode 100644 index 000000000..8f344e7c2 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/buffer_load_store.mlir @@ -0,0 +1,280 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load + tt.func @buffer_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { + // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[offset:.*]] = llvm.select %[[c_mask]] + // CHECK: %[[aux:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]], {{.*}}, %[[aux]] + %ret = amdgpu.buffer_load %arg0[%offset] cacheModifier = cs : tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_mask + tt.func @buffer_load_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[offset:.*]] = llvm.select %[[mask]] + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]] + %ret = amdgpu.buffer_load %arg0[%offset], %7 stride = %c256_i32 : tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_mask_other + tt.func @buffer_load_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + %other = arith.constant dense<0.00e+00> : tensor<128xf32, #blocked0> + // CHECK: %[[mask:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[offset:.*]] = llvm.select %[[mask]] + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]] + // CHECK: llvm.select + %ret = amdgpu.buffer_load %arg0[%offset], %7, %other stride = %c256_i32: tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_store + tt.func @buffer_store(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[w_mask:.*]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[mask:.*]] = llvm.and %[[c_mask]], %[[w_mask]] + // CHECK: %[[offset:.*]] = llvm.select %[[mask]] + // CHECK: %[[aux:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, %[[offset]], {{.*}}, %[[aux]] + %c256_i32 = arith.constant 256 : i32 + amdgpu.buffer_store %value, %arg0[%offset] cacheModifier = cs stride = %c256_i32 : tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_store_mask + tt.func @buffer_store_mask(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[mask1:.*]] = llvm.and %[[mask0]], {{.*}} + // CHECK: %[[offset:.*]] = llvm.select %[[mask1]] + // CHECK: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, %[[offset]] + amdgpu.buffer_store %value, %arg0[%offset], %7 stride = %N : tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec4 + tt.func @buffer_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + // Load 8 elements from A with two vectorized load instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.load {{.*}} : vector<4xf32> + %9 = amdgpu.buffer_load %arg0[%4] stride = %arg3 : tensor<256xf32, #blocked0> + // Load 8 elements from B with two vectorized load instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.load {{.*}} : vector<4xf32> + %10 = amdgpu.buffer_load %arg1[%4] stride = %arg3 : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + // Store 8 elements into C with two vectorized store instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.store {{.*}} : vector<4xf32> + amdgpu.buffer_store %11, %arg2[%4] stride = %arg3 : tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: buffer_load_8xf16 + tt.func public @buffer_load_8xf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.splat %arg2 : i32 -> tensor<256x64xi32, #blocked> + %2 = tt.expand_dims %0 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> + %4 = arith.addi %3, %1 : tensor<256x64xi32, #blocked> + // Load 16 f16 elements check for correct vector size of instruction (4xi32 = 8xf16) + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}} : vector<4xi32> + %5 = amdgpu.buffer_load %arg0[%4] : tensor<256x64xf16, #blocked> + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.store {{.*}} : vector<4xi32> + amdgpu.buffer_store %5, %arg0[%4] : tensor<256x64xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec1 + tt.func @buffer_load_store_vec1(%arg0: !tt.ptr , %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg3 : i32 -> tensor<256xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<256xi32, #blocked0> + // Load 8 elements from A with eight scalar load instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.load {{.*}} : f32 + %9 = amdgpu.buffer_load %arg0[%4], %7 stride = %arg3 : tensor<256xf32, #blocked0> + // Load 8 elements from B with two scalar load instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.load {{.*}} : f32 + %10 = amdgpu.buffer_load %arg1[%4], %7 stride = %arg3 : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + // Store 8 elements into C with two scalar store instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.store {{.*}} : f32 + amdgpu.buffer_store %11, %arg2[%4], %7 stride = %arg3 : tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec2 + tt.func @buffer_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr{tt.divisibility = 4 : i32}, %arg2: !tt.ptr{tt.divisibility = 4: i32}, %arg3: i32{tt.divisibility = 4: i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg3 : i32 -> tensor<256xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<256xi32, #blocked0> + // Load 8 fp16 elements from A with four i32 scalar load instructions + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}} : i32 + %9 = amdgpu.buffer_load %arg0[%4], %7 stride = %arg3 : tensor<256xf16, #blocked0> + // Load 8 fp16 elements from B with four i32 scalar load instructions + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}} : i32 + %10 = amdgpu.buffer_load %arg1[%4], %7 stride = %arg3 : tensor<256xf16, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf16, #blocked0> + // Store 8 fp16 elements into C with four i32 scalar store instructionss + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.store {{.*}} : i32 + amdgpu.buffer_store %11, %arg2[%4], %7 stride = %arg3 : tensor<256xf16, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_atomic + tt.func @buffer_atomic_rmw_fadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}, %N: i32, %values : tensor<128xf32, #blocked0>, %stride: i32 {tt.divisibility=16:i32}) { + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %mask = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // There should be a single release fence before any atomics + // CHECK: llvm.fence syncscope("agent") release + // CHECK: %[[mask1:.*]] = llvm.and %[[mask0]], {{.*}} + // CHECK: %[[offset:.*]] = llvm.select %[[mask1]] + + // We will have 4 calls to fadd, since the sizePerThread is 4. We should have a vmcnt between each call. + %ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask stride = %stride : tensor<128xf32, #blocked0> + + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void + // CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32 + + // There should be a single acquire fence after all of the atomics + // CHECK: llvm.fence syncscope("agent") acquire + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: buffer_load_layout_vectorization + tt.func public @buffer_load_layout_vectorization(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + %c1_i32 = arith.constant 1 : i32 + %21 = tt.splat %c1_i32 : i32 -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %22 = tt.expand_dims %21 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %23 = tt.broadcast %22 : tensor<1x16xi32, #blocked> -> tensor<8x16xi32, #blocked> + // Each thread has to load 8xi16 + // We expect vector size == 1 (i16) for the generated loads as sizePerThread = [1, 1] + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.load {{.*}}, {{.*}}, {{.*}}, {{.*}} : i16 + // CHECK-NOT: rocdl.raw.ptr.buffer.load + %24 = amdgpu.buffer_load %arg0[%23] : tensor<8x16xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: strided_buffer_load_and_store + tt.func public @strided_buffer_load_and_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<2> : tensor<1024xi32, #blocked> + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %1 = arith.muli %0, %cst : tensor<1024xi32, #blocked> + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}}, {{.*}}, {{.*}}, {{.*}} : f32 + // CHECK-NOT: rocdl.raw.ptr.buffer.load + %2 = amdgpu.buffer_load %arg0[%1] : tensor<1024xf32, #blocked> + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : f32 + // CHECK-NOT: rocdl.raw.ptr.buffer.store + amdgpu.buffer_store %2, %arg1[%1] : tensor<1024xf32, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir new file mode 100644 index 000000000..66f498d4f --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir @@ -0,0 +1,164 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefixes=COMMON,GFX950 +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s --check-prefixes=COMMON,GFX942 + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // COMMON-LABEL: buffer_load_to_local_simple + tt.func public @buffer_load_to_local_simple(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: !tt.ptr, + %arg2: tensor<32x64xi32, #blocked>, + %arg3: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { + // Each thread needs to load 8 elements and we load 1 (sizePerThread) per buffer load instruction + // COMMON: rocdl.make.buffer.rsrc + // COMMON-NOT: rocdl.make.buffer.rsrc + // COMMON-COUNT-8: rocdl.raw.ptr.buffer.load.lds + // COMMON-NOT: rocdl.raw.ptr.buffer.load.lds + %65 = amdgpu.buffer_load_to_local %arg1[%arg2] into %arg3 {OpIdx = #amdgpu.OpIdx<1>} : [tensor<32x64xi32, #blocked>] -> <32x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 32], order = [0, 1]}> +#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.shared = 0 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // COMMON-LABEL: buffer_load_to_local_vectorized_2xf16 + tt.func public @buffer_load_to_local_vectorized_2xf16(%arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) { + %cst = arith.constant dense<64> : tensor<1x64xi32, #blocked> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %3 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %4 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<1x64xi32, #blocked> + %6 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked> + %7 = arith.addi %3, %6 : tensor<64x64xi32, #blocked> + + // Each thread needs to load 2 elements and we load 2 (sizePerThread) per buffer load instruction + // COMMON: rocdl.make.buffer.rsrc + // COMMON-NOT: rocdl.make.buffer.rsrc + // COMMON: rocdl.raw.ptr.buffer.load.lds + // COMMON-NOT: rocdl.raw.ptr.buffer.load.lds + %8 = amdgpu.buffer_load_to_local %arg1[%7] into %arg2 : [tensor<64x64xi32, #blocked>] -> <64x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 32], order = [0, 1]}> +#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.shared = 0 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // COMMON-LABEL: buffer_load_to_local_vectorized_8xf16 + tt.func public @buffer_load_to_local_vectorized_8xf16(%arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) { + %cst = arith.constant dense<64> : tensor<1x64xi32, #blocked> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %3 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %4 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<1x64xi32, #blocked> + %6 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked> + %7 = arith.addi %3, %6 : tensor<64x64xi32, #blocked> + + // Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction + // GFX950: rocdl.make.buffer.rsrc + // GFX950-NOT: rocdl.make.buffer.rsrc + // GFX950: rocdl.raw.ptr.buffer.load.lds + // GFX950-NOT: rocdl.raw.ptr.buffer.load.lds + + // GFX942 does not support vectorization > 4bytes so we cannot lower it + // GFX942-NOT: rocdl.raw.ptr.buffer.load.lds + // GFX942: amdgpu.buffer_load_to_local + %8 = amdgpu.buffer_load_to_local %arg1[%7] into %arg2 : [tensor<64x64xi32, #blocked>] -> <64x64xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // COMMON-LABEL: buffer_load_to_local_mask_other + tt.func public @buffer_load_to_local_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: !tt.ptr, + %arg2: tensor<32x32xi32, #blocked>, + %arg3: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>, + %arg4: i32) { + // We need the splat to allow the AxisAnalysis to work during lowering + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %29 = arith.addi %arg4, %c31_i32 : i32 + %30 = arith.divsi %29, %c32_i32 : i32 + %31 = arith.cmpi sgt, %30, %c0_i32 : i32 + + %51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %65 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked> + %66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked> + %67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + + %70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked> + %71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked> + + // Each thread needs to load 4 elements and we load 1 (sizePerThread) per buffer load instruction + // Note that mask/other alignment is 1 so we need 4 conditionals + + // COMMON: rocdl.raw.ptr.buffer.load.lds + // COMMON: _predicated_store + + // COMMON: rocdl.raw.ptr.buffer.load.lds + // COMMON: _predicated_store + + // COMMON: rocdl.raw.ptr.buffer.load.lds + // COMMON: _predicated_store + + // COMMON: rocdl.raw.ptr.buffer.load.lds + // COMMON: _predicated_store + + // COMMON-NOT: rocdl.raw.ptr.buffer.load.lds + // COMMON-NOT: _predicated_store + + amdgpu.buffer_load_to_local %arg1[%arg2] mask=%67 other=%cst_0 into %arg3 {OpIdx = #amdgpu.OpIdx<1>} : [tensor<32x32xi32, #blocked>] tensor<32x32xf16, #blocked> -> <32x32xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // COMMON-LABEL: buffer_load_to_local_cache_mods + tt.func public @buffer_load_to_local_cache_mods(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>) { + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked> + // The first constant 0 skips the LDS offset which is also 0 + // COMMON: llvm.getelementptr + // COMMON: llvm.mlir.constant(0 : i32) : i32 + // COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32 + // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] + %1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: [tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable> + // COMMON: llvm.getelementptr + // COMMON: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32 + // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]] + %2 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = cg into %arg2: [tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable> + // COMMON: llvm.getelementptr + // COMMON: %[[aux_cv:.*]] = llvm.mlir.constant(17 : i32) : i32 + // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]] + %3 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = cv into %arg2: [tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable> + + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/builtin_func_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/amd/builtin_func_to_llvm.mlir new file mode 100644 index 000000000..645881730 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/builtin_func_to_llvm.mlir @@ -0,0 +1,12 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm="ftz=True" | FileCheck %s --check-prefix=LLVM_FTZ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm="ftz=False" | FileCheck %s --check-prefix=LLVM_NO_FTZ + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_fast_expf(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.amdgcn.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = tt.extern_elementwise %arg0 {libname = "libdevice", libpath = "", pure = true, symbol = "__triton_hip_fast_expf"} : (tensor<64xf32, #blocked>) -> tensor<64xf32, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/compute-base-ptr.mlir b/third_party/enflame/include/triton/test/Conversion/amd/compute-base-ptr.mlir new file mode 100644 index 000000000..5caa90afc --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/compute-base-ptr.mlir @@ -0,0 +1,20 @@ +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}> +#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @local_load_offset + tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) { + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1) + %1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> loc(#loc2) + // This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type. + // CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0 + %2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3) + tt.return + } +} +#loc1 = loc("conert_layout":1:0) +#loc2 = loc("local_alloc":2:0) +#loc3 = loc("local_load":3:0) diff --git a/third_party/enflame/include/triton/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir b/third_party/enflame/include/triton/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir new file mode 100644 index 000000000..848e13118 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir @@ -0,0 +1,33 @@ +// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s + +// CHECK-DAG: #[[DST_ENC:.+]] = #ttg.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #ttg.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #ttg.amd_mfma<{{.*}}> +// CHECK: large_tensor_conversion +#src = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = false}> +#dst = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @large_tensor_conversion(%arg0: tensor<128x128xf32, #src>) { + // CHECK: %[[TMP:.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, #[[SRC_ENC]]> -> tensor<128x128xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = ttg.convert_layout %[[TMP]] : tensor<128x128xf32, #[[TMP_ENC]]> -> tensor<128x128xf32, #[[DST_ENC]]> + %0 = ttg.convert_layout %arg0 : tensor<128x128xf32, #src> -> tensor<128x128xf32, #dst> + tt.return + } +} + +// ----- + +// CHECK-DAG: #[[DST_ENC:.+]] = #ttg.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #ttg.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #ttg.amd_mfma<{{.*}}> +// CHECK: large_tensor_3d_conversion +#src = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 1, 2], instrShape = [32, 32], isTransposed = false}> +#dst = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 64, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @large_tensor_3d_conversion(%arg0: tensor<2x128x64xf32, #src>) { + // CHECK: %[[TMP:.*]] = ttg.convert_layout {{.*}} : tensor<2x128x64xf32, #[[SRC_ENC]]> -> tensor<2x128x64xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = ttg.convert_layout %[[TMP]] : tensor<2x128x64xf32, #[[TMP_ENC]]> -> tensor<2x128x64xf32, #[[DST_ENC]]> + %0 = ttg.convert_layout %arg0 : tensor<2x128x64xf32, #src> -> tensor<2x128x64xf32, #dst> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/decompose-unsupported-conversions.mlir b/third_party/enflame/include/triton/test/Conversion/amd/decompose-unsupported-conversions.mlir new file mode 100644 index 000000000..d4930d24f --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -0,0 +1,107 @@ +// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s + +// CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #ttg.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #ttg.swizzled_shared<{{.*}}> +// CHECK-LABEL: wmma_to_wmma_dot_op +#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} { + tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { + // CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #smem> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + tt.return + } +} + +// ----- + +// CHECK: #[[$BLOCKED:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #ttg.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #ttg.swizzled_shared<{{.*}}> +// CHECK-LABEL: wmma_to_wmma_dot3d_op +#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) { + // CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #smem> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> + %0 = ttg.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130 +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: ttg.local_alloc + // CHECK: ttg.convert_layout + // CHECK-NOT: ttg.local_alloc + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx942 +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot_op_shortcut_gfx942(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: ttg.local_alloc + // CHECK: ttg.convert_layout + // CHECK-NOT: ttg.local_alloc + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx942 +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx942(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx942 +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx942(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx942 +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx942(%arg0: tensor<128x128xf16, #blocked>) { + // CHECK-NOT: ttg.convert_layout + // CHECK: ttg.local_alloc + // CHECK: ttg.local_load + %0 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/dedup-by-constancy.mlir b/third_party/enflame/include/triton/test/Conversion/amd/dedup-by-constancy.mlir new file mode 100644 index 000000000..66a224bce --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/dedup-by-constancy.mlir @@ -0,0 +1,30 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s + +// CHECK-LABEL: dedup_by_constancy_mfma +// CHECK-COUNT-4: llvm.icmp "slt" +// CHECK-NOT: llvm.icmp "slt" +// Here is why we expect exactly 4 icmp: +// For a 32x32 tensor A with mfma layout, each thread holds 16 elements, which are divided +// into 4 groups. E.g. thread 0 holds elements A[0:3,0], A[8:11,0], A[16:19,0], and A[24:27,0]. +// In this example, constancy of the tensor is 16 for dim 0, meaning A[0:15,0] have same values +// and A[16:31,0] have same values. Therefore, for thread 0, the first 8 elements are duplicated +// and the last 8 elements are duplicated. Ideally, thread 0 only needs two icmp, one for the +// first 8 elements and the other for the last 8 elements. In practice, the dedup analysis +// only allows duplication within each group of 4 elemnets. Therefore, we expect 4 icmp, one +// for each group of 4 elements. +// In the future, we can reduce the icmp to 2 in such case. +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @dedup_by_constancy_mfma(%arg0: i32 {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %1 = tt.splat %arg0 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %2 = arith.cmpi slt, %0, %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi1, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi1, #mma> + %4 = tt.broadcast %3 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma> + %cst = arith.constant dense<0.100000e+00> : tensor<32x32xf16, #mma> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #mma> + %6 = tt.broadcast %5 : tensor<32x1x!tt.ptr, #mma> -> tensor<32x32x!tt.ptr, #mma> + tt.store %6, %cst, %4 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/ds_transpose.mlir b/third_party/enflame/include/triton/test/Conversion/amd/ds_transpose.mlir new file mode 100644 index 000000000..b887dfff9 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/ds_transpose.mlir @@ -0,0 +1,145 @@ +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s + +#mma16 = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}> +#mma32 = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: ds_transpose_n_t_fp16_mfma_16 + tt.func @ds_transpose_n_t_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) { + // CHECK-COUNT-32: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_t_t_fp16_mfma_16 + tt.func @ds_transpose_t_t_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) { + // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + // CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_n_n_fp16_mfma_16 + tt.func @ds_transpose_n_n_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) { + // CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_t_n_fp16_mfma_16 + tt.func @ds_transpose_t_n_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) { + // CHECK-NOT: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_n_t_fp16_mfma32 + tt.func @ds_transpose_n_t_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) { + // CHECK-COUNT-32: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_t_t_fp16_mfma32 + tt.func @ds_transpose_t_t_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) { + // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>> + // CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_n_n_fp16_mfma32 + tt.func @ds_transpose_n_n_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) { + // CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>> + // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_t_n_fp16_mfma32 + tt.func @ds_transpose_t_n_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) { + // CHECK-NOT: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_n_t_i8_mfma_16 + tt.func @ds_transpose_n_t_i8_mfma_16(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>) { + // CHECK-COUNT-16: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_t_t_i8_mfma_16 + tt.func @ds_transpose_t_t_i8_mfma_16(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>) { + // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>> + // CHECK-COUNT-8: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_n_n_i8_mfma_16 + tt.func @ds_transpose_n_n_i8_mfma_16(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>) { + // CHECK-COUNT-8: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>> + // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_t_n_i8_mfma_16 + tt.func @ds_transpose_t_n_i8_mfma_16(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>) { + // CHECK-NOT: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_n_t_i8_mfma32 + tt.func @ds_transpose_n_t_i8_mfma32(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>) { + // CHECK-COUNT-16: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_t_t_i8_mfma32 + tt.func @ds_transpose_t_t_i8_mfma32(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>) { + // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>> + // CHECK-COUNT-6: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_n_n_i8_mfma32 + tt.func @ds_transpose_n_n_i8_mfma32(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>) { + // CHECK-COUNT-8: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>> + // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: ds_transpose_t_n_i8_mfma32 + tt.func @ds_transpose_t_n_i8_mfma32(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>) { + // CHECK-NOT: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32> + %1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>> + %2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/fp_to_fp.mlir b/third_party/enflame/include/triton/test/Conversion/amd/fp_to_fp.mlir new file mode 100644 index 000000000..8aceb4e3e --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/fp_to_fp.mlir @@ -0,0 +1,38 @@ +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s + +// CHECK-LABEL: f16_to_f32 +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @f16_to_f32(%arg0: tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>) { + // CHECK-COUNT-8: llvm.fpext %{{.+}} : f16 to f32 + %0 = tt.fp_to_fp %arg0 : tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: bf16_to_f32 +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @bf16_to_f32(%arg0: tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) { + // CHECK-COUNT-8: llvm.bitcast + %0 = tt.fp_to_fp %arg0 : tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: f32_to_f16 +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @f32_to_f16(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) { + // CHECK-COUNT-8: llvm.intr.experimental.constrained.fptrunc %{{.+}} tonearest ignore : f32 to f16 + %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> + // CHECK-COUNT-8: llvm.inline_asm asm_dialect {{.*}}s_setreg_imm32_b32{{.+}}v_cvt_f16_f32{{.+}}s_setreg_imm32_b32{{.+}} : (f32) -> f16 + + %1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/invalid_extractslice_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/amd/invalid_extractslice_to_llvm.mlir new file mode 100644 index 000000000..9730f9eac --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/invalid_extractslice_to_llvm.mlir @@ -0,0 +1,111 @@ +// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics + +// Invalid size +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTATile [256, 16]}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid zero source dimension +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x0xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{source tensor dimension size zero at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x0xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid zero result dimension +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result tensor dimension size zero at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x0xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid offset, not multiple of shapePerTile +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTATile [256, 16]}} + %1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid offset, out of bounds for dimension +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{invalid offset 128 at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,128] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid result layout +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result layout must match source layout}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2> + tt.return +} + +// ----- + +// Invalid result element type +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result element type must match source element type}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1> + tt.return +} + +// ----- + +// Invalid result rank +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result rank must be equal to source rank}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid result shape +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{result shape cannot be larger than input shape at dimension 1}} + %1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x256xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid rank +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // expected-error @+1 {{currently only 2D tensors are supported}} + %1 = amdgpu.extract_slice %arg0 [0,0,0] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1> + tt.return +} + +// ----- + +// Invalid non static offset +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) { + // expected-error @+2 {{expected ']'}} + // expected-error @+1 {{expected integer value}} + %2 = amdgpu.extract_slice %arg0 [%arg1, 0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/load_store.mlir b/third_party/enflame/include/triton/test/Conversion/amd/load_store.mlir new file mode 100644 index 000000000..e0834d4db --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/load_store.mlir @@ -0,0 +1,58 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: global_load_store_vec8 + tt.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + // Load 8 elements from A with two vectorized load instruction + // CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr<1> -> vector<4xf32> + %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #blocked0> + // Load 8 elements from B with two vectorized load instruction + // CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr<1> -> vector<4xf32> + %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: global_store_mfma_vec16 + tt.func public @global_store_mfma_vec16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %1 = math.exp2 %0 : tensor<32x32xf32, #mma> + %2 = arith.truncf %1 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %c32_i32 = arith.constant 32 : i32 + %100 = tt.get_program_id x : i32 + %101 = arith.muli %100, %c32_i32 : i32 + %102 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %300 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma> + %200 = tt.broadcast %300 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma> + %103 = tt.splat %101 : i32 -> tensor<32x32xi32, #mma> + %104 = arith.addi %103, %200 : tensor<32x32xi32, #mma> + %105 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + %106 = tt.addptr %105, %104 : tensor<32x32x!tt.ptr, #mma>, tensor<32x32xi32, #mma> + // Store 16 elements with four vectorized store instruction + // CHECK-COUNT-4: llvm.store {{.*}} : vector<4xf16>, !llvm.ptr<1> + tt.store %106, %2 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/math-denorm-handling.mlir b/third_party/enflame/include/triton/test/Conversion/amd/math-denorm-handling.mlir new file mode 100644 index 000000000..ba8ca82e4 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/math-denorm-handling.mlir @@ -0,0 +1,98 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" | FileCheck %s --check-prefixes=COMMON,LLVM_FTZ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" | FileCheck %s --check-prefixes=COMMON,LLVM_NO_FTZ + + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.amdgcn.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = math.exp2 %arg0 : tensor<64xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_exp(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = math.exp %arg0 : tensor<64xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_rsqrt(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.amdgcn.rsq.f32 + // LLVM_NO_FTZ: _ocml_rsqrt_f32 + %0 = math.rsqrt %arg0 : tensor<64xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_sqrt_f32(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ-LABEL: test_sqrt_f32 + // LLVM_FTZ-NOT: llvm.fcmp "ogt" + // LLVM_FTZ: llvm.amdgcn.sqrt.f32 + // LLVM_FTZ-NOT: llvm.fmul + // LLVM_FTZ-NOT: llvm.select + // + // LLVM_NO_FTZ-LABEL: test_sqrt_f32 + // LLVM_NO_FTZ: llvm.fcmp "ogt" + // LLVM_NO_FTZ: llvm.fmul + // LLVM_NO_FTZ-NEXT: llvm.select + // LLVM_NO_FTZ-NEXT: llvm.amdgcn.sqrt.f32 + // LLVM_NO_FTZ: llvm.fmul + // LLVM_NO_FTZ-NEXT: llvm.select + %0 = math.sqrt %arg0 : tensor<64xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_sqrt_rn_f32(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ-LABEL: test_sqrt_rn_f32 + // LLVM_FTZ: llvm.amdgcn.rsq.f32 + // LLVM_FTZ: llvm.fmul + // LLVM_FTZ: llvm.fmul + // LLVM_FTZ: llvm.fneg + // LLVM_FTZ: llvm.intr.fma + // LLVM_FTZ-NEXT: llvm.intr.fma + // LLVM_FTZ-NEXT: llvm.intr.fma + // LLVM_FTZ-NEXT: llvm.fneg + // LLVM_FTZ-NEXT: llvm.intr.fma + // LLVM_FTZ-NEXT: llvm.intr.fma + // LLVM_FTZ-NEXT: llvm.intr.is.fpclass + // LLVM_FTZ-NEXT: llvm.select + // + // LLVM_NO_FTZ-LABEL: test_sqrt_rn_f32 + // LLVM_NO_FTZ: llvm.intr.sqrt + %0 = tt.precise_sqrt %arg0 : tensor<64xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_sqrt_rn_f64(%arg0: tensor<64xf64, #blocked>) attributes {noinline = false} { + // COMMON-LABEL: test_sqrt_rn_f64 + // COMMON: llvm.intr.sqrt + %0 = tt.precise_sqrt %arg0 : tensor<64xf64, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/mfma-shortcut.mlir b/third_party/enflame/include/triton/test/Conversion/amd/mfma-shortcut.mlir new file mode 100644 index 000000000..bdf6db4e6 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/mfma-shortcut.mlir @@ -0,0 +1,217 @@ +// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s + +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: shortcut_mfma16 + tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: no_shortcut_mfma16 + tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { + // CHECK: store + // CHECK: load + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma32 + tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1) + // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0) + + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma32 + tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma16 + tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] + // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] + // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] + // CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] + // CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1) + // lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0) + // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1) + // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0) + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma16 + tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/tritongpu_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/amd/tritongpu_to_llvm.mlir new file mode 100644 index 000000000..931950d17 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -0,0 +1,343 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f32_scalar + tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { + // CHECK: llvm.cond_br + // CHECK: llvm.atomicrmw + // CHECK: llvm.store + // CHECK: llvm.br + // CHECK: rocdl.barrier + // CHECK: llvm.load + // CHECK: llvm.store + %0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (!tt.ptr, f32, i1) -> f32 + tt.store %arg0, %0 : !tt.ptr + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f32 + tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.cond_br + // CHECK: llvm.atomicrmw + // CHECK: llvm.atomicrmw + // CHECK: llvm.store + // CHECK: llvm.store + %0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> + tt.store %arg0, %0 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +// Smoke test to check that mfma 32 and dot operand layouts can work with small tensors, for example with shape 16x16 +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +#dotop1 = #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: small_mfma_tensor_conversions + tt.func public @small_mfma_tensor_conversions(%arg0: tensor<16x16xf16, #mfma>, %arg1: tensor<16x16x!tt.ptr, #mfma>) { + // CHECK-NOT: ttg.convert_layout + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + // CHECK-4: store {{.*}} vector<4xf16> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #dotop0> + // CHECK-2: load {{.*}} vector<4xf16> + %2 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #dotop1> + // CHECK-8: load {{.*}} vector<1xf16> + %3 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #mfma> + // CHECK-4: load {{.*}} vector<4xf16> + %4 = tt.fp_to_fp %3 : tensor<16x16xf16, #mfma> -> tensor<16x16xf32, #mfma> + + %5 = tt.dot %1, %2, %4 : tensor<16x16xf16, #dotop0> * tensor<16x16xf16, #dotop1> -> tensor<16x16xf32, #mfma> + // Store result to prevent DCE from removing all conversion related code + %6 = ttg.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !ttg.memdesc<16x16xf32, #shared, #smem> + tt.return + } +} + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16x2 + tt.func @atomic_add_f16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK-NOT: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + // CHECK-NOT: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16x2 + tt.func @atomic_add_bf16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK-NOT: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + // CHECK-NOT: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16_dpp + tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + // CHECK: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16_dpp + tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + // CHECK: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked3 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_dpp_max + tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) { + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 274, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 273, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 322, 10, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 323, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK: llvm.amdgcn.readlane + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64xf32, #blocked3>) -> f32 + tt.return + } +} + +// ----- + +#blocked4 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_xor_max + tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { + // CHECK: rocdl.ds_swizzle + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 12, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 264, 15, 3, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 10, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 260, 15, 5, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 78, 15, 15, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 177, 15, 15, false : i32 + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32xf32, #blocked4>) -> f32 + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomicrmw_scope_memsemantics + tt.func @atomicrmw_scope_memsemantics(%arg0 : tensor<128x!tt.ptr, #blocked0>, %arg1 : tensor<128xi1, #blocked0>, %arg2 : tensor<128xf32, #blocked0>) { + // relaxed + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} monotonic + %0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic + %1 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) monotonic + %2 = tt.atomic_rmw fadd, relaxed, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + + // acquire + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} acquire + %3 = tt.atomic_rmw fadd, acquire, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acquire + %4 = tt.atomic_rmw fadd, acquire, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) acquire + %5 = tt.atomic_rmw fadd, acquire, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + + // release + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} release + %6 = tt.atomic_rmw fadd, release, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) release + %7 = tt.atomic_rmw fadd, release, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) release + %8 = tt.atomic_rmw fadd, release, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + + // acq_rel + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} acq_rel + %9 = tt.atomic_rmw fadd, acq_rel, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acq_rel + %10 = tt.atomic_rmw fadd, acq_rel, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) acq_rel + %11 = tt.atomic_rmw fadd, acq_rel, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + + tt.return + } +} + +// ----- + +#blocked5 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: atomic_runtime_lds_reduction + tt.func @atomic_runtime_lds_reduction(%arg0 : tensor<64x!tt.ptr, #blocked5>, %arg2 : tensor<64xf32, #blocked5>) { + + // CHECK: llvm.zext + // CHECK-COUNT-7: rocdl.update.dpp + // CHECK: llvm.bitcast + // CHECK-COUNT: llvm.amdgcqn.ds.permute + // CHECK: llvm.bitcast + // CHECK: llvm.ptrtoint + // CHECK: llvm.bitcast + // CHECK-COUNT-2: llvm.amdgcn.ds.permute + // CHECK: llvm.bitcast + // CHECK: llvm.inttoptr + // CHECK: llvm.amdgcn.ballot + // CHECK: llvm.ptrtoint + // CHECK: llvm.amdgcn.ballot + + // loop body: + // CHECK: llvm.bitcast + // CHECK-COUNT-2: llvm.amdgcn.readfirstlane + // CHECK: llvm.bitcast + // CHECK: llvm.amdgcn.ballot + // CHECK: rocdl.mbcnt.lo + // CHECK: rocdl.mbcnt.hi + + // share info: + // 1. address + // CHECK: llvm.bitcast + // CHECK-COUNT-2: llvm.amdgcn.ds.permute + // CHECK: llvm.bitcast + // 2. value + // CHECK: llvm.amdgcn.ds.permute + // CHECK: llvm.bitcast + // 3. packed methadata + // CHECK: llvm.bitcast + // CHECK: llvm.amdgcn.ds.permute + // CHECK: llvm.bitcast + + // CHECK: llvm.amdgcn.ballot + + // reduction: + // CHECK-COUNT-6: llvm.amdgcn.ds.bpermute + + // CHECK: inttoptr + // CHECK: llvm.atomicrmw + %0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2 {allocation.offset = 0 : i32} : (tensor<64x!tt.ptr, #blocked5>, tensor<64xf32, #blocked5>) -> tensor<64xf32, #blocked5> + tt.return + } +} + +// ----- + +// CHECK-LABEL: v_dot_i8 +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @v_dot_i8(%arg0: tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<16x16xi32, #blocked>) { + // CHECK-4: llvm.call_intrinsic "llvm.amdgcn.sdot4" + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xi32, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: v_dot_fp16 +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @v_dot_fp16(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<16x16xf32, #blocked>) { + // CHECK-8: llvm.call_intrinsic "llvm.amdgcn.fdot2" + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf32, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: v_dot_fp16_fp16 +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @v_dot_fp16_fp16(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<16x16xf16, #blocked>) { + // CHECK-COUNT-16: llvm.call_intrinsic "llvm.fmuladd.f16" + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir new file mode 100644 index 000000000..ce1576b55 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -0,0 +1,205 @@ +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#mma2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#mma2_transposed = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2], isTranspose = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma1_dot_operand + tt.func @wmma1_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared, #smem>) { + // 2 CTA * 4 rep * load_per_thread_per_instr + // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> + %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + // CHECK-COUNT-128: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: wmma2_dot_operand + tt.func @wmma2_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared, #smem>) { + // 2 CTA * 4 rep * load_per_thread_per_instr + // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> + %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> + // CHECK-COUNT-64: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: wmma1_dot + tt.func @wmma1_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) { + // CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK: llvm.mlir.undef : vector<16xf16> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16> + // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + tt.return + } + + // CHECK-LABEL: wmma1_dot_bf16 + tt.func @wmma1_dot_bf16(%arg0: tensor<16x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma1>) { + // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> + // CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16> + // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> + // CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> + // CHECK: llvm.mlir.undef : vector<16xbf16> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1> + tt.return + } + + // CHECK-LABEL: wmma1_dot_int8_32 + tt.func @wmma1_dot_int8_32(%arg0: tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { + // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8> + // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> + // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8> + // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> + // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + tt.return + } + + // CHECK-LABEL: wmma1_dot_int4_32 + tt.func @wmma1_dot_int4_32(%arg0: tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { + // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)> + // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4> + // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> + // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)> + // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4> + // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> + // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + tt.return + } + + // CHECK-LABEL: wmma2_dot + tt.func @wmma2_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>>, %arg2: tensor<16x16xf16, #mma2>) { + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK: llvm.mlir.undef : vector<8xf16> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK: llvm.mlir.undef : vector<8xf16> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK: llvm.mlir.undef : vector<8xf16> + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v8f16"{{.*}} : (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> -> tensor<16x16xf16, #mma2> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf16> + // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + tt.return + } + + // CHECK-LABEL: wmma2_transposed_dot + tt.func @wmma2_transposed_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2_transposed, kWidth = 8}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2_transposed, kWidth = 8}>>, %arg2: tensor<16x16xf16, #mma2_transposed>) { + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v8f16"{{.*}} : (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2_transposed, kWidth = 8}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2_transposed, kWidth = 8}>> -> tensor<16x16xf16, #mma2_transposed> + tt.return + } + + // CHECK-LABEL: blocked_to_wmma1 + tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma1> + tt.return + } + + // CHECK-LABEL: slice_blocked_to_wmma1 + tt.func @slice_blocked_to_wmma1(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>> + tt.return + } + + // CHECK-LABEL: wmma1_to_blocked + tt.func @wmma1_to_blocked(%arg0: tensor<128x16xi32, #mma1>) { + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma1> -> tensor<128x16xi32, #blocked> + tt.return + } + + // CHECK-LABEL: slice_wmma1_to_blocked + tt.func @slice_wmma1_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>) { + // CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + tt.return + } + + // CHECK-LABEL: blocked_to_wmma2 + tt.func @blocked_to_wmma2(%arg0: tensor<128x16xi32, #blocked>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma2> + tt.return + } + + // CHECK-LABEL: slice_blocked_to_wmma2 + tt.func @slice_blocked_to_wmma2(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>> + tt.return + } + + // CHECK-LABEL: wmma2_to_blocked + tt.func @wmma2_to_blocked(%arg0: tensor<128x16xi32, #mma2>) { + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma2> -> tensor<128x16xi32, #blocked> + tt.return + } + + // CHECK-LABEL: slice_wmma2_to_blocked + tt.func @slice_wmma2_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>) { + // CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0]}> +#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_dot_operand3d + tt.func @wmma_dot_operand3d(%arg0: !ttg.memdesc<4x16x32xf16, #shared, #smem>) { + // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> + %0 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared, #smem> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + // CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> + %1 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared, #smem> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: wmma_dot3d + tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma1>) { + // CHECK-COUNT-32: llvm.extractvalue %arg0 + // CHECK-COUNT-32: llvm.insertelement + // CHECK-COUNT-32: llvm.extractvalue %arg1 + // CHECK-COUNT-32: llvm.insertelement + // CHECK-COUNT-8: llvm.extractvalue %arg2 + // CHECK-COUNT-8: llvm.insertelement + // CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1> + // CHECK-COUNT-8: llvm.extractelement + // CHECK-COUNT-8: llvm.insertvalue + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/atomic_ldst.mlir b/third_party/enflame/include/triton/test/Conversion/atomic_ldst.mlir new file mode 100644 index 000000000..4c1e63c40 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/atomic_ldst.mlir @@ -0,0 +1,29 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s --check-prefix=CHECK-TTG2NVGPU +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 --convert-nv-gpu-to-llvm 2>&1 | FileCheck %s --check-prefix=CHECK-NVGPU2LLVM +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel_r(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f32 + %true = arith.constant true + %c128_i32 = arith.constant 128 : i32 + %c512_i32 = arith.constant 512 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = arith.cmpi slt, %1, %c512_i32 : i32 + + // CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, gpu + // CHECK-NVGPU2LLVM: ld.global.gpu.acquire.b32 + %3 = tt.atomic_rmw fadd, acquire, gpu, %arg0, %cst, %2 : (!tt.ptr, f32, i1) -> f32 + tt.store %arg0, %3 : !tt.ptr + + // CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, cta + // CHECK-NVGPU2LLVM: ld.global.cta.acquire.b32 + %4 = tt.atomic_rmw fadd, acquire, cta, %arg0, %cst, %true : (!tt.ptr, f32, i1) -> f32 + tt.store %arg0, %4 : !tt.ptr + + // CHECK-TTG2NVGPU: nvgpu.ld_acquire acquire, sys + // CHECK-NVGPU2LLVM: ld.global.sys.acquire.b32 + %5 = tt.atomic_rmw fadd, acquire, sys, %arg0, %cst, %2 : (!tt.ptr, f32, i1) -> f32 + tt.store %arg0, %5 : !tt.ptr + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/cvt_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/cvt_to_llvm.mlir new file mode 100644 index 000000000..f577bc5af --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/cvt_to_llvm.mlir @@ -0,0 +1,153 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> + +#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + +// CHECK-LABEL: convert_layout_blocked_blocked_vec +tt.func private @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xi32, #blocked0>) -> tensor<16x16xi32, #blocked2> { + + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1 + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2 + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3 + // CHECK-NEXT: [[SRC4:%.*]] = extractvalue {{.*}} %0, 4 + // CHECK-NEXT: [[SRC5:%.*]] = extractvalue {{.*}} %0, 5 + // CHECK-NEXT: [[SRC6:%.*]] = extractvalue {{.*}} %0, 6 + // CHECK-NEXT: [[SRC7:%.*]] = extractvalue {{.*}} %0, 7 + + // CHECK-NEXT: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + + // The layout conversion looks like + // dst_lane + // dst_reg 0 1 2 3 ... 16 17 18 19 ... + // 0 T0:0 T1:0 T4:0 T5:0 T0:4 T1:4 T4:4 T5:4 + // 1 T0:1 T1:1 T4:1 T5:1 T0:5 T1:5 T4:5 T5:5 + // ... + // 4 T2:0 T3:0 T6:0 T7:0 T2:4 T3:4 T6:4 T7:4 + // 5 T2:1 T3:1 T6:1 T7:1 T2:5 T3:5 T6:5 T7:5 + // ... + // + // This subsection is tiled to fill the rest of the lanes and registers. + // + // There will need to be one select per shuffle input and one select per + // shuffle output due to src registers (i%4, (i%4)+4) mapped to the same dst + // register. + + // Lanes [2, 3, 6, 7, ...] will send register i+4 while the others send i+0. + + // CHECK-DAG: [[IS_UPPER_HALF:%.*]] = and i32 [[TID]], 2 + // CHECK-DAG: [[IS_LOWER_HALF:%.*]] = icmp eq i32 [[IS_UPPER_HALF]], 0 + + // For register [0, 4), the lane shuffle idx is essentially computed as + // `(x//2*4 + x%2)%16 + (x>=16)*2` + + // CHECK-DAG: [[X_MOD_2:%.*]] = and i32 [[TID]], 1 + // CHECK-DAG: [[X_2_4_LOWER:%.*]] = shl {{.*}} i32 [[IS_UPPER_HALF]], 1 + // CHECK-DAG: [[X_2_4_UPPER0:%.*]] = shl i32 [[TID]], 1 + // CHECK-DAG: [[X_2_4_UPPER1:%.*]] = and i32 [[X_2_4_UPPER0]], 24 + // CHECK-DAG: [[X_GE_16:%.*]] = and i32 [[TID]], 16 + // CHECK-DAG: [[X_GE_16_2:%.*]] = lshr exact i32 [[X_GE_16]], 3 + + // CHECK-DAG: [[IDX0:%.*]] = or disjoint i32 [[X_2_4_LOWER]], [[X_MOD_2]] + // CHECK-DAG: [[IDX1:%.*]] = or disjoint i32 [[IDX0]], [[X_2_4_UPPER1]] + // CHECK-DAG: [[IDX2:%.*]] = or disjoint i32 [[IDX1]], [[X_GE_16_2]] + + // CHECK-DAG: [[SHFLSRC0:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC0]], i32 [[SRC4]] + // CHECK-DAG: [[SHFLSRC1:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC1]], i32 [[SRC5]] + // CHECK-DAG: [[SHFLSRC2:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC2]], i32 [[SRC6]] + // CHECK-DAG: [[SHFLSRC3:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC3]], i32 [[SRC7]] + // CHECK-DAG: [[SHFLSRC4:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC4]], i32 [[SRC0]] + // CHECK-DAG: [[SHFLSRC5:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC5]], i32 [[SRC1]] + // CHECK-DAG: [[SHFLSRC6:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC6]], i32 [[SRC2]] + // CHECK-DAG: [[SHFLSRC7:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC7]], i32 [[SRC3]] + + // CHECK-DAG: [[SHFLOUT0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC0]], i32 [[IDX2]], i32 31) + // CHECK-DAG: [[SHFLOUT1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC1]], i32 [[IDX2]], i32 31) + // CHECK-DAG: [[SHFLOUT2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC2]], i32 [[IDX2]], i32 31) + // CHECK-DAG: [[SHFLOUT3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC3]], i32 [[IDX2]], i32 31) + + // For register [4, 8), the upper and lower halves swap. + + // CHECK-DAG: [[IDX3:%.*]] = or disjoint i32 [[IDX1]], 2 + // CHECK-DAG: [[IDX4:%.*]] = xor i32 [[IDX3]], [[X_GE_16_2]] + + // CHECK-DAG: [[SHFLOUT4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC4]], i32 [[IDX4]], i32 31) + // CHECK-DAG: [[SHFLOUT5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC5]], i32 [[IDX4]], i32 31) + // CHECK-DAG: [[SHFLOUT6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC6]], i32 [[IDX4]], i32 31) + // CHECK-DAG: [[SHFLOUT7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC7]], i32 [[IDX4]], i32 31) + + // For lanes [16, 32), swap the two results. + + // CHECK-DAG: [[SWAP_RESULTS:%.*]] = icmp eq i32 [[X_GE_16]], 0 + + // CHECK: [[DST0:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT0]], i32 [[SHFLOUT4]] + // CHECK: [[DST1:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT1]], i32 [[SHFLOUT5]] + // CHECK: [[DST2:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT2]], i32 [[SHFLOUT6]] + // CHECK: [[DST3:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT3]], i32 [[SHFLOUT7]] + // CHECK: [[DST4:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT4]], i32 [[SHFLOUT0]] + // CHECK: [[DST5:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT5]], i32 [[SHFLOUT1]] + // CHECK: [[DST6:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT6]], i32 [[SHFLOUT2]] + // CHECK: [[DST7:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT7]], i32 [[SHFLOUT3]] + + // CHECK: insertvalue {{.*}}, i32 [[DST0]], 0 + // CHECK: insertvalue {{.*}}, i32 [[DST1]], 1 + // CHECK: insertvalue {{.*}}, i32 [[DST2]], 2 + // CHECK: insertvalue {{.*}}, i32 [[DST3]], 3 + // CHECK: insertvalue {{.*}}, i32 [[DST4]], 4 + // CHECK: insertvalue {{.*}}, i32 [[DST5]], 5 + // CHECK: insertvalue {{.*}}, i32 [[DST6]], 6 + // CHECK: insertvalue {{.*}}, i32 [[DST7]], 7 + + %0 = ttg.convert_layout %arg0 : tensor<16x16xi32, #blocked0> -> tensor<16x16xi32, #blocked2> + tt.return %0 : tensor<16x16xi32, #blocked2> +} + +// CHECK-LABEL: convert_layout_blocked_blocked +tt.func private @convert_layout_blocked_blocked(%arg0: tensor<16x16xi32, #blocked0>) -> tensor<16x16xi32, #blocked1> { + // This conversion looks like: + // dst_lane + // dst_reg 0 1 ... 16 17 ... + // 0 T0:0 T16:0 T1:0 T17:0 + // 1 T4:0 T20:0 T5:0 T21:0 + // 2 T8:0 T24:0 T9:0 T25:0 + // 3 T12:0 T28:0 T13:0 T29:0 + // 4 T2:0 T18:0 T3:0 T19:0 + // 5 T6:0 T22:0 T7:0 T23:0 + // 6 T10:0 T26:0 T11:0 T27:0 + // 7 T14:0 T30:0 T15:0 T31:0 + // + // Where the registers change every 2 lanes like [0, 4, 1, 5, 2, 6, 3, 7] and + // wraps around at lane 16. Due to this, there needs to be 8 selects per + // shuffle input and output. The lane mapping also changes every register. Due + // to this, we choose to fall back to the shared memory implementation. + + // CHECK-NOT: shfl.sync.idx + // CHECK: st.shared + + %0 = ttg.convert_layout %arg0 : tensor<16x16xi32, #blocked0> -> tensor<16x16xi32, #blocked1> + tt.return %0 : tensor<16x16xi32, #blocked1> +} + +tt.func private @cvt_mma_to_dot_fp8(%a: tensor<128x64xi32, #mma>) -> tensor<128x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> { + %opA = ttg.convert_layout %a : tensor<128x64xi32, #mma> -> tensor<128x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return %opA : tensor<128x64xi32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +} + +tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<16x16xi32, #blocked0>, %arg1: tensor<128x64xi32, #mma>) { + %0 = tt.call @convert_layout_blocked_blocked(%arg0) : (tensor<16x16xi32, #blocked0>) -> tensor<16x16xi32, #blocked1> + %1 = builtin.unrealized_conversion_cast %0 : tensor<16x16xi32, #blocked1> to !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>, !llvm.ptr + + %2 = tt.call @convert_layout_blocked_blocked_vec(%arg0) : (tensor<16x16xi32, #blocked0>) -> tensor<16x16xi32, #blocked2> + %3 = builtin.unrealized_conversion_cast %2 : tensor<16x16xi32, #blocked2> to !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + llvm.store volatile %3, %ptr : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>, !llvm.ptr + + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/dedup-by-constancy.mlir b/third_party/enflame/include/triton/test/Conversion/dedup-by-constancy.mlir new file mode 100644 index 000000000..dc2cda84a --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/dedup-by-constancy.mlir @@ -0,0 +1,72 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm --llvm-optimize-for-nvvm-target | FileCheck %s + +// CHECK-LABEL: dedup_by_constancy_full +// CHECK-COUNT-2: llvm.add +// CHECK-NOT: llvm.add +// CHECK: llvm.icmp "slt" +// CHECK-NOT: llvm.icmp "slt" +// CHECK: llvm.sdiv +// CHECK-NOT: llvm.sdiv +// CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]] +// CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]] +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<256> : tensor<1024xi32, #blocked> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg2 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = tt.load %9, %6 : tensor<1024x!tt.ptr, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %12, %10, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: dedup_by_constancy_partial +// CHECK-COUNT-4: llvm.add +// CHECK-NOT: llvm.add +// CHECK: llvm.icmp "slt" +// CHECK-NOT: llvm.icmp "slt" +// CHECK-COUNT-2: llvm.sdiv +// CHECK-NOT: llvm.sdiv +// CHECK: llvm.getelementptr %arg0[[[REGISTER1:%[0-9]+]]] +// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER1]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER1]]] +// CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]] +// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]] +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<4> : tensor<1024xi32, #blocked> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg2 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = tt.load %9, %6 : tensor<1024x!tt.ptr, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %12, %10, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/divide-by-0.mlir b/third_party/enflame/include/triton/test/Conversion/divide-by-0.mlir new file mode 100644 index 000000000..f12fd1bc7 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/divide-by-0.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm --cse | FileCheck %s + +// CHECK-LABEL: dont_divide_0 +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NOT: llvm.urem %{{.*}}, %[[C0]] +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @dont_divide_0() attributes {noinline = false} { + %zero = arith.constant dense<0.000000e+00> : tensor<16x1xf32, #mma> + %cvt = ttg.convert_layout %zero : tensor<16x1xf32, #mma> -> tensor<16x1xf32, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/gather_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/gather_to_llvm.mlir new file mode 100644 index 000000000..1882c1046 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/gather_to_llvm.mlir @@ -0,0 +1,339 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +// Check the optimized LLVMIR, since InstCombine makes the linear layout +// logic understandable enough (in simple cases) to check correctness by eye. + +#trivial_layout = #ttg.linear<{register = [], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider = #ttg.linear<{register = [[32]], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider_reg_stride_1 = #ttg.linear<{register = [[1]], lane = [[2], [4], [8], [16], [32]], warp = [], block = []}> + +#trivial_2d_one_col = #ttg.linear<{register = [[0, 1]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [], block = []}> + +#span_2d_cols = #ttg.linear<{register = [[1, 0]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [], block = []}> + +#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> +#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}> + +#broadcasted_lane_1d = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#broadcasted_warp_2d = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// Each source element is mapped to a single thread, so we expect one index shuffle. +// CHECK-LABEL: @gather_warp_local_trivial +tt.func private @gather_warp_local_trivial(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, but there are two index elements per thread. Expect 2 index shuffles +// with the results packed together. +// CHECK-LABEL: @gather_warp_local_larger_output +tt.func private @gather_warp_local_larger_output(%arg0: tensor<64xi32, #trivial_layout_wider>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<64xi32, #trivial_layout_wider>) -> tensor<64xf32, #trivial_layout_wider> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<64xf32, #trivial_layout_wider> +} + +// Each thread has 2 elements of the source tensor, strided 32 apart, so we +// expect two index shuffles, using the MSB to select between the two. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 32 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, except the RegID comes from the LSB. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input_stride_1(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX]], 1 + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[TMP]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider_reg_stride_1>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Each thread has 1 element in 2 gather columns, so this is the same as the +// trivial case except now it's 2D. We expect 2 independent index shuffles. +// CHECK-LABEL: @gather_2d_trivial +tt.func private @gather_2d_trivial(%arg0: tensor<32x2xi32, #trivial_2d_one_col>, %arg1: tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #trivial_2d_one_col>, tensor<32x2xi32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #trivial_2d_one_col> +} + +// The single warp is split into two columns. Each column has half contiguous +// threads, each with 2 contiguous elements. Expect 4 index shuffles: two per +// column. Thus, the index should be dependent on the thread id, since the +// register alone is not enough to determine the column. +// CHECK-LABEL: @gather_2d_span_2 +tt.func private @gather_2d_span_2(%arg0: tensor<32x2xi32, #span_2d_cols>, %arg1: tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // This uses tid to select between the two columns: + // CHECK-NEXT: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK-NEXT: [[COL:%.*]] = and i32 [[TID]], 16 + + // Break the index into reg and thread (within column) components: + // CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1 + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID0]], [[COL]] + + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // Use the reg id to select between the two results: + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0 + // CHECK-NEXT: [[RES0_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[REGID1:%.*]] = and i32 [[IDX1]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1 + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID1]], [[COL]] + + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID1]], 0 + // CHECK-NEXT: [[RES1_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #span_2d_cols>, tensor<32x2xi32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #span_2d_cols> +} + +// CHECK-LABEL: @gather_2d_crazy +tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1: tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> { + // The specific logic becomes hard to grasp here. Just check the shuffles. + + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float, float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float, float, float } %1, 1 + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue { float, float, float, float } %1, 2 + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue { float, float, float, float } %1, 3 + + // CHECK: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: [[VALUE2:%.*]] = bitcast float [[SRC2]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: [[VALUE3:%.*]] = bitcast float [[SRC3]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x16xf32, #crazy_2d_src>, tensor<32x16xi32, #crazy_2d_idx>) -> tensor<32x16xf32, #crazy_2d_idx> + tt.return %0 : tensor<32x16xf32, #crazy_2d_idx> +} + +// There are 16 elements in the tensor. For each warp, each half-warp is mapped +// to the 16 elements, so it doesn't matter if the second half [16, 32) indexes +// into [0, 16), since they contain the same data. +// CHECK-LABEL: @gather_broadcasted_lane_1d +tt.func private @gather_broadcasted_lane_1d(%arg0: tensor<16xi32, #broadcasted_lane_1d>, %arg1: tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 15 + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<16xf32, #broadcasted_lane_1d>, tensor<16xi32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> + + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<16xf32, #broadcasted_lane_1d> +} + +// Single gather column with 64 elements, all of which have to fit into a single +// warp, so the whole column is broadcasted across the 4 warps. Each process the +// same data so the warp doesn't matter. +// CHECK-LABEL: @gather_broadcasted_warp_2d +tt.func private @gather_broadcasted_warp_2d(%arg0: tensor<64x1xi32, #broadcasted_warp_2d>, %arg1: tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1 + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 31 + + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID0]], i32 31) + + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0 + // CHECK-NEXT: select i1 [[PICK0]], i32 [[RES0_i32]], i32 [[RES1_i32]] + + // CHECK: [[REGID1:%.*]] = and i32 [[IDX1]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1 + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 31 + + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31) + + // CHECK-NEXT: [[PICK1:%.*]] = icmp eq i32 [[REGID1]], 0 + // CHECK-NEXT: select i1 [[PICK1]], i32 [[RES0_i32]], i32 [[RES1_i32]] + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64x1xf32, #broadcasted_warp_2d>, tensor<64x1xi32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> + tt.return %0 : tensor<64x1xf32, #broadcasted_warp_2d> +} + +// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM +// from removing unused function results. +tt.func @anchor(%ptr: !llvm.ptr, + %arg0: tensor<32xi32, #trivial_layout>, + %arg1: tensor<32xf32, #trivial_layout>, + %arg2: tensor<64xi32, #trivial_layout_wider>, + %arg3: tensor<64xf32, #trivial_layout_wider>, + %arg4: tensor<64xf32, #trivial_layout_wider_reg_stride_1>, + %arg5: tensor<32x2xi32, #trivial_2d_one_col>, + %arg6: tensor<32x2xf32, #trivial_2d_one_col>, + %arg7: tensor<32x2xi32, #span_2d_cols>, + %arg8: tensor<32x2xf32, #span_2d_cols>, + %arg9: tensor<32x16xi32, #crazy_2d_idx>, + %arg10: tensor<32x16xf32, #crazy_2d_src>, + %arg11: tensor<16xi32, #broadcasted_lane_1d>, + %arg12: tensor<16xf32, #broadcasted_lane_1d>, + %arg13: tensor<64x1xi32, #broadcasted_warp_2d>, + %arg14: tensor<64x1xf32, #broadcasted_warp_2d>) { + + %0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + %1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %1, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %2 = tt.call @gather_warp_local_larger_output(%arg2, %arg1) : (tensor<64xi32, #trivial_layout_wider>, tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> + %3 = builtin.unrealized_conversion_cast %2 : tensor<64xf32, #trivial_layout_wider> to !llvm.struct<(f32, f32)> + llvm.store volatile %3, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %4 = tt.call @gather_warp_local_larger_input(%arg0, %arg3) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> + %5 = builtin.unrealized_conversion_cast %4 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %5, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %6 = tt.call @gather_warp_local_larger_input_stride_1(%arg0, %arg4) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> + %7 = builtin.unrealized_conversion_cast %6 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %7, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %8 = tt.call @gather_2d_trivial(%arg5, %arg6) : (tensor<32x2xi32, #trivial_2d_one_col>, tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + %9 = builtin.unrealized_conversion_cast %8 : tensor<32x2xf32, #trivial_2d_one_col> to !llvm.struct<(f32, f32)> + llvm.store volatile %9, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %10 = tt.call @gather_2d_span_2(%arg7, %arg8) : (tensor<32x2xi32, #span_2d_cols>, tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + %11 = builtin.unrealized_conversion_cast %10 : tensor<32x2xf32, #span_2d_cols> to !llvm.struct<(f32, f32)> + llvm.store volatile %11, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %12 = tt.call @gather_2d_crazy(%arg9, %arg10) : (tensor<32x16xi32, #crazy_2d_idx>, tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> + %13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)> + llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr + + %14 = tt.call @gather_broadcasted_lane_1d(%arg11, %arg12) : (tensor<16xi32, #broadcasted_lane_1d>, tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> + %15 = builtin.unrealized_conversion_cast %14 : tensor<16xf32, #broadcasted_lane_1d> to !llvm.struct<(f32)> + llvm.store volatile %15, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %16 = tt.call @gather_broadcasted_warp_2d(%arg13, %arg14) : (tensor<64x1xi32, #broadcasted_warp_2d>, tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> + %17 = builtin.unrealized_conversion_cast %16 : tensor<64x1xf32, #broadcasted_warp_2d> to !llvm.struct<(f32, f32)> + llvm.store volatile %17, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/nvgpu_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/nvgpu_to_llvm.mlir new file mode 100644 index 000000000..11fe40bd7 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/nvgpu_to_llvm.mlir @@ -0,0 +1,187 @@ +// RUN: triton-opt %s --convert-nv-gpu-to-llvm -allow-unregistered-dialect -split-input-file | FileCheck %s + +// CHECK-LABEL: @nvvm_syncs +llvm.func @nvvm_syncs() { + // CHECK: wgmma.fence.sync.aligned; + nvgpu.wgmma_fence + + // CHECK: wgmma.commit_group.sync.aligned; + nvgpu.wgmma_commit_group + + // CHECK: barrier.cluster.wait.aligned; + nvgpu.cluster_wait + + // CHECK: fence.proxy.async.shared::cta; + nvgpu.fence_async_shared {bCluster = false} + // CHECK: fence.proxy.async.shared::cluster; + nvgpu.fence_async_shared {bCluster = true} + + // CHECK: barrier.cluster.arrive.aligned; + nvgpu.cluster_arrive {relaxed = false} + // CHECK: barrier.cluster.arrive.relaxed.aligned; + nvgpu.cluster_arrive {relaxed = true} + + llvm.return +} + +// CHECK-LABEL: @cluster_id +llvm.func @cluster_id() -> i32 { + // CHECK: %cluster_ctaid.x; + // CHECK-SAME: %cluster_ctaid.y; + // CHECK-SAME: %cluster_ctaid.z; + // CHECK-SAME: %cluster_nctaid.x; + // CHECK-SAME: %cluster_nctaid.y; + %id = nvgpu.cluster_id + llvm.return %id : i32 +} + +// ----- + +// CHECK-LABEL: @stmatrix +llvm.func @stmatrix(%i: i32, %ptr: !llvm.ptr<3>) { + // CHECK: stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4}; + nvgpu.stmatrix %ptr, %i, %i, %i, %i : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [$0], {$1, $2, $3, $4}; + nvgpu.stmatrix %ptr, %i, %i, %i, %i {trans} : !llvm.ptr<3>, i32, i32, i32, i32 + llvm.return +} + +// ----- + +// CHECK-LABEL: @ldmatrix +llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> { + // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4]; + %0 = nvgpu.ldmatrix %ptr : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {$0, $1, $2, $3}, [$4]; + %1 = nvgpu.ldmatrix %ptr {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32, i32, i32)> + %3 = llvm.insertvalue %2, %0[0] : !llvm.struct<(i32, i32, i32, i32)> + llvm.return %3 : !llvm.struct<(i32, i32, i32, i32)> +} + +// ----- + +!struct_128xf32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 +)> + +!struct_64xf32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 +)> + +// CHECK-LABEL: @wgmma +llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) { +// CHECK: wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 +%false = llvm.mlir.constant(false) : i1 +%acc0 = nvgpu.wgmma %desc, %desc, %false { + eltTypeA = 3 : i32, + eltTypeB = 3 : i32, + eltTypeC = 7 : i32, + layoutA = 0 : i32, + layoutB = 1 : i32, + m = 64 : i32, + n = 256 : i32, + k = 32 : i32 +} : (i64, i64, i1) -> !struct_128xf32 + + // CHECK: // wait for regs: $0,$1,$2,{{.*}},$127 + // CHECK: wgmma.wait_group.sync.aligned 0; + %out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} : !struct_64xf32 + llvm.return +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_base_lowering + // CHECK: %[[TID:.+]] = nvvm.read.ptx.sreg.tid.x : i32 + // CHECK: %[[C32:.+]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: %[[PRED:.+]] = llvm.icmp "ult" %[[TID]], %[[C32]] : i32 + // CHECK: %[[SHMEM:.+]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3> + // CHECK: %[[A:.+]] = llvm.inline_asm has_side_effects + // CHECK-SAME: "@$0 tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$1], 128;", "b,r" %[[PRED]], %[[SHMEM]] : (i1, !llvm.ptr<3>) -> !llvm.void + // CHECK: %[[AR:.+]] = llvm.load %[[SHMEM]] : !llvm.ptr<3> -> i32 + // CHECK: nvvm.barrier0 + // CHECK: "@$0 tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;", "b" %[[PRED]] : (i1) -> !llvm.void + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 tcgen05.dealloc.cta_group::1.sync.aligned.b32 $1, 128;", "b,r" %[[PRED]], %{{.+}} : (i1, !llvm.ptr<6>) -> !llvm.void + llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + llvm.func @tensor_memory_base_lowering() -> i32 attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = array} { + %263 = nvgpu.tensor_memory_base + %264 = llvm.ptrtoint %263 : !llvm.ptr<6> to i32 + llvm.return %264 : i32 + } +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: @warpid_warp_specialize +llvm.func @warpid_warp_specialize() { + // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x + // CHECK: [[ID:%.*]] = llvm.udiv [[TIDX]], [[C32]] + // CHECK: [[UNIFORM:%.*]] = nvvm.shfl.sync idx {{%[0-9]+}}, [[ID]] + %0 = nvgpu.warp_id + // CHECK: "use"([[UNIFORM]]) + "use"(%0) : (i32) -> () + + // CHECK: ttg.warp_specialize + ttg.warp_specialize() attributes {warpGroupStartIds = array} + // CHECK: default + default { + // CHECK: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x + // CHECK: [[ID:%.*]] = llvm.udiv [[TIDX]], [[C32]] + // CHECK: [[UNIFORM:%.*]] = nvvm.shfl.sync idx {{%[0-9]+}}, [[ID]] + %1 = nvgpu.warp_id + // CHECK: "use"([[UNIFORM]]) + "use"(%1) : (i32) -> () + ttg.warp_yield + } + // CHECK: partition0 + partition0() num_warps(4) { + // 6*32 = 196 + + // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[C192:%.*]] = llvm.mlir.constant(192 : i32) + // CHECK: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x + // CHECK: [[REL_TIDX:%.*]] = llvm.sub [[TIDX]], [[C192]] + // CHECK: [[ID:%.*]] = llvm.udiv [[REL_TIDX]], [[C32]] + // CHECK: [[UNIFORM:%.*]] = nvvm.shfl.sync idx {{%[0-9]+}}, [[ID]] + %1 = nvgpu.warp_id + // CHECK: "use"([[UNIFORM]]) + "use"(%1) : (i32) -> () + ttg.warp_return + } + partition1() num_warps(2) { + // 4*32 = 128 + + // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK: [[C128:%.*]] = llvm.mlir.constant(128 : i32) + // CHECK: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x + // CHECK: [[REL_TIDX:%.*]] = llvm.sub [[TIDX]], [[C128]] + // CHECK: [[ID:%.*]] = llvm.udiv [[REL_TIDX]], [[C32]] + // CHECK: [[UNIFORM:%.*]] = nvvm.shfl.sync idx {{%[0-9]+}}, [[ID]] + %1 = nvgpu.warp_id + // CHECK: "use"([[UNIFORM]]) + "use"(%1) : (i32) -> () + ttg.warp_return + } : () -> () + llvm.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/reduce_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/reduce_to_llvm.mlir new file mode 100644 index 000000000..0bbcecbd9 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/reduce_to_llvm.mlir @@ -0,0 +1,70 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: @reduce_linear_layout +tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1 + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2 + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3 + + // The layout looks lke + // [[ T0:0, T32:0, T0:1, T32:1, ... + // [ T4:0, T36:0, T4:1, T36:1, ... + // [ T0:2, T32:2, T0:3, T32:3, ... + // [ T4:2, T36:2, T4:3, T36:3, + // ... + // + // A reduction along axis=0 consists of adding registers (0, 2) and (1, 3) + // before shuffling. + // + // Columns along axis=0 are contained within a warp, so reduction arcoss warps + // is not needed. + + // Reduce within threads + // CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]] + // CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]] + + // Reduce within warp. + // CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31) + // CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]] + // CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31) + // CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]] + // CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31) + // CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]] + // CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31) + // CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]] + + // CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31) + // CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]] + // CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31) + // CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]] + // CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31) + // CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]] + // CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31) + // CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]] + + // CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0 + // CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1 + + %0 = "tt.reduce"(%arg0) ({ + ^bb0(%arg1: i32, %arg2: i32): + %1 = arith.addi %arg1, %arg2 : i32 + tt.reduce.return %1 : i32 + }) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> + + // CHECK-NEXT: ret { i32, i32 } [[DST1]] + tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> +} + +tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) { + %0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> + %1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)> + llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/scan_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/scan_to_llvm.mlir new file mode 100644 index 000000000..60b06b664 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/scan_to_llvm.mlir @@ -0,0 +1,68 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --canonicalize | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +#layout = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}> +#layout_adj = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}> +#layout_2d = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 2], warpsPerCTA = [2, 1], order = [0,1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: @test_1d_simple +tt.func private @test_1d_simple(%arg0: tensor<8xi32, #layout>) -> tensor<8xi32, #layout> { + // CHECK: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK: [[LANEID_AXIS:%.*]] = and i32 [[TID]], 7 + // CHECK: icmp eq i32 [[LANEID_AXIS]], 0 + %0 = "tt.scan"(%arg0) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg1: i32, %arg2: i32): + %1 = arith.addi %arg1, %arg2 : i32 + tt.scan.return %1 : i32 + }) : (tensor<8xi32, #layout>) -> tensor<8xi32, #layout> + tt.return %0 : tensor<8xi32, #layout> +} + +// CHECK-LABEL: @test_1d_grouped +tt.func private @test_1d_grouped(%arg0: tensor<8xi32, #layout_adj>) -> tensor<8xi32, #layout_adj> { + // CHECK: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK: [[LANEID_AXIS:%.*]] = and i32 [[TID]], 3 + // CHECK: icmp eq i32 [[LANEID_AXIS]], 0 + %0 = "tt.scan"(%arg0) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg1: i32, %arg2: i32): + %1 = arith.addi %arg1, %arg2 : i32 + tt.scan.return %1 : i32 + }) : (tensor<8xi32, #layout_adj>) -> tensor<8xi32, #layout_adj> + tt.return %0 : tensor<8xi32, #layout_adj> +} + +// CHECK-LABEL: @test_2d_grouped +tt.func private @test_2d_grouped(%arg0: tensor<16x1xi32, #layout_2d>) -> tensor<16x1xi32, #layout_2d> { + // CHECK: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK: [[LANEID_AXIS:%.*]] = and i32 [[TID]], 7 + // CHECK: icmp eq i32 [[LANEID_AXIS]], 0 + %0 = "tt.scan"(%arg0) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg1: i32, %arg2: i32): + %1 = arith.addi %arg1, %arg2 : i32 + tt.scan.return %1 : i32 + }) : (tensor<16x1xi32, #layout_2d>) -> tensor<16x1xi32, #layout_2d> + tt.return %0 : tensor<16x1xi32, #layout_2d> +} + +// This just prevents the test functions from being DCE'd. +tt.func public @anchor(%ptr: !llvm.ptr, %arg0: !llvm.struct<(i32)>, %arg1: !llvm.struct<(i32, i32)>, %arg2: !llvm.struct<(i32)>) { + %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.struct<(i32)> to tensor<8xi32, #layout> + %1 = tt.call @test_1d_simple(%0) : (tensor<8xi32, #layout>) -> tensor<8xi32, #layout> + %2 = builtin.unrealized_conversion_cast %1 : tensor<8xi32, #layout> to !llvm.struct<(i32)> + llvm.store volatile %2, %ptr : !llvm.struct<(i32)>, !llvm.ptr + + %3 = builtin.unrealized_conversion_cast %arg1 : !llvm.struct<(i32, i32)> to tensor<8xi32, #layout_adj> + %4 = tt.call @test_1d_grouped(%3) : (tensor<8xi32, #layout_adj>) -> tensor<8xi32, #layout_adj> + %5 = builtin.unrealized_conversion_cast %4 : tensor<8xi32, #layout_adj> to !llvm.struct<(i32, i32)> + llvm.store volatile %5, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr + + %6 = builtin.unrealized_conversion_cast %arg2 : !llvm.struct<(i32)> to tensor<16x1xi32, #layout_2d> + %7 = tt.call @test_2d_grouped(%6) : (tensor<16x1xi32, #layout_2d>) -> tensor<16x1xi32, #layout_2d> + %8 = builtin.unrealized_conversion_cast %7 : tensor<16x1xi32, #layout_2d> to !llvm.struct<(i32)> + llvm.store volatile %8, %ptr : !llvm.struct<(i32)>, !llvm.ptr + + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/tma_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/tma_to_llvm.mlir new file mode 100644 index 000000000..a22db0fe3 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tma_to_llvm.mlir @@ -0,0 +1,177 @@ +// RUN: triton-opt %s --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#linear = #ttg.linear<{register = [[1], [2], [16], [0]], lane = [[0], [0], [0], [0], [0]], warp = [[4], [8]], block = []}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + +// CHECK-LABEL: @tma_gather_simple +// CHECK-SAME: i32 [[Y0:%3]] +tt.func @tma_gather_simple(%arg0: !tt.ptr, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { + // There are 32 indices distributed to 4 warps, so each warp as 8 indices. + + // CHECK: [[BAR:%.*]] = extractvalue {{.*}} %1, 0 + // CHECK: [[BASE_PTR:%.*]] = extractvalue {{.*}} %4, 0 + + // CHECK: [[TIDX:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK: [[WIDX:%.*]] = lshr i32 [[TIDX]], 5 + // CHECK: [[WARP_ID:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[WIDX]], + + // CHECK: [[ELECT:%.*]] = tail call { i32, i1 } @llvm.nvvm.elect.sync + // CHECK: [[ELECT_PRED:%.*]] = extractvalue { i32, i1 } [[ELECT]], 1 + // CHECK: [[PRED:%.*]] = and i1 %5, [[ELECT_PRED]] + + // CHECK: [[IDX0:%.*]] = extractvalue {{.*}} %2, 0 + // CHECK: [[IDX1:%.*]] = extractvalue {{.*}} %2, 1 + // CHECK: [[IDX2:%.*]] = extractvalue {{.*}} %2, 2 + // CHECK: [[IDX3:%.*]] = extractvalue {{.*}} %2, 3 + + // CHECK: [[IDX4:%.*]] = extractvalue {{.*}} %2, 4 + // CHECK: [[IDX5:%.*]] = extractvalue {{.*}} %2, 5 + // CHECK: [[IDX6:%.*]] = extractvalue {{.*}} %2, 6 + // CHECK: [[IDX7:%.*]] = extractvalue {{.*}} %2, 7 + + // There are 32x128 = 4096 elements. Each gather4 will read 4*128/2 = 256 + // elements into smem. We need to issue 16 gather4 messages. Each warp will + // execute 4 gather4 instructions. + // + // The 64-element (128-byte) row segments are organized into shared memory + // by segments. I.e. + // + // [ t[0, 0:128], t[1: 0:128], ..., t[31: 0:128], t[0, 128:256], ..., t[31: 128:256] ]. + // + // This is captured by the `nvmma_shared` smem layout. + // + // Each warp will handle 4 consecutive row segments at a time, or 4*128 bytes + // per transaction, thus reading: + // + // t[warpId, 0:128], t[warpId, 128:256], t[warpId+16, 0:128], t[warpId+16, 128:256] + // + // Each group of 4 segments are 4*128/2 = 256 elements apart. So the starting + // addresses are [x, x+2048, x+1024, x+3072], where `x = warpId*256`. + // + // Note that result smem layout has a swizzle tile of [8, 64], and 8 such + // tiles comprise the result space. That means every other group of 4 row + // segments land in the middle of a swizzle tile, where the 0th logical column + // element may not be at the start of the tile. + + // CHECK: [[WARP_STRIDE_TMP:%.*]] = shl i32 [[WARP_ID]], 8 + // CHECK: [[WARP_STRIDE:%.*]] = and i32 [[WARP_STRIDE_TMP]], 768 + + // CHECK: [[OFFSET0:%.*]] = zext nneg i32 [[WARP_STRIDE]] to i64 + // CHECK: [[BASEPTR0:%.*]] = getelementptr bfloat, ptr addrspace(3) [[BASE_PTR]], i64 [[OFFSET0]] + // CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r" + // CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR0]], ptr addrspace(1) %0, i32 [[Y0]], i32 [[IDX0]], i32 [[IDX1]], i32 [[IDX2]], i32 [[IDX3]], ptr addrspace(3) [[BAR]]) + + // CHECK: [[OFFSET1_TMP:%.*]] = or disjoint i32 [[WARP_STRIDE]], 2048 + // CHECK: [[OFFSET1:%.*]] = zext nneg i32 [[OFFSET1_TMP]] to i64 + // CHECK: [[BASEPTR1:%.*]] = getelementptr bfloat, ptr addrspace(3) [[BASE_PTR]], i64 [[OFFSET1]] + // CHECK: [[Y1:%.*]] = add i32 [[Y0]], 64 + // CHECK: cp.async.bulk.tensor.2d.tile::gather4 + // CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR1]], ptr addrspace(1) %0, i32 [[Y1]], i32 [[IDX0]], i32 [[IDX1]], i32 [[IDX2]], i32 [[IDX3]], ptr addrspace(3) [[BAR]]) + + // CHECK: [[OFFSET2_TMP:%.*]] = or disjoint i32 [[WARP_STRIDE]], 1024 + // CHECK: [[OFFSET2:%.*]] = zext nneg i32 [[OFFSET2_TMP]] to i64 + // CHECK: [[BASEPTR2:%.*]] = getelementptr bfloat, ptr addrspace(3) [[BASE_PTR]], i64 [[OFFSET2]] + // CHECK: cp.async.bulk.tensor.2d.tile::gather4 + // CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR2]], ptr addrspace(1) %0, i32 [[Y0]], i32 [[IDX4]], i32 [[IDX5]], i32 [[IDX6]], i32 [[IDX7]], ptr addrspace(3) [[BAR]]) + + // CHECK: [[OFFSET3_TMP:%.*]] = or disjoint i32 [[WARP_STRIDE]], 3072 + // CHECK: [[OFFSET3:%.*]] = zext nneg i32 [[OFFSET3_TMP]] to i64 + // CHECK: [[BASEPTR3:%.*]] = getelementptr bfloat, ptr addrspace(3) [[BASE_PTR]], i64 [[OFFSET3]] + // CHECK: cp.async.bulk.tensor.2d.tile::gather4 + // CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR3]], ptr addrspace(1) %0, i32 [[Y1]], i32 [[IDX4]], i32 [[IDX5]], i32 [[IDX6]], i32 [[IDX7]], ptr addrspace(3) [[BAR]]) + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.ptr, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 + + // CHECK-NEXT: ret void + tt.return +} + +// CHECK-LABEL: @tma_gather_8_consecutive_indices +tt.func @tma_gather_8_consecutive_indices(%arg0: !tt.ptr, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { + // Due to the `sizePerThread = [1, 8]`, each warp now handles 8 consecutive + // rows, where each row is divided into 2 segments for a total of 4 gather4s. + // + // t[warpId, 0:128], t[warpId, 128:256], t[warpId+4, 0:128], t[warpId+4, 128:256] + // + // So the base addresses are [x, x+2048, x+256, x+2048+256], where `x = warpId*256`. + + // CHECK: [[WARP_ID:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32 + // CHECK: [[WARP_STRIDE_TMP:%.*]] = shl i32 [[WARP_ID]], 9 + // CHECK: [[OFFSET0:%.*]] = and i32 [[WARP_STRIDE_TMP]], 1536 + + // CHECK: zext nneg i32 [[OFFSET0]] to i64 + // CHECK: cp.async.bulk.tensor + + // CHECK: [[OFFSET1:%.*]] = or disjoint i32 [[OFFSET0]], 2048 + // CHECK: zext nneg i32 [[OFFSET1]] to i64 + // CHECK: cp.async.bulk.tensor + + // CHECK: [[OFFSET2:%.*]] = or disjoint i32 [[OFFSET0]], 256 + // CHECK: zext nneg i32 [[OFFSET2]] to i64 + // CHECK: cp.async.bulk.tensor + + // CHECK: [[OFFSET3:%.*]] = or disjoint i32 [[OFFSET0]], 2304 + // CHECK: zext nneg i32 [[OFFSET3]] to i64 + // CHECK: cp.async.bulk.tensor + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.ptr, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 + + // CHECK-NEXT: ret void + tt.return +} + +// CHECK-LABEL: @tma_gather_redundant_indices +tt.func @tma_gather_redundant_indices(%arg0: !tt.ptr, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #linear>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { + // Codegen for this case is actually incorrect due to linear layouts + // incorrectly handling register broadcasting, but the test outcome is nonetheless + // the same. + + // CHECK-COUNT-4: cp.async.bulk.tensor + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.ptr, tensor<32xi32, #linear>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 + // CHECK-NEXT: ret void + tt.return +} + +// CHECK-LABEL: @tma_gather_redundant_warps +tt.func @tma_gather_redundant_warps(%arg0: !tt.ptr, %arg1: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %arg2: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, %arg3: i32, %arg4: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, %arg5: i1) { + // CHECK: [[WARP_ID:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32 + // CHECK: [[WARP_SELECT:%.*]] = and i32 [[WARP_ID]], 2 + // CHECK: [[WARP_PRED:%.*]] = icmp eq i32 [[WARP_SELECT]], 0 + // CHECK: [[PRED_TMP:%.*]] = and i1 %5, [[WARP_PRED]] + // CHECK: [[ELECT:%.*]] = tail call { i32, i1 } @llvm.nvvm.elect.sync + // CHECK: [[ELECT_PRED:%.*]] = extractvalue { i32, i1 } [[ELECT]], 1 + // CHECK: [[PRED:%.*]] = and i1 [[ELECT_PRED]], [[PRED_TMP]] + + // CHECK-COUNT-8: cp.async.bulk.tensor{{.*}}(i1 [[PRED]], + ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.ptr, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1 + + // CHECK-NEXT: ret void + tt.return +} + +// CHECK-LABEL: @tma_scatter +tt.func @tma_scatter(%arg0: !tt.ptr, %arg1: tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, %arg2: i32, %arg3: !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>) { + // The lowering for `async_tma_scatter` shares practically all of its logic + // with `async_tma_gather`, so we don't need to re-test the indexing logic. + + // CHECK: [[BASE_PTR:%.*]] = extractvalue {{.*}} %3, 0 + // CHECK: [[ELECT:%.*]] = tail call { i32, i1 } @llvm.nvvm.elect.sync + // CHECK: [[PRED:%.*]] = extractvalue { i32, i1 } [[ELECT]], 1 + + // CHECK: [[PTR:%.*]] = getelementptr {{.*}} [[BASE_PTR]] + // CHECK-NEXT: "@$0 cp.async.bulk.tensor.2d.tile::scatter4.global.shared::cta.bulk_group [$1, {$2, $3, $4, $5, $6}], [$7];" + // CHECK-SAME: (i1 [[PRED]], ptr addrspace(1) %0, i32 %2, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, ptr addrspace(3) [[PTR]]) + ttng.async_tma_scatter %arg0[%arg1, %arg2] %arg3 : !tt.ptr, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable> + + // CHECK: nvvm.cp.async.bulk.commit.group() + + // CHECK-NEXT: ret void + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/triton_to_tritongpu.mlir b/third_party/enflame/include/triton/test/Conversion/triton_to_tritongpu.mlir new file mode 100644 index 000000000..5c677a05f --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/triton_to_tritongpu.mlir @@ -0,0 +1,176 @@ +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=cuda:80 num-warps=2' | FileCheck %s + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { +tt.func @ops() { + // CHECK: module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {{.*}} + %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> + %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> + %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> + %0 = tt.dot %a, %b, %c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + tt.return +} +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { +tt.func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + // Test if LoadOp is lowered properly (see #771) + %ptrs = tt.splat %ptr : !tt.ptr -> tensor<128x!tt.ptr> + %mask = arith.constant dense : tensor<128xi1> + %other = arith.constant dense<0.0e+0> : tensor<128xf32> + // CHECK: %{{.*}} = tt.load %{{.*}} : {{.*}} + %a = tt.load %ptrs : tensor<128x!tt.ptr> + // CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}} : {{.*}} + %b = tt.load %ptrs, %mask : tensor<128x!tt.ptr> + // CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}}, %{{.*}} : {{.*}} + %c = tt.load %ptrs, %mask, %other : tensor<128x!tt.ptr> + tt.store %ptrs, %a : tensor<128x!tt.ptr> + tt.store %ptrs, %b : tensor<128x!tt.ptr> + tt.store %ptrs, %c : tensor<128x!tt.ptr> + tt.return +} +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { +tt.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + // Test if the total number of threadsPerWarp is 32 + // Test if the total number of warps is 2 + // CHECK: #[[blocked0:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: #[[blocked1:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: #[[blocked2:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> + // CHECK: module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {{.*}} + %c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32> + %c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32> + %c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32> + // CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #ttg.slice<{dim = 0, parent = #[[blocked0]]}>> + %c0_ = "tt.reduce" (%c0) ({ + ^bb0(%arg1: f32, %arg2: f32): + %add = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<4x4xf32>) -> tensor<4xf32> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #ttg.slice<{dim = 0, parent = #[[blocked1]]}> + %c1_ = "tt.reduce" (%c1) ({ + ^bb0(%arg3: f32, %arg4: f32): + %add = arith.addf %arg3, %arg4 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<8x2xf32>) -> tensor<2xf32> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #ttg.slice<{dim = 1, parent = #[[blocked1]]}>> + %c2_ = "tt.reduce" (%c1) ({ + ^bb0(%arg5: f32, %arg6: f32): + %add = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<8x2xf32>) -> tensor<8xf32> + // CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[blocked2]]}>> + %c3_ = "tt.reduce" (%c2) ({ + ^bb0(%arg7: f32, %arg8: f32): + %add = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<16x16xf32>) -> tensor<16xf32> + + tt.return +} +} + + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { +tt.func public @select_op(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i1) attributes {noinline = false} { + // CHECK-LABEL: select_op + %cst = arith.constant dense<0.000000e+00> : tensor<128xf32> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.load %2 : tensor<128x!tt.ptr> + + // CHECK: %{{.*}} = arith.select %arg2, %{{.*}}, %{{.*}} : tensor<128xf32, #blocked> + %4 = arith.select %arg2, %cst, %3 : tensor<128xf32> + + %5 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %6, %4 : tensor<128x!tt.ptr> + tt.return +} +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { +tt.func @arith_splat_bool(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK-LABEL: arith_splat_bool + + // Test arith.constant with splatted bool. + // CHECK-NEXT: arith.constant dense : tensor<128xi1, #{{.*}}> + %mask = arith.constant dense : tensor<128xi1> + tt.return +} +} + +// ----- + +// CHECK-LABEL: gather_op +tt.func @gather_op() { + %cst = arith.constant dense<1.0> : tensor<128x4xf32> + %cst_0 = arith.constant dense<1> : tensor<256x4xi32> + // CHECK: tt.gather %{{.*}}[%{{.*}}] {axis = 0 : i32} : (tensor<128x4xf32, #blocked>, tensor<256x4xi32, #blocked>) -> tensor<256x4xf32, #blocked> + %0 = tt.gather %cst[%cst_0] {axis = 0 : i32} : (tensor<128x4xf32>, tensor<256x4xi32>) -> tensor<256x4xf32> + tt.return +} + +// ----- + +// CHECK: [[SLICE_PARENT:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0]}> + +// CHECK: @gather4_layout +tt.func @gather4_layout(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: !tt.ptr) { + %cst = arith.constant dense<1> : tensor<32xi32> + // CHECK: [[IDX:%.*]] = ttg.convert_layout %cst : tensor<32xi32, #{{.*}}> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = [[SLICE_PARENT]]}>> + %0 = tt.experimental_descriptor_gather %arg0[%cst, %arg1] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xf32> + %1 = tt.splat %arg2 : !tt.ptr -> tensor<32x128x!tt.ptr> + tt.store %1, %0 : tensor<32x128x!tt.ptr> + tt.return +} + +// CHECK: @scatter4_layout +tt.func @scatter4_layout(%arg0: !tt.tensordesc>, %arg1: i32, %arg2: !tt.ptr) { + %cst = arith.constant dense<1> : tensor<32xi32> + %0 = tt.splat %arg2 : !tt.ptr -> tensor<32x128x!tt.ptr> + %1 = tt.load %0 : tensor<32x128x!tt.ptr> + // CHECK: [[IDX:%.*]] = ttg.convert_layout %cst : tensor<32xi32, #{{.*}}> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = [[SLICE_PARENT]]}>> + tt.experimental_descriptor_scatter %arg0[%cst, %arg1], %1 : !tt.tensordesc>, tensor<32xi32>, i32, tensor<32x128xf32> + tt.return +} + +// ----- + +// CHECK-LABEL: @ub_poison +tt.func @ub_poison() { + // CHECK-NEXT: ub.poison : tensor<128x64xf16, #blocked> + %0 = ub.poison : tensor<128x64xf16> + tt.return +} + +// ----- + +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32} { + +// CHECK-LABEL: @partition_axis_info +tt.func @partition_axis_info(%arg0: !tt.ptr, %arg1: !tt.ptr) { + ttg.warp_specialize(%arg0) + default { + ttg.warp_yield + } + partition0(%arg2: !tt.ptr) num_warps(2) { + %splatted = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %input = tt.load %splatted : tensor<256x!tt.ptr, #blocked2> + ttg.warp_return + } : (!tt.ptr) -> () + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm.mlir new file mode 100644 index 000000000..ea0716c4b --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm.mlir @@ -0,0 +1,2282 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s --dump-input-context 20 + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) + // Here the 128 comes from the 4 in module attribute multiples 32 + // CHECK: nvvm.kernel = 1 : ui1, nvvm.reqntid = array + tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + // CHECK: llvm.return + tt.return + } +} // end module + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_load + tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: mov.u32 $0, $1; + // CHECK-SAME: @$3 ld.global.b32 { $0 }, [ $2 + 0 ];", "=r,r,l,b" + // CHECK: llvm.inline_asm + %1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: vectorized_load + tt.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.b32 + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.b32 + %1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: vectorized_load_f16 + tt.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.b16 + // CHECK: llvm.inline_asm + // CHECK-SAME: ld.global.b16 + %1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +// TODO: masked load with vectorization is pending on TODO +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: masked_load_const_other + tt.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +// TODO: masked load with vectorization is pending on TODO +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: masked_load_const_other_vec + tt.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: store_with_cache_attr + tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: st.global.L1::evict_last.b32 + // CHECK: llvm.inline_asm + // CHECK-SAME: st.global.L1::evict_last.b32 + tt.store %a_ptr_init, %cst_0, %cst evictionPolicy = evict_last cacheModifier = ca : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { + // CHECK-LABEL: global_load_store_no_vec + tt.func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Load 4 elements from vector0 + // CHECK: mov.u32 $0, 0x0 + // CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: mov.u32 $0, 0x0 + // CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: mov.u32 $0, 0x0 + // CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: mov.u32 $0, 0x0 + // CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 4 elements from vector1 + // CHECK: mov.u32 $0, 0x0 + // CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: mov.u32 $0, 0x0 + // CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: mov.u32 $0, 0x0 + // CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: mov.u32 $0, 0x0 + // CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> + %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Store 4 elements to global + // CHECK: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + // CHECK: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + // CHECK: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + // CHECK: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { + // CHECK-LABEL: global_load_store_vec4 + tt.func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Load 4 elements from A with single one vectorized load instruction + // CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 4 elements from B with single one vectorized load instruction + // CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> + %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Store 4 elements to global with single one vectorized store instruction + // CHECK: st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +// This test verifies the vectorization of Load and Store Ops. +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1. +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { + tt.func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<64xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<64xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<64x!tt.ptr, #blocked>, tensor<64xi32, #blocked> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr, #blocked>, tensor<64xi32, #blocked> + %9 = tt.splat %n_elements : i32 -> tensor<64xi32, #blocked> + %10 = arith.cmpi "slt", %4, %9 : tensor<64xi32, #blocked> + // load op has a vector width = 1 due to the %mask's alignment + // CHECK: ld.global.b32 + %11 = tt.load %6, %10 : tensor<64x!tt.ptr, #blocked> + %12 = tt.load %8, %10 : tensor<64x!tt.ptr, #blocked> + %13 = arith.addf %11, %12 : tensor<64xf32, #blocked> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<64x!tt.ptr, #blocked> + %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr, #blocked>, tensor<64xi32, #blocked> + tt.store %15, %13, %10 : tensor<64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: global_load_store_vec2 + tt.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Load 8 elements from A with four vectorized load instruction + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 8 elements from B with four vectorized load instruction + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> + %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Store 8 elements to global with four vectorized store instruction + // CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: global_load_store_vec2 + tt.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Load 8 elements from A with four vectorized load instruction + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 8 elements from B with four vectorized load instruction + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> + %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Store 8 elements to global with four vectorized store instruction + // CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + // CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} }; + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: global_load_store_vec8 + tt.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Load 8 elements from A with two vectorized load instruction + // CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 8 elements from B with two vectorized load instruction + // CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> + %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + + // Store 8 elements to global with two vectorized store instruction + // CHECK: st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + // CHECK: st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +// Slice layout with 2 unique elements, but 8 total elements per thread +#blocked2d = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +#slice = #ttg.slice<{dim = 1, parent = #blocked2d}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { + // CHECK-LABEL: global_load_store_slice + tt.func @global_load_store_slice(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #slice> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #slice> + %4 = arith.addi %3, %2 : tensor<128xi32, #slice> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr, #slice> + %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr, #slice>, tensor<128xi32, #slice> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr, #slice> + %8 = tt.addptr %7, %4 : tensor<128x!tt.ptr, #slice>, tensor<128xi32, #slice> + + // Load 2 element from vector0 without predicate + // CHECK: mov.u32 $0, 0x0 + // CHECK-NOT: @{{.*}} ld.global + // CHECK-COUNT-2: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 2 elements from vector1 without predicate + // CHECK: mov.u32 $0, 0x0 + // CHECK-NOT: @{{.*}} ld.global + // CHECK-COUNT-2: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + %9 = tt.load %6 : tensor<128x!tt.ptr, #slice> + %10 = tt.load %8 : tensor<128x!tt.ptr, #slice> + %11 = arith.addf %9, %10 : tensor<128xf32, #slice> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr, #slice> + %13 = tt.addptr %12, %4 : tensor<128x!tt.ptr, #slice>, tensor<128xi32, #slice> + + // Store 2 element to global without predicate + // CHECK-NOT: @{{.*}} st.global + // CHECK-COUNT-2: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + tt.store %13, %11 : tensor<128x!tt.ptr, #slice> + tt.return + } +} + +// TODO: Add a testcase to verify the optimization when ptr of the LoadOp +// is from an addptr with const idx + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_view_broadcast + tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { + // CHECK: llvm.mlir.undef + // CHECK: %[[T0:.*]] = llvm.extractvalue + // CHECK: %[[T1:.*]] = llvm.extractvalue + %0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T1]] + // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T1]] + // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T1]] + // CHECK: llvm.insertvalue %[[T0]] + // CHECK: llvm.insertvalue %[[T1]] + %1 = tt.broadcast %0 : tensor<256x1xf32,#blocked2> -> tensor<256x4xf32, #blocked2> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: basic_make_range + tt.func @basic_make_range() { + // CHECK: nvvm.read.ptx.sreg.tid.x + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + tt.return + } +} + + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: sliced_layout_make_range + tt.func @sliced_layout_make_range() { + // CHECK: nvvm.read.ptx.sreg.tid.x + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked0}>> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_addf + tt.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { + // CHECK: llvm.fadd + // CHECK: llvm.fadd + %1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_addi + tt.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + // CHECK: llvm.add + // CHECK: llvm.add + %1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0> + tt.return + } +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_program_id + tt.func @basic_program_id() { + // CHECK: llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32 + %0 = tt.get_program_id x : i32 + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_addptr + tt.func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + // CHECK: llvm.getelementptr + // CHECK: llvm.getelementptr + %0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + tt.return + } +} + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: basic_alloc_tensor + tt.func @basic_alloc_tensor() { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: llvm.getelementptr + // CHECK-NEXT: llvm.mlir.constant + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: basic_subview + tt.func @basic_subview() { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: llvm.mlir.constant(32 : i32) : i32 + // CHECK-NEXT: llvm.mlir.constant(512 : i32) : i32 + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.mul + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.mul + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.mul + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.getelementptr + %index = arith.constant 1 : i32 + %zero = arith.constant 0 : i32 + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> + %1 = ttg.memdesc_subview %0[%index, %zero, %zero] : !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_async_wait + tt.func @basic_async_wait() { + // CHECK: nvvm.cp.async.wait.group 4 + ttg.async_wait {num = 4: i32} + tt.return + } +} + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#slice1d0 = #ttg.slice<{dim = 0, parent = #blocked1}> +#shared1D = #ttg.swizzled_shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared2D = #ttg.swizzled_shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: basic_insert_slice_async_1d + tt.func @basic_insert_slice_async_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<64> : tensor<64xi32, #slice1d0> + %58 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr, #slice1d0> + %24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0> + %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> + %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> + %71 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared2D, #smem, mutable> + %subview = ttg.memdesc_subview %71[%c0_i32, %c0_i32] : + !ttg.memdesc<2x64xi64, #shared2D, #smem, mutable> -> + !ttg.memdesc<64xi64, #shared1D, #smem, mutable> + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK: nvvm.cp.async.commit.group + %73 = ttg.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr, #slice1d0> -> !ttg.memdesc<64xi64, #shared1D, #smem, mutable> + ttg.async_commit_group %73 + tt.return + } +} + +// ----- + +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v4 + tt.func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : tensor<16xi32, #slice2d1> -> tensor<16x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : tensor<64xi32, #slice3d0> -> tensor<1x64xi32, #block3> + %broadcast_off0_scalar = tt.broadcast %off0 : tensor<16x1xi32, #block2> -> tensor<16x64xi32, #block2> + %cst_scalar = arith.constant 64 : i32 + %cst = tt.splat %cst_scalar : i32 -> tensor<16x64xi32, #block2> + %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x64xi32, #block2> + %broadcast_off1_ = tt.broadcast %off1 : tensor<1x64xi32, #block3> -> tensor<16x64xi32, #block3> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<16x64xi32, #block2> -> tensor<16x64xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<16x64xi32, #block3> -> tensor<16x64xi32, #AL> + %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> + %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x64x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x64xf32, #A, #smem, mutable> + %index = arith.constant 1 : i32 + + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;" + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;" + // CHECK: nvvm.cp.async.commit.group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !ttg.memdesc<16x64xf32, #A, #smem, mutable> + ttg.async_commit_group + tt.return + } +} + +// ----- + +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v1 + tt.func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : tensor<16xi32, #slice2d1> -> tensor<16x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : tensor<32xi32, #slice3d0> -> tensor<1x32xi32, #block3> + %broadcast_off0_scalar = tt.broadcast %off0 : tensor<16x1xi32, #block2> -> tensor<16x32xi32, #block2> + %cst_scalar = arith.constant 32 : i32 + %cst = tt.splat %cst_scalar : i32 -> tensor<16x32xi32, #block2> + %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<16x32xi32, #block2> + %broadcast_off1_ = tt.broadcast %off1 : tensor<1x32xi32, #block3> -> tensor<16x32xi32, #block3> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<16x32xi32, #block2> -> tensor<16x32xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<16x32xi32, #block3> -> tensor<16x32xi32, #AL> + %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> + %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x32x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL>, tensor<16x32xi32, #AL> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf32, #A, #smem, mutable> + %index = arith.constant 1 : i32 + + // CHECK: llvm.inline_asm + // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: nvvm.cp.async.commit.group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !ttg.memdesc<16x32xf32, #A, #smem, mutable> + ttg.async_commit_group + tt.return + } +} + +// ----- + +#block0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#block2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#block3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice2d1 = #ttg.slice<{dim = 1, parent=#block2}> +#slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#A = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_insert_slice_async_v1_multictas + tt.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1> + %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> + %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : tensor<32xi32, #slice2d1> -> tensor<32x1xi32, #block2> + %off1 = tt.expand_dims %off1_ {axis = 0 : i32} : tensor<32xi32, #slice3d0> -> tensor<1x32xi32, #block3> + %broadcast_off0_scalar = tt.broadcast %off0 : tensor<32x1xi32, #block2> -> tensor<32x32xi32, #block2> + %cst_scalar = arith.constant 32 : i32 + %cst = tt.splat %cst_scalar : i32 -> tensor<32x32xi32, #block2> + %broadcast_off0_ = arith.muli %broadcast_off0_scalar, %cst : tensor<32x32xi32, #block2> + %broadcast_off1_ = tt.broadcast %off1 : tensor<1x32xi32, #block3> -> tensor<32x32xi32, #block3> + %broadcast_off0 = ttg.convert_layout %broadcast_off0_ : tensor<32x32xi32, #block2> -> tensor<32x32xi32, #AL> + %broadcast_off1 = ttg.convert_layout %broadcast_off1_ : tensor<32x32xi32, #block3> -> tensor<32x32xi32, #AL> + %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> + %a_init = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #A, #smem, mutable> + %index = arith.constant 1 : i32 + + // CHECK: llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.mlir.constant(16 : i32) : i32 + // CHECK: llvm.mul + // CHECK: llvm.add + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4;" + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: nvvm.cp.async.commit.group + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !ttg.memdesc<32x32xf32, #A, #smem, mutable> + ttg.async_commit_group + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: basic_splat + tt.func @basic_splat(%ptr: !tt.ptr) { + // CHECK: llvm.mlir.undef + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + %0 = tt.splat %ptr : !tt.ptr -> tensor<256x!tt.ptr,#blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: basic_store + tt.func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + // CHECK: llvm.inline_asm + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + tt.store %ptrs, %vals, %mask : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: convert_layout_blocked_blocked + tt.func @convert_layout_blocked_blocked(%arg0: tensor<32x32xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared + // CHECK-: nvvm.barrier0 + // CHECK-COUNT-8: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked0> -> tensor<32x32xf32, #blocked1> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: convert_layout_blocked_blocked_vec + tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<32x32xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked0> -> tensor<32x32xf32, #blocked1> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + +// CHECK-LABEL: convert_layout_ptr_element +tt.func @convert_layout_ptr_element(%arg0: tensor<16x16x!tt.ptr, #blocked0>) { + // CHECK: llvm.ptrtoint + // CHECK: llvm.inttoptr + %0 = ttg.convert_layout %arg0 : tensor<16x16x!tt.ptr, #blocked0> -> tensor<16x16x!tt.ptr, #blocked2> + tt.return +} + +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep + tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK: llvm.load + // CHECK: nvvm.barrier0 + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_dot_ldmatrix + tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { + %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK-NOT: nvgpu.ldmatrix + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.swizzled_shared<{vec = 8, perPhase=1, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_dot + tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { + %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK-NOT: nvgpu.ldmatrix + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase=1, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_dot + tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { + %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + // CHECK-NOT: nvgpu.ldmatrix + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_dot_mmav3_shared + tt.func @convert_dot_mmav3_shared(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) { + %AA = ttg.local_alloc %A : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem> + // CHECK-NOT: nvgpu.ldmatrix + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_b> + %cst0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0> + + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<64x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<64x64xf32, #mma0> + + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.swizzled_shared<{vec = 16, perPhase=1, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=4}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=4}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_dot_fp8 + tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) { + %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> + // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK-NOT: nvgpu.ldmatrix + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf8E5M2, #dot_operand_a> * tensor<16x16xf8E5M2, #dot_operand_b> -> tensor<16x16xf32, #mma0> + + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: convert_layout_mmav2_block + tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: st.shared + // CHECK: llvm.inline_asm + // CHECK-SAME: st.shared + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice = #ttg.slice<{dim = 0, parent = #mma}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg + tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: convert_layout_mmav3_transpose + tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) { + // CHECK-COUNT-16: st.shared.b8 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load {{.*}} -> vector<4xi32> + // CHECK-COUNT-16: st.shared.b8 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load {{.*}} -> vector<4xi32> + // CHECK-COUNT-16: st.shared.b8 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load {{.*}} -> vector<4xi32> + // CHECK-COUNT-16: st.shared.b8 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load {{.*}} -> vector<4xi32> + // CHECK-COUNT-16: st.shared.b8 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load {{.*}} -> vector<4xi32> + // CHECK-COUNT-16: st.shared.b8 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load {{.*}} -> vector<4xi32> + // CHECK-COUNT-16: st.shared.b8 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load {{.*}} -> vector<4xi32> + // CHECK-COUNT-16: st.shared.b8 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load {{.*}} -> vector<4xi32> + %0 = ttg.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> + tt.return + } +} + +// ----- +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK: llvm.mlir.global external @global_smem + // CHECK-LABEL: convert_layout_blocked_shared + tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr<3> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked1d_to_slice0 + tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { + // CHECK: llvm.load {{.*}} -> vector<4xi32> + %cvt = ttg.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked1d_to_slice1 + tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { + // CHECK-COUNT-8: llvm.load {{.*}} -> i32 + %cvt = ttg.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_blocked_to_blocked_ptr + tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { + // CHECK: llvm.ptrtoint + // CHECK: inline_asm{{.*}}st.shared + // CHECK: nvvm.barrier0 + // CHECK: llvm.inttoptr + // CHECK-COUNT-4: llvm.insertvalue + %cvt = ttg.convert_layout %src : tensor<32x!tt.ptr, #blocked0> -> tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// Regression test for https://github.com/triton-lang/triton/issues/5745 +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], warp = [[1, 0], [2, 0], [4, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 2]], lane = [[0, 0], [0, 0], [0, 0], [0, 0], [1, 0]], warp = [[2, 0], [4, 0], [0, 1]], block = []}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: linear_layout_with_multiple_iterations + tt.func @linear_layout_with_multiple_iterations(%src: tensor<8x4xbf16, #linear>) { + %cvt = ttg.convert_layout %src : tensor<8x4xbf16, #linear> -> tensor<8x4xbf16, #linear1> + // CHECK: inline_asm{{.*}}st.shared.v2 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK: nvvm.barrier0 + // CHECK: inline_asm{{.*}}st.shared.v2 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + // CHECK: nvgpu.ldmatrix + %a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b> + + %28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> + %38 = ttg.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> + + %30 = tt.splat %ptr : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> + tt.store %36, %38 : tensor<128x256x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#blocked}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#blocked}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + // CHECK: llvm.intr.fmuladd + %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b> + + %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> + %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + tt.store %36, %28 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=1}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_tf32dot + tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // CHECK: nvgpu.ldmatrix + // CHECK-SAME: (i32, i32, i32, i32) + // CHECK: nvgpu.ldmatrix + // CHECK-SAME: (i32, i32, i32, i32) + %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> + %38 = ttg.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + + %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + tt.store %36, %38 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-LABEL: atomic_add_f32 + tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32 + %0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-LABEL: atomic_add_f32_scalar + tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { + // CHECK: llvm.icmp "eq" + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32 + %0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (!tt.ptr, f32, i1) -> f32 + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-LABEL: atomic_add_f32 + tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.sys.relaxed.add.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.sys.relaxed.add.f32 + %0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_nomask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: store_f32 + tt.func @store_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: st.global.b32 + // CHECK: llvm.inline_asm + // CHECK-SAME: st.global.b32 + tt.store %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: store_f32_scalar + tt.func @store_f32_scalar(%arg0 : !tt.ptr, %arg1 : f32) { + // CHECK: llvm.icmp "eq" + // CHECK: llvm.inline_asm + // CHECK-SAME: @$2 st.global.b32 + tt.store %arg0, %arg1 : !tt.ptr + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +// CHECK-LABEL: test_get_program_id +tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { + %blockidx = tt.get_program_id x: i32 + %blockidy = tt.get_program_id y: i32 + %blockidz = tt.get_program_id z: i32 + // CHECK: ctaid.x + // CHECK: ctaid.y + // CHECK: ctaid.z + %v0 = arith.addi %blockidx, %blockidy : i32 + %v1 = arith.addi %v0, %blockidz : i32 + %0 = tt.splat %v1 : i32 -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32x!tt.ptr, #blocked0> + + tt.return +} + +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { +// CHECK-LABEL: test_get_program_id +tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { + %blockidx = tt.get_program_id x: i32 + %blockidy = tt.get_program_id y: i32 + %blockidz = tt.get_program_id z : i32 + // CHECK: clusterid.x + // CHECK: clusterid.y + // CHECK: clusterid.z + %v0 = arith.addi %blockidx, %blockidy : i32 + %v1 = arith.addi %v0, %blockidz : i32 + %0 = tt.splat %v1 : i32 -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32x!tt.ptr, #blocked0> + + tt.return +} + +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: test_get_num_program + tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { + %blockdimx = tt.get_num_programs x : i32 + %blockdimy = tt.get_num_programs y : i32 + %blockdimz = tt.get_num_programs z : i32 + // CHECK: nctaid.x + // CHECK: nctaid.y + // CHECK: nctaid.z + %v0 = arith.addi %blockdimx, %blockdimy : i32 + %v1 = arith.addi %v0, %blockdimz : i32 + %0 = tt.splat %v1 : i32 -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32x!tt.ptr, #blocked0> + + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { + %blockdimx = tt.get_num_programs x : i32 + %blockdimy = tt.get_num_programs y : i32 + %blockdimz = tt.get_num_programs z : i32 + // CHECK: nclusterid.x + // CHECK: nclusterid.y + // CHECK: nclusterid.z + %v0 = arith.addi %blockdimx, %blockdimy : i32 + %v1 = arith.addi %v0, %blockdimz : i32 + %0 = tt.splat %v1 : i32 -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32x!tt.ptr, #blocked0> + + tt.return + } +} + +// ----- +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: test_index_cache + tt.func @test_index_cache() { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + tt.return + } +} + +// ----- +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: test_base_index_cache + tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> + %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> + tt.return + } +} + +// ----- +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: test_index_cache_different_block + tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> + cf.cond_br %arg1, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> + cf.br ^bb2 + ^bb2: // 2 preds: ^bb0, ^bb1 + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=1}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_tf32_cst_b + tt.func @matmul_tf32_cst_b(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a: tensor<32x16xf32, #dot_operand_a>, %c: tensor<32x32xf32, #mma>) { + // CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[BC:.+]] = llvm.bitcast %[[CST]] : f32 to f32 + // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + %b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b> + %28 = tt.dot %a, %b_mat, %c, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> + %38 = ttg.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + tt.store %36, %38 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: matmul_f16_cst_operands + tt.func public @matmul_f16_cst_operands(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // CHECK: %[[U:.+]] = llvm.mlir.undef : vector<2xf16> + // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[V0:.+]] = llvm.insertelement %{{.*}}, %[[U]][%[[C0]] : i32] : vector<2xf16> + // CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[V1:.+]] = llvm.insertelement %{{.*}}, %[[V0]][%[[C1]] : i32] : vector<2xf16> + // CHECK: %[[BC:.+]] = llvm.bitcast %[[V1]] : vector<2xf16> to i32 + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %4 = arith.muli %3, %cst_2 : tensor<32x1xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %9 = tt.broadcast %6 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %10 = tt.broadcast %8 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %11 = tt.addptr %9, %10 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %12 = arith.truncf %1 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> + tt.store %11, %12 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: test_s8_to_bf16_conversion + tt.func @test_s8_to_bf16_conversion(%in: tensor<32xi8, #blocked>) { + // We can't vectorize if we only process + // CHECK-NOT: llvm.inline_asm + // CHECK: llvm.sitofp + // CHECK-NOT: llvm.sitofp + %out = arith.sitofp %in : tensor<32xi8, #blocked> to tensor<32xbf16, #blocked> + tt.return + } +} + +// ----- +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: test_s8_to_bf16_vectorized_conversion + tt.func @test_s8_to_bf16_vectorized_conversion(%in: tensor<16x16xi8, #mma>) { + // CHECK-NOT: llvm.sitofp + // 8 elements per thread => we should process 2 vectors of 4 + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + // CHECK-NOT: llvm.inline_asm + %out = arith.sitofp %in : tensor<16x16xi8, #mma> to tensor<16x16xbf16, #mma> + tt.return + } +} + +// ----- + +// CHECK-LABEL: sum_reduction +// CHECK: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32 +// CHECK: nvvm.redux.sync add %{{.*}}, %[[M]] +// CHECK: nvvm.barrier0 +// CHECK: nvvm.shfl.sync bfly +// CHECK: nvvm.shfl.sync bfly +// CHECK: nvvm.barrier0 +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @sum_reduction(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<1024> : tensor<1x1xi32, #blocked> + %0 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1> + %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi32, #blocked> + %3 = arith.muli %2, %cst : tensor<1x1xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<1x1x!tt.ptr, #blocked>, tensor<1x1xi32, #blocked> + %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x1024xi32, #blocked> + %8 = tt.broadcast %5 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<1x1024x!tt.ptr, #blocked>, tensor<1x1024xi32, #blocked> + %10 = tt.load %9 : tensor<1x1024x!tt.ptr, #blocked> + %11 = "tt.reduce"(%10) <{axis = 1 : i32}> ({ + ^bb0(%arg2: i32, %arg3: i32): + %15 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %15 : i32 + }) : (tensor<1x1024xi32, #blocked>) -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %12 = ttg.convert_layout %11 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %13 = tt.splat %arg1 : !tt.ptr -> tensor<1x!tt.ptr, #blocked1> + %14 = tt.addptr %13, %0 : tensor<1x!tt.ptr, #blocked1>, tensor<1xi32, #blocked1> + tt.store %14, %12 : tensor<1x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#slice = #ttg.slice<{dim = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { + // CHECK-LABEL: reduce_bools + tt.func public @reduce_bools(%arg: tensor<256x2xi1, #blocked>) { + // CHECK: llvm.mlir.addressof @global_smem + %24 = "tt.reduce"(%arg) <{axis = 1 : i32}> ({ + ^bb0(%arg4: i1, %arg5: i1): + %48 = arith.ori %arg4, %arg5 : i1 + tt.reduce.return %48 : i1 + }) : (tensor<256x2xi1, #blocked>) -> tensor<256xi1, #slice> + tt.return + } +} + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: inline_asm + tt.func public @inline_asm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<512x!tt.ptr, #blocked> + %2 = tt.addptr %1, %0 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + %3 = tt.load %2 : tensor<512x!tt.ptr, #blocked> +// CHECK: %{{.*}} = llvm.inline_asm asm_dialect = att "shl.b32 $0, $0, 3;", "=r,r" %{{.*}} : (vector<4xi8>) -> vector<4xi8> + %4 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" {constraints = "=r,r", packed_element = 4 : i32, pure = true} %3 : tensor<512xi8, #blocked> -> tensor<512xi8, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<512x!tt.ptr, #blocked> + %6 = tt.addptr %5, %0 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + tt.store %6, %4 : tensor<512x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: inline_asm_pack_16bit + tt.func public @inline_asm_pack_16bit(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<512x!tt.ptr, #blocked> + %2 = tt.addptr %1, %0 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + %3 = tt.load %2 : tensor<512x!tt.ptr, #blocked> +// CHECK: %{{.*}} = llvm.inline_asm asm_dialect = att "shl.b16 $0, $0, 3;", "=h,h" %{{.*}} : (vector<2xi8>) -> vector<2xi8> + %4 = tt.elementwise_inline_asm "shl.b16 $0, $0, 3;" {constraints = "=h,h", packed_element = 2 : i32, pure = true} %3 : tensor<512xi8, #blocked> -> tensor<512xi8, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<512x!tt.ptr, #blocked> + %6 = tt.addptr %5, %0 : tensor<512x!tt.ptr, #blocked>, tensor<512xi32, #blocked> + tt.store %6, %4 : tensor<512x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: reduce_slice +// CHECK-NOT: st.shared +// CHECK-NOT: ld.shared +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#sliced2 = #ttg.slice<{dim = 2, parent = #blocked}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @reduce_slice() attributes {noinline = false} { + %cst = arith.constant dense : tensor<4x1xi1, #sliced2> + %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ + ^bb0(%arg0: i1, %arg1: i1): + %1 = arith.ori %arg0, %arg1 : i1 + tt.reduce.return %1 : i1 + }) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #ttg.slice<{dim = 1, parent = #sliced2}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: reduce_md_slice +// CHECK: st.shared +// CHECK: st.shared +// CHECK: ld.shared +// CHECK: st.shared +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 2, 2], order = [2, 1, 0]}> +#sliced = #ttg.slice<{dim = 2, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @reduce_md_slice(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<2x128xf32, #ttg.slice<{dim = 2, parent = #blocked}>> + %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %18 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %18 : f32 + }) {allocation.offset = 0 : i32} : (tensor<2x128xf32, #sliced>) -> tensor<2xf32, #ttg.slice<{dim = 1, parent = #sliced}>> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #ttg.swizzled_shared<{vec = 8, perPhase=1, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @i16_mma_layout(%f16_inp: tensor<16x16xf16, #blocked0>, %i16_inp: tensor<16x16xi16, #blocked0>) { + // CHECK-LABEL: @i16_mma_layout + + %f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + %i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem> + + // CHECK: nvgpu.ldmatrix + // CHECK: nvgpu.ldmatrix + + %f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> + %i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b> + + // CHECK: llvm.sitofp %{{.*}} : i16 to f16 + + %converted_i16 = arith.sitofp %i16_dot : tensor<16x16xi16, #dot_operand_b> to tensor<16x16xf16, #dot_operand_b> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + + %out = tt.dot %f16_dot, %converted_i16, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma> + + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: convert_single_element + // CHECK-NOT: llvm.store + // CHECK-NOT: llvm.load + // CHECK: llvm.return + tt.func public @convert_single_element() attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> + %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: convert_single_element_and_add + // CHECK-NOT: llvm.store + // CHECK-NOT: llvm.load + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + tt.func public @convert_single_element_and_add() attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> + %cst2 = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked> + %0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> + %1 = arith.addf %0, %cst2 : tensor<1xf32, #blocked> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @vectorize_shmem_load + // CHECK: llvm.load + // CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8> + // CHECK-NOT: llvm.load + tt.func public @vectorize_shmem_load(%shmem : !ttg.memdesc<16x16xi8, #shared, #smem>) { + %0 = ttg.local_load %shmem : !ttg.memdesc<16x16xi8, #shared, #smem> -> tensor<16x16xi8, #blocked> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @vectorize_shmem_store + // CHECK: llvm.store + // CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3> + // CHECK-NOT: llvm.store + tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) { + %0 = ttg.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !ttg.memdesc<64x64xi32, #shared, #smem> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: abs_is_int_min_poison + // CHECK: %{{.*}} = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32 + tt.func @abs_is_int_min_poison(%arg0 : tensor<256xi32, #blocked0>) { + %abs = math.absi %arg0 : tensor<256xi32, #blocked0> + tt.return + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: test_local_load_bf16 + // CHECK: llvm.extractelement {{.*}} : vector<8xbf16> + tt.func public @test_local_load_bf16() { + %c0_i32 = arith.constant 0 : i32 + %19 = ttg.local_alloc : () -> !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> + %22 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> -> !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> + %39 = ttg.local_load %22 : !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> -> tensor<1x2048xbf16, #blocked> + %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> + tt.return + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: test_local_store + // CHECK: llvm.store + tt.func public @test_local_store(%arg0: tensor<1xf32, #blocked>) { + %c0_i32 = arith.constant 0 : i32 + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + ttg.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + tt.return + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_st + // CHECK: nvgpu.tensor_memory_base + // CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32 + // CHECK: tcgen05.wait::st.sync.aligned + tt.func public @tensor_memory_st(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %true = arith.constant true + ttng.tmem_store %cst_0, %0, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: test_local_store_subview + // CHECK: llvm.store + tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) { + %c0_i32 = arith.constant 0 : i32 + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + %sv = ttg.memdesc_subview %0[%c0_i32] : !ttg.memdesc<1xf32, #shared, #smem, mutable> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + ttg.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: print_ptr + // CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 + tt.func @print_ptr(%arg0 : tensor<256x!tt.ptr, #blocked0>) { + tt.print "ptr: " {hex = false, isSigned = array} : %arg0 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // Test that %u format specifier is used if isSigned is false + // CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %u{{.*}}") + // CHECK-LABEL: print_int32_tensor_issigned_off + // CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 + tt.func @print_int32_tensor_issigned_off(%arg0 : i32) { + tt.print "int32 tensor: " {hex = false, isSigned = array} : %arg0 : i32 + tt.return + } +} + +// ----- +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // Test that %i format specifier is used if isSigned is true + // CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %i{{.*}}") + // CHECK-LABEL: print_int32_tensor_issigned_on + // CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 + tt.func @print_int32_tensor_issigned_on(%arg0 : i32) { + tt.print "int32 tensor: " {hex = false, isSigned = array} : %arg0 : i32 + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @int32_to_bf16(%arg0: tensor<256xi32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: @int32_to_bf16 + // CHECK: llvm.sitofp %{{.*}} : i32 to bf16 + %a = arith.sitofp %arg0 : tensor<256xi32, #blocked> to tensor<256xbf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @bf16_to_int32(%arg0: tensor<256xbf16, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: @bf16_to_int32 + // CHECK: llvm.fptosi %{{.*}} : bf16 to i32 + %a = arith.fptosi %arg0 : tensor<256xbf16, #blocked> to tensor<256xi32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32} +// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32} +// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32} +// CHECK: llvm.call @__assertfail +// CHECK: nvvm.barrier0 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) { + tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5) + tt.return + } +} +#loc1 = loc("outer_call":33:8) +#loc2 = loc("top_func":47:8) +#loc3 = loc("inner_call":29:28) +#loc4 = loc(callsite(#loc3 at #loc1)) +#loc5 = loc(callsite(#loc4 at #loc2)) + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @log1pf_scan(%39: tensor<32x16xf32, #blocked>) attributes {noinline = false} { + // CHECK: log1pf_scan + // non-speculatable ops will introduce a cond_br; extern_elementwise with pure = true should be considered speculatable. + // CHECK-NOT: llvm.cond_br + %40 = "tt.scan"(%39) <{axis = 1 : i32, reverse = false}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %43 = tt.extern_elementwise %arg5 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (f32) -> f32 + %44 = arith.addf %43, %43 : f32 + tt.scan.return %44 : f32 + }) : (tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked> + tt.return + } +} + +// ----- + +// CHECK: inline_asm_pack +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // check specifically for the case where asm has two results, pack > 1, and the result bitwidth is < 32 + tt.func public @inline_asm_pack(%80: tensor<64x64xi8, #blocked>) attributes {noinline = false} { + // CHECK: llvm.inline_asm asm_dialect {{.*}} (vector<4xi8>) -> !llvm.struct<(vector<2xbf16>, vector<2xbf16>, vector<2xbf16>, vector<2xbf16>)> + %83:2 = tt.elementwise_inline_asm "" {constraints = "=r,=r,=r,=r,r", packed_element = 4 : i32, pure = true} %80 : tensor<64x64xi8, #blocked> -> tensor<64x64xbf16, #blocked>, tensor<64x64xbf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4xf32, #blocked>) { + // CHECK-LABEL: gather_in_shared + + // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] + + // CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]] + // CHECK: store [[S0]] + // CHECK-NEXT: nvvm.barrier0 + + // CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0] + + // CHECK: [[IDX:%.*]] = llvm.add {{.*}}, [[I0]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM]][[[IDX]]] + // CHECK-NEXT: [[OUT0:%.*]] = llvm.load [[PTR]] + + // CHECK: insertvalue [[OUT0]], {{.*}}[0] + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #blocked>, tensor<16x4xi32, #blocked1>) -> tensor<16x4xf32, #blocked1> + tt.return +} + +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}> +#dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) { + // CHECK-LABEL: gather_in_shared_dot_input + + // CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0] + // CHECK: [[S1:%.*]] = llvm.extractvalue %arg1[1] + // CHECK: [[S2:%.*]] = llvm.extractvalue %arg1[2] + // CHECK: [[S3:%.*]] = llvm.extractvalue %arg1[3] + + // CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]] + // CHECK: store [[S0]] + // CHECK: store [[S1]] + // CHECK: store [[S2]] + // CHECK: store [[S3]] + // CHECK-NEXT: nvvm.barrier0 + + // CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0] + + // CHECK: [[IDX:%.*]] = llvm.add {{.*}}, [[I0]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM]][[[IDX]]] + // CHECK-NEXT: [[OUT0:%.*]] = llvm.load [[PTR]] + + // CHECK: insertvalue [[OUT0]], {{.*}}[0] + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #dot>, tensor<16x4xi32, #blocked>) -> tensor<16x4xf32, #blocked> + tt.return +} + +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + + tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { + // CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1 + // CHECK: llvm.sitofp %{{.*}} : i8 to f16 + %2 = arith.sitofp %1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + tt.return +} + +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { + // CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0 + // CHECK: llvm.sitofp %{{.*}} : i8 to f16 + %2 = arith.sitofp %1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [4, 4, 2], warpsPerCTA = [8, 1, 1], order = [2, 1, 0]}> +#linear = #ttg.linear<{register = [[0, 0], [0, 0], [0, 0], [0, 0]], lane = [[0, 0], [0, 1], [0, 2], [1, 0], [2, 0]], warp = [[4, 0], [8, 0], [16, 0]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: expand_dims_linear_layout +tt.func private @expand_dims_linear_layout() -> tensor<1x4xi32, #linear> { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x4xi32, #linear> + // CHECK: return %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + tt.return %1 : tensor<1x4xi32, #linear> +} + +// CHECK-LABEL: reshape_linear_layout_broadcasting +tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #linear>) -> tensor<32x4x1xbf16, #blocked> { + // CHECK-COUNT-16: extractvalue + // CHECK-COUNT-16: insertvalue + %0 = tt.reshape %arg0 : tensor<32x4xbf16, #linear> -> tensor<32x4x1xbf16, #blocked> + tt.return %0 : tensor<32x4x1xbf16, #blocked> +} + +} + + +// ----- + +#linear1 = #ttg.linear<{register = [[0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [16, 0, 0, 0], [32, 0, 0, 0], [64, 0, 0, 0]], lane = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0]], warp = [[4, 0, 0, 0], [8, 0, 0, 0]], block = []}> +#linear2 = #ttg.linear<{register = [[0, 0, 1], [0, 1, 0], [16, 0, 0], [32, 0, 0], [64, 0, 0]], lane = [[0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 0, 0]], warp = [[4, 0, 0], [8, 0, 0]], block = []}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: split_linear +tt.func @split_linear(%arg : tensor<128x2x2x2xf32, #linear1>) { + // CHECK: %[[E0:.+]] = llvm.extractvalue %{{.*}}[0] + // CHECK: %[[E1:.+]] = llvm.extractvalue %{{.*}}[1] + // CHECK: %[[E2:.+]] = llvm.extractvalue %{{.*}}[2] + // CHECK: %[[E3:.+]] = llvm.extractvalue %{{.*}}[3] + // CHECK: llvm.insertvalue %[[E0]], %{{.*}}[0] + // CHECK: llvm.insertvalue %[[E2]], %{{.*}}[1] + // CHECK: llvm.insertvalue %[[E1]], %{{.*}}[0] + // CHECK: llvm.insertvalue %[[E3]], %{{.*}}[1] + %outLHS, %outRHS = tt.split %arg : tensor<128x2x2x2xf32, #linear1> -> tensor<128x2x2xf32, #linear2> + tt.return +} +} diff --git a/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_blackwell.mlir new file mode 100644 index 000000000..a4890fc95 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -0,0 +1,411 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=100 -cse | FileCheck %s + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @tc_gen5_mma + // CHECK: %[[WID:.+]] = nvgpu.warp_id + // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32 + // CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1 + // CHECK: llvm.cond_br %[[P1]] + // CHECK: %[[E:.+]] = nvvm.elect.sync -> i1 + // CHECK-COUNT-8: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[E]] + // CHECK: @$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l" %[[E]] + tt.func @tc_gen5_mma(%a: !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) { + ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier : + (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, + i1, i1, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) -> () + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @tc_gen5_mma_multi_m_n + // CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32 + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: %[[C64:.+]] = llvm.mlir.constant(64 : i32) : i32 + // CHECK-DAG: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32 + // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T0]] + // CHECK: %[[T1:.+]] = llvm.add %[[TMEM_BASE]], %[[C64]] : i32 + // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T1]] + // 1048576 = row << 16 + col = 16 << 16 + 0 + // CHECK: %[[C1048576:.+]] = llvm.mlir.constant(1048576 : i32) : i32 + // CHECK: %[[T2:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048576]] : i32 + // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T2]] + // 1048640 = row << 16 + col = 16 << 16 + 64 + // CHECK: %[[C1048640:.+]] = llvm.mlir.constant(1048640 : i32) : i32 + // CHECK: %[[T3:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048640]] : i32 + // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T3]] + + tt.func @tc_gen5_mma_multi_m_n(%a: !ttg.memdesc<128x16xf16, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<16x128xf16, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) { + ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier : + (!ttg.memdesc<128x16xf16, #shared, #ttg.shared_memory>, + !ttg.memdesc<16x128xf16, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, + i1, i1, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) -> () + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 2], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [2], CTASplitNum = [1], CTAOrder = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @tc_gen5_mma_multi_ctas + // CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32 + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: %[[C32:.+]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK-DAG: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32 + // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T0]] + // CHECK: %[[T1:.+]] = llvm.add %[[TMEM_BASE]], %[[C32]] : i32 + // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T1]] + // 1048576 = row << 16 + col = 16 << 16 + 0 + // CHECK: %[[C1048576:.+]] = llvm.mlir.constant(1048576 : i32) : i32 + // CHECK: %[[T2:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048576]] : i32 + // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T2]] + // 1048640 = row << 16 + col = 16 << 16 + 32 + // CHECK: %[[C1048608:.+]] = llvm.mlir.constant(1048608 : i32) : i32 + // CHECK: %[[T3:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048608]] : i32 + // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T3]] + + tt.func @tc_gen5_mma_multi_ctas(%a: !ttg.memdesc<128x16xf16, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<16x128xf16, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) { + ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier : + (!ttg.memdesc<128x16xf16, #shared, #ttg.shared_memory>, + !ttg.memdesc<16x128xf16, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, + i1, i1, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) -> () + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_ld + // CHECK: nvgpu.tensor_memory_base + // CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32 + // CHECK: tcgen05.wait::st.sync.aligned + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.ld.sync.aligned.32x32b.x128.b32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63, $64, $65, $66, $67, $68, $69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85, $86, $87, $88, $89, $90, $91, $92, $93, $94, $95, $96, $97, $98, $99, $100, $101, $102, $103, $104, $105, $106, $107, $108, $109, $110, $111, $112, $113, $114, $115, $116, $117, $118, $119, $120, $121, $122, $123, $124, $125, $126, $127}, [$128];", "=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,r" %{{.*}} : (i32) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: tcgen05.wait::ld.sync.aligned + tt.func public @tensor_memory_ld(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %20 = ttng.tmem_load %0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + tt.return + } +} + +// ----- + +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_allocation + // CHECK: llvm.mlir.constant(4194306 : i32) : i32 + tt.func public @tensor_memory_allocation() { + %0 = ttng.tmem_alloc {tensor_memory_col_offset = 2 : i32, tensor_memory_row_offset = 64 : i32} : () -> !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_ld_m64 + // CHECK: nvgpu.tensor_memory_base + // CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32 + // CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32 + // CHECK: tcgen05.wait::st.sync.aligned + // CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32 + // CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32 + // CHECK: tcgen05.wait::ld.sync.aligned + tt.func public @tensor_memory_ld_m64(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %20 = ttng.tmem_load %0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_unpack_f16 + // CHECK: nvgpu.tensor_memory_base + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.unpack::16b.b32 + // CHECK: tcgen05.wait::st.sync.aligned + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "tcgen05.ld.sync.aligned.32x32b.x64.pack::16b.b32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51, $52, $53, $54, $55, $56, $57, $58, $59, $60, $61, $62, $63}, [$64];", "=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,=r,r" %{{.*}} : (i32) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: tcgen05.wait::ld.sync.aligned + tt.func public @tensor_memory_unpack_f16() { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked1> + %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> + %20 = ttng.tmem_load %0 : !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf16, #blocked1> + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @tc_gen5_mma_block_scale + // CHECK-SAME: (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[USE_ACC:.+]]: i1, %{{.*}}: i1, %{{.*}}) + // CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32 + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: %[[C32:.+]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32 + // CHECK: %[[DESC0:.+]] = llvm.mlir.constant(144774144 : i32) : i32 + // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]], %{{.+}}, %{{.+}}, %[[USE_ACC]] + // CHECK: %[[TRUE:.+]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[DESC1:.+]] = llvm.mlir.constant(681645072 : i32) : i32 + // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC1]], %{{.+}}, %{{.+}}, %[[TRUE]] + tt.func @tc_gen5_mma_block_scale(%a: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<32x128xi8, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, + %scale_a: !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>, + %scale_b: !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) { + ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e4m3 rhs = e2m1, %barrier : + (!ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, + !ttg.memdesc<32x128xi8, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, + !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>, + !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>, + i1, + i1, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) -> () + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [2], CTASplitNum = [1], CTAOrder = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + tt.func @tc_gen5_mma_2ctas(%a: !ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<256x128xf32, #tmem, #ttng.tensor_memory, mutable>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) { + // CHECK: tcgen05.mma.cta_group::2.kind::f16 + // CHECK: tcgen05.mma.cta_group::2.kind::f16 + // CHECK: tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 + ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier {two_ctas} : + (!ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory>, + !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory>, + !ttg.memdesc<256x128xf32, #tmem, #ttng.tensor_memory, mutable>, + i1, i1, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) -> () + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], order=[0, 1]}> +#shared = #ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=1, order=[1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tmem_copy_2d(%src: !ttg.memdesc<256x16xi8, #shared, #ttg.shared_memory>, + %dst: !ttg.memdesc<128x32xi32, #tmem, #ttng.tensor_memory, mutable>, + %barrier: !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory>) { + // CHECK-COUNT-8: tcgen05.cp.cta_group::1.warpx4.32x128b + // CHECK: tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 + ttng.tmem_copy %src, %dst, %barrier : (!ttg.memdesc<256x16xi8, #shared, #ttg.shared_memory>, !ttg.memdesc<128x32xi32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory>) -> () + + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @tc_gen5_mma_block_scale_nvfp4 + // CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32 + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32 + // CHECK: %[[DESC0:.+]] = llvm.mlir.constant(138413184 : i32) : i32 + // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]] + // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]] + tt.func @tc_gen5_mma_block_scale_nvfp4(%a: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<64x256xi8, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>, + %scale_a: !ttg.memdesc<128x8xf8E4M3FN, #tmem_scales, #ttng.tensor_memory>, + %scale_b: !ttg.memdesc<256x8xf8E4M3FN, #tmem_scales, #ttng.tensor_memory>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) { + ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e2m1, %barrier : + (!ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, + !ttg.memdesc<64x256xi8, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>, + !ttg.memdesc<128x8xf8E4M3FN, #tmem_scales, #ttng.tensor_memory>, + !ttg.memdesc<256x8xf8E4M3FN, #tmem_scales, #ttng.tensor_memory>, + i1, + i1, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) -> () + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @tc_gen5_mma_block_scale_mxfp4 + // CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32 + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32 + // CHECK: %[[DESC0:.+]] = llvm.mlir.constant(146801792 : i32) : i32 + // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC0]] + // CHECK: %[[DESC1:.+]] = llvm.mlir.constant(1220543648 : i32) : i32 + // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0]], %{{.+}}, %{{.+}}, %[[DESC1]] + tt.func @tc_gen5_mma_block_scale_mxfp4(%a: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<64x256xi8, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>, + %scale_a: !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>, + %scale_b: !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>, + %useAcc: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) { + ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e2m1, %barrier : + (!ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, + !ttg.memdesc<64x256xi8, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>, + !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>, + !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>, + i1, + i1, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>) -> () + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#tmem = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_ld_128x256 + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.wait::st.sync.aligned + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32 + // CHECK: tcgen05.wait::ld.sync.aligned + tt.func public @tensor_memory_ld_128x256(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked1> + %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked1>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> + %20 = ttng.tmem_load %0 : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#tmem = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_memory_ld_128x256_8_warps + // CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32 + // CHECK: tcgen05.wait::st.sync.aligned + // CHECK: tcgen05.ld.sync.aligned.32x32b.x128.b32 + // CHECK: tcgen05.wait::ld.sync.aligned + tt.func public @tensor_memory_ld_128x256_8_warps(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked1> + %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked1>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> + %20 = ttng.tmem_load %0 : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked1> + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding +#tmem1 = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + tt.func @tc_gen5_mma_lhs_tmem(%arg0: !ttg.memdesc<128x32xf16, #tmem, #ttng.tensor_memory>, %arg1: !ttg.memdesc<32x128xf16, #shared, #smem>, %arg2: !ttg.memdesc<128x128xf32, #tmem1, #ttng.tensor_memory, mutable>, %arg3: i1, %arg4: i1, %arg5: !ttg.memdesc<1xi64, #shared1, #smem>) { + // CHECK-LABEL: tc_gen5_mma_lhs_tmem + // CHECK: tcgen05.mma.cta_group::1.kind::f16 + ttng.tc_gen5_mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : ( + !ttg.memdesc<128x32xf16, #tmem, #ttng.tensor_memory>, + !ttg.memdesc<32x128xf16, #shared, #smem>, + !ttg.memdesc<128x128xf32, #tmem1, #ttng.tensor_memory, mutable>, + i1, i1, !ttg.memdesc<1xi64, #shared1, #smem>) -> () + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir new file mode 100644 index 000000000..a02db0f06 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir @@ -0,0 +1,47 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s + +// CHECK-LABEL: blocked_to_dot_op_shortcut_warp32 +#blocked = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>) { + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_warp64 +#blocked = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) { + %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32 +#blocked = #ttg.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) { + %0 = ttg.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64 +#blocked = #ttg.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) { + %0 = ttg.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_hopper.mlir b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_hopper.mlir new file mode 100644 index 000000000..234673676 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -0,0 +1,333 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' 2>&1 | FileCheck %s + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_high_precision_acc + tt.func @dot_high_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + %m = ttng.warp_group_dot %a, %b, %c + {maxNumImpreciseAcc = 32 : i32, inputPrecision = 0 : i32} : + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_low_precision_acc + tt.func @dot_low_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: llvm.return + %m = ttng.warp_group_dot %a, %b, %c + {maxNumImpreciseAcc = 129 : i32, inputPrecision = 0 : i32} : + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_mix_precision_acc + tt.func @dot_mix_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: llvm.return + %m = ttng.warp_group_dot %a, %b, %c + {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @dot_zero_acc + // Generate a wgmma with 2 sources. + // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { + tt.func @dot_zero_acc(%a: !ttg.memdesc<128x64xf16, #shared, #smem>, %b: !ttg.memdesc<64x64xf16, #shared1, #smem>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %m = ttng.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : + !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @dot_reg_operand_A + // Generate a wgmma where the first operand is a struct. + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !ttg.memdesc<64x64xf16, #shared, #smem>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %opA = ttg.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %m = ttng.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: + tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @dot_reg_operand_A_fp8 + // Generate a wgmma where the first operand is a struct. + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} + tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !ttg.memdesc<128x256xf8E5M2, #shared, #smem>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> + %m = ttng.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : + tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<128x256xf8E5M2, #shared, #smem> -> tensor<128x256xf32, #mma1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: dot_reg_operand_upcast + tt.func @dot_reg_operand_upcast(%a_desc: !ttg.memdesc<128x64xi8, #shared, #smem>, %b: !ttg.memdesc<64x64xf16, #shared, #smem>, %acc: tensor<128x64xf32, #mma>) { + %a_dotop = ttg.local_load %a_desc : !ttg.memdesc<128x64xi8, #shared, #smem> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %res = ttng.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: test_fp8_to_f16_conversion + tt.func @test_fp8_to_f16_conversion( + %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>, + %in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) { + // CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> + %out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked> + // CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> + %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked> + // CHECK-COUNT-2: mul.rn.bf16x2 + %out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked> + + // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> + %out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> + // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> + %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked> + + // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> + %out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> + // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> + %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK-LABEL: clamp +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @clamp(%x : tensor<1024xf32, #blocked>, %limit : tensor<1024xf32, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked> + %neg_limit = arith.subf %cst, %limit : tensor<1024xf32, #blocked> + + // CHECK-COUNT-8: nvvm.fmin.xorsign.abs.f + %12 = tt.clampf %x, %neg_limit, %limit, propagateNan = none : tensor<1024xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 16]}> +// CHECK-LABEL: convert_mma_to_blocked +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @convert_mma_to_blocked(%a: tensor<128x256xf16, #mma>) { + // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK: nvvm.barrier0 + %c = ttg.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +// CHECK-LABEL: cvt_mma_to_dot_fp8 +// CHECK: nvvm.prmt +// CHECK: nvvm.prmt +// CHECK: nvvm.shfl.sync +// CHECK: nvvm.shfl.sync +// CHECK: nvvm.prmt +// CHECK: nvvm.prmt + tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { + %opA = ttg.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +// CHECK-LABEL: dot_zero_acc_operand +// CHECK-COUNT-128: llvm.fadd + tt.func @dot_zero_acc_operand(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x128xf8E5M2, #shared1, #smem>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %m = ttng.warp_group_dot %a, %b, %cst {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x128xf8E5M2, #shared1, #smem> -> tensor<128x128xf32, #mma> + tt.return + } +} + + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#smem = #ttg.shared_memory +// CHECK-LABEL: distribute_to_shared_st_matrix +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @distribute_to_shared_st_matrix(%a: tensor<128x128xf16, #mma>) { + // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK: llvm.return + %b = ttg.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#smem = #ttg.shared_memory +// CHECK-LABEL: distribute_to_shared_st_matrix_local_store +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<128x128xf16, #mma>) { + // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK: llvm.return + %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + ttg.local_store %a, %b : tensor<128x128xf16, #mma> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: @fp8_const + // CHECK: llvm.mlir.constant(0.000000e+00 : f8E4M3FNUZ) : i8 + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf8E4M3FNUZ, #blocked> + %a = arith.select %arg0, %arg1, %cst : tensor<1024xi1, #blocked>, tensor<1024xf8E4M3FNUZ, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_nomask + // CHECK: atom.global.gpu.acq_rel.add.v4.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_withmask + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: test_fp8_to_fp16_dot_operand + // CHECK-COUNT-16: cvt.rn.f16x2.e5m2x2 + tt.func @test_fp8_to_fp16_dot_operand(%arg: tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) { + %r = tt.fp_to_fp %arg : tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir new file mode 100644 index 000000000..1003f321d --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir @@ -0,0 +1,44 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=80' 2>&1 | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_nomask + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_withmask + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_volta.mlir b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_volta.mlir new file mode 100644 index 000000000..a5a428129 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_llvm_volta.mlir @@ -0,0 +1,18 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=70 2>&1 | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK-LABEL: clamp +module attributes {"ttg.target" = "cuda:70", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @clamp(%x : tensor<1024xf32, #blocked>, %limit : tensor<1024xf32, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked> + %neg_limit = arith.subf %cst, %limit : tensor<1024xf32, #blocked> + + // CHECK: llvm.fcmp "une" %[[REG:[a-zA-Z0-9]+]], %[[REG]] + // CHECK-NEXT: llvm.intr.maxnum + // CHECK-NEXT: llvm.intr.minnum + // CHECK-NEXT: llvm.mlir.constant + // CHECK-NEXT: llvm.select + %12 = tt.clampf %x, %neg_limit, %limit, propagateNan = all : tensor<1024xf32, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/tritongpu_to_ptx.mlir b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_ptx.mlir new file mode 100644 index 000000000..ea0109c82 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tritongpu_to_ptx.mlir @@ -0,0 +1,86 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=83' --convert-nv-gpu-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 -S | llc -mtriple nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx83 | FileCheck --check-prefixes CHECK,SM90 --dump-input-context=20 %s +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=80 ptx-version=83' --convert-nv-gpu-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 -S | llc -mtriple nvptx64-nvidia-cuda -mcpu=sm_80 -mattr=+ptx83 | FileCheck --check-prefixes CHECK,SM80 --dump-input-context=20 %s + + +#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @add_bf16(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %arg0: tensor<256xbf16, #blocked>, %arg1: tensor<256xbf16, #blocked>) { + // CHECK-LABEL: add_bf16 + // SM80-COUNT-4: fma.rn.bf16x2 + // SM90-COUNT-4: add.rn.bf16x2 + %0 = arith.addf %arg0, %arg1 : tensor<256xbf16, #blocked> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> + %2 = tt.splat %ptr : !tt.ptr -> tensor<256x!tt.ptr, #blocked> + %3 = tt.addptr %2, %1 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> + tt.store %3, %0 : tensor<256x!tt.ptr, #blocked> + tt.return + } + + tt.func public @sub_bf16(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %arg0: tensor<256xbf16, #blocked>, %arg1: tensor<256xbf16, #blocked>) { + // CHECK-LABEL: sub_bf16 + // SM80-COUNT-4: fma.rn.bf16x2 + // SM90-COUNT-4: sub.rn.bf16x2 + %0 = arith.subf %arg0, %arg1 : tensor<256xbf16, #blocked> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> + %2 = tt.splat %ptr : !tt.ptr -> tensor<256x!tt.ptr, #blocked> + %3 = tt.addptr %2, %1 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> + tt.store %3, %0 : tensor<256x!tt.ptr, #blocked> + tt.return + } + + tt.func public @mul_bf16(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %arg0: tensor<256xbf16, #blocked>, %arg1: tensor<256xbf16, #blocked>) { + // CHECK-LABEL: mul_bf16 + // SM80-COUNT-4: fma.rn.bf16x2 + // SM90-COUNT-4: mul.rn.bf16x2 + %0 = arith.mulf %arg0, %arg1 : tensor<256xbf16, #blocked> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> + %2 = tt.splat %ptr : !tt.ptr -> tensor<256x!tt.ptr, #blocked> + %3 = tt.addptr %2, %1 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> + tt.store %3, %0 : tensor<256x!tt.ptr, #blocked> + tt.return + } + + tt.func public @extf_bf16(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %arg0: tensor<256xbf16, #blocked>) { + // CHECK-LABEL: extf_bf16 + // CHECK-COUNT-8: cvt.f32.bf16 + %0 = arith.extf %arg0 : tensor<256xbf16, #blocked> to tensor<256xf32, #blocked> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> + %2 = tt.splat %ptr : !tt.ptr -> tensor<256x!tt.ptr, #blocked> + %3 = tt.addptr %2, %1 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> + tt.store %3, %0 : tensor<256x!tt.ptr, #blocked> + tt.return + } + + tt.func public @truncf_bf16(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %arg0: tensor<256xf32, #blocked>) { + // CHECK-LABEL: truncf_bf16 + // CHECK-COUNT-4: cvt.rn.bf16x2.f32 + %0 = arith.truncf %arg0 : tensor<256xf32, #blocked> to tensor<256xbf16, #blocked> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> + %2 = tt.splat %ptr : !tt.ptr -> tensor<256x!tt.ptr, #blocked> + %3 = tt.addptr %2, %1 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> + tt.store %3, %0 : tensor<256x!tt.ptr, #blocked> + tt.return + } + + tt.func public @extf_f16(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %arg0: tensor<256xf16, #blocked>) { + // CHECK-LABEL: extf_f16 + // CHECK-COUNT-8: cvt.f32.f16 + %0 = arith.extf %arg0 : tensor<256xf16, #blocked> to tensor<256xf32, #blocked> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> + %2 = tt.splat %ptr : !tt.ptr -> tensor<256x!tt.ptr, #blocked> + %3 = tt.addptr %2, %1 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> + tt.store %3, %0 : tensor<256x!tt.ptr, #blocked> + tt.return + } + + tt.func public @truncf_f16(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %arg0: tensor<256xf32, #blocked>) { + // CHECK-LABEL: truncf_f16 + // CHECK-COUNT-4: cvt.rn.f16x2.f32 + %0 = arith.truncf %arg0 : tensor<256xf32, #blocked> to tensor<256xf16, #blocked> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> + %2 = tt.splat %ptr : !tt.ptr -> tensor<256x!tt.ptr, #blocked> + %3 = tt.addptr %2, %1 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> + tt.store %3, %0 : tensor<256x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/tritonnvidiagpu_to_llvm.mlir new file mode 100644 index 000000000..a040fe436 --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -0,0 +1,164 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: init_barrier + tt.func @init_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) { + // CHECK: "@$0 mbarrier.init.shared::cta.b64 [$1], 1;", "b,r" %{{.*}}, %{{.*}} : (i1, !llvm.ptr<3>) -> !llvm.void + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem> + tt.return + } +} + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: wait_barrier + tt.func @wait_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %phase: i32) { + // CHECK: waitLoop: + // CHECK: mbarrier.try_wait.parity.shared.b64 + // CHECK: @!P1 bra.uni waitLoop + ttng.wait_barrier %alloc, %phase : !ttg.memdesc<1xi64, #shared0, #smem> + tt.return + } +} + + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: tma_copy_global_to_local + // CHECK: elect.sync + // CHECK: "@$0 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4}], [$5];", "b,r,l,r,r,r" {{.*}} : (i1, !llvm.ptr<3>, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void + // CHECK-NOT: cp.async.bulk.tensor.2d.shared + // CHECK: return + tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { + ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable> + tt.return + } +} + +// ----- + +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: tma_copy_local_to_global + // CHECK: elect.sync + // CHECK: "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r" {{.*}} : (i1, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void + // CHECK-NOT: cp.async.bulk.tensor.2d.global.shared::cta.bulk_group + // CHECK: nvvm.cp.async.bulk.commit.group + tt.func @tma_copy_local_to_global(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem>, %x: i32) { + ttng.async_tma_copy_local_to_global %tma[%x, %x] %alloc : !tt.ptr, !ttg.memdesc<128x128xf32, #shared1, #smem> + tt.return + } +} + +// ----- + +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: async_tma_store_wait + // CHECK: nvvm.cp.async.bulk.wait_group 0 {read} + tt.func @async_tma_store_wait() { + ttng.async_tma_store_wait {pendings = 0 : i32} + tt.return + } +} + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: expect_barrier + // CHECK: @$0 mbarrier.arrive.expect_tx.shared.b64 _, [$1], 16384; + tt.func @expect_barrier(%barrier: !ttg.memdesc<1xi64, #shared0, #smem, mutable>, %pred: i1) { + ttng.barrier_expect %barrier, 16384, %pred : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: byval_tma_desc + // CHECK: llvm.align = 64 + // CHECK: llvm.byval = !llvm.array<128 x i8> + // CHECK: nvvm.grid_constant + tt.func @byval_tma_desc(%desc: !tt.ptr {tt.nv_tma_desc = 1 : i32}) { + tt.return + } +} + +// ----- + +// CHECK-LABEL: device_tensormap_create1d +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @device_tensormap_create1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c256_i32 = arith.constant 256 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK: st.shared.b32 + // CHECK: bar.warp.sync + // CHECK: tensormap.replace.tile.global_address.shared::cta.b1024.b64 [ $0 + 0 ], $1; + // CHECK: tensormap.replace.tile.rank.shared::cta.b1024.b32 [ $0 + 0 ], 0x0; + // CHECK: tensormap.replace.tile.box_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.element_stride.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.elemtype.shared::cta.b1024.b32 [ $0 + 0 ], 0x3; + // CHECK: tensormap.replace.tile.interleave_layout.shared::cta.b1024.b32 [ $0 + 0 ], 0x0; + // CHECK: tensormap.replace.tile.swizzle_mode.shared::cta.b1024.b32 [ $0 + 0 ], 0x2; + // CHECK: tensormap.replace.tile.fill_mode.shared::cta.b1024.b32 [ $0 + 0 ], 0x1; + // CHECK: tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [ $0 + 0 ], [ $1 + 0 ], 0x80; + tt.experimental_tensormap_create %arg1, %arg0, [%c256_i32], [%arg2], [], [%c1_i32] {elem_type = 3 : i32, fill_mode = 1 : i32, interleave_layout = 0 : i32, swizzle_mode = 2 : i32, allocation.offset = 0 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32) -> () + tt.return + } +} + +// ----- + +// CHECK-LABEL: device_tensormap_create2d +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @device_tensormap_create2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c256_i32 = arith.constant 256 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1024_i64 = arith.constant 1024 : i64 + // CHECK: st.shared.b32 + // CHECK: bar.warp.sync + // CHECK: tensormap.replace.tile.global_address.shared::cta.b1024.b64 [ $0 + 0 ], $1; + // CHECK: tensormap.replace.tile.rank.shared::cta.b1024.b32 [ $0 + 0 ], 0x1; + // CHECK: tensormap.replace.tile.box_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.box_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x1, $1; + // CHECK: tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x1, $1; + // CHECK: tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.element_stride.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.element_stride.shared::cta.b1024.b32 [ $0 + 0 ], 0x1, $1; + // CHECK: tensormap.replace.tile.elemtype.shared::cta.b1024.b32 [ $0 + 0 ], 0x3; + // CHECK: tensormap.replace.tile.interleave_layout.shared::cta.b1024.b32 [ $0 + 0 ], 0x0; + // CHECK: tensormap.replace.tile.swizzle_mode.shared::cta.b1024.b32 [ $0 + 0 ], 0x2; + // CHECK: tensormap.replace.tile.fill_mode.shared::cta.b1024.b32 [ $0 + 0 ], 0x1; + // CHECK: tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [ $0 + 0 ], [ $1 + 0 ], 0x80; + tt.experimental_tensormap_create %arg1, %arg0, [%c256_i32, %c256_i32], [%arg2, %arg2], [%c1024_i64], [%c1_i32, %c1_i32] {elem_type = 3 : i32, fill_mode = 1 : i32, interleave_layout = 0 : i32, swizzle_mode = 2 : i32, allocation.offset = 0 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () + tt.return + } +} + +// ----- + +// CHECK-LABEL: tensormap_fenceproxy_acquire +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tensormap_fenceproxy_acquire(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + // CHECK: fence.proxy.tensormap::generic.acquire.gpu [ $0 + 0 ], 0x80; + tt.experimental_tensormap_fenceproxy_acquire %arg0 : !tt.ptr + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/Conversion/warp_specialize_to_llvm.mlir b/third_party/enflame/include/triton/test/Conversion/warp_specialize_to_llvm.mlir new file mode 100644 index 000000000..a0356038f --- /dev/null +++ b/third_party/enflame/include/triton/test/Conversion/warp_specialize_to_llvm.mlir @@ -0,0 +1,626 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -convert-warp-specialize-to-llvm | FileCheck %s + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 11 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @rewrite_barriers +llvm.func @rewrite_barriers() attributes {allocation.offset = 32 : i32} { + // CHECK: barrier.sync.aligned 2, 128 + // CHECK: barrier.sync.aligned 3, 64 + // CHECK: barrier.sync.aligned 4, 32 + + // CHECK: bb{{[0-9]+}}: + // CHECK-NEXT: barrier.sync.aligned 0, 128 + nvvm.barrier0 + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + // CHECK: barrier.sync.aligned 0, 128 + nvvm.barrier0 + ttg.warp_yield + } + partition0() num_warps(4) { + nvvm.barrier0 + ttg.warp_return + } + partition1() num_warps(2) { + nvvm.barrier0 + ttg.warp_return + } + partition2() num_warps(1) { + nvvm.barrier0 + ttg.warp_return + } : () -> () + // CHECK: barrier.sync.aligned 0, 128 + nvvm.barrier0 + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 11 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @generate_switch_loop +llvm.func @generate_switch_loop() attributes {allocation.offset = 32 : i32} { + // CHECK-NEXT: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x + // CHECK-NEXT: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-NEXT: [[WID:%.*]] = llvm.udiv [[TIDX]], [[C32]] + // CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32) + // CHECK-NEXT: [[C31:%.*]] = llvm.mlir.constant(31 : i32) + // CHECK-NEXT: [[CNEG1:%.*]] = llvm.mlir.constant(-1 : i32) + // CHECK-NEXT: [[WARP_ID:%.*]] = nvvm.shfl.sync idx [[CNEG1]], [[WID]], [[C0]], [[C31]] + // CHECK-NEXT: [[C4:%.*]] = llvm.mlir.constant(4 : i32) + // CHECK-NEXT: [[IS_DEFAULT:%.*]] = llvm.icmp "ult" [[WARP_ID]], [[C4]] + // CHECK-NEXT: llvm.cond_br [[IS_DEFAULT]], [[BODY:\^.*]], [[SWITCH_LOOP:\^.*]] + + // CHECK: [[SWITCH_LOOP]]: + // CHECK-NEXT: "barrier.sync 1 ;" + // CHECK-NEXT: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C32]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK-NEXT: [[C4:%.*]] = llvm.mlir.constant(4 : i32) + // CHECK-NEXT: [[REL_WID:%.*]] = llvm.sub [[WARP_ID]], [[C4]] + + // CHECK-NEXT: [[STATE_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][[[REL_WID]]] + // CHECK-NEXT: [[STATE:%.*]] = llvm.load [[STATE_PTR]] + // CHECK-NEXT: llvm.switch [[STATE]] : i8, [[DEFAULT:\^.*]] [ + // CHECK-NEXT: 0: [[PARTITION0:\^.*]], + // CHECK-NEXT: 1: [[PARTITION1:\^.*]], + // CHECK-NEXT: 2: [[PARTITION2:\^.*]], + // CHECK-NEXT: 3: [[EXIT:\^.*]] + + // CHECK: [[DEFAULT]]: + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: [[EXIT]]: + // CHECK-NEXT: llvm.return + + // CHECK: [[PARTITION0]]: + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: "partition0" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: [[PARTITION1]]: + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: "partition1" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: [[PARTITION2]]: + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: "partition2" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: [[BODY]]: + // CHECK-NEXT: "before" + // CHECK-NEXT: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C32]]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[DEFAULT_PARTITION:\^.*]] + // CHECK: [[DEFAULT_PARTITION]]: + // CHECK-NEXT: "default" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[AFTER:\^.*]] + "before"() : () -> () + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "partition0"() : () -> () + ttg.warp_return + } + partition1() num_warps(2) { + "partition1"() : () -> () + ttg.warp_return + } + partition2() num_warps(1) { + "partition2"() : () -> () + ttg.warp_return + } : () -> () + // CHECK: [[AFTER]]: + // CHECK-NEXT: "after" + + // CHECK-NEXT: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C32]]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.return + "after"() : () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 8 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @pass_captures +llvm.func @pass_captures(%arg0: i32, %arg1: i64) attributes {allocation.offset = 32 : i32} { + // CHECK: ^bb4: + // CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C0]]] + + // CHECK-NEXT: [[ARG0_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.struct + // CHECK-NEXT: [[ARG0:%.*]] = llvm.load [[ARG0_PTR]] {alignment = 1 : i64} + // CHECK-NEXT: [[ARG1_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0, 1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.struct + // CHECK-NEXT: [[ARG1:%.*]] = llvm.load [[ARG1_PTR]] {alignment = 1 : i64} + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: "use"([[ARG0]], [[ARG1]]) + // CHECK-NEXT: barrier.sync 1 ; + + // CHECK: ^bb5: + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C0]]] + // CHECK-NEXT: [[ARG0_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.struct + // CHECK-NEXT: llvm.store %arg0, [[ARG0_PTR]] {alignment = 1 : i64} + // CHECK-NEXT: [[ARG1_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0, 1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.struct + // CHECK-NEXT: llvm.store %arg1, [[ARG1_PTR]] {alignment = 1 : i64} + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: barrier.sync 1 ; + ttg.warp_specialize(%arg0, %arg1) attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + ttg.warp_yield + } + partition0(%arg2: i32, %arg3: i64) num_warps(4) { + "use"(%arg2, %arg3) : (i32, i64) -> () + ttg.warp_return + } : (i32, i64) -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 18 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @partition_warpid_order +llvm.func @partition_warpid_order() attributes {allocation.offset = 32 : i32} { + // CHECK: llvm.switch + // CHECK-NEXT: 0: [[PARTITION0:\^.*]], + // CHECK-NEXT: 1: [[PARTITION1:\^.*]], + // CHECK-NEXT: 2: [[PARTITION2:\^.*]], + // CHECK-NEXT: 3: [[EXIT:\^.*]] + + // CHECK: [[PARTITION0]]: + // CHECK: "ws0_partition0" + // CHECK: [[PARTITION1]]: + // CHECK: "ws0_partition1" + // CHECK: [[PARTITION2]]: + // CHECK: "ws0_partition2" + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[8] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[9] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[10] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[11] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[12] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[13] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws0_default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "ws0_partition0"() : () -> () + ttg.warp_return + } + partition1() num_warps(2) { + "ws0_partition1"() : () -> () + ttg.warp_return + } + partition2() num_warps(8) { + "ws0_partition2"() : () -> () + ttg.warp_return + } : () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 12 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @multiple_specialize +llvm.func @multiple_specialize() attributes {allocation.offset = 32 : i32} { + // CHECK: llvm.switch + // CHECK-NEXT: 0: [[WS0_PARTITION0:\^.*]], + // CHECK-NEXT: 1: [[WS0_PARTITION1:\^.*]], + // CHECK-NEXT: 2: [[WS0_PARTITION2:\^.*]], + // CHECK-NEXT: 3: [[WS1_PARTITION0:\^.*]], + // CHECK-NEXT: 4: [[WS1_PARTITION1:\^.*]], + // CHECK-NEXT: 5: [[WS3_PARTITION0:\^.*]], + // CHECK-NEXT: 6: [[EXIT:\^.*]] + + // CHECK: [[WS0_PARTITION0]]: + // CHECK: "ws0_partition0" + // CHECK: [[WS0_PARTITION1]]: + // CHECK: "ws0_partition1" + // CHECK: [[WS0_PARTITION2]]: + // CHECK: "ws0_partition2" + // CHECK: [[WS1_PARTITION0]]: + // CHECK: "ws1_partition0" + // CHECK: [[WS1_PARTITION1]]: + // CHECK: "ws1_partition1" + // CHECK: [[WS3_PARTITION0]]: + // CHECK: "ws3_partition0" + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK: barrier.sync 1 ; + // CHECK: "ws0_default" + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws0_default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "ws0_partition0"() : () -> () + ttg.warp_return + } + partition1() num_warps(2) { + "ws0_partition1"() : () -> () + ttg.warp_return + } + partition2() num_warps(1) { + "ws0_partition2"() : () -> () + ttg.warp_return + } : () -> () + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(4 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(4 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(4 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(4 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK: barrier.sync 1 ; + // CHECK: "ws1_default" + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws1_default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "ws1_partition0"() : () -> () + ttg.warp_return + } + partition1() num_warps(4) { + "ws1_partition1"() : () -> () + ttg.warp_return + } : () -> () + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK: barrier.sync 1 ; + // CHECK: "ws2_default" + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws2_default"() : () -> () + ttg.warp_yield + } : () -> () + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK: barrier.sync 1 ; + // CHECK: "ws3_default" + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws3_default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(8) { + "ws3_partition0"() : () -> () + ttg.warp_return + }: () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 8 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @cfg +llvm.func @cfg() attributes {allocation.offset = 32 : i32} { + // CHECK: [[SWITCH_LOOP:\^bb1]]: + // CHECK: llvm.switch + // CHECK-NEXT: 0: [[PARTITION:\^.*]], + // CHECK-NEXT: 1: [[EXIT:\^.*]] + + // CHECK: [[PARTITION]]: + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: "something"()[[[A:\^.*]], [[B:\^.*]]] + // CHECK: [[A]]: + // CHECK-NEXT: "A" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + // CHECK: [[B]]: + // CHECK-NEXT: "B" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: barrier.sync 1 ; + // CHECK: llvm.br [[DEFAULT:\^.*]] + // CHECK: [[DEFAULT]]: + // CHECK-NEXT: "something"()[[[A:\^.*]], [[B:\^.*]]] + // CHECK: [[A]]: + // CHECK-NEXT: "A" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[AFTER:\^.*]] + // CHECK: [[B]]: + // CHECK-NEXT: "B" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[AFTER]] + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "something"()[^A, ^B] : () -> () + ^A: + "A"() : () -> () + ttg.warp_yield + ^B: + "B"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "something"()[^A, ^B] : () -> () + ^A: + "A"() : () -> () + ttg.warp_return + ^B: + "B"() : () -> () + ttg.warp_return + } : () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 8 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @no_captures +llvm.func @no_captures() attributes {allocation.offset = 0 : i32} { + ttg.warp_specialize() attributes {warpGroupStartIds = array} + default { + ttg.warp_yield + } + partition0() num_warps(4) { + ttg.warp_return + } : () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.total-num-warps" = 6 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @type_conversion_results +// CHECK-NOT: !tt.ptr +// CHECK-NOT: unrealized_conversion_cast +llvm.func @type_conversion_results(%arg0: !llvm.ptr<1>) attributes {allocation.offset = 0 : i32} { + %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<1> to !tt.ptr + %1 = ttg.warp_specialize(%0) attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + // CHECK: llvm.br [[AFTER:\^.*]](%arg0 : !llvm.ptr<1>) + ttg.warp_yield %0 : !tt.ptr + } + partition0(%arg1: !tt.ptr) num_warps(2) { + %3 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to !llvm.ptr<1> + %4 = llvm.load %3 : !llvm.ptr<1> -> i32 + ttg.warp_return + } : (!tt.ptr) -> !tt.ptr + // CHECK: [[AFTER]]([[OUT:%.*]]: !llvm.ptr<1>): + %2 = builtin.unrealized_conversion_cast %1 : !tt.ptr to !llvm.ptr<1> + // CHECK-NEXT: "use"([[OUT]]) + "use"(%2) : (!llvm.ptr<1>) -> () + llvm.return +} + +} diff --git a/third_party/enflame/include/triton/test/LLVMIR/break-phi-struct.ll b/third_party/enflame/include/triton/test/LLVMIR/break-phi-struct.ll new file mode 100644 index 000000000..b27c87588 --- /dev/null +++ b/third_party/enflame/include/triton/test/LLVMIR/break-phi-struct.ll @@ -0,0 +1,33 @@ +; RUN: triton-llvm-opt -break-struct-phi-nodes %s | FileCheck %s + +; CHECK-LABEL: struct +define {i32, i32} @struct(i1 %c) { +; CHECK: br i1 %{{.*}}, label [[TRUE:%.*]], label [[FALSE:%.*]] + br i1 %c, label %true, label %false + +true: + %s.1 = insertvalue {i32, i32} undef, i32 20, 0 + %s.2 = insertvalue {i32, i32} %s.1, i32 200, 1 + +; CHECK-DAG: [[E0:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0 +; CHECK-DAG: [[E1:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1 +; CHECK: br + br label %exit + +false: + %s.3 = insertvalue {i32, i32} undef, i32 30, 0 + %s.4 = insertvalue {i32, i32} %s.3, i32 300, 1 +; CHECK-DAG: [[E2:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0 +; CHECK-DAG: [[E3:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1 +; CHECK: br + br label %exit + +exit: +; CHECK-DAG: [[PHI0:%.*]] = phi i32 [ [[E0]], [[TRUE]] ], [ [[E2]], [[FALSE]] ] +; CHECK-DAG: [[PHI1:%.*]] = phi i32 [ [[E1]], [[TRUE]] ], [ [[E3]], [[FALSE]] ] +; CHECK: [[S0:%.*]] = insertvalue { i32, i32 } undef, i32 [[PHI0]], 0 +; CHECK: [[S1:%.*]] = insertvalue { i32, i32 } [[S0]], i32 [[PHI1]], 1 +; CHECK: ret { i32, i32 } [[S1]] + %r = phi {i32, i32} [ %s.2, %true], [ %s.4, %false ] + ret {i32, i32} %r +} diff --git a/third_party/enflame/include/triton/test/Proton/ops.mlir b/third_party/enflame/include/triton/test/Proton/ops.mlir new file mode 100644 index 000000000..22a17e3f0 --- /dev/null +++ b/third_party/enflame/include/triton/test/Proton/ops.mlir @@ -0,0 +1,15 @@ +// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s + +module { + // CHECK-LABEL: proton_record + tt.func @proton_record() { + // CHECK: proton.record() {isStart = true, regionId = 1 : i32} + // CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32} + // CHECK-NEXT: tt.return + proton.record() {isStart = true, regionId = 1 : i32} + proton.record() {isStart = false, regionId = 1 : i32} + tt.return + } +} // end module + +// ----- diff --git a/third_party/enflame/include/triton/test/Tools/tensor_layout_print.mlir b/third_party/enflame/include/triton/test/Tools/tensor_layout_print.mlir new file mode 100644 index 000000000..9f802d2e3 --- /dev/null +++ b/third_party/enflame/include/triton/test/Tools/tensor_layout_print.mlir @@ -0,0 +1,58 @@ +// RUN: triton-tensor-layout -i %s -alias-names="blocked" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-BLOCKED + +// RUN: triton-tensor-layout -i %s -alias-names="mfma" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA + +// RUN: triton-tensor-layout -l "#ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA + +// RUN: triton-tensor-layout -i %s -alias-names="mfma" -t "tensor<16x16xf16>" -use-hw-view | FileCheck %s --check-prefix=CHECK-HW + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +tt.func @print(%A : !tt.ptr) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #blocked> + %cst1 = arith.constant dense<0.00e+00> : tensor<16x16xf16, #mfma> + tt.return +} + +// CHECK-BLOCKED: Print layout attribute: #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-BLOCKED: T0:0| T4:0, T0:1| T4:1, T0:2| T4:2, T0:3| T4:3, T1:0| T5:0, T1:1| T5:1, T1:2| T5:2, T1:3| T5:3, T2:0| T6:0, T2:1| T6:1, T2:2| T6:2, T2:3| T6:3, T3:0| T7:0, T3:1| T7:1, T3:2| T7:2, T3:3| T7:3 +// CHECK-BLOCKED: T8:0| T12:0, T8:1| T12:1, T8:2| T12:2, T8:3| T12:3, T9:0| T13:0, T9:1| T13:1, T9:2| T13:2, T9:3| T13:3, T10:0| T14:0, T10:1| T14:1, T10:2| T14:2, T10:3| T14:3, T11:0| T15:0, T11:1| T15:1, T11:2| T15:2, T11:3| T15:3 +// CHECK-BLOCKED: T16:0| T20:0, T16:1| T20:1, T16:2| T20:2, T16:3| T20:3, T17:0| T21:0, T17:1| T21:1, T17:2| T21:2, T17:3| T21:3, T18:0| T22:0, T18:1| T22:1, T18:2| T22:2, T18:3| T22:3, T19:0| T23:0, T19:1| T23:1, T19:2| T23:2, T19:3| T23:3 +// CHECK-BLOCKED: T24:0| T28:0, T24:1| T28:1, T24:2| T28:2, T24:3| T28:3, T25:0| T29:0, T25:1| T29:1, T25:2| T29:2, T25:3| T29:3, T26:0| T30:0, T26:1| T30:1, T26:2| T30:2, T26:3| T30:3, T27:0| T31:0, T27:1| T31:1, T27:2| T31:2, T27:3| T31:3 +// CHECK-BLOCKED: T32:0| T36:0, T32:1| T36:1, T32:2| T36:2, T32:3| T36:3, T33:0| T37:0, T33:1| T37:1, T33:2| T37:2, T33:3| T37:3, T34:0| T38:0, T34:1| T38:1, T34:2| T38:2, T34:3| T38:3, T35:0| T39:0, T35:1| T39:1, T35:2| T39:2, T35:3| T39:3 +// CHECK-BLOCKED: T40:0| T44:0, T40:1| T44:1, T40:2| T44:2, T40:3| T44:3, T41:0| T45:0, T41:1| T45:1, T41:2| T45:2, T41:3| T45:3, T42:0| T46:0, T42:1| T46:1, T42:2| T46:2, T42:3| T46:3, T43:0| T47:0, T43:1| T47:1, T43:2| T47:2, T43:3| T47:3 +// CHECK-BLOCKED: T48:0| T52:0, T48:1| T52:1, T48:2| T52:2, T48:3| T52:3, T49:0| T53:0, T49:1| T53:1, T49:2| T53:2, T49:3| T53:3, T50:0| T54:0, T50:1| T54:1, T50:2| T54:2, T50:3| T54:3, T51:0| T55:0, T51:1| T55:1, T51:2| T55:2, T51:3| T55:3 +// CHECK-BLOCKED: T56:0| T60:0, T56:1| T60:1, T56:2| T60:2, T56:3| T60:3, T57:0| T61:0, T57:1| T61:1, T57:2| T61:2, T57:3| T61:3, T58:0| T62:0, T58:1| T62:1, T58:2| T62:2, T58:3| T62:3, T59:0| T63:0, T59:1| T63:1, T59:2| T63:2, T59:3| T63:3 +// CHECK-BLOCKED: T64:0| T68:0, T64:1| T68:1, T64:2| T68:2, T64:3| T68:3, T65:0| T69:0, T65:1| T69:1, T65:2| T69:2, T65:3| T69:3, T66:0| T70:0, T66:1| T70:1, T66:2| T70:2, T66:3| T70:3, T67:0| T71:0, T67:1| T71:1, T67:2| T71:2, T67:3| T71:3 +// CHECK-BLOCKED: T72:0| T76:0, T72:1| T76:1, T72:2| T76:2, T72:3| T76:3, T73:0| T77:0, T73:1| T77:1, T73:2| T77:2, T73:3| T77:3, T74:0| T78:0, T74:1| T78:1, T74:2| T78:2, T74:3| T78:3, T75:0| T79:0, T75:1| T79:1, T75:2| T79:2, T75:3| T79:3 +// CHECK-BLOCKED: T80:0| T84:0, T80:1| T84:1, T80:2| T84:2, T80:3| T84:3, T81:0| T85:0, T81:1| T85:1, T81:2| T85:2, T81:3| T85:3, T82:0| T86:0, T82:1| T86:1, T82:2| T86:2, T82:3| T86:3, T83:0| T87:0, T83:1| T87:1, T83:2| T87:2, T83:3| T87:3 +// CHECK-BLOCKED: T88:0| T92:0, T88:1| T92:1, T88:2| T92:2, T88:3| T92:3, T89:0| T93:0, T89:1| T93:1, T89:2| T93:2, T89:3| T93:3, T90:0| T94:0, T90:1| T94:1, T90:2| T94:2, T90:3| T94:3, T91:0| T95:0, T91:1| T95:1, T91:2| T95:2, T91:3| T95:3 +// CHECK-BLOCKED: T96:0|T100:0, T96:1|T100:1, T96:2|T100:2, T96:3|T100:3, T97:0|T101:0, T97:1|T101:1, T97:2|T101:2, T97:3|T101:3, T98:0|T102:0, T98:1|T102:1, T98:2|T102:2, T98:3|T102:3, T99:0|T103:0, T99:1|T103:1, T99:2|T103:2, T99:3|T103:3 +// CHECK-BLOCKED: T104:0|T108:0, T104:1|T108:1, T104:2|T108:2, T104:3|T108:3, T105:0|T109:0, T105:1|T109:1, T105:2|T109:2, T105:3|T109:3, T106:0|T110:0, T106:1|T110:1, T106:2|T110:2, T106:3|T110:3, T107:0|T111:0, T107:1|T111:1, T107:2|T111:2, T107:3|T111:3 +// CHECK-BLOCKED: T112:0|T116:0, T112:1|T116:1, T112:2|T116:2, T112:3|T116:3, T113:0|T117:0, T113:1|T117:1, T113:2|T117:2, T113:3|T117:3, T114:0|T118:0, T114:1|T118:1, T114:2|T118:2, T114:3|T118:3, T115:0|T119:0, T115:1|T119:1, T115:2|T119:2, T115:3|T119:3 +// CHECK-BLOCKED: T120:0|T124:0, T120:1|T124:1, T120:2|T124:2, T120:3|T124:3, T121:0|T125:0, T121:1|T125:1, T121:2|T125:2, T121:3|T125:3, T122:0|T126:0, T122:1|T126:1, T122:2|T126:2, T122:3|T126:3, T123:0|T127:0, T123:1|T127:1, T123:2|T127:2, T123:3|T127:3 + + +// CHECK-MFMA: Print layout attribute: {{.*}}#ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +// CHECK-MFMA: T0:0| T64:0|T128:0|T192:0, T0:1| T64:1|T128:1|T192:1, T0:2| T64:2|T128:2|T192:2, T0:3| T64:3|T128:3|T192:3, T16:0| T80:0|T144:0|T208:0, T16:1| T80:1|T144:1|T208:1, T16:2| T80:2|T144:2|T208:2, T16:3| T80:3|T144:3|T208:3, T32:0| T96:0|T160:0|T224:0, T32:1| T96:1|T160:1|T224:1, T32:2| T96:2|T160:2|T224:2, T32:3| T96:3|T160:3|T224:3, T48:0|T112:0|T176:0|T240:0, T48:1|T112:1|T176:1|T240:1, T48:2|T112:2|T176:2|T240:2, T48:3|T112:3|T176:3|T240:3 +// CHECK-MFMA: T1:0| T65:0|T129:0|T193:0, T1:1| T65:1|T129:1|T193:1, T1:2| T65:2|T129:2|T193:2, T1:3| T65:3|T129:3|T193:3, T17:0| T81:0|T145:0|T209:0, T17:1| T81:1|T145:1|T209:1, T17:2| T81:2|T145:2|T209:2, T17:3| T81:3|T145:3|T209:3, T33:0| T97:0|T161:0|T225:0, T33:1| T97:1|T161:1|T225:1, T33:2| T97:2|T161:2|T225:2, T33:3| T97:3|T161:3|T225:3, T49:0|T113:0|T177:0|T241:0, T49:1|T113:1|T177:1|T241:1, T49:2|T113:2|T177:2|T241:2, T49:3|T113:3|T177:3|T241:3 +// CHECK-MFMA: T2:0| T66:0|T130:0|T194:0, T2:1| T66:1|T130:1|T194:1, T2:2| T66:2|T130:2|T194:2, T2:3| T66:3|T130:3|T194:3, T18:0| T82:0|T146:0|T210:0, T18:1| T82:1|T146:1|T210:1, T18:2| T82:2|T146:2|T210:2, T18:3| T82:3|T146:3|T210:3, T34:0| T98:0|T162:0|T226:0, T34:1| T98:1|T162:1|T226:1, T34:2| T98:2|T162:2|T226:2, T34:3| T98:3|T162:3|T226:3, T50:0|T114:0|T178:0|T242:0, T50:1|T114:1|T178:1|T242:1, T50:2|T114:2|T178:2|T242:2, T50:3|T114:3|T178:3|T242:3 +// CHECK-MFMA: T3:0| T67:0|T131:0|T195:0, T3:1| T67:1|T131:1|T195:1, T3:2| T67:2|T131:2|T195:2, T3:3| T67:3|T131:3|T195:3, T19:0| T83:0|T147:0|T211:0, T19:1| T83:1|T147:1|T211:1, T19:2| T83:2|T147:2|T211:2, T19:3| T83:3|T147:3|T211:3, T35:0| T99:0|T163:0|T227:0, T35:1| T99:1|T163:1|T227:1, T35:2| T99:2|T163:2|T227:2, T35:3| T99:3|T163:3|T227:3, T51:0|T115:0|T179:0|T243:0, T51:1|T115:1|T179:1|T243:1, T51:2|T115:2|T179:2|T243:2, T51:3|T115:3|T179:3|T243:3 +// CHECK-MFMA: T4:0| T68:0|T132:0|T196:0, T4:1| T68:1|T132:1|T196:1, T4:2| T68:2|T132:2|T196:2, T4:3| T68:3|T132:3|T196:3, T20:0| T84:0|T148:0|T212:0, T20:1| T84:1|T148:1|T212:1, T20:2| T84:2|T148:2|T212:2, T20:3| T84:3|T148:3|T212:3, T36:0|T100:0|T164:0|T228:0, T36:1|T100:1|T164:1|T228:1, T36:2|T100:2|T164:2|T228:2, T36:3|T100:3|T164:3|T228:3, T52:0|T116:0|T180:0|T244:0, T52:1|T116:1|T180:1|T244:1, T52:2|T116:2|T180:2|T244:2, T52:3|T116:3|T180:3|T244:3 +// CHECK-MFMA: T5:0| T69:0|T133:0|T197:0, T5:1| T69:1|T133:1|T197:1, T5:2| T69:2|T133:2|T197:2, T5:3| T69:3|T133:3|T197:3, T21:0| T85:0|T149:0|T213:0, T21:1| T85:1|T149:1|T213:1, T21:2| T85:2|T149:2|T213:2, T21:3| T85:3|T149:3|T213:3, T37:0|T101:0|T165:0|T229:0, T37:1|T101:1|T165:1|T229:1, T37:2|T101:2|T165:2|T229:2, T37:3|T101:3|T165:3|T229:3, T53:0|T117:0|T181:0|T245:0, T53:1|T117:1|T181:1|T245:1, T53:2|T117:2|T181:2|T245:2, T53:3|T117:3|T181:3|T245:3 +// CHECK-MFMA: T6:0| T70:0|T134:0|T198:0, T6:1| T70:1|T134:1|T198:1, T6:2| T70:2|T134:2|T198:2, T6:3| T70:3|T134:3|T198:3, T22:0| T86:0|T150:0|T214:0, T22:1| T86:1|T150:1|T214:1, T22:2| T86:2|T150:2|T214:2, T22:3| T86:3|T150:3|T214:3, T38:0|T102:0|T166:0|T230:0, T38:1|T102:1|T166:1|T230:1, T38:2|T102:2|T166:2|T230:2, T38:3|T102:3|T166:3|T230:3, T54:0|T118:0|T182:0|T246:0, T54:1|T118:1|T182:1|T246:1, T54:2|T118:2|T182:2|T246:2, T54:3|T118:3|T182:3|T246:3 +// CHECK-MFMA: T7:0| T71:0|T135:0|T199:0, T7:1| T71:1|T135:1|T199:1, T7:2| T71:2|T135:2|T199:2, T7:3| T71:3|T135:3|T199:3, T23:0| T87:0|T151:0|T215:0, T23:1| T87:1|T151:1|T215:1, T23:2| T87:2|T151:2|T215:2, T23:3| T87:3|T151:3|T215:3, T39:0|T103:0|T167:0|T231:0, T39:1|T103:1|T167:1|T231:1, T39:2|T103:2|T167:2|T231:2, T39:3|T103:3|T167:3|T231:3, T55:0|T119:0|T183:0|T247:0, T55:1|T119:1|T183:1|T247:1, T55:2|T119:2|T183:2|T247:2, T55:3|T119:3|T183:3|T247:3 +// CHECK-MFMA: T8:0| T72:0|T136:0|T200:0, T8:1| T72:1|T136:1|T200:1, T8:2| T72:2|T136:2|T200:2, T8:3| T72:3|T136:3|T200:3, T24:0| T88:0|T152:0|T216:0, T24:1| T88:1|T152:1|T216:1, T24:2| T88:2|T152:2|T216:2, T24:3| T88:3|T152:3|T216:3, T40:0|T104:0|T168:0|T232:0, T40:1|T104:1|T168:1|T232:1, T40:2|T104:2|T168:2|T232:2, T40:3|T104:3|T168:3|T232:3, T56:0|T120:0|T184:0|T248:0, T56:1|T120:1|T184:1|T248:1, T56:2|T120:2|T184:2|T248:2, T56:3|T120:3|T184:3|T248:3 +// CHECK-MFMA: T9:0| T73:0|T137:0|T201:0, T9:1| T73:1|T137:1|T201:1, T9:2| T73:2|T137:2|T201:2, T9:3| T73:3|T137:3|T201:3, T25:0| T89:0|T153:0|T217:0, T25:1| T89:1|T153:1|T217:1, T25:2| T89:2|T153:2|T217:2, T25:3| T89:3|T153:3|T217:3, T41:0|T105:0|T169:0|T233:0, T41:1|T105:1|T169:1|T233:1, T41:2|T105:2|T169:2|T233:2, T41:3|T105:3|T169:3|T233:3, T57:0|T121:0|T185:0|T249:0, T57:1|T121:1|T185:1|T249:1, T57:2|T121:2|T185:2|T249:2, T57:3|T121:3|T185:3|T249:3 +// CHECK-MFMA: T10:0| T74:0|T138:0|T202:0, T10:1| T74:1|T138:1|T202:1, T10:2| T74:2|T138:2|T202:2, T10:3| T74:3|T138:3|T202:3, T26:0| T90:0|T154:0|T218:0, T26:1| T90:1|T154:1|T218:1, T26:2| T90:2|T154:2|T218:2, T26:3| T90:3|T154:3|T218:3, T42:0|T106:0|T170:0|T234:0, T42:1|T106:1|T170:1|T234:1, T42:2|T106:2|T170:2|T234:2, T42:3|T106:3|T170:3|T234:3, T58:0|T122:0|T186:0|T250:0, T58:1|T122:1|T186:1|T250:1, T58:2|T122:2|T186:2|T250:2, T58:3|T122:3|T186:3|T250:3 +// CHECK-MFMA: T11:0| T75:0|T139:0|T203:0, T11:1| T75:1|T139:1|T203:1, T11:2| T75:2|T139:2|T203:2, T11:3| T75:3|T139:3|T203:3, T27:0| T91:0|T155:0|T219:0, T27:1| T91:1|T155:1|T219:1, T27:2| T91:2|T155:2|T219:2, T27:3| T91:3|T155:3|T219:3, T43:0|T107:0|T171:0|T235:0, T43:1|T107:1|T171:1|T235:1, T43:2|T107:2|T171:2|T235:2, T43:3|T107:3|T171:3|T235:3, T59:0|T123:0|T187:0|T251:0, T59:1|T123:1|T187:1|T251:1, T59:2|T123:2|T187:2|T251:2, T59:3|T123:3|T187:3|T251:3 +// CHECK-MFMA: T12:0| T76:0|T140:0|T204:0, T12:1| T76:1|T140:1|T204:1, T12:2| T76:2|T140:2|T204:2, T12:3| T76:3|T140:3|T204:3, T28:0| T92:0|T156:0|T220:0, T28:1| T92:1|T156:1|T220:1, T28:2| T92:2|T156:2|T220:2, T28:3| T92:3|T156:3|T220:3, T44:0|T108:0|T172:0|T236:0, T44:1|T108:1|T172:1|T236:1, T44:2|T108:2|T172:2|T236:2, T44:3|T108:3|T172:3|T236:3, T60:0|T124:0|T188:0|T252:0, T60:1|T124:1|T188:1|T252:1, T60:2|T124:2|T188:2|T252:2, T60:3|T124:3|T188:3|T252:3 +// CHECK-MFMA: T13:0| T77:0|T141:0|T205:0, T13:1| T77:1|T141:1|T205:1, T13:2| T77:2|T141:2|T205:2, T13:3| T77:3|T141:3|T205:3, T29:0| T93:0|T157:0|T221:0, T29:1| T93:1|T157:1|T221:1, T29:2| T93:2|T157:2|T221:2, T29:3| T93:3|T157:3|T221:3, T45:0|T109:0|T173:0|T237:0, T45:1|T109:1|T173:1|T237:1, T45:2|T109:2|T173:2|T237:2, T45:3|T109:3|T173:3|T237:3, T61:0|T125:0|T189:0|T253:0, T61:1|T125:1|T189:1|T253:1, T61:2|T125:2|T189:2|T253:2, T61:3|T125:3|T189:3|T253:3 +// CHECK-MFMA: T14:0| T78:0|T142:0|T206:0, T14:1| T78:1|T142:1|T206:1, T14:2| T78:2|T142:2|T206:2, T14:3| T78:3|T142:3|T206:3, T30:0| T94:0|T158:0|T222:0, T30:1| T94:1|T158:1|T222:1, T30:2| T94:2|T158:2|T222:2, T30:3| T94:3|T158:3|T222:3, T46:0|T110:0|T174:0|T238:0, T46:1|T110:1|T174:1|T238:1, T46:2|T110:2|T174:2|T238:2, T46:3|T110:3|T174:3|T238:3, T62:0|T126:0|T190:0|T254:0, T62:1|T126:1|T190:1|T254:1, T62:2|T126:2|T190:2|T254:2, T62:3|T126:3|T190:3|T254:3 +// CHECK-MFMA: T15:0| T79:0|T143:0|T207:0, T15:1| T79:1|T143:1|T207:1, T15:2| T79:2|T143:2|T207:2, T15:3| T79:3|T143:3|T207:3, T31:0| T95:0|T159:0|T223:0, T31:1| T95:1|T159:1|T223:1, T31:2| T95:2|T159:2|T223:2, T31:3| T95:3|T159:3|T223:3, T47:0|T111:0|T175:0|T239:0, T47:1|T111:1|T175:1|T239:1, T47:2|T111:2|T175:2|T239:2, T47:3|T111:3|T175:3|T239:3, T63:0|T127:0|T191:0|T255:0, T63:1|T127:1|T191:1|T255:1, T63:2|T127:2|T191:2|T255:2, T63:3|T127:3|T191:3|T255:3 + + +// CHECK-HW: Warp0: +// CHECK-HW: Warp1: +// CHECK-HW: Warp2: +// CHECK-HW: Warp3: diff --git a/third_party/enflame/include/triton/test/Triton/canonicalize.mlir b/third_party/enflame/include/triton/test/Triton/canonicalize.mlir new file mode 100644 index 000000000..ef448d500 --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/canonicalize.mlir @@ -0,0 +1,186 @@ +// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s + +// CHECK-LABEL: dead_load +tt.func @dead_load(%ptr: tensor<32x128x!tt.ptr>) { + %mask = arith.constant dense : tensor<32x128xi1> + %other = arith.constant dense<0.00e+00> : tensor<32x128xf16> + // CHECK-NOT: tt.load {{.*}}isVolatile = false + // CHECK: tt.load {{.*}}isVolatile = true + %a = tt.load %ptr, %mask, %other : tensor<32x128x!tt.ptr> + %b = tt.load %ptr, %mask, %other {isVolatile = true} : tensor<32x128x!tt.ptr> + tt.return +} + +// ----- + +// CHECK-LABEL: make_range +tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) { + // CHECK-DAG: %[[c:.*]] = arith.constant dense<0> : tensor<128x1xi32> + %a = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> + %b = tt.expand_dims %a {axis = 1 : i32} : tensor<1xi32> -> tensor<1x1xi32> + %c = tt.broadcast %b : tensor<1x1xi32> -> tensor<128x1xi32> + + // CHECK-DAG: %[[d:.*]] = arith.constant dense<1> : tensor<1xi32> + %d = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32> + + // CHECK-DAG: tt.return %[[c]], %[[d]] : tensor<128x1xi32>, tensor<1xi32> + tt.return %c, %d : tensor<128x1xi32>, tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: fold_addptr +tt.func @fold_addptr(%arg: tensor<64x64x!tt.ptr>) -> (tensor<64x64x!tt.ptr>) { + // CHECK-NOT: tt.addptr + // CHECK-NOT: arith.constant + // CHECK: tt.return %arg + %c0_i32 = arith.constant dense<0> : tensor<64x64xi32> + %0 = tt.addptr %arg, %c0_i32 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> + tt.return %0 : tensor<64x64x!tt.ptr> +} + +// ----- + +// CHECK-LABEL: fold_addptr_scalar +tt.func @fold_addptr_scalar(%arg: !tt.ptr) -> (!tt.ptr) { + // CHECK-NOT: tt.addptr + // CHECK-NOT: arith.constant + // CHECK: tt.return %arg + %c0_i32 = arith.constant 0 : i32 + %0 = tt.addptr %arg, %c0_i32 : !tt.ptr, i32 + tt.return %0 : !tt.ptr +} + +// ----- + +// CHECK-LABEL: fold_advance +tt.func @fold_advance(%arg: !tt.ptr>) -> (!tt.ptr>) { + %c0_i32 = arith.constant 0 : i32 + %0 = tt.advance %arg, [%c0_i32, %c0_i32] : > + // CHECK-NOT: tt.advance + // CHECK: tt.return %arg + tt.return %0 : !tt.ptr> +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#sliced0 = #ttg.slice<{dim = 1, parent = #blocked0}> + +// CHECK-LABEL: fn +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { +tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){ + // CHECK: %[[a:.*]] = tt.expand_dims + // CHECK: tt.broadcast %[[a]] + %a = tt.broadcast %arg0 : tensor<1xf32, #sliced0> -> tensor<32xf32, #sliced0> + %b = tt.expand_dims %a {axis = 1 : i32} : tensor<32xf32, #sliced0> -> tensor<32x1xf32, #blocked0> + tt.return %b : tensor<32x1xf32, #blocked0> +} +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fp_to_fp_pos_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { + // CHECK-LABEL: fp_to_fp_pos_zero_fold + // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked> + // CHECK-NEXT: tt.return %[[cst_folded]] + %cst = arith.constant dense<0.00e+00> : tensor<32x128xf32, #blocked> + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> + } +} // end module + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fp_to_fp_pos_zero_fold_scalar() -> f8E4M3FNUZ { + // CHECK-LABEL: fp_to_fp_pos_zero_fold_scalar + // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant 0.000000e+00 : f8E4M3FNUZ + // CHECK-NEXT: tt.return %[[cst_folded]] + %cst = arith.constant 0.00e+00 : f32 + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : f32 -> f8E4M3FNUZ + tt.return %cst_converted : f8E4M3FNUZ + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FN, #blocked> { + // CHECK-LABEL: fp_to_fp_neg_zero_fold + // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<-0.000000e+00> : tensor<32x128xf8E4M3FN, #blocked> + // CHECK-NEXT: tt.return %[[cst_folded]] + %cst = arith.constant dense<-0.00e+00> : tensor<32x128xf32, #blocked> + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FN, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FN, #blocked> + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { + // CHECK-LABEL: fp_to_fp_neg_zero_fold + // We fold to the positive zero here given by definition f8E4M3FNUZ does not have negative zero encoding. + // CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked> + // CHECK-NEXT: tt.return %[[cst_folded]] + %cst = arith.constant dense<-0.00e+00> : tensor<32x128xf32, #blocked> + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fold_fp_to_fp_non_zero_nofold() -> tensor<32x128xf8E4M3FNUZ, #blocked> { + // CHECK-LABEL: fold_fp_to_fp_non_zero_nofold + // CHECK-NEXT: %[[cst:.+]] = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked> + // CHECK-NEXT: %[[cst_cvt:.+]] = tt.fp_to_fp %[[cst]] + // CHECK-NEXT: tt.return %[[cst_cvt]] + %cst = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked> + %cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @fold_fp_to_fp_non_constant_nofold(%arg0: tensor<32x128xf32, #blocked>) -> tensor<32x128xf8E4M3FNUZ, #blocked> { + // CHECK-LABEL: fold_fp_to_fp_non_constant_nofold + // CHECK-NEXT: %[[arg_cvt:.+]] = tt.fp_to_fp %arg0 + // CHECK-NEXT: tt.return %[[arg_cvt]] + %cst_converted = tt.fp_to_fp %arg0, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked> + tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> + } +} // end module + +// ----- + +// CHECK-LABEL: @fold_broadcast_constant_pattern +tt.func @fold_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { + // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> + %const = arith.constant dense<1.0> : tensor<8x1xf32> + %bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32> + + // CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32> + tt.return %bst_out : tensor<8x2xf32> +} + +// ----- + +// CHECK-LABEL: @fold_transpose_constant +tt.func @fold_transpose_constant() -> tensor<128x16xf32> { + // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<128x16xf32> + %cst = arith.constant dense<1.0> : tensor<16x128xf32> + %r = tt.trans %cst {order = array} : tensor<16x128xf32> -> tensor<128x16xf32> + // CHECK-NEXT: tt.return %[[cst]] : tensor<128x16xf32> + tt.return %r : tensor<128x16xf32> +} diff --git a/third_party/enflame/include/triton/test/Triton/combine.mlir b/third_party/enflame/include/triton/test/Triton/combine.mlir new file mode 100644 index 000000000..7bb27aa01 --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/combine.mlir @@ -0,0 +1,392 @@ +// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s + +// We don't combine if the dot result is used by more than one op. +// CHECK-LABEL: @test_combine_dot_add_invalid_pattern +tt.func @test_combine_dot_add_invalid_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> + // CHECK-DAG: %[[e:.*]] = arith.constant dense<4.000000e+00> : tensor<128x128xf32> + %a = arith.constant dense<1.0> : tensor<128x128xf32> + %b = arith.constant dense<2.0> : tensor<128x128xf32> + %zero = arith.constant dense<0.0> : tensor<128x128xf32> + %d = arith.constant dense<3.0> : tensor<128x128xf32> + %e = arith.constant dense<4.0> : tensor<128x128xf32> + + %dot_out = tt.dot %a, %b, %zero : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + + // CHECK: arith.addf %{{.*}}, %[[d]] : tensor<128x128xf32> + %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> + + // CHECK-NEXT: arith.addf %{{.*}}, %[[e]] : tensor<128x128xf32> + %res1 = arith.addf %dot_out, %e : tensor<128x128xf32> + + tt.return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> +} + + +// CHECK-LABEL: @test_combine_dot_add_pattern +tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>) { + // CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> + // CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> + // CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> + %a = arith.constant dense<1.0> : tensor<128x128xf32> + %b = arith.constant dense<2.0> : tensor<128x128xf32> + %zero = arith.constant dense<0.0> : tensor<128x128xf32> + %d = arith.constant dense<3.0> : tensor<128x128xf32> + + %dot_out = tt.dot %a, %b, %zero : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + + // CHECK-NEXT: %[[res:.*]] = tt.dot %[[a]], %[[b]], %[[d]] : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: tt.return %[[res]] : tensor<128x128xf32> + %res = arith.addf %dot_out, %d : tensor<128x128xf32> + + tt.return %res : tensor<128x128xf32> +} + + +// CHECK-LABEL: @test_combine_dot_add_rev_pattern +tt.func @test_combine_dot_add_rev_pattern() -> (tensor<128x128xf32>) { + // CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> + // CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> + // CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> + %a = arith.constant dense<1.0> : tensor<128x128xf32> + %b = arith.constant dense<2.0> : tensor<128x128xf32> + %zero = arith.constant dense<0.0> : tensor<128x128xf32> + %d = arith.constant dense<3.0> : tensor<128x128xf32> + + %dot_out = tt.dot %a, %b, %zero : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + + // CHECK-NEXT: %[[res:.*]] = tt.dot %[[a]], %[[b]], %[[d]] : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: tt.return %[[res]] : tensor<128x128xf32> + %res = arith.addf %d, %dot_out : tensor<128x128xf32> + + tt.return %res : tensor<128x128xf32> +} + + +// CHECK-LABEL: @test_combine_addptr_pattern +tt.func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 10 : i32 + %off1 = arith.constant 15 : i32 + + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32> + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + + // CHECK-NEXT: %[[tmp0:.*]] = tt.splat %{{.*}} : !tt.ptr -> tensor<8x!tt.ptr> + + %idx0 = tt.splat %off0 : i32 -> tensor<8xi32> + %idx1 = tt.splat %off1 : i32 -> tensor<8xi32> + + // CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr>, tensor<8xi32> + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi32> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi32> + + tt.return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_combine_addptr_pattern_i64 +tt.func @test_combine_addptr_pattern_i64(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 10 : i64 + %off1 = arith.constant dense<15> : tensor<8xi64> + + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi64> + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + + // CHECK-NEXT: %[[tmp0:.*]] = tt.splat %{{.*}} : !tt.ptr -> tensor<8x!tt.ptr> + + %idx0 = tt.splat %off0 : i64 -> tensor<8xi64> + + // CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr>, tensor<8xi64> + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi64> + %ptr1 = tt.addptr %ptr0, %off1 : tensor<8x!tt.ptr>, tensor<8xi64> + + tt.return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_combine_addptr_pattern_scalar +tt.func @test_combine_addptr_pattern_scalar(%base: !tt.ptr) -> !tt.ptr { + %off0 = arith.constant 10 : i32 + %off1 = arith.constant 15 : i32 + + // CHECK-NEXT: %[[cst:.*]] = arith.constant 25 : i32 + // CHECK-NEXT: %0 = tt.addptr %{{.*}}, %[[cst]] : !tt.ptr, i32 + %ptr0 = tt.addptr %base, %off0 : !tt.ptr, i32 + %ptr1 = tt.addptr %ptr0, %off1 : !tt.ptr, i32 + + tt.return %ptr1 : !tt.ptr +} + +// CHECK-LABEL: @test_not_combine_addptr_pattern_1 +tt.func @test_not_combine_addptr_pattern_1(%base: !tt.ptr, %idx0: tensor<8xi32>) -> tensor<8x!tt.ptr> { + %off1 = arith.constant 15 : i32 + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + %idx1 = tt.splat %off1 : i32 -> tensor<8xi32> + + // CHECK: tt.addptr + // CHECK-NEXT: tt.addptr + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi32> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_not_combine_addptr_pattern +tt.func @test_not_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 10 : i16 + %off1 = arith.constant 15 : i32 + + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<10> : tensor<8xi16> + // CHECK-DAG: %[[cst1:.*]] = arith.constant dense<15> : tensor<8xi32> + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + + %idx0 = tt.splat %off0 : i16 -> tensor<8xi16> + %idx1 = tt.splat %off1 : i32 -> tensor<8xi32> + + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi16> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi32> + + tt.return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_not_combine_addptr_pattern_overflow +tt.func @test_not_combine_addptr_pattern_overflow(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 127 : i8 + %off1 = arith.constant 1 : i8 + + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<127> : tensor<8xi8> + // CHECK-DAG: %[[cst1:.*]] = arith.constant dense<1> : tensor<8xi8> + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + + %idx0 = tt.splat %off0 : i8 -> tensor<8xi8> + %idx1 = tt.splat %off1 : i8 -> tensor<8xi8> + + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi8> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi8> + + tt.return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_combine_select_masked_load_pattern +tt.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { + %mask = tt.splat %cond : i1 -> tensor<8xi1> + %false_val = arith.constant dense<0.0> : tensor<8xf32> + + // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} : tensor<8x!tt.ptr> + %x = tt.load %ptr, %mask, %false_val : tensor<8x!tt.ptr> + %0 = arith.select %cond, %x, %false_val : tensor<8xf32> + + // CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} : tensor<8x!tt.ptr> + %y = tt.load %ptr, %mask, %false_val : tensor<8x!tt.ptr> + %1 = arith.select %cond, %y, %false_val : tensor<8xf32> + + // CHECK: tt.return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32> + tt.return %0, %1 : tensor<8xf32>, tensor<8xf32> +} + +// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern +tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { + %false_val = arith.constant dense<0.0> : tensor<8xf32> + + // Case 1: value at the "load" position is not an "op". Select should not be canonicalized. + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + %0 = arith.select %cond0, %dummy_load, %false_val : tensor<8xf32> + + // Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized. + %real_load0 = tt.load %ptr, %dummy_broadcast, %false_val : tensor<8x!tt.ptr> + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + %1 = arith.select %cond0, %real_load0, %false_val : tensor<8xf32> + + // Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized. + %cond0_ = tt.splat %cond0 : i1 -> tensor<8xi1> + %real_load1 = tt.load %ptr, %cond0_, %false_val : tensor<8x!tt.ptr> + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + %2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32> + + tt.return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32> +} + +// CHECK-LABEL: @test_canonicalize_masked_load_pattern +tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { + %true_mask = arith.constant dense : tensor<8xi1> + %false_mask = arith.constant dense : tensor<8xi1> + %other_val = arith.constant dense<0.0> : tensor<8xf32> + + // true_mask with other + // CHECK: %[[res1:.*]] = tt.load %{{.*}} : tensor<8x!tt.ptr> + %x = tt.load %ptr, %true_mask : tensor<8x!tt.ptr> + + // true_mask without other + // CHECK: %[[res2:.*]] = tt.load %{{.*}} : tensor<8x!tt.ptr> + %y = tt.load %ptr, %true_mask, %other_val : tensor<8x!tt.ptr> + + // false_mask with other. It should become "other" (i.e., %y) + %z = tt.load %ptr, %false_mask, %y : tensor<8x!tt.ptr> + + // CHECK: tt.return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32> + tt.return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32> +} + +// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern +tt.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) { + %other_val = arith.constant dense<0.0> : tensor<8xf32> + + // Case: value at the "mask" position is not an "op". Load should not be canonicalized. + // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} : tensor<8x!tt.ptr> + %x = tt.load %ptr, %mask : tensor<8x!tt.ptr> + // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} : tensor<8x!tt.ptr> + %y = tt.load %ptr, %mask, %other_val : tensor<8x!tt.ptr> + + tt.return %x, %y: tensor<8xf32>, tensor<8xf32> +} + +// CHECK-LABEL: @test_canonicalize_masked_store_pattern +tt.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>) { + %true_mask = arith.constant dense : tensor<8xi1> + %false_mask = arith.constant dense : tensor<8xi1> + + // CHECK: tt.store %{{.*}}, %{{.*}} : tensor<8x!tt.ptr> + tt.store %ptr, %val, %true_mask : tensor<8x!tt.ptr> + + // The following store should disappear. + // CHECK-NEXT: tt.return + tt.store %ptr, %val, %false_mask : tensor<8x!tt.ptr> + tt.return +} + +// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern +tt.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>, %mask: tensor<8xi1>) { + // Case: value at the "mask" position is not an "op". Store should not be canonicalized. + // CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8x!tt.ptr> + tt.store %ptr, %val, %mask : tensor<8x!tt.ptr> + tt.return +} + +// CHECK-LABEL: @test_canonicalize_expand_dims +tt.func @test_canonicalize_expand_dims(%arg0: tensor, %arg1: tensor<1xf32>) -> (tensor<1x8xf32>, tensor<8x8xf32>) { + %splat = tt.splat %arg0 : tensor -> tensor<8xf32> + // CHECK: %{{.*}} = tt.splat %arg0 : tensor -> tensor<1x8xf32> + %ed = tt.expand_dims %splat {axis = 0 : i32} : tensor<8xf32> -> tensor<1x8xf32> + + // CHECK-NEXT: %[[ed2:.*]] = tt.expand_dims %arg1 {axis = 0 : i32} : tensor<1xf32> -> tensor<1x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[ed2]] : tensor<1x1xf32> -> tensor<8x8xf32> + %bc = tt.broadcast %arg1 : tensor<1xf32> -> tensor<8xf32> + %ed2 = tt.expand_dims %bc {axis = 0 : i32} : tensor<8xf32> -> tensor<1x8xf32> + %bc2 = tt.broadcast %ed2 : tensor<1x8xf32> -> tensor<8x8xf32> + + tt.return %ed, %bc2 : tensor<1x8xf32>, tensor<8x8xf32> +} + +// CHECK-LABEL: @test_canonicalize_view +tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>, tensor<2x2x2xf32>) { + %view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32> + // CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32> + %view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32> + + %splat = tt.splat %arg1 : tensor -> tensor<8xf32> + // CHECK: %{{.*}} = tt.splat %arg1 : tensor -> tensor<2x2x2xf32> + %view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32> + + %view3 = tt.reshape %arg0 : tensor<8xf32> -> tensor<8xf32> + // CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32> + %add = arith.addf %view3, %arg0 : tensor<8xf32> + + // CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32> + %reshape = tt.reshape %view0 : tensor<2x4xf32> -> tensor<2x2x2xf32> + + tt.return %view1, %view2, %add, %reshape : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>, tensor<2x2x2xf32> +} + +// CHECK-LABEL: @test_canonicalize_reshape +tt.func @test_canonicalize_reshape(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>, tensor<2x2x2xf32>) { + %reshape0 = tt.reshape %arg0 : tensor<8xf32> -> tensor<2x4xf32> + // CHECK: %{{.*}} = tt.reshape %arg0 : tensor<8xf32> -> tensor<4x2xf32> + %reshape1 = tt.reshape %reshape0 : tensor<2x4xf32> -> tensor<4x2xf32> + + %splat = tt.splat %arg1 : tensor -> tensor<8xf32> + // CHECK: %{{.*}} = tt.splat %arg1 : tensor -> tensor<2x2x2xf32> + %reshape2 = tt.reshape %splat : tensor<8xf32> -> tensor<2x2x2xf32> + + %reshape3 = tt.reshape %arg0 : tensor<8xf32> -> tensor<8xf32> + // CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32> + %add = arith.addf %reshape3, %arg0 : tensor<8xf32> + + // CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32> + %view = tt.reshape %reshape0 allow_reorder : tensor<2x4xf32> -> tensor<2x2x2xf32> + + tt.return %reshape1, %reshape2, %add, %view : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>, tensor<2x2x2xf32> +} + +// CHECK-LABEL: @test_canonicalize_broadcast +tt.func @test_canonicalize_broadcast(%arg0: tensor<1x1x8xf32>, %arg1: tensor) -> (tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>) { + %broadcast0 = tt.broadcast %arg0 : tensor<1x1x8xf32> -> tensor<1x2x8xf32> + // CHECK: %{{.*}} = tt.broadcast %arg0 : tensor<1x1x8xf32> -> tensor<4x2x8xf32> + %broadcast1 = tt.broadcast %broadcast0 : tensor<1x2x8xf32> -> tensor<4x2x8xf32> + + %splat = tt.splat %arg1 : tensor -> tensor<1x8xf32> + // CHECK: %{{.*}} = tt.splat %arg1 : tensor -> tensor<8x8xf32> + %broadcast2 = tt.broadcast %splat : tensor<1x8xf32> -> tensor<8x8xf32> + + %broadcast3 = tt.broadcast %arg0 : tensor<1x1x8xf32> -> tensor<1x1x8xf32> + // CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<1x1x8xf32> + %add = arith.addf %broadcast3, %arg0 : tensor<1x1x8xf32> + + tt.return %broadcast1, %broadcast2, %add : tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32> +} + +// CHECK-LABEL: @test_fold_views +tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>) { + %a = arith.constant dense<1.0> : tensor<1x128xf32> + + // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32> + %b = tt.reshape %a allow_reorder : tensor<1x128xf32> -> tensor<16x8xf32> + + // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x128xf32> + %c = tt.broadcast %a : tensor<1x128xf32> -> tensor<16x128xf32> + + // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<1x1x128xf32> + %d = tt.expand_dims %a {axis = 0: i32} : tensor<1x128xf32> -> tensor<1x1x128xf32> + + tt.return %b, %c, %d : tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32> +} + +// CHECK-LABEL: @test_nop_transpose +tt.func @test_nop_transpose(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>) { + %a = tt.trans %arg0 {order = array} : tensor<2x4xf32> -> tensor<2x4xf32> + // CHECK: tt.return %arg0 + tt.return %a : tensor<2x4xf32> +} + +// CHECK-LABEL: @test_nested_transpose +tt.func @test_nested_transpose(%arg0: tensor<2x4x8xf32>) -> (tensor<8x2x4xf32>) { + %a = tt.trans %arg0 {order = array} : tensor<2x4x8xf32> -> tensor<4x2x8xf32> + %b = tt.trans %a {order = array} : tensor<4x2x8xf32> -> tensor<8x2x4xf32> + // CHECK: %[[res:.*]] = tt.trans %arg0 {order = array} + // CHECK: tt.return %[[res]] + tt.return %b : tensor<8x2x4xf32> +} + +// CHECK-LABEL: test_reshape_reduce +tt.func @test_reshape_reduce(%0: tensor<32x4x2xi32>) -> (i32, tensor<16xi32>) { + // CHECK: tt.reshape %{{.+}} allow_reorder : tensor<32x4x2xi32> -> tensor<256xi32> + %1 = tt.reshape %0 : tensor<32x4x2xi32> -> tensor<256xi32> + %2 = "tt.reduce" (%1) ({ + ^bb0(%arg7: i32, %arg8: i32): + %add = arith.addi %arg7, %arg8 : i32 + tt.reduce.return %add : i32 + }) {axis = 0 : i32} : (tensor<256xi32>) -> i32 + %3 = tt.histogram %1 : tensor<256xi32> -> tensor<16xi32> + tt.return %2, %3 : i32, tensor<16xi32> +} + +// CHECK-LABEL: test_rank_reduce_desc_load +tt.func @test_rank_reduce_desc_load(%0: !tt.tensordesc>) -> (tensor<128x64xf16>) { + %c0 = arith.constant 0 : i32 + // CHECK: %[[R:.+]] = tt.experimental_descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x64xf16> + // CHECK: tt.return %[[R]] + %l = tt.experimental_descriptor_load %0[%c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x128x64xf16> + %r = tt.reshape %l : tensor<1x128x64xf16> -> tensor<128x64xf16> + tt.return %r : tensor<128x64xf16> +} diff --git a/third_party/enflame/include/triton/test/Triton/invalid.mlir b/third_party/enflame/include/triton/test/Triton/invalid.mlir new file mode 100644 index 000000000..1397c980b --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/invalid.mlir @@ -0,0 +1,531 @@ +// RUN: triton-opt --split-input-file %s --verify-diagnostics + +tt.func @fn(%v: i32) { + %b = tt.splat %v : i32 -> tensor<128xi32> + // expected-error @+1 {{rank of source must be same as rank of result}} + %c = tt.broadcast %b : tensor<128xi32> -> tensor<128x32xi32> + tt.return +} + +// ----- + +tt.func @fn(%v: i32) { + %b = tt.splat %v : i32 -> tensor<2x32xi32> + // expected-error @+1 {{Different dimensions at index 0 between source and result. Broadcast requires the source dimension to be 1.}} + %c = tt.broadcast %b : tensor<2x32xi32> -> tensor<128x32xi32> + tt.return +} + +// ----- + +tt.func public @fn(%arg0: tensor<128xf32>) { + // expected-error @+1 {{packed_element}} + %a = tt.elementwise_inline_asm "" + {constraints = "=r,r", packed_element=3:i32, pure=true} %arg0 : tensor<128xf32> -> tensor<128xf32> + tt.return +} + +// ----- + +tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) { + // expected-error @+1 {{same shape}} + %a = tt.elementwise_inline_asm "" + {constraints = "=r,r,r", packed_element=1:i32, pure=true} + %arg0, %arg1: tensor<128xf32>, tensor<64xf32> -> tensor<128xf32> + tt.return +} +// ----- + +tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) { + // expected-error @+1 {{number of src and dst elements of reshape must be the same}} + %a = tt.reshape %arg0 : tensor<32x128xf16> -> tensor<64x32xf16> + tt.return +} + +// ----- + +// expected-note @+1 {{prior use}} +tt.func public @fn(%arg0: tensor<32xf32>, %arg1: tensor<33xf32>) { + // expected-error @+1 {{expects different type}} + %a = tt.join %arg0, %arg1 : tensor<32xf32> -> tensor<32x2xf32> + tt.return +} + +// ----- + +// expected-note @+1 {{prior use}} +tt.func public @fn(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf16>) { + // expected-error @+1 {{expects different type}} + %a = tt.join %arg0, %arg1 : tensor<32x32xf32> -> tensor<32x32x2xf32> + tt.return +} + +// ----- + +tt.func public @fn(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) { + // expected-error @+2 {{op failed to infer returned types}} + // expected-error @+1 {{incompatible with return type}} + %a = tt.join %arg0, %arg1 : tensor<32xf32> -> tensor<64xf32> + tt.return +} + +// ----- + +tt.func public @fn(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) { + // expected-error @+2 {{op failed to infer returned types}} + // expected-error @+1 {{incompatible with return type}} + %a = tt.join %arg0, %arg1 : tensor<32x32xf32> -> tensor<32x64xf32> + tt.return +} + +// ----- + +// This one is OK +tt.func public @fn(%arg0: tensor, %arg1: tensor) { + %a = tt.join %arg0, %arg1 : tensor -> tensor<2xf32> + tt.return +} + +// ----- + +tt.func public @fn(%arg0: f32, %arg1: f32) { + // expected-error @+1 {{kind of type}} + %a = tt.join %arg0, %arg1 : f32 -> tensor<2xf32> + tt.return +} + +// ----- + +tt.func public @fn(%v: tensor<4x128xf64>) { + // expected-error @+1 {{operand types and result types}} + %a = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<4x128xf64>) -> tensor<128xf32> + tt.return +} + +// ----- + +tt.func @reduce_different_input_shapes(%arg0: tensor<32x32x64xf32>, %arg1: tensor<16x32x64xf32>) -> (tensor<32x64xf32>, tensor<16x64xf32>) { + // expected-error @below {{op requires the same shape for all operands}} + %0:2 = "tt.reduce" (%arg0, %arg1) <{axis = 1 : i32}> ({ + ^bb0(%acc0: f32, %acc1: f32, %cur0: f32, %cur1: f32): + %1 = arith.addf %acc0, %cur0 : f32 + %2 = arith.addf %acc1, %cur1 : f32 + tt.reduce.return %1, %2 : f32, f32 + }) : (tensor<32x32x64xf32>, tensor<16x32x64xf32>) -> (tensor<32x64xf32>, tensor<16x64xf32>) + tt.return %0#0, %0#1 : tensor<32x64xf32>, tensor<16x64xf32> +} + +// ----- + +tt.func public @fn(%v: tensor<4x128xf32>) { + // expected-error @+1 {{requires the same shape}} + %a = "tt.scan" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.scan.return %add : f32 + }) {axis = 0 : i32, reverse = false} : (tensor<4x128xf32>) -> tensor<128xf32> + tt.return +} + +// ----- + +tt.func public @fn(%v1: tensor<4x128xf32>, %v2: tensor<4x128xi64>) { + // expected-error @+1 {{operand types and result types}} + %a, %b = "tt.scan" (%v1, %v2) ({ + ^bb0(%arg0: f32, %arg1: i32, %arg2: f32, %arg3: i32): + %add = arith.addf %arg0, %arg2 : f32 + tt.scan.return %add, %arg1 : f32, i32 + }) {axis = 0 : i32, reverse = false} : (tensor<4x128xf32>, tensor<4x128xi64>) -> (tensor<4x128xi64>, tensor<4x128xf32>) + tt.return +} + +// ----- + +tt.func public @fn(%v1: tensor<4x128xf32>, %v2: tensor<4x128xi64>) { + // expected-error @+1 {{operand types and result types}} + %a, %b = "tt.reduce" (%v1, %v2) ({ + ^bb0(%arg0: f32, %arg1: i32, %arg2: f32, %arg3: i32): + %add = arith.addf %arg0, %arg2 : f32 + tt.reduce.return %add, %arg1 : f32, i32 + }) {axis = 0 : i32} : (tensor<4x128xf32>, tensor<4x128xi64>) -> (tensor<128xi64>, tensor<128xf32>) + tt.return +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { + // expected-error @+2 {{op failed to infer returned types}} + // expected-error @+1 {{incompatible with return type}} + %a = tt.join %arg0, %arg0 : tensor<32xf32, #blocked> -> tensor<32x2xf32> + tt.return +} +} // end module + +// ----- + +// Bad order; should be [1,0] +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [0,1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { + // expected-error @+2 {{incompatible with return type(s) of operation}} + // expected-error @+1 {{op failed to infer returned types}} + %a = tt.join %arg0, %arg0 : tensor<32xf32, #blocked> -> tensor<32x2xf32, #blocked1> + tt.return +} +} // end module + +// ----- + +tt.func public @fn(%arg0: tensor<32xf32>) { + // expected-error @+2 {{last dimension}} + // expected-error @+1 {{op failed to infer returned types}} + %a, %b = tt.split %arg0 : tensor<32xf32> -> tensor<16xf32> + tt.return +} + +// ----- + +tt.func public @fn(%arg0: tensor<32x2xf32>) { + // expected-error @+2 {{op inferred type}} + // expected-error @+1 {{op failed to infer returned types}} + %a, %b = tt.split %arg0 : tensor<32x2xf32> -> tensor<32xf16> + tt.return +} + +// ----- + +tt.func public @fn(%arg0: f32) { + // expected-error @+1 {{invalid kind of type}} + %a, %b = tt.split %arg0 : f32 -> f16 + tt.return +} +// ----- + +tt.func public @fn(%arg0: tensor<2xf32>) { + %a, %b = tt.split %arg0 : tensor<2xf32> -> tensor // OK + tt.return +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1,2,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> +// Bad order, should be [1,0]. +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}> + +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { + // expected-error @+2 {{op inferred type}} + // expected-error @+1 {{op failed to infer returned types}} + %a, %b = tt.split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> + tt.return +} +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> +// bad sizePerThread; should be [1,1]. +#blocked1 = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [0,1]}> + +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { + // expected-error @+2 {{op inferred type}} + // expected-error @+1 {{op failed to infer returned types}} + %a, %b = tt.split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> + tt.return +} +} // end module + +// ----- + +// Valid ops. +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<16x32x64xf32>) { + %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32> -> tensor<16x32x64xf32> + %b = tt.trans %arg0 {order = array} : tensor<16x32x64xf32> -> tensor<32x16x64xf32> + tt.return +} +} // end module + +// ----- + +// Valid op with blocked encoding. +#blocked = #ttg.blocked<{sizePerThread = [1,2,3,4], threadsPerWarp = [2,4,2,2], warpsPerCTA = [4,2,4,2], order = [3,2,1,0], CTAsPerCGA = [1,2,2,2], CTASplitNum = [1,2,4,8], CTAOrder = [3,2,1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2,4,3,1], threadsPerWarp = [4,2,2,2], warpsPerCTA = [2,2,4,4], order = [1,2,0,3], CTAsPerCGA = [2,2,2,1], CTASplitNum = [2,8,4,1], CTAOrder = [1,2,0,3]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2,1,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<2x4x8x16xf32, #blocked>, %arg1: tensor<16x32x64xf32, #blocked2>) { + %a = tt.trans %arg0 {order = array} : tensor<2x4x8x16xf32, #blocked> -> tensor<4x16x8x2xf32, #blocked1> + %b = tt.trans %arg1 {order = array} : tensor<16x32x64xf32, #blocked2> -> tensor<32x16x64xf32, #blocked3> + tt.return +} +} // end module + +// ----- + +// Valid op with shared encoding. +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [3, 2, 1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0, 3]}> +#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CTAsPerCGA = [1, 2], CTASplitNum = [2, 4], CTAOrder = [0, 1]}> +#shared3 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32, CTAsPerCGA = [2, 1], CTASplitNum = [4, 2], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: !ttg.memdesc<2x4x8x16xf32, #shared, #smem>, %arg1: !ttg.memdesc<16x32xf32, #shared2, #smem>) { + %a = ttg.memdesc_trans %arg0 {order = array} : !ttg.memdesc<2x4x8x16xf32, #shared, #smem> -> !ttg.memdesc<4x16x8x2xf32, #shared1, #smem> + %b = ttg.memdesc_trans %arg1 {order = array} : !ttg.memdesc<16x32xf32, #shared2, #smem> -> !ttg.memdesc<32x16xf32, #shared3, #smem> + tt.return +} +} // end module + +// ----- + +// Invalid blocked encoding. +#blocked = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<16x32x64xf32, #blocked>) { + // expected-error @+1 {{type}} + %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #blocked> -> tensor<32x16x64xf32, #blocked1> + tt.return +} +} // end module + +// ----- + +// Invalid shared encoding. +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { + // expected-error @+1 {{type}} + %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #shared> -> tensor<32x16x64xf32, #shared1> + tt.return +} +} // end module + +// ----- + +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<16x32xf32>) { + // expected-error @+1 {{order}} + %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> + tt.return +} +} // end module + +// ----- + +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<16x32xf32>) { + // expected-error @+1 {{order}} + %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> + tt.return +} +} // end module + +// ----- + +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<16x32xf32>) { + // expected-error @+1 {{order must be a permutation}} + %a = tt.trans %arg0 {order = array} : tensor<16x32xf32> -> tensor<32x16xf32> + tt.return +} +} // end module + +// ----- + +// Invalid tensor with shared encoding. +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func public @fn(%arg0: tensor<16x32x64xf32, #shared>) { + // expected-error @+1 {{has an invalid layout: Shared layout is not allowed on tensor type.}} + %a = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #shared> -> tensor<32x16x64xf32, #shared1> + tt.return +} +} // end module + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{indices and output shapes must match}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512xf32> + tt.return +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32, #blocked>) { + // expected-error @below {{indices and output encodings must match}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x4xi32, #blocked>) -> tensor<512x4xf32, #blocked1> + tt.return +} +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf16>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{input and output element types must match}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf16>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{input and indices ranks must match}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return +} + +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x32xi32>) { + // expected-error @below {{indices dimension 1 must match the corresponding input dimension}} + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x32xi32>) -> tensor<512x32xf32> + tt.return +} +// ----- + +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) { + // expected-error @below {{gather dimension must be less than the input rank}} + %0 = tt.gather %arg0[%arg1] {axis = 3 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32> + tt.return +} + +// ----- + +tt.func @invalid_desc_load(%arg0: !tt.tensordesc>) { + %c = arith.constant 0 : i32 + // expected-error @below {{ranked reduce load only allowed for unit dimension leading dim}} + tt.experimental_descriptor_load %arg0[%c, %c] : !tt.tensordesc> -> tensor<16xf32> + tt.return +} + +// ----- + +tt.func @invalid_desc_load(%arg0: !tt.tensordesc>) { + %c = arith.constant 0 : i32 + // expected-error @below {{tensor desciptor block and tensor types must match}} + tt.experimental_descriptor_load %arg0[%c, %c] : !tt.tensordesc> -> tensor<16x16xf16> + tt.return +} + +// ----- + +tt.func @invalid_desc_store(%arg0: !tt.tensordesc>, %arg1: tensor<32x16xf32>) { + %c = arith.constant 0 : i32 + // expected-error @below {{tensor desciptor block and tensor types must match}} + tt.experimental_descriptor_store %arg0[%c, %c], %arg1 : !tt.tensordesc>, tensor<32x16xf32> + tt.return +} + +// ----- + +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { + // expected-error @below {{block must be a 2D tensor}} + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32xbf16> + tt.return +} + +// ----- + +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { + // expected-error @below {{block must have exactly 1 row}} + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xbf16> + tt.return +} + +// ----- + +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<1x32xi32>, %arg2: i32) { + // expected-error @below {{x offsets must be a 1D tensor}} + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<1x32xi32>, i32) -> tensor<32x128xbf16> + tt.return +} + +// ----- + +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { + // expected-error @below {{result must be a 2D tensor}} + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<128xbf16> + tt.return +} + +// ----- + +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { + // expected-error @below {{result tensor number of columns must match block (128)}} + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x64xbf16> + tt.return +} + +// ----- + +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { + // expected-error @below {{result tensor must have as many rows as indices (32)}} + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<64x128xbf16> + tt.return +} + +// ----- + +tt.func @invalid_tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { + // expected-error @below {{result tensor element type must match block ('bf16')}} + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xf32> + tt.return +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @invalid_dot(%arg0: tensor<32x32x!tt.ptr, #blocked>, %arg1: tensor<16x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %10 = tt.load %arg1 : tensor<16x32x!tt.ptr, #blocked> + %11 = ttg.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %12 = ttg.local_alloc %10 : (tensor<16x32xf32, #blocked>) -> !ttg.memdesc<16x32xf32, #shared, #smem> + %13 = ttg.local_load %11 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %14 = ttg.local_load %12 : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %15 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + + // expected-error @below {{'tt.dot' op expected the last dimension of the first operand to be equal to the second-to-last dimension of the second operand}} + %16 = tt.dot %13, %14, %15 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %17 = ttg.convert_layout %16 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %arg0, %17 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @dot_scaled_fp8( + %a: tensor<128x32xi8, #blocked2>, + %scale: tensor<128x2xi8, #blocked1>, + %b_fp8: tensor<128x128xf8E4M3FN, #blocked> + ) -> tensor<128x128xf32, #blocked> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + // expected-error @below {{'tt.dot_scaled' op expected the last dimension of the first operand to be equal to the second-to-last dimension of the second operand}} + %result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 {fastMath = true} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<128x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} diff --git a/third_party/enflame/include/triton/test/Triton/loop-unroll.mlir b/third_party/enflame/include/triton/test/Triton/loop-unroll.mlir new file mode 100644 index 000000000..531a14fff --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/loop-unroll.mlir @@ -0,0 +1,46 @@ +// RUN: triton-opt --split-input-file %s -triton-loop-unroll | FileCheck %s + +tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr>, %arg1: i32) { + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tt.splat %c1_i32 : i32 -> tensor<256xi32> + %1 = tt.splat %cst : f32 -> tensor<256xf32> + // Check the loop is unrolled by factor of 2 and is followed by a reminder loop. + // CHECK-LABEL: add_kernel_unroll + // CHECK: scf.for + // CHECK-COUNT-2: tt.load + // CHECK-NOT: tt.load + // CHECK: scf.for + // CHECK: tt.load + // CHECK-NOT: tt.load + // CHECK: tt.num_stages = 1 : i32 + %2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr>) : i32 { + %3 = tt.load %arg5 : tensor<256x!tt.ptr> + %4 = arith.addf %arg4, %3 : tensor<256xf32> + %5 = tt.addptr %arg5, %0 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %4, %5 : tensor<256xf32>, tensor<256x!tt.ptr> + } {tt.loop_unroll_factor = 2 : i32} + tt.return +} + +// ----- + +tt.func @add_kernel_nounroll(%arg0: tensor<256x!tt.ptr>, %arg1: i32) { + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tt.splat %c1_i32 : i32 -> tensor<256xi32> + %1 = tt.splat %cst : f32 -> tensor<256xf32> + // Check the loop is not unrolled. + // CHECK-LABEL: add_kernel_nounroll + // CHECK: scf.for + // CHECK-COUNT-1: tt.load + // CHECK-NOT: tt.load + // CHECK-NOT: scf.for + %2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr>) : i32 { + %3 = tt.load %arg5 : tensor<256x!tt.ptr> + %4 = arith.addf %arg4, %3 : tensor<256xf32> + %5 = tt.addptr %arg5, %0 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %4, %5 : tensor<256xf32>, tensor<256x!tt.ptr> + } + tt.return +} diff --git a/third_party/enflame/include/triton/test/Triton/ops.mlir b/third_party/enflame/include/triton/test/Triton/ops.mlir new file mode 100644 index 000000000..535f9d6e8 --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/ops.mlir @@ -0,0 +1,273 @@ +// RUN: triton-opt %s | FileCheck %s + +// CHECK-LABEL: @cast_ops +tt.func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { + // scalar -> scalar + // CHECK: i64 -> !tt.ptr + %0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr + // CHECK: !tt.ptr -> i64 + %1 = tt.ptr_to_int %scalar_ptr : !tt.ptr -> i64 + // CHECK: f32 to f16 + %2 = arith.truncf %scalar_f32 : f32 to f16 + + // 0D tensor -> 0D tensor + %tensor_ptr_0d = tt.splat %scalar_ptr : !tt.ptr -> tensor> + %tensor_f32_0d = tt.splat %scalar_f32 : f32 -> tensor + %tensor_i64_0d = tt.splat %scalar_i64 : i64 -> tensor + + // CHECK: tensor -> tensor> + %3 = tt.int_to_ptr %tensor_i64_0d : tensor -> tensor> + // CHECK: tensor> -> tensor + %4 = tt.ptr_to_int %tensor_ptr_0d : tensor> -> tensor + // CHECK: tensor to tensor + %5 = arith.truncf %tensor_f32_0d : tensor to tensor + + // 1D tensor -> 1D tensor + %tensor_ptr_1d = tt.splat %scalar_ptr : !tt.ptr -> tensor<16x!tt.ptr> + %tensor_f32_1d = tt.splat %scalar_f32 : f32 -> tensor<16xf32> + %tensor_i64_1d = tt.splat %scalar_i64 : i64 -> tensor<16xi64> + + // CHECK: tensor<16xi64> -> tensor<16x!tt.ptr> + %6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr> + // CHECK: tensor<16x!tt.ptr> -> tensor<16xi64> + %7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr> -> tensor<16xi64> + // CHECK: tensor<16xf32> to tensor<16xf16> + %8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16> + tt.return +} + +// CHECK-LABEL: @addptr_ops +tt.func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { + // scalar -> scalar + // CHECK: !tt.ptr + %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr, i32 + + // 0D tensor -> 0D tensor + %tensor_ptr_0d = tt.splat %scalar_ptr : !tt.ptr -> tensor> + %tensor_i32_0d = tt.splat %scalar_i32 : i32 -> tensor + // CHECK: tensor> + %1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor>, tensor + + // 1D tensor -> 1D tensor + %tensor_ptr_1d = tt.splat %scalar_ptr : !tt.ptr -> tensor<16x!tt.ptr> + %tensor_i32_1d = tt.splat %scalar_i32 : i32 -> tensor<16xi32> + // CHECK: tensor<16x!tt.ptr> + %2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr>, tensor<16xi32> + tt.return +} + +// CHECK-LABEL: @load_store_ops_scalar +tt.func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %mask : i1) { + // Test if Load/Store ops can handle scalar values + %other = arith.constant 0.0e+0 : f32 + + // load scalar + // CHECK: %[[L0:.*]] = tt.load %{{.*}} : !tt.ptr + %a = tt.load %ptr : !tt.ptr + // CHECK: %[[L1:.*]] = tt.load %{{.*}}, %{{.*}} : !tt.ptr + %b = tt.load %ptr, %mask : !tt.ptr + // CHECK: %[[L2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} : !tt.ptr + %c = tt.load %ptr, %mask, %other : !tt.ptr + + // store scalar + // CHECK: tt.store %{{.*}}, %[[L0]] : !tt.ptr + tt.store %ptr, %a : !tt.ptr + // CHECK: tt.store %{{.*}}, %[[L1]], %{{.*}} : !tt.ptr + tt.store %ptr, %b, %mask : !tt.ptr + // CHECK: tt.store %{{.*}}, %[[L2]], %{{.*}} : !tt.ptr + tt.store %ptr, %c, %mask : !tt.ptr + tt.return +} + +// CHECK-LABEL: reduce_ops_infer +tt.func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { + // Test if reduce ops infer types correctly + + // CHECK: tt.reduce + // CHECK-SAME: axis = 0 + // CHECK: tt.reduce.return + // CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<2x4xf32> + %a = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32> + + // CHECK: tt.reduce + // CHECK-SAME: axis = 1 + // CHECK: tt.reduce.return + // CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<1x4xf32> + %b = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32> + + // CHECK: tt.reduce + // CHECK-SAME: axis = 2 + // CHECK: tt.reduce.return + // CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<1x2xf32> + %c = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32> + + // CHECK: tt.reduce + // CHECK-SAME: axis = 1 + // CHECK: tt.reduce.return + // CHECK-NEXT: (tensor<1x4xf32>) -> tensor<1xf32> + %e = "tt.reduce" (%b) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32> + + // CHECK: tt.reduce + // CHECK-SAME: axis = 0 + // CHECK: tt.reduce.return + // CHECK-NEXT: (tensor<2x4xf32>) -> tensor<4xf32> + %f = "tt.reduce" (%a) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32> + + // CHECK: tt.reduce + // CHECK-SAME: axis = 0 + // CHECK: tt.reduce.return + // CHECK-NEXT: (tensor<4xf32>) -> f32 + %g = "tt.reduce" (%f) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<4xf32>) -> f32 + + // Avoid optimizations for c, e, and g + %ptr1x2 = tt.splat %ptr : !tt.ptr -> tensor<1x2x!tt.ptr> + %ptr1 = tt.splat %ptr : !tt.ptr -> tensor<1x!tt.ptr> + tt.store %ptr1x2, %c : tensor<1x2x!tt.ptr> + tt.store %ptr1, %e : tensor<1x!tt.ptr> + tt.store %ptr, %g : !tt.ptr + tt.return +} + +// CHECK-LABEL: @dot_ops_infer +tt.func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { + // Test if reduce ops infer types correctly + %v128x32 = tt.splat %v : f32 -> tensor<128x32xf32> + %v32x128 = tt.splat %v : f32 -> tensor<32x128xf32> + %v128x1 = tt.splat %v : f32 -> tensor<128x1xf32> + %v1x128 = tt.splat %v : f32 -> tensor<1x128xf32> + + %zero128x128 = arith.constant dense<0.00e+00> : tensor<128x128xf32> + %zero32x32 = arith.constant dense<0.00e+00> : tensor<32x32xf32> + %zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32> + + // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> + %r1 = tt.dot %v128x32, %v32x128, %zero128x128 : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32> + %r2 = tt.dot %v32x128, %v128x32, %zero32x32 : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> + // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> + %r3 = tt.dot %v128x1, %v1x128, %zero128x128 : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> + // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32> + %r4 = tt.dot %v1x128, %v128x1, %zero1x1 : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> + + %ptr128x128 = tt.splat %ptr : !tt.ptr -> tensor<128x128x!tt.ptr> + %ptr32x32 = tt.splat %ptr : !tt.ptr -> tensor<32x32x!tt.ptr> + %ptr1x1 = tt.splat %ptr : !tt.ptr -> tensor<1x1x!tt.ptr> + tt.store %ptr128x128, %r1 : tensor<128x128x!tt.ptr> + tt.store %ptr32x32, %r2 : tensor<32x32x!tt.ptr> + tt.store %ptr128x128, %r3 : tensor<128x128x!tt.ptr> + tt.store %ptr1x1, %r4 : tensor<1x1x!tt.ptr> + tt.return +} + +// CHECK-LABEL: @print_no_arg +tt.func @print_no_arg(%arg0: !tt.ptr) { +// CHECK: tt.print "test" + tt.print "test" { hex = false, isSigned = array} + %0 = tt.load %arg0 : !tt.ptr + tt.store %arg0, %0 : !tt.ptr + tt.return +} + +// CHECK-LABEL: scan_op +tt.func @scan_op(%ptr: tensor<1x2x4x!tt.ptr>, %v : tensor<1x2x4xf32>) { + // CHECK: tt.scan + // CHECK-SAME: axis = 1 + // CHECK: tt.scan.return + // CHECK-NEXT: (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + %a = "tt.scan"(%v) <{axis = 1 : i32, reverse = false}>({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.scan.return %add : f32 + }) : (tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + tt.store %ptr, %a : tensor<1x2x4x!tt.ptr> + tt.return +} + +// CHECK-LABEL: inline_asm +// CHECK: tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" +tt.func @inline_asm(%0: tensor<512xi8>) { + %1 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" + {constraints = "=r,r", packed_element = 4 : i32, pure = true} %0 : tensor<512xi8> -> tensor<512xi8> + tt.return +} + +// CHECK-LABEL: inline_asm_scalar +// CHECK: tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" {{.*}} : i32 -> i32 +tt.func @inline_asm_scalar(%0: i32) { + %1 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" + {constraints = "=r,r", packed_element = 1 : i32, pure = true} %0 : i32 -> i32 + tt.return +} + +// CHECK-LABEL: reshape +tt.func @reshape(%0: tensor<512xi32>) { + // CHECK: tt.reshape %{{.+}} : tensor<512xi32> -> tensor<16x32xi32> + %1 = tt.reshape %0 : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32> + %2 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + %3 = tt.reshape %0 allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + %4 = tt.reshape %0 efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + tt.return +} + +// CHECK-LABEL: histogram +tt.func @histogram(%0: tensor<512xi32>) { + // CHECK: tt.histogram %{{.+}} : tensor<512xi32> -> tensor<16xi32> + %1 = tt.histogram %0 : tensor<512xi32> -> tensor<16xi32> + tt.return +} + +// CHECK-LABEL: experimental_descriptor_load +tt.func @experimental_descriptor_load(%0: !tt.tensordesc>) { + // CHECK: tt.experimental_descriptor_load %{{.+}}[%{{.+}}] : !tt.tensordesc> -> tensor<128xf32> + %c0_i32 = arith.constant 0 : i32 + %1 = tt.experimental_descriptor_load %0[%c0_i32] : !tt.tensordesc> -> tensor<128xf32> + tt.return +} + +// CHECK-LABEL: @gather_op +tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x16xi32>) -> tensor<512x16xf32> { + // CHECK-NEXT: %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x16xi32>) -> tensor<512x16xf32> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<128x16xf32>, tensor<512x16xi32>) -> tensor<512x16xf32> + tt.return %0 : tensor<512x16xf32> +} + +// CHECK-LABEL: @tma_gather +tt.func @tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32) { + // CHECK-NEXT: %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xbf16> + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32>, i32) -> tensor<32x128xbf16> + tt.return +} + +// CHECK-LABEL: @tma_scatter +tt.func @tma_scatter(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32>, %arg2: i32, %arg3: tensor<32x128xbf16>) { + // CHECK-NEXT: tt.experimental_descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc>, tensor<32xi32>, i32, tensor<32x128xbf16> + tt.experimental_descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc>, tensor<32xi32>, i32, tensor<32x128xbf16> + tt.return +} diff --git a/third_party/enflame/include/triton/test/Triton/reorder-broadcast.mlir b/third_party/enflame/include/triton/test/Triton/reorder-broadcast.mlir new file mode 100644 index 000000000..7d901feb5 --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/reorder-broadcast.mlir @@ -0,0 +1,67 @@ +// RUN: triton-opt %s -split-input-file -triton-reorder-broadcast | FileCheck %s + +// CHECK-LABEL: @test_splat_elementwise_pattern +tt.func @test_splat_elementwise_pattern(%arg0: f32) -> (tensor<128x128xf32>, tensor<128x128x!tt.ptr>) { + // CHECK-DAG: %[[a:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : i64 + %c1 = arith.constant 1 : i64 + %a = arith.constant dense<1.0> : tensor<128x128xf32> + + // CHECK-DAG: %[[add:.*]] = arith.addf %arg0, %[[a]] : f32 + // CHECK-NEXT: %[[splat:.*]] = tt.splat %[[add]] : f32 -> tensor<128x128xf32> + %b = tt.splat %arg0 : f32 -> tensor<128x128xf32> + %add = arith.addf %a, %b : tensor<128x128xf32> + + + // CHECK-NEXT: %[[ptr:.*]] = tt.int_to_ptr %[[c1]] : i64 -> !tt.ptr + // CHECK-NEXT: %{{.*}} = tt.splat %[[ptr]] : !tt.ptr -> tensor<128x128x!tt.ptr> + %c1_t = tt.splat %c1 : i64 -> tensor<128x128xi64> + %ptr = tt.int_to_ptr %c1_t : tensor<128x128xi64> -> tensor<128x128x!tt.ptr> + + tt.return %add, %ptr : tensor<128x128xf32>, tensor<128x128x!tt.ptr> +} + +// CHECK-LABEL: @test_broadcast_elementwise_pattern +tt.func @test_broadcast_elementwise_pattern(%arg0: tensor<128x1xf32>) -> (tensor<128x128xf32>, tensor<128x32xf32>) { + // CHECK: %[[one:.*]] = arith.constant dense<1.000000e+00> : tensor<128x1xf32> + + // CHECK-NEXT: %[[abs:.*]] = math.absf %arg0 : tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[abs]] : tensor<128x1xf32> -> tensor<128x128xf32> + %broadcast = tt.broadcast %arg0 : tensor<128x1xf32> -> tensor<128x128xf32> + %abs = math.absf %broadcast : tensor<128x128xf32> + + // CHECK-NEXT: %[[add:.*]] = arith.addf %arg0, %[[one]] : tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[add]] : tensor<128x1xf32> -> tensor<128x32xf32> + %broadcast2 = tt.broadcast %arg0 : tensor<128x1xf32> -> tensor<128x32xf32> + %one = arith.constant dense<1.0> : tensor<128x32xf32> + %add = arith.addf %one, %broadcast2 : tensor<128x32xf32> + + tt.return %abs, %add : tensor<128x128xf32>, tensor<128x32xf32> +} + +// CHECK-LABEL: @test_broadcast_binary_op_pattern +tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tensor<128x1xf32>, %arg2: tensor<1x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[mul]] : tensor<128x1xf32> -> tensor<128x128xf32> + %broadcast0 = tt.broadcast %arg0 : tensor<128x1xf32> -> tensor<128x128xf32> + %broadcast1 = tt.broadcast %arg1 : tensor<128x1xf32> -> tensor<128x128xf32> + %mul = arith.mulf %broadcast0, %broadcast1 : tensor<128x128xf32> + + // CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x128xf32> + %broadcast2 = tt.broadcast %arg2 : tensor<1x128xf32> -> tensor<128x128xf32> + %mul1 = arith.mulf %broadcast0, %broadcast2 : tensor<128x128xf32> + + tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32> +} + +// CHECK-LABEL: @test_broadcast_mix_type_op_pattern +tt.func @test_broadcast_mix_type_op_pattern(%arg0: tensor<128x1xf32>, %arg1: f32, %arg2: tensor<1x128xf32>, %arg3: tensor<128x1xi1>) -> (tensor<128x128xf32>) { + // CHECK: %[[sel:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<128x1xi1>, tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[sel]] : tensor<128x1xf32> -> tensor<128x128xf32> + %broadcast0 = tt.broadcast %arg0 : tensor<128x1xf32> -> tensor<128x128xf32> + %broadcast1 = tt.splat %arg1 : f32 -> tensor<128x128xf32> + %cond = tt.broadcast %arg3 : tensor<128x1xi1> -> tensor<128x128xi1> + %sel = arith.select %cond, %broadcast0, %broadcast1 : tensor<128x128xi1>, tensor<128x128xf32> + + tt.return %sel : tensor<128x128xf32> +} diff --git a/third_party/enflame/include/triton/test/Triton/reproducer.mlir b/third_party/enflame/include/triton/test/Triton/reproducer.mlir new file mode 100644 index 000000000..5a6747d21 --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/reproducer.mlir @@ -0,0 +1,20 @@ +// RUN: triton-opt --verify-diagnostics --dump-pass-pipeline --run-reproducer %s 2>&1 | FileCheck %s + +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton__() attributes {noinline = false} { + tt.return + } +} + +{-# + external_resources: { + mlir_reproducer: { + pipeline: "builtin.module(any(convert-scf-to-cf,convert-index-to-llvm{index-bitwidth=0},convert-triton-gpu-to-llvm{compute-capability=90},convert-nv-gpu-to-llvm,convert-arith-to-llvm{index-bitwidth=0},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,symbol-dce,enable-line-info))", + disable_threading: false, + verify_each: false + } + } +#-} + +// CHECK: Pass Manager with +// CHECK-NEXT: convert-triton-gpu-to-llvm diff --git a/third_party/enflame/include/triton/test/Triton/rewrite-tensor-pointer.mlir b/third_party/enflame/include/triton/test/Triton/rewrite-tensor-pointer.mlir new file mode 100644 index 000000000..eb39dcac0 --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/rewrite-tensor-pointer.mlir @@ -0,0 +1,218 @@ +// RUN: triton-opt %s -triton-rewrite-tensor-pointer -split-input-file | FileCheck %s + +tt.func public @rewrite_load(%arg0: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + %load = tt.load %0 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @rewrite_load( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[SPLAT0:.*]] = tt.splat %[[ARG0]] : !tt.ptr -> tensor<128x32x!tt.ptr> +// CHECK: %[[SPLAT1:.*]] = tt.splat %[[EXTSI0]] : i64 -> tensor<128xi64> +// CHECK: %[[MAKE_RANGE0:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[MAKE_RANGE0]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[ADDI0:.*]] = arith.addi %[[SPLAT1]], %[[EXTSI2]] : tensor<128xi64> +// CHECK: %[[EXPAND_DIMS0:.*]] = tt.expand_dims %[[ADDI0]] {axis = 1 : i32} : tensor<128xi64> -> tensor<128x1xi64> +// CHECK: %[[SPLAT2:.*]] = tt.splat %[[C1_I64]] : i64 -> tensor<128x1xi64> +// CHECK: %[[MULI0:.*]] = arith.muli %[[EXPAND_DIMS0]], %[[SPLAT2]] : tensor<128x1xi64> +// CHECK: %[[BROADCAST0:.*]] = tt.broadcast %[[MULI0]] : tensor<128x1xi64> -> tensor<128x32xi64> +// CHECK: %[[ADDPTR0:.*]] = tt.addptr %[[SPLAT0]], %[[BROADCAST0]] : tensor<128x32x!tt.ptr>, tensor<128x32xi64> +// CHECK: %[[SPLAT3:.*]] = tt.splat %[[EXTSI1]] : i64 -> tensor<32xi64> +// CHECK: %[[MAKE_RANGE1:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> +// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[MAKE_RANGE1]] : tensor<32xi32> to tensor<32xi64> +// CHECK: %[[ADDI1:.*]] = arith.addi %[[SPLAT3]], %[[EXTSI3]] : tensor<32xi64> +// CHECK: %[[EXPAND_DIMS1:.*]] = tt.expand_dims %[[ADDI1]] {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64> +// CHECK: %[[SPLAT4:.*]] = tt.splat %[[C1_I64]] : i64 -> tensor<1x32xi64> +// CHECK: %[[MULI1:.*]] = arith.muli %[[EXPAND_DIMS1]], %[[SPLAT4]] : tensor<1x32xi64> +// CHECK: %[[BROADCAST1:.*]] = tt.broadcast %[[MULI1]] : tensor<1x32xi64> -> tensor<128x32xi64> +// CHECK: %[[ADDPTR1:.*]] = tt.addptr %[[ADDPTR0]], %[[BROADCAST1]] : tensor<128x32x!tt.ptr>, tensor<128x32xi64> +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[SPLAT5:.*]] = tt.splat %[[C0_I64]] : i64 -> tensor<1x32xi64> +// CHECK: %[[CMP0:.*]] = arith.cmpi sge, %[[EXPAND_DIMS1]], %[[SPLAT5]] : tensor<1x32xi64> +// CHECK: %[[SPLAT6:.*]] = tt.splat %[[C32_I64]] : i64 -> tensor<1x32xi64> +// CHECK: %[[CMPI:.*]] = arith.cmpi slt, %[[EXPAND_DIMS1]], %[[SPLAT6]] : tensor<1x32xi64> +// CHECK: %[[ANDI:.*]] = arith.andi %[[CMP0]], %[[CMPI]] : tensor<1x32xi1> +// CHECK: %[[BROADCAST2:.*]] = tt.broadcast %[[ANDI]] : tensor<1x32xi1> -> tensor<128x32xi1> +// CHECK: %[[OTHER:.*]] = arith.constant 0x7E00 : f16 +// CHECK: %[[SPLAT7:.*]] = tt.splat %[[OTHER]] : f16 -> tensor<128x32xf16> +// CHECK: %[[LOAD:.*]] = tt.load %[[ADDPTR1]], %[[BROADCAST2]], %[[SPLAT7]] : tensor<128x32x!tt.ptr> +// CHECK: tt.return + +// ----- +tt.func public @rewrite_store(%arg0: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + tt.store %0, %cst: !tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @rewrite_store( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[SPLAT0:.*]] = tt.splat %[[ARG0]] : !tt.ptr -> tensor<128x32x!tt.ptr> +// CHECK: %[[SPLAT1:.*]] = tt.splat %[[EXTSI0]] : i64 -> tensor<128xi64> +// CHECK: %[[MAKE_RANGE0:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[MAKE_RANGE0]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[ADDI0:.*]] = arith.addi %[[SPLAT1]], %[[EXTSI2]] : tensor<128xi64> +// CHECK: %[[EXPAND_DIMS0:.*]] = tt.expand_dims %[[ADDI0]] {axis = 1 : i32} : tensor<128xi64> -> tensor<128x1xi64> +// CHECK: %[[SPLAT2:.*]] = tt.splat %[[C1_I64]] : i64 -> tensor<128x1xi64> +// CHECK: %[[MULI0:.*]] = arith.muli %[[EXPAND_DIMS0]], %[[SPLAT2]] : tensor<128x1xi64> +// CHECK: %[[BROADCAST0:.*]] = tt.broadcast %[[MULI0]] : tensor<128x1xi64> -> tensor<128x32xi64> +// CHECK: %[[ADDPTR0:.*]] = tt.addptr %[[SPLAT0]], %[[BROADCAST0]] : tensor<128x32x!tt.ptr>, tensor<128x32xi64> +// CHECK: %[[SPLAT3:.*]] = tt.splat %[[EXTSI1]] : i64 -> tensor<32xi64> +// CHECK: %[[MAKE_RANGE1:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> +// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[MAKE_RANGE1]] : tensor<32xi32> to tensor<32xi64> +// CHECK: %[[ADDI1:.*]] = arith.addi %[[SPLAT3]], %[[EXTSI3]] : tensor<32xi64> +// CHECK: %[[EXPAND_DIMS1:.*]] = tt.expand_dims %[[ADDI1]] {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64> +// CHECK: %[[SPLAT4:.*]] = tt.splat %[[C1_I64]] : i64 -> tensor<1x32xi64> +// CHECK: %[[MULI1:.*]] = arith.muli %[[EXPAND_DIMS1]], %[[SPLAT4]] : tensor<1x32xi64> +// CHECK: %[[BROADCAST1:.*]] = tt.broadcast %[[MULI1]] : tensor<1x32xi64> -> tensor<128x32xi64> +// CHECK: %[[ADDPTR1:.*]] = tt.addptr %[[ADDPTR0]], %[[BROADCAST1]] : tensor<128x32x!tt.ptr>, tensor<128x32xi64> +// CHECK: tt.store %[[ADDPTR1]], %[[CST]] : tensor<128x32x!tt.ptr> +// CHECK: tt.return + +// ----- +tt.func public @rewrite_for(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + %1:2 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst, %arg4 = %0) -> (tensor<128x32xf16>, !tt.ptr>) { + %3 = tt.load %arg4 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> + %4 = arith.addf %arg3, %3 : tensor<128x32xf16> + %5 = tt.advance %arg4, [%c32_i32, %c0_i32] : !tt.ptr> + scf.yield %4, %5 : tensor<128x32xf16>, !tt.ptr> + } {tt.num_stages = 3 : i32} + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x32x!tt.ptr> + tt.store %2, %1#0 : tensor<128x32x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @rewrite_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C32_I32:.*]] = arith.constant 32 : i32 +// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[FOR:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C32]] step %[[C1]] +// CHECK-SAME: iter_args(%[[ARG3:.*]] = %[[CST]], %[[ARG4:.*]] = %[[EXTSI0]], %[[ARG5:.*]] = %[[EXTSI1]]) -> (tensor<128x32xf16>, i64, i64) +// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[C32_I32]] : i32 to i64 +// CHECK: %[[ADDI0:.*]] = arith.addi %[[ARG4]], %[[EXTSI2]] : i64 +// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[ADDI1:.*]] = arith.addi %[[ARG5]], %[[EXTSI3]] : i64 +// CHECK: scf.yield %{{.*}}, %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64 +// CHECK: tt.num_stages = 3 + +// ----- +tt.func public @rewrite_if(%arg0: !tt.ptr, %arg1: i1, %arg2: tensor<128x32xf32>) -> tensor<128x32xf16> { + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + %1:2 = scf.if %arg1 -> (tensor<128x32xf16>, !tt.ptr>) { + %2 = tt.advance %0, [%c32_i32, %c0_i32] : !tt.ptr> + %3 = arith.truncf %arg2 : tensor<128x32xf32> to tensor<128x32xf16> + scf.yield %3, %2 : tensor<128x32xf16>, !tt.ptr> + } else { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + scf.yield %cst, %0 : tensor<128x32xf16>, !tt.ptr> + } + %4 = tt.load %1#1 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> + %5 = arith.addf %1#0, %4 : tensor<128x32xf16> + tt.return %5 : tensor<128x32xf16> +} + +// CHECK-LABEL: tt.func public @rewrite_if( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i1 +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<128x32xf32> +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C32_I32:.*]] = arith.constant 32 : i32 +// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[IF:.*]]:3 = scf.if %[[ARG1]] -> (tensor<128x32xf16>, i64, i64) { +// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[C32_I32]] : i32 to i64 +// CHECK: %[[ADDI0:.*]] = arith.addi %[[EXTSI0]], %[[EXTSI2]] : i64 +// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[ADDI1:.*]] = arith.addi %[[EXTSI1]], %[[EXTSI3]] : i64 +// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[ARG2]] : tensor<128x32xf32> to tensor<128x32xf16> +// CHECK: scf.yield %[[TRUNCF]], %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64 +// CHECK: } else { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: scf.yield %[[CST]], %[[EXTSI0]], %[[EXTSI1]] : tensor<128x32xf16>, i64, i64 +// CHECK: } +// CHECK: %{{.*}} = tt.splat %[[IF]]#1 : i64 -> tensor<128xi64> +// CHECK: %{{.*}} = tt.splat %[[IF]]#2 : i64 -> tensor<32xi64> +// CHECK: %{{.*}} = arith.addf %[[IF]]#0, %{{.*}} : tensor<128x32xf16> + + +// ----- +tt.func public @asm_in_loop(%arg0: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i64 = arith.constant 0 : i64 + %c128_i64 = arith.constant 128 : i64 + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.make_tensor_ptr %arg0, [%c128_i64, %c128_i64], [%c128_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + %2:1 = scf.for %arg1 = %c0_i32 to %c1_i32 step %c1_i32 iter_args(%arg2 = %1) -> (!tt.ptr>) : i32 { + %3:2 = tt.elementwise_inline_asm "asm_multiple_results" {constraints = "=r,=r,r", packed_element = 1 : i32, pure = true} %0 : tensor<16xi32> -> tensor<16xi16>, tensor<16xi16> + %4 = tt.advance %arg2, [%c0_i32, %c0_i32] : !tt.ptr> + scf.yield %4 : !tt.ptr> + } + tt.return +} + +// CHECK-LABEL: tt.func public @asm_in_loop( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[RANGE:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[FOR:.*]]:2 = scf.for %[[ARG1:.*]] = %[[C0_I32]] to %[[C1_I32]] step %[[C1_I32]] +// CHECK-SAME: iter_args(%[[ARG2:.*]] = %[[EXTSI0]], %[[ARG3:.*]] = %[[EXTSI1]]) -> (i64, i64) +// CHECK: %[[ASM:.*]]:2 = tt.elementwise_inline_asm "asm_multiple_results" {{.*}} %[[RANGE]] : tensor<16xi32> -> tensor<16xi16>, tensor<16xi16> diff --git a/third_party/enflame/include/triton/test/Triton/vecadd.mlir b/third_party/enflame/include/triton/test/Triton/vecadd.mlir new file mode 100644 index 000000000..40c4210d2 --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/vecadd.mlir @@ -0,0 +1,127 @@ +// RUN: triton-opt %s -verify-diagnostics + +module { + tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + %0 = tt.get_program_id x : i32 + %c256_i32 = arith.constant 256 : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %3 = tt.splat %1 : i32 -> tensor<256xi32> + %4 = arith.addi %3, %2 : tensor<256xi32> + %5 = tt.splat %arg3 : i32 -> tensor<256xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<256xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + %cst = arith.constant 0.000000e+00 : f32 + %11 = tt.splat %cst : f32 -> tensor<256xf32> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %15:3 = scf.for %arg6 = %c0_i32 to %arg4 step %c32_i32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr>) : i32 { + %cst_0 = arith.constant 0.000000e+00 : f32 + %18 = tt.splat %cst_0 : f32 -> tensor<256xf32> + %19 = tt.load %arg8, %6, %18 : tensor<256x!tt.ptr> + %cst_1 = arith.constant 0.000000e+00 : f32 + %20 = tt.splat %cst_1 : f32 -> tensor<256xf32> + %21 = tt.load %arg9, %6, %20 : tensor<256x!tt.ptr> + %22 = arith.addf %19, %21 : tensor<256xf32> + %23 = arith.addf %arg7, %22 : tensor<256xf32> + %24 = tt.splat %arg5 : i32 -> tensor<256xi32> + %25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr>, tensor<256xi32> + %26 = tt.splat %arg5 : i32 -> tensor<256xi32> + %27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr> + } + %16 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr> + %17 = tt.addptr %16, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %17, %15#0, %6 : tensor<256x!tt.ptr> + tt.return + } +} +// module { +// tt.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { +// %c64 = arith.constant 64 : index +// %c32 = arith.constant 32 : index +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : f32 +// %c256_i32 = arith.constant 256 : i32 +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c256_i32 : i32 +// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg<"coalesced encoding">> +// %3 = tt.broadcast %1 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %4 = arith.addi %3, %2 : tensor<256xi32, #ttg<"coalesced encoding">> +// %5 = tt.broadcast %arg3 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %6 = arith.cmpi "slt", %4, %5 : (tensor<256xi32, #ttg<"coalesced encoding">>, tensor<256xi32, #ttg<"coalesced encoding">>) -> tensor<256xi1, #ttg<"coalesced encoding">> +// %7 = tt.broadcast %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %9 = tt.broadcast %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %11 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %12 = arith.index_cast %arg4 : i32 to index +// %13 = arith.cmpi slt, %c0, %12 : index +// %14 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %15 = tt.broadcast %13 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %16 = arith.andi %6, %15 : tensor<256xi1, #ttg<"coalesced encoding">> +// %17 = ttg.copy_async %8, %16, %14 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %18 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %19 = tt.broadcast %13 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %20 = arith.andi %6, %19 : tensor<256xi1, #ttg<"coalesced encoding">> +// %21 = ttg.copy_async %10, %20, %18 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %22 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %24 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %26 = arith.cmpi slt, %c32, %12 : index +// %27 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %28 = tt.broadcast %26 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %29 = arith.andi %6, %28 : tensor<256xi1, #ttg<"coalesced encoding">> +// %30 = ttg.copy_async %23, %29, %27 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %31 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %32 = tt.broadcast %26 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %33 = arith.andi %6, %32 : tensor<256xi1, #ttg<"coalesced encoding">> +// %34 = ttg.copy_async %25, %33, %31 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %35 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %37 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %39 = arith.cmpi slt, %c64, %12 : index +// %40 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %41 = tt.broadcast %39 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %42 = arith.andi %6, %41 : tensor<256xi1, #ttg<"coalesced encoding">> +// %43 = ttg.copy_async %36, %42, %40 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %44 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %45 = tt.broadcast %39 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %46 = arith.andi %6, %45 : tensor<256xi1, #ttg<"coalesced encoding">> +// %47 = ttg.copy_async %38, %46, %44 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %48 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %50 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, index) { +// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #ttg<"coalesced encoding">> +// %56 = arith.addf %arg7, %55 : tensor<256xf32, #ttg<"coalesced encoding">> +// %57 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %59 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %61 = arith.addi %arg18, %c32 : index +// %62 = arith.cmpi slt, %61, %12 : index +// %63 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %64 = tt.broadcast %62 : i1 -> tensor<256xi1, #ttg<"coalesced encoding">> +// %65 = arith.andi %64, %6 : tensor<256xi1, #ttg<"coalesced encoding">> +// %66 = ttg.copy_async %arg17, %65, %63 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %67 = tt.broadcast %cst : f32 -> tensor<256xf32, #ttg<"coalesced encoding">> +// %68 = ttg.copy_async %arg16, %65, %67 : tensor<256x!tt.ptr, #ttg<"coalesced encoding">> -> tensor<256xf32, #ttg<"coalesced encoding">> +// %69 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// %71 = tt.broadcast %arg5 : i32 -> tensor<256xi32, #ttg<"coalesced encoding">> +// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256xf32, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, index +// } +// %53 = tt.broadcast %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #ttg<"coalesced encoding">> +// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr, #ttg<"coalesced encoding">>, tensor<256xi32> +// tt.store %54, %52#0, %6 : tensor<256xf32, #ttg<"coalesced encoding">> +// tt.return +// } +// } diff --git a/third_party/enflame/include/triton/test/Triton/verify-make-range.mlir b/third_party/enflame/include/triton/test/Triton/verify-make-range.mlir new file mode 100644 index 000000000..1dc439c83 --- /dev/null +++ b/third_party/enflame/include/triton/test/Triton/verify-make-range.mlir @@ -0,0 +1,35 @@ +// RUN: triton-opt --split-input-file %s --verify-diagnostics + +tt.func public @i64_tensor() { + // expected-error @+1 {{i32 elements}} + %a = tt.make_range { start = 0 : i32, end = 16 : i32 } : tensor<16xi64> + tt.return +} + +// ----- +tt.func public @i32_scalar() { + // expected-error @+1 {{invalid kind of type}} + %a = tt.make_range { start = 0 : i32, end = 16 : i32 } : i32 + tt.return +} + +// ----- +tt.func public @_2d_tensor() { + // expected-error @+1 {{must be a 1D tensor}} + %a = tt.make_range { start = 0 : i32, end = 16 : i32 } : tensor<16x1xi32> + tt.return +} + +// ----- +tt.func public @bad_start_end() { + // expected-error @+1 {{start must be less than or equal to end}} + %a = tt.make_range { start = 0 : i32, end = -16 : i32 } : tensor<16xi32> + tt.return +} + +// ----- +tt.func public @bad_num_elems() { + // expected-error @+1 {{number of elements}} + %a = tt.make_range { start = 0 : i32, end = 32 : i32 } : tensor<16xi32> + tt.return +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/accelerate-matmul.mlir b/third_party/enflame/include/triton/test/TritonGPU/accelerate-matmul.mlir new file mode 100644 index 000000000..9c478bd5a --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/accelerate-matmul.mlir @@ -0,0 +1,509 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul | FileCheck %s + +// CHECK: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +// CHECK: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +// CHECK: #[[MMA2:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK: mma_chain_loop + tt.func public @mma_chain_loop( + %170: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %171: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %179: tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>, + %164: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>, + %165: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>>, + %173: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>, + %153: tensor<128x64x!tt.ptr, #blocked1>) { + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x16xf16, #blocked> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2> + // CHECK: scf.for + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { + %172 = tt.dot %170, %171, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> + %178 = ttg.convert_layout %172 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %180 = tt.dot %178, %179, %arg16 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + scf.yield %180 : tensor<128x64xf16, #blocked1> + } + // CHECK: scf.for + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { + %166 = tt.dot %164, %165, %cst_2 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> + %172 = ttg.convert_layout %166 : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %174 = tt.dot %172, %173, %arg16 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + scf.yield %174 : tensor<128x64xf16, #blocked1> + } + tt.store %153, %149 : tensor<128x64x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: chained_dot + tt.func public @chained_dot( + %arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> + // CHECK: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> + %d = tt.dot %arg0, %arg1, %cst_0 : + tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> + %c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> + %r = tt.dot %c, %arg2, %cst_1 : + tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> + tt.return %r : tensor<64x128xf32, #blocked1> + } +} + +// ----- + +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> +// CHECK: #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: chained_dot + tt.func public @chained_dot_wgmma( + %arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma> + %d = tt.dot %arg0, %arg1, %cst_0 : + tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> + %c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1> + %r = tt.dot %c, %arg2, %cst_1 : + tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> + tt.return %r : tensor<64x128xf32, #blocked1> + } +} + +// ----- + +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:89", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: fp8_dot + tt.func public @fp8_dot( + %arg0: tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x64xf32, #blocked> { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> + // CHECK: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> + %d = tt.dot %arg0, %arg1, %cst_0 : + tensor<64x128xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + tt.return %d : tensor<64x64xf32, #blocked> + } +} + +// ----- + +// CHECK-DAG: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +// CHECK-DAG: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> + +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [0, 1, 2]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 4, 8], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK: kernel_ + tt.func public @kernel_() attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<2x16x16xf32, #blocked> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> + %0 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %1 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> + %2 = ttg.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #blocked1> + // CHECK: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]> + %3 = tt.dot %0, %1, %2, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<16x16xf32, #blocked1> + %4 = ttg.convert_layout %3 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16x16xf32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16x16xf32, #blocked2> + %6 = ttg.convert_layout %5 : tensor<1x16x16xf32, #blocked2> -> tensor<1x16x16xf32, #blocked> + %7 = tt.broadcast %6 : tensor<1x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked> + %8 = ttg.convert_layout %7 : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> + %9 = ttg.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> + %10 = ttg.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked3> + // CHECK: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]> + %11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3> + %12 = ttg.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked> + tt.print ": " {hex = false, isSigned = array} : %12 : tensor<2x16x16xf32, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 3, {{.*}}, instrShape = [16, 32, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: check_instrShape_per_warps + tt.func @check_instrShape_per_warps(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %mask = arith.constant dense : tensor<128x128xi1, #blocked> + %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %a = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %b = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + + %result = tt.dot %a, %b, %zero_f32 : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %result_ptr = tt.splat %arg0 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> + tt.store %result_ptr, %result, %mask : tensor<128x128x!tt.ptr, #blocked> + tt.return + } +} + + +// ----- + +// Verify that we use mmav2 when the k dim is too small for mmav3. +// CHECK: #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: small_k_size + tt.func @small_k_size( + %a: tensor<128x16xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %b: tensor<16x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) + -> tensor<128x128xf32, #blocked> { + %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<128x16xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding + // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> + // CHECK-DAG: #[[$T:.+]] = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> + // CHECK-LABEL: mmav5 + // CHECK-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem + // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !ttg.memdesc<64x256xf16, #{{.*}}, #smem + // CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> + // CHECK: ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]], %[[TRUE]], %[[TRUE]] : (!ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x256xf16, #shared, #smem>, !ttg.memdesc<128x256xf32, #[[$TMEM]], #ttng.tensor_memory, mutable>, i1, i1) -> () + // CHECK: %[[R:.+]] = ttng.tmem_load %[[ACC]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32 + // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$T]]> -> tensor<128x256xf32, #[[$B]]> + // CHECK: tt.return %[[CVT]] : tensor<128x256xf32 + tt.func public @mmav5(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { + %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %bd = ttg.convert_layout %b : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.return %d : tensor<128x256xf32, #blocked> + } +} + +// ----- + +// CHECK: #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 8], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [16, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [16, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-label: mmav5_fallback_v2_num_warps + tt.func public @mmav5_fallback_v2_num_warps(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { + %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %bd = ttg.convert_layout %b : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.return %d : tensor<128x256xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: mmav5_fp32 + // CHECK-DAG: %[[AD:.+]] = ttg.convert_layout %{{.*}} : tensor<128x64xf32, + // CHECK-DAG: %[[BD:.+]] = ttg.convert_layout %{{.*}} : tensor<64x256xf32, + // CHECK-DAG: %[[D:.*]] = tt.dot %[[AD]], %[[BD]], %{{.*}} + // CHECK: tt.return %[[D]] : tensor<128x256xf32 + tt.func public @mmav5_fp32(%a: tensor<128x64xf32, #blocked2>, %b: tensor<64x256xf32, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { + %ad = ttg.convert_layout %a : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %bd = ttg.convert_layout %b : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %d = tt.dot %ad, %bd, %c : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.return %d : tensor<128x256xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding + // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-DAG: #[[$T:.+]] = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-LABEL: mmav5 + // CHECK-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem + // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !ttg.memdesc<64x256xf16, #{{.*}}, #smem + // CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> + // CHECK: ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]], %[[TRUE]], %[[TRUE]] : (!ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x256xf16, #shared, #smem>, !ttg.memdesc<128x256xf32, #[[$TMEM]], #ttng.tensor_memory, mutable>, i1, i1) -> () + // CHECK: %[[R:.+]] = ttng.tmem_load %[[ACC]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32 + // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$T]]> -> tensor<128x256xf32, #[[$B]]> + // CHECK: tt.return %[[CVT]] : tensor<128x256xf32 + tt.func public @mmav5_multi_ctas(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { + %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %bd = ttg.convert_layout %b : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.return %d : tensor<128x256xf32, #blocked> + } +} + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding + // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-DAG: #[[$T:.+]] = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> + // CHECK-LABEL: mmav5 + // CHECK-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem + // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !ttg.memdesc<64x256xf16, #{{.*}}, #smem + // CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> + // CHECK: ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]], %[[TRUE]], %[[TRUE]] {two_ctas} : (!ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x256xf16, #shared1, #smem>, !ttg.memdesc<128x256xf32, #[[$TMEM]], #ttng.tensor_memory, mutable>, i1, i1) -> () + // CHECK: %[[R:.+]] = ttng.tmem_load %[[ACC]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32 + // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$T]]> -> tensor<128x256xf32, #[[$B]]> + // CHECK: tt.return %[[CVT]] : tensor<128x256xf32 + tt.func public @mmav5_2ctas(%a: tensor<128x64xf16, #blocked2>, %b_ptr: tensor<64x256x!tt.ptr, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { + %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %b = tt.load %b_ptr : tensor<64x256x!tt.ptr, #blocked1> + %bd = ttg.convert_layout %b : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.return %d : tensor<128x256xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding + // CHECK-DAG: #[[$TMEM1:.+]] = #ttng.tensor_memory_scales_encoding + // CHECK-LABEL: mmav5_block_scaled + // CHECK-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xi8, #{{.*}}>) -> !ttg.memdesc<128x64xi8, #{{.*}}, #smem + // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x128xi8, #{{.*}}>) -> !ttg.memdesc<64x128xi8, #{{.*}}, #smem + // CHECK-DAG: %[[SCALEA_LOCAL:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #{{.*}}, #smem> + // CHECK: ttg.local_load %[[SCALEA_LOCAL]] : !ttg.memdesc<128x2xi8, #{{.*}}, #smem> -> tensor<128x2xi8, #{{.*}}> + // CHECK-DAG: %[[SCALEB_LOCAL:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #{{.*}}, #smem> + // CHECK: ttg.local_load %[[SCALEB_LOCAL]] : !ttg.memdesc<128x2xi8, #{{.*}}, #smem> -> tensor<128x2xi8, #{{.*}}> + // CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x128xf32, #{{.*}}>) -> !ttg.memdesc<128x128xf32, #{{.*}}, #ttng.tensor_memory, mutable> + // CHECK: %[[SCALEA:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #[[$TMEM1]], #ttng.tensor_memory> + // CHECK: %[[SCALEB:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #[[$TMEM1]], #ttng.tensor_memory> + // CHECK: ttng.tc_gen5_mma_scaled %[[A]], %[[B]], %[[ACC]], %[[SCALEA]], %[[SCALEB]], %[[TRUE]], %[[TRUE]] lhs = e4m3 rhs = e4m3 + tt.func public @mmav5_block_scaled(%a: tensor<128x64xi8, #blocked2>, %scale_a_ptr: tensor<128x2x!tt.ptr, #blocked1>, %b: tensor<64x128xi8, #blocked>, %scale_b_ptr: tensor<128x2x!tt.ptr, #blocked1>) -> tensor<128x128xf32, #blocked> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %scale_a = tt.load %scale_a_ptr: tensor<128x2x!tt.ptr, #blocked1> + %scale_b = tt.load %scale_b_ptr: tensor<128x2x!tt.ptr, #blocked1> + %d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked>, tensor<128x2xi8, #blocked1> -> tensor<128x128xf32, #blocked> + tt.return %d : tensor<128x128xf32, #blocked> + } +} + +// ----- + +// Verify that dot_scaled (mxfp4 x {bf16,fp8}) decomposes to mmav3 if it's bf16, otherwise it fallsback to mmav2 +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[LINEAR:.+]] = #ttg.linear<{{.*}}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK: dot_scaled + tt.func @dot_scaled( + %a: tensor<128x32xi8, #blocked2>, + %scale: tensor<128x2xi8, #blocked1>, + %b_bf16: tensor<64x128xbf16, #blocked> + ) -> tensor<128x128xf32, #blocked> { + // CHECK: ttg.fp4_to_fp + // CHECK: ttng.warp_group_dot + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 {fastMath = false} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } + + // Verify that dot_scaled (mxfp4 x fp8) decomposes into mmav3 as well + // CHECK: dot_scaled_fp8 + tt.func @dot_scaled_fp8( + %a: tensor<128x32xi8, #blocked2>, + %scale: tensor<128x2xi8, #blocked1>, + %b_fp8: tensor<64x128xf8E4M3FN, #blocked> + ) -> tensor<128x128xf32, #blocked> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + // CHECK: ttg.fp4_to_fp + // CHECK: ttng.warp_group_dot + %result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 {fastMath = true} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} + +// ----- + +// Mixed dtype matmul with upcasting on the left is transposed and uses MMAv3 +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: mixed_dtype_matmul + tt.func @mixed_dtype_matmul( + %a: tensor<64x32xf32, #blocked2>, + %b: tensor<32x64xf8E4M3FN, #blocked1>, + %c: tensor<64x64xf32, #blocked> + ) -> tensor<64x64xf32, #blocked> { + %b_upcast = tt.fp_to_fp %b : tensor<32x64xf8E4M3FN, #blocked1> -> tensor<32x64xf32, #blocked1> + %a_cvt = ttg.convert_layout %a : tensor<64x32xf32, #blocked2> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %b_cvt = ttg.convert_layout %b_upcast : tensor<32x64xf32, #blocked1> -> tensor<32x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: ttng.warp_group_dot + %d = tt.dot %a_cvt, %b_cvt, %c, inputPrecision = tf32 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + tt.return %d : tensor<64x64xf32, #blocked> +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> + // CHECK-DAG: #[[$S:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8, fp4Padded = true}> + tt.func public @mmav5_block_scaled_mixed_prec(%a: tensor<128x64xi8, #blocked2>, %scale_a: tensor<128x2xi8, #blocked1>, %b: tensor<32x128xi8, #blocked>, %scale_b: tensor<128x2xi8, #blocked1>) -> tensor<128x128xf32, #blocked> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + // CHECK: ttg.local_alloc %arg2 : (tensor<32x128xi8, #[[$B]]>) -> !ttg.memdesc<32x128xi8, #[[$S]], #smem> + %d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<32x128xi8, #blocked>, tensor<128x2xi8, #blocked1> -> tensor<128x128xf32, #blocked> + tt.return %d : tensor<128x128xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 4, 8, 1, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 1, 2, 3, 0]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[32, 0], [64, 0], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = []}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding + // CHECK-DAG: #[[$TMEM1:.+]] = #ttng.tensor_memory_scales_encoding + // CHECK-LABEL: mmav5_block_scaled_5d_scale + // CHECK-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x128xi8, #{{.*}}>) -> !ttg.memdesc<128x128xi8, #{{.*}}, #smem + // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x128xi8, #{{.*}}>) -> !ttg.memdesc<128x128xi8, #{{.*}}, #smem + // CHECK-DAG: %[[SCALEA_LOCAL:.+]] = ttg.local_alloc + // CHECK: ttg.local_load %[[SCALEA_LOCAL]] + // CHECK-DAG: %[[SCALEB_LOCAL:.+]] = ttg.local_alloc + // CHECK: ttg.local_load %[[SCALEB_LOCAL]] + // CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x128xf32, #{{.*}}>) -> !ttg.memdesc<128x128xf32, #{{.*}}, #ttng.tensor_memory, mutable> + // CHECK: %[[SCALEA:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x4xi8, #{{.*}}>) -> !ttg.memdesc<128x4xi8, #[[$TMEM1]], #ttng.tensor_memory> + // CHECK: %[[SCALEB:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x4xi8, #{{.*}}>) -> !ttg.memdesc<128x4xi8, #[[$TMEM1]], #ttng.tensor_memory> + // CHECK: ttng.tc_gen5_mma_scaled %[[A]], %[[B]], %[[ACC]], %[[SCALEA]], %[[SCALEB]], %[[TRUE]], %[[TRUE]] lhs = e4m3 rhs = e4m3 + tt.func public @mmav5_block_scaled_5d_scale(%a: tensor<128x128xi8, #blocked2>, %scale_a_ptr: tensor<1x1x32x4x4x!tt.ptr, #blocked3>, %b: tensor<128x128xi8, #blocked>, %scale_b_ptr: tensor<1x1x32x4x4x!tt.ptr, #blocked3>) -> tensor<128x128xf32, #blocked> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %scale_a_5d = tt.load %scale_a_ptr: tensor<1x1x32x4x4x!tt.ptr, #blocked3> + %scale_a_trans = tt.trans %scale_a_5d {order = array} : tensor<1x1x32x4x4xi8, #blocked3> -> tensor<1x4x32x1x4xi8, #blocked4> + %scale_a = tt.reshape %scale_a_trans : tensor<1x4x32x1x4xi8, #blocked4> -> tensor<128x4xi8, #linear> + %scale_b_5d = tt.load %scale_b_ptr: tensor<1x1x32x4x4x!tt.ptr, #blocked3> + %scale_b_trans = tt.trans %scale_b_5d {order = array} : tensor<1x1x32x4x4xi8, #blocked3> -> tensor<1x4x32x1x4xi8, #blocked4> + %scale_b = tt.reshape %scale_b_trans : tensor<1x4x32x1x4xi8, #blocked4> -> tensor<128x4xi8, #linear> + %d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xi8, #blocked2>, tensor<128x4xi8, #linear> * tensor<128x128xi8, #blocked>, tensor<128x4xi8, #linear> -> tensor<128x128xf32, #blocked> + tt.return %d : tensor<128x128xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +tt.func @scalar_load_in_bwd_slice(%arg0: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg1: !tt.tensordesc>, %arg2: !tt.ptr) -> tensor<128x128xf32, #blocked> { + %0 = tt.load %arg2 : !tt.ptr + %1 = tt.experimental_descriptor_load %arg1[%0, %0] : !tt.tensordesc> -> tensor<128x128xf8E5M2, #blocked1> + %2 = ttg.convert_layout %1 : tensor<128x128xf8E5M2, #blocked1> -> tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %3 = tt.dot %2, %arg0, %cst, inputPrecision = tf32 : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + tt.return %3 : tensor<128x128xf32, #blocked> +} +} + +// ----- + +// check for heuristic to increase kWidth when join is present +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 16, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked6 = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @join_reshape_upcast_mma_kwidth(%84: tensor<16x256x!tt.ptr, #blocked3>, %112: tensor<64x128x!tt.ptr, #blocked2>) -> tensor<16x64xf32, #blocked> { + %90 = tt.load %84 : tensor<16x256x!tt.ptr, #blocked3> + %118 = tt.load %112, : tensor<64x128x!tt.ptr, #blocked2> + %121:2 = tt.elementwise_inline_asm "" {constraints = "=r,=r,=r,=r,r", packed_element = 4 : i32, pure = true} %118 : tensor<64x128xi8, #blocked2> -> tensor<64x128xbf16, #blocked2>, tensor<64x128xbf16, #blocked2> + %122 = tt.join %121#0, %121#1 : tensor<64x128xbf16, #blocked2> -> tensor<64x128x2xbf16, #blocked4> + %123 = tt.reshape %122 : tensor<64x128x2xbf16, #blocked4> -> tensor<64x256xbf16, #blocked5> + %124 = tt.trans %123 {order = array} : tensor<64x256xbf16, #blocked5> -> tensor<256x64xbf16, #blocked6> + %125 = ttg.convert_layout %90 : tensor<16x256xbf16, #blocked3> -> tensor<16x256xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %126 = ttg.convert_layout %124 : tensor<256x64xbf16, #blocked6> -> tensor<256x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: {{.*}} = tt.dot {{.*}} tensor<16x256xbf16, #ttg.dot_op<{opIdx = 0, parent = {{.*}}, kWidth = 8}>> * tensor<256x64xbf16, #ttg.dot_op<{opIdx = 1, parent = {{.*}}, kWidth = 8}>> + %cst = arith.constant dense<0.000000e+00> : tensor<16x64xf32, #blocked> + %127 = tt.dot %125, %126, %cst, inputPrecision = tf32 : tensor<16x256xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x64xf32, #blocked> + tt.return %127 : tensor<16x64xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-DAG: #[[$TMEM1:.+]] = #ttng.tensor_memory_scales_encoding + // CHECK{LITERALE}-DAG: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> + // CHECK-LABEL: mmav5_block_scaled_8_warps + // CHECK: ttng.tmem_alloc %{{.*}} : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #[[$TMEM1]], #ttng.tensor_memory> + // CHECK: ttng.tmem_alloc %{{.*}} : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #[[$TMEM1]], #ttng.tensor_memory> + // CHECK: ttng.tc_gen5_mma_scaled + tt.func public @mmav5_block_scaled_8_warps(%a: tensor<128x256xi8, #blocked2>, %scale_a: tensor<128x8xi8, #blocked1>, %b: tensor<256x128xi8, #blocked>, %scale_b: tensor<128x8xi8, #blocked1>) -> tensor<128x128xf32, #blocked> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x256xi8, #blocked2>, tensor<128x8xi8, #blocked1> * tensor<256x128xi8, #blocked>, tensor<128x8xi8, #blocked1> -> tensor<128x128xf32, #blocked> + tt.return %d : tensor<128x128xf32, #blocked> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/accumulator-init.mlir b/third_party/enflame/include/triton/test/TritonGPU/accumulator-init.mlir new file mode 100644 index 000000000..3596410c7 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/accumulator-init.mlir @@ -0,0 +1,382 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-accumulator-init | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK-LABEL: @constant_init +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @constant_init_integer +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init_integer(%A: !ttg.memdesc<128x64xi8, #shared, #smem>, %B: !ttg.memdesc<64x16xi8, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xi32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0> : tensor<128x16xi32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xi32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xi8, #shared, #smem> * !ttg.memdesc<64x16xi8, #shared1, #smem> -> tensor<128x16xi32, #mma1> + scf.yield %acc: tensor<128x16xi32, #mma1> + } + tt.return %17 : tensor<128x16xi32, #mma1> + } + +// CHECK-LABEL: @if_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @if_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_after_mma_invert +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[CND]] + tt.func @if_after_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %acc : tensor<128x16xf32, #mma1> + } else { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_before_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_before_mma_invert +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[USE_ACC]], %[[FALSE]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } else { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @sel_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @sel_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @sel_before_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @sel_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + + +// Check that we look only at the zeroing directly preceding the mma + +// CHECK-LABEL: @if_before_and_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[C0_TENSOR]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_and_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } + %acc = ttng.warp_group_dot %A, %B, %acc_0 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + scf.yield %acc_1: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @two_ifs_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[C0_TENSOR]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_CND]] +// CHECK: else +// CHECK: scf.yield %[[ACC_CND]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @two_ifs_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc_0 : tensor<128x16xf32, #mma1> + } + scf.yield %acc_1: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// Check that we bail out in unsupported cases + +// CHECK-LABEL: @non_zero_init +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @non_zero_init(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @zero_init_dist_2 +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @zero_init_dist_2(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %cst_2) -> (tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %arg5 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_, %arg4: tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_defines_alternative +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @if_defines_alternative(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + %acc_alt = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> + scf.yield %acc_alt : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @non_cond_override +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @non_cond_override(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_ = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// If the condition is a tensor skip the optimization. +// CHECK-LABEL: @negative_sel_tensor +// CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc + tt.func @negative_sel_tensor(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-chain-dot.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-chain-dot.mlir new file mode 100644 index 000000000..8fe4e2c80 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-chain-dot.mlir @@ -0,0 +1,131 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16' | FileCheck %s --check-prefixes MFMA16,CHECK +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32' | FileCheck %s --check-prefixes MFMA32,CHECK + +// Check the warpsPerCTA parameter of #mma layout of the two dot's. +// The 1st dot always has warpsPerCTA = [4, 1]. +// The warpsPerCTA for the 2nd dot depends on mfma instruction size and BLOCK_M size. + + +// BLOCK_M = 128 +// warpsPerCTA = [4, 1] for mfma16 and mfma32 +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +// MFMA16{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +// MFMA32{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +// CHECK-LABEL: mfma_chain_dot_BM128 +// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<128x16xf32, #mma> +// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<128x128xf32, #mma> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_chain_dot_BM128( + %q: tensor<128x128xf16, #dotOp0>, + %k: tensor<128x16xf16, #dotOp1>, + %v: tensor<16x128xf16, #dotOp1>, + %o_ptr: tensor<128x128x!tt.ptr, #blocked>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #blocked> + %cst1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %qk = tt.dot %q, %k, %cst : tensor<128x128xf16, #dotOp0> * tensor<128x16xf16, #dotOp1> -> tensor<128x16xf32, #blocked> + %qk_f16 = arith.truncf %qk : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> + %p = ttg.convert_layout %qk_f16 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #dotOp0> + %o = tt.dot %p, %v, %cst1 : tensor<128x16xf16, #dotOp0> * tensor<16x128xf16, #dotOp1> -> tensor<128x128xf32, #blocked> + tt.store %o_ptr, %o : tensor<128x128x!tt.ptr, #blocked> + tt.return + } +} + + +// ----- + +// BLOCK_M = 64 +// warpsPerCTA = [4, 1] for mfma16 +// warpsPerCTA = [2, 2] for mfma32 +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +// MFMA16{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +// MFMA32{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +// MFMA32{LITERAL}: #mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +// CHECK-LABEL: mfma_chain_dot_BM64 +// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<64x16xf32, #mma> +// MFMA16: tt.dot {{.*}} : {{.*}} -> tensor<64x128xf32, #mma> +// MFMA32: tt.dot {{.*}} : {{.*}} -> tensor<64x128xf32, #mma1> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_chain_dot_BM64( + %q: tensor<64x128xf16, #dotOp0>, + %k: tensor<128x16xf16, #dotOp1>, + %v: tensor<16x128xf16, #dotOp1>, + %o_ptr: tensor<64x128x!tt.ptr, #blocked>) { + %cst = arith.constant dense<0.000000e+00> : tensor<64x16xf32, #blocked> + %cst1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked> + %qk = tt.dot %q, %k, %cst : tensor<64x128xf16, #dotOp0> * tensor<128x16xf16, #dotOp1> -> tensor<64x16xf32, #blocked> + %qk_f16 = arith.truncf %qk : tensor<64x16xf32, #blocked> to tensor<64x16xf16, #blocked> + %p = ttg.convert_layout %qk_f16 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #dotOp0> + %o = tt.dot %p, %v, %cst1 : tensor<64x16xf16, #dotOp0> * tensor<16x128xf16, #dotOp1> -> tensor<64x128xf32, #blocked> + tt.store %o_ptr, %o : tensor<64x128x!tt.ptr, #blocked> + tt.return + } +} + + +// ----- + +// BLOCK_M = 32 +// warpsPerCTA = [2, 2] for mfma16 +// warpsPerCTA = [1, 4] for mfma32 +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +// MFMA16{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +// MFMA32{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +// MFMA16{LITERAL}: #mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}> +// MFMA32{LITERAL}: #mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}> +// CHECK-LABEL: mfma_chain_dot_BM32 +// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<32x16xf32, #mma> +// MFMA16: tt.dot {{.*}} : {{.*}} -> tensor<32x128xf32, #mma1> +// MFMA32: tt.dot {{.*}} : {{.*}} -> tensor<32x128xf32, #mma1> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_chain_dot_BM32( + %q: tensor<32x128xf16, #dotOp0>, + %k: tensor<128x16xf16, #dotOp1>, + %v: tensor<16x128xf16, #dotOp1>, + %o_ptr: tensor<32x128x!tt.ptr, #blocked>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf32, #blocked> + %cst1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #blocked> + %qk = tt.dot %q, %k, %cst : tensor<32x128xf16, #dotOp0> * tensor<128x16xf16, #dotOp1> -> tensor<32x16xf32, #blocked> + %qk_f16 = arith.truncf %qk : tensor<32x16xf32, #blocked> to tensor<32x16xf16, #blocked> + %p = ttg.convert_layout %qk_f16 : tensor<32x16xf16, #blocked> -> tensor<32x16xf16, #dotOp0> + %o = tt.dot %p, %v, %cst1 : tensor<32x16xf16, #dotOp0> * tensor<16x128xf16, #dotOp1> -> tensor<32x128xf32, #blocked> + tt.store %o_ptr, %o : tensor<32x128x!tt.ptr, #blocked> + tt.return + } +} + + +// ----- + +// BLOCK_M = 16, only check mfma16 since it's too small for mfma32 +// warpsPerCTA = [1, 4] for mfma16 +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +// MFMA16{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +// MFMA16{LITERAL}: #mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = true}> +// CHECK-LABEL: mfma_chain_dot_BM16 +// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<16x16xf32, #mma> +// MFMA16: tt.dot {{.*}} : {{.*}} -> tensor<16x128xf32, #mma1> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_chain_dot_BM16( + %q: tensor<16x128xf16, #dotOp0>, + %k: tensor<128x16xf16, #dotOp1>, + %v: tensor<16x128xf16, #dotOp1>, + %o_ptr: tensor<16x128x!tt.ptr, #blocked>) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked> + %qk = tt.dot %q, %k, %cst : tensor<16x128xf16, #dotOp0> * tensor<128x16xf16, #dotOp1> -> tensor<16x16xf32, #blocked> + %qk_f16 = arith.truncf %qk : tensor<16x16xf32, #blocked> to tensor<16x16xf16, #blocked> + %p = ttg.convert_layout %qk_f16 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #dotOp0> + %o = tt.dot %p, %v, %cst1 : tensor<16x16xf16, #dotOp0> * tensor<16x128xf16, #dotOp1> -> tensor<16x128xf32, #blocked> + tt.store %o_ptr, %o : tensor<16x128x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-fma.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-fma.mlir new file mode 100644 index 000000000..49c93d235 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-fma.mlir @@ -0,0 +1,106 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942' | FileCheck %s + +// CHECK: fma_dot_fp16_fp16 +// CHECK: %[[D:.*]] = tt.dot {{.*}} : tensor<2x64xf16, {{.*}}> * tensor<64x64xf16, {{.*}}> -> tensor<2x64xf16, {{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_fp16_fp16( + %arg0: tensor<2x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.0> : tensor<2x64xf16, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xf16, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: fma_dot_fp32_fp32 +// CHECK: tt.dot {{.*}} : tensor<2x64xf32, {{.*}}> * tensor<64x64xf32, {{.*}}> -> tensor<2x64xf32, {{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_fp32_fp32( + %arg0: tensor<2x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.0> : tensor<2x64xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xf32, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: fma_dot_i8 +// CHECK: tt.dot {{.*}} : tensor<2x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<64x64xi8, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<2x64xi32, #[[BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_i8( + %arg0: tensor<2x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0> : tensor<2x64xi32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xi32, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: fma_dot_f16 +// CHECK: tt.dot {{.*}} : tensor<2x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<2x64xf32, #[[BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_f16( + %arg0: tensor<2x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.0> : tensor<2x64xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xf32, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: fma_dot_f8 +// CHECK: tt.dot {{.*}} : tensor<2x64xf32, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<2x64xf32, #[[BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_f8( + %arg0: tensor<2x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.0> : tensor<2x64xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xf32, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: fma_dot_i8_i8 +// CHECK-DAG: %[[A:.*]] = arith.sitofp +// CHECK-DAG: %[[B:.*]] = arith.sitofp +// CHECK: %[[D:.*]] = tt.dot %[[A]], %[[B]], {{.*}} : tensor<2x64xf16, {{.*}}> * tensor<64x64xf16, {{.*}}> -> tensor<2x64xf16, {{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_i8_i8( + %arg0: tensor<2x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0> : tensor<2x64xi8, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xi8, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir new file mode 100644 index 000000000..6f8c66930 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir @@ -0,0 +1,229 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx950 matrix-instruction-size=0' | FileCheck %s --check-prefixes CHECK + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 32], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 16]], warp = [[0, 0], [32, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [32, 0], [0, 64]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [16, 0]], warp = [[0, 32], [0, 0]], block = []}> +// CHECK{LITERAL}: #linear2 = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 0], [32, 0]], block = []}> +// CHECK{LITERAL}: #linear3 = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[32, 0], [0, 0]], block = []}> +// CHECK-LABEL: mfma_dot_scaled_mxfp4_mxfp4 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_scaled_mxfp4_mxfp4( + %arg0: tensor<128x64xi8, #blocked>, + %arg1: tensor<64x128xi8, #blocked1>, + %arg2: tensor<128x4xi8, #blocked2>, + %arg3: tensor<128x4xi8, #blocked2>, + %arg4: tensor<128x128x!tt.ptr, #blocked1> + ) { + // CHECK-NOT: arith.constant dense<127> : tensor<128x4xi8, #linear> + // CHECK-NOT: arith.constant dense<127> : tensor<128x4xi8, #linear1> + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, #blocked1> -> tensor<128x128xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #linear> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<64x128xi8, #blocked1> -> tensor<64x128xi8, #linear1> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<128x4xi8, #blocked2> -> tensor<128x4xi8, #linear2> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<128x4xi8, #blocked2> -> tensor<128x4xi8, #linear3> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e2m1 rhs = e2m1 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #blocked>, tensor<128x4xi8, #blocked2> * tensor<64x128xi8, #blocked1>, tensor<128x4xi8, #blocked2> -> tensor<128x128xf32, #blocked1> + tt.store %arg4, %1 : tensor<128x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_scaled_mxfp4_fp4 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_scaled_mxfp4_fp4( + %arg0: tensor<128x64xi8, #blocked>, + %arg1: tensor<64x128xi8, #blocked1>, + %arg2: tensor<128x4xi8, #blocked2>, + %arg3: tensor<128x128x!tt.ptr, #blocked1> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[CST1:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<128x4xi8, #blocked2> -> tensor<128x4xi8, #linear3> + // CHECK: tt.dot_scaled {{.*}} scale %[[SCALE0]], {{.*}} scale %[[CST1]], {{.*}} lhs = e2m1 rhs = e2m1 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #blocked>, tensor<128x4xi8, #blocked2> * tensor<64x128xi8, #blocked1> -> tensor<128x128xf32, #blocked1> + tt.store %arg3, %1 : tensor<128x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_scaled_fp4_mxfp4 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_scaled_fp4_mxfp4( + %arg0: tensor<128x64xi8, #blocked>, + %arg1: tensor<64x128xi8, #blocked1>, + %arg2: tensor<128x4xi8, #blocked2>, + %arg3: tensor<128x128x!tt.ptr, #blocked1> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[CST0:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<128x4xi8, #blocked2> -> tensor<128x4xi8, #linear3> + // CHECK: tt.dot_scaled {{.*}} scale %[[CST0]], {{.*}} scale %[[SCALE1]], {{.*}} lhs = e2m1 rhs = e2m1 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %1 = tt.dot_scaled %arg0, %arg1 scale %arg2, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #blocked> * tensor<64x128xi8, #blocked1>, tensor<128x4xi8, #blocked2> -> tensor<128x128xf32, #blocked1> + tt.store %arg3, %1 : tensor<128x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> +// #blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_scaled_fp4_fp4 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_scaled_fp4_fp4( + %arg0: tensor<128x64xi8, #blocked>, + %arg1: tensor<64x128xi8, #blocked1>, + %arg2: tensor<128x128x!tt.ptr, #blocked1> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: tt.dot_scaled {{[^ ]+}}, {{[^ ]+}}, {{[^ ]+}} lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #linear> * tensor<64x128xi8, #linear1> -> tensor<128x128xf32, #mma> + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %1 = tt.dot_scaled %arg0, %arg1, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #blocked> * tensor<64x128xi8, #blocked1> -> tensor<128x128xf32, #blocked1> + tt.store %arg2, %1 : tensor<128x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 32], [0, 64], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 16]], warp = [[0, 0], [32, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [32, 0], [64, 0], [0, 64]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [16, 0]], warp = [[0, 32], [0, 0]], block = []}> +// CHECK{LITERAL}: #linear2 = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 0], [32, 0]], block = []}> +// CHECK{LITERAL}: #linear3 = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[32, 0], [0, 0]], block = []}> +// CHECK-LABEL: mfma_dot_scaled_mxfp8e4_mxfp8e4 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_scaled_mxfp8e4_mxfp8e4( + %arg0: tensor<128x128xf8E4M3FN, #blocked>, + %arg1: tensor<128x128xf8E4M3FN, #blocked>, + %arg2: tensor<128x4xi8, #blocked1>, + %arg3: tensor<128x4xi8, #blocked1>, + %arg4: tensor<128x128x!tt.ptr, #blocked> + ) { + // CHECK-NOT: arith.constant dense<127> : tensor<128x4xi8, #linear> + // CHECK-NOT: arith.constant dense<127> : tensor<128x4xi8, #linear1> + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<128x128xf8E4M3FN, #blocked> -> tensor<128x128xf8E4M3FN, #linear> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<128x128xf8E4M3FN, #blocked> -> tensor<128x128xf8E4M3FN, #linear1> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<128x4xi8, #blocked1> -> tensor<128x4xi8, #linear2> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<128x4xi8, #blocked1> -> tensor<128x4xi8, #linear3> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E4M3FN, #blocked>, tensor<128x4xi8, #blocked1> * tensor<128x128xf8E4M3FN, #blocked>, tensor<128x4xi8, #blocked1> -> tensor<128x128xf32, #blocked> + tt.store %arg4, %1 : tensor<128x128x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_scaled_fp8e4_mxfp4 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_scaled_fp8e4_mxfp4( + %arg0: tensor<128x128xf8E4M3FN, #blocked>, + %arg1: tensor<64x128xi8, #blocked>, + %arg2: tensor<128x4xi8, #blocked1>, + %arg3: tensor<128x128x!tt.ptr, #blocked> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[CST0:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<128x4xi8, #blocked1> -> tensor<128x4xi8, #linear3> + // CHECK: tt.dot_scaled {{.*}} scale %[[CST0]], {{.*}} scale %[[SCALE1]], {{.*}} lhs = e4m3 rhs = e2m1 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %1 = tt.dot_scaled %arg0, %arg1 scale %arg2, %cst lhs = e4m3 rhs = e2m1 {fastMath = false} : tensor<128x128xf8E4M3FN, #blocked> * tensor<64x128xi8, #blocked>, tensor<128x4xi8, #blocked1> -> tensor<128x128xf32, #blocked> + tt.store %arg3, %1 : tensor<128x128x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_scaled_mxfp4_fp8e5 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_scaled_mxfp4_fp8e5( + %arg0: tensor<128x64xi8, #blocked>, + %arg1: tensor<128x128xf8E5M2, #blocked>, + %arg2: tensor<128x4xi8, #blocked1>, + %arg3: tensor<128x128x!tt.ptr, #blocked> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[CST1:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<128x4xi8, #blocked1> -> tensor<128x4xi8, #linear3> + // CHECK: tt.dot_scaled {{.*}} scale %[[SCALE0]], {{.*}} scale %[[CST1]], {{.*}} lhs = e2m1 rhs = e5m2 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1, %cst lhs = e2m1 rhs = e5m2 {fastMath = false} : tensor<128x64xi8, #blocked>, tensor<128x4xi8, #blocked1> * tensor<128x128xf8E5M2, #blocked> -> tensor<128x128xf32, #blocked> + tt.store %arg3, %1 : tensor<128x128x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#dot_op_a = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dot_op_b = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +// CHECK-LABEL: mfma_bf8_dot_to_dot_scaled +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_bf8_dot_to_dot_scaled( + %arg0: tensor<128x64xf8E5M2, #dot_op_a>, + %arg1: tensor<64x128xf8E5M2, #dot_op_b>, + %arg2: tensor<128x128x!tt.ptr, #blocked> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK-NOT: tt.dot {{.*}}, {{.*}}, {{.*}} + // CHECK-DAG: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xf8E5M2, #linear> + // CHECK-DAG: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<64x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xf8E5M2, #linear1> + // CHECK: tt.dot_scaled %[[A]], %[[B]], {{.*}} lhs = e5m2 rhs = e5m2 {fastMath = false} : tensor<128x64xf8E5M2, #linear> * tensor<64x128xf8E5M2, #linear1> -> tensor<128x128xf32, #mma> + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #dot_op_a> * tensor<64x128xf8E5M2, #dot_op_b> -> tensor<128x128xf32, #blocked> + tt.store %arg2, %1 : tensor<128x128x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#dot_op_a = #ttg.dot_op<{opIdx = 0, parent = #blocked}> +#dot_op_b = #ttg.dot_op<{opIdx = 1, parent = #blocked}> +// CHECK-LABEL: mfma_fp16_dot_to_dot +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_fp16_dot_to_dot( + %arg0: tensor<128x64xf16, #dot_op_a>, + %arg1: tensor<64x128xf16, #dot_op_b>, + %arg2: tensor<128x128x!tt.ptr, #blocked> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK-NOT: tt.dot_scaled + // CHECK-DAG: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + // CHECK-DAG: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + // CHECK: tt.dot %[[A]], %[[B]], {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma> + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf16, #dot_op_a> * tensor<64x128xf16, #dot_op_b> -> tensor<128x128xf32, #blocked> + tt.store %arg2, %1 : tensor<128x128x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir new file mode 100644 index 000000000..3da8d93ae --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir @@ -0,0 +1,57 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=0' | FileCheck %s --check-prefixes MFMA0,CHECK +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16' | FileCheck %s --check-prefixes MFMA16,CHECK + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_fp8e5m2 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_fp8e5m2( + %arg0: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<128x256x!tt.ptr, #blocked>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> + // CHECK: %[[A0:.+]] = ttg.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[B0:.+]] = ttg.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: tt.dot %[[A1]], %[[B1]] + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.store %arg2, %1 : tensor<128x256x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// Verify that we use FMA when the N dimension is too small for any mma. +// MFMA0-NOT: #ttg.amd_mfma +// MFMA16: #ttg.amd_mfma +// CHECK-LABEL: small_n_size +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 64], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @small_n_size( + %a: tensor<4x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %b: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) + -> tensor<4x128xf32, #blocked> { + %zero_f32 = arith.constant dense<0.000000e+00> : tensor<4x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<4x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<4x128xf32, #blocked> + tt.return %result : tensor<4x128xf32, #blocked> + } +} + +// ----- + +// MFMA0-NOT: amd_mfma +// MFMA16: amd_mfma +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_small_k +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_small_k( + %arg0: tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<128x256x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.store %arg2, %1 : tensor<128x256x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir new file mode 100644 index 000000000..843a943a0 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir @@ -0,0 +1,150 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx1100 matrix-instruction-size=0' | FileCheck %s + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 1, isTranspose = false, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_cf32( + // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<128x256x!tt.ptr, #blocked>) { + // CHECK: %[[DOT0_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT0_OP_C:.+]] = ttg.convert_layout %[[DOT0_ARG_C]] + // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] + %3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> + // CHECK: %[[DOT0_OP_A:.+]] = ttg.convert_layout %[[DOT0_ARG_A]] + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_B:.+]] = ttg.convert_layout %[[DOT0_ARG_B]] + // CHECK-SAME: -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_WMMA_RES:.+]] = tt.dot %[[DOT0_OP_A]], %[[DOT0_OP_B]], %[[DOT0_OP_C]] + // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] + %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + // CHECK: ttg.convert_layout %[[DOT0_WMMA_RES]] + // CHECK-SAME: -> tensor<128x256xf32, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<128x256x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 1, isTranspose = false, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_cf16( + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x32x!tt.ptr, #blocked>) { + // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] + // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] + %3 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] + // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] + // CHECK-SAME: -> tensor<32x32xf16, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 1, isTranspose = false, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_ab8_cf16( + // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x64x!tt.ptr, #blocked>) { + // CHECK: %[[DOT2_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT2_OP_C:.+]] = ttg.convert_layout %[[DOT2_ARG_C]] + // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] + %3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #blocked> + // CHECK: %[[DOT2_OP_A_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_A]] + // CHECK-SAME: -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_A_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_A_F8]] + // CHECK-SAME: -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 16}>> + // CHECK: %[[DOT2_OP_B_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_B]] + // CHECK-SAME: -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_B_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_B_F8]] + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 16}>> + // CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A_F16]], %[[DOT2_OP_B_F16]], %[[DOT2_OP_C]] + // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] + %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT2_WMMA_RES]] + // CHECK-SAME: -> tensor<32x64xf16, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 1, isTranspose = false, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_i8_i32( + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x32x!tt.ptr, #blocked>) { + // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0> : tensor<32x32xi32, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] + // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] + %3 = arith.constant dense<0> : tensor<32x32xi32, #blocked> + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] + // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] + // CHECK-SAME: -> tensor<32x32xi32, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @fma_dot_i16_i16( + // CHECK: %[[DOT3_ARG_A:.+]]: tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT3_ARG_B:.+]]: tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<128x32x!tt.ptr, #blocked>) { + // CHECK: %[[DOT3_OP_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #[[DOT_OP_PARENT]]> + %3 = arith.constant dense<0> : tensor<128x32xi16, #blocked> + // CHECK: %[[DOT3_OP_A:.+]] = arith.sitofp %[[DOT3_ARG_A]] + // CHECK-SAME: to tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]] + // CHECK: %[[DOT3_OP_B:.+]] = arith.sitofp %[[DOT3_ARG_B]] + // CHECK-SAME: to tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]] + // CHECK: %[[DOT3_FMA_RES:.+]] = tt.dot %[[DOT3_OP_A]], %[[DOT3_OP_B]], %[[DOT3_OP_C]] + // CHECK-SAME: -> tensor<128x32xf32, #[[DOT_OP_PARENT]]> + %4 = tt.dot %0, %1, %3 : tensor<128x64xi16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xi16, #blocked> + // CHECK: arith.fptosi %[[DOT3_FMA_RES]] + // CHECK-SAME: to tensor<128x32xi16, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<128x32x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir new file mode 100644 index 000000000..f9d24964f --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir @@ -0,0 +1,123 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx1200 matrix-instruction-size=0' | FileCheck %s + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 2, isTranspose = false, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_cf32( + // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<128x256x!tt.ptr, #blocked>) { + // CHECK: %[[DOT0_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT0_OP_C:.+]] = ttg.convert_layout %[[DOT0_ARG_C]] + // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] + %3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> + // CHECK: %[[DOT0_OP_A:.+]] = ttg.convert_layout %[[DOT0_ARG_A]] + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_B:.+]] = ttg.convert_layout %[[DOT0_ARG_B]] + // CHECK-SAME: -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_WMMA_RES:.+]] = tt.dot %[[DOT0_OP_A]], %[[DOT0_OP_B]], %[[DOT0_OP_C]] + // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] + %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + // CHECK: ttg.convert_layout %[[DOT0_WMMA_RES]] + // CHECK-SAME: -> tensor<128x256xf32, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<128x256x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 2, isTranspose = false, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_cf16( + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x32x!tt.ptr, #blocked>) { + // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] + // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] + %3 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] + // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] + // CHECK-SAME: -> tensor<32x32xf16, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #ttg.amd_wmma<{version = 2, isTranspose = false, warpsPerCTA = [1, 4]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_ab8_cf16( + // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x64x!tt.ptr, #blocked>) { + // CHECK: %[[DOT2_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT2_OP_C:.+]] = ttg.convert_layout %[[DOT2_ARG_C]] + // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] + %3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #blocked> + // CHECK: %[[DOT2_OP_A_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_A]] + // CHECK-SAME: -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_A_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_A_F8]] + // CHECK-SAME: -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 8}>> + // CHECK: %[[DOT2_OP_B_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_B]] + // CHECK-SAME: -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_B_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_B_F8]] + // CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 8}>> + // CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A_F16]], %[[DOT2_OP_B_F16]], %[[DOT2_OP_C]] + // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] + %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> + // CHECK: ttg.convert_layout %[[DOT2_WMMA_RES]] + // CHECK-SAME: -> tensor<32x64xf16, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #ttg.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #ttg.amd_wmma<{version = 2, isTranspose = false, warpsPerCTA = [2, 2]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_i8_i32( + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x32x!tt.ptr, #blocked>) { + // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0> : tensor<32x32xi32, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT1_OP_C:.+]] = ttg.convert_layout %[[DOT1_ARG_C]] + // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] + %3 = arith.constant dense<0> : tensor<32x32xi32, #blocked> + // CHECK: %[[DOT1_OP_A:.+]] = ttg.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = ttg.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] + // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> + // CHECK: ttg.convert_layout %[[DOT1_WMMA_RES]] + // CHECK-SAME: -> tensor<32x32xi32, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-block-pingpong.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-block-pingpong.mlir new file mode 100644 index 000000000..8cce34430 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-block-pingpong.mlir @@ -0,0 +1,888 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-block-pingpong | FileCheck %s + +//CHECK-LABEL: pingpong_small +//CHECK: ttg.local_load +//CHECK: rocdl.s.setprio 1 +//CHECK: tt.load +//CHECK: rocdl.sched.barrier +//CHECK: ttg.local_load +//CHECK: rocdl.s.setprio 0 +//CHECK: tt.load +//CHECK: rocdl.sched.barrier +//CHECK: rocdl.s.setprio 1 +//CHECK: tt.dot +//CHECK: rocdl.s.setprio 0 + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_small(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %27 = tt.load %26 : tensor<128x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %29 = tt.load %28 : tensor<64x128x!tt.ptr, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %32 = tt.dot %30, %31, %arg11 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma> + %33 = arith.addi %arg14, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %36 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %29, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %32, %26, %28, %35, %36, %37 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + +// ----- + +// CHECK: gpu.barrier +// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x +// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]] +// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]] +// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]] +// CHECK: amdgpu.cond_barrier %[[WARPHIGH]] +// CHECK: scf.for +// CHECK: tt.load +// CHECK: %[[SLICEA0:.+]] = ttg.local_load +// CHECK: %[[SLICEB0:.+]] = ttg.local_load +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: tt.load +// CHECK: %[[SLICEA1:.+]] = ttg.local_load +// CHECK: %[[SLICEB1:.+]] = ttg.local_load +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: %[[SLICEA2:.+]] = ttg.local_load +// CHECK: %[[SLICEB2:.+]] = ttg.local_load +// CHECK: %[[SLICEA3:.+]] = ttg.local_load +// CHECK: %[[SLICEB3:.+]] = ttg.local_load +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT2:.+]] = tt.dot %[[SLICEA2]], %[[SLICEB2]], %[[DOT1]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: ttg.local_store +// CHECK: ttg.local_store +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: tt.dot %[[SLICEA3]], %[[SLICEB3]], %[[DOT2]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: scf.yield +// CHECK: amdgpu.cond_barrier %[[WARPLOW]] + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_large(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr, #blocked1>, tensor<256x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<256x1x!tt.ptr, #blocked1> -> tensor<256x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x256x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x256xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %29 = tt.load %28 : tensor<64x256x!tt.ptr, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> + %33 = arith.addi %arg14, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %29, %37 : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + +// ----- + +// CHECK: gpu.barrier +// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x +// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]] +// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]] +// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]] +// CHECK: amdgpu.cond_barrier %[[WARPHIGH]] +// CHECK: scf.for + +// CHECK: %[[SLICEA0:.+]] = ttg.local_load +// CHECK: %[[SLICEB0:.+]] = ttg.local_load +// CHECK: rocdl.sched.barrier 0 +// CHECK: tt.load +// CHECK: rocdl.sched.barrier 0 +// CHECK: %[[SLICEA1:.+]] = ttg.local_load +// CHECK: %[[SLICEB1:.+]] = ttg.local_load +// CHECK: rocdl.sched.barrier 0 +// CHECK: tt.load +// CHECK: rocdl.s.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: ttg.local_store +// CHECK: ttg.local_store +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: scf.yield +// CHECK: amdgpu.cond_barrier %[[WARPLOW]] + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_medium(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr, #blocked1>, tensor<256x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<256x1x!tt.ptr, #blocked1> -> tensor<256x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %29 = tt.load %28 : tensor<64x128x!tt.ptr, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + %33 = arith.addi %arg14, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %29, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + +// ----- + +// CHECK-LABEL: pingpong_medium_cast +// CHECK-COUNT-2: local_load +// CHECK-NOT: setprio +// CHECK-NOT: barrier + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_medium_cast(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr, #blocked1>, tensor<256x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<256x1x!tt.ptr, #blocked1> -> tensor<256x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %29 = tt.load %28 : tensor<64x128x!tt.ptr, #blocked> + %cast2 = tt.bitcast %29 : tensor<64x128xf16, #blocked> -> tensor<64x128xi16, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xi16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %cast = tt.bitcast %31 : tensor<64x128xi16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %32 = tt.dot %30, %cast, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + %33 = arith.addi %arg14, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %cast2, %37 : tensor<64x128xi16, #blocked> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + + +// ----- + + +// CHECK-LABEL: pingpong_reject +// CHECK-COUNT-2: local_load +// CHECK-NOT: local_load +// CHECK-NOT: setprio +// CHECK-NOT: barrier + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_reject(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<16x256xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<256x16xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr, #blocked1>, tensor<256x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<256x1x!tt.ptr, #blocked1> -> tensor<256x16x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x16xi32, #blocked1> -> tensor<256x16xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<256x16x!tt.ptr, #blocked1>, tensor<256x16xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<16x1x!tt.ptr, #blocked>, tensor<16x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<16x1x!tt.ptr, #blocked> -> tensor<16x256x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<16x256xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x256xf32, #mma>, tensor<256x16x!tt.ptr, #blocked1>, tensor<16x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<256x16x!tt.ptr, #blocked1>, tensor<256x16xi32, #blocked1> + %27 = tt.load %26 : tensor<256x16x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> + %29 = tt.load %28 : tensor<16x256x!tt.ptr, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<16x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %32 = tt.dot %30, %31, %arg11 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<256x256xf32, #mma> + %33 = arith.addi %arg14, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %36 : tensor<256x16xf16, #blocked1> -> !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %29, %37 : tensor<16x256xf16, #blocked> -> !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x256xf32, #mma>, tensor<256x16x!tt.ptr, #blocked1>, tensor<16x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x16xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<16x256xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x256x16xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x16x256xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + +// ----- + +//CHECK-LABEL: pingpong_small_prologue_load +//CHECK: ttg.local_load +//CHECK: rocdl.s.setprio 1 +//CHECK: tt.load +//CHECK: rocdl.sched.barrier +//CHECK: ttg.local_load +//CHECK: rocdl.s.setprio 0 +//CHECK: tt.load +//CHECK: rocdl.sched.barrier +//CHECK: rocdl.s.setprio 1 +//CHECK: tt.dot +//CHECK: rocdl.s.setprio 0 + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_small_prologue_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<128x64xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = arith.cmpi eq, %arg10, %c0_i32: i32 + %27 = scf.if %26 -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> { + %28 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %29 = tt.broadcast %28 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %30 = tt.load %29 : tensor<128x64x!tt.ptr, #blocked1> + %31 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + %32 = ttg.memdesc_subview %31[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %30, %32 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %33 = ttg.local_load %32 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + scf.yield %33 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + } else { + scf.yield %cst_2 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + } + %34 = tt.addptr %arg12, %cst_1 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %35 = tt.load %34 : tensor<128x64x!tt.ptr, #blocked1> + %36 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %37 = tt.load %36 : tensor<64x128x!tt.ptr, #blocked> + %38 = ttg.local_load %arg15 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %39 = arith.addf %38, %27: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %40 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %41 = tt.dot %39, %40, %arg11 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma> + %42 = arith.addi %arg14, %c1_i32 : i32 + %43 = arith.cmpi slt, %42, %c1_i32 : i32 + %44 = arith.select %43, %42, %c0_i32 : i32 + %45 = ttg.memdesc_subview %21[%44, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %35, %45 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %46 = ttg.memdesc_subview %22[%44, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %37, %46 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %41, %34, %36, %44, %45, %46 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + + +// ----- +// CHECK-LABEL: pingpong_medium_dependency + +// CHECK: gpu.barrier +// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x +// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]] +// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]] +// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]] +// CHECK: amdgpu.cond_barrier %[[WARPHIGH]] +// CHECK: scf.for + +// CHECK: %[[SLICEA0:.+]] = ttg.local_load +// CHECK: %[[SLICEB0:.+]] = ttg.local_load +// CHECK: rocdl.sched.barrier 0 +// CHECK: tt.load +// CHECK: rocdl.sched.barrier 0 +// CHECK: %[[SLICEA1:.+]] = ttg.local_load +// CHECK: %[[SLICEB1:.+]] = ttg.local_load +// CHECK: rocdl.sched.barrier 0 +// CHECK: tt.load +// CHECK: rocdl.s.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: ttg.local_store +// CHECK: ttg.local_store +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: scf.yield +// CHECK: amdgpu.cond_barrier %[[WARPLOW]] + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_medium_dependency(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %cst_2 = arith.constant dense<1.000000e+00> : tensor<256x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr, #blocked1>, tensor<256x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<256x1x!tt.ptr, #blocked1> -> tensor<256x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %29 = tt.load %28 : tensor<64x128x!tt.ptr, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + %33 = arith.addf %32, %cst_2 : tensor<256x128xf32, #mma> + %34 = arith.addi %arg14, %c1_i32 : i32 + %35 = arith.cmpi slt, %34, %c1_i32 : i32 + %36 = arith.select %35, %34, %c0_i32 : i32 + %37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %37 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %29, %38 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %33, %26, %28, %36, %37, %38 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + +// ----- +// CHECK-LABEL: pingpong_large_dependency + +// CHECK: gpu.barrier +// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x +// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]] +// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]] +// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]] +// CHECK: amdgpu.cond_barrier %[[WARPHIGH]] +// CHECK: scf.for +// CHECK: tt.load +// CHECK: %[[SLICEA0:.+]] = ttg.local_load +// CHECK: %[[SLICEB0:.+]] = ttg.local_load +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: tt.load +// CHECK: %[[SLICEA1:.+]] = ttg.local_load +// CHECK: %[[SLICEB1:.+]] = ttg.local_load +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: %[[SLICEA2:.+]] = ttg.local_load +// CHECK: %[[SLICEB2:.+]] = ttg.local_load +// CHECK: %[[SLICEA3:.+]] = ttg.local_load +// CHECK: %[[SLICEB3:.+]] = ttg.local_load +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT2:.+]] = tt.dot %[[SLICEA2]], %[[SLICEB2]], %[[DOT1]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: ttg.local_store +// CHECK: ttg.local_store +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: tt.dot %[[SLICEA3]], %[[SLICEB3]], %[[DOT2]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: scf.yield +// CHECK: amdgpu.cond_barrier %[[WARPLOW]] + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_large_dependency(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %cst_2 = arith.constant dense<1.000000e+00> : tensor<256x256xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c63_i32 = arith.constant 63: i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr, #blocked1>, tensor<256x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<256x1x!tt.ptr, #blocked1> -> tensor<256x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x256x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x256xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + %29 = tt.load %28 : tensor<64x256x!tt.ptr, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> + %33 = arith.addf %32, %cst_2 : tensor<256x256xf32, #mma> + %34 = arith.addi %arg14, %c1_i32 : i32 + %35 = arith.cmpi slt, %34, %c1_i32 : i32 + %36 = arith.select %35, %34, %c0_i32 : i32 + %37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %37 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %29, %38 : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %33, %26, %28, %36, %37, %38 : tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} +// ----- +//CHECK-LABEL: pingpong_small_load_reorder +//CHECK: ttg.local_load +//CHECK: rocdl.s.setprio 1 +//CHECK: tt.load +//CHECK: rocdl.sched.barrier +//CHECK: ttg.local_load +//CHECK: rocdl.s.setprio 0 +//CHECK: tt.load +//CHECK: rocdl.sched.barrier +//CHECK: rocdl.s.setprio 1 +//CHECK: tt.dot +//CHECK: rocdl.s.setprio 0 + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_small_load_reorder(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + // This swaps the assumption on the ordering of the local load and + // global load from the base test to ensure the one ping pong cluster + // is robust to different patterns. + %26 = ttg.local_load %arg15 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %27 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %28 = tt.addptr %arg12, %cst_1 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %29 = tt.load %28 : tensor<128x64x!tt.ptr, #blocked1> + %30 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %31 = tt.load %30 : tensor<64x128x!tt.ptr, #blocked> + %32 = tt.dot %26, %27, %arg11 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma> + %33 = arith.addi %arg14, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %29, %36 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %31, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %32, %28, %30, %35, %36, %37 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + + +// ----- +//CHECK-LABEL: pingpong_small_local_load_dep +//CHECK: ttg.local_load +//CHECK: rocdl.s.setprio 1 +//CHECK: tt.load +//CHECK: rocdl.sched.barrier +//CHECK: ttg.local_load +//CHECK: rocdl.s.setprio 0 +//CHECK: tt.load +//CHECK: rocdl.sched.barrier +//CHECK: rocdl.s.setprio 1 +//CHECK: tt.dot +//CHECK: rocdl.s.setprio 0 + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_small_local_load_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<128x64xi32, #blocked1> + %cst_2 = arith.constant dense<1.000000e+00> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %27 = tt.load %26 : tensor<128x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %29 = tt.load %28 : tensor<64x128x!tt.ptr, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %31 = arith.addf %30, %cst_2 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %32 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %33 = tt.dot %31, %32, %arg11 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma> + %34 = arith.addi %arg14, %c1_i32 : i32 + %35 = arith.cmpi slt, %34, %c1_i32 : i32 + %36 = arith.select %35, %34, %c0_i32 : i32 + %37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %37 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> + %38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %29, %38 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %33, %26, %28, %36, %37, %38 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-canonicalize-pointers.mlir new file mode 100644 index 000000000..9a12cb1f2 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -0,0 +1,1317 @@ +// NOTE: Assertions have been autogenerated by mlir/utils/generate-test-checks.py + +// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize -verify-diagnostics | FileCheck %s + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion1(%arg0: !tt.ptr) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.splat %1 : i32 -> tensor<1024xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %4 = tt.addptr %3, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %5 = tt.load %4 : tensor<1024x!tt.ptr> + tt.return %5 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @conversion1( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<1024xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_6]] : tensor<1024xf32> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion2(%arg0: !tt.ptr) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7 = tt.load %6 : tensor<1024x!tt.ptr> + tt.return %7 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @conversion2( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<1024xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_6:.*]] = tt.splat %[[VAL_5]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_6]], %[[VAL_4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_8:.*]] = tt.load %[[VAL_7]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_8]] : tensor<1024xf32> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion3(%arg0: !tt.ptr) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @conversion3( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<1024xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_4]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_5]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_4]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_6]] : tensor<1024xi64> +// CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_7]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_10]], %[[VAL_9]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_12:.*]] = tt.load %[[VAL_11]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_12]] : tensor<1024xf32> +// CHECK: } + +// ----- + + +// +// This is the same as conversion3, but now the `arith.extsi` operations +// disappeared and all the offsets are 32 bits. +// + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion4(%arg0: !tt.ptr {tt.pointer_range = 32 : i32}) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @conversion4( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.pointer_range = 32 : i32}) -> tensor<1024xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_5]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : tensor<1024xi32> +// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_8]], %[[VAL_7]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_10:.*]] = tt.load %[[VAL_9]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_10]] : tensor<1024xf32> +// CHECK: } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @convertLayoutOp(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: tensor<1024xi32, #blocked>) -> tensor<1024xf32, #blocked1> { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %2 = tt.addptr %1, %arg2 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %3 = tt.load %2 : tensor<1024x!tt.ptr, #blocked> + %4 = tt.addptr %0, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %5 = ttg.convert_layout %4 : tensor<1024x!tt.ptr, #blocked> -> tensor<1024x!tt.ptr, #blocked1> + %6 = tt.load %5 : tensor<1024x!tt.ptr, #blocked1> + tt.return %6 : tensor<1024xf32, #blocked1> + } +} + +// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> + +// CHECK-LABEL: tt.func public @convertLayoutOp( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: !tt.ptr, %[[VAL_2:.*]]: tensor<1024xi32, #[[$ATTR_0]]>) -> tensor<1024xf32, #[[$ATTR_1]]> { +// CHECK: %[[VAL_3:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<1024x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_3]], %[[VAL_2]] : tensor<1024x!tt.ptr, #[[$ATTR_0]]>, tensor<1024xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : tensor<1024x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_5]] : tensor<1024xi32, #[[$ATTR_0]]> to tensor<1024xi64, #[[$ATTR_0]]> +// CHECK: %[[VAL_7:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<1024xi64, #[[$ATTR_0]]> -> tensor<1024xi64, #[[$ATTR_1]]> +// CHECK: %[[VAL_8:.*]] = arith.trunci %[[VAL_7]] : tensor<1024xi64, #[[$ATTR_1]]> to tensor<1024xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<1024x!tt.ptr, #[[$ATTR_1]]> +// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_9]], %[[VAL_8]] : tensor<1024x!tt.ptr, #[[$ATTR_1]]>, tensor<1024xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = tt.load %[[VAL_10]] : tensor<1024x!tt.ptr, #[[$ATTR_1]]> +// CHECK: tt.return %[[VAL_11]] : tensor<1024xf32, #[[$ATTR_1]]> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOp(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7:2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %6, %arg4 = %arg1) -> (tensor<1024x!tt.ptr>, tensor<1024xf32>) { + %10 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %11 = tt.load %10 : tensor<1024x!tt.ptr> + %12 = arith.addf %11, %arg4 : tensor<1024xf32> + scf.yield %10, %12 : tensor<1024x!tt.ptr>, tensor<1024xf32> + } + %8 = tt.addptr %7#0, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @forOp( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1024 : i32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_11:.*]]:3 = scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_13:.*]] = %[[VAL_9]], %[[VAL_14:.*]] = %[[VAL_10]], %[[VAL_15:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_13]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_14]] : tensor<1024xi64> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_16]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_19]], %[[VAL_18]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_21:.*]] = tt.load %[[VAL_20]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_21]], %[[VAL_15]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_22]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_24:.*]]#0, %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_25:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_23]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_27]], %[[VAL_26]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_29:.*]] = tt.load %[[VAL_28]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_29]] : tensor<1024xf32> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOp2(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6:2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %arg1) -> (tensor<1024x!tt.ptr>, tensor<1024xf32>) { + %9 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + %11 = arith.addf %10, %arg4 : tensor<1024xf32> + scf.yield %9, %11 : tensor<1024x!tt.ptr>, tensor<1024xf32> + } + %7 = tt.addptr %6#0, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @forOp2( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_12]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_16:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_13]] : tensor<1024xi64> +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_15]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_18]], %[[VAL_17]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_20:.*]] = tt.load %[[VAL_19]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_21:.*]] = arith.addf %[[VAL_20]], %[[VAL_14]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_15]], %[[VAL_17]], %[[VAL_21]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_22:.*]] = tt.addptr %[[VAL_23:.*]]#0, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_24:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_23]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_22]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_28]] : tensor<1024xf32> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forNested(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6:2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %arg1) -> (tensor<1024x!tt.ptr>, tensor<1024xf32>) { + %9:2 = scf.for %arg5 = %c0 to %c128 step %c1 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (tensor<1024x!tt.ptr>, tensor<1024xf32>) { + %10 = tt.addptr %arg6, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %11 = tt.load %10 : tensor<1024x!tt.ptr> + %12 = arith.addf %11, %arg7 : tensor<1024xf32> + scf.yield %10, %12 : tensor<1024x!tt.ptr>, tensor<1024xf32> + } + scf.yield %9#0, %9#1 : tensor<1024x!tt.ptr>, tensor<1024xf32> + } + %7 = tt.addptr %6#0, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @forNested( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_15:.*]]:3 = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_14]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] : tensor<1024xi64> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_26]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_27:.*]]#0, %[[VAL_27]]#1, %[[VAL_27]]#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_29:.*]]#0, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_28]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_34]] : tensor<1024xf32> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @ifOp(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %arg2: i1) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = scf.if %arg2 -> (tensor<1024x!tt.ptr>) { + %8 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + scf.yield %8 : tensor<1024x!tt.ptr> + } else { + %8 = tt.addptr %5, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + scf.yield %8 : tensor<1024x!tt.ptr> + } + %7 = tt.load %6 : tensor<1024x!tt.ptr> + tt.return %7 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @ifOp( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: i1) -> tensor<1024xf32> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_8:.*]]:2 = scf.if %[[VAL_2]] -> (!tt.ptr, tensor<1024xi64>) { +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr, i32 +// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_7]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : !tt.ptr, tensor<1024xi64> +// CHECK: } else { +// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_11]], %[[VAL_3]] : !tt.ptr, tensor<1024xi64> +// CHECK: } +// CHECK: %[[VAL_12:.*]] = arith.trunci %[[VAL_13:.*]]#1 : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_14:.*]] = tt.splat %[[VAL_13]]#0 : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_14]], %[[VAL_12]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_16:.*]] = tt.load %[[VAL_15]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_16]] : tensor<1024xf32> +// CHECK: } + +// ----- + + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @whileOp(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6:2 = scf.while (%arg2 = %5, %arg3 = %2) : (tensor<1024x!tt.ptr>, tensor<1024xi32>) -> (tensor<1024x!tt.ptr> , tensor<1024xi32>) { + %8 = "dummy.evaluate_condition"() : () -> i1 + scf.condition(%8) %arg2, %arg3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + } do { + ^bb0(%arg2: tensor<1024x!tt.ptr>, %arg3: tensor<1024xi32>): + %res = tt.addptr %arg2, %arg3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + scf.yield %res, %arg3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + } + %7 = tt.load %6#0 : tensor<1024x!tt.ptr> + tt.return %7 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @whileOp( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK: %[[VAL_3:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_4:.*]] = scf.while (%[[VAL_5:.*]] = %[[VAL_2]]) : (tensor<1024xi64>) -> tensor<1024xi64> { +// CHECK: %[[VAL_6:.*]] = "dummy.evaluate_condition"() : () -> i1 +// CHECK: scf.condition(%[[VAL_6]]) %[[VAL_5]] : tensor<1024xi64> +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_7:.*]]: tensor<1024xi64>): +// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_3]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<1024xi64> +// CHECK: scf.yield %[[VAL_9]] : tensor<1024xi64> +// CHECK: } +// CHECK: %[[VAL_10:.*]] = arith.trunci %[[VAL_4]] : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_11:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_12:.*]] = tt.addptr %[[VAL_11]], %[[VAL_10]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_13:.*]] = tt.load %[[VAL_12]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_13]] : tensor<1024xf32> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @condBranch(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + cf.cond_br %arg1, ^bb1(%5 : tensor<1024x!tt.ptr>), ^bb2(%6 : tensor<1024x!tt.ptr>) + ^bb1(%7: tensor<1024x!tt.ptr>): // pred: ^bb0 + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + ^bb2(%9: tensor<1024x!tt.ptr>): // pred: ^bb0 + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @condBranch( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i1) -> tensor<1024xf32> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: cf.cond_br %[[VAL_1]], ^bb1(%[[VAL_0]], %[[VAL_2]] : !tt.ptr, tensor<1024xi64>), ^bb2(%[[VAL_7]], %[[VAL_8]] : !tt.ptr, tensor<1024xi64>) +// CHECK: ^bb1(%[[VAL_9:.*]]: !tt.ptr, %[[VAL_10:.*]]: tensor<1024xi64>): +// CHECK: %[[VAL_11:.*]] = arith.trunci %[[VAL_10]] : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_9]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_13:.*]] = tt.addptr %[[VAL_12]], %[[VAL_11]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_14:.*]] = tt.load %[[VAL_13]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_14]] : tensor<1024xf32> +// CHECK: ^bb2(%[[VAL_15:.*]]: !tt.ptr, %[[VAL_16:.*]]: tensor<1024xi64>): +// CHECK: %[[VAL_17:.*]] = arith.trunci %[[VAL_16]] : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_15]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_18]], %[[VAL_17]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_20:.*]] = tt.load %[[VAL_19]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_20]] : tensor<1024xf32> +// CHECK: } + +// ----- + + +// REWRITE branch gets DCEd + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @branch(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + cf.br ^bb1(%6 : tensor<1024x!tt.ptr>) + ^bb1(%7: tensor<1024x!tt.ptr>): // pred: ^bb0 + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @branch( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i1) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_5:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: %[[VAL_7:.*]] = tt.splat %[[VAL_6]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_7]], %[[VAL_5]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_9]] : tensor<1024xf32> +// CHECK: } + +// ----- + + +// The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So +// we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform +// offset will be A*B+D +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @tile_offset(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) -> tensor<16x256xf16, #blocked> { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %4 = arith.addi %3, %2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %7 = tt.splat %arg2 : i32 -> tensor<16x1xi32, #blocked> + %8 = arith.muli %6, %7 : tensor<16x1xi32, #blocked> + %9 = tt.expand_dims %4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %10 = tt.broadcast %8 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + %11 = tt.broadcast %9 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + %12 = arith.addi %10, %11 : tensor<16x256xi32, #blocked> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> + %14 = tt.addptr %13, %12 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> + %15 = tt.load %14 : tensor<16x256x!tt.ptr, #blocked> + tt.return %15 : tensor<16x256xf16, #blocked> + } +} + +// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: tt.func @tile_offset( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i32, +// CHECK-SAME: %[[VAL_2:.*]]: i32) -> tensor<16x256xf16, #[[$ATTR_0]]> { +// CHECK: %[[VAL_3:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_0]]}>> +// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_0]]}>> +// CHECK: %[[VAL_8:.*]] = tt.expand_dims %[[VAL_7]] {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<16x1xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_2]] : i32 -> tensor<16x1xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_8]], %[[VAL_9]] : tensor<16x1xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_11:.*]] = tt.broadcast %[[VAL_10]] : tensor<16x1xi32, #[[$ATTR_0]]> -> tensor<16x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_12:.*]] = tt.expand_dims %[[VAL_6]] {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_0]]}>> -> tensor<1x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_13:.*]] = tt.broadcast %[[VAL_12]] : tensor<1x256xi32, #[[$ATTR_0]]> -> tensor<16x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_13]] : tensor<16x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_15]] : !tt.ptr -> tensor<16x256x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_16]], %[[VAL_14]] : tensor<16x256x!tt.ptr, #[[$ATTR_0]]>, tensor<16x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_18:.*]] = tt.load %[[VAL_17]] : tensor<16x256x!tt.ptr, #[[$ATTR_0]]> +// CHECK: tt.return %[[VAL_18]] : tensor<16x256xf16, #[[$ATTR_0]]> +// CHECK: } + +// ----- + + +// The following is a more complex case where also a multiplication is involved. It's useful to walk through the case. +// We have that the offset to the pointer is the following: +// %12 = %10 + 11 +// This can be transformed in: +// = %7 + %9 +// = %5*%6 + %8 +// = %4*%arg1 + %8 +// = (%3+%2)*%arg1 + %8 +// = (%1 + %2) * %arg1 + %8 +// = (U + N)*U + N +// Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range) +// The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8) +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) -> tensor<128x16xf16, #blocked> { + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = arith.addi %3, %2 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> + %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked> + %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + %11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + %12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %14 = tt.addptr %13, %12 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %15 = tt.load %14 : tensor<128x16x!tt.ptr, #blocked> + tt.return %15 : tensor<128x16xf16, #blocked> + } +} + +// CHECK: #[[$ATTR_1:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: tt.func public @matmul_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, +// CHECK-SAME: %[[VAL_1:.*]]: i32 {tt.divisibility = 16 : i32}) -> tensor<128x16xf16, #[[$ATTR_1]]> { +// CHECK: %[[VAL_2:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_5:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_7:.*]] = tt.expand_dims %[[VAL_5]] {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128x1xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_4]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_1]] : i32 -> tensor<128x1xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_7]], %[[VAL_9]] : tensor<128x1xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = tt.broadcast %[[VAL_10]] : tensor<128x1xi32, #[[$ATTR_1]]> -> tensor<128x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_12:.*]] = tt.expand_dims %[[VAL_6]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_1]]}>> -> tensor<1x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_13:.*]] = tt.broadcast %[[VAL_12]] : tensor<1x16xi32, #[[$ATTR_1]]> -> tensor<128x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_13]] : tensor<128x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_0]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_15]] : !tt.ptr -> tensor<128x16x!tt.ptr, #[[$ATTR_1]]> +// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_16]], %[[VAL_14]] : tensor<128x16x!tt.ptr, #[[$ATTR_1]]>, tensor<128x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_18:.*]] = tt.load %[[VAL_17]] : tensor<128x16x!tt.ptr, #[[$ATTR_1]]> +// CHECK: tt.return %[[VAL_18]] : tensor<128x16xf16, #[[$ATTR_1]]> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @select(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7 = arith.select %arg1, %5, %6 : tensor<1024x!tt.ptr> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @select( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i1) -> tensor<1024xf32> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_1]], %[[VAL_0]], %[[VAL_7]] : !tt.ptr +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_1]], %[[VAL_2]], %[[VAL_8]] : tensor<1024xi64> +// CHECK: %[[VAL_11:.*]] = arith.trunci %[[VAL_10]] : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_9]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_13:.*]] = tt.addptr %[[VAL_12]], %[[VAL_11]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_14:.*]] = tt.load %[[VAL_13]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_14]] : tensor<1024xf32> +// CHECK: } + +// ----- + + +module attributes {"ttg.num-ctas" = 1 : i32} { + tt.func @where_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %cst: i8) -> tensor<1024xi64> { + %c0_i8 = arith.constant 0 : i8 + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = arith.cmpi ne, %c0_i8, %cst : i8 + %6 = arith.select %5, %arg0, %arg1 : !tt.ptr + %7 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xi64> + } +} + +// I don't know why but FileCheck doesn't like check-same here and elsewhere where I've removed them... + +// CHECK: tt.func @where_kernel(%[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: !tt.ptr, %[[VAL_3:.*]]: i8) -> tensor<1024xi64> { +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i8 +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i8 +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_9]], %[[VAL_0]], %[[VAL_1]] : !tt.ptr +// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_10]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_11]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_13:.*]] = tt.addptr %[[VAL_12]], %[[VAL_8]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_14:.*]] = tt.load %[[VAL_13]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_14]] : tensor<1024xi64> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOpWithHints(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %2 = tt.splat %0 : i32 -> tensor<1024xi32> + %3 = arith.addi %2, %1 : tensor<1024xi32> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6:2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %arg1) -> (tensor<1024x!tt.ptr>, tensor<1024xf32>) { + %9 = tt.load %arg3 : tensor<1024x!tt.ptr> + %10 = tt.addptr %arg3, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %11 = tt.addptr %10, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %12 = arith.addf %9, %arg4 : tensor<1024xf32> + scf.yield %11, %12 : tensor<1024x!tt.ptr>, tensor<1024xf32> + } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>} + %7 = tt.addptr %6#0, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @forOpWithHints( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_9:.*]]:3 = scf.for %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_11:.*]] = %[[VAL_7]], %[[VAL_12:.*]] = %[[VAL_8]], %[[VAL_13:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_14:.*]] = arith.trunci %[[VAL_12]] : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_15:.*]] = tt.splat %[[VAL_11]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_15]], %[[VAL_14]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK: %[[VAL_17:.*]] = tt.load %[[VAL_16]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_18:.*]] = tt.addptr %[[VAL_11]], %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_19:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_12]] : tensor<1024xi64> +// CHECK: %[[VAL_21:.*]] = tt.addptr %[[VAL_18]], %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_17]], %[[VAL_13]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_21]], %[[VAL_20]], %[[VAL_22]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>, tt.divisibility_arg2 = dense<16> : tensor<1xi32>} +// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_24:.*]]#0, %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_25:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_23]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_27]], %[[VAL_26]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_29:.*]] = tt.load %[[VAL_28]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_29]] : tensor<1024xf32> +// CHECK: } + +// ----- + + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i64 = arith.constant 0 : i64 + %c100_i32 = arith.constant 100 : i32 + %1 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + %2 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %1) -> (!tt.ptr) : i32 { + tt.store %arg4, %c0_i64 : !tt.ptr + %3 = tt.addptr %arg4, %c1_i32 : !tt.ptr, i32 + scf.yield %3 : !tt.ptr + } + tt.return + } +} + +// CHECK: tt.func public @scalar_pointers(%[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 100 : i32 +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_9:.*]] = %[[VAL_6]]) -> (!tt.ptr) : i32 { +// CHECK: tt.store %[[VAL_9]], %[[VAL_3]] : !tt.ptr +// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_9]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_10]] : !tt.ptr +// CHECK: } +// CHECK: tt.return +// CHECK: } + +// ----- + + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @scalar_if(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %arg2: i1) -> f32 { + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %1 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + %2 = scf.if %arg2 -> (!tt.ptr) { + %4 = tt.addptr %1, %c1_i32 : !tt.ptr, i32 + scf.yield %4 : !tt.ptr + } else { + %4 = tt.addptr %1, %c100_i32 : !tt.ptr, i32 + scf.yield %4 : !tt.ptr + } + %3 = tt.load %2 : !tt.ptr + tt.return %3 : f32 + } +} + +// CHECK-LABEL: tt.func @scalar_if( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: i1) -> f32 { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 100 : i32 +// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_2]] -> (!tt.ptr) { +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_5]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_7]] : !tt.ptr +// CHECK: } else { +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_5]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_8]] : !tt.ptr +// CHECK: } +// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_6]] : !tt.ptr +// CHECK: tt.return %[[VAL_9]] : f32 +// CHECK: } + +// ----- + + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @scalar_while(%arg0: !tt.ptr, %arg1: f32) -> f32 { + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %2 = scf.while (%arg2 = %1) : (!tt.ptr) -> !tt.ptr { + %4 = "dummy.evaluate_condition"() : () -> i1 + scf.condition(%4) %arg2 : !tt.ptr + } do { + ^bb0(%arg2: !tt.ptr): + %4 = tt.addptr %arg2, %c128_i32 : !tt.ptr, i32 + scf.yield %4 : !tt.ptr + } + %3 = tt.load %2 : !tt.ptr + tt.return %3 : f32 + } +} + +// CHECK-LABEL: tt.func @scalar_while( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: f32) -> f32 { +// CHECK: %[[VAL_2:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_4]]) : (!tt.ptr) -> !tt.ptr { +// CHECK: %[[VAL_7:.*]] = "dummy.evaluate_condition"() : () -> i1 +// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : !tt.ptr +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_8:.*]]: !tt.ptr): +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_8]], %[[VAL_2]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_9]] : !tt.ptr +// CHECK: } +// CHECK: %[[VAL_10:.*]] = tt.load %[[VAL_5]] : !tt.ptr +// CHECK: tt.return %[[VAL_10]] : f32 +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @scalar_cond_branch(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i1) -> f32 { + cf.cond_br %arg2, ^bb1(%arg0 : !tt.ptr), ^bb2(%arg1 : !tt.ptr) + ^bb1(%0: !tt.ptr): // pred: ^bb0 + %1 = tt.load %0 : !tt.ptr + tt.return %1 : f32 + ^bb2(%2: !tt.ptr): // pred: ^bb0 + %3 = tt.load %2 : !tt.ptr + tt.return %3 : f32 + } +} + +// CHECK-LABEL: tt.func @scalar_cond_branch( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: !tt.ptr, %[[VAL_2:.*]]: i1) -> f32 { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64 +// CHECK: cf.cond_br %[[VAL_2]], ^bb1(%[[VAL_0]], %[[VAL_3]] : !tt.ptr, i64), ^bb2(%[[VAL_1]], %[[VAL_3]] : !tt.ptr, i64) +// CHECK: ^bb1(%[[VAL_4:.*]]: !tt.ptr, %[[VAL_5:.*]]: i64): +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_4]], %[[VAL_5]] : !tt.ptr, i64 +// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]] : !tt.ptr +// CHECK: tt.return %[[VAL_7]] : f32 +// CHECK: ^bb2(%[[VAL_8:.*]]: !tt.ptr, %[[VAL_9:.*]]: i64): +// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_8]], %[[VAL_9]] : !tt.ptr, i64 +// CHECK: %[[VAL_11:.*]] = tt.load %[[VAL_10]] : !tt.ptr +// CHECK: tt.return %[[VAL_11]] : f32 +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @flipFlopForOpSimple(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %60 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg30 = %60, %arg3 = %6, %arg4 = %arg1) -> (tensor<1024x!tt.ptr>, tensor<1024x!tt.ptr>, tensor<1024xf32>) { + %10 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %11 = tt.load %10 : tensor<1024x!tt.ptr> + %12 = arith.addf %11, %arg4 : tensor<1024xf32> + %100 = tt.addptr %arg30, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + scf.yield %10, %arg30, %12 : tensor<1024x!tt.ptr>, tensor<1024x!tt.ptr>, tensor<1024xf32> + } + %8 = tt.addptr %7#0, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @flipFlopForOpSimple( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1024 : i32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_12:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_13:.*]]:5 = scf.for %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_15:.*]] = %[[VAL_11]], %[[VAL_16:.*]] = %[[VAL_12]], %[[VAL_17:.*]] = %[[VAL_9]], %[[VAL_18:.*]] = %[[VAL_10]], %[[VAL_19:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] : tensor<1024xi64> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_15]], %[[VAL_16]], %[[VAL_26]] : !tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_28:.*]]#0, %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_29:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_31:.*]] = tt.splat %[[VAL_27]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_32:.*]] = tt.addptr %[[VAL_31]], %[[VAL_30]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_33:.*]] = tt.load %[[VAL_32]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_33]] : tensor<1024xf32> +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @flipFlopForOpComplex(%arg0: !tt.ptr, %arg00: !tt.ptr, %arg1: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %40 = arith.addi %3, %2 : tensor<1024xi32> + %50 = tt.splat %arg00 : !tt.ptr -> tensor<1024x!tt.ptr> + %60 = tt.addptr %50, %40 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7:4 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %6, %arg4 = %arg1, %arg30 = %60, %arg40 = %arg1) -> (tensor<1024x!tt.ptr>, tensor<1024xf32>, tensor<1024x!tt.ptr>, tensor<1024xf32>) { + %10 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %11 = tt.load %10 : tensor<1024x!tt.ptr> + %12 = arith.addf %11, %arg4 : tensor<1024xf32> + %100 = tt.addptr %arg30, %40 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %110 = tt.load %100 : tensor<1024x!tt.ptr> + %120 = arith.addf %110, %arg40 : tensor<1024xf32> + scf.yield %100, %120, %10, %12 : tensor<1024x!tt.ptr>, tensor<1024xf32>, tensor<1024x!tt.ptr>, tensor<1024xf32> + } + %8 = tt.addptr %7#0, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + %80 = tt.addptr %7#2, %40 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %90 = tt.load %80 : tensor<1024x!tt.ptr> + tt.return %9, %90 : tensor<1024xf32>, tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @flipFlopForOpComplex( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: !tt.ptr, %[[VAL_2:.*]]: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_0]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_11:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_12:.*]] = tt.addptr %[[VAL_1]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_13:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_14:.*]]:6 = scf.for %[[VAL_15:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_16:.*]] = %[[VAL_10]], %[[VAL_17:.*]] = %[[VAL_11]], %[[VAL_18:.*]] = %[[VAL_2]], %[[VAL_19:.*]] = %[[VAL_12]], %[[VAL_20:.*]] = %[[VAL_13]], %[[VAL_21:.*]] = %[[VAL_2]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_22:.*]] = tt.addptr %[[VAL_16]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_23:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_17]] : tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = tt.splat %[[VAL_22]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = tt.addptr %[[VAL_25]], %[[VAL_24]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_27:.*]] = tt.load %[[VAL_26]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_18]] : tensor<1024xf32> +// CHECK: %[[VAL_29:.*]] = tt.addptr %[[VAL_19]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_20]] : tensor<1024xi64> +// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_29]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_21]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_29]], %[[VAL_31]], %[[VAL_35]], %[[VAL_22]], %[[VAL_24]], %[[VAL_28]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_36:.*]] = tt.addptr %[[VAL_37:.*]]#0, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_38:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_37]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_40:.*]] = tt.splat %[[VAL_36]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_41:.*]] = tt.addptr %[[VAL_40]], %[[VAL_39]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_42:.*]] = tt.load %[[VAL_41]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_43:.*]] = tt.addptr %[[VAL_37]]#3, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_44:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_37]]#4 : tensor<1024xi64> +// CHECK: %[[VAL_46:.*]] = tt.splat %[[VAL_43]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_47:.*]] = tt.addptr %[[VAL_46]], %[[VAL_45]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_48:.*]] = tt.load %[[VAL_47]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_42]], %[[VAL_48]] : tensor<1024xf32>, tensor<1024xf32> +// CHECK: } + +// ----- + +// test_functional_regressions.test_inductor_cummax_bool +// tt.bitcast immediately materializes the fat pointer, ending the analysis +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @test_inductor_cummax_bool(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<64xi8, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %3 = tt.bitcast %2 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %4 = tt.load %3 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %5 = arith.cmpi ne, %4, %cst : tensor<64xi8, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %6 = arith.extsi %0 : tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> to tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %7:2 = "tt.scan"(%5, %6) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg3: i1, %arg4: i64, %arg5: i1, %arg6: i64): + %14 = arith.cmpi ugt, %arg3, %arg5 : i1 + %15 = arith.cmpi eq, %arg3, %arg5 : i1 + %16 = arith.cmpi sgt, %arg4, %arg6 : i64 + %17 = arith.andi %15, %16 : i1 + %18 = arith.ori %14, %17 : i1 + %19 = arith.select %18, %arg3, %arg5 : i1 + %20 = arith.select %18, %arg4, %arg6 : i64 + tt.scan.return %19, %20 : i1, i64 + }) : (tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> (tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) + %8 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %9 = tt.addptr %8, %0 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %10 = tt.bitcast %9 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %11 = arith.extui %7#0 : tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> to tensor<64xi8, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + tt.store %10, %11 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %13 = tt.addptr %12, %0 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + tt.store %13, %7#1 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + tt.return + } +} + +// CHECK-LABEL: tt.func public @test_inductor_cummax_bool( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<64xi8, #[[$ATTR_0]]> +// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<64x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_5]], %[[VAL_4]] : tensor<64x!tt.ptr, #[[$ATTR_0]]>, tensor<64xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_7:.*]] = tt.bitcast %[[VAL_6]] : tensor<64x!tt.ptr, #[[$ATTR_0]]> -> tensor<64x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_8:.*]] = tt.load %[[VAL_7]] : tensor<64x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_9:.*]] = arith.cmpi ne, %[[VAL_8]], %[[VAL_3]] : tensor<64xi8, #[[$ATTR_0]]> +// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_4]] : tensor<64xi32, #[[$ATTR_0]]> to tensor<64xi64, #[[$ATTR_0]]> +// CHECK: %[[VAL_11:.*]]:2 = "tt.scan"(%[[VAL_9]], %[[VAL_10]]) <{axis = 0 : i32, reverse = false}> ({ +// CHECK: ^bb0(%[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: i64, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i64): +// CHECK: %[[VAL_16:.*]] = arith.cmpi ugt, %[[VAL_12]], %[[VAL_14]] : i1 +// CHECK: %[[VAL_17:.*]] = arith.cmpi eq, %[[VAL_12]], %[[VAL_14]] : i1 +// CHECK: %[[VAL_18:.*]] = arith.cmpi sgt, %[[VAL_13]], %[[VAL_15]] : i64 +// CHECK: %[[VAL_19:.*]] = arith.andi %[[VAL_17]], %[[VAL_18]] : i1 +// CHECK: %[[VAL_20:.*]] = arith.ori %[[VAL_16]], %[[VAL_19]] : i1 +// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_20]], %[[VAL_12]], %[[VAL_14]] : i1 +// CHECK: %[[VAL_22:.*]] = arith.select %[[VAL_20]], %[[VAL_13]], %[[VAL_15]] : i64 +// CHECK: tt.scan.return %[[VAL_21]], %[[VAL_22]] : i1, i64 +// CHECK: }) : (tensor<64xi1, #[[$ATTR_0]]>, tensor<64xi64, #[[$ATTR_0]]>) -> (tensor<64xi1, #[[$ATTR_0]]>, tensor<64xi64, #[[$ATTR_0]]>) +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<64x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_4]] : tensor<64x!tt.ptr, #[[$ATTR_0]]>, tensor<64xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_25:.*]] = tt.bitcast %[[VAL_24]] : tensor<64x!tt.ptr, #[[$ATTR_0]]> -> tensor<64x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_26:.*]] = arith.extui %[[VAL_27:.*]]#0 : tensor<64xi1, #[[$ATTR_0]]> to tensor<64xi8, #[[$ATTR_0]]> +// CHECK: tt.store %[[VAL_25]], %[[VAL_26]] : tensor<64x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_28:.*]] = tt.splat %[[VAL_2]] : !tt.ptr -> tensor<64x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_29:.*]] = tt.addptr %[[VAL_28]], %[[VAL_4]] : tensor<64x!tt.ptr, #[[$ATTR_0]]>, tensor<64xi32, #[[$ATTR_0]]> +// CHECK: tt.store %[[VAL_29]], %[[VAL_27]]#1 : tensor<64x!tt.ptr, #[[$ATTR_0]]> +// CHECK: tt.return +// CHECK: } + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @test_atomic_rmw(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} { + %true = arith.constant true + %0 = tt.get_program_id x : i32 + %1 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %2 = tt.load %1 : !tt.ptr + %3 = tt.atomic_rmw fadd, acq_rel, gpu, %arg1, %2, %true : (!tt.ptr, f16, i1) -> f16 + tt.return + } +} + +// CHECK-LABEL: tt.func public @test_atomic_rmw( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_2:.*]] = arith.constant true +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr +// CHECK: %[[VAL_6:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, %[[VAL_1]], %[[VAL_5]], %[[VAL_2]] : (!tt.ptr, f16, i1) -> f16 +// CHECK: tt.return +// CHECK: } + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + // expected-remark@+1 {{expected at least 1 use of unrealized_cast}} + tt.func public @empty_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} { + tt.return + } +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @test_reduce(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<16> : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %cst_0 = arith.constant dense<16> : tensor<1x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %cst_1 = arith.constant dense<16> : tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %cst_2 = arith.constant dense<2> : tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<32x1xi32, #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %4 = tt.expand_dims %3 {axis = 2 : i32} : tensor<32x1xi32, #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %5 = arith.muli %4, %cst_2 : tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %6 = arith.muli %5, %cst_1 : tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x1x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %8 = tt.addptr %7, %6 : tensor<32x1x1x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>, tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %9 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %11 = tt.expand_dims %10 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<1x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %12 = arith.muli %11, %cst_0 : tensor<1x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %13 = tt.broadcast %8 : tensor<32x1x1x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> -> tensor<32x2x1x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %14 = tt.broadcast %12 : tensor<1x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> -> tensor<32x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %15 = tt.addptr %13, %14 : tensor<32x2x1x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>, tensor<32x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> + %17 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> + %18 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %19 = tt.expand_dims %17 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<1x1x16xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %21 = tt.broadcast %15 : tensor<32x2x1x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> -> tensor<32x2x16x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %22 = tt.broadcast %20 : tensor<1x1x16xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> -> tensor<32x2x16xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %23 = tt.addptr %21, %22 : tensor<32x2x16x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>, tensor<32x2x16xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %24 = tt.load %23 : tensor<32x2x16x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %25 = "tt.reduce"(%24) <{axis = 1 : i32}> ({ + ^bb0(%arg2: f32, %arg3: f32): + %34 = arith.maxnumf %arg2, %arg3 : f32 + tt.reduce.return %34 : f32 + }) : (tensor<32x2x16xf32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>) -> tensor<32x16xf32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %26 = tt.expand_dims %25 {axis = 1 : i32} : tensor<32x16xf32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x1x16xf32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + %27 = arith.muli %2, %cst : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %28 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %29 = tt.addptr %28, %27 : tensor<32x1x!tt.ptr, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>, tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %30 = tt.broadcast %29 : tensor<32x1x!tt.ptr, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x16x!tt.ptr, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %31 = tt.broadcast %18 : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %32 = tt.addptr %30, %31 : tensor<32x16x!tt.ptr, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>, tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> + %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32x16x!tt.ptr, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x1x16x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + tt.store %33, %26 : tensor<32x1x16x!tt.ptr, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> + tt.return + } +} + +// CHECK: #[[$ATTR_3:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +// CHECK-LABEL: tt.func public @test_reduce( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<16> : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<16> : tensor<1x2x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<16> : tensor<32x1x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_5:.*]] = arith.constant dense<2> : tensor<32x1x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> +// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>}>> +// CHECK: %[[VAL_8:.*]] = tt.expand_dims %[[VAL_7]] {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>}>> -> tensor<32x1xi32, #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_9:.*]] = tt.expand_dims %[[VAL_8]] {axis = 2 : i32} : tensor<32x1xi32, #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>> -> tensor<32x1x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_5]] : tensor<32x1x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_4]] : tensor<32x1x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_12:.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>}>> +// CHECK: %[[VAL_13:.*]] = tt.broadcast %[[VAL_11]] : tensor<32x1x1xi32, #[[$ATTR_3]]> -> tensor<32x2x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_14:.*]] = tt.expand_dims %[[VAL_12]] {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_15:.*]] = tt.expand_dims %[[VAL_14]] {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>> -> tensor<1x2x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_15]], %[[VAL_3]] : tensor<1x2x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_17:.*]] = tt.broadcast %[[VAL_16]] : tensor<1x2x1xi32, #[[$ATTR_3]]> -> tensor<32x2x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_13]] : tensor<32x2x1xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_19:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> +// CHECK: %[[VAL_21:.*]] = tt.broadcast %[[VAL_18]] : tensor<32x2x1xi32, #[[$ATTR_3]]> -> tensor<32x2x16xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_22:.*]] = tt.expand_dims %[[VAL_20]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_23:.*]] = tt.expand_dims %[[VAL_22]] {axis = 1 : i32} : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<1x1x16xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_24:.*]] = tt.broadcast %[[VAL_23]] : tensor<1x1x16xi32, #[[$ATTR_3]]> -> tensor<32x2x16xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_21]] : tensor<32x2x16xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<32x2x16x!tt.ptr, #[[$ATTR_3]]> +// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<32x2x16x!tt.ptr, #[[$ATTR_3]]>, tensor<32x2x16xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<32x2x16x!tt.ptr, #[[$ATTR_3]]> +// CHECK: %[[VAL_29:.*]] = "tt.reduce"(%[[VAL_28]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_30:.*]]: f32, %[[VAL_31:.*]]: f32): +// CHECK: %[[VAL_32:.*]] = arith.maxnumf %[[VAL_30]], %[[VAL_31]] : f32 +// CHECK: tt.reduce.return %[[VAL_32]] : f32 +// CHECK: }) : (tensor<32x2x16xf32, #[[$ATTR_3]]>) -> tensor<32x16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_33:.*]] = tt.expand_dims %[[VAL_29]] {axis = 1 : i32} : tensor<32x16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<32x1x16xf32, #[[$ATTR_3]]> +// CHECK: %[[VAL_34:.*]] = tt.expand_dims %[[VAL_6]] {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> -> tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_35:.*]] = arith.muli %[[VAL_34]], %[[VAL_2]] : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_36:.*]] = tt.broadcast %[[VAL_35]] : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_37:.*]] = tt.expand_dims %[[VAL_19]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_38:.*]] = tt.broadcast %[[VAL_37]] : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_36]] : tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_40:.*]] = tt.expand_dims %[[VAL_39]] {axis = 1 : i32} : tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<32x1x16xi32, #[[$ATTR_3]]> +// CHECK: %[[VAL_41:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<32x1x16x!tt.ptr, #[[$ATTR_3]]> +// CHECK: %[[VAL_42:.*]] = tt.addptr %[[VAL_41]], %[[VAL_40]] : tensor<32x1x16x!tt.ptr, #[[$ATTR_3]]>, tensor<32x1x16xi32, #[[$ATTR_3]]> +// CHECK: tt.store %[[VAL_42]], %[[VAL_33]] : tensor<32x1x16x!tt.ptr, #[[$ATTR_3]]> +// CHECK: tt.return +// CHECK: } + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @block_copy_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %c2_i32 = arith.constant 2 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = arith.divsi %arg2, %c2_i32 : i32 + %3 = arith.extsi %2 : i32 to i64 + %4 = tt.bitcast %arg0 : !tt.ptr -> !tt.ptr + %5 = arith.extsi %1 : i32 to i64 + %6 = arith.extsi %arg2 : i32 to i64 + %7 = tt.bitcast %arg1 : !tt.ptr -> !tt.ptr + %8 = tt.splat %4 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %9 = tt.splat %5 : i64 -> tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %11 = arith.extsi %10 : tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> to tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %12 = arith.addi %9, %11 : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %13 = tt.addptr %8, %12 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %14 = arith.cmpi sge, %12, %cst : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %15 = tt.splat %3 : i64 -> tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %16 = arith.cmpi slt, %12, %15 : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %17 = arith.andi %14, %16 : tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %18 = tt.load %13, %17 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %19 = tt.splat %7 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %20 = tt.addptr %19, %12 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %21 = tt.splat %6 : i64 -> tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %22 = arith.cmpi slt, %12, %21 : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %23 = arith.andi %14, %22 : tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + tt.store %20, %18, %23 : tensor<64x!tt.ptr, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + tt.return + } +} + +// CHECK: #[[$ATTR_4:.+]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK-LABEL: tt.func public @block_copy_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_2:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_4:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] : i32 +// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_2]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_9:.*]] = arith.extsi %[[VAL_8]] : i32 to i64 +// CHECK: %[[VAL_10:.*]] = tt.bitcast %[[VAL_0]] : !tt.ptr -> !tt.ptr +// CHECK: %[[VAL_11:.*]] = arith.extsi %[[VAL_7]] : i32 to i64 +// CHECK: %[[VAL_12:.*]] = arith.extsi %[[VAL_2]] : i32 to i64 +// CHECK: %[[VAL_13:.*]] = tt.bitcast %[[VAL_1]] : !tt.ptr -> !tt.ptr +// CHECK: %[[VAL_14:.*]] = tt.splat %[[VAL_10]] : !tt.ptr -> tensor<64x!tt.ptr, #[[$ATTR_4]]> +// CHECK: %[[VAL_15:.*]] = tt.splat %[[VAL_11]] : i64 -> tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_16:.*]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #[[$ATTR_4]]> +// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_16]] : tensor<64xi32, #[[$ATTR_4]]> to tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_17]] : tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_14]], %[[VAL_18]] : tensor<64x!tt.ptr, #[[$ATTR_4]]>, tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_20:.*]] = arith.cmpi sge, %[[VAL_18]], %[[VAL_3]] : tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_21:.*]] = tt.splat %[[VAL_9]] : i64 -> tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_22:.*]] = arith.cmpi slt, %[[VAL_18]], %[[VAL_21]] : tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : tensor<64xi1, #[[$ATTR_4]]> +// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_19]], %[[VAL_23]] : tensor<64x!tt.ptr, #[[$ATTR_4]]> +// CHECK: %[[VAL_25:.*]] = tt.splat %[[VAL_13]] : !tt.ptr -> tensor<64x!tt.ptr, #[[$ATTR_4]]> +// CHECK: %[[VAL_26:.*]] = tt.addptr %[[VAL_25]], %[[VAL_18]] : tensor<64x!tt.ptr, #[[$ATTR_4]]>, tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_12]] : i64 -> tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_28:.*]] = arith.cmpi slt, %[[VAL_18]], %[[VAL_27]] : tensor<64xi64, #[[$ATTR_4]]> +// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_20]], %[[VAL_28]] : tensor<64xi1, #[[$ATTR_4]]> +// CHECK: tt.store %[[VAL_26]], %[[VAL_24]], %[[VAL_29]] : tensor<64x!tt.ptr, #[[$ATTR_4]]> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-conditional-barrier.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-conditional-barrier.mlir new file mode 100644 index 000000000..ac232354a --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-conditional-barrier.mlir @@ -0,0 +1,33 @@ +// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @conditional_barrier() { + // CHECK-LABEL: llvm.func @conditional_barrier + + // CHECK: %[[CMP0:.+]] = llvm.icmp "ne" %3, %1 : i32 + // CHECK: %[[CMP1:.+]] = llvm.icmp "eq" %3, %1 : i32 + // CHECK: llvm.cond_br %[[CMP0]], ^bb1, ^bb2 + // CHECK: ^bb1: + // CHECK: rocdl.s.barrier + // CHECK: llvm.br ^bb2 + // CHECK: ^bb2: + // CHECK: llvm.add + // CHECK: llvm.cond_br %[[CMP1]], ^bb3, ^bb4 + // CHECK: ^bb3: + // CHECK: rocdl.s.barrier + // CHECK: llvm.br ^bb4 + // CHECK: ^bb4: + // CHECK: llvm.return + + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = rocdl.workitem.id.x : i32 + %1 = arith.divsi %0, %c256_i32 : i32 + %2 = arith.cmpi ne, %1, %c0_i32 : i32 + %3 = arith.cmpi eq, %1, %c0_i32 : i32 + amdgpu.cond_barrier %2 + %4 = arith.addi %0, %c256_i32 : i32 + amdgpu.cond_barrier %3 + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir new file mode 100644 index 000000000..f0f5a2f8e --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir @@ -0,0 +1,1051 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonamdgpu-convert-buffer-ops='arch-generation-name=gfx942' | FileCheck %s + +// CHECK-LABEL: tt.func @conversion1( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<1024xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_6]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion1(%arg0: !tt.ptr) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> + %4 = tt.load %3 : tensor<1024x!tt.ptr> + tt.return %4 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @conversion2( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_5:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: %[[VAL_7:.*]] = amdgpu.buffer_load %[[VAL_6]]{{\[}}%[[VAL_5]]] : tensor<1024xf32> +// CHECK: tt.return %[[VAL_7]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion2(%arg0: !tt.ptr) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.splat %3 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.load %5 : tensor<1024x!tt.ptr> + tt.return %6 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @conversion3( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<1024xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_4]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_5]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_4]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_6]] : tensor<1024xi64> +// CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_7]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_10]], %[[VAL_9]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_12:.*]] = tt.load %[[VAL_11]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_12]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion3(%arg0: !tt.ptr) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5 = tt.addptr %3, %1 : !tt.ptr, i32 + %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %7 = arith.addi %6, %4 : tensor<1024xi64> + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @conversion4( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.pointer_range = 32 : i32}) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_5:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_6]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = amdgpu.buffer_load %[[VAL_7]]{{\[}}%[[VAL_8]]] : tensor<1024xf32> +// CHECK: tt.return %[[VAL_9]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion4(%arg0: !tt.ptr {tt.pointer_range = 32 : i32}) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.addptr %3, %1 : !tt.ptr, i32 + %5 = arith.addi %2, %2 : tensor<1024xi32> + %6 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %5 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forOp( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_11:.*]]:3 = scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_13:.*]] = %[[VAL_9]], %[[VAL_14:.*]] = %[[VAL_10]], %[[VAL_15:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_13]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_14]] : tensor<1024xi64> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_16]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_19]], %[[VAL_18]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_21:.*]] = tt.load %[[VAL_20]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_21]], %[[VAL_15]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_22]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_24:.*]]#0, %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_25:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_23]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_27]], %[[VAL_26]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_29:.*]] = tt.load %[[VAL_28]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_29]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOp(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %3, %arg4 = %4, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %12 = tt.addptr %arg3, %1 : !tt.ptr, i32 + %13 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %14 = arith.addi %13, %arg4 : tensor<1024xi64> + %15 = tt.splat %12 : !tt.ptr -> tensor<1024x!tt.ptr> + %16 = tt.addptr %15, %14 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %17 = tt.load %16 : tensor<1024x!tt.ptr> + %18 = arith.addf %17, %arg5 : tensor<1024xf32> + scf.yield %12, %14, %18 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %6 = tt.addptr %5#0, %1 : !tt.ptr, i32 + %7 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %8 = arith.addi %7, %5#1 : tensor<1024xi64> + %9 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %11 = tt.load %10 : tensor<1024x!tt.ptr> + tt.return %11 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forOp2( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_12]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_16:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_13]] : tensor<1024xi64> +// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_15]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_18]], %[[VAL_17]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_20:.*]] = tt.load %[[VAL_19]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_21:.*]] = arith.addf %[[VAL_20]], %[[VAL_14]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_15]], %[[VAL_17]], %[[VAL_21]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_22:.*]] = tt.addptr %[[VAL_23:.*]]#0, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_24:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_23]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_22]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_28]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOp2(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %cst = arith.constant dense<0> : tensor<1024xi64> + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %10 = tt.addptr %arg3, %1 : !tt.ptr, i32 + %11 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %12 = arith.addi %11, %arg4 : tensor<1024xi64> + %13 = tt.splat %10 : !tt.ptr -> tensor<1024x!tt.ptr> + %14 = tt.addptr %13, %12 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %15 = tt.load %14 : tensor<1024x!tt.ptr> + %16 = arith.addf %15, %arg5 : tensor<1024xf32> + scf.yield %10, %12, %16 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 + %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %6 = arith.addi %5, %3#1 : tensor<1024xi64> + %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forNested( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_15:.*]]:3 = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_14]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] : tensor<1024xi64> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_26]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_27:.*]]#0, %[[VAL_27]]#1, %[[VAL_27]]#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_29:.*]]#0, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_28]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_34]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forNested(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %cst = arith.constant dense<0> : tensor<1024xi64> + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:3 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %10:3 = scf.for %arg6 = %c0 to %c16 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 + %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %13 = arith.addi %12, %arg8 : tensor<1024xi64> + %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %16 = tt.load %15 : tensor<1024x!tt.ptr> + %17 = arith.addf %16, %arg9 : tensor<1024xf32> + scf.yield %11, %13, %17 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 + %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %6 = arith.addi %5, %3#1 : tensor<1024xi64> + %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forNestedOverMaxTripCount( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_15:.*]]:3 = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_14]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] : tensor<1024xi64> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_26]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_27:.*]]#0, %[[VAL_27]]#1, %[[VAL_27]]#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_29:.*]]#0, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_28]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_34]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forNestedOverMaxTripCount(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %cst = arith.constant dense<0> : tensor<1024xi64> + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %10:3 = scf.for %arg6 = %c0 to %c128 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 + %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %13 = arith.addi %12, %arg8 : tensor<1024xi64> + %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %16 = tt.load %15 : tensor<1024x!tt.ptr> + %17 = arith.addf %16, %arg9 : tensor<1024xf32> + scf.yield %11, %13, %17 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 + %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %6 = arith.addi %5, %3#1 : tensor<1024xi64> + %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @ifOp( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>, %[[VAL_2:.*]]: i1) -> tensor<1024xf32> { +// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK: %[[VAL_5:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]]:2 = scf.if %[[VAL_2]] -> (!tt.ptr, tensor<1024xi64>) { +// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_11:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: scf.yield %[[VAL_10]], %[[VAL_11]] : !tt.ptr, tensor<1024xi64> +// CHECK: } else { +// CHECK: %[[VAL_12:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_12]], %[[VAL_4]] : !tt.ptr, tensor<1024xi64> +// CHECK: } +// CHECK: %[[VAL_13:.*]] = arith.trunci %[[VAL_14:.*]]#1 : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_15:.*]] = amdgpu.buffer_load %[[VAL_14]]#0{{\[}}%[[VAL_13]]] : tensor<1024xf32> +// CHECK: tt.return %[[VAL_15]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @ifOp(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %arg2: i1) -> tensor<1024xf32> { + %cst = arith.constant dense<0> : tensor<1024xi64> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:2 = scf.if %arg2 -> (!tt.ptr, tensor<1024xi64>) { + %8 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + scf.yield %8, %9 : !tt.ptr, tensor<1024xi64> + } else { + %8 = tt.addptr %arg0, %1 : !tt.ptr, i32 + scf.yield %8, %cst : !tt.ptr, tensor<1024xi64> + } + %4 = arith.trunci %3#1 : tensor<1024xi64> to tensor<1024xi32> + %5 = tt.splat %3#0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7 = tt.load %6 : tensor<1024x!tt.ptr> + tt.return %7 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @condBranch( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: i1) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK: %[[VAL_4:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr, i32 +// CHECK: %[[VAL_9:.*]] = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> +// CHECK: cf.cond_br %[[VAL_1]], ^bb1(%[[VAL_0]], %[[VAL_2]] : !tt.ptr, tensor<1024xi64>), ^bb1(%[[VAL_8]], %[[VAL_9]] : !tt.ptr, tensor<1024xi64>) +// CHECK: ^bb1(%[[VAL_9:.*]]: !tt.ptr, %[[VAL_11:.*]]: tensor<1024xi64>): +// CHECK: %[[VAL_12:.*]] = arith.trunci %[[VAL_11]] : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_13:.*]] = amdgpu.buffer_load %[[VAL_9]]{{\[}}%[[VAL_12]]] : tensor<1024xf32> +// CHECK: tt.return %[[VAL_13]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @condBranch(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + %cst = arith.constant dense<0> : tensor<1024xi64> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + cf.cond_br %arg1, ^bb1(%arg0, %cst : !tt.ptr, tensor<1024xi64>), ^bb2(%3, %4 : !tt.ptr, tensor<1024xi64>) + ^bb1(%5: !tt.ptr, %6: tensor<1024xi64>): // pred: ^bb0 + %7 = arith.trunci %6 : tensor<1024xi64> to tensor<1024xi32> + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + ^bb2(%11: !tt.ptr, %12: tensor<1024xi64>): // pred: ^bb0 + %13 = arith.trunci %12 : tensor<1024xi64> to tensor<1024xi32> + %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %16 = tt.load %15 : tensor<1024x!tt.ptr> + tt.return %16 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @branch( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: i1) -> tensor<1024xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_8:.*]] = amdgpu.buffer_load %[[VAL_7]]{{\[}}%[[VAL_6]]] : tensor<1024xf32> +// CHECK: tt.return %[[VAL_8]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @branch(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.splat %3 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.load %5 : tensor<1024x!tt.ptr> + tt.return %6 : tensor<1024xf32> + } +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: tt.func @tile_offset( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32) -> tensor<16x256xf16, #[[$ATTR_0]]> { +// CHECK: %[[VAL_3:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_0]]}>> +// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_0]]}>> +// CHECK: %[[VAL_8:.*]] = tt.expand_dims %[[VAL_7]] {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<16x1xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_2]] : i32 -> tensor<16x1xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_8]], %[[VAL_9]] : tensor<16x1xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_11:.*]] = tt.broadcast %[[VAL_10]] : tensor<16x1xi32, #[[$ATTR_0]]> -> tensor<16x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_12:.*]] = tt.expand_dims %[[VAL_6]] {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_0]]}>> -> tensor<1x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_13:.*]] = tt.broadcast %[[VAL_12]] : tensor<1x256xi32, #[[$ATTR_0]]> -> tensor<16x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_13]] : tensor<16x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr, i32 +// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_15]] : !tt.ptr -> tensor<16x256x!tt.ptr, #[[$ATTR_0]]> +// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_16]], %[[VAL_14]] : tensor<16x256x!tt.ptr, #[[$ATTR_0]]>, tensor<16x256xi32, #[[$ATTR_0]]> +// CHECK: %[[VAL_18:.*]] = tt.load %[[VAL_17]] : tensor<16x256x!tt.ptr, #[[$ATTR_0]]> +// CHECK: tt.return %[[VAL_18]] : tensor<16x256xf16, #[[$ATTR_0]]> +// CHECK: } + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @tile_offset(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) -> tensor<16x256xf16, #blocked> { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %5 = tt.splat %arg2 : i32 -> tensor<16x1xi32, #blocked> + %6 = arith.muli %4, %5 : tensor<16x1xi32, #blocked> + %7 = tt.broadcast %6 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + %8 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %9 = tt.broadcast %8 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + %10 = arith.addi %7, %9 : tensor<16x256xi32, #blocked> + %11 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %12 = tt.splat %11 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> + %13 = tt.addptr %12, %10 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> + %14 = tt.load %13 : tensor<16x256x!tt.ptr, #blocked> + tt.return %14 : tensor<16x256xf16, #blocked> + } +} + +// ----- + +// CHECK: #[[$ATTR_1:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-LABEL: tt.func public @matmul_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: i32 {tt.divisibility = 16 : i32}) -> tensor<128x16xf16, #[[$ATTR_1]]> { +// CHECK: %[[VAL_2:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_5:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_7:.*]] = tt.expand_dims %[[VAL_5]] {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128x1xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_4]], %[[VAL_1]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_1]] : i32 -> tensor<128x1xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_7]], %[[VAL_9]] : tensor<128x1xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = tt.broadcast %[[VAL_10]] : tensor<128x1xi32, #[[$ATTR_1]]> -> tensor<128x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_12:.*]] = tt.expand_dims %[[VAL_6]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_1]]}>> -> tensor<1x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_13:.*]] = tt.broadcast %[[VAL_12]] : tensor<1x16xi32, #[[$ATTR_1]]> -> tensor<128x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_13]] : tensor<128x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_0]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_15]] : !tt.ptr -> tensor<128x16x!tt.ptr, #[[$ATTR_1]]> +// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_16]], %[[VAL_14]] : tensor<128x16x!tt.ptr, #[[$ATTR_1]]>, tensor<128x16xi32, #[[$ATTR_1]]> +// CHECK: %[[VAL_18:.*]] = tt.load %[[VAL_17]] : tensor<128x16x!tt.ptr, #[[$ATTR_1]]> +// CHECK: tt.return %[[VAL_18]] : tensor<128x16xf16, #[[$ATTR_1]]> +// CHECK: } + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) -> tensor<128x16xf16, #blocked> { + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %5 = arith.muli %1, %arg1 : i32 + %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> + %7 = arith.muli %4, %6 : tensor<128x1xi32, #blocked> + %8 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + %9 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %10 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + %11 = arith.addi %8, %10 : tensor<128x16xi32, #blocked> + %12 = tt.addptr %arg0, %5 : !tt.ptr, i32 + %13 = tt.splat %12 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %14 = tt.addptr %13, %11 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %15 = tt.load %14 : tensor<128x16x!tt.ptr, #blocked> + tt.return %15 : tensor<128x16xf16, #blocked> + } +} + +// ----- + +// CHECK-LABEL: tt.func @select( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: i1) -> tensor<1024xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<1024xi64> +// CHECK: %[[VAL_4:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr, i32 +// CHECK: %[[VAL_9:.*]] = arith.extsi %[[VAL_7]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_1]], %[[VAL_0]], %[[VAL_8]] : !tt.ptr +// CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_1]], %[[VAL_3]], %[[VAL_9]] : tensor<1024xi64> +// CHECK: %[[VAL_12:.*]] = arith.trunci %[[VAL_11]] : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_13:.*]] = amdgpu.buffer_load %[[VAL_10]]{{\[}}%[[VAL_12]]] : tensor<1024xf32> +// CHECK: tt.return %[[VAL_13]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @select(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + %cst = arith.constant dense<0> : tensor<1024xi64> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5 = arith.select %arg1, %arg0, %3 : !tt.ptr + %6 = arith.select %arg1, %cst, %4 : tensor<1024xi64> + %7 = arith.trunci %6 : tensor<1024xi64> to tensor<1024xi32> + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @where_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: !tt.ptr, %[[VAL_2:.*]]: i8) -> tensor<1024xi64> { +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8 +// CHECK: %[[VAL_5:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = arith.cmpi ne, %[[VAL_2]], %[[VAL_4]] : i8 +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_9]], %[[VAL_0]], %[[VAL_1]] : !tt.ptr +// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_10]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_12:.*]] = amdgpu.buffer_load %[[VAL_11]]{{\[}}%[[VAL_8]]] : tensor<1024xi64> +// CHECK: tt.return %[[VAL_12]] : tensor<1024xi64> +// CHECK: } + +module attributes {"ttg.num-ctas" = 1 : i32} { + tt.func @where_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i8) -> tensor<1024xi64> { + %c0_i8 = arith.constant 0 : i8 + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = arith.cmpi ne, %arg2, %c0_i8 : i8 + %4 = arith.select %3, %arg0, %arg1 : !tt.ptr + %5 = tt.addptr %4, %1 : !tt.ptr, i32 + %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xi64> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forOpWithHints( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr, i32 +// CHECK: %[[VAL_9:.*]] = arith.extsi %[[VAL_7]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_12:.*]] = %[[VAL_8]], %[[VAL_13:.*]] = %[[VAL_9]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_15:.*]] = arith.trunci %[[VAL_13]] : tensor<1024xi64> to tensor<1024xi32> +// CHECK: %[[VAL_16:.*]] = amdgpu.buffer_load %[[VAL_12]]{{\[}}%[[VAL_15]]] : tensor<1024xf32> +// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_12]], %[[VAL_6]] : !tt.ptr, i32 +// CHECK: %[[VAL_18:.*]] = arith.extsi %[[VAL_7]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_13]] : tensor<1024xi64> +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_6]] : !tt.ptr, i32 +// CHECK: %[[VAL_21:.*]] = arith.addf %[[VAL_16]], %[[VAL_14]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_20]], %[[VAL_19]], %[[VAL_21]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>, tt.divisibility_arg2 = dense<16> : tensor<1xi32>} +// CHECK: %[[VAL_22:.*]] = tt.addptr %[[VAL_23:.*]]#0, %[[VAL_6]] : !tt.ptr, i32 +// CHECK: %[[VAL_24:.*]] = arith.extsi %[[VAL_7]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_23]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_22]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_28]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOpWithHints(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %3 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> + %4:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %2, %arg4 = %3, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %11 = arith.trunci %arg4 : tensor<1024xi64> to tensor<1024xi32> + %12 = tt.splat %arg3 : !tt.ptr -> tensor<1024x!tt.ptr> + %13 = tt.addptr %12, %11 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %14 = tt.load %13 : tensor<1024x!tt.ptr> + %15 = tt.addptr %arg3, %0 : !tt.ptr, i32 + %16 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> + %17 = arith.addi %16, %arg4 : tensor<1024xi64> + %18 = tt.addptr %15, %0 : !tt.ptr, i32 + %19 = arith.addf %14, %arg5 : tensor<1024xf32> + scf.yield %18, %17, %19 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>, tt.divisibility_arg2 = dense<16> : tensor<1xi32>} + %5 = tt.addptr %4#0, %0 : !tt.ptr, i32 + %6 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> + %7 = arith.addi %6, %4#1 : tensor<1024xi64> + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func public @scalar_pointers( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_3:.*]] = arith.constant 100 : i32 +// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_2]] : !tt.ptr, i32 +// CHECK: %[[VAL_5:.*]] = scf.for %[[VAL_6:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_7:.*]] = %[[VAL_4]]) -> (!tt.ptr) : i32 { +// CHECK: tt.store %[[VAL_7]], %[[VAL_1]] : !tt.ptr +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_7]], %[[VAL_2]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_8]] : !tt.ptr +// CHECK: } +// CHECK: tt.return +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + %1 = scf.for %arg1 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg2 = %0) -> (!tt.ptr) : i32 { + tt.store %arg2, %c0_i64 : !tt.ptr + %2 = tt.addptr %arg2, %c1_i32 : !tt.ptr, i32 + scf.yield %2 : !tt.ptr + } + tt.return + } +} + +// ----- + +// CHECK-LABEL: tt.func @scalar_if( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>, %[[VAL_2:.*]]: i1) -> f32 { +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant 100 : i32 +// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_2]] -> (!tt.ptr) { +// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_5]], %[[VAL_3]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_7]] : !tt.ptr +// CHECK: } else { +// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_5]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: scf.yield %[[VAL_8]] : !tt.ptr +// CHECK: } +// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_6]] : !tt.ptr +// CHECK: tt.return %[[VAL_9]] : f32 +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @scalar_if(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %arg2: i1) -> f32 { + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + %1 = scf.if %arg2 -> (!tt.ptr) { + %3 = tt.addptr %0, %c1_i32 : !tt.ptr, i32 + scf.yield %3 : !tt.ptr + } else { + %3 = tt.addptr %0, %c100_i32 : !tt.ptr, i32 + scf.yield %3 : !tt.ptr + } + %2 = tt.load %1 : !tt.ptr + tt.return %2 : f32 + } +} + +// ----- + +// CHECK-LABEL: tt.func @scalar_cond_branch( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: !tt.ptr, %[[VAL_2:.*]]: i1) -> f32 { +// CHECK: cf.cond_br %[[VAL_2]], ^bb1(%[[VAL_0]] : !tt.ptr), ^bb1(%[[VAL_1]] : !tt.ptr) +// CHECK: ^bb1(%[[VAL_3:.*]]: !tt.ptr): +// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr +// CHECK: tt.return %[[VAL_4]] : f32 +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @scalar_cond_branch(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i1) -> f32 { + cf.cond_br %arg2, ^bb1(%arg0 : !tt.ptr), ^bb2(%arg1 : !tt.ptr) + ^bb1(%0: !tt.ptr): // pred: ^bb0 + %1 = tt.load %0 : !tt.ptr + tt.return %1 : f32 + ^bb2(%2: !tt.ptr): // pred: ^bb0 + %3 = tt.load %2 : !tt.ptr + tt.return %3 : f32 + } +} + +// ----- + +// CHECK-LABEL: tt.func @flipFlopForOpSimple( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_12:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_13:.*]]:5 = scf.for %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_15:.*]] = %[[VAL_11]], %[[VAL_16:.*]] = %[[VAL_12]], %[[VAL_17:.*]] = %[[VAL_9]], %[[VAL_18:.*]] = %[[VAL_10]], %[[VAL_19:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] : tensor<1024xi64> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_15]], %[[VAL_16]], %[[VAL_26]] : !tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_28:.*]]#0, %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_29:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_31:.*]] = tt.splat %[[VAL_27]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_32:.*]] = tt.addptr %[[VAL_31]], %[[VAL_30]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_33:.*]] = tt.load %[[VAL_32]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_33]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @flipFlopForOpSimple(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %7:5 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %6, %arg5 = %3, %arg6 = %4, %arg7 = %arg1) -> (!tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %14 = tt.addptr %arg5, %1 : !tt.ptr, i32 + %15 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %16 = arith.addi %15, %arg6 : tensor<1024xi64> + %17 = tt.splat %14 : !tt.ptr -> tensor<1024x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %19 = tt.load %18 : tensor<1024x!tt.ptr> + %20 = arith.addf %19, %arg7 : tensor<1024xf32> + scf.yield %14, %16, %arg3, %arg4, %20 : !tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %8 = tt.addptr %7#0, %1 : !tt.ptr, i32 + %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %10 = arith.addi %9, %7#1 : tensor<1024xi64> + %11 = tt.splat %8 : !tt.ptr -> tensor<1024x!tt.ptr> + %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %13 = tt.load %12 : tensor<1024x!tt.ptr> + tt.return %13 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @flipFlopForOpComplex( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: !tt.ptr, %[[VAL_2:.*]]: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { +// CHECK: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_0]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_11:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_12:.*]] = tt.addptr %[[VAL_1]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_13:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_14:.*]]:6 = scf.for %[[VAL_15:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_16:.*]] = %[[VAL_10]], %[[VAL_17:.*]] = %[[VAL_11]], %[[VAL_18:.*]] = %[[VAL_2]], %[[VAL_19:.*]] = %[[VAL_12]], %[[VAL_20:.*]] = %[[VAL_13]], %[[VAL_21:.*]] = %[[VAL_2]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_22:.*]] = tt.addptr %[[VAL_16]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_23:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_17]] : tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = tt.splat %[[VAL_22]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = tt.addptr %[[VAL_25]], %[[VAL_24]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_27:.*]] = tt.load %[[VAL_26]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_18]] : tensor<1024xf32> +// CHECK: %[[VAL_29:.*]] = tt.addptr %[[VAL_19]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_20]] : tensor<1024xi64> +// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_29]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_21]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_29]], %[[VAL_31]], %[[VAL_35]], %[[VAL_22]], %[[VAL_24]], %[[VAL_28]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_36:.*]] = tt.addptr %[[VAL_37:.*]]#0, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_38:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_37]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_40:.*]] = tt.splat %[[VAL_36]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_41:.*]] = tt.addptr %[[VAL_40]], %[[VAL_39]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_42:.*]] = tt.load %[[VAL_41]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_43:.*]] = tt.addptr %[[VAL_37]]#3, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_44:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_37]]#4 : tensor<1024xi64> +// CHECK: %[[VAL_46:.*]] = tt.splat %[[VAL_43]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_47:.*]] = tt.addptr %[[VAL_46]], %[[VAL_45]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_48:.*]] = tt.load %[[VAL_47]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_42]], %[[VAL_48]] : tensor<1024xf32>, tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @flipFlopForOpComplex(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %7:6 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %3, %arg5 = %4, %arg6 = %arg2, %arg7 = %5, %arg8 = %6, %arg9 = %arg2) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %20 = tt.addptr %arg4, %1 : !tt.ptr, i32 + %21 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %22 = arith.addi %21, %arg5 : tensor<1024xi64> + %23 = tt.splat %20 : !tt.ptr -> tensor<1024x!tt.ptr> + %24 = tt.addptr %23, %22 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %25 = tt.load %24 : tensor<1024x!tt.ptr> + %26 = arith.addf %25, %arg6 : tensor<1024xf32> + %27 = tt.addptr %arg7, %1 : !tt.ptr, i32 + %28 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %29 = arith.addi %28, %arg8 : tensor<1024xi64> + %30 = tt.splat %27 : !tt.ptr -> tensor<1024x!tt.ptr> + %31 = tt.addptr %30, %29 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %32 = tt.load %31 : tensor<1024x!tt.ptr> + %33 = arith.addf %32, %arg9 : tensor<1024xf32> + scf.yield %27, %29, %33, %20, %22, %26 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %8 = tt.addptr %7#0, %1 : !tt.ptr, i32 + %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %10 = arith.addi %9, %7#1 : tensor<1024xi64> + %11 = tt.splat %8 : !tt.ptr -> tensor<1024x!tt.ptr> + %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %13 = tt.load %12 : tensor<1024x!tt.ptr> + %14 = tt.addptr %7#3, %1 : !tt.ptr, i32 + %15 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %16 = arith.addi %15, %7#4 : tensor<1024xi64> + %17 = tt.splat %14 : !tt.ptr -> tensor<1024x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %19 = tt.load %18 : tensor<1024x!tt.ptr> + tt.return %13, %19 : tensor<1024xf32>, tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forOpDynamicKBound( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>, %[[VAL_2:.*]]: index) -> tensor<1024xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_11:.*]]:3 = scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_13:.*]] = %[[VAL_9]], %[[VAL_14:.*]] = %[[VAL_10]], %[[VAL_15:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_13]], %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_14]] : tensor<1024xi64> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_16]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_19]], %[[VAL_18]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_21:.*]] = tt.load %[[VAL_20]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_21]], %[[VAL_15]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_22]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_24:.*]]#0, %[[VAL_7]] : !tt.ptr, i32 +// CHECK: %[[VAL_25:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]]#1 : tensor<1024xi64> +// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_23]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_27]], %[[VAL_26]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_29:.*]] = tt.load %[[VAL_28]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_29]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOpDynamicKBound(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %K: index) -> tensor<1024xf32> { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5:3 = scf.for %arg2 = %c0 to %c128 step %K iter_args(%arg3 = %3, %arg4 = %4, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %12 = tt.addptr %arg3, %1 : !tt.ptr, i32 + %13 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %14 = arith.addi %13, %arg4 : tensor<1024xi64> + %15 = tt.splat %12 : !tt.ptr -> tensor<1024x!tt.ptr> + %16 = tt.addptr %15, %14 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %17 = tt.load %16 : tensor<1024x!tt.ptr> + %18 = arith.addf %17, %arg5 : tensor<1024xf32> + scf.yield %12, %14, %18 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %6 = tt.addptr %5#0, %1 : !tt.ptr, i32 + %7 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %8 = arith.addi %7, %5#1 : tensor<1024xi64> + %9 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %11 = tt.load %10 : tensor<1024x!tt.ptr> + tt.return %11 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @whileOp +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @whileOp(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %2 = scf.while (%arg2 = %1) : (tensor<1024x!tt.ptr>) -> tensor<1024x!tt.ptr> { + %4 = "dummy.evaluate_condition"() : () -> i1 + scf.condition(%4) %arg2 : tensor<1024x!tt.ptr> + } do { + ^bb0(%arg2: tensor<1024x!tt.ptr>): + %4 = tt.addptr %arg2, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + scf.yield %4 : tensor<1024x!tt.ptr> + } + %3 = tt.load %2 : tensor<1024x!tt.ptr> + tt.return %3 : tensor<1024xf32> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-convert-buffer-ops.mlir new file mode 100644 index 000000000..199f0934e --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -0,0 +1,631 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops='arch-generation-name=gfx942'| FileCheck %s + +#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: simple + tt.func @simple(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 :i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + // CHECK: buffer_load %arg0[%[[offset]]] + %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> + // CHECK: buffer_load %arg1[%[[offset]]] + %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> + // CHECK: %[[data:.*]] = arith.addf + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + // CHECK: buffer_store %[[data]], %arg2[%[[offset]]] + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { +// CHECK-LABEL: buffer_stride + tt.func public @buffer_stride(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c48_i32 = arith.constant 48 : i32 + %c32_i32 = arith.constant 32 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %cmp = arith.cmpi sgt, %arg6, %c0_i32 : i32 + llvm.intr.assume %cmp : i1 + %2 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<256x1xi32, #blocked> + %4 = tt.addptr %arg0, %c32_i32 : !tt.ptr, i32 + %5 = tt.broadcast %3 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> + %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> + %9 = arith.addi %8, %5 : tensor<256x64xi32, #blocked> + %10 = tt.splat %4 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> + %11 = tt.addptr %10, %9 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + + // CHECK: %[[splat:.*]] = tt.splat %arg[[#stride:]] + // CHECK: %[[mul:.*]] = arith.muli %[[#]], %[[splat]] + // CHECK: %[[ptr:.*]] = tt.addptr %arg0 + // CHECK: %[[bcast1:.*]] = tt.broadcast %[[mul]] + // CHECK: %[[bcast0:.*]] = tt.broadcast %[[#]] + // CHECK: %[[offset:.*]] = arith.addi %[[bcast0]], %[[bcast1]] + // CHECK: %[[buffer:.*]] = amdgpu.buffer_load %[[ptr]][%[[offset]]] stride = %arg[[#stride]] + + %12 = tt.load %11 {OpIdx = #amdgpu.OpIdx<0>} : tensor<256x64x!tt.ptr, #blocked> + %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %15 = tt.expand_dims %13 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %16 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %cmp1 = arith.cmpi sgt, %arg8, %c0_i32 : i32 + llvm.intr.assume %cmp1 : i1 + %17 = tt.splat %arg8 : i32 -> tensor<256x1xi32, #blocked> + %18 = arith.muli %17, %15 : tensor<256x1xi32, #blocked> + %19 = tt.addptr %arg2, %c48_i32 : !tt.ptr, i32 + %20 = tt.broadcast %18 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> + %21 = tt.broadcast %16 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> + %22 = tt.addptr %19, %c48_i32 : !tt.ptr, i32 + %23 = arith.addi %21, %20 : tensor<256x64xi32, #blocked> + %24 = tt.splat %22 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> + %25 = tt.addptr %24, %23 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + + // CHECK: %[[splatb:.*]] = tt.splat %arg[[#strideb:]] + // CHECK: %[[mulb:.*]] = arith.muli %[[splatb]], %[[#]] + // CHECK: %[[bcast1b:.*]] = tt.broadcast %[[mulb]] + // CHECK: %[[bcast0b:.*]] = tt.broadcast %[[#]] + // CHECK: %[[ptrb:.*]] = tt.addptr + // CHECK: %[[offsetb:.*]] = arith.addi %[[bcast0b]], %[[bcast1b]] + // CHECK: buffer_store %[[buffer]], %[[ptrb]][%[[offsetb]]] stride = %arg[[#strideb]] + + tt.store %25, %12 : tensor<256x64x!tt.ptr, #blocked> + tt.return + } +} +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_positive_offset + tt.func @assume_positive_offset(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %cmp = arith.cmpi sgt, %sub, %c0_i32 : i32 + llvm.intr.assume %cmp : i1 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %2, %3 : tensor<1024xi32, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: buffer_load %[[scalar_ptr]][%[[offset]]] + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: offset_64_bits + tt.func @offset_64_bits(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> { + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %ext2 = arith.extsi %2 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %ext3 = arith.extsi %3 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %4 = arith.addi %ext2, %ext3 : tensor<1024xi64, #blocked> + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> + // CHECK: tt.load + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: offset_64_bits_narrow + tt.func public @offset_64_bits_narrow(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> { + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.splat %1: i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %ext2 = arith.extsi %2 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %ext3 = arith.extsi %3 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %4 = arith.addi %ext2, %ext3 : tensor<1024xi64, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[offset_32_bit:.*]] = arith.trunci + %narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked> + %9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: non_canonical_ptr + tt.func @non_canonical_ptr(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: tensor<1024xi32, #blocked>) -> tensor<1024xf32, #blocked>{ + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %arg1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: tt.load + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_eq_non_neg + tt.func @assume_eq_non_neg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32) { + %c10_i32 = arith.constant 10 : i32 + %0 = arith.cmpi eq, %arg2, %c10_i32 : i32 + llvm.intr.assume %0 : i1 + // CHECK: %[[range:.*]] = tt.make_range + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> + // CHECK: %[[ptr:.*]] = tt.addptr %arg0, %arg2 + %2 = tt.addptr %arg0, %arg2: !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %4 = tt.addptr %3, %1 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %6 = tt.addptr %5, %1 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg1[%1] + %7 = tt.load %6 : tensor<16x!tt.ptr, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %[[ptr]][%[[range]]] + tt.store %4, %7 : tensor<16x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_nonneg_less + tt.func @assume_nonneg_less(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32) { + %c10_i32 = arith.constant 5 : i32 + %0 = arith.cmpi slt, %c10_i32, %arg2 : i32 + llvm.intr.assume %0 : i1 + // CHECK: %[[range:.*]] = tt.make_range + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> + // CHECK: %[[ptr:.*]] = tt.addptr %arg0, %arg2 + %2 = tt.addptr %arg0, %arg2: !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %4 = tt.addptr %3, %1 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %6 = tt.addptr %5, %1 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg1[%1] + %7 = tt.load %6 : tensor<16x!tt.ptr, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %[[ptr]][%[[range]]] + tt.store %4, %7 : tensor<16x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_cmp_non_const + tt.func @assume_cmp_non_const(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32, %arg3 : i32, %arg4 : i32, %arg5 : i32, %arg6 : i32) { + %0 = arith.cmpi sgt, %arg2, %arg3 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.subi %arg2, %arg3 : i32 + %2 = arith.cmpi sge, %1, %arg4 : i32 + llvm.intr.assume %2 : i1 + %3 = arith.subi %1, %arg4 : i32 + %4 = arith.cmpi slt, %3, %arg5 : i32 + llvm.intr.assume %4 : i1 + %5 = arith.subi %arg5, %3 : i32 + %6 = arith.cmpi sle, %5, %arg6 : i32 + llvm.intr.assume %6 : i1 + %7 = arith.subi %arg6, %5 : i32 + %8 = arith.minsi %1, %3 : i32 + %9 = arith.minsi %8, %5 : i32 + %10 = arith.minsi %9, %7 : i32 + // CHECK: %[[range:.*]] = tt.make_range + %11 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> + %12 = tt.splat %10 : i32 -> tensor<16xi32, #blocked> + // CHECK: %[[offsets:.*]] = arith.addi + %offsets = arith.addi %11, %12 : tensor<16xi32, #blocked> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %14 = tt.addptr %13, %11 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + %15 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %16 = tt.addptr %15, %offsets : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg1[%[[offsets]]] + %17 = tt.load %16 : tensor<16x!tt.ptr, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg0[%[[range]]] + tt.store %14, %17 : tensor<16x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blockedtrans = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.slice<{dim=0, parent=#blocked}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: unary_triton_ops_transitive_nonneg + tt.func @unary_triton_ops_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + %c10_i32 = arith.constant 5 : i32 + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blocked> + %2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<8x2xi32, #blocked> + %3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<2x8xi32, #blocked> + %4 = tt.trans %3 {order = array} : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blockedtrans> + %5 = ttg.convert_layout %4 : tensor<8x2xi32, #blockedtrans> -> tensor<8x2xi32, #blocked> + %6 = arith.addi %5, %2 : tensor<8x2xi32, #blocked> + %7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked2> + %8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked1> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked1> -> tensor<1x8xi32, #blocked> + %10 = tt.broadcast %9 : tensor<1x8xi32, #blocked> -> tensor<2x8xi32, #blocked> + %11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blocked> + %12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32, #blocked> + %13 = arith.addi %11, %12 : tensor<8x2xi32, #blocked> + %14 = arith.minsi %13, %5 : tensor<8x2xi32, #blocked> + // CHECK: %[[lhs:.*]], %[[rhs:.*]] = tt.split + %15, %16 = tt.split %11: tensor<8x2xi32, #blocked> -> tensor<8xi32, #blocked2> + %17 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked2> + %18 = tt.addptr %17, %15 : tensor<8x!tt.ptr, #blocked2>, tensor<8xi32, #blocked2> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%[[lhs]]] + %19 = tt.load %18 : tensor<8x!tt.ptr, #blocked2> + %20 = tt.addptr %17, %16 : tensor<8x!tt.ptr, #blocked2>, tensor<8xi32, #blocked2> + // CHECK: %[[loaded2:.*]] = amdgpu.buffer_load %arg0[%[[rhs]]] + %21 = tt.load %20 : tensor<8x!tt.ptr, #blocked2> + // CHECK: %[[added:.*]] = arith.addf %[[loaded]], %[[loaded2]] + %22 = arith.addf %19, %21 : tensor<8xbf16, #blocked2> + %23 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked2> + %24 = tt.addptr %23, %7 : tensor<8x!tt.ptr, #blocked2>, tensor<8xi32, #blocked2> + // CHECK: amdgpu.buffer_store %[[added]], %arg1[%{{.*}}] + tt.store %24, %22 : tensor<8x!tt.ptr, #blocked2> + tt.return + } +} + +// ----- + + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: join_cat_transitive_nonneg + tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1> + %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked1> + %2 = tt.join %0, %1 : tensor<8xi32, #blocked1> -> tensor<8x2xi32, #blocked> + %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> + %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> + %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> + %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> + %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> + %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> + %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> + %8 = tt.gather %7[%zeros] {axis = 1 : i32} : (tensor<8x2xi32, #blocked>, tensor<8x1xi32, #blocked>) -> tensor<8x1xi32, #blocked> + %9 = tt.gather %7[%ones] {axis = 1 : i32} : (tensor<8x2xi32, #blocked>, tensor<8x1xi32, #blocked>) -> tensor<8x1xi32, #blocked> + %10 = arith.addi %8, %9 : tensor<8x1xi32, #blocked> + %11 = tt.reshape %10 allow_reorder : tensor<8x1xi32, #blocked> -> tensor<8xi32, #blocked1> + %12 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked1> + %14 = tt.addptr %12, %11 : tensor<8x!tt.ptr, #blocked1>, tensor<8xi32, #blocked1> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %15 = tt.load %14 : tensor<8x!tt.ptr, #blocked1> + %16 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked1> + %17 = tt.addptr %16, %0 : tensor<8x!tt.ptr, #blocked1>, tensor<8xi32, #blocked1> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %17, %15 : tensor<8x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: histo_nonneg + tt.func @histo_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : tensor<256xi32, #blocked>) { + /// Purposely specify %arg2 so that we can't statically determine the input + /// data is nonneg. + // CHECK: tt.histogram + %0 = tt.histogram %arg2 : tensor<256xi32, #blocked> -> tensor<8xi32, #blocked> + %1 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %3 = tt.addptr %2, %0 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %4 = tt.load %3 : tensor<8x!tt.ptr, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %6 = tt.addptr %5, %1 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %6, %4 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: get_num_prog_nonneg + tt.func @get_num_prog_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32) { + %0 = tt.get_num_programs x : i32 + %1 = tt.get_num_programs y : i32 + %2 = tt.get_num_programs z : i32 + %3 = arith.minsi %0, %1 : i32 + %4 = arith.minsi %2, %3 : i32 + %5 = arith.maxsi %arg2, %4 : i32 + %6 = tt.splat %5 : i32 -> tensor<8xi32, #blocked> + %7 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<8xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %11 = tt.load %10 : tensor<8x!tt.ptr, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %13 = tt.addptr %12, %7 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %13, %11 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: unsigned_ops + tt.func @unsigned_ops(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32, %arg4 : f32) { + %c5_i32 = arith.constant 5 : i32 + %0 = arith.ceildivui %arg2, %c5_i32 : i32 + %1 = arith.divui %arg3, %c5_i32 : i32 + %2 = arith.fptoui %arg4 : f32 to i32 + %4 = arith.maxui %arg2, %arg3 : i32 + %5 = arith.minui %arg2, %arg3 : i32 + %6 = arith.remui %arg2, %c5_i32 : i32 + %7 = arith.shrui %arg3, %c5_i32 : i32 + %8 = arith.addi %0, %1 : i32 + %10 = arith.addi %4, %5 : i32 + %11 = arith.addi %6, %7 : i32 + %12 = arith.addi %8, %2 : i32 + %13 = arith.addi %10, %11 : i32 + %14 = arith.addi %8, %13 : i32 + %15 = tt.splat %14 : i32 -> tensor<8xi32, #blocked> + %16 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %17 = arith.addi %15, %16 : tensor<8xi32, #blocked> + %18 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %19 = tt.addptr %18, %17 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %20 = tt.load %19 : tensor<8x!tt.ptr, #blocked> + %21 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %22 = tt.addptr %21, %16 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %22, %20 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: extui_nonneg + tt.func @extui_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32) { + %0 = arith.extui %arg2 : i32 to i64 + %1 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %3 = arith.extui %2 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %4 = arith.addi %1, %3 : tensor<8xi64, #blocked> + %5 = arith.trunci %4 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %6 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %7 = tt.addptr %6, %5 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %8 = tt.load %7: tensor<8x!tt.ptr, #blocked> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %10 = tt.addptr %9, %2 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %10, %8 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: traverse_if + tt.func @traverse_if(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 7 : i32 + %c7_i32 = arith.constant 5 : i32 + %0 = arith.extui %arg2 : i32 to i64 + %1 = arith.remui %arg2, %c2_i32 : i32 + %2 = arith.cmpi eq, %1, %c0_i32 : i32 + %3 = scf.if %2 -> tensor<8xi64, #blocked> { + %20 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %21 = arith.extui %20 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %22 = tt.splat %arg3 : i32 -> tensor<8xi32, #blocked> + %23 = arith.extui %22 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %24 = arith.addi %21, %23 : tensor<8xi64, #blocked> + scf.yield %24 : tensor<8xi64, #blocked> + } else { + %30 = tt.make_range {end = 16 : i32, start = 8 : i32} : tensor<8xi32, #blocked> + %31 = arith.extui %30 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %32 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %33 = arith.addi %31, %32 : tensor<8xi64, #blocked> + scf.yield %33 : tensor<8xi64, #blocked> + } + %4 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %7 = tt.load %6: tensor<8x!tt.ptr, #blocked> + %8 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %10, %7 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: traverse_if + tt.func @traverse_if(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 7 : i32 + %c7_i32 = arith.constant 5 : i32 + %zeros = arith.constant dense<0> : tensor<8xi32, #blocked> + %0 = arith.extui %arg2 : i32 to i64 + %1 = arith.remui %arg2, %c2_i32 : i32 + %2 = arith.cmpi eq, %1, %c0_i32 : i32 + %3, %4 = scf.if %2 -> (tensor<8xi64, #blocked>, tensor<8xi32, #blocked>) { + %20 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %21 = arith.extui %20 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %22 = tt.splat %arg3 : i32 -> tensor<8xi32, #blocked> + %23 = arith.extui %22 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %24 = arith.addi %21, %23 : tensor<8xi64, #blocked> + %25 = tt.make_range {end = 9 : i32, start = 1 : i32} : tensor<8xi32, #blocked> + scf.yield %24, %25 : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } else { + %30 = tt.make_range {end = 16 : i32, start = 8 : i32} : tensor<8xi32, #blocked> + %31 = arith.extui %30 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %32 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %33 = arith.addi %31, %32 : tensor<8xi64, #blocked> + scf.yield %33, %zeros : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } + %5 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %6 = arith.addi %4, %5 : tensor<8xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %9 = tt.load %8: tensor<8x!tt.ptr, #blocked> + %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %12 = tt.addptr %11, %10 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %12, %9 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_positive_offset_buffer_atomic + tt.func @assume_positive_offset_buffer_atomic(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %cmp = arith.cmpi sgt, %sub, %c0_i32 : i32 + llvm.intr.assume %cmp : i1 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %2, %3 : tensor<1024xi32, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] + %8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { +// CHECK-LABEL: buffer_load_to_local + tt.func public @buffer_load_to_local(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, + %arg10: !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, %arg11: tensor<256x64xi1, #blocked>, %arg12: tensor<256x64xf16, #blocked>) { + %c48_i32 = arith.constant 48 : i32 + %c32_i32 = arith.constant 32 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %cmp = arith.cmpi sgt, %arg6, %c0_i32 : i32 + llvm.intr.assume %cmp : i1 + %2 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<256x1xi32, #blocked> + %4 = tt.addptr %arg0, %c32_i32 : !tt.ptr, i32 + %5 = tt.broadcast %3 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> + %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> + %9 = arith.addi %8, %5 : tensor<256x64xi32, #blocked> + %10 = tt.splat %4 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> + %11 = tt.addptr %10, %9 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + + // CHECK: %[[splat:.*]] = tt.splat %arg[[#stride:]] + // CHECK: %[[mul:.*]] = arith.muli %[[#]], %[[splat]] + // CHECK: %[[ptr:.*]] = tt.addptr %arg0 + // CHECK: %[[bcast1:.*]] = tt.broadcast %[[mul]] + // CHECK: %[[bcast0:.*]] = tt.broadcast %[[#]] + // CHECK: %[[offset:.*]] = arith.addi %[[bcast0]], %[[bcast1]] + + // CHECK: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] stride = %arg[[#stride]] into %arg10 + %12 = ttg.async_copy_global_to_local %11, %arg10 : tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> + + // CHECK: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] other = %arg12 stride = %arg[[#stride]] into %arg10 + %13 = ttg.async_copy_global_to_local %11, %arg10 other %arg12: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> + + // CHECK: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 stride = %arg[[#stride]] into %arg10 + %14 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> + + // CHECK: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] into %arg10 + %15 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 : tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> + + // CHECK: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = ca into %arg10 + %16 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = ca: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> + + // CHECK: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cg into %arg10 + %17 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = cg: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> + + // CHECK: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cv into %arg10 + %18 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = cv: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-extractslice-op.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-extractslice-op.mlir new file mode 100644 index 000000000..bde77b475 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-extractslice-op.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @basic_insert_slice(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) { + // CHECK: llvm.func @basic_insert_slice + // CHECK-COUNT-64: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %64 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-8: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + %72 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir new file mode 100644 index 000000000..f8eba39a9 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-hoist-cvtToDotOp.mlir @@ -0,0 +1,86 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-hoist-layout-conversions | FileCheck %s + +// Hoist convert_layout out of the loop since the defining op of the src is out of the loop + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// CHECK-LABEL: hoist_cvtToDotOp +// CHECK: %[[AF16:.*]] = arith.truncf +// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout %[[AF16]] +// CHECK-NEXT: scf.for +// CHECK: tt.dot %[[opA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @hoist_cvtToDotOp(%opA: tensor<256x128xf32, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = arith.truncf %opA : tensor<256x128xf32, #blocked> to tensor<256x128xf16, #blocked> + %1:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %2 = ttg.convert_layout %0 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0> + %3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Keep convert_layout inside the loop since the defining op of the src is inside the loop + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// CHECK-LABEL: defOp_in_loop +// CHECK: scf.for +// CHECK: %[[AF16:.*]] = arith.truncf +// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout %[[AF16]] +// CHECK: tt.dot %[[opA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @defOp_in_loop(%opA: tensor<256x128xf32, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %1:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %0 = arith.truncf %opA : tensor<256x128xf32, #blocked> to tensor<256x128xf16, #blocked> + %2 = ttg.convert_layout %0 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0> + %3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Keep convert_layout inside the loop since the defining op is a block argument of the loop + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +// CHECK-LABEL: defOp_blockArg +// CHECK: scf.for +// CHECK-NEXT: %[[opA:.*]] = ttg.convert_layout +// CHECK: tt.dot %[[opA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @defOp_blockArg(%opA: tensor<256x128xf16, #blocked>, %opB: tensor<128x256xf16, #dotOp1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %1:2 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst, %arg2 = %opA) -> (tensor<256x256xf32, #mma>, tensor<256x128xf16, #blocked>) : i32 { + %2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #dotOp0> + %3 = tt.dot %2, %opB, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + scf.yield %3, %arg2 : tensor<256x256xf32, #mma>, tensor<256x128xf16, #blocked> + } + tt.store %C_ptr, %1#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-instruction-sched.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-instruction-sched.mlir new file mode 100644 index 000000000..ad974b993 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -0,0 +1,169 @@ +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_0' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_1' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='arch=gfx942 num_stages=2' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 + +module { + // INSERT_IGLP0-LABEL: @test_dot_op + // INSERT_IGLP1-LABEL: @test_dot_op + // INSTR_COUNT_NS1-LABEL: @test_dot_op + // INSTR_COUNT_NS2-LABEL: @test_dot_op + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: @test_dot_op + // LABELING_PS_1-LABEL: @test_dot_op + // LABELING_PS_2-LABEL: @test_dot_op + tt.func @test_dot_op(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %C : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32> -> tensor<128x32xi32> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32> -> tensor<32x128xi32> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + + %a_mask = arith.constant dense : tensor<128x32xi1> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16> + %b_mask = arith.constant dense : tensor<32x128xi1> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32> + + %a_off = arith.constant dense<4> : tensor<128x32xi32> + %b_off = arith.constant dense<4> : tensor<32x128xi32> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32>) { + %a = tt.load %a_ptr : tensor<128x32x!tt.ptr> + %b = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr> + + // INSERT_IGLP0: rocdl.iglp.opt 0 + // INSERT_IGLP1: rocdl.iglp.opt 1 + + // INSTR_COUNT_NS1: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // INSTR_COUNT_NS2: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD:.+]] + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE:512]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA:8]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ:32]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ:256]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0 + // USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD]] + + + // LABELING_PS_1: scf.for + // LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_1: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_1: %[[REG1_OP0:.+]] = ttg.convert_layout %[[REG0_OP0]] + // LABELING_PS_1: %[[REG1_OP1:.+]] = ttg.convert_layout %[[REG0_OP1]] + // LABELING_PS_1: tt.dot %[[REG1_OP0]], %[[REG1_OP1]], {{.*}} + + // LABELING_PS_2: scf.for + // LABELING_PS_2: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_2: ttg.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: ttg.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<1>} + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32> + } + + // C ptrs + %c_ptr_splat = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr> + %c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32> -> tensor<128x128xi32> + %c_ptr = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + + tt.store %c_ptr, %loop#2 : tensor<128x128x!tt.ptr> + tt.return +} +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-optimize-epilogue.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-optimize-epilogue.mlir new file mode 100644 index 000000000..8cc467e77 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-optimize-epilogue.mlir @@ -0,0 +1,42 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-optimize-epilogue | FileCheck %s + +// CHECK-LABEL: one_op_in_chain +// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> +// CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x32x!tt.ptr, #mma> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @one_op_in_chain(%arg0: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %2 = arith.truncf %1 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + tt.store %3, %2 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: two_ops_in_chain +// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> +// CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x32x!tt.ptr, #mma> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @two_ops_in_chain(%arg0: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %2 = math.exp2 %1 : tensor<32x32xf32, #blocked> + %3 = arith.truncf %2 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + tt.store %4, %3 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-range-analysis.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-range-analysis.mlir new file mode 100644 index 000000000..48018c479 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-range-analysis.mlir @@ -0,0 +1,1269 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -test-tritonamdgpu-range-analysis -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: tt.func @conversion1 +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion1(%arg0: !tt.ptr) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}} + // expected-remark@+1 {{non-neg}} + %numps = tt.get_num_programs x : i32 + %2 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> + %4 = tt.load %3 : tensor<1024x!tt.ptr> + tt.return %4 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @assumepid +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @assumepid(%arg0: !tt.ptr) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : i32 + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 2147483647] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %pid = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %cmpsle = arith.cmpi sle, %pid, %c1024_i32 : i32 + llvm.intr.assume %cmpsle : i1 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %cmpsge = arith.cmpi sge, %pid, %c0 : i32 + llvm.intr.assume %cmpsge : i1 + // expected-remark@+2 {{unsigned : [0, 1048576] signed : [0, 1048576]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %pid, %c1024_i32 : i32 + %2 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> + %4 = tt.load %3 : tensor<1024x!tt.ptr> + tt.return %4 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @conversion2 +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion2(%arg0: !tt.ptr) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.splat %3 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.load %5 : tensor<1024x!tt.ptr> + tt.return %6 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @conversion3 +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion3(%arg0: !tt.ptr) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5 = tt.addptr %3, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 2048] signed : [0, 2048]}} + // expected-remark@+1 {{non-neg}} + %7 = arith.addi %6, %4 : tensor<1024xi64> + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @conversion4 +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @conversion4(%arg0: !tt.ptr {tt.pointer_range = 32 : i32}) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.addptr %3, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 2048] signed : [0, 2048]}} + // expected-remark@+1 {{non-neg}} + %5 = arith.addi %2, %2 : tensor<1024xi32> + %6 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %5 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forOp +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOp(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %3, %arg4 = %4, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %12 = tt.addptr %arg3, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %13 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %14 = arith.addi %13, %arg4 : tensor<1024xi64> + %15 = tt.splat %12 : !tt.ptr -> tensor<1024x!tt.ptr> + %16 = tt.addptr %15, %14 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %17 = tt.load %16 : tensor<1024x!tt.ptr> + %18 = arith.addf %17, %arg5 : tensor<1024xf32> + scf.yield %12, %14, %18 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %6 = tt.addptr %5#0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %7 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %8 = arith.addi %7, %5#1 : tensor<1024xi64> + %9 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %11 = tt.load %10 : tensor<1024x!tt.ptr> + tt.return %11 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forOp2 +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOp2(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %cst = arith.constant dense<0> : tensor<1024xi64> + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %10 = tt.addptr %arg3, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %11 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+1 {{non-neg}} + %12 = arith.addi %11, %arg4 : tensor<1024xi64> + %13 = tt.splat %10 : !tt.ptr -> tensor<1024x!tt.ptr> + %14 = tt.addptr %13, %12 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %15 = tt.load %14 : tensor<1024x!tt.ptr> + %16 = arith.addf %15, %arg5 : tensor<1024xf32> + scf.yield %10, %12, %16 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+1 {{non-neg}} + %6 = arith.addi %5, %3#1 : tensor<1024xi64> + %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forNested +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forNested(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %cst = arith.constant dense<0> : tensor<1024xi64> + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [16, 16] signed : [16, 16]}} + // expected-remark@+1 {{non-neg}} + %c16 = arith.constant 16 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:3 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %10:3 = scf.for %arg6 = %c0 to %c16 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 262144] signed : [0, 262144]}} + // expected-remark@+1 {{non-neg}} + %13 = arith.addi %12, %arg8 : tensor<1024xi64> + %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %16 = tt.load %15 : tensor<1024x!tt.ptr> + %17 = arith.addf %16, %arg9 : tensor<1024xf32> + scf.yield %11, %13, %17 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 16384] signed : [0, 16384]}} + // expected-remark@+1 {{non-neg}} + %6 = arith.addi %5, %3#1 : tensor<1024xi64> + %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forNestedOverMaxTripCount +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forNestedOverMaxTripCount(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %cst = arith.constant dense<0> : tensor<1024xi64> + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %10:3 = scf.for %arg6 = %c0 to %c128 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %13 = arith.addi %12, %arg8 : tensor<1024xi64> + %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %16 = tt.load %15 : tensor<1024x!tt.ptr> + %17 = arith.addf %16, %arg9 : tensor<1024xf32> + scf.yield %11, %13, %17 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %6 = arith.addi %5, %3#1 : tensor<1024xi64> + %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @ifOp +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @ifOp(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %arg2: i1) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %cst = arith.constant dense<0> : tensor<1024xi64> + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:2 = scf.if %arg2 -> (!tt.ptr, tensor<1024xi64>) { + %8 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + scf.yield %8, %9 : !tt.ptr, tensor<1024xi64> + } else { + %8 = tt.addptr %arg0, %1 : !tt.ptr, i32 + scf.yield %8, %cst : !tt.ptr, tensor<1024xi64> + } + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.trunci %3#1 : tensor<1024xi64> to tensor<1024xi32> + %5 = tt.splat %3#0 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7 = tt.load %6 : tensor<1024x!tt.ptr> + tt.return %7 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @condBranch +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @condBranch(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %cst = arith.constant dense<0> : tensor<1024xi64> + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + cf.cond_br %arg1, ^bb1(%arg0, %cst : !tt.ptr, tensor<1024xi64>), ^bb2(%3, %4 : !tt.ptr, tensor<1024xi64>) + ^bb1(%5: !tt.ptr, %6: tensor<1024xi64>): // pred: ^bb0 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %7 = arith.trunci %6 : tensor<1024xi64> to tensor<1024xi32> + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + ^bb2(%11: !tt.ptr, %12: tensor<1024xi64>): // pred: ^bb0 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %13 = arith.trunci %12 : tensor<1024xi64> to tensor<1024xi32> + %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %16 = tt.load %15 : tensor<1024x!tt.ptr> + tt.return %16 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @branch +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @branch(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.splat %3 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.addptr %4, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %6 = tt.load %5 : tensor<1024x!tt.ptr> + tt.return %6 : tensor<1024xf32> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @tile_offset(%arg0: !tt.ptr, %arg1: i32, %arg2: i32) -> tensor<16x256xf16, #blocked> { + // expected-remark@+2 {{unsigned : [256, 256] signed : [256, 256]}} + // expected-remark@+1 {{non-neg}} + %c256_i32 = arith.constant 256 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 16776960] signed : [0, 16776960]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c256_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+1 {{non-neg}} + %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+1 {{non-neg}} + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %5 = tt.splat %arg2 : i32 -> tensor<16x1xi32, #blocked> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %6 = arith.muli %4, %5 : tensor<16x1xi32, #blocked> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %7 = tt.broadcast %6 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+1 {{non-neg}} + %8 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+1 {{non-neg}} + %9 = tt.broadcast %8 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %10 = arith.addi %7, %9 : tensor<16x256xi32, #blocked> + %11 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %12 = tt.splat %11 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> + %13 = tt.addptr %12, %10 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> + %14 = tt.load %13 : tensor<16x256x!tt.ptr, #blocked> + tt.return %14 : tensor<16x256xf16, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) -> tensor<128x16xf16, #blocked> { + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128_i32 = arith.constant 128 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 8388480] signed : [0, 8388480]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c128_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+1 {{non-neg}} + %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + // expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}} + // expected-remark@+1 {{non-neg}} + %4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %5 = arith.muli %1, %arg1 : i32 + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %7 = arith.muli %4, %6 : tensor<128x1xi32, #blocked> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %8 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+1 {{non-neg}} + %9 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+1 {{non-neg}} + %10 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %11 = arith.addi %8, %10 : tensor<128x16xi32, #blocked> + %12 = tt.addptr %arg0, %5 : !tt.ptr, i32 + %13 = tt.splat %12 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %14 = tt.addptr %13, %11 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %15 = tt.load %14 : tensor<128x16x!tt.ptr, #blocked> + tt.return %15 : tensor<128x16xf16, #blocked> + } +} + +// ----- + +// CHECK-LABEL: tt.func @select +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @select(%arg0: !tt.ptr, %arg1: i1) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %cst = arith.constant dense<0> : tensor<1024xi64> + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5 = arith.select %arg1, %arg0, %3 : !tt.ptr + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %6 = arith.select %arg1, %cst, %4 : tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %7 = arith.trunci %6 : tensor<1024xi64> to tensor<1024xi32> + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @where_kernel +module attributes {"ttg.num-ctas" = 1 : i32} { + tt.func @where_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i8) -> tensor<1024xi64> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0_i8 = arith.constant 0 : i8 + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %3 = arith.cmpi ne, %arg2, %c0_i8 : i8 + %4 = arith.select %3, %arg0, %arg1 : !tt.ptr + %5 = tt.addptr %4, %1 : !tt.ptr, i32 + %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %8 = tt.load %7 : tensor<1024x!tt.ptr> + tt.return %8 : tensor<1024xi64> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forOpWithHints +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOpWithHints(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : index + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %3 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> + %4:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %2, %arg4 = %3, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+1 {{non-neg}} + %11 = arith.trunci %arg4 : tensor<1024xi64> to tensor<1024xi32> + %12 = tt.splat %arg3 : !tt.ptr -> tensor<1024x!tt.ptr> + %13 = tt.addptr %12, %11 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %14 = tt.load %13 : tensor<1024x!tt.ptr> + %15 = tt.addptr %arg3, %0 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %16 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %17 = arith.addi %16, %arg4 : tensor<1024xi64> + %18 = tt.addptr %15, %0 : !tt.ptr, i32 + %19 = arith.addf %14, %arg5 : tensor<1024xf32> + scf.yield %18, %17, %19 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>, tt.divisibility_arg2 = dense<16> : tensor<1xi32>} + %5 = tt.addptr %4#0, %0 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %6 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %7 = arith.addi %6, %4#1 : tensor<1024xi64> + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + tt.return %10 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func public @scalar_pointers +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0_i64 = arith.constant 0 : i64 + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1_i32 = arith.constant 1 : i32 + // expected-remark@+2 {{unsigned : [100, 100] signed : [100, 100]}} + // expected-remark@+1 {{non-neg}} + %c100_i32 = arith.constant 100 : i32 + %0 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + %1 = scf.for %arg1 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg2 = %0) -> (!tt.ptr) : i32 { + tt.store %arg2, %c0_i64 : !tt.ptr + %2 = tt.addptr %arg2, %c1_i32 : !tt.ptr, i32 + scf.yield %2 : !tt.ptr + } + tt.return + } +} + +// ----- + +// CHECK-LABEL: tt.func @scalar_if +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @scalar_if(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %arg2: i1) -> f32 { + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1_i32 = arith.constant 1 : i32 + // expected-remark@+2 {{unsigned : [100, 100] signed : [100, 100]}} + // expected-remark@+1 {{non-neg}} + %c100_i32 = arith.constant 100 : i32 + %0 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + %1 = scf.if %arg2 -> (!tt.ptr) { + %3 = tt.addptr %0, %c1_i32 : !tt.ptr, i32 + scf.yield %3 : !tt.ptr + } else { + %3 = tt.addptr %0, %c100_i32 : !tt.ptr, i32 + scf.yield %3 : !tt.ptr + } + %2 = tt.load %1 : !tt.ptr + tt.return %2 : f32 + } +} + +// ----- + +// CHECK-LABEL: tt.func @scalar_cond_branch +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @scalar_cond_branch(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i1) -> f32 { + cf.cond_br %arg2, ^bb1(%arg0 : !tt.ptr), ^bb2(%arg1 : !tt.ptr) + ^bb1(%0: !tt.ptr): // pred: ^bb0 + %1 = tt.load %0 : !tt.ptr + tt.return %1 : f32 + ^bb2(%2: !tt.ptr): // pred: ^bb0 + %3 = tt.load %2 : !tt.ptr + tt.return %3 : f32 + } +} + +// ----- + +// CHECK-LABEL: tt.func @flipFlopForOpSimple +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @flipFlopForOpSimple(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %7:5 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %6, %arg5 = %3, %arg6 = %4, %arg7 = %arg1) -> (!tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %14 = tt.addptr %arg5, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %15 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %16 = arith.addi %15, %arg6 : tensor<1024xi64> + %17 = tt.splat %14 : !tt.ptr -> tensor<1024x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %19 = tt.load %18 : tensor<1024x!tt.ptr> + %20 = arith.addf %19, %arg7 : tensor<1024xf32> + scf.yield %14, %16, %arg3, %arg4, %20 : !tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %8 = tt.addptr %7#0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %10 = arith.addi %9, %7#1 : tensor<1024xi64> + %11 = tt.splat %8 : !tt.ptr -> tensor<1024x!tt.ptr> + %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %13 = tt.load %12 : tensor<1024x!tt.ptr> + tt.return %13 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @flipFlopForOpComplex +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @flipFlopForOpComplex(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %7:6 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %3, %arg5 = %4, %arg6 = %arg2, %arg7 = %5, %arg8 = %6, %arg9 = %arg2) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %20 = tt.addptr %arg4, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %21 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %22 = arith.addi %21, %arg5 : tensor<1024xi64> + %23 = tt.splat %20 : !tt.ptr -> tensor<1024x!tt.ptr> + %24 = tt.addptr %23, %22 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %25 = tt.load %24 : tensor<1024x!tt.ptr> + %26 = arith.addf %25, %arg6 : tensor<1024xf32> + %27 = tt.addptr %arg7, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %28 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %29 = arith.addi %28, %arg8 : tensor<1024xi64> + %30 = tt.splat %27 : !tt.ptr -> tensor<1024x!tt.ptr> + %31 = tt.addptr %30, %29 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %32 = tt.load %31 : tensor<1024x!tt.ptr> + %33 = arith.addf %32, %arg9 : tensor<1024xf32> + scf.yield %27, %29, %33, %20, %22, %26 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %8 = tt.addptr %7#0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %10 = arith.addi %9, %7#1 : tensor<1024xi64> + %11 = tt.splat %8 : !tt.ptr -> tensor<1024x!tt.ptr> + %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %13 = tt.load %12 : tensor<1024x!tt.ptr> + %14 = tt.addptr %7#3, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %15 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+1 {{non-neg}} + %16 = arith.addi %15, %7#4 : tensor<1024xi64> + %17 = tt.splat %14 : !tt.ptr -> tensor<1024x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %19 = tt.load %18 : tensor<1024x!tt.ptr> + tt.return %13, %19 : tensor<1024xf32>, tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @forOpDynamicKBound +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forOpDynamicKBound(%arg0: !tt.ptr, %arg1: tensor<1024xf32>, %K: index) -> tensor<1024xf32> { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_program_id x : i32 + // expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}} + // expected-remark@+1 {{non-neg}} + %1 = arith.muli %0, %c1024_i32 : i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %5:3 = scf.for %arg2 = %c0 to %c128 step %K iter_args(%arg3 = %3, %arg4 = %4, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %12 = tt.addptr %arg3, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %13 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %14 = arith.addi %13, %arg4 : tensor<1024xi64> + %15 = tt.splat %12 : !tt.ptr -> tensor<1024x!tt.ptr> + %16 = tt.addptr %15, %14 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %17 = tt.load %16 : tensor<1024x!tt.ptr> + %18 = arith.addf %17, %arg5 : tensor<1024xf32> + scf.yield %12, %14, %18 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %6 = tt.addptr %5#0, %1 : !tt.ptr, i32 + // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+1 {{non-neg}} + %7 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %8 = arith.addi %7, %5#1 : tensor<1024xi64> + %9 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %11 = tt.load %10 : tensor<1024x!tt.ptr> + tt.return %11 : tensor<1024xf32> + } +} + +// ----- + +// CHECK-LABEL: tt.func @DynamicKBound +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @DynamicKBound(%K: i32) { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %cmp = arith.cmpi sle, %K, %c128 : i32 + llvm.intr.assume %cmp : i1 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %condtest = arith.cmpi sle, %K, %c1024_i32 : i32 + tt.return + } +} + +// ----- + +// CHECK-LABEL: tt.func @unsupportedAssumption +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @unsupportedAssumption(%K: i32) { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : i32 + // expected-remark@+2 {{unsigned : [0, 1] signed : [-1, 0]}} + // expected-remark@+1 {{unsigned arithmetic not currently supported}} + %cmp = arith.cmpi ule, %K, %c128 : i32 + llvm.intr.assume %cmp : i1 + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %condtest = arith.cmpi sle, %K, %c1024_i32 : i32 + tt.return + } +} + +// ----- + +// CHECK-LABEL: tt.func @moreDynamicKBound +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @moreDynamicKBound( + %Keqlhs: i32, + %Ksgelhs: i32, + %Ksgtlhs: i32, + %Kslelhs: i32, + %Ksltlhs: i32, + %Keqrhs: i32, + %Ksgerhs: i32, + %Ksgtrhs: i32, + %Kslerhs: i32, + %Ksltrhs: i32 + ) { + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : i32 + // expected-remark@+2 {{unsigned : [16, 16] signed : [16, 16]}} + // expected-remark@+1 {{non-neg}} + %c16 = arith.constant 16 : i32 + // expected-remark@+2 {{unsigned : [32, 32] signed : [32, 32]}} + // expected-remark@+1 {{non-neg}} + %c32 = arith.constant 32 : i32 + // expected-remark@+2 {{unsigned : [64, 64] signed : [64, 64]}} + // expected-remark@+1 {{non-neg}} + %c64 = arith.constant 64 : i32 + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : i32 + // expected-remark@+2 {{unsigned : [256, 256] signed : [256, 256]}} + // expected-remark@+1 {{non-neg}} + %c256 = arith.constant 256 : i32 + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + + //// eq comparison + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumeeqlhs = arith.cmpi eq, %Keqlhs, %c128 : i32 + llvm.intr.assume %assumeeqlhs : i1 + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %testeqlhs1 = arith.addi %Keqlhs, %c0 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %testeqlhs2 = arith.cmpi ne, %Keqlhs, %c256 : i32 + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumeeqrhs = arith.cmpi eq, %c64, %Keqrhs : i32 + llvm.intr.assume %assumeeqrhs : i1 + // expected-remark@+2 {{unsigned : [64, 64] signed : [64, 64]}} + // expected-remark@+1 {{non-neg}} + %testeqrhs1 = arith.addi %Keqrhs, %c0 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %testeqrhs2 = arith.cmpi ne, %Keqrhs, %c256 : i32 + + //// sge comparison + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumesgelhs = arith.cmpi sge, %Ksgelhs, %c128 : i32 + llvm.intr.assume %assumesgelhs : i1 + // expected-remark@+2 {{unsigned : [128, 2147483647] signed : [128, 2147483647]}} + // expected-remark@+1 {{non-neg}} + %testsgelhs1 = arith.addi %Ksgelhs, %c0 : i32 + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %testsgelhs2 = arith.cmpi sge, %Ksgelhs, %c1024_i32 : i32 + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumesgerhs = arith.cmpi sge, %c128, %Ksgerhs : i32 + llvm.intr.assume %assumesgerhs : i1 + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 128]}} + %testsgerhs1 = arith.addi %Ksgerhs, %c0 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %testsgerhs2 = arith.cmpi sge, %c1024_i32, %Ksgerhs : i32 + + //// sgt comparison + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumesgtlhs = arith.cmpi sgt, %Ksgtlhs, %c128 : i32 + llvm.intr.assume %assumesgtlhs : i1 + // expected-remark@+2 {{unsigned : [129, 2147483647] signed : [129, 2147483647]}} + // expected-remark@+1 {{non-neg}} + %testsgtlhs1 = arith.addi %Ksgtlhs, %c0 : i32 + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %testsgtlhs2 = arith.cmpi sgt, %Ksgtlhs, %c1024_i32 : i32 + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumesgtrhs = arith.cmpi sgt, %c128, %Ksgtrhs : i32 + llvm.intr.assume %assumesgtrhs : i1 + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 127]}} + %testsgtrhs1 = arith.addi %Ksgtrhs, %c0 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %testsgtrhs2 = arith.cmpi sgt, %c1024_i32, %Ksgtrhs : i32 + + //// sle comparison + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumeslelhs = arith.cmpi sle, %Kslelhs, %c128 : i32 + llvm.intr.assume %assumeslelhs : i1 + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 128]}} + %testslelhs1 = arith.addi %Kslelhs, %c0 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %testslelhs2 = arith.cmpi sle, %Kslelhs, %c1024_i32 : i32 + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumeslerhs = arith.cmpi sle, %c128, %Kslerhs : i32 + llvm.intr.assume %assumeslerhs : i1 + // expected-remark@+2 {{unsigned : [128, 2147483647] signed : [128, 2147483647]}} + // expected-remark@+1 {{non-neg}} + %testslerhs1 = arith.addi %Kslerhs, %c0 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %testslerhs2 = arith.cmpi sle, %c64, %Kslerhs : i32 + + //// slt comparison + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumesltlhs = arith.cmpi slt, %Ksltlhs, %c128 : i32 + llvm.intr.assume %assumesltlhs : i1 + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 127]}} + %testsltlhs1 = arith.addi %Ksltlhs, %c0 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %testsltlhs2 = arith.cmpi slt, %Ksltlhs, %c1024_i32 : i32 + + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %assumesltrhs = arith.cmpi slt, %c128, %Ksltrhs : i32 + llvm.intr.assume %assumesltrhs : i1 + // expected-remark@+2 {{unsigned : [129, 2147483647] signed : [129, 2147483647]}} + // expected-remark@+1 {{non-neg}} + %testsltrhs1 = arith.addi %Ksltrhs, %c0 : i32 + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %testsltrhs2 = arith.cmpi slt, %c64, %Ksltrhs : i32 + + tt.return + } +} + +// ----- + + +// CHECK-LABEL: join_cat_transitive_nonneg +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr, %arg1: !tt.ptr) { + // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + // expected-remark@+2 {{unsigned : [2, 10] signed : [2, 10]}} + // expected-remark@+1 {{non-neg}} + %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32> + // expected-remark@+2 {{unsigned : [0, 10] signed : [0, 10]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.join %0, %1 : tensor<8xi32> -> tensor<8x2xi32> + // expected-remark@+2 {{unsigned : [0, 4] signed : [0, 4]}} + // expected-remark@+1 {{non-neg}} + %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + // expected-remark@+2 {{unsigned : [4, 8] signed : [4, 8]}} + // expected-remark@+1 {{non-neg}} + %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32> + // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+1 {{non-neg}} + %5 = tt.join %3, %4 : tensor<4xi32> -> tensor<4x2xi32> + // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+1 {{non-neg}} + %6 = tt.cat %5, %5 : tensor<4x2xi32> -> tensor<8x2xi32> + // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+1 {{non-neg}} + %7 = arith.addi %2, %6 : tensor<8x2xi32> + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %zeros = arith.constant dense<0> : tensor<8x1xi32> + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %ones = arith.constant dense<1> : tensor<8x1xi32> + // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+1 {{non-neg}} + %8 = tt.gather %7[%zeros] {axis = 1 : i32} : (tensor<8x2xi32>, tensor<8x1xi32>) -> tensor<8x1xi32> + // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+1 {{non-neg}} + %9 = tt.gather %7[%ones] {axis = 1 : i32} : (tensor<8x2xi32>, tensor<8x1xi32>) -> tensor<8x1xi32> + // expected-remark@+2 {{unsigned : [0, 36] signed : [0, 36]}} + // expected-remark@+1 {{non-neg}} + %10 = arith.addi %8, %9 : tensor<8x1xi32> + // expected-remark@+2 {{unsigned : [0, 36] signed : [0, 36]}} + // expected-remark@+1 {{non-neg}} + %11 = tt.reshape %10 allow_reorder : tensor<8x1xi32> -> tensor<8xi32> + tt.return + } +} + +// ----- + +// CHECK-LABEL: histo_nonneg +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @histo_nonneg(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2 : tensor<256xi32>) { + // expected-remark@+2 {{unsigned : [0, 4294967295] signed : [0, -1]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.histogram %arg2 : tensor<256xi32> -> tensor<8xi32> + // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+1 {{non-neg}} + %1 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + tt.return + } +} + +// ----- + +// CHECK-LABEL: get_num_prog_nonneg +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @get_num_prog_nonneg(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2 : i32) { + // expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}} + // expected-remark@+1 {{non-neg}} + %0 = tt.get_num_programs x : i32 + // expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}} + // expected-remark@+1 {{non-neg}} + %1 = tt.get_num_programs y : i32 + // expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.get_num_programs z : i32 + // expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}} + // expected-remark@+1 {{non-neg}} + %3 = arith.minsi %0, %1 : i32 + // expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}} + // expected-remark@+1 {{non-neg}} + %4 = arith.minsi %2, %3 : i32 + // expected-remark@+2 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} + // expected-remark@+1 {{non-neg}} + %5 = arith.maxsi %arg2, %4 : i32 + // expected-remark@+2 {{[0, 2147483647] signed : [0, 2147483647]}} + // expected-remark@+1 {{non-neg}} + %6 = tt.splat %5 : i32 -> tensor<8xi32> + // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+1 {{non-neg}} + %7 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + // expected-remark@+1 {{unsigned : [0, 2147483655] signed : [-2147483648, 2147483647]}} + %8 = arith.addi %6, %7 : tensor<8xi32> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-reorder-instructions.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-reorder-instructions.mlir new file mode 100644 index 000000000..e481b6742 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -0,0 +1,450 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#smem = #ttg.shared_memory +// CHECK-LABEL: order_load_alloc_local_load_local_store +// CHECK: %[[LOAD:.+]] = tt.load +// CHECK: %[[ALLOC:.+]] = ttg.local_alloc +// CHECK: ttg.local_store %[[LOAD]], %[[ALLOC]] +// CHECK: ttg.local_load %[[ALLOC]] +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @order_load_alloc_local_load_local_store(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %10 = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttg.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- +// Move loads (and independent local_stores) as early as possible. +// For example in the matmul_loop below, the scf.for loop looks like this after pipeliner: +// scf.for ... { +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Solution for num_stages=2 : +// scf.for ... { +// // stage 0.a +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0.b +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Solution for num_stages=3 (double-buffered) : +// scf.for ... { +// // stage 1 +// tt.local_store %a_next_1 +// tt.local_store %b_next_1 +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next_2 = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next_2 = tt.load %bptr +// // stage 2 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// yield +// } + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared2 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}> +#shared3 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared4 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32, ttg.target = "hip:gfx942"} { + +// CHECK-LABEL: tt.func @matmul_loop +// CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) +// Stage 0.a +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// CHECK: %[[ADDPTR_25:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]] +// Stage 1 +// CHECK: %[[LOCAL_LOAD_28:.*]] = ttg.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_29:.*]] = ttg.local_load %[[ARG11]] +// CHECK: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} +// CHECK: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %[[ARG8]] +// Stage 0.b +// CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}} +// CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] +// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: } + + tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) { + %20 = arith.subi %arg1, %arg2 : index + %21 = arith.cmpi slt, %arg5, %20 : index + %22 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %23 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = arith.mulf %23, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %27 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %28 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1> + %29 = tt.load %26, %28 : tensor<128x32x!tt.ptr, #blocked1> + %30 = tt.splat %21 : i1 -> tensor<32x128xi1, #blocked> + %31 = tt.load %27, %30, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %32 = arith.addi %arg9, %c1_i32 : i32 + %33 = arith.cmpi slt, %32, %c1_i32 : i32 + %34 = arith.select %33, %32, %c0_i32 : i32 + %35 = ttg.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %36 = ttg.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + } + ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> + tt.return %19#2 : tensor<128x128xf32, #mma> + } + + +// This example tests that tt.load overlaps with independent ttg.local_store which +// overlaps with independent tt.dot. +// num_stages == 3, double buffered + +// CHECK-LABEL: tt.func @matmul_loop_mb +// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// Stage 1 +// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_31]] +// CHECK: %[[MEMDESC_SUBVIEW_32:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_32]] +// Stage 1 +// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[MULI_34:.*]] = arith.muli %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_35:.*]] = arith.subi %{{.*}}, %[[MULI_34]] +// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_35]] +// CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_36]] +// CHECK: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_37]] +// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_36]] +// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]] +// Stage 2 +// CHECK: %[[LOCAL_LOAD_42:.*]] = ttg.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_43:.*]] = ttg.local_load %[[ARG11]] +// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} +// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] +// CHECK: scf.yield %[[ADDPTR_33]], %[[ADDPTR_39]], %[[DOT_45]], %[[SELECT_30]], %[[MEMDESC_SUBVIEW_31]], %[[MEMDESC_SUBVIEW_32]], %[[LOAD_38]], %[[LOAD_41]] +// CHECK: } + + tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c2 = arith.constant 2 : index + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = arith.addi %arg0, %arg2 : index + %18 = arith.cmpi slt, %17, %arg1 : index + %19 = tt.addptr %4, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %20 = tt.addptr %9, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %21 = tt.splat %18 : i1 -> tensor<128x32xi1, #blocked1> + %22 = tt.load %19, %21 : tensor<128x32x!tt.ptr, #blocked1> + %23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> + %24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %25 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %26 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { + %28 = arith.muli %arg2, %c2 : index + %29 = arith.subi %arg1, %28 : index + %30 = arith.cmpi slt, %arg5, %29 : index + %31 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %33 = arith.mulf %32, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %36 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %37 = tt.splat %30 : i1 -> tensor<128x32xi1, #blocked1> + %38 = tt.load %35, %37 : tensor<128x32x!tt.ptr, #blocked1> + %39 = tt.splat %30 : i1 -> tensor<32x128xi1, #blocked> + %40 = tt.load %36, %39, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %41 = arith.addi %arg9, %c1_i32 : i32 + %42 = arith.cmpi slt, %41, %c2_i32 : i32 + %43 = arith.select %42, %41, %c0_i32 : i32 + %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> + } + ttg.local_dealloc %10 : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> + tt.return %27#2 : tensor<128x128xf32, #mma> + } + +// This example shows dependent loads and verifies all are moved early. +// CHECK-LABEL: tt.func @indirect_bmm_vector +// CHECK: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// Stage 0 +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// Stage 1.a +// CHECK: %[[EXPAND_DIMS_25:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_26:.*]] = tt.broadcast %[[EXPAND_DIMS_25]] +// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %[[BROADCAST_26]] +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}, %[[MULI_27]] +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SUBI_32:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_32]] +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]] +// Stage 2 +// CHECK: %[[LOCAL_LOAD_36:.*]] = ttg.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_37:.*]] = ttg.local_load %[[ARG12]] +// CHECK: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_36]], %[[LOCAL_LOAD_37]], %[[ARG7]] +// Stage 1.b +// CHECK: %[[ADDI_39:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}} +// CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] +// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: ttg.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] +// CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]] +// CHECK: } + + tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c2 = arith.constant 2 : index + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<1> : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> + %2 = arith.cmpi sgt, %arg1, %c0 : index + %3 = tt.splat %2 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = arith.cmpi sgt, %arg1, %c1 : index + %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %7 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked1> + %8 = tt.load %arg2, %7 : tensor<16x16x!tt.ptr, #blocked1> + %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %10 = tt.broadcast %9 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %11 = arith.muli %arg0, %10 : tensor<16x16xi64, #blocked> + %12 = tt.addptr %arg5, %11 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %13 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked> + %14 = tt.load %12, %13 : tensor<16x16x!tt.ptr, #blocked> + %15 = tt.splat %5 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = ttg.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %18 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>) { + %20 = arith.subi %arg1, %c2 : index + %21 = arith.cmpi slt, %arg6, %20 : index + %22 = arith.subi %arg1, %c1 : index + %23 = arith.cmpi slt, %arg6, %22 : index + %24 = ttg.local_load %arg11 : !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %25 = ttg.local_load %arg12 : !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %29 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked1> + %30 = tt.load %27, %29 : tensor<16x16x!tt.ptr, #blocked1> + %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %32 = tt.broadcast %31 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %33 = arith.muli %arg0, %32 : tensor<16x16xi64, #blocked> + %34 = tt.addptr %arg5, %33 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %35 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked> + %36 = tt.load %34, %35 : tensor<16x16x!tt.ptr, #blocked> + %37 = tt.splat %21 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> + %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> + %39 = arith.addi %arg10, %c1_i32 : i32 + %40 = arith.cmpi slt, %39, %c1_i32 : i32 + %41 = arith.select %40, %39, %c0_i32 : i32 + %42 = ttg.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %43 = ttg.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + } + ttg.local_dealloc %0 : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> + tt.return %19#0 : tensor<16x16xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: sink_convert_dealloc +// CHECK-COUNT-2: ttg.local_dealloc %{{.+}} : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> +// CHECK: ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %2 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: anchor_barrier +// CHECK: gpu.barrier +// CHECK: tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + gpu.barrier + %2 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %1 = ttg.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + tt.return + } +} + + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: dont_hoist_scf_ops + // Make sure we don't hoist scf ops above its dependencies. + tt.func public @dont_hoist_scf_ops(%init: tensor<256x128xf32, #mfma>, + %base: tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, + %p1: tensor<128x128x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>, %i1: i1) -> (tensor<256x128xf32, #mfma>) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + // CHECK: scf.for + %54 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg = %init) -> (tensor<256x128xf32, #mfma>) : i32 { + // CHECK: arith.addi + %f = arith.addi %arg21, %c128_i32 : i32 + // CHECK: scf.if + // CHECK: tt.load + %p0 = scf.if %i1 -> tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>{ + %t = tt.splat %f : i32 -> tensor<256x128xi32> + %padd = tt.addptr %base, %t : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, tensor<256x128xi32> + scf.yield %padd : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + } else { + scf.yield %base : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + } + %l = tt.load %p0 : tensor<256x128x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %r = tt.load %p1 : tensor<128x128x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %acc = tt.dot %l, %r, %arg : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + scf.yield %acc : tensor<256x128xf32, #mfma> + } + tt.return %54 : tensor<256x128xf32, #mfma> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-sched-2nd-load.mlir new file mode 100644 index 000000000..d9a2e9965 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -0,0 +1,376 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s + +// Check the logic of sched-2nd-load optimizations +// + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough +// The following tile sizes should apply the optimization +// 256x256x128 +// 256x256x64 +// The following tile sizes should NOT apply the optimization +// 256x64x128 +// 256x256x32 +// + +// Should apply: tile size 256x256x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Should apply: tile size 256x256x128 with nested single dot +// CHECK-LABEL: nested_sink_2nd_load_256x256x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @nested_sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + scf.for %arg2 = %c0 to %c1 step %c1 : i32 { + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + } + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Should apply: tile size 256x256x64 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x64 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #smem, mutable> -> tensor<256x64xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Should NOT apply: tile size 256x64x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x64x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x64xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Should NOT apply: tile size 256x256x32 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x32 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +// CHECK-NEXT: ttg.local_store %[[tileB]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x32xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<32x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x32xf16, #shared, #smem, mutable> -> tensor<256x32xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<32x256xf16, #shared1, #smem, mutable> -> tensor<32x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !ttg.memdesc<256x32xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !ttg.memdesc<32x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory + +// Category 2: single dot with two loads and tile size is large enough (128x128x128). +// We make sure the move is legal. +// Should NOT apply: the 2nd load has a user before the dot +// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.store +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store %[[tileA]] +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x128xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x128xf16, #shared1, #smem, mutable> -> tensor<128x128xf16, #dotOp1> + tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> + %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + scf.yield %3 : tensor<128x128xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Category 3: two dots in the for loop. Make sure the optimization is not applied +// should NOT apply: two dots +// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot +// CHECK: tt.load +// CHECK-NEXT: tt.load +// CHECK-NEXT: ttg.local_load +// CHECK-NEXT: ttg.local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.dot +// CHECK-NEXT: ttg.local_store +// CHECK-NEXT: ttg.local_store +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #smem, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Category 4: load a scalar. Make sure the optimization is not applied +// should NOT apply: load scalar +// CHECK-LABEL: sink_2nd_load_scalar +// CHECK: tt.load +// CHECK-NEXT: tt.load +// CHECK-NEXT: tt.splat +// CHECK-NEXT: tt.broadcast +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: tt.dot +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { +tt.func public @sink_2nd_load_scalar(%A_ptr: !tt.ptr, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = tt.load %A_ptr : !tt.ptr + %2 = tt.splat %1 : f16 -> tensor<1x64xf16, #blocked> + %3 = tt.broadcast %2 : tensor<1x64xf16, #blocked> -> tensor<256x64xf16, #blocked> + %4 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + %5 = ttg.convert_layout %3 : tensor<256x64xf16, #blocked> -> tensor<256x64xf16, #dotOp0> + %6 = ttg.convert_layout %4 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #dotOp1> + %7 = tt.dot %5, %6, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + scf.yield %7 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Category 5: load a 1D tensor. Make sure the optimization is not applied +// should NOT apply: load scalar +// CHECK-LABEL: sink_2nd_load_1D_tensor +// CHECK: tt.load +// CHECK-NEXT: tt.load +// CHECK-NEXT: tt.expand_dims +// CHECK-NEXT: tt.broadcast +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: tt.dot +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { +tt.func public @sink_2nd_load_1D_tensor(%A_ptr: tensor<256x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = tt.load %A_ptr : tensor<256x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<256xf16, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xf16, #blocked> + %3 = tt.broadcast %2 : tensor<256x1xf16, #blocked> -> tensor<256x64xf16, #blocked> + %4 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + %5 = ttg.convert_layout %3 : tensor<256x64xf16, #blocked> -> tensor<256x64xf16, #dotOp0> + %6 = ttg.convert_layout %4 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #dotOp1> + %7 = tt.dot %5, %6, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + scf.yield %7 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/amd-stream-prefetch.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-stream-prefetch.mlir new file mode 100644 index 000000000..82d93b9d1 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/amd-stream-prefetch.mlir @@ -0,0 +1,121 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=3 global_prefetch=1" -canonicalize | FileCheck %s --check-prefixes=GLOBAL_1 +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=4 global_prefetch=2" -canonicalize | FileCheck %s --check-prefixes=GLOBAL_2 +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=3 global_prefetch=1 local_prefetch=1" -canonicalize | FileCheck %s --check-prefixes=GLOBAL_LOCAL_1 +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2 local_prefetch=1" -canonicalize | FileCheck %s --check-prefixes=LOCAL_1 + +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +// An extra register buffer for global loads. +// GLOBAL_1-LABEL: tt.func @matmul_loop +// GLOBAL_1-COUNT-2: tt.load +// GLOBAL_1-COUNT-2: ttg.local_store +// GLOBAL_1-COUNT-2: tt.load +// GLOBAL_1: scf.for +// GLOBAL_1-COUNT-2: ttg.local_load +// GLOBAL_1: tt.dot +// GLOBAL_1-COUNT-2: ttg.local_store +// GLOBAL_1-COUNT-2: tt.load +// GLOBAL_1: scf.yield +// GLOBAL_1-COUNT-2: tt.dot +// GLOBAL_1-NOT: tt.dot + +// Two extra register buffers for global loads. +// GLOBAL_2-LABEL: tt.func @matmul_loop +// GLOBAL_2-COUNT-4: tt.load +// GLOBAL_2-COUNT-2: ttg.local_store +// GLOBAL_2-COUNT-2: tt.load +// GLOBAL_2: scf.for +// GLOBAL_2-COUNT-2: ttg.local_load +// GLOBAL_2: tt.dot +// GLOBAL_2-COUNT-2: ttg.local_store +// GLOBAL_2-COUNT-2: tt.load +// GLOBAL_2: scf.yield +// GLOBAL_2-COUNT-3: tt.dot +// GLOBAL_2-NOT: tt.dot + +// An extra register buffer for global loads and an extra register buffer for local_loads. +// GLOBAL_LOCAL_1-LABEL: tt.func @matmul_loop +// GLOBAL_LOCAL_1-COUNT-2: tt.load +// GLOBAL_LOCAL_1-COUNT-2: ttg.local_store +// GLOBAL_LOCAL_1: tt.load +// GLOBAL_LOCAL_1: ttg.local_load +// GLOBAL_LOCAL_1: tt.load +// GLOBAL_LOCAL_1: ttg.local_load +// GLOBAL_LOCAL_1: scf.for +// GLOBAL_LOCAL_1-COUNT-2: ttg.local_store +// GLOBAL_LOCAL_1: tt.dot +// GLOBAL_LOCAL_1: tt.load +// GLOBAL_LOCAL_1: ttg.local_load +// GLOBAL_LOCAL_1: tt.load +// GLOBAL_LOCAL_1: ttg.local_load +// GLOBAL_LOCAL_1: scf.yield +// GLOBAL_LOCAL_1-COUNT-2: tt.dot +// GLOBAL_LOCAL_1-NOT: tt.dot + +// One Local buffer. +// LOCAL_1-LABEL: tt.func @matmul_loop +// LOCAL_1-COUNT-2: tt.load +// LOCAL_1-COUNT-2: ttg.local_store +// LOCAL_1-COUNT-2: ttg.local_load +// LOCAL_1: scf.for +// LOCAL_1-COUNT-2: tt.load +// LOCAL_1: tt.dot +// LOCAL_1-COUNT-2: ttg.local_store +// LOCAL_1-COUNT-2: ttg.local_load +// LOCAL_1: scf.yield +// LOCAL_1: tt.dot +// LOCAL_1-NOT: tt.dot + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %b_scale = arith.constant dense<4.> : tensor<32x128xf16, #B> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b__ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b_ = ttg.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/invalid.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/invalid.mlir new file mode 100644 index 000000000..2c4d86fbd --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/invalid.mlir @@ -0,0 +1,21 @@ +// RUN: triton-opt --split-input-file %s --verify-diagnostics + +// expected-error @+1 {{Transposed WMMA is supported only for version 2}} +#wmma = #ttg.amd_wmma<{version = 1, isTranspose = true, warpsPerCTA = [2, 2]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @fn(%arg0: !tt.ptr) { + %t = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #wmma> + tt.return + } +} + +// ----- + +// expected-error @+1 {{WMMA version must be in the [1, 2] range}} +#wmma = #ttg.amd_wmma<{version = 0, isTranspose = false, warpsPerCTA = [2, 2]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @fn(%arg0: !tt.ptr) { + %t = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #wmma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/mfma-double-rate.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/mfma-double-rate.mlir new file mode 100644 index 000000000..97dd4b8a5 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/mfma-double-rate.mlir @@ -0,0 +1,60 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx950' | FileCheck %s + +// CHECK-LABEL:mfma_16x16x32_f16 + +#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_16x16x32_f16(%arg0: tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, + %arg1: tensor<32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + // CHECK: rocdl.mfma.f32.16x16x32.f16 {{.*}} : (vector<8xf16>, vector<8xf16> + %dot = tt.dot %arg0, %arg1, %cst : tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<16x16xf32, #mma> + tt.return + } +} + +// ----- + +// CHECK-LABEL:mfma_16x16x32_bf16 + +#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_16x16x32_bf16(%arg0: tensor<16x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, + %arg1: tensor<32x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + // CHECK: rocdl.mfma.f32.16x16x32.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16> + %dot = tt.dot %arg0, %arg1, %cst : tensor<16x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<16x16xf32, #mma> + tt.return + } +} + +// ----- + +// CHECK-LABEL:mfma_32x32x16_f16 + +#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_32x32x16_f16(%arg0: tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, + %arg1: tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // CHECK: rocdl.mfma.f32.32x32x16.f16 {{.*}} : (vector<8xf16>, vector<8xf16> + %dot = tt.dot %arg0, %arg1, %cst : tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x32xf32, #mma> + tt.return + } +} + + +// ----- + +// CHECK-LABEL:mfma_32x32x16_bf16 + +#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_32x32x16_bf16(%arg0: tensor<32x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, + %arg1: tensor<16x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // CHECK: rocdl.mfma.f32.32x32x16.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16> + %dot = tt.dot %arg0, %arg1, %cst : tensor<32x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x32xf32, #mma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/mfma-xf32.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/mfma-xf32.mlir new file mode 100644 index 000000000..2d81fbc71 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/mfma-xf32.mlir @@ -0,0 +1,39 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +// CHECK-LABEL:mfma_xf32 + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_xf32( + %arg0: tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, + %arg1: tensor<128x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + // Check that we generate xf32 instructions + // CHECK: rocdl.mfma.f32.16x16x8.xf32 + %dot = tt.dot %arg0, %arg1, %cst_0, inputPrecision = tf32 : + tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<128x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x64xf32, #mma> + tt.return + } +} + +// ----- + +// CHECK-LABEL:mfma_not_xf32 + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_not_xf32( + %arg0: tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, + %arg1: tensor<128x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + // Check that we don't generate xf32 instructions if the input precision is "ieee" + // CHECK: rocdl.mfma.f32.16x16x4f32 + %dot = tt.dot %arg0, %arg1, %cst_0, inputPrecision = ieee : + tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<128x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x64xf32, #mma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/optimize-lds-usage.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/optimize-lds-usage.mlir new file mode 100644 index 000000000..2d769f000 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/optimize-lds-usage.mlir @@ -0,0 +1,167 @@ +// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a | FileCheck %s +// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a -optimize-amd-lds-usage=lds-limit=32768 | FileCheck %s --check-prefix=CHECK-32KLIMIT + +// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS +// CHECK-LABEL: alloc_convert_load +// CHECK-32KLIMIT-LABEL: alloc_convert_load +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} { + %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %2 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// ----- + +// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS +// in case of relatively small scratch buffer +// CHECK-LABEL: alloc_convert_small_load +// CHECK-32KLIMIT-LABEL: alloc_convert_small_load +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} { + %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %2 = ttg.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// FIXME: This was broken in https://github.com/triton-lang/triton/pull/5840 +// // ----- + +// Check that optimization works with 3d tensors +// in case of relatively small scratch buffer +// DISABLE-CHECK-LABEL: alloc_convert_3d_load +// DISABLE-CHECK-32KLIMIT-LABEL: alloc_convert_3d_load +// DISABLE-CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// DISABLE-CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma +// DISABLE-CHECK: %2 = ttg.convert_layout %1 : {{.*}}#mma{{.*}}#mma1 +// DISABLE-CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x128x128xf16, #blocked>) attributes {noinline = false} { + %1 = ttg.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !ttg.memdesc<1x128x128xf16, #shared, #smem> + %2 = ttg.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<1x128x128xf16, #shared, #smem> -> tensor<1x128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// ----- + +// Check that optimization triggers with custom LDS limit and do not triggers with default one +// CHECK-LABEL: alloc_convert_32k_limit +// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma +// CHECK: %2 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK-32KLIMIT-LABEL: alloc_convert_32k_limit +// CHECK-32KLIMIT: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK-32KLIMIT: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK-32KLIMIT: %2 = ttg.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK-32KLIMIT: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<64x128xf16, #blocked>) attributes {noinline = false} { + %1 = ttg.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #smem> + %2 = ttg.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma> + %3 = ttg.local_load %1 : !ttg.memdesc<64x128xf16, #shared, #smem> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> + tt.return + } +} + +// FIXME: This was broken in https://github.com/triton-lang/triton/pull/5840 +// ----- + +// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion) +// DISABLE-CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +// DISABLE-CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}> +// DISABLE-CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +// DISABLE-CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +// DISABLE-CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> + +// DISABLE-CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}}) +// DISABLE-CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem> +// DISABLE-CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]> +// DISABLE-CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]> +// DISABLE-CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>> +// DISABLE-CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#mma2 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#dotop1 = #ttg.dot_op<{opIdx=0, parent=#mma1, kWidth=4}> +#dotop2 = #ttg.dot_op<{opIdx=0, parent=#mma2, kWidth=4}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_shortcut(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>, %arg2: tensor<256x128xf16, #mma2>) attributes {noinline = false} { + %alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %convert_1 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma1> + %convert_2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #dotop1> + tt.return + } +} +// ----- + +// Checks that optimization do not crash on 1d tensor +// CHECK-LABEL: convert_1d +// CHECK: ttg.local_alloc +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: ttg.local_load +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @convert_1d(%arg0: tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} { + %alloc = ttg.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !ttg.memdesc<128x128xf32, #shared, #smem> + %1 = ttg.convert_layout %arg0 : tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf32, #shared, #smem> -> tensor<128x128xf32, #mma> + tt.return + } +} + +// ----- + +// Checks that optimization do not crash on linear encoding tensor +// CHECK-LABEL: convert_linear +// CHECK: ttg.local_alloc +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: ttg.convert_layout +// CHECK-NEXT: ttg.local_load +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [1, 0]}> +#linear = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp = [[0, 16], [32, 0]], block = []}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @convert_linear(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} { + %alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %1 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #linear> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/amd/sink-setprio-mfma.mlir b/third_party/enflame/include/triton/test/TritonGPU/amd/sink-setprio-mfma.mlir new file mode 100644 index 000000000..e7a2234cc --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/amd/sink-setprio-mfma.mlir @@ -0,0 +1,25 @@ +// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +// CHECK-LABEL: llvm.func @sink_setprio +// CHECK: rocdl.mfma +// CHECK-NOT: rocdl.mfma +// CHECK: rocdl.s.setprio 1 +// CHECK-COUNT-15: rocdl.mfma +// CHECK-NOT: rocdl.mfma +// CHECK: rocdl.s.setprio 0 + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @sink_setprio( + %arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, + %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + rocdl.s.setprio 1 + %dot = tt.dot %arg0, %arg1, %cst_0 : + tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma> + rocdl.s.setprio 0 + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/atomic-cas.mlir b/third_party/enflame/include/triton/test/TritonGPU/atomic-cas.mlir new file mode 100644 index 000000000..c4472e9c9 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/atomic-cas.mlir @@ -0,0 +1,27 @@ +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=target=cuda:80 2>&1 | FileCheck %s --check-prefix=GPU +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=target=cuda:80 -convert-triton-gpu-to-llvm 2>&1 | FileCheck %s --check-prefix=LLVM + +// GPU: %9 = tt.atomic_cas acq_rel, cta, %8, %cst_0, %cst : (tensor<2x!tt.ptr, #blocked>, tensor<2xi64, #blocked>, tensor<2xi64, #blocked>) -> tensor<2xi64, #blocked> +// LLVM: llvm.inline_asm {{.*}} "mov.u64 $0, 0x0;\0A\09@$4 atom.global.acq_rel.cta.cas.b64 $0, [ $1 + 0 ], $2, $3;", "=l,l,l,l,b" + +module { + tt.func public @atomic_cas_kernel_0d1d2e(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<2> : tensor<2xi64> + %cst_0 = arith.constant dense<1> : tensor<2xi64> + %c2_i32 = arith.constant 2 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c2_i32 : i32 + %2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %3 = tt.splat %1 : i32 -> tensor<2xi32> + %4 = arith.addi %3, %2 : tensor<2xi32> + %5 = tt.splat %arg2 : i32 -> tensor<2xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<2xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<2x!tt.ptr>, tensor<2xi32> + %9 = tt.atomic_cas acq_rel, cta, %8, %cst_0, %cst : (tensor<2x!tt.ptr>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %11, %9, %6 : tensor<2x!tt.ptr> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/blackwell_acc_tmem.mlir b/third_party/enflame/include/triton/test/TritonGPU/blackwell_acc_tmem.mlir new file mode 100644 index 000000000..f9121b514 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/blackwell_acc_tmem.mlir @@ -0,0 +1,143 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-keep-acc-in-tmem | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @trivial + // CHECK: %[[TMEM_BASE:.*]] = arith.constant dense<0.000000e+00> + // CHECK: %[[TMEM:.*]] = ttng.tmem_alloc %[[TMEM_BASE]] + // CHECK-NEXT: scf.for + // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM]] + // CHECK: scf.yield + // CHECK: %[[ACC_RES:.*]] = ttng.tmem_load %[[TMEM]] + // CHECK: %[[RES:.*]] = arith.truncf %[[ACC_RES]] + // CHECK: tt.return + tt.func public @trivial(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + scf.yield %acc_res : tensor<128x128xf32, #blocked> + } + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @use_after_dot + // CHECK: %[[TMEM_BASE:.*]] = arith.constant dense<0.000000e+00> + // CHECK: %[[TMEM:.*]] = ttng.tmem_alloc %[[TMEM_BASE]] + // CHECK-NEXT: scf.for + // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM]] + // CHECK: %[[ACC_RES_INS:.*]] = ttng.tmem_load %[[TMEM]] + // CHECK: arith.addf %[[ACC_RES_INS]] + // CHECK: scf.yield + // CHECK: %[[ACC_RES:.*]] = ttng.tmem_load %[[TMEM]] + // CHECK: %[[RES:.*]] = arith.truncf %[[ACC_RES]] + // CHECK: tt.return + tt.func public @use_after_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %add: tensor<128x128xf32, #blocked>, %arg3: i32) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %res:2 = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst, %add_ = %cst) -> (tensor<128x128xf32, #blocked>, tensor<128x128xf32, #blocked>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %add_res = arith.addf %acc_res, %add_ : tensor<128x128xf32, #blocked> + scf.yield %acc_res, %add_res : tensor<128x128xf32, #blocked>, tensor<128x128xf32, #blocked> + } + %res_f16 = arith.truncf %res#0 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: @use_before_dot + // CHECK: %[[TMEM_BASE:.*]] = arith.constant dense<0.000000e+00> + // CHECK: %[[TMEM:.*]] = ttng.tmem_alloc %[[TMEM_BASE]] + // CHECK-NEXT: scf.for + // CHECK: %[[ACC_RES_INS:.*]] = ttng.tmem_load %[[TMEM]] + // CHECK: arith.addf %[[ACC_RES_INS]] + // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM]] + // CHECK: scf.yield + // CHECK: %[[ACC_RES:.*]] = ttng.tmem_load %[[TMEM]] + // CHECK: %[[RES:.*]] = arith.truncf %[[ACC_RES]] + // CHECK: tt.return + tt.func public @use_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %add: tensor<128x128xf32, #blocked>, %arg3: i32) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %res:2 = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst, %add_ = %cst) -> (tensor<128x128xf32, #blocked>, tensor<128x128xf32, #blocked>) : i32 { + %add_res = arith.addf %acc, %add_ : tensor<128x128xf32, #blocked> + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + scf.yield %acc_res, %add_res : tensor<128x128xf32, #blocked>, tensor<128x128xf32, #blocked> + } + %res_f16 = arith.truncf %res#0 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8, fp4Padded = true}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @hoist_constant_inputs + tt.func public @hoist_constant_inputs(%arg0: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem>, %arg2: !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>, %arg3: i32, %arg4: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>) { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + // CHECK: arith.trunci + // CHECK: tt.splat + // CHECK: ttng.tmem_alloc + // CHECK: scf.for + // CHECK: ttng.tc_gen5_mma_scaled + scf.for %arg5 = %c0_i32 to %arg3 step %c1_i32 : i32 { + %0 = arith.trunci %arg3 : i32 to i8 + %1 = tt.splat %0 : i8 -> tensor<128x4xi8, #blocked1> + %2 = ttng.tmem_alloc %1 : (tensor<128x4xi8, #blocked1>) -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory> + ttng.tc_gen5_mma_scaled %arg0, %arg1, %arg4, %arg2, %2, %true, %true lhs = e5m2 rhs = e2m1 : (!ttg.memdesc<128x128xf8E5M2, #shared, #smem>, !ttg.memdesc<64x128xi8, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>, !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>, i1, i1) -> () + } + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/canonicalize.mlir b/third_party/enflame/include/triton/test/TritonGPU/canonicalize.mlir new file mode 100644 index 000000000..1566693a9 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/canonicalize.mlir @@ -0,0 +1,271 @@ +// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s + + +// CHECK-LABEL: @test_canonicalize_convert_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 +// CHECK-NOT: ttg.convert_layout +// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder +// CHECK: tt.return %[[V]] +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { +tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = ttg.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> + %r = tt.reshape %c allow_reorder : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +} +} // end module + +// ----- + +// test that the convert doesn't get combined with view if the resulting operations +// is an expensive view which would require moving data across threads. +// CHECK-LABEL: @test_canonicalize_convert_expensive_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 +// CHECK: %[[C:.+]] = ttg.convert_layout %[[ARG]] +// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder +// CHECK: tt.return %[[V]] +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { +tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = ttg.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2> + %r = tt.reshape %c allow_reorder : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +} +} // end module + +// ----- + +// test that the convert doesn't get combined with view if the resulting operations +// is an expensive view which would require moving data across threads. +// CHECK-LABEL: @test_canonicalize_convert_expensive_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<2xf32 +// CHECK: %[[C:.+]] = ttg.convert_layout %[[ARG]] +// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder +// CHECK: tt.return %[[V]] +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} { + tt.func @test_canonicalize_convert_expensive_view2(%arg0: tensor<2xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> tensor<2xf32, #blocked1> { + %c = ttg.convert_layout %arg0 : tensor<2xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<2xf32, #blocked1> + %r = tt.reshape %c allow_reorder : tensor<2xf32, #blocked1> -> tensor<2xf32, #blocked1> + tt.return %r : tensor<2xf32, #blocked1> + } +} + +// ----- + +// test that the convert does get combined with the view even if the resulting operation +// is an efficient view. +// CHECK-LABEL: @test_canonicalize_convert_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 +// CHECK-NOT: ttg.convert_layout +// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder +// CHECK: tt.return %[[V]] +#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> + +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { +tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = ttg.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> + %r = tt.reshape %c allow_reorder efficient_layout : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +} +} // end module + +// ----- + +// CHECK-LABEL: @test_canonicalize_convert_histogram +// CHECK-SAME: (%[[ARG:.+]]: tensor<256xi32 +// CHECK-NOT: ttg.convert_layout +// CHECK: %[[V:.+]] = tt.histogram %[[ARG]] +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return %[[V]] +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { +tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>) -> tensor<512xi32, #blocked2> { + %0 = ttg.convert_layout %arg0 : tensor<256xi32, #blocked1> -> tensor<256xi32, #blocked> + %1 = tt.histogram %0 : tensor<256xi32, #blocked> -> tensor<512xi32, #blocked> + %2 = ttg.convert_layout %1 : tensor<512xi32, #blocked> -> tensor<512xi32, #blocked2> + tt.return %2 : tensor<512xi32, #blocked2> +} +} // end module + +// ----- + +// CHECK-LABEL: @test_canonicalize_convert_local_load +// CHECK-NOT: gpu.barrier +// CHECK: %[[V:.+]] = ttg.local_load +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: tt.return %[[V]] + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.compute-capability" = 80} { +tt.func @test_canonicalize_convert_local_load() -> tensor<256xi32, #blocked1> { + %0 = ttg.local_alloc : () -> !ttg.memdesc<256xi32, #shared, #smem, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<256xi32, #shared, #smem, mutable> -> tensor<256xi32, #blocked> + gpu.barrier + %2 = ttg.convert_layout %1 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + tt.return %2 : tensor<256xi32, #blocked1> +} +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: local_alloc_nofold1 + tt.func @local_alloc_nofold1(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> { + // CHECK: %[[ARG:.+]] = ttg.local_alloc + // CHECK-NEXT: %[[ARG2:.+]] = ttg.local_load %[[ARG]] + // CHECK-NEXT: %[[ARG3:.+]] = ttg.local_alloc %[[ARG2]] + // CHECK-NEXT: tt.return %[[ARG3]] + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #smem> + } +} // end module + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: local_alloc_nofold2 + tt.func @local_alloc_nofold2(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #smem> { + // CHECK: %[[ARG:.+]] = ttg.local_alloc + // CHECK-NEXT: %[[ARG2:.+]] = ttg.local_load %[[ARG]] + // CHECK-NEXT: %[[ARG3:.+]] = ttg.local_alloc %[[ARG2]] + // CHECK-NEXT: tt.return %[[ARG3]] + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared1, #smem> + } +} // end module + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func @local_alloc_fold(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> { + // CHECK-LABEL: local_alloc_fold + // CHECK-NEXT: %[[ARG:.+]] = ttg.local_alloc + // CHECK-NEXT: tt.return %[[ARG]] + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #smem> + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: convert_layout_gather_src + tt.func @convert_layout_gather_src(%arg0: tensor<16x16xf16, #blocked>, %arg1: tensor<16x16xi32, #blocked>) -> tensor<16x16xf16, #blocked> { + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1> + // CHECK-NEXT: tt.gather %arg0[%arg1] + %1 = tt.gather %0[%arg1] {axis = 0 : i32} : (tensor<16x16xf16, #blocked1>, tensor<16x16xi32, #blocked>) -> tensor<16x16xf16, #blocked> + tt.return %1 : tensor<16x16xf16, #blocked> + } + + // CHECK-LABEL: gather_efficient_layout + tt.func @gather_efficient_layout(%arg0: tensor<16x16xf16, #blocked>, %arg1: tensor<16x16xi32, #blocked>) -> tensor<16x16xf16, #blocked> { + // CHECK-NEXT: convert_layout + %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1> + // CHECK-NEXT: tt.gather {{.*}} (tensor<16x16xf16, #blocked1> + %1 = tt.gather %0[%arg1] {axis = 0 : i32, efficient_layout} : (tensor<16x16xf16, #blocked1>, tensor<16x16xi32, #blocked>) -> tensor<16x16xf16, #blocked> + tt.return %1 : tensor<16x16xf16, #blocked> + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [8, 0], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[0, 8], [0, 16]], block = []}> +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked_trans = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: @infer_trans +tt.func @infer_trans(%arg0: tensor<32x32xf32, #linear>) -> tensor<32x32xf32, #blocked_trans> { + // CHECK-NOT: ttg.convert_layout + %0 = ttg.convert_layout %arg0 : tensor<32x32xf32, #linear> -> tensor<32x32xf32, #blocked> + %1 = tt.trans %0 {order = array} : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked_trans> + tt.return %1 : tensor<32x32xf32, #blocked_trans> +} + +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#dot_t = #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 64], [0, 128]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 16], [0, 32]], block = []}> +#dot_linear = #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [64, 0], [128, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @simplify_trans_trans + tt.func public @simplify_trans_trans(%arg0: tensor<256x256xf32, #dot_linear>) -> tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> { + // CHECK-NEXT: ttg.convert_layout + %a = tt.trans %arg0 {order=array} : tensor<256x256xf32, #dot_linear> -> tensor<256x256xf32, #dot_t> + %b = tt.trans %a {order=array} : tensor<256x256xf32, #dot_t> -> tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + tt.return %b : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + } +} + +// ----- + +// CHECK-LABEL: @warp_specialize_with_no_uses_and_effects +tt.func @warp_specialize_with_no_uses_and_effects(%arg0: i32) { + %0 = ttg.warp_specialize(%arg0) + default { + %1 = arith.addi %arg0, %arg0 : i32 + ttg.warp_yield %1 : i32 + } + partition0(%arg1: i32) num_warps(4) { + arith.addi %arg1, %arg1 : i32 + ttg.warp_return + } : (i32) -> i32 + // CHECK-NEXT: tt.return + tt.return +} + +// CHECK-LABEL: @canonicalize_within_warp_specialize +tt.func @canonicalize_within_warp_specialize(%arg0: i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %0 = ttg.warp_specialize() + default { + %1 = arith.addi %arg0, %c0_i32 : i32 + // CHECK: warp_yield %arg0 + ttg.warp_yield %1 : i32 + } + // CHECK: partition0 + partition0() num_warps(4) { + %c0_i32_0 = arith.constant 0 : i32 + // CHECK-NEXT: warp_return + ttg.warp_return + } : () -> i32 + tt.return %0 : i32 +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/coalesce-async-copy.mlir b/third_party/enflame/include/triton/test/TritonGPU/coalesce-async-copy.mlir new file mode 100644 index 000000000..5ce8230a3 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/coalesce-async-copy.mlir @@ -0,0 +1,37 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s + +// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr, #blocked>, + %view: !ttg.memdesc<64x16xi8, #shared, #smem, mutable>, + %mask: tensor<64x16xi1, #blocked>, + %other: tensor<64x16xi8, #blocked>) { + %token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #smem, mutable> + tt.return +} +} + +// ----- + +// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr, #blocked>, + %view: !ttg.memdesc<64x16xi8, #shared, #smem, mutable>) { + %token = ttg.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #smem, mutable> + tt.return +} +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/coalesce.mlir b/third_party/enflame/include/triton/test/TritonGPU/coalesce.mlir new file mode 100644 index 000000000..44eb3d47e --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/coalesce.mlir @@ -0,0 +1,201 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#slice2dim0 = #ttg.slice<{dim = 0, parent = #blocked2}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: [[row_layout:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: [[col_layout:#.*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: [[load_ptr:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[row_layout]]> +// CHECK: [[load_mask:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]> +// CHECK: [[load_other:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]> +// CHECK: [[load_val:%.*]] = tt.load [[load_ptr]], [[load_mask]], [[load_other]] : tensor<64x64x!tt.ptr, [[row_layout]]> +// CHECK: [[store_ptr:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64x!tt.ptr, [[col_layout]]> +// CHECK: [[store_val:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> +// CHECK: [[store_mask:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> +// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] +tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: i32 {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense : tensor<64x64xi1, #blocked1> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> + %00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1> + %01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0> + %1 = tt.expand_dims %00 {axis = 1 : i32} : tensor<64xi32, #slice1dim1> -> tensor<64x1xi32, #blocked1> + %2 = tt.splat %arg1 : i32 -> tensor<64x1xi32, #blocked1> + %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> + %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %11 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %13 = tt.splat %arg3 : i32 -> tensor<1x64xi32, #blocked2> + %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> + %15 = tt.broadcast %12 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %16 = tt.broadcast %14 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %17 = ttg.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %19 = tt.load %10, %cst, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> + tt.store %18, %19, %cst : tensor<64x64x!tt.ptr, #blocked1> + tt.return +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + + +// CHECK: [[NARROW_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[WIDE_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> + %13 = arith.extf %12 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked> + %14 = arith.addf %9, %13 : tensor<1024xf32, #blocked> + %15 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: tt.store {{.*}} : tensor<1024x!tt.ptr, [[WIDE_LAYOUT]]> + tt.store %16, %14, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-NOT: sizePerThread = [4] +// CHECK: #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK-NOT: sizePerThread = [4] +tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> + %13 = arith.extf %12 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked> + %14 = arith.addf %9, %13 : tensor<1024xf32, #blocked> + %15 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %17 = arith.truncf %14 : tensor<1024xf32, #blocked> to tensor<1024xf16, #blocked> + tt.store %16, %17, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return +} + +} + +// ----- + +// COM: Reproducer for issue #3866 +// CHECK-LABEL: @test_3866 +// CHECK: tt.load {{.*}} : !tt.ptr +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { + tt.func public @test_3866(%arg0: !tt.ptr, %arg1: i32, %arg2: i64) { + %0 = tt.make_tensor_ptr %arg0, [%arg2, %arg2], [%arg2, %arg2], [%arg1, %arg1] {order = array} : > + %1 = tt.load %0 : !tt.ptr> + tt.return + } +} + +// ----- + +// COM: Reproducer for issue #5122 +// CHECK-LABEL: @test_5122 +module { + tt.func public @test_5122(%arg0: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %0 = arith.cmpi sgt, %arg0, %c1_i32 : i32 + scf.if %0 { + %1 = scf.if %0 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %2 = arith.cmpi sgt, %1, %c1_i32 : i32 + %3 = scf.if %2 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %4 = scf.for %arg1 = %1 to %1 step %c1_i32 iter_args(%arg2 = %3) -> (i32) : i32 { + %5 = arith.addi %arg2, %c1_i32 : i32 + scf.yield %5 : i32 + } + } + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + +// CHECK: [[COALESCED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK: @coalesce_poison +tt.func @coalesce_poison(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> + %2 = ttg.convert_layout %1 : tensor<128xi32, #blocked1> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %4 = ttg.convert_layout %3 : tensor<128x1xi32, #blocked2> -> tensor<128x1xi32, #blocked3> + %5 = tt.broadcast %4 {axis = 1 : i32} : tensor<128x1xi32, #blocked3> -> tensor<128x64xi32, #blocked3> + %6 = ttg.convert_layout %5 : tensor<128x64xi32, #blocked3> -> tensor<128x64xi32, #blocked> + %7 = tt.addptr %0, %6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + + %8 = ub.poison : tensor<128x64x!tt.ptr, #blocked> + // CHECK: scf.if + %9 = scf.if %arg2 -> (tensor<128x64x!tt.ptr, #blocked>) { + scf.yield %8 : tensor<128x64x!tt.ptr, #blocked> + } else { + scf.yield %7 : tensor<128x64x!tt.ptr, #blocked> + } + // CHECK: [[PTR:%.*]] = ttg.convert_layout %{{.*}} : tensor<128x64x!tt.ptr, #{{.*}}> -> tensor<128x64x!tt.ptr, [[COALESCED_LAYOUT]]> + // CHECK-NEXT: tt.load [[PTR]] + %10 = tt.load %9 : tensor<128x64x!tt.ptr, #blocked> + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/combine-select-if.mlir b/third_party/enflame/include/triton/test/TritonGPU/combine-select-if.mlir new file mode 100644 index 000000000..586e33ec4 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/combine-select-if.mlir @@ -0,0 +1,101 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s + +// CHECK-LABEL: @select_if_combine +tt.func public @select_if_combine(%arg0: tensor<64xf32>, %dst_ptr: tensor<64x!tt.ptr>, %cnd: i1) { + // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> + %cst = arith.constant dense<0.000000e+00> : tensor<64xf32> + // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32> + // CHECK-NOT: arith.select + %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32> + // CHECK: %[[R:.+]] = scf.if %{{.*}} + // CHECK: tt.store %{{.*}}, %{{.*}} + // CHECK: scf.yield %[[CST0]] + // CHECK: } else { + // CHECK: scf.yield %[[CST1]] + // CHECK: } + scf.if %cnd { + tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr> + } + // CHECK: tt.store %{{.*}}, %[[R]] + tt.store %dst_ptr, %sel : tensor<64x!tt.ptr> + tt.return +} + +// ----- +// CHECK-LABEL: @if_multiple_sel +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func @if_multiple_sel(%arg0: i1, %arg1: tensor<64xi32, #blocked>, %arg2: tensor<64xi32, #blocked>, %arg3: tensor<64xf32, #blocked>, %arg4: tensor<64xf32, #blocked>) -> (tensor<64xi32, #blocked>, tensor<64xf32, #blocked>, tensor<64xi32, #blocked>){ + // CHECK-NOT: select + // CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>, tensor<64xf32, #blocked>) { + // CHECK: scf.yield {{.*}} : tensor<64xi32, #blocked>, tensor<64xi32, #blocked>, tensor<64xf32, #blocked> + // CHECK: } else { + // CHECK: scf.yield {{.*}} : tensor<64xi32, #blocked>, tensor<64xi32, #blocked>, tensor<64xf32, #blocked> + // CHECK: } + // CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : tensor<64xi32, #blocked>, tensor<64xf32, #blocked>, tensor<64xi32, #blocked> + %0 = arith.select %arg0, %arg1, %arg2 : tensor<64xi32, #blocked> + %1 = arith.select %arg0, %arg3, %arg4 : tensor<64xf32, #blocked> + %2 = scf.if %arg0 -> (tensor<64xi32, #blocked>) { + %3 = arith.subi %arg1, %arg2 : tensor<64xi32, #blocked> + scf.yield %3 : tensor<64xi32, #blocked> + } else { + scf.yield %arg1 : tensor<64xi32, #blocked> + } + tt.return %0, %1, %2 : tensor<64xi32, #blocked>, tensor<64xf32, #blocked>, tensor<64xi32, #blocked> + } +} + +// ----- + +tt.func @if_multiple_sel(%arg0: i1, %arg1: tensor<64xi32>, %arg2: tensor<64xi32>, %arg3: tensor<64xi32>, %arg4: tensor<64xi32>) -> (tensor<64xi32>, tensor<64xi32>, tensor<64xi32>){ + // CHECK-NOT: arith.select + %0 = arith.select %arg0, %arg1, %arg2 : tensor<64xi32> + %1 = arith.select %arg0, %arg3, %arg4 : tensor<64xi32> + // CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (tensor<64xi32>, tensor<64xi32>, tensor<64xi32>) { + // CHECK: scf.yield {{.*}} : tensor<64xi32>, tensor<64xi32>, tensor<64xi32> + // CHECK: } else { + // CHECK: scf.yield {{.*}} : tensor<64xi32>, tensor<64xi32>, tensor<64xi32> + // CHECK: } + %2 = scf.if %arg0 -> (tensor<64xi32>) { + %3 = arith.subi %arg1, %arg2 : tensor<64xi32> + scf.yield %3 : tensor<64xi32> + } else { + scf.yield %arg1 : tensor<64xi32> + } + // CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : tensor<64xi32>, tensor<64xi32>, tensor<64xi32> + tt.return %0, %1, %2 : tensor<64xi32>, tensor<64xi32>, tensor<64xi32> +} + +// ----- +// CHECK-LABEL: tt.func @users_in_if( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i1 +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<64xi32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xi32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<64xf32> +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<64xf32> +tt.func @users_in_if(%arg0: i1, %arg1: tensor<64xi32>, %arg2: tensor<64xi32>, %arg3: tensor<64xf32>, %arg4: tensor<64xf32>) -> (tensor<64xi32>, tensor<64xf32>, tensor<64xi32>, tensor<64xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<8> : tensor<64xi32> + %c8_i32 = arith.constant dense<8> : tensor<64xi32> + // CHECK-NOT: arith.select + %0 = arith.select %arg0, %arg1, %arg2 : tensor<64xi32> + %1 = arith.select %arg0, %arg3, %arg4 : tensor<64xf32> + // CHECK: %[[R:.+]]:4 = scf.if %[[ARG0]] -> (tensor<64xi32>, tensor<64xi32>, tensor<64xi32>, tensor<64xf32>) { + // CHECK: %[[MULI:.*]] = arith.muli %[[ARG1]], %[[ARG2]] : tensor<64xi32> + // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[CST]] : tensor<64xi32> + // CHECK: scf.yield %[[MULI]], %[[ADDI]], %[[ARG1]], %[[ARG3]] : tensor<64xi32>, tensor<64xi32>, tensor<64xi32>, tensor<64xf32> + // CHECK: } else { + // CHECK: %[[ADDI:.*]] = arith.subi %[[ARG2]], %[[CST]] : tensor<64xi32> + // CHECK: scf.yield %[[ARG1]], %[[ADDI]], %[[ARG2]], %[[ARG4]] : tensor<64xi32>, tensor<64xi32>, tensor<64xi32>, tensor<64xf32> + // CHECK: } + %2:2 = scf.if %arg0 -> (tensor<64xi32>, tensor<64xi32>) { + %3 = arith.muli %0, %arg2 : tensor<64xi32> + %4 = arith.addi %0, %c8_i32 : tensor<64xi32> + scf.yield %3, %4 : tensor<64xi32>, tensor<64xi32> + } else { + %3 = arith.subi %0, %c8_i32 : tensor<64xi32> + scf.yield %arg1, %3 : tensor<64xi32>, tensor<64xi32> + } + // CHECK: tt.return %[[R]]#2, %[[R]]#3, %[[R]]#0, %[[R]]#1 : tensor<64xi32>, tensor<64xf32>, tensor<64xi32>, tensor<64xi32> + tt.return %0, %1, %2#0, %2#1 : tensor<64xi32>, tensor<64xf32>, tensor<64xi32>, tensor<64xi32> +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/combine.mlir b/third_party/enflame/include/triton/test/TritonGPU/combine.mlir new file mode 100644 index 000000000..013eda756 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/combine.mlir @@ -0,0 +1,3681 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-remove-layout-conversions -cse 2>&1 | FileCheck %s + +#layout0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#layout1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +#layout2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#layout3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> + + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { + +// CHECK: [[$target_layout:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK-LABEL: cst +tt.func @cst() -> tensor<1024xi32, #layout1> { + %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> + %1 = ttg.convert_layout %cst : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.return %cst : tensor<1024xi32, [[$target_layout]]> + tt.return %1: tensor<1024xi32, #layout1> +} + +// CHECK-LABEL: range +tt.func @range() -> tensor<1024xi32, #layout1> { + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %1 = ttg.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]> + tt.return %1: tensor<1024xi32, #layout1> +} + +// CHECK-LABEL: splat +tt.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { + %0 = tt.splat %arg0 : i32 -> tensor<1024xi32, #layout0> + %1 = ttg.convert_layout %0 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.return %0 : tensor<1024xi32, [[$target_layout]]> + tt.return %1: tensor<1024xi32, #layout1> +} + +// CHECK-LABEL: remat +tt.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %2 = arith.muli %0, %1 : tensor<1024xi32, #layout0> + %3 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %4 = tt.splat %arg0 : i32 -> tensor<1024xi32, #layout0> + %5 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + %6 = arith.addi %3, %5 : tensor<1024xi32, #layout1> + tt.return %6: tensor<1024xi32, #layout1> + // CHECK: %[[A:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[$target_layout]]> + // CHECK: %[[C:.+]] = arith.muli %[[A]], %[[A]] : tensor<1024xi32, [[$target_layout]]> + // CHECK: %[[D:.+]] = arith.addi %[[C]], %[[C]] : tensor<1024xi32, [[$target_layout]]> + // CHECK: tt.return %[[D]] : tensor<1024xi32, [[$target_layout]]> +} + +// Always rematerialize single value loads +// CHECK-LABEL: remat_single_value +tt.func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.splat %arg : !tt.ptr -> tensor<1x!tt.ptr, #layout1> + %1 = tt.load %0 : tensor<1x!tt.ptr, #layout1> + // CHECK-NOT: ttg.convert_layout + %2 = ttg.convert_layout %1 : tensor<1xi32, #layout1> -> tensor<1xi32, #layout0> + %3 = ttg.convert_layout %0 : tensor<1x!tt.ptr, #layout1> -> tensor<1x!tt.ptr, #layout0> + tt.store %3, %2 : tensor<1x!tt.ptr, #layout0> + tt.return +} + +// CHECK-LABEL: remat_fast_load +tt.func @remat_fast_load(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.splat %arg : !tt.ptr -> tensor<16x!tt.ptr, #layout1> + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1> + %2 = tt.addptr %0, %1 : tensor<16x!tt.ptr, #layout1>, tensor<16xi32, #layout1> + %3 = tt.load %2 : tensor<16x!tt.ptr, #layout1> + // CHECK-NOT: ttg.convert_layout + %4 = ttg.convert_layout %3 : tensor<16xi32, #layout1> -> tensor<16xi32, #layout0> + %5 = ttg.convert_layout %2 : tensor<16x!tt.ptr, #layout1> -> tensor<16x!tt.ptr, #layout0> + tt.store %5, %4 : tensor<16x!tt.ptr, #layout0> + tt.return +} + +// Hoist the convert on top of ext to make it cheaper. +// CHECK-LABEL: hoist_above_ext +tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tensor<1024xf32, #layout1> { +// CHECK: %[[CVT:.+]] = ttg.convert_layout +// CHECK: arith.extf %[[CVT]] +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return + %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %1 = tt.splat %arg1 : f32 -> tensor<1024xf32, #layout0> + %2 = arith.addf %0, %1 : tensor<1024xf32, #layout0> + %3 = ttg.convert_layout %2 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + tt.return %3 : tensor<1024xf32, #layout1> +} + +// CHECK-LABEL: hoist_above_ext2 +tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tensor<1024xf32, #layout1> { +// CHECK: %[[CVT:.+]] = ttg.convert_layout +// CHECK: arith.extf %[[CVT]] +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return + %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %1 = tt.splat %arg1 : f16 -> tensor<1024xf16, #layout0> + %2 = arith.extf %1 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %3 = arith.addf %0, %2 : tensor<1024xf32, #layout0> + %4 = ttg.convert_layout %3 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + tt.return %4 : tensor<1024xf32, #layout1> +} + +/// CHECK-LABEL: hoist_above_fptofp +tt.func @hoist_above_fptofp(%arg0: tensor<1024xf8E4M3FNUZ, #layout0>) -> tensor<1024xf32, #layout1> { +// CHECK: %[[CVT:.+]] = ttg.convert_layout +// CHECK: tt.fp_to_fp %[[CVT]] +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return + %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf8E4M3FNUZ, #layout0> -> tensor<1024xf32, #layout0> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #layout0> -> tensor<1024xf32, #layout1> + tt.return %1 : tensor<1024xf32, #layout1> +} + +/// CHECK-LABEL: dont_hoist_above_trunc_fptofp +tt.func @dont_hoist_above_trunc_fptofp(%arg0: tensor<1024xf32, #layout0>) -> tensor<1024xf8E4M3FNUZ, #layout1> { +// CHECK-NOT: ttg.convert_layout +// CHECK: %[[FP8:.+]] = tt.fp_to_fp +// CHECK: ttg.convert_layout %[[FP8]] +// CHECK: tt.return + %0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1024xf32, #layout0> -> tensor<1024xf8E4M3FNUZ, #layout0> + %1 = ttg.convert_layout %0 : tensor<1024xf8E4M3FNUZ, #layout0> -> tensor<1024xf8E4M3FNUZ, #layout1> + tt.return %1 : tensor<1024xf8E4M3FNUZ, #layout1> +} + +// Hoist the convert on top of broadcast to make it cheaper. +// CHECK-LABEL: hoist_above_broadcast +tt.func @hoist_above_broadcast(%arg0: tensor<1024x1xf32, #layout2>, %arg1: f32) -> tensor<1024x128xf32, #layout3> { +// CHECK: %[[CVT:.+]] = ttg.convert_layout +// CHECK: tt.broadcast %[[CVT]] +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return + %0 = tt.broadcast %arg0 : tensor<1024x1xf32, #layout2> -> tensor<1024x128xf32, #layout2> + %1 = tt.splat %arg1 : f32 -> tensor<1024x128xf32, #layout2> + %2 = arith.addf %0, %1 : tensor<1024x128xf32, #layout2> + %3 = ttg.convert_layout %2 : tensor<1024x128xf32, #layout2> -> tensor<1024x128xf32, #layout3> + tt.return %3 : tensor<1024x128xf32, #layout3> +} + + +// CHECK-LABEL: if +tt.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK-NOT: ttg.convert_layout + %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1> + %0 = tt.get_program_id x : i32 + %1 = tt.splat %0 : i32 -> tensor<1024xi32, #layout1> + %2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout1> + %3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout1> + %4 = arith.cmpi sgt, %0, %arg0 : i32 + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout0> + scf.if %4 { + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout1> -> tensor<1024xi32, #layout0> + tt.store %5, %6 : tensor<1024x!tt.ptr, #layout0> + } + tt.return +} + +// CHECK-LABEL: if_convert_else_not +tt.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> + %0 = tt.get_program_id x : i32 + %1 = tt.splat %0 : i32 -> tensor<1024xi32, #layout0> + %9 = tt.splat %0 : i32 -> tensor<1024xi32, #layout1> + %2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0> + %3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout0> + %4 = arith.cmpi sgt, %0, %arg0 : i32 + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout1> + %8 = scf.if %4 -> tensor<1024xi32, #layout1> { + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + scf.yield %6 : tensor<1024xi32, #layout1> + } else { + scf.yield %9 : tensor<1024xi32, #layout1> + } + // CHECK-NOT: ttg.convert_layout + tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> + tt.return +} + +// CHECK-LABEL: if_not_else_convert +tt.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> + %0 = tt.get_program_id x : i32 + %1 = tt.splat %0 : i32 -> tensor<1024xi32, #layout0> + %9 = tt.splat %0 : i32 -> tensor<1024xi32, #layout1> + %2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0> + %3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout0> + %4 = arith.cmpi sgt, %0, %arg0 : i32 + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout1> + %8 = scf.if %4 -> tensor<1024xi32, #layout1> { + scf.yield %9 : tensor<1024xi32, #layout1> + } else { + %7 = ttg.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + scf.yield %7 : tensor<1024xi32, #layout1> + } + // CHECK-NOT: ttg.convert_layout + tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> + tt.return +} + +// CHECK-LABEL: if_else_both_convert +tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> + %0 = tt.get_program_id x : i32 + %1 = tt.splat %0 : i32 -> tensor<1024xi32, #layout0> + %2 = arith.muli %1, %c32_i32 : tensor<1024xi32, #layout0> + %3 = arith.addi %2, %c32_i32 : tensor<1024xi32, #layout0> + %4 = arith.cmpi sgt, %0, %arg0 : i32 + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #layout1> + %8 = scf.if %4 -> tensor<1024xi32, #layout1> { + %6 = ttg.convert_layout %2 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + scf.yield %6 : tensor<1024xi32, #layout1> + } else { + %7 = ttg.convert_layout %3 : tensor<1024xi32, #layout0> -> tensor<1024xi32, #layout1> + scf.yield %7 : tensor<1024xi32, #layout1> + } + // TODO(csigg): seems like the whole function is converted to layout1. + // disabledCHECK: ttg.convert_layout + // CHECK-NOT: ttg.convert_layout + tt.store %5, %8 : tensor<1024x!tt.ptr, #layout1> + tt.return +} + +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked0a = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2a = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#slice2dim0 = #ttg.slice<{dim = 0, parent = #blocked2}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked5 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +// CHECK-DAG: [[$row_layout:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK-DAG: [[$col_layout:#.*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +// CHECK-DAG: [[$col_layout_novec:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> + +// CHECK-LABEL: @transpose +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { + // CHECK-NOT: ttg.convert_layout + // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]> + // CHECK: [[cvt_val:%.*]] = ttg.convert_layout [[loaded_val]] : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout]]> + // CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64x!tt.ptr, [[$col_layout]]> + // CHECK: tt.return + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> + %cst_0 = arith.constant dense : tensor<64x64xi1, #blocked1> + %00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1> + %01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0> + %1 = tt.expand_dims %00 {axis = 1 : i32} : tensor<64xi32, #slice1dim1> -> tensor<64x1xi32, #blocked1> + %2 = tt.splat %arg1 : i32 -> tensor<64x1xi32, #blocked1> + %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> + %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %11 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %13 = tt.splat %arg3 : i32 -> tensor<1x64xi32, #blocked2> + %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> + %15 = tt.broadcast %12 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %16 = tt.broadcast %14 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %17 = ttg.convert_layout %16 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %19 = ttg.convert_layout %10 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %20 = ttg.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %21 = ttg.convert_layout %cst : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %22 = tt.load %19, %20, %21 : tensor<64x64x!tt.ptr, #blocked3> + %23 = ttg.convert_layout %22 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + %24 = ttg.convert_layout %18 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked4> + %25 = ttg.convert_layout %23 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked4> + %26 = ttg.convert_layout %cst_0 : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked4> + tt.store %24, %25, %26 : tensor<64x64x!tt.ptr, #blocked4> + tt.return +} +} + +// CHECK-LABEL: loop +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { + // CHECK-NOT: ttg.convert_layout + // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]>) + // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]> + // CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[$row_layout]]> + // CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr, [[$row_layout]]>, tensor<64x64xi32, [[$row_layout]]> + // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr, [[$row_layout]]> + // CHECK-NEXT: } + // CHECK-NOT: ttg.convert_layout + // CHECK: {{.*}} = ttg.convert_layout [[loop_ret]]#0 : tensor<64x64xf32, [[$row_layout]]> -> tensor<64x64xf32, [[$col_layout_novec]]> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.return + %cst = arith.constant dense : tensor<64x64xi1, #blocked1> + %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> + %00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1> + %01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0> + %1 = tt.expand_dims %00 {axis = 1 : i32} : tensor<64xi32, #slice1dim1> -> tensor<64x1xi32, #blocked1> + %2 = tt.splat %arg1 : i32 -> tensor<64x1xi32, #blocked1> + %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> + %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { + %23 = ttg.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %25 = ttg.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %26 = tt.load %23, %24, %25 : tensor<64x64x!tt.ptr, #blocked3> + %27 = ttg.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + %28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1> + %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1> + } + %12 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %14 = tt.splat %arg3 : i32 -> tensor<1x64xi32, #blocked2> + %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %13 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %17 = tt.broadcast %15 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %18 = ttg.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %20 = ttg.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %21 = ttg.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> + %22 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> + tt.store %20, %21, %22 : tensor<64x64x!tt.ptr, #blocked1> + tt.return +} +} + +// CHECK-LABEL: loop_if +// CHECK-NOT: ttg.convert_layout +// CHECK: scf.for +// CHECK-NOT: ttg.convert_layout +// CHECK: scf.if +// CHECK-NOT: ttg.convert_layout +// CHECK: scf.yield +// CHECK: else +// CHECK: scf.yield +// CHECK-NOT: ttg.convert_layout +// CHECK: scf.yield +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.store +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { + %cst = arith.constant dense : tensor<64x64xi1, #blocked1> + %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %i0 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> + %00 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1dim1> + %01 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice2dim0> + %1 = tt.expand_dims %00 {axis = 1 : i32} : tensor<64xi32, #slice1dim1> -> tensor<64x1xi32, #blocked1> + %2 = tt.splat %arg1 : i32 -> tensor<64x1xi32, #blocked1> + %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %6 = tt.expand_dims %01 {axis = 0 : i32} : tensor<64xi32, #slice2dim0> -> tensor<1x64xi32, #blocked2> + %7 = tt.broadcast %5 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %8 = tt.broadcast %6 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %9 = ttg.convert_layout %8 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { + %33 = arith.cmpi "sgt", %arg5, %c0 : index + %34 = scf.if %33 -> (tensor<64x64xf32, #blocked1>) { + %23 = ttg.convert_layout %arg7 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked3> + %25 = ttg.convert_layout %cst_1 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked3> + %26 = tt.load %23, %24, %25 : tensor<64x64x!tt.ptr, #blocked3> + %27 = ttg.convert_layout %26 : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #blocked1> + scf.yield %27 : tensor<64x64xf32, #blocked1> + } else { + scf.yield %arg6 : tensor<64x64xf32, #blocked1> + } + %28 = arith.addf %arg6, %34 : tensor<64x64xf32, #blocked1> + %29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1> + } + %12 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %14 = tt.splat %arg3 : i32 -> tensor<1x64xi32, #blocked2> + %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %13 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %17 = tt.broadcast %15 : tensor<1x64xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %18 = ttg.convert_layout %17 : tensor<64x64xi32, #blocked2> -> tensor<64x64xi32, #blocked1> + %19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> + %20 = ttg.convert_layout %19 : tensor<64x64x!tt.ptr, #blocked1> -> tensor<64x64x!tt.ptr, #blocked1> + %21 = ttg.convert_layout %11#0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked1> + %22 = ttg.convert_layout %cst : tensor<64x64xi1, #blocked1> -> tensor<64x64xi1, #blocked1> + tt.store %20, %21, %22 : tensor<64x64x!tt.ptr, #blocked1> + tt.return +} +} + +// CHECK-LABEL: vecadd +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + // CHECK-NOT: ttg.convert_layout + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #blocked5> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked5> + %4 = tt.splat %1 : i32 -> tensor<256xi32, #blocked5> + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked5> + %6 = tt.splat %1 : i32 -> tensor<256xi32, #blocked5> + %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked5> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked5> + %9 = arith.addi %6, %7 : tensor<256xi32, #blocked5> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked5> + %11 = arith.addi %4, %5 : tensor<256xi32, #blocked5> + %12 = tt.addptr %8, %9 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> + %13 = tt.load %12 : tensor<256x!tt.ptr, #blocked5> + %14 = ttg.convert_layout %13 : tensor<256xf32, #blocked5> -> tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %15 = tt.addptr %10, %11 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> + %16 = tt.load %15 : tensor<256x!tt.ptr, #blocked5> + %17 = ttg.convert_layout %16 : tensor<256xf32, #blocked5> -> tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %18 = arith.addf %14, %17 : tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> + %19 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked5> + %20 = arith.addi %2, %3 : tensor<256xi32, #blocked5> + %21 = tt.addptr %19, %20 : tensor<256x!tt.ptr, #blocked5>, tensor<256xi32, #blocked5> + %22 = ttg.convert_layout %18 : tensor<256xf32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<256xf32, #blocked5> + tt.store %21, %22 : tensor<256x!tt.ptr, #blocked5> + tt.return +} +} + +// Select has args with different element types +// CHECK-LABEL: select +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { + // CHECK-NOT: ttg.convert_layout + %cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2> + %cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2> + %c512 = arith.constant 512 : i32 + %c30000 = arith.constant 30000 : i32 + %c0 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2> + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked0> + %2 = ttg.convert_layout %1 : tensor<1xi32, #blocked0> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1xi32, #blocked1> + %4 = ttg.convert_layout %3 : tensor<1x1xi32, #blocked1> -> tensor<1x1xi32, #blocked2> + %5 = tt.splat %0 : i32 -> tensor<1x1xi32, #blocked2> + %6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked2> + %7 = arith.cmpi "slt", %6, %cst_1 : tensor<1x1xi32, #blocked2> + %8 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked0> + %9 = ttg.convert_layout %8 : tensor<512xi32, #blocked0> -> tensor<512xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<512xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x512xi32, #blocked2> + %11 = arith.muli %6, %cst : tensor<1x1xi32, #blocked2> + %12 = tt.broadcast %11 : tensor<1x1xi32, #blocked2> -> tensor<1x512xi32, #blocked2> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<1x512x!tt.ptr, #blocked2> + %14 = tt.broadcast %7 : tensor<1x1xi1, #blocked2> -> tensor<1x512xi1, #blocked2> + %15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) : i32 { + %17 = tt.splat %arg3 : i32 -> tensor<1x512xi32, #blocked2> + %18 = arith.addi %17, %10 : tensor<1x512xi32, #blocked2> + %19 = arith.cmpi "slt", %18, %cst_0 : tensor<1x512xi32, #blocked2> + %20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2> + %21 = tt.addptr %13, %20 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> + %22 = arith.andi %19, %14 : tensor<1x512xi1, #blocked2> + %23 = ttg.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> + %24 = ttg.convert_layout %22 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> + %25 = tt.load %23, %24 : tensor<1x512x!tt.ptr, #blocked3> + %26 = ttg.convert_layout %25 : tensor<1x512xf64, #blocked3> -> tensor<1x512xf64, #blocked2> + %27 = arith.andi %14, %19 : tensor<1x512xi1, #blocked2> + %28 = arith.cmpf "olt", %arg4, %26 : tensor<1x512xf64, #blocked2> + %29 = arith.andi %27, %28 : tensor<1x512xi1, #blocked2> + %30 = arith.select %29, %26, %arg4 : tensor<1x512xi1, #blocked2>, tensor<1x512xf64, #blocked2> + %31 = ttg.convert_layout %21 : tensor<1x512x!tt.ptr, #blocked2> -> tensor<1x512x!tt.ptr, #blocked3> + %32 = ttg.convert_layout %30 : tensor<1x512xf64, #blocked2> -> tensor<1x512xf64, #blocked3> + %33 = ttg.convert_layout %27 : tensor<1x512xi1, #blocked2> -> tensor<1x512xi1, #blocked3> + tt.store %31, %32, %33 : tensor<1x512x!tt.ptr, #blocked3> + scf.yield %30 : tensor<1x512xf64, #blocked2> + } + tt.return +} +} + +// Make sure the following IR doesn't hang the compiler. +// CHECK-LABEL: long_func +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: !tt.ptr {tt.divisibility = 16 : i32}, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: !tt.ptr {tt.divisibility = 16 : i32}, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0> + %cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0> + %cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0> + %cst_2 = arith.constant dense<1.000000e+04> : tensor<1024xf32, #blocked0> + %cst_3 = arith.constant dense<5000> : tensor<1024xi32, #blocked0> + %cst_4 = arith.constant dense<150> : tensor<1024xi32, #blocked0> + %cst_5 = arith.constant dense : tensor<1024xi1, #blocked0> + %cst_6 = arith.constant dense<2> : tensor<1024xi32, #blocked0> + %cst_7 = arith.constant dense<4999> : tensor<1024xi32, #blocked0> + %cst_8 = arith.constant dense<2499> : tensor<1024xi32, #blocked0> + %cst_9 = arith.constant dense<2500> : tensor<1024xi32, #blocked0> + %cst_10 = arith.constant dense<0.91629076> : tensor<1024xf32, #blocked0> + %c2499_i32 = arith.constant 2499 : i32 + %cst_11 = arith.constant dense<1024> : tensor<1024xi32, #blocked0> + %c1024_i32 = arith.constant 1024 : i32 + %cst_12 = arith.constant dense<1> : tensor<1024xi32, #blocked0> + %cst_13 = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked0> + %cst_14 = arith.constant dense<0> : tensor<1024xi32, #blocked0> + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked0> + %5 = arith.cmpi "slt", %4, %cst_11 : tensor<1024xi32, #blocked0> + %6 = tt.splat %arg5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %8 = ttg.convert_layout %7 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %9 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %10 = tt.load %8, %9 : tensor<1024x!tt.ptr, #blocked0a> + %11 = ttg.convert_layout %10 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> + %12 = tt.splat %arg7 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %14 = ttg.convert_layout %13 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked2a> + %15 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked2a> + %16 = tt.load %14, %15 : tensor<1024x!tt.ptr, #blocked2a> + %17 = ttg.convert_layout %16 : tensor<1024xi64, #blocked2a> -> tensor<1024xi64, #blocked0> + %18 = tt.splat %arg8 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %19 = tt.addptr %18, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %20 = ttg.convert_layout %19 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %21 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + %22 = tt.load %20, %21 : tensor<1024x!tt.ptr, #blocked0a> + %23 = ttg.convert_layout %22 : tensor<1024xf32, #blocked0a> -> tensor<1024xf32, #blocked0> + %24 = arith.subf %cst_13, %11 : tensor<1024xf32, #blocked0> + %25 = math.exp %24 : tensor<1024xf32, #blocked0> + %26 = arith.sitofp %cst_12 : tensor<1024xi32, #blocked0> to tensor<1024xf32, #blocked0> + %27 = arith.addf %25, %26 : tensor<1024xf32, #blocked0> + %28 = arith.divf %26, %27 : tensor<1024xf32, #blocked0> + %29 = tt.addptr %arg6, %c2499_i32 : !tt.ptr, i32 + %30 = tt.load %29 : !tt.ptr + %31 = arith.subf %11, %cst_10 : tensor<1024xf32, #blocked0> + %32 = arith.subf %cst_13, %31 : tensor<1024xf32, #blocked0> + %33 = math.exp %32 : tensor<1024xf32, #blocked0> + %34 = arith.addf %33, %26 : tensor<1024xf32, #blocked0> + %35 = arith.divf %26, %34 : tensor<1024xf32, #blocked0> + %36 = tt.splat %30 : f32 -> tensor<1024xf32, #blocked0> + %37 = arith.cmpf "oge", %36, %35 : tensor<1024xf32, #blocked0> + %38 = arith.select %37, %cst_14, %cst_9 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %39 = arith.select %37, %cst_8, %cst_7 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %40 = arith.subi %39, %38 : tensor<1024xi32, #blocked0> + %41 = arith.cmpi "slt", %40, %cst_14 : tensor<1024xi32, #blocked0> + %42 = arith.cmpi "ne", %41, %cst_5 : tensor<1024xi1, #blocked0> + %43 = arith.remsi %40, %cst_6 : tensor<1024xi32, #blocked0> + %44 = arith.cmpi "ne", %43, %cst_14 : tensor<1024xi32, #blocked0> + %45 = arith.divsi %40, %cst_6 : tensor<1024xi32, #blocked0> + %46 = arith.subi %45, %cst_12 : tensor<1024xi32, #blocked0> + %47 = arith.select %44, %46, %45 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %48 = arith.select %42, %47, %45 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %49 = arith.addi %38, %48 : tensor<1024xi32, #blocked0> + %50 = arith.cmpi "slt", %38, %39 : tensor<1024xi32, #blocked0> + %51 = arith.select %50, %49, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %52 = tt.splat %arg6 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %53 = tt.addptr %52, %51 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %54 = ttg.convert_layout %53 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %55 = tt.load %54 : tensor<1024x!tt.ptr, #blocked0> + %56 = arith.cmpf "oge", %55, %35 :tensor<1024xf32, #blocked0> + %57 = arith.cmpi "eq", %56, %cst_5 : tensor<1024xi1, #blocked0> + %58 = arith.andi %57, %50 : tensor<1024xi1, #blocked0> + %59 = arith.addi %51, %cst_12 : tensor<1024xi32, #blocked0> + %60 = arith.select %58, %59, %38 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %61 = arith.andi %56, %50 : tensor<1024xi1, #blocked0> + %62 = arith.select %61, %51, %39 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %63 = arith.cmpi "slt", %60, %62 : tensor<1024xi32, #blocked0> + %64 = arith.subi %62, %60 : tensor<1024xi32, #blocked0> + %65 = arith.cmpi "slt", %64, %cst_14 : tensor<1024xi32, #blocked0> + %66 = arith.cmpi "ne", %65, %cst_5 : tensor<1024xi1, #blocked0> + %67 = arith.remsi %64, %cst_6 : tensor<1024xi32, #blocked0> + %68 = arith.cmpi "ne", %67, %cst_14 : tensor<1024xi32, #blocked0> + %69 = arith.divsi %64, %cst_6 : tensor<1024xi32, #blocked0> + %70 = arith.subi %69, %cst_12 : tensor<1024xi32, #blocked0> + %71 = arith.select %68, %70, %69 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %72 = arith.select %66, %71, %69 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %73 = arith.addi %60, %72 : tensor<1024xi32, #blocked0> + %74 = arith.select %63, %73, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %75 = tt.addptr %52, %74 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %76 = ttg.convert_layout %75 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %77 = tt.load %76 : tensor<1024x!tt.ptr, #blocked0> + %78 = arith.cmpf "oge", %77, %35 :tensor<1024xf32, #blocked0> + %79 = arith.cmpi "eq", %78, %cst_5 : tensor<1024xi1, #blocked0> + %80 = arith.andi %79, %63 : tensor<1024xi1, #blocked0> + %81 = arith.addi %74, %cst_12 : tensor<1024xi32, #blocked0> + %82 = arith.select %80, %81, %60 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %83 = arith.andi %78, %63 : tensor<1024xi1, #blocked0> + %84 = arith.select %83, %74, %62 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %85 = arith.cmpi "slt", %82, %84 : tensor<1024xi32, #blocked0> + %86 = arith.subi %84, %82 : tensor<1024xi32, #blocked0> + %87 = arith.cmpi "slt", %86, %cst_14 : tensor<1024xi32, #blocked0> + %88 = arith.cmpi "ne", %87, %cst_5 : tensor<1024xi1, #blocked0> + %89 = arith.remsi %86, %cst_6 : tensor<1024xi32, #blocked0> + %90 = arith.cmpi "ne", %89, %cst_14 : tensor<1024xi32, #blocked0> + %91 = arith.divsi %86, %cst_6 : tensor<1024xi32, #blocked0> + %92 = arith.subi %91, %cst_12 : tensor<1024xi32, #blocked0> + %93 = arith.select %90, %92, %91 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %94 = arith.select %88, %93, %91 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %95 = arith.addi %82, %94 : tensor<1024xi32, #blocked0> + %96 = arith.select %85, %95, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %97 = tt.addptr %52, %96 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %98 = ttg.convert_layout %97 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %99 = tt.load %98 : tensor<1024x!tt.ptr, #blocked0> + %100 = arith.cmpf "oge", %99, %35 : tensor<1024xf32, #blocked0> + %101 = arith.cmpi "eq", %100, %cst_5 : tensor<1024xi1, #blocked0> + %102 = arith.andi %101, %85 : tensor<1024xi1, #blocked0> + %103 = arith.addi %96, %cst_12 : tensor<1024xi32, #blocked0> + %104 = arith.select %102, %103, %82 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %105 = arith.andi %100, %85 : tensor<1024xi1, #blocked0> + %106 = arith.select %105, %96, %84 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %107 = arith.cmpi "slt", %104, %106 : tensor<1024xi32, #blocked0> + %108 = arith.subi %106, %104 : tensor<1024xi32, #blocked0> + %109 = arith.cmpi "slt", %108, %cst_14 : tensor<1024xi32, #blocked0> + %110 = arith.cmpi "ne", %109, %cst_5 : tensor<1024xi1, #blocked0> + %111 = arith.remsi %108, %cst_6 : tensor<1024xi32, #blocked0> + %112 = arith.cmpi "ne", %111, %cst_14 : tensor<1024xi32, #blocked0> + %113 = arith.divsi %108, %cst_6 : tensor<1024xi32, #blocked0> + %114 = arith.subi %113, %cst_12 : tensor<1024xi32, #blocked0> + %115 = arith.select %112, %114, %113 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %116 = arith.select %110, %115, %113 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %117 = arith.addi %104, %116 : tensor<1024xi32, #blocked0> + %118 = arith.select %107, %117, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %119 = tt.addptr %52, %118 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %120 = ttg.convert_layout %119 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %121 = tt.load %120 : tensor<1024x!tt.ptr, #blocked0> + %122 = arith.cmpf "oge", %121, %35 : tensor<1024xf32, #blocked0> + %123 = arith.cmpi "eq", %122, %cst_5 : tensor<1024xi1, #blocked0> + %124 = arith.andi %123, %107 : tensor<1024xi1, #blocked0> + %125 = arith.addi %118, %cst_12 : tensor<1024xi32, #blocked0> + %126 = arith.select %124, %125, %104 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %127 = arith.andi %122, %107 : tensor<1024xi1, #blocked0> + %128 = arith.select %127, %118, %106 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %129 = arith.cmpi "slt", %126, %128 : tensor<1024xi32, #blocked0> + %130 = arith.subi %128, %126 : tensor<1024xi32, #blocked0> + %131 = arith.cmpi "slt", %130, %cst_14 : tensor<1024xi32, #blocked0> + %132 = arith.cmpi "ne", %131, %cst_5 : tensor<1024xi1, #blocked0> + %133 = arith.remsi %130, %cst_6 : tensor<1024xi32, #blocked0> + %134 = arith.cmpi "ne", %133, %cst_14 : tensor<1024xi32, #blocked0> + %135 = arith.divsi %130, %cst_6 : tensor<1024xi32, #blocked0> + %136 = arith.subi %135, %cst_12 : tensor<1024xi32, #blocked0> + %137 = arith.select %134, %136, %135 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %138 = arith.select %132, %137, %135 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %139 = arith.addi %126, %138 : tensor<1024xi32, #blocked0> + %140 = arith.select %129, %139, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %141 = tt.addptr %52, %140 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %142 = ttg.convert_layout %141 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %143 = tt.load %142 : tensor<1024x!tt.ptr, #blocked0> + %144 = arith.cmpf "oge", %143, %35 : tensor<1024xf32, #blocked0> + %145 = arith.cmpi "eq", %144, %cst_5 : tensor<1024xi1, #blocked0> + %146 = arith.andi %145, %129 : tensor<1024xi1, #blocked0> + %147 = arith.addi %140, %cst_12 : tensor<1024xi32, #blocked0> + %148 = arith.select %146, %147, %126 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %149 = arith.andi %144, %129 : tensor<1024xi1, #blocked0> + %150 = arith.select %149, %140, %128 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %151 = arith.cmpi "slt", %148, %150 : tensor<1024xi32, #blocked0> + %152 = arith.subi %150, %148 : tensor<1024xi32, #blocked0> + %153 = arith.cmpi "slt", %152, %cst_14 : tensor<1024xi32, #blocked0> + %154 = arith.cmpi "ne", %153, %cst_5 : tensor<1024xi1, #blocked0> + %155 = arith.remsi %152, %cst_6 : tensor<1024xi32, #blocked0> + %156 = arith.cmpi "ne", %155, %cst_14 : tensor<1024xi32, #blocked0> + %157 = arith.divsi %152, %cst_6 : tensor<1024xi32, #blocked0> + %158 = arith.subi %157, %cst_12 : tensor<1024xi32, #blocked0> + %159 = arith.select %156, %158, %157 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %160 = arith.select %154, %159, %157 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %161 = arith.addi %148, %160 : tensor<1024xi32, #blocked0> + %162 = arith.select %151, %161, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %163 = tt.addptr %52, %162 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %164 = ttg.convert_layout %163 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %165 = tt.load %164 : tensor<1024x!tt.ptr, #blocked0> + %166 = arith.cmpf "oge", %165, %35 : tensor<1024xf32, #blocked0> + %167 = arith.cmpi "eq", %166, %cst_5 : tensor<1024xi1, #blocked0> + %168 = arith.andi %167, %151 : tensor<1024xi1, #blocked0> + %169 = arith.addi %162, %cst_12 : tensor<1024xi32, #blocked0> + %170 = arith.select %168, %169, %148 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %171 = arith.andi %166, %151 : tensor<1024xi1, #blocked0> + %172 = arith.select %171, %162, %150 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %173 = arith.cmpi "slt", %170, %172 : tensor<1024xi32, #blocked0> + %174 = arith.subi %172, %170 : tensor<1024xi32, #blocked0> + %175 = arith.cmpi "slt", %174, %cst_14 : tensor<1024xi32, #blocked0> + %176 = arith.cmpi "ne", %175, %cst_5 : tensor<1024xi1, #blocked0> + %177 = arith.remsi %174, %cst_6 : tensor<1024xi32, #blocked0> + %178 = arith.cmpi "ne", %177, %cst_14 : tensor<1024xi32, #blocked0> + %179 = arith.divsi %174, %cst_6 : tensor<1024xi32, #blocked0> + %180 = arith.subi %179, %cst_12 : tensor<1024xi32, #blocked0> + %181 = arith.select %178, %180, %179 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %182 = arith.select %176, %181, %179 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %183 = arith.addi %170, %182 : tensor<1024xi32, #blocked0> + %184 = arith.select %173, %183, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %185 = tt.addptr %52, %184 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %186 = ttg.convert_layout %185 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %187 = tt.load %186 : tensor<1024x!tt.ptr, #blocked0> + %188 = arith.cmpf "oge", %187, %35 : tensor<1024xf32, #blocked0> + %189 = arith.cmpi "eq", %188, %cst_5 : tensor<1024xi1, #blocked0> + %190 = arith.andi %189, %173 : tensor<1024xi1, #blocked0> + %191 = arith.addi %184, %cst_12 : tensor<1024xi32, #blocked0> + %192 = arith.select %190, %191, %170 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %193 = arith.andi %188, %173 : tensor<1024xi1, #blocked0> + %194 = arith.select %193, %184, %172 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %195 = arith.cmpi "slt", %192, %194 : tensor<1024xi32, #blocked0> + %196 = arith.subi %194, %192 : tensor<1024xi32, #blocked0> + %197 = arith.cmpi "slt", %196, %cst_14 : tensor<1024xi32, #blocked0> + %198 = arith.cmpi "ne", %197, %cst_5 : tensor<1024xi1, #blocked0> + %199 = arith.remsi %196, %cst_6 : tensor<1024xi32, #blocked0> + %200 = arith.cmpi "ne", %199, %cst_14 : tensor<1024xi32, #blocked0> + %201 = arith.divsi %196, %cst_6 : tensor<1024xi32, #blocked0> + %202 = arith.subi %201, %cst_12 : tensor<1024xi32, #blocked0> + %203 = arith.select %200, %202, %201 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %204 = arith.select %198, %203, %201 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %205 = arith.addi %192, %204 : tensor<1024xi32, #blocked0> + %206 = arith.select %195, %205, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %207 = tt.addptr %52, %206 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %208 = ttg.convert_layout %207 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %209 = tt.load %208 : tensor<1024x!tt.ptr, #blocked0> + %210 = arith.cmpf "oge", %209, %35 :tensor<1024xf32, #blocked0> + %211 = arith.cmpi "eq", %210, %cst_5 : tensor<1024xi1, #blocked0> + %212 = arith.andi %211, %195 : tensor<1024xi1, #blocked0> + %213 = arith.addi %206, %cst_12 : tensor<1024xi32, #blocked0> + %214 = arith.select %212, %213, %192 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %215 = arith.andi %210, %195 : tensor<1024xi1, #blocked0> + %216 = arith.select %215, %206, %194 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %217 = arith.cmpi "slt", %214, %216 : tensor<1024xi32, #blocked0> + %218 = arith.subi %216, %214 : tensor<1024xi32, #blocked0> + %219 = arith.cmpi "slt", %218, %cst_14 : tensor<1024xi32, #blocked0> + %220 = arith.cmpi "ne", %219, %cst_5 : tensor<1024xi1, #blocked0> + %221 = arith.remsi %218, %cst_6 : tensor<1024xi32, #blocked0> + %222 = arith.cmpi "ne", %221, %cst_14 : tensor<1024xi32, #blocked0> + %223 = arith.divsi %218, %cst_6 : tensor<1024xi32, #blocked0> + %224 = arith.subi %223, %cst_12 : tensor<1024xi32, #blocked0> + %225 = arith.select %222, %224, %223 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %226 = arith.select %220, %225, %223 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %227 = arith.addi %214, %226 : tensor<1024xi32, #blocked0> + %228 = arith.select %217, %227, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %229 = tt.addptr %52, %228 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %230 = ttg.convert_layout %229 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %231 = tt.load %230 : tensor<1024x!tt.ptr, #blocked0> + %232 = arith.cmpf "oge", %231, %35 : tensor<1024xf32, #blocked0> + %233 = arith.cmpi "eq", %232, %cst_5 : tensor<1024xi1, #blocked0> + %234 = arith.andi %233, %217 : tensor<1024xi1, #blocked0> + %235 = arith.addi %228, %cst_12 : tensor<1024xi32, #blocked0> + %236 = arith.select %234, %235, %214 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %237 = arith.andi %232, %217 : tensor<1024xi1, #blocked0> + %238 = arith.select %237, %228, %216 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %239 = arith.cmpi "slt", %236, %238 : tensor<1024xi32, #blocked0> + %240 = arith.subi %238, %236 : tensor<1024xi32, #blocked0> + %241 = arith.cmpi "slt", %240, %cst_14 : tensor<1024xi32, #blocked0> + %242 = arith.cmpi "ne", %241, %cst_5 : tensor<1024xi1, #blocked0> + %243 = arith.remsi %240, %cst_6 : tensor<1024xi32, #blocked0> + %244 = arith.cmpi "ne", %243, %cst_14 : tensor<1024xi32, #blocked0> + %245 = arith.divsi %240, %cst_6 : tensor<1024xi32, #blocked0> + %246 = arith.subi %245, %cst_12 : tensor<1024xi32, #blocked0> + %247 = arith.select %244, %246, %245 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %248 = arith.select %242, %247, %245 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %249 = arith.addi %236, %248 : tensor<1024xi32, #blocked0> + %250 = arith.select %239, %249, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %251 = tt.addptr %52, %250 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %252 = ttg.convert_layout %251 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %253 = tt.load %252 : tensor<1024x!tt.ptr, #blocked0> + %254 = arith.cmpf "oge", %253, %35 : tensor<1024xf32, #blocked0> + %255 = arith.cmpi "eq", %254, %cst_5 : tensor<1024xi1, #blocked0> + %256 = arith.andi %255, %239 : tensor<1024xi1, #blocked0> + %257 = arith.addi %250, %cst_12 : tensor<1024xi32, #blocked0> + %258 = arith.select %256, %257, %236 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %259 = arith.andi %254, %239 : tensor<1024xi1, #blocked0> + %260 = arith.select %259, %250, %238 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %261 = arith.cmpi "slt", %258, %260 : tensor<1024xi32, #blocked0> + %262 = arith.subi %260, %258 : tensor<1024xi32, #blocked0> + %263 = arith.cmpi "slt", %262, %cst_14 : tensor<1024xi32, #blocked0> + %264 = arith.cmpi "ne", %263, %cst_5 : tensor<1024xi1, #blocked0> + %265 = arith.remsi %262, %cst_6 : tensor<1024xi32, #blocked0> + %266 = arith.cmpi "ne", %265, %cst_14 : tensor<1024xi32, #blocked0> + %267 = arith.divsi %262, %cst_6 : tensor<1024xi32, #blocked0> + %268 = arith.subi %267, %cst_12 : tensor<1024xi32, #blocked0> + %269 = arith.select %266, %268, %267 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %270 = arith.select %264, %269, %267 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %271 = arith.addi %258, %270 : tensor<1024xi32, #blocked0> + %272 = arith.select %261, %271, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %273 = tt.addptr %52, %272 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %274 = ttg.convert_layout %273 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %275 = tt.load %274 : tensor<1024x!tt.ptr, #blocked0> + %276 = arith.cmpf "oge", %275, %35 : tensor<1024xf32, #blocked0> + %277 = arith.cmpi "eq", %276, %cst_5 : tensor<1024xi1, #blocked0> + %278 = arith.andi %277, %261 : tensor<1024xi1, #blocked0> + %279 = arith.addi %272, %cst_12 : tensor<1024xi32, #blocked0> + %280 = arith.select %278, %279, %258 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %281 = arith.andi %276, %261 : tensor<1024xi1, #blocked0> + %282 = arith.select %281, %272, %260 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %283 = arith.cmpi "slt", %280, %282 : tensor<1024xi32, #blocked0> + %284 = arith.subi %282, %280 : tensor<1024xi32, #blocked0> + %285 = arith.cmpi "slt", %284, %cst_14 : tensor<1024xi32, #blocked0> + %286 = arith.cmpi "ne", %285, %cst_5 : tensor<1024xi1, #blocked0> + %287 = arith.remsi %284, %cst_6 : tensor<1024xi32, #blocked0> + %288 = arith.cmpi "ne", %287, %cst_14 : tensor<1024xi32, #blocked0> + %289 = arith.divsi %284, %cst_6 : tensor<1024xi32, #blocked0> + %290 = arith.subi %289, %cst_12 : tensor<1024xi32, #blocked0> + %291 = arith.select %288, %290, %289 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %292 = arith.select %286, %291, %289 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %293 = arith.addi %280, %292 : tensor<1024xi32, #blocked0> + %294 = arith.select %283, %293, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %295 = tt.addptr %52, %294 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %296 = ttg.convert_layout %295 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %297 = tt.load %296 : tensor<1024x!tt.ptr, #blocked0> + %298 = arith.cmpf "oge", %297, %35 :tensor<1024xf32, #blocked0> + %299 = arith.cmpi "eq", %298, %cst_5 : tensor<1024xi1, #blocked0> + %300 = arith.andi %299, %283 : tensor<1024xi1, #blocked0> + %301 = arith.addi %294, %cst_12 : tensor<1024xi32, #blocked0> + %302 = arith.select %300, %301, %280 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %303 = arith.extsi %cst_12 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> + %304 = arith.cmpi "eq", %17, %303 : tensor<1024xi64, #blocked0> + %305 = arith.fptosi %23 : tensor<1024xf32, #blocked0> to tensor<1024xi64, #blocked0> + %306 = arith.extsi %cst_14 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> + %307 = arith.cmpi "sgt", %306, %305 : tensor<1024xi64, #blocked0> + %308 = arith.extsi %cst_4 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> + %309 = arith.cmpi "sgt", %305, %308 : tensor<1024xi64, #blocked0> + %310 = arith.select %309, %306, %305 : tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0> + %311 = arith.select %307, %306, %310 : tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0> + %312 = arith.select %304, %311, %306 : tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0> + %313 = arith.extsi %cst_3 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> + %314 = arith.muli %312, %313 : tensor<1024xi64, #blocked0> + %315 = arith.extsi %302 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> + %316 = arith.addi %315, %314 : tensor<1024xi64, #blocked0> + %317 = arith.trunci %316 : tensor<1024xi64, #blocked0> to tensor<1024xi32, #blocked0> + %318 = arith.extsi %317 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> + %319 = tt.splat %arg9 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %320 = tt.addptr %319, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> + %321 = ttg.convert_layout %320 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %322 = tt.load %321 : tensor<1024x!tt.ptr, #blocked0> + %323 = arith.extf %cst_2 : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0> + %324 = arith.cmpf "ogt", %322, %323 : tensor<1024xf64, #blocked0> + %325 = tt.splat %arg10 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %326 = tt.addptr %325, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> + %327 = ttg.convert_layout %326 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %328 = tt.load %327 : tensor<1024x!tt.ptr, #blocked0> + %329 = arith.divf %328, %322 : tensor<1024xf64, #blocked0> + %330 = arith.truncf %329 : tensor<1024xf64, #blocked0> to tensor<1024xf32, #blocked0> + %331 = arith.mulf %330, %cst_1 : tensor<1024xf32, #blocked0> + %332 = arith.mulf %35, %cst_0 : tensor<1024xf32, #blocked0> + %333 = arith.addf %331, %332 : tensor<1024xf32, #blocked0> + %334 = arith.select %324, %333, %35 : tensor<1024xi1, #blocked0>, tensor<1024xf32, #blocked0> + %335 = tt.addptr %319, %317 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %336 = ttg.convert_layout %335 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %337 = tt.load %336 : tensor<1024x!tt.ptr, #blocked0> + %338 = arith.extf %cst : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0> + %339 = arith.mulf %337, %338 : tensor<1024xf64, #blocked0> + %340 = tt.addptr %325, %317 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %341 = ttg.convert_layout %340 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %342 = tt.load %341 : tensor<1024x!tt.ptr, #blocked0> + %343 = arith.mulf %342, %338 : tensor<1024xf64, #blocked0> + %344 = tt.splat %arg11 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %345 = tt.addptr %344, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %346 = ttg.convert_layout %345 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %347 = ttg.convert_layout %28 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> + %348 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + tt.store %346, %347, %348 : tensor<1024x!tt.ptr, #blocked0a> + %349 = tt.splat %arg12 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %350 = tt.addptr %349, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %351 = ttg.convert_layout %350 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %352 = ttg.convert_layout %317 : tensor<1024xi32, #blocked0> -> tensor<1024xi32, #blocked0a> + %353 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + tt.store %351, %352, %353 : tensor<1024x!tt.ptr, #blocked0a> + %354 = tt.splat %arg13 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %355 = tt.addptr %354, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> + %356 = ttg.convert_layout %355 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0a> + %357 = ttg.convert_layout %334 : tensor<1024xf32, #blocked0> -> tensor<1024xf32, #blocked0a> + %358 = ttg.convert_layout %5 : tensor<1024xi1, #blocked0> -> tensor<1024xi1, #blocked0a> + tt.store %356, %357, %358 : tensor<1024x!tt.ptr, #blocked0a> + %359 = tt.splat %arg14 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %360 = tt.addptr %359, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> + %361 = ttg.convert_layout %360 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %362 = ttg.convert_layout %339 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> + tt.store %361, %362 : tensor<1024x!tt.ptr, #blocked0> + %363 = tt.splat %arg15 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked0> + %364 = tt.addptr %363, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> + %365 = ttg.convert_layout %364 : tensor<1024x!tt.ptr, #blocked0> -> tensor<1024x!tt.ptr, #blocked0> + %366 = ttg.convert_layout %343 : tensor<1024xf64, #blocked0> -> tensor<1024xf64, #blocked0> + tt.store %365, %366 : tensor<1024x!tt.ptr, #blocked0> + tt.return +} +} + +// A mnist model from torch inductor. +// Check if topological sort is working correct and there's no unnecessary convert +// CHECK-LABEL: mnist +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { + // CHECK-NOT: ttg.convert_layout + %cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2> + %cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3> + %c16_i32 = arith.constant 16 : i32 + %cst_1 = arith.constant dense<64> : tensor<16x1xi32, #blocked2> + %cst_2 = arith.constant dense<0xFF800000> : tensor<16x16xf32, #blocked2> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked2> + %cst_4 = arith.constant dense<0> : tensor<16x16xi32, #blocked2> + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c16_i32 : i32 + %2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked0> + %3 = ttg.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<16x1xi32, #blocked1> -> tensor<16x1xi32, #blocked2> + %6 = tt.splat %1 : i32 -> tensor<16x1xi32, #blocked2> + %7 = arith.addi %6, %5 : tensor<16x1xi32, #blocked2> + %8 = arith.cmpi "slt", %7, %cst_1 : tensor<16x1xi32, #blocked2> + %9 = ttg.convert_layout %2 : tensor<16xi32, #blocked0> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %11 = arith.cmpi "slt", %10, %cst_0 : tensor<1x16xi32, #blocked3> + %12 = arith.muli %7, %cst : tensor<16x1xi32, #blocked2> + %13 = tt.broadcast %10 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> + %14 = ttg.convert_layout %13 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked2> + %15 = tt.broadcast %12 : tensor<16x1xi32, #blocked2> -> tensor<16x16xi32, #blocked2> + %16 = arith.addi %14, %15 : tensor<16x16xi32, #blocked2> + %17 = tt.splat %arg0 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked2> + %18 = tt.addptr %17, %16 : tensor<16x16x!tt.ptr, #blocked2>, tensor<16x16xi32, #blocked2> + %19 = tt.broadcast %11 : tensor<1x16xi1, #blocked3> -> tensor<16x16xi1, #blocked3> + %20 = ttg.convert_layout %19 : tensor<16x16xi1, #blocked3> -> tensor<16x16xi1, #blocked2> + %21 = tt.broadcast %8 : tensor<16x1xi1, #blocked2> -> tensor<16x16xi1, #blocked2> + %22 = arith.andi %20, %21 : tensor<16x16xi1, #blocked2> + %23 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %24 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %25 = tt.load %23, %24 : tensor<16x16x!tt.ptr, #blocked4> + %26 = ttg.convert_layout %25 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %27 = arith.cmpf "olt", %cst_2, %26 : tensor<16x16xf32, #blocked2> + %28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2> + %29 = arith.select %28, %26, %cst_2 : tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2> + %30 = "tt.reduce" (%29) ({ + ^bb0(%arg4: f32, %arg5: f32): + %max = arith.maximumf %arg4, %arg5 : f32 + tt.reduce.return %max : f32 + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %31 = ttg.convert_layout %30 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> + %32 = ttg.convert_layout %31 : tensor<16xf32, #blocked0> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> + %34 = ttg.convert_layout %33 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> + %35 = arith.sitofp %cst_4 : tensor<16x16xi32, #blocked2> to tensor<16x16xf32, #blocked2> + %36 = arith.addf %35, %cst_3 : tensor<16x16xf32, #blocked2> + %37 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %38 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %39 = tt.load %37, %38 : tensor<16x16x!tt.ptr, #blocked4> + %40 = ttg.convert_layout %39 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %41 = tt.broadcast %34 : tensor<16x1xf32, #blocked2> -> tensor<16x16xf32, #blocked2> + %42 = arith.subf %40, %41 : tensor<16x16xf32, #blocked2> + %43 = math.exp %42 : tensor<16x16xf32, #blocked2> + %44 = arith.addf %36, %43 : tensor<16x16xf32, #blocked2> + %45 = arith.select %22, %44, %36 : tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2> + %46 = "tt.reduce" (%45) ({ + ^bb0(%arg4: f32, %arg5: f32): + %add = arith.addf %arg4, %arg5 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %47 = ttg.convert_layout %46 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16xf32, #blocked0> + %48 = ttg.convert_layout %47 : tensor<16xf32, #blocked0> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %49 = tt.expand_dims %48 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xf32, #blocked1> + %50 = ttg.convert_layout %49 : tensor<16x1xf32, #blocked1> -> tensor<16x1xf32, #blocked2> + %51 = ttg.convert_layout %18 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %52 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + %53 = tt.load %51, %52 : tensor<16x16x!tt.ptr, #blocked4> + %54 = ttg.convert_layout %53 : tensor<16x16xf32, #blocked4> -> tensor<16x16xf32, #blocked2> + %55 = arith.subf %54, %41 : tensor<16x16xf32, #blocked2> + %56 = math.log %50 : tensor<16x1xf32, #blocked2> + %57 = tt.broadcast %56 : tensor<16x1xf32, #blocked2> -> tensor<16x16xf32, #blocked2> + %58 = arith.subf %55, %57 : tensor<16x16xf32, #blocked2> + %59 = tt.splat %arg1 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked2> + %60 = tt.addptr %59, %16 : tensor<16x16x!tt.ptr, #blocked2>, tensor<16x16xi32, #blocked2> + %61 = ttg.convert_layout %60 : tensor<16x16x!tt.ptr, #blocked2> -> tensor<16x16x!tt.ptr, #blocked4> + %62 = ttg.convert_layout %58 : tensor<16x16xf32, #blocked2> -> tensor<16x16xf32, #blocked4> + %63 = ttg.convert_layout %22 : tensor<16x16xi1, #blocked2> -> tensor<16x16xi1, #blocked4> + tt.store %61, %62, %63 : tensor<16x16x!tt.ptr, #blocked4> + tt.return +} +} + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +// cmpf and cmpi have different operands and result types +// CHECK-LABEL: cmp +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %c64 = arith.constant 64 : i32 + %c2048 = arith.constant 2048 : i32 + %c0 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<-3.40282347E+38> : tensor<64x64xf32, #blocked2> + %cst_0 = arith.constant dense<4194304> : tensor<64x1xi32, #blocked2> + %cst_1 = arith.constant dense<12> : tensor<64x1xi32, #blocked2> + %cst_2 = arith.constant dense<2048> : tensor<1x64xi32, #blocked3> + %cst_3 = arith.constant dense<0> : tensor<64x64xi32, #blocked2> + %cst_4 = arith.constant dense<2048> : tensor<64x1xi32, #blocked2> + %cst_5 = arith.constant dense<49152> : tensor<64x1xi32, #blocked2> + %cst_6 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked2> + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> + %3 = ttg.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<64x1xi32, #blocked1> -> tensor<64x1xi32, #blocked2> + %6 = tt.splat %1 : i32 -> tensor<64x1xi32, #blocked2> + %7 = arith.addi %6, %5 : tensor<64x1xi32, #blocked2> + %8 = arith.cmpi "slt", %7, %cst_5 : tensor<64x1xi32, #blocked2> + %9 = ttg.convert_layout %2 : tensor<64xi32, #blocked0> -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x64xi32, #blocked3> + %11 = arith.remsi %7, %cst_4 : tensor<64x1xi32, #blocked2> + %12 = arith.divsi %7, %cst_4 : tensor<64x1xi32, #blocked2> + %13 = arith.sitofp %cst_3 : tensor<64x64xi32, #blocked2> to tensor<64x64xf32, #blocked2> + %14 = arith.addf %13, %cst_6 : tensor<64x64xf32, #blocked2> + %15 = arith.muli %7, %cst_4 : tensor<64x1xi32, #blocked2> + %16 = tt.broadcast %15 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %17 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> + %18 = tt.broadcast %8 : tensor<64x1xi1, #blocked2> -> tensor<64x64xi1, #blocked2> + %19 = arith.muli %11, %cst_4 : tensor<64x1xi32, #blocked2> + %20 = tt.broadcast %19 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %21 = arith.divsi %12, %cst_1 : tensor<64x1xi32, #blocked2> + %22 = arith.muli %21, %cst_0 : tensor<64x1xi32, #blocked2> + %23 = tt.broadcast %22 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> + %25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args(%arg7 = %14) -> (tensor<64x64xf32, #blocked2>) : i32 { + %45 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked3> + %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> + %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> + %48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3> + %49 = ttg.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> + %50 = arith.addi %49, %16 : tensor<64x64xi32, #blocked2> + %51 = tt.addptr %17, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %52 = tt.broadcast %47 : tensor<1x64xi1, #blocked3> -> tensor<64x64xi1, #blocked3> + %53 = ttg.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> + %54 = arith.andi %53, %18 : tensor<64x64xi1, #blocked2> + %55 = ttg.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %56 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + %57 = tt.load %55, %56 : tensor<64x64x!tt.ptr, #blocked4> + %58 = ttg.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> + %59 = arith.extf %58 : tensor<64x64xf16, #blocked2> to tensor<64x64xf32, #blocked2> + %60 = arith.addi %49, %20 : tensor<64x64xi32, #blocked2> + %61 = arith.addi %60, %23 : tensor<64x64xi32, #blocked2> + %62 = tt.addptr %24, %61 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %63 = ttg.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %64 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + %65 = tt.load %63, %64 : tensor<64x64x!tt.ptr, #blocked5> + %66 = ttg.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> + %67 = arith.addf %59, %66 : tensor<64x64xf32, #blocked2> + %68 = arith.cmpf "une", %67, %67 : tensor<64x64xf32, #blocked2> + %69 = arith.cmpf "ogt", %67, %cst : tensor<64x64xf32, #blocked2> + %70 = arith.select %69, %67, %cst : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> + %71 = arith.select %68, %67, %70 : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> + %72 = math.exp %71 : tensor<64x64xf32, #blocked2> + %73 = arith.addf %arg7, %72 : tensor<64x64xf32, #blocked2> + %74 = arith.select %54, %73, %arg7 : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> + scf.yield %74 : tensor<64x64xf32, #blocked2> + } + %26 = "tt.reduce" (%25) ({ + ^bb0(%arg8: f32, %arg9: f32): + %add = arith.addf %arg8, %arg9 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %27 = ttg.convert_layout %26 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64xf32, #blocked0> + %28 = ttg.convert_layout %27 : tensor<64xf32, #blocked0> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %29 = tt.expand_dims %28 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xf32, #blocked1> + %30 = ttg.convert_layout %29 : tensor<64x1xf32, #blocked1> -> tensor<64x1xf32, #blocked2> + %31 = arith.muli %7, %cst_4 : tensor<64x1xi32, #blocked2> + %32 = tt.broadcast %31 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %33 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> + %34 = tt.broadcast %8 : tensor<64x1xi1, #blocked2> -> tensor<64x64xi1, #blocked2> + %35 = arith.muli %11, %cst_4 : tensor<64x1xi32, #blocked2> + %36 = tt.broadcast %35 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %37 = arith.divsi %12, %cst_1 : tensor<64x1xi32, #blocked2> + %38 = arith.muli %37, %cst_0 : tensor<64x1xi32, #blocked2> + %39 = tt.broadcast %38 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> + %40 = tt.splat %arg1 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> + %41 = tt.broadcast %30 : tensor<64x1xf32, #blocked2> -> tensor<64x64xf32, #blocked2> + %42 = tt.splat %arg2 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> + %43 = tt.splat %arg3 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> + scf.for %arg6 = %c0 to %c2048 step %c64 : i32 { + %45 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked3> + %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> + %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> + %48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3> + %49 = ttg.convert_layout %48 : tensor<64x64xi32, #blocked3> -> tensor<64x64xi32, #blocked2> + %50 = arith.addi %49, %32 : tensor<64x64xi32, #blocked2> + %51 = tt.addptr %33, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %52 = tt.broadcast %47 : tensor<1x64xi1, #blocked3> -> tensor<64x64xi1, #blocked3> + %53 = ttg.convert_layout %52 : tensor<64x64xi1, #blocked3> -> tensor<64x64xi1, #blocked2> + %54 = arith.andi %53, %34 : tensor<64x64xi1, #blocked2> + %55 = ttg.convert_layout %51 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %56 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + %57 = tt.load %55, %56 : tensor<64x64x!tt.ptr, #blocked4> + %58 = ttg.convert_layout %57 : tensor<64x64xf16, #blocked4> -> tensor<64x64xf16, #blocked2> + %59 = arith.extf %58 : tensor<64x64xf16, #blocked2> to tensor<64x64xf32, #blocked2> + %60 = arith.addi %49, %36 : tensor<64x64xi32, #blocked2> + %61 = arith.addi %60, %39 : tensor<64x64xi32, #blocked2> + %62 = tt.addptr %40, %61 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %63 = ttg.convert_layout %62 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %64 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + %65 = tt.load %63, %64 : tensor<64x64x!tt.ptr, #blocked5> + %66 = ttg.convert_layout %65 : tensor<64x64xf32, #blocked5> -> tensor<64x64xf32, #blocked2> + %67 = arith.addf %59, %66 : tensor<64x64xf32, #blocked2> + %68 = arith.cmpf "une", %67, %67 : tensor<64x64xf32, #blocked2> + %69 = arith.cmpf "ogt", %67, %cst : tensor<64x64xf32, #blocked2> + %70 = arith.select %69, %67, %cst : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> + %71 = arith.select %68, %67, %70 : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> + %72 = math.exp %71 : tensor<64x64xf32, #blocked2> + %73 = arith.divf %72, %41 : tensor<64x64xf32, #blocked2> + %74 = tt.addptr %42, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %75 = ttg.convert_layout %74 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked5> + %76 = ttg.convert_layout %73 : tensor<64x64xf32, #blocked2> -> tensor<64x64xf32, #blocked5> + %77 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked5> + tt.store %75, %76, %77 : tensor<64x64x!tt.ptr, #blocked5> + %78 = tt.addptr %43, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> + %79 = arith.truncf %73 : tensor<64x64xf32, #blocked2> to tensor<64x64xf16, #blocked2> + %80 = ttg.convert_layout %78 : tensor<64x64x!tt.ptr, #blocked2> -> tensor<64x64x!tt.ptr, #blocked4> + %81 = ttg.convert_layout %79 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #blocked4> + %82 = ttg.convert_layout %54 : tensor<64x64xi1, #blocked2> -> tensor<64x64xi1, #blocked4> + tt.store %80, %81, %82 : tensor<64x64x!tt.ptr, #blocked4> + } + tt.return +} +} + +// ----- + +// Just make sure it doesn't crash on non-tensor types. +// CHECK-LABEL: if_no_tensor +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func public @if_no_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK-NOT: ttg.convert_layout + %c-1_i64 = arith.constant -1 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c-1_i32 = arith.constant -1 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.addptr %arg3, %0 : !tt.ptr, i32 + %2 = tt.load %1 : !tt.ptr + %3 = arith.cmpi eq, %2, %c-1_i64 : i64 + %4 = arith.select %3, %c-1_i32, %arg2 : i32 + %5 = scf.if %3 -> (!tt.ptr) { + scf.yield %arg0 : !tt.ptr + } else { + %10 = tt.addptr %arg0, %2 : !tt.ptr, i64 + scf.yield %10 : !tt.ptr + } + %6 = arith.extsi %4 : i32 to i64 + %7 = arith.cmpi slt, %2, %6 : i64 + %8 = tt.load %5, %7, %cst : !tt.ptr + %9 = tt.addptr %arg1, %0 : !tt.ptr, i32 + tt.store %9, %8 : !tt.ptr + tt.return +} +} + +// ----- + +// Check if the SimplifyReduceCvt rewriter pattern doesn't hang. +// CHECK-LABEL: reduce_cvt +// CHECK-NOT: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 2 : i32, "ttg.num-ctas" = 1 : i32} { + tt.func public @reduce_cvt1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) { + %cst = arith.constant dense<0> : tensor<1x2xi32, #blocked> + %cst_0 = arith.constant dense<2> : tensor<1x2xi32, #blocked> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<2xi32, #blocked1> -> tensor<2xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x2xi32, #blocked> + %3 = arith.cmpi "slt", %2, %cst_0 : tensor<1x2xi32, #blocked> + %4 = "tt.reduce" (%cst) ({ + ^bb0(%arg3: i32, %arg4: i32): + %add = arith.addi %arg3, %arg4 : i32 + tt.reduce.return %add : i32 + }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = ttg.convert_layout %4 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %6 = ttg.convert_layout %5 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %8 = ttg.convert_layout %7 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<1x2x!tt.ptr, #blocked> + %10 = tt.addptr %9, %2 : tensor<1x2x!tt.ptr, #blocked>, tensor<1x2xi32, #blocked> + %11 = tt.broadcast %8 : tensor<1x1xi32, #blocked> -> tensor<1x2xi32, #blocked> + %12 = arith.extsi %11 : tensor<1x2xi32, #blocked> to tensor<1x2xi64, #blocked> + %13 = ttg.convert_layout %10 : tensor<1x2x!tt.ptr, #blocked> -> tensor<1x2x!tt.ptr, #blocked3> + %14 = ttg.convert_layout %12 : tensor<1x2xi64, #blocked> -> tensor<1x2xi64, #blocked3> + %15 = ttg.convert_layout %3 : tensor<1x2xi1, #blocked> -> tensor<1x2xi1, #blocked3> + tt.store %13, %14, %15 : tensor<1x2x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +// CHECK-LABEL: reduce_cvt2 +// Match the reduction +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.reduce +// CHECK-SAME: axis = 1 +// CHECK: (tensor<1x256xf32, #{{.*}}>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #{{.*}}}>> +// CHECK: ttg.convert_layout +// CHECK: tt.expand_dims +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { + tt.func public @reduce_cvt2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked> + %c3136_i32 = arith.constant 3136 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<3.136000e+03> : tensor<1x1xf32, #blocked> + %cst_1 = arith.constant dense<50176> : tensor<1x256xi32, #blocked> + %cst_2 = arith.constant dense<196> : tensor<1x1xi32, #blocked> + %cst_3 = arith.constant dense<196> : tensor<1x256xi32, #blocked> + %cst_4 = arith.constant dense<3136> : tensor<1x256xi32, #blocked> + %cst_5 = arith.constant dense<256> : tensor<1x1xi32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32, #blocked1> + %2 = ttg.convert_layout %1 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %4 = ttg.convert_layout %3 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked> + %5 = tt.splat %0 : i32 -> tensor<1x1xi32, #blocked> + %6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked> + %7 = arith.cmpi "slt", %6, %cst_5 : tensor<1x1xi32, #blocked> + %8 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<256xi32, #blocked1> -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %11 = arith.muli %6, %cst_2 : tensor<1x1xi32, #blocked> + %12 = tt.broadcast %11 : tensor<1x1xi32, #blocked> -> tensor<1x256xi32, #blocked> + %13 = tt.splat %arg1 : !tt.ptr -> tensor<1x256x!tt.ptr, #blocked> + %14 = tt.broadcast %7 : tensor<1x1xi1, #blocked> -> tensor<1x256xi1, #blocked> + %15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) : i32 { + %43 = tt.splat %arg5 : i32 -> tensor<1x256xi32, #blocked> + %44 = arith.addi %43, %10 : tensor<1x256xi32, #blocked> + %45 = arith.cmpi "slt", %44, %cst_4 : tensor<1x256xi32, #blocked> + %46 = arith.remsi %44, %cst_3 : tensor<1x256xi32, #blocked> + %47 = arith.divsi %44, %cst_3 : tensor<1x256xi32, #blocked> + %48 = arith.addi %46, %12 : tensor<1x256xi32, #blocked> + %49 = arith.muli %47, %cst_1 : tensor<1x256xi32, #blocked> + %50 = arith.addi %48, %49 : tensor<1x256xi32, #blocked> + %51 = tt.addptr %13, %50 : tensor<1x256x!tt.ptr, #blocked>, tensor<1x256xi32, #blocked> + %52 = arith.andi %45, %14 : tensor<1x256xi1, #blocked> + %53 = ttg.convert_layout %51 : tensor<1x256x!tt.ptr, #blocked> -> tensor<1x256x!tt.ptr, #blocked3> + %54 = ttg.convert_layout %52 : tensor<1x256xi1, #blocked> -> tensor<1x256xi1, #blocked3> + %55 = ttg.convert_layout %cst : tensor<1x256xf32, #blocked> -> tensor<1x256xf32, #blocked3> + %56 = tt.load %53, %54, %55 : tensor<1x256x!tt.ptr, #blocked3> + %57 = ttg.convert_layout %56 : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #blocked> + %58 = arith.addf %arg6, %57 : tensor<1x256xf32, #blocked> + %59 = arith.select %52, %58, %arg6 : tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked> + scf.yield %59 : tensor<1x256xf32, #blocked> + } + %16 = "tt.reduce" (%15) ({ + ^bb0(%arg7: f32, %arg8: f32): + %add = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %add : f32 + + }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = ttg.convert_layout %16 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %18 = ttg.convert_layout %17 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.expand_dims %18 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xf32, #blocked2> + %20 = ttg.convert_layout %19 : tensor<1x1xf32, #blocked2> -> tensor<1x1xf32, #blocked> + %21 = arith.divf %20, %cst_0 : tensor<1x1xf32, #blocked> + %22 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> + %23 = tt.addptr %22, %6 : tensor<1x1x!tt.ptr, #blocked>, tensor<1x1xi32, #blocked> + %24 = ttg.convert_layout %23 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1x!tt.ptr, #blocked> + %25 = ttg.convert_layout %21 : tensor<1x1xf32, #blocked> -> tensor<1x1xf32, #blocked> + %26 = ttg.convert_layout %7 : tensor<1x1xi1, #blocked> -> tensor<1x1xi1, #blocked> + tt.store %24, %25, %26 : tensor<1x1x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// Ensure that RematerializeForward doesn't apply when a convert has multiple uses +// CHECK-LABEL: loop_convert_multi_uses +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { + tt.func public @loop_convert_multi_uses(%arg0: i32 {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<16xf32, #blocked> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16xf32, #blocked> + %cst_1 = arith.constant dense<1> : tensor<16xi32, #blocked> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> + %cst_3 = arith.constant dense<1> : tensor<16x1xi32, #blocked1> + %c16_i32 = arith.constant 16 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.divsi %1, %arg0 : i32 + %3 = arith.remsi %1, %arg0 : i32 + %4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> + %5 = arith.muli %0, %c16_i32 : i32 + %6 = tt.splat %5 : i32 -> tensor<16xi32, #blocked> + %7 = arith.addi %6, %4 : tensor<16xi32, #blocked> + %8 = arith.muli %2, %arg3 : i32 + %9 = arith.muli %3, %arg4 : i32 + %10 = arith.addi %8, %9 : i32 + %11 = ttg.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %12 = tt.expand_dims %11 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> + %13 = ttg.convert_layout %12 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> + %14 = tt.splat %arg6 : i32 -> tensor<16x1xi32, #blocked1> + %15 = arith.muli %13, %14 : tensor<16x1xi32, #blocked1> + %16 = ttg.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %18 = tt.broadcast %15 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> + %19 = tt.broadcast %17 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> + %20 = ttg.convert_layout %19 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> + %21 = arith.addi %18, %20 : tensor<16x16xi32, #blocked1> + %22 = tt.splat %arg2 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked1> + %23 = arith.cmpi "slt", %13, %cst_3 : tensor<16x1xi32, #blocked1> + %24 = tt.broadcast %23 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> + %25 = arith.truncf %cst_2 : tensor<16x16xf32, #blocked1> to tensor<16x16xf16, #blocked1> + %26 = arith.muli %2, %arg11 : i32 + %27 = arith.muli %3, %arg12 : i32 + %28 = arith.addi %26, %27 : i32 + %29 = tt.splat %arg10 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %30 = arith.cmpi "slt", %7, %cst_1 : tensor<16xi32, #blocked> + %31 = arith.muli %2, %arg8 : i32 + %32 = arith.muli %3, %arg9 : i32 + %33 = arith.addi %31, %32 : i32 + %34 = tt.splat %arg7 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %35:3 = scf.for %arg17 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg18 = %cst_2, %arg19 = %cst_0, %arg20 = %cst) -> (tensor<16x16xf32, #blocked1>, tensor<16xf32, #blocked>, tensor<16xf32, #blocked>) : i32 { + %60 = arith.muli %arg17, %arg5 : i32 + %61 = arith.addi %10, %60 : i32 + %62 = tt.splat %61 : i32 -> tensor<16x16xi32, #blocked1> + %63 = arith.addi %62, %21 : tensor<16x16xi32, #blocked1> + %64 = tt.addptr %22, %63 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %65 = ttg.convert_layout %64 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> + %66 = ttg.convert_layout %24 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> + %67 = ttg.convert_layout %25 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> + %68 = tt.load %65, %66, %67 : tensor<16x16x!tt.ptr, #blocked4> + %69 = ttg.convert_layout %68 : tensor<16x16xf16, #blocked4> -> tensor<16x16xf16, #blocked1> + %70 = arith.addi %28, %arg17 : i32 + %71 = tt.splat %70 : i32 -> tensor<16xi32, #blocked> + %72 = arith.addi %71, %7 : tensor<16xi32, #blocked> + %73 = tt.addptr %29, %72 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + %74 = ttg.convert_layout %73 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> + %75 = ttg.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> + %76 = ttg.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> + %77 = tt.load %74, %75, %76 : tensor<16x!tt.ptr, #blocked> + %78 = arith.addi %33, %arg17 : i32 + %79 = tt.splat %78 : i32 -> tensor<16xi32, #blocked> + %80 = arith.addi %79, %7 : tensor<16xi32, #blocked> + %81 = tt.addptr %34, %80 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + %82 = ttg.convert_layout %81 : tensor<16x!tt.ptr, #blocked> -> tensor<16x!tt.ptr, #blocked> + %83 = ttg.convert_layout %30 : tensor<16xi1, #blocked> -> tensor<16xi1, #blocked> + %84 = ttg.convert_layout %cst_0 : tensor<16xf32, #blocked> -> tensor<16xf32, #blocked> + %85 = tt.load %82, %83, %84 : tensor<16x!tt.ptr, #blocked> + %86 = arith.cmpf "ogt", %arg20, %85 : tensor<16xf32, #blocked> + %87 = arith.select %86, %arg20, %85 : tensor<16xi1, #blocked>, tensor<16xf32, #blocked> + %88 = arith.subf %arg20, %87 : tensor<16xf32, #blocked> + %89 = math.exp %88 : tensor<16xf32, #blocked> + %90 = arith.subf %85, %87 : tensor<16xf32, #blocked> + %91 = math.exp %90 : tensor<16xf32, #blocked> + %92 = arith.mulf %89, %arg19 : tensor<16xf32, #blocked> + %93 = arith.mulf %91, %77 : tensor<16xf32, #blocked> + %94 = arith.addf %92, %93 : tensor<16xf32, #blocked> + %95 = arith.divf %91, %94 : tensor<16xf32, #blocked> + %96 = arith.divf %arg19, %94 : tensor<16xf32, #blocked> + %97 = arith.mulf %96, %89 : tensor<16xf32, #blocked> + %98 = ttg.convert_layout %97 : tensor<16xf32, #blocked> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %99 = tt.expand_dims %98 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> + %100 = ttg.convert_layout %99 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> + %101 = tt.broadcast %100 : tensor<16x1xf32, #blocked1> -> tensor<16x16xf32, #blocked1> + %102 = arith.mulf %arg18, %101 : tensor<16x16xf32, #blocked1> + %103 = ttg.convert_layout %95 : tensor<16xf32, #blocked> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xf32, #blocked2> + %105 = ttg.convert_layout %104 : tensor<16x1xf32, #blocked2> -> tensor<16x1xf32, #blocked1> + %106 = tt.broadcast %105 : tensor<16x1xf32, #blocked1> -> tensor<16x16xf32, #blocked1> + %107 = arith.extf %69 : tensor<16x16xf16, #blocked1> to tensor<16x16xf32, #blocked1> + %108 = arith.mulf %107, %106 : tensor<16x16xf32, #blocked1> + %109 = arith.addf %102, %108 : tensor<16x16xf32, #blocked1> + scf.yield %109, %94, %87 : tensor<16x16xf32, #blocked1>, tensor<16xf32, #blocked>, tensor<16xf32, #blocked> + } + %36 = arith.muli %2, %arg14 : i32 + %37 = arith.muli %3, %arg15 : i32 + %38 = arith.addi %36, %37 : i32 + %39 = ttg.convert_layout %7 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2> + %41 = ttg.convert_layout %40 : tensor<16x1xi32, #blocked2> -> tensor<16x1xi32, #blocked1> + %42 = tt.splat %arg16 : i32 -> tensor<16x1xi32, #blocked1> + %43 = arith.muli %41, %42 : tensor<16x1xi32, #blocked1> + %44 = ttg.convert_layout %4 : tensor<16xi32, #blocked> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %45 = tt.expand_dims %44 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x16xi32, #blocked3> + %46 = tt.broadcast %43 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> + %47 = tt.broadcast %45 : tensor<1x16xi32, #blocked3> -> tensor<16x16xi32, #blocked3> + %48 = ttg.convert_layout %47 : tensor<16x16xi32, #blocked3> -> tensor<16x16xi32, #blocked1> + %49 = arith.addi %46, %48 : tensor<16x16xi32, #blocked1> + %50 = tt.splat %38 : i32 -> tensor<16x16xi32, #blocked1> + %51 = arith.addi %50, %49 : tensor<16x16xi32, #blocked1> + %52 = tt.splat %arg13 : !tt.ptr -> tensor<16x16x!tt.ptr, #blocked1> + %53 = tt.addptr %52, %51 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %54 = arith.cmpi "slt", %41, %cst_3 : tensor<16x1xi32, #blocked1> + %55 = tt.broadcast %54 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> + %56 = arith.truncf %35#0 : tensor<16x16xf32, #blocked1> to tensor<16x16xf16, #blocked1> + %57 = ttg.convert_layout %53 : tensor<16x16x!tt.ptr, #blocked1> -> tensor<16x16x!tt.ptr, #blocked4> + %58 = ttg.convert_layout %56 : tensor<16x16xf16, #blocked1> -> tensor<16x16xf16, #blocked4> + %59 = ttg.convert_layout %55 : tensor<16x16xi1, #blocked1> -> tensor<16x16xi1, #blocked4> + tt.store %57, %58, %59 : tensor<16x16x!tt.ptr, #blocked4> + tt.return + } +} + +// ----- + +// Check if MoveConvertOutOfLoop hangs because of adding additional conversions +// CHECK-LABEL: @loop_print +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { + tt.func public @loop_print(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<32> : tensor<32x128xi32, #blocked> + %cst_0 = arith.constant dense<32> : tensor<128x32xi32, #blocked1> + %cst_1 = arith.constant 0.000000e+00 : f32 + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked2> + %1 = ttg.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %3 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> + %4 = arith.muli %2, %3 : tensor<128x1xi32, #blocked1> + %5 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked2> + %6 = ttg.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %8 = tt.broadcast %4 : tensor<128x1xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %9 = tt.broadcast %7 : tensor<1x32xi32, #blocked3> -> tensor<128x32xi32, #blocked3> + %10 = ttg.convert_layout %9 : tensor<128x32xi32, #blocked3> -> tensor<128x32xi32, #blocked1> + %11 = arith.addi %8, %10 : tensor<128x32xi32, #blocked1> + %12 = ttg.convert_layout %5 : tensor<32xi32, #blocked2> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %14 = ttg.convert_layout %13 : tensor<32x1xi32, #blocked1> -> tensor<32x1xi32, #blocked> + %15 = ttg.convert_layout %0 : tensor<128xi32, #blocked2> -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> + %17 = tt.broadcast %14 : tensor<32x1xi32, #blocked> -> tensor<32x128xi32, #blocked> + %18 = tt.broadcast %16 : tensor<1x128xi32, #blocked3> -> tensor<32x128xi32, #blocked3> + %19 = ttg.convert_layout %18 : tensor<32x128xi32, #blocked3> -> tensor<32x128xi32, #blocked> + %20 = arith.addi %17, %19 : tensor<32x128xi32, #blocked> + %21 = arith.addi %arg5, %c31_i32 : i32 + %22 = arith.divsi %21, %c32_i32 : i32 + %23 = tt.splat %arg0 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %24 = tt.splat %arg1 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %25:3 = scf.for %arg7 = %c0_i32 to %22 step %c1_i32 iter_args(%arg8 = %cst_1, %arg9 = %11, %arg10 = %20) -> (f32, tensor<128x32xi32, #blocked1>, tensor<32x128xi32, #blocked>) : i32 { + tt.print "a_offsets: " { hex = false, isSigned = array } : %arg9 : tensor<128x32xi32, #blocked1> + %27 = tt.addptr %23, %arg9 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %28 = ttg.convert_layout %27 : tensor<128x32x!tt.ptr, #blocked1> -> tensor<128x32x!tt.ptr, #blocked4> + %29 = tt.load %28 : tensor<128x32x!tt.ptr, #blocked4> + %30 = ttg.convert_layout %29 : tensor<128x32xf16, #blocked4> -> tensor<128x32xf16, #blocked1> + %31 = tt.addptr %24, %arg10 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %32 = ttg.convert_layout %31 : tensor<32x128x!tt.ptr, #blocked> -> tensor<32x128x!tt.ptr, #blocked5> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked5> + %34 = ttg.convert_layout %33 : tensor<32x128xf16, #blocked5> -> tensor<32x128xf16, #blocked> + %35 = "tt.reduce"(%30) <{axis = 0 : i32}> ({ + ^bb0(%arg11: f16, %arg12: f16): + %46 = arith.addf %arg11, %arg12 : f16 + tt.reduce.return %46 : f16 + }) : (tensor<128x32xf16, #blocked1>) -> tensor<32xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> + %36 = ttg.convert_layout %35 : tensor<32xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<32xf16, #blocked2> + %37 = "tt.reduce"(%36) <{axis = 0 : i32}> ({ + ^bb0(%arg11: f16, %arg12: f16): + %46 = arith.addf %arg11, %arg12 : f16 + tt.reduce.return %46 : f16 + }) : (tensor<32xf16, #blocked2>) -> f16 + %38 = "tt.reduce"(%34) <{axis = 0 : i32}> ({ + ^bb0(%arg11: f16, %arg12: f16): + %46 = arith.addf %arg11, %arg12 : f16 + tt.reduce.return %46 : f16 + }) : (tensor<32x128xf16, #blocked>) -> tensor<128xf16, #ttg.slice<{dim = 0, parent = #blocked}>> + %39 = ttg.convert_layout %38 : tensor<128xf16, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<128xf16, #blocked2> + %40 = "tt.reduce"(%39) <{axis = 0 : i32}> ({ + ^bb0(%arg11: f16, %arg12: f16): + %46 = arith.addf %arg11, %arg12 : f16 + tt.reduce.return %46 : f16 + }) : (tensor<128xf16, #blocked2>) -> f16 + %41 = arith.addf %37, %40 : f16 + %42 = arith.extf %41 : f16 to f32 + %43 = arith.addf %arg8, %42 : f32 + %44 = arith.addi %arg9, %cst_0 : tensor<128x32xi32, #blocked1> + %45 = arith.addi %arg10, %cst : tensor<32x128xi32, #blocked> + scf.yield %43, %44, %45 : f32, tensor<128x32xi32, #blocked1>, tensor<32x128xi32, #blocked> + } + %26 = arith.truncf %25#0 : f32 to f16 + tt.store %arg2, %26 : !tt.ptr + tt.return + } +} + +// ----- + +// Check if SimplifyReduceCvt handles the cvt,reduce->reduce,cvt conversion but not the general push forward conversion +// CHECK-LABEL: reduce_cvt3 +// CHECK: tt.dot +// CHECK-NEXT: tt.reduce +// CHECK: ttg.convert_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { + tt.func public @reduce_cvt3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<32x1xi32, #blocked2> -> tensor<32x1xi32, #blocked> + %4 = arith.muli %3, %cst_0 : tensor<32x1xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %7 = ttg.convert_layout %0 : tensor<32xi32, #blocked1> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %9 = tt.broadcast %6 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %10 = tt.broadcast %8 : tensor<1x32xi32, #blocked3> -> tensor<32x32xi32, #blocked3> + %11 = ttg.convert_layout %10 : tensor<32x32xi32, #blocked3> -> tensor<32x32xi32, #blocked> + %12 = tt.addptr %9, %11 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %13 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %14 = tt.addptr %13, %4 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %15 = tt.broadcast %14 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %16 = tt.addptr %15, %11 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %17 = ttg.convert_layout %12 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> + %18 = tt.load %17 : tensor<32x32x!tt.ptr, #blocked4> + %19 = ttg.convert_layout %18 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> + %20 = ttg.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> + %21 = tt.load %20 : tensor<32x32x!tt.ptr, #blocked4> + %22 = ttg.convert_layout %21 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> + %23 = ttg.local_alloc %22 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<32x32xf16, #shared, #smem> + %24 = ttg.memdesc_trans %23 {order=array} : !ttg.memdesc<32x32xf16, #shared, #smem> -> !ttg.memdesc<32x32xf16, #shared1, #smem> + %25 = ttg.local_load %24 : !ttg.memdesc<32x32xf16, #shared1, #smem> -> tensor<32x32xf16, #blocked> + %26 = ttg.convert_layout %19 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked5}>> + %27 = ttg.convert_layout %25 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked5}>> + %28 = ttg.convert_layout %cst : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked5> + %29 = tt.dot %26, %27, %28 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> + %30 = ttg.convert_layout %29 : tensor<32x32xf32, #blocked5> -> tensor<32x32xf32, #blocked> + %31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): + %37 = arith.cmpf "oeq", %arg3, %arg5 : f32 + %38 = arith.cmpi "slt", %arg4, %arg6 : i32 + %39 = arith.andi %37, %38 : i1 + %40 = arith.cmpf "ogt", %arg3, %arg5 : f32 + %41 = arith.ori %40, %39 : i1 + %42 = arith.select %41, %arg3, %arg5 : f32 + %43 = arith.select %41, %arg4, %arg6 : i32 + tt.reduce.return %42, %43 : f32, i32 + }) : (tensor<32x32xf32, #blocked>, tensor<32x32xi32, #blocked>) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + %32 = ttg.convert_layout %31#1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #blocked1> + %33 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #blocked1> + %34 = tt.addptr %33, %0 : tensor<32x!tt.ptr, #blocked1>, tensor<32xi32, #blocked1> + %35 = ttg.convert_layout %34 : tensor<32x!tt.ptr, #blocked1> -> tensor<32x!tt.ptr, #blocked1> + %36 = ttg.convert_layout %32 : tensor<32xi32, #blocked1> -> tensor<32xi32, #blocked1> + tt.store %35, %36 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + + +// ----- + +// Check that we don't have extra convert for flash attention IR. +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3a = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [4, 1, 8], warpsPerCTA = [4, 1, 1], order = [1, 2, 0]}> +#blocked4a = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [1, 4, 8], warpsPerCTA = [1, 4, 1], order = [0, 2, 1]}> +#blocked6a = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked6 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked7 = #ttg.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [1, 1, 4], order = [1, 0, 2]}> +#blocked8 = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 1, 4], order = [0, 1, 2]}> +#blocked9 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @attention_fw(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c0_i64 = arith.constant 0 : i64 + %c64_i64 = arith.constant 64 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked1> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked2> + %cst_3 = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = arith.muli %1, %arg10 : i32 + %4 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %5 = arith.muli %0, %c128_i32 : i32 + %6 = arith.extsi %arg8 : i32 to i64 + %7 = arith.extsi %5 : i32 to i64 + %8 = tt.addptr %arg1, %3 : !tt.ptr, i32 + %9 = arith.addi %arg20, %arg21 : i32 + %10 = arith.extsi %arg11 : i32 to i64 + %11 = tt.addptr %arg2, %3 : !tt.ptr, i32 + %12 = arith.extsi %arg14 : i32 to i64 + %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> + %14 = tt.splat %5 : i32 -> tensor<128xi32, #blocked1> + %15 = arith.addi %14, %13 : tensor<128xi32, #blocked1> + %16 = arith.mulf %arg3, %cst_3 : f32 + %17 = tt.splat %4 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked3> + %18 = tt.splat %7 : i64 -> tensor<128xi64, #blocked3a> + %19 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3a> + %20 = arith.extsi %19 : tensor<128xi32, #blocked3a> to tensor<128xi64, #blocked3a> + %21 = arith.addi %18, %20 : tensor<128xi64, #blocked3a> + %22 = ttg.convert_layout %21 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> + %23 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> + %24 = tt.splat %6 : i64 -> tensor<128x1xi64, #blocked4a> + %25 = arith.muli %23, %24 : tensor<128x1xi64, #blocked4a> + %26 = tt.broadcast %25 : tensor<128x1xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> + %27 = ttg.convert_layout %26 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %28 = tt.addptr %17, %27 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %29 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> + %30 = arith.extsi %29 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> + %31 = ttg.convert_layout %30 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> + %32 = tt.expand_dims %31 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> + %33 = tt.broadcast %32 : tensor<1x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> + %34 = ttg.convert_layout %33 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %35 = tt.addptr %28, %34 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %36 = tt.load %35 : tensor<128x64x!tt.ptr, #blocked3> + %37 = ttg.convert_layout %36 : tensor<128x64xf16, #blocked3> -> tensor<128x64xf16, #blocked2> + %38 = tt.splat %16 : f32 -> tensor<128x64xf32, #blocked2> + %39 = arith.extf %37 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> + %40 = arith.mulf %39, %38 : tensor<128x64xf32, #blocked2> + %41 = arith.truncf %40 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> +// CHECK-NOT: ttg.convert_layout +// CHECK: scf.for +// CHECK-NOT: ttg.convert_layout +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.dot +// CHECK-NOT: ttg.convert_layout +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK: ttg.convert_layout %{{.*}} #ttg.dot_op +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.dot +// CHECK: scf.yield + %42:5 = scf.for %arg22 = %c0_i32 to %9 step %c64_i32 iter_args(%arg23 = %cst_2, %arg24 = %cst_1, %arg25 = %cst_0, %arg26 = %c0_i64, %arg27 = %c0_i64) -> (tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64) : i32 { + %78 = tt.splat %8 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked6> + %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6a> + %80 = arith.extsi %79 : tensor<64xi32, #blocked6a> to tensor<64xi64, #blocked6a> + %81 = ttg.convert_layout %80 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked6}>> + %82 = tt.expand_dims %81 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked6}>> -> tensor<64x1xi64, #blocked6> + %83 = tt.broadcast %82 : tensor<64x1xi64, #blocked6> -> tensor<64x64xi64, #blocked6> + %84 = ttg.convert_layout %83 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> + %85 = tt.addptr %78, %84 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> + %86 = tt.splat %arg26 : i64 -> tensor<64xi64, #blocked6a> + %87 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked6a> + %88 = arith.extsi %87 : tensor<64xi32, #blocked6a> to tensor<64xi64, #blocked6a> + %89 = arith.addi %86, %88 : tensor<64xi64, #blocked6a> + %90 = ttg.convert_layout %89 : tensor<64xi64, #blocked6a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> + %91 = tt.expand_dims %90 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> + %92 = tt.splat %10 : i64 -> tensor<1x64xi64, #blocked6> + %93 = arith.muli %91, %92 : tensor<1x64xi64, #blocked6> + %94 = tt.broadcast %93 : tensor<1x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> + %95 = ttg.convert_layout %94 : tensor<64x64xi64, #blocked6> -> tensor<64x64xi64, #blocked6> + %96 = tt.addptr %85, %95 : tensor<64x64x!tt.ptr, #blocked6>, tensor<64x64xi64, #blocked6> + %97 = tt.load %96 : tensor<64x64x!tt.ptr, #blocked6> + %98 = tt.splat %11 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked3> + %99 = tt.splat %arg27 : i64 -> tensor<64xi64, #blocked3a> + %100 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> + %101 = arith.extsi %100 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> + %102 = arith.addi %99, %101 : tensor<64xi64, #blocked3a> + %103 = ttg.convert_layout %102 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi64, #blocked3> + %105 = tt.splat %12 : i64 -> tensor<64x1xi64, #blocked3> + %106 = arith.muli %104, %105 : tensor<64x1xi64, #blocked3> + %107 = tt.broadcast %106 : tensor<64x1xi64, #blocked3> -> tensor<64x64xi64, #blocked3> + %108 = ttg.convert_layout %107 : tensor<64x64xi64, #blocked3> -> tensor<64x64xi64, #blocked3> + %109 = tt.addptr %98, %108 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> + %110 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> + %111 = arith.extsi %110 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> + %112 = ttg.convert_layout %111 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> + %113 = tt.expand_dims %112 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked4a}>> -> tensor<1x64xi64, #blocked4a> + %114 = tt.broadcast %113 : tensor<1x64xi64, #blocked4a> -> tensor<64x64xi64, #blocked4a> + %115 = ttg.convert_layout %114 : tensor<64x64xi64, #blocked4a> -> tensor<64x64xi64, #blocked3> + %116 = tt.addptr %109, %115 : tensor<64x64x!tt.ptr, #blocked3>, tensor<64x64xi64, #blocked3> + %117 = tt.load %116 : tensor<64x64x!tt.ptr, #blocked3> + %118 = ttg.convert_layout %41 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %119 = ttg.convert_layout %97 : tensor<64x64xf16, #blocked6> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %120 = tt.dot %118, %119, %cst : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> + %121 = ttg.convert_layout %120 : tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #blocked2> + %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> + %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32, %arg29: f32): + %153 = arith.maximumf %arg28, %arg29 : f32 + tt.reduce.return %153 : f32 + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %124 = ttg.convert_layout %123 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> + %125 = arith.maximumf %arg25, %124 : tensor<128xf32, #blocked1> + %126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1> + %127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> + %128 = ttg.convert_layout %125 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %129 = tt.expand_dims %128 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %130 = ttg.convert_layout %129 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %131 = tt.broadcast %130 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> + %132 = arith.subf %122, %131 : tensor<128x64xf32, #blocked2> + %133 = tt.extern_elementwise %132 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked2> + %134 = arith.mulf %arg24, %cst_1 : tensor<128xf32, #blocked1> + %135 = arith.addf %134, %127 : tensor<128xf32, #blocked1> + %136 = ttg.convert_layout %135 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %137 = tt.expand_dims %136 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %138 = ttg.convert_layout %137 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %139 = tt.broadcast %138 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> + %140 = arith.mulf %arg23, %139 : tensor<128x64xf32, #blocked2> + %141 = arith.truncf %133 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> + %142 = ttg.convert_layout %141 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %143 = ttg.convert_layout %117 : tensor<64x64xf16, #blocked3> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %144 = ttg.convert_layout %140 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked> + %145 = tt.dot %142, %143, %144 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> + %146 = ttg.convert_layout %145 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked2> + %147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1> + %148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32, %arg29: f32): + %153 = arith.addf %arg28, %arg29 : f32 + tt.reduce.return %153 : f32 + }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %149 = ttg.convert_layout %148 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128xf32, #blocked1> + %150 = arith.addf %147, %149 : tensor<128xf32, #blocked1> + %151 = arith.addi %arg26, %c64_i64 : i64 + %152 = arith.addi %arg27, %c64_i64 : i64 + scf.yield %146, %150, %125, %151, %152 : tensor<128x64xf32, #blocked2>, tensor<128xf32, #blocked1>, tensor<128xf32, #blocked1>, i64, i64 + } + %43 = ttg.convert_layout %42#1 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> + %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1xf32, #blocked9> + %45 = ttg.convert_layout %44 : tensor<128x1xf32, #blocked9> -> tensor<128x1xf32, #blocked2> + %46 = tt.broadcast %45 : tensor<128x1xf32, #blocked2> -> tensor<128x64xf32, #blocked2> + %47 = arith.divf %42#0, %46 : tensor<128x64xf32, #blocked2> + %48 = arith.muli %1, %arg20 : i32 + %49 = tt.addptr %arg4, %48 : !tt.ptr, i32 + %50 = tt.splat %49 : !tt.ptr -> tensor<128x!tt.ptr, #blocked1> + %51 = tt.addptr %50, %15 : tensor<128x!tt.ptr, #blocked1>, tensor<128xi32, #blocked1> + %52 = tt.extern_elementwise %42#1 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_log2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> + %53 = arith.addf %42#2, %52 : tensor<128xf32, #blocked1> + tt.store %51, %53 : tensor<128x!tt.ptr, #blocked1> + %54 = tt.addptr %arg5, %2 : !tt.ptr, i32 + %55 = arith.extsi %arg17 : i32 to i64 + %56 = arith.extsi %5 : i32 to i64 + %57 = arith.truncf %47 : tensor<128x64xf32, #blocked2> to tensor<128x64xf16, #blocked2> + %58 = ttg.convert_layout %57 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #blocked3> + %59 = tt.splat %54 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked3> + %60 = tt.splat %56 : i64 -> tensor<128xi64, #blocked3a> + %61 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3a> + %62 = arith.extsi %61 : tensor<128xi32, #blocked3a> to tensor<128xi64, #blocked3a> + %63 = arith.addi %60, %62 : tensor<128xi64, #blocked3a> + %64 = ttg.convert_layout %63 : tensor<128xi64, #blocked3a> -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> + %65 = tt.expand_dims %64 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked4a}>> -> tensor<128x1xi64, #blocked4a> + %66 = tt.splat %55 : i64 -> tensor<128x1xi64, #blocked4a> + %67 = arith.muli %65, %66 : tensor<128x1xi64, #blocked4a> + %68 = tt.broadcast %67 : tensor<128x1xi64, #blocked4a> -> tensor<128x64xi64, #blocked4a> + %69 = ttg.convert_layout %68 : tensor<128x64xi64, #blocked4a> -> tensor<128x64xi64, #blocked3> + %70 = tt.addptr %59, %69 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + %71 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked3a> + %72 = arith.extsi %71 : tensor<64xi32, #blocked3a> to tensor<64xi64, #blocked3a> + %73 = ttg.convert_layout %72 : tensor<64xi64, #blocked3a> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> + %74 = tt.expand_dims %73 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked6}>> -> tensor<1x64xi64, #blocked6> + %75 = tt.broadcast %74 : tensor<1x64xi64, #blocked6> -> tensor<128x64xi64, #blocked6> + %76 = ttg.convert_layout %75 : tensor<128x64xi64, #blocked6> -> tensor<128x64xi64, #blocked3> + %77 = tt.addptr %70, %76 : tensor<128x64x!tt.ptr, #blocked3>, tensor<128x64xi64, #blocked3> + tt.store %77, %58 : tensor<128x64x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK-LABEL: axis_mismatch +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @axis_mismatch(%arg0: f32) -> tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> { +// CHECK: %[[R:.+]] = "tt.reduce"(%0) <{axis = 1 : i32}> +// CHECK: %[[C:.+]] = ttg.convert_layout %[[R]] +// CHECK: tt.return %[[C]] + %0 = tt.splat %arg0 : f32 -> tensor<1x16xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ + ^bb0(%arg9: f32, %arg10: f32): + %60 = arith.addf %arg9, %arg10 : f32 + tt.reduce.return %60 : f32 + }) : (tensor<1x16xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = ttg.convert_layout %1 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %3 = ttg.convert_layout %2 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> + tt.return %3: tensor<1xf32, #ttg.slice<{dim = 0, parent = #blocked}>> +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: reduce_to_scalar +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return +tt.func @reduce_to_scalar(%ptr: tensor<1024x!tt.ptr, #blocked>) -> (f32, i32) { + %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked1> + %3:2 = "tt.reduce"(%1, %2) <{axis = 0 : i32}> ({ + ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): + %51 = arith.cmpf "oeq", %arg7, %arg9 : f32 + %52 = arith.cmpi "slt", %arg8, %arg10 : i32 + %53 = arith.andi %51, %52 : i1 + %54 = arith.cmpf "ogt", %arg7, %arg9 : f32 + %55 = arith.ori %54, %53 : i1 + %56 = arith.select %55, %arg7, %arg9 : f32 + %57 = arith.select %55, %arg8, %arg10 : i32 + tt.reduce.return %56, %57 : f32, i32 + }) : (tensor<1024xf32, #blocked1>, tensor<1024xi32, #blocked1>) -> (f32, i32) + tt.return %3#0, %3#1: f32, i32 +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: whileop +// CHECK: %[[L:.+]] = tt.load %{{.*}} : tensor<1024x!tt.ptr, #blocked> +// CHECK: %[[W:.+]] = scf.while (%[[I:.+]] = %[[L]], %{{.*}} = %{{.*}}) : (tensor<1024xf32, #blocked>, i1) -> tensor<1024xf32, #blocked> { +// CHECK: scf.condition(%{{.*}}) %[[I]] : tensor<1024xf32, #blocked> +// CHECK: } do { +// CHECK: ^bb0(%[[ARG1:.+]]: tensor<1024xf32, #blocked>): +// CHECK: %[[ADD:.+]] = arith.addf %[[ARG1]], %[[ARG1]] : tensor<1024xf32, #blocked> +// CHECK: scf.yield %[[ADD]], %{{.*}} : tensor<1024xf32, #blocked>, i1 +// CHECK: } +// CHECK: tt.store %{{.*}}, %[[W]] : tensor<1024x!tt.ptr, #blocked> +tt.func @whileop(%ptr: tensor<1024x!tt.ptr, #blocked>, %cond: i1) { + %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> + %1 = ttg.convert_layout %0 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + %2 = scf.while (%arg0 = %1, %arg1 = %cond) : (tensor<1024xf32, #blocked1>, i1) -> (tensor<1024xf32, #blocked1>) { + scf.condition(%arg1) %arg0 : tensor<1024xf32, #blocked1> + } do { + ^bb0(%arg0: tensor<1024xf32, #blocked1>): + %4 = ttg.convert_layout %arg0 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> + %5 = arith.addf %4, %4 : tensor<1024xf32, #blocked> + %6 = ttg.convert_layout %5 : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked1> + scf.yield %6, %cond : tensor<1024xf32, #blocked1>, i1 + } + %3 = ttg.convert_layout %2 : tensor<1024xf32, #blocked1> -> tensor<1024xf32, #blocked> + tt.store %ptr, %3 : tensor<1024x!tt.ptr, #blocked> + tt.return +} +} + +// ----- + +// Suppose we have a loop which yields a value from outside the loop: +// %x = ... +// %y = ... +// %z = for iter_args(%unused = %x) { +// yield %y +// } +// return %z +// +// This loop returns %y if it runs 1 or more times; otherwise, it returns %x. +// +// Check that we don't transform this loop into `yield %x` on the incorrect +// theory that the yield is dead unless %x = %y. + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { + +// CHECK-LABEL @yield_outside_loop1 +tt.func public @yield_outside_loop1(%arg0: i32, %arg1: i32) -> (i32) { + %c0 = arith.constant 0 : index + %c5 = arith.constant 5 : index + %c1 = arith.constant 1 : index + %0 = scf.for %i = %c0 to %c5 step %c1 iter_args(%arg3 = %arg0) -> (i32) { + scf.yield %arg1 : i32 + } + + // We should return %arg1, not %arg0. (It would also be OK to return %0, if + // the loop didn't get eliminated.) + // + // CHECK: tt.return %arg1 + tt.return %0 : i32 +} // end function + +// CHECK-LABEL @yield_outside_loop2 +tt.func public @yield_outside_loop2(%arg0: i32, %arg1: i32) -> (i32, i32) { + %c0 = arith.constant 0 : index + %c5 = arith.constant 5 : index + %c1 = arith.constant 1 : index + %i0 = arith.constant 0 : i32 + // Only yield a single value + // CHECK: scf.yield %{{.*}} : i32 + %0, %1 = scf.for %i = %c0 to %c5 step %c1 iter_args(%arg3 = %arg0, %sum = %i0) -> (i32, i32) { + %sum1 = arith.addi %sum, %arg3 : i32 + scf.yield %arg0, %sum1 : i32, i32 + } + + tt.return %0, %1 : i32, i32 +} // end function + +} // end module + +// ----- + +// Check that we handle corner cases when hoisting conversions on top of extf because conversion operations on a smaller type are faster. +// For complex slices we may hoist convert on top of extf while the source of extf has multiple uses in the slice. +// In this case we want to make sure we don't replace other uses of extf source. +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK: [[$BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK: [[$MMA:#.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> + +// CHECK-LABEL: @hoist_convert_above_extf_and_remat + tt.func public @hoist_convert_above_extf_and_remat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<256> : tensor<32x1xi32, #blocked> + %cst_0 = arith.constant dense<256> : tensor<32x1xi32, #blocked1> + %cst_1 = arith.constant dense<256> : tensor<256x1xi32, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<1.000000e-03> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %cst_3 = arith.constant dense<2.560000e+02> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x256xf32, #blocked3> + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %4 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked> + %5 = arith.addi %4, %3 : tensor<32x1xi32, #blocked> + %6 = arith.muli %5, %cst : tensor<32x1xi32, #blocked> + %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %9 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %10 = tt.expand_dims %8 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %11 = tt.broadcast %9 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %14 = arith.muli %13, %cst_1 : tensor<256x1xi32, #blocked> + %15 = tt.broadcast %10 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> + %16 = tt.splat %arg0 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked> + %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> + %18 = scf.for %arg7 = %c0_i32 to %c256_i32 step %c64_i32 iter_args(%arg8 = %cst_4) -> (tensor<32x256xf32, #blocked3>) : i32 { + %58 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #blocked> + %59 = arith.addi %6, %58 : tensor<32x1xi32, #blocked> + %60 = tt.broadcast %59 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked> + %61 = arith.addi %60, %11 : tensor<32x64xi32, #blocked> + %62 = tt.splat %arg7 : i32 -> tensor<256x1xi32, #blocked> + %63 = arith.addi %14, %62 : tensor<256x1xi32, #blocked> + %64 = tt.broadcast %63 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> + %65 = arith.addi %64, %15 : tensor<256x64xi32, #blocked> + %66 = tt.addptr %16, %61 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + %67 = tt.load %66 : tensor<32x64x!tt.ptr, #blocked> + %68 = tt.addptr %17, %65 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + %69 = tt.load %68 : tensor<256x64x!tt.ptr, #blocked> + %70 = ttg.local_alloc %69 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem> + %71 = ttg.memdesc_trans %70 {order=array} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem> + %72 = ttg.convert_layout %67 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> + %73 = ttg.local_load %71 : !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> + %74 = ttg.convert_layout %arg8 : tensor<32x256xf32, #blocked3> -> tensor<32x256xf32, #mma> + %75 = ttg.convert_layout %72 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %76 = ttg.convert_layout %73 : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %77 = tt.dot %75, %76, %74 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> + %78 = ttg.convert_layout %77 : tensor<32x256xf32, #mma> -> tensor<32x256xf32, #blocked3> + scf.yield %78 : tensor<32x256xf32, #blocked3> + } + %19 = arith.truncf %18 : tensor<32x256xf32, #blocked3> to tensor<32x256xf16, #blocked3> + %20 = ttg.convert_layout %19 : tensor<32x256xf16, #blocked3> -> tensor<32x256xf16, #blocked2> + %21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %24 = tt.expand_dims %22 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %25 = tt.splat %arg2 : !tt.ptr -> tensor<1x256x!tt.ptr, #blocked2> + %26 = tt.addptr %25, %23 : tensor<1x256x!tt.ptr, #blocked2>, tensor<1x256xi32, #blocked2> + %27 = tt.load %26 : tensor<1x256x!tt.ptr, #blocked2> + %28 = tt.broadcast %27 : tensor<1x256xf16, #blocked2> -> tensor<32x256xf16, #blocked2> + %29 = arith.addf %20, %28 : tensor<32x256xf16, #blocked2> +// CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<1x256xf16, [[$BLOCKED]]> -> tensor<1x256xf16, [[$MMA]]> +// CHECK: %[[B:.+]] = tt.broadcast %[[A]] +// CHECK: %[[C:.+]] = arith.addf %[[B:.+]], {{.*}} +// CHECK: arith.extf %[[C]] : tensor<32x256xf16, [[$MMA]]> to tensor<32x256xf32, [[$MMA]]> + %30 = arith.extf %29 : tensor<32x256xf16, #blocked2> to tensor<32x256xf32, #blocked2> + %31 = "tt.reduce"(%30) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: f32): + %58 = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %58 : f32 + }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %32 = arith.divf %31, %cst_3 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %33 = arith.mulf %30, %30 : tensor<32x256xf32, #blocked2> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: f32): + %58 = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %58 : f32 + }) : (tensor<32x256xf32, #blocked2>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %35 = arith.divf %34, %cst_3 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %36 = arith.mulf %32, %32 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %37 = arith.subf %35, %36 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %38 = math.sqrt %37 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %39 = arith.addf %38, %cst_2 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %40 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> + %41 = tt.expand_dims %39 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<32x1xf32, #blocked2> + %42 = tt.broadcast %40 : tensor<32x1xf32, #blocked2> -> tensor<32x256xf32, #blocked2> + %43 = arith.subf %30, %42 : tensor<32x256xf32, #blocked2> + %44 = tt.broadcast %41 : tensor<32x1xf32, #blocked2> -> tensor<32x256xf32, #blocked2> + %45 = arith.divf %43, %44 : tensor<32x256xf32, #blocked2> + %46 = arith.truncf %45 : tensor<32x256xf32, #blocked2> to tensor<32x256xf16, #blocked2> + %47 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %49 = arith.muli %48, %cst_0 : tensor<32x1xi32, #blocked1> + %50 = tt.splat %1 : i32 -> tensor<32x1xi32, #blocked1> + %51 = arith.addi %50, %49 : tensor<32x1xi32, #blocked1> + %52 = tt.broadcast %51 : tensor<32x1xi32, #blocked1> -> tensor<32x256xi32, #blocked1> + %53 = tt.broadcast %24 : tensor<1x256xi32, #blocked1> -> tensor<32x256xi32, #blocked1> + %54 = arith.addi %52, %53 : tensor<32x256xi32, #blocked1> + %55 = tt.splat %arg5 : !tt.ptr -> tensor<32x256x!tt.ptr, #blocked1> + %56 = tt.addptr %55, %54 : tensor<32x256x!tt.ptr, #blocked1>, tensor<32x256xi32, #blocked1> + %57 = ttg.convert_layout %46 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #blocked1> + tt.store %56, %57 : tensor<32x256x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: @backward_reduce_multiple_results +// CHECK-NOT: ttg.convert_layout +// CHECK: tt.return + tt.func public @backward_reduce_multiple_results() -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> { + %cst = arith.constant dense<0xFFF0000000000000> : tensor<1x32xf64, #blocked1> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<1x32xi32, #blocked2> -> tensor<1x32xi32, #blocked1> + %3:2 = "tt.reduce"(%cst, %2) <{axis = 1 : i32}> ({ + ^bb0(%arg0: f64, %arg1: i32, %arg2: f64, %arg3: i32): + %5 = arith.addi %arg1, %arg3 : i32 + %6 = arith.addf %arg0, %arg2 : f64 + tt.reduce.return %6, %5 : f64, i32 + }) : (tensor<1x32xf64, #blocked1>, tensor<1x32xi32, #blocked1>) -> (tensor<1xf64, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>) + %4 = ttg.convert_layout %3#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + tt.return %4 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> +} +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @reshape_propagate + tt.func public @reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked3> { + // CHECK-NOT: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %c = ttg.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> + tt.return %c : tensor<32xf32, #blocked3> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @reshape_sink_convert + tt.func public @reshape_sink_convert(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked2> { + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.reshape + // CHECK: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + tt.return %b : tensor<32xf32, #blocked2> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @permuting_reshape_propagate + tt.func public @permuting_reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf16, #blocked2> { + // CHECK-NOT: ttg.convert_layout + // CHECK: arith.truncf + // CHECK: ttg.convert_layout + %a = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> + %b = ttg.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> + %c = arith.truncf %b : tensor<32xf32, #blocked2> to tensor<32xf16, #blocked2> + tt.return %c : tensor<32xf16, #blocked2> + } +} + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#slice1dim1 = #ttg.slice<{dim = 1, parent = #blocked1}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: scan_propagation +tt.func @scan_propagation(%arg: tensor<1024xi32, #slice1dim1>) -> tensor<1024xi32, #slice1dim1> { + %1 = ttg.convert_layout %arg : tensor<1024xi32, #slice1dim1> -> tensor<1024xi32, #blocked2> + %2 = "tt.scan" (%1) ({ + ^bb0(%arg3: i32, %arg4: i32): + %add = arith.addi %arg3, %arg4 : i32 + tt.scan.return %add : i32 + }) {axis = 1 : i32, reverse = false} : (tensor<1024xi32, #blocked2>) -> tensor<1024xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<1024xi32, #blocked2> -> tensor<1024xi32, #slice1dim1> + // don't allow non blocked layout to be propagated to scan + // CHECK: ttg.convert_layout + // CHECK: tt.scan + // CHECK: ttg.convert_layout + // CHECK: tt.return + tt.return %3: tensor<1024xi32, #slice1dim1> +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: fw_propagate_for_op + tt.func public @fw_propagate_for_op(%arg0: tensor<1024x4xi32, #blocked>, %arg1: tensor<1024x4x!tt.ptr, #blocked1>) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + + // CHECK-NOT: ttg.convert_layout + // CHECK: arith.muli + // CHECK: scf.for + // CHECK: scf.yield + // CHECK: ttg.convert_layout + // CHECK: tt.store + %0 = ttg.convert_layout %arg0 : tensor<1024x4xi32, #blocked> -> tensor<1024x4xi32, #blocked1> + %1 = arith.muli %0, %0 : tensor<1024x4xi32, #blocked1> + %2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %1) -> (tensor<1024x4xi32, #blocked1>) : i32 { + %3 = arith.addi %arg3, %arg3 : tensor<1024x4xi32, #blocked1> + scf.yield %3 : tensor<1024x4xi32, #blocked1> + } + tt.store %arg1, %2 : tensor<1024x4x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: @rematerialize_through_if + tt.func public @rematerialize_through_if(%arg0: i1, %arg1: f32) -> tensor<32xf32, #blocked> { + // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> + // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> + // CHECK: scf.if %arg0 -> (tensor<32xf32, #blocked>) { + // CHECK-NOT: ttg.convert_layout + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked1> + %0 = tt.splat %arg1 : f32 -> tensor<32xf32, #blocked1> + %3 = scf.if %arg0 -> (tensor<32xf32, #blocked1>) { + %1 = arith.addf %cst, %0 : tensor<32xf32, #blocked1> + scf.yield %1 : tensor<32xf32, #blocked1> + } else { + %2 = arith.addf %cst_0, %0 : tensor<32xf32, #blocked1> + scf.yield %2 : tensor<32xf32, #blocked1> + } + %4 = ttg.convert_layout %3 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + tt.return %4 : tensor<32xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: @rematerialize_if_inside_loop + tt.func public @rematerialize_if_inside_loop() -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) { + // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> + // CHECK: arith.constant {{.*}} : tensor<32xf32, #blocked> + // CHECK-NOT: ttg.convert_layout + // CHECK: %[[for:[0-9]*]]:2 = scf.for {{.*}} -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) + + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.if %{{.*}} -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield {{.*}} : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> + // CHECK: scf.yield {{.*}} : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.return %[[for]]#1, %[[for]]#0 + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %1:2 = scf.for %arg0 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg1 = %cst, %arg3 = %cst_0) -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>) : i32 { + %2 = arith.cmpi eq, %arg0, %c0_i32 : i32 + %3:2 = scf.if %2 -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>) { + scf.yield %cst, %cst_0 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> + } else { + %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> + scf.yield %4, %6 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> + } + scf.yield %3#0, %3#1 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> + } + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + tt.return %7, %1#1 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: rematerialize_loop_arg + tt.func public @rematerialize_loop_arg(%arg0: !tt.ptr) { + // CHECK-NOT: ttg.convert_layout + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c128_i32 = arith.constant 128 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_1 = arith.constant dense<64> : tensor<128x64xi32, #blocked> + %cst_2 = arith.constant dense<128> : tensor<128x64xi32, #blocked> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %0) -> (tensor<128x64x!tt.ptr, #blocked>) + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield %{{.*}} : tensor<128x64x!tt.ptr, #blocked> + %1 = scf.for %arg1 = %c0_i32 to %c128_i32 step %c1_i32 iter_args(%arg2 = %0) -> (tensor<128x64x!tt.ptr, #blocked>) : i32 { + %2 = tt.addptr %arg2, %cst_1 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %3 = ttg.convert_layout %2 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> + tt.store %3, %cst_0 : tensor<128x64x!tt.ptr, #blocked1> + %4 = tt.addptr %arg2, %cst_2 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %5 = ttg.convert_layout %4 : tensor<128x64x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked1> + tt.store %5, %cst_0 : tensor<128x64x!tt.ptr, #blocked1> + scf.yield %2 : tensor<128x64x!tt.ptr, #blocked> + } + tt.return + } +} + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: assertop +// CHECK: %[[L:.+]] = tt.load %{{.*}} : tensor<1024x!tt.ptr, #blocked> +// CHECK: tt.assert %[[L]] + +tt.func @assertop(%ptr: tensor<1024x!tt.ptr, #blocked>) { + %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> + %1 = ttg.convert_layout %0 : tensor<1024xi1, #blocked> -> tensor<1024xi1, #blocked1> + tt.assert %1, "cond must be true " : tensor<1024xi1, #blocked1> + tt.return +} +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @warp_group_dot_wait_propagate + tt.func public @warp_group_dot_wait_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<16x2xf32, #blocked> { + // CHECK-NOT: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + %b = ttng.warp_group_dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> + %c = ttg.convert_layout %b : tensor<16x2xf32, #blocked1> -> tensor<16x2xf32, #blocked> + tt.return %c : tensor<16x2xf32, #blocked> + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [1,0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2,4], threadsPerWarp = [16,2], warpsPerCTA = [1,1], order = [1,0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4,2], threadsPerWarp = [2,16], warpsPerCTA = [1,1], order = [0,1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @trans_propagate + tt.func public @trans_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<2x16xf32, #blocked2> { + // CHECK: tt.trans + // CHECK: ttg.convert_layout + %a = ttg.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + %b = tt.trans %a {order=array} : tensor<16x2xf32, #blocked1> -> tensor<2x16xf32, #blocked2> + tt.return %b : tensor<2x16xf32, #blocked2> + } +} + + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // Verify that we don't hoist the convert on top of the broadcast. In general we should hoist the convert to reduce its cost + // but because this would combine the 1st and 2nd convert and since the 1st convert is known to be a no-op this would + // generate more expensive code. + // CHECK-LABEL: @hoist_with_free_convert + tt.func public @hoist_with_free_convert(%arg0: tensor<128x256xf32, #mma1>, %arg1: tensor<128x1xf32, #mma>) -> tensor<128x256xf32, #blocked> { + // CHECK: ttg.convert_layout + // CHECK: tt.broadcast + // CHECK: ttg.convert_layout + // CHECK: tt.return + %0 = ttg.convert_layout %arg0 : tensor<128x256xf32, #mma1> -> tensor<128x256xf32, #mma> + %1 = tt.broadcast %arg1 : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> + %2 = arith.addf %0, %1 : tensor<128x256xf32, #mma> + %3 = ttg.convert_layout %2 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> + tt.return %3 : tensor<128x256xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @rematerialize_loop_arg + tt.func public @rematerialize_loop_arg() -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1>) { + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c4096_i32 = arith.constant 4096 : i32 + // CHECK: %[[F:.+]]:3 = scf.for + // CHECK: %[[R:.+]] = arith.addf + // CHECK: arith.addf + // CHECK: scf.yield %{{.+}}, %{{.+}}, %[[R]] + // CHECK: } + // CHECK: tt.return %[[F]]#2, %[[F]]#1, %[[F]]#0 + %1:3 = scf.for %arg0 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg1 = %cst, %arg3 = %cst_0, %arg4 = %cst) -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1>) : i32 { + %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> + scf.yield %4, %6, %4 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1> + } + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + tt.return %7, %1#1, %1#2 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1> + + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // Regression test: + // Rematerialization of multiple loop-carried variables, where one is + // rematerialized to the same layout by multiple users. + // Previously this didn't interact correctly with the de-duplication mechanism. + // CHECK-LABEL: @multi_rematerialize_loop_arg + tt.func public @multi_rematerialize_loop_arg(%arg0: !tt.ptr, %arg1: !tt.ptr) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) { + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c2048_i32 = arith.constant 2048 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %cst_1 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %1 = tt.load %0 : tensor<128x64x!tt.ptr, #blocked1> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) + // FIXME: The optimal number of conversions should be 4. + // CHECK-COUNT-5: convert_layout + // CHECK-NOT: convert_layout + // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + // CHECK: } + // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %6 = tt.load %2 : tensor<64x64x!tt.ptr, #blocked2> + %7 = ttg.convert_layout %1 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %8 = ttg.convert_layout %6 : tensor<64x64xf16, #blocked2> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %9 = tt.dot %7, %8, %cst_2, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %10 = tt.load %3 : tensor<128x64x!tt.ptr, #blocked> + %11 = tt.load %4 : tensor<128x64x!tt.ptr, #blocked> + %12 = arith.cmpi eq, %10, %11 : tensor<128x64xi8, #blocked> + %13 = ttg.convert_layout %12 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #mma> + %14 = arith.select %13, %9, %cst_1 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> + %15 = ttg.convert_layout %14 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %16 = "tt.reduce"(%15) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32, %arg7: f32): + %34 = arith.maxnumf %arg6, %arg7 : f32 + tt.reduce.return %34 : f32 + }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = arith.maxnumf %arg5, %16 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %18 = arith.cmpf oeq, %17, %cst_0 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %19 = ttg.convert_layout %18 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>> + %20 = arith.select %18, %cst, %17 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %21 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi1, #mma> + %22 = tt.broadcast %21 : tensor<128x1xi1, #mma> -> tensor<128x64xi1, #mma> + %23 = arith.select %22, %cst_2, %14 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> + %24 = ttg.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %25 = arith.mulf %arg4, %cst : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %26 = ttg.convert_layout %25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %27 = tt.expand_dims %26 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %28 = tt.broadcast %27 : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma> + %29 = arith.mulf %arg3, %28 : tensor<128x64xf32, #mma> + %30 = ttg.convert_layout %23 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %31 = arith.mulf %arg4, %20 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %32 = "tt.reduce"(%24) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32, %arg7: f32): + %34 = arith.addf %arg6, %arg7 : f32 + tt.reduce.return %34 : f32 + }) : (tensor<128x64xf32, #blocked>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %33 = arith.addf %31, %32 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %29, %33, %17 : tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + tt.return %5#1, %5#2 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } +} + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked7 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // Regression test: + // The while loop use the result of the for loop as an argument. + // When propagating the layout, we should only "forward" propagate the layout to the argument and the result of the while loop + // CHECK-LABEL: @while_use_for + tt.func public @while_use_for(%arg0: !tt.ptr, %arg3: !tt.ptr, %arg6: i32) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %c0_i1 = arith.constant 1 : i1 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #blocked1> + %1000 = tt.splat %arg0 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked2> + %1001 = tt.splat %arg0 : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked1> + %1002 = tt.splat %arg0 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %1003 = tt.splat %arg3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %74 = tt.load %1000 : tensor<256x64x!tt.ptr, #blocked2> + %67:2 = scf.for %arg11 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg12 = %cst_0, %arg14 = %1001) -> (tensor<256x128xf32, #blocked1>, tensor<64x128x!tt.ptr, #blocked1>) : i32 { + %76 = tt.load %arg14 : tensor<64x128x!tt.ptr, #blocked1> + %78 = ttg.convert_layout %74 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked7}>> + %79 = ttg.convert_layout %76 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked7}>> + %80 = ttg.convert_layout %arg12 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> + %81 = tt.dot %78, %79, %80, inputPrecision = tf32 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> + %82 = ttg.convert_layout %81 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> + scf.yield %82, %arg14 : tensor<256x128xf32, #blocked1>, tensor<64x128x!tt.ptr, #blocked1> + } + %68:2 = scf.while (%arg11 = %67#0, %arg12 = %c1_i32) : (tensor<256x128xf32, #blocked1>, i32) -> (tensor<256x128xf32, #blocked1>, i32) { + scf.condition(%c0_i1) %arg11, %arg12 : tensor<256x128xf32, #blocked1>, i32 + } do { + ^bb0(%arg11: tensor<256x128xf32, #blocked1>, %arg12: i32): + %80 = ttg.convert_layout %1003 : tensor<256x128x!tt.ptr, #blocked1> -> tensor<256x128x!tt.ptr, #blocked1> + %81 = tt.load %80 : tensor<256x128x!tt.ptr, #blocked1> + %82 = arith.addf %arg11, %81 : tensor<256x128xf32, #blocked1> + %83 = arith.addi %arg12, %c1_i32 : i32 + scf.yield %82, %83 : tensor<256x128xf32, #blocked1>, i32 + } + %69 = arith.truncf %68#0 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %71 = ttg.convert_layout %69 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> + tt.store %1002, %71 : tensor<256x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- +// Minimized reproducer for https://github.com/pytorch/pytorch/issues/130101 +// Check that backward rematerialization bails out when the same tensor requires two different layouts + +// CHECK-LABEL: double_remat +// CHECK: %[[res:.*]] = ttg.convert_layout +// CHECK-NEXT: tt.return %[[res]] +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 2], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:86", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @double_remat() -> tensor<1x256xi32, #blocked> attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<1x256xi32, #blocked1> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> + %3 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<1x2x128xi32, #blocked2> + %4 = tt.reshape %3 : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1> + %5 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x64xi32, #blocked2> + %6 = tt.reshape %5 : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1> + %7 = arith.cmpi ne, %4, %cst : tensor<1x256xi32, #blocked1> + %8 = arith.select %7, %6, %cst : tensor<1x256xi1, #blocked1>, tensor<1x256xi32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked> + tt.return %9 : tensor<1x256xi32, #blocked> + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @if_condition_not_dead_inside_loop + // CHECK: scf.if + // CHECK-NOT: convert_layout + tt.func public @if_condition_not_dead_inside_loop(%arg0: i32) -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) { + %true = arith.constant true + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %1:3 = scf.for %arg10 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg1 = %cst, %arg3 = %cst_0, %arg4 = %true) -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, i1) : i32 { + %3:2 = scf.if %arg4 -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>) { + scf.yield %cst, %cst_0 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> + } else { + %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> + scf.yield %4, %6 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> + } + %119 = arith.cmpi eq, %arg10, %arg0 : i32 + scf.yield %3#0, %3#1, %119 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, i1 + } + %7 = ttg.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + tt.return %7, %1#1 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> + } +} + +// ----- +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @dot_wait + tt.func public @dot_wait(%arg0: tensor<64x64xf32, #mma>, %arg1: tensor<64x128xf32, #mma1>) -> (tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1>) { + %0:2 = ttng.warp_group_dot_wait %arg0, %arg1 {pendings = 0 : i32} : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> + tt.return %0#0, %0#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> + // CHECK: %[[W:.+]]:2 = ttng.warp_group_dot_wait + // CHECK: tt.return %[[W]]#0, %[[W]]#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @split_propagation + // CHECK-SAME: (%[[ARG:.+]]: tensor<128x64x2xf32 + // CHECK: %[[S:.+]], %{{.+}} = tt.split %[[ARG]] + // CHECK: %[[C:.+]] = ttg.convert_layout %[[S]] + // CHECK: tt.return %[[C]] + tt.func public @split_propagation(%arg0: tensor<128x64x2xf32, #blocked>) -> tensor<128x64xf32, #blocked1> { + %0 = ttg.convert_layout %arg0 : tensor<128x64x2xf32, #blocked> -> tensor<128x64x2xf32, #blocked2> + %outLHS, %outRHS = tt.split %0 : tensor<128x64x2xf32, #blocked2> -> tensor<128x64xf32, #blocked1> + tt.return %outLHS : tensor<128x64xf32, #blocked1> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 2, 1], order = [4, 0, 1, 2, 3]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 2, 2, 1, 1], order = [4, 0, 3, 2, 1]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 0, 1, 2, 3]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: lift_convert_to_local_load + // CHECK-NOT: convert_layout + // CHECK: tt.return + tt.func public @lift_convert_to_local_load(%arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #ttg.shared_memory, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> { + %1 = ttg.local_load %arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #ttg.shared_memory, mutable> -> tensor<2x1x32x4x4xi8, #blocked> + %2 = tt.trans %1 {order = array} : tensor<2x1x32x4x4xi8, #blocked> -> tensor<2x4x32x1x4xi8, #blocked1> + %3 = ttg.convert_layout %2 : tensor<2x4x32x1x4xi8, #blocked1> -> tensor<2x4x32x1x4xi8, #blocked2> + tt.return %3 : tensor<2x4x32x1x4xi8, #blocked2> + } +} + +// ----- + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#CL = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { + // CHECK-LABEL: matmul_add + tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %c_ptr_init = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr, #CL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #CL> + %cst = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %t = ttg.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> + // CHECK: %[[T0:.*]] = tt.dot + // CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma> + %t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: scf.yield + scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL> + } + + // CHECK: ttg.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked + tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr, #CL> + tt.return + } +} + +// ----- + +// Minimized reproducer for compiler crash during remove layouts conversions pass: +// If dot result transformed into tensor with shape smaller than one MFMA instruction size, it triggers various asserts. +// This is a smoke test that checks that compiler do not crash. +// +// CHECK-LABEL: small_tensor_mfma + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}> +#mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @small_tensor_mfma(%arg0: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %cst_3 = arith.constant dense<1.230000e+02> : tensor<32x16xf32, #mma1> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %2 = "tt.reduce" (%1) ({ + ^bb0(%arg1: f32, %arg2: f32): + %3 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %3 : f32 + }) {axis = 1 : i32} : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> + %5 = tt.broadcast %4 : tensor<32x1xf32, #blocked> -> tensor<32x16xf32, #blocked> + %6 = ttg.convert_layout %5 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %7 = tt.dot %cst_2, %6, %cst_3 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<32x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<32x16xf32, #mma1> + %addr = tt.splat %arg0 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> + %8 = ttg.convert_layout %7 : tensor<32x16xf32, #mma1> -> tensor<32x16xf32, #blocked> + tt.store %addr, %8 : tensor<32x16x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 1, 2, 2, 1], order = [4, 0, 1, 2, 3]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 2, 2, 1, 1], order = [4, 0, 3, 2, 1]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 0, 1, 2, 3]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: lift_convert_to_local_load + // CHECK-NOT: convert_layout + // CHECK: tt.return + tt.func public @lift_convert_to_local_load(%arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #smem, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> { + %1 = ttg.local_load %arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #smem, mutable> -> tensor<2x1x32x4x4xi8, #blocked> + %2 = tt.trans %1 {order = array} : tensor<2x1x32x4x4xi8, #blocked> -> tensor<2x4x32x1x4xi8, #blocked1> + %3 = ttg.convert_layout %2 : tensor<2x4x32x1x4xi8, #blocked1> -> tensor<2x4x32x1x4xi8, #blocked2> + tt.return %3 : tensor<2x4x32x1x4xi8, #blocked2> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +tt.func @forward_propagate_layout_gather(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked1>) -> tensor<1024x256xf32, #blocked> { + // CHECK-LABEL: forward_propagate_layout_gather + + // CHECK-NOT: convert_layout + %0 = ttg.convert_layout %arg0 : tensor<1024x256xi32, #blocked> -> tensor<1024x256xi32, #blocked2> + %1 = tt.gather %arg1[%0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked1>, tensor<1024x256xi32, #blocked2>) -> tensor<1024x256xf32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<1024x256xf32, #blocked2> -> tensor<1024x256xf32, #blocked> + tt.return %2 : tensor<1024x256xf32, #blocked> +} + +tt.func @forward_only_propagation(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked1>) -> tensor<1024x256xf32, #blocked1> { + // CHECK-LABEL: forward_only_propagation + + // CHECK-NEXT: [[GATHER:%.*]] = tt.gather + %0 = ttg.convert_layout %arg0 : tensor<1024x256xi32, #blocked> -> tensor<1024x256xi32, #blocked2> + %1 = tt.gather %arg1[%0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked1>, tensor<1024x256xi32, #blocked2>) -> tensor<1024x256xf32, #blocked2> + // CHECK-NEXT: [[RES:%.*]] = ttg.convert_layout [[GATHER]] : tensor<1024x256xf32, #blocked> -> tensor<1024x256xf32, #blocked1> + %2 = ttg.convert_layout %1 : tensor<1024x256xf32, #blocked2> -> tensor<1024x256xf32, #blocked1> + // CHECK-NEXT: return [[RES]] + tt.return %2 : tensor<1024x256xf32, #blocked1> +} + +tt.func @backward_remat_gather_layout(%arg0: tensor<64x64xf32, #blocked1>) -> tensor<1x64xf32, #blocked1> { + // CHECK-LABEL: backward_remat_gather_layout + + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %2 = tt.gather %arg0[%1] {axis = 0 : i32} : (tensor<64x64xf32, #blocked1>, tensor<1x64xi32, #blocked>) -> tensor<1x64xf32, #blocked> + + // CHECK-NOT: convert_layout + %3 = ttg.convert_layout %2 : tensor<1x64xf32, #blocked> -> tensor<1x64xf32, #blocked1> + tt.return %3 : tensor<1x64xf32, #blocked1> +} + +tt.func @do_not_propagate(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked1>) -> tensor<1024x256xf32, #blocked> { + // CHECK-LABEL: do_not_propagate + + %0 = ttg.convert_layout %arg0 : tensor<1024x256xi32, #blocked> -> tensor<1024x256xi32, #blocked2> + // CHECK: tt.gather {{.*}} (tensor<128x256xf32, #blocked1>, tensor<1024x256xi32, #blocked2>) -> tensor<1024x256xf32, #blocked2> + %1 = tt.gather %arg1[%0] {axis = 0 : i32, efficient_layout} : (tensor<128x256xf32, #blocked1>, tensor<1024x256xi32, #blocked2>) -> tensor<1024x256xf32, #blocked2> + %2 = ttg.convert_layout %1 : tensor<1024x256xf32, #blocked2> -> tensor<1024x256xf32, #blocked> + tt.return %2 : tensor<1024x256xf32, #blocked> +} + +tt.func @do_not_remat(%arg0: tensor<64x64xf32, #blocked1>) -> tensor<1x64xf32, #blocked1> { + // CHECK-LABEL: do_not_remat + + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + // CHECK: tt.gather {{.*}} (tensor<64x64xf32, #blocked1>, tensor<1x64xi32, #blocked>) -> tensor<1x64xf32, #blocked> + %2 = tt.gather %arg0[%1] {axis = 0 : i32, efficient_layout} : (tensor<64x64xf32, #blocked1>, tensor<1x64xi32, #blocked>) -> tensor<1x64xf32, #blocked> + + %3 = ttg.convert_layout %2 : tensor<1x64xf32, #blocked> -> tensor<1x64xf32, #blocked1> + tt.return %3 : tensor<1x64xf32, #blocked1> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: reuse_layout_conversion +tt.func @reuse_layout_conversion(%arg0: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>) { + // CHECK-NEXT: %cst = arith.constant {{.*}} tensor<64x64xf32, #blocked> + %cst = arith.constant dense<2.000000e+00> : tensor<64x64xf32, #blocked1> + // CHECK-NEXT: [[TRANS:%.*]] = tt.trans %arg0 {{.*}} tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + %0 = tt.trans %arg0 {order = array} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + // CHECK-NEXT: [[CVT:%.*]] = ttg.convert_layout [[TRANS]] : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + %1 = ttg.convert_layout %0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + // CHECK-NEXT: [[RESULT:%.*]] = arith.mulf [[CVT]], %cst : tensor<64x64xf32, #blocked> + %2 = arith.mulf %0, %cst : tensor<64x64xf32, #blocked1> + %3 = ttg.convert_layout %2 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + // CHECK-NEXT: return [[CVT]], [[RESULT]] + tt.return %1, %3 : tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked> +} + +// CHECK-LABEL: respect_dominance +tt.func @respect_dominance(%arg0: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>) { + %cst = arith.constant dense<2.000000e+00> : tensor<64x64xf32, #blocked1> + + // CHECK-COUNT-2: convert_layout + %0 = tt.trans %arg0 {order = array} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + + %2 = arith.mulf %0, %cst : tensor<64x64xf32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + %3 = ttg.convert_layout %2 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + tt.return %1, %3 : tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked> +} + +// CHECK-LABEL: remat_across_regions +tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) { + // CHECK-NEXT: scf.if + scf.if %arg0 { + // CHECK-NEXT: convert_layout + %0 = ttg.convert_layout %arg1 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "test.keep"(%0) : (tensor<8x8xf32, #blocked1>) -> () + // CHECK: else + } else { + %0 = "test.dummy"() : () -> i32 + // CHECK: convert_layout + %1 = ttg.convert_layout %arg1 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "test.keep"(%1) : (tensor<8x8xf32, #blocked1>) -> () + // CHECK: } + } + // CHECK-NEXT: return + tt.return +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + +// CHECK-LABEL: @hoist_one_conditional +tt.func @hoist_one_conditional( + %arg0: i1, + %arg1: tensor<128x32x!tt.ptr, #blocked> +) -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> { + + // CHECK: arith.constant {{.*}} tensor<128x32xf32, #blocked> + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked> + // CHECK: scf.if + %0 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) { + // CHECK-NEXT: [[RES:%.*]] = tt.load + %3 = tt.load %arg1 : tensor<128x32x!tt.ptr, #blocked> + // CHECK-NEXT: yield [[RES]] + scf.yield %3 : tensor<128x32xf32, #blocked> + } else { + scf.yield %cst : tensor<128x32xf32, #blocked> + } + // CHECK: [[TRUNC:%.*]] = arith.truncf + %1 = arith.truncf %0 : tensor<128x32xf32, #blocked> to tensor<128x32xf16, #blocked> + // CHECK-NEXT: convert_layout [[TRUNC]] + %2 = ttg.convert_layout %1 : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + tt.return %2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +} + +// CHECK-LABEL: @hoist_multiple_conditional +tt.func @hoist_multiple_conditional( + %arg0: i1, + %arg1: i1, + %arg2: tensor<128x32x!tt.ptr, #blocked>, + %arg3: tensor<128x32x!tt.ptr, #blocked>, + %arg4: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, + %arg5: tensor<128x128xf32, #mma> +) -> tensor<128x128xf32, #mma> { + // CHECK-COUNT-1: ttg.convert_layout + %cst0 = arith.constant dense<1.0> : tensor<128x32xf32, #blocked> + %cst1 = arith.constant dense<2.0> : tensor<128x32xf32, #blocked> + %0 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) { + %3 = tt.load %arg2 : tensor<128x32x!tt.ptr, #blocked> + scf.yield %3 : tensor<128x32xf32, #blocked> + } else { + scf.yield %cst0 : tensor<128x32xf32, #blocked> + } + %1 = scf.if %arg1 -> (tensor<128x32xf32, #blocked>) { + %4 = tt.load %arg3 : tensor<128x32x!tt.ptr, #blocked> + scf.yield %4 : tensor<128x32xf32, #blocked> + } else { + scf.yield %cst1 : tensor<128x32xf32, #blocked> + } + %2 = arith.addf %0, %1 : tensor<128x32xf32, #blocked> + %3 = ttg.convert_layout %2 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %4 = tt.dot %3, %arg4, %arg5 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + tt.return %4 : tensor<128x128xf32, #mma> +} + +// CHECK-LABEL: @hoist_across_loop +tt.func @hoist_across_loop( + %arg0: i1, + %arg1: tensor<128x32x!tt.ptr, #blocked>, + %arg2: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, + %arg3: tensor<128x128xf32, #mma> +) -> tensor<128x128xf32, #mma> { + // CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op + %cst = arith.constant dense<1.0> : tensor<128x32xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + // CHECK: scf.for + %0:2 = scf.for %i = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg4 = %cst, %acc = %arg3) -> (tensor<128x32xf32, #blocked>, tensor<128x128xf32, #mma>) : i32 { + // CHECK-NEXT: scf.if + %1 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) { + // CHECK-NEXT: [[RES:%.*]] = tt.load + // CHECK-NEXT: ttg.convert_layout [[RES]] + %3 = tt.load %arg1 : tensor<128x32x!tt.ptr, #blocked> + scf.yield %3 : tensor<128x32xf32, #blocked> + } else { + scf.yield %arg4 : tensor<128x32xf32, #blocked> + } + // CHECK-NOT: ttg.convert_layout + %2 = ttg.convert_layout %1 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %3 = tt.dot %2, %arg2, %acc : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + scf.yield %1, %3 : tensor<128x32xf32, #blocked>, tensor<128x128xf32, #mma> + } + tt.return %0#1 : tensor<128x128xf32, #mma> +} + +// CHECK-LABEL: @chained_if +tt.func @chained_if(%arg0: i1, %arg1: i1, %arg2: tensor<32x32x!tt.ptr, #blocked>, %arg3: tensor<32x32x!tt.ptr, #blocked>) -> tensor<32x32xf32, #mma> { + // CHECK-COUNT-1: ttg.convert_layout + %cst = arith.constant dense<1.0> : tensor<32x32xf32, #blocked> + %0 = scf.if %arg0 -> tensor<32x32xf32, #blocked> { + %anchor = tt.load %arg2 : tensor<32x32x!tt.ptr, #blocked> + scf.yield %anchor : tensor<32x32xf32, #blocked> + } else { + scf.yield %cst : tensor<32x32xf32, #blocked> + } + %1 = scf.if %arg1 -> tensor<32x32xf32, #blocked> { + %anchor = tt.load %arg3 : tensor<32x32x!tt.ptr, #blocked> + scf.yield %anchor : tensor<32x32xf32, #blocked> + } else { + scf.yield %0 : tensor<32x32xf32, #blocked> + } + %2 = ttg.convert_layout %1 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #mma> + tt.return %2 : tensor<32x32xf32, #mma> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: @cvt_in_peeled_prologue +tt.func @cvt_in_peeled_prologue(%arg0: tensor<32x32x!tt.ptr, #blocked>, %arg1: i1, %arg2: i32, %arg3: i32, %arg4: i1) { + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.0> : tensor<32x32xbf16, #blocked1> + + // CHECK: scf.if + %0 = scf.if %arg1 -> (tensor<32x32xbf16, #blocked1>) { + // CHECK-NEXT: tt.load + %1 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %2 = ttg.convert_layout %1 : tensor<32x32xbf16, #blocked> -> tensor<32x32xbf16, #blocked1> + // CHECK-NEXT: yield + scf.yield %2 : tensor<32x32xbf16, #blocked1> + // CHECK-NEXT: else + } else { + // CHECK-NEXT: yield + scf.yield %cst : tensor<32x32xbf16, #blocked1> + // CHECK-NEXT: } + } + + // CHECK: [[PEEL1:%.*]] = scf.if + %1 = scf.if %arg4 -> (tensor<32x32xbf16, #blocked1>) { + // CHECK-NEXT: tt.load + %2 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %3 = ttg.convert_layout %2 : tensor<32x32xbf16, #blocked> -> tensor<32x32xbf16, #blocked1> + // CHECK-NEXT: yield + scf.yield %3 : tensor<32x32xbf16, #blocked1> + // CHECK-NEXT: else + } else { + // CHECK-NEXT: yield + scf.yield %0 : tensor<32x32xbf16, #blocked1> + // CHECK-NEXT: } + } + + // CHECK-NEXT: [[CVT:%.*]] = ttg.convert_layout [[PEEL1]] + // CHECK-NEXT: scf.for {{.*}} iter_args(%{{arg[0-9]+}} = [[CVT]]) + %3 = scf.for %i = %arg2 to %arg3 step %c1_i32 iter_args(%k = %1) -> (tensor<32x32xbf16, #blocked1>) : i32 { + // CHECK-NEXT: scf.if + %4 = scf.if %arg1 -> (tensor<32x32xbf16, #blocked1>) { + // CHECK-NEXT: tt.load + %5 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + // CHECK-NEXT: ttg.convert_layout + %6 = ttg.convert_layout %5 : tensor<32x32xbf16, #blocked> -> tensor<32x32xbf16, #blocked1> + scf.yield %6 : tensor<32x32xbf16, #blocked1> + } else { + scf.yield %k : tensor<32x32xbf16, #blocked1> + } + "use.it"(%4) : (tensor<32x32xbf16, #blocked1>) -> () + scf.yield %4 : tensor<32x32xbf16, #blocked1> + } + // CHECK-NOT: ttg.convert_layout + tt.return +} + +// CHECK-LABEL: @cvt_in_loop_if_slice +tt.func @cvt_in_loop_if_slice(%arg0: tensor<32x32x!tt.ptr, #blocked>, %arg1: i1, %arg2: i32, %arg3: i32, %arg4: i1) { + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.0> : tensor<32x32xbf16, #blocked> + + // CHECK: [[IF_OUT:%.*]] = scf.if + %0 = scf.if %arg1 -> (tensor<32x32xbf16, #blocked>) { + // CHECK-NEXT: tt.load + %1 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + // CHECK-NEXT: yield + scf.yield %1 : tensor<32x32xbf16, #blocked> + // CHECK-NEXT: else + } else { + // CHECK-NEXT: yield + scf.yield %cst : tensor<32x32xbf16, #blocked> + // CHECK-NEXT: } + } + + // CHECK-NEXT: [[CVT:%.*]] = ttg.convert_layout [[IF_OUT]] + // CHECK-NEXT: scf.for + %1 = scf.for %i = %arg2 to %arg3 step %c1_i32 iter_args(%k = %cst) -> tensor<32x32xbf16, #blocked> : i32 { + // CHECK-NEXT: scf.if + %4 = scf.if %arg4 -> (tensor<32x32xbf16, #blocked>) { + // CHECK-NEXT: tt.load + %5 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + // CHECK-NEXT: ttg.convert_layout + scf.yield %5 : tensor<32x32xbf16, #blocked> + } else { + scf.yield %k : tensor<32x32xbf16, #blocked> + } + %6 = arith.addf %4, %0 : tensor<32x32xbf16, #blocked> + // CHECK-NOT: ttg.convert_layout + %7 = ttg.convert_layout %6 : tensor<32x32xbf16, #blocked> -> tensor<32x32xbf16, #blocked1> + "use.it"(%7) : (tensor<32x32xbf16, #blocked1>) -> () + scf.yield %6 : tensor<32x32xbf16, #blocked> + } + + tt.return +} + +} + +// ----- + +#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}> +#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: reduce_linear_layouts +tt.func @reduce_linear_layouts(%arg0: tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> { + // CHECK-NOT: convert_layout + %0 = ttg.convert_layout %arg0 : tensor<32x32xi32, #linear> -> tensor<32x32xi32, #blocked> + // CHECK-NEXT: tt.reduce + %1 = "tt.reduce" (%0) ({ + ^bb0(%arg1: i32, %arg2: i32): + tt.reduce.return %arg1 : i32 + // CHECK: (tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}> + }) {axis = 1 : i32} : (tensor<32x32xi32, #blocked>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = ttg.convert_layout %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> + tt.return %2 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#linear = #ttg.linear<{register = [[16, 0]], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + + // Test that after dot_scaled with rhs scales is decomposed, we are able to get rid of the redundant convert_layout + // CHECK-LABEL: dot_scale_transpose + tt.func public @dot_scale_transpose(%arg0: tensor<128x64xf8E4M3FN, #blocked>, %arg1: tensor<32x32xi8, #blocked1>, %arg2: tensor<128x32x!tt.ptr, #blocked3>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked1> + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = scf.for %arg4 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<128x32xf32, #blocked1>) : i32 { + %3 = tt.trans %arg0 {order = array} : tensor<128x64xf8E4M3FN, #blocked> -> tensor<64x128xf8E4M3FN, #blocked4> + %4 = tt.trans %arg1 {order = array} : tensor<32x32xi8, #blocked1> -> tensor<32x32xi8, #blocked5> + %5 = tt.trans %arg5 {order = array} : tensor<128x32xf32, #blocked1> -> tensor<32x128xf32, #blocked5> + %6 = ttg.convert_layout %5 : tensor<32x128xf32, #blocked5> -> tensor<32x128xf32, #mma> + %7 = ttg.convert_layout %4 : tensor<32x32xi8, #blocked5> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %9 = ttg.fp4_to_fp %7 {axis = 1 : i32} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %10 = ttg.convert_layout %3 : tensor<64x128xf8E4M3FN, #blocked4> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %11 = tt.fp_to_fp %10 : tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %12 = tt.dot %9, %11, %6 : tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x128xf32, #mma> + // CHECK: tt.dot + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield + %13 = ttg.convert_layout %12 : tensor<32x128xf32, #mma> -> tensor<32x128xf32, #blocked5> + %14 = tt.trans %13 {order = array} : tensor<32x128xf32, #blocked5> -> tensor<128x32xf32, #blocked1> + scf.yield %14 : tensor<128x32xf32, #blocked1> + } + // CHECK: arith.truncf + // CHECK-NEXT: ttg.convert_layout + // CHECK-NEXT: tt.store + %1 = arith.truncf %0 : tensor<128x32xf32, #blocked1> to tensor<128x32xbf16, #blocked1> + %2 = ttg.convert_layout %1 : tensor<128x32xbf16, #blocked1> -> tensor<128x32xbf16, #blocked3> + tt.store %arg2, %2 : tensor<128x32x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +tt.func public @reshape_slice_dot_enc(%arg0: tensor<4x16xi32, #blocked>) -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> { + %0 = tt.reshape %arg0 : tensor<4x16xi32, #blocked> -> tensor<64xi32, #blocked2> + %1 = ttg.convert_layout %0 : tensor<64xi32, #blocked2> -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3> + %3 = ttg.convert_layout %2 : tensor<64x1xi32, #blocked3> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> + tt.return %3 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> +} + +} +#Cv2 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#Av2k1 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=1}> +#Bv2k1 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=1}> +#Av2k2 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> +#Bv2k2 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> +#Av2k4 = #ttg.dot_op<{opIdx = 0, parent = #Cv2, kWidth=4}> +#Bv2k4 = #ttg.dot_op<{opIdx = 1, parent = #Cv2, kWidth=4}> +#ALR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}> +#BLR = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLC = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: tt.func @push_elementwise +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: %[[BCVT:.*]] = ttg.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> +// CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]] +// CHECK-SAME: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @push_elementwise( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> + %af8 = tt.bitcast %ai8: tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> + %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> + %dota = ttg.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} + + +// CHECK: tt.func @succeeds_if_arg_is_not_convert_layout +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] +// CHECK: %[[C:.*]] = tt.dot %[[AF16]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @succeeds_if_arg_is_not_convert_layout( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BLC> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %dotai8 = ttg.convert_layout %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xi8, #Av2k4> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #BLC> + %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> + %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> + %dotb = ttg.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> + %newc = tt.dot %dota, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} + +// CHECK: tt.func @push_inline_asm_op +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] +// CHECK: %[[AF16:.*]] = tt.elementwise_inline_asm {{.*}} %[[AF8E5]] +// CHECK: %[[C:.*]] = tt.dot %[[AF16]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @push_inline_asm_op( + %pa: tensor<16x16x!tt.ptr, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %dotb: tensor<16x16xf16, #Bv2k4>, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa : tensor<16x16x!tt.ptr, #ALR> + %dotaf8 = tt.bitcast %ai8 : tensor<16x16xi8, #ALR> -> tensor<16x16xf8E5M2, #ALR> + %dota = tt.elementwise_inline_asm "{ cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; }" {constraints = "=r,r", packed_element = 2 : i32, pure = true} %dotaf8 : tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> + %dota_cvt = ttg.convert_layout %dota : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> + %newc = tt.dot %dota_cvt, %dotb, %c : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} +} + +// ----- + +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> + +// CHECK: tt.func @push_convert_both_operands +// CHECK-DAG: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> +// CHECK-DAG: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> +// CHECK-DAG: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +tt.func @push_convert_both_operands( + %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> + %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> + %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> +} + +} + +// ----- + +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + +// CHECK: #[[BA:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[BB:.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +// CHECK: #[[MMA:.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = []}> + +// CHECK: tt.func @update_kwidth_slice +// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[ALOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BA]]> +// CHECK-DAG: %[[BLOAD:.*]] = tt.load %{{.*}} : tensor<16x16x!tt.ptr, #[[BB]]> +// CHECK-DAG: %[[ACVT:.*]] = ttg.convert_layout %[[ALOAD]] : tensor<16x16xf16, #[[BA]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BCVT:.*]] = ttg.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> +// CHECK-DAG: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +tt.func @update_kwidth_slice( + %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + %cst = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blockedB> + %a = tt.load %pa : tensor<16x16x!tt.ptr, #blockedA> + %b = tt.load %pb : tensor<16x16x!tt.ptr, #blockedB> + %ae = arith.extf %a : tensor<16x16xf16, #blockedA> to tensor<16x16xf32, #blockedA> + %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> + %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> + %al = ttg.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %bl = ttg.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %r = tt.dot %al, %bl, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: tt.func @propagate_dot_op_to_constant() + // CHECK: arith.constant dense<1.000000e+00> : tensor<64x32xf32, #mma> + tt.func @propagate_dot_op_to_constant() -> tensor<64x32xf32, #mma> { + %cst = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %cst1 = arith.constant dense<1.000000e+00> : tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %cst2 = arith.constant dense<1.000000e+00> : tensor<64x32xf32, #mma> + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %2 = tt.dot %cst1, %1, %cst2 : tensor<64x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<128x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + tt.return %2 : tensor<64x32xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: tt.func @propagate_dot_op_to_constant_above_for() + // CHECK: arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + tt.func @propagate_dot_op_to_constant_above_for() -> tensor<32x128xf32, #mma> { + %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %loop:1 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst_1) -> (tensor<32x128xf32, #mma>) : i32 { + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %3 = tt.dot %2, %1, %arg0, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> + scf.yield %3 : tensor<32x128xf32, #mma> + } + tt.return %loop#0 : tensor<32x128xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // We currently don't propagate through block arguments on hoistDotOperand + // that being said, https://github.com/triton-lang/triton/pull/5350 + // allowed to lift DotOperand(opIdx=1), which might be alright + + // CHECK: tt.func @do_not_propagate_through_block_arguments() + // CHECK: %[[THROUGH_FOR_OP:.*]] = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[THROUGH_FOR_OP]], + tt.func @do_not_propagate_through_block_arguments() -> tensor<32x128xf32, #mma> { + %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %loop:2 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst, %arg1 = %cst_1) -> (tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma>) : i32 { + %0 = arith.addf %cst, %arg0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %3 = tt.dot %2, %1, %arg1, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> + scf.yield %0, %3 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<32x128xf32, #mma> + } + tt.return %loop#1 : tensor<32x128xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice( + %pa: tensor<16x16x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %b: tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + // CHECK: tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice + // This checks that we propagate dot op layout given the following: + // initializer -> unsupported op -> initializer -> supported ops -> convert, + // where initializers can be constants or loads. + // CHECK: %[[LOAD1:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD1]] + %offset = arith.constant dense<16> : tensor<16x1xi32, #blocked> + %broadcast = tt.broadcast %offset : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked> + %pa2 = tt.addptr %pa, %broadcast : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> + %a = tt.load %pa2 : tensor<16x16x!tt.ptr, #blocked> + %ae = arith.extf %a : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked> + %ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = ttg.convert_layout %[[A_BLOCK]] : tensor<128x64xbf16, #blocked> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_DOTOP]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a = tt.fp_to_fp %a_bf16 : tensor<128x64xbf16, #blocked> -> tensor<128x64xf16, #blocked> + %dota = ttg.convert_layout %a: tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise_chained +// CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = ttg.convert_layout %[[A_BLOCK]] : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %a_i8 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a_f16 = arith.sitofp %a_i8 : tensor<128x64xi8, #blocked> to tensor<128x64xf16, #blocked> + %a_scaled = arith.mulf %a_f16, %cst : tensor<128x64xf16, #blocked> + %a_negated = arith.negf %a_scaled : tensor<128x64xf16, #blocked> + %dota = ttg.convert_layout %a_negated: tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice( + %pa1: tensor<16x1x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pa2: tensor<16x16x!tt.ptr, #blocked> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %b: tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, + %c: tensor<16x16xf32, #mma>) -> tensor<16x16xf32, #mma>{ + // CHECK: tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice + // Confirm that both loads feed directly into a convert_layout. + // CHECK: %[[LOAD1:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD1]] + // CHECK: %[[LOAD2:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD2]] + %a1 = tt.load %pa1 : tensor<16x1x!tt.ptr, #blocked> + %a2 = tt.load %pa2 : tensor<16x16x!tt.ptr, #blocked> + %ab = tt.broadcast %a1 : tensor<16x1xf16, #blocked> -> tensor<16x16xf16, #blocked> + %aa = arith.addf %ab, %a2 : tensor<16x16xf16, #blocked> + %ae = arith.extf %aa : tensor<16x16xf16, #blocked> to tensor<16x16xf32, #blocked> + %ac = ttg.convert_layout %ae : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %ac, %b, %c : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + tt.return %r : tensor<16x16xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked6 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked7 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 4, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[0, 32], [0, 64], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = []}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK: @remove_layout_dot_scaled + // CHECK: %[[LOAD1:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD1]] + // CHECK: %[[LOAD2:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD2]] + // CHECK: %[[LOAD3:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD3]] + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.dot + // CHECK-NOT: ttg.convert_layout + // CHECK: %[[STORE:.*]] = ttg.convert_layout + // CHECK: tt.store %[[PTR:.+]], %[[STORE]] + tt.func @remove_layout_dot_scaled(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0x7FC0> : tensor<32x128xbf16, #blocked> + %cst_0 = arith.constant dense<-1> : tensor<32x4xi8, #blocked1> + %cst_1 = arith.constant dense<7> : tensor<32x4xi16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked2> + %cst_3 = arith.constant dense<32> : tensor<32x1xi32, #blocked3> + %cst_4 = arith.constant dense<4> : tensor<32x1xi32, #blocked1> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %3 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<32x1xi32, #blocked4> + %4 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %5 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi32, #blocked3> + %6 = tt.splat %arg1 : i32 -> tensor<32x1xi32, #blocked4> + %7 = arith.muli %3, %6 : tensor<32x1xi32, #blocked4> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked4> + %9 = tt.addptr %8, %7 : tensor<32x1x!tt.ptr, #blocked4>, tensor<32x1xi32, #blocked4> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x64xi32, #blocked4> + %12 = tt.broadcast %9 : tensor<32x1x!tt.ptr, #blocked4> -> tensor<32x64x!tt.ptr, #blocked4> + %13 = tt.broadcast %11 : tensor<1x64xi32, #blocked4> -> tensor<32x64xi32, #blocked4> + %14 = tt.addptr %12, %13 : tensor<32x64x!tt.ptr, #blocked4>, tensor<32x64xi32, #blocked4> + %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<128x1xi32, #blocked5> + %17 = tt.splat %arg4 : i32 -> tensor<128x1xi32, #blocked5> + %18 = arith.muli %16, %17 : tensor<128x1xi32, #blocked5> + %19 = tt.splat %arg3 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked5> + %20 = tt.addptr %19, %18 : tensor<128x1x!tt.ptr, #blocked5>, tensor<128x1xi32, #blocked5> + %21 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> + %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %23 = tt.expand_dims %21 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x32xi32, #blocked5> + %24 = tt.expand_dims %22 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %25 = tt.broadcast %20 : tensor<128x1x!tt.ptr, #blocked5> -> tensor<128x32x!tt.ptr, #blocked5> + %26 = tt.broadcast %23 : tensor<1x32xi32, #blocked5> -> tensor<128x32xi32, #blocked5> + %27 = tt.addptr %25, %26 : tensor<128x32x!tt.ptr, #blocked5>, tensor<128x32xi32, #blocked5> + %28 = tt.load %14 : tensor<32x64x!tt.ptr, #blocked4> + %29 = ttg.convert_layout %28 : tensor<32x64xi8, #blocked4> -> tensor<32x64xi8, #blocked6> + %30 = tt.load %27 : tensor<128x32x!tt.ptr, #blocked5> + %31 = arith.muli %4, %cst_4 : tensor<32x1xi32, #blocked1> + %32 = tt.splat %arg2 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked1> + %33 = tt.addptr %32, %31 : tensor<32x1x!tt.ptr, #blocked1>, tensor<32x1xi32, #blocked1> + %34 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %35 = tt.expand_dims %34 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x4xi32, #blocked1> + %36 = tt.broadcast %33 : tensor<32x1x!tt.ptr, #blocked1> -> tensor<32x4x!tt.ptr, #blocked1> + %37 = tt.broadcast %35 : tensor<1x4xi32, #blocked1> -> tensor<32x4xi32, #blocked1> + %38 = tt.addptr %36, %37 : tensor<32x4x!tt.ptr, #blocked1>, tensor<32x4xi32, #blocked1> + %39 = tt.load %38 : tensor<32x4x!tt.ptr, #blocked1> + %40 = tt.bitcast %30 : tensor<128x32xi8, #blocked5> -> tensor<128x32xf8E4M3FN, #blocked5> + %41 = ttg.convert_layout %40 : tensor<128x32xf8E4M3FN, #blocked5> -> tensor<128x32xf8E4M3FN, #blocked2> + %42 = ttg.fp4_to_fp %29 {axis = 1 : i32} : tensor<32x64xi8, #blocked6> -> tensor<32x128xbf16, #blocked> + %43 = arith.extui %39 : tensor<32x4xi8, #blocked1> to tensor<32x4xi16, #blocked1> + %44 = arith.shli %43, %cst_1 : tensor<32x4xi16, #blocked1> + %45 = tt.bitcast %44 : tensor<32x4xi16, #blocked1> -> tensor<32x4xbf16, #blocked1> + %46 = ttg.convert_layout %45 : tensor<32x4xbf16, #blocked1> -> tensor<32x4xbf16, #ttg.slice<{dim = 2, parent = #blocked7}>> + %47 = tt.expand_dims %46 {axis = 2 : i32} : tensor<32x4xbf16, #ttg.slice<{dim = 2, parent = #blocked7}>> -> tensor<32x4x1xbf16, #blocked7> + %48 = tt.broadcast %47 : tensor<32x4x1xbf16, #blocked7> -> tensor<32x4x32xbf16, #blocked7> + %49 = tt.reshape %48 : tensor<32x4x32xbf16, #blocked7> -> tensor<32x128xbf16, #linear> + %50 = ttg.convert_layout %49 : tensor<32x128xbf16, #linear> -> tensor<32x128xbf16, #blocked> + %51 = arith.mulf %42, %50 : tensor<32x128xbf16, #blocked> + %52 = arith.cmpi eq, %39, %cst_0 : tensor<32x4xi8, #blocked1> + %53 = ttg.convert_layout %52 : tensor<32x4xi1, #blocked1> -> tensor<32x4xi1, #ttg.slice<{dim = 2, parent = #blocked7}>> + %54 = tt.expand_dims %53 {axis = 2 : i32} : tensor<32x4xi1, #ttg.slice<{dim = 2, parent = #blocked7}>> -> tensor<32x4x1xi1, #blocked7> + %55 = tt.broadcast %54 : tensor<32x4x1xi1, #blocked7> -> tensor<32x4x32xi1, #blocked7> + %56 = tt.reshape %55 : tensor<32x4x32xi1, #blocked7> -> tensor<32x128xi1, #linear> + %57 = ttg.convert_layout %56 : tensor<32x128xi1, #linear> -> tensor<32x128xi1, #blocked> + %58 = arith.select %57, %cst, %51 : tensor<32x128xi1, #blocked>, tensor<32x128xbf16, #blocked> + %59 = ttg.convert_layout %58 : tensor<32x128xbf16, #blocked> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> + %60 = tt.fp_to_fp %41 : tensor<128x32xf8E4M3FN, #blocked2> -> tensor<128x32xbf16, #blocked2> + %61 = ttg.convert_layout %60 : tensor<128x32xbf16, #blocked2> -> tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> + %62 = ttg.convert_layout %cst_2 : tensor<32x32xf32, #blocked2> -> tensor<32x32xf32, #mma> + %63 = ttg.convert_layout %59 : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %64 = ttg.convert_layout %61 : tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %65 = tt.dot %63, %64, %62 : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x32xf32, #mma> + %66 = ttg.convert_layout %65 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked2> + %67 = ttg.convert_layout %66 : tensor<32x32xf32, #blocked2> -> tensor<32x32xf32, #blocked2> + %68 = arith.muli %5, %cst_3 : tensor<32x1xi32, #blocked3> + %69 = tt.splat %arg5 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked3> + %70 = tt.addptr %69, %68 : tensor<32x1x!tt.ptr, #blocked3>, tensor<32x1xi32, #blocked3> + %71 = tt.broadcast %70 : tensor<32x1x!tt.ptr, #blocked3> -> tensor<32x32x!tt.ptr, #blocked3> + %72 = tt.broadcast %24 : tensor<1x32xi32, #blocked3> -> tensor<32x32xi32, #blocked3> + %73 = tt.addptr %71, %72 : tensor<32x32x!tt.ptr, #blocked3>, tensor<32x32xi32, #blocked3> + %74 = arith.truncf %67 : tensor<32x32xf32, #blocked2> to tensor<32x32xbf16, #blocked2> + %75 = ttg.convert_layout %74 : tensor<32x32xbf16, #blocked2> -> tensor<32x32xbf16, #blocked3> + tt.store %73, %75 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#linear = #ttg.linear<{register = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [8, 0, 0], [0, 1, 0], [0, 2, 0]], lane = [[0, 0, 8], [0, 0, 16], [1, 0, 0], [2, 0, 0], [4, 0, 0]], warp = [[0, 0, 0], [16, 0, 0]], block = []}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // Check that the remove-layout-conversions pass is idempotent + // in that it keeps the convert_layout ops next to the loads + // CHECK: tt.func @remove_layout_is_idempotent + tt.func @remove_layout_is_idempotent(%14: tensor<32x64x!tt.ptr, #blocked2>, %39: tensor<32x4x!tt.ptr, #blocked>, %27: tensor<128x32x!tt.ptr, #blocked3>) -> tensor<32x32xf32, #mma> { + %cst = arith.constant dense<0x7FC0> : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_3 = arith.constant dense<7> : tensor<32x4xi16, #ttg.slice<{dim = 2, parent = #linear}>> + %cst_4 = arith.constant dense<-1> : tensor<32x4xi8, #ttg.slice<{dim = 2, parent = #linear}>> + // CHECK: %[[LOAD1:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD1]] + // CHECK: %[[LOAD2:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD2]] + // CHECK: %[[LOAD3:.*]] = tt.load + // CHECK: ttg.convert_layout %[[LOAD3]] + %28 = tt.load %14 : tensor<32x64x!tt.ptr, #blocked2> + %29 = ttg.convert_layout %28 : tensor<32x64xi8, #blocked2> -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %30 = tt.load %27 : tensor<128x32x!tt.ptr, #blocked3> + %31 = ttg.convert_layout %30 : tensor<128x32xi8, #blocked3> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %40 = tt.load %39 : tensor<32x4x!tt.ptr, #blocked> + %41 = ttg.convert_layout %40 : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #ttg.slice<{dim = 2, parent = #linear}>> + %42 = tt.bitcast %31 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %43 = ttg.fp4_to_fp %29 {axis = 1 : i32} : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %44 = arith.extui %41 : tensor<32x4xi8, #ttg.slice<{dim = 2, parent = #linear}>> to tensor<32x4xi16, #ttg.slice<{dim = 2, parent = #linear}>> + %45 = arith.shli %44, %cst_3 : tensor<32x4xi16, #ttg.slice<{dim = 2, parent = #linear}>> + %46 = tt.bitcast %45 : tensor<32x4xi16, #ttg.slice<{dim = 2, parent = #linear}>> -> tensor<32x4xbf16, #ttg.slice<{dim = 2, parent = #linear}>> + %47 = tt.expand_dims %46 {axis = 2 : i32} : tensor<32x4xbf16, #ttg.slice<{dim = 2, parent = #linear}>> -> tensor<32x4x1xbf16, #linear> + %48 = tt.broadcast %47 : tensor<32x4x1xbf16, #linear> -> tensor<32x4x32xbf16, #linear> + %49 = tt.reshape %48 : tensor<32x4x32xbf16, #linear> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %50 = arith.mulf %43, %49 : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %51 = arith.cmpi eq, %41, %cst_4 : tensor<32x4xi8, #ttg.slice<{dim = 2, parent = #linear}>> + %52 = tt.expand_dims %51 {axis = 2 : i32} : tensor<32x4xi1, #ttg.slice<{dim = 2, parent = #linear}>> -> tensor<32x4x1xi1, #linear> + %53 = tt.broadcast %52 : tensor<32x4x1xi1, #linear> -> tensor<32x4x32xi1, #linear> + %54 = tt.reshape %53 : tensor<32x4x32xi1, #linear> -> tensor<32x128xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %55 = arith.select %54, %cst, %50 : tensor<32x128xi1, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %56 = tt.fp_to_fp %42 : tensor<128x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> + %57 = tt.dot %55, %56, %cst_0 : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x32xf32, #mma> + tt.return %57 : tensor<32x32xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 16, 2], threadsPerWarp = [16, 2, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked6 = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + tt.func @join_reshape_dot(%112: tensor<128x32x!tt.ptr, #blocked2>, %117: tensor<128x32xi1, #blocked2>, %128: tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) -> tensor<16x128xf32, #mma> { + %cst = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked> + // CHECK: %[[LOAD_I8:.*]] = tt.load {{.*}} tensor<128x32x!tt.ptr + // CHECK: ttg.convert_layout %[[LOAD_I8]] {{.*}} #linear + %118 = tt.load %112, %117 : tensor<128x32x!tt.ptr, #blocked2> + %121:2 = tt.elementwise_inline_asm "" {constraints = "=r,=r,=r,=r,r", packed_element = 4 : i32, pure = true} %118 : tensor<128x32xi8, #blocked2> -> tensor<128x32xbf16, #blocked2>, tensor<128x32xbf16, #blocked2> + %122 = tt.join %121#0, %121#1 : tensor<128x32xbf16, #blocked2> -> tensor<128x32x2xbf16, #blocked4> + %123 = tt.reshape %122 : tensor<128x32x2xbf16, #blocked4> -> tensor<128x64xbf16, #blocked5> + %124 = tt.trans %123 {order = array} : tensor<128x64xbf16, #blocked5> -> tensor<64x128xbf16, #blocked6> + %126 = ttg.convert_layout %124 : tensor<64x128xbf16, #blocked6> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %127 = ttg.convert_layout %cst : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #mma> + %129 = ttg.convert_layout %126 : tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %130 = tt.dot %128, %129, %127, inputPrecision = tf32 : tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x128xf32, #mma> + tt.return %130 : tensor<16x128xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [2, 16, 1], warpsPerCTA = [1, 1, 1], order = [2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 1], order = [1, 0]}> +#linear = #ttg.linear<{register = [], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0]], warp = [], block = []}> +module attributes {"ttg.num-warps" = 1 : i32, ttg.target = "cuda:80"} { + // CHECK-LABEL: join_forward + tt.func @join_forward(%arg0: tensor<2x16xf32, #linear>) -> tensor<2x16x2xf32, #blocked> { + // CHECK-LABEL: tt.join + // CHECK-LABEL: ttg.convert_layout + %0 = ttg.convert_layout %arg0 : tensor<2x16xf32, #linear> -> tensor<2x16xf32, #blocked1> + %1 = tt.join %0, %0 : tensor<2x16xf32, #blocked1> -> tensor<2x16x2xf32, #blocked> + tt.return %1 : tensor<2x16x2xf32, #blocked> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/dot-operands.mlir b/third_party/enflame/include/triton/test/TritonGPU/dot-operands.mlir new file mode 100644 index 000000000..9a148f0d9 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/dot-operands.mlir @@ -0,0 +1,138 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -canonicalize | FileCheck %s + + +#blockedA = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @a_impl +// CHECK-NOT: %[[SELECT:.*]] = arith.select {{.*}} : tensor<128x128xi1, #ttg.dot_op<{{.*}}>, tensor<128x128xf16, #ttg.dot_op<{{.*}}> + tt.func @a_impl(%pa: tensor<128x128x!tt.ptr, #blocked>) -> tensor<128x128xf32, #mma> { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_3 = arith.constant dense<5> : tensor<128x1xi32, #blocked> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked> + %tl = tt.load %pa : tensor<128x128x!tt.ptr, #blocked> + %tr = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %te = tt.expand_dims %tr {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %tc = arith.cmpi slt, %te, %cst_3 : tensor<128x1xi32, #blocked> + %tb = tt.broadcast %tc : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> + %ts = arith.select %tb, %tl, %cst_4 : tensor<128x128xi1, #blocked>, tensor<128x128xf16, #blocked> + %conv = ttg.convert_layout %ts : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %td = tt.dot %cst_0, %conv, %cst : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + tt.return %td : tensor<128x128xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: mma_reorder_transpose +// CHECK: ttg.local_alloc +// CHECK: ttg.memdesc_trans +// CHECK: ttng.warp_group_dot + tt.func @mma_reorder_transpose(%t: tensor<64x128xf16, #blocked1>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a = tt.trans %t {order = array} : tensor<64x128xf16, #blocked1> -> tensor<128x64xf16, #blocked> + %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: mmav2_reorder_transpose +// CHECK: ttg.local_alloc +// CHECK: ttg.memdesc_trans +// CHECK: ttg.local_load +// CHECK: tt.dot + tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a = tt.trans %t {order = array} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked> + %cv = ttg.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: mmav2_transpose_indirect +// CHECK: tt.trans +// CHECK: ttg.convert_layout +// CHECK: arith.addf +// CHECK: tt.dot + tt.func @mmav2_transpose_indirect(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a = tt.trans %t {order = array} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked> + %cv = ttg.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %add = arith.addf %cv, %cst : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %add, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked8 = #ttg.blocked<{sizePerThread = [1, 1, 1, 2, 4], threadsPerWarp = [1, 1, 16, 2, 1], warpsPerCTA = [2, 1, 2, 1, 1], order = [4, 3, 2, 1, 0]}> +#blocked9 = #ttg.blocked<{sizePerThread = [1, 2, 1, 1, 4], threadsPerWarp = [1, 2, 16, 1, 1], warpsPerCTA = [2, 1, 2, 1, 1], order = [4, 1, 2, 3, 0]}> +#blocked10 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}> +#blocked11 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @scales_in_shmem + // CHECK: %[[A_LA:.*]] = ttg.local_alloc + // CHECK: %[[B_LA:.*]] = ttg.local_alloc + // CHECK: ttng.tc_gen5_mma_scaled {{.*}}, %[[A_LA]], %[[B_LA]], + + tt.func public @scales_in_shmem( + %scale: tensor<2x512x!tt.ptr, #blocked4> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, + %A_sh: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, + %B_sh: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, + %acc_tm: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> + ) attributes {noinline = false} { + %true = arith.constant true + %A_la = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable> + %B_la = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable> + %A_ll = ttg.local_load %A_la : !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512> -> tensor<2x512xi8, #blocked4> + %B_ll = ttg.local_load %B_la : !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512> -> tensor<2x512xi8, #blocked4> + %A_r = tt.reshape %A_ll : tensor<2x512xi8, #blocked4> -> tensor<2x1x32x4x4xi8, #blocked8> + %B_r = tt.reshape %B_ll : tensor<2x512xi8, #blocked4> -> tensor<2x1x32x4x4xi8, #blocked8> + %A_tr = tt.trans %A_r {order = array} : tensor<2x1x32x4x4xi8, #blocked8> -> tensor<2x4x32x1x4xi8, #blocked9> + %B_tr = tt.trans %B_r {order = array} : tensor<2x1x32x4x4xi8, #blocked8> -> tensor<2x4x32x1x4xi8, #blocked9> + %A_cv = ttg.convert_layout %A_tr : tensor<2x4x32x1x4xi8, #blocked9> -> tensor<2x4x32x1x4xi8, #blocked10> + %B_cv = ttg.convert_layout %B_tr : tensor<2x4x32x1x4xi8, #blocked9> -> tensor<2x4x32x1x4xi8, #blocked10> + %A_r2 = tt.reshape %A_cv : tensor<2x4x32x1x4xi8, #blocked10> -> tensor<256x4xi8, #blocked11> + %B_r2 = tt.reshape %B_cv : tensor<2x4x32x1x4xi8, #blocked10> -> tensor<256x4xi8, #blocked11> + %A_tm = ttng.tmem_alloc %A_r2 : (tensor<256x4xi8, #blocked11>) -> !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory> + %B_tm = ttng.tmem_alloc %B_r2 : (tensor<256x4xi8, #blocked11>) -> !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory> + ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm, %A_tm, %B_tm, %true, %true lhs = e5m2 rhs = e5m2 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>, !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>, i1, i1) -> () + tt.return +} +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/fence-inserstion.mlir b/third_party/enflame/include/triton/test/TritonGPU/fence-inserstion.mlir new file mode 100644 index 000000000..739fa2853 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/fence-inserstion.mlir @@ -0,0 +1,49 @@ +// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: matmul_like_fence + tt.func public @matmul_like_fence(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked2>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + // CHECK: ttng.fence_async_shared + %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: fence_outside_loop + tt.func public @fence_outside_loop(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + // CHECK: ttng.fence_async_shared + // CHECK: scf.for + // CHECK-NOT: ttng.fence_async_shared + // CHECK: ttng.warp_group_dot + scf.for %iv0 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { + scf.for %iv1 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { + %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> + } + } + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/fuse-nested-loops.mlir b/third_party/enflame/include/triton/test/TritonGPU/fuse-nested-loops.mlir new file mode 100644 index 000000000..b46af95a3 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/fuse-nested-loops.mlir @@ -0,0 +1,569 @@ +// RUN: triton-opt %s --allow-unregistered-dialect --tritongpu-fuse-nested-loops -cse | FileCheck %s + +// CHECK-LABEL: @empty_function +tt.func @empty_function() { + tt.return +} + +// CHECK-LABEL: @no_fusion +tt.func @no_fusion(%lb: index, %ub: index, %step: index) -> index { + %c0 = arith.constant 0 : index + // CHECK: before.loop + "before.loop"() : () -> () + // CHECK-NEXT: scf.for + %0 = scf.for %i = %lb to %ub step %step iter_args(%k = %c0) -> index { + // CHECK-NEXT: body + %1 = "body"(%i, %k) : (index, index) -> index + // CHECK-NEXT: yield + scf.yield %1 : index + // CHECK-NEXT: } + } {"ttg.always-fuse"} + // CHECK-NEXT: after.loop + "after.loop"() : () -> () + tt.return %0 : index +} + +// CHECK-LABEL: @fuse_one_level_simple +// CHECK-SAME: [[LBI:%.*]]: i64, [[UBI:%.*]]: i64, [[STEPI:%.*]]: i64, [[LBJ:%.*]]: i64, [[UBJ:%.*]]: i64, [[STEPJ:%.*]]: i64 +tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64) { + // len_i = len(range(lbi, ubi, stepi)) + // + // CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] + // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] + + // len_j = len(range(lbj0, ubj0, stepj0)) + // + // CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]] + // CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]] + + // inner_len = max(1, len_j0) + // + // CHECK-NEXT: [[PLEN0:%.*]] = arith.constant 0 : i64 + // CHECK: [[LEN_J_CLAMP:%.*]] = arith.maxsi %c1_i64, [[LEN_J]] + // CHECK-NEXT: [[PLEN1:%.*]] = arith.addi [[PLEN0]], [[LEN_J_CLAMP]] + // CHECK-NEXT: [[INNER_LEN:%.*]] = arith.subi [[PLEN1]], %c0_i64 + + // total_iters = len_i * max(1, inner_len) + // + // CHECK: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]] + + // T = -1 + // i = lbi - stepi + // j = None + // for _ in range(total_iters): + // + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] + // CHECK: scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( + // CHECK-SAME: [[T_ARG:%.*]] = %c-1_i64, [[I_ARG:%.*]] = [[I_INIT]], [[J_ARG:%.*]] = %c0_i64) -> (i64, i64, i64) : i64 { + scf.for %i = %lbi to %ubi step %stepi : i64 { + // T = 0 if T == (inner_len - 1) else T + 1 + // + // CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64 + // CHECK-NEXT: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 + // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[T_ARG]], [[T_END]] + // CHECK-NEXT: [[T:%.*]] = arith.select [[ROLLOVER]], %c0_i64, [[T_PLUS_1]] + + // if T == 0: + // i += stepi + // prologue(i) + // j = lbj + // + // CHECK: [[START:%.*]] = arith.subi %c0_i64, %c0_i64 : i64 + // CHECK-NEXT: [[PROLOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[START]] + // CHECK-NEXT: [[JI:%.*]]:2 = scf.if [[PROLOGUE_COND]] -> (i64, i64) { + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] + // CHECK-NEXT: "prologue"([[I]]) : (i64) -> () + // CHECK-NEXT: yield [[LBJ]], [[I]] + // CHECK-NEXT: } else { + // CHECK-NEXT: yield [[J_ARG]], [[I_ARG]] + // CHECK-NEXT: } + "prologue"(%i) : (i64) -> () + + // if T >= 0 and T < len_j: + // body(i, j) + // j += stepj + // + // CHECK: [[END:%.*]] = arith.addi [[START]], [[LEN_J]] + // CHECK-NEXT: [[GE:%.*]] = arith.cmpi sge, [[T]], [[START]] + // CHECK-NEXT: [[LT:%.*]] = arith.cmpi slt, [[T]], [[END]] + // CHECK-NEXT: [[COND:%.*]] = arith.andi [[GE]], [[LT]] + // CHECK-NEXT: [[J_NEXT:%.*]] = scf.if [[COND]] -> (i64) { + // CHECK-NEXT: "body"([[JI]]#1, [[JI]]#0) : (i64, i64) -> () + // CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[JI]]#0, [[STEPJ]] + // CHECK-NEXT: yield [[J_INCR]] + // CHECK-NEXT: } else { + // CHECK-NEXT: yield [[JI]]#0 + // CHECK-NEXT: } + scf.for %j = %lbj to %ubj step %stepj : i64 { + "body"(%i, %j) : (i64, i64) -> () + } + + // if T == max(1, len_j) - 1: + // epilogue(i) + // i += stepi + // + // CHECK-NEXT: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[T_END]] + // CHECK-NEXT: scf.if [[EPILOGUE_COND]] { + // CHECK-NEXT: "epilogue"([[JI]]#1) : (i64) -> () + // CHECK-NEXT: } else { + // CHECK-NEXT: } + "epilogue"(%i) : (i64) -> () + + // CHECK-NEXT: yield [[T]], [[JI]]#1, [[J_NEXT]] : i64, i64, i64 + } {"ttg.always-fuse"} + tt.return +} + +// CHECK-LABEL: @fuse_one_level_inouts +// CHECK-SAME: [[LBI:%.*]]: i64, [[UBI:%.*]]: i64, [[STEPI:%.*]]: i64, [[LBJ:%.*]]: i64, [[UBJ:%.*]]: i64, [[STEPJ:%.*]]: i64 +// CHECK-SAME: [[INOUT:%.*]]: index +tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64, %inout: index) -> index { + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] + // CHECK: [[OUTER_OUTS:%.*]]:6 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS:%.*]] step %c1_i64 iter_args( + // CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64, + // CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]] + // CHECK-SAME: [[M:%arg[0-9]+]] = [[INOUT]] + // CHECK-SAME: [[J_ARG:%arg[0-9]+]] = %c0_i64 + // CHECK-SAME: [[K_ARG:%arg[0-9]+]] = %c0 + // CHECK-SAME: [[PROLOGUE_OUT_ARG:%arg[0-9]+]] = %c0 + // CHECK-SAME: ) -> (i64, i64, index, i64, index, index) : i64 { + %outer_out = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %inout) -> index : i64 { + // if T == 0: + // i += stepi + // prologue(i) + // j = lbj + // + // CHECK: [[PROLOGUE_OUTS:%.*]]:4 = scf.if %{{[0-9]+}} -> (i64, index, index, i64) { + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] + // CHECK-NEXT: [[PROLOGUE_RES:%.*]] = "prologue"([[I]], [[INOUT]], [[M]]) : (i64, index, index) -> index + // CHECK-NEXT: yield [[LBJ]], [[PROLOGUE_RES]], [[M]], [[I]] + // CHECK-NEXT: } else { + // CHECK-NEXT: yield [[J_ARG]], [[PROLOGUE_OUT_ARG]], [[K_ARG]], [[I_ARG]] + // CHECK-NEXT: } + // + // J := [[PROLOGUE_OUTS]]#0 + // PROLOGUE_OUT := [[PROLOGUE_OUTS]]#1 + // K := [[PROLOGUE_OUTS]]#2 + // I := [[PROLOGUE_OUTS]]#3 + %prologue_out = "prologue"(%i, %inout, %m) : (i64, index, index) -> index + + // if T >= 0 and T < len_j: + // body(i, j) + // j += stepj + // + // CHECK: [[BODY_OUTS:%.*]]:2 = scf.if {{.*}} -> (i64, index) { + // CHECK-NEXT: [[BODY_OUT:%.*]] = "body"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#0, [[PROLOGUE_OUTS]]#2, [[PROLOGUE_OUTS]]#1, [[M]]) : (i64, i64, index, index, index) -> index + // CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[PROLOGUE_OUTS]]#0, [[STEPJ]] + // CHECK-NEXT: yield [[J_INCR]], [[BODY_OUT]] + // CHECK-NEXT: } else { + // CHECK-NEXT: yield [[PROLOGUE_OUTS]]#0, [[K_ARG]] + // CHECK-NEXT: } + %inner_out = scf.for %j = %lbj to %ubj step %stepj iter_args(%k = %m) -> index : i64 { + %body_out = "body"(%i, %j, %k, %prologue_out, %m) : (i64, i64, index, index, index) -> index + scf.yield %body_out : index + } + + // if T == max(1, len_j) - 1: + // epilogue(i) + // i += stepi + // + // CHECK: [[EPILOGUE_OUTS:%.*]] = scf.if {{.*}} -> (index) { + // CHECK-NEXT: [[EPILOGUE_OUT:%.*]] = "epilogue"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#1, [[BODY_OUTS]]#1, [[M]]) : (i64, index, index, index) -> index + // CHECK-NEXT: yield [[EPILOGUE_OUT]] + // CHECK-NEXT: } else { + // CHECK-NEXT: yield [[M]] + // CHECK-NEXT: } + %epilogue_out = "epilogue"(%i, %prologue_out, %inner_out, %m) : (i64, index, index, index) -> index + + // CHECK-NEXT: yield %{{.*}}, [[PROLOGUE_OUTS]]#3, [[EPILOGUE_OUTS]], [[BODY_OUTS]]#0, [[BODY_OUTS]]#1, [[PROLOGUE_OUTS]]#1 : i64, i64, index, i64, index, index + scf.yield %epilogue_out : index + } {"ttg.always-fuse"} + // CHECK: return [[OUTER_OUTS]]#2 + tt.return %outer_out : index +} + +// CHECK-LABEL: @multiple_loops +tt.func @multiple_loops( + // CHECK-SAME: [[LBI:%arg[0-9]+]]: i64, [[UBI:%arg[0-9]+]]: i64, [[STEPI:%arg[0-9]+]]: i64, + // CHECK-SAME: [[LBJ0:%arg[0-9]+]]: i64, [[UBJ0:%arg[0-9]+]]: i64, [[STEPJ0:%arg[0-9]+]]: i64, + // CHECK-SAME: [[LBJ1:%arg[0-9]+]]: i64, [[UBJ1:%arg[0-9]+]]: i64, [[STEPJ1:%arg[0-9]+]]: i64, + // CHECK-SAME: [[LBJ2:%arg[0-9]+]]: i64, [[UBJ2:%arg[0-9]+]]: i64, [[STEPJ2:%arg[0-9]+]]: i64, + // CHECK-SAME: [[M0:%arg[0-9]+]]: f32 + %lbi: i64, %ubi: i64, %stepi: i64, + %lbj0: i64, %ubj0: i64, %stepj0: i64, + %lbj1: i64, %ubj1: i64, %stepj1: i64, + %lbj2: i64, %ubj2: i64, %stepj2: i64, + %m0: f32) -> f32 { + // CHECK: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] + // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] + // CHECK-NEXT: [[DIFF_J0:%.*]] = arith.subi [[UBJ0]], [[LBJ0]] + // CHECK-NEXT: [[LEN_J0:%.*]] = arith.ceildivsi [[DIFF_J0]], [[STEPJ0]] + // CHECK-NEXT: [[DIFF_J1:%.*]] = arith.subi [[UBJ1]], [[LBJ1]] + // CHECK-NEXT: [[LEN_J1:%.*]] = arith.ceildivsi [[DIFF_J1]], [[STEPJ1]] + // CHECK-NEXT: [[DIFF_J2:%.*]] = arith.subi [[UBJ2]], [[LBJ2]] + // CHECK-NEXT: [[LEN_J2:%.*]] = arith.ceildivsi [[DIFF_J2]], [[STEPJ2]] + + // CHECK: [[PLEN0:%.*]] = arith.constant 0 : i64 + // CHECK: [[LEN_J0_CLAMP:%.*]] = arith.maxsi %c1_i64, [[LEN_J0]] + // CHECK-NEXT: [[PLEN1:%.*]] = arith.addi [[PLEN0]], [[LEN_J0_CLAMP]] + // CHECK-NEXT: [[LEN_J1_CLAMP:%.*]] = arith.maxsi %c1_i64, [[LEN_J1]] + // CHECK-NEXT: [[PLEN2:%.*]] = arith.addi [[PLEN1]], [[LEN_J1_CLAMP]] + // CHECK-NEXT: [[LEN_J2_CLAMP:%.*]] = arith.maxsi %c1_i64, [[LEN_J2]] + // CHECK-NEXT: [[PLEN3:%.*]] = arith.addi [[PLEN2]], [[LEN_J2_CLAMP]] + // CHECK: [[INNER_LEN:%.*]] = arith.subi [[PLEN3]], %c2_i64 + // CHECK-NEXT: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]] + + // CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]] + // CHECK: [[OUTS:%.*]]:12 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args( + // CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64, + // CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]], + // CHECK-SAME: [[M:%arg[0-9]+]] = [[M0]], + // CHECK-SAME: [[J0_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[J1_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[J2_ARG:%arg[0-9]+]] = %c0_i64, + // CHECK-SAME: [[BODY0_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[BODY1_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[BODY2_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE0_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE1_ARG:%arg[0-9]+]] = %cst, + // CHECK-SAME: [[PROLOGUE2_ARG:%arg[0-9]+]] = %cst) + %mN = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %m0) -> f32 : i64 { + + // CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64 + // CHECK-NEXT: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64 + // CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[T_ARG]], [[T_END]] + // CHECK-NEXT: [[T:%.*]] = arith.select [[ROLLOVER]], %c0_i64, [[T_PLUS_1]] + + // CHECK: [[START0:%.*]] = arith.subi [[PLEN0]], %c0_i64 + // CHECK-NEXT: [[PROLOGUE_COND0:%.*]] = arith.cmpi eq, [[T]], [[START0]] + // CHECK-NEXT: [[PROLOGUE0_OUTS:%.*]]:4 = scf.if [[PROLOGUE_COND0]] + // CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]] + // CHECK-NEXT: [[RES:%.*]] = "prologue0"([[I]], [[M]]) + // CHECK-NEXT: yield [[LBJ0]], [[RES]], [[RES]], [[I]] + // CHECK-NEXT: else + // CHECK-NEXT: yield [[J0_ARG]], [[PROLOGUE0_ARG]], [[BODY0_ARG]], [[I_ARG]] + %k00 = "prologue0"(%i, %m) : (i64, f32) -> f32 + + // CHECK: [[END0:%.*]] = arith.addi [[START0]], [[LEN_J0]] + // CHECK-NEXT: [[GE0:%.*]] = arith.cmpi sge, [[T]], [[START0]] + // CHECK-NEXT: [[LT0:%.*]] = arith.cmpi slt, [[T]], [[END0]] + // CHECK-NEXT: [[BODY_COND0:%.*]] = arith.andi [[GE0]], [[LT0]] + // CHECK-NEXT: [[BODY0_OUTS:%.*]]:2 = scf.if [[BODY_COND0]] + // CHECK-NEXT: [[RES:%.*]] = "body0"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE0_OUTS]]#0, [[PROLOGUE0_OUTS]]#2) + // CHECK-NEXT: [[NEXT_J0:%.*]] = arith.addi [[PROLOGUE0_OUTS]]#0, [[STEPJ0]] + // CHECK-NEXT: yield [[NEXT_J0]], [[RES]] + // CHECK-NEXT: else + // CHECK-NEXT: yield [[PROLOGUE0_OUTS]]#0, [[BODY0_ARG]] + %k0N = scf.for %j0 = %lbj0 to %ubj0 step %stepj0 iter_args(%k0 = %k00) -> f32 : i64 { + %res = "body0"(%i, %j0, %k0) : (i64, i64, f32) -> f32 + scf.yield %res : f32 + } + + // CHECK: [[START1:%.*]] = arith.subi [[PLEN1]], %c1_i64 + // CHECK-NEXT: [[PROLOGUE_COND1:%.*]] = arith.cmpi eq, [[T]], [[START1]] + // CHECK-NEXT: [[PROLOGUE1_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND1]] + // CHECK-NEXT: [[RES:%.*]] = "prologue1"([[PROLOGUE0_OUTS]]#3, [[BODY0_OUTS]]#1) + // CHECK-NEXT: yield [[LBJ1]], [[RES]], [[RES]] + // CHECK-NEXT: else + // CHECK-NEXT: yield [[J1_ARG]], [[PROLOGUE1_ARG]], [[BODY1_ARG]] + %k10 = "prologue1"(%i, %k0N) : (i64, f32) -> f32 + + // CHECK: [[END1:%.*]] = arith.addi [[START1]], [[LEN_J1]] + // CHECK-NEXT: [[GE1:%.*]] = arith.cmpi sge, [[T]], [[START1]] + // CHECK-NEXT: [[LT1:%.*]] = arith.cmpi slt, [[T]], [[END1]] + // CHECK-NEXT: [[BODY_COND1:%.*]] = arith.andi [[GE1]], [[LT1]] + // CHECK-NEXT: [[BODY1_OUTS:%.*]]:2 = scf.if [[BODY_COND1]] + // CHECK-NEXT: [[RES:%.*]] = "body1"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE1_OUTS]]#0, [[PROLOGUE1_OUTS]]#2) + // CHECK-NEXT: [[NEXT_J1:%.*]] = arith.addi [[PROLOGUE1_OUTS]]#0, [[STEPJ1]] + // CHECK-NEXT: yield [[NEXT_J1]], [[RES]] + // CHECK-NEXT: else + // CHECK-NEXT: yield [[PROLOGUE1_OUTS]]#0, [[BODY1_ARG]] + %k1N = scf.for %j1 = %lbj1 to %ubj1 step %stepj1 iter_args(%k1 = %k10) -> f32 : i64 { + %res = "body1"(%i, %j1, %k1) : (i64, i64, f32) -> f32 + scf.yield %res : f32 + } + + // CHECK: [[START2:%.*]] = arith.subi [[PLEN2]], %c2_i64 + // CHECK-NEXT: [[PROLOGUE_COND2:%.*]] = arith.cmpi eq, [[T]], [[START2]] + // CHECK-NEXT: [[PROLOGUE2_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND2]] + // CHECK-NEXT: [[RES:%.*]] = "prologue2"([[PROLOGUE0_OUTS]]#3, [[BODY1_OUTS]]#1) + // CHECK-NEXT: yield [[LBJ2]], [[RES]], [[RES]] + // CHECK-NEXT: else + // CHECK-NEXT: yield [[J2_ARG]], [[PROLOGUE2_ARG]], [[BODY2_ARG]] + %k20 = "prologue2"(%i, %k1N) : (i64, f32) -> f32 + + // CHECK: [[END2:%.*]] = arith.addi [[START2]], [[LEN_J2]] + // CHECK-NEXT: [[GE2:%.*]] = arith.cmpi sge, [[T]], [[START2]] + // CHECK-NEXT: [[LT2:%.*]] = arith.cmpi slt, [[T]], [[END2]] + // CHECK-NEXT: [[BODY_COND2:%.*]] = arith.andi [[GE2]], [[LT2]] + // CHECK-NEXT: [[BODY2_OUTS:%.*]]:2 = scf.if [[BODY_COND2]] + // CHECK-NEXT: [[RES:%.*]] = "body2"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE2_OUTS]]#0, [[PROLOGUE2_OUTS]]#2) + // CHECK-NEXT: [[NEXT_J2:%.*]] = arith.addi [[PROLOGUE2_OUTS]]#0, [[STEPJ2]] + // CHECK-NEXT: yield [[NEXT_J2]], [[RES]] + // CHECK-NEXT: else + // CHECK-NEXT: yield [[PROLOGUE2_OUTS]]#0, [[BODY2_ARG]] + %k2N = scf.for %j2 = %lbj2 to %ubj2 step %stepj2 iter_args(%k2 = %k20) -> f32 : i64 { + %res = "body2"(%i, %j2, %k2) : (i64, i64, f32) -> f32 + scf.yield %res : f32 + } + + // CHECK: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[T_END]] + // CHECK-NEXT: [[EPILOGUE_OUTS:%.*]] = scf.if [[EPILOGUE_COND]] + // CHECK-NEXT: [[RES:%.*]] = "epilogue"([[PROLOGUE0_OUTS]]#3, [[BODY2_OUTS]]#1) + // CHECK-NEXT: yield [[RES]] + // CHECK-NEXT: else + // CHECK-NEXT: yield [[M]] + %out = "epilogue"(%i, %k2N) : (i64, f32) -> f32 + + // CHECK: scf.yield [[T]], [[PROLOGUE0_OUTS]]#3, [[EPILOGUE_OUTS]], + // CHECK-SAME: [[BODY0_OUTS]]#0, [[BODY1_OUTS]]#0, [[BODY2_OUTS]]#0, + // CHECK-SAME: [[PROLOGUE0_OUTS]]#1, [[PROLOGUE1_OUTS]]#1, [[PROLOGUE2_OUTS]]#1 : + scf.yield %out : f32 + } {"ttg.always-fuse"} + // CHECK: return [[OUTS]]#2 + tt.return %mN : f32 +} + +// CHECK-LABEL: @two_loop_nests +tt.func @two_loop_nests(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64) { + // CHECK-COUNT-2: scf.for + scf.for %i = %lbi to %ubi step %stepi : i64 { + scf.for %j = %lbj to %ubj step %stepj : i64 { + "body"(%i, %j) : (i64, i64) -> () + } + } {"ttg.always-fuse"} + scf.for %i = %lbi to %ubi step %stepi : i64 { + scf.for %j = %lbj to %ubj step %stepj : i64 { + "body"(%i, %j) : (i64, i64) -> () + } + } {"ttg.always-fuse"} + // CHECK-NOT: scf.for + // CHECK: tt.return + tt.return +} + +// CHECK-LABEL: @hoist_loop_bound_computations +// CHECK-SAME: [[LBI:%.*]]: i64, [[UBI:%.*]]: i64, [[STEPI:%.*]]: i64 +tt.func @hoist_loop_bound_computations(%lbi: i64, %ubi: i64, %stepi: i64) { + // CHECK-NEXT: [[LBJ:%.*]] = arith.addi [[LBI]], [[STEPI]] + // CHECK-NEXT: [[UBJ:%.*]] = arith.addi [[UBI]], [[STEPI]] + // CHECK-NEXT: [[STEPJ:%.*]] = arith.addi [[STEPI]], [[STEPI]] + + // CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] + // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] + // CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]] + // CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]] + + // CHECK: scf.for + scf.for %i = %lbi to %ubi step %stepi : i64 { + %lbj = arith.addi %lbi, %stepi : i64 + %ubj = arith.addi %ubi, %stepi : i64 + %stepj = arith.addi %stepi, %stepi : i64 + // CHECK: [[J:%.*]]:2 = scf.if + // CHECK: yield [[LBJ]] + + // CHECK: scf.if + // CHECK-NEXT: "body" + // CHECK-NEXT: arith.addi [[J]]#0, [[STEPJ]] + scf.for %j = %lbj to %ubj step %stepj : i64 { + "body"(%i, %j) : (i64, i64) -> () + } + } {"ttg.always-fuse"} + tt.return +} + +// CHECK-LABEL: @cannot_fuse +tt.func @cannot_fuse(%lbi: i64, %ubi: i64, %stepi: i64) { + // CHECK-COUNT-2: scf.for + scf.for %i = %lbi to %ubi step %stepi : i64 { + %lbj = arith.addi %lbi, %stepi : i64 + %ubj = arith.addi %ubi, %i : i64 + %stepj = arith.addi %stepi, %stepi : i64 + scf.for %j = %lbj to %ubj step %stepj : i64 { + "body"(%i, %j) : (i64, i64) -> () + } + } {"ttg.always-fuse"} + // CHECK-NOT: scf.for + tt.return +} + +// CHECK-LABEL: @upcast_i16_to_i32 +// CHECK-SAME: [[LBI:%.*]]: i32, [[UBI:%.*]]: i32, [[STEPI:%.*]]: i32, [[LBJ:%.*]]: i16, [[UBJ:%.*]]: i16, [[STEPJ:%.*]]: i16 +tt.func @upcast_i16_to_i32(%lbi: i32, %ubi: i32, %stepi: i32, %lbj: i16, %ubj: i16, %stepj: i16) { + // CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] : i32 + // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] : i32 + // CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]] : i16 + // CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]] : i16 + + // CHECK: arith.extsi [[LEN_J]] : i16 to i32 + scf.for %i = %lbi to %ubi step %stepi : i32 { + scf.for %j = %lbj to %ubj step %stepj : i16 { + "body"(%i, %j) : (i32, i16) -> () + } + } {"ttg.always-fuse"} + tt.return +} + +// CHECK-LABEL: @upcast_index_to_i64 +// CHECK-SAME: [[LBI:%.*]]: index, [[UBI:%.*]]: index, [[STEPI:%.*]]: index, [[LBJ:%.*]]: index, [[UBJ:%.*]]: index, [[STEPJ:%.*]]: index +tt.func @upcast_index_to_i64(%lbi: index, %ubi: index, %stepi: index, %lbj: index, %ubj: index, %stepj: index) { + // CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] : index + // CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] : index + // CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]] : index + // CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]] : index + + // CHECK: arith.index_cast [[LEN_J]] : index to i64 + // CHECK: arith.index_cast [[LEN_I]] : index to i64 + scf.for %i = %lbi to %ubi step %stepi { + scf.for %j = %lbj to %ubj step %stepj { + "body"(%i, %j) : (index, index) -> () + } + } {"ttg.always-fuse"} + tt.return +} + +// CHECK-LABEL: @triple_loop_nest +tt.func @triple_loop_nest( + %lbi: i64, %ubi: i64, %stepi: i64, + %lbj: i64, %ubj: i64, %stepj: i64, + %lbk: i64, %ubk: i64, %stepk: i64) { + // CHECK-COUNT-1: scf.for + scf.for %i = %lbi to %ubi step %stepi : i64 { + scf.for %j = %lbj to %ubj step %stepj : i64 { + scf.for %k = %lbk to %ubk step %stepk : i64 { + "body"(%i, %j, %k) : (i64, i64, i64) -> () + } + } + } {"ttg.always-fuse"} + // CHECK-NOT: scf.for + // CHECK: tt.return + tt.return +} + +// CHECK-LABEL: @preserve_stage_count +tt.func @preserve_stage_count(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK-COUNT-1: scf.for + scf.for %i = %lb to %ub step %c1_i32 : i32 { + scf.for %j = %lb to %ub step %c1_i32 : i32 { + "body"(%j) : (i32) -> () + scf.yield + } {tt.num_stages = 4 : i32} + scf.for %j = %lb to %ub step %c1_i32 : i32 { + "body"(%j) : (i32) -> () + scf.yield + } {tt.num_stages = 6 : i32} + } {"ttg.always-fuse"} + // CHECK: tt.num_stages = 6 : i32 + // CHECK-NOT: scf.for + tt.return +} + +// CHECK-LABEL: @fuse_attr_speculate +// CHECK-SAME: [[LB:%.*]]: i32, [[UB:%.*]]: i32 +tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK: [[DIFF:%.*]] = arith.subi [[UB]], [[LB]] + // CHECK: [[LEN:%.*]] = arith.ceildivsi [[DIFF]], %c1_i32 + // CHECK: [[IS_ZERO:%.*]] = arith.cmpi eq, [[LEN]], %c0_i32 + + // CHECK: scf.if [[IS_ZERO]] + // CHECK-NEXT: scf.for %{{.*}} = [[LB]] to [[UB]] step %c1_i32 + // CHECK-NEXT: "prologue" + // CHECK-NXET: } + + // CHECK: else + // CHECK-COUNT-1: scf.for + // CHECK-NOT: scf.for + scf.for %i = %lb to %ub step %c1_i32 : i32 { + // CHECK: "prologue" + "prologue"(%i) : (i32) -> () + // CHECK: scf.if %true + scf.for %j = %lb to %ub step %c1_i32 : i32 { + // CHECK-NEXT: "body" + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + } {tt.flatten} + tt.return +} + +// CHECK-LABEL: @speculate_hoist +// CHECK-SAME: [[LB:%.*]]: i32, [[UB:%.*]]: i32 +tt.func @speculate_hoist(%lb: i32, %ub: i32) { + %c1_i32 = arith.constant 1 : i32 + + // CHECK: [[UBJ:%.*]] = arith.addi [[LB]], [[UB]] + // CHECK: [[DIFF:%.*]] = arith.subi [[UBJ]], [[LB]] + // CHECK: [[LEN:%.*]] = arith.ceildivsi [[DIFF]], %c1_i32 + // CHECK: [[IS_ZERO:%.*]] = arith.cmpi eq, [[LEN]], %c0_i32 + + // CHECK: scf.if [[IS_ZERO]] + scf.for %i = %lb to %ub step %c1_i32 : i32 { + "prologue"(%i) : (i32) -> () + %ubj = arith.addi %lb, %ub : i32 + scf.for %j = %lb to %ubj step %c1_i32 : i32 { + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + } {tt.flatten} + tt.return +} + +// CHECK-LABEL: @sink_prologue_to_epilogue +// CHECK-SAME: [[UB:%.*]]: i32 +tt.func @sink_prologue_to_epilogue(%ub: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // CHECK: else + // CHECK: scf.for + %0 = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args(%k = %c0_i32) -> i32 : i32 { + // CHECK: [[PROLOGUE_OUTS:%.*]]:2 = scf.if + %0 = arith.addi %i, %ub : i32 + // CHECK: scf.if %true + // CHECK-NEXT: "body" + scf.for %j = %c0_i32 to %ub step %c1_i32 : i32 { + "body"(%i, %j) : (i32, i32) -> () + scf.yield + } + // CHECK: scf.if + // CHECK-NEXT: [[V0:%.*]] = arith.addi [[PROLOGUE_OUTS]]#1, [[UB]] + // CHECK-NEXT: [[V1:%.*]] = arith.addi [[V0]], [[UB]] + %1 = arith.addi %0, %ub : i32 + // CHECK-NEXT: "epilogue"([[V1]]) + "epilogue"(%1) : (i32) -> () + scf.yield %0 : i32 + } {tt.flatten} + + tt.return +} + +// ----- + +// CHECK-LABEL: @prologue_output +tt.func @prologue_output(%ub: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // CHECK: scf.for + %0 = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args(%k = %c0_i32) -> i32 : i32 { + // CHECK: scf.if + // CHECK: {increment} + %next = arith.addi %k, %c1_i32 {increment} : i32 + // CHECK: scf.if + scf.for %j = %c0_i32 to %ub step %c1_i32 : i32 { + // CHECK-NEXT: "body" + "body"(%i, %j) : (i32, i32) -> () + } + // CHECK: scf.if {{%[0-9]+}} { + // CHECK-NEXT: "epilogue" + "epilogue"(%i) : (i32) -> () + // CHECK-NEXT: } else { + scf.yield %next : i32 + } {"ttg.always-fuse"} + + tt.return +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/global_scratch_alloc.mlir b/third_party/enflame/include/triton/test/TritonGPU/global_scratch_alloc.mlir new file mode 100644 index 000000000..1c4d5bb2e --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/global_scratch_alloc.mlir @@ -0,0 +1,34 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-global-scratch-memory-allocation | FileCheck %s + +// CHECK: module attributes {ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32{{.*}}} +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK: @test_alloc{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32 + tt.func public @test_alloc() -> (!tt.ptr, !tt.ptr) { + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr + // CHECK: ttg.global_scratch_memory_offset = 128 + %1 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + tt.return %0, %1 : !tt.ptr, !tt.ptr + } +} + +// ----- + +// CHECK: module attributes {ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32{{.*}}} +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK: @helper1{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 128 : i32 + tt.func private @helper1() -> (!tt.ptr) { + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + tt.return %0 : !tt.ptr + } + +// CHECK: @test_function{{.*}}ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 256 : i32 + tt.func public @test_function() -> (!tt.ptr, !tt.ptr) { + // CHECK: ttg.global_scratch_memory_offset = 0 + %0 = ttg.global_scratch_alloc {alignment = 8 : i32, nbytes = 100 : i32} : !tt.ptr + // CHECK: ttg.global_scratch_memory_offset = 128 + %1 = tt.call @helper1() : () -> (!tt.ptr) + tt.return %0, %1 : !tt.ptr, !tt.ptr + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/invalid-attributes.mlir b/third_party/enflame/include/triton/test/TritonGPU/invalid-attributes.mlir new file mode 100644 index 000000000..df693a6ea --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/invalid-attributes.mlir @@ -0,0 +1,78 @@ +// RUN: triton-opt %s -split-input-file -verify-diagnostics + +// expected-error@+2 {{ttg.dot_op opIdx parameter can be 0 or 1, got: 2}} +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#dot_op = #ttg.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}> + +// ----- + +// expected-error@+2 {{ttg.dot_op kWidth parameter is not supported when the parent is a blocked layout}} +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #blocked, kWidth = 8}> + +// ----- + +// expected-error@+2 {{ttg.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> + +// ----- + +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 0, parent = #mma}> + +// ----- + +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #ttg.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 0, parent = #mma}> + +// ----- + +// expected-error@+2 {{ttg.dot_op opIdx parameter must be 0 for Hopper MMA parent, since Hopper WGMMA only allows first operand to be in registers}} +#mma = #ttg.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> + +// ----- + +// expected-error@+2 {{ttg.dot_op kWidth parameter is mandatory for MFMA parent}} +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #mfma}> + +// ----- + +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma}> + +// ----- + +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}> + +// ----- +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 32}> + +// ----- +// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}} +#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 4}> + +// ----- + +// expected-error@+1 {{major version must be in the [0, 4] range}} +#mfma = #ttg.amd_mfma<{versionMajor = 10, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> + +// ----- + +// expected-error@+1 {{minor version must be 0}} +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 5, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> + +// ----- + +// expected-error@+1 {{(M, N) cases other than (32, 32) or (16, 16) unimplemented}} +#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [16, 8], isTransposed = false}> diff --git a/third_party/enflame/include/triton/test/TritonGPU/invalid.mlir b/third_party/enflame/include/triton/test/TritonGPU/invalid.mlir new file mode 100644 index 000000000..74df97ba2 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/invalid.mlir @@ -0,0 +1,347 @@ +// RUN: triton-opt --split-input-file %s --verify-diagnostics + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @miss_encoding(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{,}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<8x16xf16> + tt.return +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @miss_memory_space(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{,}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared> -> !ttg.memdesc<8x16xf16> + tt.return +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @subview_element_ty(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{element type}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x16xf16, #shared, #smem> + tt.return +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @too_many_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{offsets}} + %a = ttg.memdesc_subview %arg0[%zero, %zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc + tt.return +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @too_few_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{offsets}} + %a = ttg.memdesc_subview %arg0[%zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc + tt.return +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{result rank}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<3x8x16xf32, #shared, #smem> + tt.return +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @result_dim_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{result shape}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<32xf32, #shared, #smem> + tt.return +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { + tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { + // expected-error@+1 {{element types of operands A and B must have same bit width}} + %D = tt.dot %A, %B, %C : tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { + tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { + // expected-error@+1 {{mismatching encoding between A and B operands}} + %D = tt.dot %A, %B, %C : tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { + tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32>) { + // expected-error@+1 {{miss encoding of C operand}} + %D = tt.dot %A, %B, %C : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32> + tt.return + } +} + +// ----- + +#mma0 = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"ttg.num-warps" = 1 : i32} { + tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { + // expected-error@+1 {{mismatching kWidth between A and B operands}} + %D = tt.dot %A, %B, %C : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + tt.return + } +} + +// ----- + +tt.func @warp_specialize_no_holder() { + // expected-error @below {{'ttg.warp_specialize' op expected to find only a `ttg.warp_specialize.partitions` op inside its second region}} + "ttg.warp_specialize"() ({ + "ttg.warp_yield"() : () -> () + }, { + "ttg.warp_yield"() : () -> () + }) {partitionNumWarps = array} : () -> () + tt.return +} + +// ----- + +tt.func @warp_specialize_mismatch_partition_count() { + // expected-error @below {{'ttg.warp_specialize' op has 0 partitions but `partitionNumWarps` has 1 elements}} + "ttg.warp_specialize"() ({ + "ttg.warp_yield"() : () -> () + }, { + "ttg.warp_specialize.partitions"() : () -> () + }) {partitionNumWarps = array} : () -> () +} + +// ----- + +tt.func @not_power_of_2() { + // expected-error @below {{'ttg.warp_specialize' op partition #0 number of warps (3) must be a power of 2}} + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(3) { + ttg.warp_return + } : () -> () + tt.return +} + +// ----- + +tt.func @bad_argument_count() { + // expected-error @below {{'ttg.warp_specialize' op partition region #0 has 1 arguments but expected 0}} + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0(%arg0: i32) num_warps(4) { + ttg.warp_return + } : () -> () + tt.return +} + +// ----- + +tt.func @bad_argument_type(%arg0: i32) { + // expected-error @below {{'ttg.warp_specialize' op partition region #0 argument #0 has type 'i64' but corresponding capture has type 'i32'}} + ttg.warp_specialize(%arg0) + default { + ttg.warp_yield + } + partition0(%arg1: i64) num_warps(4) { + ttg.warp_return + } : (i32) -> () + tt.return +} + +// ----- + +tt.func @bad_default_yields(%arg0: i32) { + ttg.warp_specialize() + default { + // expected-error @below {{'ttg.warp_yield' op has 0 operands but parent op expected 1}} + ttg.warp_yield + } : () -> i32 + tt.return +} + +// ----- + +tt.func @bad_default_yields(%arg0: i32, %arg1: i64) { + ttg.warp_specialize() + default { + // expected-error @below {{'ttg.warp_yield' op operand #0 has type 'i64' but parent op expected 'i32'}} + ttg.warp_yield %arg1 : i64 + } : () -> i32 + tt.return +} + +// ----- + +#blocked_4_warps = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32} { + +tt.func @function_scope() attributes {"ttg.num-warps" = 8 : i32} { + // expected-error @below {{Layout has a total of 4 warps per CTA, but the context requires 8 warps per CTA}} + tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked_4_warps> + tt.return +} + +} + +// ----- + +#blocked_1_warps = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32} { + +tt.func @function_no_scope() { + // expected-error @below {{Layout has a total of 1 warps per CTA, but the context requires 4 warps per CTA}} + tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked_1_warps> + tt.return +} + +} + +// ----- + +#blocked_8_warps = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32} { + +tt.func @function_no_scope() { + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(2) { + // expected-error @below {{Layout has a total of 8 warps per CTA, but the context requires 2 warps per CTA}} + tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked_8_warps> + ttg.warp_return + } : () -> () + tt.return +} + +} + +// ----- + +#blocked_2_warps = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32} { + +tt.func @function_no_scope() { + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(2) { + ttg.warp_return + } + partition1() num_warps(1) { + // expected-error @below {{Layout has a total of 2 warps per CTA, but the context requires 1 warps per CTA}} + tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked_2_warps> + ttg.warp_return + } : () -> () + tt.return +} + +} + +// ----- + +tt.func @illegal_ws_nest() { + ttg.warp_specialize() + default { + // expected-error @below {{'ttg.warp_specialize' op cannot be nested inside another `ttg.warp_specialize` op}} + ttg.warp_specialize() + default { + ttg.warp_yield + } : () -> () + ttg.warp_yield + } : () -> () + tt.return +} + +// ----- + +tt.func @invalid_start_ids() { + // expected-error @below {{'ttg.warp_specialize' op has 1 warp group start IDs but expected 2}} + ttg.warp_specialize() attributes {warpGroupStartIds = array} + default { + ttg.warp_yield + } + partition0() num_warps(2) { + ttg.warp_return + } + partition1() num_warps(1) { + ttg.warp_return + } : () -> () + tt.return +} + +// ----- + +tt.func @partition_no_terminator() { + ttg.warp_specialize() + default { + ttg.warp_yield + } + // expected-error @below {{region with at least 1 blocks}} + partition0() num_warps(2) { + } : () -> () + tt.return +} + +// ----- + +tt.func @partition_no_terminator() { + ttg.warp_specialize() + default { + ttg.warp_yield + } + partition0() num_warps(2) { + // expected-error @below {{block with no terminator}} + %c1_i32 = arith.constant 1 : i32 + } : () -> () + tt.return +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-async-latencies.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-async-latencies.mlir new file mode 100644 index 000000000..aba9ed556 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-async-latencies.mlir @@ -0,0 +1,144 @@ +// RUN: triton-opt %s --tritongpu-pipeline="num-stages=3" -canonicalize -cse | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: matmul_kernel_tma_persistent +tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg1: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg2: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = arith.subi %arg3, %c2_i32 : i32 + + %1 = tt.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> + %2 = tt.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> + + // CHECK: [[LHS_BUFFERS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x64xf16, + // CHECK: [[RHS_BUFFERS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<4x256x64xf16, + + // CHECK: [[LHS_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, + // CHECK-NEXT: [[LHS_BAR0:%.*]] = ttg.memdesc_subview [[LHS_BARS]][%c0_i32] + // CHECK-NEXT: ttng.init_barrier [[LHS_BAR0]] + // CHECK-NEXT: [[LHS_BAR1:%.*]] = ttg.memdesc_subview [[LHS_BARS]][%c1_i32] + // CHECK-NEXT: ttng.init_barrier [[LHS_BAR1]] + + // CHECK: [[RHS_BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<4xi64, + // CHECK-NEXT: [[RHS_BAR0:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c0_i32] + // CHECK-NEXT: ttng.init_barrier [[RHS_BAR0]] + // CHECK-NEXT: [[RHS_BAR1:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c1_i32] + // CHECK-NEXT: ttng.init_barrier [[RHS_BAR1]] + // CHECK-NEXT: [[RHS_BAR2:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c2_i32] + // CHECK-NEXT: ttng.init_barrier [[RHS_BAR2]] + // CHECK-NEXT: [[RHS_BAR3:%.*]] = ttg.memdesc_subview [[RHS_BARS]][%c3_i32] + // CHECK-NEXT: ttng.init_barrier [[RHS_BAR3]] + + // CHECK: [[MASK0:%.*]] = arith.cmpi sgt, %arg3, %c0_i32 + // CHECK-NEXT: ttng.barrier_expect [[RHS_BAR0]], 32768, [[MASK0]] + // CHECK-NEXT: [[RHS_BUF0:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c0_i32] [[RHS_BUF0]], [[RHS_BAR0]], [[MASK0]] + + // CHECK: [[MASK1:%.*]] = arith.cmpi sgt, %arg3, %c1_i32 + // CHECK-NEXT: ttng.barrier_expect [[RHS_BAR1]], 32768, [[MASK1]] + // CHECK-NEXT: [[RHS_BUF1:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c1_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c1_i32] [[RHS_BUF1]], [[RHS_BAR1]], [[MASK1]] + + // CHECK: [[MASK2:%.*]] = arith.cmpi sgt, %arg3, %c2_i32 + + // CHECK-NEXT: ttng.barrier_expect [[LHS_BAR0]], 16384, [[MASK0]] + // CHECK-NEXT: [[LHS_BUF0:%.*]] = ttg.memdesc_subview [[LHS_BUFFERS]][%c0_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32] [[LHS_BUF0]], [[LHS_BAR0]], [[MASK0]] + + // CHECK: ttng.barrier_expect [[RHS_BAR2]], 32768, [[MASK2]] + // CHECK-NEXT: [[RHS_BUF2:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][%c2_i32, %c0_i32, %c0_i32] + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, %c2_i32] [[RHS_BUF2]], [[RHS_BAR2]], [[MASK2]] + + %true = arith.constant true + %false = arith.constant false + + // CHECK: scf.for [[I:%.*]] = %c0_i32 to + // CHECK-SAME: iter_args([[ACCUM:%arg[0-9]+]] = %cst + + // CHECK-SAME: [[NEXT_LHS_BUF_IDX:%arg[0-9]+]] = %c0_i32 + // CHECK-SAME: [[LHS_BUF_IDX:%arg[0-9]+]] = %c-1_i32 + // CHECK-SAME: [[LHS_PHASE_ARG:%arg[0-9]+]] = %c0_i32 + + // CHECK-SAME: [[NEXT_RHS_BUF_IDX:%arg[0-9]+]] = %c2_i32 + // CHECK-SAME: [[RHS_BUF_IDX:%arg[0-9]+]] = %c-1_i32 + // CHECK-SAME: [[RHS_PHASE_ARG:%arg[0-9]+]] = %c0_i32 + %3 = scf.for %arg6 = %c0_i32 to %arg3 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<128x256xf32, #mma>) : i32 { + // CHECK: [[RHS_MAX_ITER:%.*]] = arith.subi %arg3, %c3_i32 + // CHECK-NEXT: [[RHS_MASK:%.*]] = arith.cmpi slt, [[I]], [[RHS_MAX_ITER]] + // CHECK: [[LHS_MAX_ITER:%.*]] = arith.subi %arg3, %c1_i32 + // CHECK-NEXT: [[LHS_MASK:%.*]] = arith.cmpi slt, [[I]], [[LHS_MAX_ITER]] + + // Compute RHS buffer index modulo 4. + // CHECK: [[V0:%.*]] = arith.addi [[RHS_BUF_IDX]], %c1_i32 + // CHECK-NEXT: [[V1:%.*]] = arith.cmpi slt, [[V0]], %c4_i32 + // CHECK-NEXT: [[RHS_BUF_IDX:%.*]] = arith.select [[V1]], [[V0]], %c0_i32 + + // Compute RHS phase index modulo 4. + // CHECK: [[V0:%.*]] = arith.xori [[RHS_PHASE_ARG]], %c1_i32 + // CHECK-NEXT: [[RHS_PHASE:%.*]] = arith.select [[V1]], [[RHS_PHASE_ARG]], [[V0]] + + // Compute LHS buffer index modulo 2. + // CHECK: [[V0:%.*]] = arith.addi [[LHS_BUF_IDX]], %c1_i32 + // CHECK-NEXT: [[V1:%.*]] = arith.cmpi slt, [[V0]], %c2_i32 + // CHECK-NEXT: [[LHS_BUF_IDX:%.*]] = arith.select [[V1]], [[V0]], %c0_i32 + + // Compute LHS phase index modulo 2. + // CHECK: [[V0:%.*]] = arith.xori [[LHS_PHASE_ARG]], %c1_i32 + // CHECK-NEXT: [[LHS_PHASE:%.*]] = arith.select [[V1]], [[LHS_PHASE_ARG]], [[V0]] + + // CHECK: [[LHS_MBAR:%.*]] = ttg.memdesc_subview [[LHS_BARS]][[[LHS_BUF_IDX]]] + // CHECK-NEXT: ttng.wait_barrier [[LHS_MBAR]], [[LHS_PHASE]] + + // CHECK: [[RHS_MBAR:%.*]] = ttg.memdesc_subview [[RHS_BARS]][[[RHS_BUF_IDX]]] + // CHECK-NEXT: ttng.wait_barrier [[RHS_MBAR]], [[RHS_PHASE]] + + %4 = tt.experimental_descriptor_load %1[%c0_i32, %arg6] {tt_latency = 1 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %5 = ttg.local_alloc %4 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %6 = tt.experimental_descriptor_load %2[%c0_i32, %arg6] {tt_latency = 3 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #blocked> + %7 = ttg.local_alloc %6 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem> + %8 = ttg.memdesc_trans %7 {order = array} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem> + %9 = ttng.warp_group_dot %5, %8, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> + + // CHECK: [[V0:%.*]] = arith.addi [[NEXT_LHS_BUF_IDX]], %c1_i32 + // CHECK-NEXT: [[V1:%.*]] = arith.cmpi slt, [[V0]], %c2_i32 + // CHECK-NEXT: [[NEXT_LHS_BUF_IDX:%.*]] = arith.select [[V1]], [[V0]], %c0_i32 + // CHECK-NEXT: [[NEXT_LHS_BAR:%.*]] = ttg.memdesc_subview [[LHS_BARS]][[[NEXT_LHS_BUF_IDX]]] + // CHECK-NEXT: ttng.barrier_expect [[NEXT_LHS_BAR]], 16384, [[LHS_MASK]] + + // CHECK-NEXT: [[NEXT_LHS_BUF:%.*]] = ttg.memdesc_subview [[LHS_BUFFERS]][[[NEXT_LHS_BUF_IDX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[NEXT_LHS_IDX:%.*]] = arith.addi [[I]], %c1_i32 + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg0[%c0_i32, [[NEXT_LHS_IDX]]] [[NEXT_LHS_BUF]], [[NEXT_LHS_BAR]], [[LHS_MASK]] + + // CHECK: [[V0:%.*]] = arith.addi [[NEXT_RHS_BUF_IDX]], %c1_i32 + // CHECK-NEXT: [[V1:%.*]] = arith.cmpi slt, [[V0]], %c4_i32 + // CHECK-NEXT: [[NEXT_RHS_BUF_IDX:%.*]] = arith.select [[V1]], [[V0]], %c0_i32 + // CHECK-NEXT: [[NEXT_RHS_BAR:%.*]] = ttg.memdesc_subview [[RHS_BARS]][[[NEXT_RHS_BUF_IDX]]] + // CHECK-NEXT: ttng.barrier_expect [[NEXT_RHS_BAR]], 32768, [[RHS_MASK]] + + // CHECK-NEXT: [[NEXT_RHS_BUF:%.*]] = ttg.memdesc_subview [[RHS_BUFFERS]][[[NEXT_RHS_BUF_IDX]], %c0_i32, %c0_i32] + // CHECK-NEXT: [[NEXT_RHS_IDX:%.*]] = arith.addi [[I]], %c3_i32 + // CHECK-NEXT: ttng.async_tma_copy_global_to_local %arg1[%c0_i32, [[NEXT_RHS_IDX]]] [[NEXT_RHS_BUF]], [[NEXT_RHS_BAR]], [[RHS_MASK]] + + %10 = arith.cmpi eq, %arg3, %0 : i32 + scf.if %10 { + %11 = arith.truncf %9 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %12 = ttg.convert_layout %11 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + %13 = tt.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc> + tt.experimental_descriptor_store %13[%c0_i32, %c0_i32], %12 : !tt.tensordesc>, tensor<128x256xf16, #blocked1> + } + // CHECK: yield %{{.*}}, [[NEXT_LHS_BUF_IDX]], [[LHS_BUF_IDX]], [[LHS_PHASE]], [[NEXT_RHS_BUF_IDX]], [[RHS_BUF_IDX]], [[RHS_PHASE]] + scf.yield %9 : tensor<128x256xf32, #mma> + } {tt.num_stages = 4 : i32} + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-blackwell.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-blackwell.mlir new file mode 100644 index 000000000..fff8c9bc2 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-blackwell.mlir @@ -0,0 +1,402 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=CHECK + +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @chained_dot_scaled_acc + // CHECK-DAG: %[[C0_F:.+]] = arith.constant dense<0.000000e+00> + // CHECK-DAG: %[[C2_F:.+]] = arith.constant dense<2.000000e+00> + // CHECK-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-DAG: %[[FALSE:.+]] = arith.constant false + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK: %[[TMEM_BUF:.+]] = ttng.tmem_alloc %[[C0_F]] + // CHECK: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 + // CHECK: ttng.init_barrier %[[BAR_BUF]], 1 + // CHECK: %[[FOR_RET:.+]]:2 = scf.for {{.*}} iter_args(%[[PHASE:.+]] = %[[C0]], %[[NOT_0_ITER:.+]] = %[[FALSE]]) + // CHECK: ttng.wait_barrier %[[BAR_BUF]], %[[PHASE]], %[[NOT_0_ITER]] + // CHECK: %[[NOT_0_ITER_I32:.+]] = arith.extui %[[NOT_0_ITER]] : i1 to i32 + // CHECK: %[[PHASE_NEXT:.+]] = arith.xori %[[PHASE]], %[[NOT_0_ITER_I32]] + // CHECK: %[[ACC:.+]] = ttng.tmem_load %[[TMEM_BUF]] + // CHECK: %[[ACC2:.+]] = arith.mulf %[[ACC]], %[[C2_F]] + // CHECK: ttng.tmem_store %[[ACC2]], %[[TMEM_BUF]], %[[TRUE]] + // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_BUF]], %[[TRUE]], %[[TRUE]], %[[BAR_BUF]] + // CHECK: scf.yield %[[PHASE_NEXT]], %[[TRUE]] + // CHECK: ttng.wait_barrier %[[BAR_BUF]], %[[FOR_RET]]#0, %[[FOR_RET]]#1 + // CHECK: ttng.tmem_load %[[TMEM_BUF]] + // CHECK: ttng.inval_barrier %[[BAR_BUF]] + // CHECK: ttg.local_dealloc %[[BAR_BUF]] + tt.func public @chained_dot_scaled_acc(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %cst2 = arith.constant dense<2.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %sacc = arith.mulf %acc, %cst2 : tensor<128x128xf32, #blocked> + %acc_tm = ttng.tmem_alloc %sacc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + scf.yield %acc_res : tensor<128x128xf32, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @chained_scale_after_dot + // CHECK: ttng.tmem_alloc + // CHECK: scf.for + // CHECK: ttng.tc_gen5_mma + // CHECK: ttng.wait_barrier + // CHECK: ttng.tmem_load + // CHECK: arith.mulf + // CHECK: ttng.tmem_store + tt.func public @chained_scale_after_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %cst2 = arith.constant dense<2.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %sacc = arith.mulf %acc_res, %cst2 : tensor<128x128xf32, #blocked> + scf.yield %sacc : tensor<128x128xf32, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +// ----- + +// 4 warps +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_cast_load(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { +// CHECK-LABEL: tt.func @matmul_loop_cast_load +// CHECK-NOT: ttng.init_barrier +// CHECK-NOT: ttng.wait_barrier + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %true = arith.constant true + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf8E4M3FN, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a___ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a__ = tt.fp_to_fp %a___ : tensor<128x32xf8E4M3FN, #AL> -> tensor<128x32xf16, #AL> + %a_ = ttg.convert_layout %a__ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b___ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b__ = tt.fp_to_fp %b___ : tensor<32x128xf8E4M3FN, #BL> -> tensor<32x128xf16, #BL> + %b_ = ttg.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %a = ttg.local_alloc %a_ {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> !ttg.memdesc<128x32xf16, #shared, #smem> + %b = ttg.local_alloc %b_ {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x128xf16, #B>) -> !ttg.memdesc<32x128xf16, #shared, #smem> + %acc_tm = ttng.tmem_alloc %prev_c : (tensor<128x128xf32, #C>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %a, %b, %acc_tm, %true, %true : (!ttg.memdesc<128x32xf16, #shared, #smem>, !ttg.memdesc<32x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %c = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + +// CHECK-LABEL: @pipelined_gather +// CHECK-SAME: [[LHS_DESC:%arg[0-9]+]]: +// CHECK-SAME: [[RHS_DESC:%arg[0-9]+]]: +// CHECK-SAME: [[LHS_X:%arg[0-9]+]]: +// CHECK-SAME: [[RHS_X:%arg[0-9]+]]: +tt.func private @pipelined_gather( + %lhs_desc: !tt.tensordesc>, + %rhs_desc: !tt.tensordesc>, + %lhs_x_offsets: tensor<32xi32, #blocked1>, + %rhs_x_offsets: tensor<128xi32, #blocked1>) -> tensor<32x32xf32, #blocked> { + %c0_i32 = arith.constant 0 : i32 + %c128_i32 = arith.constant 128 : i32 + %c1024_i32 = arith.constant 1024 : i32 + + %c0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + + // CHECK: [[LHS_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x32x128xbf16, + // CHECK: [[RHS_BUFS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xbf16, + // CHECK: [[BARS:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64, + + // CHECK-COUNT-2: ttng.init_barrier + + // CHECK: [[BAR0:%.*]] = ttg.memdesc_subview [[BARS]][%c0_i32] + // CHECK: ttng.barrier_expect [[BAR0]], 16384 + // CHECK: [[LHS_BUF0:%.*]] = ttg.memdesc_subview [[LHS_BUFS]][%c0_i32, + // CHECK: [[LHS_PTR:%.*]] = ttng.tensor_desc_to_tma_ptr [[LHS_DESC]] + // CHECK: ttng.async_tma_gather [[LHS_PTR]][[[LHS_X]], %c0_i32] [[LHS_BUF0]], [[BAR0]], %true + // CHECK: [[RHS_BUF0:%.*]] = ttg.memdesc_subview [[RHS_BUFS]][%c0_i32, + // CHECK: [[RHS_PTR:%.*]] = ttng.tensor_desc_to_tma_ptr [[RHS_DESC]] + // CHECK: ttng.async_tma_gather [[RHS_PTR]][[[RHS_X]], %c0_i32] [[RHS_BUF0]], [[BAR0]], %true + + // CHECK: [[BAR1:%.*]] = ttg.memdesc_subview [[BARS]][%c1_i32] + // CHECK: ttng.barrier_expect [[BAR1]], 16384 + // CHECK: [[LHS_BUF1:%.*]] = ttg.memdesc_subview [[LHS_BUFS]][%c1_i32, + // CHECK: [[LHS_PTR:%.*]] = ttng.tensor_desc_to_tma_ptr [[LHS_DESC]] + // CHECK: ttng.async_tma_gather [[LHS_PTR]][[[LHS_X]], %c128_i32] [[LHS_BUF1]], [[BAR1]], %true + // CHECK: [[RHS_BUF1:%.*]] = ttg.memdesc_subview [[RHS_BUFS]][%c1_i32, + // CHECK: [[RHS_PTR:%.*]] = ttng.tensor_desc_to_tma_ptr [[RHS_DESC]] + // CHECK: ttng.async_tma_gather [[RHS_PTR]][[[RHS_X]], %c128_i32] [[RHS_BUF1]], [[BAR1]], %true + + // CHECK: scf.for + %out = scf.for %y = %c0_i32 to %c1024_i32 step %c128_i32 iter_args(%acc = %c0) -> (tensor<32x32xf32, #mma>) : i32 { + // CHECK: ttng.wait_barrier + // CHECK: [[RHS_VIEW:%.*]] = ttg.memdesc_subview [[RHS_BUFS]] + // CHECK: [[RHS:%.*]] = ttg.local_load [[RHS_VIEW]] + // CHECK: [[LHS_VIEW:%.*]] = ttg.memdesc_subview [[LHS_BUFS]] + // CHECK: [[LHS:%.*]] = ttg.local_load [[LHS_VIEW]] + // CHECK: tt.dot [[LHS]], [[RHS]] + %lhs = tt.experimental_descriptor_gather %lhs_desc[%lhs_x_offsets, %y] : (!tt.tensordesc>, tensor<32xi32, #blocked1>, i32) -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %rhs = tt.experimental_descriptor_gather %rhs_desc[%rhs_x_offsets, %y] : (!tt.tensordesc>, tensor<128xi32, #blocked1>, i32) -> tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %next = tt.dot %lhs, %rhs, %acc : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * + tensor<128x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + -> tensor<32x32xf32, #mma> + + + // CHECK-COUNT-2: async_tma_gather + scf.yield %next : tensor<32x32xf32, #mma> + } + %out_cvt = ttg.convert_layout %out : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.return %out_cvt : tensor<32x32xf32, #blocked> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 4, 8, 1, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 1, 2, 3, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[32, 0], [64, 0], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = []}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}> +#tmem = #ttng.tensor_memory_encoding +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @block_scale_mxfp_matmul(%lb : index, %ub : index, %step : index, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked4> { + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x128x256xf8E5M2 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x256x128xf8E5M2 + // Do not multibuffer the scale loads, as we cannot pipeline the mma due to tmem.cp not being used + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<2x1x2x32x4x4xi8 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<2x1x2x32x4x4xi8 + + %true = arith.constant true + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked4> + %incr_A = arith.constant dense<4> : tensor<128x256xi32, #blocked> + %incr_B = arith.constant dense<4> : tensor<256x128xi32, #blocked1> + %incr_scale = arith.constant dense<4> : tensor<1x2x32x4x4xi32, #blocked2> + + %arg0_splat = tt.splat %arg0: !tt.ptr -> tensor<128x256x!tt.ptr, #blocked> + %arg1_splat = tt.splat %arg1: !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %arg3_splat = tt.splat %arg3: !tt.ptr -> tensor<1x2x32x4x4x!tt.ptr, #blocked2> + %arg4_splat = tt.splat %arg4: !tt.ptr -> tensor<1x2x32x4x4x!tt.ptr, #blocked2> + + %76 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %77 = tt.expand_dims %76 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %79 = tt.broadcast %77 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> + %arg0_init = tt.addptr %arg0_splat, %79 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> + + %83 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %88 = tt.broadcast %84 : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %arg1_init = tt.addptr %arg1_splat, %88 : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + + %44 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>}>}>> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>}>}>> -> tensor<1x4xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>}>> + %48 = tt.expand_dims %46 {axis = 1 : i32} : tensor<1x4xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>}>> -> tensor<1x1x4xi32, #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>> + %50 = tt.expand_dims %48 {axis = 2 : i32} : tensor<1x1x4xi32, #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>> -> tensor<1x1x1x4xi32, #ttg.slice<{dim = 3, parent = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}>}>> + %56 = tt.expand_dims %50 {axis = 3 : i32} : tensor<1x1x1x4xi32, #ttg.slice<{dim = 3, parent = #blocked2}>> -> tensor<1x1x1x1x4xi32, #blocked2> + %57 = tt.broadcast %56 : tensor<1x1x1x1x4xi32, #blocked2> -> tensor<1x2x32x4x4xi32, #blocked2> + + %arg3_init = tt.addptr %arg3_splat, %57 : tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2> + %arg4_init = tt.addptr %arg4_splat, %57 : tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2> + + %99:5 = scf.for %iv = %lb to %ub step %step iter_args(%arg15 = %cst_1, %arg16 = %arg0_init, %arg17 = %arg1_init, %arg18 = %arg3_init, %arg19 = %arg4_init) -> (tensor<128x128xf32, #blocked4>, tensor<128x256x!tt.ptr, #blocked>, tensor<256x128x!tt.ptr, #blocked1>, tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4x!tt.ptr, #blocked2>) { + %117 = tt.load %arg16 : tensor<128x256x!tt.ptr, #blocked> + %118 = ttg.local_alloc %117 : (tensor<128x256xf8E5M2, #blocked>) -> !ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory> + %119 = tt.load %arg17 : tensor<256x128x!tt.ptr, #blocked1> + %120 = ttg.local_alloc %119 : (tensor<256x128xf8E5M2, #blocked1>) -> !ttg.memdesc<256x128xf8E5M2, #shared, #ttg.shared_memory> + %121 = tt.load %arg18 : tensor<1x2x32x4x4x!tt.ptr, #blocked2> + %122 = tt.load %arg19 : tensor<1x2x32x4x4x!tt.ptr, #blocked2> + + %137 = ttg.local_alloc %121 : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> + %138 = ttg.local_load %137 : !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> -> tensor<1x2x32x4x4xi8, #blocked2> + %123 = tt.trans %138 {order = array} : tensor<1x2x32x4x4xi8, #blocked2> -> tensor<1x4x32x2x4xi8, #blocked3> + %124 = tt.reshape %123 : tensor<1x4x32x2x4xi8, #blocked3> -> tensor<128x8xi8, #linear> + + %139 = ttg.local_alloc %122 : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> + %140 = ttg.local_load %139 : !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> -> tensor<1x2x32x4x4xi8, #blocked2> + %125 = tt.trans %140 {order = array} : tensor<1x2x32x4x4xi8, #blocked2> -> tensor<1x4x32x2x4xi8, #blocked3> + %126 = tt.reshape %125 : tensor<1x4x32x2x4xi8, #blocked3> -> tensor<128x8xi8, #linear> + + %127 = ttng.tmem_alloc %arg15 : (tensor<128x128xf32, #blocked4>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %128 = ttg.convert_layout %124 : tensor<128x8xi8, #linear> -> tensor<128x8xi8, #blocked5> + %129 = ttg.convert_layout %126 : tensor<128x8xi8, #linear> -> tensor<128x8xi8, #blocked5> + %130 = ttng.tmem_alloc %128 : (tensor<128x8xi8, #blocked5>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory> + %131 = ttng.tmem_alloc %129 : (tensor<128x8xi8, #blocked5>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory> + ttng.tc_gen5_mma_scaled %118, %120, %127, %130, %131, %true, %true lhs = e5m2 rhs = e5m2 : (!ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<256x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>, i1, i1) -> () + %132 = ttng.tmem_load %127 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked4> + + %133 = tt.addptr %arg16, %incr_A : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> + %134 = tt.addptr %arg17, %incr_B : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + %135 = tt.addptr %arg18, %incr_scale : tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2> + %136 = tt.addptr %arg19, %incr_scale : tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2> + scf.yield %132, %133, %134, %135, %136 : tensor<128x128xf32, #blocked4>, tensor<128x256x!tt.ptr, #blocked>, tensor<256x128x!tt.ptr, #blocked1>, tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4x!tt.ptr, #blocked2> + } {tt.num_stages = 3 : i32} + tt.return %99#0 : tensor<128x128xf32, #blocked4> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[32, 0], [64, 0], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = []}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @block_scale_mxfp_matmul_tmem_copy(%lb : index, %ub : index, %step : index, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked4> { + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x128x256xf8E5M2 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x256x128xf8E5M2 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x1x2x32x4x4xi8 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x1x2x32x4x4xi8 + + %true = arith.constant true + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked4> + %incr_A = arith.constant dense<4> : tensor<128x256xi32, #blocked> + %incr_B = arith.constant dense<4> : tensor<256x128xi32, #blocked1> + %incr_scale = arith.constant dense<4> : tensor<1x2x32x4x4xi32, #blocked2> + + %arg0_splat = tt.splat %arg0: !tt.ptr -> tensor<128x256x!tt.ptr, #blocked> + %arg1_splat = tt.splat %arg1: !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %arg3_splat = tt.splat %arg3: !tt.ptr -> tensor<1x2x32x4x4x!tt.ptr, #blocked2> + %arg4_splat = tt.splat %arg4: !tt.ptr -> tensor<1x2x32x4x4x!tt.ptr, #blocked2> + + %76 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %77 = tt.expand_dims %76 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %79 = tt.broadcast %77 : tensor<1x256xi32, #blocked> -> tensor<128x256xi32, #blocked> + %arg0_init = tt.addptr %arg0_splat, %79 : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> + + %83 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %88 = tt.broadcast %84 : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %arg1_init = tt.addptr %arg1_splat, %88 : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + + %44 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>}>}>> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>}>}>> -> tensor<1x4xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>}>> + %48 = tt.expand_dims %46 {axis = 1 : i32} : tensor<1x4xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>}>> -> tensor<1x1x4xi32, #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>> + %50 = tt.expand_dims %48 {axis = 2 : i32} : tensor<1x1x4xi32, #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked2}>}>> -> tensor<1x1x1x4xi32, #ttg.slice<{dim = 3, parent = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}>}>> + %56 = tt.expand_dims %50 {axis = 3 : i32} : tensor<1x1x1x4xi32, #ttg.slice<{dim = 3, parent = #blocked2}>> -> tensor<1x1x1x1x4xi32, #blocked2> + %57 = tt.broadcast %56 : tensor<1x1x1x1x4xi32, #blocked2> -> tensor<1x2x32x4x4xi32, #blocked2> + + %arg3_init = tt.addptr %arg3_splat, %57 : tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2> + %arg4_init = tt.addptr %arg4_splat, %57 : tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2> + + %99:5 = scf.for %iv = %lb to %ub step %step iter_args(%arg15 = %cst_1, %arg16 = %arg0_init, %arg17 = %arg1_init, %arg18 = %arg3_init, %arg19 = %arg4_init) -> (tensor<128x128xf32, #blocked4>, tensor<128x256x!tt.ptr, #blocked>, tensor<256x128x!tt.ptr, #blocked1>, tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4x!tt.ptr, #blocked2>) { + %117 = tt.load %arg16 : tensor<128x256x!tt.ptr, #blocked> + %118 = ttg.local_alloc %117 : (tensor<128x256xf8E5M2, #blocked>) -> !ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory> + %119 = tt.load %arg17 : tensor<256x128x!tt.ptr, #blocked1> + %120 = ttg.local_alloc %119 : (tensor<256x128xf8E5M2, #blocked1>) -> !ttg.memdesc<256x128xf8E5M2, #shared, #ttg.shared_memory> + %121 = tt.load %arg18 : tensor<1x2x32x4x4x!tt.ptr, #blocked2> + %122 = tt.load %arg19 : tensor<1x2x32x4x4x!tt.ptr, #blocked2> + + %137 = ttg.local_alloc %121 : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> + %139 = ttg.local_alloc %122 : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> + + %127 = ttng.tmem_alloc %arg15 : (tensor<128x128xf32, #blocked4>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + + ttng.tc_gen5_mma_scaled %118, %120, %127, %137, %139, %true, %true lhs = e5m2 rhs = e5m2 : (!ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<256x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, i1, i1) -> () + %132 = ttng.tmem_load %127 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked4> + + %133 = tt.addptr %arg16, %incr_A : tensor<128x256x!tt.ptr, #blocked>, tensor<128x256xi32, #blocked> + %134 = tt.addptr %arg17, %incr_B : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + %135 = tt.addptr %arg18, %incr_scale : tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2> + %136 = tt.addptr %arg19, %incr_scale : tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2> + scf.yield %132, %133, %134, %135, %136 : tensor<128x128xf32, #blocked4>, tensor<128x256x!tt.ptr, #blocked>, tensor<256x128x!tt.ptr, #blocked1>, tensor<1x2x32x4x4x!tt.ptr, #blocked2>, tensor<1x2x32x4x4x!tt.ptr, #blocked2> + } {tt.num_stages = 3 : i32} + tt.return %99#0 : tensor<128x128xf32, #blocked4> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-cuda.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-cuda.mlir new file mode 100644 index 000000000..809d7471d --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-cuda.mlir @@ -0,0 +1,201 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +// CHECK-LABEL: tt.func @load_two_users + tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: scf.for + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} + // CHECK: tt.dot + // CHECK: tt.dot + // CHECK: ttg.async_copy_global_to_local + // CHECK: scf.yield + // CHECK: ttg.async_wait {num = 0 : i32} + + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem> -> !ttg.memdesc<16x64xf16, #shared1, #smem> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +// CHECK-NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.get_program_id y : i32 + %3 = tt.load %arg3 : !tt.ptr + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> + %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> + %11 = arith.extsi %arg5 : i32 to i64 + %12 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked> + %13 = arith.muli %10, %12 : tensor<64x1xi64, #blocked> + %14 = arith.muli %2, %arg5 : i32 + %15 = arith.extsi %14 : i32 to i64 + %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> + %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> + %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> + %25 = arith.muli %21, %23 : tensor<1x64xi32, #blocked1> + %26 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> + %27 = arith.extsi %24 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> + %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> + %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> + %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> + %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> + %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> + %36 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked1> + %37 = arith.muli %35, %36 : tensor<32x1xi64, #blocked1> + %38 = tt.splat %15 : i64 -> tensor<32x1xi64, #blocked1> + %39 = arith.addi %37, %38 : tensor<32x1xi64, #blocked1> + %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> + %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> + %50 = arith.muli %46, %48 : tensor<1x32xi32, #blocked> + %51 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %52 = arith.extsi %49 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %53 = arith.extsi %50 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> + %54 = tt.broadcast %52 : tensor<1x32xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %55 = arith.addi %51, %54 : tensor<32x32xi64, #blocked1> + %56 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %57 = tt.addptr %56, %30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked1> + %59 = tt.addptr %58, %42 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi64, #blocked1> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %55 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi64, #blocked1> + %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> + %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { + %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> + %71 = ttg.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #smem> + %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #smem> -> !ttg.memdesc<64x32xf32, #shared1, #smem> + %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #smem> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> + %77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = ttg.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + scf.yield %79 : tensor<64x32xf32, #mma> + } + %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> + %65 = tt.broadcast %53 : tensor<1x32xi64, #blocked> -> tensor<64x32xi64, #blocked> + %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> + %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> + %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> + %69 = ttg.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> + tt.return + } +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +// CHECK-LABEL: @matmul_tma +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #{{.+}}, #smem, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #{{.+}}, #smem, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3xi64, #{{.+}}, #smem, mutable> +// CHECK-COUNT-3: ttng.init_barrier +// CHECK-COUNT-4: ttng.async_tma_copy_global_to_local +// CHECK: scf.for +// CHECK: ttng.wait_barrier +// CHECK-NOT: ttng.wait_barrier +// CHECK-COUNT-2: ttng.async_tma_copy_global_to_local +// CHECK: scf.yield + tt.func public @matmul_tma(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>) -> tensor<128x256xf32, #mma> { + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { + %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %4 = ttg.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> + %5 = ttng.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> + %6 = arith.addi %arg5, %c64_i32 : i32 + scf.yield %5, %6 : tensor<128x256xf32, #mma>, i32 + } + tt.return %0#0 : tensor<128x256xf32, #mma> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-expand.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-expand.mlir new file mode 100644 index 000000000..42d5d4d47 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-expand.mlir @@ -0,0 +1,34 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline | FileCheck %s --check-prefixes=CHECK + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 8]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @pipeline_load_mmav3 + tt.func public @pipeline_load_mmav3(%arg0: tensor<256x128xf32, #mma>, %arg1: tensor<256x32x!tt.ptr, #blocked>, %arg2: tensor<32x128x!tt.ptr, #blocked1>, %arg3: tensor<256x32xi32, #blocked>, %arg4: tensor<32x128xi32, #blocked1>) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr, #blocked>, tensor<32x128x!tt.ptr, #blocked1>) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c128_i32 = arith.constant 128 : i32 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<4x256x32xf32 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<4x32x128xf32 + %0:3 = scf.for %arg5 = %c0_i32 to %c128_i32 step %c1_i32 iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr, #blocked>, tensor<32x128x!tt.ptr, #blocked1>) : i32 { + // CHECK: ttg.memdesc_subview {{.*}} : !ttg.memdesc<4x256x32xf32 + // CHECK: ttg.async_wait {{.*}} {num = 4 : i32} + // CHECK: ttg.memdesc_subview {{.*}} : !ttg.memdesc<4x32x128xf32 + // CHECK: ttng.warp_group_dot {{.*}} {inputPrecision = 0 : i32, isAsync = true} + // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32} + %1 = tt.load %arg7 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<256x32x!tt.ptr, #blocked> + %2 = ttg.local_alloc %1 {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<256x32xf32, #blocked>) -> !ttg.memdesc<256x32xf32, #shared, #smem> + %3 = tt.load %arg8 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<32x128x!tt.ptr, #blocked1> + %4 = ttg.local_alloc %3 {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<32x128xf32, #blocked1>) -> !ttg.memdesc<32x128xf32, #shared1, #smem> + %5 = ttng.warp_group_dot %2, %4, %arg6 {inputPrecision = 0 : i32, loop.cluster = 0 : i32, loop.stage = 3 : i32} : !ttg.memdesc<256x32xf32, #shared, #smem> * !ttg.memdesc<32x128xf32, #shared1, #smem> -> tensor<256x128xf32, #mma> + %6 = tt.addptr %arg7, %arg3 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<256x32x!tt.ptr, #blocked>, tensor<256x32xi32, #blocked> + %7 = tt.addptr %arg8, %arg4 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<32x128x!tt.ptr, #blocked1>, tensor<32x128xi32, #blocked1> + scf.yield %5, %6, %7 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr, #blocked>, tensor<32x128x!tt.ptr, #blocked1> + } {tt.num_stages = 4 : i32} + tt.return %0#0, %0#1, %0#2 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr, #blocked>, tensor<32x128x!tt.ptr, #blocked1> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hip.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hip.mlir new file mode 100644 index 000000000..aa439317c --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hip.mlir @@ -0,0 +1,336 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @load_two_users + tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: ttg.local_store + // CHECK: scf.for + // CHECK: tt.load + // CHECK: tt.dot + // CHECK: tt.dot + // CHECK: ttg.local_store + // CHECK: scf.yield + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x64xf16, #shared1, #smem, mutable> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem, mutable> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de +// CHECK-NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.get_program_id y : i32 + %3 = tt.load %arg3 : !tt.ptr + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> + %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> + %11 = arith.extsi %arg5 : i32 to i64 + %12 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked> + %13 = arith.muli %10, %12 : tensor<64x1xi64, #blocked> + %14 = arith.muli %2, %arg5 : i32 + %15 = arith.extsi %14 : i32 to i64 + %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> + %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> + %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> + %25 = arith.muli %21, %23 : tensor<1x64xi32, #blocked1> + %26 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> + %27 = arith.extsi %24 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> + %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> + %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> + %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> + %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> + %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> + %36 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked1> + %37 = arith.muli %35, %36 : tensor<32x1xi64, #blocked1> + %38 = tt.splat %15 : i64 -> tensor<32x1xi64, #blocked1> + %39 = arith.addi %37, %38 : tensor<32x1xi64, #blocked1> + %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> + %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> + %50 = arith.muli %46, %48 : tensor<1x32xi32, #blocked> + %51 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %52 = arith.extsi %49 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %53 = arith.extsi %50 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> + %54 = tt.broadcast %52 : tensor<1x32xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %55 = arith.addi %51, %54 : tensor<32x32xi64, #blocked1> + %56 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %57 = tt.addptr %56, %30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked1> + %59 = tt.addptr %58, %42 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi64, #blocked1> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %55 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi64, #blocked1> + %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> + %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { + %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> + %71 = ttg.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #smem, mutable> + %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #smem, mutable> -> !ttg.memdesc<64x32xf32, #shared1, #smem, mutable> + %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #smem, mutable> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> + %77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = ttg.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + scf.yield %79 : tensor<64x32xf32, #mma> + } + %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> + %65 = tt.broadcast %53 : tensor<1x32xi64, #blocked> -> tensor<64x32xi64, #blocked> + %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> + %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> + %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> + %69 = ttg.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> + tt.return + } +} // end module + +// ----- + +// Disable pipelining for loops that contain barrier. +// Barriers are problematic since they are not chained to any other operation. +// CHECK-LABEL: tt.func public @add_barrier_kernel +// CHECK: scf.for +// CHECK: tt.load +// CHECK: gpu.barrier +// CHECK: tt.store +// CHECK-NOT: gpu.barrier +// CHECK: tt.return + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func public @add_barrier_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %cval_f32 = arith.constant dense<0.3> : tensor<1024xf32, #blocked> + %c1016800_i32 = arith.constant 1016800 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %arg3 step %c1024_i32 : i32 { + %7 = arith.addi %1, %arg4 : i32 + %8 = tt.splat %7 : i32 -> tensor<1024xi32, #blocked> + %9 = arith.addi %8, %2 : tensor<1024xi32, #blocked> + %11 = tt.addptr %4, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + gpu.barrier + %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %15 = arith.addf %12, %cval_f32 : tensor<1024xf32, #blocked> + tt.store %16, %15 : tensor<1024x!tt.ptr, #blocked> + } {tt.num_stages = 2 : i32} + tt.return + } +} // end module + +// ----- + +// CHECK-NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1] +// CHECK: #ttg.swizzled_shared<{{.*}} order = [2, 1, 0] +// CHECK-NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1] + +// CHECK-LABEL: tt.func public @slowest_dim_is_batch +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @slowest_dim_is_batch(%arg0: tensor<1x512x!tt.ptr, #blocked2>, %arg1: tensor<64x8x32x!tt.ptr, #blocked1>, %arg2: tensor<64x1x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x1x32xf32, #blocked> + %cst_0 = arith.constant dense<512> : tensor<1x512xi32, #blocked2> + %cst_1 = arith.constant dense<128> : tensor<64x8x32xi32, #blocked1> + %c1_i32 = arith.constant 1 : i32 + %c5_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1>) : i32 { + %39 = tt.load %arg9 : tensor<1x512x!tt.ptr, #blocked2> + %40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr, #blocked1> + %41 = tt.reshape %39 allow_reorder : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> + %43 = ttg.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %44 = ttg.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> + %46 = tt.addptr %arg9, %cst_0 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> + %47 = tt.addptr %arg10, %cst_1 : tensor<64x8x32x!tt.ptr, #blocked1>, tensor<64x8x32xi32, #blocked1> + scf.yield %45, %46, %47 : tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1> + } + tt.store %arg2, %33#0 : tensor<64x1x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced +// CHECK-LABEL: loop_with_dot_and_transpose +// CHECK: ttg.local_alloc {{.*}}, mutable> +// CHECK: ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable> + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1201", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @loop_with_dot_and_transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32, %arg4: tensor<32x32x!tt.ptr, #blocked1>, %arg5: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %0 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<32x32xf32, #blocked>) : i32 { + %2 = tt.load %arg4 : tensor<32x32x!tt.ptr, #blocked1> + %3 = ttg.local_alloc %2 : (tensor<32x32xf32, #blocked1>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %4 = ttg.memdesc_trans %3 {order = array} : !ttg.memdesc<32x32xf32, #shared, #smem> -> !ttg.memdesc<32x32xf32, #shared1, #smem> + %5 = ttg.local_load %4 : !ttg.memdesc<32x32xf32, #shared1, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %6 = ttg.convert_layout %2 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %7 = tt.dot %6, %5, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf32, #blocked> + scf.yield %7 : tensor<32x32xf32, #blocked> + } + tt.store %arg5, %0 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// Check that the stream pipeliner updates atomic op in the k-loop correctly +// CHECK-LABEL: _triton_gemm_kernel_atomic_rmw +// CHECK: scf.for +// CHECK: tt.atomic_rmw fadd, acq_rel, gpu +// CHECK: tt.dot +// CHECK: scf.yield + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @_triton_gemm_kernel_atomic_rmw(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg3: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant dense<32> : tensor<32x32xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c31_i32 = arith.constant 31 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %2 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<32x1xi32, #blocked> + %4 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %6 = tt.broadcast %3 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> + %7 = tt.broadcast %5 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<32x32xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %12 = tt.addptr %11, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %13 = tt.splat %arg2 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %14 = tt.addptr %13, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %15 = tt.broadcast %14 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %16 = tt.addptr %15, %7 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %17 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked> + %18 = arith.cmpi slt, %1, %17 : tensor<32x1xi32, #blocked> + %19 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #blocked> + %20 = arith.cmpi slt, %5, %19 : tensor<1x32xi32, #blocked> + %21 = tt.broadcast %18 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + %22 = tt.broadcast %20 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked> + %23 = arith.andi %21, %22 : tensor<32x32xi1, #blocked> + %24 = arith.addi %arg3, %c31_i32 : i32 + %25 = arith.divsi %24, %c32_i32 : i32 + %26 = arith.muli %arg4, %c32_i32 : i32 + %27 = tt.splat %26 : i32 -> tensor<32x32xi32, #blocked> + %28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args(%arg6 = %cst_0, %arg7 = %10, %arg8 = %12) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { + %32 = tt.load %arg7 : tensor<32x32x!tt.ptr, #blocked> + %33 = tt.load %arg8 : tensor<32x32x!tt.ptr, #blocked> + %34 = ttg.convert_layout %32 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %35 = ttg.convert_layout %33 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %36 = tt.dot %34, %35, %arg6 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %37 = tt.addptr %arg7, %cst : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %38 = tt.addptr %arg8, %27 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %39 = arith.truncf %36 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %40 = ttg.convert_layout %39 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked> + %41 = tt.atomic_rmw fadd, acq_rel, gpu, %16, %40, %23 : (tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked> + scf.yield %36, %37, %38 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> + } + %29 = arith.truncf %28#0 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %30 = ttg.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #mma> + %31 = ttg.convert_layout %23 : tensor<32x32xi1, #blocked> -> tensor<32x32xi1, #mma> + tt.store %30, %29, %31 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir new file mode 100644 index 000000000..c35d65414 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir @@ -0,0 +1,168 @@ +// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-pipeline -canonicalize | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: two_dependent_dot + tt.func public @two_dependent_dot(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_1 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %cst_4 = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = arith.divsi %2, %arg8 : i32 + %4 = arith.extsi %arg21 : i32 to i64 + %5 = arith.extsi %arg11 : i32 to i64 + %6 = arith.extsi %c0_i32 : i32 to i64 + %7 = arith.extsi %3 : i32 to i64 + %8 = arith.extsi %arg14 : i32 to i64 + %9 = arith.extsi %3 : i32 to i64 + %10 = arith.extsi %c0_i32 : i32 to i64 + %11 = arith.muli %0, %c128_i32 : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> + %15 = tt.splat %11 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.splat %11 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %17 = tt.splat %11 : i32 -> tensor<128xi32, #blocked1> + %18 = arith.addi %15, %12 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %19 = arith.addi %16, %13 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %20 = arith.addi %17, %14 : tensor<128xi32, #blocked1> + %21 = arith.mulf %arg3, %cst_4 : f32 + %22 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %23 = tt.expand_dims %18 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %24 = tt.expand_dims %19 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> + %25 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked> + %26 = arith.muli %23, %25 : tensor<128x1xi32, #blocked> + %27 = tt.splat %22 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %28 = tt.addptr %27, %26 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %29 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %28 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> + %32 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> + %33 = tt.addptr %31, %32 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> + %34 = tt.load %33 : tensor<128x128x!tt.ptr, #blocked> + %35 = tt.splat %21 : f32 -> tensor<128x128xf32, #blocked> + %36 = arith.extf %34 : tensor<128x128xf16, #blocked> to tensor<128x128xf32, #blocked> + %37 = arith.mulf %36, %35 : tensor<128x128xf32, #blocked> + %38 = arith.truncf %37 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + %39 = arith.addi %0, %c1_i32 : i32 + %40 = arith.muli %39, %c128_i32 : i32 + %41:7 = scf.for %arg22 = %c0_i32 to %40 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %7, %arg28 = %9, %arg29 = %10) -> (tensor<128x128xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64) : i32 { + %69 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked2> + %70 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %71 = arith.extsi %70 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %72 = tt.splat %arg26 : i64 -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %73 = arith.addi %71, %72 : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> + %74 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi64, #blocked2> + %75 = tt.broadcast %74 : tensor<128x1xi64, #blocked2> -> tensor<128x64xi64, #blocked2> + %76 = tt.splat %c1_i64 : i64 -> tensor<128x64xi64, #blocked2> + %77 = arith.muli %75, %76 : tensor<128x64xi64, #blocked2> + %78 = tt.broadcast %77 : tensor<128x64xi64, #blocked2> -> tensor<128x64xi64, #blocked2> + %79 = tt.addptr %69, %78 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi64, #blocked2> + %80 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %81 = arith.extsi %80 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %82 = tt.splat %arg27 : i64 -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.addi %81, %82 : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> + %84 = tt.expand_dims %83 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> + %85 = tt.broadcast %84 : tensor<1x64xi64, #blocked2> -> tensor<128x64xi64, #blocked2> + %86 = tt.splat %5 : i64 -> tensor<128x64xi64, #blocked2> + %87 = arith.muli %85, %86 : tensor<128x64xi64, #blocked2> + %88 = tt.broadcast %87 : tensor<128x64xi64, #blocked2> -> tensor<128x64xi64, #blocked2> + %89 = tt.addptr %79, %88 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi64, #blocked2> + %90 = tt.load %89 : tensor<128x64x!tt.ptr, #blocked2> + %91 = tt.splat %arg2 : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked> + %92 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %93 = arith.extsi %92 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %94 = tt.splat %arg28 : i64 -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %95 = arith.addi %93, %94 : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %96 = tt.expand_dims %95 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi64, #blocked> + %97 = tt.broadcast %96 : tensor<64x1xi64, #blocked> -> tensor<64x128xi64, #blocked> + %98 = tt.splat %8 : i64 -> tensor<64x128xi64, #blocked> + %99 = arith.muli %97, %98 : tensor<64x128xi64, #blocked> + %100 = tt.broadcast %99 : tensor<64x128xi64, #blocked> -> tensor<64x128xi64, #blocked> + %101 = tt.addptr %91, %100 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi64, #blocked> + %102 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %103 = arith.extsi %102 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %104 = tt.splat %arg29 : i64 -> tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %105 = arith.addi %103, %104 : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %106 = tt.expand_dims %105 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> + %107 = tt.broadcast %106 : tensor<1x128xi64, #blocked> -> tensor<64x128xi64, #blocked> + %108 = tt.splat %c1_i64 : i64 -> tensor<64x128xi64, #blocked> + %109 = arith.muli %107, %108 : tensor<64x128xi64, #blocked> + %110 = tt.broadcast %109 : tensor<64x128xi64, #blocked> -> tensor<64x128xi64, #blocked> + %111 = tt.addptr %101, %110 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi64, #blocked> + %112 = tt.load %111 : tensor<64x128x!tt.ptr, #blocked> + %113 = ttg.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %114 = ttg.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %115 = ttng.warp_group_dot %113, %114, %cst :!ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> + %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + %117 = ttg.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #smem> + %118 = ttg.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + // The first dot gets converted to dot-async + wait. The second one + // doesn't have a wait because the first wait is sufficient. + // CHECK: ttng.warp_group_dot + // CHECK: ttng.warp_group_dot_wait {{.*}}, {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait + // CHECK: scf.yield + %119 = ttng.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xf16, #shared, #smem> -> tensor<128x128xf32, #mma1> + %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %121 = arith.addf %120, %arg25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %122 = arith.extsi %c0_i32 : i32 to i64 + %123 = arith.addi %arg26, %122 : i64 + %124 = arith.extsi %c64_i32 : i32 to i64 + %125 = arith.addi %arg27, %124 : i64 + %126 = arith.extsi %c64_i32 : i32 to i64 + %127 = arith.addi %arg28, %126 : i64 + %128 = arith.extsi %c0_i32 : i32 to i64 + %129 = arith.addi %arg29, %128 : i64 + scf.yield %119, %121, %arg25, %123, %125, %127, %129 : tensor<128x128xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, i64, i64, i64, i64 + } + %42 = arith.addi %3, %11 : i32 + %43 = arith.extsi %arg17 : i32 to i64 + %44 = arith.extsi %42 : i32 to i64 + %45 = arith.extsi %c0_i32 : i32 to i64 + %46 = arith.truncf %41#0 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1> + %47 = ttg.convert_layout %46 : tensor<128x128xf16, #mma1> -> tensor<128x128xf16, #blocked> + %48 = tt.splat %arg5 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> + %49 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %50 = arith.extsi %49 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %51 = tt.splat %44 : i64 -> tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %52 = arith.addi %50, %51 : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %53 = tt.expand_dims %52 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> + %54 = tt.broadcast %53 : tensor<128x1xi64, #blocked> -> tensor<128x128xi64, #blocked> + %55 = tt.splat %43 : i64 -> tensor<128x128xi64, #blocked> + %56 = arith.muli %54, %55 : tensor<128x128xi64, #blocked> + %57 = tt.broadcast %56 : tensor<128x128xi64, #blocked> -> tensor<128x128xi64, #blocked> + %58 = tt.addptr %48, %57 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi64, #blocked> + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %60 = arith.extsi %59 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %61 = tt.splat %45 : i64 -> tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %62 = arith.addi %60, %61 : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %63 = tt.expand_dims %62 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi64, #blocked> + %64 = tt.broadcast %63 : tensor<1x128xi64, #blocked> -> tensor<128x128xi64, #blocked> + %65 = tt.splat %c1_i64 : i64 -> tensor<128x128xi64, #blocked> + %66 = arith.muli %64, %65 : tensor<128x128xi64, #blocked> + %67 = tt.broadcast %66 : tensor<128x128xi64, #blocked> -> tensor<128x128xi64, #blocked> + %68 = tt.addptr %58, %67 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi64, #blocked> + tt.store %68, %47 : tensor<128x128x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hopper.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hopper.mlir new file mode 100644 index 000000000..01c531b3a --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-hopper.mlir @@ -0,0 +1,828 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 | FileCheck %s --check-prefix=CHECK-NOCANON + +// 4 warps +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#smem = #ttg.shared_memory + +// CHECK-LABEL: tt.func @matmul_loop +// CHECK-DAG: %[[CONSTANT_NEG1:.*]] = arith.constant -1 : i32 +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] +// CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 2x128x32> +// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #smem, mutable, 2x128x32> +// CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] +// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #smem, mutable, 2x32x128> +// CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] +// CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] +// CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] +// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] +// CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] +// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_NEG1]] +// CHECK: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: ttg.async_wait {{.*}} {num = 2 : i32} +// CHECK: %[[A:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[A]] +// CHECK: %[[B:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[B]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[ASUB3:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[ASUB3]] +// CHECK: %[[BSUB3:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[BSUB3]] +// CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]] +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: dot_chained_single_load + tt.func @dot_chained_single_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x64xf32, #mma> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: scf.for + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group + // CHECK: scf.yield + %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>) : i32 { + %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %21 = ttng.warp_group_dot %19, %20, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %23 = ttg.memdesc_trans %20 {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> + %24 = ttg.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %25 = ttng.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> + } + tt.return %17#0 : tensor<128x64xf32, #mma> + } + + // Check that we are able to perform WGMMA pipelining if the accumulator is conditionally being modified + // CHECK-LABEL: dot_acc_cond_modified + tt.func @dot_acc_cond_modified(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext : i32) -> tensor<128x16xf32, #mma1> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: scf.for + // CHECK: ttg.async_wait {{.*}} {num = 2 : i32} + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32} + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group + // CHECK: scf.if + // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: arith.mulf + // CHECK: scf.yield + // CHECK: scf.yield + // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { + %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> + %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> + scf.yield %acc_zero : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + %22 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %acc_, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> + } + tt.return %17#0 : tensor<128x16xf32, #mma1> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: two_accumulator_escape + tt.func @two_accumulator_escape(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + // CHECK: %[[ALLOC1:.+]] = ttg.local_alloc + // CHECK: %[[ALLOC2:.+]] = ttg.local_alloc + // CHECK: %[[R:.+]]:{{.+}} = scf.for + // CHECK: %[[DOT1:.+]] = ttng.warp_group_dot{{.*}} + // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} + // CHECK: %[[TRANS:.+]] = ttg.memdesc_trans{{.*}} : !ttg.memdesc + // CHECK: %[[DOT2:.+]] = ttng.warp_group_dot{{.*}} %[[TRANS]] + // CHECK: ttng.warp_group_dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} + // CHECK: scf.yield + // CHECK: %{{.*}}:2 = ttng.warp_group_dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { + %21 = ttng.warp_group_dot %19, %20, %arg6 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> + %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> + %25 = ttng.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> + } + tt.return %17#0, %17#2 : tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory + +// Make sure that if one of the load dot operand is not pipelined (and therefore not double buffered) we won't use +// async dot. +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: no_wgmma_pipeline + tt.func public @no_wgmma_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %cst_0 = arith.constant dense<512> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %cst_1 = arith.constant dense<512> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %cst_2 = arith.constant dense<512> : tensor<128x1xi32, #blocked> + %cst_3 = arith.constant dense<512> : tensor<128x1xi32, #blocked1> + %cst_4 = arith.constant dense<512> : tensor<64x1xi32, #blocked1> + %cst_5 = arith.constant dense<32768> : tensor<64x256xi32, #blocked1> + %cst_6 = arith.constant dense<64> : tensor<128x64xi32, #blocked> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = arith.remsi %0, %cst_0 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %3 = arith.remsi %2, %cst_1 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %4 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %5 = arith.muli %4, %cst_2 : tensor<128x1xi32, #blocked> + %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %8 = tt.broadcast %5 : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked> + %9 = tt.broadcast %7 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %10 = arith.addi %8, %9 : tensor<128x64xi32, #blocked> + %11 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %12 = tt.addptr %11, %10 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %14 = tt.expand_dims %13 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %15 = arith.muli %14, %cst_4 : tensor<64x1xi32, #blocked1> + %16 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %17 = tt.broadcast %15 : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %18 = tt.broadcast %16 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %19 = arith.addi %17, %18 : tensor<64x256xi32, #blocked1> + %20 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> + %21 = tt.addptr %20, %19 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %22:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %21) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { + %35 = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked> + %36 = tt.load %arg6 : tensor<64x256x!tt.ptr, #blocked1> + %37 = ttg.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !ttg.memdesc<128x64xf8E5M2, #shared, #smem> + %38 = ttg.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !ttg.memdesc<64x256xf8E5M2, #shared1, #smem> + // CHECK: ttg.local_alloc + // CHECK: scf.for + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait + %39 = ttng.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x64xf8E5M2, #shared, #smem> * !ttg.memdesc<64x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> + %40 = tt.addptr %arg5, %cst_6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %41 = tt.addptr %arg6, %cst_5 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + scf.yield %39, %40, %41 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> + } + %23 = arith.truncf %22#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %26 = arith.muli %25, %cst_3 : tensor<128x1xi32, #blocked1> + %27 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %28 = tt.addptr %27, %26 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %29 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %30 = tt.broadcast %28 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> + %31 = tt.broadcast %29 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %32 = tt.addptr %30, %31 : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %33 = tt.fp_to_fp %23 {rounding = 1 : i32} : tensor<128x256xf16, #mma> -> tensor<128x256xf8E5M2, #mma> + %34 = ttg.convert_layout %33 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked1> + tt.store %32, %34 : tensor<128x256x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// A dot can be properly async if all its uses follow a synchronous MMAv3 dot. +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: async_following_sync + tt.func @async_following_sync(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { + %cst = arith.constant dense<64> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + + // Add a "dummy" early return here to test that we don't crash in the + // presence of unstructured control flow. + %cond = arith.constant 0 : i1 + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + %zero = arith.constant 0.0 : f32 + %t1 = tt.splat %zero : f32 -> tensor<128x64xf32, #mma> + %t2 = tt.splat %zero : f32 -> tensor<128x16xf32, #mma1> + tt.return %t1, %t2 : tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1> + ^bb2: // pred: ^bb0 + + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + // CHECK: %[[LOOP:[^ :]+]]{{.*}} scf.for {{.*}} iter_args(%[[PREV_DOT2:[^ ]+]] + // CHECK-NOT: ttng.warp_group_dot_wait + // CHECK: %[[DOT0:.+]] = ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait + // CHECK: %[[DOT1:.+]] = ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait + // CHECK-DAG-SAME: %[[DOT0]] + // CHECK-DAG-SAME: %[[DOT1]] + // CHECK-DAG-SAME: %[[PREV_DOT2]] + // CHECK-SAME: {pendings = 0 : i32} + // CHECK: %[[DOT2:.+]] = ttng.warp_group_dot + // CHECK-NOT: ttng.warp_group_dot_wait + // CHECK: scf.yield %[[DOT2]] + // CHECK: ttng.warp_group_dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} + %17:4 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%prev_dot2 = %cst_3, %arg5 = %16, %prev_dot1 = %cst_2, %prev_dot0 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { + // This one can be async. + %dot0 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + // This can't be async because its result is modified before it's yielded. + %dot1 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> + %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> + %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> + // This dot can be async even though %prev_dot2 is not used directly by an + // async dot, because that use follows the synchronous dot above. + %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> + %dot2 = ttng.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> + } + tt.return %17#0, %17#2 : tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1> + } +} + +// ----- +// Test pipelining of experimental_descriptor_store +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tma_store_pipeline + tt.func public @tma_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { + %1 = arith.divsi %arg4, %arg2 : i32 + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global + tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> + } + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tma_store_pipeline + tt.func public @tma_store_pipeline(%arg0: tensor<8x128xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { + %1 = arith.divsi %arg4, %arg2 : i32 + %2 = tt.splat %1 : i32 -> tensor<8xi32, #blocked1> + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_scatter + tt.experimental_descriptor_scatter %arg1[%2, %1], %arg0 : !tt.tensordesc>, tensor<8xi32, #blocked1>, i32, tensor<8x128xf32, #blocked> + } + tt.return + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tma_multiple_store_pipeline + tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + // CHECK: scf.for + scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { + %1 = arith.divsi %arg4, %arg2 : i32 + %2 = arith.divsi %arg2, %arg4 : i32 + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store %{{.+}}, %[[ALLOC]] + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] + // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: ttg.local_store %{{.+}}, %[[ALLOC]] + // CHECK-NEXT: ttng.fence_async_shared + // CHECK-NEXT: ttng.tensor_desc_to_tma_ptr + // CHECK-NEXT: ttng.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] + tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> + tt.experimental_descriptor_store %arg1[%2], %arg0 : !tt.tensordesc>, tensor<1xf32, #blocked> + } + tt.return + } +} + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: _kernel_matmul_dependency + tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr, #blocked>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %cst_0 = arith.constant 1.000000e+00 : f32 + %c8_i32 = arith.constant 8 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked1> + %2:4 = scf.for %arg6 = %c8_i32 to %arg3 step %c8_i32 iter_args(%arg7 = %c8_i32, %arg8 = %c8_i32, %arg9 = %cst_1, %arg10 = %arg5) -> (i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) : i32 { + %3 = arith.addi %arg7, %c8_i32 : i32 + %4 = arith.cmpi eq, %3, %c8_i32 : i32 + %5:2 = scf.if %4 -> (i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) { + %21 = arith.addi %arg8, %c8_i32 : i32 + scf.yield %21, %arg5 : i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + } else { + scf.yield %arg8, %arg10 : i32, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + } + %6 = arith.cmpi eq, %3, %c8_i32 : i32 + %7 = scf.if %6 -> (f32) { + scf.yield %cst_0 : f32 + } else { + %21 = tt.load %arg4 : !tt.ptr + scf.yield %21 : f32 + } + %8 = tt.splat %3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %9 = arith.addi %8, %0 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %11 = tt.broadcast %10 : tensor<128x1xi32, #blocked1> -> tensor<128x128xi32, #blocked1> + %12 = tt.addptr %1, %11 : tensor<128x128x!tt.ptr, #blocked1>, tensor<128x128xi32, #blocked1> + %13 = tt.load %arg0 : tensor<128x128x!tt.ptr, #blocked> + %14 = ttg.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared, #smem> + %15 = tt.load %12 : tensor<128x128x!tt.ptr, #blocked1> + %16 = ttg.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1, #smem> + %17 = ttng.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x128xf8E4M3FNUZ, #shared, #smem> * !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1, #smem> -> tensor<128x128xf32, #mma> + %18 = tt.splat %7 : f32 -> tensor<128x128xf32, #mma> + %19 = arith.mulf %17, %18 : tensor<128x128xf32, #mma> + %20 = scf.if %6 -> (tensor<128x128xf32, #mma>) { + scf.yield %cst_1 : tensor<128x128xf32, #mma> + } else { + scf.yield %19 : tensor<128x128xf32, #mma> + } + scf.yield %3, %5#0, %20, %5#1 : i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + } + tt.return + } +} + +// ----- + +// Pipeline the if ops at the beginning and the end of the loop +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // COMMON-LABEL: dot_prologue_epilogue + // COMMON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} + tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // COMMON: %[[C0:.*]] = arith.constant 0 : i32 + // COMMON: scf.for %[[IND_VAR:.*]] = %[[C0]] + // COMMON-NOT: load + // COMMON: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] + // COMMON: scf.if %[[CND]] + // COMMON: dot + // COMMON: scf.if %[[CND]] + // COMMON: arith.mulf + // COMMON: scf.yield + // COMMON-NOT: tt.addptr + // COMMON: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { + %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %inc_ptr = scf.if %cnd -> tensor<64x16x!tt.ptr, #blocked> { + %ptr = tt.addptr %arg5, %inc : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %ptr : tensor<64x16x!tt.ptr, #blocked> + } else { + scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> + } + %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> + scf.yield %acc_zero : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + %22 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %acc_, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> + } + tt.return %17#0 : tensor<128x16xf32, #mma1> + } +} + +// ----- + +// Verify that uses of the ops scheduled in partucular place of the loop (like epilogue if) are correctly scheduled too. +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-NOCANON-LABEL: pipeline_downstream_dependencies + // CHECK-NOCANON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} + tt.func @pipeline_downstream_dependencies(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst1 = arith.constant dense<1> : tensor<64x16xi32, #blocked> + %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK-NOCANON: %[[C0:.*]] = arith.constant 0 : i32 + // CHECK-NOCANON: scf.for %[[IND_VAR:.*]] = %[[C0]] + // CHECK-NOCANON-NOT load + // CHECK-NOCANON: dot + // CHECK-NOCANON: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] + // CHECK-NOCANON: %[[IFRET:.*]]:2 = scf.if %[[CND]] + // CHECK-NOCANON: arith.mulf + // CHECK-NOCANON: scf.yield + // CHECK-NOCANON: tt.addptr {{.*}}, %[[IFRET]]#1 + // CHECK-NOCANON: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { + %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> + %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { + %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> + scf.yield %acc_zero, %cst : tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked> + } else { + scf.yield %acc, %cst1 : tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked> + } + %22 = tt.addptr %arg5, %if_ret#1 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %if_ret#0, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> + } + tt.return %17#0 : tensor<128x16xf32, #mma1> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: dot_lhs_registers + tt.func @dot_lhs_registers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_3 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %cst_4 = arith.constant dense<2.0> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: scf.for + // CHECK: ttg.async_wait {{.*}} {num = 2 : i32} + // CHECK: ttg.local_load + // CHECK: ttng.warp_group_dot + // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group + // CHECK: ttg.async_copy_global_to_local + // CHECK: ttg.async_commit_group + // CHECK: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %8, %arg6 = %16) -> (tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, + tensor<64x16x!tt.ptr, #blocked>) : i32 { + %a_block = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked1> + %b_block = tt.load %arg6 : tensor<64x16x!tt.ptr, #blocked> + %a_dotop = ttg.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %b_smem = ttg.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %21 = ttng.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma> + %25 = tt.addptr %arg5, %cst_3 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %26 = tt.addptr %arg6, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %21, %25, %26 : tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x16x!tt.ptr, #blocked> + } + tt.return %17#0 : tensor<128x16xf32, #mma> + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @mmav3_fp8_row_major_rhs(%arg0: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg1: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg2: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + // CHECK-LABEL: mmav3_fp8_row_major_rhs + // The col-major RHS SMEM encoding in the input, created by accelerate-matmul, should be overwritten by the row-major TMA layout. + // Note that this "overwriting" makes the program invalid after SWP, since warp_group_dot does not support row-major fp8 RHS. + // In this case, the TMA load on B should not be pipelined. When this bug is fixed, this test should be rewritten to verify that. + // CHECK-NOT: order = [0, 1] + %c128_i32 = arith.constant 128 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.remsi %0, %2 : i32 + %4 = arith.divsi %0, %2 : i32 + %5 = arith.muli %3, %c128_i32 : i32 + %6 = arith.muli %4, %c64_i32 : i32 + %7 = arith.addi %arg5, %c63_i32 : i32 + %8 = arith.divsi %7, %c64_i32 : i32 + %9 = tt.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> + %10 = tt.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> + %true = arith.constant true + %false = arith.constant false + %11:2 = scf.for %arg6 = %c0_i32 to %8 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x64xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 32]}>>, i32) : i32 { + %14 = tt.experimental_descriptor_load %9[%5, %arg8] : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #blocked> + %15 = ttg.local_alloc %14 : (tensor<128x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared, #ttg.shared_memory> + %16 = tt.experimental_descriptor_load %10[%arg8, %6] : !tt.tensordesc> -> tensor<64x64xf8E4M3FN, #blocked> + %17 = ttg.local_alloc %16 : (tensor<64x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<64x64xf8E4M3FN, #shared1, #ttg.shared_memory> + %18 = ttng.warp_group_dot %15, %17, %arg7 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x64xf8E4M3FN, #shared, #ttg.shared_memory> * !ttg.memdesc<64x64xf8E4M3FN, #shared1, #ttg.shared_memory> -> tensor<128x64xf32, #mma> + %19 = arith.addi %arg8, %c64_i32 : i32 + scf.yield %18, %19 : tensor<128x64xf32, #mma>, i32 + } + %12 = ttg.convert_layout %11#0 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> + %13 = tt.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc> + tt.experimental_descriptor_store %13[%5, %6], %12 : !tt.tensordesc>, tensor<128x64xf32, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-indirect-load.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-indirect-load.mlir new file mode 100644 index 000000000..c02a3094f --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline-indirect-load.mlir @@ -0,0 +1,90 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=2 | FileCheck %s +// CHECK-LABEL: @indirect_load_two_stages +// CHECK: scf.for +// CHECK: tt.dot +// CHECK: tt.load +// CHECK: async_copy_global_to_local +// CHECK: async_copy_global_to_local + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @indirect_load_two_stages(%arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked> + + %0 = tt.get_program_id y : i32 + %1 = tt.addptr %arg3, %0 : !tt.ptr, i32 + %2 = tt.load %1 : !tt.ptr + + %7 = tt.get_program_id x : i32 + %8 = arith.muli %7, %c16_i32 : i32 + %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %15 = tt.splat %8 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %18 = arith.addi %15, %10 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + + %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %34 = arith.extsi %arg12 : i32 to i64 + %35 = arith.muli %2, %34 : i64 + %36 = tt.addptr %arg2, %35 : !tt.ptr, i64 + + %47 = tt.splat %arg4 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>> + %48 = tt.addptr %47, %20 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> + %61 = arith.extsi %59 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked3}>> + %63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3> + + %85 = arith.extsi %22 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %107 = tt.splat %36 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked3> + %108 = tt.splat %34 : i64 -> tensor<32x1xi64, #blocked3> + %109 = tt.broadcast %63 : tensor<1x128xi64, #blocked3> -> tensor<32x128xi64, #blocked3> + + %101 = tt.splat %arg5 : !tt.ptr -> tensor<16x32x!tt.ptr, #blocked1> + %111:1 = scf.for %arg28 = %arg18 to %arg19 step %c32_i32 iter_args(%arg29 = %cst) -> (tensor<16x128xf32, #blocked>) : i32 { + %129 = tt.splat %arg28 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %160 = tt.addptr %48, %129 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %161 = tt.load %160 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked1}>> + %162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1> + %163 = tt.broadcast %162 : tensor<1x32xi64, #blocked1> -> tensor<16x32xi64, #blocked1> + %182 = tt.addptr %101, %163 : tensor<16x32x!tt.ptr, #blocked1>, tensor<16x32xi64, #blocked1> + %183 = tt.load %182 : tensor<16x32x!tt.ptr, #blocked1> + + %197 = arith.extsi %arg28 : i32 to i64 + %198 = tt.splat %197 : i64 -> tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %199 = arith.addi %198, %85 : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> + %200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3> + %201 = arith.muli %200, %108 : tensor<32x1xi64, #blocked3> + %202 = tt.broadcast %201 : tensor<32x1xi64, #blocked3> -> tensor<32x128xi64, #blocked3> + %203 = arith.addi %202, %109 : tensor<32x128xi64, #blocked3> + %204 = tt.addptr %107, %203 : tensor<32x128x!tt.ptr, #blocked3>, tensor<32x128xi64, #blocked3> + %209 = tt.load %204 : tensor<32x128x!tt.ptr, #blocked3> + + %210 = ttg.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %211 = ttg.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked> + scf.yield %212 : tensor<16x128xf32, #blocked> + } + %112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3> + %113 = tt.splat %2 : i64 -> tensor<16x1xi64, #blocked3> + %114 = arith.extsi %112 : tensor<16x1xi32, #blocked3> to tensor<16x1xi64, #blocked3> + %115 = arith.addi %113, %114 : tensor<16x1xi64, #blocked3> + %116 = arith.extsi %arg17 : i32 to i64 + %117 = tt.splat %116 : i64 -> tensor<16x1xi64, #blocked3> + %118 = arith.muli %115, %117 : tensor<16x1xi64, #blocked3> + %119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> + %120 = tt.broadcast %118 : tensor<16x1xi64, #blocked3> -> tensor<16x128xi64, #blocked3> + %121 = arith.extsi %119 : tensor<1x128xi32, #blocked3> to tensor<1x128xi64, #blocked3> + %122 = tt.broadcast %121 : tensor<1x128xi64, #blocked3> -> tensor<16x128xi64, #blocked3> + %123 = arith.addi %120, %122 : tensor<16x128xi64, #blocked3> + %124 = tt.splat %arg7 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked3> + %125 = tt.addptr %124, %123 : tensor<16x128x!tt.ptr, #blocked3>, tensor<16x128xi64, #blocked3> + %128 = ttg.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3> + tt.store %125, %128 : tensor<16x128x!tt.ptr, #blocked3> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline.mlir new file mode 100644 index 000000000..af4f60c07 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-pipeline.mlir @@ -0,0 +1,1698 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=3 global_prefetch=1 local_prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH + +// 4 warps +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#smem = #ttg.shared_memory + +// CHECK-LABEL: tt.func @matmul_loop +// CHECK-DAG: %[[CONSTANT_NEG1:.*]] = arith.constant -1 : i32 +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] +// CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] +// CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] +// CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} +// CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] +// CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] +// CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] +// CHECK-DAG: %[[ASUB1:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_A1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB1]] mask %[[LOOP_COND_1_SPLAT_A]] +// CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] +// CHECK-DAG: %[[BSUB1:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[T_B1:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB1]] mask %[[LOOP_COND_1_SPLAT_B]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_NEG1]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} +// CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[A0]] +// CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[B0]] +// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[ASUB3:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[ASUB3]] +// CHECK: %[[BSUB3:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = ttg.async_copy_global_to_local {{.*}}, %[[BSUB3]] +// CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]] + +// AMD-LABEL: tt.func @matmul_loop +// AMD-DAG: %[[CM1:.*]] = arith.constant -1 : index +// AMD-DAG: %[[C1:.*]] = arith.constant 1 : index +// AMD-DAG: %[[C0:.*]] = arith.constant 0 : index +// AMD: %[[UB1:.*]] = arith.subi %[[UB:.*]], %arg2 : index +// AMD: %[[FOR:.*]]:6 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UB1]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) +// AMD: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[ADDPTR_35:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// AMD: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]] +// AMD: %[[LOCAL_LOAD_37:.*]] = ttg.local_load %[[ARG10]] +// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_35]] +// AMD: %[[LOCAL_LOAD_39:.*]] = ttg.local_load %[[ARG11]] +// AMD: %[[MULF_40:.*]] = arith.mulf %[[LOCAL_LOAD_39]], %{{.*}} +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[MULF_40]], %[[ARG8]] +// AMD: %[[ADDI_42:.*]] = arith.addi %[[ARG9]], %{{.*}} +// AMD: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} +// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_45]] +// AMD: %[[MEMDESC_SUBVIEW_46:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_46]] +// AMD: scf.yield %[[ADDPTR_34]], %[[ADDPTR_35]], %[[DOT_41]], %[[SELECT_44]], %[[MEMDESC_SUBVIEW_45]], %[[MEMDESC_SUBVIEW_46]] +// AMD: } +// AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]] +// AMD: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[C1]], %[[CM1]] +// AMD: %[[SUBI_23:.*]] = arith.subi %[[UB]], %[[LB]] +// AMD: %[[ADDI_24:.*]] = arith.addi %[[SUBI_23]], %[[STEP]] +// AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]] +// AMD: %[[DIVSI_26:.*]] = arith.divsi %[[ADDI_25]], %[[STEP]] +// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %[[DIVSI_26]], %{{.*}} +// AMD: %[[LOCAL_LOAD_28:.*]] = ttg.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_29:.*]] = ttg.local_load %{{.*}}#5 +// AMD: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} +// AMD: %[[IF_31:.*]] = scf.if %[[CMPI_27]] +// AMD: %[[DOT_33:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %{{.*}}#2 +// AMD: scf.yield %[[DOT_33]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#2 +// AMD: } +// AMD: %[[SELECT_32:.*]] = arith.select %[[CMPI_27]], %[[IF_31]], %{{.*}}#2 +// AMD: ttg.local_dealloc %{{.*}} +// AMD: ttg.local_dealloc %{{.*}} + +// Prefetch pipelining adds another stage in between global load and compute. +// This stage will local_store, then local_load, creating a prefetch from shared +// memory into a register buffer for compute. +// +// AMD_PREFETCH-LABEL: tt.func @matmul_loop +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @matmul_loop(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %b_scale = arith.constant dense<4.> : tensor<32x128xf16, #B> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b__ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b_ = ttg.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: tt.func @matmul_loop_nested +// CHECK-DAG: %[[CONSTANT_NEG1:.*]] = arith.constant -1 : i32 +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK: scf.for +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_NEG1]]{{.*}} +// CHECK: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: ttg.async_wait {{.*}} {num = 2 : i32} +// CHECK: %[[A:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[EXT_IDX_3]], +// CHECK: %[[arg_a0_dot_op:.*]] = ttg.local_load %[[A]] +// CHECK: %[[B:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]], +// CHECK: %[[arg_b0_dot_op_0:.*]] = ttg.local_load %[[B]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]] +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK scf.yield + +// AMD-LABEL: tt.func @matmul_loop_nested +// AMD: scf.for +// AMD-COUNT-2: ttg.local_alloc +// AMD-COUNT-2: tt.load +// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: %[[FOR:.*]]:6 = scf.for +// AMD-COUNT-2: tt.addptr +// AMD: tt.load +// AMD: ttg.local_load +// AMD: tt.load +// AMD: ttg.local_load +// AMD: tt.dot +// AMD: %[[SUBVIEW0:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = ttg.memdesc_subview +// AMD: ttg.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: scf.yield +// AMD-COUNT-2: ttg.local_load +// AMD: %[[IF1:.*]] = scf.if +// AMD: %[[DOT1:.*]] = tt.dot +// AMD: scf.yield %[[DOT1]] +// AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[IF1]], %[[FOR]]#2 +// AMD-COUNT-2: ttg.local_dealloc +// AMD: scf.yield %[[SEL1]] + +// AMD_PREFETCH-LABEL: tt.func @matmul_loop_nested + +tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ + + %c_start = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %loop1:1 = scf.for %iv0 = %lb to %ub step %step iter_args(%c_init = %c_start) -> (tensor<128x128xf32, #C>) { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop2:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + + scf.yield %loop2#2 : tensor<128x128xf32, #C> + } + tt.return %loop1#0 : tensor<128x128xf32, #C> +} + +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline +// CHECK-DAG: %[[CONSTANT_NEG1:.*]] = arith.constant -1 : i32 +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_NEG1]] +// CHECK: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} +// CHECK: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[EXT_IDX_3]] +// CHECK: %[[arg_b0_dot_op:.*]] = ttg.local_load %[[B0]] +// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[INS_IDX_3]], %[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]] + +// AMD-LABEL: tt.func @matmul_loop_single_pipeline +// AMD: %[[LOAD_10:.*]] = tt.load %{{.*}} +// AMD: %[[CONVERT_LAYOUT_11:.*]] = ttg.convert_layout %[[LOAD_10]] +// AMD: %[[LOCAL_ALLOC_12:.*]] = ttg.local_alloc +// AMD: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] +// AMD: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] +// AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_17]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[MEMDESC_SUBVIEW_16]]) +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_32]] +// AMD: %[[LOCAL_LOAD_30:.*]] = ttg.local_load %[[ARG9]] +// AMD: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_30]], %[[ARG7]] +// AMD: %[[ADDI_34:.*]] = arith.addi %[[ARG8]], %{{.*}} +// AMD: %[[CMPI_35:.*]] = arith.cmpi slt, %[[ADDI_34]], %{{.*}} +// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_35]], %[[ADDI_34]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_36]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_33]], %[[MEMDESC_SUBVIEW_37]] +// AMD: scf.yield %[[ADDPTR_32]], %[[DOT_31]], %[[SELECT_36]], %[[MEMDESC_SUBVIEW_37]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_12]] + +// AMD_PREFETCH-LABEL: tt.func @matmul_loop_single_pipeline +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return + +tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#1 : tensor<128x128xf32, #C> +} + +// CHECK-LABEL: tt.func @indirect_bmm_scalar +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: scf.for +// CHECK: ttg.async_wait {{.*}} {num = 2 : i32} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: %[[IND_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} +// CHECK: %[[IND_BUFFER_1:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_0]] +// CHECK: %[[IND_BUFFER_2:.*]] = tt.splat %[[IND_BUFFER_1]] +// CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] + +// AMD-LABEL: tt.func @indirect_bmm_scalar +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] {OpIdx = #amdgpu.OpIdx<0>} +// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] +// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] +// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] {OpIdx = #amdgpu.OpIdx<1>} +// AMD: %[[MEMDESC_SUBVIEW_11:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_11]] {OpIdx = #amdgpu.OpIdx<0>} +// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_12]] {OpIdx = #amdgpu.OpIdx<1>} +// AMD: %[[CMPI_13:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_15:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_16:.*]] = tt.splat %[[CMPI_13]] +// AMD: %[[LOAD_17:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_16]] {OpIdx = #amdgpu.OpIdx<0>} +// AMD: %[[LOAD_18:.*]] = tt.load %[[ADDPTR_15]], %[[CMPI_13]] +// AMD: %[[MULI_19:.*]] = arith.muli %{{.*}}, %[[LOAD_18]] +// AMD: %[[SPLAT_20:.*]] = tt.splat %[[MULI_19]] +// AMD: %[[ADDPTR_21:.*]] = tt.addptr %{{.*}}, %[[SPLAT_20]] +// AMD: %[[SPLAT_22:.*]] = tt.splat %[[CMPI_13]] +// AMD: %[[LOAD_23:.*]] = tt.load %[[ADDPTR_21]], %[[SPLAT_22]] {OpIdx = #amdgpu.OpIdx<1>} +// AMD: %[[MEMDESC_SUBVIEW_24:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_17]], %[[MEMDESC_SUBVIEW_24]] {OpIdx = #amdgpu.OpIdx<0>} +// AMD: %[[MEMDESC_SUBVIEW_25:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_23]], %[[MEMDESC_SUBVIEW_25]] {OpIdx = #amdgpu.OpIdx<1>} +// AMD: %[[SUBI_26:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_26]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_14]], %[[ARG9:.*]] = %[[ADDPTR_15]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_11]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_24]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_12]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_25]]) +// AMD: %[[ADDPTR_38:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_40:.*]] = tt.load %[[ADDPTR_38]] {OpIdx = #amdgpu.OpIdx<0>} +// AMD: %[[LOCAL_LOAD_41:.*]] = ttg.local_load %[[ARG11]] +// AMD: %[[LOAD_42:.*]] = tt.load %[[ADDPTR_39]] +// AMD: %[[MULI_43:.*]] = arith.muli %{{.*}}, %[[LOAD_42]] +// AMD: %[[SPLAT_44:.*]] = tt.splat %[[MULI_43]] +// AMD: %[[ADDPTR_45:.*]] = tt.addptr %{{.*}}, %[[SPLAT_44]] +// AMD: %[[LOAD_46:.*]] = tt.load %[[ADDPTR_45]] {OpIdx = #amdgpu.OpIdx<1>} +// AMD: %[[LOCAL_LOAD_47:.*]] = ttg.local_load %[[ARG13]] +// AMD: %[[DOT_48:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_47]], %[[ARG7]] +// AMD: %[[ADDI_49:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_50:.*]] = arith.cmpi slt, %[[ADDI_49]], %{{.*}} +// AMD: %[[SELECT_51:.*]] = arith.select %[[CMPI_50]], %[[ADDI_49]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_52:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_51]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_40]], %[[MEMDESC_SUBVIEW_52]] {OpIdx = #amdgpu.OpIdx<0>} +// AMD: %[[MEMDESC_SUBVIEW_53:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_51]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_46]], %[[MEMDESC_SUBVIEW_53]] {OpIdx = #amdgpu.OpIdx<1>} +// AMD: scf.yield %[[DOT_48]], %[[ADDPTR_38]], %[[ADDPTR_39]], %[[SELECT_51]], %[[ARG12]], %[[MEMDESC_SUBVIEW_52]], %[[ARG14]], %[[MEMDESC_SUBVIEW_53]] +// AMD: } {tt.num_stages = 3 +// AMD: %[[CMPI_28:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_29:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[LOCAL_LOAD_30:.*]] = ttg.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_31:.*]] = ttg.local_load %{{.*}}#6 +// AMD: %[[IF_32:.*]] = scf.if %[[CMPI_28]] +// AMD: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_38]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_28]], %[[IF_32]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_34:.*]] = ttg.local_load %{{.*}}#5 +// AMD: %[[LOCAL_LOAD_35:.*]] = ttg.local_load %{{.*}}#7 +// AMD: %[[IF_36:.*]] = scf.if %[[CMPI_29]] +// AMD: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_34]], %[[LOCAL_LOAD_35]], %[[SELECT_33]] +// AMD: scf.yield %[[DOT_38]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_33]] +// AMD: } +// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_29]], %[[IF_36]], %[[SELECT_33]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_0]] +// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_1]] + +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return + +tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: !tt.ptr, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21 : !tt.ptr + %84 = arith.muli %77, %83 : i64 + %85 = tt.splat %84 : i64 -> tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr + } {tt.num_stages = 3 : i32} + tt.return %79#0 : tensor<16x16xf32, #C> +} + +// CHECK-LABEL: tt.func @indirect_bmm_scalar_dist_one +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: scf.for %{{.*}} iter_args(%{{[^,]*}}, %{{[^,]*}}, %{{[^,]*}}, %[[IND_BUFFER_PREV:[^,]*]] = {{[^,]*}} +// CHECK: ttg.async_wait {{.*}} {num = 2 : i32} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: %[[IND_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} +// CHECK: %[[IND_BUFFER_1:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_PREV]] +// CHECK: %[[IND_BUFFER_2:.*]] = tt.splat %[[IND_BUFFER_1]] +// CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[IND_BUFFER_0]] + +// AMD-LABEL: tt.func @indirect_bmm_scalar_dist_one +// AMD-COUNT-4: tt.load +// AMD: scf.for +// AMD: tt.load +// AMD: tt.dot +// AMD: ttg.local_store +// AMD: scf.yield + +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar_dist_one + +tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: !tt.ptr, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %50 = tt.load %75 : !tt.ptr + %51 = tt.addptr %75, %c1_i32 : !tt.ptr, i32 + %79:4 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %51, %arg22 = %50) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr, i64) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21 : !tt.ptr + %84 = arith.muli %77, %arg22 : i64 + %85 = tt.splat %84 : i64 -> tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 + scf.yield %90, %91, %92, %83 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr, i64 + } + tt.return %79#0 : tensor<16x16xf32, #C> +} + +// CHECK-LABEL: tt.func @indirect_bmm_vector +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: scf.for +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} +// CHECK: tt.dot +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK-DAG: %[[IND_BUFFER_WAIT_TOKEN:.*]] = ttg.async_wait {{.*}} {num = 1 : i32} +// CHECK-DAG: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview +// CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] token %[[IND_BUFFER_WAIT_TOKEN]] +// CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} +// CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] +// CHECK: %[[IND_BUFFER_4:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_3]] +// CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] +// CHECK: scf.yield + +// AMD-LABEL: tt.func @indirect_bmm_vector +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[CMPI_5:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_6:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_8:.*]] = tt.load %{{.*}}, %[[SPLAT_7]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_5]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_9]] +// AMD: %[[EXPAND_DIMS_11:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} +// AMD: %[[BROADCAST_12:.*]] = tt.broadcast %[[EXPAND_DIMS_11]] +// AMD: %[[MULI_13:.*]] = arith.muli %{{.*}}, %[[BROADCAST_12]] +// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %[[MULI_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_15]] +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] +// AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[LOAD_10]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]]) +// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[LOCAL_LOAD_50:.*]] = ttg.local_load %[[ARG11]] +// AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] +// AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} +// AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] +// AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] +// AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] +// AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] +// AMD: %[[LOCAL_LOAD_57:.*]] = ttg.local_load %[[ARG13]] +// AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] +// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} +// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] + +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_vector + +tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: tensor<16x!tt.ptr, #BLs1>, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_splat = tt.splat %c1_i32 : i32 -> tensor<16xi32, #BLs1> + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1>) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21 : tensor<16x!tt.ptr, #BLs1> + %84 = tt.expand_dims %83 {axis=1: i32}: tensor<16xi64, #BLs1> -> tensor<16x1xi64, #BL> + %850 = tt.broadcast %84 : tensor<16x1xi64, #BL> -> tensor<16x16xi64, #BL> + %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> + } {tt.num_stages = 3 : i32} + tt.return %79#0 : tensor<16x16xf32, #C> +} + +// COMMON-LABEL: tt.func @post_load_inv +// COMMON: scf.for +// COMMON-DAG: %[[IV:.*]] = arith.index_cast +// COMMON: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 +// COMMON: arith.index_cast +// COMMON-NOT: arith.addi %[[NEXT_IV]] +tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> { + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %84 = arith.constant 900 : index + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL> + %50 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #AL> + %59 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %81 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %66 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #AL> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %82 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %85:3 = scf.for %arg9 = %c0_index to %84 step %c1_index iter_args(%arg10 = %cst, %arg11 = %59, %arg12 = %81) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>) { + %130 = arith.index_cast %arg9 : index to i32 + %107 = arith.muli %130, %c32_i32 : i32 + %108 = arith.subi %arg5, %107 : i32 + %109 = tt.splat %108 : i32 -> tensor<1x32xi32, #AL> + %110 = arith.cmpi "slt", %50, %109 : tensor<1x32xi32, #AL> + %111 = tt.broadcast %110 : tensor<1x32xi1, #AL> -> tensor<32x32xi1, #AL> + %112 = tt.load %arg11, %111, %cst_0 : tensor<32x32x!tt.ptr, #AL> + %113 = tt.splat %108 : i32 -> tensor<32x1xi32, #AL> + %114 = arith.cmpi "slt", %66, %113 : tensor<32x1xi32, #AL> + %115 = tt.broadcast %114 : tensor<32x1xi1, #AL> -> tensor<32x32xi1, #AL> + %116 = tt.load %arg12, %115, %cst_0 : tensor<32x32x!tt.ptr, #AL> + %117 = ttg.convert_layout %112 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %118 = ttg.convert_layout %116 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %119 = tt.dot %117, %118, %arg10 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %131 = arith.index_cast %arg9 : index to i32 + %120 = arith.addi %131, %c1_i32 : i32 + %121 = arith.muli %120, %c32_i32 : i32 + %122 = tt.splat %121 : i32 -> tensor<32x32xi32, #AL> + %123 = tt.addptr %60, %122 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + %124 = arith.muli %121, %arg7 : i32 + %125 = tt.splat %124 : i32 -> tensor<32x32xi32, #AL> + %126 = tt.addptr %82, %125 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + scf.yield %119, %123, %126 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL> + } + tt.return %85#0 : tensor<32x32xf32, #C> +} + +// COMMON-LABEL: tt.func @cross_iter_dep +// TODO: enable pipelining with distance of 2 +// COMMON-NOT: ttg.async_commit_group +// COMMON: scf.for +// COMMON: scf.yield + +tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> { + %c0_i32 = arith.constant 0 : index + %118 = arith.constant 32 : index + %c1_i32 = arith.constant 1 : index + %c2_i32 = arith.constant 2 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL> + %78 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %110 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %112 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %113 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %116 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %65 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #AL> + %88 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #AL> + %80 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> + %119:5 = scf.for %arg9 = %c0_i32 to %118 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %78, %arg12 = %110, %arg13 = %113, %arg14 = %116) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>) { + %161 = arith.index_cast %arg9 : index to i32 + %141 = arith.muli %161, %c32_i32 : i32 + %142 = arith.subi %arg5, %141 : i32 + %143 = tt.splat %142 : i32 -> tensor<1x32xi32, #AL> + %144 = arith.cmpi "slt", %65, %143 : tensor<1x32xi32, #AL> + %145 = tt.broadcast %144 : tensor<1x32xi1, #AL> -> tensor<32x32xi1, #AL> + %146 = tt.load %arg11, %145, %cst_1 : tensor<32x32x!tt.ptr, #AL> + %147 = tt.splat %142 : i32 -> tensor<32x1xi32, #AL> + %148 = arith.cmpi "slt", %88, %147 : tensor<32x1xi32, #AL> + %149 = tt.broadcast %148 : tensor<32x1xi1, #AL> -> tensor<32x32xi1, #AL> + %150 = tt.load %arg12, %149, %cst_1 : tensor<32x32x!tt.ptr, #AL> + %151 = ttg.convert_layout %146 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> + %152 = ttg.convert_layout %150 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> + %153 = tt.dot %151, %152, %arg10 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %162 = arith.index_cast %arg9 : index to i32 + %154 = arith.addi %162, %c2_i32 : i32 + %155 = arith.muli %154, %c32_i32 : i32 + %156 = tt.splat %155 : i32 -> tensor<32x32xi32, #AL> + %157 = tt.addptr %80, %156 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + %158 = arith.muli %155, %arg7 : i32 + %159 = tt.splat %158 : i32 -> tensor<32x32xi32, #AL> + %160 = tt.addptr %112, %159 : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> + scf.yield %153, %arg13, %arg14, %157, %160 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL>, tensor<32x32x!tt.ptr, #AL> + } + tt.return %119#0 : tensor<32x32xf32, #C> +} + +// COMMON-LABEL: tt.func @dep_arg_two_uses +// COMMON: tt.expand_dims +// COMMON: tt.expand_dims +// COMMON: tt.expand_dims %arg5 +// COMMON: %[[PTR0:.*]] = tt.splat %arg6 +// COMMON: %[[PTR1:.*]] = tt.addptr %[[PTR0]] +// COMMON-NEXT: tt.load %[[PTR1]] +tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { + %23 = arith.constant 100 : index + %c64 = arith.constant 64 : i64 + %56 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %57 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %58 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %83 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %85 = tt.splat %c64 : i64 -> tensor<1x32xi64, #AL> + %86 = tt.splat %c64 : i64 -> tensor<1x32xi64, #AL> + %68 = tt.splat %arg0 : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %c32_index = arith.constant 32 : index + %c32_i32 = arith.index_cast %c32_index : index to i32 + %80 = tt.splat %arg2 : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %cst_6 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #BL> + %88 = arith.truncf %cst_6 : tensor<32x128xf32, #BL> to tensor<32x128xf16, #BL> + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #C> + %90 = tt.splat %c64 : i64 -> tensor<32x128xi64, #BL> + %92 = tt.addptr %arg1, %c32_i32 : !tt.ptr, i32 + %c0_index = arith.constant 0 : index + %91:5 = scf.for %arg19 = %c0_index to %23 step %c32_index iter_args(%arg20 = %68, %arg21 = %83, %arg22 = %92, %arg23 = %cst, %arg24 = %80) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL>) { + %1750 = arith.subi %23, %arg19 : index + %175 = arith.index_cast %1750 : index to i32 + %176 = tt.splat %175 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %177 = tt.splat %175 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %178 = arith.cmpi "slt", %57, %176 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %179 = arith.cmpi "slt", %58, %177 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #BL}>> + %180 = tt.expand_dims %178 {axis = 0 : i32} : tensor<32xi1, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi1, #AL> + %181 = tt.expand_dims %179 {axis = 1 : i32} : tensor<32xi1, #ttg.slice<{dim = 1, parent = #BL}>> -> tensor<32x1xi1, #BL> + %182 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> + %183 = tt.expand_dims %arg21 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> -> tensor<1x32xi32, #AL> + %184 = arith.extsi %182 : tensor<1x32xi32, #AL> to tensor<1x32xi64, #AL> + %185 = arith.extsi %183 : tensor<1x32xi32, #AL> to tensor<1x32xi64, #AL> + %186 = arith.muli %184, %85 : tensor<1x32xi64, #AL> + %187 = arith.muli %185, %86 : tensor<1x32xi64, #AL> + %188 = tt.broadcast %186 : tensor<1x32xi64, #AL> -> tensor<128x32xi64, #AL> + %189 = tt.broadcast %187 : tensor<1x32xi64, #AL> -> tensor<128x32xi64, #AL> + %190 = tt.addptr %arg20, %188 : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi64, #AL> + %191 = tt.addptr %arg20, %189 : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi64, #AL> + %192 = tt.broadcast %180 : tensor<1x32xi1, #AL> -> tensor<128x32xi1, #AL> + %193 = tt.load %191, %192 : tensor<128x32x!tt.ptr, #AL> + %194 = tt.splat %arg22 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>> + %195 = tt.addptr %194, %56 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>> + %196 = tt.load %195 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #AL}>> + %197 = tt.addptr %arg22, %c32_i32 : !tt.ptr, i32 + %198 = tt.broadcast %181 : tensor<32x1xi1, #BL> -> tensor<32x128xi1, #BL> + %199 = tt.load %arg24, %198, %88 : tensor<32x128x!tt.ptr, #BL> + %200 = ttg.convert_layout %193 : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> + %201 = ttg.convert_layout %199 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> + %202 = tt.dot %200, %201, %arg23 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> + %203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi64, #BL> + scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> + } + tt.return %91#3 : tensor<128x128xf32, #C> +} +} // end module + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1]}> +#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +// COMMON-LABEL: tt.func @load_two_users_incompatible_layouts + tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // check that the load didn't get pipelined. + // COMMON-NOT: alloc + // COMMON: scf.for + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = ttg.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem> -> !ttg.memdesc<16x64xf16, #shared1, #smem> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: nested_loops +// CHECK: scf.for +// CHECK: ttg.local_alloc +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_commit_group +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: ttg.async_wait {num = 0 : i32} + +// AMD-LABEL: tt.func public @nested_loops +// AMD: scf.for +// AMD: ttg.local_alloc +// AMD-NOT: ttg.local_alloc +// AMD: scf.for +// AMD: scf.yield +// AMD-DIS: scf.yield + +// +// The following code has the structure: +// +// ``` +// for { +// %a = load() +// for { +// %b = load() +// dot(%a, %b) +// } +// } +// ``` +// +// For CUDA, we pipeline the inner loop first then pipeline the outer +// loop to prefetch the async copy after the inner loop. +// For HIP, we only pipeline the inner loop for now. +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<320> : tensor<32x1xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %c10_i32 = arith.constant 10 : i32 + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = arith.muli %2, %cst_0 : tensor<32x1xi32, #blocked> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %5 = tt.addptr %4, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %6 = tt.broadcast %5 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %8 = tt.splat %arg3 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %9 = arith.muli %arg4, %c32_i32 : i32 + %10 = tt.splat %9 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %11 = tt.splat %9 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %12 = arith.addi %10, %0 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %13 = arith.addi %11, %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %14 = tt.expand_dims %12 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %15 = tt.broadcast %14 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %16 = tt.addptr %6, %15 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %17 = tt.load %16 : tensor<32x32x!tt.ptr, #blocked> + %18 = tt.expand_dims %13 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %19 = arith.muli %18, %cst_0 : tensor<32x1xi32, #blocked> + %20 = tt.addptr %7, %19 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %21 = tt.broadcast %20 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %22 = tt.addptr %8, %19 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %23 = tt.broadcast %22 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %24 = arith.muli %arg5, %c32_i32 : i32 + %25 = tt.splat %24 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %26 = arith.addi %25, %0 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %27 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %28 = tt.broadcast %27 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %29 = tt.addptr %21, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %30 = tt.load %29 : tensor<32x32x!tt.ptr, #blocked> + %31 = ttg.convert_layout %30 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %32 = ttg.convert_layout %17 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %33 = tt.dot %31, %32, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %34 = tt.addptr %23, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %35 = ttg.convert_layout %33 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %34, %35 : tensor<32x32x!tt.ptr, #blocked> + } + } + tt.return + } +} // end module + + +// ----- +// CHECK: #[[$SHARED_LAYOUT:shared.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +// CHECK-LABEL: tt.func @indirect_load_shared_layout +// CHECK: scf.for +// CHECK: ttg.async_wait {{.*}} {num = 1 : i32} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] +// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #smem, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #smem, mutable, 1x16> +// CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] +// CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} +// CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] +// CHECK: %[[IND_BUFFER_4:.*]] = arith.muli {{.*}}, %[[IND_BUFFER_3]] +// CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] +// CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_0]] + +// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +// AMD-LABEL: tt.func @indirect_load_shared_layout +// AMD: %[[LOCAL_ALLOC_0:.*]] = ttg.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = ttg.local_alloc +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[LOCAL_LOAD_50:.*]] = ttg.local_load %[[ARG11]] +// AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] +// AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} +// AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] +// AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] +// AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] +// AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] +// AMD: %[[LOCAL_LOAD_57:.*]] = ttg.local_load %[[ARG13]] +// AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] +// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} +// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] +// AMD: } +// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_23:.*]] = tt.addptr %{{.*}}#1, %{{.*}} +// AMD: %[[SPLAT_24:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_23]], %[[SPLAT_24]] +// AMD: %[[LOCAL_LOAD_26:.*]] = ttg.local_load %{{.*}}#4 +// AMD: %[[EXPAND_DIMS_27:.*]] = tt.expand_dims %{{.*}}#5 {axis = 1 : i32} +// AMD: %[[BROADCAST_28:.*]] = tt.broadcast %[[EXPAND_DIMS_27]] +// AMD: %[[MULI_29:.*]] = arith.muli %{{.*}}, %[[BROADCAST_28]] +// AMD: %[[ADDPTR_30:.*]] = tt.addptr %{{.*}}, %[[MULI_29]] +// AMD: %[[SPLAT_31:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_32:.*]] = tt.load %[[ADDPTR_30]], %[[SPLAT_31]] +// AMD: %[[LOCAL_LOAD_33:.*]] = ttg.local_load %{{.*}}#6 +// AMD: %[[IF_34:.*]] = scf.if %[[CMPI_21]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_26]], %[[LOCAL_LOAD_33]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: ttg.local_store %[[LOAD_32]], %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_34]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_41:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[LOCAL_LOAD_42:.*]] = ttg.local_load %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_40]] +// AMD: } +// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] +// AMD: ttg.local_dealloc %{{.*}} +// AMD: ttg.local_dealloc %{{.*}} + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: tensor<16x!tt.ptr, #BLs1>, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_splat = tt.splat %c1_i32 : i32 -> tensor<16xi32, #BLs1> + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1>) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21 : tensor<16x!tt.ptr, #BLs1> + %84 = tt.expand_dims %83 {axis=1: i32}: tensor<16xi64, #BLs1> -> tensor<16x1xi64, #BL> + %850 = tt.broadcast %84 : tensor<16x1xi64, #BL> -> tensor<16x16xi64, #BL> + %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> + } {tt.num_stages = 3 : i32} + tt.return %79#0 : tensor<16x16xf32, #C> +} +} + + +// ----- + +// CHECK-LABEL: @kernel_yield_constant +// CHECK: ttg.async_copy_global_to_local +// CHECK: scf.for +// CHECK: ttg.memdesc_subview +// CHECK: ttg.async_copy_global_to_local +// CHECK: tt.return + +// AMD-LABEL: @kernel_yield_constant +// AMD: tt.load +// AMD: ttg.memdesc_subview +// AMD: ttg.local_store +// AMD: scf.for +// AMD: tt.load +// AMD: ttg.memdesc_subview +// AMD: ttg.local_store +// AMD: tt.return +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func public @kernel_yield_constant(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst1 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %cst_1 = arith.constant dense<2.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %0 = tt.get_program_id x : i32 + %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %12 = arith.addi %arg4, %c31_i32 : i32 + %13 = arith.divsi %12, %c32_i32 : i32 + %14 = tt.expand_dims %7 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %22 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %42 = scf.for %arg7 = %c0_i32 to %13 step %c1_i32 iter_args(%arg8 = %cst) -> (tensor<32x32xf32, #mma>) : i32 { + %43 = arith.muli %arg7, %c32_i32 : i32 + %44 = arith.muli %43, %arg5 : i32 + %45 = tt.splat %44 : i32 -> tensor<32x32xi32, #blocked> + %46 = tt.addptr %22, %45 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %47 = arith.subi %arg4, %43 : i32 + %48 = tt.splat %47 : i32 -> tensor<32x1xi32, #blocked> + %49 = arith.cmpi slt, %14, %48 : tensor<32x1xi32, #blocked> + %50 = tt.broadcast %49 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + %51 = tt.load %46, %50, %cst_0 : tensor<32x32x!tt.ptr, #blocked> + %52 = ttg.convert_layout %51 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %53 = tt.dot %cst_1, %52, %arg8 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %54 = ttg.convert_layout %53 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %34, %54 : tensor<32x32x!tt.ptr, #blocked> + scf.yield %cst1 : tensor<32x32xf32, #mma> + } + tt.return + } +} + + +// ----- + +// CHECK-LABEL: @add_kernel +// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[ABUFFER:.*]] = ttg.local_alloc +// CHECK: %[[BBUFFER:.*]] = ttg.local_alloc +// CHECK: %[[A0BUFFER:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[A0BUFFER]] +// CHECK: %[[B0BUFFER:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[B0BUFFER]] +// CHECK: %[[A1BUFFER:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[A1BUFFER]] +// CHECK: %[[B1BUFFER:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] +// CHECK: ttg.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] +// CHECK: scf.for + +// AMD-LABEL: tt.func public @add_kernel +// AMD: %[[LOAD_11:.*]] = tt.load %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[LOAD_13:.*]] = tt.load %[[ADDPTR_12]], %{{.*}} +// AMD: %[[ADDI_14:.*]] = arith.addi %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[ADDI_14]] +// AMD: %[[ADDI_16:.*]] = arith.addi %[[SPLAT_15]], %{{.*}} +// AMD: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_16]], %{{.*}} +// AMD: %[[ADDPTR_18:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// AMD: %[[LOAD_19:.*]] = tt.load %[[ADDPTR_18]], %[[CMPI_17]] +// AMD: %[[ADDPTR_20:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[CMPI_17]] +// AMD: scf.for +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1016800_i32 = arith.constant 1016800 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1016800_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %c1016800_i32 step %c1024_i32 : i32 { + %7 = arith.addi %1, %arg4 : i32 + %8 = tt.splat %7 : i32 -> tensor<1024xi32, #blocked> + %9 = arith.addi %8, %2 : tensor<1024xi32, #blocked> + %10 = arith.cmpi slt, %9, %3 : tensor<1024xi32, #blocked> + %11 = tt.addptr %4, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %10 : tensor<1024x!tt.ptr, #blocked> + %13 = tt.addptr %5, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %14 = tt.load %13, %10 : tensor<1024x!tt.ptr, #blocked> + %15 = arith.addf %12, %14 : tensor<1024xf32, #blocked> + %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %15, %10 : tensor<1024x!tt.ptr, #blocked> + } {tt.num_stages = 3 : i32} + tt.return + } +} + + +// ----- + +// CHECK-LABEL: @nested_loops +// CHECK: tt.addptr %{{.*}}, {{.*}} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: scf.for +// CHECK: %[[LOAD_1:.*]] = tt.load %[[NEXT_BUFFER_1]] +// CHECK: %[[BUFFER_2:.*]] = ttg.local_alloc %[[LOAD_1]] +// CHECK: %[[TRANS:.*]] = ttg.memdesc_trans %[[BUFFER_2]] +// CHECK: %[[LOCAL_LOAD_1:.*]] = ttg.local_load %[[TRANS]] +// CHECK: %[[BUFFER_1:.*]] = ttg.local_alloc : () +// CHECK: %[[SUBVIEW_1:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_1:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_1]] +// CHECK: %[[SUBVIEW_2:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_2:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_2]] +// CHECK: scf.for +// CHECK: ttg.async_wait +// CHECK: ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[LOCAL_LOAD_2:.*]] = ttg.local_load +// CHECK: %[[DOT:.*]] = tt.dot %[[LOCAL_LOAD_2]], %[[LOCAL_LOAD_1]] +// CHECK: %[[CONVERT_LAYOUT_3:.*]] = ttg.convert_layout %[[DOT]] +// CHECK: %[[SUBVIEW_4:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_3:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_4]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_3]] +// CHECK: ttg.local_dealloc %[[BUFFER_1]] + +// AMD-LABEL: tt.func public @nested_loops +// AMD-NOT: ttg.local_alloc +// AMD: scf.for +// AMD: ttg.local_alloc +// AMD: scf.for +// AMD: ttg.local_load +// AMD: tt.dot +// AMD: ttg.local_store +// AMD: scf.yield +// AMD: ttg.local_dealloc + +// AMD_PREFETCH-LABEL: tt.func public @nested_loops +// AMD_PREFETCH-NOT: ttg.local_alloc +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: ttg.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: ttg.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: ttg.local_dealloc + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { + tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<16> : tensor<16x1xi32, #blocked> + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %2 = arith.muli %1, %cst_0 : tensor<16x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<16x1x!tt.ptr, #blocked>, tensor<16x1xi32, #blocked> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %7 = tt.broadcast %4 : tensor<16x1x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> + scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> + %11 = ttg.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !ttg.memdesc<16x16xf32, #shared, #smem> + %12 = ttg.memdesc_trans %11 {order = array} : !ttg.memdesc<16x16xf32, #shared, #smem> -> !ttg.memdesc<16x16xf32, #shared1, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<16x16xf32, #shared1, #smem> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> + %15 = ttg.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %16 = tt.dot %15, %13, %cst : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma> + %17 = ttg.convert_layout %16 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked> + tt.store %9, %17 : tensor<16x16x!tt.ptr, #blocked> + } + } + tt.return + } +} + +// ----- + + // CHECK-LABEL: @int4_matmul_ampere +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [16, 1, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [1, 8, 1], order = [2, 0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [16, 2, 1], threadsPerWarp = [4, 1, 8], warpsPerCTA = [1, 1, 8], order = [1, 0, 2]}> +#blocked5 = #ttg.blocked<{sizePerThread = [32, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + tt.func public @int4_matmul_ampere( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) -> tensor<16x256xf32, #mma> attributes {noinline = false} { + %cst = arith.constant dense<64> : tensor<64x256xi32, #blocked> + %cst_0 = arith.constant dense<128> : tensor<16x128xi32, #blocked1> + %c256_i32 = arith.constant 256 : i32 + %c16_i32 = arith.constant 16 : i32 + %c128_i32 = arith.constant 128 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x128xf16, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c255_i32 = arith.constant 255 : i32 + %c15_i32 = arith.constant 15 : i32 + %cst_2 = arith.constant dense<4> : tensor<64x256xi8, #blocked> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<16x256xf32, #mma> + + %35 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %36 = tt.expand_dims %35 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %38 = tt.broadcast %36 : tensor<1x128xi32, #blocked1> -> tensor<16x128xi32, #blocked1> + %40 = tt.splat %arg0 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked1> + %41 = tt.addptr %40, %38 : tensor<16x128x!tt.ptr, #blocked1>, tensor<16x128xi32, #blocked1> + + %42 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %43 = tt.expand_dims %42 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %47 = tt.broadcast %43 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> + %50 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked> + %51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + + // Check that both loads in the loop are pipelined. + // CHECK: scf.for + // CHECK-NOT: tt.load + // CHECK: ttg.async_copy_global_to_local + // CHECK-NOT: tt.load + // CHECK: ttg.async_copy_global_to_local + // CHECK-NOT: tt.load + // CHECK: scf.yield + %54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>) : i32 { + %78 = tt.load %arg11 : tensor<16x128x!tt.ptr, #blocked1> + %79 = tt.load %arg12 : tensor<64x256x!tt.ptr, #blocked> + %80 = arith.shli %79, %cst_2 : tensor<64x256xi8, #blocked> + %81 = arith.shrsi %80, %cst_2 : tensor<64x256xi8, #blocked> + %82 = arith.shrsi %79, %cst_2 : tensor<64x256xi8, #blocked> + %83 = arith.sitofp %81 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked> + %84 = arith.sitofp %82 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked> + %85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3> + %86 = tt.trans %85 {order = array} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4> + %87 = tt.reshape %86 : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> + %88 = ttg.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %89 = ttg.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> + %91 = tt.addptr %arg11, %cst_0 : tensor<16x128x!tt.ptr, #blocked1>, tensor<16x128xi32, #blocked1> + %92 = tt.addptr %arg12, %cst : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> + scf.yield %90, %91, %92 : tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked> + } + tt.return %54#0 : tensor<16x256xf32, #mma> + } +} + + +// ----- + +// This test triggered some failure in the verifier, so we only +// included a simple check for the kernel name. +// COMMON-LABEL: @load_convert_layout +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, + %76: index, + %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %75: tensor<16x!tt.ptr, #BLs1>, + %78: tensor<16x16xi32, #AL> {tt.constancy=16: i32, tt.divisibility=16: i32}, + %60: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=16 : i32}) -> tensor<16x16xf32, #C>{ + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #BLs1> + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #C> + %cst_0 = arith.constant dense<2> : tensor<16xi32, #BLs1> + %c4_i32 = arith.constant 4 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %c1_i32_splat = tt.splat %c1_i32 : i32 -> tensor<16xi32, #BLs1> + %15 = arith.cmpi slt, %1, %cst_0 : tensor<16xi32, #BLs1> + %79:3 = scf.for %arg18 = %c0 to %76 step %c1 iter_args(%arg19 = %cst, %arg20 = %49, %arg21 = %75) -> (tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1>) { + %82 = tt.load %arg20 : tensor<16x16x!tt.ptr, #AL> + %83 = tt.load %arg21, %15 : tensor<16x!tt.ptr, #BLs1> + %84 = tt.expand_dims %83 {axis=1: i32}: tensor<16xi64, #BLs1> -> tensor<16x1xi64, #BL> + %850 = tt.broadcast %84 : tensor<16x1xi64, #BL> -> tensor<16x16xi64, #BL> + %85 = arith.muli %77, %850 : tensor<16x16xi64, #BL> + %86 = tt.addptr %60, %85 : tensor<16x16x!tt.ptr, #BL>, tensor<16x16xi64, #BL> + %87 = tt.load %86 : tensor<16x16x!tt.ptr, #BL> + %88 = ttg.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> + %89 = ttg.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> + %90 = tt.dot %88, %89, %arg19 : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> + %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> + scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> + } {tt.num_stages = 3 : i32} + tt.return %79#0 : tensor<16x16xf32, #C> +} +} + + +// ----- + +// This test captured some ICE in MatmulLoopPipeline pass, so we only +// included a simple check for the kernel name. +// COMMON-LABEL: @matmul_indirect_pipeline +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { + tt.func public @matmul_indirect_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %4 = tt.broadcast %2 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> + %5 = tt.broadcast %3 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %6 = arith.addi %4, %5 : tensor<32x32xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %9 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> + %10 = tt.splat %arg3 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %11 = tt.addptr %10, %6 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %13 = tt.addptr %12, %0 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %15 = tt.load %13 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %16 = tt.addptr %14, %15 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<32xi64, #ttg.slice<{dim = 0, parent = #blocked}>> + %17 = tt.load %16 : tensor<32x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> + %18 = tt.expand_dims %17 {axis = 0 : i32} : tensor<32xf32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xf32, #blocked> + %19 = tt.broadcast %18 : tensor<1x32xf32, #blocked> -> tensor<32x32xf32, #blocked> + %20 = arith.addf %9, %19 : tensor<32x32xf32, #blocked> + %21 = ttg.convert_layout %9 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %22 = ttg.convert_layout %20 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %24 = ttg.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %11, %24 : tensor<32x32x!tt.ptr, #blocked> + } {tt.num_stages = 3 : i32} + tt.return + } +} + +// ----- + +// COMMON-LABEL: @dont_pipeline_128x1 +// AMD-NOT: local_load{{.*}}128x1 +// CHECK: local_load{{.*}}128x1 +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func public @dont_pipeline_128x1(%arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_4 = arith.constant dense<-1.000000e+30> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + + %99:1 = scf.for %arg25 = %c0_i32 to %c128_i32 step %c64_i32 iter_args(%arg31 = %cst_4) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) : i32 { + %94 = tt.splat %arg6 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %151 = tt.load %94 : tensor<128x1x!tt.ptr, #blocked> + %161 = ttg.convert_layout %151 : tensor<128x1xi32, #blocked> -> tensor<128x1xi32, #mma> + %162 = tt.broadcast %161 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma> + %170 = arith.sitofp %162 : tensor<128x64xi32, #mma> to tensor<128x64xf32, #mma> + + %173 = "tt.reduce"(%170) <{axis = 1 : i32}> ({ + ^bb0(%arg33: f32, %arg34: f32): + %207 = arith.maxnumf %arg33, %arg34 : f32 + tt.reduce.return %207 : f32 + }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %175 = arith.maxnumf %arg31, %173 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + + %201 = arith.truncf %170 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + %202 = ttg.convert_layout %201 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + + %192 = arith.constant dense<0.> : tensor<128x64xf32, #mma> + %203 = arith.constant dense<0.> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %204 = tt.dot %202, %203, %192 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + + scf.yield %175 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } + tt.return + } +} + +// ----- + +// Check that the dependencies across ops of different nesting does not cause crash or +// incorrect schedule that fails to pipeline. +// COMMON-LABEL: @matmul_nested_ops +// COMMON: ttg.local_load + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#BLs1 = #ttg.slice<{parent=#BL, dim=1}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %ext : index) -> tensor<128x128xf32, #C> { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + + %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x128xf32, #C>) { + %cnd = arith.cmpi slt, %iv, %ext : index + %inc_a_ptr = scf.if %cnd -> (tensor<128x32x!tt.ptr, #AL>) { + %a_ptr_ = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %a_ptr_ : tensor<128x32x!tt.ptr, #AL> + } else { + scf.yield %a_ptr : tensor<128x32x!tt.ptr, #AL> + } + %a_ = tt.load %inc_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %inc_a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %next_a_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x128xf32, #C> + } + tt.return %loop#1: tensor<128x128xf32, #C> +} +} + +// ----- + +// CHECK-LABEL: @masked_add_kernel +// CHECK: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// CHECK: scf.for +// CHECK: %[[A:.*]] = ttg.local_load +// CHECK: arith.select {{.*}}, %[[A]], %[[CONSTANT]] +// CHECK: %[[B:.*]] = ttg.local_load +// CHECK: arith.select {{.*}}, %[[B]], %[[CONSTANT]] + +// AMD-LABEL: @masked_add_kernel +// AMD: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: scf.for +// AMD: arith.select +// AMD: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: arith.addf +// AMD: tt.store +// AMD: scf.yield +// AMD: tt.store +// AMD: tt.store + +// AMD_PREFETCH-LABEL: @masked_add_kernel +// AMD_PREFETCH: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// AMD_PREFETCH-COUNT-4: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: arith.select +// AMD_PREFETCH: arith.addf +// AMD_PREFETCH: tt.store +// AMD_PREFETCH: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.store +// AMD_PREFETCH: tt.store + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func public @masked_add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1016800_i32 = arith.constant 1016800 : i32 + %cst = arith.constant dense<0xFF800000> : tensor<1024xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1016800_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %c1016800_i32 step %c1024_i32 : i32 { + %7 = arith.addi %1, %arg4 : i32 + %8 = tt.splat %7 : i32 -> tensor<1024xi32, #blocked> + %9 = arith.addi %8, %2 : tensor<1024xi32, #blocked> + %10 = arith.cmpi slt, %9, %3 : tensor<1024xi32, #blocked> + %11 = tt.addptr %4, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %10, %cst : tensor<1024x!tt.ptr, #blocked> + %13 = tt.addptr %5, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %14 = tt.load %13, %10, %cst : tensor<1024x!tt.ptr, #blocked> + %15 = arith.addf %12, %14 : tensor<1024xf32, #blocked> + %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %15, %10 : tensor<1024x!tt.ptr, #blocked> + }{tt.num_stages = 3 : i32} + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/loop-schedule.mlir b/third_party/enflame/include/triton/test/TritonGPU/loop-schedule.mlir new file mode 100644 index 000000000..1b198b669 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/loop-schedule.mlir @@ -0,0 +1,228 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-test-pipeline-assign-latencies=num-stages=3 -tritongpu-test-pipeline-schedule-loop | FileCheck %s + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#ALs0 = #ttg.slice<{parent=#AL, dim=0}> +#BLs0 = #ttg.slice<{parent=#BL, dim=0}> +#CLs0 = #ttg.slice<{parent=#C, dim=0}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABLE: @matmul_loop_load_acc +// CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} +// CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} +// CHECK: tt.load %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +tt.func @matmul_loop_load_acc(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %C : !tt.ptr {tt.divisibility = 16 : i32}, + %c_init: tensor<128x128xf32, #C>) -> tensor<128x128xf32, #C> { + + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // C ptrs + %c_ptr_splat = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr, #C> + %c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #CLs0> + %c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32, #CLs0> -> tensor<1x128xi32, #C> + %c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32, #C> -> tensor<128x128xi32, #C> + %c_ptr_init = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr, #C>, tensor<128x128xi32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c_off = arith.constant dense<4> : tensor<128x128xi32, #C> + + %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %c_ptr = %c_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128x!tt.ptr, #C>, tensor<128x128xf32, #C>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + %c_ = tt.load %c_ptr : tensor<128x128x!tt.ptr, #C> + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_c_ptr = tt.addptr %c_ptr, %c_off : tensor<128x128x!tt.ptr, #C>, tensor<128x128xi32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_c_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128x!tt.ptr, #C>, tensor<128x128xf32, #C> + } + tt.return %loop#3: tensor<128x128xf32, #C> +} +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: @fused_loop +tt.func public @fused_loop(%arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) { + %c10_i32 = arith.constant 10 : i32 + %false = arith.constant false + %0 = ub.poison : !tt.tensordesc> + %cst = arith.constant dense<0> : tensor<128x1xi64, #blocked> + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + + %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %3 = arith.extsi %arg7 : i32 to i64 + %4 = tt.make_tensor_descriptor %arg5, [%arg7, %arg7], [%3, %c1_i64] : , > + %5 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %7 = tt.splat %3 : i64 -> tensor<128x1xi64, #blocked> + + // CHECK: scf.for + %8:9 = scf.for %arg29 = %c0_i32 to %arg7 step %c1_i32 iter_args(%arg30 = %c-1_i32, %arg31 = %4, %arg32 = %c0_i32, %arg33 = %arg5, %arg34 = %cst_0, %arg35 = %c0_i32, %arg36 = %cst, %arg37 = %0, %arg38 = %false) -> (i32, !tt.tensordesc>, i32, !tt.ptr, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc>, i1) : i32 { + %9 = arith.addi %arg30, %c1_i32 : i32 + %10 = arith.cmpi eq, %arg30, %c10_i32 : i32 + %11 = arith.select %10, %c0_i32, %9 : i32 + %12 = arith.cmpi eq, %11, %c0_i32 : i32 + + // This op is a distance 1 dependency of itself. + // CHECK: {_test_marker_0, loop.cluster = 4 : i32, loop.stage = 0 : i32} + %13 = arith.select %12, %c0_i32, %arg32 {_test_marker_0} : i32 + + %14 = arith.select %12, %arg31, %arg37 : !tt.tensordesc> + %15 = arith.select %12, %c10_i32, %arg35 : i32 + %16 = scf.if %12 -> (tensor<128x1xi64, #blocked>) { + %32 = arith.muli %cst, %7 : tensor<128x1xi64, #blocked> + scf.yield %32 : tensor<128x1xi64, #blocked> + } else { + scf.yield %arg36 : tensor<128x1xi64, #blocked> + } + %17 = tt.splat %arg33 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %18 = tt.addptr %17, %16 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi64, #blocked> + %19 = tt.broadcast %18 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %20 = tt.addptr %19, %5 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %21 = tt.addptr %arg33, %c64_i32 : !tt.ptr, i32 + %22 = tt.load %20 : tensor<128x64x!tt.ptr, #blocked> + %23 = ttg.local_alloc %22 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %24 = arith.muli %13, %c64_i32 : i32 + %25 = tt.experimental_descriptor_load %14[%24, %15] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %26 = ttg.local_alloc %25 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> + %27 = ttng.warp_group_dot %23, %26, %arg34, %arg38 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> + %28 = arith.addi %13, %c1_i32 : i32 + + // This op is in the backward slice of `_test_marker_2` and the epilogue. + // CHECK: {_test_marker_1, loop.cluster = 3 : i32, loop.stage = 1 : i32} + %29 = arith.cmpi eq, %11, %c10_i32 {_test_marker_1} : i32 + + // CHECK: {_test_marker_2, loop.cluster = 3 : i32, loop.stage = 1 : i32} + %30 = arith.select %29, %arg5, %21 {_test_marker_2} : !tt.ptr + + %31 = arith.cmpi ne, %11, %c10_i32 : i32 + + scf.if %29 { + "use"(%27) : (tensor<128x256xf32, #mma>) -> () + // CHECK: {_test_marker_3, loop.cluster = 5 : i32, loop.stage = 2 : i32} + } {_test_marker_3} + scf.yield %11, %14, %28, %30, %27, %15, %16, %14, %31 : i32, !tt.tensordesc>, i32, !tt.ptr, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc>, i1 + } + tt.return +} + +} + +// ----- + +// CHECK-LABEL: @prologue_backward_slice +tt.func @prologue_backward_slice(%ub: i32, %cond: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // CHECK: scf.for + scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 { + // CHECK: scf.if + %0 = scf.if %cond -> i32 { + scf.yield %c0_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + // CHECK: loop.cluster = 0 : i32, loop.stage = 0 : i32 + + // CHECK: op.with_region + %1 = "op.with_region"() ({ + "use"(%0) : (i32) -> () + }) : () -> i32 + // CHECK: loop.cluster = 1 : i32, loop.stage = 0 : i32 + + // CHECK: op.with_region + "op.with_region"() ({ + "use"(%1) : (i32) -> () + }) {tt_latency = 2 : i32} : () -> () + // CHECK: loop.cluster = 1 : i32, loop.stage = 0 : i32 + + } {tt.num_stages = 3 : i32} + + tt.return +} + +// ----- + +// CHECK-LABEL: @epilogue_forward_slice +tt.func @epilogue_forward_slice(%ub: i32, %cond: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // CHECK: scf.for + scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 { + // CHECK: "latency.op"() {loop.cluster = 3 : i32, loop.stage = 0 : i32 + %0 = "latency.op"() {tt_latency = 2 : i32} : () -> i32 + // CHECK: scf.if + %1 = scf.if %cond -> i32 { + scf.yield %0 : i32 + } else { + scf.yield %c0_i32 : i32 + } + // CHECK: {loop.cluster = 1 : i32, loop.stage = 2 : i32} + + // CHECK: "use"(%{{.*}}) {loop.cluster = 1 : i32, loop.stage = 2 : i32} + "use"(%1) : (i32) -> () + + } {tt.num_stages = 3 : i32} + + tt.return +} + +// ----- + +// CHECK-LABEL: @prologue_latency +tt.func @prologue_latency(%ub: i32, %cond: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // CHECK: scf.for + scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 { + // CHECK: "some.op"() {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %0 = "some.op"() : () -> i32 + // CHECK: scf.if + %1 = scf.if %cond -> i32 { + scf.yield %0 : i32 + } else { + scf.yield %c0_i32 : i32 + } {tt_latency = 2 : i32} + // CHECK: loop.cluster = 0 : i32, loop.stage = 0 : i32 + + } {tt.num_stages = 3 : i32} + + tt.return +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/matmul-loop-pipeline.mlir b/third_party/enflame/include/triton/test/TritonGPU/matmul-loop-pipeline.mlir new file mode 100644 index 000000000..8416bf739 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/matmul-loop-pipeline.mlir @@ -0,0 +1,79 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: @softmax_kernel +tt.func public @softmax_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = tt.get_num_programs x : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> + %3 = tt.splat %arg5 : i32 -> tensor<128xi32, #blocked> + // CHECK: [[MASK:%.*]] = arith.cmpi slt, {{.*}} tensor<128xi32, + %4 = arith.cmpi slt, %2, %3 : tensor<128xi32, #blocked> + // CHECK: scf.for + scf.for %arg6 = %0 to %arg4 step %1 : i32 { + %5 = tt.splat %arg1 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.ptr -> tensor<128x!tt.ptr, #blocked> + %6 = tt.addptr %5, %2 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked> + // CHECK: [[RESULT:%.*]] = ttg.local_load + // CHECK-NEXT: arith.select [[MASK]], [[RESULT]], %cst + %7 = tt.load %6, %4, %cst {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr, #blocked> + %8 = tt.splat %arg0 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : !tt.ptr -> tensor<128x!tt.ptr, #blocked> + %9 = tt.addptr %8, %2 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr, #blocked>, tensor<128xi32, #blocked> + tt.store %9, %7, %4 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr, #blocked> + } {tt.num_stages = 2 : i32} + tt.return +} + +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90"} { + +// CHECK-LABEL: @scalar_load +tt.func public @scalar_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: f32) -> f32 { + %c1_i32 = arith.constant 1 : i32 + %2 = scf.for %i = %arg1 to %arg2 step %c1_i32 iter_args(%k = %arg3) -> f32 : i32 { + // CHECK: tt.load %arg0 + %0 = tt.load %arg0 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.ptr + %1 = arith.addf %0, %k {loop.cluster = 1 : i32, loop.stage = 0 : i32} : f32 + %2 = arith.addf %1, %k {loop.cluster = 0 : i32, loop.stage = 1 : i32} : f32 + scf.yield %2 : f32 + } {num_stages = 2 : i32} + tt.return %2 : f32 +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90"} { + +// CHECK-LABEL: @make_tensor_desc_epilogue +tt.func public @make_tensor_desc_epilogue(%arg0: i32, %arg1: !tt.ptr, %arg2: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1_i64 = arith.constant 1 : i64 + // CHECK: scf.for + scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 { + %1 = tt.splat %arg1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked> + %2 = tt.load %1 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : tensor<128x256x!tt.ptr, #blocked> + %3 = arith.addf %2, %2 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : tensor<128x256xf32, #blocked> + %4 = arith.cmpi eq, %arg3, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32 + // CHECK: scf.if + scf.if %4 { + // CHECK-NOT: tt.make_tensor_descriptor + // CHECK: tt.experimental_tensormap_create + // CHECK-NEXT: tt.experimental_tensormap_fenceproxy_acquire + %5 = tt.make_tensor_descriptor %arg1, [%arg2, %arg2], [%c1_i64, %c1_i64] : , > + } {loop.cluster = 5 : i32, loop.stage = 2 : i32} + } {tt.num_stages = 3 : i32} + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/matmul.mlir b/third_party/enflame/include/triton/test/TritonGPU/matmul.mlir new file mode 100644 index 000000000..c4d9cf056 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/matmul.mlir @@ -0,0 +1,105 @@ +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=target=cuda:80 -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -canonicalize -test-print-allocation 2>&1 | FileCheck %s + +// CHECK: offset = 0, size = 32768 +// CHECK: offset = 32768, size = 32768 +// CHECK: size = 65536 +module { +tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { + %cst = arith.constant dense : tensor<64x64xi1> + %c64 = arith.constant 64 : i32 + %c0 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg4, %c63_i32 : i32 + %4 = arith.divsi %3, %c64_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.cmpi slt, %8, %c8_i32 : i32 + %10 = arith.select %9, %8, %c8_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 + %14 = arith.divsi %13, %10 : i32 + %15 = arith.muli %12, %c64_i32 : i32 + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %17 = tt.splat %15 : i32 -> tensor<64xi32> + %18 = arith.addi %17, %16 : tensor<64xi32> + %19 = arith.muli %14, %c64_i32 : i32 + %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %21 = tt.splat %19 : i32 -> tensor<64xi32> + %22 = arith.addi %21, %20 : tensor<64xi32> + %23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %24 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %25 = tt.splat %arg6 : i32 -> tensor<64x1xi32> + %26 = arith.muli %24, %25 : tensor<64x1xi32> + %27 = tt.expand_dims %23 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %28 = tt.splat %arg7 : i32 -> tensor<1x64xi32> + %29 = arith.muli %27, %28 : tensor<1x64xi32> + %30 = tt.broadcast %26 : tensor<64x1xi32> -> tensor<64x64xi32> + %31 = tt.broadcast %29 : tensor<1x64xi32> -> tensor<64x64xi32> + %32 = arith.addi %30, %31 : tensor<64x64xi32> + %33 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr> + %34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> + %35 = tt.expand_dims %23 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %36 = tt.splat %arg8 : i32 -> tensor<64x1xi32> + %37 = arith.muli %35, %36 : tensor<64x1xi32> + %38 = tt.expand_dims %22 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %39 = tt.splat %arg9 : i32 -> tensor<1x64xi32> + %40 = arith.muli %38, %39 : tensor<1x64xi32> + %41 = tt.broadcast %37 : tensor<64x1xi32> -> tensor<64x64xi32> + %42 = tt.broadcast %40 : tensor<1x64xi32> -> tensor<64x64xi32> + %43 = arith.addi %41, %42 : tensor<64x64xi32> + %44 = tt.splat %arg1 : !tt.ptr -> tensor<64x64x!tt.ptr> + %45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> + %47:3 = scf.for %arg12 = %c0 to %arg5 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) : i32 { + %76 = tt.load %arg14, %cst, %cst_0 : tensor<64x64x!tt.ptr> + %77 = tt.load %arg15, %cst, %cst_0 : tensor<64x64x!tt.ptr> + %78 = tt.dot %76, %77, %cst_0 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> + %79 = arith.addf %arg13, %78 : tensor<64x64xf32> + %80 = arith.muli %arg7, %c64_i32 : i32 + %81 = tt.splat %80 : i32 -> tensor<64x64xi32> + %82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> + %83 = arith.muli %arg8, %c64_i32 : i32 + %84 = tt.splat %83 : i32 -> tensor<64x64xi32> + %85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> + scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr> + } + %48 = arith.muli %12, %c64_i32 : i32 + %49 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %50 = tt.splat %48 : i32 -> tensor<64xi32> + %51 = arith.addi %50, %49 : tensor<64xi32> + %52 = arith.muli %14, %c64_i32 : i32 + %53 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %54 = tt.splat %52 : i32 -> tensor<64xi32> + %55 = arith.addi %54, %53 : tensor<64xi32> + %56 = tt.expand_dims %51 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %57 = tt.splat %arg10 : i32 -> tensor<64x1xi32> + %58 = arith.muli %57, %56 : tensor<64x1xi32> + %59 = tt.expand_dims %55 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %60 = tt.splat %arg11 : i32 -> tensor<1x64xi32> + %61 = arith.muli %59, %60 : tensor<1x64xi32> + %62 = tt.broadcast %58 : tensor<64x1xi32> -> tensor<64x64xi32> + %63 = tt.broadcast %61 : tensor<1x64xi32> -> tensor<64x64xi32> + %64 = arith.addi %62, %63 : tensor<64x64xi32> + %65 = tt.splat %arg2 : !tt.ptr -> tensor<64x64x!tt.ptr> + %66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> + %67 = tt.expand_dims %51 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %68 = tt.splat %arg3 : i32 -> tensor<64x1xi32> + %69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32> + %70 = tt.expand_dims %55 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %71 = tt.splat %arg4 : i32 -> tensor<1x64xi32> + %72 = arith.cmpi slt, %70, %71 : tensor<1x64xi32> + %73 = tt.broadcast %69 : tensor<64x1xi1> -> tensor<64x64xi1> + %74 = tt.broadcast %72 : tensor<1x64xi1> -> tensor<64x64xi1> + %75 = arith.andi %73, %74 : tensor<64x64xi1> + tt.store %66, %47#0, %75 : tensor<64x64x!tt.ptr> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/mma-pipeline-blackwell.mlir b/third_party/enflame/include/triton/test/TritonGPU/mma-pipeline-blackwell.mlir new file mode 100644 index 000000000..794cb7e6e --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/mma-pipeline-blackwell.mlir @@ -0,0 +1,1027 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-tc05mma-pipeline=disable-expander=true -canonicalize | FileCheck --dump-input-context=50 %s --check-prefix=CHECK-LOWER +// RUN: triton-opt %s -split-input-file -tritongpu-tc05mma-pipeline -canonicalize | FileCheck --dump-input-context=50 %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem1 = #ttng.tensor_memory_scales_encoding<> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @chained_dot_no_multibuf_acc + // CHECK-LOWER-DAG: %[[C0_F:.+]] = arith.constant dense<0.000000e+00> + // CHECK-LOWER-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-LOWER-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-LOWER-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-LOWER-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-LOWER: %[[TMEM_BUF:.+]] = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32 + // CHECK-LOWER: ttng.tmem_store %[[C0_F]], %[[TMEM_BUF]] + // CHECK-LOWER: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE1]], 1 + // CHECK-LOWER: scf.for {{.*}} iter_args(%[[PHASE:.+]] = %[[C0]], %[[BAR_IDX:.+]] = %[[C0]]) + // CHECK-LOWER: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]]] + // CHECK-LOWER: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_BUF]], %[[TRUE]], %[[TRUE]], %[[BAR_SLICE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[BAR_IDX_P1:.+]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK-LOWER: %[[BAR_WRAP:.+]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[PHASE_XOR:.+]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK-LOWER: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: scf.yield %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]] + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK-LOWER: ttg.local_dealloc %[[BAR_BUF]] + // CHECK-LOWER: ttng.tmem_load %[[TMEM_BUF]] + + // CHECK-LABEL: @chained_dot_no_multibuf_acc + tt.func public @chained_dot_no_multibuf_acc(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + scf.yield %acc_res : tensor<128x128xf32, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @chained_dot_wait_before_store + // CHECK-LOWER-DAG: %[[C0_F:.+]] = arith.constant dense<0.000000e+00> + // CHECK-LOWER-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-LOWER-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-LOWER-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-LOWER-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-LOWER: %[[TMEM_BUF:.+]] = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32 + // CHECK-LOWER: ttng.tmem_store %[[C0_F]], %[[TMEM_BUF]] + // CHECK-LOWER: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE1]], 1 + // CHECK-LOWER: scf.for {{.*}} iter_args(%[[PHASE:.+]] = %[[C0]], %[[BAR_IDX:.+]] = %[[C0]]) + // CHECK-LOWER: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]]] + // CHECK-LOWER: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_BUF]], %[[TRUE]], %[[TRUE]], %[[BAR_SLICE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[BAR_IDX_P1:.+]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK-LOWER: %[[BAR_WRAP:.+]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[PHASE_XOR:.+]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK-LOWER: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: scf.if + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[TMEM_BUF]] + // CHECK-LOWER: tt.store %{{.*}}, %[[ACC_RES]] + // CHECK-LOWER: } {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK-LOWER: ttg.local_dealloc %[[BAR_BUF]] + // CHECK-LOWER: ttng.tmem_load %[[TMEM_BUF]] + + // CHECK-LABEL: @chained_dot_wait_before_store + tt.func public @chained_dot_wait_before_store(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %cnd: i1) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + scf.if %cnd { + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } + scf.yield %acc_res : tensor<128x128xf32, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // Verify that we still can pipeline the mma when the subview is in the previous iteration. + // CHECK-LOWER-LABEL: @subview_dist_1 + // CHECK-LOWER: ttng.tmem_alloc + // CHECK-LOWER: ttng.tmem_store + // CHECK-LOWER: scf.for + // CHECK-LOWER: ttng.tc_gen5_mma + // CHECK-LOWER: scf.yield + // CHECK-LOWER: ttng.tmem_load + + // CHECK-LABEL: @subview_dist_1 + tt.func public @subview_dist_1(%arg3: i32) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %A_sh0 = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %res, %_ = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst, %A_sh_arg = %A_sh0) -> (tensor<128x128xf32, #blocked>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>) : i32 { + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh_arg, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %A_sh1 = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + scf.yield %acc_res, %A_sh1 : tensor<128x128xf32, #blocked>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @multibuf_acc + // CHECK-LOWER-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-LOWER-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-LOWER-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-LOWER-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-LOWER: %[[TMEM_BUF:.+]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32 + // CHECK-LOWER: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE1]], 1 + // CHECK-LOWER: scf.for {{.*}} iter_args(%[[PHASE:.+]] = %[[C0]], %[[BAR_IDX:.+]] = %[[C0]], %[[ACC_INS_IDX:.+]] = %[[C0]], %[[ACC_EXT_IDX:.+]] = %[[C0]]) + // CHECK-LOWER: %[[ACC_INS_IDX_P1:.+]] = arith.addi %[[ACC_INS_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_INS_WRAP:.+]] = arith.cmpi eq, %[[ACC_INS_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_INS_NEXT:.+]] = arith.select %[[ACC_INS_WRAP]], %[[C0]], %[[ACC_INS_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[ACC_EXT_IDX_P1:.+]] = arith.addi %[[ACC_EXT_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_EXT_WRAP:.+]] = arith.cmpi eq, %[[ACC_EXT_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_EXT_NEXT:.+]] = arith.select %[[ACC_EXT_WRAP]], %[[C0]], %[[ACC_EXT_IDX_P1]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_INS_NEXT]], + // CHECK-LOWER: ttng.tmem_store {{.*}}, %[[TMEM_INS_SLICE]], %[[TRUE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_INS_NEXT]], + // CHECK-LOWER: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]]] + // CHECK-LOWER: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_INS_SLICE]], %[[TRUE]], %[[TRUE]], %[[BAR_SLICE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[BAR_IDX_P1:.+]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK-LOWER: %[[BAR_WRAP:.+]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[PHASE_XOR:.+]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK-LOWER: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[TMEM_EXT_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_EXT_NEXT]], + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[TMEM_EXT_SLICE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES]] + // CHECK-LOWER: scf.yield %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]], %[[ACC_INS_NEXT]], %[[ACC_EXT_NEXT]] + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK-LOWER: ttg.local_dealloc %[[BAR_BUF]] + + // CHECK-LABEL: @multibuf_acc + tt.func public @multibuf_acc(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %acc_ptr: tensor<128x128x!tt.ptr, #blocked>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32) attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + scf.for %i = %c0_i32 to %arg3 step %c1_i32 : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc = tt.load %acc_ptr : tensor<128x128x!tt.ptr, #blocked> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // Do not pipeline the mma, as multibuffering is disabled, and would need to wait in the + // every iteration of the loop anyway. + // CHECK-LOWER-LABEL: @disable_multibuf_acc + // CHECK-LOWER-NOT: ttng.wait_barrier + // CHECK-LABEL: @disable_multibuf_acc + tt.func public @disable_multibuf_acc(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %acc_ptr: tensor<128x128x!tt.ptr, #blocked>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32) attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + scf.for %i = %c0_i32 to %arg3 step %c1_i32 : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc = tt.load %acc_ptr : tensor<128x128x!tt.ptr, #blocked> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } {tt.disallow_acc_multi_buffer} + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @do_not_pipeline_two_dots + // CHECK-LOWER-NOT: triton.pipeline_stage + + // CHECK-LABEL: @do_not_pipeline_two_dots + tt.func public @do_not_pipeline_two_dots(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %acc_ptr: tensor<128x128x!tt.ptr, #blocked>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32) attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + scf.for %i = %c0_i32 to %arg3 step %c1_i32 : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc = tt.load %acc_ptr : tensor<128x128x!tt.ptr, #blocked> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %acc_tm2 = ttng.tmem_alloc %acc_res : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm2, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res2 = ttng.tmem_load %acc_tm2 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + tt.store %res_ptr, %acc_res2 : tensor<128x128x!tt.ptr, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @multibuf_acc_sel_override + // CHECK-LOWER-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-LOWER-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-LOWER-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-LOWER-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-LOWER: %[[TMEM_BUF:.+]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32 + // CHECK-LOWER: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE1]], 1 + // CHECK-LOWER: scf.for {{.*}} iter_args(%[[PHASE:.+]] = %[[C0]], %[[BAR_IDX:.+]] = %[[C0]], %[[ACC_INS_IDX:.+]] = %[[C0]], %[[ACC_EXT_IDX:.+]] = %[[C0]]) + // CHECK-LOWER: %[[ACC_INS_IDX_P1:.+]] = arith.addi %[[ACC_INS_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_INS_WRAP:.+]] = arith.cmpi eq, %[[ACC_INS_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_INS_NEXT:.+]] = arith.select %[[ACC_INS_WRAP]], %[[C0]], %[[ACC_INS_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[ACC_EXT_IDX_P1:.+]] = arith.addi %[[ACC_EXT_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_EXT_WRAP:.+]] = arith.cmpi eq, %[[ACC_EXT_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_EXT_NEXT:.+]] = arith.select %[[ACC_EXT_WRAP]], %[[C0]], %[[ACC_EXT_IDX_P1]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_INS_NEXT]], + // CHECK-LOWER: ttng.tmem_store {{.*}}, %[[TMEM_INS_SLICE]], %[[TRUE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_INS_NEXT]], + // CHECK-LOWER: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]]] + // CHECK-LOWER: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_INS_SLICE]], %[[TRUE]], %[[TRUE]], %[[BAR_SLICE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[BAR_IDX_P1:.+]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK-LOWER: %[[BAR_WRAP:.+]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[PHASE_XOR:.+]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK-LOWER: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[TMEM_EXT_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_EXT_NEXT]], + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[TMEM_EXT_SLICE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES]] + // CHECK-LOWER: scf.yield %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]], %[[ACC_INS_NEXT]], %[[ACC_EXT_NEXT]] + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK-LOWER: ttg.local_dealloc %[[BAR_BUF]] + + // CHECK-LABEL: @multibuf_acc + tt.func public @multibuf_acc_sel_override(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32, %cnd: i1) attributes {noinline = false} { + %true = arith.constant true + %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %cst1 = arith.constant dense<1.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + scf.for %i = %c0_i32 to %arg3 step %c1_i32 : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + // %acc = tt.load %acc_ptr : tensor<128x128x!tt.ptr, #blocked> + %acc = arith.select %cnd, %cst1, %cst0 : tensor<128x128xf32, #blocked> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @multibuf_acc_unused + // CHECK-LOWER-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-LOWER-DAG: %[[FALSE:.+]] = arith.constant false + // CHECK-LOWER-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-LOWER-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-LOWER-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-LOWER: %[[TMEM_BUF:.+]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32 + // CHECK-LOWER: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE1]], 1 + // CHECK-LOWER: scf.for {{.*}} iter_args(%[[PHASE:.+]] = %[[C0]], %[[BAR_IDX:.+]] = %[[C0]], %[[ACC_INS_IDX:.+]] = %[[C0]], %[[ACC_EXT_IDX:.+]] = %[[C0]]) + // CHECK-LOWER: %[[ACC_INS_IDX_P1:.+]] = arith.addi %[[ACC_INS_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_INS_WRAP:.+]] = arith.cmpi eq, %[[ACC_INS_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_INS_NEXT:.+]] = arith.select %[[ACC_INS_WRAP]], %[[C0]], %[[ACC_INS_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[ACC_EXT_IDX_P1:.+]] = arith.addi %[[ACC_EXT_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_EXT_WRAP:.+]] = arith.cmpi eq, %[[ACC_EXT_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_EXT_NEXT:.+]] = arith.select %[[ACC_EXT_WRAP]], %[[C0]], %[[ACC_EXT_IDX_P1]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_INS_NEXT]], + // CHECK-LOWER: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]]] + // CHECK-LOWER: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_INS_SLICE]], %[[FALSE]], %[[TRUE]], %[[BAR_SLICE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[BAR_IDX_P1:.+]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK-LOWER: %[[BAR_WRAP:.+]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[PHASE_XOR:.+]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK-LOWER: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[TMEM_EXT_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_EXT_NEXT]], + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[TMEM_EXT_SLICE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES]] + // CHECK-LOWER: scf.yield %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]], %[[ACC_INS_NEXT]], %[[ACC_EXT_NEXT]] + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK-LOWER: ttg.local_dealloc %[[BAR_BUF]] + + // CHECK-LABEL: @multibuf_acc_unused + tt.func public @multibuf_acc_unused(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %acc_ptr: tensor<128x128x!tt.ptr, #blocked>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32) attributes {noinline = false} { + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + scf.for %i = %c0_i32 to %arg3 step %c1_i32 : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %false, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @acc_reinit_under_sel + // CHECK-LOWER-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-LOWER-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-LOWER-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-LOWER-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-LOWER-DAG: %[[C0_F:.+]] = arith.constant dense<0.000000e+00> + // CHECK-LOWER: %[[TMEM_BUF:.+]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32 + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[C0]] + // CHECK-LOWER: ttng.tmem_store %[[C0_F]], %[[TMEM_INS_SLICE]] + // CHECK-LOWER: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE1]], 1 + // CHECK-LOWER: scf.for {{.*}} iter_args(%[[PHASE:.+]] = %[[C0]], %[[BAR_IDX:.+]] = %[[C0]], %[[ACC_INS_IDX:.+]] = %[[C0]], %[[ACC_EXT_IDX:.+]] = %[[C0]]) + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_INS_IDX]], + // CHECK-LOWER: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]] + // CHECK-LOWER: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_INS_SLICE]], %[[TRUE]], %[[TRUE]], %[[BAR_SLICE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[BAR_IDX_P1:.+]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK-LOWER: %[[BAR_WRAP:.+]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[PHASE_XOR:.+]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK-LOWER: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[ACC_INS_IDX_P1:.+]] = arith.addi %[[ACC_INS_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_INS_WRAP:.+]] = arith.cmpi eq, %[[ACC_INS_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_INS_NEXT:.+]] = arith.select %[[ACC_INS_WRAP]], %[[C0]], %[[ACC_INS_IDX_P1]] + // CHECK-LOWER: %[[ACC_INS_NEXT_PRED:.+]] = arith.select %[[CND:.+]], %[[ACC_INS_NEXT]], %[[ACC_INS_IDX]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[ACC_EXT_IDX_P1:.+]] = arith.addi %[[ACC_EXT_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_EXT_WRAP:.+]] = arith.cmpi eq, %[[ACC_EXT_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_EXT_NEXT:.+]] = arith.select %[[ACC_EXT_WRAP]], %[[C0]], %[[ACC_EXT_IDX_P1]] + // CHECK-LOWER: %[[ACC_EXT_NEXT_PRED:.+]] = arith.select %[[CND]], %[[ACC_EXT_NEXT]], %[[ACC_EXT_IDX]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_INS_NEXT_PRED]], + // CHECK-LOWER: ttng.tmem_store %[[C0_F]], %[[TMEM_INS_SLICE]], %[[CND]] + // CHECK-LOWER: scf.if %[[CND]] + // CHECK-LOWER: %[[TMEM_EXT_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_EXT_IDX]], + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[TMEM_EXT_SLICE]] + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES]] + // CHECK-LOWER: } {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: scf.yield %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]], %[[ACC_INS_NEXT_PRED]], %[[ACC_EXT_NEXT_PRED]] + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK-LOWER: ttg.local_dealloc %[[BAR_BUF]] + + // CHECK-LABEL: @acc_reinit_under_sel + tt.func public @acc_reinit_under_sel(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32, %cnd: i1) attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %new_acc = arith.select %cnd, %cst, %acc_res : tensor<128x128xf32, #blocked> + scf.if %cnd { + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } + scf.yield %new_acc : tensor<128x128xf32, #blocked> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // Do not pipeline if multibufferring is disallowed and we are physilcally override accumulator + // in the loop body. + // CHECK-LOWER-LABEL: @acc_reinit_under_sel_disallow_multibuffer + // CHECK-LOWER: ttng.tc_gen5_mma + // CHECK-LOWER-NOT: ttng.wait_barrier + + // CHECK-LABEL: @acc_reinit_under_sel_disallow_multibuffer + tt.func public @acc_reinit_under_sel_disallow_multibuffer(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32, %cnd: i1) attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %new_acc = arith.select %cnd, %cst, %acc_res : tensor<128x128xf32, #blocked> + scf.if %cnd { + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } + scf.yield %new_acc : tensor<128x128xf32, #blocked> + } {tt.disallow_acc_multi_buffer} + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @acc_reinit_under_if_acc_flag + // CHECK-LOWER-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-LOWER-DAG: %[[FALSE:.+]] = arith.constant false + // CHECK-LOWER-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-LOWER-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-LOWER-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-LOWER-DAG: %[[C0_F:.+]] = arith.constant dense<0.000000e+00> + // CHECK-LOWER: %[[TMEM_BUF:.+]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32 + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[C0]] + // CHECK-LOWER: ttng.tmem_store %[[C0_F]], %[[TMEM_INS_SLICE]] + // CHECK-LOWER: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE1]], 1 + // CHECK-LOWER: scf.for {{.*}} iter_args(%[[ACC_USE:.+]] = %[[FALSE]], %[[PHASE:.+]] = %[[C0]], %[[BAR_IDX:.+]] = %[[C0]], %[[ACC_INS_IDX:.+]] = %[[C0]], %[[ACC_EXT_IDX:.+]] = %[[C0]]) + // CHECK-LOWER: %[[TMEM_INS_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_INS_IDX]], + // CHECK-LOWER: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]] + // CHECK-LOWER: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_INS_SLICE]], %[[ACC_USE]], %[[TRUE]], %[[BAR_SLICE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[BAR_IDX_P1:.+]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK-LOWER: %[[BAR_WRAP:.+]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[PHASE_XOR:.+]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK-LOWER: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[ACC_USE_NEXT:.+]] = arith.xori %[[CND:.+]], %[[TRUE]] + // CHECK-LOWER: %[[ACC_INS_IDX_P1:.+]] = arith.addi %[[ACC_INS_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_INS_WRAP:.+]] = arith.cmpi eq, %[[ACC_INS_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_INS_NEXT:.+]] = arith.select %[[ACC_INS_WRAP]], %[[C0]], %[[ACC_INS_IDX_P1]] + // CHECK-LOWER: %[[ACC_INS_NEXT_PRED:.+]] = arith.select %[[CND]], %[[ACC_INS_NEXT]], %[[ACC_INS_IDX]] + // CHECK-LOWER: %[[ACC_EXT_IDX_P1:.+]] = arith.addi %[[ACC_EXT_IDX]], %[[C1]] + // CHECK-LOWER: %[[ACC_EXT_WRAP:.+]] = arith.cmpi eq, %[[ACC_EXT_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[ACC_EXT_NEXT:.+]] = arith.select %[[ACC_EXT_WRAP]], %[[C0]], %[[ACC_EXT_IDX_P1]] + // CHECK-LOWER: %[[ACC_EXT_NEXT_PRED:.+]] = arith.select %[[CND]], %[[ACC_EXT_NEXT]], %[[ACC_EXT_IDX]] + // CHECK-LOWER: scf.if %[[CND]] + // CHECK-LOWER: %[[TMEM_EXT_SLICE:.+]] = ttg.memdesc_subview %[[TMEM_BUF]][%[[ACC_EXT_IDX]], + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[TMEM_EXT_SLICE]] + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES]] + // CHECK-LOWER: } {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: scf.yield %[[ACC_USE_NEXT]], %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]], %[[ACC_INS_NEXT_PRED]], %[[ACC_EXT_NEXT_PRED]] + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK-LOWER: ttg.local_dealloc %[[BAR_BUF]] + + // CHECK-LABEL: @acc_reinit_under_if_acc_flag + tt.func public @acc_reinit_under_if_acc_flag(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32, %cnd: i1) attributes {noinline = false} { + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res:2 = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst, %accUse = %false) -> (tensor<128x128xf32, #blocked>, i1) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %accUse, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %new_accUse = arith.select %cnd, %false, %true : i1 + scf.if %cnd { + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } + scf.yield %acc_res, %new_accUse : tensor<128x128xf32, #blocked>, i1 + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @acc_reinit_under_if_acc_flag_disallow_multibuffer + // CHECK-LOWER-DAG: %[[TRUE:.+]] = arith.constant true + // CHECK-LOWER-DAG: %[[FALSE:.+]] = arith.constant false + // CHECK-LOWER-DAG: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK-LOWER-DAG: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK-LOWER-DAG: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK-LOWER-DAG: %[[C0_F:.+]] = arith.constant dense<0.000000e+00> + // CHECK-LOWER: %[[TMEM_BUF:.+]] = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32 + // CHECK-LOWER: ttng.tmem_store %[[C0_F]], %[[TMEM_BUF]] + // CHECK-LOWER: %[[BAR_BUF:.+]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.init_barrier %[[BAR_SLICE1]], 1 + // CHECK-LOWER: scf.for {{.*}} iter_args(%[[ACC_USE:.+]] = %[[FALSE]], %[[PHASE:.+]] = %[[C0]], %[[BAR_IDX:.+]] = %[[C0]]) + // CHECK-LOWER: %[[BAR_SLICE:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[BAR_IDX]] + // CHECK-LOWER: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM_BUF]], %[[ACC_USE]], %[[TRUE]], %[[BAR_SLICE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] {triton.pipeline_stage = 1 : i32} + // CHECK-LOWER: %[[BAR_IDX_P1:.+]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK-LOWER: %[[BAR_WRAP:.+]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK-LOWER: %[[BAR_IDX_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[C0]], %[[BAR_IDX_P1]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[PHASE_XOR:.+]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK-LOWER: %[[PHASE_NEXT:.+]] = arith.select %[[BAR_WRAP]], %[[PHASE_XOR]], %[[PHASE]] {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: %[[ACC_USE_NEXT:.+]] = arith.xori %[[CND:.+]], %[[TRUE]] + // CHECK-LOWER: scf.if %{{.*}} + // CHECK-LOWER: ttng.wait_barrier %[[BAR_SLICE]], %[[PHASE]] + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[TMEM_BUF]] + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES]] + // CHECK-LOWER: } {triton.pipeline_stage = 0 : i32} + // CHECK-LOWER: scf.yield %[[ACC_USE_NEXT]], %[[PHASE_NEXT]], %[[BAR_IDX_NEXT]] + // CHECK-LOWER: %[[BAR_SLICE0:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C0]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK-LOWER: %[[BAR_SLICE1:.+]] = ttg.memdesc_subview %[[BAR_BUF]][%[[C1]]] + // CHECK-LOWER: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK-LOWER: ttg.local_dealloc %[[BAR_BUF]] + + // CHECK-LABEL: @acc_reinit_under_if_acc_flag_disallow_multibuffer + tt.func public @acc_reinit_under_if_acc_flag_disallow_multibuffer(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32, %cnd: i1) attributes {noinline = false} { + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res:2 = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst, %accUse = %false) -> (tensor<128x128xf32, #blocked>, i1) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %accUse, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %new_accUse = arith.select %cnd, %false, %true : i1 + scf.if %cnd { + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } + scf.yield %acc_res, %new_accUse : tensor<128x128xf32, #blocked>, i1 + } {tt.disallow_acc_multi_buffer} + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @acc_used_if_else + // CHECK-LOWER: ttng.tmem_alloc + // CHECK-LOWER: ttng.tmem_store + // CHECK-LOWER: scf.for + // CHECK-LOWER: ttng.tc_gen5_mma + // CHECK-LOWER: %[[EXT_SLICE:.+]] = ttg.memdesc_subview + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[EXT_SLICE]] + // CHECK-LOWER: scf.if + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES]] + // CHECK-LOWER: } else { + // CHECK-LOWER: arith.addf %[[ACC_RES]], %[[ACC_RES]] + + // CHECK-LABEL: @acc_used_if_else + tt.func public @acc_used_if_else(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32, %cnd: i1) attributes {noinline = false} { + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res:2 = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst, %accUse = %false) -> (tensor<128x128xf32, #blocked>, i1) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %accUse, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %new_accUse = arith.select %cnd, %false, %true : i1 + scf.if %cnd { + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } else { + %acc_res2 = arith.addf %acc_res, %acc_res : tensor<128x128xf32, #blocked> + tt.store %res_ptr, %acc_res2 : tensor<128x128x!tt.ptr, #blocked> + } + scf.yield %acc_res, %new_accUse : tensor<128x128xf32, #blocked>, i1 + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LOWER-LABEL: @acc_used_if_and_outside + // CHECK-LOWER: ttng.tmem_alloc + // CHECK-LOWER: ttng.tmem_store + // CHECK-LOWER: scf.for + // CHECK-LOWER: ttng.tc_gen5_mma + // CHECK-LOWER: %[[EXT_SLICE:.+]] = ttg.memdesc_subview + // CHECK-LOWER: %[[ACC_RES:.+]] = ttng.tmem_load %[[EXT_SLICE]] + // CHECK-LOWER: %[[ACC_RES2:.+]] = arith.addf %[[ACC_RES]], %[[ACC_RES]] + // CHECK-LOWER: scf.if + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES]] + // CHECK-LOWER: } else { + // CHECK-LOWER: tt.store {{.*}}, %[[ACC_RES2]] + + // CHECK-LABEL: @acc_used_if_and_outside + tt.func public @acc_used_if_and_outside(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %res_ptr: tensor<128x128x!tt.ptr, #blocked>, %arg3: i32, %cnd: i1) attributes {noinline = false} { + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res:2 = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst, %accUse = %false) -> (tensor<128x128xf32, #blocked>, i1) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %accUse, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + %new_accUse = arith.select %cnd, %false, %true : i1 + %acc_res2 = arith.addf %acc_res, %acc_res : tensor<128x128xf32, #blocked> + scf.if %cnd { + tt.store %res_ptr, %acc_res : tensor<128x128x!tt.ptr, #blocked> + } else { + tt.store %res_ptr, %acc_res2 : tensor<128x128x!tt.ptr, #blocked> + } + scf.yield %acc_res, %new_accUse : tensor<128x128xf32, #blocked>, i1 + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @pipeline_tc05mma + // CHECK: (%[[UB_ARG:[0-9a-z]+]]: i32, + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32 + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> + // CHECK: %[[A_SH:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf16 + // CHECK: %[[B_SH:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf16 + + // CHECK: %[[TMEM:.*]] = ttng.tmem_alloc + // CHECK: ttng.tmem_store %[[CST]], %[[TMEM]], %[[TRUE]] + + // Barrier allocation: + // CHECK: %[[BAR_SH:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 + // CHECK: %[[BAR_SLICE0:.*]] = ttg.memdesc_subview %[[BAR_SH]][%[[C0]]] + // CHECK: ttng.init_barrier %[[BAR_SLICE0]], 1 + // CHECK: %[[BAR_SLICE1:.*]] = ttg.memdesc_subview %[[BAR_SH]][%[[C1]]] + // CHECK: ttng.init_barrier %[[BAR_SLICE1]], 1 + + // Peeled prologue: + // CHECK-DAG: %[[I0_PRED:.*]] = arith.cmpi sgt, %[[UB_ARG]], %[[C0]] : i32 + // CHECK-DAG: %[[A_SLICE0:.*]] = ttg.memdesc_subview %[[A_SH]][%[[C0]], %[[C0]], %[[C0]]] + // CHECK-DAG: %[[B_SLICE0:.*]] = ttg.memdesc_subview %[[B_SH]][%[[C0]], %[[C0]], %[[C0]]] + // CHECK-DAG: %[[BAR_SLICE0_1:.*]] = ttg.memdesc_subview %[[BAR_SH]][%[[C0]]] + // CHECK: ttng.tc_gen5_mma %[[A_SLICE0]], %[[B_SLICE0]], %[[TMEM]], %[[TRUE]], %[[I0_PRED]], %[[BAR_SLICE0_1]] + + // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[UB_ARG]] step %[[C1]] iter_args({{.*}} %[[PHASE:.[^,]+]] = %[[C0]], %[[BAR_IDX:[^,]+]] = %[[C1]], %[[BAR_SLICE_PREV:.[^,]+]] = %[[BAR_SLICE0_1]], %[[PHASE_PREV:.[^,]+]] = %[[C0]] + // CHECK: %[[UB_M1:.*]] = arith.subi %[[UB_ARG]], %[[C1]] + // CHECK: %[[IN_PRED:.*]] = arith.cmpi slt, %[[IV]], %[[UB_M1]] + // CHECK: %[[BAR_SLICE:.*]] = ttg.memdesc_subview %[[BAR_SH]][%[[BAR_IDX]]] + // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, %[[TMEM]], %[[TRUE]], %[[IN_PRED]], %[[BAR_SLICE]] + // CHECK: ttng.wait_barrier %[[BAR_SLICE_PREV]], %[[PHASE_PREV]] + + // CHECK: %[[BAR_IDX_P1:.*]] = arith.addi %[[BAR_IDX]], %[[C1]] + // CHECK: %[[BAR_IDX_WRAP:.*]] = arith.cmpi eq, %[[BAR_IDX_P1]], %[[C2]] + // CHECK: %[[BAR_IDX_NEXT:.*]] = arith.select %[[BAR_IDX_WRAP]], %[[C0]], %[[BAR_IDX_P1]] + + // CHECK: %[[XOR:.*]] = arith.xori %[[PHASE]], %[[C1]] + // CHECK: %[[NEXT_PHASE:.*]] = arith.select %[[BAR_IDX_WRAP]], %[[XOR]], %[[PHASE]] + // CHECK: scf.yield {{.*}}, %[[NEXT_PHASE]], %[[BAR_IDX_NEXT]], %[[BAR_SLICE]], %[[PHASE]] + // CHECK: %[[BAR_SLICE0:.*]] = ttg.memdesc_subview %[[BAR_SH]][%[[C0]]] + // CHECK: ttng.inval_barrier %[[BAR_SLICE0]] + // CHECK: %[[BAR_SLICE1:.*]] = ttg.memdesc_subview %[[BAR_SH]][%[[C1]]] + // CHECK: ttng.inval_barrier %[[BAR_SLICE1]] + // CHECK: ttg.local_dealloc %[[BAR_SH]] + + tt.func public @pipeline_tc05mma(%arg0: i32, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked1> attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %1 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + %2 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + %3 = arith.cmpi sgt, %arg0, %c0_i32 : i32 + %4 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %5 = tt.splat %3 : i1 -> tensor<128x128xi1, #blocked> + %6 = ttg.async_copy_global_to_local %arg1, %4 mask %5 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %7 = ttg.async_commit_group %6 + %8 = ttg.memdesc_subview %2[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %9 = tt.splat %3 : i1 -> tensor<128x128xi1, #blocked> + %10 = ttg.async_copy_global_to_local %arg2, %8 mask %9 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %11 = ttg.async_commit_group %10 + %12:5 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %c0_i32, %arg5 = %c-1_i32, %arg6 = %7, %arg7 = %11, %acc = %cst) -> (i32, i32, !ttg.async.token, !ttg.async.token, tensor<128x128xf32, #blocked1>) : i32 { + %15 = arith.subi %arg0, %c1_i32 : i32 + %16 = arith.cmpi slt, %arg3, %15 : i32 + %17 = arith.addi %arg5, %c1_i32 : i32 + %18 = arith.cmpi slt, %17, %c2_i32 : i32 + %19 = arith.select %18, %17, %c0_i32 : i32 + %20 = ttg.memdesc_subview %1[%19, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %21 = ttg.async_wait %arg7 {num = 0 : i32} + %22 = ttg.memdesc_subview %2[%19, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %tmem = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %20, %22, %tmem, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %tmem : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + %23 = arith.addi %arg4, %c1_i32 : i32 + %24 = arith.cmpi slt, %23, %c2_i32 : i32 + %25 = arith.select %24, %23, %c0_i32 : i32 + %26 = ttg.memdesc_subview %1[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %27 = tt.splat %16 : i1 -> tensor<128x128xi1, #blocked> + %28 = ttg.async_copy_global_to_local %arg1, %26 mask %27 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %29 = ttg.async_commit_group %28 + %30 = ttg.memdesc_subview %2[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %31 = tt.splat %16 : i1 -> tensor<128x128xi1, #blocked> + %32 = ttg.async_copy_global_to_local %arg2, %30 mask %31 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %33 = ttg.async_commit_group %32 + scf.yield %25, %19, %29, %33, %acc_res : i32, i32, !ttg.async.token, !ttg.async.token, tensor<128x128xf32, #blocked1> + } + %13 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %1 : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %2 : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + tt.return %12#4 : tensor<128x128xf32, #blocked1> + } +} + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @pipeline_tc05mma_scaled + // CHECK: ttng.tc_gen5_mma_scaled + // MMA pipeline should not apply since scales are not passed in shmem + // CHECK-NOT: ttng.wait_barrier + + tt.func public @pipeline_tc05mma_scaled(%arg0: i32, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %scale_A: tensor<128x4x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %scale_B: tensor<128x4x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked1> attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %1 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + %2 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + %scale_A_SMEM = ttg.local_alloc : () -> !ttg.memdesc<2x128x4xi8, #shared1, #ttg.shared_memory, mutable> + %scale_B_SMEM = ttg.local_alloc : () -> !ttg.memdesc<2x128x4xi8, #shared1, #ttg.shared_memory, mutable> + %3 = arith.cmpi sgt, %arg0, %c0_i32 : i32 + %4 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %5 = tt.splat %3 : i1 -> tensor<128x128xi1, #blocked> + %6 = ttg.async_copy_global_to_local %arg1, %4 mask %5 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %7 = ttg.async_commit_group %6 + %8 = ttg.memdesc_subview %2[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %9 = tt.splat %3 : i1 -> tensor<128x128xi1, #blocked> + %10 = ttg.async_copy_global_to_local %arg2, %8 mask %9 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %11 = ttg.async_commit_group %10 + %43 = ttg.memdesc_subview %scale_A_SMEM[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x4xi8, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable> + %44 = tt.splat %3 : i1 -> tensor<128x4xi1, #blocked> + %45 = ttg.async_copy_global_to_local %scale_A, %43 mask %44 : tensor<128x4x!tt.ptr, #blocked> -> <128x4xi8, #shared1, #ttg.shared_memory, mutable> + %46 = ttg.async_commit_group %45 + %47 = ttg.memdesc_subview %scale_B_SMEM[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x4xi8, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable> + %48 = tt.splat %3 : i1 -> tensor<128x4xi1, #blocked> + %49 = ttg.async_copy_global_to_local %scale_B, %47 mask %48 : tensor<128x4x!tt.ptr, #blocked> -> <128x4xi8, #shared1, #ttg.shared_memory, mutable> + %50 = ttg.async_commit_group %49 + %51 = ttg.async_wait %50 {num = 0 : i32} + %12:9 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %c0_i32, %arg5 = %c-1_i32, %arg6 = %7, %arg7 = %11, %arg8 = %43, %arg8_token = %51, %arg9 = %47, %arg9_token = %51, %acc = %cst) -> (i32, i32, !ttg.async.token, !ttg.async.token, !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.async.token, !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.async.token, tensor<128x128xf32, #blocked1>) : i32 { + %15 = arith.subi %arg0, %c1_i32 : i32 + %16 = arith.cmpi slt, %arg3, %15 : i32 + %17 = arith.addi %arg5, %c1_i32 : i32 + %18 = arith.cmpi slt, %17, %c2_i32 : i32 + %19 = arith.select %18, %17, %c0_i32 : i32 + %20 = ttg.memdesc_subview %1[%19, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %21 = ttg.async_wait %arg7 {num = 0 : i32} + %22 = ttg.memdesc_subview %2[%19, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %122 = ttg.local_load %arg8 token %arg8_token : !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable> -> tensor<128x4xi8, #blocked2> + %123 = ttg.local_load %arg9 token %arg9_token : !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable> -> tensor<128x4xi8, #blocked2> + %125 = ttg.convert_layout %122 : tensor<128x4xi8, #blocked2> -> tensor<128x4xi8, #blocked3> + %126 = ttg.convert_layout %123 : tensor<128x4xi8, #blocked2> -> tensor<128x4xi8, #blocked3> + %127 = ttng.tmem_alloc %125 : (tensor<128x4xi8, #blocked3>) -> !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory> + %128 = ttng.tmem_alloc %126 : (tensor<128x4xi8, #blocked3>) -> !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory> + %tmem = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma_scaled %20, %22, %tmem, %127, %128, %true, %true lhs = e5m2 rhs = e5m2: (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory>, !ttg.memdesc<128x4xi8, #tmem1, #ttng.tensor_memory>, i1, i1) -> () + %acc_res = ttng.tmem_load %tmem : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + %23 = arith.addi %arg4, %c1_i32 : i32 + %24 = arith.cmpi slt, %23, %c2_i32 : i32 + %25 = arith.select %24, %23, %c0_i32 : i32 + %26 = ttg.memdesc_subview %1[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %27 = tt.splat %16 : i1 -> tensor<128x128xi1, #blocked> + %28 = ttg.async_copy_global_to_local %arg1, %26 mask %27 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %29 = ttg.async_commit_group %28 + %30 = ttg.memdesc_subview %2[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %31 = tt.splat %16 : i1 -> tensor<128x128xi1, #blocked> + %32 = ttg.async_copy_global_to_local %arg2, %30 mask %31 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %33 = ttg.async_commit_group %32 + %34 = ttg.memdesc_subview %scale_A_SMEM[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x4xi8, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable> + %35 = tt.splat %16 : i1 -> tensor<128x4xi1, #blocked> + %36 = ttg.async_copy_global_to_local %scale_A, %34 mask %35 : tensor<128x4x!tt.ptr, #blocked> -> <128x4xi8, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.async_commit_group %36 + %38 = ttg.memdesc_subview %scale_B_SMEM[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x4xi8, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable> + %39 = tt.splat %16 : i1 -> tensor<128x4xi1, #blocked> + %40 = ttg.async_copy_global_to_local %scale_B, %38 mask %39 : tensor<128x4x!tt.ptr, #blocked> -> <128x4xi8, #shared1, #ttg.shared_memory, mutable> + %41 = ttg.async_commit_group %40 + %42 = ttg.async_wait %41 {num = 0 : i32} + scf.yield %25, %19, %29, %33, %34, %42, %38, %42, %acc_res : i32, i32, !ttg.async.token, !ttg.async.token, !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.async.token, !ttg.memdesc<128x4xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.async.token, tensor<128x128xf32, #blocked1> + } + %13 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %1 : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %2 : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %scale_A_SMEM : !ttg.memdesc<2x128x4xi8, #shared1, #ttg.shared_memory, mutable> + ttg.local_dealloc %scale_B_SMEM : !ttg.memdesc<2x128x4xi8, #shared1, #ttg.shared_memory, mutable> + tt.return %12#8 : tensor<128x128xf32, #blocked1> + } +} + + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @pipeline_tc05mma_scaled_shmem + // CHECK: ttng.wait_barrier + + tt.func public @pipeline_tc05mma_scaled_shmem(%arg0: i32, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %scale_A: tensor<1x512x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %scale_B: tensor<1x512x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked1> attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %1 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + %2 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + %scale_A_SMEM = ttg.local_alloc : () -> !ttg.memdesc<2x1x512xi8, #shared1, #ttg.shared_memory, mutable> + %scale_B_SMEM = ttg.local_alloc : () -> !ttg.memdesc<2x1x512xi8, #shared1, #ttg.shared_memory, mutable> + %3 = arith.cmpi sgt, %arg0, %c0_i32 : i32 + %4 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %5 = tt.splat %3 : i1 -> tensor<128x128xi1, #blocked> + %6 = ttg.async_copy_global_to_local %arg1, %4 mask %5 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %7 = ttg.async_commit_group %6 + %8 = ttg.memdesc_subview %2[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %9 = tt.splat %3 : i1 -> tensor<128x128xi1, #blocked> + %10 = ttg.async_copy_global_to_local %arg2, %8 mask %9 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %11 = ttg.async_commit_group %10 + + %43 = ttg.memdesc_subview %scale_A_SMEM[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x1x512xi8, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable> + %44 = tt.splat %3 : i1 -> tensor<1x512xi1, #blocked> + %45 = ttg.async_copy_global_to_local %scale_A, %43 mask %44 : tensor<1x512x!tt.ptr, #blocked> -> <1x512xi8, #shared1, #ttg.shared_memory, mutable> + %46 = ttg.async_commit_group %45 + %47 = ttg.memdesc_subview %scale_B_SMEM[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x1x512xi8, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable> + %48 = tt.splat %3 : i1 -> tensor<1x512xi1, #blocked> + %49 = ttg.async_copy_global_to_local %scale_B, %47 mask %48 : tensor<1x512x!tt.ptr, #blocked> -> <1x512xi8, #shared1, #ttg.shared_memory, mutable> + %50 = ttg.async_commit_group %49 + %51 = ttg.async_wait %50 {num = 0 : i32} + %12:9 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %c0_i32, %arg5 = %c-1_i32, %arg6 = %7, %arg7 = %11, %arg8 = %43, %arg8_token = %51, %arg9 = %47, %arg9_token = %51, %acc = %cst) -> (i32, i32, !ttg.async.token, !ttg.async.token, !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.async.token, !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.async.token, tensor<128x128xf32, #blocked1>) : i32 { + %15 = arith.subi %arg0, %c1_i32 : i32 + %16 = arith.cmpi slt, %arg3, %15 : i32 + %17 = arith.addi %arg5, %c1_i32 : i32 + %18 = arith.cmpi slt, %17, %c2_i32 : i32 + %19 = arith.select %18, %17, %c0_i32 : i32 + %20 = ttg.memdesc_subview %1[%19, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %21 = ttg.async_wait %arg7 {num = 0 : i32} + %22 = ttg.memdesc_subview %2[%19, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + + %tmem = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + + ttng.tc_gen5_mma_scaled %20, %22, %tmem, %arg8, %arg9, %true, %true lhs = e5m2 rhs = e5m2: (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %tmem : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + %23 = arith.addi %arg4, %c1_i32 : i32 + %24 = arith.cmpi slt, %23, %c2_i32 : i32 + %25 = arith.select %24, %23, %c0_i32 : i32 + %26 = ttg.memdesc_subview %1[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %27 = tt.splat %16 : i1 -> tensor<128x128xi1, #blocked> + %28 = ttg.async_copy_global_to_local %arg1, %26 mask %27 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %29 = ttg.async_commit_group %28 + %30 = ttg.memdesc_subview %2[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %31 = tt.splat %16 : i1 -> tensor<128x128xi1, #blocked> + %32 = ttg.async_copy_global_to_local %arg2, %30 mask %31 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf16, #shared, #ttg.shared_memory, mutable> + %33 = ttg.async_commit_group %32 + %34 = ttg.memdesc_subview %scale_A_SMEM[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x1x512xi8, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable> + %35 = tt.splat %16 : i1 -> tensor<1x512xi1, #blocked> + %36 = ttg.async_copy_global_to_local %scale_A, %34 mask %35 : tensor<1x512x!tt.ptr, #blocked> -> <1x512xi8, #shared1, #ttg.shared_memory, mutable> + %37 = ttg.async_commit_group %36 + %38 = ttg.memdesc_subview %scale_B_SMEM[%25, %c0_i32, %c0_i32] : !ttg.memdesc<2x1x512xi8, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable> + %39 = tt.splat %16 : i1 -> tensor<1x512xi1, #blocked> + %40 = ttg.async_copy_global_to_local %scale_B, %38 mask %39 : tensor<1x512x!tt.ptr, #blocked> -> <1x512xi8, #shared1, #ttg.shared_memory, mutable> + %41 = ttg.async_commit_group %40 + %42 = ttg.async_wait %41 {num = 0 : i32} + scf.yield %25, %19, %29, %33, %34, %42, %38, %42, %acc_res : i32, i32, !ttg.async.token, !ttg.async.token, !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.async.token, !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.async.token, tensor<128x128xf32, #blocked1> + } + %13 = ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %1 : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %2 : !ttg.memdesc<2x128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %scale_A_SMEM : !ttg.memdesc<2x1x512xi8, #shared1, #ttg.shared_memory, mutable> + ttg.local_dealloc %scale_B_SMEM : !ttg.memdesc<2x1x512xi8, #shared1, #ttg.shared_memory, mutable> + tt.return %12#8 : tensor<128x128xf32, #blocked1> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/ops.mlir b/third_party/enflame/include/triton/test/TritonGPU/ops.mlir new file mode 100644 index 000000000..b300c0595 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/ops.mlir @@ -0,0 +1,209 @@ +// RUN: triton-opt --split-input-file %s | FileCheck %s + +// CHECK: #[[$WMMA_GEN1:.*]] = #ttg.amd_wmma<{{.*}}version = 1{{.*}}> +// CHECK: #[[$WMMA_GEN2:.*]] = #ttg.amd_wmma<{{.*}}version = 2{{.*}}> +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> + +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_layout + tt.func @wmma_layout(%0: tensor<16x16xf16, #blocked>) { + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN1]]> + tt.return + } + + // CHECK-LABEL: wmma_dot_op_layout + tt.func @wmma_dot_op_layout(%0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) { + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>, kWidth = 16}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN1]], kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: wmma_gen2_layout + tt.func @wmma_gen2_layout(%0: tensor<16x16xf16, #blocked>) { + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN2]]> + tt.return + } + + // CHECK-LABEL: wmma_gen2_dot_op_layout + tt.func @wmma_gen2_dot_op_layout(%0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) { + %1 = ttg.convert_layout %0 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>, kWidth = 8}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN2]], kWidth = 8}>> + tt.return + } +} +// ----- + +#blocked= #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[$LINEAR:.*]] = #ttg.linear<{{.*}}> + +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @blocked_to_linear + tt.func @blocked_to_linear(%input: tensor<32x4xi8, #blocked>) { + // The layout is the basic layout generated by DecomposeScaledBlocked + %output = ttg.convert_layout %input {allocation.offset = 0 : i32} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>> + // CHECK: %{{.+}} = ttg.convert_layout %{{.+}} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #[[$LINEAR]]> + tt.return + } +} + +// ----- + +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: memdesc + // CHECK-SAME: !ttg.memdesc<1x64x16xf16, #{{.+}}> + tt.func @memdesc(%d : !ttg.memdesc<1x64x16xf16, #shared0, #smem>) { + tt.return + } +} + +// ----- + +// CHECK-LABEL: @warp_specialize_nothing +tt.func @warp_specialize_nothing() { + // CHECK-NEXT: ttg.warp_specialize() + ttg.warp_specialize() + // CHECK-NEXT: default { + default { + // CHECK-NEXT: ttg.warp_yield + ttg.warp_yield + // CHECK-NEXT: } : () -> () + } : () -> () + tt.return +} + +// CHECK-LABEL: @warp_specialize_no_partitions +tt.func @warp_specialize_no_partitions(%arg0: i32, %arg1: i64) -> i64 { + // CHECK-NEXT: %0 = ttg.warp_specialize(%arg0) + %0 = ttg.warp_specialize(%arg0) + // CHECK-NEXT: default { + default { + // CHECK-NEXT: ttg.warp_yield %arg1 : i64 + ttg.warp_yield %arg1 : i64 + // CHECK-NEXT: } : (i32) -> i64 + } : (i32) -> i64 + tt.return %0 : i64 +} + +// CHECK-LABEL: @warp_specialize_partitions +tt.func @warp_specialize_partitions(%arg0: i32, %arg1: i64) -> i64 { + // CHECK-NEXT: %0 = ttg.warp_specialize(%arg0) + %0 = ttg.warp_specialize(%arg0) + // CHECK-NEXT: default { + default { + // CHECK-NEXT: ttg.warp_yield %arg1 : i64 + ttg.warp_yield %arg1 : i64 + // CHECK-NEXT: } + } + // CHECK-NEXT: partition0(%arg2: i32) num_warps(4) { + partition0(%arg2: i32) num_warps(4) { + // CHECK-NEXT: arith.addi %arg2, %arg2 : i32 + %1 = arith.addi %arg2, %arg2 : i32 + // CHECK-NEXT: ttg.warp_return + ttg.warp_return + // CHECK-NEXT: } + } + // CHECK-NEXT: partition1(%arg2: i32) num_warps(1) { + partition1(%arg2: i32) num_warps(1) { + // CHECK-NEXT: ttg.warp_return + ttg.warp_return + // CHECK-NEXT: } + } + // CHECK-NEXT: partition2(%arg2: i32) num_warps(8) { + partition2(%arg2: i32) num_warps(8) { + // CHECK-NEXT: arith.muli + %1 = arith.muli %arg2, %arg2 : i32 + // CHECK-NEXT: ttg.warp_return + ttg.warp_return + // CHECK-NEXT: } : (i32) -> i64 + } : (i32) -> i64 + tt.return %0 : i64 +} + +// CHECK-LABEL: @warp_specialize_multiple_args +tt.func @warp_specialize_multiple_args_res(%arg0: i32, %arg1: i32) -> (i32, i32) { + // CHECK-NEXT: %0:2 = ttg.warp_specialize(%arg0, %arg1) + %0:2 = ttg.warp_specialize(%arg0, %arg1) + // CHECK-NEXT: default { + default { + // CHECK-NEXT: ttg.warp_yield %arg0, %arg1 : i32, i32 + ttg.warp_yield %arg0, %arg1 : i32, i32 + // CHECK-NEXT: } + } + // CHECK-NEXT: partition0(%arg2: i32, %arg3: i32) num_warps(4) { + partition0(%arg2: i32, %arg3: i32) num_warps(4) { + // CHECK-NEXT: arith.addi %arg2, %arg3 : i32 + %1 = arith.addi %arg2, %arg3 : i32 + // CHECK-NEXT: ttg.warp_return + ttg.warp_return + // CHECK-NEXT: } : (i32, i32) -> (i32, i32) + } : (i32, i32) -> (i32, i32) + tt.return %0#0, %0#1 : i32, i32 +} + +// ----- + +// CHECK-DAG: [[BLOCKED_1_WARPS:#.*]] = #ttg.blocked{{.*}} warpsPerCTA = [1] +#blocked_1_warps = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +// CHECK-DAG: [[BLOCKED_2_WARPS:#.*]] = #ttg.blocked{{.*}} warpsPerCTA = [2] +#blocked_2_warps = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> +// CHECK-DAG: [[BLOCKED_4_WARPS:#.*]] = #ttg.blocked{{.*}} warpsPerCTA = [4] +#blocked_4_warps = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK-DAG: [[BLOCKED_8_WARPS:#.*]] = #ttg.blocked{{.*}} warpsPerCTA = [8] +#blocked_8_warps = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32} { + +// CHECK: @function_scope +tt.func @function_scope() attributes {"ttg.num-warps" = 8 : i32} { + // CHECK-NEXT: tt.make_range {{.*}} tensor<128xi32, [[BLOCKED_8_WARPS]]> + tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked_8_warps> + tt.return +} + +// CHECK: @function_no_scope +tt.func @function_no_scope() { + // CHECK-NEXT: tt.make_range {{.*}} tensor<128xi32, [[BLOCKED_4_WARPS]]> + tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked_4_warps> + // CHECK-NEXT: ttg.warp_specialize() + ttg.warp_specialize() + default { + ttg.warp_yield + } + // CHECK: partition0() num_warps(2) + partition0() num_warps(2) { + // CHECK-NEXT: tt.make_range {{.*}} tensor<128xi32, [[BLOCKED_2_WARPS]]> + tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked_2_warps> + ttg.warp_return + } + // CHECK: partition1() num_warps(1) + partition1() num_warps(1) { + // CHECK-NEXT: tt.make_range {{.*}} tensor<128xi32, [[BLOCKED_1_WARPS]]> + tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked_1_warps> + ttg.warp_return + } : () -> () + tt.return +} + +} + +// ----- + +// CHECK-DAG: [[$BLOCKED:#.*]] = #ttg.blocked +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-DAG: [[$LINEAR:#.*]] = #ttg.linear +#linear = #ttg.linear<{register = [[0, 1], [16, 0], [32, 0], [64, 0]], lane = [[0, 0], [0, 0], [0, 0], [1, 0], [2, 0]], warp = [[4, 0], [8, 0]], block = []}> + +module attributes {"ttg.num-warps" = 4 : i32} { +// CHECK-LABEL: @split_join_linear_mix +tt.func @split_join_linear_mix(%arg: tensor<128x2xf32, #linear>) attributes {"ttg.num-warps" = 8 : i32} { + // CHECK-NEXT: tt.split %{{.*}} : tensor<128x2xf32, [[$LINEAR]]> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = [[$BLOCKED]]}>> + %lhs, %rhs = tt.split %arg : tensor<128x2xf32, #linear> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: tt.join %{{.*}}, %{{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = [[$BLOCKED]]}>> -> tensor<128x2xf32, [[$LINEAR]]> + %j = tt.join %lhs, %rhs : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x2xf32, #linear> + tt.return +} +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/optimize-locality.mlir b/third_party/enflame/include/triton/test/TritonGPU/optimize-locality.mlir new file mode 100644 index 000000000..1214ed9c1 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/optimize-locality.mlir @@ -0,0 +1,771 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-thread-locality -canonicalize | FileCheck %s + +// CHECK-LABEL: negative_zero_accumulator +// CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0.000000e+00> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.addf +// CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] +// CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @negative_zero_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<-0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: positive_zero_accumulator +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> +// CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<0.000000e+00> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST1]]) -> {{.*}} +// CHECK: tt.load +// CHECK: tt.reshape +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.addf +// CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] +// CHECK: arith.addf %[[CVT_OUTPUT]], %[[CST]] +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @positive_zero_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: slice_layout +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK-NEXT: "tt.reduce"(%[[LOAD]]) <{axis = 1 : i32}> +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK-NEXT: scf.yield +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[LOOP_OUTPUT]] +#blocked3d = #ttg.blocked<{sizePerThread = [1, 4, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#slice2d = #ttg.slice<{dim = 2, parent = #blocked3d}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @slice_layout( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #slice2d> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #slice2d}>> -> tensor<1x128xi32, #slice2d> + %31 = tt.broadcast %30 : tensor<1x128xi32, #slice2d> -> tensor<32x128xi32, #slice2d> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #slice2d>, tensor<32x128xi32, #slice2d> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #slice2d> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #slice2d>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #slice2d}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: mma_layout +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK-NEXT: "tt.reduce"(%[[LOAD]]) <{axis = 1 : i32}> +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK-NEXT: scf.yield +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[LOOP_OUTPUT]] +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @mma_layout( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #mma> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma> + %31 = tt.broadcast %30 : tensor<1x128xi32, #mma> -> tensor<32x128xi32, #mma> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #mma>, tensor<32x128xi32, #mma> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #mma> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: max_reduce +// CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0xFF800000> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.maximumf +// CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.maximumf +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] +// CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @max_reduce( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.maximumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: max_reduce_zero_int_accumulator +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> +// CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<0xFF800000> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST1]]) -> {{.*}} +// CHECK: tt.load +// CHECK: tt.reshape +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.maximumf +// CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.maximumf +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] +// CHECK: arith.maximumf %[[CVT_OUTPUT]], %[[CST]] +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @max_reduce_zero_int_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.maximumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: min_reduce +// CHECK: %[[CST:.*]] = arith.constant dense<0x7F800000> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.minimumf +// CHECK: arith.minimumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.minimumf +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] +// CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @min_reduce( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0x7F800000> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.minimumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: min_reduce_zero_int_accumulator +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> +// CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<0x7F800000> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST1]]) -> {{.*}} +// CHECK: tt.load +// CHECK: tt.reshape +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.minimumf +// CHECK: arith.minimumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.minimumf +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] +// CHECK: arith.minimumf %[[CVT_OUTPUT]], %[[CST]] +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @min_reduce_zero_int_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.minimumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: mul_reduce +// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.mulf +// CHECK: arith.mulf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.mulf +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] +// CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @mul_reduce( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.mulf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.mulf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: mul_reduce_zero_int_accumulator +// CHECK: %[[CST:.*]] = arith.constant dense +// CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<1.000000e+00> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST1]]) -> {{.*}} +// CHECK: tt.load +// CHECK: tt.reshape +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.mulf +// CHECK: arith.mulf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.mulf +// CHECK: %[[CVT_OUTPUT:.*]] = ttg.convert_layout %[[FINAL_REDUCE]] +// CHECK: arith.mulf %[[CVT_OUTPUT]], %[[CST]] +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @mul_reduce_zero_int_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.mulf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.mulf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + + +// ----- + +// CHECK-LABEL: remains_unchanged +// CHECK: %[[CST:.*]] = arith.constant dense +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: %[[MULF:.*]] = arith.mulf %[[LOAD]], %[[LOAD]] +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"(%[[MULF]]) <{axis = 1 : i32}> +// CHECK: arith.maximumf +// CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @remains_unchanged( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs y : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 : tensor<32x128x!tt.ptr, #blocked> + %333 = arith.mulf %33, %33: tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%333) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.maximumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } + %26 = ttg.convert_layout %19 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xf32, #blocked1> + tt.store %25, %26 : tensor<32x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-DAG: #[[$BLOCK0:.+]] = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCK1:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCK2:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +// CHECK-LABEL: optimize_view_layout +// CHECK: %[[R:.+]] = tt.reshape {{.*}} allow_reorder efficient_layout : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> +// CHECK: %[[C:.+]] = ttg.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK2]]> -> tensor<64x16xf32, #[[$BLOCK1]]> +// CHECK: "tt.reduce"(%[[C]]) +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @optimize_view_layout(%arg0: tensor<8x128xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> { + %0 = tt.reshape %arg0 allow_reorder : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> + %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %2 = arith.maximumf %arg1, %arg2 : f32 + tt.reduce.return %2 : f32 + }) : (tensor<64x16xf32, #blocked1>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + tt.return %1 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> + } +} + +// ----- + + +// CHECK-DAG: #[[$BLOCK0:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +// CHECK-DAG: #[[$BLOCK1:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +// CHECK-LABEL: optimize_view_layout_same_shape +// CHECK: %[[R:.+]] = tt.reshape {{.*}} allow_reorder efficient_layout : tensor<64x16xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK1]]> +// CHECK: %[[C:.+]] = ttg.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK1]]> -> tensor<64x16xf32, #[[$BLOCK0]]> +// CHECK: "tt.reduce"(%[[C]]) +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @optimize_view_layout_same_shape(%arg0: tensor<64x16xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> { + %0 = tt.reshape %arg0 allow_reorder : tensor<64x16xf32, #blocked> -> tensor<64x16xf32, #blocked> + %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %2 = arith.maximumf %arg1, %arg2 : f32 + tt.reduce.return %2 : f32 + }) : (tensor<64x16xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + tt.return %1 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + } +} + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#slice = #ttg.slice<{dim = 1, parent = #blocked}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { + tt.func public @reduce_for_arg(%arg: tensor<64x128xf32, #blocked>, %arg1: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c128_i32 = arith.constant 128 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64x128xf32, #blocked> + %64:1 = scf.for %arg22 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg29 = %arg) -> (tensor<64x128xf32, #blocked>) : i32 { + %129 = "tt.reduce"(%arg29) <{axis = 1 : i32}> ({ + ^bb0(%arg31: f32, %arg32: f32): + %160 = arith.maxnumf %arg31, %arg32 : f32 + tt.reduce.return %160 : f32 + }) : (tensor<64x128xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %75 = ttg.convert_layout %129 : tensor<64xf32, #slice> -> tensor<64xf32, #blocked1> + %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> + %80 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %81 = tt.addptr %80, %79 : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + tt.store %81, %75 : tensor<64x!tt.ptr, #blocked1> + %141 = arith.addf %arg29, %cst_1 : tensor<64x128xf32, #blocked> + scf.yield %141 : tensor<64x128xf32, #blocked> + } + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [1, 0]}> + +// CHECK: [[LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: set_warp_shuffle_layout_square_axis_0 +tt.func @set_warp_shuffle_layout_square_axis_0(%arg0: tensor<64x64xf32, #blocked>, %arg1: tensor<64x64xi32, #blocked>) -> tensor<64x64xf32, #blocked> { + // CHECK-NEXT: [[SRC:%.*]] = ttg.convert_layout %arg0 + // CHECK-NEXT: [[IDX:%.*]] = ttg.convert_layout %arg1 + // CHECK-NEXT: [[OUT:%.*]] = tt.gather [[SRC]][[[IDX]]] {axis = 0 : i32, efficient_layout} : (tensor<64x64xf32, [[LAYOUT]]>, tensor<64x64xi32, [[LAYOUT]]>) -> tensor<64x64xf32, [[LAYOUT]]> + %0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<64x64xf32, #blocked>, tensor<64x64xi32, #blocked>) -> tensor<64x64xf32, #blocked> + // CHECK-NEXT: [[RES:%.*]] = ttg.convert_layout [[OUT]] + // CHECK-NEXT: return [[RES]] + tt.return %0 : tensor<64x64xf32, #blocked> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [1, 0]}> + +// CHECK: [[LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: set_warp_shuffle_layout_square_axis_1 +tt.func @set_warp_shuffle_layout_square_axis_1(%arg0: tensor<64x64xf32, #blocked>, %arg1: tensor<64x64xi32, #blocked>) -> tensor<64x64xf32, #blocked> { + // CHECK: tt.gather {{.*}} (tensor<64x64xf32, [[LAYOUT]]>, tensor<64x64xi32, [[LAYOUT]]>) -> tensor<64x64xf32, [[LAYOUT]]> + %0 = tt.gather %arg0[%arg1] {axis = 1 : i32} : (tensor<64x64xf32, #blocked>, tensor<64x64xi32, #blocked>) -> tensor<64x64xf32, #blocked> + tt.return %0 : tensor<64x64xf32, #blocked> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [1, 0]}> + +// CHECK: [[LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: set_warp_shuffle_layout_warp_broadcast +tt.func @set_warp_shuffle_layout_warp_broadcast(%arg0: tensor<64x64xf32, #blocked>, %arg1: tensor<64x1xi32, #blocked>) -> tensor<64x1xf32, #blocked> { + // CHECK: tt.gather {{.*}} [[LAYOUT]]> + %0 = tt.gather %arg0[%arg1] {axis = 1 : i32} : (tensor<64x64xf32, #blocked>, tensor<64x1xi32, #blocked>) -> tensor<64x1xf32, #blocked> + tt.return %0 : tensor<64x1xf32, #blocked> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [16, 2, 1], warpsPerCTA = [2, 1, 2], order = [1, 0, 2]}> + +// CHECK: [[LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: set_warp_shuffle_layout_3d_warp +tt.func @set_warp_shuffle_layout_3d_warp(%arg0: tensor<32x2x32xf32, #blocked>, %arg1: tensor<32x2x2xi32, #blocked>) -> tensor<32x2x2xf32, #blocked> { + // CHECK: tt.gather {{.*}} [[LAYOUT]]> + %0 = tt.gather %arg0[%arg1] {axis = 2 : i32} : (tensor<32x2x32xf32, #blocked>, tensor<32x2x2xi32, #blocked>) -> tensor<32x2x2xf32, #blocked> + tt.return %0 : tensor<32x2x2xf32, #blocked> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [16, 2, 1], warpsPerCTA = [2, 1, 2], order = [1, 0, 2]}> + +// CHECK: [[LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: set_warp_shuffle_layout_3d_warp_thread_split +tt.func @set_warp_shuffle_layout_3d_warp_thread_split(%arg0: tensor<32x4x16xf32, #blocked>, %arg1: tensor<32x4x2xi32, #blocked>) -> tensor<32x4x2xf32, #blocked> { + // CHECK: tt.gather {{.*}} [[LAYOUT]]> + %0 = tt.gather %arg0[%arg1] {axis = 2 : i32} : (tensor<32x4x16xf32, #blocked>, tensor<32x4x2xi32, #blocked>) -> tensor<32x4x2xf32, #blocked> + tt.return %0 : tensor<32x4x2xf32, #blocked> +} + +} + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [1, 0]}> + +// CHECK: [[LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: set_warp_shuffle_layout_thread_broadcast +tt.func @set_warp_shuffle_layout_thread_broadcast(%arg0: tensor<16x64xf32, #blocked>, %arg1: tensor<16x1xi32, #blocked>) -> tensor<16x1xf32, #blocked> { + // CHECK: tt.gather {{.*}} [[LAYOUT]]> + %0 = tt.gather %arg0[%arg1] {axis = 1 : i32} : (tensor<16x64xf32, #blocked>, tensor<16x1xi32, #blocked>) -> tensor<16x1xf32, #blocked> + tt.return %0 : tensor<16x1xf32, #blocked> +} + +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [1, 0]}> + +// CHECK: [[LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + +// CHECK: set_warp_shuffle_layout_large_source +tt.func @set_warp_shuffle_layout_large_source(%arg0: tensor<256x256xf32, #blocked>, %arg1: tensor<256x8xi32, #blocked>) -> tensor<256x8xf32, #blocked> { + // CHECK: tt.gather {{.*}} [[LAYOUT]]> + %0 = tt.gather %arg0[%arg1] {axis = 1 : i32} : (tensor<256x256xf32, #blocked>, tensor<256x8xi32, #blocked>) -> tensor<256x8xf32, #blocked> + tt.return %0 : tensor<256x8xf32, #blocked> +} + +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/optimize_epilogue.mlir b/third_party/enflame/include/triton/test/TritonGPU/optimize_epilogue.mlir new file mode 100644 index 000000000..142ec762f --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/optimize_epilogue.mlir @@ -0,0 +1,32 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-optimize-epilogue | FileCheck --check-prefixes=GCN %s + +#mfma = #ttg.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // GCN-LABEL: mfma_epilogue_simple + // CHECK-LABEL: mfma_epilogue_simple + tt.func public @mfma_epilogue_simple(%data: tensor<64x64xf16, #mfma>, %ptr: tensor<64x64x!tt.ptr, #blocked>) { + // GCN: [[PTR:%[a-z0-9]+]] = ttg.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> + // GCN: tt.store [[PTR]], {{.*}} : tensor<{{.*}}, #mma> + %converted_data = ttg.convert_layout %data : tensor<64x64xf16, #mfma> -> tensor<64x64xf16, #blocked> + tt.store %ptr, %converted_data : tensor<64x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{warpsPerCTA=[1,1], instrShape=[32,32], isTranspose=false}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 16], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // GCN-LABEL: mfma_epilogue_chained_elementwise + // CHECK-LABEL: mfma_epilogue_chained_elementwise + tt.func public @mfma_epilogue_chained_elementwise(%data: tensor<64x64xf32, #mfma>, %ptr: tensor<64x64x!tt.ptr, #blocked>) { + // GCN: [[PTR:%[a-z0-9]+]] = ttg.convert_layout {{.*}} : tensor<{{.*}}, #blocked> -> tensor<{{.*}}, #mma> + // GCN: tt.store [[PTR]], {{.*}} : tensor<{{.*}}, #mma> + %converted_data = ttg.convert_layout %data : tensor<64x64xf32, #mfma> -> tensor<64x64xf32, #blocked> + %trunked = arith.truncf %converted_data : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> + tt.store %ptr, %trunked : tensor<64x64x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/pipeline-assign-latencies.mlir b/third_party/enflame/include/triton/test/TritonGPU/pipeline-assign-latencies.mlir new file mode 100644 index 000000000..a42943f66 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/pipeline-assign-latencies.mlir @@ -0,0 +1,377 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-test-pipeline-assign-latencies=num-stages=3 -canonicalize | FileCheck %s + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}> +#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 16}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @default_stages +tt.func @default_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @small_load +// We should *not* assign latency to the load of b_ptr. +tt.func @small_load(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} + // CHECK-NOT: tt.latency + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @load_into_shared +tt.func @load_into_shared(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #mma> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #mma> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.local_alloc %a_ : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> + + %c = ttng.warp_group_dot %a, %b, %prev_c {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma> + } + tt.return %loop#2: tensor<128x128xf32, #mma> +} + +// CHECK-LABEL: @load_into_lt_4b +tt.func @load_into_lt_4b(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #mma> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #mma> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.local_alloc %a_ : (tensor<128x32xf16, #AL>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> + // Do not pipeline if cp.async would read less than 4 consecutive bytes + // CHECK: tt.load + // CHECK-NOT: {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #shared2, #ttg.shared_memory> + + %c = ttng.warp_group_dot %a, %b, %prev_c {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<32x128xf16, #shared2, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #mma> + } + tt.return %loop#2: tensor<128x128xf32, #mma> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load +tt.func @indirect_load(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @mixed_loads +tt.func @mixed_loads(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#3: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @per_loop_stages +tt.func @per_loop_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> (tensor<128x128xf32, #C>, tensor<128x128xf32, #C>) { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop_cust_stages:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 3 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 3 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 4 : i32} + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop_cust_stages#2, %loop#2: tensor<128x128xf32, #C>, tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load_cust_stages +tt.func @indirect_load_cust_stages(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 5 : i32} + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load_few_stages +tt.func @indirect_load_few_stages(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ind_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_ind_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %b_ind_ptr = %b_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load + // CHECK-NOT: tt.latency + %a_off = tt.load %a_ind_ptr : tensor<128x32x!tt.ptr, #AL> + // CHECK: tt.load + // CHECK-NOT: tt.latency + %b_off = tt.load %b_ind_ptr : tensor<32x128x!tt.ptr, #BL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ind_ptr = tt.addptr %b_ind_ptr, %b_ind_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ptr = tt.addptr %a_ptr, %a_off {tt.divisibility = dense<16> : tensor<128x32xi32>, tt.contiguity = dense<32> : tensor<128x32xi32>, tt.constancy = dense<1> : tensor<128x32xi32>} : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off {tt.divisibility = dense<16> : tensor<32x128xi32>, tt.contiguity = dense<32> : tensor<32x128xi32>, tt.constancy = dense<1> : tensor<32x128xi32>} : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %a_ = tt.load %next_a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 1 : i32} + %b_ = tt.load %next_b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_b_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 2 : i32} + tt.return %loop#4: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @non_dot_pipeline +tt.func @non_dot_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x32xf16, #A> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %c = arith.addf %a, %prev_c : tensor<128x32xf16, #A> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %next_a_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A> + } {tt.num_stages = 3 : i32} + tt.return %loop#1: tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @no_pipeline +tt.func @no_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x32xf16, #A> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A>) { + // CHECK: tt.load + // CHECK-NOT: tt.latency + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + + %c = arith.addf %a, %prev_c : tensor<128x32xf16, #A> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + scf.yield %next_a_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xf16, #A> + } + tt.return %loop#1: tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use_cust_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL> {tt.divisibility = 16 : i32, tt.contiguity = 32 : i32}) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {tt.latency = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } {tt.num_stages = 3 : i32} + tt.return %loop#2: tensor<128x128xf32, #C> +} + +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/pipeline-loop-nest.mlir b/third_party/enflame/include/triton/test/TritonGPU/pipeline-loop-nest.mlir new file mode 100644 index 000000000..e3497799c --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/pipeline-loop-nest.mlir @@ -0,0 +1,81 @@ +// RUN: triton-opt %s -pass-pipeline='builtin.module(convert-triton-to-tritongpu{num-warps=4 target=cuda:100},tritongpu-coalesce,tritongpu-accelerate-matmul,tritongpu-remove-layout-conversions,tritongpu-optimize-dot-operands,cse,tritongpu-fuse-nested-loops,canonicalize,tritongpu-optimize-accumulator-init,tritongpu-pipeline,canonicalize)' | FileCheck %s --check-prefix=BLACKWELL +// RUN: triton-opt %s -pass-pipeline='builtin.module(convert-triton-to-tritongpu{num-warps=4 target=cuda:90 },tritongpu-coalesce,tritongpu-accelerate-matmul,tritongpu-remove-layout-conversions,tritongpu-optimize-dot-operands,cse,tritongpu-fuse-nested-loops,canonicalize,tritongpu-optimize-accumulator-init,canonicalize,tritongpu-combine-tensor-select-and-if,tritongpu-pipeline,canonicalize)' | FileCheck %s --check-prefix=HOPPER + +// BLACKWELL-LABEL: @matmul_kernel_tma_persistent +// HOPPER-LABEL: @matmul_kernel_tma_persistent +tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + %c63_i32 = arith.constant 63 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c132_i32 = arith.constant 132 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c127_i32 : i32 + %4 = arith.divsi %3, %c128_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.subi %0, %c132_i32 : i32 + %9 = arith.muli %4, %c8_i32 : i32 + + // BLACKWELL: [[ACC_BUFS:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, + // BLACKWELL: ttg.memdesc_trans + // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]] + // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]], %false + + // BLACKWELL: scf.for + %10 = scf.for %arg6 = %0 to %7 step %c132_i32 iter_args(%arg7 = %8) -> (i32) : i32 { + %11 = arith.divsi %arg6, %9 : i32 + %12 = arith.muli %11, %c8_i32 : i32 + %13 = arith.subi %2, %12 : i32 + %14 = arith.minsi %13, %c8_i32 : i32 + %15 = arith.remsi %arg6, %14 : i32 + %16 = arith.addi %12, %15 : i32 + %17 = arith.remsi %arg6, %9 : i32 + %18 = arith.divsi %17, %14 : i32 + %19 = arith.muli %16, %c128_i32 : i32 + %20 = arith.muli %18, %c128_i32 : i32 + %21 = scf.for %arg8 = %c0_i32 to %6 step %c1_i32 iter_args(%arg9 = %cst) -> (tensor<128x128xf32>) : i32 { + %35 = arith.muli %arg8, %c64_i32 : i32 + %36 = tt.reinterpret_tensor_descriptor %arg0 : !tt.ptr to !tt.tensordesc> + %37 = tt.experimental_descriptor_load %36[%19, %35] : !tt.tensordesc> -> tensor<128x64xf16> + %38 = tt.reinterpret_tensor_descriptor %arg1 : !tt.ptr to !tt.tensordesc> + %39 = tt.experimental_descriptor_load %38[%20, %35] : !tt.tensordesc> -> tensor<128x64xf16> + // BLACKWELL: ttg.memdesc_trans + // BLACKWELL: [[ACC_BUF:%.*]] = ttg.memdesc_subview [[ACC_BUFS]] + // BLACKWELL: ttng.tc_gen5_mma {{%[0-9]+}}, {{%[0-9]+}}, [[ACC_BUF]], %arg + + // HOPPER: [[RESULT:%.*]] = ttng.warp_group_dot {{.*}} isAsync = true + // HOPPER-NEXT: ttng.warp_group_dot_wait [[RESULT]], {{.*}} {pendings = 1 : i32} + %40 = tt.trans %39 {order = array} : tensor<128x64xf16> -> tensor<64x128xf16> + %41 = tt.dot %37, %40, %arg9, inputPrecision = tf32 : tensor<128x64xf16> * tensor<64x128xf16> -> tensor<128x128xf32> + scf.yield %41 : tensor<128x128xf32> + } + // BLACKWELL-COUNT-1: ttng.tmem_load + // BLACKWELL-NOT: ttng.tmem_load + + // HOPPER: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + %22 = arith.addi %arg7, %c132_i32 : i32 + %23 = arith.divsi %22, %9 : i32 + %24 = arith.muli %23, %c8_i32 : i32 + %25 = arith.subi %2, %24 : i32 + %26 = arith.minsi %25, %c8_i32 : i32 + %27 = arith.remsi %22, %26 : i32 + %28 = arith.addi %24, %27 : i32 + %29 = arith.remsi %22, %9 : i32 + %30 = arith.divsi %29, %26 : i32 + %31 = arith.muli %28, %c128_i32 : i32 + %32 = arith.muli %30, %c128_i32 : i32 + %33 = arith.truncf %21 : tensor<128x128xf32> to tensor<128x128xf16> + %34 = tt.reinterpret_tensor_descriptor %arg2 : !tt.ptr to !tt.tensordesc> + tt.experimental_descriptor_store %34[%31, %32], %33 : !tt.tensordesc>, tensor<128x128xf16> + scf.yield %22 : i32 + } {tt.flatten} + tt.return +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/pipeline-lower-loop.mlir b/third_party/enflame/include/triton/test/TritonGPU/pipeline-lower-loop.mlir new file mode 100644 index 000000000..a0f717925 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/pipeline-lower-loop.mlir @@ -0,0 +1,885 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-test-pipeline-lower-loop -canonicalize | FileCheck %s + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @unscheduled_loop +// CHECK: scf.for +// CHECK: tt.load +// CHECK: "use" +tt.func @unscheduled_loop(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.load %a_ptr_init : tensor<128x32x!tt.ptr, #A> + "use"(%a) : (tensor<128x32xf16, #A>) -> () + } + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @one_dep_async +// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 +// CHECK-DAG: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32 +// CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 +// CHECK: scf.for {{.*}} iter_args(%[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]]) +// CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 +// CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 +// CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 +// CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 +// CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 +// CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 +// CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} +// CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_VAL:.*]] = ttg.local_load %[[A_EXT]] token %[[A_TOK3]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: "use"(%[[A_VAL]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: scf.yield %[[INS_NEXT]], %[[EXT_NEXT]] +// CHECK-DAG: ttg.local_dealloc %[[A]] +// CHECK-DAG: ttg.async_wait {num = 0 : i32} + +tt.func @one_dep_async(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () + } {tt.scheduled_max_stage = 2 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @different_use_stages +// CHECK: scf.for +// CHECK: ttg.async_copy_global_to_local %{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: ttg.async_wait {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} +// CHECK: ttg.memdesc_subview {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_VAL:.*]] = ttg.local_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: "use1"(%[[A_VAL]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: "use2"(%[[A_VAL]]) {loop.cluster = 0 : i32, loop.stage = 3 : i32} +tt.func @different_use_stages(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + "use1"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () + "use2"(%a) {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf16, #A>) -> () + } {tt.scheduled_max_stage = 3 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @used_by_if_yield +// CHECK-DAG: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32 +// CHECK: scf.for +// CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} +// CHECK: ttg.local_load {{.*}} token %[[A_TOK3]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: "use"{{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + +tt.func @used_by_if_yield(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %init_a : tensor<128x32xf16, #A>, + %cnd : i1) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + %a_if = scf.if %cnd -> tensor<128x32xf16, #A> { + scf.yield %a : tensor<128x32xf16, #A> + } else { + scf.yield %init_a : tensor<128x32xf16, #A> + } {loop.cluster = 0 : i32, loop.stage = 2 : i32} + "use"(%a_if) {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf16, #A>) -> () + } {tt.scheduled_max_stage = 3 : i32} + tt.return +} +} +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @dist1_load +tt.func @dist1_load(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %init_a : tensor<128x32xf16, #A>) -> () { + %_ = scf.for %iv = %lb to %ub step %step iter_args(%prev_a = %init_a) -> (tensor<128x32xf16, #A>) : index { + "use_next_iter"(%prev_a) {loop.cluster = 2 : i32, loop.stage = 0 : i32} : (tensor<128x32xf16, #A>) -> () + %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () + scf.yield %a : tensor<128x32xf16, #A> + } {tt.scheduled_max_stage = 2 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @one_dep_sync +// CHECK: scf.for +// CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +tt.func @one_dep_sync(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<1x!tt.ptr, #A>) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x!tt.ptr, #A> + "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<1xf16, #A>) -> () + } {tt.scheduled_max_stage = 2 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK: #[[SHARED:.*]] = #ttg.swizzled_shared +// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 +// CHECK-DAG: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32 +// CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 +// CHECK: scf.for {{.*}} iter_args(%[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]]) +// CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 +// CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 +// CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 +// CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 +// CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 +// CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 +// CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} +// CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_VAL:.*]] = ttg.local_load %[[A_EXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x32xf16, #[[SHARED]], # +// CHECK: "use"(%[[A_VAL]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: scf.yield %[[INS_NEXT]], %[[EXT_NEXT]] +// CHECK-DAG: ttg.local_dealloc %[[A]] +// CHECK-DAG: ttg.async_wait {num = 0 : i32} +tt.func @one_dep_local_alloc(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + %a_alloc = ttg.local_alloc %a {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> + %a_load = ttg.local_load %a_alloc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #A> + "use"(%a_load) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () + } {tt.scheduled_max_stage = 2 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @one_load_group +tt.func @one_load_group(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + // CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 + // Only one insert and extract index is used. + // CHECK: scf.for {{.*}} iter_args(%[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]]) -> + scf.for %iv = %lb to %ub step %step : index { + // CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] + // CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] + // CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] + // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] + // CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] + // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] + %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + %b = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + "use1"(%a){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> () + "use2"(%b){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> () + } {tt.scheduled_max_stage = 2 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @two_load_groups +tt.func @two_load_groups(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>, + %c_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + // CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[NUM_BUFS2:.*]] = arith.constant {{.*}} 2 : i32 + // CHECK-DAG: %[[NUM_BUFS3:.*]] = arith.constant {{.*}} 3 : i32 + // Two insert and extract indices are used. + // CHECK: scf.for {{.*}} iter_args(%[[INS2:.*]] = %[[MINUS_ONE]], %[[EXT2:.*]] = %[[MINUS_ONE]], %[[INS3:.*]] = %[[MINUS_ONE]], %[[EXT3:.*]] = %[[MINUS_ONE]]) -> + scf.for %iv = %lb to %ub step %step : index { + // CHECK-DAG: %[[INS3_P1:.*]] = arith.addi %[[INS3]], %[[ONE]] + // CHECK-DAG: %[[INS3_CMP:.*]] = arith.cmpi slt, %[[INS3_P1]], %[[NUM_BUFS3]] + // CHECK-DAG: %[[INS3_NEXT:.*]] = arith.select %[[INS3_CMP]], %[[INS3_P1]], %[[ZERO]] + // CHECK-DAG: %[[EXT3_P1:.*]] = arith.addi %[[EXT3]], %[[ONE]] + // CHECK-DAG: %[[EXT3_CMP:.*]] = arith.cmpi slt, %[[EXT3_P1]], %[[NUM_BUFS3]] + // CHECK-DAG: %[[EXT3_NEXT:.*]] = arith.select %[[EXT3_CMP]], %[[EXT3_P1]], %[[ZERO]] + // CHECK-DAG: %[[INS2_P1:.*]] = arith.addi %[[INS2]], %[[ONE]] + // CHECK-DAG: %[[INS2_CMP:.*]] = arith.cmpi slt, %[[INS2_P1]], %[[NUM_BUFS2]] + // CHECK-DAG: %[[INS2_NEXT:.*]] = arith.select %[[INS2_CMP]], %[[INS2_P1]], %[[ZERO]] + // CHECK-DAG: %[[EXT2_P1:.*]] = arith.addi %[[EXT2]], %[[ONE]] + // CHECK-DAG: %[[EXT2_CMP:.*]] = arith.cmpi slt, %[[EXT2_P1]], %[[NUM_BUFS2]] + // CHECK-DAG: %[[EXT2_NEXT:.*]] = arith.select %[[EXT2_CMP]], %[[EXT2_P1]], %[[ZERO]] + %a = tt.load %a_ptr_init {loop.cluster = 3 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + %b = tt.load %a_ptr_init {loop.cluster = 3 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + %c = tt.load %a_ptr_init {loop.cluster = 3 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + "use1"(%a){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> () + "use2"(%b){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> () + "use3"(%c){loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf32, #A>) -> () + } {tt.scheduled_max_stage = 3 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @dependent_loads +tt.func @dependent_loads(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + // CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 + // CHECK: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xf32 + // CHECK: %[[C:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xf32 + // CHECK: scf.for {{.*}} iter_args(%[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]]) -> + // CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 2 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_VAL:.*]] = ttg.local_load %[[A_EXT]] token %[[A_TOK3]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[B:.*]] = "pointerize"(%[[A_VAL]]) {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_INS:.*]] = ttg.memdesc_subview %[[C]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_TOK:.*]] = ttg.async_copy_global_to_local %[[B]], %[[C_INS]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_TOK2:.*]] = ttg.async_commit_group %[[C_TOK]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_TOK3:.*]] = ttg.async_wait %[[C_TOK2]] {loop.cluster = 0 : i32, loop.stage = 4 : i32, num = 0 : i32} + // CHECK: %[[C_EXT:.*]] = ttg.memdesc_subview %[[C]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 4 : i32} + // CHECK: %[[C_VAL:.*]] = ttg.local_load %[[C_EXT]] token %[[C_TOK3]] {loop.cluster = 0 : i32, loop.stage = 4 : i32} + // CHECK: "use1"(%[[C_VAL]]) {loop.cluster = 0 : i32, loop.stage = 4 : i32} + // CHECK: scf.yield + // CHECK-DAG: ttg.local_dealloc %[[A]] + // CHECK-DAG: ttg.local_dealloc %[[C]] + // CHECK-DAG: ttg.async_wait {num = 0 : i32} + scf.for %iv = %lb to %ub step %step : index { + %a = tt.load %a_ptr_init {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + %b = "pointerize"(%a) {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> tensor<128x32x!tt.ptr, #A> + %c = tt.load %b {loop.cluster = 2 : i32, loop.stage = 2 : i32} : tensor<128x32x!tt.ptr, #A> + "use1"(%c){loop.cluster = 0 : i32, loop.stage = 4 : i32} : (tensor<128x32xf32, #A>) -> () + } {tt.scheduled_max_stage = 4 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @dependent_loads_asymmetric +// Loads have different latencies, should create two load groups. +tt.func @dependent_loads_asymmetric(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + // CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[NUM_BUFS2:.*]] = arith.constant {{.*}} 2 : i32 + // CHECK-DAG: %[[NUM_BUFS3:.*]] = arith.constant {{.*}} 3 : i32 + // CHECK: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xf32 + // CHECK: %[[C:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x32xf32 + // CHECK: scf.for {{.*}} iter_args(%[[INS2:.*]] = %[[MINUS_ONE]], %[[EXT2:.*]] = %[[MINUS_ONE]], %[[INS3:.*]] = %[[MINUS_ONE]], %[[EXT3:.*]] = %[[MINUS_ONE]]) -> + // CHECK-DAG: %[[INS3_P1:.*]] = arith.addi %[[INS3]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK-DAG: %[[INS3_CMP:.*]] = arith.cmpi slt, %[[INS3_P1]], %[[NUM_BUFS3]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK-DAG: %[[INS3_NEXT:.*]] = arith.select %[[INS3_CMP]], %[[INS3_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK-DAG: %[[EXT3_P1:.*]] = arith.addi %[[EXT3]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 5 : i32} + // CHECK-DAG: %[[EXT3_CMP:.*]] = arith.cmpi slt, %[[EXT3_P1]], %[[NUM_BUFS3]] {loop.cluster = 0 : i32, loop.stage = 5 : i32} + // CHECK-DAG: %[[EXT3_NEXT:.*]] = arith.select %[[EXT3_CMP]], %[[EXT3_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 5 : i32} + // CHECK-DAG: %[[INS2_P1:.*]] = arith.addi %[[INS2]], %[[ONE]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK-DAG: %[[INS2_CMP:.*]] = arith.cmpi slt, %[[INS2_P1]], %[[NUM_BUFS2]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK-DAG: %[[INS2_NEXT:.*]] = arith.select %[[INS2_CMP]], %[[INS2_P1]], %[[ZERO]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK-DAG: %[[EXT2_P1:.*]] = arith.addi %[[EXT2]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK-DAG: %[[EXT2_CMP:.*]] = arith.cmpi slt, %[[EXT2_P1]], %[[NUM_BUFS2]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK-DAG: %[[EXT2_NEXT:.*]] = arith.select %[[EXT2_CMP]], %[[EXT2_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS2_NEXT]]{{.*}} {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 4 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 2 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT2_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_VAL:.*]] = ttg.local_load %[[A_EXT]] token %[[A_TOK3]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[B:.*]] = "pointerize"(%[[A_VAL]]) {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_INS:.*]] = ttg.memdesc_subview %[[C]][%[[INS3_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_TOK:.*]] = ttg.async_copy_global_to_local %[[B]], %[[C_INS]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_TOK2:.*]] = ttg.async_commit_group %[[C_TOK]] {loop.cluster = 2 : i32, loop.stage = 2 : i32} + // CHECK: %[[C_TOK3:.*]] = ttg.async_wait %[[C_TOK2]] {loop.cluster = 0 : i32, loop.stage = 5 : i32, num = 0 : i32} + // CHECK: %[[C_EXT:.*]] = ttg.memdesc_subview %[[C]][%[[EXT3_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 5 : i32} + // CHECK: %[[C_VAL:.*]] = ttg.local_load %[[C_EXT]] token %[[C_TOK3]] {loop.cluster = 0 : i32, loop.stage = 5 : i32} + // CHECK: "use1"(%[[C_VAL]]) {loop.cluster = 0 : i32, loop.stage = 5 : i32} + // CHECK: scf.yield + // CHECK-DAG: ttg.local_dealloc %[[A]] + // CHECK-DAG: ttg.local_dealloc %[[C]] + // CHECK-DAG: ttg.async_wait {num = 0 : i32} + scf.for %iv = %lb to %ub step %step : index { + %a = tt.load %a_ptr_init {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> + %b = "pointerize"(%a) {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> tensor<128x32x!tt.ptr, #A> + %c = tt.load %b {loop.cluster = 2 : i32, loop.stage = 2 : i32} : tensor<128x32x!tt.ptr, #A> + "use1"(%c){loop.cluster = 0 : i32, loop.stage = 5 : i32} : (tensor<128x32xf32, #A>) -> () + } {tt.scheduled_max_stage = 5 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @unused_load +tt.func @unused_load(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> () { + // CHECK: scf.for + scf.for %iv = %lb to %ub step %step : index { + // CHECK: dummy + %a = tt.load %a_ptr_init {loop.cluster = 0 : i32, loop.stage = 1 : i32} : tensor<128x32x!tt.ptr, #A> + "dummy"() : () -> () + } {tt.scheduled_max_stage = 1 : i32} + tt.return +} +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, instrShape = [16, 16, 16]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @shmem_pipelining_mmav3 + // CHECK-DAG: %[[INIT:.*]] = arith.constant dense<0.000000e+00> + // CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant 3 : i32 + // CHECK: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x128 + // CHECK: %[[B:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x128 + // CHECK: scf.for {{.*}} iter_args(%[[ACC:.*]] = %[[INIT]], %[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]]) + // CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_INS:.*]] = ttg.memdesc_subview %[[B]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[B_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK2:.*]] = ttg.async_commit_group %[[B_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK3:.*]] = ttg.async_wait %[[B_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[B_EXT:.*]] = ttg.memdesc_subview %[[B]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: ttng.warp_group_dot %[[A_EXT]], %[[B_EXT]], %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: scf.yield {{.*}}, %[[INS_NEXT]], %[[EXT_NEXT]] + // CHECK-DAG: ttg.local_dealloc %[[A]] + // CHECK-DAG: ttg.local_dealloc %[[B]] + // CHECK-DAG: ttg.async_wait {num = 0 : i32} + tt.func public @shmem_pipelining_mmav3(%lb : index, %ub : index, %step : index, + %A_ptr: tensor<128x128x!tt.ptr, #blocked1>, + %B_ptr: tensor<128x128x!tt.ptr, #blocked1>) -> tensor<128x128xf16, #mma> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %res = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> (tensor<128x128xf32, #mma>) : index { + %A = tt.load %A_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %B = tt.load %B_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + scf.yield %acc_res : tensor<128x128xf32, #mma> + } {tt.scheduled_max_stage = 2 : i32} + %res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + tt.return %res_f16 : tensor<128x128xf16, #mma> + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, instrShape = [16, 16, 16]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // The combination of blocked and shared layouts for operand B would result in cp.async with less than 4 bytes size. + // We can't pipeline that using shared memory buffer. + // CHECK-LABEL: @no_shmem_pipelining_incompat_layout + // CHECK-DAG: %[[INIT:.*]] = arith.constant dense<0.000000e+00> + // CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant 3 : i32 + // CHECK: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x128 + // CHECK: scf.for {{.*}} iter_args(%[[ACC:.*]] = %[[INIT]], %[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]]) + // CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B:.*]] = tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_SH:.*]] = ttg.local_alloc %[[B]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: ttng.warp_group_dot %[[A_EXT]], %[[B_SH]], %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: scf.yield {{.*}}, %[[INS_NEXT]], %[[EXT_NEXT]] + // CHECK-DAG: ttg.local_dealloc %[[A]] + // CHECK-DAG: ttg.async_wait {num = 0 : i32} + tt.func public @no_shmem_pipelining_incompat_layout( + %lb : index, %ub : index, %step : index, + %A_ptr: tensor<128x128x!tt.ptr, #blocked1>, + %B_ptr: tensor<128x128x!tt.ptr, #blocked1>) -> tensor<128x128xf32, #mma> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %res = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> (tensor<128x128xf32, #mma>) : index { + %A = tt.load %A_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %B = tt.load %B_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory> + %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + scf.yield %acc_res : tensor<128x128xf32, #mma> + } {tt.scheduled_max_stage = 2 : i32} + tt.return %res : tensor<128x128xf32, #mma> + } +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, instrShape = [16, 16, 16]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // non-zero "other" value is used in the load, while cp.async does not support it. + // We can't feed the shared memory values directly to mma, we need other values being filled in the registers. + // CHECK-LABEL: @no_shmem_pipelining_other_used + // CHECK-DAG: %[[INIT:.*]] = arith.constant dense<0.000000e+00> + // CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 + // CHECK: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x128 + // CHECK: %[[B:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x128 + // CHECK: scf.for {{.*}} iter_args(%[[ACC:[^,]*]] = %[[INIT]], %[[INS:[^,]*]] = %[[MINUS_ONE]], %[[EXT:[^,]*]] = %[[MINUS_ONE]]) + // CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_LOAD:.*]] = ttg.local_load %[[A_EXT]] {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_MASKED:.*]] = arith.select {{.*}}, %[[A_LOAD]], {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_INS:.*]] = ttg.memdesc_subview %[[B]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[B_INS]] {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK2:.*]] = ttg.async_commit_group %[[B_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK3:.*]] = ttg.async_wait %[[B_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[B_EXT:.*]] = ttg.memdesc_subview %[[B]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_LOAD:.*]] = ttg.local_load %[[B_EXT]] {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_MASKED:.*]] = arith.select {{.*}}, %[[B_LOAD]], {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[A_SH:.*]] = ttg.local_alloc %[[A_MASKED]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_SH:.*]] = ttg.local_alloc %[[B_MASKED]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: ttng.warp_group_dot %[[A_SH]], %[[B_SH]], %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: scf.yield {{.*}}, %[[INS_NEXT]], %[[EXT_NEXT]] + // CHECK-DAG: ttg.local_dealloc %[[A]] + // CHECK-DAG: ttg.local_dealloc %[[B]] + // CHECK-DAG: ttg.async_wait {num = 0 : i32} + tt.func public @no_shmem_pipelining_other_used( + %lb : index, %ub : index, %step : index, + %A_ptr: tensor<128x128x!tt.ptr, #blocked1>, + %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, + %mask: tensor<128x128xi1, #blocked1>, + %other: tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #mma> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %res = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> (tensor<128x128xf32, #mma>) : index { + %A = tt.load %A_ptr, %mask, %other {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %B = tt.load %B_ptr, %mask, %other {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + scf.yield %acc_res : tensor<128x128xf32, #mma> + } {tt.scheduled_max_stage = 2 : i32} + %res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + tt.return %res_f16 : tensor<128x128xf16, #mma> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#tmem = #ttng.tensor_memory_encoding + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @shmem_pipelining_mmav5 + // CHECK-DAG: %[[INIT:.*]] = arith.constant dense<0.000000e+00> + // CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant 3 : i32 + // CHECK: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x128 + // CHECK: %[[B:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x128 + // CHECK: scf.for {{.*}} iter_args(%[[ACC:.*]] = %[[INIT]], %[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]]) + // CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + // CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32 + // CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[A_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK2:.*]] = ttg.async_commit_group %[[A_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[A_TOK3:.*]] = ttg.async_wait %[[A_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: %[[B_INS:.*]] = ttg.memdesc_subview %[[B]][%[[INS_NEXT]]{{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[B_INS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK2:.*]] = ttg.async_commit_group %[[B_TOK]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} + // CHECK: %[[B_TOK3:.*]] = ttg.async_wait %[[B_TOK2]] {loop.cluster = 0 : i32, loop.stage = 2 : i32, num = 0 : i32} + // CHECK: %[[B_EXT:.*]] = ttg.memdesc_subview %[[B]][%[[EXT_NEXT]]{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: ttng.tc_gen5_mma %[[A_EXT]], %[[B_EXT]], %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + // CHECK: scf.yield {{.*}}, %[[INS_NEXT]], %[[EXT_NEXT]] + // CHECK-DAG: ttg.local_dealloc %[[A]] + // CHECK-DAG: ttg.local_dealloc %[[B]] + // CHECK-DAG: ttg.async_wait {num = 0 : i32} + tt.func public @shmem_pipelining_mmav5(%lb : index, %ub : index, %step : index, + %A_ptr: tensor<128x128x!tt.ptr, #blocked1>, + %B_ptr: tensor<128x128x!tt.ptr, #blocked1>) -> tensor<128x128xf16, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %res = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : index { + %A = tt.load %A_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %B = tt.load %B_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %acc_tm = ttng.tmem_alloc %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked> + scf.yield %acc_res : tensor<128x128xf32, #blocked> + } {tt.scheduled_max_stage = 2 : i32} + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + tt.return %res_f16 : tensor<128x128xf16, #blocked> + } +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @tma_load_lowering +// CHECK-DAG: %[[TRUE:.*]] = arith.constant {{.*}} true +// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 : i32 +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 +// CHECK-DAG: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x128x32 +// CHECK-DAG: %[[BARRIER:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 +// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ZERO]]] +// CHECK: ttng.init_barrier %[[BAR1_VIEW]], 1 +// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ONE]]] +// CHECK: ttng.init_barrier %[[BAR2_VIEW]], 1 +// CHECK: scf.for {{.*}} iter_args(%[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]], %[[PHASE:.*]] = %[[ZERO]]) +// CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[PHASE_XOR:.*]] = arith.xori %[[PHASE]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[PHASE_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[PHASE]], %[[PHASE_XOR]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[BAR_INS:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[INS_NEXT]]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: ttng.barrier_expect %[[BAR_INS]], 8192 {loop.cluster = 2 : i32, loop.stage = 0 : i32}, %[[TRUE]] +// CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]], %[[ZERO]], %[[ZERO]]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[TMA_PTR:.*]] = ttng.tensor_desc_to_tma_ptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: ttng.async_tma_copy_global_to_local %[[TMA_PTR]][{{.*}}] %[[A_INS]], %[[BAR_INS]], %[[TRUE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[BAR_EXT:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[EXT_NEXT]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: ttng.wait_barrier %[[BAR_EXT]], %[[PHASE_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]], %[[ZERO]], %[[ZERO]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_LOAD:.*]] = ttg.local_load %[[A_EXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: "use"(%[[A_LOAD]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: scf.yield %[[INS_NEXT]], %[[EXT_NEXT]], %[[PHASE_NEXT]] : i32, i32, i32 +// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ZERO]]] +// CHECK: ttng.inval_barrier %[[BAR1_VIEW]] +// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ONE]]] +// CHECK: ttng.inval_barrier %[[BAR2_VIEW]] +// CHECK: ttg.local_dealloc %[[BARRIER]] +// CHECK: ttg.local_dealloc %[[A]] +tt.func @tma_load_lowering(%lb : index, %ub : index, %step : index, + %desc : !tt.tensordesc>, + %offs : i32) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.experimental_descriptor_load %desc[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> + "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () + } {tt.scheduled_max_stage = 2 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#offsets = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @tma_gather_lowering +// CHECK-DAG: %[[TRUE:.*]] = arith.constant {{.*}} true +// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.constant -1 : i32 +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[NUM_BUFS:.*]] = arith.constant {{.*}} 2 : i32 +// CHECK-DAG: %[[A:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x32x128 +// CHECK-DAG: %[[BARRIER:.*]] = ttg.local_alloc : () -> !ttg.memdesc<2xi64 +// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ZERO]]] +// CHECK: ttng.init_barrier %[[BAR1_VIEW]], 1 +// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ONE]]] +// CHECK: ttng.init_barrier %[[BAR2_VIEW]], 1 +// CHECK: scf.for {{.*}} iter_args(%[[INS:.*]] = %[[MINUS_ONE]], %[[EXT:.*]] = %[[MINUS_ONE]], %[[PHASE:.*]] = %[[ZERO]]) +// CHECK: %[[INS_P1:.*]] = arith.addi %[[INS]], %[[ONE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[INS_CMP:.*]] = arith.cmpi slt, %[[INS_P1]], %[[NUM_BUFS]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[INS_NEXT:.*]] = arith.select %[[INS_CMP]], %[[INS_P1]], %[[ZERO]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[EXT_P1:.*]] = arith.addi %[[EXT]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[EXT_CMP:.*]] = arith.cmpi slt, %[[EXT_P1]], %[[NUM_BUFS]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[EXT_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[EXT_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[PHASE_XOR:.*]] = arith.xori %[[PHASE]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[PHASE_NEXT:.*]] = arith.select %[[EXT_CMP]], %[[PHASE]], %[[PHASE_XOR]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[BAR_INS:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[INS_NEXT]]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: ttng.barrier_expect %[[BAR_INS]], 16384 {loop.cluster = 2 : i32, loop.stage = 0 : i32}, %[[TRUE]] +// CHECK: %[[A_INS:.*]] = ttg.memdesc_subview %[[A]][%[[INS_NEXT]], %[[ZERO]], %[[ZERO]]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[TMA_PTR:.*]] = ttng.tensor_desc_to_tma_ptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: ttng.async_tma_gather %[[TMA_PTR]][{{.*}}] %[[A_INS]], %[[BAR_INS]], %[[TRUE]] {loop.cluster = 2 : i32, loop.stage = 0 : i32} +// CHECK: %[[BAR_EXT:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[EXT_NEXT]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: ttng.wait_barrier %[[BAR_EXT]], %[[PHASE_NEXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_EXT:.*]] = ttg.memdesc_subview %[[A]][%[[EXT_NEXT]], %[[ZERO]], %[[ZERO]]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: %[[A_LOAD:.*]] = ttg.local_load %[[A_EXT]] {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: "use"(%[[A_LOAD]]) {loop.cluster = 0 : i32, loop.stage = 2 : i32} +// CHECK: scf.yield %[[INS_NEXT]], %[[EXT_NEXT]], %[[PHASE_NEXT]] : i32, i32, i32 +// CHECK: %[[BAR1_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ZERO]]] +// CHECK: ttng.inval_barrier %[[BAR1_VIEW]] +// CHECK: %[[BAR2_VIEW:.*]] = ttg.memdesc_subview %[[BARRIER]][%[[ONE]]] +// CHECK: ttng.inval_barrier %[[BAR2_VIEW]] +// CHECK-DAG: ttg.local_dealloc %[[BARRIER]] +// CHECK-DAG: ttg.local_dealloc %[[A]] +tt.func @tma_gather_lowering(%lb : index, %ub : index, %step : index, + %desc : !tt.tensordesc>, + %x : tensor<32xi32, #offsets>, + %y : i32) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.experimental_descriptor_gather %desc[%x, %y] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : (!tt.tensordesc>, tensor<32xi32, #offsets>, i32) -> tensor<32x128xf32, #A> + "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x128xf32, #A>) -> () + } {tt.scheduled_max_stage = 2 : i32} + tt.return +} +} + +// ----- + +#A = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @tma_reuse_barrier +// CHECK: scf.for +// CHECK: ttng.barrier_expect {{.*}}, 16384 +// CHECK: ttng.async_tma_copy_global_to_local +// CHECK-NOT: ttng.wait_barrier +// CHECK: ttng.async_tma_copy_global_to_local +// CHECK: ttng.wait_barrier +// CHECK: "use1" +// CHECK: "use2" +// CHECK: ttng.barrier_expect {{.*}}, 8192 +// CHECK: ttng.async_tma_copy_global_to_local +// CHECK: ttng.wait_barrier +// CHECK: "use3" +tt.func @tma_reuse_barrier(%lb : index, %ub : index, %step : index, + %descA : !tt.tensordesc>, + %descB : !tt.tensordesc>, + %descC : !tt.tensordesc>, + %offs : i32) -> () { + scf.for %iv = %lb to %ub step %step : index { + %a = tt.experimental_descriptor_load %descA[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> + %b = tt.experimental_descriptor_load %descB[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> + "use1"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () + "use2"(%b) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () + %c = tt.experimental_descriptor_load %descC[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> + "use3"(%c) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () + } {tt.scheduled_max_stage = 2 : i32} + tt.return +} +} + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, instrShape = [16, 16, 16]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tma_pipelining_mmav3 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x128x128 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x128x128 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3xi64 + // CHECK: scf.for + // CHECK: ttng.barrier_expect + // CHECK: ttng.async_tma_copy_global_to_local + // CHECK-NOT: ttng.wait_barrier + // CHECK: ttng.async_tma_copy_global_to_local + // CHECK: ttng.wait_barrier + // CHECK-NOT: ttg.local_alloc + // CHECK: ttng.warp_group_dot + tt.func public @tma_pipelining_mmav3(%lb : index, %ub : index, %step : index, + %descA : !tt.tensordesc>, + %descB : !tt.tensordesc>, + %offs : i32) -> tensor<128x128xf16, #mma> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %res = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> (tensor<128x128xf32, #mma>) : index { + %A = tt.experimental_descriptor_load %descA[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x128xf16, #blocked1> + %A_sh = ttg.local_alloc %A {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B = tt.experimental_descriptor_load %descB[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x128xf16, #blocked1> + %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + scf.yield %acc_res : tensor<128x128xf32, #mma> + } {tt.scheduled_max_stage = 2 : i32} + %res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + tt.return %res_f16 : tensor<128x128xf16, #mma> + } +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @tensor_descriptor_lowering + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : i32 + // CHECK-DAG: %[[NUM_STAGES:.*]] = arith.constant 2 : i32 + // CHECK-DAG: %[[_128:.*]] = arith.constant{{.*}} 128 : i32 + // CHECK: %[[GLOBAL_ALLOC:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 256 : i32} : !tt.ptr + // CHECK: scf.for {{.*}} iter_args(%[[IDX:.*]] = %[[ZERO]]) + // CHECK: %[[OFFS:.*]] = arith.muli %[[IDX]], %[[_128]] {loop.cluster = 0 : i32, loop.stage = 1 : i32} + // CHECK: %[[DESC_PTR:.*]] = tt.addptr %[[GLOBAL_ALLOC]], %[[OFFS]] {loop.cluster = 0 : i32, loop.stage = 1 : i32} + // CHECK: tt.experimental_tensormap_create %[[DESC_PTR]]{{.*}} loop.cluster = 0 : i32, loop.stage = 1 : i32 + // CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[DESC_PTR]] {loop.cluster = 0 : i32, loop.stage = 1 : i32} + // CHECK: %[[DESC:.*]] = tt.reinterpret_tensor_descriptor %[[DESC_PTR]] {loop.cluster = 0 : i32, loop.stage = 1 : i32} + // CHECK: %[[IDX_P1:.*]] = arith.addi %[[IDX]], %[[ONE]] {loop.cluster = 0 : i32, loop.stage = 1 : i32} + // CHECK: %[[IDX_CMP:.*]] = arith.cmpi slt, %[[IDX_P1]], %[[NUM_STAGES]] {loop.cluster = 0 : i32, loop.stage = 1 : i32} + // CHECK: %[[IDX_NEXT:.*]] = arith.select %[[IDX_CMP]], %[[IDX_P1]], %[[ZERO]] {loop.cluster = 0 : i32, loop.stage = 1 : i32} + // CHECK: "use"(%[[DESC]]) {loop.cluster = 0 : i32, loop.stage = 1 : i32} + tt.func @tensor_descriptor_lowering( + %lb : index, %ub : index, %step : index, + %A: !tt.ptr, + %shape_x: i32, + %shape_y: i32, + %strides_x: i64, + %strides_y: i64) -> (){ + scf.for %iv = %lb to %ub step %step : index { + %desc = tt.make_tensor_descriptor %A, [%shape_x, %shape_y], [%strides_x, %strides_y] {loop.cluster = 0 : i32, loop.stage = 1 : i32} : , > + "use"(%desc) {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (!tt.tensordesc>) -> () + } {tt.scheduled_max_stage = 1 : i32} + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @pipelining_mmav5_scaled + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x128x128xf8E5M2 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x128x128xf8E5M2 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x1x2x32x4x4xi8 + // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x1x2x32x4x4xi8 + tt.func public @pipelining_mmav5_scaled(%lb : index, %ub : index, %step : index, + %A_ptr: tensor<128x128x!tt.ptr, #blocked1>, + %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, + %A_sc_ptr: tensor<1x2x32x4x4x!tt.ptr, #blocked2>, + %B_sc_ptr: tensor<1x2x32x4x4x!tt.ptr, #blocked2>) -> tensor<128x128xf32, #blocked> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %res = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked>) : index { + %A = tt.load %A_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %B = tt.load %B_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf8E5M2, #blocked1>) -> !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory> + %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf8E5M2, #blocked1>) -> !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory> + + %A_sc = tt.load %A_sc_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x2x32x4x4x!tt.ptr, #blocked2> + %A_sc_sh = ttg.local_alloc %A_sc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> + + %B_sc = tt.load %B_sc_ptr {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x2x32x4x4x!tt.ptr, #blocked2> + %B_sc_sh = ttg.local_alloc %B_sc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> + + %acc_tm = ttng.tmem_alloc %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> + ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm, %A_sc_sh, %B_sc_sh, %true, %true lhs = e5m2 rhs = e5m2 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked> + scf.yield %acc_res : tensor<128x128xf32, #blocked> + } {tt.scheduled_max_stage = 2 : i32} + tt.return %res : tensor<128x128xf32, #blocked> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/pipeline-schedule-loop.mlir b/third_party/enflame/include/triton/test/TritonGPU/pipeline-schedule-loop.mlir new file mode 100644 index 000000000..ad46c3871 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/pipeline-schedule-loop.mlir @@ -0,0 +1,353 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-test-pipeline-schedule-loop -canonicalize | FileCheck %s + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 16}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { +// CHECK-LABEL: @one_dep +tt.func @one_dep(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res : tensor<128x32xf16, #A> + } + // CHECK: tt.scheduled_max_stage + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps +tt.func @parallel_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %b = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps_uneven1 +tt.func @parallel_deps_uneven1(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 1 : i32, loop.stage = 1 : i32} + %b = tt.load %a_ptr_init {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @parallel_deps_uneven2 +tt.func @parallel_deps_uneven2(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>, + %b_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc_a = %init, %acc_b = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 1 : i32, loop.stage = 1 : i32} + %a = tt.load %a_ptr_init {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %b = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_a = arith.addf %acc_a, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res_b = arith.addf %acc_b, %b : tensor<128x32xf16, #A> + scf.yield %res_a, %res_b : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @direct_deps +tt.func @direct_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %a_ptr = %a_ptr_init) -> (tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A>) { + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a_ptr_next = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_next {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res, %a_ptr_next : tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @dist1_deps +tt.func @dist1_deps(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %a_ptr = %a_ptr_init) -> (tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %a_ptr_next = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + scf.yield %res, %a_ptr_next : tensor<128x32xf16, #A>, tensor<128x32x!tt.ptr, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @prologue_if +tt.func @prologue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: scf.if + // CHECK: {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %a_ptr = scf.if %cnd -> tensor<128x32x!tt.ptr, #A> { + %a_ptr_ret = tt.addptr %a_ptr_init, %a_off : tensor<128x32x!tt.ptr, #A>, tensor<128x32xi32, #A> + scf.yield %a_ptr_ret : tensor<128x32x!tt.ptr, #A> + } else { + scf.yield %a_ptr_init : tensor<128x32x!tt.ptr, #A> + } + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @independent_epilogue_if +tt.func @independent_epilogue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: scf.if + // CHECK: {loop.cluster = 4 : i32, loop.stage = 2 : i32} + scf.if %cnd { + tt.store %a_ptr_init, %init : tensor<128x32x!tt.ptr, #A> + } + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @independent_last_stage +tt.func @independent_last_stage(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init, %acc2 = %init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 0 : i32} + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + // CHECK: arith.addf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %res2 = arith.addf %acc2, %init : tensor<128x32xf16, #A> + scf.yield %res, %res2 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + } + tt.return %loop#0, %loop#1 : tensor<128x32xf16, #A>, tensor<128x32xf16, #A> +} + +// CHECK-LABEL: @basic_pipeline +tt.func @basic_pipeline(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @unpipelined_load +tt.func @unpipelined_load(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // load below should be in the same stage as tt.dot (not pipelined) + // CHECK: tt.load {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // addptr below should be scheduled to the last stage + // CHECK: tt.addptr {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @epilogue_if +tt.func @epilogue_if(%lb : index, %ub : index, %step : index, %cnd : i1, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>, + %c_ptr_store : tensor<128x128x!tt.ptr, #C>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: scf.if + // CHECK: {loop.cluster = 4 : i32, loop.stage = 2 : i32} + scf.if %cnd { + tt.store %c_ptr_store, %c : tensor<128x128x!tt.ptr, #C> + } + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @intermediate_use +tt.func @intermediate_use(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + %c2 = arith.constant dense<2.00> : tensor<32x128xf16, #BL> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %a_ = tt.load %a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} + %b_ = tt.load %b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: arith.mulf {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b_2 = arith.mulf %b_ , %c2 : tensor<32x128xf16, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %b = ttg.convert_layout %b_2 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#2: tensor<128x128xf32, #C> +} + +// CHECK-LABEL: @indirect_load +tt.func @indirect_load(%lb : index, %ub : index, %step : index, + %a_ind_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %a_ptr_init : tensor<128x32x!tt.ptr, #AL>, + %b_ptr_init : tensor<32x128x!tt.ptr, #BL>) -> tensor<128x128xf32, #C> { + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_ind_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ind_ptr = %a_ind_ptr_init, %a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { + // CHECK: tt.load {{.*}} {loop.cluster = 4 : i32, loop.stage = 0 : i32} + %a_off = tt.load %a_ind_ptr {tt.latency = 1 : i32} : tensor<128x32x!tt.ptr, #AL> + %next_a_ind_ptr = tt.addptr %a_ind_ptr, %a_ind_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + // addptr below scheduled by scheduleDependencies to the same stage as tt.load that is using it + // CHECK: tt.addptr {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %a_ = tt.load %next_a_ptr {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #AL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A> + // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 1 : i32} + %b_ = tt.load %next_b_ptr {tt.latency = 2 : i32} : tensor<32x128x!tt.ptr, #BL> + // CHECK: ttg.convert_layout {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> + + // CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + scf.yield %next_a_ind_ptr, %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> + } + tt.return %loop#3: tensor<128x128xf32, #C> +} + +// Verify that we don't schedule/pipeline loops with gpu.barrier +// CHECK-LABEL: @gpu_barrier +tt.func @gpu_barrier(%lb : index, %ub : index, %step : index, + %a_ptr_init : tensor<128x32x!tt.ptr, #A>) -> tensor<128x32xf16, #A> { + %init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) { + // CHECK-NOT: loop.cluster + %a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr, #A> + %res = arith.addf %acc, %a : tensor<128x32xf16, #A> + gpu.barrier + scf.yield %res : tensor<128x32xf16, #A> + } + tt.return %loop#0 : tensor<128x32xf16, #A> +} +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/prefetch.mlir b/third_party/enflame/include/triton/test/TritonGPU/prefetch.mlir new file mode 100644 index 000000000..136b7f751 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/prefetch.mlir @@ -0,0 +1,368 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-prefetch -canonicalize | FileCheck %s + +// 4 warps +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_OP = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_OP = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#smem = #ttg.shared_memory + +// CHECK: tt.func @matmul_loop_mixed +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] +// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +module attributes { "ttg.num-warps" = 4 : i32 } { +tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf8E5M2, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> + %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A, #smem> -> tensor<128x32xf8E5M2, #A_OP> + %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B, #smem> -> tensor<32x128xf16, #B_OP> + %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> + %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> + + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C> + } + tt.return %loop#4 : tensor<128x128xf32, #C> +} +} // end module + +// 4 warps +// matmul: 128x16 @ 16x128 -> 128x128 +// CHECK: tt.func @matmul_loop_mixed +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: tt.dot %[[a0_prefetch]], %[[b0_prefetch]], {{.*}} +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +module attributes { "ttg.num-warps" = 4 : i32 } { +tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x16x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<16x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x16xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x16xf8E5M2, #AL> + %b_mask = arith.constant dense : tensor<16x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<16x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x16xi32, #AL> + %b_off = arith.constant dense<4> : tensor<16x128xi32, #BL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> + %a_init = ttg.local_alloc %a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A, #smem> + %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> + %b_init = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B, #smem> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A, #smem>, !ttg.memdesc<16x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x16xf8E5M2, #A, #smem> -> tensor<128x16xf8E5M2, #A_OP> + %a_op = tt.fp_to_fp %a_op_ : tensor<128x16xf8E5M2, #A_OP> -> tensor<128x16xf16, #A_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<16x128xf16, #B, #smem> -> tensor<16x128xf16, #B_OP> + %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x16xf16, #A_OP> * tensor<16x128xf16, #B_OP> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x16x!tt.ptr, #AL>, tensor<128x16xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<16x128x!tt.ptr, #BL>, tensor<16x128xi32, #BL> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A, #smem> + %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> + %next_b = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B, #smem> + + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A, #smem>, !ttg.memdesc<16x128xf16, #B, #smem>, tensor<128x128xf32, #C> + } + tt.return %loop#4 : tensor<128x128xf32, #C> +} +} // end module + +#AL_3D = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [2, 4, 4], warpsPerCTA = [1, 4, 1], order = [2, 0, 1]}> +#BL_3D = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [2, 4, 4], warpsPerCTA = [1, 4, 1], order = [2, 0, 1]}> +#A_3D = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [2, 0, 1]}> +#B_3D = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [2, 0, 1]}> +#C_3D = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 4, 1]}> +#A_OP_3D = #ttg.dot_op<{opIdx = 0, parent = #C_3D, kWidth = 2}> +#B_OP_3D = #ttg.dot_op<{opIdx = 1, parent = #C_3D, kWidth = 2}> + +// matmul: 8x128x16 @ 8x16x128 -> 8x128x128 +// CHECK: tt.func @matmul_3D_loop_mixed +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: tt.dot %[[a0_prefetch]], %[[b0_prefetch]], {{.*}} +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +module attributes { "ttg.num-warps" = 4 : i32 } { +tt.func @matmul_3D_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<8x128x128xf32, #C_3D>{ + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<8x128x16x!tt.ptr, #AL_3D> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<8x16x128x!tt.ptr, #BL_3D> + + %a_mask = arith.constant dense : tensor<8x128x16xi1, #AL_3D> + %a_other = arith.constant dense<0.00e+00> : tensor<8x128x16xf8E5M2, #AL_3D> + %b_mask = arith.constant dense : tensor<8x16x128xi1, #BL_3D> + %b_other = arith.constant dense<0.00e+00> : tensor<8x16x128xf16, #BL_3D> + %c_init = arith.constant dense<0.00e+00> : tensor<8x128x128xf32, #C_3D> + + %a_off = arith.constant dense<4> : tensor<8x128x16xi32, #AL_3D> + %b_off = arith.constant dense<4> : tensor<8x16x128xi32, #BL_3D> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<8x128x16x!tt.ptr, #AL_3D> + %a_init = ttg.local_alloc %a_ : (tensor<8x128x16xf8E5M2, #AL_3D>) -> !ttg.memdesc<8x128x16xf8E5M2, #A_3D, #smem> + %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<8x16x128x!tt.ptr, #BL_3D> + %b_init = ttg.local_alloc %b_ : (tensor<8x16x128xf16, #BL_3D>) -> !ttg.memdesc<8x16x128xf16, #B_3D, #smem> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<8x128x16x!tt.ptr, #AL_3D>, tensor<8x16x128x!tt.ptr, #BL_3D>, !ttg.memdesc<8x128x16xf8E5M2, #A_3D, #smem>, !ttg.memdesc<8x16x128xf16, #B_3D, #smem>, tensor<8x128x128xf32, #C_3D>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<8x128x16xf8E5M2, #A_3D, #smem> -> tensor<8x128x16xf8E5M2, #A_OP_3D> + %a_op = tt.fp_to_fp %a_op_ : tensor<8x128x16xf8E5M2, #A_OP_3D> -> tensor<8x128x16xf16, #A_OP_3D> + %b_op = ttg.local_load %b : !ttg.memdesc<8x16x128xf16, #B_3D, #smem> -> tensor<8x16x128xf16, #B_OP_3D> + %c = tt.dot %a_op, %b_op, %prev_c : tensor<8x128x16xf16, #A_OP_3D> * tensor<8x16x128xf16, #B_OP_3D> -> tensor<8x128x128xf32, #C_3D> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<8x128x16x!tt.ptr, #AL_3D>, tensor<8x128x16xi32, #AL_3D> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<8x16x128x!tt.ptr, #BL_3D>, tensor<8x16x128xi32, #BL_3D> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<8x128x16x!tt.ptr, #AL_3D> + %next_a = ttg.local_alloc %next_a_ : (tensor<8x128x16xf8E5M2, #AL_3D>) -> !ttg.memdesc<8x128x16xf8E5M2, #A_3D, #smem> + %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<8x16x128x!tt.ptr, #BL_3D> + %next_b = ttg.local_alloc %b_ : (tensor<8x16x128xf16, #BL_3D>) -> !ttg.memdesc<8x16x128xf16, #B_3D, #smem> + + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<8x128x16x!tt.ptr, #AL_3D>, tensor<8x16x128x!tt.ptr, #BL_3D>, !ttg.memdesc<8x128x16xf8E5M2, #A_3D, #smem>, !ttg.memdesc<8x16x128xf16, #B_3D, #smem>, tensor<8x128x128xf32, #C_3D> + } + tt.return %loop#4 : tensor<8x128x128xf32, #C_3D> +} +} // end module + +// matmul: 8x128x32 @ 8x32x128 -> 8x128x128 +// CHECK: tt.func @matmul_3D_loop_mixed2 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C0]], %[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] +// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +module attributes { "ttg.num-warps" = 4 : i32 } { +tt.func @matmul_3D_loop_mixed2(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<8x128x128xf32, #C_3D>{ + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<8x128x32x!tt.ptr, #AL_3D> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<8x32x128x!tt.ptr, #BL_3D> + + %a_mask = arith.constant dense : tensor<8x128x32xi1, #AL_3D> + %a_other = arith.constant dense<0.00e+00> : tensor<8x128x32xf8E5M2, #AL_3D> + %b_mask = arith.constant dense : tensor<8x32x128xi1, #BL_3D> + %b_other = arith.constant dense<0.00e+00> : tensor<8x32x128xf16, #BL_3D> + %c_init = arith.constant dense<0.00e+00> : tensor<8x128x128xf32, #C_3D> + + %a_off = arith.constant dense<4> : tensor<8x128x32xi32, #AL_3D> + %b_off = arith.constant dense<4> : tensor<8x32x128xi32, #BL_3D> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<8x128x32x!tt.ptr, #AL_3D> + %a_init = ttg.local_alloc %a_ : (tensor<8x128x32xf8E5M2, #AL_3D>) -> !ttg.memdesc<8x128x32xf8E5M2, #A_3D, #smem> + %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<8x32x128x!tt.ptr, #BL_3D> + %b_init = ttg.local_alloc %b_ : (tensor<8x32x128xf16, #BL_3D>) -> !ttg.memdesc<8x32x128xf16, #B_3D, #smem> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<8x128x32x!tt.ptr, #AL_3D>, tensor<8x32x128x!tt.ptr, #BL_3D>, !ttg.memdesc<8x128x32xf8E5M2, #A_3D, #smem>, !ttg.memdesc<8x32x128xf16, #B_3D, #smem>, tensor<8x128x128xf32, #C_3D>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<8x128x32xf8E5M2, #A_3D, #smem> -> tensor<8x128x32xf8E5M2, #A_OP_3D> + %a_op = tt.fp_to_fp %a_op_ : tensor<8x128x32xf8E5M2, #A_OP_3D> -> tensor<8x128x32xf16, #A_OP_3D> + %b_op = ttg.local_load %b : !ttg.memdesc<8x32x128xf16, #B_3D, #smem> -> tensor<8x32x128xf16, #B_OP_3D> + %c = tt.dot %a_op, %b_op, %prev_c : tensor<8x128x32xf16, #A_OP_3D> * tensor<8x32x128xf16, #B_OP_3D> -> tensor<8x128x128xf32, #C_3D> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<8x128x32x!tt.ptr, #AL_3D>, tensor<8x128x32xi32, #AL_3D> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<8x32x128x!tt.ptr, #BL_3D>, tensor<8x32x128xi32, #BL_3D> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<8x128x32x!tt.ptr, #AL_3D> + %next_a = ttg.local_alloc %next_a_ : (tensor<8x128x32xf8E5M2, #AL_3D>) -> !ttg.memdesc<8x128x32xf8E5M2, #A_3D, #smem> + %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<8x32x128x!tt.ptr, #BL_3D> + %next_b = ttg.local_alloc %b_ : (tensor<8x32x128xf16, #BL_3D>) -> !ttg.memdesc<8x32x128xf16, #B_3D, #smem> + + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<8x128x32x!tt.ptr, #AL_3D>, tensor<8x32x128x!tt.ptr, #BL_3D>, !ttg.memdesc<8x128x32xf8E5M2, #A_3D, #smem>, !ttg.memdesc<8x32x128xf16, #B_3D, #smem>, tensor<8x128x128xf32, #C_3D> + } + tt.return %loop#4 : tensor<8x128x128xf32, #C_3D> +} +} // end module + +// CHECK: tt.func @matmul_loop_yield_no_operand +// CHECK: scf.for +// CHECK: scf.if +// CHECK: tt.store +// CHECK-NOT: scf.yield +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:86", "ttg.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_yield_no_operand(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = arith.muli %arg9, %arg10 : i32 + %1 = arith.addi %arg8, %c31_i32 : i32 + %2 = arith.divsi %1, %c32_i32 : i32 + %3 = arith.addi %0, %c31_i32 : i32 + %4 = arith.divsi %3, %c32_i32 : i32 + %5 = arith.muli %1, %4 : i32 + %6 = tt.get_program_id x : i32 + %7 = tt.get_num_programs x : i32 + %8 = tt.splat %arg3 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + scf.for %arg11 = %6 to %5 step %7 : i32 { + %9 = arith.divsi %arg11, %4 : i32 + %10 = arith.remsi %9, %2 : i32 + %11 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> + %12 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> + %13 = ttg.convert_layout %12 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %14 = ttg.convert_layout %11 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %15 = tt.dot %13, %14, %cst, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %16 = arith.cmpi sgt, %10, %c0_i32 : i32 + %17 = scf.if %16 -> (tensor<32x32xf32, #mma>) { + %21 = tt.dot %13, %14, %15, inputPrecision = tf32 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + scf.yield %21 : tensor<32x32xf32, #mma> + } else { + scf.yield %15 : tensor<32x32xf32, #mma> + } + %18 = tt.splat %arg5 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %19 = arith.truncf %17 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %20 = ttg.convert_layout %19 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked1> + tt.store %18, %20 : tensor<32x32x!tt.ptr, #blocked1> + } + tt.return + } +} + +// ----- + +#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = false}> +#A_OP = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_OP = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#smem = #ttg.shared_memory + +// CHECK: tt.func @matmul_loop_mixed_amd +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = ttg.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = ttg.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = ttg.local_load %[[B0_PREFETCH_SMEM]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM:.*]] = ttg.local_load %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = ttg.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM:.*]] = ttg.local_load %[[B_REM_SMEM]] +// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = ttg.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = ttg.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = ttg.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +module attributes { "ttg.num-warps" = 4 : i32 } { +tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf8E5M2, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> + %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A, #smem> -> tensor<128x32xf8E5M2, #A_OP> + %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B, #smem> -> tensor<32x128xf16, #B_OP> + %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> + %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> + + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C> + } + tt.return %loop#4 : tensor<128x128xf32, #C> +} +} // end module diff --git a/third_party/enflame/include/triton/test/TritonGPU/promote-lhs-to-tmem.mlir b/third_party/enflame/include/triton/test/TritonGPU/promote-lhs-to-tmem.mlir new file mode 100644 index 000000000..46907599d --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/promote-lhs-to-tmem.mlir @@ -0,0 +1,137 @@ +// RUN: env ENABLE_LHS_TO_TMEM=1 triton-opt %s -split-input-file -tritongpu-promote-lhs-to-tmem | FileCheck --dump-input-context=50 %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#tmem = #ttng.tensor_memory_encoding +#tmem_scales = #ttng.tensor_memory_scales_encoding<> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @promote_lhs + // CHECK: scf.for + // CHECK: %[[A:.+]] = tt.load + // CHECK: %[[A_TMEM:.+]] = ttng.tmem_alloc %[[A]] + // CHECK: ttng.tc_gen5_mma %[[A_TMEM]] + tt.func public @promote_lhs(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32) -> tensor<128x128xf16, #blocked1> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + scf.yield %acc_res : tensor<128x128xf32, #blocked1> + } + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked1> to tensor<128x128xf16, #blocked1> + tt.return %res_f16 : tensor<128x128xf16, #blocked1> + } + + // CHECK-LABEL: @promote_lhs_mxfp + // CHECK: scf.for + // CHECK: %[[A:.+]] = tt.load + // CHECK: %[[A_TMEM:.+]] = ttng.tmem_alloc %[[A]] + // CHECK: ttng.tc_gen5_mma_scaled %[[A_TMEM]] + tt.func public @promote_lhs_mxfp(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32, %a_scale: tensor<128x1xi8, #blocked2>, %b_scale: tensor<64x1xi8, #blocked2>) -> tensor<128x128xf16, #blocked1> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %a_scale_tm = ttng.tmem_alloc %a_scale : (tensor<128x1xi8, #blocked2>) -> !ttg.memdesc<128x1xi8, #tmem_scales, #ttng.tensor_memory> + %b_scale_tm = ttng.tmem_alloc %b_scale : (tensor<64x1xi8, #blocked2>) -> !ttg.memdesc<64x1xi8, #tmem_scales, #ttng.tensor_memory> + ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm, %a_scale_tm, %b_scale_tm, %true, %true lhs = e5m2 rhs = e5m2 : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x1xi8, #tmem_scales, #ttng.tensor_memory>, !ttg.memdesc<64x1xi8, #tmem_scales, #ttng.tensor_memory>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + scf.yield %acc_res : tensor<128x128xf32, #blocked1> + } + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked1> to tensor<128x128xf16, #blocked1> + tt.return %res_f16 : tensor<128x128xf16, #blocked1> + } + + // CHECK-LABEL: @dont_promote_rhs + // CHECK: scf.for + // CHECK: %[[B:.+]] = tt.load + // CHECK: %[[B_TMEM:.+]] = ttg.local_alloc %[[B]] + // CHECK: ttng.tc_gen5_mma %{{.+}}, %[[B_TMEM]], %{{.+}}, {{.+}} + tt.func public @dont_promote_rhs(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32) -> tensor<128x128xf16, #blocked1> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %A_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { + %B = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %A_sh = ttg.memdesc_subview %A_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + scf.yield %acc_res : tensor<128x128xf32, #blocked1> + } + ttg.local_dealloc %A_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked1> to tensor<128x128xf16, #blocked1> + tt.return %res_f16 : tensor<128x128xf16, #blocked1> + } + + // CHECK-LABEL: @dont_promote_long_lr + // CHECK: %[[A:.+]] = tt.load + // CHECK: %[[A_SMEM:.+]] = ttg.local_alloc %[[A]] + // CHECK: scf.for + // CHECK: ttng.tc_gen5_mma %[[A_SMEM]] + tt.func public @dont_promote_long_lr(%A_ptr: tensor<128x128x!tt.ptr, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32) -> tensor<128x128xf16, #blocked1> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked1> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + scf.yield %acc_res : tensor<128x128xf32, #blocked1> + } + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked1> to tensor<128x128xf16, #blocked1> + tt.return %res_f16 : tensor<128x128xf16, #blocked1> + } + + // CHECK-LABEL: @dont_convert_layout + // CHECK: scf.for + // CHECK: %[[A:.+]] = tt.load + // CHECK: %[[A_SMEM:.+]] = ttg.local_alloc %[[A]] + // CHECK: ttng.tc_gen5_mma %[[A_SMEM]] + tt.func public @dont_convert_layout(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked1>, %arg3: i32) -> tensor<128x128xf16, #blocked1> attributes {noinline = false} { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %B_multibuf = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res = scf.for %i = %c0_i32 to %arg3 step %c1_i32 iter_args(%acc = %cst) -> (tensor<128x128xf32, #blocked1>) : i32 { + %A = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %B_sh = ttg.memdesc_subview %B_multibuf[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + %acc_tm = ttng.tmem_alloc %acc : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1) -> () + %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + scf.yield %acc_res : tensor<128x128xf32, #blocked1> + } + ttg.local_dealloc %B_multibuf : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory, mutable> + %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked1> to tensor<128x128xf16, #blocked1> + tt.return %res_f16 : tensor<128x128xf16, #blocked1> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/reduce-data-duplication.mlir b/third_party/enflame/include/triton/test/TritonGPU/reduce-data-duplication.mlir new file mode 100644 index 000000000..83a1b1afc --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/reduce-data-duplication.mlir @@ -0,0 +1,42 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s + +// CHECK: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]} +// CHECK-LABEL: apply_swizzle +// CHECK: %{{.*}} = ttg.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !ttg.memdesc<16x256xf16, #[[$SHARED]], #smem> + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @apply_swizzle(%arg0: tensor<16x256xf16, #blocked>) { + %0 = ttg.convert_layout %arg0 : tensor<16x256xf16, #blocked> -> tensor<16x256xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp32 +// CHECK-NOT: ttg.local_alloc +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.local_alloc +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @conversion_shortcut_blocked_dotop_warp32(%arg0: tensor<64x64xf16, #blocked>) { + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp64 +// CHECK-NOT: ttg.local_alloc +// CHECK: ttg.convert_layout +// CHECK-NOT: ttg.local_alloc +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @conversion_shortcut_blocked_dotop_warp64(%arg0: tensor<64x64xf16, #blocked>) { + %0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/reorder-instructions.mlir b/third_party/enflame/include/triton/test/TritonGPU/reorder-instructions.mlir new file mode 100644 index 000000000..dcce92702 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/reorder-instructions.mlir @@ -0,0 +1,106 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-reorder-instructions | FileCheck %s + +// check that we don't hoist convert_layout above its operand definition. +// CHECK-LABEL: convert_cannot_hoist +// CHECK: %[[CVTS:.+]] = ttg.local_alloc +// CHECK: ttg.local_load %[[CVTS]] +// CHECK: tt.dot +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @convert_cannot_hoist(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %10 = ttg.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: sink_convert_dealloc +// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> +// CHECK: ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> +// CHECK: %3 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %2 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + ttg.async_wait {num = 0 : i32} + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: sink_convert_idx_1 +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> +// CHECK: tt.dot +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @sink_convert_idx_1(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %B = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %A = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %AS = ttg.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD = ttg.local_load %AS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %AD, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// check that we don't sink convert_layout if it has multi users +// CHECK-LABEL: convert_cannot_sink +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: tt.dot +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: tt.dot +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @convert_cannot_sink(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %B = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %A0 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %AS0 = ttg.local_alloc %A0 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD0 = ttg.local_load %AS0 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %12 = tt.dot %AD0, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %A1 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %AS1 = ttg.local_alloc %A1 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD1 = ttg.local_load %AS1 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %13 = tt.dot %AD1, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir b/third_party/enflame/include/triton/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir new file mode 100644 index 000000000..050183a98 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir @@ -0,0 +1,178 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. + +// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +// CHECK: #[[$ATTR_4:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +// CHECK: #[[$ATTR_5:.+]] = #ttg.shared_memory +// To regenerate this test case, run `make golden-samples` in the triton root directory +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=51 %s +// CHECK-LABEL: tt.func public @matmul_kernel_with_descriptors( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_7:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_8:.*]] = arith.constant -1 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_10:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_11:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_13:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_15:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 127 : i32 +// CHECK: %[[VAL_17:.*]] = arith.constant 255 : i32 +// CHECK: %[[VAL_18:.*]] = arith.constant 63 : i32 +// CHECK: %[[VAL_19:.*]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_20:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_3]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_22:.*]] = arith.divsi %[[VAL_21]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_4]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_24:.*]] = arith.divsi %[[VAL_23]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_24]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_26:.*]] = arith.divsi %[[VAL_20]], %[[VAL_25]] : i32 +// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_26]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_22]], %[[VAL_27]] : i32 +// CHECK: %[[VAL_29:.*]] = arith.minsi %[[VAL_28]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_30:.*]] = arith.remsi %[[VAL_20]], %[[VAL_29]] : i32 +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_27]], %[[VAL_30]] : i32 +// CHECK: %[[VAL_32:.*]] = arith.remsi %[[VAL_20]], %[[VAL_25]] : i32 +// CHECK: %[[VAL_33:.*]] = arith.divsi %[[VAL_32]], %[[VAL_29]] : i32 +// CHECK: %[[VAL_34:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 +// CHECK: %[[VAL_35:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_36:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_37:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 +// CHECK: %[[VAL_38:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_37]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_31]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_40:.*]] = arith.muli %[[VAL_33]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_5]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_42:.*]] = arith.divsi %[[VAL_41]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_43:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_44:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_45:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_46:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_46]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_47:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_47]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_48:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_48]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_49:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_50:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_51:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_52:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_52]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_54:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_54]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_53]], %[[VAL_50]], %[[VAL_49]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_55:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_56:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_56]], 49152, %[[VAL_55]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_57:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_58:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_58]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_59:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_60:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_60]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_59]], %[[VAL_56]], %[[VAL_55]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_61:.*]]:5 = scf.for %[[VAL_62:.*]] = %[[VAL_12]] to %[[VAL_42]] step %[[VAL_15]] iter_args(%[[VAL_63:.*]] = %[[VAL_19]], %[[VAL_64:.*]] = %[[VAL_13]], %[[VAL_65:.*]] = %[[VAL_15]], %[[VAL_66:.*]] = %[[VAL_8]], %[[VAL_67:.*]] = %[[VAL_12]]) -> (tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32) : i32 { +// CHECK: %[[VAL_68:.*]] = arith.subi %[[VAL_42]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_69:.*]] = arith.cmpi slt, %[[VAL_62]], %[[VAL_68]] : i32 +// CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_66]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_71:.*]] = arith.cmpi slt, %[[VAL_70]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_72:.*]] = arith.select %[[VAL_71]], %[[VAL_70]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_73:.*]] = arith.xori %[[VAL_67]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_71]], %[[VAL_67]], %[[VAL_73]] : i32 +// CHECK: %[[VAL_75:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_72]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.wait_barrier %[[VAL_75]], %[[VAL_74]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_76:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_77:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_78:.*]] = ttg.memdesc_trans %[[VAL_76]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_79:.*]] = ttng.warp_group_dot %[[VAL_77]], %[[VAL_78]], %[[VAL_63]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> -> tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_80:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_79]], %[[VAL_77]], %[[VAL_78]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_64]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_82:.*]] = arith.addi %[[VAL_65]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_83:.*]] = arith.cmpi slt, %[[VAL_82]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_82]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_85:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_84]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_85]], 49152, %[[VAL_69]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_86:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_87:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_87]]{{\[}}%[[VAL_39]], %[[VAL_81]]] %[[VAL_86]], %[[VAL_85]], %[[VAL_69]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_88:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_89:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_89]]{{\[}}%[[VAL_40]], %[[VAL_81]]] %[[VAL_88]], %[[VAL_85]], %[[VAL_69]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_81]], %[[VAL_84]], %[[VAL_72]], %[[VAL_74]] : tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_90:.*]] = ttng.warp_group_dot_wait %[[VAL_91:.*]]#0 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_92:.*]] = ttg.async_wait {num = 0 : i32} +// CHECK: %[[VAL_93:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_93]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_94:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_94]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_95:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_95]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttg.local_dealloc %[[VAL_45]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttg.local_dealloc %[[VAL_44]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttg.local_dealloc %[[VAL_43]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_96:.*]] = arith.truncf %[[VAL_90]] : tensor<128x256xf32, #[[$ATTR_1]]> to tensor<128x256xf16, #[[$ATTR_1]]> +// CHECK: %[[VAL_97:.*]] = ttg.convert_layout %[[VAL_96]] : tensor<128x256xf16, #[[$ATTR_1]]> -> tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: tt.experimental_descriptor_store %[[VAL_38]]{{\[}}%[[VAL_39]], %[[VAL_40]]], %[[VAL_97]] : !tt.tensordesc>, tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: tt.return +// CHECK: } +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.extsi %arg5 : i32 to i64 + %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , > + %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , > + %17 = arith.extsi %arg4 : i32 to i64 + %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , > + %19 = arith.muli %11, %c128_i32 : i32 + %20 = arith.muli %13, %c256_i32 : i32 + %21 = arith.addi %arg5, %c63_i32 : i32 + %22 = arith.divsi %21, %c64_i32 : i32 + %23:2 = scf.for %arg6 = %c0_i32 to %22 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32) : i32 { + %26 = tt.experimental_descriptor_load %15[%19, %arg8] : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %27 = ttg.local_alloc %26 : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> + %28 = tt.experimental_descriptor_load %16[%20, %arg8] : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %29 = ttg.local_alloc %28 : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> + %30 = ttg.memdesc_trans %29 {order = array} : !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> + %31 = ttng.warp_group_dot %27, %30, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %32 = arith.addi %arg8, %c64_i32 : i32 + scf.yield %31, %32 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32 + } + %24 = arith.truncf %23#0 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %25 = ttg.convert_layout %24 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %18[%19, %20], %25 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in b/third_party/enflame/include/triton/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in new file mode 100644 index 000000000..6e679e647 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in @@ -0,0 +1,54 @@ +// To regenerate this test case, run `make golden-samples` in the triton root directory +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=51 %s +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.extsi %arg5 : i32 to i64 + %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , > + %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , > + %17 = arith.extsi %arg4 : i32 to i64 + %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , > + %19 = arith.muli %11, %c128_i32 : i32 + %20 = arith.muli %13, %c256_i32 : i32 + %21 = arith.addi %arg5, %c63_i32 : i32 + %22 = arith.divsi %21, %c64_i32 : i32 + %23:2 = scf.for %arg6 = %c0_i32 to %22 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32) : i32 { + %26 = tt.experimental_descriptor_load %15[%19, %arg8] : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %27 = ttg.local_alloc %26 : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> + %28 = tt.experimental_descriptor_load %16[%20, %arg8] : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %29 = ttg.local_alloc %28 : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> + %30 = ttg.memdesc_trans %29 {order = array} : !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> + %31 = ttng.warp_group_dot %27, %30, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %32 = arith.addi %arg8, %c64_i32 : i32 + scf.yield %31, %32 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32 + } + %24 = arith.truncf %23#0 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %25 = ttg.convert_layout %24 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %18[%19, %20], %25 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/samples/simulated-grouped-gemm.mlir b/third_party/enflame/include/triton/test/TritonGPU/samples/simulated-grouped-gemm.mlir new file mode 100644 index 000000000..374a83c0b --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/samples/simulated-grouped-gemm.mlir @@ -0,0 +1,368 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. + +// CHECK: #[[$ATTR_0:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +// CHECK: #[[$ATTR_4:.+]] = #ttg.shared_memory +// To regenerate this test case, run `make golden-samples` in the triton root directory +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +// CHECK-LABEL: tt.func public @matmul_kernel_descriptor_persistent( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : i64 +// CHECK: %[[VAL_8:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_7:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant false +// CHECK: %[[VAL_10:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_11:.*]] = arith.constant 132 : i32 +// CHECK: %[[VAL_12:.*]] = arith.constant -1 : i32 +// CHECK: %[[VAL_13:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_15:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_17:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_18:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_19:.*]] = arith.constant 127 : i32 +// CHECK: %[[VAL_20:.*]] = arith.constant 255 : i32 +// CHECK: %[[VAL_21:.*]] = arith.constant 63 : i32 +// CHECK: %[[VAL_22:.*]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_23:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_3]], %[[VAL_19]] : i32 +// CHECK: %[[VAL_25:.*]] = arith.divsi %[[VAL_24]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_4]], %[[VAL_20]] : i32 +// CHECK: %[[VAL_27:.*]] = arith.divsi %[[VAL_26]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_5]], %[[VAL_21]] : i32 +// CHECK: %[[VAL_29:.*]] = arith.divsi %[[VAL_28]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_30:.*]] = arith.muli %[[VAL_25]], %[[VAL_27]] : i32 +// CHECK: %[[VAL_31:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 +// CHECK: %[[VAL_32:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_31]], %[[VAL_18]]] : , > +// CHECK: %[[VAL_33:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_31]], %[[VAL_18]]] : , > +// CHECK: %[[VAL_34:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 +// CHECK: %[[VAL_35:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_34]], %[[VAL_18]]] : , > +// CHECK: %[[VAL_36:.*]] = arith.divsi %[[VAL_30]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_37:.*]] = arith.remsi %[[VAL_30]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_37]] : i32 +// CHECK: %[[VAL_39:.*]] = scf.if %[[VAL_38]] -> (i32) { +// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_36]], %[[VAL_10]] : i32 +// CHECK: scf.yield %[[VAL_40]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_36]] : i32 +// CHECK: } +// CHECK: %[[VAL_41:.*]] = arith.subi %[[VAL_23]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_42:.*]] = arith.muli %[[VAL_27]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_43:.*]] = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 +// CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_29]], %[[VAL_39]] : i32 +// CHECK: %[[VAL_45:.*]] = arith.subi %[[VAL_29]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_49:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_50:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_51:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_52:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_13]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_52]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_10]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_53]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_54:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_7]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_54]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_46:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_47:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_48:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_55:.*]] = arith.cmpi sgt, %[[VAL_44]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_56:.*]] = arith.select %[[VAL_55]], %[[VAL_23]], %[[VAL_41]] : i32 +// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_55]], %[[VAL_13]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_58:.*]]:2 = scf.if %[[VAL_55]] -> (i32, i32) { +// CHECK: %[[VAL_59:.*]] = arith.divsi %[[VAL_23]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_60:.*]] = arith.muli %[[VAL_59]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_61:.*]] = arith.subi %[[VAL_25]], %[[VAL_60]] : i32 +// CHECK: %[[VAL_62:.*]] = arith.minsi %[[VAL_61]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_63:.*]] = arith.remsi %[[VAL_23]], %[[VAL_62]] : i32 +// CHECK: %[[VAL_64:.*]] = arith.addi %[[VAL_60]], %[[VAL_63]] : i32 +// CHECK: %[[VAL_65:.*]] = arith.remsi %[[VAL_23]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_66:.*]] = arith.divsi %[[VAL_65]], %[[VAL_62]] : i32 +// CHECK: %[[VAL_67:.*]] = arith.muli %[[VAL_64]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_68:.*]] = arith.muli %[[VAL_66]], %[[VAL_16]] : i32 +// CHECK: scf.yield %[[VAL_67]], %[[VAL_68]] : i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_13]], %[[VAL_13]] : i32, i32 +// CHECK: } +// CHECK: %[[VAL_69:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_13]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_69]], 49152, %[[VAL_55]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_70:.*]] = ttg.memdesc_subview %[[VAL_49]]{{\[}}%[[VAL_13]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_71:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_32]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_71]]{{\[}}%[[VAL_72:.*]]#0, %[[VAL_13]]] %[[VAL_70]], %[[VAL_69]], %[[VAL_55]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_73:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_13]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_74:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_33]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_74]]{{\[}}%[[VAL_72]]#1, %[[VAL_13]]] %[[VAL_73]], %[[VAL_69]], %[[VAL_55]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_75:.*]] = arith.cmpi sgt, %[[VAL_44]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_76:.*]] = arith.cmpi ne, %[[VAL_45]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_77:.*]] = arith.extui %[[VAL_76]] : i1 to i32 +// CHECK: %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_77]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_79:.*]] = arith.andi %[[VAL_75]], %[[VAL_78]] : i1 +// CHECK: %[[VAL_80:.*]]:10 = scf.if %[[VAL_79]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_57]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_82:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_83:.*]] = arith.select %[[VAL_82]], %[[VAL_13]], %[[VAL_81]] : i32 +// CHECK: %[[VAL_84:.*]] = arith.extui %[[VAL_82]] : i1 to i32 +// CHECK: %[[VAL_85:.*]] = arith.extui %[[VAL_82]] : i1 to i32 +// CHECK: %[[VAL_86:.*]] = arith.extui %[[VAL_82]] : i1 to i32 +// CHECK: %[[VAL_87:.*]]:3 = scf.if %[[VAL_82]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>) { +// CHECK: %[[VAL_88:.*]] = tt.addptr %[[VAL_0]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_89:.*]] = arith.muli %[[VAL_31]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_46]], %[[VAL_88]], {{\[}}%[[VAL_17]], %[[VAL_15]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_89]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_46]] : !tt.ptr +// CHECK: %[[VAL_90:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_46]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_91:.*]] = tt.addptr %[[VAL_1]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_92:.*]] = arith.muli %[[VAL_31]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_47]], %[[VAL_91]], {{\[}}%[[VAL_17]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_92]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_47]] : !tt.ptr +// CHECK: %[[VAL_93:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_47]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_94:.*]] = tt.addptr %[[VAL_2]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_95:.*]] = arith.muli %[[VAL_34]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_48]], %[[VAL_94]], {{\[}}%[[VAL_17]], %[[VAL_15]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_95]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_48]] : !tt.ptr +// CHECK: %[[VAL_96:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_48]] : !tt.ptr to !tt.tensordesc> +// CHECK: scf.yield %[[VAL_90]], %[[VAL_93]], %[[VAL_96]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_32]], %[[VAL_33]], %[[VAL_35]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc> +// CHECK: } +// CHECK: %[[VAL_97:.*]] = arith.addi %[[VAL_56]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_98:.*]] = arith.divsi %[[VAL_97]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_99:.*]] = arith.muli %[[VAL_98]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_100:.*]] = arith.subi %[[VAL_25]], %[[VAL_99]] : i32 +// CHECK: %[[VAL_101:.*]] = arith.minsi %[[VAL_100]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_102:.*]] = arith.remsi %[[VAL_97]], %[[VAL_101]] : i32 +// CHECK: %[[VAL_103:.*]] = arith.addi %[[VAL_99]], %[[VAL_102]] : i32 +// CHECK: %[[VAL_104:.*]] = arith.remsi %[[VAL_97]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_105:.*]] = arith.divsi %[[VAL_104]], %[[VAL_101]] : i32 +// CHECK: %[[VAL_106:.*]] = arith.muli %[[VAL_103]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_107:.*]] = arith.muli %[[VAL_105]], %[[VAL_16]] : i32 +// CHECK: scf.yield %[[VAL_108:.*]]#0, %[[VAL_108]]#1, %[[VAL_108]]#2, %[[VAL_97]], %[[VAL_83]], %[[VAL_106]], %[[VAL_107]], %[[VAL_84]], %[[VAL_85]], %[[VAL_86]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_32]], %[[VAL_33]], %[[VAL_35]], %[[VAL_56]], %[[VAL_57]], %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_13]], %[[VAL_13]], %[[VAL_13]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_109:.*]] = arith.muli %[[VAL_77]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_110:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_10]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_110]], 49152, %[[VAL_75]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_111:.*]] = ttg.memdesc_subview %[[VAL_49]]{{\[}}%[[VAL_10]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_112:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_113:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_112]]{{\[}}%[[VAL_113]]#5, %[[VAL_109]]] %[[VAL_111]], %[[VAL_110]], %[[VAL_75]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_114:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_10]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_115:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_113]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_115]]{{\[}}%[[VAL_113]]#6, %[[VAL_109]]] %[[VAL_114]], %[[VAL_110]], %[[VAL_75]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_116:.*]] = ttg.local_alloc : () -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_117:.*]]:20 = scf.for %[[VAL_118:.*]] = %[[VAL_13]] to %[[VAL_44]] step %[[VAL_10]] iter_args(%[[VAL_119:.*]] = %[[VAL_77]], %[[VAL_120:.*]] = %[[VAL_113]]#0, %[[VAL_121:.*]] = %[[VAL_113]]#1, %[[VAL_122:.*]] = %[[VAL_113]]#2, %[[VAL_123:.*]] = %[[VAL_113]]#3, %[[VAL_124:.*]] = %[[VAL_113]]#4, %[[VAL_125:.*]] = %[[VAL_113]]#5, %[[VAL_126:.*]] = %[[VAL_113]]#6, %[[VAL_127:.*]] = %[[VAL_22]], %[[VAL_128:.*]] = %[[VAL_9]], %[[VAL_129:.*]] = %[[VAL_10]], %[[VAL_130:.*]] = %[[VAL_12]], %[[VAL_131:.*]] = %[[VAL_13]], %[[VAL_132:.*]] = %[[VAL_113]]#7, %[[VAL_133:.*]] = %[[VAL_113]]#8, %[[VAL_134:.*]] = %[[VAL_113]]#9, %[[VAL_135:.*]] = %[[VAL_13]], %[[VAL_136:.*]] = %[[VAL_35]], %[[VAL_137:.*]] = %[[VAL_72]]#0, %[[VAL_138:.*]] = %[[VAL_72]]#1) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, i32, i32) : i32 { +// CHECK: %[[VAL_139:.*]] = arith.subi %[[VAL_44]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_140:.*]] = arith.cmpi slt, %[[VAL_118]], %[[VAL_139]] : i32 +// CHECK: %[[VAL_141:.*]] = arith.cmpi eq, %[[VAL_119]], %[[VAL_45]] : i32 +// CHECK: %[[VAL_142:.*]] = arith.addi %[[VAL_119]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_143:.*]] = arith.select %[[VAL_141]], %[[VAL_13]], %[[VAL_142]] : i32 +// CHECK: %[[VAL_144:.*]] = arith.cmpi eq, %[[VAL_143]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_145:.*]] = arith.andi %[[VAL_140]], %[[VAL_144]] : i1 +// CHECK: %[[VAL_146:.*]]:10 = scf.if %[[VAL_145]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_147:.*]] = arith.addi %[[VAL_124]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_148:.*]] = arith.cmpi eq, %[[VAL_147]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_149:.*]] = arith.select %[[VAL_148]], %[[VAL_13]], %[[VAL_147]] : i32 +// CHECK: %[[VAL_150:.*]]:6 = scf.if %[[VAL_148]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32) { +// CHECK: %[[VAL_151:.*]] = tt.addptr %[[VAL_0]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_152:.*]] = arith.muli %[[VAL_132]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_153:.*]] = tt.addptr %[[VAL_46]], %[[VAL_152]] : !tt.ptr, i32 +// CHECK: %[[VAL_154:.*]] = arith.muli %[[VAL_31]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_153]], %[[VAL_151]], {{\[}}%[[VAL_17]], %[[VAL_15]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_154]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_153]] : !tt.ptr +// CHECK: %[[VAL_155:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_153]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_156:.*]] = arith.addi %[[VAL_132]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_157:.*]] = arith.cmpi slt, %[[VAL_156]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_158:.*]] = arith.select %[[VAL_157]], %[[VAL_156]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_159:.*]] = tt.addptr %[[VAL_1]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_160:.*]] = arith.muli %[[VAL_133]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_161:.*]] = tt.addptr %[[VAL_47]], %[[VAL_160]] : !tt.ptr, i32 +// CHECK: %[[VAL_162:.*]] = arith.muli %[[VAL_31]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_161]], %[[VAL_159]], {{\[}}%[[VAL_17]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_162]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_161]] : !tt.ptr +// CHECK: %[[VAL_163:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_161]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_164:.*]] = arith.addi %[[VAL_133]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_165:.*]] = arith.cmpi slt, %[[VAL_164]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_166:.*]] = arith.select %[[VAL_165]], %[[VAL_164]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_167:.*]] = tt.addptr %[[VAL_2]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_168:.*]] = arith.muli %[[VAL_134]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_169:.*]] = tt.addptr %[[VAL_48]], %[[VAL_168]] : !tt.ptr, i32 +// CHECK: %[[VAL_170:.*]] = arith.muli %[[VAL_34]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_169]], %[[VAL_167]], {{\[}}%[[VAL_17]], %[[VAL_15]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_170]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_169]] : !tt.ptr +// CHECK: %[[VAL_171:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_169]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_172:.*]] = arith.addi %[[VAL_134]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_173:.*]] = arith.cmpi slt, %[[VAL_172]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_174:.*]] = arith.select %[[VAL_173]], %[[VAL_172]], %[[VAL_13]] : i32 +// CHECK: scf.yield %[[VAL_155]], %[[VAL_163]], %[[VAL_171]], %[[VAL_158]], %[[VAL_166]], %[[VAL_174]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_120]], %[[VAL_121]], %[[VAL_122]], %[[VAL_132]], %[[VAL_133]], %[[VAL_134]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_175:.*]] = arith.addi %[[VAL_123]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_176:.*]] = arith.divsi %[[VAL_175]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_177:.*]] = arith.muli %[[VAL_176]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_178:.*]] = arith.subi %[[VAL_25]], %[[VAL_177]] : i32 +// CHECK: %[[VAL_179:.*]] = arith.minsi %[[VAL_178]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_180:.*]] = arith.remsi %[[VAL_175]], %[[VAL_179]] : i32 +// CHECK: %[[VAL_181:.*]] = arith.addi %[[VAL_177]], %[[VAL_180]] : i32 +// CHECK: %[[VAL_182:.*]] = arith.remsi %[[VAL_175]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_183:.*]] = arith.divsi %[[VAL_182]], %[[VAL_179]] : i32 +// CHECK: %[[VAL_184:.*]] = arith.muli %[[VAL_181]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_185:.*]] = arith.muli %[[VAL_183]], %[[VAL_16]] : i32 +// CHECK: scf.yield %[[VAL_186:.*]]#0, %[[VAL_186]]#1, %[[VAL_186]]#2, %[[VAL_175]], %[[VAL_149]], %[[VAL_184]], %[[VAL_185]], %[[VAL_186]]#3, %[[VAL_186]]#4, %[[VAL_186]]#5 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_120]], %[[VAL_121]], %[[VAL_122]], %[[VAL_123]], %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_132]], %[[VAL_133]], %[[VAL_134]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_187:.*]] = arith.addi %[[VAL_130]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_188:.*]] = arith.cmpi slt, %[[VAL_187]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_189:.*]] = arith.select %[[VAL_188]], %[[VAL_187]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_190:.*]] = arith.xori %[[VAL_131]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_191:.*]] = arith.select %[[VAL_188]], %[[VAL_131]], %[[VAL_190]] : i32 +// CHECK: %[[VAL_192:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_189]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.wait_barrier %[[VAL_192]], %[[VAL_191]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_193:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_189]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_194:.*]] = ttg.memdesc_subview %[[VAL_49]]{{\[}}%[[VAL_189]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_195:.*]] = ttg.memdesc_trans %[[VAL_193]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_196:.*]] = ttng.warp_group_dot %[[VAL_194]], %[[VAL_195]], %[[VAL_127]], %[[VAL_128]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> -> tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_197:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_196]], %[[VAL_194]], %[[VAL_195]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_198:.*]] = arith.addi %[[VAL_129]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_199:.*]] = arith.cmpi slt, %[[VAL_198]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_200:.*]] = arith.select %[[VAL_199]], %[[VAL_198]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_201:.*]] = arith.muli %[[VAL_143]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_202:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_200]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_202]], 49152, %[[VAL_140]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_203:.*]] = ttg.memdesc_subview %[[VAL_49]]{{\[}}%[[VAL_200]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_204:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_205:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_204]]{{\[}}%[[VAL_205]]#5, %[[VAL_201]]] %[[VAL_203]], %[[VAL_202]], %[[VAL_140]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_206:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_200]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_207:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_205]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_207]]{{\[}}%[[VAL_205]]#6, %[[VAL_201]]] %[[VAL_206]], %[[VAL_202]], %[[VAL_140]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_208:.*]] = arith.cmpi eq, %[[VAL_135]], %[[VAL_45]] : i32 +// CHECK: %[[VAL_209:.*]] = arith.cmpi ne, %[[VAL_135]], %[[VAL_45]] : i32 +// CHECK: scf.if %[[VAL_208]] { +// CHECK: %[[VAL_210:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_197]]#0, %[[VAL_194]], %[[VAL_195]] {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_211:.*]] = arith.truncf %[[VAL_210]]#0 : tensor<128x256xf32, #[[$ATTR_0]]> to tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} +// CHECK: ttg.local_store %[[VAL_211]], %[[VAL_116]] : tensor<128x256xf16, #[[$ATTR_0]]> -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: ttng.fence_async_shared {bCluster = false} +// CHECK: %[[VAL_212:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_136]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_local_to_global %[[VAL_212]]{{\[}}%[[VAL_137]], %[[VAL_138]]] %[[VAL_116]] : !tt.ptr, !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: } +// CHECK: scf.yield %[[VAL_143]], %[[VAL_205]]#0, %[[VAL_205]]#1, %[[VAL_205]]#2, %[[VAL_205]]#3, %[[VAL_205]]#4, %[[VAL_205]]#5, %[[VAL_205]]#6, %[[VAL_197]]#0, %[[VAL_209]], %[[VAL_200]], %[[VAL_189]], %[[VAL_191]], %[[VAL_205]]#7, %[[VAL_205]]#8, %[[VAL_205]]#9, %[[VAL_119]], %[[VAL_122]], %[[VAL_125]], %[[VAL_126]] : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, i32, i32 +// CHECK: } +// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} +// CHECK: ttg.local_dealloc %[[VAL_116]] : !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_213:.*]] = ttng.warp_group_dot_wait %[[VAL_214:.*]]#8 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_215:.*]] = ttg.async_wait {num = 0 : i32} +// CHECK: %[[VAL_216:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_13]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_216]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_217:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_10]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_217]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_218:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_7]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_218]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttg.local_dealloc %[[VAL_51]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> +// CHECK: ttg.local_dealloc %[[VAL_50]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: ttg.local_dealloc %[[VAL_49]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: tt.return +// CHECK: } +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %11 = arith.extsi %arg4 : i32 to i64 + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %13 = arith.divsi %7, %c132_i32 : i32 + %14 = arith.remsi %7, %c132_i32 : i32 + %15 = arith.cmpi slt, %0, %14 : i32 + %16 = scf.if %15 -> (i32) { + %23 = arith.addi %13, %c1_i32 : i32 + scf.yield %23 : i32 + } else { + scf.yield %13 : i32 + } + %17 = arith.subi %0, %c132_i32 : i32 + %18 = arith.muli %4, %c8_i32 : i32 + %19 = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 + %20 = arith.muli %6, %16 : i32 + %21 = arith.subi %6, %c1_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %37 = arith.addi %arg12, %c1_i32 : i32 + %38 = arith.cmpi eq, %37, %c1_i32 : i32 + %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } + %40 = arith.addi %arg11, %c132_i32 : i32 + %41 = arith.divsi %40, %18 : i32 + %42 = arith.muli %41, %c8_i32 : i32 + %43 = arith.subi %2, %42 : i32 + %44 = arith.minsi %43, %c8_i32 : i32 + %45 = arith.remsi %40, %44 : i32 + %46 = arith.addi %42, %45 : i32 + %47 = arith.remsi %40, %18 : i32 + %48 = arith.divsi %47, %44 : i32 + %49 = arith.muli %46, %c128_i32 : i32 + %50 = arith.muli %48, %c256_i32 : i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + %29 = tt.experimental_descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> + %31 = tt.experimental_descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> + %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> + %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %35 = arith.cmpi eq, %25, %21 {loop.cluster = 3 : i32, loop.stage = 2 : i32} : i32 + %36 = scf.if %35 -> (i1) { + %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + scf.yield %false : i1 + } else { + scf.yield %true : i1 + } {loop.cluster = 3 : i32, loop.stage = 2 : i32} + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + } + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in b/third_party/enflame/include/triton/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in new file mode 100644 index 000000000..857819af6 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in @@ -0,0 +1,101 @@ +// To regenerate this test case, run `make golden-samples` in the triton root directory +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %11 = arith.extsi %arg4 : i32 to i64 + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %13 = arith.divsi %7, %c132_i32 : i32 + %14 = arith.remsi %7, %c132_i32 : i32 + %15 = arith.cmpi slt, %0, %14 : i32 + %16 = scf.if %15 -> (i32) { + %23 = arith.addi %13, %c1_i32 : i32 + scf.yield %23 : i32 + } else { + scf.yield %13 : i32 + } + %17 = arith.subi %0, %c132_i32 : i32 + %18 = arith.muli %4, %c8_i32 : i32 + %19 = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 + %20 = arith.muli %6, %16 : i32 + %21 = arith.subi %6, %c1_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %37 = arith.addi %arg12, %c1_i32 : i32 + %38 = arith.cmpi eq, %37, %c1_i32 : i32 + %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } + %40 = arith.addi %arg11, %c132_i32 : i32 + %41 = arith.divsi %40, %18 : i32 + %42 = arith.muli %41, %c8_i32 : i32 + %43 = arith.subi %2, %42 : i32 + %44 = arith.minsi %43, %c8_i32 : i32 + %45 = arith.remsi %40, %44 : i32 + %46 = arith.addi %42, %45 : i32 + %47 = arith.remsi %40, %18 : i32 + %48 = arith.divsi %47, %44 : i32 + %49 = arith.muli %46, %c128_i32 : i32 + %50 = arith.muli %48, %c256_i32 : i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + %29 = tt.experimental_descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> + %31 = tt.experimental_descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> + %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> + %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %35 = arith.cmpi eq, %25, %21 {loop.cluster = 3 : i32, loop.stage = 2 : i32} : i32 + %36 = scf.if %35 -> (i1) { + %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + scf.yield %false : i1 + } else { + scf.yield %true : i1 + } {loop.cluster = 3 : i32, loop.stage = 2 : i32} + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + } + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/tf32x3-matmul.mlir b/third_party/enflame/include/triton/test/TritonGPU/tf32x3-matmul.mlir new file mode 100644 index 000000000..180a5c633 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/tf32x3-matmul.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-F32DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK + +// CHECK: %[[DOT1:.*]] = tt.dot %[[LHS_LOW:.*]], %[[RHS_HIGH:.*]], %cst, inputPrecision = tf32 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> +// CHECK: %[[DOT2:.*]] = tt.dot %[[LHS_HIGH:.*]], %[[RHS_LOW:.*]], %[[DOT1]], inputPrecision = tf32 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> +// CHECK: %[[CMP:.*]] = arith.cmpf uno, %[[DOT2]], %[[DOT2]] : tensor<16x16xf32> +// CHECK: %[[MASKED:.*]] = arith.select %[[CMP]], %cst, %[[DOT2]] : tensor<16x16xi1>, tensor<16x16xf32> +// CHECK: %[[RESULT:.*]] = tt.dot %[[LHS_HIGH]], %[[RHS_HIGH]], %[[MASKED]], inputPrecision = tf32 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + +module { + tt.func @dot_test(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { + %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = tf32x3 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + tt.return %4 : tensor<16x16xf32> + } +} diff --git a/third_party/enflame/include/triton/test/TritonGPU/verify-blocked-layout.mlir b/third_party/enflame/include/triton/test/TritonGPU/verify-blocked-layout.mlir new file mode 100644 index 000000000..3c1d016cd --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonGPU/verify-blocked-layout.mlir @@ -0,0 +1,115 @@ +// RUN: triton-opt --split-input-file %s --verify-diagnostics + +#blocked = #ttg.blocked<{ + sizePerThread=[1, 1], + threadsPerWarp=[16, 1], + warpsPerCTA=[4, 1], + order=[0, 1], + CTAsPerCGA=[2, 1], + CTASplitNum=[1, 1], + CTAOrder=[0, 1] +}> +module attributes { + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 +} { + tt.func public @fn(%arg0: !tt.ptr) { + // expected-error @+1 {{threads per warp}} + %t = tt.splat %arg0 : !tt.ptr -> tensor<8x1x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{ + sizePerThread=[1, 1], + threadsPerWarp=[32, 1], + warpsPerCTA=[4, 2], + order=[0, 1], + CTAsPerCGA=[2, 1], + CTASplitNum=[1, 1], + CTAOrder=[0, 1] +}> +module attributes { + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 +} { + tt.func public @fn(%arg0: !tt.ptr) { + // expected-error @+1 {{warps per CTA}} + %t = tt.splat %arg0 : !tt.ptr -> tensor<8x1x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{ + sizePerThread=[1, 1], + threadsPerWarp=[32, 1], + warpsPerCTA=[4, 1], + order=[0, 1], + CTAsPerCGA=[1, 1], + CTASplitNum=[1, 1], + CTAOrder=[0, 1] +}> +module attributes { + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 +} { + tt.func public @fn(%arg0: !tt.ptr) { + // expected-error @+1 {{CTAs per CGA}} + %t = tt.splat %arg0 : !tt.ptr -> tensor<8x1x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{ + sizePerThread=[1, 1], + threadsPerWarp=[32, 1], + warpsPerCTA=[4, 1], + order=[0, 1], + CTAsPerCGA=[1, 2], + CTASplitNum=[1, 1], + CTAOrder=[0, 1] +}> +module attributes { + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 +} { + tt.func public @fn(%arg0: !tt.ptr) { + // Note it's a 3d tensor here, but #blocked is 2D. + // expected-error @+1 {{rank}} + %t = tt.splat %arg0 : !tt.ptr -> tensor<8x1x1x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{ + sizePerThread=[1, 1], + threadsPerWarp=[32, 1], + warpsPerCTA=[4, 1], + order=[0, 1], + CTAsPerCGA=[1, 2], + CTASplitNum=[1, 1], + CTAOrder=[0, 1] +}> +module attributes { + "ttg.num-warps" = 4 : i32, + "ttg.num-ctas" = 2 : i32, + "ttg.threads-per-warp" = 32 : i32 +} { + tt.func public @fn(%arg0: tensor<8xf32, #blocked>) { + // expected-error @+1 {{rank}} + %t = tt.expand_dims %arg0 {axis = 0 : i32} : tensor<8xf32, #blocked> -> tensor<8x1xf32, #blocked> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir new file mode 100644 index 000000000..1cca80d21 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir @@ -0,0 +1,63 @@ +// RUN: triton-opt %s -split-input-file --triton-gpu-taskid-propagate=num-consumer-groups=1 | FileCheck %s + +// CHECK-LABEL: @async_kernel +// CHECK: %0 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 +// CHECK: %5 = tt.splat %arg2 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1024xi32> +// CHECK: %9 = tt.load %8, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> +// CHECK: %10 = tt.splat %arg1 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: tt.store %11, %9 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + +module { + tt.func public @async_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg2 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %11, %9 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK-LABEL: @two_consumers +// CHECK: tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 +// CHECK: tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.splat %arg1 {async_task_id = dense<[1, 2]> : vector<2xi32>} +// CHECK: tt.store {{.*}} {async_task_id = dense<1> : vector<1xi32>} +// CHECK: tt.store {{.*}} {async_task_id = dense<2> : vector<1xi32>} + +module { + tt.func public @two_consumers(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.make_range {end = 2048 : i32, start = 1024 : i32} : tensor<1024xi32> + %4 = tt.splat %1 : i32 -> tensor<1024xi32> + %5 = arith.addi %4, %2 : tensor<1024xi32> + %6 = arith.addi %4, %3 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %5 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.addptr %7, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %8 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %11 = tt.load %9 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %13 = tt.addptr %12, %5 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %14 = tt.addptr %12, %6 {async_task_id = dense<2> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %13, %10 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.store %14, %11 {async_task_id = dense<2> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir new file mode 100644 index 000000000..51bd5f5f6 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir @@ -0,0 +1,857 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-code-partition=num-buffers=1 | FileCheck %s + +// CHECK-LABEL: @matmul_kernel_one_consumer +// CHECK: %[[#TASKID:]] = ttng.get_async_task_id : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[#WG0:]] = arith.cmpi eq, %[[#TASKID]], %c0_i32 : i32 +// CHECK: scf.if %[[#WG0]] +// CHECK: ttng.reg_dealloc 40 +// CHECK: scf.for +// CHECK: ttng.producer_acquire +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttng.producer_commit +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %[[#WG1:]] = arith.cmpi eq, %[[#TASKID]], %c1_i32 : i32 +// CHECK: scf.if %[[#WG1]] +// CHECK: ttng.reg_alloc 232 +// CHECK: ttng.consumer_wait +// CHECK: ttg.local_load +// CHECK: ttg.local_load +// CHECK: tt.dot +// CHECK: ttng.consumer_release + + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_one_consumer(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked1> + %cst_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked2> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked2> + %0 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %1 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %2 = arith.divsi %1, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %3 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %4 = arith.divsi %3, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %5 = arith.muli %4, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %6 = arith.divsi %0, %5 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.muli %6, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.subi %2, %7 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.minsi %8, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.remsi %0, %5 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.remsi %10, %9 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.addi %7, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.divsi %10, %9 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.muli %12, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %16 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.splat %14 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %20 = arith.addi %18, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %21 = arith.addi %19, %16 {async_task_id = dense<1> : vector<1xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %23 = arith.remsi %20, %22 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %24 = arith.muli %13, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %25 = tt.splat %24 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %26 = arith.addi %25, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %27 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %28 = arith.remsi %26, %27 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %29 = tt.expand_dims %23 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %30 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked2> + %31 = arith.muli %29, %30 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked2> + %32 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %33 = tt.expand_dims %32 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %34 = tt.broadcast %31 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %35 = tt.broadcast %33 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %36 = arith.addi %34, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256xi32, #blocked2> + %37 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked2> + %38 = tt.addptr %37, %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %39 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %40 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %41 = tt.expand_dims %39 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %42 = tt.expand_dims %40 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %43 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked1> + %44 = arith.muli %41, %43 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> + %45 = tt.expand_dims %28 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %46 = tt.broadcast %44 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %47 = tt.broadcast %45 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %48 = arith.addi %46, %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128xi32, #blocked1> + %49 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %50 = tt.addptr %49, %48 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + %51 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %52 = arith.divsi %51, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %53 = arith.muli %arg7, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %54 = tt.splat %53 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x128xi32, #blocked1> + %55:3 = scf.for %arg9 = %c0_i32 to %52 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %38, %arg12 = %50) -> (tensor<128x128xf32, #blocked>, tensor<128x256x!tt.ptr, #blocked2>, tensor<256x128x!tt.ptr, #blocked1>) : i32 { + %74 = arith.muli %arg9, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %75 = arith.subi %arg5, %74 {async_task_id = dense<0> : vector<1xi32>} : i32 + %76 = tt.splat %75 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x256xi32, #blocked2> + %77 = arith.cmpi slt, %33, %76 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked2> + %78 = tt.broadcast %77 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %79 = tt.load %arg11, %78, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2> + %80 = tt.splat %75 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked1> + %81 = arith.cmpi slt, %42, %80 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> + %82 = tt.broadcast %81 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi1, #blocked1> -> tensor<256x128xi1, #blocked1> + %83 = tt.load %arg12, %82, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1> + %84 = ttg.convert_layout %79 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #blocked2> -> tensor<128x256xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %85 = ttg.convert_layout %83 {async_task_id = dense<1> : vector<1xi32>} : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %86 = tt.dot %84, %85, %arg10, inputPrecision = tf32 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %87 = tt.addptr %arg11, %cst_2 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %88 = tt.addptr %arg12, %54 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1]> : vector<2xi32>} %86, %87, %88 : tensor<128x128xf32, #blocked>, tensor<128x256x!tt.ptr, #blocked2>, tensor<256x128x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1]> : vector<2xi32>} + %56 = arith.truncf %55#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + %57 = tt.expand_dims %21 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %58 = tt.splat %arg8 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %59 = arith.muli %58, %57 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %60 = tt.splat %arg2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %59 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %62 = tt.expand_dims %26 {async_task_id = dense<1> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %63 = tt.broadcast %61 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x128x!tt.ptr, #blocked1> + %64 = tt.broadcast %62 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<128x128xi32, #blocked1> + %65 = tt.addptr %63, %64 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked1>, tensor<128x128xi32, #blocked1> + %66 = tt.splat %arg3 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %67 = arith.cmpi slt, %57, %66 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %68 = tt.splat %arg4 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<1x128xi32, #blocked1> + %69 = arith.cmpi slt, %62, %68 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked1> + %70 = tt.broadcast %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi1, #blocked1> -> tensor<128x128xi1, #blocked1> + %71 = tt.broadcast %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked1> -> tensor<128x128xi1, #blocked1> + %72 = arith.andi %70, %71 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked1> + %73 = ttg.convert_layout %56 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked1> + tt.store %65, %73, %72 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + + +// CHECK-LABEL: @matmul_kernel_two_consumers +// CHECK: scf.if +// CHECK: ttng.reg_dealloc 40 +// CHECK: scf.for +// CHECK: ttng.producer_acquire +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttng.producer_commit +// CHECK: ttng.producer_acquire +// CHECK: ttng.producer_acquire +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttng.producer_commit +// CHECK: ttng.producer_commit +// CHECK: scf.if +// CHECK: ttng.reg_alloc 232 +// CHECK: ttng.consumer_wait +// CHECK: ttng.consumer_wait +// CHECK: ttng.warp_group_dot +// CHECK: ttng.consumer_release +// CHECK: ttng.consumer_release +// CHECK: scf.if +// CHECK: ttng.reg_alloc 232 +// CHECK: ttng.consumer_wait +// CHECK: ttng.consumer_wait +// CHECK: ttng.warp_group_dot +// CHECK: ttng.consumer_release +// CHECK: ttng.consumer_release + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_two_consumers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<64> : tensor<64x64xi32, #blocked> + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c8_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 8 : i32 + %cst_0 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf16, #blocked> + %cst_1 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x128xf16, #blocked1> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %cst_2 = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<0.000000e+00> : tensor<64x128xf32, #mma> + %0 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.divsi %1, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.divsi %3, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = arith.muli %4, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = arith.divsi %0, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = arith.muli %6, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = arith.subi %2, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %9 = arith.minsi %8, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %10 = arith.remsi %0, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %11 = arith.remsi %10, %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %12 = arith.addi %7, %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %13 = arith.divsi %10, %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %14 = arith.muli %12, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.splat %14 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %18 = tt.splat %14 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %19 = arith.addi %17, %15 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %20 = arith.addi %18, %16 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %21 = tt.splat %arg3 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %22 = arith.remsi %19, %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %23 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %24 = tt.make_range {async_task_id = dense<2> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %25 = arith.addi %17, %23 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %26 = arith.addi %18, %24 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.remsi %25, %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %28 = arith.muli %13, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %29 = tt.make_range {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %30 = tt.splat %28 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %31 = arith.addi %30, %29 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %32 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %33 = arith.remsi %31, %32 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %34 = tt.expand_dims %22 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %35 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %36 = arith.muli %34, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %37 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %38 = tt.expand_dims %37 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %39 = tt.broadcast %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %40 = tt.broadcast %38 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked> + %41 = arith.addi %39, %40 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64xi32, #blocked> + %42 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %43 = tt.addptr %42, %41 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %44 = tt.expand_dims %27 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %45 = arith.muli %44, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %46 = tt.broadcast %45 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %47 = arith.addi %46, %40 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64xi32, #blocked> + %48 = tt.addptr %42, %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %49 = tt.expand_dims %16 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %50 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %51 = arith.muli %49, %50 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %52 = tt.expand_dims %33 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %53 = tt.broadcast %51 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %54 = tt.broadcast %52 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %55 = arith.addi %53, %54 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128xi32, #blocked1> + %56 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %58 = arith.addi %arg5, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %59 = arith.divsi %58, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %60 = tt.expand_dims %37 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %61 = tt.expand_dims %16 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %62 = arith.muli %arg7, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %63 = tt.splat %62 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x128xi32, #blocked1> + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %true_3 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false_4 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %64:5 = scf.for %arg9 = %c0_i32 to %59 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %cst_2, %arg12 = %43, %arg13 = %57, %arg14 = %48) -> (tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked1>, tensor<64x64x!tt.ptr, #blocked>) : i32 { + %93 = arith.muli %arg9, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %94 = arith.subi %arg5, %93 {async_task_id = dense<0> : vector<1xi32>} : i32 + %95 = tt.splat %94 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x64xi32, #blocked> + %96 = arith.cmpi slt, %60, %95 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %97 = tt.broadcast %96 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked> + %98 = tt.load %arg12, %97, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked> + %99 = ttg.local_alloc %98 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #ttg.shared_memory> + %100 = tt.splat %94 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %101 = arith.cmpi slt, %61, %100 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %102 = tt.broadcast %101 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %103 = tt.load %arg13, %102, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + %104 = ttg.local_alloc %103 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #ttg.shared_memory> + %105 = tt.load %arg14, %97, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked> + %106 = ttg.local_alloc %105 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #ttg.shared_memory> + %107 = ttng.warp_group_dot %99, %104, %arg10 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x128xf16, #shared, #ttg.shared_memory> -> tensor<64x128xf32, #mma> + %108 = ttng.warp_group_dot %106, %104, %arg11 {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x128xf16, #shared, #ttg.shared_memory> -> tensor<64x128xf32, #mma> + %109 = tt.addptr %arg12, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %110 = tt.addptr %arg14, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %111 = tt.addptr %arg13, %63 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %107, %108, %109, %111, %110 : tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked1>, tensor<64x64x!tt.ptr, #blocked> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %65 = arith.truncf %64#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma> + %66 = arith.truncf %64#1 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma> + %67 = tt.expand_dims %20 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %68 = tt.splat %arg8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %69 = arith.muli %68, %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %70 = tt.splat %arg2 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %71 = tt.addptr %70, %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %72 = tt.expand_dims %31 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %73 = tt.broadcast %71 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x128x!tt.ptr, #blocked1> + %74 = tt.broadcast %72 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %75 = tt.addptr %73, %74 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %76 = tt.expand_dims %26 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %77 = arith.muli %68, %76 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %78 = tt.addptr %70, %77 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %79 = tt.broadcast %78 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x128x!tt.ptr, #blocked1> + %80 = tt.addptr %79, %74 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %81 = tt.splat %arg3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %82 = arith.cmpi slt, %67, %81 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %83 = tt.splat %arg4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<1x128xi32, #blocked1> + %84 = arith.cmpi slt, %72, %83 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked1> + %85 = tt.broadcast %82 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %86 = tt.broadcast %84 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %87 = arith.andi %85, %86 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xi1, #blocked1> + %88 = arith.cmpi slt, %76, %81 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %89 = tt.broadcast %88 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %90 = arith.andi %89, %86 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xi1, #blocked1> + %91 = ttg.convert_layout %65 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1> + tt.store %75, %91, %87 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + %92 = ttg.convert_layout %66 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1> + tt.store %80, %92, %90 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + tt.return + } +} + + +// ----- + +// CHECK-LABEL: @_matmul_layernorm_persistent_one_producer_one_consumer_one_epilog +// CHECK: %[[#TASKID:]] = ttng.get_async_task_id : i32 +// CHECK: %c0_i32_0 = arith.constant 0 : i32 +// CHECK: %[[#WG0:]] = arith.cmpi eq, %[[#TASKID]], %c0_i32_0 : i32 +// CHECK: scf.if %[[#WG0]] +// CHECK: ttng.reg_dealloc 40 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ttng.producer_acquire +// CHECK: ttng.barrier_expect +// CHECK: ttng.async_tma_copy_global_to_local +// CHECK: ttng.async_tma_copy_global_to_local +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %[[#WG1:]] = arith.cmpi eq, %[[#TASKID]], %c1_i32 : i32 +// CHECK: scf.if %[[#WG1]] +// CHECK: ttng.reg_alloc 232 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ttng.wait_barrier +// CHECK: ttg.local_load +// CHECK: ttg.local_load +// CHECK: ttng.warp_group_dot +// CHECK: ttng.consumer_release +// CHECK: ttng.producer_acquire +// CHECK: ttg.local_store +// CHECK: ttng.producer_commit +// CHECK: %c2_i32 = arith.constant 2 : i32 +// CHECK: %[[#WG2:]] = arith.cmpi eq, %[[#TASKID]], %c2_i32 : i32 +// CHECK: scf.if %[[#WG2]] +// CHECK: ttng.reg_alloc 232 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ttng.consumer_wait +// CHECK: ttg.local_load +// CHECK: ttng.consumer_release +// CHECK: tt.experimental_descriptor_store + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @_matmul_layernorm_persistent_one_producer_one_consumer_one_epilog(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: !tt.tensordesc>, %arg4: !tt.tensordesc>, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: f32) attributes {noinline = false} { + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c132_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 132 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 256 : i32 + %c255_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 255 : i32 + %cst = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<128x256xf32, #mma> + %cst_0 = arith.constant {async_task_id = dense<2> : vector<1xi32>} dense<1.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %0 = arith.addi %arg7, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.addi %arg5, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.divsi %2, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.addi %arg6, %c255_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = arith.divsi %4, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = arith.muli %3, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = arith.sitofp %arg6 {async_task_id = dense<2> : vector<1xi32>} : i32 to f32 + %9 = tt.splat %8 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %10 = tt.splat %arg11 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + scf.for %arg12 = %7 to %6 step %c132_i32 : i32 { + %11 = arith.muli %arg12, %c128_i32 {async_task_id = dense<[0, 2]> : vector<2xi32>} : i32 + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %12 = scf.for %arg13 = %c0_i32 to %1 step %c1_i32 iter_args(%arg14 = %cst) -> (tensor<128x256xf32, #mma>) : i32 { + %45 = arith.muli %arg13, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %46 = tt.experimental_descriptor_load %arg0[%11, %45] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %47 = ttg.local_alloc %46 {async_task_id = dense<1> : vector<1xi32>} : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %48 = tt.experimental_descriptor_load %arg1[%45, %c0_i32] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %49 = ttg.local_alloc %48 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> + %50 = ttng.warp_group_dot %47, %49, %arg14 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> -> tensor<128x256xf32, #mma> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %50 : tensor<128x256xf32, #mma> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %13 = "tt.reduce"(%12) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %45 = arith.addf %arg13, %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 + tt.reduce.return %45 {async_task_id = dense<2> : vector<1xi32>} : f32 + }) {async_task_id = dense<2> : vector<1xi32>} : (tensor<128x256xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %14 = arith.divf %13, %9 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %15 = tt.expand_dims %14 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %16 = tt.broadcast %15 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> + %17 = arith.subf %12, %16 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %18 = arith.mulf %17, %17 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %19 = "tt.reduce"(%18) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %45 = arith.addf %arg13, %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 + tt.reduce.return %45 {async_task_id = dense<2> : vector<1xi32>} : f32 + }) {async_task_id = dense<2> : vector<1xi32>} : (tensor<128x256xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %20 = arith.divf %19, %9 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %21 = arith.addf %20, %10 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %22 = math.sqrt %21 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %23 = arith.divf %cst_0, %22 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %24 = tt.experimental_descriptor_load %arg3[%c0_i32] {async_task_id = dense<2> : vector<1xi32>} : !tt.tensordesc> -> tensor<256xf16, #blocked2> + %25 = tt.experimental_descriptor_load %arg4[%c0_i32] {async_task_id = dense<2> : vector<1xi32>} : !tt.tensordesc> -> tensor<256xf16, #blocked2> + %26 = tt.expand_dims %23 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %27 = tt.broadcast %26 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> + %28 = arith.mulf %17, %27 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %29 = ttg.convert_layout %24 {async_task_id = dense<2> : vector<1xi32>} : tensor<256xf16, #blocked2> -> tensor<256xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> + %30 = tt.expand_dims %29 {async_task_id = dense<2> : vector<1xi32>, axis = 0 : i32} : tensor<256xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xf16, #blocked1> + %31 = ttg.convert_layout %30 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked1> -> tensor<1x256xf16, #blocked3> + %32 = arith.extf %31 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked3> to tensor<1x256xf32, #blocked3> + %33 = ttg.convert_layout %32 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #mma> + %34 = tt.broadcast %33 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #mma> -> tensor<128x256xf32, #mma> + %35 = arith.mulf %28, %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %36 = ttg.convert_layout %25 {async_task_id = dense<2> : vector<1xi32>} : tensor<256xf16, #blocked2> -> tensor<256xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> + %37 = tt.expand_dims %36 {async_task_id = dense<2> : vector<1xi32>, axis = 0 : i32} : tensor<256xf16, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xf16, #blocked1> + %38 = ttg.convert_layout %37 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked1> -> tensor<1x256xf16, #blocked3> + %39 = arith.extf %38 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked3> to tensor<1x256xf32, #blocked3> + %40 = ttg.convert_layout %39 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #mma> + %41 = tt.broadcast %40 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #mma> -> tensor<128x256xf32, #mma> + %42 = arith.addf %35, %41 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %43 = arith.truncf %42 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %44 = ttg.convert_layout %43 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.experimental_descriptor_store %arg2[%11, %c0_i32], %44 {async_task_id = dense<2> : vector<1xi32>} : !tt.tensordesc>, tensor<128x256xf16, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} + +// ----- + +// Verify that we can reuse buffers between two for loops +// CHECK-LABEL: @_attn_bwd_ws +// CHECK-DAG: ttg.local_alloc {allocation.shareGroup = 0 : i32} : () -> !ttg.memdesc<2x64x128xbf16 +// CHECK-DAG: ttg.local_alloc {allocation.shareGroup = 1 : i32} : () -> !ttg.memdesc<2x64x128xbf16 +// CHECK-DAG: ttg.local_alloc {allocation.shareGroup = 0 : i32} : () -> !ttg.memdesc<2x64x128xbf16 +// CHECK-DAG: ttg.local_alloc {allocation.shareGroup = 1 : i32} : () -> !ttg.memdesc<2x64x128xbf16 + +// CHECK: %[[TID:.*]] = ttng.get_async_task_id : i32 +// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK: %[[TWG0:.*]] = arith.cmpi eq, %[[TID]], %[[ZERO]] : i32 +// CHECK: scf.if %[[TWG0]] +// CHECK: ttng.reg_dealloc 40 +// CHECK: scf.if +// CHECK: scf.yield + +// CHECK: %[[IF_IDX:.*]] = scf.if +// CHECK: arith.divui %c0{{.*}} +// CHECK: arith.subi %c0{{.*}} +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX:.*]] = arith.addi %c0 +// CHECK: scf.yield {{.*}} %[[NEW_IDX]] +// CHECK: scf.yield {{.*}} %c0_ + +// CHECK: scf.if +// CHECK: arith.divui %[[IF_IDX]] +// CHECK: arith.subi %[[IF_IDX]] +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX2:.*]] = arith.addi %[[IF_IDX]] +// CHECK: scf.yield {{.*}} %[[NEW_IDX2]] +// CHECK: scf.yield {{.*}} %[[IF_IDX]] + +// CHECK: %[[ONE:.*]] = arith.constant 1 : i32 +// CHECK: %[[TWG1:.*]] = arith.cmpi eq, %[[TID]], %[[ONE]] : i32 +// CHECK: scf.if %[[TWG1]] +// CHECK: ttng.reg_alloc 232 +// CHECK: scf.if +// CHECK: scf.yield + +// CHECK: %[[IF_IDX_WG1:.*]] = scf.if +// CHECK: arith.divui %c0{{.*}} +// CHECK: arith.subi %c0{{.*}} +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX_WG1:.*]] = arith.addi %c0 +// CHECK: scf.yield {{.*}} %[[NEW_IDX_WG1]] +// CHECK: scf.yield {{.*}} %c0_ + +// CHECK: scf.if +// CHECK: arith.divui %[[IF_IDX_WG1]] +// CHECK: arith.subi %[[IF_IDX_WG1]] +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX2_WG1:.*]] = arith.addi %[[IF_IDX_WG1]] +// CHECK: scf.yield {{.*}} %[[NEW_IDX2_WG1]] +// CHECK: scf.yield {{.*}} %[[IF_IDX_WG1]] + +// CHECK: %[[TWO:.*]] = arith.constant 2 : i32 +// CHECK: %[[TWG2:.*]] = arith.cmpi eq, %[[TID]], %[[TWO]] : i32 +// CHECK: scf.if %[[TWG2]] +// CHECK: ttng.reg_alloc 232 +// CHECK: scf.if +// CHECK: scf.yield + +// CHECK: %[[IF_IDX_WG2:.*]] = scf.if +// CHECK: arith.divui %c0{{.*}} +// CHECK: arith.subi %c0{{.*}} +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX_WG2:.*]] = arith.addi %c0 +// CHECK: scf.yield {{.*}} %[[NEW_IDX_WG2]] +// CHECK: scf.yield {{.*}} %c0_ + +// CHECK: scf.if +// CHECK: arith.divui %[[IF_IDX_WG2]] +// CHECK: arith.subi %[[IF_IDX_WG2]] +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX2_WG2:.*]] = arith.addi %[[IF_IDX_WG2]] +// CHECK: scf.yield {{.*}} %[[NEW_IDX2_WG2]] +// CHECK: scf.yield {{.*}} %[[IF_IDX_WG2]] + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @_attn_bwd_ws(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.tensordesc>, %arg6: !tt.tensordesc>, %arg7: !tt.tensordesc>, %arg8: !tt.tensordesc>, %arg9: !tt.tensordesc>, %arg10: !tt.tensordesc>, %arg11: !tt.tensordesc>, %arg12: !tt.tensordesc>, %arg14: f32, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: !tt.ptr {tt.divisibility = 16 : i32}, %arg17: !tt.ptr {tt.divisibility = 16 : i32}, %arg18: !tt.ptr {tt.divisibility = 16 : i32}, %arg19: !tt.ptr {tt.divisibility = 16 : i32}, %arg20: !tt.ptr {tt.divisibility = 16 : i32}, %arg21: !tt.ptr {tt.divisibility = 16 : i32}, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32 {tt.divisibility = 16 : i32}, %arg25: i32 {tt.divisibility = 16 : i32}, %arg26: i32 {tt.divisibility = 16 : i32}, %arg27: i32 {tt.divisibility = 16 : i32}, %arg28: i32 {tt.divisibility = 16 : i32}, %arg29: i32, %arg30: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %false = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} false + %cst = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<0.000000e+00> : tensor<64x64xf32, #mma> + %cst_0 = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<128> : tensor<1x128xi32, #blocked> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c64_i64 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i64 + %c63_i64 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i64 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c0_i64 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i64 + %c1_i64 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i64 + %cst_1 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<64x128xf32, #mma1> + %cst_2 = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<0.693147182> : tensor<64x128xf32, #mma1> + %0 = tt.get_program_id z {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %arg29 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.remsi %0, %arg29 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = tt.addptr %arg1, %1 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %5 = tt.load %4 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %6 = tt.addptr %4, %c1_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %7 = tt.load %6 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %8 = arith.subi %7, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %9 = tt.addptr %arg3, %1 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %10 = tt.load %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %11 = tt.addptr %9, %c1_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %12 = tt.load %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %13 = arith.subi %12, %10 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %14 = arith.muli %3, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.make_range {async_task_id = dense<2> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %17 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> + %18 = tt.make_range {async_task_id = dense<2> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #blocked1> + %19 = arith.extsi %14 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %20 = arith.cmpi sle, %19, %13 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %21 = arith.cmpi sle, %19, %8 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %22 = arith.ori %20, %21 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i1 + %23:5 = scf.if %22 -> (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr) { + %27 = tt.addptr %arg16, %1 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %28 = tt.load %27 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %29 = arith.extsi %2 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %30 = arith.extsi %arg26 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %31 = arith.muli %29, %30 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %32 = arith.addi %31, %28 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %33 = arith.extsi %arg24 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %34 = arith.muli %29, %33 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %35 = arith.extsi %arg22 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %36 = arith.muli %5, %35 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %37 = arith.addi %34, %36 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %38 = arith.muli %2, %arg25 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %39 = arith.extsi %arg23 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %40 = arith.muli %10, %39 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %41 = arith.extsi %38 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %42 = arith.addi %41, %40 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %43 = tt.addptr %arg17, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + %44 = tt.addptr %arg18, %42 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + %45 = tt.addptr %arg19, %42 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + %46 = tt.addptr %arg20, %32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + %47 = tt.addptr %arg21, %32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %43, %44, %45, %46, %47 : !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr + } else { + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %arg17, %arg18, %arg19, %arg20, %arg21 : !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %24 = arith.extsi %14 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %25 = arith.cmpi slt, %24, %13 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + scf.if %25 { + %27 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %28 = tt.splat %14 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %29 = arith.addi %27, %15 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %30 = arith.addi %28, %16 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %31 = arith.extsi %14 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %32 = arith.addi %10, %31 {async_task_id = dense<0> : vector<1xi32>} : i64 + %33 = arith.trunci %32 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %34 = arith.addi %33, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %35 = arith.addi %33, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %36 = arith.muli %2, %arg25 {async_task_id = dense<0> : vector<1xi32>} : i32 + %37 = tt.experimental_descriptor_load %arg6[%33, %36] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %38 = tt.experimental_descriptor_load %arg6[%35, %36] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %39 = ttg.local_alloc %37 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %40 = ttg.local_alloc %38 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %41 = tt.experimental_descriptor_load %arg7[%33, %36] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %42 = tt.experimental_descriptor_load %arg7[%34, %36] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %43 = ttg.local_alloc %41 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %44 = ttg.local_alloc %42 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %45 = arith.addi %8, %c63_i64 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %46 = arith.divsi %45, %c64_i64 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %47 = arith.extsi %2 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %48 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %49 = arith.addi %5, %c64_i64 {async_task_id = dense<0> : vector<1xi32>} : i64 + %50 = arith.trunci %49 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %51 = arith.extsi %arg24 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %52 = arith.muli %47, %51 {async_task_id = dense<0> : vector<1xi32>} : i64 + %53 = arith.trunci %52 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %54 = tt.splat %8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i64 -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #mma}>> + %55 = tt.splat %23#3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> + %56 = tt.splat %arg14 {async_task_id = dense<1> : vector<1xi32>} : f32 -> tensor<64x64xf32, #mma> + %57 = tt.splat %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<64x64xf32, #mma> + %58 = tt.splat %23#4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> + %59:3 = scf.for %arg31 = %c0_i64 to %46 step %c1_i64 iter_args(%arg32 = %c0_i32, %arg33 = %cst_1, %arg35 = %cst_1) -> (i32, tensor<64x128xf32, #mma1>, tensor<64x128xf32, #mma1>) : i64 { + %111 = tt.splat %arg32 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %112 = arith.addi %111, %48 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %113 = tt.experimental_descriptor_load %arg5[%50, %53] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %114 = ttg.local_alloc %113 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %115 = ttg.memdesc_trans %114 {async_task_id = dense<[1, 2]> : vector<2xi32>, order = array} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> -> !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> + %116 = arith.extsi %112 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #mma}>> + %117 = arith.cmpi slt, %116, %54 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #mma}>> + %118 = tt.addptr %55, %112 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %119 = tt.load %118, %117 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> + %120 = ttng.warp_group_dot %39, %115, %cst, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> -> tensor<64x64xf32, #mma> + %121 = ttng.warp_group_dot %40, %115, %cst, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> -> tensor<64x64xf32, #mma> + %122 = arith.mulf %120, %56 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %123 = arith.mulf %121, %57 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %124 = tt.experimental_descriptor_load %arg8[%50, %53] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %125 = ttg.local_alloc %124 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %126 = ttg.memdesc_trans %125 {async_task_id = dense<[1, 2]> : vector<2xi32>, order = array} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> -> !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> + %127 = ttng.warp_group_dot %43, %126, %cst, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> -> tensor<64x64xf32, #mma> + %128 = ttng.warp_group_dot %44, %126, %cst, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> -> tensor<64x64xf32, #mma> + %129 = tt.expand_dims %119 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> + %130 = tt.broadcast %129 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> + %131 = tt.broadcast %129 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> + %132 = arith.subf %122, %130 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %133 = arith.subf %123, %131 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %134 = math.exp2 %132 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %135 = math.exp2 %133 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %136 = arith.truncf %134 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xbf16, #mma> + %137 = arith.truncf %135 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xbf16, #mma> + %138 = tt.addptr %58, %112 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %139 = tt.load %138, %117 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> + %140 = ttg.convert_layout %136 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xbf16, #mma> -> tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %141 = ttg.convert_layout %137 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xbf16, #mma> -> tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %142 = ttng.warp_group_dot %140, %125, %arg33 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> -> tensor<64x128xf32, #mma1> + %143 = ttng.warp_group_dot %141, %125, %arg35 {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> -> tensor<64x128xf32, #mma1> + %157 = arith.addi %arg32, %c64_i32 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 + scf.yield {async_task_id = dense<[1, 2]> : vector<2xi32>} %157, %142, %143 : i32, tensor<64x128xf32, #mma1>, tensor<64x128xf32, #mma1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, tt.num_stages = 2 : i32} + %60 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %61 = tt.expand_dims %60 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %62 = arith.cmpi slt, %61, %cst_0 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked> + %63 = tt.expand_dims %29 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %64 = tt.expand_dims %30 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %65 = arith.extsi %63 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %66 = arith.extsi %64 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %67 = tt.splat %13 {async_task_id = dense<1> : vector<1xi32>} : i64 -> tensor<64x1xi64, #blocked> + %68 = tt.splat %13 {async_task_id = dense<2> : vector<1xi32>} : i64 -> tensor<64x1xi64, #blocked> + %69 = arith.cmpi slt, %65, %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi64, #blocked> + %70 = arith.cmpi slt, %66, %68 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi64, #blocked> + %71 = tt.broadcast %62 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<64x128xi1, #blocked> + %72 = tt.broadcast %62 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<64x128xi1, #blocked> + %73 = tt.broadcast %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> + %74 = tt.broadcast %70 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> + %75 = arith.andi %71, %73 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xi1, #blocked> + %76 = arith.andi %72, %74 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xi1, #blocked> + %77 = tt.splat %arg23 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %78 = tt.splat %arg23 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %79 = arith.muli %63, %77 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %80 = arith.muli %64, %78 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %81 = tt.splat %23#2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %82 = tt.splat %23#2 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %83 = tt.addptr %81, %79 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %84 = tt.addptr %82, %80 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %85 = tt.broadcast %83 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %86 = tt.broadcast %84 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %87 = tt.broadcast %61 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> + %88 = tt.broadcast %61 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> + %89 = tt.addptr %85, %87 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %90 = tt.addptr %86, %88 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %91 = arith.truncf %59#1 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> to tensor<64x128xbf16, #mma1> + %92 = arith.truncf %59#2 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> to tensor<64x128xbf16, #mma1> + %93 = ttg.convert_layout %91 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xbf16, #mma1> -> tensor<64x128xbf16, #blocked> + %94 = ttg.convert_layout %92 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xbf16, #mma1> -> tensor<64x128xbf16, #blocked> + tt.store %89, %93, %75 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked> + tt.store %90, %94, %76 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %26 = arith.cmpi slt, %24, %8 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + scf.if %26 { + %27 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %28 = tt.splat %14 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %29 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64xi32, #blocked1> + %30 = tt.splat %14 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64xi32, #blocked1> + %31 = arith.addi %27, %15 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %32 = arith.addi %28, %16 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %33 = arith.addi %29, %17 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #blocked1> + %34 = arith.addi %30, %18 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #blocked1> + %35 = arith.extsi %2 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %36 = arith.extsi %14 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %37 = arith.addi %5, %36 {async_task_id = dense<0> : vector<1xi32>} : i64 + %38 = arith.trunci %37 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %39 = arith.addi %38, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %40 = arith.addi %38, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %41 = arith.extsi %arg24 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %42 = arith.muli %35, %41 {async_task_id = dense<0> : vector<1xi32>} : i64 + %43 = arith.trunci %42 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %44 = tt.experimental_descriptor_load %arg9[%38, %43] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %45 = tt.experimental_descriptor_load %arg9[%40, %43] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %46 = ttg.local_alloc %44 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %47 = ttg.local_alloc %45 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %48 = arith.extsi %arg28 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %49 = arith.muli %35, %48 {async_task_id = dense<0> : vector<1xi32>} : i64 + %50 = arith.trunci %49 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %51 = tt.experimental_descriptor_load %arg12[%38, %50] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %52 = tt.experimental_descriptor_load %arg12[%39, %50] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %53 = ttg.local_alloc %51 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %54 = ttg.local_alloc %52 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %55 = arith.extsi %33 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #blocked1> to tensor<64xi64, #blocked1> + %56 = arith.extsi %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #blocked1> to tensor<64xi64, #blocked1> + %57 = tt.splat %8 {async_task_id = dense<1> : vector<1xi32>} : i64 -> tensor<64xi64, #blocked1> + %58 = tt.splat %8 {async_task_id = dense<2> : vector<1xi32>} : i64 -> tensor<64xi64, #blocked1> + %59 = arith.cmpi slt, %55, %57 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi64, #blocked1> + %60 = arith.cmpi slt, %56, %58 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi64, #blocked1> + %61 = tt.splat %23#3 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %62 = tt.splat %23#3 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %63 = tt.addptr %61, %33 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + %64 = tt.addptr %62, %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + %65 = tt.load %63, %59 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1> + %66 = tt.load %64, %60 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1> + %67 = ttg.convert_layout %65 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xf32, #blocked1> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %68 = ttg.convert_layout %66 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xf32, #blocked1> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %69 = tt.expand_dims %67 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xf32, #blocked3> + %70 = tt.expand_dims %68 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xf32, #blocked3> + %71 = arith.addi %13, %c63_i64 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %72 = arith.divsi %71, %c64_i64 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %73 = tt.splat %23#4 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %74 = tt.splat %23#4 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %75 = tt.addptr %73, %33 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + %76 = tt.addptr %74, %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + %77 = tt.load %75, %59 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1> + %78 = tt.load %76, %60 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1> + %79 = arith.trunci %10 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %80 = arith.muli %2, %arg25 {async_task_id = dense<0> : vector<1xi32>} : i32 + %81 = tt.splat %arg14 {async_task_id = dense<1> : vector<1xi32>} : f32 -> tensor<64x64xf32, #mma> + %82 = tt.splat %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<64x64xf32, #mma> + %83 = tt.broadcast %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xf32, #blocked3> -> tensor<64x64xf32, #blocked3> + %84 = tt.broadcast %70 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xf32, #blocked3> -> tensor<64x64xf32, #blocked3> + %85 = ttg.convert_layout %83 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + %86 = ttg.convert_layout %84 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + %87 = ttg.convert_layout %77 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xf32, #blocked1> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %88 = ttg.convert_layout %78 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xf32, #blocked1> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %89 = tt.expand_dims %87 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xf32, #blocked3> + %90 = tt.expand_dims %88 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xf32, #blocked3> + %91 = tt.broadcast %89 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xf32, #blocked3> -> tensor<64x64xf32, #blocked3> + %92 = tt.broadcast %90 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xf32, #blocked3> -> tensor<64x64xf32, #blocked3> + %93 = ttg.convert_layout %91 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + %94 = ttg.convert_layout %92 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + %95 = tt.splat %arg14 {async_task_id = dense<1> : vector<1xi32>} : f32 -> tensor<64x128xf32, #mma1> + %96 = tt.splat %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<64x128xf32, #mma1> + %97:2 = scf.for %arg31 = %c0_i64 to %72 step %c1_i64 iter_args(%arg32 = %cst_1, %arg33 = %cst_1) -> (tensor<64x128xf32, #mma1>, tensor<64x128xf32, #mma1>) : i64 { + %135 = tt.experimental_descriptor_load %arg10[%79, %80] {async_task_id = dense<0> : vector<1xi32>} : !tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %136 = ttg.local_alloc %135 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %137 = ttg.memdesc_trans %136 {async_task_id = dense<[1, 2]> : vector<2xi32>, order = array} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> -> !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> + %138 = tt.experimental_descriptor_load %arg11[%79, %80] {async_task_id = dense<0> : vector<1xi32>} :!tt.tensordesc> -> tensor<64x128xbf16, #blocked2> + %139 = ttg.local_alloc %138 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xbf16, #blocked2>) -> !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> + %140 = ttg.memdesc_trans %139 {async_task_id = dense<[1, 2]> : vector<2xi32>, order = array} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> -> !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> + %141 = ttng.warp_group_dot %46, %137, %cst, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> -> tensor<64x64xf32, #mma> + %142 = ttng.warp_group_dot %47, %137, %cst, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> -> tensor<64x64xf32, #mma> + %143 = arith.mulf %141, %81 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %144 = arith.mulf %142, %82 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %145 = ttng.warp_group_dot %53, %140, %cst, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> -> tensor<64x64xf32, #mma> + %146 = ttng.warp_group_dot %54, %140, %cst, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x64xbf16, #shared1, #ttg.shared_memory> -> tensor<64x64xf32, #mma> + %147 = arith.subf %143, %85 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %148 = arith.subf %144, %86 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %149 = math.exp2 %147 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %150 = math.exp2 %148 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %151 = arith.subf %145, %93 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %152 = arith.subf %146, %94 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %153 = arith.mulf %149, %151 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %154 = arith.mulf %150, %152 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %155 = arith.truncf %153 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xbf16, #mma> + %156 = arith.truncf %154 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xbf16, #mma> + %157 = ttg.convert_layout %155 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xbf16, #mma> -> tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %158 = ttg.convert_layout %156 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xbf16, #mma> -> tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %159 = ttng.warp_group_dot %157, %136, %cst_1, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> -> tensor<64x128xf32, #mma1> + %160 = ttng.warp_group_dot %158, %136, %cst_1, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : tensor<64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #ttg.shared_memory> -> tensor<64x128xf32, #mma1> + %161 = arith.mulf %159, %95 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %162 = arith.mulf %160, %96 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %163 = arith.addf %arg32, %161 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %164 = arith.addf %arg33, %162 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> + scf.yield {async_task_id = dense<[1, 2]> : vector<2xi32>} %163, %164 : tensor<64x128xf32, #mma1>, tensor<64x128xf32, #mma1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, tt.num_stages = 2 : i32} + %98 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %99 = tt.expand_dims %98 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %100 = arith.cmpi slt, %99, %cst_0 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked> + %101 = tt.expand_dims %31 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %102 = tt.expand_dims %32 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %103 = arith.extsi %101 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %104 = arith.extsi %102 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %105 = tt.splat %8 {async_task_id = dense<1> : vector<1xi32>} : i64 -> tensor<64x1xi64, #blocked> + %106 = tt.splat %8 {async_task_id = dense<2> : vector<1xi32>} : i64 -> tensor<64x1xi64, #blocked> + %107 = arith.cmpi slt, %103, %105 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi64, #blocked> + %108 = arith.cmpi slt, %104, %106 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi64, #blocked> + %109 = tt.broadcast %100 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<64x128xi1, #blocked> + %110 = tt.broadcast %100 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<64x128xi1, #blocked> + %111 = tt.broadcast %107 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> + %112 = tt.broadcast %108 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> + %113 = arith.andi %109, %111 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xi1, #blocked> + %114 = arith.andi %110, %112 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xi1, #blocked> + %115 = tt.splat %arg22 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %116 = tt.splat %arg22 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %117 = arith.muli %101, %115 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %118 = arith.muli %102, %116 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %119 = tt.splat %23#0 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %120 = tt.splat %23#0 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %121 = tt.addptr %119, %117 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %122 = tt.addptr %120, %118 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %123 = tt.broadcast %121 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %124 = tt.broadcast %122 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %125 = tt.broadcast %99 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> + %126 = tt.broadcast %99 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> + %127 = tt.addptr %123, %125 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %128 = tt.addptr %124, %126 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %129 = arith.mulf %97#0, %cst_2 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %130 = arith.mulf %97#1, %cst_2 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %131 = arith.truncf %129 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> to tensor<64x128xbf16, #mma1> + %132 = arith.truncf %130 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> to tensor<64x128xbf16, #mma1> + %133 = ttg.convert_layout %131 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xbf16, #mma1> -> tensor<64x128xbf16, #blocked> + %134 = ttg.convert_layout %132 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xbf16, #mma1> -> tensor<64x128xbf16, #blocked> + tt.store %127, %133, %113 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked> + tt.store %128, %134, %114 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir new file mode 100644 index 000000000..d05f14365 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir @@ -0,0 +1,136 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-data-partition=num-consumer-groups=2 | FileCheck %s + +// CHECK-LABEL: @matmul_persistent_ws_cooperative_kernel +// CHECK: %[[#GA1:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr +// CHECK: %[[#GA2:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr +// CHECK: %[[#LA1:]] = ttg.local_alloc %[[#GA1]] +// CHECK: %[[#LA2:]] = ttg.local_alloc %[[#GA2]] +// CHECK: %[[#GB:]] = tt.load {{.*}} : tensor<64x256x!tt.ptr +// CHECK: %[[#LB:]] = ttg.local_alloc %[[#GB]] +// CHECK: %[[#C1:]] = ttng.warp_group_dot %[[#LA1]], %[[#LB]], {{.*}} : !ttg.memdesc<64x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<64x256xf32, #mma> +// CHECK: %[[#C2:]] = ttng.warp_group_dot %[[#LA2]], %[[#LB]], {{.*}} : !ttg.memdesc<64x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<64x256xf32, #mma> +// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr, #blocked1> +// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr, #blocked1> + + + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_persistent_ws_cooperative_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<64> : tensor<128x64xi32, #blocked> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c255_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 255 : i32 + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 256 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c8_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 8 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %cst_0 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %cst_1 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x256xf16, #blocked1> + %cst_2 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.addi %arg4, %c255_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.divsi %2, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.muli %1, %3 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = tt.get_num_programs x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = arith.muli %3, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %9 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %10 = tt.splat %arg3 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %11 = tt.make_range {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %12 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %13 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %14 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %15 = tt.expand_dims %14 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %16 = tt.broadcast %15 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %17 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %18 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.expand_dims %18 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %20 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %21 = arith.muli %19, %20 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %22 = tt.broadcast %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %23 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> + %24 = arith.addi %arg5, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %25 = arith.divsi %24, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %26 = tt.expand_dims %14 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %27 = tt.expand_dims %18 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %28 = arith.muli %arg7, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %29 = tt.splat %28 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x256xi32, #blocked1> + %30 = tt.splat %arg8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %31 = tt.splat %arg2 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %32 = tt.splat %arg3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %33 = tt.splat %arg4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<1x256xi32, #blocked1> + scf.for %arg9 = %5 to %4 step %6 : i32 { + %34 = arith.divsi %arg9, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %35 = arith.muli %34, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %36 = arith.subi %1, %35 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %37 = arith.minsi %36, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %38 = arith.remsi %arg9, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %39 = arith.remsi %38, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %40 = arith.addi %35, %39 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %41 = arith.divsi %38, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %42 = arith.muli %40, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %43 = tt.splat %42 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %44 = tt.splat %42 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %45 = arith.addi %43, %8 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %46 = arith.addi %44, %9 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %47 = arith.remsi %45, %10 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %48 = arith.muli %41, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %49 = tt.splat %48 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %50 = arith.addi %49, %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %51 = arith.remsi %50, %12 {async_task_id = dense<0> : vector<1xi32>} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %52 = tt.expand_dims %47 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %53 = arith.muli %52, %13 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %54 = tt.broadcast %53 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked> + %55 = arith.addi %54, %16 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64xi32, #blocked> + %56 = tt.addptr %17, %55 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %57 = tt.expand_dims %51 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %58 = tt.broadcast %57 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %59 = arith.addi %22, %58 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256xi32, #blocked1> + %60 = tt.addptr %23, %59 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %61:3 = scf.for %arg10 = %c0_i32 to %25 step %c1_i32 iter_args(%arg11 = %cst_2, %arg12 = %56, %arg13 = %60) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { + %76 = arith.muli %arg10, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %77 = arith.subi %arg5, %76 {async_task_id = dense<0> : vector<1xi32>} : i32 + %78 = tt.splat %77 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x64xi32, #blocked> + %79 = arith.cmpi slt, %26, %78 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %80 = tt.broadcast %79 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi1, #blocked> -> tensor<128x64xi1, #blocked> + %81 = tt.load %arg12, %80, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked> + %82 = ttg.local_alloc %81 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %83 = tt.splat %77 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %84 = arith.cmpi slt, %27, %83 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %85 = tt.broadcast %84 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %86 = tt.load %arg13, %85, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1> + %87 = ttg.local_alloc %86 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> + %88 = ttng.warp_group_dot %82, %87, %arg11 {async_task_id = dense<[1, 2]> : vector<2xi32>, inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> -> tensor<128x256xf32, #mma> + %89 = tt.addptr %arg12, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %90 = tt.addptr %arg13, %29 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %88, %89, %90 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %62 = arith.truncf %61#0 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %63 = tt.expand_dims %46 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %64 = arith.muli %30, %63 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi32, #blocked1> + %65 = tt.addptr %31, %64 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %66 = tt.expand_dims %50 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %67 = tt.broadcast %65 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> + %68 = tt.broadcast %66 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %69 = tt.addptr %67, %68 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %70 = arith.cmpi slt, %63, %32 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi32, #blocked1> + %71 = arith.cmpi slt, %66, %33 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi32, #blocked1> + %72 = tt.broadcast %70 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %73 = tt.broadcast %71 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %74 = arith.andi %72, %73 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xi1, #blocked1> + %75 = ttg.convert_layout %62 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.store %69, %75, %74 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir new file mode 100644 index 000000000..b40703f66 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir @@ -0,0 +1,237 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-lowering=num-consumer-groups=1 | FileCheck %s + +// CHECK: %[[#PBARRIER:]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 +// CHECK: %[[#CBARRIER:]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64 +// CHECK: %[[#]] = ttg.memdesc_subview %[[#PBARRIER]][%c0_i32] +// CHECK: ttng.init_barrier %[[#]], 128 +// CHECK: %[[#]] = ttg.memdesc_subview %[[#CBARRIER]][%c0_i32] +// CHECK: ttng.init_barrier %[[#]], 1 +// CHECK: scf.for +// CHECK: %[[#]] = ttg.memdesc_subview %[[#CBARRIER]] +// CHECK: ttng.wait_barrier %[[#]] +// CHECK: ttg.async_copy_global_to_local +// CHECK: ttg.async_copy_global_to_local +// CHECK: %[[#]] = ttg.memdesc_subview %[[#PBARRIER]] +// CHECK: ttng.mbarrier_arrive %[[#]] +// CHECK: scf.for +// CHECK: %[[#]] = ttg.memdesc_subview %[[#PBARRIER]] +// CHECK: ttng.wait_barrier %[[#]] +// CHECK: ttg.local_load +// CHECK: ttg.local_load +// CHECK: tt.dot +// CHECK: %[[#]] = ttg.memdesc_subview %[[#CBARRIER]] +// CHECK: ttng.mbarrier_arrive %[[#]] + + + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<1x128x256xf16, #shared, #ttg.shared_memory, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf16, #shared, #ttg.shared_memory, mutable> + %2 = ttng.create_token {loadType = 1 : i32, num = 1 : i32} : tensor<1x!ttng.token> + %3 = ttng.get_async_task_id : i32 + %c0_i32 = arith.constant 0 : i32 + %4 = arith.cmpi eq, %3, %c0_i32 : i32 + scf.if %4 { + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked> + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked1> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_3 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked1> + %6 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %10, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.subi %8, %13 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.minsi %14, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.remsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.remsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.addi %13, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.divsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.muli %18, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %23 = tt.splat %20 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %24 = arith.addi %23, %21 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %26 = arith.remsi %24, %25 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.muli %19, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %28 = tt.splat %27 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %22 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %30 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %31 = arith.remsi %29, %30 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %32 = tt.expand_dims %26 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %33 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %34 = arith.muli %32, %33 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %35 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %36 = tt.expand_dims %35 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %37 = tt.broadcast %34 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %38 = tt.broadcast %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %39 = arith.addi %37, %38 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256xi32, #blocked1> + %40 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked1> + %41 = tt.addptr %40, %39 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %42 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %43 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %44 = tt.expand_dims %42 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %45 = tt.expand_dims %43 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %46 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked> + %47 = arith.muli %44, %46 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> + %48 = tt.expand_dims %31 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %49 = tt.broadcast %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> -> tensor<256x128xi32, #blocked> + %50 = tt.broadcast %48 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<256x128xi32, #blocked> + %51 = arith.addi %49, %50 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128xi32, #blocked> + %52 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked> + %53 = tt.addptr %52, %51 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked>, tensor<256x128xi32, #blocked> + %54 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %55 = arith.divsi %54, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %56 = arith.muli %arg7, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %57 = tt.splat %56 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x128xi32, #blocked> + %c1_i32_4 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_5 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} false + %58:4 = scf.for %arg9 = %c0_i32_1 to %55 step %c1_i32_0 iter_args(%arg10 = %41, %arg11 = %53, %arg12 = %false, %arg13 = %c0_i32_5) -> (tensor<128x256x!tt.ptr, #blocked1>, tensor<256x128x!tt.ptr, #blocked>, i1, i32) : i32 { + %59 = arith.muli %arg9, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %60 = arith.subi %arg5, %59 {async_task_id = dense<0> : vector<1xi32>} : i32 + %61 = tt.splat %60 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x256xi32, #blocked1> + %62 = arith.cmpi slt, %36, %61 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> + %63 = tt.broadcast %62 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + ttng.producer_acquire %2, %arg13, %false {async_task_id = dense<0> : vector<1xi32>} : tensor<1x!ttng.token>, i32, i1 + %c0_i32_6 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %c1_i32_7 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %64 = ttg.memdesc_subview %0[%arg13, %c0_i32_6, %c0_i32_6] {async_task_id = dense<0> : vector<1xi32>} : !ttg.memdesc<1x128x256xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x256xf16, #shared, #ttg.shared_memory, mutable> + %65 = ttg.async_copy_global_to_local %arg10, %64 mask %63 other %cst_2 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1> -> <128x256xf16, #shared, #ttg.shared_memory, mutable> + %66 = tt.splat %60 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked> + %67 = arith.cmpi slt, %45, %66 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> + %68 = tt.broadcast %67 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi1, #blocked> -> tensor<256x128xi1, #blocked> + %c0_i32_8 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %c1_i32_9 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %69 = ttg.memdesc_subview %1[%arg13, %c0_i32_8, %c0_i32_8] {async_task_id = dense<0> : vector<1xi32>} : !ttg.memdesc<1x256x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> + %70 = ttg.async_copy_global_to_local %arg11, %69 mask %68 other %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked> -> <256x128xf16, #shared, #ttg.shared_memory, mutable> + ttng.producer_commit %2, %arg13 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x!ttng.token>, i32 + %71 = tt.addptr %arg10, %cst_3 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %72 = tt.addptr %arg11, %57 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked>, tensor<256x128xi32, #blocked> + %c1_i32_10 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %c0_i32_11 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_task_id = dense<0> : vector<1xi32>} true + %73 = arith.addi %arg13, %c1_i32_10 {async_task_id = dense<0> : vector<1xi32>} : i32 + %74 = arith.cmpi uge, %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %75 = arith.cmpi ult, %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %76 = arith.subi %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %77 = arith.select %74, %76, %73 {async_task_id = dense<0> : vector<1xi32>} : i32 + %78 = arith.xori %arg12, %true {async_task_id = dense<0> : vector<1xi32>} : i1 + %79 = arith.andi %74, %78 {async_task_id = dense<0> : vector<1xi32>} : i1 + %80 = arith.andi %75, %arg12 {async_task_id = dense<0> : vector<1xi32>} : i1 + %81 = arith.ori %79, %80 {async_task_id = dense<0> : vector<1xi32>} : i1 + scf.yield {async_task_id = dense<0> : vector<1xi32>} %71, %72, %81, %77 : tensor<128x256x!tt.ptr, #blocked1>, tensor<256x128x!tt.ptr, #blocked>, i1, i32 + } {async_task_id = dense<0> : vector<1xi32>} + } {async_task_id = dense<0> : vector<1xi32>} + %c1_i32 = arith.constant 1 : i32 + %5 = arith.cmpi eq, %3, %c1_i32 : i32 + scf.if %5 { + %cst = arith.constant {async_task_id = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #blocked2> + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked> + %cst_3 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked1> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_4 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked1> + %6 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %10, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.subi %8, %13 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.minsi %14, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.remsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.remsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.addi %13, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.divsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.muli %18, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %23 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %20 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.splat %20 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %26 = arith.addi %24, %21 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %25, %22 {async_task_id = dense<1> : vector<1xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %28 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %29 = arith.remsi %26, %28 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %30 = arith.muli %19, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %31 = tt.splat %30 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %32 = arith.addi %31, %23 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %33 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %34 = arith.divsi %33, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %c1_i32_5 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_6 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} false + %35:3 = scf.for %arg9 = %c0_i32_1 to %34 step %c1_i32_0 iter_args(%arg10 = %cst, %arg11 = %false, %arg12 = %c0_i32_6) -> (tensor<128x128xf32, #blocked2>, i1, i32) : i32 { + %c0_i32_7 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_8 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_9 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_10 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + ttng.consumer_wait %2, %arg12, %false {async_task_id = dense<1> : vector<1xi32>} : tensor<1x!ttng.token>, i32, i1 + %54 = ttg.memdesc_subview %0[%arg12, %c0_i32_7, %c0_i32_7] {async_task_id = dense<1> : vector<1xi32>} : !ttg.memdesc<1x128x256xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x256xf16, #shared, #ttg.shared_memory, mutable> + %55 = ttg.local_load %54 {async_task_id = dense<1> : vector<1xi32>} : !ttg.memdesc<128x256xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x256xf16, #blocked1> + %56 = ttg.convert_layout %55 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #blocked1> -> tensor<128x256xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> + %57 = ttg.memdesc_subview %1[%arg12, %c0_i32_9, %c0_i32_9] {async_task_id = dense<1> : vector<1xi32>} : !ttg.memdesc<1x256x128xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> + %58 = ttg.local_load %57 {async_task_id = dense<1> : vector<1xi32>} : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf16, #blocked> + %59 = ttg.convert_layout %58 {async_task_id = dense<1> : vector<1xi32>} : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> + %60 = tt.dot %56, %59, %arg10, inputPrecision = tf32 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<256x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> + ttng.consumer_release %2, %arg12 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x!ttng.token>, i32 + %c1_i32_11 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_12 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_task_id = dense<1> : vector<1xi32>} true + %61 = arith.addi %arg12, %c1_i32_11 {async_task_id = dense<1> : vector<1xi32>} : i32 + %62 = arith.cmpi uge, %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %63 = arith.cmpi ult, %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %64 = arith.subi %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %65 = arith.select %62, %64, %61 {async_task_id = dense<1> : vector<1xi32>} : i32 + %66 = arith.xori %arg11, %true {async_task_id = dense<1> : vector<1xi32>} : i1 + %67 = arith.andi %62, %66 {async_task_id = dense<1> : vector<1xi32>} : i1 + %68 = arith.andi %63, %arg11 {async_task_id = dense<1> : vector<1xi32>} : i1 + %69 = arith.ori %67, %68 {async_task_id = dense<1> : vector<1xi32>} : i1 + scf.yield {async_task_id = dense<1> : vector<1xi32>} %60, %69, %65 : tensor<128x128xf32, #blocked2>, i1, i32 + } {async_task_id = dense<1> : vector<1xi32>} + %36 = arith.truncf %35#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf32, #blocked2> to tensor<128x128xf16, #blocked2> + %37 = tt.expand_dims %27 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %38 = tt.splat %arg8 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %39 = arith.muli %38, %37 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %40 = tt.splat %arg2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %41 = tt.addptr %40, %39 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %42 = tt.expand_dims %32 {async_task_id = dense<1> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %43 = tt.broadcast %41 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> + %44 = tt.broadcast %42 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> + %45 = tt.addptr %43, %44 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> + %46 = tt.splat %arg3 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %47 = arith.cmpi slt, %37, %46 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %48 = tt.splat %arg4 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<1x128xi32, #blocked> + %49 = arith.cmpi slt, %42, %48 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> + %50 = tt.broadcast %47 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> + %51 = tt.broadcast %49 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> + %52 = arith.andi %50, %51 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked> + %53 = ttg.convert_layout %36 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked> + tt.store %45, %53, %52 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked> + } {async_task_id = dense<1> : vector<1xi32>} + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_task_partition.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_task_partition.mlir new file mode 100644 index 000000000..d8d7b09ad --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/WarpSpecialization/ws_task_partition.mlir @@ -0,0 +1,64 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-task-partition=num-consumer-groups=2 | FileCheck %s + +// CHECK-LABEL: @matmul_persistent_tma_ws_cooperative_kernel +// CHECK: %[[#GA:]] = tt.experimental_descriptor_load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: %[[#LA:]] = ttg.local_alloc %[[#GA]] +// CHECK: %[[#GB:]] = tt.experimental_descriptor_load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: %[[#LB:]] = ttg.local_alloc %[[#GB]] +// CHECK: %[[#C:]] = ttng.warp_group_dot %[[#LA]], %[[#LB]], {{.*}} {async_task_id = dense<[1, 2]> : vector<2xi32> + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_persistent_tma_ws_cooperative_kernel(%arg0: !tt.tensordesc>, %arg1: !tt.tensordesc>, %arg2: !tt.tensordesc>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c127_i32 = arith.constant 127 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = arith.addi %arg3, %c127_i32 : i32 + %1 = arith.divsi %0, %c128_i32 : i32 + %2 = arith.addi %arg4, %c255_i32 : i32 + %3 = arith.divsi %2, %c256_i32 : i32 + %4 = arith.muli %1, %3 : i32 + %5 = tt.get_program_id x : i32 + %6 = tt.get_num_programs x : i32 + %7 = arith.muli %3, %c8_i32 : i32 + %8 = arith.addi %arg5, %c63_i32 : i32 + %9 = arith.divsi %8, %c64_i32 : i32 + scf.for %arg6 = %5 to %4 step %6 : i32 { + %10 = arith.divsi %arg6, %7 : i32 + %11 = arith.muli %10, %c8_i32 : i32 + %12 = arith.subi %1, %11 : i32 + %13 = arith.minsi %12, %c8_i32 : i32 + %14 = arith.remsi %arg6, %7 : i32 + %15 = arith.remsi %14, %13 : i32 + %16 = arith.addi %11, %15 : i32 + %17 = arith.divsi %14, %13 : i32 + %18 = arith.muli %16, %c128_i32 : i32 + %19 = arith.muli %17, %c256_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %20:2 = scf.for %arg7 = %c0_i32 to %9 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { + %23 = tt.experimental_descriptor_load %arg0[%18, %arg9] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %24 = ttg.local_alloc %23 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %25 = tt.experimental_descriptor_load %arg1[%arg9, %19] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> + %26 = ttg.local_alloc %25 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> + %27 = ttng.warp_group_dot %24, %26, %arg8 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> -> tensor<128x256xf32, #mma> + %28 = arith.addi %arg9, %c64_i32 : i32 + scf.yield %27, %28 : tensor<128x256xf32, #mma>, i32 + } + %21 = arith.truncf %20#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %22 = ttg.convert_layout %21 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.experimental_descriptor_store %arg2[%18, %19], %22 : !tt.tensordesc>, tensor<128x256xf16, #blocked1> + } + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/canonicalize.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/canonicalize.mlir new file mode 100644 index 000000000..b20583546 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/canonicalize.mlir @@ -0,0 +1,13 @@ +// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s + +// CHECK-LABEL: @test_dce_tmem_alloc +// CHECK-NOT: ttng.tmem_alloc +// CHECK: tt.return +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> +#tmem_scales = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { +tt.func @test_dce_tmem_alloc(%arg: tensor<128x4xi8, #linear>) { + %a = ttng.tmem_alloc %arg : (tensor<128x4xi8, #linear>) -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory> + tt.return +} +} // end module diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/invalid.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/invalid.mlir new file mode 100644 index 000000000..8e808b488 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/invalid.mlir @@ -0,0 +1,89 @@ +// RUN: triton-opt --split-input-file %s --verify-diagnostics + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @alloc_tensor_memory(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + // expected-error @+1 {{op should use tensor memory encoding.}} + %0 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #shared, #ttng.tensor_memory, mutable> + tt.return + } +} + +// ----- + +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @alloc_tensor_memory() { + // expected-error @+1 {{uninitialized alloc must have a mutable memdesc type}} + %0 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @alloc_tensor_memory() { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %true = arith.constant true + %0 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> + // expected-error @+1 {{Cannot store into an immutable alloc}} + ttng.tmem_store %cst, %0, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> + tt.return + } +} + +// ----- + +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#tmem = #ttng.tensor_memory_scales_encoding<> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @alloc_tensor_memory(%arg: !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>) { + %cst = arith.constant dense<0> : tensor<128x4xi8, #blocked> + %0 = ttng.tmem_alloc %cst : (tensor<128x4xi8, #blocked>) -> !ttg.memdesc<128x4xi8, #tmem, #ttng.tensor_memory> + // expected-error @+1 {{Cannot copy into an immutable alloc}} + ttng.tmem_copy %arg, %0, : (!ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem, #ttng.tensor_memory>) -> () + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @async_tma_gather(%desc: !tt.ptr, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, + %bar: !ttg.memdesc<2xi32, #shared1, #ttg.shared_memory, mutable>, + %result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, + %pred: i1) { + // expected-error @below {{barrier allocation must be a descriptor of 1xi64 type}} + ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.ptr, tensor<32xi32, #blocked>, i32, !ttg.memdesc<2xi32, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 + tt.return +} +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { +tt.func @async_tma_gather(%desc: !tt.ptr, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, + %bar: !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, + %result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory>, + %pred: i1) { + // expected-error @below {{cannot store into immutable memory}} + ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.ptr, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory>, i1 + tt.return +} +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/membar.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/membar.mlir new file mode 100644 index 000000000..fb9206f3d --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/membar.mlir @@ -0,0 +1,122 @@ +// RUN: triton-opt %s -split-input-file --triton-nvidia-tma-lowering --allocate-shared-memory -test-print-membar | FileCheck %s + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: init_barrier + // CHECK: local_alloc + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: init_barrier + tt.func @init_barrier() { + %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: inval_barrier + // CHECK: local_alloc + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: init_barrier + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: inval_barrier + tt.func @inval_barrier() { + %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.inval_barrier %alloc : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: barrier_expect + // CHECK: local_alloc + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: init_barrier + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: barrier_expect + tt.func @barrier_expect(%pred : i1) { + %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.barrier_expect %alloc, 16384, %pred : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: wait_barrier + // CHECK: local_alloc + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: init_barrier + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: wait_barrier + tt.func @wait_barrier(%phase : i32) { + %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.wait_barrier %alloc, %phase : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + tt.return + } +} + +// ----- + + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked0> { + // CHECK-LABEL: tma_load + // CHECK: local_dealloc + // CHECK-NEXT: local_alloc + // CHECK-NEXT: local_alloc + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: init_barrier + %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> + %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked0> + tt.return %l : tensor<128x64xf16, #blocked0> + } +} + +// ----- + +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: tma_store +// CHECK: ttg.local_alloc +// CHECK-NEXT: ttg.local_dealloc +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: ttg.local_alloc + tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { + %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> + %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked0> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/mma_lowering.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/mma_lowering.mlir new file mode 100644 index 000000000..728cc1ed5 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/mma_lowering.mlir @@ -0,0 +1,59 @@ +// RUN: triton-opt %s -split-input-file --triton-nvidia-mma-lowering | FileCheck %s + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#tmem = #ttng.tensor_memory_encoding +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: gen5_mma_scaled_shmem_to_tmem + tt.func public @gen5_mma_scaled_shmem_to_tmem( + %A_sh: !ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>, + %B_sh: !ttg.memdesc<256x64xf8E5M2, #shared, #ttg.shared_memory>, + %C_tmem: !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, + %A_scale_sh: !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, + %B_scale_sh: !ttg.memdesc<1x2x16x4x4xi8, #shared1, #smem>, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) { + + %true = arith.constant true + // Verify that the scale in tmem has the shape of (LHS) BlockM x BlockK / 32, (RHS) BlockN x BlockK / 32 + // CHECK: %[[A_SC_TMEM:.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_copy {{.*}}, %[[A_SC_TMEM]] + // CHECK: %[[B_SC_TMEM:.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<64x8xi8, #tmem_scales, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_copy {{.*}}, %[[B_SC_TMEM]] + // CHECK: ttng.tc_gen5_mma_scaled {{.*}}, %[[A_SC_TMEM]], %[[B_SC_TMEM]] + ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %C_tmem, %A_scale_sh, %B_scale_sh, %true, %true lhs = e5m2 rhs = e5m2, %barrier : (!ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<256x64xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<1x2x16x4x4xi8, #shared1, #smem>, i1, i1, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) -> () + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#tmem = #ttng.tensor_memory_encoding +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: gen5_mma_scaled_shmem_to_tmem + tt.func public @gen5_mma_scaled_shmem_to_tmem( + %A_sh: !ttg.memdesc<128x256xi8, #shared, #ttg.shared_memory>, + %B_sh: !ttg.memdesc<256x64xi8, #shared, #ttg.shared_memory>, + %C_tmem: !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, + %A_scale_sh: !ttg.memdesc<1x2x32x4x4xf8E4M3FN, #shared1, #smem>, + %B_scale_sh: !ttg.memdesc<1x2x16x4x4xf8E4M3FN, #shared1, #smem>, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) { + + %true = arith.constant true + // Verify that the scale in tmem has the shape of (LHS) BlockM x BlockK / 32, (RHS) BlockN x BlockK / 32 + // CHECK: %[[A_SC_TMEM:.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<128x8xf8E4M3FN, #tmem_scales, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_copy {{.*}}, %[[A_SC_TMEM]] + // CHECK: %[[B_SC_TMEM:.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<64x8xf8E4M3FN, #tmem_scales, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_copy {{.*}}, %[[B_SC_TMEM]] + // CHECK: ttng.tc_gen5_mma_scaled {{.*}}, %[[A_SC_TMEM]], %[[B_SC_TMEM]] + ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %C_tmem, %A_scale_sh, %B_scale_sh, %true, %true lhs = e2m1 rhs = e2m1, %barrier : (!ttg.memdesc<128x256xi8, #shared, #ttg.shared_memory>, !ttg.memdesc<256x64xi8, #shared, #ttg.shared_memory>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1x2x32x4x4xf8E4M3FN, #shared1, #smem>, !ttg.memdesc<1x2x16x4x4xf8E4M3FN, #shared1, #smem>, i1, i1, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) -> () + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/ops.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/ops.mlir new file mode 100644 index 000000000..fc24f4676 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/ops.mlir @@ -0,0 +1,82 @@ +// RUN: triton-opt --split-input-file %s | FileCheck %s + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + + // CHECK-LABEL: @tcgen5 + // CHECK: ttng.tc_gen5_mma + // CHECK: ttng.tc_gen5_mma + tt.func @tcgen5(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, + %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, + %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + %accUse: i1, + %pred: i1, + %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) { + ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier: + (!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + i1, i1, + !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) -> () + + ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred: + (!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + i1, i1) -> () + tt.return + } + + // CHECK-LABEL: @async_tma_gather + // CHECK-SAME: [[DESC:%arg[0-9]+]]: + // CHECK-SAME: [[X_OFFSETS:%arg[0-9]+]]: + // CHECK-SAME: [[Y_OFFSET:%arg[0-9]+]]: + // CHECK-SAME: [[BAR:%arg[0-9]+]]: + // CHECK-SAME: [[RESULT:%arg[0-9]+]]: + // CHECK-SAME: [[PRED:%arg[0-9]+]]: + tt.func @async_tma_gather(%desc: !tt.ptr, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, + %bar: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, + %result: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, + %pred: i1) { + // CHECK-NEXT: ttng.async_tma_gather [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[RESULT]], [[BAR]], [[PRED]] : !tt.ptr, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared2, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable>, i1 + ttng.async_tma_gather %desc[%x_offsets, %y_offset] %result, %bar, %pred : !tt.ptr, tensor<32xi32, #blocked>, i32, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>, i1 + tt.return + } + + // CHECK-LABEL: @async_tma_scatter + // CHECK-SAME: [[DESC:%arg[0-9]+]]: + // CHECK-SAME: [[X_OFFSETS:%arg[0-9]+]]: + // CHECK-SAME: [[Y_OFFSET:%arg[0-9]+]]: + // CHECK-SAME: [[SRC:%arg[0-9]+]]: + tt.func @async_tma_scatter(%desc: !tt.ptr, %x_offsets: tensor<32xi32, #blocked>, %y_offset: i32, + %src: !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable>) { + // CHECK-NEXT: ttng.async_tma_scatter [[DESC]][[[X_OFFSETS]], [[Y_OFFSET]]] [[SRC]] : !tt.ptr, tensor<32xi32, #blocked>, i32, !ttg.memdesc<32x128xbf16, #shared, #smem, mutable> + ttng.async_tma_scatter %desc[%x_offsets, %y_offset] %src : !tt.ptr, tensor<32xi32, #blocked>, i32, !ttg.memdesc<32x128xbf16, #shared, #ttg.shared_memory, mutable> + tt.return + } + + // CHECK-LABEL: @wait_barrier + // CHECK-SAME: [[ALLOC:%arg[0-9]+]]: + // CHECK-SAME: [[PHASE:%arg[0-9]+]]: + tt.func @wait_barrier(%alloc: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, %phase: i32) { + // CHECK-NEXT: ttng.wait_barrier [[ALLOC]], [[PHASE]] : !ttg.memdesc<1xi64, #shared2, #smem, mutable> + ttng.wait_barrier %alloc, %phase : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable> + tt.return + } + + // CHECK-LABEL: @wait_barrier + // CHECK-SAME: [[ALLOC:%arg[0-9]+]]: + // CHECK-SAME: [[PHASE:%arg[0-9]+]]: + // CHECK-SAME: [[DEP1:%arg[0-9]+]]: + // CHECK-SAME: [[DEP2:%arg[0-9]+]]: + tt.func @wait_barrier_deps(%alloc: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, %phase: i32, %dep1: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, %dep2: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory, mutable>) { + // CHECK-NEXT: ttng.wait_barrier [[ALLOC]], [[PHASE]] deps [[DEP1]], [[DEP2]] : !ttg.memdesc<1xi64, #shared2, #smem, mutable>, !ttg.memdesc<1xi64, #shared2, #smem, mutable>, !ttg.memdesc<128x128xf8E5M2, #shared, #smem, mutable> + ttng.wait_barrier %alloc, %phase deps %dep1, %dep2 : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory, mutable> + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/test_promotion_to_tensor_memory.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/test_promotion_to_tensor_memory.mlir new file mode 100644 index 000000000..a54248341 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/test_promotion_to_tensor_memory.mlir @@ -0,0 +1,71 @@ +// RUN: env ENABLE_LHS_TO_TMEM=1 triton-opt %s -split-input-file -tritongpu-promote-lhs-to-tmem | FileCheck %s + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 32}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +// Incompatible access layout for tmem; tmem access requires one thread per datapath +#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @no_tmem_promotion + tt.func public @no_tmem_promotion( + %lhs: tensor<128x32xf32, #blocked1>, + %rhs: tensor<32x256xf32, #blocked2> + ) { + %true = arith.constant true + %cst = arith.constant dense<0.0> : tensor<128x256xf32, #blocked> + // CHECK: ttng.tmem_alloc %[[CST:.*]] : (tensor<128x256xf32, #[[BLOCKED:blocked[0-9]*]]>) -> !ttg.memdesc<128x256xf32, #tmem + %tmem = ttng.tmem_alloc %cst : + (tensor<128x256xf32, #blocked>) -> + !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK-NOT: ttng.tmem_alloc %[[ARG0:.*]] : (tensor<128x32xf32, #[[BLOCKED:blocked[0-9]*]]>) -> !ttg.memdesc<128x32xf32, #[[TMEM:tmem[0-9]*]] + %lhs_shared = ttg.local_alloc %lhs : (tensor<128x32xf32, #blocked1>) -> !ttg.memdesc<128x32xf32, #shared, #ttg.shared_memory> + %rhs_shared = ttg.local_alloc %rhs : (tensor<32x256xf32, #blocked2>) -> !ttg.memdesc<32x256xf32, #shared1, #ttg.shared_memory> + + ttng.tc_gen5_mma %lhs_shared, %rhs_shared, %tmem, %true, %true : + (!ttg.memdesc<128x32xf32, #shared, #ttg.shared_memory>, + !ttg.memdesc<32x256xf32, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>, + i1, i1) -> () + + tt.return + } +} + +// ----- + +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 32}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 32}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +// Compatible layout for tmem access +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: @promote_lhs_to_tmem + tt.func public @promote_lhs_to_tmem( + %lhs: tensor<128x32xf32, #blocked3>, + %rhs: tensor<32x256xf32, #blocked2> + ) { + %true = arith.constant true + %cst = arith.constant dense<0.0> : tensor<128x256xf32, #blocked> + // CHECK: ttng.tmem_alloc %[[CST:.*]] : (tensor<128x256xf32, #[[BLOCKED:blocked[0-9]*]]>) -> !ttg.memdesc<128x256xf32, #tmem + %tmem = ttng.tmem_alloc %cst : + (tensor<128x256xf32, #blocked>) -> + !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %[[ARG0:.*]] : (tensor<128x32xf32, #[[BLOCKED:blocked[0-9]*]]>) -> !ttg.memdesc<128x32xf32, #[[TMEM:tmem[0-9]*]] + %lhs_shared = ttg.local_alloc %lhs : (tensor<128x32xf32, #blocked3>) -> !ttg.memdesc<128x32xf32, #shared, #ttg.shared_memory> + %rhs_shared = ttg.local_alloc %rhs : (tensor<32x256xf32, #blocked2>) -> !ttg.memdesc<32x256xf32, #shared1, #ttg.shared_memory> + + ttng.tc_gen5_mma %lhs_shared, %rhs_shared, %tmem, %true, %true : + (!ttg.memdesc<128x32xf32, #shared, #ttg.shared_memory>, + !ttg.memdesc<32x256xf32, #shared1, #ttg.shared_memory>, + !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>, + i1, i1) -> () + + tt.return + } +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir new file mode 100644 index 000000000..6384bffa1 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir @@ -0,0 +1,176 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -triton-tensor-memory-allocation | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem1 = #ttng.tensor_memory_encoding +#tmem2 = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: ttg.tensor_memory_size = 512 + // CHECK: alloc_tensor_memory + tt.func public @alloc_tensor_memory(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %true = arith.constant true + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked> + %cst1 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked> + %cst2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> + + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %0 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} + %1 = ttng.tmem_alloc %cst0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 256 : i32, tensor_memory_row_offset = 0 : i32} + %2 = ttng.tmem_alloc %cst1 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #tmem1, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 320 : i32, tensor_memory_row_offset = 0 : i32} + %3 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + + ttng.tmem_store %cst, %0, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst0, %1, %true : tensor<128x128xf16, #blocked> -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst1, %2, %true : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #tmem1, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst, %3, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %4 = ttng.tmem_alloc %cst2 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #tmem2, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 16 : i32} + %5 = ttng.tmem_alloc %cst2 : (tensor<64x256xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #tmem2, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} + %6 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + + ttng.tmem_store %cst2, %4, %true : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x128xf16, #tmem2, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %5, %true : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x128xf16, #tmem2, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst, %6, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem1 = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: ttg.tensor_memory_size = 512 + // CHECK: alloc_tensor_memory + tt.func public @alloc_tensor_memory_re_use(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %true = arith.constant true + %c1 = arith.constant 1 : i32 + %c0 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> + %cst1 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked> + + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %a = ttng.tmem_alloc %cst0 : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> + + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %0 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %1 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32} + %2 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %1, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %2, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + + // Test that the 2 allocations above are re-used. + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %3 = ttng.tmem_alloc %cst0 : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> + + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %4 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32} + %5 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %4, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} + %6 = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %s = ttg.memdesc_subview %6[%c1, %c0, %c0] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %7 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 384 : i32, tensor_memory_row_offset = 0 : i32} + %8 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + + ttng.tmem_store %cst, %s, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %7, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %5, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#tmem = #ttng.tensor_memory_encoding +#tmem1 = #ttng.tensor_memory_encoding +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, ttg.shared = 65536 : i32} { + // CHECK-LABEL: multi_ctas + tt.func public @multi_ctas() { + %true = arith.constant true + %cst0 = arith.constant dense<0.000000e+00> : tensor<256x128xf16, #blocked> + + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %0 = ttng.tmem_alloc : () -> !ttg.memdesc<256x128xf16, #tmem, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} + %1 = ttng.tmem_alloc : () -> !ttg.memdesc<256x128xf16, #tmem1, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 256 : i32, tensor_memory_row_offset = 0 : i32} + %2 = ttng.tmem_alloc : () -> !ttg.memdesc<256x128xf16, #tmem, #ttng.tensor_memory, mutable> + + ttng.tmem_store %cst0, %0, %true : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst0, %1, %true : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #tmem1, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst0, %2, %true : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #tmem, #ttng.tensor_memory, mutable> + tt.return + } +} + +// ----- + +#layout = #ttng.tensor_memory_encoding +#tmem = #ttng.tensor_memory + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100"} { + +// CHECK-LABEL: @alloc_warp_specialize +tt.func @alloc_warp_specialize() { + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %0 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #layout, #tmem, mutable> + ttg.warp_specialize() + default { + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} + %1 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #layout, #tmem, mutable> + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} + %2 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #layout, #tmem, mutable> + ttg.warp_yield + } + partition0() num_warps(1) { + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 256 : i32, tensor_memory_row_offset = 0 : i32} + %1 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #layout, #tmem, mutable> + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 384 : i32, tensor_memory_row_offset = 0 : i32} + %2 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #layout, #tmem, mutable> + "use"(%1) : (!ttg.memdesc<128x128xf32, #layout, #tmem, mutable>) -> () + ttg.warp_return + } : () -> () + "use"(%0) : (!ttg.memdesc<128x128xf32, #layout, #tmem, mutable>) -> () + tt.return +} + +// CHECK-LABEL: @alloc_warp_specialize_explicit_capture +tt.func @alloc_warp_specialize_explicit_capture() { + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} + %0 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #layout, #tmem, mutable> + ttg.warp_specialize(%0) + default { + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} + %1 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #layout, #tmem, mutable> + ttg.warp_yield + } + partition0(%arg0: !ttg.memdesc<128x128xf32, #layout, #tmem, mutable>) num_warps(1) { + // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 256 : i32, tensor_memory_row_offset = 0 : i32} + %1 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #layout, #tmem, mutable> + ttg.warp_return + } : (!ttg.memdesc<128x128xf32, #layout, #tmem, mutable>) -> () + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/TritonNvidiaGPU/tma_lowering.mlir b/third_party/enflame/include/triton/test/TritonNvidiaGPU/tma_lowering.mlir new file mode 100644 index 000000000..9ba91af19 --- /dev/null +++ b/third_party/enflame/include/triton/test/TritonNvidiaGPU/tma_lowering.mlir @@ -0,0 +1,88 @@ +// RUN: triton-opt %s -split-input-file --triton-nvidia-tma-lowering | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: tma_load +// CHECK: ttg.local_alloc : () +// CHECK: ttg.local_alloc : () +// CHECK: ttng.init_barrier +// CHECK: ttng.tensor_desc_to_tma_ptr +// CHECK: ttng.async_tma_copy_global_to_local +// CHECK: ttng.wait_barrier +// CHECK: ttng.inval_barrier +// CHECK: ttg.local_load + tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked> { + %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + tt.return %l : tensor<128x64xf16, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: tma_store +// CHECK: ttg.local_alloc +// CHECK: ttng.fence_async_shared {bCluster = false} +// CHECK: ttng.tensor_desc_to_tma_ptr +// CHECK: ttng.async_tma_copy_local_to_global + tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked>) { + tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked> + tt.return + } +} + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: make_tensor_descriptor + // CHECK: %0 = arith.extsi %arg2 : i32 to i64 + // CHECK: %1 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr + // CHECK: tt.experimental_tensormap_create %1, %arg0, [%c32_i32, %c8_i32], [%arg2, %arg1], [%0], [%c1_i32, %c1_i32] {elem_type = 0 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 1 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () + // CHECK: tt.experimental_tensormap_fenceproxy_acquire %1 : !tt.ptr + // CHECK: tt.reinterpret_tensor_descriptor %1 : !tt.ptr to !tt.tensordesc> + tt.func public @make_tensor_descriptor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32} ) -> !tt.tensordesc> { + %c1_i64 = arith.constant 1 : i64 + %cst = arith.constant dense<32> : tensor<8x1xi32> + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = arith.extsi %arg2 : i32 to i64 + %1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : !tt.ptr, !tt.tensordesc> + tt.return %1 : !tt.tensordesc> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + +// CHECK-LABEL: @tma_gather +tt.func @tma_gather(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32, #blocked>, %arg2: i32) -> tensor<32x128xbf16, #blocked1> { + // CHECK: [[RESULT:%.*]] = ttg.local_alloc + // CHECK: [[BARRIER:%.*]] = ttg.local_alloc + // CHECK: ttng.init_barrier [[BARRIER]] + // CHECK: [[DESC_PTR:%.*]] = ttng.tensor_desc_to_tma_ptr %arg0 + // CHECK: ttng.async_tma_gather [[DESC_PTR]][%arg1, %arg2] [[RESULT]], [[BARRIER]], %true + // CHECK: ttng.wait_barrier [[BARRIER]] + // CHECK: ttng.inval_barrier [[BARRIER]] + // CHECK: [[OUT:%.*]] = ttg.local_load [[RESULT]] + %0 = tt.experimental_descriptor_gather %arg0[%arg1, %arg2] : (!tt.tensordesc>, tensor<32xi32, #blocked>, i32) -> tensor<32x128xbf16, #blocked1> + // CHECK: return [[OUT]] + tt.return %0 : tensor<32x128xbf16, #blocked1> +} + +// CHECK-LABEL: @tma_scatter +tt.func @tma_scatter(%arg0: !tt.tensordesc>, %arg1: tensor<32xi32, #blocked>, %arg2: i32, %arg3: tensor<32x128xbf16, #blocked1>) { + // CHECK-NEXT: [[SRC:%.*]] = ttg.local_alloc %arg3 + // CHECK-NEXT: ttng.fence_async_shared {bCluster = false} + // CHECK-NEXT: [[PTR:%.*]] = ttng.tensor_desc_to_tma_ptr %arg0 + // CHECK-NEXT: ttng.async_tma_scatter [[PTR]][%arg1, %arg2] [[SRC]] + // CHECK-NEXT: ttng.async_tma_store_wait + tt.experimental_descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc>, tensor<32xi32, #blocked>, i32, tensor<32x128xbf16, #blocked1> + tt.return +} + +} diff --git a/third_party/enflame/include/triton/test/lib/Analysis/CMakeLists.txt b/third_party/enflame/include/triton/test/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..1bf9d8470 --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Analysis/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_library(TritonTestAnalysis + TestAlias.cpp + TestAxisInfo.cpp + TestAllocation.cpp + TestMembar.cpp + + LINK_LIBS PUBLIC + MLIRPass + ${triton_libs} +) diff --git a/third_party/enflame/include/triton/test/lib/Analysis/TestAlias.cpp b/third_party/enflame/include/triton/test/lib/Analysis/TestAlias.cpp new file mode 100644 index 000000000..038467aac --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Analysis/TestAlias.cpp @@ -0,0 +1,108 @@ +#include "mlir/IR/AsmState.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Alias.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; + +namespace { + +struct TestAliasPass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); + + static std::string getValueOperandName(Value value, AsmState &state) { + std::string opName; + llvm::raw_string_ostream ss(opName); + value.printAsOperand(ss, state); + return opName; + } + + static void emit(Location loc, StringRef name, + SmallVector &vals) { + if (vals.empty()) + return; + InFlightDiagnostic diag = mlir::emitRemark(loc); + diag << name << " -> "; + size_t i = 0; + for (auto val : vals) { + if (i != 0) + diag << ","; + diag << val; + ++i; + } + } + + StringRef getArgument() const final { return "test-print-alias"; } + StringRef getDescription() const final { + return "print the result of the alias analysis pass"; + } + + void runOnOperation() override { + Operation *operation = getOperation(); + + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *analysis = + solver->load(); + if (failed(solver->initializeAndRun(operation))) + return signalPassFailure(); + + AsmState state(operation->getParentOfType()); + // Get operation ids of value's aliases + auto getLocalAllocOpNames = [&](Value value) { + dataflow::Lattice *latticeElement = + analysis->getLatticeElement(value); + SmallVector opNames; + if (latticeElement) { + auto &info = latticeElement->getValue(); + for (auto &alias : info.getAllocs()) { + auto opName = + getValueOperandName(alias.getDefiningOp()->getResult(0), state); + opNames.push_back(std::move(opName)); + } + } + // Ensure deterministic output + std::sort(opNames.begin(), opNames.end()); + return opNames; + }; + + operation->walk([&](Operation *op) { + if (op->getNumResults() < 1) { + // cond br, br + if (auto branch = dyn_cast(op)) { + auto *block = branch->getBlock(); + for (auto arg : llvm::enumerate(block->getArguments())) { + auto operand = block->getArgument(arg.index()); + auto opNames = getLocalAllocOpNames(operand); + auto argName = getValueOperandName(arg.value(), state); + emit(op->getLoc(), argName, opNames); + } + } + return; + } + if (auto forOp = dyn_cast(op)) { + for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { + auto operand = forOp.getTiedLoopInit(arg.value())->get(); + auto opNames = getLocalAllocOpNames(operand); + auto argName = getValueOperandName(arg.value(), state); + emit(op->getLoc(), argName, opNames); + } + } + for (auto result : llvm::enumerate(op->getResults())) { + auto opNames = getLocalAllocOpNames(result.value()); + auto resultName = getValueOperandName(result.value(), state); + emit(op->getLoc(), resultName, opNames); + } + }); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestAliasPass() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/third_party/enflame/include/triton/test/lib/Analysis/TestAllocation.cpp b/third_party/enflame/include/triton/test/lib/Analysis/TestAllocation.cpp new file mode 100644 index 000000000..9c22d4985 --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Analysis/TestAllocation.cpp @@ -0,0 +1,90 @@ +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" + +using namespace mlir; + +namespace { + +unsigned getScratchSize128(Operation *) { return 128; } + +enum class GetScratchSizeFunction { + None, + ValidConstant, +}; + +struct TestAllocationPass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); + + TestAllocationPass() = default; + TestAllocationPass(const TestAllocationPass &other) + : PassWrapper>(other) {} + + StringRef getArgument() const final { return "test-print-allocation"; } + StringRef getDescription() const final { + return "print the result of the allocation pass"; + } + + ModuleAllocation getModuleAllocation() { + switch (getScratchSizeFunction) { + case GetScratchSizeFunction::None: + return {getOperation()}; + case GetScratchSizeFunction::ValidConstant: + return {getOperation(), getScratchSize128}; + } + llvm_unreachable("Unhandled case"); + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + // Convert to std::string can remove quotes from opName + ModuleAllocation moduleAllocation = getModuleAllocation(); + moduleOp.walk([&](triton::FuncOp funcOp) { + auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); + mlir::emitRemark(funcOp.getLoc(), opName); + auto *allocation = moduleAllocation.getFuncData(funcOp); + funcOp.walk([&](Operation *op) { + auto scratchBufferId = allocation->getBufferId(op); + if (scratchBufferId != Allocation::InvalidBufferId) { + size_t offset = allocation->getOffset(scratchBufferId); + size_t size = allocation->getAllocatedSize(scratchBufferId); + mlir::emitRemark(op->getLoc()) + << (allocation->isVirtualBuffer(scratchBufferId) ? "virtual" + : "scratch") + << " offset = " << offset << ", size = " << size; + } + if (op->getNumResults() < 1) + return; + for (Value result : op->getResults()) { + auto bufferId = allocation->getBufferId(result); + if (bufferId != Allocation::InvalidBufferId) { + size_t offset = allocation->getOffset(bufferId); + size_t size = allocation->getAllocatedSize(bufferId); + mlir::emitRemark(op->getLoc()) + << "offset = " << offset << ", size = " << size; + } + } + }); + mlir::emitRemark(funcOp.getLoc()) + << "size = " << allocation->getSharedMemorySize(); + }); + } + + Option getScratchSizeFunction{ + *this, "get-scratch-size-function", + llvm::cl::desc("Custom scratch size function to use"), + llvm::cl::init(GetScratchSizeFunction::None), + llvm::cl::values( + clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"), + clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant", + "ValidConstant"))}; +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestAllocationPass() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/third_party/enflame/include/triton/test/lib/Analysis/TestAxisInfo.cpp b/third_party/enflame/include/triton/test/lib/Analysis/TestAxisInfo.cpp new file mode 100644 index 000000000..54663c36b --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Analysis/TestAxisInfo.cpp @@ -0,0 +1,52 @@ +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct TestAxisInfoPass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass); + + StringRef getArgument() const final { return "test-print-alignment"; } + StringRef getDescription() const final { + return "print the result of the alignment analysis pass"; + } + + void runOnOperation() override { + Operation *operation = getOperation(); + ModuleOp moduleOp = cast(operation); + ModuleAxisInfoAnalysis moduleAxisInfoAnalysis(moduleOp); + moduleOp.walk([&](FuncOp funcOp) { + funcOp.walk([&](Operation *op) { + if (op->getNumResults() < 1) + return; + for (Value result : op->getResults()) { + InFlightDiagnostic diag = mlir::emitRemark(op->getLoc()); + diag << result; + diag << " => "; + auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result); + if (axisInfo) { + std::string str; + llvm::raw_string_ostream os(str); + axisInfo->print(os); + diag << str; + } + } + }); + }); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestAlignmentPass() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/third_party/enflame/include/triton/test/lib/Analysis/TestMembar.cpp b/third_party/enflame/include/triton/test/lib/Analysis/TestMembar.cpp new file mode 100644 index 000000000..25e8e2d19 --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Analysis/TestMembar.cpp @@ -0,0 +1,41 @@ +#include "../third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" + +using namespace mlir; + +namespace { + +struct TestMembarPass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); + + StringRef getArgument() const final { return "test-print-membar"; } + StringRef getDescription() const final { + return "print the result of the allocation pass"; + } + + void runOnOperation() override { + Operation *operation = getOperation(); + ModuleOp moduleOp = cast(operation); + // Print all ops after membar pass + ModuleAllocation allocation(moduleOp); + ModuleMembarAnalysis membarPass(&allocation, + mlir::triton::NVIDIA::canSkipBarSync); + membarPass.run(); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestMembarPass() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/third_party/enflame/include/triton/test/lib/CMakeLists.txt b/third_party/enflame/include/triton/test/lib/CMakeLists.txt new file mode 100644 index 000000000..5dd06592b --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Analysis) +add_subdirectory(Dialect) +add_subdirectory(Instrumentation) diff --git a/third_party/enflame/include/triton/test/lib/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/test/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..eba47a67c --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonGPU) diff --git a/third_party/enflame/include/triton/test/lib/Dialect/TritonGPU/CMakeLists.txt b/third_party/enflame/include/triton/test/lib/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..63f2d1d6b --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_library(TritonTestDialectTritonGPU + TestTC05MMAPipeline.cpp + + DEPENDS + TritonGPUTransformsIncGen + + LINK_LIBS PUBLIC + MLIRPass + ${triton_libs} +) diff --git a/third_party/enflame/include/triton/test/lib/Dialect/TritonGPU/TestTC05MMAPipeline.cpp b/third_party/enflame/include/triton/test/lib/Dialect/TritonGPU/TestTC05MMAPipeline.cpp new file mode 100644 index 000000000..259558dab --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Dialect/TritonGPU/TestTC05MMAPipeline.cpp @@ -0,0 +1,30 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTC05MMAPIPELINE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct TC05MMAPipelinePass + : public impl::TritonGPUTC05MMAPipelineBase { + + using impl::TritonGPUTC05MMAPipelineBase< + TC05MMAPipelinePass>::TritonGPUTC05MMAPipelineBase; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + pipelineTC05MMALoops(m, /*numStages=*/2, disableExpander); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/test/lib/Instrumentation/CMakeLists.txt b/third_party/enflame/include/triton/test/lib/Instrumentation/CMakeLists.txt new file mode 100644 index 000000000..6ab340ff5 --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Instrumentation/CMakeLists.txt @@ -0,0 +1,40 @@ +set(GPU_INSTRUMENTATION_PASSES + GPUInstrumentationTestLib + ) + +set(GPUInstrumentationTestLib_SOURCES + GPUHello.cpp + ) + + +foreach( plugin ${GPU_INSTRUMENTATION_PASSES} ) + add_library( + ${plugin} + SHARED + ${${plugin}_SOURCES} + ) + + target_link_libraries( + ${plugin} + PRIVATE + LLVMCore + "$<$:-undefined dynamic_lookup>" + ) + # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python + # build. It is empty if building directly from the root + # CMakeLists.txt file. Therefore if not building from Python just + # use the default CMake shared lib path otherwise this causes a hard + # build error + if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + set_target_properties(${plugin} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY + "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation") + endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + + # This is set to -fvisibility=hidden in the top level CMake file + # which causes the llvmGetPassPluginInfo symbol to be hidden and + # an "entry point not found" error. Reset it just for this target + if(NOT MSVC) + target_compile_options(${plugin} PRIVATE -fvisibility=default) + endif() +endforeach() diff --git a/third_party/enflame/include/triton/test/lib/Instrumentation/GPUHello.cpp b/third_party/enflame/include/triton/test/lib/Instrumentation/GPUHello.cpp new file mode 100644 index 000000000..5c71857c8 --- /dev/null +++ b/third_party/enflame/include/triton/test/lib/Instrumentation/GPUHello.cpp @@ -0,0 +1,76 @@ +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +using namespace llvm; +using namespace std; + +namespace { + +struct GpuHello : public PassInfoMixin { + PreservedAnalyses run(Module &module, ModuleAnalysisManager &) { + bool modifiedCodeGen = runOnModule(module); + + return (modifiedCodeGen ? llvm::PreservedAnalyses::none() + : llvm::PreservedAnalyses::all()); + } + bool runOnModule(llvm::Module &module); + // isRequired being set to true keeps this pass from being skipped + // if it has the optnone LLVM attribute + static bool isRequired() { return true; } +}; + +} // end anonymous namespace + +bool GpuHello::runOnModule(Module &module) { + bool modifiedCodeGen = false; + + for (auto &function : module) { + if (function.isIntrinsic()) + continue; + StringRef functionName = function.getName(); + if (function.getCallingConv() == CallingConv::AMDGPU_KERNEL || + function.getCallingConv() == CallingConv::PTX_Kernel || + functionName.contains("kernel")) { + for (Function::iterator basicBlock = function.begin(); + basicBlock != function.end(); basicBlock++) { + for (BasicBlock::iterator inst = basicBlock->begin(); + inst != basicBlock->end(); inst++) { + DILocation *debugLocation = + dyn_cast(inst)->getDebugLoc(); + std::string sourceInfo = + (function.getName() + "\t" + debugLocation->getFilename() + ":" + + Twine(debugLocation->getLine()) + ":" + + Twine(debugLocation->getColumn())) + .str(); + + errs() << "Hello From First Instruction of GPU Kernel: " << sourceInfo + << "\n"; + return modifiedCodeGen; + } + } + } + } + return modifiedCodeGen; +} + +PassPluginLibraryInfo getPassPluginInfo() { + const auto callback = [](PassBuilder &pb) { + pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto, auto) { + mpm.addPass(GpuHello()); + return true; + }); + }; + + return {LLVM_PLUGIN_API_VERSION, "gpu-hello", LLVM_VERSION_STRING, callback}; +}; + +extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo +llvmGetPassPluginInfo() { + return getPassPluginInfo(); +} diff --git a/third_party/enflame/include/triton/test/lit.cfg.py b/third_party/enflame/include/triton/test/lit.cfg.py new file mode 100644 index 000000000..4aec81e34 --- /dev/null +++ b/third_party/enflame/include/triton/test/lit.cfg.py @@ -0,0 +1,69 @@ +# -*- Python -*- +# ruff: noqa: F821 + +import os + +import lit.formats +import lit.util +from lit.llvm import llvm_config +from lit.llvm.subst import ToolSubst + +# Configuration file for the 'lit' test runner + +# (config is an instance of TestingConfig created when discovering tests) +# name: The name of this test suite +config.name = 'TRITON' + +config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) + +# suffixes: A list of file extensions to treat as test files. +config.suffixes = ['.mlir', '.ll'] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.triton_obj_root, 'test') +config.substitutions.append(('%PATH%', config.environment['PATH'])) +config.substitutions.append(("%shlibdir", config.llvm_shlib_dir)) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) + +llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) + +# llvm_config.use_default_substitutions() + +# excludes: A list of directories to exclude from the testsuite. The 'Inputs' +# subdirectories contain auxiliary inputs for various tests in their parent +# directories. +config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.triton_obj_root, 'test') +config.triton_tools_dir = os.path.join(config.triton_obj_root, 'bin') +config.filecheck_dir = os.path.join(config.triton_obj_root, 'bin', 'FileCheck') + +# FileCheck -enable-var-scope is enabled by default in MLIR test +# This option avoids to accidentally reuse variable across -LABEL match, +# it can be explicitly opted-in by prefixing the variable name with $ +config.environment["FILECHECK_OPTS"] = "--enable-var-scope" + +tool_dirs = [config.triton_tools_dir, config.llvm_tools_dir, config.filecheck_dir] + +# Tweak the PATH to include the tools dir. +for d in tool_dirs: + llvm_config.with_environment('PATH', d, append_path=True) +tools = [ + 'triton-opt', + 'triton-llvm-opt', + ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), +] + +llvm_config.add_tool_substitutions(tools, tool_dirs) + +# TODO: what's this? +llvm_config.with_environment('PYTHONPATH', [ + os.path.join(config.mlir_binary_dir, 'python_packages', 'triton'), +], append_path=True) diff --git a/third_party/enflame/include/triton/test/lit.site.cfg.py.in b/third_party/enflame/include/triton/test/lit.site.cfg.py.in new file mode 100644 index 000000000..fd1fb486d --- /dev/null +++ b/third_party/enflame/include/triton/test/lit.site.cfg.py.in @@ -0,0 +1,23 @@ +@LIT_SITE_CFG_IN_HEADER@ + +import sys + +config.triton_obj_root = "@triton_BINARY_DIR@" +config.llvm_src_root = "@LLVM_SOURCE_DIR@" +config.llvm_obj_root = "@LLVM_BINARY_DIR@" +config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" +config.llvm_lib_dir = "@LLVM_LIBS_DIR@" +config.llvm_shlib_dir = "@CMAKE_LIBRARY_OUTPUT_DIRECTORY@" +config.llvm_shlib_ext = "@CMAKE_SHARED_LIBRARY_SUFFIX@" +config.llvm_exe_ext = "@EXEEXT@" +config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" +config.mlir_binary_dir = "@MLIR_BINARY_DIR@" +config.python_executable = "@Python3_EXECUTABLE@" +config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ + + +import lit.llvm +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work +lit_config.load_config(config, "@triton_SOURCE_DIR@/test/lit.cfg.py") diff --git a/third_party/enflame/include/triton/third_party/amd/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/CMakeLists.txt new file mode 100644 index 000000000..b030dbbd1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/CMakeLists.txt @@ -0,0 +1,12 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM) + target_link_libraries(TritonAMD PRIVATE Python3::Module pybind11::headers) +endif() +if(TRITON_BUILD_UT) + add_subdirectory(unittest) +endif() +add_subdirectory(test) diff --git a/third_party/enflame/include/triton/third_party/amd/backend/compiler.py b/third_party/enflame/include/triton/third_party/amd/backend/compiler.py new file mode 100644 index 000000000..cf2687a5f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/compiler.py @@ -0,0 +1,419 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes, llvm, amd +from dataclasses import dataclass +from typing import Any, Dict, Tuple +from types import ModuleType +import hashlib +import tempfile +import os +import re +import subprocess +import functools +from pathlib import Path + + +def min_dot_size(target: GPUTarget): + # If some given configuration is not supported in hardware we fallback to FMA and cast arguments + return lambda lhsType, rhsType: (1, 1, 1) + + +def is_pingpong_enabled(arch): + default = "1" if arch == "gfx942" else "0" + return os.getenv("TRITON_HIP_USE_BLOCK_PINGPONG", default) == "1" + + +@dataclass(frozen=True) +class HIPOptions: + num_warps: int = 4 + waves_per_eu: int = 1 + num_stages: int = 2 + num_ctas: int = 1 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 + extern_libs: dict = None + cluster_dims: tuple = (1, 1, 1) + debug: bool = False + sanitize_overflow: bool = True + arch: str = None + supported_fp8_dtypes: Tuple[str] = ("fp8e5", ) + deprecated_fp8_dtypes: Tuple[str] = () + default_dot_input_precision: str = "ieee" + allowed_dot_input_precisions: Tuple[str] = ("ieee", ) + enable_fp_fusion: bool = True + launch_cooperative_grid: bool = False + matrix_instr_nonkdim: int = 0 + kpack: int = 1 + allow_flush_denorm: bool = False + max_num_imprecise_acc_default: int = 0 + backend_name: str = 'hip' + + # The following option provides hints to the AMDGPU backend regarding instruction scheduling + # for all `tt.dot` operations in a kernel. The "none" variant preserves the default + # instruction scheduling of the AMDGPU backend which aims at maximizing occupancy. + # The option is experimental and may change at any time regarding its semantics and/or may + # be gone entirely anytime. + # + # Current experimental scheduling variants: + # + # llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's + # k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels". + # llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's + # k-loop; i.e., "interleave DS and MFMA instructions for single wave small + # GEMM kernels.". + # local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable + # Kernel library. Note, this variant requires the use of buffer load/store ops + # and a special software pipelining style - i.e., 1x LDS and 1x register + # prefetch buffers for each GEMM tile. + instruction_sched_variant: str = 'none' + + def __post_init__(self): + default_libdir = Path(__file__).parent / 'lib' + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + # Ignore user-defined warp size for gfx9 + warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64 + object.__setattr__(self, 'warp_size', warp_size) + # Only kpack=1 is supported on gfx950 + kpack = 1 if self.arch == 'gfx950' else self.kpack + object.__setattr__(self, 'kpack', kpack) + libs = ["ocml", "ockl"] + for lib in libs: + extern_libs[lib] = str(default_libdir / f'{lib}.bc') + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class HIPBackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'hip' + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + assert isinstance(target.arch, str) + self.binary_ext = "hsaco" + + def parse_options(self, opts) -> Any: + args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)} + + # Enable XF32 (TF32) for CDNA3 GPUs + if self.target.arch in ('gfx940', 'gfx941', 'gfx942'): + allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions) + allowed_dot_input_precisions.update({'tf32'}) + args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions)) + + if "supported_fp8_dtypes" not in opts: + supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes) + if self.target.arch in ('gfx940', 'gfx941', 'gfx942'): + supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'}) + elif self.target.arch in ('gfx950'): + supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'}) + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + + if "enable_fp_fusion" not in opts: + args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" + args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None}) + return HIPOptions(**args) + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def get_codegen_implementation(self, options): + codegen_fns = {"min_dot_size": min_dot_size(self.target)} + return codegen_fns + + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.hip import libdevice + + return {"triton.language.extra.libdevice": libdevice} + + def load_dialects(self, ctx): + amd.load_dialects(ctx) + + @staticmethod + @functools.lru_cache() + def use_buffer_ops(): + return os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1" + + @staticmethod + def is_within_2gb(arg): + import torch + + MAX_INT_32 = 2**31 - 1 + if hasattr(arg, "ptr_range"): + return arg.ptr_range() <= MAX_INT_32 + if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"): + return arg.untyped_storage().size() <= MAX_INT_32 + return False + + @staticmethod + def parse_attr(desc): + ret = BaseBackend.parse_attr(desc) + if "S" in desc: + ret += [["tt.pointer_range", 32]] + return ret + + @staticmethod + def get_arg_specialization(arg, ty, **kwargs): + ret = BaseBackend.get_arg_specialization(arg, ty, **kwargs) + # Only attempt to do buffer ops specialization if buffer ops are enabled. + # Otherwise the is_within_2gb check is unnecessary overhead. + if HIPBackend.use_buffer_ops() and ty == "tensor" and HIPBackend.is_within_2gb(arg): + ret += "S" + return ret + + @staticmethod + def path_to_rocm_lld(): + # Check env path for ld.lld + lld_env_path = os.getenv("TRITON_HIP_LLD_PATH") + if lld_env_path is not None: + lld = Path(lld_env_path) + if lld.is_file(): + return lld + # Check backend for ld.lld (used for pytorch wheels) + lld = Path(__file__).parent / "llvm/bin/ld.lld" + if lld.is_file(): + return lld + lld = Path("/opt/rocm/llvm/bin/ld.lld") + if lld.is_file(): + return lld + lld = Path("/usr/bin/ld.lld") + if lld.is_file(): + return lld + raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found. Set 'TRITON_HIP_LLD_PATH' to its path.") + + @staticmethod + def make_ttir(mod, metadata, options): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + passes.ttir.add_loop_unroll(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, options): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size, + options.num_ctas) + pm.run(mod) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttgpuir.add_coalesce(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack) + passes.ttgpuir.add_remove_layout_conversions(pm) + amd.passes.ttgpuir.add_optimize_epilogue(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, True) + amd.passes.ttgpuir.add_hoist_layout_conversions(pm) + + global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0")) + local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0")) + + # The `local-prefetch` scheduling variant requires turning on buffer ops. + if options.instruction_sched_variant == "local-prefetch": + global_prefetch = local_prefetch = 1 + + if amd.has_matrix_core_feature(options.arch): + assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. " + "We used to trigger software pipelining with " + "num_stages == 0. Now it will not happen anymore; " + "please update to use num_stages == 2 for " + "equivalent behavior in the past.") + amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch) + passes.common.add_canonicalizer(pm) + if options.instruction_sched_variant.lower() != "none": + amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant) + passes.ttgpuir.add_optimize_dot_operands(pm, True) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + if amd.has_matrix_core_feature(options.arch): + amd.passes.ttgpuir.add_reorder_instructions(pm) + use_block_pingpong = is_pingpong_enabled(options.arch) + if use_block_pingpong and options.num_stages == 2: + amd.passes.ttgpuir.add_block_pingpong(pm) + + if HIPBackend.use_buffer_ops(): + amd.passes.ttgpuir.add_canonicalize_pointers(pm) + passes.common.add_canonicalizer(pm) + amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_llir(src, metadata, options): + mod = src + # TritonGPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch) + # custom_lds_size is an experimental parameter that defines amount of LDS available + # for one thread block. Measured in bytes. + # + # If custom_lds_size = 0, pass will consider all LDS is available for one threads block, + # LDS size is determined by provided arch name. + custom_lds_size = 0 + amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size) + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + + passes.ttgpuir.add_allocate_shared_memory(pm) + ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows: + ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless + ## of the value of kernel arg `allow_flush_denorm`. + ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output + ## depends on the value of kernel arg `allow_flush_denorm`. + ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument. + ## For now it is used as a controller for developers only. + __HIP_FTZ = True + amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + + passes.convert.add_cf_to_llvmir(pm) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if options.instruction_sched_variant.lower() != "none": + amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ) + pm.run(mod) + + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + amd.attach_target_triple(llvm_mod) + target_features = '' + if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1": + target_features = '+xnack' + llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features) + + # Set various control constants on the LLVM module so that device + # libraries can resolve references to them. + amd.set_isa_version(llvm_mod, options.arch) + amd.set_abi_version(llvm_mod, 500) + amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False) + amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True) + amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False) + amd.set_bool_control_constant(llvm_mod, "__oclc_wavefrontsize64", options.warp_size == 64) + + # Set kernel attributes first given this may affect later optimizations. + fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()] + # The public kernel should be kernel 0. + fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL) + fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}") + # LLVM AMDGPU backend supports the attribute "amdgpu-waves-per-eu"="[, ]". + # This attribute may be attached to a kernel function definition and is an optimization hint. + # parameter specifies the requested minimum number of waves per EU, and optional parameter + # specifies the requested maximum number of waves per EU (must be greater than if specified). + # If is omitted, then there is no restriction on the maximum number of waves per EU other than + # the one dictated by the hardware for which the kernel is compiled. Passing 0, 0 as , + # implies the default behavior (no limits). + fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}") + denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee" + fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode) + if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1": + fns[0].add_fn_target_feature("+xnack") + fns[0].add_fn_asan_attr() + + # Hint the compiler that we'd like the firmware to set the kernel arguments + # to user SGPRs so that the kernel does not need to s_load its arguments + # from memory. + amd.set_all_fn_arg_inreg(fns[0]) + + if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1": + default_libdir = Path(__file__).parent / 'lib' + paths = [ + str(default_libdir / 'asanrtl.bc'), + str(default_libdir / "ocml.bc"), + str(default_libdir / "ockl.bc") + ] + llvm.link_extern_libs(llvm_mod, paths) + elif options.extern_libs: + paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)] + llvm.link_extern_libs(llvm_mod, paths) + + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion) + + # Get some metadata + metadata["shared"] = src.get_int_attr("ttg.shared") + + amd.cleanup_bitcode_metadata(llvm_mod) + # Disable inlining of print related functions, + # because inlining of these function could slow down compilation significantly + amd.disable_print_inline(llvm_mod) + return str(llvm_mod) + + @staticmethod + def make_amdgcn(src, metadata, options): + # Find kernel names (there should only be one) + # We get the name at the last possible step to accomodate `triton.compile` + # on user-provided LLVM + names = re.findall(r"define amdgpu_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src) + assert len(names) == 1 + metadata["name"] = names[0] + # llvm -> hsaco + amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False) + if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1": + print("// -----// AMDGCN Dump //----- //") + print(amdgcn) + return amdgcn + + @staticmethod + def make_hsaco(src, metadata, options): + target_features = '' + if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1": + target_features = '+xnack' + hsaco = amd.assemble_amdgcn(src, options.arch, target_features) + + rocm_path = HIPBackend.path_to_rocm_lld() + with tempfile.NamedTemporaryFile() as tmp_out: + with tempfile.NamedTemporaryFile() as tmp_in: + with open(tmp_in.name, 'wb') as fd_in: + fd_in.write(hsaco) + subprocess.check_call([rocm_path, '-flavor', 'gnu', '-shared', tmp_in.name, '-o', tmp_out.name]) + with open(tmp_out.name, 'rb') as fd_out: + ret = fd_out.read() + return ret + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) + stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options) + stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options) + + @functools.lru_cache() + def hash(self): + version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding='utf-8') + return f'{version}-{self.target}' diff --git a/third_party/enflame/include/triton/third_party/amd/backend/driver.c b/third_party/enflame/include/triton/third_party/amd/backend/driver.c new file mode 100644 index 000000000..62eee09e7 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/driver.c @@ -0,0 +1,211 @@ +#define __HIP_PLATFORM_AMD__ +// clang-format off +// hip_depreated.h needs definitions from hip_runtime.h. +#include +#include +// clang-format on +#define PY_SSIZE_T_CLEAN +#include +#include +#include +#include + +// The list of paths to search for the HIP runtime library. The caller Python +// code should substitute the search path placeholder. +static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"}; + +// The list of HIP dynamic library symbols and their signature we are interested +// in this file. +// |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t; +// |FOR_EACH_STR_FN| is a macro to process APIs that return const char *. +// +// HIP 6.0 introduced an updated hipGetDeviceProperties API under a new symbol, +// hipGetDevicePropertiesR0600. However, the associated hipDeviceProp_t was +// directly updated with breaking changes to match hipGetDevicePropertiesR0600 +// in the header file. We include the header file from HIP 6.0. So here if we +// use hipGetDeviceProperties together with hipDeviceProp_t we will use the +// old API with a new struct definition and mess up the interpretation. +// +// This is a known issue: https://github.com/ROCm/ROCm/issues/2728. +// +// For now explicitly defer to the old hipDeviceProp_t struct. This should work +// for both 5.x and 6.x. In the long term we need to switch to use +// hipGetProcAddress once available: +// https://github.com/ROCm/clr/commit/0479cdb3dd30ef58718cad44e424bd793c394cc0 +#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \ + FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \ + FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_tR0000 *prop, \ + int deviceId) \ + FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \ + unsigned int numOptions, hipJitOption *options, \ + void **optionValues) \ + FOR_EACH_ERR_FN(hipModuleGetFunction, hipFunction_t *function, \ + hipModule_t module, const char *kname) \ + FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \ + hipFunction_t function) + +// The HIP symbol table for holding resolved dynamic library symbols. +struct HIPSymbolTable { +#define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \ + hipError_t (*hipSymbolName)(__VA_ARGS__); +#define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \ + const char *(*hipSymbolName)(__VA_ARGS__); + + HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD) +}; + +static struct HIPSymbolTable hipSymbolTable; + +bool initSymbolTable() { + // Use the HIP runtime library loaded into the existing process if it exits. + void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD); + if (lib) { + // printf("[triton] chosen loaded libamdhip64.so in the process\n"); + } + + // Otherwise, go through the list of search paths to dlopen the first HIP + // driver library. + if (!lib) { + int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]); + for (int i = 0; i < n; ++i) { + void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL); + if (handle) { + lib = handle; + // printf("[triton] chosen %s\n", hipLibSearchPaths[i]); + } + } + } + if (!lib) { + PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so"); + return false; + } + + // Resolve all symbols we are interested in. + dlerror(); // Clear existing errors + const char *error = NULL; +#define QUERY_EACH_FN(hipSymbolName, ...) \ + *(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \ + error = dlerror(); \ + if (error) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "cannot query " #hipSymbolName " from libamdhip64.so"); \ + dlclose(lib); \ + return false; \ + } + + HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN) + + return true; +} + +static inline void gpuAssert(hipError_t code, const char *file, int line) { + { + if (code != HIP_SUCCESS) { + { + const char *prefix = "Triton Error [HIP]: "; + const char *str = hipSymbolTable.hipGetErrorString(code); + char err[1024] = {0}; + snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + } + } + } +} + +#define HIP_CHECK(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + if (PyErr_Occurred()) \ + return NULL; \ + } + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + + hipDeviceProp_tR0000 props; + HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id)); + + // create a struct to hold device properties + return Py_BuildValue( + "{s:i, s:i, s:i, s:i, s:i, s:i, s:s, s:i, s:i}", "max_shared_mem", + props.sharedMemPerBlock, "max_num_regs", props.regsPerBlock, + "multiprocessor_count", props.multiProcessorCount, "sm_clock_rate", + props.clockRate, "mem_clock_rate", props.memoryClockRate, "mem_bus_width", + props.memoryBusWidth, "arch", props.gcnArchName, "warpSize", + props.warpSize, "max_threads_per_sm", props.maxThreadsPerMultiProcessor); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + + // set HIP options + hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, + hipJitOptionErrorLogBuffer, + hipJitOptionInfoLogBufferSizeBytes, + hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose}; + const unsigned int errbufsize = 8192; + const unsigned int logbufsize = 8192; + char _err[errbufsize]; + char _log[logbufsize]; + void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err, + (void *)(uintptr_t)logbufsize, (void *)_log, (void *)1}; + + // launch HIP Binary + hipModule_t mod; + hipFunction_t fun; + HIP_CHECK(hipSymbolTable.hipModuleLoadDataEx(&mod, data, 5, opt, optval)) + HIP_CHECK(hipSymbolTable.hipModuleGetFunction(&fun, mod, name)); + + // get allocated registers and spilled registers from the function + int n_regs = 0; + int n_spills = 0; + hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun); + hipSymbolTable.hipFuncGetAttribute(&n_spills, + HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); + n_spills /= 4; + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided hsaco into HIP driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_hip_utils(void) { + if (!initSymbolTable()) { + return NULL; + } + + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/enflame/include/triton/third_party/amd/backend/driver.py b/third_party/enflame/include/triton/third_party/amd/backend/driver.py new file mode 100644 index 000000000..b99ff86c8 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/driver.py @@ -0,0 +1,546 @@ +import functools +import os +import hashlib +import subprocess +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.compiler import GPUTarget +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dir = [os.path.join(dirname, "include")] + + +def _find_already_mmapped_dylib_on_linux(lib_name): + import platform + if platform.system() != 'Linux': + return None + + # Use dl_iterate_phdr to walk through the list of shared libraries at runtime. + # See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details. + + import ctypes + from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER + + class DlPhdrInfo(ctypes.Structure): + _fields_ = [ + ('dlpi_addr', c_void_p), + ('dlpi_name', c_char_p), + # We don't care about the remaining fields. + ] + + # callback_t must use POINTER(c_char) to avoid copying. + callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char)) + + # Load libc and get the dl_iterate_phdr symbol. + try: + dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr + except Exception: + return None + # argtypes must use c_char_p to accept create_string_buffer. + dl_iterate_phdr.argtypes = [callback_t, c_char_p] + dl_iterate_phdr.restype = c_int + + max_path_length = 4096 + path = ctypes.create_string_buffer(max_path_length + 1) + + # Define callback to get the loaded dylib path. + def callback(info, size, data): + dlpi_name = info.contents.dlpi_name + p = Path(os.fsdecode(dlpi_name)) + if lib_name in p.name: + # Found the dylib; get its path. + ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name))) + return 1 + return 0 + + if dl_iterate_phdr(callback_t(callback), path): + return os.fsdecode(ctypes.string_at(path)) + return None + + +@functools.lru_cache() +def _get_path_to_hip_runtime_dylib(): + lib_name = "libamdhip64.so" + + # If we are told explicitly what HIP runtime dynamic library to use, obey that. + env_libhip_path = os.getenv("TRITON_LIBHIP_PATH") + if env_libhip_path: + if env_libhip_path.endswith(lib_name) and os.path.exists(env_libhip_path): + return env_libhip_path + raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}") + + # If the shared object is already mmapped to address space, use it. + mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name) + if mmapped_path: + if os.path.exists(mmapped_path): + return mmapped_path + raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}") + + paths = [] + + import site + # First search the HIP runtime dynamic library packaged with PyTorch. It's very likely + # that we run Triton together with PyTorch. This makes sure we use the same dynamic + # library to avoid version mismatch. + site_packages = site.getsitepackages() + user_site = site.getusersitepackages() + if site.ENABLE_USER_SITE: # ENABLE_USER_SITE is initialized in getusersitepackages() + site_packages = [user_site] + site_packages + for path in site_packages: + path = os.path.join(path, "torch", "lib", lib_name) + if os.path.exists(path): + return path + paths.append(path) + + # Then try to see if developer provides a HIP runtime dynamic library using LD_LIBARAY_PATH. + env_ld_library_path = os.getenv("LD_LIBRARY_PATH") + if env_ld_library_path: + for d in env_ld_library_path.split(":"): + f = os.path.join(d, lib_name) + if os.path.exists(f): + return f + paths.append(f) + + # Afterwards try to search the loader dynamic library resolution paths. + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libamdhip64.so.6 (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so.6 + # libamdhip64.so (libc6,x86-64) => /opt/rocm-6.0.2/lib/libamdhip64.so + locs = [line.split()[-1] for line in libs.splitlines() if line.strip().endswith(lib_name)] + for loc in locs: + if os.path.exists(loc): + return loc + paths.append(loc) + + # As a last resort, guess if we have it in some common installation path. + common_install_path = os.path.join('/opt/rocm/lib/', lib_name) + if os.path.exists(common_install_path): + return common_install_path + paths.append(common_install_path) + + raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}") + + +def compile_module_from_src(src, name): + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, [], include_dir, []) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +class HIPUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(HIPUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + libhip_path = _get_path_to_hip_runtime_dylib() + src = Path(os.path.join(dirname, "driver.c")).read_text() + # Just do a simple search and replace here instead of templates or format strings. + # This way we don't need to escape-quote C code curly brackets and we can replace + # exactly once. + src = src.replace('/*py_libhip_search_path*/', libhip_path, 1) + mod = compile_module_from_src(src, "hip_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +# -------------------- Launcher ---------------------------- +def ty_to_cpp(ty): + if ty[0] == '*': + return "hipDeviceptr_t" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def make_launcher(constants, signature, warp_size): + + def _serialize_signature(sig): + if isinstance(sig, tuple): + return ','.join(map(_serialize_signature, sig)) + return sig + + def _extracted_type(ty): + if isinstance(ty, tuple): + val = ','.join(map(_extracted_type, ty)) + return f"[{val}]" + if ty[0] == '*': + return "PyObject*" + if ty in ("constexpr"): + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + if isinstance(ty, tuple): + val = ''.join(map(format_of, ty)) + return f"({val})" + if ty[0] == '*': + return "O" + if ty in ("constexpr"): + return "O" + return { + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "L", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty_to_cpp(ty)] + + args_format = ''.join([format_of(ty) for ty in signature.values()]) + format = "piiiKKOOOO" + args_format + signature = ','.join(map(_serialize_signature, signature.values())) + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") + internal_args_list = [] + for i, ty in signature.items(): + if ty[0] == "*": + internal_args_list.append(f"ptr_info{i}.dev_ptr") + elif ty != "constexpr": + internal_args_list.append(f"_arg{i}") + libhip_path = _get_path_to_hip_runtime_dylib() + + # generate glue code + params = list(range(len(signature))) + params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] + params.append("&global_scratch") + src = f""" +#define __HIP_PLATFORM_AMD__ +#include +#include +#include +#include +#include + +// The list of paths to search for the HIP runtime library. The caller Python +// code should substitute the search path placeholder. +static const char *hipLibSearchPaths[] = {{"{libhip_path}"}}; + +// The list of HIP dynamic library symbols and their signature we are interested +// in this file. +#define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \\ + FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \\ + FOR_EACH_ERR_FN(hipModuleLaunchKernel, hipFunction_t f, \\ + unsigned int gridDimX, unsigned int gridDimY, \\ + unsigned int gridDimZ, unsigned int blockDimX, \\ + unsigned int blockDimY, unsigned int blockDimZ, \\ + unsigned int sharedMemBytes, hipStream_t stream, \\ + void **kernelParams, void **extra) \\ + FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, hipFunction_t f, \\ + unsigned int gridDimX, unsigned int gridDimY, \\ + unsigned int gridDimZ, unsigned int blockDimX, \\ + unsigned int blockDimY, unsigned int blockDimZ, \\ + unsigned int sharedMemBytes, hipStream_t stream, \\ + void **kernelParams, void **extra) \\ + FOR_EACH_ERR_FN(hipPointerGetAttribute, void *data, \\ + hipPointer_attribute attribute, hipDeviceptr_t ptr) + +// The HIP symbol table for holding resolved dynamic library symbols. +struct HIPSymbolTable {{ +#define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \\ + hipError_t (*hipSymbolName)(__VA_ARGS__); +#define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \\ + const char *(*hipSymbolName)(__VA_ARGS__); + + HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD) +}}; + +static struct HIPSymbolTable hipSymbolTable; + +bool initSymbolTable() {{ + // Use the HIP runtime library loaded into the existing process if it exits. + void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD); + if (lib) {{ + // printf("[triton] chosen loaded libamdhip64.so in the process\\n"); + }} + + // Otherwise, go through the list of search paths to dlopen the first HIP + // driver library. + if (!lib) {{ + int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]); + for (int i = 0; i < n; ++i) {{ + void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL); + if (handle) {{ + lib = handle; + // printf("[triton] chosen %s\\n", hipLibSearchPaths[i]); + }} + }} + }} + if (!lib) {{ + PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so"); + return false; + }} + + // Resolve all symbols we are interested in. + dlerror(); // Clear existing errors + const char *error = NULL; +#define QUERY_EACH_FN(hipSymbolName, ...) \\ + *(void **)&hipSymbolTable.hipSymbolName = dlsym(lib, #hipSymbolName); \\ + error = dlerror(); \\ + if (error) {{ \\ + PyErr_SetString(PyExc_RuntimeError, \\ + "cannot query " #hipSymbolName " from libamdhip64.so"); \\ + dlclose(lib); \\ + return false; \\ + }} + + HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN) + + return true; +}} + +static inline void gpuAssert(hipError_t code, const char *file, int line) +{{ + if (code != HIP_SUCCESS) + {{ + const char* prefix = "Triton Error [HIP]: "; + const char* str = hipSymbolTable.hipGetErrorString(code); + char err[1024] = {{0}}; + snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str ); + PyErr_SetString(PyExc_RuntimeError, err); + }} +}} + +#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // printf("_launch hip kernel\\n"); + hipDeviceptr_t global_scratch = 0; + void *params[] = {{ {', '.join(params)} }}; + if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{ + HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0)); + return; + }} + if (gridX*gridY*gridZ > 0) {{ + HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} +}} + +typedef struct _DevicePtrInfo {{ + hipDeviceptr_t dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + uint64_t dev_ptr; + hipError_t status = hipSymbolTable.hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == hipErrorInvalidValue) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} + ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr; + Py_DECREF(ret); + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + // printf("launch\\n"); + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + int launch_cooperative_grid; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid, + &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + // extract kernel metadata + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + return NULL; + }} + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + if(PyErr_Occurred()) {{ + return NULL; + }} + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + if (!initSymbolTable()) {{ + return NULL; + }} + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class HIPLauncher(object): + + def __init__(self, src, metadata): + constants = src.constants if hasattr(src, "constants") else dict() + arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x + constants = {arg_idx(idx): value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} + src = make_launcher(constants, signature, metadata.warp_size) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + self.launch_cooperative_grid = metadata.launch_cooperative_grid + + def __call__(self, *args): + self.launch(self.launch_cooperative_grid, *args) + + +class HIPDriver(GPUDriver): + + def __init__(self): + super().__init__() + self.utils = HIPUtils() + self.launcher_cls = HIPLauncher + + def get_device_interface(self): + import torch + return torch.cuda + + @staticmethod + def is_active(): + try: + import torch + return torch.version.hip is not None + except ImportError: + return False + + def get_current_target(self): + device = self.get_current_device() + device_properties = self.utils.get_device_properties(device) + arch = device_properties['arch'] + warp_size = device_properties['warpSize'] + return GPUTarget("hip", arch.split(':')[0], warp_size) + + def get_active_torch_device(self): + import torch + # when using hip devices, the device string in pytorch is "cuda" + return torch.device("cuda", self.get_current_device()) + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # It's the same as the Nvidia backend. + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') + + def clear_cache(self, cache): + cache.zero_() diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_channel_descriptor.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_channel_descriptor.h new file mode 100644 index 000000000..6313bbbd0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_channel_descriptor.h @@ -0,0 +1,312 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_CHANNEL_DESCRIPTOR_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_CHANNEL_DESCRIPTOR_H + +#if !defined(__HIPCC_RTC__) +#include +#include +#include +#endif + +#ifdef __cplusplus + +extern "C" HIP_PUBLIC_API hipChannelFormatDesc +hipCreateChannelDesc(int x, int y, int z, int w, hipChannelFormatKind f); + +static inline hipChannelFormatDesc hipCreateChannelDescHalf() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindFloat); +} + +static inline hipChannelFormatDesc hipCreateChannelDescHalf1() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindFloat); +} + +static inline hipChannelFormatDesc hipCreateChannelDescHalf2() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindFloat); +} + +static inline hipChannelFormatDesc hipCreateChannelDescHalf4() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindFloat); +} + +template +static inline hipChannelFormatDesc hipCreateChannelDesc() { + return hipCreateChannelDesc(0, 0, 0, 0, hipChannelFormatKindNone); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(char) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed char) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned char) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned char) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed char) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned char) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed char) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindSigned); +} + +#ifndef __GNUC__ // vector3 is the same as vector4 +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned char) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed char) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindSigned); +} +#endif + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned char) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed char) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed short) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed short) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed short) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindSigned); +} + +#ifndef __GNUC__ +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed short) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindSigned); +} +#endif + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned short) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed short) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned int) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed int) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned int) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed int) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned int) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed int) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindSigned); +} + +#ifndef __GNUC__ +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned int) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed int) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindSigned); +} +#endif + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned int) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed int) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(float) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindFloat); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(float) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindFloat); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(float) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindFloat); +} + +#ifndef __GNUC__ +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(float) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindFloat); +} +#endif + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(float) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindFloat); +} + +#if !defined(__LP64__) + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned long) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed long) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned long) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed long) * 8; + return hipCreateChannelDesc(e, 0, 0, 0, hipChannelFormatKindSigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned long) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed long) * 8; + return hipCreateChannelDesc(e, e, 0, 0, hipChannelFormatKindSigned); +} + +#ifndef __GNUC__ +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned long) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed long) * 8; + return hipCreateChannelDesc(e, e, e, 0, hipChannelFormatKindSigned); +} +#endif + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(unsigned long) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindUnsigned); +} + +template <> inline hipChannelFormatDesc hipCreateChannelDesc() { + int e = (int)sizeof(signed long) * 8; + return hipCreateChannelDesc(e, e, e, e, hipChannelFormatKindSigned); +} +#endif /* !__LP64__ */ + +#else + +struct hipChannelFormatDesc hipCreateChannelDesc(int x, int y, int z, int w, + enum hipChannelFormatKind f); + +#endif /* __cplusplus */ + +#endif /* !HIP_INCLUDE_HIP_AMD_DETAIL_CHANNEL_DESCRIPTOR_H */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h new file mode 100644 index 000000000..824633ade --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h @@ -0,0 +1,1009 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_DEVICE_FUNCTIONS_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_DEVICE_FUNCTIONS_H + +#if !defined(__HIPCC_RTC__) +#include "host_defines.h" +#include "math_fwd.h" +#include +#include +#include +#include +#include +#include +#endif // !defined(__HIPCC_RTC__) + +#if defined(__clang__) && defined(__HIP__) +extern "C" __device__ int printf(const char *fmt, ...); +#else +template +static inline __device__ void printf(const char *format, All... all) {} +#endif + +extern "C" __device__ unsigned long long __ockl_steadyctr_u64(); + +/* +Integer Intrinsics +*/ + +// integer intrinsic function __poc __clz __ffs __brev +__device__ static inline unsigned int __popc(unsigned int input) { + return __builtin_popcount(input); +} +__device__ static inline unsigned int __popcll(unsigned long long int input) { + return __builtin_popcountll(input); +} + +__device__ static inline int __clz(int input) { + return __ockl_clz_u32((uint)input); +} + +__device__ static inline int __clzll(long long int input) { + return __ockl_clz_u64((uint64_t)input); +} + +__device__ static inline unsigned int __ffs(unsigned int input) { + return (input == 0 ? -1 : __builtin_ctz(input)) + 1; +} + +__device__ static inline unsigned int __ffsll(unsigned long long int input) { + return (input == 0 ? -1 : __builtin_ctzll(input)) + 1; +} + +__device__ static inline unsigned int __ffs(int input) { + return (input == 0 ? -1 : __builtin_ctz(input)) + 1; +} + +__device__ static inline unsigned int __ffsll(long long int input) { + return (input == 0 ? -1 : __builtin_ctzll(input)) + 1; +} + +// Given a 32/64-bit value exec mask and an integer value base (between 0 and +// WAVEFRONT_SIZE), find the n-th (given by offset) set bit in the exec mask +// from the base bit, and return the bit position. If not found, return -1. +__device__ static int32_t __fns64(uint64_t mask, uint32_t base, + int32_t offset) { + uint64_t temp_mask = mask; + int32_t temp_offset = offset; + + if (offset == 0) { + temp_mask &= (1 << base); + temp_offset = 1; + } else if (offset < 0) { + temp_mask = __builtin_bitreverse64(mask); + base = 63 - base; + temp_offset = -offset; + } + + temp_mask = temp_mask & ((~0ULL) << base); + if (__builtin_popcountll(temp_mask) < temp_offset) + return -1; + int32_t total = 0; + for (int i = 0x20; i > 0; i >>= 1) { + uint64_t temp_mask_lo = temp_mask & ((1ULL << i) - 1); + int32_t pcnt = __builtin_popcountll(temp_mask_lo); + if (pcnt < temp_offset) { + temp_mask = temp_mask >> i; + temp_offset -= pcnt; + total += i; + } else { + temp_mask = temp_mask_lo; + } + } + if (offset < 0) + return 63 - total; + else + return total; +} + +__device__ static int32_t __fns32(uint64_t mask, uint32_t base, + int32_t offset) { + uint64_t temp_mask = mask; + int32_t temp_offset = offset; + if (offset == 0) { + temp_mask &= (1 << base); + temp_offset = 1; + } else if (offset < 0) { + temp_mask = __builtin_bitreverse64(mask); + base = 63 - base; + temp_offset = -offset; + } + temp_mask = temp_mask & ((~0ULL) << base); + if (__builtin_popcountll(temp_mask) < temp_offset) + return -1; + int32_t total = 0; + for (int i = 0x20; i > 0; i >>= 1) { + uint64_t temp_mask_lo = temp_mask & ((1ULL << i) - 1); + int32_t pcnt = __builtin_popcountll(temp_mask_lo); + if (pcnt < temp_offset) { + temp_mask = temp_mask >> i; + temp_offset -= pcnt; + total += i; + } else { + temp_mask = temp_mask_lo; + } + } + if (offset < 0) + return 63 - total; + else + return total; +} +__device__ static inline unsigned int __brev(unsigned int input) { + return __builtin_bitreverse32(input); +} + +__device__ static inline unsigned long long int +__brevll(unsigned long long int input) { + return __builtin_bitreverse64(input); +} + +__device__ static inline unsigned int __lastbit_u32_u64(uint64_t input) { + return input == 0 ? -1 : __builtin_ctzl(input); +} + +__device__ static inline unsigned int +__bitextract_u32(unsigned int src0, unsigned int src1, unsigned int src2) { + uint32_t offset = src1 & 31; + uint32_t width = src2 & 31; + return width == 0 ? 0 : (src0 << (32 - offset - width)) >> (32 - width); +} + +__device__ static inline uint64_t +__bitextract_u64(uint64_t src0, unsigned int src1, unsigned int src2) { + uint64_t offset = src1 & 63; + uint64_t width = src2 & 63; + return width == 0 ? 0 : (src0 << (64 - offset - width)) >> (64 - width); +} + +__device__ static inline unsigned int __bitinsert_u32(unsigned int src0, + unsigned int src1, + unsigned int src2, + unsigned int src3) { + uint32_t offset = src2 & 31; + uint32_t width = src3 & 31; + uint32_t mask = (1 << width) - 1; + return ((src0 & ~(mask << offset)) | ((src1 & mask) << offset)); +} + +__device__ static inline uint64_t __bitinsert_u64(uint64_t src0, uint64_t src1, + unsigned int src2, + unsigned int src3) { + uint64_t offset = src2 & 63; + uint64_t width = src3 & 63; + uint64_t mask = (1ULL << width) - 1; + return ((src0 & ~(mask << offset)) | ((src1 & mask) << offset)); +} + +__device__ inline unsigned int __funnelshift_l(unsigned int lo, unsigned int hi, + unsigned int shift) { + uint32_t mask_shift = shift & 31; + return mask_shift == 0 ? hi + : __builtin_amdgcn_alignbit(hi, lo, 32 - mask_shift); +} + +__device__ inline unsigned int +__funnelshift_lc(unsigned int lo, unsigned int hi, unsigned int shift) { + uint32_t min_shift = shift >= 32 ? 32 : shift; + return min_shift == 0 ? hi + : __builtin_amdgcn_alignbit(hi, lo, 32 - min_shift); +} + +__device__ inline unsigned int __funnelshift_r(unsigned int lo, unsigned int hi, + unsigned int shift) { + return __builtin_amdgcn_alignbit(hi, lo, shift); +} + +__device__ inline unsigned int +__funnelshift_rc(unsigned int lo, unsigned int hi, unsigned int shift) { + return shift >= 32 ? hi : __builtin_amdgcn_alignbit(hi, lo, shift); +} + +__device__ static unsigned int __byte_perm(unsigned int x, unsigned int y, + unsigned int s); +__device__ static unsigned int __hadd(int x, int y); +__device__ static int __mul24(int x, int y); +__device__ static long long int __mul64hi(long long int x, long long int y); +__device__ static int __mulhi(int x, int y); +__device__ static int __rhadd(int x, int y); +__device__ static unsigned int __sad(int x, int y, unsigned int z); +__device__ static unsigned int __uhadd(unsigned int x, unsigned int y); +__device__ static int __umul24(unsigned int x, unsigned int y); +__device__ static unsigned long long int __umul64hi(unsigned long long int x, + unsigned long long int y); +__device__ static unsigned int __umulhi(unsigned int x, unsigned int y); +__device__ static unsigned int __urhadd(unsigned int x, unsigned int y); +__device__ static unsigned int __usad(unsigned int x, unsigned int y, + unsigned int z); + +struct ucharHolder { + union { + unsigned char c[4]; + unsigned int ui; + }; +} __attribute__((aligned(4))); + +struct uchar2Holder { + union { + unsigned int ui[2]; + unsigned char c[8]; + }; +} __attribute__((aligned(8))); + +__device__ static inline unsigned int +__byte_perm(unsigned int x, unsigned int y, unsigned int s) { + struct uchar2Holder cHoldVal; + struct ucharHolder cHoldKey; + cHoldKey.ui = s; + cHoldVal.ui[0] = x; + cHoldVal.ui[1] = y; + unsigned int result; + result = cHoldVal.c[cHoldKey.c[0] & 0x07]; + result += (cHoldVal.c[(cHoldKey.c[0] & 0x70) >> 4] << 8); + result += (cHoldVal.c[cHoldKey.c[1] & 0x07] << 16); + result += (cHoldVal.c[(cHoldKey.c[1] & 0x70) >> 4] << 24); + return result; +} + +__device__ static inline unsigned int __hadd(int x, int y) { + int z = x + y; + int sign = z & 0x8000000; + int value = z & 0x7FFFFFFF; + return ((value) >> 1 || sign); +} + +__device__ static inline int __mul24(int x, int y) { + return __ockl_mul24_i32(x, y); +} + +__device__ static inline long long __mul64hi(long long int x, long long int y) { + unsigned long long x0 = (unsigned long long)x & 0xffffffffUL; + long long x1 = x >> 32; + unsigned long long y0 = (unsigned long long)y & 0xffffffffUL; + long long y1 = y >> 32; + unsigned long long z0 = x0 * y0; + long long t = x1 * y0 + (z0 >> 32); + long long z1 = t & 0xffffffffL; + long long z2 = t >> 32; + z1 = x0 * y1 + z1; + return x1 * y1 + z2 + (z1 >> 32); +} + +__device__ static inline int __mulhi(int x, int y) { + return __ockl_mul_hi_i32(x, y); +} + +__device__ static inline int __rhadd(int x, int y) { + int z = x + y + 1; + int sign = z & 0x8000000; + int value = z & 0x7FFFFFFF; + return ((value) >> 1 || sign); +} +__device__ static inline unsigned int __sad(int x, int y, unsigned int z) { + return x > y ? x - y + z : y - x + z; +} +__device__ static inline unsigned int __uhadd(unsigned int x, unsigned int y) { + return (x + y) >> 1; +} +__device__ static inline int __umul24(unsigned int x, unsigned int y) { + return __ockl_mul24_u32(x, y); +} + +__device__ static inline unsigned long long +__umul64hi(unsigned long long int x, unsigned long long int y) { + unsigned long long x0 = x & 0xffffffffUL; + unsigned long long x1 = x >> 32; + unsigned long long y0 = y & 0xffffffffUL; + unsigned long long y1 = y >> 32; + unsigned long long z0 = x0 * y0; + unsigned long long t = x1 * y0 + (z0 >> 32); + unsigned long long z1 = t & 0xffffffffUL; + unsigned long long z2 = t >> 32; + z1 = x0 * y1 + z1; + return x1 * y1 + z2 + (z1 >> 32); +} + +__device__ static inline unsigned int __umulhi(unsigned int x, unsigned int y) { + return __ockl_mul_hi_u32(x, y); +} +__device__ static inline unsigned int __urhadd(unsigned int x, unsigned int y) { + return (x + y + 1) >> 1; +} +__device__ static inline unsigned int __usad(unsigned int x, unsigned int y, + unsigned int z) { + return __ockl_sadd_u32(x, y, z); +} + +__device__ static inline unsigned int __mbcnt_lo(unsigned int x, + unsigned int y) { + return __builtin_amdgcn_mbcnt_lo(x, y); +}; + +__device__ static inline unsigned int __mbcnt_hi(unsigned int x, + unsigned int y) { + return __builtin_amdgcn_mbcnt_hi(x, y); +}; + +/* +HIP specific device functions +*/ + +#if !defined(__HIPCC_RTC__) +#include "amd_warp_functions.h" +#include "amd_warp_sync_functions.h" +#endif + +#define MASK1 0x00ff00ff +#define MASK2 0xff00ff00 + +__device__ static inline char4 __hip_hc_add8pk(char4 in1, char4 in2) { + char4 out; + unsigned one1 = in1.w & MASK1; + unsigned one2 = in2.w & MASK1; + out.w = (one1 + one2) & MASK1; + one1 = in1.w & MASK2; + one2 = in2.w & MASK2; + out.w = out.w | ((one1 + one2) & MASK2); + return out; +} + +__device__ static inline char4 __hip_hc_sub8pk(char4 in1, char4 in2) { + char4 out; + unsigned one1 = in1.w & MASK1; + unsigned one2 = in2.w & MASK1; + out.w = (one1 - one2) & MASK1; + one1 = in1.w & MASK2; + one2 = in2.w & MASK2; + out.w = out.w | ((one1 - one2) & MASK2); + return out; +} + +__device__ static inline char4 __hip_hc_mul8pk(char4 in1, char4 in2) { + char4 out; + unsigned one1 = in1.w & MASK1; + unsigned one2 = in2.w & MASK1; + out.w = (one1 * one2) & MASK1; + one1 = in1.w & MASK2; + one2 = in2.w & MASK2; + out.w = out.w | ((one1 * one2) & MASK2); + return out; +} + +__device__ static inline float __double2float_rd(double x) { + return __ocml_cvtrtn_f32_f64(x); +} +__device__ static inline float __double2float_rn(double x) { return x; } +__device__ static inline float __double2float_ru(double x) { + return __ocml_cvtrtp_f32_f64(x); +} +__device__ static inline float __double2float_rz(double x) { + return __ocml_cvtrtz_f32_f64(x); +} + +__device__ static inline int __double2hiint(double x) { + static_assert(sizeof(double) == 2 * sizeof(int), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &x, sizeof(tmp)); + + return tmp[1]; +} +__device__ static inline int __double2loint(double x) { + static_assert(sizeof(double) == 2 * sizeof(int), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &x, sizeof(tmp)); + + return tmp[0]; +} + +__device__ static inline int __double2int_rd(double x) { + return (int)__ocml_floor_f64(x); +} +__device__ static inline int __double2int_rn(double x) { + return (int)__ocml_rint_f64(x); +} +__device__ static inline int __double2int_ru(double x) { + return (int)__ocml_ceil_f64(x); +} +__device__ static inline int __double2int_rz(double x) { return (int)x; } + +__device__ static inline long long int __double2ll_rd(double x) { + return (long long)__ocml_floor_f64(x); +} +__device__ static inline long long int __double2ll_rn(double x) { + return (long long)__ocml_rint_f64(x); +} +__device__ static inline long long int __double2ll_ru(double x) { + return (long long)__ocml_ceil_f64(x); +} +__device__ static inline long long int __double2ll_rz(double x) { + return (long long)x; +} + +__device__ static inline unsigned int __double2uint_rd(double x) { + return (unsigned int)__ocml_floor_f64(x); +} +__device__ static inline unsigned int __double2uint_rn(double x) { + return (unsigned int)__ocml_rint_f64(x); +} +__device__ static inline unsigned int __double2uint_ru(double x) { + return (unsigned int)__ocml_ceil_f64(x); +} +__device__ static inline unsigned int __double2uint_rz(double x) { + return (unsigned int)x; +} + +__device__ static inline unsigned long long int __double2ull_rd(double x) { + return (unsigned long long int)__ocml_floor_f64(x); +} +__device__ static inline unsigned long long int __double2ull_rn(double x) { + return (unsigned long long int)__ocml_rint_f64(x); +} +__device__ static inline unsigned long long int __double2ull_ru(double x) { + return (unsigned long long int)__ocml_ceil_f64(x); +} +__device__ static inline unsigned long long int __double2ull_rz(double x) { + return (unsigned long long int)x; +} +__device__ static inline long long int __double_as_longlong(double x) { + static_assert(sizeof(long long) == sizeof(double), ""); + + long long tmp; + __builtin_memcpy(&tmp, &x, sizeof(tmp)); + + return tmp; +} + +/* +__device__ unsigned short __float2half_rn(float x); +__device__ float __half2float(unsigned short); + +The above device function are not a valid . +Use +__device__ __half __float2half_rn(float x); +__device__ float __half2float(__half); +from hip_fp16.h + +CUDA implements half as unsigned short whereas, HIP doesn't. + +*/ + +__device__ static inline int __float2int_rd(float x) { + return (int)__ocml_floor_f32(x); +} +__device__ static inline int __float2int_rn(float x) { + return (int)__ocml_rint_f32(x); +} +__device__ static inline int __float2int_ru(float x) { + return (int)__ocml_ceil_f32(x); +} +__device__ static inline int __float2int_rz(float x) { + return (int)__ocml_trunc_f32(x); +} + +__device__ static inline long long int __float2ll_rd(float x) { + return (long long int)__ocml_floor_f32(x); +} +__device__ static inline long long int __float2ll_rn(float x) { + return (long long int)__ocml_rint_f32(x); +} +__device__ static inline long long int __float2ll_ru(float x) { + return (long long int)__ocml_ceil_f32(x); +} +__device__ static inline long long int __float2ll_rz(float x) { + return (long long int)x; +} + +__device__ static inline unsigned int __float2uint_rd(float x) { + return (unsigned int)__ocml_floor_f32(x); +} +__device__ static inline unsigned int __float2uint_rn(float x) { + return (unsigned int)__ocml_rint_f32(x); +} +__device__ static inline unsigned int __float2uint_ru(float x) { + return (unsigned int)__ocml_ceil_f32(x); +} +__device__ static inline unsigned int __float2uint_rz(float x) { + return (unsigned int)x; +} + +__device__ static inline unsigned long long int __float2ull_rd(float x) { + return (unsigned long long int)__ocml_floor_f32(x); +} +__device__ static inline unsigned long long int __float2ull_rn(float x) { + return (unsigned long long int)__ocml_rint_f32(x); +} +__device__ static inline unsigned long long int __float2ull_ru(float x) { + return (unsigned long long int)__ocml_ceil_f32(x); +} +__device__ static inline unsigned long long int __float2ull_rz(float x) { + return (unsigned long long int)x; +} + +__device__ static inline int __float_as_int(float x) { + static_assert(sizeof(int) == sizeof(float), ""); + + int tmp; + __builtin_memcpy(&tmp, &x, sizeof(tmp)); + + return tmp; +} + +__device__ static inline unsigned int __float_as_uint(float x) { + static_assert(sizeof(unsigned int) == sizeof(float), ""); + + unsigned int tmp; + __builtin_memcpy(&tmp, &x, sizeof(tmp)); + + return tmp; +} + +__device__ static inline double __hiloint2double(int hi, int lo) { + static_assert(sizeof(double) == sizeof(uint64_t), ""); + + uint64_t tmp0 = + (static_cast(hi) << 32ull) | static_cast(lo); + double tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + + return tmp1; +} + +__device__ static inline double __int2double_rn(int x) { return (double)x; } + +__device__ static inline float __int2float_rd(int x) { + return __ocml_cvtrtn_f32_s32(x); +} +__device__ static inline float __int2float_rn(int x) { return (float)x; } +__device__ static inline float __int2float_ru(int x) { + return __ocml_cvtrtp_f32_s32(x); +} +__device__ static inline float __int2float_rz(int x) { + return __ocml_cvtrtz_f32_s32(x); +} + +__device__ static inline float __int_as_float(int x) { + static_assert(sizeof(float) == sizeof(int), ""); + + float tmp; + __builtin_memcpy(&tmp, &x, sizeof(tmp)); + + return tmp; +} + +__device__ static inline double __ll2double_rd(long long int x) { + return __ocml_cvtrtn_f64_s64(x); +} +__device__ static inline double __ll2double_rn(long long int x) { + return (double)x; +} +__device__ static inline double __ll2double_ru(long long int x) { + return __ocml_cvtrtp_f64_s64(x); +} +__device__ static inline double __ll2double_rz(long long int x) { + return __ocml_cvtrtz_f64_s64(x); +} + +__device__ static inline float __ll2float_rd(long long int x) { + return __ocml_cvtrtn_f32_s64(x); +} +__device__ static inline float __ll2float_rn(long long int x) { + return (float)x; +} +__device__ static inline float __ll2float_ru(long long int x) { + return __ocml_cvtrtp_f32_s64(x); +} +__device__ static inline float __ll2float_rz(long long int x) { + return __ocml_cvtrtz_f32_s64(x); +} + +__device__ static inline double __longlong_as_double(long long int x) { + static_assert(sizeof(double) == sizeof(long long), ""); + + double tmp; + __builtin_memcpy(&tmp, &x, sizeof(tmp)); + + return tmp; +} + +__device__ static inline double __uint2double_rn(unsigned int x) { + return (double)x; +} + +__device__ static inline float __uint2float_rd(unsigned int x) { + return __ocml_cvtrtn_f32_u32(x); +} +__device__ static inline float __uint2float_rn(unsigned int x) { + return (float)x; +} +__device__ static inline float __uint2float_ru(unsigned int x) { + return __ocml_cvtrtp_f32_u32(x); +} +__device__ static inline float __uint2float_rz(unsigned int x) { + return __ocml_cvtrtz_f32_u32(x); +} + +__device__ static inline float __uint_as_float(unsigned int x) { + static_assert(sizeof(float) == sizeof(unsigned int), ""); + + float tmp; + __builtin_memcpy(&tmp, &x, sizeof(tmp)); + + return tmp; +} + +__device__ static inline double __ull2double_rd(unsigned long long int x) { + return __ocml_cvtrtn_f64_u64(x); +} +__device__ static inline double __ull2double_rn(unsigned long long int x) { + return (double)x; +} +__device__ static inline double __ull2double_ru(unsigned long long int x) { + return __ocml_cvtrtp_f64_u64(x); +} +__device__ static inline double __ull2double_rz(unsigned long long int x) { + return __ocml_cvtrtz_f64_u64(x); +} + +__device__ static inline float __ull2float_rd(unsigned long long int x) { + return __ocml_cvtrtn_f32_u64(x); +} +__device__ static inline float __ull2float_rn(unsigned long long int x) { + return (float)x; +} +__device__ static inline float __ull2float_ru(unsigned long long int x) { + return __ocml_cvtrtp_f32_u64(x); +} +__device__ static inline float __ull2float_rz(unsigned long long int x) { + return __ocml_cvtrtz_f32_u64(x); +} + +#if defined(__clang__) && defined(__HIP__) + +// Clock functions +__device__ long long int __clock64(); +__device__ long long int __clock(); +__device__ long long int clock64(); +__device__ long long int clock(); +__device__ long long int wall_clock64(); +// hip.amdgcn.bc - named sync +__device__ void __named_sync(); + +#ifdef __HIP_DEVICE_COMPILE__ + +// Clock function to return GPU core cycle count. +// GPU can change its core clock frequency at runtime. The maximum frequency can +// be queried through hipDeviceAttributeClockRate attribute. +__device__ inline __attribute((always_inline)) long long int __clock64() { +#if __has_builtin(__builtin_amdgcn_s_memtime) + // Exists on gfx8, gfx9, gfx10.1, gfx10.2, gfx10.3 + return (long long int)__builtin_amdgcn_s_memtime(); +#else + // Subject to change when better solution available + return (long long int)__builtin_readcyclecounter(); +#endif +} + +__device__ inline __attribute((always_inline)) long long int __clock() { + return __clock64(); +} + +// Clock function to return wall clock count at a constant frequency that can be +// queried through hipDeviceAttributeWallClockRate attribute. +__device__ inline __attribute__((always_inline)) long long int wall_clock64() { + return (long long int)__ockl_steadyctr_u64(); +} + +__device__ inline __attribute__((always_inline)) long long int clock64() { + return __clock64(); +} + +__device__ inline __attribute__((always_inline)) long long int clock() { + return __clock(); +} + +// hip.amdgcn.bc - named sync +__device__ inline void __named_sync() { __builtin_amdgcn_s_barrier(); } + +#endif // __HIP_DEVICE_COMPILE__ + +// hip.amdgcn.bc - lanemask +__device__ inline uint64_t __lanemask_gt() { + uint32_t lane = __ockl_lane_u32(); + if (lane == 63) + return 0; + uint64_t ballot = __ballot64(1); + uint64_t mask = (~((uint64_t)0)) << (lane + 1); + return mask & ballot; +} + +__device__ inline uint64_t __lanemask_lt() { + uint32_t lane = __ockl_lane_u32(); + int64_t ballot = __ballot64(1); + uint64_t mask = ((uint64_t)1 << lane) - (uint64_t)1; + return mask & ballot; +} + +__device__ inline uint64_t __lanemask_eq() { + uint32_t lane = __ockl_lane_u32(); + int64_t mask = ((uint64_t)1 << lane); + return mask; +} + +__device__ inline void *__local_to_generic(void *p) { return p; } + +#ifdef __HIP_DEVICE_COMPILE__ +__device__ inline void *__get_dynamicgroupbaseptr() { + // Get group segment base pointer. + return (char *)__local_to_generic( + (void *)__to_local(__builtin_amdgcn_groupstaticsize())); +} +#else +__device__ void *__get_dynamicgroupbaseptr(); +#endif // __HIP_DEVICE_COMPILE__ + +__device__ inline void *__amdgcn_get_dynamicgroupbaseptr() { + return __get_dynamicgroupbaseptr(); +} + +// Memory Fence Functions +__device__ inline static void __threadfence() { + __builtin_amdgcn_fence(__ATOMIC_SEQ_CST, "agent"); +} + +__device__ inline static void __threadfence_block() { + __builtin_amdgcn_fence(__ATOMIC_SEQ_CST, "workgroup"); +} + +__device__ inline static void __threadfence_system() { + __builtin_amdgcn_fence(__ATOMIC_SEQ_CST, ""); +} +__device__ inline static void __work_group_barrier(__cl_mem_fence_flags flags) { + if (flags) { + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup"); + } else { + __builtin_amdgcn_s_barrier(); + } +} + +__device__ inline static void __barrier(int n) { + __work_group_barrier((__cl_mem_fence_flags)n); +} + +__device__ inline __attribute__((convergent)) void __syncthreads() { + __barrier(__CLK_LOCAL_MEM_FENCE); +} + +__device__ inline __attribute__((convergent)) int +__syncthreads_count(int predicate) { + return __ockl_wgred_add_i32(!!predicate); +} + +__device__ inline __attribute__((convergent)) int +__syncthreads_and(int predicate) { + return __ockl_wgred_and_i32(!!predicate); +} + +__device__ inline __attribute__((convergent)) int +__syncthreads_or(int predicate) { + return __ockl_wgred_or_i32(!!predicate); +} + +// hip.amdgcn.bc - device routine +/* + HW_ID Register bit structure for RDNA2 & RDNA3 + WAVE_ID 4:0 Wave id within the SIMD. + SIMD_ID 9:8 SIMD_ID within the WGP: [0] = row, [1] = column. + WGP_ID 13:10 Physical WGP ID. + SA_ID 16 Shader Array ID + SE_ID 20:18 Shader Engine the wave is assigned to for gfx11 + SE_ID 19:18 Shader Engine the wave is assigned to for gfx10 + DP_RATE 31:29 Number of double-precision float units per SIMD + + HW_ID Register bit structure for GCN and CDNA + WAVE_ID 3:0 Wave buffer slot number. 0-9. + SIMD_ID 5:4 SIMD which the wave is assigned to within the CU. + PIPE_ID 7:6 Pipeline from which the wave was dispatched. + CU_ID 11:8 Compute Unit the wave is assigned to. + SH_ID 12 Shader Array (within an SE) the wave is assigned to. + SE_ID 15:13 Shader Engine the wave is assigned to for gfx908, gfx90a, + gfx940-942 14:13 Shader Engine the wave is assigned to for Vega. TG_ID 19:16 + Thread-group ID VM_ID 23:20 Virtual Memory ID QUEUE_ID 26:24 Queue + from which this wave was dispatched. STATE_ID 29:27 State ID (graphics + only, not compute). ME_ID 31:30 Micro-engine ID. + + XCC_ID Register bit structure for gfx940 + XCC_ID 3:0 XCC the wave is assigned to. + */ + +#if (defined(__GFX10__) || defined(__GFX11__)) +#define HW_ID 23 +#else +#define HW_ID 4 +#endif + +#if (defined(__GFX10__) || defined(__GFX11__)) +#define HW_ID_WGP_ID_SIZE 4 +#define HW_ID_WGP_ID_OFFSET 10 +#if (defined(__AMDGCN_CUMODE__)) +#define HW_ID_CU_ID_SIZE 1 +#define HW_ID_CU_ID_OFFSET 8 +#endif +#else +#define HW_ID_CU_ID_SIZE 4 +#define HW_ID_CU_ID_OFFSET 8 +#endif + +#if (defined(__gfx908__) || defined(__gfx90a__) || defined(__GFX11__)) +#define HW_ID_SE_ID_SIZE 3 +#else // 4 SEs/XCC for gfx940-942 +#define HW_ID_SE_ID_SIZE 2 +#endif +#if (defined(__GFX10__) || defined(__GFX11__)) +#define HW_ID_SE_ID_OFFSET 18 +#define HW_ID_SA_ID_OFFSET 16 +#define HW_ID_SA_ID_SIZE 1 +#else +#define HW_ID_SE_ID_OFFSET 13 +#endif + +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#define XCC_ID 20 +#define XCC_ID_XCC_ID_SIZE 4 +#define XCC_ID_XCC_ID_OFFSET 0 +#endif + +#if (!defined(__HIP_NO_IMAGE_SUPPORT) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))) +#define __HIP_NO_IMAGE_SUPPORT 1 +#endif + +/* + Encoding of parameter bitmask + HW_ID 5:0 HW_ID + OFFSET 10:6 Range: 0..31 + SIZE 15:11 Range: 1..32 + */ + +#define GETREG_IMMED(SZ, OFF, REG) (((SZ) << 11) | ((OFF) << 6) | (REG)) + +/* + __smid returns the wave's assigned Compute Unit and Shader Engine. + The Compute Unit, CU_ID returned in bits 3:0, and Shader Engine, SE_ID in bits + 5:4. Note: the results vary over time. SZ minus 1 since SIZE is 1-based. +*/ +__device__ inline unsigned __smid(void) { + unsigned se_id = __builtin_amdgcn_s_getreg( + GETREG_IMMED(HW_ID_SE_ID_SIZE - 1, HW_ID_SE_ID_OFFSET, HW_ID)); +#if (defined(__GFX10__) || defined(__GFX11__)) + unsigned wgp_id = __builtin_amdgcn_s_getreg( + GETREG_IMMED(HW_ID_WGP_ID_SIZE - 1, HW_ID_WGP_ID_OFFSET, HW_ID)); + unsigned sa_id = __builtin_amdgcn_s_getreg( + GETREG_IMMED(HW_ID_SA_ID_SIZE - 1, HW_ID_SA_ID_OFFSET, HW_ID)); +#if (defined(__AMDGCN_CUMODE__)) + unsigned cu_id = __builtin_amdgcn_s_getreg( + GETREG_IMMED(HW_ID_CU_ID_SIZE - 1, HW_ID_CU_ID_OFFSET, HW_ID)); +#endif +#else +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + unsigned xcc_id = __builtin_amdgcn_s_getreg( + GETREG_IMMED(XCC_ID_XCC_ID_SIZE - 1, XCC_ID_XCC_ID_OFFSET, XCC_ID)); +#endif + unsigned cu_id = __builtin_amdgcn_s_getreg( + GETREG_IMMED(HW_ID_CU_ID_SIZE - 1, HW_ID_CU_ID_OFFSET, HW_ID)); +#endif +#if (defined(__GFX10__) || defined(__GFX11__)) + unsigned temp = se_id; + temp = (temp << HW_ID_SA_ID_SIZE) | sa_id; + temp = (temp << HW_ID_WGP_ID_SIZE) | wgp_id; +#if (defined(__AMDGCN_CUMODE__)) + temp = (temp << HW_ID_CU_ID_SIZE) | cu_id; +#endif + return temp; + // TODO : CU Mode impl +#elif (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + unsigned temp = xcc_id; + temp = (temp << HW_ID_SE_ID_SIZE) | se_id; + temp = (temp << HW_ID_CU_ID_SIZE) | cu_id; + return temp; +#else + return (se_id << HW_ID_CU_ID_SIZE) + cu_id; +#endif +} + +/** + * Map HIP_DYNAMIC_SHARED to "extern __shared__" for compatibility with old HIP + * applications To be removed in a future release. + */ +#define HIP_DYNAMIC_SHARED(type, var) extern __shared__ type var[]; +#define HIP_DYNAMIC_SHARED_ATTRIBUTE + +#endif // defined(__clang__) && defined(__HIP__) + +// loop unrolling +static inline __device__ void *__hip_hc_memcpy(void *dst, const void *src, + size_t size) { + auto dstPtr = static_cast(dst); + auto srcPtr = static_cast(src); + + while (size >= 4u) { + dstPtr[0] = srcPtr[0]; + dstPtr[1] = srcPtr[1]; + dstPtr[2] = srcPtr[2]; + dstPtr[3] = srcPtr[3]; + + size -= 4u; + srcPtr += 4u; + dstPtr += 4u; + } + switch (size) { + case 3: + dstPtr[2] = srcPtr[2]; + case 2: + dstPtr[1] = srcPtr[1]; + case 1: + dstPtr[0] = srcPtr[0]; + } + + return dst; +} + +static inline __device__ void *__hip_hc_memset(void *dst, unsigned char val, + size_t size) { + auto dstPtr = static_cast(dst); + + while (size >= 4u) { + dstPtr[0] = val; + dstPtr[1] = val; + dstPtr[2] = val; + dstPtr[3] = val; + + size -= 4u; + dstPtr += 4u; + } + switch (size) { + case 3: + dstPtr[2] = val; + case 2: + dstPtr[1] = val; + case 1: + dstPtr[0] = val; + } + + return dst; +} +#ifndef __OPENMP_AMDGCN__ +static inline __device__ void *memcpy(void *dst, const void *src, size_t size) { + return __hip_hc_memcpy(dst, src, size); +} + +static inline __device__ void *memset(void *ptr, int val, size_t size) { + unsigned char val8 = static_cast(val); + return __hip_hc_memset(ptr, val8, size); +} +#endif // !__OPENMP_AMDGCN__ + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_atomic.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_atomic.h new file mode 100644 index 000000000..db4f2b189 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_atomic.h @@ -0,0 +1,1518 @@ +/* +Copyright (c) 2015 - Present Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#if !defined(__HIPCC_RTC__) +#include "amd_device_functions.h" +#endif + +#if __has_builtin(__hip_atomic_compare_exchange_strong) + +template struct Cond_t; + +template struct Cond_t { + using type = T; +}; +template struct Cond_t { + using type = F; +}; + +#if !__HIP_DEVICE_COMPILE__ +// TODO: Remove this after compiler pre-defines the following Macros. +#define __HIP_MEMORY_SCOPE_SINGLETHREAD 1 +#define __HIP_MEMORY_SCOPE_WAVEFRONT 2 +#define __HIP_MEMORY_SCOPE_WORKGROUP 3 +#define __HIP_MEMORY_SCOPE_AGENT 4 +#define __HIP_MEMORY_SCOPE_SYSTEM 5 +#endif + +#if !defined(__HIPCC_RTC__) +#include "amd_hip_unsafe_atomics.h" +#endif + +// Atomic expanders +template +inline __attribute__((always_inline, device)) T hip_cas_expander(T *p, T x, + Op op, + F f) noexcept { + using FP = __attribute__((address_space(0))) const void *; + + __device__ extern bool is_shared_workaround(FP) asm("llvm.amdgcn.is.shared"); + + if (is_shared_workaround((FP)p)) + return f(); + + using U = typename Cond_t::type; + + auto q = reinterpret_cast(p); + + U tmp0{__hip_atomic_load(q, mem_order, mem_scope)}; + U tmp1; + do { + tmp1 = tmp0; + + op(reinterpret_cast(tmp1), x); + } while (!__hip_atomic_compare_exchange_strong(q, &tmp0, tmp1, mem_order, + mem_order, mem_scope)); + + return reinterpret_cast(tmp0); +} + +template +inline __attribute__((always_inline, device)) T +hip_cas_extrema_expander(T *p, T x, Cmp cmp, F f) noexcept { + using FP = __attribute__((address_space(0))) const void *; + + __device__ extern bool is_shared_workaround(FP) asm("llvm.amdgcn.is.shared"); + + if (is_shared_workaround((FP)p)) + return f(); + + using U = typename Cond_t::type; + + auto q = reinterpret_cast(p); + + U tmp{__hip_atomic_load(q, mem_order, mem_scope)}; + while (cmp(x, reinterpret_cast(tmp)) && + !__hip_atomic_compare_exchange_strong(q, &tmp, x, mem_order, mem_order, + mem_scope)) + ; + + return reinterpret_cast(tmp); +} + +__device__ inline int atomicCAS(int *address, int compare, int val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + return compare; +} + +__device__ inline int atomicCAS_system(int *address, int compare, int val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + return compare; +} + +__device__ inline unsigned int +atomicCAS(unsigned int *address, unsigned int compare, unsigned int val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + return compare; +} + +__device__ inline unsigned int atomicCAS_system(unsigned int *address, + unsigned int compare, + unsigned int val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + return compare; +} + +__device__ inline unsigned long +atomicCAS(unsigned long *address, unsigned long compare, unsigned long val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + return compare; +} + +__device__ inline unsigned long atomicCAS_system(unsigned long *address, + unsigned long compare, + unsigned long val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + return compare; +} + +__device__ inline unsigned long long atomicCAS(unsigned long long *address, + unsigned long long compare, + unsigned long long val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + return compare; +} + +__device__ inline unsigned long long +atomicCAS_system(unsigned long long *address, unsigned long long compare, + unsigned long long val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + return compare; +} + +__device__ inline float atomicCAS(float *address, float compare, float val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + return compare; +} + +__device__ inline float atomicCAS_system(float *address, float compare, + float val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + return compare; +} + +__device__ inline double atomicCAS(double *address, double compare, + double val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + return compare; +} + +__device__ inline double atomicCAS_system(double *address, double compare, + double val) { + __hip_atomic_compare_exchange_strong(address, &compare, val, __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + return compare; +} + +__device__ inline int atomicAdd(int *address, int val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline int atomicAdd_system(int *address, int val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned int atomicAdd(unsigned int *address, + unsigned int val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned int atomicAdd_system(unsigned int *address, + unsigned int val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned long atomicAdd(unsigned long *address, + unsigned long val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned long atomicAdd_system(unsigned long *address, + unsigned long val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned long long atomicAdd(unsigned long long *address, + unsigned long long val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned long long +atomicAdd_system(unsigned long long *address, unsigned long long val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline float atomicAdd(float *address, float val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicAdd(address, val); +#else + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif +} + +__device__ inline float atomicAdd_system(float *address, float val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +#if !defined(__HIPCC_RTC__) +DEPRECATED("use atomicAdd instead") +#endif // !defined(__HIPCC_RTC__) +__device__ inline void atomicAddNoRet(float *address, float val) { + __ockl_atomic_add_noret_f32(address, val); +} + +__device__ inline double atomicAdd(double *address, double val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicAdd(address, val); +#else + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif +} + +__device__ inline double atomicAdd_system(double *address, double val) { + return __hip_atomic_fetch_add(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline int atomicSub(int *address, int val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline int atomicSub_system(int *address, int val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned int atomicSub(unsigned int *address, + unsigned int val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned int atomicSub_system(unsigned int *address, + unsigned int val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned long atomicSub(unsigned long *address, + unsigned long val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned long atomicSub_system(unsigned long *address, + unsigned long val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned long long atomicSub(unsigned long long *address, + unsigned long long val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned long long +atomicSub_system(unsigned long long *address, unsigned long long val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline float atomicSub(float *address, float val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicAdd(address, -val); +#else + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif +} + +__device__ inline float atomicSub_system(float *address, float val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline double atomicSub(double *address, double val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicAdd(address, -val); +#else + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif +} + +__device__ inline double atomicSub_system(double *address, double val) { + return __hip_atomic_fetch_add(address, -val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline int atomicExch(int *address, int val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline int atomicExch_system(int *address, int val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned int atomicExch(unsigned int *address, + unsigned int val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned int atomicExch_system(unsigned int *address, + unsigned int val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned long atomicExch(unsigned long *address, + unsigned long val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned long atomicExch_system(unsigned long *address, + unsigned long val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline unsigned long long atomicExch(unsigned long long *address, + unsigned long long val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline unsigned long long +atomicExch_system(unsigned long long *address, unsigned long long val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline float atomicExch(float *address, float val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline float atomicExch_system(float *address, float val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline double atomicExch(double *address, double val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ inline double atomicExch_system(double *address, double val) { + return __hip_atomic_exchange(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +__device__ inline int atomicMin(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](int x, int y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline int atomicMin_system(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](int x, int y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicMin(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned int x, unsigned int y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicMin_system(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned int x, unsigned int y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long long atomicMin(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned long x, unsigned long y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicMin_system(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned long x, unsigned long y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long long atomicMin(unsigned long long *address, + unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, + [](unsigned long long x, unsigned long long y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long long +atomicMin_system(unsigned long long *address, unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, + [](unsigned long long x, unsigned long long y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline long long atomicMin(long long *address, long long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](long long x, long long y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline long long atomicMin_system(long long *address, + long long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](long long x, long long y) { return x < y; }, + [=]() { + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_min(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline float atomicMin(float *addr, float val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicMin(addr, val); +#else + typedef union u_hold { + float a; + unsigned int b; + } u_hold_t; + u_hold_t u{val}; + bool neg_zero = 0x80000000U == u.b; +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + float value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && (value > val || (neg_zero && value == 0.0f))) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned int *uaddr = (unsigned int *)addr; + unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && (__uint_as_float(value) > val || + (neg_zero && __uint_as_float(value) == 0.0f))) { + done = + __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __uint_as_float(value); +#endif +#endif +} + +__device__ inline float atomicMin_system(float *address, float val) { + unsigned int *uaddr{reinterpret_cast(address)}; +#if __has_builtin(__hip_atomic_load) + unsigned int tmp{ + __hip_atomic_load(uaddr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM)}; +#else + unsigned int tmp{__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; +#endif + float value = __uint_as_float(tmp); + + while (val < value) { + value = atomicCAS_system(address, value, val); + } + + return value; +} + +__device__ inline double atomicMin(double *addr, double val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicMin(addr, val); +#else + typedef union u_hold { + double a; + unsigned long long b; + } u_hold_t; + u_hold_t u{val}; + bool neg_zero = 0x8000000000000000ULL == u.b; +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + double value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && (value > val || (neg_zero && value == 0.0))) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned long long *uaddr = (unsigned long long *)addr; + unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && (__longlong_as_double(value) > val || + (neg_zero && __longlong_as_double(value) == 0.0))) { + done = + __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), + false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __longlong_as_double(value); +#endif +#endif +} + +__device__ inline double atomicMin_system(double *address, double val) { + unsigned long long *uaddr{reinterpret_cast(address)}; +#if __has_builtin(__hip_atomic_load) + unsigned long long tmp{ + __hip_atomic_load(uaddr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM)}; +#else + unsigned long long tmp{__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; +#endif + double value = __longlong_as_double(tmp); + + while (val < value) { + value = atomicCAS_system(address, value, val); + } + + return value; +} + +__device__ inline int atomicMax(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](int x, int y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline int atomicMax_system(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](int x, int y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicMax(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned int x, unsigned int y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicMax_system(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned int x, unsigned int y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicMax(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned long x, unsigned long y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicMax_system(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned long x, unsigned long y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long long atomicMax(unsigned long long *address, + unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, + [](unsigned long long x, unsigned long long y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long long +atomicMax_system(unsigned long long *address, unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, + [](unsigned long long x, unsigned long long y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline long long atomicMax(long long *address, long long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](long long x, long long y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline long long atomicMax_system(long long *address, + long long val) { +#if defined(__gfx941__) + return hip_cas_extrema_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](long long x, long long y) { return y < x; }, + [=]() { + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_max(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline float atomicMax(float *addr, float val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicMax(addr, val); +#else + typedef union u_hold { + float a; + unsigned int b; + } u_hold_t; + u_hold_t u{val}; + bool neg_zero = 0x80000000U == u.b; +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + float value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && (value < val || (neg_zero && value == 0.0f))) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned int *uaddr = (unsigned int *)addr; + unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && (__uint_as_float(value) < val || + (neg_zero && __uint_as_float(value) == 0.0f))) { + done = + __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __uint_as_float(value); +#endif +#endif +} + +__device__ inline float atomicMax_system(float *address, float val) { + unsigned int *uaddr{reinterpret_cast(address)}; +#if __has_builtin(__hip_atomic_load) + unsigned int tmp{ + __hip_atomic_load(uaddr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM)}; +#else + unsigned int tmp{__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; +#endif + float value = __uint_as_float(tmp); + + while (value < val) { + value = atomicCAS_system(address, value, val); + } + + return value; +} + +__device__ inline double atomicMax(double *addr, double val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicMax(addr, val); +#else + typedef union u_hold { + double a; + unsigned long long b; + } u_hold_t; + u_hold_t u{val}; + bool neg_zero = 0x8000000000000000ULL == u.b; +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + double value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && (value < val || (neg_zero && value == 0.0))) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned long long *uaddr = (unsigned long long *)addr; + unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && (__longlong_as_double(value) < val || + (neg_zero && __longlong_as_double(value) == 0.0))) { + done = + __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), + false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __longlong_as_double(value); +#endif +#endif +} + +__device__ inline double atomicMax_system(double *address, double val) { + unsigned long long *uaddr{reinterpret_cast(address)}; +#if __has_builtin(__hip_atomic_load) + unsigned long long tmp{ + __hip_atomic_load(uaddr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM)}; +#else + unsigned long long tmp{__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; +#endif + double value = __longlong_as_double(tmp); + + while (value < val) { + value = atomicCAS_system(address, value, val); + } + + return value; +} + +__device__ inline unsigned int atomicInc(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, + [](unsigned int &x, unsigned int y) { x = (x >= y) ? 0 : (x + 1); }, + [=]() { + return __builtin_amdgcn_atomic_inc32(address, val, __ATOMIC_RELAXED, + "agent"); + }); +#else + return __builtin_amdgcn_atomic_inc32(address, val, __ATOMIC_RELAXED, "agent"); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicDec(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, + [](unsigned int &x, unsigned int y) { x = (!x || x > y) ? y : (x - 1); }, + [=]() { + return __builtin_amdgcn_atomic_dec32(address, val, __ATOMIC_RELAXED, + "agent"); + }); +#else + return __builtin_amdgcn_atomic_dec32(address, val, __ATOMIC_RELAXED, "agent"); +#endif // __gfx941__ +} + +__device__ inline int atomicAnd(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](int &x, int y) { x &= y; }, + [=]() { + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline int atomicAnd_system(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](int &x, int y) { x &= y; }, + [=]() { + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicAnd(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned int &x, unsigned int y) { x &= y; }, + [=]() { + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicAnd_system(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned int &x, unsigned int y) { x &= y; }, + [=]() { + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicAnd(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned long &x, unsigned long y) { x &= y; }, + [=]() { + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicAnd_system(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned long &x, unsigned long y) { x &= y; }, + [=]() { + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long long atomicAnd(unsigned long long *address, + unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned long long &x, unsigned long long y) { x &= y; }, + [=]() { + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long long +atomicAnd_system(unsigned long long *address, unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned long long &x, unsigned long long y) { x &= y; }, + [=]() { + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_and(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline int atomicOr(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](int &x, int y) { x |= y; }, + [=]() { + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline int atomicOr_system(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](int &x, int y) { x |= y; }, + [=]() { + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicOr(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned int &x, unsigned int y) { x |= y; }, + [=]() { + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicOr_system(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned int &x, unsigned int y) { x |= y; }, + [=]() { + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicOr(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned long &x, unsigned long y) { x |= y; }, + [=]() { + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicOr_system(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned long &x, unsigned long y) { x |= y; }, + [=]() { + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long long atomicOr(unsigned long long *address, + unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned long long &x, unsigned long long y) { x |= y; }, + [=]() { + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long long +atomicOr_system(unsigned long long *address, unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned long long &x, unsigned long long y) { x |= y; }, + [=]() { + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_or(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline int atomicXor(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](int &x, int y) { x ^= y; }, + [=]() { + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline int atomicXor_system(int *address, int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](int &x, int y) { x ^= y; }, + [=]() { + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicXor(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned int &x, unsigned int y) { x ^= y; }, + [=]() { + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned int atomicXor_system(unsigned int *address, + unsigned int val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned int &x, unsigned int y) { x ^= y; }, + [=]() { + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicXor(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned long &x, unsigned long y) { x ^= y; }, + [=]() { + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long atomicXor_system(unsigned long *address, + unsigned long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM>( + address, val, [](unsigned long &x, unsigned long y) { x ^= y; }, + [=]() { + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); + }); +#else + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#endif // __gfx941__ +} + +__device__ inline unsigned long long atomicXor(unsigned long long *address, + unsigned long long val) { +#if defined(__gfx941__) + return hip_cas_expander<__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT>( + address, val, [](unsigned long long &x, unsigned long long y) { x ^= y; }, + [=]() { + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + }); +#else + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#endif // __gfx941__ +} + +__device__ inline unsigned long long +atomicXor_system(unsigned long long *address, unsigned long long val) { + return __hip_atomic_fetch_xor(address, val, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +} + +#else // __hip_atomic_compare_exchange_strong + +__device__ inline int atomicCAS(int *address, int compare, int val) { + __atomic_compare_exchange_n(address, &compare, val, false, __ATOMIC_RELAXED, + __ATOMIC_RELAXED); + + return compare; +} +__device__ inline unsigned int +atomicCAS(unsigned int *address, unsigned int compare, unsigned int val) { + __atomic_compare_exchange_n(address, &compare, val, false, __ATOMIC_RELAXED, + __ATOMIC_RELAXED); + + return compare; +} +__device__ inline unsigned long long atomicCAS(unsigned long long *address, + unsigned long long compare, + unsigned long long val) { + __atomic_compare_exchange_n(address, &compare, val, false, __ATOMIC_RELAXED, + __ATOMIC_RELAXED); + + return compare; +} + +__device__ inline int atomicAdd(int *address, int val) { + return __atomic_fetch_add(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned int atomicAdd(unsigned int *address, + unsigned int val) { + return __atomic_fetch_add(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned long long atomicAdd(unsigned long long *address, + unsigned long long val) { + return __atomic_fetch_add(address, val, __ATOMIC_RELAXED); +} +__device__ inline float atomicAdd(float *address, float val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicAdd(address, val); +#else + return __atomic_fetch_add(address, val, __ATOMIC_RELAXED); +#endif +} + +#if !defined(__HIPCC_RTC__) +DEPRECATED("use atomicAdd instead") +#endif // !defined(__HIPCC_RTC__) +__device__ inline void atomicAddNoRet(float *address, float val) { + __ockl_atomic_add_noret_f32(address, val); +} + +__device__ inline double atomicAdd(double *address, double val) { +#if defined(__AMDGCN_UNSAFE_FP_ATOMICS__) + return unsafeAtomicAdd(address, val); +#else + return __atomic_fetch_add(address, val, __ATOMIC_RELAXED); +#endif +} + +__device__ inline int atomicSub(int *address, int val) { + return __atomic_fetch_sub(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned int atomicSub(unsigned int *address, + unsigned int val) { + return __atomic_fetch_sub(address, val, __ATOMIC_RELAXED); +} + +__device__ inline int atomicExch(int *address, int val) { + return __atomic_exchange_n(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned int atomicExch(unsigned int *address, + unsigned int val) { + return __atomic_exchange_n(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned long long atomicExch(unsigned long long *address, + unsigned long long val) { + return __atomic_exchange_n(address, val, __ATOMIC_RELAXED); +} +__device__ inline float atomicExch(float *address, float val) { + return __uint_as_float( + __atomic_exchange_n(reinterpret_cast(address), + __float_as_uint(val), __ATOMIC_RELAXED)); +} + +__device__ inline int atomicMin(int *address, int val) { + return __atomic_fetch_min(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned int atomicMin(unsigned int *address, + unsigned int val) { + return __atomic_fetch_min(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned long long atomicMin(unsigned long long *address, + unsigned long long val) { + unsigned long long tmp{__atomic_load_n(address, __ATOMIC_RELAXED)}; + while (val < tmp) { + const auto tmp1 = __atomic_load_n(address, __ATOMIC_RELAXED); + + if (tmp1 != tmp) { + tmp = tmp1; + continue; + } + + tmp = atomicCAS(address, tmp, val); + } + + return tmp; +} +__device__ inline long long atomicMin(long long *address, long long val) { + long long tmp{__atomic_load_n(address, __ATOMIC_RELAXED)}; + while (val < tmp) { + const auto tmp1 = __atomic_load_n(address, __ATOMIC_RELAXED); + + if (tmp1 != tmp) { + tmp = tmp1; + continue; + } + + tmp = atomicCAS(address, tmp, val); + } + return tmp; +} + +__device__ inline int atomicMax(int *address, int val) { + return __atomic_fetch_max(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned int atomicMax(unsigned int *address, + unsigned int val) { + return __atomic_fetch_max(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned long long atomicMax(unsigned long long *address, + unsigned long long val) { + unsigned long long tmp{__atomic_load_n(address, __ATOMIC_RELAXED)}; + while (tmp < val) { + const auto tmp1 = __atomic_load_n(address, __ATOMIC_RELAXED); + + if (tmp1 != tmp) { + tmp = tmp1; + continue; + } + + tmp = atomicCAS(address, tmp, val); + } + + return tmp; +} +__device__ inline long long atomicMax(long long *address, long long val) { + long long tmp{__atomic_load_n(address, __ATOMIC_RELAXED)}; + while (tmp < val) { + const auto tmp1 = __atomic_load_n(address, __ATOMIC_RELAXED); + + if (tmp1 != tmp) { + tmp = tmp1; + continue; + } + + tmp = atomicCAS(address, tmp, val); + } + return tmp; +} + +__device__ inline unsigned int atomicInc(unsigned int *address, + unsigned int val) { + return __builtin_amdgcn_atomic_inc32(address, val, __ATOMIC_RELAXED, "agent"); +} + +__device__ inline unsigned int atomicDec(unsigned int *address, + unsigned int val) { + return __builtin_amdgcn_atomic_dec32(address, val, __ATOMIC_RELAXED, "agent"); +} + +__device__ inline int atomicAnd(int *address, int val) { + return __atomic_fetch_and(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned int atomicAnd(unsigned int *address, + unsigned int val) { + return __atomic_fetch_and(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned long long atomicAnd(unsigned long long *address, + unsigned long long val) { + return __atomic_fetch_and(address, val, __ATOMIC_RELAXED); +} + +__device__ inline int atomicOr(int *address, int val) { + return __atomic_fetch_or(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned int atomicOr(unsigned int *address, + unsigned int val) { + return __atomic_fetch_or(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned long long atomicOr(unsigned long long *address, + unsigned long long val) { + return __atomic_fetch_or(address, val, __ATOMIC_RELAXED); +} + +__device__ inline int atomicXor(int *address, int val) { + return __atomic_fetch_xor(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned int atomicXor(unsigned int *address, + unsigned int val) { + return __atomic_fetch_xor(address, val, __ATOMIC_RELAXED); +} +__device__ inline unsigned long long atomicXor(unsigned long long *address, + unsigned long long val) { + return __atomic_fetch_xor(address, val, __ATOMIC_RELAXED); +} + +#endif // __hip_atomic_compare_exchange_strong diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_bf16.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_bf16.h new file mode 100644 index 000000000..734b8c73f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_bf16.h @@ -0,0 +1,1968 @@ +/** + * MIT License + * + * Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * \file + * \brief hip_bf16.h provides struct for __hip_bfloat16 types + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT16 bfloat16 Precision Intrinsics + * This section describes hip_bfloat16 precision intrinsic functions. + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_ARITH Bfloat16 Arithmetic Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_COMP Bfloat16 Comparision Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_COMP Bfloat162 Comparision Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_ARITH Bfloat162 Arithmetic Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_CONV Bfloat16 Conversion Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_CONV Bfloat162 Conversion Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_MATH Bfloat16 Math Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_MATH Bfloat162 Math Functions + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_RAW Bfloat16 Raw Struct + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_RAW Bfloat162 Raw Struct + * \ingroup HIP_INTRINSIC_BFLOAT16 + * To use these functions, include the header file \p hip_bf16.h in your + * program. + */ + +#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_ +#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_ + +#if !defined(__HIPCC_RTC__) +#include +#endif // !defined(__HIPCC_RTC__) + +#include "amd_hip_vector_types.h" // float2 etc +#include "device_library_decls.h" // ocml conversion functions +#include "math_fwd.h" // ocml device functions + +#define __BF16_DEVICE__ __device__ +#if defined(__HIPCC_RTC__) +#define __BF16_HOST_DEVICE__ __BF16_DEVICE__ +#else +#include +#include +#include +#define __BF16_HOST_DEVICE__ __host__ __BF16_DEVICE__ +#endif +#define __BF16_DEVICE_STATIC__ __BF16_DEVICE__ static inline +#define __BF16_HOST_DEVICE_STATIC__ __BF16_HOST_DEVICE__ static inline + +#if defined(__AVX512VL__) and defined(__AVX512BF16__) and \ + not defined(__HIP_DEVICE_COMPILE__) +// Enable with -mavx512vl -mavx512bf16 +#if defined(__MINGW64__) +#include +#else +#include +#endif +#define HIP_BF16_AVX512_OP 1 +static_assert(sizeof(__bf16) == sizeof(unsigned short), + "sizeof __bf16 should match sizeof unsigned short"); +#else +#define HIP_BF16_AVX512_OP 0 +#endif + +#define HIPRT_ONE_BF16 __float2bfloat16(1.0f) +#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f) +#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) +#define HIPRT_MAX_NORMAL_BF16 __ushort_as_bfloat16((unsigned short)0x7F7FU) +#define HIPRT_MIN_DENORM_BF16 __ushort_as_bfloat16((unsigned short)0x0001U) +#define HIPRT_NAN_BF16 __ushort_as_bfloat16((unsigned short)0x7FFFU) +#define HIPRT_NEG_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x8000U) + +// Since we are using unsigned short to represent data in bfloat16, it can be of +// different sizes on different machines. These naive checks should prevent some +// undefined behavior on systems which have different sizes for basic types. +#if !defined(__HIPCC_RTC__) +static_assert(CHAR_BIT == 8, "byte size should be of 8 bits"); +#endif +static_assert(sizeof(unsigned short) == 2, + "size of unsigned short should be 2 bytes"); + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_RAW + * \brief represents raw bfloat16 type + */ +typedef struct __attribute__((aligned(2))) { + unsigned short x; +} __hip_bfloat16_raw; + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_RAW + * \brief represents raw bfloat16x2 vector type + */ +typedef struct __attribute__((aligned(4))) { + unsigned short x; + unsigned short y; +} __hip_bfloat162_raw; + +/** + * \defgroup HIP_INTRINSIC_BFLOAT16_STRUCT + * \ingroup HIP_INTRINSIC_BFLOAT16 + * \brief Struct to represent a 16 bit brain floating point number. + * @{ + */ +struct __attribute__((aligned(2))) __hip_bfloat16 { +private: + __BF16_HOST_DEVICE_STATIC__ float bfloatraw_2_float(unsigned short val) { +#if HIP_BF16_AVX512_OP + union { + unsigned short us; + __bf16 bf16; + } u = {val}; + return _mm_cvtsbh_ss(u.bf16); +#else + unsigned int uval = val << 16; + union { + unsigned int u32; + float fp32; + } u = {uval}; + return u.fp32; +#endif + } + __BF16_HOST_DEVICE_STATIC__ unsigned short float_2_bfloatraw(float f) { +#if HIP_BF16_AVX512_OP + union { + __bf16 bf16; + unsigned short us; + } u = {_mm_cvtness_sbh(f)}; + return u.us; +#else + union { + float fp32; + unsigned int u32; + } u = {f}; + if (~u.u32 & 0x7f800000) { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // Round to nearest, round to even + } else if (u.u32 & 0xffff) { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.u32 |= 0x10000; // Preserve signaling NaN + } + return static_cast(u.u32 >> 16); +#endif + } + + __BF16_HOST_DEVICE_STATIC__ unsigned short double_2_bfloatraw(double d_in) { + union { + float fp32; + unsigned int u32; + } u = {static_cast(d_in)}; + double d = u.fp32; + + // Round to odd + if ((d_in > 0.0 && d > d_in) || (d_in < 0.0 && d < d_in)) { + u.u32--; + u.u32 |= 1; + } + + return float_2_bfloatraw(u.fp32); + } + +protected: + /*! \brief raw representation of bfloat16 */ + unsigned short __x; + +public: + // TODO: SWDEV-452411 + // Need to add constructor of __hip_bfloat16 from + // unsigned long long + // long long + // long + // unsigned long + // Casting directly to double might lead to double rounding. + + /*! \brief create __hip_bfloat16 from an unsigned int */ + __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned int val) + : __x(double_2_bfloatraw(static_cast(val))) {} + + /*! \brief create __hip_bfloat16 from a int */ + __BF16_HOST_DEVICE__ __hip_bfloat16(int val) + : __x(double_2_bfloatraw(static_cast(val))) {} + + /*! \brief create __hip_bfloat16 from an unsigned short */ + __BF16_HOST_DEVICE__ __hip_bfloat16(unsigned short val) + : __x(float_2_bfloatraw(static_cast(val))) {} + + /*! \brief create __hip_bfloat16 from a short */ + __BF16_HOST_DEVICE__ __hip_bfloat16(short val) + : __x(float_2_bfloatraw(static_cast(val))) {} + + /*! \brief create __hip_bfloat16 from a double */ + __BF16_HOST_DEVICE__ __hip_bfloat16(const double val) + : __x(double_2_bfloatraw(val)) {} + + /*! \brief create __hip_bfloat16 from a float */ + __BF16_HOST_DEVICE__ __hip_bfloat16(const float val) + : __x(float_2_bfloatraw(val)) {} + + /*! \brief create __hip_bfloat16 from a __hip_bfloat16_raw */ + __BF16_HOST_DEVICE__ __hip_bfloat16(const __hip_bfloat16_raw &val) + : __x(val.x) {} + + /*! \brief default constructor */ + __BF16_HOST_DEVICE__ __hip_bfloat16() = default; + + /*! \brief return a __hip_bfloat16_raw */ + __BF16_HOST_DEVICE__ operator __hip_bfloat16_raw() const { + return __hip_bfloat16_raw{__x}; + } + + /*! \brief return a __hip_bfloat16_raw cv qualifier */ + __BF16_HOST_DEVICE__ operator __hip_bfloat16_raw() const volatile { + return __hip_bfloat16_raw{__x}; + } + + /*! \brief return false if bfloat value is +0.0 or -0.0, returns true + * otherwise */ + __BF16_HOST_DEVICE__ operator bool() const { + auto val = bfloatraw_2_float(__x); + return val != 0.0f && val != -0.0f; + } + + /*! \brief return a casted char from underlying float val */ + __BF16_HOST_DEVICE__ operator char() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a float */ + __BF16_HOST_DEVICE__ operator float() const { return bfloatraw_2_float(__x); } + + /*! \brief return a casted int casted from float of underlying bfloat16 value + */ + __BF16_HOST_DEVICE__ operator int() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted long casted from float of underlying bfloat16 value + */ + __BF16_HOST_DEVICE__ operator long() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted long long casted from float of underlying bfloat16 + * value */ + __BF16_HOST_DEVICE__ operator long long() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted short casted from float of underlying bfloat16 + * value */ + __BF16_HOST_DEVICE__ operator short() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted signed char from float of underlying bfloat16 value + */ + __BF16_HOST_DEVICE__ operator signed char() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned char casted from float of underlying + * bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned char() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned int casted from float of underlying + * bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned int() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned from float of underlying bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned long() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned long long from float of underlying + * bfloat16 value */ + __BF16_HOST_DEVICE__ operator unsigned long long() const { + return static_cast(bfloatraw_2_float(__x)); + } + + /*! \brief return a casted unsigned short from float of underlying bfloat16 + * value */ + __BF16_HOST_DEVICE__ operator unsigned short() const { + return static_cast(bfloatraw_2_float(__x)); + } + + // TODO: SWDEV-452411 add operator which converts unsigned long long and long + // long to bfloat + + /*! \brief assign value from an unsigned int */ + __BF16_HOST_DEVICE__ __hip_bfloat16 &operator=(unsigned int val) { + __x = float_2_bfloatraw(static_cast(val)); + return *this; + } + + /*! \brief assign value from a int */ + __BF16_HOST_DEVICE__ __hip_bfloat16 &operator=(int val) { + __x = float_2_bfloatraw(static_cast(val)); + return *this; + } + + /*! \brief assign value from an unsigned short */ + __BF16_HOST_DEVICE__ __hip_bfloat16 &operator=(unsigned short val) { + __x = float_2_bfloatraw(static_cast(val)); + return *this; + } + + /*! \brief assign value from a short int */ + __BF16_HOST_DEVICE__ __hip_bfloat16 &operator=(short val) { + __x = float_2_bfloatraw(static_cast(val)); + return *this; + } + + /*! \brief assign value from a double */ + __BF16_HOST_DEVICE__ __hip_bfloat16 &operator=(const double f) { + __x = float_2_bfloatraw(static_cast(f)); + return *this; + } + + /*! \brief assign value from a float */ + __BF16_HOST_DEVICE__ __hip_bfloat16 &operator=(const float f) { + __x = float_2_bfloatraw(f); + return *this; + } + + /*! \brief assign value from a __hip_bfloat16_raw */ + __BF16_HOST_DEVICE__ __hip_bfloat16 &operator=(const __hip_bfloat16_raw &hr) { + __x = hr.x; + return *this; + } + + /*! \brief assign value from a __hip_bfloat16_raw volatile */ + __BF16_HOST_DEVICE__ volatile __hip_bfloat16 & + operator=(const __hip_bfloat16_raw &hr) volatile { + __x = hr.x; + return *this; + } + + /*! \brief assign value from a __hip_bfloat16_raw cv qualifier */ + __BF16_HOST_DEVICE__ volatile __hip_bfloat16 & + operator=(const volatile __hip_bfloat16_raw &hr) volatile { + __x = hr.x; + return *this; + } +}; +/**@}*/ + +/** + * \defgroup HIP_INTRINSIC_BFLOAT162_STRUCT + * \ingroup HIP_INTRINSIC_BFLOAT16 + * \brief Struct to represent a two 16 bit brain floating point number. + * @{ + */ +struct __attribute__((aligned(4))) __hip_bfloat162 { +public: + __hip_bfloat16 x; /*! \brief raw representation of bfloat16 */ + __hip_bfloat16 y; /*! \brief raw representation of bfloat16 */ + +public: + /*! \brief create __hip_bfloat162 from __hip_bfloat162_raw */ + __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162_raw &h2r) + : x(__hip_bfloat16(__hip_bfloat16_raw{h2r.x})), + y(__hip_bfloat16(__hip_bfloat16_raw{h2r.y})) {} + + /*! \brief copy constructor of __hip_bfloat162 */ + __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat162 &val) { + __hip_bfloat162_raw hr = val; + x = __hip_bfloat16_raw{hr.x}; + y = __hip_bfloat16_raw{hr.y}; + } + + /*! \brief create __hip_bfloat162 from two __hip_bfloat16 */ + __BF16_HOST_DEVICE__ __hip_bfloat162(const __hip_bfloat16 &a, + const __hip_bfloat16 &b) + : x(a), y(b) {} + + /*! \brief default constructor of __hip_bfloat162 */ + __BF16_HOST_DEVICE__ __hip_bfloat162() = default; + + /*! \brief return a __hip_bfloat162_raw */ + __BF16_HOST_DEVICE__ operator __hip_bfloat162_raw() const { + __hip_bfloat16_raw l = x; + __hip_bfloat16_raw r = y; + return __hip_bfloat162_raw{l.x, r.x}; + } + + /*! \brief return a float2 */ + __BF16_HOST_DEVICE__ operator float2() const { +#if HIP_BF16_AVX512_OP + union { + __hip_bfloat162_raw raw2; + __bf16 bf162[2]; + static_assert(sizeof(__bf16[2]) == sizeof(__hip_bfloat162_raw)); + } u; + u.raw2 = *this; + __m128bh pbf16{u.bf162[0], u.bf162[1], 0, 0}; + __m128 pf32 = _mm_cvtpbh_ps(pbf16); + float2 ret(pf32[0], pf32[1]); +#else + float2 ret(x, y); +#endif + return ret; + } + + /*! \brief assign value from __hip_bfloat162_raw */ + __BF16_HOST_DEVICE__ __hip_bfloat162 & + operator=(const __hip_bfloat162_raw &h2r) { + x = __hip_bfloat16(__hip_bfloat16_raw{h2r.x}); + y = __hip_bfloat16(__hip_bfloat16_raw{h2r.y}); + return *this; + } + + /*! \brief assign value from __hip_bfloat162 */ + __BF16_HOST_DEVICE__ __hip_bfloat162 &operator=(const __hip_bfloat162 &src) { + __hip_bfloat162_raw hr = src; + x = __hip_bfloat16(__hip_bfloat16_raw{hr.x}); + y = __hip_bfloat16(__hip_bfloat16_raw{hr.y}); + return *this; + } +}; +/**@}*/ + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_CONV + * \brief Converts bfloat16 to float + */ +__BF16_HOST_DEVICE_STATIC__ float __bfloat162float(__hip_bfloat16 a) { + float ret = a; + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_CONV + * \brief Converts float to bfloat16 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __float2bfloat16(float f) { + __hip_bfloat16 ret{f}; + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Converts and moves bfloat162 to float2 + */ +__BF16_HOST_DEVICE_STATIC__ float2 __bfloat1622float2(const __hip_bfloat162 a) { + float2 ret = a; + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Moves bfloat16 value to bfloat162 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__bfloat162bfloat162(const __hip_bfloat16 a) { + return __hip_bfloat162(a, a); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer + */ +__BF16_HOST_DEVICE_STATIC__ short int +__bfloat16_as_short(const __hip_bfloat16 h) { + short ret = h; + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short + * integer + */ +__BF16_HOST_DEVICE_STATIC__ unsigned short int +__bfloat16_as_ushort(const __hip_bfloat16 h) { + unsigned short ret = h; + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Convert double to __hip_bfloat16 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __double2bfloat16(const double a) { + __hip_bfloat16 ret{a}; + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Convert float2 to __hip_bfloat162 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__float22bfloat162_rn(const float2 a) { + return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)}; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Combine two __hip_bfloat16 to __hip_bfloat162 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __hip_bfloat162(a, b); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns high 16 bits of __hip_bfloat162 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 +__high2bfloat16(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat16(__hip_bfloat16_raw{hr.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns high 16 bits of __hip_bfloat162 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__high2bfloat162(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat162(__hip_bfloat16_raw{hr.y}, __hip_bfloat16_raw{hr.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Converts high 16 bits of __hip_bfloat162 to float and returns the + * result + */ +__BF16_HOST_DEVICE_STATIC__ float __high2float(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Extracts high 16 bits from each and combines them + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__highs2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hip_bfloat162_raw{hr_a.y, hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns low 16 bits of __hip_bfloat162 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 +__low2bfloat16(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat16(hr.x); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns low 16 bits of __hip_bfloat162 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__low2bfloat162(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat162(hr.x, hr.x); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Converts low 16 bits of __hip_bfloat162 to float and returns the + * result + */ +__BF16_HOST_DEVICE_STATIC__ float __low2float(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __bfloat162float(__hip_bfloat16(__hip_bfloat16_raw{hr.x})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Swaps both halves + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__lowhigh2highlow(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr = a; + return __hip_bfloat162(__hip_bfloat162_raw{hr.y, hr.x}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Extracts low 16 bits from each and combines them + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162(__hip_bfloat162_raw{hr_a.x, hr_b.x}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets short int into a bfloat16 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 +__short_as_bfloat16(const short int a) { + return __hip_bfloat16(a); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets unsigned short int into a bfloat16 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 +__ushort_as_bfloat16(const unsigned short int a) { + return __hip_bfloat16(a); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Adds two bfloat16 values + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Subtracts two bfloat16 values + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Divides two bfloat16 values + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Performs FMA of given bfloat16 values + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 __hfma(const __hip_bfloat16 a, + const __hip_bfloat16 b, + const __hip_bfloat16 c) { + return __float2bfloat16(__ocml_fma_f32( + __bfloat162float(a), __bfloat162float(b), __bfloat162float(c))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Multiplies two bfloat16 values + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Negate a bfloat16 value + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { + __hip_bfloat16_raw hr = a; + hr.x ^= 0x8000; + return __hip_bfloat16(hr); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Returns absolute of a bfloat16 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { + __hip_bfloat16_raw hr = a; + hr.x &= 0x7FFF; + return __hip_bfloat16(hr); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Divides bfloat162 values + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162( + __float2bfloat16(__bfloat162float(__hip_bfloat16_raw{hr_a.x}) / + __bfloat162float(__hip_bfloat16_raw{hr_b.x})), + __float2bfloat16(__bfloat162float(__hip_bfloat16_raw{hr_a.y}) / + __bfloat162float(__hip_bfloat16_raw{hr_b.y}))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Returns absolute of a bfloat162 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr_a = a; + return __hip_bfloat162(__habs(__hip_bfloat16_raw{hr_a.x}), + __habs(__hip_bfloat16_raw{hr_a.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Adds two bfloat162 values + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162( + __hadd(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hadd(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Performs FMA of given bfloat162 values + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 __hfma2(const __hip_bfloat162 a, + const __hip_bfloat162 b, + const __hip_bfloat162 c) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + __hip_bfloat162_raw hr_c = c; + return __hip_bfloat162( + __hfma(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}, + __hip_bfloat16_raw{hr_c.x}), + __hfma(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}, + __hip_bfloat16_raw{hr_c.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Multiplies two bfloat162 values + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162( + __hmul(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hmul(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Converts a bfloat162 into negative + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr_a = a; + return __hip_bfloat162(__hneg(__hip_bfloat16_raw{hr_a.x}), + __hneg(__hip_bfloat16_raw{hr_a.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Subtracts two bfloat162 values + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162( + __hsub(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hsub(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to multiply two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator*(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hmul(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to multiply-assign two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & +operator*=(__hip_bfloat16 &l, const __hip_bfloat16 &r) { + l = __hmul(l, r); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to unary+ on a __hip_bfloat16 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16 &l) { + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to add two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hadd(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to negate a __hip_bfloat16 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16 &l) { + return __hneg(l); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to subtract two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hsub(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to post increment a __hip_bfloat16 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator++(__hip_bfloat16 &l, + const int) { + auto ret = l; + l = __hadd(l, HIPRT_ONE_BF16); + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to pre increment a __hip_bfloat16 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 &operator++(__hip_bfloat16 &l) { + l = __hadd(l, HIPRT_ONE_BF16); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to post decrement a __hip_bfloat16 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator--(__hip_bfloat16 &l, + const int) { + auto ret = l; + l = __hsub(l, HIPRT_ONE_BF16); + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to pre decrement a __hip_bfloat16 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 &operator--(__hip_bfloat16 &l) { + l = __hsub(l, HIPRT_ONE_BF16); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to add-assign two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & +operator+=(__hip_bfloat16 &l, const __hip_bfloat16 &r) { + l = __hadd(l, r); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to subtract-assign two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & +operator-=(__hip_bfloat16 &l, const __hip_bfloat16 &r) { + l = __hsub(l, r); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to divide two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator/(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hdiv(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Operator to divide-assign two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & +operator/=(__hip_bfloat16 &l, const __hip_bfloat16 &r) { + l = __hdiv(l, r); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to multiply two __hip_bfloat162 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +operator*(const __hip_bfloat162 &l, const __hip_bfloat162 &r) { + return __hmul2(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to multiply-assign two __hip_bfloat162 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 & +operator*=(__hip_bfloat162 &l, const __hip_bfloat162 &r) { + l = __hmul2(l, r); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to unary+ on a __hip_bfloat162 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +operator+(const __hip_bfloat162 &l) { + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to add two __hip_bfloat162 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +operator+(const __hip_bfloat162 &l, const __hip_bfloat162 &r) { + return __hadd2(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to negate a __hip_bfloat162 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +operator-(const __hip_bfloat162 &l) { + return __hneg2(l); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to subtract two __hip_bfloat162 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +operator-(const __hip_bfloat162 &l, const __hip_bfloat162 &r) { + return __hsub2(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to post increment a __hip_bfloat162 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator++(__hip_bfloat162 &l, + const int) { + auto ret = l; + l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to pre increment a __hip_bfloat162 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 &operator++(__hip_bfloat162 &l) { + l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to post decrement a __hip_bfloat162 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 operator--(__hip_bfloat162 &l, + const int) { + auto ret = l; + l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); + return ret; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to pre decrement a __hip_bfloat162 number + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 &operator--(__hip_bfloat162 &l) { + l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to add-assign two __hip_bfloat162 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 & +operator+=(__hip_bfloat162 &l, const __hip_bfloat162 &r) { + l = __hadd2(l, r); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to subtract-assign two __hip_bfloat162 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 & +operator-=(__hip_bfloat162 &l, const __hip_bfloat162 &r) { + l = __hsub2(l, r); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to divide two __hip_bfloat162 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +operator/(const __hip_bfloat162 &l, const __hip_bfloat162 &r) { + return __h2div(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Operator to divide-assign two __hip_bfloat162 numbers + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 & +operator/=(__hip_bfloat162 &l, const __hip_bfloat162 &r) { + l = __h2div(l, r); + return l; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values + */ +__BF16_HOST_DEVICE_STATIC__ bool __heq(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __bfloat162float(a) == __bfloat162float(b); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered equal + */ +__BF16_HOST_DEVICE_STATIC__ bool __hequ(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return !(__bfloat162float(a) < __bfloat162float(b)) && + !(__bfloat162float(a) > __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - greater than + */ +__BF16_HOST_DEVICE_STATIC__ bool __hgt(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __bfloat162float(a) > __bfloat162float(b); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered greater than + */ +__BF16_HOST_DEVICE_STATIC__ bool __hgtu(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return !(__bfloat162float(a) <= __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - greater than equal + */ +__BF16_HOST_DEVICE_STATIC__ bool __hge(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __bfloat162float(a) >= __bfloat162float(b); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered greater than equal + */ +__BF16_HOST_DEVICE_STATIC__ bool __hgeu(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return !(__bfloat162float(a) < __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - not equal + */ +__BF16_HOST_DEVICE_STATIC__ bool __hne(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __bfloat162float(a) != __bfloat162float(b); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered not equal + */ +__BF16_HOST_DEVICE_STATIC__ bool __hneu(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return !(__bfloat162float(a) == __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - return max + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, + const __hip_bfloat16 b) { +#if __HIP_DEVICE_COMPILE__ + return __float2bfloat16( + __ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b))); +#else + return __float2bfloat16(std::max(__bfloat162float(a), __bfloat162float(b))); +#endif +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - return min + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, + const __hip_bfloat16 b) { +#if __HIP_DEVICE_COMPILE__ + return __float2bfloat16( + __ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b))); +#else + return __float2bfloat16(std::min(__bfloat162float(a), __bfloat162float(b))); +#endif +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - less than operator + */ +__BF16_HOST_DEVICE_STATIC__ bool __hlt(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __bfloat162float(a) < __bfloat162float(b); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered less than + */ +__BF16_HOST_DEVICE_STATIC__ bool __hltu(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return !(__bfloat162float(a) >= __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - less than equal + */ +__BF16_HOST_DEVICE_STATIC__ bool __hle(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return __bfloat162float(a) <= __bfloat162float(b); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered less than equal + */ +__BF16_HOST_DEVICE_STATIC__ bool __hleu(const __hip_bfloat16 a, + const __hip_bfloat16 b) { + return !(__bfloat162float(a) > __bfloat162float(b)); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Checks if number is inf + */ +__BF16_HOST_DEVICE_STATIC__ int __hisinf(const __hip_bfloat16 a) { + __hip_bfloat16_raw hr = a; + return !(~hr.x & 0x7f80) && !(hr.x & 0x7f); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Checks if number is nan + */ +__BF16_HOST_DEVICE_STATIC__ bool __hisnan(const __hip_bfloat16 a) { + __hip_bfloat16_raw hr = a; + return !(~hr.x & 0x7f80) && +(hr.x & 0x7f); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks if two numbers are equal + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbeq2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __heq(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __heq(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks if two numbers are equal - unordered + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbequ2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hequ(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hequ(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbge2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hge(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hge(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b - unordered + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbgeu2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hgeu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hgeu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbgt2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hgt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hgt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b - unordered + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbgtu2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hgtu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hgtu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b + */ +__BF16_HOST_DEVICE_STATIC__ bool __hble2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hle(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hle(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b - unordered + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbleu2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hleu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hleu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b + */ +__BF16_HOST_DEVICE_STATIC__ bool __hblt2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hlt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hlt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b - unordered + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbltu2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hltu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) && + __hltu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbne2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hne(__hip_bfloat16(__hip_bfloat16_raw{hr_a.x}), + __hip_bfloat16(__hip_bfloat16_raw{hr_b.x})) && + __hne(__hip_bfloat16(__hip_bfloat16_raw{hr_a.y}), + __hip_bfloat16(__hip_bfloat16_raw{hr_b.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b + */ +__BF16_HOST_DEVICE_STATIC__ bool __hbneu2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hneu(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) || + __hneu(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b, returns 1.0 if equal, otherwise 0.0 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__heq(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__heq(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hge(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hge(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hgt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hgt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) + ? HIPRT_ONE_BF16 + : HIPRT_ONE_BF16}}; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { + __hip_bfloat162_raw hr_a = a; + return __hip_bfloat162{ + {__hisnan(__hip_bfloat16_raw{hr_a.x}) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, + {__hisnan(__hip_bfloat16_raw{hr_a.y}) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hle(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hle(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0 + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hlt(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hlt(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Returns max of two elements + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162( + __hmax(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hmax(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Returns min of two elements + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162( + __hmin(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}), + __hmin(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks for not equal to + */ +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, + const __hip_bfloat162 b) { + __hip_bfloat162_raw hr_a = a; + __hip_bfloat162_raw hr_b = b; + return __hip_bfloat162{ + {__hne(__hip_bfloat16_raw{hr_a.x}, __hip_bfloat16_raw{hr_b.x}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}, + {__hne(__hip_bfloat16_raw{hr_a.y}, __hip_bfloat16_raw{hr_b.y}) + ? HIPRT_ONE_BF16 + : HIPRT_ZERO_BF16}}; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __heq(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Operator to perform a not equal on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hne(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Operator to perform a less than on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hlt(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hle(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Operator to perform a greater than on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hgt(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat16 &l, + const __hip_bfloat16 &r) { + return __hge(l, r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat162 &l, + const __hip_bfloat162 &r) { + float2 ret = __heq2(l, r); + return ret.x != 0.0f && ret.y != 0.0f; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Operator to perform a not equal on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat162 &l, + const __hip_bfloat162 &r) { + return !(l == r); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Operator to perform a less than on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat162 &l, + const __hip_bfloat162 &r) { + float2 fl = l, fr = r; + return fl.x < fr.x && fl.x < fr.y; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat162 &l, + const __hip_bfloat162 &r) { + float2 fl = l, fr = r; + return fl.x <= fr.x && fl.x <= fr.y; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Operator to perform a greater than on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat162 &l, + const __hip_bfloat162 &r) { + float2 fl = l, fr = r; + return fl.x > fr.x && fl.x > fr.y; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers + */ +__BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat162 &l, + const __hip_bfloat162 &r) { + float2 fl = l, fr = r; + return fl.x >= fr.x && fl.x >= fr.y; +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate ceil of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hceil(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_ceil_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate cosine of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hcos(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_cos_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate exponential of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_exp_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate exponential 10 of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp10(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_exp10_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate exponential 2 of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hexp2(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_exp2_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate floor of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hfloor(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_floor_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate natural log of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_log_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate log 10 of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog10(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_log10_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate log 2 of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hlog2(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_log2_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate reciprocal + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hrcp(const __hip_bfloat16 h) { + return __float2bfloat16(1.0f / (__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Round to nearest int + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hrint(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_rint_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Reciprocal square root + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hrsqrt(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_rsqrt_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate sin of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hsin(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_sin_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate sqrt of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 hsqrt(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_sqrt_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT16_MATH + * \brief Calculate truncate of bfloat16 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat16 htrunc(const __hip_bfloat16 h) { + return __float2bfloat16(__ocml_trunc_f32(__bfloat162float(h))); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate ceil of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2ceil(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hceil(__hip_bfloat16_raw{hr.x}), + hceil(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate cosine of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2cos(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hcos(__hip_bfloat16_raw{hr.x}), + hcos(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate exponential of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hexp(__hip_bfloat16_raw{hr.x}), + hexp(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate exponential 10 of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp10(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hexp10(__hip_bfloat16_raw{hr.x}), + hexp10(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate exponential 2 of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2exp2(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hexp2(__hip_bfloat16_raw{hr.x}), + hexp2(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate floor of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2floor(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hfloor(__hip_bfloat16_raw{hr.x}), + hfloor(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate natural log of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hlog(__hip_bfloat16_raw{hr.x}), + hlog(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate log 10 of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log10(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hlog10(__hip_bfloat16_raw{hr.x}), + hlog10(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate log 2 of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2log2(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hlog2(__hip_bfloat16_raw{hr.x}), + hlog2(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate vector reciprocal + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rcp(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hrcp(__hip_bfloat16_raw{hr.x}), + hrcp(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate vector round to nearest int + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rint(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hrint(__hip_bfloat16_raw{hr.x}), + hrint(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate vector reciprocal square root + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2rsqrt(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hrsqrt(__hip_bfloat16_raw{hr.x}), + hrsqrt(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate sin of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2sin(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hsin(__hip_bfloat16_raw{hr.x}), + hsin(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate sqrt of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2sqrt(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(hsqrt(__hip_bfloat16_raw{hr.x}), + hsqrt(__hip_bfloat16_raw{hr.y})); +} + +/** + * \ingroup HIP_INTRINSIC_BFLOAT162_MATH + * \brief Calculate truncate of bfloat162 + */ +__BF16_DEVICE_STATIC__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) { + __hip_bfloat162_raw hr = h; + return __hip_bfloat162(htrunc(__hip_bfloat16_raw{hr.x}), + htrunc(__hip_bfloat16_raw{hr.y})); +} +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_bfloat16.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_bfloat16.h new file mode 100644 index 000000000..a99609f6e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_bfloat16.h @@ -0,0 +1,250 @@ +/** + * MIT License + * + * Copyright (c) 2019 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/*!\file + * \brief hip_bfloat16.h provides struct for hip_bfloat16 typedef + */ + +#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_ +#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_ + +#include "host_defines.h" +#if defined(__HIPCC_RTC__) +#define __HOST_DEVICE__ __device__ +#else +#define __HOST_DEVICE__ __host__ __device__ +#endif + +#if __cplusplus < 201103L || !defined(__HIPCC__) + +// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, +// we only include a minimal definition of hip_bfloat16 + +#include +/*! \brief Struct to represent a 16 bit brain floating point number. */ +typedef struct { + uint16_t data; +} hip_bfloat16; + +#else // __cplusplus < 201103L || !defined(__HIPCC__) + +#include + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wshadow" +struct hip_bfloat16 { + __hip_uint16_t data; + + enum truncate_t { truncate }; + + __HOST_DEVICE__ hip_bfloat16() = default; + + // round upper 16 bits of IEEE float to convert to bfloat16 + explicit __HOST_DEVICE__ hip_bfloat16(float f) : data(float_to_bfloat16(f)) {} + + explicit __HOST_DEVICE__ hip_bfloat16(float f, truncate_t) + : data(truncate_float_to_bfloat16(f)) {} + + // zero extend lower 16 bits of bfloat16 to convert to IEEE float + __HOST_DEVICE__ operator float() const { + union { + uint32_t int32; + float fp32; + } u = {uint32_t(data) << 16}; + return u.fp32; + } + + __HOST_DEVICE__ hip_bfloat16 &operator=(const float &f) { + data = float_to_bfloat16(f); + return *this; + } + + static __HOST_DEVICE__ hip_bfloat16 round_to_bfloat16(float f) { + hip_bfloat16 output; + output.data = float_to_bfloat16(f); + return output; + } + + static __HOST_DEVICE__ hip_bfloat16 round_to_bfloat16(float f, truncate_t) { + hip_bfloat16 output; + output.data = truncate_float_to_bfloat16(f); + return output; + } + +private: + static __HOST_DEVICE__ __hip_uint16_t float_to_bfloat16(float f) { + union { + float fp32; + uint32_t int32; + } u = {f}; + if (~u.int32 & 0x7f800000) { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += + 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + } else if (u.int32 & 0xffff) { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN + } + return __hip_uint16_t(u.int32 >> 16); + } + + // Truncate instead of rounding, preserving SNaN + static __HOST_DEVICE__ __hip_uint16_t truncate_float_to_bfloat16(float f) { + union { + float fp32; + uint32_t int32; + } u = {f}; + return __hip_uint16_t(u.int32 >> 16) | + (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff)); + } +}; +#pragma clang diagnostic pop + +typedef struct { + __hip_uint16_t data; +} hip_bfloat16_public; + +static_assert(__hip_internal::is_standard_layout{}, + "hip_bfloat16 is not a standard layout type, and thus is " + "incompatible with C."); + +static_assert(__hip_internal::is_trivial{}, + "hip_bfloat16 is not a trivial type, and thus is " + "incompatible with C."); +#if !defined(__HIPCC_RTC__) +static_assert(sizeof(hip_bfloat16) == sizeof(hip_bfloat16_public) && + offsetof(hip_bfloat16, data) == + offsetof(hip_bfloat16_public, data), + "internal hip_bfloat16 does not match public hip_bfloat16"); + +inline std::ostream &operator<<(std::ostream &os, const hip_bfloat16 &bf16) { + return os << float(bf16); +} +#endif + +inline __HOST_DEVICE__ hip_bfloat16 operator+(hip_bfloat16 a) { return a; } +inline __HOST_DEVICE__ hip_bfloat16 operator-(hip_bfloat16 a) { + a.data ^= 0x8000; + return a; +} +inline __HOST_DEVICE__ hip_bfloat16 operator+(hip_bfloat16 a, hip_bfloat16 b) { + return hip_bfloat16(float(a) + float(b)); +} +inline __HOST_DEVICE__ hip_bfloat16 operator-(hip_bfloat16 a, hip_bfloat16 b) { + return hip_bfloat16(float(a) - float(b)); +} +inline __HOST_DEVICE__ hip_bfloat16 operator*(hip_bfloat16 a, hip_bfloat16 b) { + return hip_bfloat16(float(a) * float(b)); +} +inline __HOST_DEVICE__ hip_bfloat16 operator/(hip_bfloat16 a, hip_bfloat16 b) { + return hip_bfloat16(float(a) / float(b)); +} +inline __HOST_DEVICE__ bool operator<(hip_bfloat16 a, hip_bfloat16 b) { + return float(a) < float(b); +} +inline __HOST_DEVICE__ bool operator==(hip_bfloat16 a, hip_bfloat16 b) { + return float(a) == float(b); +} +inline __HOST_DEVICE__ bool operator>(hip_bfloat16 a, hip_bfloat16 b) { + return b < a; +} +inline __HOST_DEVICE__ bool operator<=(hip_bfloat16 a, hip_bfloat16 b) { + return !(a > b); +} +inline __HOST_DEVICE__ bool operator!=(hip_bfloat16 a, hip_bfloat16 b) { + return !(a == b); +} +inline __HOST_DEVICE__ bool operator>=(hip_bfloat16 a, hip_bfloat16 b) { + return !(a < b); +} +inline __HOST_DEVICE__ hip_bfloat16 &operator+=(hip_bfloat16 &a, + hip_bfloat16 b) { + return a = a + b; +} +inline __HOST_DEVICE__ hip_bfloat16 &operator-=(hip_bfloat16 &a, + hip_bfloat16 b) { + return a = a - b; +} +inline __HOST_DEVICE__ hip_bfloat16 &operator*=(hip_bfloat16 &a, + hip_bfloat16 b) { + return a = a * b; +} +inline __HOST_DEVICE__ hip_bfloat16 &operator/=(hip_bfloat16 &a, + hip_bfloat16 b) { + return a = a / b; +} +inline __HOST_DEVICE__ hip_bfloat16 &operator++(hip_bfloat16 &a) { + return a += hip_bfloat16(1.0f); +} +inline __HOST_DEVICE__ hip_bfloat16 &operator--(hip_bfloat16 &a) { + return a -= hip_bfloat16(1.0f); +} +inline __HOST_DEVICE__ hip_bfloat16 operator++(hip_bfloat16 &a, int) { + hip_bfloat16 orig = a; + ++a; + return orig; +} +inline __HOST_DEVICE__ hip_bfloat16 operator--(hip_bfloat16 &a, int) { + hip_bfloat16 orig = a; + --a; + return orig; +} + +namespace std { +constexpr __HOST_DEVICE__ bool isinf(hip_bfloat16 a) { + return !(~a.data & 0x7f80) && !(a.data & 0x7f); +} +constexpr __HOST_DEVICE__ bool isnan(hip_bfloat16 a) { + return !(~a.data & 0x7f80) && +(a.data & 0x7f); +} +constexpr __HOST_DEVICE__ bool iszero(hip_bfloat16 a) { + return !(a.data & 0x7fff); +} +} // namespace std + +#endif // __cplusplus < 201103L || !defined(__HIPCC__) + +#endif // _HIP_BFLOAT16_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_common.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_common.h new file mode 100644 index 000000000..0c7dc51b5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_common.h @@ -0,0 +1,32 @@ +/* +Copyright (c) 2019 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COMMON_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COMMON_H + +#if defined(__clang__) && defined(__HIP__) +#define __HIP_CLANG_ONLY__ 1 +#else +#define __HIP_CLANG_ONLY__ 0 +#endif + +#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COMMON_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_complex.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_complex.h new file mode 100644 index 000000000..bfca20277 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_complex.h @@ -0,0 +1,194 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/* The header defines complex numbers and related functions*/ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COMPLEX_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COMPLEX_H + +#if !defined(__HIPCC_RTC__) +#include "hip/amd_detail/amd_hip_vector_types.h" +#endif + +#if defined(__HIPCC_RTC__) +#define __HOST_DEVICE__ __device__ +#else +#define __HOST_DEVICE__ __host__ __device__ +// TODO: Clang has a bug which allows device functions to call std functions +// when std functions are introduced into default namespace by using statement. +// math.h may be included after this bug is fixed. +#if __cplusplus +#include +#else +#include "math.h" +#endif +#endif // !defined(__HIPCC_RTC__) + +typedef float2 hipFloatComplex; + +__HOST_DEVICE__ static inline float hipCrealf(hipFloatComplex z) { return z.x; } + +__HOST_DEVICE__ static inline float hipCimagf(hipFloatComplex z) { return z.y; } + +__HOST_DEVICE__ static inline hipFloatComplex make_hipFloatComplex(float a, + float b) { + hipFloatComplex z; + z.x = a; + z.y = b; + return z; +} + +__HOST_DEVICE__ static inline hipFloatComplex hipConjf(hipFloatComplex z) { + hipFloatComplex ret; + ret.x = z.x; + ret.y = -z.y; + return ret; +} + +__HOST_DEVICE__ static inline float hipCsqabsf(hipFloatComplex z) { + return z.x * z.x + z.y * z.y; +} + +__HOST_DEVICE__ static inline hipFloatComplex hipCaddf(hipFloatComplex p, + hipFloatComplex q) { + return make_hipFloatComplex(p.x + q.x, p.y + q.y); +} + +__HOST_DEVICE__ static inline hipFloatComplex hipCsubf(hipFloatComplex p, + hipFloatComplex q) { + return make_hipFloatComplex(p.x - q.x, p.y - q.y); +} + +__HOST_DEVICE__ static inline hipFloatComplex hipCmulf(hipFloatComplex p, + hipFloatComplex q) { + return make_hipFloatComplex(p.x * q.x - p.y * q.y, p.y * q.x + p.x * q.y); +} + +__HOST_DEVICE__ static inline hipFloatComplex hipCdivf(hipFloatComplex p, + hipFloatComplex q) { + float sqabs = hipCsqabsf(q); + hipFloatComplex ret; + ret.x = (p.x * q.x + p.y * q.y) / sqabs; + ret.y = (p.y * q.x - p.x * q.y) / sqabs; + return ret; +} + +__HOST_DEVICE__ static inline float hipCabsf(hipFloatComplex z) { + return sqrtf(hipCsqabsf(z)); +} + +typedef double2 hipDoubleComplex; + +__HOST_DEVICE__ static inline double hipCreal(hipDoubleComplex z) { + return z.x; +} + +__HOST_DEVICE__ static inline double hipCimag(hipDoubleComplex z) { + return z.y; +} + +__HOST_DEVICE__ static inline hipDoubleComplex make_hipDoubleComplex(double a, + double b) { + hipDoubleComplex z; + z.x = a; + z.y = b; + return z; +} + +__HOST_DEVICE__ static inline hipDoubleComplex hipConj(hipDoubleComplex z) { + hipDoubleComplex ret; + ret.x = z.x; + ret.y = -z.y; + return ret; +} + +__HOST_DEVICE__ static inline double hipCsqabs(hipDoubleComplex z) { + return z.x * z.x + z.y * z.y; +} + +__HOST_DEVICE__ static inline hipDoubleComplex hipCadd(hipDoubleComplex p, + hipDoubleComplex q) { + return make_hipDoubleComplex(p.x + q.x, p.y + q.y); +} + +__HOST_DEVICE__ static inline hipDoubleComplex hipCsub(hipDoubleComplex p, + hipDoubleComplex q) { + return make_hipDoubleComplex(p.x - q.x, p.y - q.y); +} + +__HOST_DEVICE__ static inline hipDoubleComplex hipCmul(hipDoubleComplex p, + hipDoubleComplex q) { + return make_hipDoubleComplex(p.x * q.x - p.y * q.y, p.y * q.x + p.x * q.y); +} + +__HOST_DEVICE__ static inline hipDoubleComplex hipCdiv(hipDoubleComplex p, + hipDoubleComplex q) { + double sqabs = hipCsqabs(q); + hipDoubleComplex ret; + ret.x = (p.x * q.x + p.y * q.y) / sqabs; + ret.y = (p.y * q.x - p.x * q.y) / sqabs; + return ret; +} + +__HOST_DEVICE__ static inline double hipCabs(hipDoubleComplex z) { + return sqrt(hipCsqabs(z)); +} + +typedef hipFloatComplex hipComplex; + +__HOST_DEVICE__ static inline hipComplex make_hipComplex(float x, float y) { + return make_hipFloatComplex(x, y); +} + +__HOST_DEVICE__ static inline hipFloatComplex +hipComplexDoubleToFloat(hipDoubleComplex z) { + return make_hipFloatComplex((float)z.x, (float)z.y); +} + +__HOST_DEVICE__ static inline hipDoubleComplex +hipComplexFloatToDouble(hipFloatComplex z) { + return make_hipDoubleComplex((double)z.x, (double)z.y); +} + +__HOST_DEVICE__ static inline hipComplex hipCfmaf(hipComplex p, hipComplex q, + hipComplex r) { + float real = (p.x * q.x) + r.x; + float imag = (q.x * p.y) + r.y; + + real = -(p.y * q.y) + real; + imag = (p.x * q.y) + imag; + + return make_hipComplex(real, imag); +} + +__HOST_DEVICE__ static inline hipDoubleComplex +hipCfma(hipDoubleComplex p, hipDoubleComplex q, hipDoubleComplex r) { + double real = (p.x * q.x) + r.x; + double imag = (q.x * p.y) + r.y; + + real = -(p.y * q.y) + real; + imag = (p.x * q.y) + imag; + + return make_hipDoubleComplex(real, imag); +} + +#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COMPLEX_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_cooperative_groups.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_cooperative_groups.h new file mode 100644 index 000000000..ee8da62b4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_cooperative_groups.h @@ -0,0 +1,889 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file amd_detail/hip_cooperative_groups.h + * + * @brief Device side implementation of `Cooperative Group` feature. + * + * Defines new types and device API wrappers related to `Cooperative Group` + * feature, which the programmer can directly use in his kernel(s) in order to + * make use of this feature. + */ +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H + +#if __cplusplus +#if !defined(__HIPCC_RTC__) +#include +#endif + +namespace cooperative_groups { + +/** @brief The base type of all cooperative group types + * + * \details Holds the key properties of a constructed cooperative group types + * object, like the group type, its size, etc + * + * @note Cooperative groups feature is implemented on Linux, under + * developement on Windows. + */ +class thread_group { +protected: + uint32_t _type; // thread_group type + uint32_t _size; // total number of threads in the tread_group + uint64_t _mask; // Lanemask for coalesced and tiled partitioned group types, + // LSB represents lane 0, and MSB represents lane 63 + + // Construct a thread group, and set thread group type and other essential + // thread group properties. This generic thread group is directly constructed + // only when the group is supposed to contain only the calling the thread + // (throurh the API - `this_thread()`), and in all other cases, this thread + // group object is a sub-object of some other derived thread group object + __CG_QUALIFIER__ thread_group(internal::group_type type, + uint32_t size = static_cast(0), + uint64_t mask = static_cast(0)) { + _type = type; + _size = size; + _mask = mask; + } + + struct _tiled_info { + bool is_tiled; + unsigned int size; + unsigned int meta_group_rank; + unsigned int meta_group_size; + }; + + struct _coalesced_info { + lane_mask member_mask; + unsigned int size; + struct _tiled_info tiled_info; + } coalesced_info; + + friend __CG_QUALIFIER__ thread_group + tiled_partition(const thread_group &parent, unsigned int tile_size); + friend class thread_block; + +public: + // Total number of threads in the thread group, and this serves the purpose + // for all derived cooperative group types since their `size` is directly + // saved during the construction + __CG_QUALIFIER__ uint32_t size() const { return _size; } + __CG_QUALIFIER__ unsigned int cg_type() const { return _type; } + // Rank of the calling thread within [0, size()) + __CG_QUALIFIER__ uint32_t thread_rank() const; + // Is this cooperative group type valid? + __CG_QUALIFIER__ bool is_valid() const; + // synchronize the threads in the thread group + __CG_QUALIFIER__ void sync() const; +}; +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup CooperativeG Cooperative Groups + * @ingroup API + * @{ + * This section describes the cooperative groups functions of HIP runtime API. + * + * The cooperative groups provides flexible thread parallel programming + *algorithms, threads cooperate and share data to perform collective + *computations. + * + * @note Cooperative groups feature is implemented on Linux, under + *developement on Windows. + * + */ +/** \brief The multi-grid cooperative group type + * + * \details Represents an inter-device cooperative group type where the + * participating threads within the group spans across multple + * devices, running the (same) kernel on these devices + * @note The multi-grid cooperative group type is implemented on Linux, under + * developement on Windows. + */ +class multi_grid_group : public thread_group { + // Only these friend functions are allowed to construct an object of this + // class and access its resources + friend __CG_QUALIFIER__ multi_grid_group this_multi_grid(); + +protected: + // Construct mutli-grid thread group (through the API this_multi_grid()) + explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size) + : thread_group(internal::cg_multi_grid, size) {} + +public: + // Number of invocations participating in this multi-grid group. In other + // words, the number of GPUs + __CG_QUALIFIER__ uint32_t num_grids() { + return internal::multi_grid::num_grids(); + } + // Rank of this invocation. In other words, an ID number within the range + // [0, num_grids()) of the GPU, this kernel is running on + __CG_QUALIFIER__ uint32_t grid_rank() { + return internal::multi_grid::grid_rank(); + } + __CG_QUALIFIER__ uint32_t thread_rank() const { + return internal::multi_grid::thread_rank(); + } + __CG_QUALIFIER__ bool is_valid() const { + return internal::multi_grid::is_valid(); + } + __CG_QUALIFIER__ void sync() const { internal::multi_grid::sync(); } +}; + +/** @brief User exposed API interface to construct multi-grid cooperative + * group type object - `multi_grid_group` + * + * \details User is not allowed to directly construct an object of type + * `multi_grid_group`. Instead, he should construct it through this + * API function + * @note This multi-grid cooperative API type is implemented on Linux, under + * developement on Windows. + */ +__CG_QUALIFIER__ multi_grid_group this_multi_grid() { + return multi_grid_group(internal::multi_grid::size()); +} + +/** @brief The grid cooperative group type + * + * \details Represents an inter-workgroup cooperative group type where the + * participating threads within the group spans across multiple + * workgroups running the (same) kernel on the same device + * @note This is implemented on Linux, under developement + * on Windows. + */ +class grid_group : public thread_group { + // Only these friend functions are allowed to construct an object of this + // class and access its resources + friend __CG_QUALIFIER__ grid_group this_grid(); + +protected: + // Construct grid thread group (through the API this_grid()) + explicit __CG_QUALIFIER__ grid_group(uint32_t size) + : thread_group(internal::cg_grid, size) {} + +public: + __CG_QUALIFIER__ uint32_t thread_rank() const { + return internal::grid::thread_rank(); + } + __CG_QUALIFIER__ bool is_valid() const { return internal::grid::is_valid(); } + __CG_QUALIFIER__ void sync() const { internal::grid::sync(); } +}; + +/** @brief User exposed API interface to construct grid cooperative group type + * object - `grid_group` + * + * \details User is not allowed to directly construct an object of type + * `multi_grid_group`. Instead, he should construct it through this + * API function + * @note This function is implemented on Linux, under developement + * on Windows. + */ +__CG_QUALIFIER__ grid_group this_grid() { + return grid_group(internal::grid::size()); +} + +/** @brief The workgroup (thread-block in CUDA terminology) cooperative group + * type + * + * \details Represents an intra-workgroup cooperative group type where the + * participating threads within the group are exactly the same threads + * which are participated in the currently executing `workgroup` + * @note This is implemented on Linux, under developement + * on Windows. + */ +class thread_block : public thread_group { + // Only these friend functions are allowed to construct an object of thi + // class and access its resources + friend __CG_QUALIFIER__ thread_block this_thread_block(); + friend __CG_QUALIFIER__ thread_group + tiled_partition(const thread_group &parent, unsigned int tile_size); + friend __CG_QUALIFIER__ thread_group + tiled_partition(const thread_block &parent, unsigned int tile_size); + +protected: + // Construct a workgroup thread group (through the API this_thread_block()) + explicit __CG_QUALIFIER__ thread_block(uint32_t size) + : thread_group(internal::cg_workgroup, size) {} + + __CG_QUALIFIER__ thread_group new_tiled_group(unsigned int tile_size) const { + const bool pow2 = ((tile_size & (tile_size - 1)) == 0); + // Invalid tile size, assert + if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) { + __hip_assert(false && "invalid tile size"); + } + + auto block_size = size(); + auto rank = thread_rank(); + auto partitions = (block_size + tile_size - 1) / tile_size; + auto tail = (partitions * tile_size) - block_size; + auto partition_size = + tile_size - tail * (rank >= (partitions - 1) * tile_size); + thread_group tiledGroup = + thread_group(internal::cg_tiled_group, partition_size); + + tiledGroup.coalesced_info.tiled_info.size = tile_size; + tiledGroup.coalesced_info.tiled_info.is_tiled = true; + tiledGroup.coalesced_info.tiled_info.meta_group_rank = rank / tile_size; + tiledGroup.coalesced_info.tiled_info.meta_group_size = partitions; + return tiledGroup; + } + +public: + // 3-dimensional block index within the grid + __CG_STATIC_QUALIFIER__ dim3 group_index() { + return internal::workgroup::group_index(); + } + // 3-dimensional thread index within the block + __CG_STATIC_QUALIFIER__ dim3 thread_index() { + return internal::workgroup::thread_index(); + } + __CG_STATIC_QUALIFIER__ uint32_t thread_rank() { + return internal::workgroup::thread_rank(); + } + __CG_STATIC_QUALIFIER__ uint32_t size() { + return internal::workgroup::size(); + } + __CG_STATIC_QUALIFIER__ bool is_valid() { + return internal::workgroup::is_valid(); + } + __CG_STATIC_QUALIFIER__ void sync() { internal::workgroup::sync(); } + __CG_QUALIFIER__ dim3 group_dim() { return internal::workgroup::block_dim(); } +}; + +/** \brief User exposed API interface to construct workgroup cooperative + * group type object - `thread_block`. + * + * \details User is not allowed to directly construct an object of type + * `thread_block`. Instead, he should construct it through this API + * function. + * @note This function is implemented on Linux, under developement + * on Windows. + */ +__CG_QUALIFIER__ thread_block this_thread_block() { + return thread_block(internal::workgroup::size()); +} + +/** \brief The tiled_group cooperative group type + * + * \details Represents one tiled thread group in a wavefront. + * This group type also supports sub-wave level intrinsics. + * @note This is implemented on Linux, under developement + * on Windows. + */ + +class tiled_group : public thread_group { +private: + friend __CG_QUALIFIER__ thread_group + tiled_partition(const thread_group &parent, unsigned int tile_size); + friend __CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group &parent, + unsigned int tile_size); + + __CG_QUALIFIER__ tiled_group new_tiled_group(unsigned int tile_size) const { + const bool pow2 = ((tile_size & (tile_size - 1)) == 0); + + if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) { + __hip_assert(false && "invalid tile size"); + } + + if (size() <= tile_size) { + return *this; + } + + tiled_group tiledGroup = tiled_group(tile_size); + tiledGroup.coalesced_info.tiled_info.is_tiled = true; + return tiledGroup; + } + +protected: + explicit __CG_QUALIFIER__ tiled_group(unsigned int tileSize) + : thread_group(internal::cg_tiled_group, tileSize) { + coalesced_info.tiled_info.size = tileSize; + coalesced_info.tiled_info.is_tiled = true; + } + +public: + __CG_QUALIFIER__ unsigned int size() const { + return (coalesced_info.tiled_info.size); + } + + __CG_QUALIFIER__ unsigned int thread_rank() const { + return (internal::workgroup::thread_rank() & + (coalesced_info.tiled_info.size - 1)); + } + + __CG_QUALIFIER__ void sync() const { internal::tiled_group::sync(); } +}; + +/** \brief The coalesced_group cooperative group type + * + * \details Represents a active thread group in a wavefront. + * This group type also supports sub-wave level intrinsics. + * @note This is implemented on Linux, under developement + * on Windows. + */ +class coalesced_group : public thread_group { +private: + friend __CG_QUALIFIER__ coalesced_group coalesced_threads(); + friend __CG_QUALIFIER__ thread_group + tiled_partition(const thread_group &parent, unsigned int tile_size); + friend __CG_QUALIFIER__ coalesced_group + tiled_partition(const coalesced_group &parent, unsigned int tile_size); + + __CG_QUALIFIER__ coalesced_group + new_tiled_group(unsigned int tile_size) const { + const bool pow2 = ((tile_size & (tile_size - 1)) == 0); + + if (!tile_size || (tile_size > size()) || !pow2) { + return coalesced_group(0); + } + + // If a tiled group is passed to be partitioned further into a + // coalesced_group. prepare a mask for further partitioning it so that it + // stays coalesced. + if (coalesced_info.tiled_info.is_tiled) { + unsigned int base_offset = (thread_rank() & (~(tile_size - 1))); + unsigned int masklength = + min(static_cast(size()) - base_offset, tile_size); + lane_mask member_mask = + static_cast(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength); + + member_mask <<= (__lane_id() & ~(tile_size - 1)); + coalesced_group coalesced_tile = coalesced_group(member_mask); + coalesced_tile.coalesced_info.tiled_info.is_tiled = true; + coalesced_tile.coalesced_info.tiled_info.meta_group_rank = + thread_rank() / tile_size; + coalesced_tile.coalesced_info.tiled_info.meta_group_size = + size() / tile_size; + return coalesced_tile; + } + // Here the parent coalesced_group is not partitioned. + else { + lane_mask member_mask = 0; + unsigned int tile_rank = 0; + int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size; + + for (unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) { + lane_mask active = coalesced_info.member_mask & (1 << i); + // Make sure the lane is active + if (active) { + if (lanes_to_skip <= 0 && tile_rank < tile_size) { + // Prepare a member_mask that is appropriate for a tile + member_mask |= active; + tile_rank++; + } + lanes_to_skip--; + } + } + coalesced_group coalesced_tile = coalesced_group(member_mask); + coalesced_tile.coalesced_info.tiled_info.meta_group_rank = + thread_rank() / tile_size; + coalesced_tile.coalesced_info.tiled_info.meta_group_size = + (size() + tile_size - 1) / tile_size; + return coalesced_tile; + } + return coalesced_group(0); + } + +protected: + // Constructor + explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask) + : thread_group(internal::cg_coalesced_group) { + coalesced_info.member_mask = member_mask; // Which threads are active + coalesced_info.size = + __popcll(coalesced_info.member_mask); // How many threads are active + coalesced_info.tiled_info.is_tiled = false; // Not a partitioned group + coalesced_info.tiled_info.meta_group_rank = 0; + coalesced_info.tiled_info.meta_group_size = 1; + } + +public: + __CG_QUALIFIER__ unsigned int size() const { return coalesced_info.size; } + + __CG_QUALIFIER__ unsigned int thread_rank() const { + return internal::coalesced_group::masked_bit_count( + coalesced_info.member_mask); + } + + __CG_QUALIFIER__ void sync() const { internal::coalesced_group::sync(); } + + __CG_QUALIFIER__ unsigned int meta_group_rank() const { + return coalesced_info.tiled_info.meta_group_rank; + } + + __CG_QUALIFIER__ unsigned int meta_group_size() const { + return coalesced_info.tiled_info.meta_group_size; + } + + template __CG_QUALIFIER__ T shfl(T var, int srcRank) const { + static_assert(is_valid_type::value, "Neither an integer or float type."); + + srcRank = srcRank % static_cast(size()); + + int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank + : (__AMDGCN_WAVEFRONT_SIZE == 64) + ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1)) + : __fns32(coalesced_info.member_mask, 0, (srcRank + 1)); + + return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE); + } + + template + __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const { + static_assert(is_valid_type::value, "Neither an integer or float type."); + + // Note: The cuda implementation appears to use the remainder of lane_delta + // and WARP_SIZE as the shift value rather than lane_delta itself. + // This is not described in the documentation and is not done here. + + if (size() == __AMDGCN_WAVEFRONT_SIZE) { + return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE); + } + + int lane; + if (__AMDGCN_WAVEFRONT_SIZE == 64) { + lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1); + } else { + lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1); + } + + if (lane == -1) { + lane = __lane_id(); + } + + return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE); + } + + template + __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const { + static_assert(is_valid_type::value, "Neither an integer or float type."); + + // Note: The cuda implementation appears to use the remainder of lane_delta + // and WARP_SIZE as the shift value rather than lane_delta itself. + // This is not described in the documentation and is not done here. + + if (size() == __AMDGCN_WAVEFRONT_SIZE) { + return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE); + } + + int lane; + if (__AMDGCN_WAVEFRONT_SIZE == 64) { + lane = + __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1)); + } else if (__AMDGCN_WAVEFRONT_SIZE == 32) { + lane = + __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1)); + } + + if (lane == -1) { + lane = __lane_id(); + } + + return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE); + } +}; + +/** \brief User exposed API to create coalesced groups. + * + * \details A collective operation that groups all active lanes into a new + * thread group. + * @note This function is implemented on Linux, under developement + * on Windows. + */ + +__CG_QUALIFIER__ coalesced_group coalesced_threads() { + return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec()); +} + +/** + * Implemenation of all publicly exposed base class APIs + * @note This function is implemented on Linux, under developement + * on Windows. + */ +__CG_QUALIFIER__ uint32_t thread_group::thread_rank() const { + switch (this->_type) { + case internal::cg_multi_grid: { + return (static_cast(this)->thread_rank()); + } + case internal::cg_grid: { + return (static_cast(this)->thread_rank()); + } + case internal::cg_workgroup: { + return (static_cast(this)->thread_rank()); + } + case internal::cg_tiled_group: { + return (static_cast(this)->thread_rank()); + } + case internal::cg_coalesced_group: { + return (static_cast(this)->thread_rank()); + } + default: { + __hip_assert(false && "invalid cooperative group type"); + return -1; + } + } +} +/** + * Implemenation of all publicly exposed thread group API + * @note This function is implemented on Linux, under developement + * on Windows. + */ +__CG_QUALIFIER__ bool thread_group::is_valid() const { + switch (this->_type) { + case internal::cg_multi_grid: { + return (static_cast(this)->is_valid()); + } + case internal::cg_grid: { + return (static_cast(this)->is_valid()); + } + case internal::cg_workgroup: { + return (static_cast(this)->is_valid()); + } + case internal::cg_tiled_group: { + return (static_cast(this)->is_valid()); + } + case internal::cg_coalesced_group: { + return (static_cast(this)->is_valid()); + } + default: { + __hip_assert(false && "invalid cooperative group type"); + return false; + } + } +} +/** + * Implemenation of all publicly exposed thread group sync API + * @note This function is implemented on Linux, under developement + * on Windows. + */ +__CG_QUALIFIER__ void thread_group::sync() const { + switch (this->_type) { + case internal::cg_multi_grid: { + static_cast(this)->sync(); + break; + } + case internal::cg_grid: { + static_cast(this)->sync(); + break; + } + case internal::cg_workgroup: { + static_cast(this)->sync(); + break; + } + case internal::cg_tiled_group: { + static_cast(this)->sync(); + break; + } + case internal::cg_coalesced_group: { + static_cast(this)->sync(); + break; + } + default: { + __hip_assert(false && "invalid cooperative group type"); + } + } +} + +/** + * Implemenation of publicly exposed `wrapper` API on top of basic cooperative + * group type APIs + * @note This function is implemented on Linux, under developement + * on Windows. + */ +template __CG_QUALIFIER__ uint32_t group_size(CGTy const &g) { + return g.size(); +} +/** + * Implemenation of publicly exposed `wrapper` API on top of basic cooperative + * group type APIs + * @note This function is implemented on Linux, under developement + * on Windows. + */ +template __CG_QUALIFIER__ uint32_t thread_rank(CGTy const &g) { + return g.thread_rank(); +} +/** + * Implemenation of publicly exposed `wrapper` API on top of basic cooperative + * group type APIs + * @note This function is implemented on Linux, under developement + * on Windows. + */ +template __CG_QUALIFIER__ bool is_valid(CGTy const &g) { + return g.is_valid(); +} +/** + * Implemenation of publicly exposed `wrapper` API on top of basic cooperative + * group type APIs + * @note This function is implemented on Linux, under developement + * on Windows. + */ +template __CG_QUALIFIER__ void sync(CGTy const &g) { g.sync(); } +/** + * template class tile_base + * @note This class is implemented on Linux, under developement + * on Windows. + */ +template class tile_base { +protected: + _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize; + +public: + // Rank of the thread within this tile + _CG_STATIC_CONST_DECL_ unsigned int thread_rank() { + return (internal::workgroup::thread_rank() & (numThreads - 1)); + } + + // Number of threads within this tile + __CG_STATIC_QUALIFIER__ unsigned int size() { return numThreads; } +}; +/** + * template class thread_block_tile_base + * @note This class is implemented on Linux, under developement + * on Windows. + */ +template +class thread_block_tile_base : public tile_base { + static_assert(is_valid_tile_size::value, + "Tile size is either not a power of 2 or greater than the " + "wavefront size"); + using tile_base::numThreads; + +public: + __CG_STATIC_QUALIFIER__ void sync() { internal::tiled_group::sync(); } + + template __CG_QUALIFIER__ T shfl(T var, int srcRank) const { + static_assert(is_valid_type::value, "Neither an integer or float type."); + return (__shfl(var, srcRank, numThreads)); + } + + template + __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const { + static_assert(is_valid_type::value, "Neither an integer or float type."); + return (__shfl_down(var, lane_delta, numThreads)); + } + + template + __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const { + static_assert(is_valid_type::value, "Neither an integer or float type."); + return (__shfl_up(var, lane_delta, numThreads)); + } + + template + __CG_QUALIFIER__ T shfl_xor(T var, unsigned int laneMask) const { + static_assert(is_valid_type::value, "Neither an integer or float type."); + return (__shfl_xor(var, laneMask, numThreads)); + } +}; +/** \brief User exposed API that captures the state of the parent group + * pre-partition + */ +template class parent_group_info { +public: + // Returns the linear rank of the group within the set of tiles partitioned + // from a parent group (bounded by meta_group_size) + __CG_STATIC_QUALIFIER__ unsigned int meta_group_rank() { + return ParentCGTy::thread_rank() / tileSize; + } + + // Returns the number of groups created when the parent group was partitioned. + __CG_STATIC_QUALIFIER__ unsigned int meta_group_size() { + return (ParentCGTy::size() + tileSize - 1) / tileSize; + } +}; + +/** \brief Group type - thread_block_tile + * + * \details Represents one tile of thread group. + * @note This type is implemented on Linux, under developement + * on Windows. + */ +template +class thread_block_tile_type : public thread_block_tile_base, + public tiled_group, + public parent_group_info { + _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize; + typedef thread_block_tile_base tbtBase; + +protected: + __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) { + coalesced_info.tiled_info.size = numThreads; + coalesced_info.tiled_info.is_tiled = true; + } + +public: + using tbtBase::size; + using tbtBase::sync; + using tbtBase::thread_rank; +}; + +// Partial template specialization +template +class thread_block_tile_type + : public thread_block_tile_base, public tiled_group { + _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize; + + typedef thread_block_tile_base tbtBase; + +protected: + __CG_QUALIFIER__ thread_block_tile_type(unsigned int meta_group_rank, + unsigned int meta_group_size) + : tiled_group(numThreads) { + coalesced_info.tiled_info.size = numThreads; + coalesced_info.tiled_info.is_tiled = true; + coalesced_info.tiled_info.meta_group_rank = meta_group_rank; + coalesced_info.tiled_info.meta_group_size = meta_group_size; + } + +public: + using tbtBase::size; + using tbtBase::sync; + using tbtBase::thread_rank; + + __CG_QUALIFIER__ unsigned int meta_group_rank() const { + return coalesced_info.tiled_info.meta_group_rank; + } + + __CG_QUALIFIER__ unsigned int meta_group_size() const { + return coalesced_info.tiled_info.meta_group_size; + } + // end of operative group + /** + * @} + */ +}; + +/** \brief User exposed API to partition groups. + * + * \details A collective operation that partitions the parent group into a + * one-dimensional, row-major, tiling of subgroups. + */ + +__CG_QUALIFIER__ thread_group tiled_partition(const thread_group &parent, + unsigned int tile_size) { + if (parent.cg_type() == internal::cg_tiled_group) { + const tiled_group *cg = static_cast(&parent); + return cg->new_tiled_group(tile_size); + } else if (parent.cg_type() == internal::cg_coalesced_group) { + const coalesced_group *cg = static_cast(&parent); + return cg->new_tiled_group(tile_size); + } else { + const thread_block *tb = static_cast(&parent); + return tb->new_tiled_group(tile_size); + } +} + +// Thread block type overload +__CG_QUALIFIER__ thread_group tiled_partition(const thread_block &parent, + unsigned int tile_size) { + return (parent.new_tiled_group(tile_size)); +} + +__CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group &parent, + unsigned int tile_size) { + return (parent.new_tiled_group(tile_size)); +} + +// If a coalesced group is passed to be partitioned, it should remain coalesced +__CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group &parent, + unsigned int tile_size) { + return (parent.new_tiled_group(tile_size)); +} + +template class thread_block_tile; + +namespace impl { +template class thread_block_tile_internal; + +template +class thread_block_tile_internal + : public thread_block_tile_type { +protected: + template + __CG_QUALIFIER__ thread_block_tile_internal( + const thread_block_tile_internal &g) + : thread_block_tile_type(g.meta_group_rank(), + g.meta_group_size()) {} + + __CG_QUALIFIER__ thread_block_tile_internal(const thread_block &g) + : thread_block_tile_type() {} +}; +} // namespace impl + +template +class thread_block_tile + : public impl::thread_block_tile_internal { +protected: + __CG_QUALIFIER__ thread_block_tile(const ParentCGTy &g) + : impl::thread_block_tile_internal(g) {} + +public: + __CG_QUALIFIER__ operator thread_block_tile() const { + return thread_block_tile(*this); + } +}; + +template +class thread_block_tile + : public impl::thread_block_tile_internal { + template friend class thread_block_tile; + +protected: +public: + template + __CG_QUALIFIER__ + thread_block_tile(const thread_block_tile &g) + : impl::thread_block_tile_internal(g) {} +}; + +template class thread_block_tile; + +namespace impl { +template struct tiled_partition_internal; + +template +struct tiled_partition_internal + : public thread_block_tile { + __CG_QUALIFIER__ tiled_partition_internal(const thread_block &g) + : thread_block_tile(g) {} +}; + +} // namespace impl + +/** \brief User exposed API to partition groups. + * + * \details This constructs a templated class derieved from thread_group. + * The template defines tile size of the new thread group at compile + * time. + */ +template +__CG_QUALIFIER__ thread_block_tile +tiled_partition(const ParentCGTy &g) { + static_assert( + is_valid_tile_size::value, + "Tiled partition with size > wavefront size. Currently not supported "); + return impl::tiled_partition_internal(g); +} +} // namespace cooperative_groups + +#endif // __cplusplus +#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp16.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp16.h new file mode 100644 index 000000000..7e3d42e16 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp16.h @@ -0,0 +1,1171 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H + +#if defined(__HIPCC_RTC__) +#define __HOST_DEVICE__ __device__ +#else +#define __HOST_DEVICE__ __host__ __device__ +#include "hip/amd_detail/host_defines.h" +#include +#include +#if defined(__cplusplus) +#include +#include +#include +#endif +#endif // !defined(__HIPCC_RTC__) + +#if defined(__clang__) && defined(__HIP__) +typedef _Float16 _Float16_2 __attribute__((ext_vector_type(2))); + +struct __half_raw { + union { + static_assert(sizeof(_Float16) == sizeof(unsigned short), ""); + + _Float16 data; + unsigned short x; + }; +}; + +struct __half2_raw { + union { + static_assert(sizeof(_Float16_2) == sizeof(unsigned short[2]), ""); + + struct { + __half_raw x; + __half_raw y; + }; + _Float16_2 data; + }; +}; + +#if defined(__cplusplus) +#if !defined(__HIPCC_RTC__) +#include "amd_device_functions.h" +#include "amd_hip_vector_types.h" +#include "amd_warp_functions.h" +#include "hip_fp16_math_fwd.h" +#include "host_defines.h" +#endif +namespace std { +template <> struct is_floating_point<_Float16> : std::true_type {}; +} // namespace std + +template +using Enable_if_t = typename std::enable_if::type; + +// BEGIN STRUCT __HALF +struct __half { +protected: + union { + static_assert(sizeof(_Float16) == sizeof(unsigned short), ""); + + _Float16 data; + unsigned short __x; + }; + +public: + // CREATORS + __HOST_DEVICE__ + __half() = default; + __HOST_DEVICE__ + __half(const __half_raw &x) : data{x.data} {} +#if !defined(__HIP_NO_HALF_CONVERSIONS__) + __HOST_DEVICE__ + __half(decltype(data) x) : data{x} {} + template {}> * = nullptr> + __HOST_DEVICE__ __half(T x) : data{static_cast<_Float16>(x)} {} +#endif + __HOST_DEVICE__ + __half(const __half &) = default; + __HOST_DEVICE__ + __half(__half &&) = default; + __HOST_DEVICE__ + ~__half() = default; + +// CREATORS - DEVICE ONLY +#if !defined(__HIP_NO_HALF_CONVERSIONS__) + template {}> * = nullptr> + __HOST_DEVICE__ __half(T x) : data{static_cast<_Float16>(x)} {} +#endif + + // MANIPULATORS + __HOST_DEVICE__ + __half &operator=(const __half &) = default; + __HOST_DEVICE__ + __half &operator=(__half &&) = default; + __HOST_DEVICE__ + __half &operator=(const __half_raw &x) { + data = x.data; + return *this; + } + __HOST_DEVICE__ + volatile __half &operator=(const __half_raw &x) volatile { + data = x.data; + return *this; + } + volatile __half &operator=(const volatile __half_raw &x) volatile { + data = x.data; + return *this; + } + __half &operator=(__half_raw &&x) { + data = x.data; + return *this; + } + volatile __half &operator=(__half_raw &&x) volatile { + data = x.data; + return *this; + } + volatile __half &operator=(volatile __half_raw &&x) volatile { + data = x.data; + return *this; + } +#if !defined(__HIP_NO_HALF_CONVERSIONS__) + template {}> * = nullptr> + __HOST_DEVICE__ __half &operator=(T x) { + data = static_cast<_Float16>(x); + return *this; + } +#endif + +// MANIPULATORS - DEVICE ONLY +#if !defined(__HIP_NO_HALF_CONVERSIONS__) + template {}> * = nullptr> + __device__ __half &operator=(T x) { + data = static_cast<_Float16>(x); + return *this; + } +#endif + +#if !defined(__HIP_NO_HALF_OPERATORS__) + __device__ __half &operator+=(const __half &x) { + data += x.data; + return *this; + } + __device__ __half &operator-=(const __half &x) { + data -= x.data; + return *this; + } + __device__ __half &operator*=(const __half &x) { + data *= x.data; + return *this; + } + __device__ __half &operator/=(const __half &x) { + data /= x.data; + return *this; + } + __device__ __half &operator++() { + ++data; + return *this; + } + __device__ __half operator++(int) { + __half tmp{*this}; + ++*this; + return tmp; + } + __device__ __half &operator--() { + --data; + return *this; + } + __device__ __half operator--(int) { + __half tmp{*this}; + --*this; + return tmp; + } +#endif + +// ACCESSORS +#if !defined(__HIP_NO_HALF_CONVERSIONS__) + template {}> * = nullptr> + __HOST_DEVICE__ operator T() const { + return data; + } +#endif + __HOST_DEVICE__ + operator __half_raw() const { return __half_raw{data}; } + __HOST_DEVICE__ + operator __half_raw() const volatile { return __half_raw{data}; } + +#if !defined(__HIP_NO_HALF_CONVERSIONS__) + template {}> * = nullptr> + __HOST_DEVICE__ operator T() const { + return data; + } +#endif + +#if !defined(__HIP_NO_HALF_OPERATORS__) + __device__ __half operator+() const { return *this; } + __device__ __half operator-() const { + __half tmp{*this}; + tmp.data = -tmp.data; + return tmp; + } +#endif + +// FRIENDS +#if !defined(__HIP_NO_HALF_OPERATORS__) + friend inline __device__ __half operator+(const __half &x, const __half &y) { + return __half{x} += y; + } + friend inline __device__ __half operator-(const __half &x, const __half &y) { + return __half{x} -= y; + } + friend inline __device__ __half operator*(const __half &x, const __half &y) { + return __half{x} *= y; + } + friend inline __device__ __half operator/(const __half &x, const __half &y) { + return __half{x} /= y; + } + friend inline __device__ bool operator==(const __half &x, const __half &y) { + return x.data == y.data; + } + friend inline __device__ bool operator!=(const __half &x, const __half &y) { + return !(x == y); + } + friend inline __device__ bool operator<(const __half &x, const __half &y) { + return x.data < y.data; + } + friend inline __device__ bool operator>(const __half &x, const __half &y) { + return y.data < x.data; + } + friend inline __device__ bool operator<=(const __half &x, const __half &y) { + return !(y < x); + } + friend inline __device__ bool operator>=(const __half &x, const __half &y) { + return !(x < y); + } +#endif // !defined(__HIP_NO_HALF_OPERATORS__) +}; +// END STRUCT __HALF + +// BEGIN STRUCT __HALF2 +struct __half2 { +public: + union { + static_assert(sizeof(_Float16_2) == sizeof(unsigned short[2]), ""); + + struct { + __half x; + __half y; + }; + _Float16_2 data; + }; + + // CREATORS + __HOST_DEVICE__ + __half2() = default; + __HOST_DEVICE__ + __half2(const __half2_raw &xx) : data{xx.data} {} + __HOST_DEVICE__ + __half2(decltype(data) xx) : data{xx} {} + __HOST_DEVICE__ + __half2(const __half &xx, const __half &yy) + : data{static_cast<__half_raw>(xx).data, + static_cast<__half_raw>(yy).data} {} + __HOST_DEVICE__ + __half2(const __half2 &) = default; + __HOST_DEVICE__ + __half2(__half2 &&) = default; + __HOST_DEVICE__ + ~__half2() = default; + + // MANIPULATORS + __HOST_DEVICE__ + __half2 &operator=(const __half2 &) = default; + __HOST_DEVICE__ + __half2 &operator=(__half2 &&) = default; + __HOST_DEVICE__ + __half2 &operator=(const __half2_raw &xx) { + data = xx.data; + return *this; + } + +// MANIPULATORS - DEVICE ONLY +#if !defined(__HIP_NO_HALF_OPERATORS__) + __device__ __half2 &operator+=(const __half2 &xx) { + data += xx.data; + return *this; + } + __device__ __half2 &operator-=(const __half2 &xx) { + data -= xx.data; + return *this; + } + __device__ __half2 &operator*=(const __half2 &xx) { + data *= xx.data; + return *this; + } + __device__ __half2 &operator/=(const __half2 &xx) { + data /= xx.data; + return *this; + } + __device__ __half2 &operator++() { return *this += _Float16_2{1, 1}; } + __device__ __half2 operator++(int) { + __half2 tmp{*this}; + ++*this; + return tmp; + } + __device__ __half2 &operator--() { return *this -= _Float16_2{1, 1}; } + __device__ __half2 operator--(int) { + __half2 tmp{*this}; + --*this; + return tmp; + } +#endif + + // ACCESSORS + __HOST_DEVICE__ + operator decltype(data)() const { return data; } + __HOST_DEVICE__ + operator __half2_raw() const { + __half2_raw r; + r.data = data; + return r; + } + +// ACCESSORS - DEVICE ONLY +#if !defined(__HIP_NO_HALF_OPERATORS__) + __device__ __half2 operator+() const { return *this; } + __device__ __half2 operator-() const { + __half2 tmp{*this}; + tmp.data = -tmp.data; + return tmp; + } +#endif + +// FRIENDS +#if !defined(__HIP_NO_HALF_OPERATORS__) + friend inline __device__ __half2 operator+(const __half2 &xx, + const __half2 &yy) { + return __half2{xx} += yy; + } + friend inline __device__ __half2 operator-(const __half2 &xx, + const __half2 &yy) { + return __half2{xx} -= yy; + } + friend inline __device__ __half2 operator*(const __half2 &xx, + const __half2 &yy) { + return __half2{xx} *= yy; + } + friend inline __device__ __half2 operator/(const __half2 &xx, + const __half2 &yy) { + return __half2{xx} /= yy; + } + friend inline __device__ bool operator==(const __half2 &xx, + const __half2 &yy) { + auto r = xx.data == yy.data; + return r.x != 0 && r.y != 0; + } + friend inline __device__ bool operator!=(const __half2 &xx, + const __half2 &yy) { + return !(xx == yy); + } + friend inline __device__ bool operator<(const __half2 &xx, + const __half2 &yy) { + auto r = xx.data < yy.data; + return r.x != 0 && r.y != 0; + } + friend inline __device__ bool operator>(const __half2 &xx, + const __half2 &yy) { + return yy < xx; + } + friend inline __device__ bool operator<=(const __half2 &xx, + const __half2 &yy) { + return !(yy < xx); + } + friend inline __device__ bool operator>=(const __half2 &xx, + const __half2 &yy) { + return !(xx < yy); + } +#endif // !defined(__HIP_NO_HALF_OPERATORS__) +}; +// END STRUCT __HALF2 + +namespace { +inline __HOST_DEVICE__ __half2 make_half2(__half x, __half y) { + return __half2{x, y}; +} + +inline __HOST_DEVICE__ __half __low2half(__half2 x) { + return __half{__half_raw{static_cast<__half2_raw>(x).data.x}}; +} + +inline __HOST_DEVICE__ __half __high2half(__half2 x) { + return __half{__half_raw{static_cast<__half2_raw>(x).data.y}}; +} + +inline __HOST_DEVICE__ __half2 __half2half2(__half x) { return __half2{x, x}; } + +inline __HOST_DEVICE__ __half2 __halves2half2(__half x, __half y) { + return __half2{x, y}; +} + +inline __HOST_DEVICE__ __half2 __low2half2(__half2 x) { + return __half2{_Float16_2{static_cast<__half2_raw>(x).data.x, + static_cast<__half2_raw>(x).data.x}}; +} + +inline __HOST_DEVICE__ __half2 __high2half2(__half2 x) { + return __half2{_Float16_2{static_cast<__half2_raw>(x).data.y, + static_cast<__half2_raw>(x).data.y}}; +} + +inline __HOST_DEVICE__ __half2 __lows2half2(__half2 x, __half2 y) { + return __half2{_Float16_2{static_cast<__half2_raw>(x).data.x, + static_cast<__half2_raw>(y).data.x}}; +} + +inline __HOST_DEVICE__ __half2 __highs2half2(__half2 x, __half2 y) { + return __half2{_Float16_2{static_cast<__half2_raw>(x).data.y, + static_cast<__half2_raw>(y).data.y}}; +} + +inline __HOST_DEVICE__ __half2 __lowhigh2highlow(__half2 x) { + return __half2{_Float16_2{static_cast<__half2_raw>(x).data.y, + static_cast<__half2_raw>(x).data.x}}; +} + +// Bitcasts +inline __device__ short __half_as_short(__half x) { + return static_cast<__half_raw>(x).x; +} + +inline __device__ unsigned short __half_as_ushort(__half x) { + return static_cast<__half_raw>(x).x; +} + +inline __device__ __half __short_as_half(short x) { + __half_raw r; + r.x = x; + return r; +} + +inline __device__ __half __ushort_as_half(unsigned short x) { + __half_raw r; + r.x = x; + return r; +} + +// float -> half | half2 +inline __HOST_DEVICE__ __half __float2half(float x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __HOST_DEVICE__ __half __float2half_rn(float x) { + return __half_raw{static_cast<_Float16>(x)}; +} +#if !defined(__HIPCC_RTC__) +// TODO: rounding behaviour is not correct for host functions. +inline __host__ __half __float2half_rz(float x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __host__ __half __float2half_rd(float x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __host__ __half __float2half_ru(float x) { + return __half_raw{static_cast<_Float16>(x)}; +} +#endif +inline __device__ __half __float2half_rz(float x) { + return __half_raw{__ocml_cvtrtz_f16_f32(x)}; +} +inline __device__ __half __float2half_rd(float x) { + return __half_raw{__ocml_cvtrtn_f16_f32(x)}; +} +inline __device__ __half __float2half_ru(float x) { + return __half_raw{__ocml_cvtrtp_f16_f32(x)}; +} +inline __HOST_DEVICE__ __half2 __float2half2_rn(float x) { + return __half2{ + _Float16_2{static_cast<_Float16>(x), static_cast<_Float16>(x)}}; +} +inline __HOST_DEVICE__ __half2 __floats2half2_rn(float x, float y) { + return __half2{ + _Float16_2{static_cast<_Float16>(x), static_cast<_Float16>(y)}}; +} +inline __HOST_DEVICE__ __half2 __float22half2_rn(float2 x) { + return __floats2half2_rn(x.x, x.y); +} + +// half | half2 -> float +inline __HOST_DEVICE__ float __half2float(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __HOST_DEVICE__ float __low2float(__half2 x) { + return static_cast<__half2_raw>(x).data.x; +} +inline __HOST_DEVICE__ float __high2float(__half2 x) { + return static_cast<__half2_raw>(x).data.y; +} +inline __HOST_DEVICE__ float2 __half22float2(__half2 x) { + return make_float2(static_cast<__half2_raw>(x).data.x, + static_cast<__half2_raw>(x).data.y); +} + +// half -> int +inline __device__ int __half2int_rn(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ int __half2int_rz(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ int __half2int_rd(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ int __half2int_ru(__half x) { + return static_cast<__half_raw>(x).data; +} + +// int -> half +inline __device__ __half __int2half_rn(int x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __int2half_rz(int x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __int2half_rd(int x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __int2half_ru(int x) { + return __half_raw{static_cast<_Float16>(x)}; +} + +// half -> short +inline __device__ short __half2short_rn(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ short __half2short_rz(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ short __half2short_rd(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ short __half2short_ru(__half x) { + return static_cast<__half_raw>(x).data; +} + +// short -> half +inline __device__ __half __short2half_rn(short x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __short2half_rz(short x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __short2half_rd(short x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __short2half_ru(short x) { + return __half_raw{static_cast<_Float16>(x)}; +} + +// half -> long long +inline __device__ long long __half2ll_rn(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ long long __half2ll_rz(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ long long __half2ll_rd(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ long long __half2ll_ru(__half x) { + return static_cast<__half_raw>(x).data; +} + +// long long -> half +inline __device__ __half __ll2half_rn(long long x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ll2half_rz(long long x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ll2half_rd(long long x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ll2half_ru(long long x) { + return __half_raw{static_cast<_Float16>(x)}; +} + +// half -> unsigned int +inline __device__ unsigned int __half2uint_rn(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned int __half2uint_rz(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned int __half2uint_rd(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned int __half2uint_ru(__half x) { + return static_cast<__half_raw>(x).data; +} + +// unsigned int -> half +inline __device__ __half __uint2half_rn(unsigned int x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __uint2half_rz(unsigned int x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __uint2half_rd(unsigned int x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __uint2half_ru(unsigned int x) { + return __half_raw{static_cast<_Float16>(x)}; +} + +// half -> unsigned short +inline __device__ unsigned short __half2ushort_rn(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned short __half2ushort_rz(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned short __half2ushort_rd(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned short __half2ushort_ru(__half x) { + return static_cast<__half_raw>(x).data; +} + +// unsigned short -> half +inline __device__ __half __ushort2half_rn(unsigned short x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ushort2half_rz(unsigned short x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ushort2half_rd(unsigned short x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ushort2half_ru(unsigned short x) { + return __half_raw{static_cast<_Float16>(x)}; +} + +// half -> unsigned long long +inline __device__ unsigned long long __half2ull_rn(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned long long __half2ull_rz(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned long long __half2ull_rd(__half x) { + return static_cast<__half_raw>(x).data; +} +inline __device__ unsigned long long __half2ull_ru(__half x) { + return static_cast<__half_raw>(x).data; +} + +// unsigned long long -> half +inline __device__ __half __ull2half_rn(unsigned long long x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ull2half_rz(unsigned long long x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ull2half_rd(unsigned long long x) { + return __half_raw{static_cast<_Float16>(x)}; +} +inline __device__ __half __ull2half_ru(unsigned long long x) { + return __half_raw{static_cast<_Float16>(x)}; +} + +// Load primitives +inline __device__ __half __ldg(const __half *ptr) { return *ptr; } +inline __device__ __half __ldcg(const __half *ptr) { return *ptr; } +inline __device__ __half __ldca(const __half *ptr) { return *ptr; } +inline __device__ __half __ldcs(const __half *ptr) { return *ptr; } + +inline __HOST_DEVICE__ __half2 __ldg(const __half2 *ptr) { return *ptr; } +inline __HOST_DEVICE__ __half2 __ldcg(const __half2 *ptr) { return *ptr; } +inline __HOST_DEVICE__ __half2 __ldca(const __half2 *ptr) { return *ptr; } +inline __HOST_DEVICE__ __half2 __ldcs(const __half2 *ptr) { return *ptr; } + +// Relations +inline __device__ bool __heq(__half x, __half y) { + return static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data; +} +inline __device__ bool __hne(__half x, __half y) { + return static_cast<__half_raw>(x).data != static_cast<__half_raw>(y).data; +} +inline __device__ bool __hle(__half x, __half y) { + return static_cast<__half_raw>(x).data <= static_cast<__half_raw>(y).data; +} +inline __device__ bool __hge(__half x, __half y) { + return static_cast<__half_raw>(x).data >= static_cast<__half_raw>(y).data; +} +inline __device__ bool __hlt(__half x, __half y) { + return static_cast<__half_raw>(x).data < static_cast<__half_raw>(y).data; +} +inline __device__ bool __hgt(__half x, __half y) { + return static_cast<__half_raw>(x).data > static_cast<__half_raw>(y).data; +} +inline __device__ bool __hequ(__half x, __half y) { + return !(static_cast<__half_raw>(x).data < static_cast<__half_raw>(y).data) && + !(static_cast<__half_raw>(x).data > static_cast<__half_raw>(y).data); +} +inline __device__ bool __hneu(__half x, __half y) { + return !(static_cast<__half_raw>(x).data == static_cast<__half_raw>(y).data); +} +inline __device__ bool __hleu(__half x, __half y) { + return !(static_cast<__half_raw>(x).data > static_cast<__half_raw>(y).data); +} +inline __device__ bool __hgeu(__half x, __half y) { + return !(static_cast<__half_raw>(x).data < static_cast<__half_raw>(y).data); +} +inline __device__ bool __hltu(__half x, __half y) { + return !(static_cast<__half_raw>(x).data >= static_cast<__half_raw>(y).data); +} +inline __device__ bool __hgtu(__half x, __half y) { + return !(static_cast<__half_raw>(x).data <= static_cast<__half_raw>(y).data); +} + +inline __HOST_DEVICE__ __half2 __heq2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(x).data == static_cast<__half2_raw>(y).data; + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hne2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(x).data != static_cast<__half2_raw>(y).data; + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hle2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(x).data <= static_cast<__half2_raw>(y).data; + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hge2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(x).data >= static_cast<__half2_raw>(y).data; + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hlt2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(x).data < static_cast<__half2_raw>(y).data; + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hgt2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(x).data > static_cast<__half2_raw>(y).data; + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hequ2(__half2 x, __half2 y) { + auto r = + !(static_cast<__half2_raw>(x).data < static_cast<__half2_raw>(y).data) && + !(static_cast<__half2_raw>(x).data > static_cast<__half2_raw>(y).data); + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hneu2(__half2 x, __half2 y) { + auto r = + !(static_cast<__half2_raw>(x).data == static_cast<__half2_raw>(y).data); + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hleu2(__half2 x, __half2 y) { + auto r = + !(static_cast<__half2_raw>(x).data > static_cast<__half2_raw>(y).data); + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hgeu2(__half2 x, __half2 y) { + auto r = + !(static_cast<__half2_raw>(x).data < static_cast<__half2_raw>(y).data); + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hltu2(__half2 x, __half2 y) { + auto r = + !(static_cast<__half2_raw>(x).data >= static_cast<__half2_raw>(y).data); + return __builtin_convertvector(-r, _Float16_2); +} +inline __HOST_DEVICE__ __half2 __hgtu2(__half2 x, __half2 y) { + auto r = + !(static_cast<__half2_raw>(x).data <= static_cast<__half2_raw>(y).data); + return __builtin_convertvector(-r, _Float16_2); +} + +inline __HOST_DEVICE__ bool __hbeq2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__heq2(x, y)); + return r.data.x != 0 && r.data.y != 0; +} +inline __HOST_DEVICE__ bool __hbne2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__hne2(x, y)); + return r.data.x != 0 && r.data.y != 0; +} +inline __HOST_DEVICE__ bool __hble2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__hle2(x, y)); + return r.data.x != 0 && r.data.y != 0; +} +inline __HOST_DEVICE__ bool __hbge2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__hge2(x, y)); + return r.data.x != 0 && r.data.y != 0; +} +inline __HOST_DEVICE__ bool __hblt2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__hlt2(x, y)); + return r.data.x != 0 && r.data.y != 0; +} +inline __HOST_DEVICE__ bool __hbgt2(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__hgt2(x, y)); + return r.data.x != 0 && r.data.y != 0; +} +inline __HOST_DEVICE__ bool __hbequ2(__half2 x, __half2 y) { + return __hbeq2(x, y); +} +inline __HOST_DEVICE__ bool __hbneu2(__half2 x, __half2 y) { + return __hbne2(x, y); +} +inline __HOST_DEVICE__ bool __hbleu2(__half2 x, __half2 y) { + return __hble2(x, y); +} +inline __HOST_DEVICE__ bool __hbgeu2(__half2 x, __half2 y) { + return __hbge2(x, y); +} +inline __HOST_DEVICE__ bool __hbltu2(__half2 x, __half2 y) { + return __hblt2(x, y); +} +inline __HOST_DEVICE__ bool __hbgtu2(__half2 x, __half2 y) { + return __hbgt2(x, y); +} +inline __device__ __half __hmax(const __half x, const __half y) { + return __half_raw{__ocml_fmax_f16(static_cast<__half_raw>(x).data, + static_cast<__half_raw>(y).data)}; +} +inline __device__ __half __hmax_nan(const __half x, const __half y) { + if (__ocml_isnan_f16(static_cast<__half_raw>(x).data)) { + return x; + } else if (__ocml_isnan_f16(static_cast<__half_raw>(y).data)) { + return y; + } + return __hmax(x, y); +} +inline __device__ __half __hmin(const __half x, const __half y) { + return __half_raw{__ocml_fmin_f16(static_cast<__half_raw>(x).data, + static_cast<__half_raw>(y).data)}; +} +inline __device__ __half __hmin_nan(const __half x, const __half y) { + if (__ocml_isnan_f16(static_cast<__half_raw>(x).data)) { + return x; + } else if (__ocml_isnan_f16(static_cast<__half_raw>(y).data)) { + return y; + } + return __hmin(x, y); +} + +// Arithmetic +inline __device__ __half __clamp_01(__half x) { + auto r = static_cast<__half_raw>(x); + + if (__hlt(x, __half_raw{0})) + return __half_raw{0}; + if (__hlt(__half_raw{1}, x)) + return __half_raw{1}; + return r; +} + +inline __device__ __half __hadd(__half x, __half y) { + return __half_raw{static_cast<__half_raw>(x).data + + static_cast<__half_raw>(y).data}; +} +inline __device__ __half __habs(__half x) { + return __half_raw{__ocml_fabs_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half __hsub(__half x, __half y) { + return __half_raw{static_cast<__half_raw>(x).data - + static_cast<__half_raw>(y).data}; +} +inline __device__ __half __hmul(__half x, __half y) { + return __half_raw{static_cast<__half_raw>(x).data * + static_cast<__half_raw>(y).data}; +} +inline __device__ __half __hadd_sat(__half x, __half y) { + return __clamp_01(__hadd(x, y)); +} +inline __device__ __half __hsub_sat(__half x, __half y) { + return __clamp_01(__hsub(x, y)); +} +inline __device__ __half __hmul_sat(__half x, __half y) { + return __clamp_01(__hmul(x, y)); +} +inline __device__ __half __hfma(__half x, __half y, __half z) { + return __half_raw{__ocml_fma_f16(static_cast<__half_raw>(x).data, + static_cast<__half_raw>(y).data, + static_cast<__half_raw>(z).data)}; +} +inline __device__ __half __hfma_sat(__half x, __half y, __half z) { + return __clamp_01(__hfma(x, y, z)); +} +inline __device__ __half __hdiv(__half x, __half y) { + return __half_raw{static_cast<__half_raw>(x).data / + static_cast<__half_raw>(y).data}; +} + +inline __HOST_DEVICE__ __half2 __hadd2(__half2 x, __half2 y) { + return __half2{static_cast<__half2_raw>(x).data + + static_cast<__half2_raw>(y).data}; +} +inline __HOST_DEVICE__ __half2 __habs2(__half2 x) { + return __half2{__ocml_fabs_2f16(static_cast<__half2_raw>(x).data)}; +} +inline __HOST_DEVICE__ __half2 __hsub2(__half2 x, __half2 y) { + return __half2{static_cast<__half2_raw>(x).data - + static_cast<__half2_raw>(y).data}; +} +inline __HOST_DEVICE__ __half2 __hmul2(__half2 x, __half2 y) { + return __half2{static_cast<__half2_raw>(x).data * + static_cast<__half2_raw>(y).data}; +} +inline __HOST_DEVICE__ __half2 __hadd2_sat(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__hadd2(x, y)); + return __half2{__clamp_01(__half_raw{r.data.x}), + __clamp_01(__half_raw{r.data.y})}; +} +inline __HOST_DEVICE__ __half2 __hsub2_sat(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__hsub2(x, y)); + return __half2{__clamp_01(__half_raw{r.data.x}), + __clamp_01(__half_raw{r.data.y})}; +} +inline __HOST_DEVICE__ __half2 __hmul2_sat(__half2 x, __half2 y) { + auto r = static_cast<__half2_raw>(__hmul2(x, y)); + return __half2{__clamp_01(__half_raw{r.data.x}), + __clamp_01(__half_raw{r.data.y})}; +} +inline __HOST_DEVICE__ __half2 __hfma2(__half2 x, __half2 y, __half2 z) { + return __half2{__ocml_fma_2f16(x, y, z)}; +} +inline __HOST_DEVICE__ __half2 __hfma2_sat(__half2 x, __half2 y, __half2 z) { + auto r = static_cast<__half2_raw>(__hfma2(x, y, z)); + return __half2{__clamp_01(__half_raw{r.data.x}), + __clamp_01(__half_raw{r.data.y})}; +} +inline __HOST_DEVICE__ __half2 __h2div(__half2 x, __half2 y) { + return __half2{static_cast<__half2_raw>(x).data / + static_cast<__half2_raw>(y).data}; +} + +// Math functions +#if defined(__clang__) && defined(__HIP__) +inline __device__ float amd_mixed_dot(__half2 a, __half2 b, float c, + bool saturate) { + return __ockl_fdot2(static_cast<__half2_raw>(a).data, + static_cast<__half2_raw>(b).data, c, saturate); +} +#endif +inline __device__ __half htrunc(__half x) { + return __half_raw{__ocml_trunc_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hceil(__half x) { + return __half_raw{__ocml_ceil_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hfloor(__half x) { + return __half_raw{__ocml_floor_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hrint(__half x) { + return __half_raw{__ocml_rint_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hsin(__half x) { + return __half_raw{__ocml_sin_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hcos(__half x) { + return __half_raw{__ocml_cos_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hexp(__half x) { + return __half_raw{__ocml_exp_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hexp2(__half x) { + return __half_raw{__ocml_exp2_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hexp10(__half x) { + return __half_raw{__ocml_exp10_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hlog2(__half x) { + return __half_raw{__ocml_log2_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hlog(__half x) { + return __half_raw{__ocml_log_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hlog10(__half x) { + return __half_raw{__ocml_log10_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hrcp(__half x) { + return __half_raw{static_cast<_Float16>(1.0f) / + static_cast<__half_raw>(x).data}; +} +inline __device__ __half hrsqrt(__half x) { + return __half_raw{__ocml_rsqrt_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ __half hsqrt(__half x) { + return __half_raw{__ocml_sqrt_f16(static_cast<__half_raw>(x).data)}; +} +inline __device__ bool __hisinf(__half x) { + return __ocml_isinf_f16(static_cast<__half_raw>(x).data); +} +inline __device__ bool __hisnan(__half x) { + return __ocml_isnan_f16(static_cast<__half_raw>(x).data); +} +inline __device__ __half __hneg(__half x) { + return __half_raw{-static_cast<__half_raw>(x).data}; +} + +inline __HOST_DEVICE__ __half2 h2trunc(__half2 x) { + return __half2{__ocml_trunc_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2ceil(__half2 x) { + return __half2{__ocml_ceil_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2floor(__half2 x) { + return __half2{__ocml_floor_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2rint(__half2 x) { + return __half2{__ocml_rint_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2sin(__half2 x) { + return __half2{__ocml_sin_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2cos(__half2 x) { + return __half2{__ocml_cos_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2exp(__half2 x) { + return __half2{__ocml_exp_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2exp2(__half2 x) { + return __half2{__ocml_exp2_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2exp10(__half2 x) { + return __half2{__ocml_exp10_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2log2(__half2 x) { + return __half2{__ocml_log2_2f16(x)}; +} +inline __HOST_DEVICE__ __half2 h2log(__half2 x) { return __ocml_log_2f16(x); } +inline __HOST_DEVICE__ __half2 h2log10(__half2 x) { + return __ocml_log10_2f16(x); +} +inline __HOST_DEVICE__ __half2 h2rcp(__half2 x) { + return _Float16_2{ + _Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / + x.data}; +} +inline __HOST_DEVICE__ __half2 h2rsqrt(__half2 x) { + return __ocml_rsqrt_2f16(x); +} +inline __HOST_DEVICE__ __half2 h2sqrt(__half2 x) { return __ocml_sqrt_2f16(x); } +inline __HOST_DEVICE__ __half2 __hisinf2(__half2 x) { + auto r = __ocml_isinf_2f16(x); + return __half2{ + _Float16_2{static_cast<_Float16>(r.x), static_cast<_Float16>(r.y)}}; +} +inline __HOST_DEVICE__ __half2 __hisnan2(__half2 x) { + auto r = __ocml_isnan_2f16(x); + return __half2{ + _Float16_2{static_cast<_Float16>(r.x), static_cast<_Float16>(r.y)}}; +} +inline __HOST_DEVICE__ __half2 __hneg2(__half2 x) { + return __half2{-static_cast<__half2_raw>(x).data}; +} +} // Anonymous namespace. + +#if !defined(HIP_NO_HALF) +using half = __half; +using half2 = __half2; +#endif +__device__ inline __half __shfl(__half var, int src_lane, + int width = warpSize) { + union { + int i; + __half h; + } tmp; + tmp.h = var; + tmp.i = __shfl(tmp.i, src_lane, width); + return tmp.h; +} +__device__ inline __half2 __shfl(__half2 var, int src_lane, + int width = warpSize) { + union { + int i; + __half2 h; + } tmp; + tmp.h = var; + tmp.i = __shfl(tmp.i, src_lane, width); + return tmp.h; +} +__device__ inline __half __shfl_up(__half var, unsigned int lane_delta, + int width = warpSize) { + union { + int i; + __half h; + } tmp; + tmp.h = var; + tmp.i = __shfl_up(tmp.i, lane_delta, width); + return tmp.h; +} +__device__ inline __half2 __shfl_up(__half2 var, unsigned int lane_delta, + int width = warpSize) { + union { + int i; + __half2 h; + } tmp; + tmp.h = var; + tmp.i = __shfl_up(tmp.i, lane_delta, width); + return tmp.h; +} +__device__ inline __half __shfl_down(__half var, unsigned int lane_delta, + int width = warpSize) { + union { + int i; + __half h; + } tmp; + tmp.h = var; + tmp.i = __shfl_down(tmp.i, lane_delta, width); + return tmp.h; +} +__device__ inline __half2 __shfl_down(__half2 var, unsigned int lane_delta, + int width = warpSize) { + union { + int i; + __half2 h; + } tmp; + tmp.h = var; + tmp.i = __shfl_down(tmp.i, lane_delta, width); + return tmp.h; +} +__device__ inline __half __shfl_xor(__half var, int lane_mask, + int width = warpSize) { + union { + int i; + __half h; + } tmp; + tmp.h = var; + tmp.i = __shfl_xor(tmp.i, lane_mask, width); + return tmp.h; +} +__device__ inline __half2 __shfl_xor(__half2 var, int lane_mask, + int width = warpSize) { + union { + int i; + __half2 h; + } tmp; + tmp.h = var; + tmp.i = __shfl_xor(tmp.i, lane_mask, width); + return tmp.h; +} +#endif // defined(__cplusplus) +#elif defined(__GNUC__) +#if !defined(__HIPCC_RTC__) +#include "hip_fp16_gcc.h" +#endif +#endif // !defined(__clang__) && defined(__GNUC__) + +#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp8.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp8.h new file mode 100644 index 000000000..6df377686 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_fp8.h @@ -0,0 +1,1487 @@ +/** + * MIT License + * + * Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * \file + * \brief amd_hip_fp8.h header, for AMD fp8 data types + */ + +#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_ +#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_ + +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \ + __HIP_DEVICE_COMPILE__ +#define HIP_FP8_CVT_FAST_PATH 1 +#else +#define HIP_FP8_CVT_FAST_PATH 0 +#endif + +#if !defined(__HIPCC_RTC__) +#include +#include + +#include "amd_hip_bf16.h" // bf16 +#include "amd_hip_fp16.h" // __half_raw +#include "amd_hip_vector_types.h" // float2 etc +#include "host_defines.h" // __hip_internal:: +#include "math_fwd.h" // ocml device functions +#endif // !defined(__HIPCC_RTC__) + +#if defined(__HIPCC_RTC__) +#define __FP8_HOST_DEVICE__ __device__ +#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static +#else +#define __FP8_HOST_DEVICE__ __host__ __device__ +#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline +#endif // __HIPCC_RTC__ + +#if !defined(__HIPCC_RTC__) +static_assert(CHAR_BIT == 8, "byte size should be of 8 bits"); +#endif +static_assert(sizeof(unsigned char) == 1); +static_assert(sizeof(unsigned short int) == 2); +static_assert(sizeof(unsigned int) == 4); + +/** + * \brief Describes FP8 interpretation + */ +enum __hip_fp8_interpretation_t { + __HIP_E4M3_FNUZ = 0, /**< Standard FP8 */ + __HIP_E5M2_FNUZ = 1, /**< BF8 */ +}; + +/** + * \brief Describes saturation behavior + */ +enum __hip_saturation_t { + __HIP_NOSAT = 0, /**< No saturation */ + __HIP_SATFINITE = 1, /**< Saturate to finite */ +}; + +/** \typedef __hip_fp8_storage_t + * + * \brief type to store single fp8 number + */ +typedef unsigned char __hip_fp8_storage_t; + +/** \typedef __hip_fp8x2_storage_t + * + * \brief type to store two fp8 numbers + */ +typedef unsigned short int __hip_fp8x2_storage_t; + +/** \typedef __hip_fp8x4_storage_t + * + * \brief type to store four fp8 numbers + */ +typedef unsigned int __hip_fp8x4_storage_t; + +namespace internal { +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39 +// This has been modified to add double types conversion as well +template +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t +cast_to_f8(T _x, int wm, int we, bool clip = false, bool stoch = false, + unsigned int rng = 0) { + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, + "Only half, float and double can be cast to f8"); + + const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); + unsigned long long x; + + if (sizeof(T) == 8) + x = reinterpret_cast(_x); + else if (sizeof(T) == 4) + x = reinterpret_cast(_x); + else + x = reinterpret_cast(_x); + + unsigned long long head, mantissa; + int exponent, bias; + unsigned int sign; + + if (sizeof(T) == 8) { + head = x & 0xFFF0000000000000ull; + mantissa = x & 0xFFFFFFFFFFFFFull; + exponent = (head >> 52) & 0x7FF; + sign = head >> 63; + bias = 1023; + } else if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + unsigned int signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { + if (sizeof(T) == 8) { + if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull) + return 0x80; + } else if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) + return 0x80; + } else { + if ((x & 0x7C00) == 0x7C00) + return 0x80; + } + } else { + if (sizeof(T) == 8) { + if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull) + return signed_inf + (mantissa != 0 ? 1 : 0); + } else if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) + return signed_inf + (mantissa != 0 ? 1 : 0); + } else { + if ((x & 0x7C00) == 0x7C00) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implict 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = + 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we +mostly concern fp16 here. In this case, f8 is usually in denormal. But there +could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has +exponent bias 16. It means that there are some numbers in fp16 denormal but they +are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 +(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = + f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal +range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 +actual exponent is -7, it is actually larger due to the implict 1, +Therefore it needs to be adjust to -6 and mantissa shift right by 1. +So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference + // for this case, act_exponent could be larger. Just + // that it does not need shift mantissa + } + mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) == + (1ull << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be +done before we shift right as shift right could rip off some residual part and +make something not midpoint look like midpoint. For example, the fp16 number +0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right +by 4 bits, it would look like midpoint. +*/ + + if (exponent_diff > 0) + mantissa >>= exponent_diff; + else if (exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1ull << mfmt); + // if there is no implict 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1; + bool odd = + mantissa & + (1ull << (mfmt - + wm)); // if the least significant bit that is not truncated is 1 + mantissa += + (stoch ? rng + : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & + drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1ull << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1ull << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << 7); + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220 +// This has been modified to handle double types as well +template +__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, + int we) { + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, + "only half, float and double are supported"); + + constexpr int weo = is_half ? 5 : (is_float ? 8 : 11); + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52); + + T fInf, fNegInf, fNaN, fNeg0; + if (is_half) { + const unsigned short int ihInf = 0x7C00; + const unsigned short int ihNegInf = 0xFC00; + const unsigned short int ihNaN = 0x7C01; + const unsigned short int ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else if (is_float) { + const unsigned int ifInf = 0x7F800000; + const unsigned int ifNegInf = 0xFF800000; + const unsigned int ifNaN = 0x7F800001; + const unsigned int ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } else if (is_double) { + const unsigned long long ifInf = 0x7FF0000000000000ull; + const unsigned long long ifNegInf = 0xFFF0000000000000ull; + const unsigned long long ifNaN = 0x7FF0000000000001ull; + const unsigned long long ifNeg0 = 0x8000000000000000ull; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + unsigned long long sign = x >> 7; + unsigned long long mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) + return fNaN; + } else { + if (x == 0x80) + return fNeg0; + if (exponent == ((1 << we) - 1)) + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + + typename __hip_internal::conditional< + sizeof(T) == 2, unsigned short int, + typename __hip_internal::conditional< + sizeof(T) == 4, unsigned int, unsigned long long>::type>::type retval; + + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; + return reinterpret_cast(retval); + } + + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { +#if __HIP_DEVICE_COMPILE__ + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __clz(mantissa) - (32 - wm); +#else + int sh = 1 + __builtin_clz(mantissa) - (32 - wm); +#endif + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1ull << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) + retval = (sign << 15) | (exponent << 10) | mantissa; + else if (sizeof(T) == 4) + retval = (sign << 31) | (exponent << 23) | mantissa; + else + retval = (sign << 63) | (static_cast(exponent) << 52) | + mantissa; + return reinterpret_cast(retval); +} + +#if HIP_FP8_CVT_FAST_PATH +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 +template +static __device__ __hip_fp8_storage_t cast_to_f8_from_f32( + float v, bool saturate, __hip_fp8_interpretation_t interpret, + unsigned int rng = 0) { + __hip_fp8_storage_t i8data; + union { + float fval; + unsigned int i32val; + unsigned char i8val[4]; // NOTE: not endian independent + } val; + + unsigned int ival = 0; + val.fval = v; + + if (saturate) { + if (interpret == __HIP_E4M3_FNUZ) { + if ((val.i32val & 0x7F800000) != + 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + } else { + if ((val.i32val & 0x7F800000) != + 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0); + } + } + } + + if (stochastic_rounding) { + ival = + interpret == __HIP_E4M3_FNUZ + ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0) + : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + i8data = val.i8val[0]; // little endian + } else { // RNE CVT + ival = + interpret == __HIP_E4M3_FNUZ + ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false) + : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + } + return i8data; +} + +static __device__ __hip_fp8x2_storage_t cast_to_f8x2_from_f32x2( + float2 v, bool saturate, __hip_fp8_interpretation_t interpret) { + union { + static_assert(sizeof(float2) == sizeof(unsigned int[2])); + static_assert(sizeof(float2) == sizeof(unsigned short[4])); + float2 fval; + unsigned int i32val[2]; + unsigned short i16val[4]; + } f2val; + + f2val.fval = v; + + if (saturate) { /// propagate NAN/INF, no clipping + if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) { + f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0); + } + if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) { + f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0); + } + } + + f2val.i32val[0] = interpret == __HIP_E4M3_FNUZ + ? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0, false) + : __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0, false); + + return static_cast<__hip_fp8x2_storage_t>(f2val.i16val[0]); +} + +static __device__ float +cast_to_f32_from_f8(__hip_fp8_storage_t v, + __hip_fp8_interpretation_t interpret) { + union { + unsigned int i32val; + unsigned char i8val[4]; + } val; + val.i8val[0] = v; + + float fval = interpret == __HIP_E4M3_FNUZ + ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0) + : __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); + return fval; +} + +static __device__ float2 cast_to_f32x2_from_f8x2( + __hip_fp8x2_storage_t v, __hip_fp8_interpretation_t interpret) { + union { + unsigned int i32val; + unsigned short i16val[2]; + } val; + val.i16val[0] = v; + + auto f2 = interpret == __HIP_E4M3_FNUZ + ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val, false) + : __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val, false); + return float2{f2[0], f2[1]}; +} +#endif // HIP_FP8_CVT_FAST_PATH + +/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned. +Inf are not supported. This gives us one additional number to represent. +NaN are represented by 1-0000-000 or 1-00000-00 */ +__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_fnuz_is_nan(__hip_fp8_storage_t a) { + return static_cast(a) == 0x80; +} +} // namespace internal + +/** + * \brief convert float to @p __hip_fp8_storage_t + * + * \param f float number + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t +__hip_cvt_float_to_fp8(const float f, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f8_from_f32(f, sat == __HIP_SATFINITE, type); +#else // HIP_FP8_CVT_FAST_PATH + int we = type == __HIP_E4M3_FNUZ ? 4 : 5; + int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; + return internal::cast_to_f8(f, wm, we, sat == __HIP_SATFINITE); +#endif // HIP_FP8_CVT_FAST_PATH +} + +/** + * \brief convert float2 to @p __hip_fp8x2_storage_t + * + * \param f2 float2 number + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8x2_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t +__hip_cvt_float2_to_fp8x2(const float2 f2, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, type); +#else + return static_cast<__hip_fp8x2_storage_t>( + static_cast(__hip_cvt_float_to_fp8(f2.y, sat, type)) + << 8 | + static_cast(__hip_cvt_float_to_fp8(f2.x, sat, type))); +#endif +} + +/** + * \brief convert double to @p __hip_fp8_storage_t + * + * \param d double val + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t +__hip_cvt_double_to_fp8(const double d, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { + int we = type == __HIP_E4M3_FNUZ ? 4 : 5; + int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; + return internal::cast_to_f8(d, wm, we, sat == __HIP_SATFINITE); +} + +/** + * \brief convert double2 to @p __hip_fp8x2_storage_t + * + * \param d2 double2 val + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8x2_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t +__hip_cvt_double2_to_fp8x2(const double2 d2, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { + return static_cast<__hip_fp8x2_storage_t>( + static_cast(__hip_cvt_double_to_fp8(d2.y, sat, type)) + << 8 | + static_cast( + __hip_cvt_double_to_fp8(d2.x, sat, type))); +} + +/** + * \brief convert __hip_bfloat16_raw to @p __hip_fp8_storage_t + * + * \param hr __hip_bfloat16_raw val + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_bfloat16raw_to_fp8( + const __hip_bfloat16_raw hr, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { + float fval = __hip_bfloat16(hr); + return __hip_cvt_float_to_fp8(fval, sat, type); +} + +/** + * \brief convert double2 to @p __hip_fp8x2_storage_t + * + * \param hr __hip_bfloat162_raw value + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8x2_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t +__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, + const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { + float2 f2 = __hip_bfloat162(hr); + return __hip_cvt_float2_to_fp8x2(f2, sat, type); +} + +/** + * \brief convert @p __hip_fp8_storage_t to __half_raw + * + * \param x __hip_fp8_storage_t val + * \param type interpretation of fp8 + * \return __half_raw + */ +__FP8_HOST_DEVICE_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw( + const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t type) { + unsigned int we = type == __HIP_E4M3_FNUZ ? 4 : 5; + unsigned int wm = type == __HIP_E4M3_FNUZ ? 3 : 2; + return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)}; +} + +/** + * \brief convert @p __hip_fp8x2_storage_t to __half2_raw + * + * \param x __hip_fp8x2_storage_t val + * \param type interpretation of fp8 + * \return __half2_raw + */ +__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2( + const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t type) { + __half2 ret(static_cast<__half>(__hip_cvt_fp8_to_halfraw( + static_cast<__hip_fp8_storage_t>(x & 0xFF), type)), + static_cast<__half>(__hip_cvt_fp8_to_halfraw( + static_cast<__hip_fp8_storage_t>(x >> 8), type))); + return static_cast<__half2_raw>(ret); +} + +/** + * \brief convert __half_raw to @p __hip_fp8_storage_t + * + * \param x __half_raw value + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t +__hip_cvt_halfraw_to_fp8(const __half_raw x, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { + return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, type); +} + +/** + * \brief convert __half2_raw to @p __hip_fp8x2_storage_t + * + * \param x __half2_raw value + * \param sat saturation of fp8 + * \param type interpretation of fp8 + * \return __hip_fp8x2_storage_t + */ +__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t +__hip_cvt_halfraw2_to_fp8x2(const __half2_raw x, const __hip_saturation_t sat, + const __hip_fp8_interpretation_t type) { + return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, type); +} + +/** + * \brief struct representing single fp8 number with e4m3 interpretation + * + */ +struct __hip_fp8_e4m3_fnuz { + __hip_fp8_storage_t __x; //! raw storage of fp8 number + constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE; + constexpr static __hip_fp8_interpretation_t __default_interpret = + __HIP_E4M3_FNUZ; + constexpr static unsigned int __we = 4; + constexpr static unsigned int __wm = 3; + + // TODO: SWDEV-452411 + // Add cast from unsigned long long, long long to fp8 + + /*! create fp8 e4m3 from long */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e4m3 from int */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e4m3 from short int */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e4m3 from unsigned long */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e4m3 from unsigned int */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e4m3 from unsigned short */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e4m3 from double */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f) + : __x(__hip_cvt_double_to_fp8(f, __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from float */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f) + : __x(__hip_cvt_float_to_fp8(f, __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from __hip_bfloat16 */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f) + : __x(__hip_cvt_float_to_fp8(static_cast(f), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e4m3 from __half */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f) + : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), + __default_saturation, + __default_interpret)) {} + + /*! default construct fp8 e4m3 */ + __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz() = default; + + /*! convert fp8 e4m3 to __half */ + __FP8_HOST_DEVICE__ operator __half() const { + return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret)); + } + + /*! convert fp8 e4m3 to __hip_bfloat16 */ + __FP8_HOST_DEVICE__ operator __hip_bfloat16() const { + float f = *this; + return __hip_bfloat16(f); + } + + /*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */ + __FP8_HOST_DEVICE__ operator bool() const { + // it can be 0x00 (+0.0) since 0x80 will be nan + return !(static_cast(__x) == 0); + } + + /*! convert fp8 e4m3 to char, clamp number to CHAR_MIN/CHAR_MAX if its out of + * range */ + __FP8_HOST_DEVICE__ operator char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + auto fval = internal::cast_from_f8(__x, __wm, __we); + auto llval = static_cast(fval); + if (llval <= CHAR_MIN) { + return CHAR_MIN; + } else if (llval >= CHAR_MAX) { + return CHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to double */ + __FP8_HOST_DEVICE__ operator double() const { + return internal::cast_from_f8(__x, __wm, __we); + } + + /*! convert fp8 e4m3 to float */ + __FP8_HOST_DEVICE__ operator float() const { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f32_from_f8(__x, __default_interpret); +#else + return internal::cast_from_f8(__x, __wm, __we); +#endif + } + + /*! convert fp8 e4m3 to int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e4m3 to long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e4m3 to long long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator long long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e4m3 to short int, clamp out of bound values, return 0 if + * value is NaN */ + __FP8_HOST_DEVICE__ operator short int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= SHRT_MIN) { + return SHRT_MIN; + } else if (llval >= SHRT_MAX) { + return SHRT_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to signed char, clamp out of bound values, return 0 if + * value is NaN */ + __FP8_HOST_DEVICE__ operator signed char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= SCHAR_MIN) { + return SCHAR_MIN; + } else if (llval >= SCHAR_MAX) { + return SCHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to unsigned char, clamp out of bound values, return 0 if + * value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } else if (llval >= UCHAR_MAX) { + return UCHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to unsigned int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to unsigned long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to long long int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned long long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e4m3 to unsigned short, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned short int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } +}; + +/** + * \brief struct representing two fp8 numbers with e4m3 interpretation + * + */ +struct __hip_fp8x2_e4m3_fnuz { + __hip_fp8x2_storage_t __x; //! raw storage of two fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = + __HIP_E4M3_FNUZ; + static constexpr unsigned int __we = 4; + static constexpr unsigned int __wm = 3; + + /*! create fp8x2 e4m3 type from double2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val) + : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, + __default_interpret)) {} + + /*! create fp8x2 e4m3 type from float2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val) + : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, + __default_interpret)) {} + + /*! create fp8x2 e4m3 type from __hip_bfloat162 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val) + : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, + __default_interpret)) {} + + /*! create fp8x2 e4m3 type from __half2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val) + : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, + __default_interpret)) {} + + /*! Default construct of fp8x2 e4m3 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz() = default; + + /*! convert fp8x2 e4m3 to __half2 */ + __FP8_HOST_DEVICE__ operator __half2() const { + return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret)); + } + + /*! convert fp8x2 e4m3 to float2 */ + __FP8_HOST_DEVICE__ operator float2() const { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); +#else + return float2(internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm, __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we)); +#endif + } +}; + +/** + * \brief struct representing four fp8 numbers with e4m3 interpretation + * + */ +struct __hip_fp8x4_e4m3_fnuz { + __hip_fp8x4_storage_t __x; //! raw storage of four fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = + __HIP_E4M3_FNUZ; + static constexpr unsigned int __we = 4; + static constexpr unsigned int __wm = 3; + + /*! create fp8x4 e4m3 type from double4 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val) + : __x{reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast(__hip_cvt_double_to_fp8( + val.x, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.y, __default_saturation, __default_interpret)) + << 8 | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.z, __default_saturation, __default_interpret)) + << 16 | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.w, __default_saturation, __default_interpret)) + << 24))} {} + + /*! create fp8x4 e4m3 type from float4 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val) + : __x{reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast(__hip_cvt_float_to_fp8( + val.x, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.y, __default_saturation, __default_interpret)) + << 8 | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.z, __default_saturation, __default_interpret)) + << 16 | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.w, __default_saturation, __default_interpret)) + << 24))} {} + + /*! create fp8x4 e4m3 type from two __hip_bfloat162 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, + const __hip_bfloat162 high) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast(__hip_cvt_bfloat16raw2_to_fp8x2( + high, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_bfloat16raw2_to_fp8x2( + low, __default_saturation, __default_interpret)) + << 16))) {} + + /*! create fp8x4 e4m3 type from two __half2 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, + const __half2 high) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( + high, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( + low, __default_saturation, __default_interpret)) + << 16))) {} + + /*! Default construct fp8x4 e4m3 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz() = default; + + /*! convert fp8x4 e4m3 to float4 */ + __FP8_HOST_DEVICE__ operator float4() const { + auto x = __x; // bypass const + auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t *>(&x); // Little E + auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t *>(&x) + 1); +#if HIP_FP8_CVT_FAST_PATH + float2 high = + internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret); + float2 low = + internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret); +#else + float2 high = float2( + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, + __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we)); + float2 low = float2( + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, + __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we)); +#endif + return float4(low.x, low.y, high.x, high.y); + } +}; + +/** + * \brief struct representing one fp8 number with e5m2 interpretation + * + */ +struct __hip_fp8_e5m2_fnuz { + __hip_fp8_storage_t __x; //! raw storage of one fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = + __HIP_E5M2_FNUZ; + static constexpr unsigned int __we = 5; + static constexpr unsigned int __wm = 2; + + // TODO: SWDEV-452411 + // Add cast from unsigned long long, long long to fp8 + + /*! create fp8 e5m2 type from long */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e5m2 type from int */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e5m2 type from short int */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e5m2 type from unsigned long */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e5m2 type from unsigned int */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e5m2 type from unsigned short */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val) + : __x(__hip_cvt_float_to_fp8(static_cast(val), + __default_saturation, __default_interpret)) { + } + + /*! create fp8 e5m2 type from double */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f) + : __x(__hip_cvt_double_to_fp8(f, __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from float */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f) + : __x(__hip_cvt_float_to_fp8(f, __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from __hip_bfloat16 */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f) + : __x(__hip_cvt_float_to_fp8(static_cast(f), __default_saturation, + __default_interpret)) {} + + /*! create fp8 e5m2 type from __hip_bfloat16 */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f) + : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), + __default_saturation, + __default_interpret)) {} + + /*! default construct fp8 e5m2 */ + __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz() = default; + + /*! convert fp8 e5m2 to float */ + __FP8_HOST_DEVICE__ operator float() const { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f32_from_f8(__x, __default_interpret); +#else + return internal::cast_from_f8(__x, __wm, __we); +#endif + } + + /*! convert fp8 e5m2 to __half */ + __FP8_HOST_DEVICE__ operator __half() const { + return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret)); + } + + /*! convert fp8 e5m2 to __hip_bfloat16 */ + __FP8_HOST_DEVICE__ operator __hip_bfloat16() const { + float f = *this; + return __hip_bfloat16(f); + } + + /*! convert fp8 e4m3 to bool, return false if value is 0, true otherwise */ + __FP8_HOST_DEVICE__ operator bool() const { + // it can be 0x00 (+0.0) since 0x80 will be nan + return !(static_cast(__x) == 0); + } + + /*! convert fp8 e5m2 to char, clamp out of bound values, return 0 if value is + * NaN */ + __FP8_HOST_DEVICE__ operator char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= CHAR_MIN) { + return CHAR_MIN; + } else if (llval >= CHAR_MAX) { + return CHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to double */ + __FP8_HOST_DEVICE__ operator double() const { + return internal::cast_from_f8(__x, __wm, __we); + } + + /*! convert fp8 e5m2 to int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e5m2 to long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e5m2 to long long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator long long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + return static_cast(fval); + } + + /*! convert fp8 e5m2 to short, clamp out of bound values, return 0 if value is + * NaN */ + __FP8_HOST_DEVICE__ operator short int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= SHRT_MIN) { + return SHRT_MIN; + } else if (llval >= SHRT_MAX) { + return SHRT_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to signed char, clamp out of bound values, return 0 if + * value is NaN */ + __FP8_HOST_DEVICE__ operator signed char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= SCHAR_MIN) { + return SCHAR_MIN; + } else if (llval >= SCHAR_MAX) { + return SCHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned char, clamp out of bound values, return 0 if + * value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned char() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } else if (llval >= UCHAR_MAX) { + return UCHAR_MAX; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned int, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned long long, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned long long int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } + + /*! convert fp8 e5m2 to unsigned short, return 0 if value is NaN */ + __FP8_HOST_DEVICE__ operator unsigned short int() const { + if (internal::hip_fp8_fnuz_is_nan(__x)) { + return 0; + } + + float fval = *this; + auto llval = static_cast(fval); + if (llval <= 0) { + return 0; + } + return static_cast(fval); + } +}; + +/** + * \brief struct representing two fp8 numbers with e5m2 interpretation + * + */ +struct __hip_fp8x2_e5m2_fnuz { + __hip_fp8x2_storage_t __x; //! raw storage of two fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = + __HIP_E5M2_FNUZ; + static constexpr unsigned int __we = 5; + static constexpr unsigned int __wm = 2; + + /*! create fp8x2 e5m2 type from double2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val) + : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, + __default_interpret)) {} + + /*! create fp8x2 e5m2 type from float2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val) + : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, + __default_interpret)) {} + + /*! create fp8x2 e5m2 type from __hip_bfloat162 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val) + : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, + __default_interpret)) {} + + /*! create fp8x2 e5m2 type from __half2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val) + : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, + __default_interpret)) {} + + /*! default construct fp8x2 e5m2 */ + __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz() = default; + + /*! convert fp8x2 e5m2 to __half2 */ + __FP8_HOST_DEVICE__ operator __half2() const { + return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret)); + } + + /*! convert fp8x2 e5m2 to float2 */ + __FP8_HOST_DEVICE__ operator float2() const { +#if HIP_FP8_CVT_FAST_PATH + return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret); +#else + return float2(internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm, __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we)); +#endif + } +}; + +/** + * \brief struct representing four fp8 numbers with e5m2 interpretation + * + */ +struct __hip_fp8x4_e5m2_fnuz { + __hip_fp8x4_storage_t __x; //! raw storage of four fp8 numbers + static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE; + static constexpr __hip_fp8_interpretation_t __default_interpret = + __HIP_E5M2_FNUZ; + static constexpr unsigned int __we = 5; + static constexpr unsigned int __wm = 2; + + /*! create fp8x4 e5m2 type from double4 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast(__hip_cvt_double_to_fp8( + val.x, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.y, __default_saturation, __default_interpret)) + << 8 | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.z, __default_saturation, __default_interpret)) + << 16 | + reinterpret_cast(__hip_cvt_double_to_fp8( + val.w, __default_saturation, __default_interpret)) + << 24))) {} + + /*! create fp8x4 e5m2 type from float4 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast(__hip_cvt_float_to_fp8( + val.x, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.y, __default_saturation, __default_interpret)) + << 8 | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.z, __default_saturation, __default_interpret)) + << 16 | + reinterpret_cast(__hip_cvt_float_to_fp8( + val.w, __default_saturation, __default_interpret)) + << 24))) {} + + /*! create fp8x4 e5m2 type from two __hip_bfloat162 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, + const __hip_bfloat162 high) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast(__hip_cvt_bfloat16raw2_to_fp8x2( + high, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_bfloat16raw2_to_fp8x2( + low, __default_saturation, __default_interpret)) + << 16))) {} + + /*! create fp8x4 e5m2 type from two __half2 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, + const __half2 high) + : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast( + reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( + high, __default_saturation, __default_interpret)) | + reinterpret_cast(__hip_cvt_halfraw2_to_fp8x2( + low, __default_saturation, __default_interpret)) + << 16))) {} + + /* default construct fp8x4 e5m2 */ + __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz() = default; + + /*! convert fp8x4 e5m2 to float4 */ + __FP8_HOST_DEVICE__ operator float4() const { + auto x = __x; // bypass const + auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t *>(&x); // Little E + auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t *>(&x) + 1); +#if HIP_FP8_CVT_FAST_PATH + float2 high = + internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret); + float2 low = + internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret); +#else + float2 high = float2( + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, + __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we)); + float2 low = float2( + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, + __we), + internal::cast_from_f8( + static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we)); +#endif + return float4(low.x, low.y, high.x, high.y); + } +}; + +#endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_gl_interop.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_gl_interop.h new file mode 100644 index 000000000..5a96e72cb --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_gl_interop.h @@ -0,0 +1,115 @@ +/* +Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_AMD_HIP_GL_INTEROP_H +#define HIP_INCLUDE_AMD_HIP_GL_INTEROP_H + +#if defined(__cplusplus) +extern "C" { +#endif + +/** + * + * @addtogroup GlobalDefs + * @{ + * + */ + +/** + * HIP Devices used by current OpenGL Context. + */ +typedef enum hipGLDeviceList { + hipGLDeviceListAll = 1, ///< All hip devices used by current OpenGL context. + hipGLDeviceListCurrentFrame = 2, ///< Hip devices used by current OpenGL + ///< context in current frame + hipGLDeviceListNextFrame = 3 ///< Hip devices used by current OpenGL context + ///< in next frame. +} hipGLDeviceList; + +/** GLuint as uint.*/ +typedef unsigned int GLuint; +/** GLenum as uint.*/ +typedef unsigned int GLenum; +/** + * @} + */ + +/** + * @ingroup GL + * @{ + * + */ +/** + * @brief Queries devices associated with the current OpenGL context. + * + * @param [out] pHipDeviceCount - Pointer of number of devices on the current GL + * context. + * @param [out] pHipDevices - Pointer of devices on the current OpenGL context. + * @param [in] hipDeviceCount - Size of device. + * @param [in] deviceList - The setting of devices. It could be either + * hipGLDeviceListCurrentFrame for the devices used to render the current frame, + * or hipGLDeviceListAll for all devices. The default setting is Invalid + * deviceList value. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + */ +hipError_t hipGLGetDevices(unsigned int *pHipDeviceCount, int *pHipDevices, + unsigned int hipDeviceCount, + hipGLDeviceList deviceList); +/** + * @brief Registers a GL Buffer for interop and returns corresponding graphics + * resource. + * + * @param [out] resource - Returns pointer of graphics resource. + * @param [in] buffer - Buffer to be registered. + * @param [in] flags - Register flags. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorUnknown, + * #hipErrorInvalidResourceHandle + * + */ +hipError_t hipGraphicsGLRegisterBuffer(hipGraphicsResource **resource, + GLuint buffer, unsigned int flags); +/** + * @brief Register a GL Image for interop and returns the corresponding graphic + * resource. + * + * @param [out] resource - Returns pointer of graphics resource. + * @param [in] image - Image to be registered. + * @param [in] target - Valid target value Id. + * @param [in] flags - Register flags. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorUnknown, + * #hipErrorInvalidResourceHandle + * + */ +hipError_t hipGraphicsGLRegisterImage(hipGraphicsResource **resource, + GLuint image, GLenum target, + unsigned int flags); +/** + * @} + */ +#if defined(__cplusplus) +} +#endif /* __cplusplus */ +#endif /* HIP_INCLUDE_AMD_HIP_GL_INTEROP_H */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_math_constants.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_math_constants.h new file mode 100644 index 000000000..221d48993 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_math_constants.h @@ -0,0 +1,124 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef AMD_HIP_MATH_CONSTANTS_H +#define AMD_HIP_MATH_CONSTANTS_H + +// single precision constants +#define HIP_INF_F __int_as_float(0x7f800000U) +#define HIP_NAN_F __int_as_float(0x7fffffffU) +#define HIP_MIN_DENORM_F __int_as_float(0x00000001U) +#define HIP_MAX_NORMAL_F __int_as_float(0x7f7fffffU) +#define HIP_NEG_ZERO_F __int_as_float(0x80000000U) +#define HIP_ZERO_F 0.0F +#define HIP_ONE_F 1.0F +#define HIP_SQRT_HALF_F 0.707106781F +#define HIP_SQRT_HALF_HI_F 0.707106781F +#define HIP_SQRT_HALF_LO_F 1.210161749e-08F +#define HIP_SQRT_TWO_F 1.414213562F +#define HIP_THIRD_F 0.333333333F +#define HIP_PIO4_F 0.785398163F +#define HIP_PIO2_F 1.570796327F +#define HIP_3PIO4_F 2.356194490F +#define HIP_2_OVER_PI_F 0.636619772F +#define HIP_SQRT_2_OVER_PI_F 0.797884561F +#define HIP_PI_F 3.141592654F +#define HIP_L2E_F 1.442695041F +#define HIP_L2T_F 3.321928094F +#define HIP_LG2_F 0.301029996F +#define HIP_LGE_F 0.434294482F +#define HIP_LN2_F 0.693147181F +#define HIP_LNT_F 2.302585093F +#define HIP_LNPI_F 1.144729886F +#define HIP_TWO_TO_M126_F 1.175494351e-38F +#define HIP_TWO_TO_126_F 8.507059173e37F +#define HIP_NORM_HUGE_F 3.402823466e38F +#define HIP_TWO_TO_23_F 8388608.0F +#define HIP_TWO_TO_24_F 16777216.0F +#define HIP_TWO_TO_31_F 2147483648.0F +#define HIP_TWO_TO_32_F 4294967296.0F +#define HIP_REMQUO_BITS_F 3U +#define HIP_REMQUO_MASK_F (~((~0U) << HIP_REMQUO_BITS_F)) +#define HIP_TRIG_PLOSS_F 105615.0F + +// double precision constants +#define HIP_INF __longlong_as_double(0x7ff0000000000000ULL) +#define HIP_NAN __longlong_as_double(0xfff8000000000000ULL) +#define HIP_NEG_ZERO __longlong_as_double(0x8000000000000000ULL) +#define HIP_MIN_DENORM __longlong_as_double(0x0000000000000001ULL) +#define HIP_ZERO 0.0 +#define HIP_ONE 1.0 +#define HIP_SQRT_TWO 1.4142135623730951e+0 +#define HIP_SQRT_HALF 7.0710678118654757e-1 +#define HIP_SQRT_HALF_HI 7.0710678118654757e-1 +#define HIP_SQRT_HALF_LO (-4.8336466567264567e-17) +#define HIP_THIRD 3.3333333333333333e-1 +#define HIP_TWOTHIRD 6.6666666666666667e-1 +#define HIP_PIO4 7.8539816339744828e-1 +#define HIP_PIO4_HI 7.8539816339744828e-1 +#define HIP_PIO4_LO 3.0616169978683830e-17 +#define HIP_PIO2 1.5707963267948966e+0 +#define HIP_PIO2_HI 1.5707963267948966e+0 +#define HIP_PIO2_LO 6.1232339957367660e-17 +#define HIP_3PIO4 2.3561944901923448e+0 +#define HIP_2_OVER_PI 6.3661977236758138e-1 +#define HIP_PI 3.1415926535897931e+0 +#define HIP_PI_HI 3.1415926535897931e+0 +#define HIP_PI_LO 1.2246467991473532e-16 +#define HIP_SQRT_2PI 2.5066282746310007e+0 +#define HIP_SQRT_2PI_HI 2.5066282746310007e+0 +#define HIP_SQRT_2PI_LO (-1.8328579980459167e-16) +#define HIP_SQRT_PIO2 1.2533141373155003e+0 +#define HIP_SQRT_PIO2_HI 1.2533141373155003e+0 +#define HIP_SQRT_PIO2_LO (-9.1642899902295834e-17) +#define HIP_SQRT_2OPI 7.9788456080286536e-1 +#define HIP_L2E 1.4426950408889634e+0 +#define HIP_L2E_HI 1.4426950408889634e+0 +#define HIP_L2E_LO 2.0355273740931033e-17 +#define HIP_L2T 3.3219280948873622e+0 +#define HIP_LG2 3.0102999566398120e-1 +#define HIP_LG2_HI 3.0102999566398120e-1 +#define HIP_LG2_LO (-2.8037281277851704e-18) +#define HIP_LGE 4.3429448190325182e-1 +#define HIP_LGE_HI 4.3429448190325182e-1 +#define HIP_LGE_LO 1.09831965021676510e-17 +#define HIP_LN2 6.9314718055994529e-1 +#define HIP_LN2_HI 6.9314718055994529e-1 +#define HIP_LN2_LO 2.3190468138462996e-17 +#define HIP_LNT 2.3025850929940459e+0 +#define HIP_LNT_HI 2.3025850929940459e+0 +#define HIP_LNT_LO (-2.1707562233822494e-16) +#define HIP_LNPI 1.1447298858494002e+0 +#define HIP_LN2_X_1024 7.0978271289338397e+2 +#define HIP_LN2_X_1025 7.1047586007394398e+2 +#define HIP_LN2_X_1075 7.4513321910194122e+2 +#define HIP_LG2_X_1024 3.0825471555991675e+2 +#define HIP_LG2_X_1075 3.2360724533877976e+2 +#define HIP_TWO_TO_23 8388608.0 +#define HIP_TWO_TO_52 4503599627370496.0 +#define HIP_TWO_TO_53 9007199254740992.0 +#define HIP_TWO_TO_54 18014398509481984.0 +#define HIP_TWO_TO_M54 5.5511151231257827e-17 +#define HIP_TWO_TO_M1022 2.22507385850720140e-308 +#define HIP_TRIG_PLOSS 2147483648.0 +#define HIP_DBL2INT_CVT 6755399441055744.0 + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_runtime.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_runtime.h new file mode 100644 index 000000000..8fde7eaad --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_runtime.h @@ -0,0 +1,444 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file amd_detail/hip_runtime.h + * @brief Contains definitions of APIs for HIP runtime. + */ + +// #pragma once +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_RUNTIME_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_RUNTIME_H + +#include + +#if !defined(__HIPCC_RTC__) +#ifdef __cplusplus +#include +#else +#include +#endif // __cplusplus +#endif // !defined(__HIPCC_RTC__) + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Query the installed library build name. + * + * This function can be used even when the library is not initialized. + * + * @returns Returns a string describing the build version of the library. The + * string is owned by the library. + */ +const char *amd_dbgapi_get_build_name(); + +/** + * @brief Query the installed library git hash. + * + * This function can be used even when the library is not initialized. + * + * @returns Returns git hash of the library. + */ +const char *amd_dbgapi_get_git_hash(); + +/** + * @brief Query the installed library build ID. + * + * This function can be used even when the library is not initialized. + * + * @returns Returns build ID of the library. + */ +size_t amd_dbgapi_get_build_id(); + +#ifdef __cplusplus +} /* extern "c" */ +#endif + +//--- +// Top part of file can be compiled with any compiler + +#if !defined(__HIPCC_RTC__) +#ifdef __cplusplus +#include +#include +#include +#else +#include +#include +#endif // __cplusplus +#else +#if !__HIP_NO_STD_DEFS__ +typedef unsigned int uint32_t; +typedef unsigned long long uint64_t; +typedef signed int int32_t; +typedef signed long long int64_t; +namespace std { +using ::int32_t; +using ::int64_t; +using ::uint32_t; +using ::uint64_t; +} // namespace std +#endif // __HIP_NO_STD_DEFS__ +#endif // !defined(__HIPCC_RTC__) + +#if __HIP_CLANG_ONLY__ + +#if !defined(__align__) +#define __align__(x) __attribute__((aligned(x))) +#endif + +#define CUDA_SUCCESS hipSuccess + +#if !defined(__HIPCC_RTC__) +#include +#include +#include +#include +#include +#include +extern int HIP_TRACE_API; +#endif // !defined(__HIPCC_RTC__) + +#ifdef __cplusplus +#include +#endif + +#include + +// TODO-HCC remove old definitions ; ~1602 hcc supports __HCC_ACCELERATOR__ +// define. +#if defined(__KALMAR_ACCELERATOR__) && !defined(__HCC_ACCELERATOR__) +#define __HCC_ACCELERATOR__ __KALMAR_ACCELERATOR__ +#endif + +// Feature tests: +#if (defined(__HCC_ACCELERATOR__) && (__HCC_ACCELERATOR__ != 0)) || \ + __HIP_DEVICE_COMPILE__ +// Device compile and not host compile: + +// 32-bit Atomics: +#define __HIP_ARCH_HAS_GLOBAL_INT32_ATOMICS__ (1) +#define __HIP_ARCH_HAS_GLOBAL_FLOAT_ATOMIC_EXCH__ (1) +#define __HIP_ARCH_HAS_SHARED_INT32_ATOMICS__ (1) +#define __HIP_ARCH_HAS_SHARED_FLOAT_ATOMIC_EXCH__ (1) +#define __HIP_ARCH_HAS_FLOAT_ATOMIC_ADD__ (1) + +// 64-bit Atomics: +#define __HIP_ARCH_HAS_GLOBAL_INT64_ATOMICS__ (1) +#define __HIP_ARCH_HAS_SHARED_INT64_ATOMICS__ (1) + +// Doubles +#define __HIP_ARCH_HAS_DOUBLES__ (1) + +// warp cross-lane operations: +#define __HIP_ARCH_HAS_WARP_VOTE__ (1) +#define __HIP_ARCH_HAS_WARP_BALLOT__ (1) +#define __HIP_ARCH_HAS_WARP_SHUFFLE__ (1) +#define __HIP_ARCH_HAS_WARP_FUNNEL_SHIFT__ (0) + +// sync +#define __HIP_ARCH_HAS_THREAD_FENCE_SYSTEM__ (1) +#define __HIP_ARCH_HAS_SYNC_THREAD_EXT__ (0) + +// misc +#define __HIP_ARCH_HAS_SURFACE_FUNCS__ (0) +#define __HIP_ARCH_HAS_3DGRID__ (1) +#define __HIP_ARCH_HAS_DYNAMIC_PARALLEL__ (0) + +#endif /* Device feature flags */ + +#define launch_bounds_impl0(requiredMaxThreadsPerBlock) \ + __attribute__((amdgpu_flat_work_group_size(1, requiredMaxThreadsPerBlock))) +#define launch_bounds_impl1(requiredMaxThreadsPerBlock, \ + minBlocksPerMultiprocessor) \ + __attribute__((amdgpu_flat_work_group_size(1, requiredMaxThreadsPerBlock), \ + amdgpu_waves_per_eu(minBlocksPerMultiprocessor))) +#define select_impl_(_1, _2, impl_, ...) impl_ +#define __launch_bounds__(...) \ + select_impl_(__VA_ARGS__, launch_bounds_impl1, \ + launch_bounds_impl0, )(__VA_ARGS__) + +#if !defined(__HIPCC_RTC__) +__host__ inline void *__get_dynamicgroupbaseptr() { return nullptr; } +#endif // !defined(__HIPCC_RTC__) + +// End doxygen API: +/** + * @} + */ + +// +// hip-clang functions +// +#if !defined(__HIPCC_RTC__) +#define HIP_KERNEL_NAME(...) __VA_ARGS__ +#define HIP_SYMBOL(X) X + +typedef int hipLaunchParm; + +template ::type * = nullptr> +void pArgs(const std::tuple &, void *) {} + +template ::type * = nullptr> +void pArgs(const std::tuple &formals, void **_vargs) { + using T = typename std::tuple_element>::type; + + static_assert(!std::is_reference{}, + "A __global__ function cannot have a reference as one of its " + "arguments."); +#if defined(HIP_STRICT) + static_assert(std::is_trivially_copyable{}, + "Only TriviallyCopyable types can be arguments to a __global__ " + "function"); +#endif + _vargs[n] = + const_cast(reinterpret_cast(&std::get(formals))); + return pArgs(formals, _vargs); +} + +template +std::tuple validateArgsCountType(void (*kernel)(Formals...), + std::tuple(actuals)) { + static_assert(sizeof...(Formals) == sizeof...(Actuals), + "Argument Count Mismatch"); + std::tuple to_formals{std::move(actuals)}; + return to_formals; +} + +#if defined(HIP_TEMPLATE_KERNEL_LAUNCH) +template +void hipLaunchKernelGGL(F kernel, const dim3 &numBlocks, const dim3 &dimBlocks, + std::uint32_t sharedMemBytes, hipStream_t stream, + Args... args) { + constexpr size_t count = sizeof...(Args); + auto tup_ = std::tuple{args...}; + auto tup = validateArgsCountType(kernel, tup_); + void *_Args[count]; + pArgs<0>(tup, _Args); + + auto k = reinterpret_cast(kernel); + hipLaunchKernel(k, numBlocks, dimBlocks, _Args, sharedMemBytes, stream); +} +#else +#define hipLaunchKernelGGLInternal(kernelName, numBlocks, numThreads, \ + memPerBlock, streamId, ...) \ + do { \ + kernelName<<<(numBlocks), (numThreads), (memPerBlock), (streamId)>>>( \ + __VA_ARGS__); \ + } while (0) + +#define hipLaunchKernelGGL(kernelName, ...) \ + hipLaunchKernelGGLInternal((kernelName), __VA_ARGS__) +#endif + +#include +#endif // !defined(__HIPCC_RTC__) + +#if defined(__HIPCC_RTC__) +typedef struct dim3 { + uint32_t x; ///< x + uint32_t y; ///< y + uint32_t z; ///< z +#ifdef __cplusplus + constexpr __device__ dim3(uint32_t _x = 1, uint32_t _y = 1, uint32_t _z = 1) + : x(_x), y(_y), z(_z) {}; +#endif +} dim3; +#endif // !defined(__HIPCC_RTC__) + +#pragma push_macro("__DEVICE__") +#define __DEVICE__ static __device__ __forceinline__ + +extern "C" __device__ __attribute__((const)) size_t +__ockl_get_local_id(unsigned int); +__DEVICE__ unsigned int __hip_get_thread_idx_x() { + return __ockl_get_local_id(0); +} +__DEVICE__ unsigned int __hip_get_thread_idx_y() { + return __ockl_get_local_id(1); +} +__DEVICE__ unsigned int __hip_get_thread_idx_z() { + return __ockl_get_local_id(2); +} + +extern "C" __device__ __attribute__((const)) size_t +__ockl_get_group_id(unsigned int); +__DEVICE__ unsigned int __hip_get_block_idx_x() { + return __ockl_get_group_id(0); +} +__DEVICE__ unsigned int __hip_get_block_idx_y() { + return __ockl_get_group_id(1); +} +__DEVICE__ unsigned int __hip_get_block_idx_z() { + return __ockl_get_group_id(2); +} + +extern "C" __device__ __attribute__((const)) size_t +__ockl_get_local_size(unsigned int); +__DEVICE__ unsigned int __hip_get_block_dim_x() { + return __ockl_get_local_size(0); +} +__DEVICE__ unsigned int __hip_get_block_dim_y() { + return __ockl_get_local_size(1); +} +__DEVICE__ unsigned int __hip_get_block_dim_z() { + return __ockl_get_local_size(2); +} + +extern "C" __device__ __attribute__((const)) size_t +__ockl_get_num_groups(unsigned int); +__DEVICE__ unsigned int __hip_get_grid_dim_x() { + return __ockl_get_num_groups(0); +} +__DEVICE__ unsigned int __hip_get_grid_dim_y() { + return __ockl_get_num_groups(1); +} +__DEVICE__ unsigned int __hip_get_grid_dim_z() { + return __ockl_get_num_groups(2); +} + +#define __HIP_DEVICE_BUILTIN(DIMENSION, FUNCTION) \ + __declspec(property(get = __get_##DIMENSION)) unsigned int DIMENSION; \ + __DEVICE__ unsigned int __get_##DIMENSION(void) { return FUNCTION; } + +struct __hip_builtin_threadIdx_t { + __HIP_DEVICE_BUILTIN(x, __hip_get_thread_idx_x()); + __HIP_DEVICE_BUILTIN(y, __hip_get_thread_idx_y()); + __HIP_DEVICE_BUILTIN(z, __hip_get_thread_idx_z()); +#ifdef __cplusplus + __device__ operator dim3() const { return dim3(x, y, z); } +#endif +}; + +struct __hip_builtin_blockIdx_t { + __HIP_DEVICE_BUILTIN(x, __hip_get_block_idx_x()); + __HIP_DEVICE_BUILTIN(y, __hip_get_block_idx_y()); + __HIP_DEVICE_BUILTIN(z, __hip_get_block_idx_z()); +#ifdef __cplusplus + __device__ operator dim3() const { return dim3(x, y, z); } +#endif +}; + +struct __hip_builtin_blockDim_t { + __HIP_DEVICE_BUILTIN(x, __hip_get_block_dim_x()); + __HIP_DEVICE_BUILTIN(y, __hip_get_block_dim_y()); + __HIP_DEVICE_BUILTIN(z, __hip_get_block_dim_z()); +#ifdef __cplusplus + __device__ operator dim3() const { return dim3(x, y, z); } +#endif +}; + +struct __hip_builtin_gridDim_t { + __HIP_DEVICE_BUILTIN(x, __hip_get_grid_dim_x()); + __HIP_DEVICE_BUILTIN(y, __hip_get_grid_dim_y()); + __HIP_DEVICE_BUILTIN(z, __hip_get_grid_dim_z()); +#ifdef __cplusplus + __device__ operator dim3() const { return dim3(x, y, z); } +#endif +}; + +#undef __HIP_DEVICE_BUILTIN +#pragma pop_macro("__DEVICE__") + +extern const __device__ + __attribute__((weak)) __hip_builtin_threadIdx_t threadIdx; +extern const __device__ __attribute__((weak)) __hip_builtin_blockIdx_t blockIdx; +extern const __device__ __attribute__((weak)) __hip_builtin_blockDim_t blockDim; +extern const __device__ __attribute__((weak)) __hip_builtin_gridDim_t gridDim; + +#define hipThreadIdx_x threadIdx.x +#define hipThreadIdx_y threadIdx.y +#define hipThreadIdx_z threadIdx.z + +#define hipBlockIdx_x blockIdx.x +#define hipBlockIdx_y blockIdx.y +#define hipBlockIdx_z blockIdx.z + +#define hipBlockDim_x blockDim.x +#define hipBlockDim_y blockDim.y +#define hipBlockDim_z blockDim.z + +#define hipGridDim_x gridDim.x +#define hipGridDim_y gridDim.y +#define hipGridDim_z gridDim.z + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if __HIP_HCC_COMPAT_MODE__ +// Define HCC work item functions in terms of HIP builtin variables. +#pragma push_macro("__DEFINE_HCC_FUNC") +#define __DEFINE_HCC_FUNC(hc_fun, hip_var) \ + inline __device__ __attribute__(( \ + always_inline)) unsigned int hc_get_##hc_fun(unsigned int i) { \ + if (i == 0) \ + return hip_var.x; \ + else if (i == 1) \ + return hip_var.y; \ + else \ + return hip_var.z; \ + } + +__DEFINE_HCC_FUNC(workitem_id, threadIdx) +__DEFINE_HCC_FUNC(group_id, blockIdx) +__DEFINE_HCC_FUNC(group_size, blockDim) +__DEFINE_HCC_FUNC(num_groups, gridDim) +#pragma pop_macro("__DEFINE_HCC_FUNC") + +extern "C" __device__ __attribute__((const)) size_t +__ockl_get_global_id(unsigned int); +inline __device__ __attribute__((always_inline)) unsigned int +hc_get_workitem_absolute_id(int dim) { + return (unsigned int)__ockl_get_global_id(dim); +} + +#endif + +#if !__CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__ +#if !defined(__HIPCC_RTC__) +// Support std::complex. +#if !_OPENMP || __HIP_ENABLE_CUDA_WRAPPER_FOR_OPENMP__ +#pragma push_macro("__CUDA__") +#define __CUDA__ +#include <__clang_cuda_complex_builtins.h> +#include <__clang_cuda_math_forward_declares.h> +// Workaround for using libc++ with HIP-Clang. +// The following headers requires clang include path before standard C++ include +// path. However libc++ include path requires to be before clang include path. +// To workaround this, we pass -isystem with the parent directory of clang +// include path instead of the clang include path itself. +#include +#include +#include +#undef __CUDA__ +#pragma pop_macro("__CUDA__") +#endif // !_OPENMP || __HIP_ENABLE_CUDA_WRAPPER_FOR_OPENMP__ +#endif // !defined(__HIPCC_RTC__) +#endif // !__CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__ +#endif // __HIP_CLANG_ONLY__ + +#endif // HIP_AMD_DETAIL_RUNTIME_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_runtime_pt_api.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_runtime_pt_api.h new file mode 100644 index 000000000..a58670016 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_runtime_pt_api.h @@ -0,0 +1,222 @@ +/* +Copyright (c) 2022 - Present Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#ifndef HIP_INCLUDE_HIP_HIP_RUNTIME_PT_API_H +#define HIP_INCLUDE_HIP_HIP_RUNTIME_PT_API_H + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) + +/// hipStreamPerThread implementation +#if defined(HIP_API_PER_THREAD_DEFAULT_STREAM) +#define __HIP_STREAM_PER_THREAD +#define __HIP_API_SPT(api) api##_spt +#else +#define __HIP_API_SPT(api) api +#endif + +#if defined(__HIP_STREAM_PER_THREAD) +// Memory APIs +#define hipMemcpy __HIP_API_SPT(hipMemcpy) +#define hipMemcpyToSymbol __HIP_API_SPT(hipMemcpyToSymbol) +#define hipMemcpyFromSymbol __HIP_API_SPT(hipMemcpyFromSymbol) +#define hipMemcpy2D __HIP_API_SPT(hipMemcpy2D) +#define hipMemcpy2DFromArray __HIP_API_SPT(hipMemcpy2DFromArray) +#define hipMemcpy3D __HIP_API_SPT(hipMemcpy3D) +#define hipMemset __HIP_API_SPT(hipMemset) +#define hipMemset2D __HIP_API_SPT(hipMemset2D) +#define hipMemset3D __HIP_API_SPT(hipMemset3D) +#define hipMemcpyAsync __HIP_API_SPT(hipMemcpyAsync) +#define hipMemset3DAsync __HIP_API_SPT(hipMemset3DAsync) +#define hipMemset2DAsync __HIP_API_SPT(hipMemset2DAsync) +#define hipMemsetAsync __HIP_API_SPT(hipMemsetAsync) +#define hipMemcpy3DAsync __HIP_API_SPT(hipMemcpy3DAsync) +#define hipMemcpy2DAsync __HIP_API_SPT(hipMemcpy2DAsync) +#define hipMemcpyFromSymbolAsync __HIP_API_SPT(hipMemcpyFromSymbolAsync) +#define hipMemcpyToSymbolAsync __HIP_API_SPT(hipMemcpyToSymbolAsync) +#define hipMemcpyFromArray __HIP_API_SPT(hipMemcpyFromArray) +#define hipMemcpy2DToArray __HIP_API_SPT(hipMemcpy2DToArray) +#define hipMemcpy2DFromArrayAsync __HIP_API_SPT(hipMemcpy2DFromArrayAsync) +#define hipMemcpy2DToArrayAsync __HIP_API_SPT(hipMemcpy2DToArrayAsync) + +// Stream APIs +#define hipStreamSynchronize __HIP_API_SPT(hipStreamSynchronize) +#define hipStreamQuery __HIP_API_SPT(hipStreamQuery) +#define hipStreamGetFlags __HIP_API_SPT(hipStreamGetFlags) +#define hipStreamGetPriority __HIP_API_SPT(hipStreamGetPriority) +#define hipStreamWaitEvent __HIP_API_SPT(hipStreamWaitEvent) +#define hipStreamAddCallback __HIP_API_SPT(hipStreamAddCallback) +#define hipLaunchHostFunc __HIP_API_SPT(hipLaunchHostFunc) + +// Event APIs +#define hipEventRecord __HIP_API_SPT(hipEventRecord) + +// Launch APIs +#define hipLaunchKernel __HIP_API_SPT(hipLaunchKernel) +#define hipLaunchCooperativeKernel __HIP_API_SPT(hipLaunchCooperativeKernel) + +// Graph APIs +#define hipGraphLaunch __HIP_API_SPT(hipGraphLaunch) +#define hipStreamBeginCapture __HIP_API_SPT(hipStreamBeginCapture) +#define hipStreamEndCapture __HIP_API_SPT(hipStreamEndCapture) +#define hipStreamIsCapturing __HIP_API_SPT(hipStreamIsCapturing) +#define hipStreamGetCaptureInfo __HIP_API_SPT(hipStreamGetCaptureInfo) +#define hipStreamGetCaptureInfo_v2 __HIP_API_SPT(hipStreamGetCaptureInfo_v2) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +hipError_t hipMemcpy_spt(void *dst, const void *src, size_t sizeBytes, + hipMemcpyKind kind); + +hipError_t +hipMemcpyToSymbol_spt(const void *symbol, const void *src, size_t sizeBytes, + size_t offset __dparm(0), + hipMemcpyKind kind __dparm(hipMemcpyHostToDevice)); + +hipError_t +hipMemcpyFromSymbol_spt(void *dst, const void *symbol, size_t sizeBytes, + size_t offset __dparm(0), + hipMemcpyKind kind __dparm(hipMemcpyDeviceToHost)); + +hipError_t hipMemcpy2D_spt(void *dst, size_t dpitch, const void *src, + size_t spitch, size_t width, size_t height, + hipMemcpyKind kind); + +hipError_t hipMemcpy2DFromArray_spt(void *dst, size_t dpitch, + hipArray_const_t src, size_t wOffset, + size_t hOffset, size_t width, size_t height, + hipMemcpyKind kind); + +hipError_t hipMemcpy3D_spt(const struct hipMemcpy3DParms *p); + +hipError_t hipMemset_spt(void *dst, int value, size_t sizeBytes); + +hipError_t hipMemsetAsync_spt(void *dst, int value, size_t sizeBytes, + hipStream_t stream); + +hipError_t hipMemset2D_spt(void *dst, size_t pitch, int value, size_t width, + size_t height); + +hipError_t hipMemset2DAsync_spt(void *dst, size_t pitch, int value, + size_t width, size_t height, + hipStream_t stream); + +hipError_t hipMemset3DAsync_spt(hipPitchedPtr pitchedDevPtr, int value, + hipExtent extent, hipStream_t stream); + +hipError_t hipMemset3D_spt(hipPitchedPtr pitchedDevPtr, int value, + hipExtent extent); + +hipError_t hipMemcpyAsync_spt(void *dst, const void *src, size_t sizeBytes, + hipMemcpyKind kind, hipStream_t stream); + +hipError_t hipMemcpy3DAsync_spt(const hipMemcpy3DParms *p, hipStream_t stream); + +hipError_t hipMemcpy2DAsync_spt(void *dst, size_t dpitch, const void *src, + size_t spitch, size_t width, size_t height, + hipMemcpyKind kind, hipStream_t stream); + +hipError_t hipMemcpyFromSymbolAsync_spt(void *dst, const void *symbol, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind, hipStream_t stream); + +hipError_t hipMemcpyToSymbolAsync_spt(const void *symbol, const void *src, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind, hipStream_t stream); + +hipError_t hipMemcpyFromArray_spt(void *dst, hipArray_const_t src, + size_t wOffsetSrc, size_t hOffset, + size_t count, hipMemcpyKind kind); + +hipError_t hipMemcpy2DToArray_spt(hipArray_t dst, size_t wOffset, + size_t hOffset, const void *src, + size_t spitch, size_t width, size_t height, + hipMemcpyKind kind); + +hipError_t hipMemcpy2DFromArrayAsync_spt(void *dst, size_t dpitch, + hipArray_const_t src, + size_t wOffsetSrc, size_t hOffsetSrc, + size_t width, size_t height, + hipMemcpyKind kind, + hipStream_t stream); + +hipError_t hipMemcpy2DToArrayAsync_spt(hipArray_t dst, size_t wOffset, + size_t hOffset, const void *src, + size_t spitch, size_t width, + size_t height, hipMemcpyKind kind, + hipStream_t stream); + +hipError_t hipStreamQuery_spt(hipStream_t stream); + +hipError_t hipStreamSynchronize_spt(hipStream_t stream); + +hipError_t hipStreamGetPriority_spt(hipStream_t stream, int *priority); + +hipError_t hipStreamWaitEvent_spt(hipStream_t stream, hipEvent_t event, + unsigned int flags __dparm(0)); + +hipError_t hipStreamGetFlags_spt(hipStream_t stream, unsigned int *flags); + +hipError_t hipStreamAddCallback_spt(hipStream_t stream, + hipStreamCallback_t callback, + void *userData, unsigned int flags); +#ifdef __cplusplus +hipError_t hipEventRecord_spt(hipEvent_t event, hipStream_t stream = NULL); +#else +hipError_t hipEventRecord_spt(hipEvent_t event, hipStream_t stream); +#endif + +hipError_t hipLaunchCooperativeKernel_spt(const void *f, dim3 gridDim, + dim3 blockDim, void **kernelParams, + uint32_t sharedMemBytes, + hipStream_t hStream); + +hipError_t hipLaunchKernel_spt(const void *function_address, dim3 numBlocks, + dim3 dimBlocks, void **args, + size_t sharedMemBytes, hipStream_t stream); + +hipError_t hipGraphLaunch_spt(hipGraphExec_t graphExec, hipStream_t stream); +hipError_t hipStreamBeginCapture_spt(hipStream_t stream, + hipStreamCaptureMode mode); +hipError_t hipStreamEndCapture_spt(hipStream_t stream, hipGraph_t *pGraph); +hipError_t hipStreamIsCapturing_spt(hipStream_t stream, + hipStreamCaptureStatus *pCaptureStatus); +hipError_t hipStreamGetCaptureInfo_spt(hipStream_t stream, + hipStreamCaptureStatus *pCaptureStatus, + unsigned long long *pId); +hipError_t hipStreamGetCaptureInfo_v2_spt( + hipStream_t stream, hipStreamCaptureStatus *captureStatus_out, + unsigned long long *id_out, hipGraph_t *graph_out, + const hipGraphNode_t **dependencies_out, size_t *numDependencies_out); +hipError_t hipLaunchHostFunc_spt(hipStream_t stream, hipHostFn_t fn, + void *userData); + +#ifdef __cplusplus +} +#endif // extern "C" + +#endif // defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#endif // HIP_INCLUDE_HIP_HIP_RUNTIME_PT_API_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_unsafe_atomics.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_unsafe_atomics.h new file mode 100644 index 000000000..abb3495f3 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_unsafe_atomics.h @@ -0,0 +1,605 @@ +/* +Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#ifdef __cplusplus + +/** + * @brief Unsafe floating point rmw atomic add. + * + * Performs a relaxed read-modify-write floating point atomic add with + * device memory scope. Original value at \p addr is returned and + * the value of \p addr is updated to have the original value plus \p value + * + * @note This operation currently only performs different operations for + * the gfx90a target. Other devices continue to use safe atomics. + * + * It can be used to generate code that uses fast hardware floating point atomic + * operations which may handle rounding and subnormal values differently than + * non-atomic floating point operations. + * + * The operation is not always safe and can have undefined behavior unless + * following condition are met: + * + * - \p addr is at least 4 bytes aligned + * - If \p addr is a global segment address, it is in a coarse grain allocation. + * Passing in global segment addresses in fine grain allocations will result in + * undefined behavior and is not supported. + * + * @param [in,out] addr Pointer to value to be increment by \p value. + * @param [in] value Value by \p addr is to be incremented. + * @return Original value contained in \p addr. + */ +__device__ inline float unsafeAtomicAdd(float *addr, float value) { +#if defined(__gfx90a__) && __has_builtin(__builtin_amdgcn_is_shared) && \ + __has_builtin(__builtin_amdgcn_is_private) && \ + __has_builtin(__builtin_amdgcn_ds_atomic_fadd_f32) && \ + __has_builtin(__builtin_amdgcn_global_atomic_fadd_f32) + if (__builtin_amdgcn_is_shared( + (const __attribute__((address_space(0))) void *)addr)) + return __builtin_amdgcn_ds_atomic_fadd_f32(addr, value); + else if (__builtin_amdgcn_is_private( + (const __attribute__((address_space(0))) void *)addr)) { + float temp = *addr; + *addr = temp + value; + return temp; + } else + return __builtin_amdgcn_global_atomic_fadd_f32(addr, value); +#elif __has_builtin(__hip_atomic_fetch_add) + return __hip_atomic_fetch_add(addr, value, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#else + return __atomic_fetch_add(addr, value, __ATOMIC_RELAXED); +#endif +} + +/** + * @brief Unsafe floating point rmw atomic max. + * + * Performs a relaxed read-modify-write floating point atomic max with + * device memory scope. The original value at \p addr is returned and + * the value at \p addr is replaced by \p val if greater. + * + * @note This operation is currently identical to that performed by + * atomicMax and is included for completeness. + * + * @param [in,out] addr Pointer to value to be updated + * @param [in] val Value used to update the value at \p addr. + * @return Original value contained in \p addr. + */ +__device__ inline float unsafeAtomicMax(float *addr, float val) { +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + float value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && value < val) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned int *uaddr = (unsigned int *)addr; + unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && __uint_as_float(value) < val) { + done = + __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __uint_as_float(value); +#endif +} + +/** + * @brief Unsafe floating point rmw atomic min. + * + * Performs a relaxed read-modify-write floating point atomic min with + * device memory scope. The original value at \p addr is returned and + * the value at \p addr is replaced by \p val if lesser. + * + * @note This operation is currently identical to that performed by + * atomicMin and is included for completeness. + * + * @param [in,out] addr Pointer to value to be updated + * @param [in] val Value used to update the value at \p addr. + * @return Original value contained in \p addr. + */ +__device__ inline float unsafeAtomicMin(float *addr, float val) { +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + float value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && value > val) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned int *uaddr = (unsigned int *)addr; + unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && __uint_as_float(value) > val) { + done = + __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __uint_as_float(value); +#endif +} + +/** + * @brief Unsafe double precision rmw atomic add. + * + * Performs a relaxed read-modify-write double precision atomic add with + * device memory scope. Original value at \p addr is returned and + * the value of \p addr is updated to have the original value plus \p value + * + * @note This operation currently only performs different operations for + * the gfx90a target. Other devices continue to use safe atomics. + * + * It can be used to generate code that uses fast hardware floating point atomic + * operations which may handle rounding and subnormal values differently than + * non-atomic floating point operations. + * + * The operation is not always safe and can have undefined behavior unless + * following condition are met: + * + * - \p addr is at least 8 byte aligned + * - If \p addr is a global segment address, it is in a coarse grain allocation. + * Passing in global segment addresses in fine grain allocations will result in + * undefined behavior and are not supported. + * + * @param [in,out] addr Pointer to value to be updated. + * @param [in] value Value by \p addr is to be incremented. + * @return Original value contained in \p addr. + */ +__device__ inline double unsafeAtomicAdd(double *addr, double value) { +#if defined(__gfx90a__) && __has_builtin(__builtin_amdgcn_flat_atomic_fadd_f64) + return __builtin_amdgcn_flat_atomic_fadd_f64(addr, value); +#elif defined(__hip_atomic_fetch_add) + return __hip_atomic_fetch_add(addr, value, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#else + return __atomic_fetch_add(addr, value, __ATOMIC_RELAXED); +#endif +} + +/** + * @brief Unsafe double precision rmw atomic max. + * + * Performs a relaxed read-modify-write double precision atomic max with + * device memory scope. Original value at \p addr is returned and + * the value of \p addr is updated with \p val if greater. + * + * @note This operation currently only performs different operations for + * the gfx90a target. Other devices continue to use safe atomics. + * + * It can be used to generate code that uses fast hardware floating point atomic + * operations which may handle rounding and subnormal values differently than + * non-atomic floating point operations. + * + * The operation is not always safe and can have undefined behavior unless + * following condition are met: + * + * - \p addr is at least 8 byte aligned + * - If \p addr is a global segment address, it is in a coarse grain allocation. + * Passing in global segment addresses in fine grain allocations will result in + * undefined behavior and are not supported. + * + * @param [in,out] addr Pointer to value to be updated. + * @param [in] val Value used to updated the contents at \p addr + * @return Original value contained at \p addr. + */ +__device__ inline double unsafeAtomicMax(double *addr, double val) { +#if (defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fmax_f64) + return __builtin_amdgcn_flat_atomic_fmax_f64(addr, val); +#else +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + double value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && value < val) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned long long *uaddr = (unsigned long long *)addr; + unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && __longlong_as_double(value) < val) { + done = + __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), + false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __longlong_as_double(value); +#endif +#endif +} + +/** + * @brief Unsafe double precision rmw atomic min. + * + * Performs a relaxed read-modify-write double precision atomic min with + * device memory scope. Original value at \p addr is returned and + * the value of \p addr is updated with \p val if lesser. + * + * @note This operation currently only performs different operations for + * the gfx90a target. Other devices continue to use safe atomics. + * + * It can be used to generate code that uses fast hardware floating point atomic + * operations which may handle rounding and subnormal values differently than + * non-atomic floating point operations. + * + * The operation is not always safe and can have undefined behavior unless + * following condition are met: + * + * - \p addr is at least 8 byte aligned + * - If \p addr is a global segment address, it is in a coarse grain allocation. + * Passing in global segment addresses in fine grain allocations will result in + * undefined behavior and are not supported. + * + * @param [in,out] addr Pointer to value to be updated. + * @param [in] val Value used to updated the contents at \p addr + * @return Original value contained at \p addr. + */ +__device__ inline double unsafeAtomicMin(double *addr, double val) { +#if (defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fmin_f64) + return __builtin_amdgcn_flat_atomic_fmin_f64(addr, val); +#else +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + double value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && value > val) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned long long *uaddr = (unsigned long long *)addr; + unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && __longlong_as_double(value) > val) { + done = + __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), + false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __longlong_as_double(value); +#endif +#endif +} + +/** + * @brief Safe floating point rmw atomic add. + * + * Performs a relaxed read-modify-write floating point atomic add with + * device memory scope. Original value at \p addr is returned and + * the value of \p addr is updated to have the original value plus \p value + * + * @note This operation ensures that, on all targets, we produce safe atomics. + * This will be the case even when -munsafe-fp-atomics is passed into the + * compiler. + * + * @param [in,out] addr Pointer to value to be increment by \p value. + * @param [in] value Value by \p addr is to be incremented. + * @return Original value contained in \p addr. + */ +__device__ inline float safeAtomicAdd(float *addr, float value) { +#if defined(__gfx908__) || defined(__gfx941__) || \ + ((defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx942__)) && \ + !__has_builtin(__hip_atomic_fetch_add)) + // On gfx908, we can generate unsafe FP32 atomic add that does not follow all + // IEEE rules when -munsafe-fp-atomics is passed. Do a CAS loop emulation + // instead. On gfx941, we can generate unsafe FP32 atomic add that may not + // always happen atomically, so we need to force a CAS loop emulation to + // ensure safety. On gfx90a, gfx940 and gfx942 if we do not have the + // __hip_atomic_fetch_add builtin, we need to force a CAS loop here. + float old_val; +#if __has_builtin(__hip_atomic_load) + old_val = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); +#else // !__has_builtin(__hip_atomic_load) + old_val = __uint_as_float(__atomic_load_n( + reinterpret_cast(addr), __ATOMIC_RELAXED)); +#endif // __has_builtin(__hip_atomic_load) + float expected, temp; + do { + temp = expected = old_val; +#if __has_builtin(__hip_atomic_compare_exchange_strong) + __hip_atomic_compare_exchange_strong(addr, &expected, old_val + value, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#else // !__has_builtin(__hip_atomic_compare_exchange_strong) + __atomic_compare_exchange_n(addr, &expected, old_val + value, false, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); +#endif // __has_builtin(__hip_atomic_compare_exchange_strong) + old_val = expected; + } while (__float_as_uint(temp) != __float_as_uint(old_val)); + return old_val; +#elif defined(__gfx90a__) + // On gfx90a, with the __hip_atomic_fetch_add builtin, relaxed system-scope + // atomics will produce safe CAS loops, but are otherwise not different than + // agent-scope atomics. This logic is only applicable for gfx90a, and should + // not be assumed on other architectures. + return __hip_atomic_fetch_add(addr, value, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#elif __has_builtin(__hip_atomic_fetch_add) + return __hip_atomic_fetch_add(addr, value, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#else + return __atomic_fetch_add(addr, value, __ATOMIC_RELAXED); +#endif +} + +/** + * @brief Safe floating point rmw atomic max. + * + * Performs a relaxed read-modify-write floating point atomic max with + * device memory scope. The original value at \p addr is returned and + * the value at \p addr is replaced by \p val if greater. + * + * @note This operation ensures that, on all targets, we produce safe atomics. + * This will be the case even when -munsafe-fp-atomics is passed into the + * compiler. + * + * @param [in,out] addr Pointer to value to be updated + * @param [in] val Value used to update the value at \p addr. + * @return Original value contained in \p addr. + */ +__device__ inline float safeAtomicMax(float *addr, float val) { +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + float value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && value < val) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned int *uaddr = (unsigned int *)addr; + unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && __uint_as_float(value) < val) { + done = + __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __uint_as_float(value); +#endif +} + +/** + * @brief Safe floating point rmw atomic min. + * + * Performs a relaxed read-modify-write floating point atomic min with + * device memory scope. The original value at \p addr is returned and + * the value at \p addr is replaced by \p val if lesser. + * + * @note This operation ensures that, on all targets, we produce safe atomics. + * This will be the case even when -munsafe-fp-atomics is passed into the + * compiler. + * + * @param [in,out] addr Pointer to value to be updated + * @param [in] val Value used to update the value at \p addr. + * @return Original value contained in \p addr. + */ +__device__ inline float safeAtomicMin(float *addr, float val) { +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + float value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && value > val) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned int *uaddr = (unsigned int *)addr; + unsigned int value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && __uint_as_float(value) > val) { + done = + __atomic_compare_exchange_n(uaddr, &value, __float_as_uint(val), false, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __uint_as_float(value); +#endif +} + +/** + * @brief Safe double precision rmw atomic add. + * + * Performs a relaxed read-modify-write double precision atomic add with + * device memory scope. Original value at \p addr is returned and + * the value of \p addr is updated to have the original value plus \p value + * + * @note This operation ensures that, on all targets, we produce safe atomics. + * This will be the case even when -munsafe-fp-atomics is passed into the + * compiler. + * + * @param [in,out] addr Pointer to value to be increment by \p value. + * @param [in] value Value by \p addr is to be incremented. + * @return Original value contained in \p addr. + */ +__device__ inline double safeAtomicAdd(double *addr, double value) { +#if defined(__gfx90a__) && __has_builtin(__hip_atomic_fetch_add) + // On gfx90a, with the __hip_atomic_fetch_add builtin, relaxed system-scope + // atomics will produce safe CAS loops, but are otherwise not different than + // agent-scope atomics. This logic is only applicable for gfx90a, and should + // not be assumed on other architectures. + return __hip_atomic_fetch_add(addr, value, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); +#elif defined(__gfx90a__) + // On gfx90a, if we do not have the __hip_atomic_fetch_add builtin, we need to + // force a CAS loop here. + double old_val; +#if __has_builtin(__hip_atomic_load) + old_val = __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); +#else // !__has_builtin(__hip_atomic_load) + old_val = __longlong_as_double(__atomic_load_n( + reinterpret_cast(addr), __ATOMIC_RELAXED)); +#endif // __has_builtin(__hip_atomic_load) + double expected, temp; + do { + temp = expected = old_val; +#if __has_builtin(__hip_atomic_compare_exchange_strong) + __hip_atomic_compare_exchange_strong(addr, &expected, old_val + value, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#else // !__has_builtin(__hip_atomic_compare_exchange_strong) + __atomic_compare_exchange_n(addr, &expected, old_val + value, false, + __ATOMIC_RELAXED, __ATOMIC_RELAXED); +#endif // __has_builtin(__hip_atomic_compare_exchange_strong) + old_val = expected; + } while (__double_as_longlong(temp) != __double_as_longlong(old_val)); + return old_val; +#else // !defined(__gfx90a__) +#if __has_builtin(__hip_atomic_fetch_add) + return __hip_atomic_fetch_add(addr, value, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); +#else // !__has_builtin(__hip_atomic_fetch_add) + return __atomic_fetch_add(addr, value, __ATOMIC_RELAXED); +#endif // __has_builtin(__hip_atomic_fetch_add) +#endif +} + +/** + * @brief Safe double precision rmw atomic max. + * + * Performs a relaxed read-modify-write double precision atomic max with + * device memory scope. Original value at \p addr is returned and + * the value of \p addr is updated with \p val if greater. + * + * @note This operation ensures that, on all targets, we produce safe atomics. + * This will be the case even when -munsafe-fp-atomics is passed into the + * compiler. + * + * @param [in,out] addr Pointer to value to be updated. + * @param [in] val Value used to updated the contents at \p addr + * @return Original value contained at \p addr. + */ +__device__ inline double safeAtomicMax(double *addr, double val) { +#if __has_builtin(__builtin_amdgcn_is_private) + if (__builtin_amdgcn_is_private( + (const __attribute__((address_space(0))) void *)addr)) { + double old = *addr; + *addr = __builtin_fmax(old, val); + return old; + } else { +#endif +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + double value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && value < val) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned long long *uaddr = (unsigned long long *)addr; + unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && __longlong_as_double(value) < val) { + done = + __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), + false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __longlong_as_double(value); +#endif +#if __has_builtin(__builtin_amdgcn_is_private) + } +#endif +} + +/** + * @brief Safe double precision rmw atomic min. + * + * Performs a relaxed read-modify-write double precision atomic min with + * device memory scope. Original value at \p addr is returned and + * the value of \p addr is updated with \p val if lesser. + * + * @note This operation ensures that, on all targets, we produce safe atomics. + * This will be the case even when -munsafe-fp-atomics is passed into the + * compiler. + * + * @param [in,out] addr Pointer to value to be updated. + * @param [in] val Value used to updated the contents at \p addr + * @return Original value contained at \p addr. + */ +__device__ inline double safeAtomicMin(double *addr, double val) { +#if __has_builtin(__builtin_amdgcn_is_private) + if (__builtin_amdgcn_is_private( + (const __attribute__((address_space(0))) void *)addr)) { + double old = *addr; + *addr = __builtin_fmin(old, val); + return old; + } else { +#endif +#if __has_builtin(__hip_atomic_load) && \ + __has_builtin(__hip_atomic_compare_exchange_strong) + double value = + __hip_atomic_load(addr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + bool done = false; + while (!done && value > val) { + done = __hip_atomic_compare_exchange_strong( + addr, &value, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + return value; +#else + unsigned long long *uaddr = (unsigned long long *)addr; + unsigned long long value = __atomic_load_n(uaddr, __ATOMIC_RELAXED); + bool done = false; + while (!done && __longlong_as_double(value) > val) { + done = + __atomic_compare_exchange_n(uaddr, &value, __double_as_longlong(val), + false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); + } + return __longlong_as_double(value); +#endif +#if __has_builtin(__builtin_amdgcn_is_private) + } +#endif +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_vector_types.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_vector_types.h new file mode 100644 index 000000000..6782c105b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_hip_vector_types.h @@ -0,0 +1,1975 @@ +/* +Copyright (c) 2015 - 2022 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file amd_detail/hip_vector_types.h + * @brief Defines the different newt vector types for HIP runtime. + */ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_VECTOR_TYPES_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_VECTOR_TYPES_H + +#include "hip/amd_detail/host_defines.h" + +#if defined(__HIPCC_RTC__) +#define __HOST_DEVICE__ __device__ +#else +#define __HOST_DEVICE__ __host__ __device__ +#endif + +#if defined(__has_attribute) +#if __has_attribute(ext_vector_type) +#define __HIP_USE_NATIVE_VECTOR__ 1 +#define __NATIVE_VECTOR__(n, T) T __attribute__((ext_vector_type(n))) +#else +#define __NATIVE_VECTOR__(n, T) T[n] +#endif + +#if defined(__cplusplus) +#if !defined(__HIPCC_RTC__) +#include +#include +#include +#else +namespace std { +using ::size_t; + +template struct integral_constant { + static constexpr const _Tp value = __v; + typedef _Tp value_type; + typedef integral_constant type; + constexpr operator value_type() const { return value; } + constexpr value_type operator()() const { return value; } +}; +template +constexpr const _Tp integral_constant<_Tp, __v>::value; + +typedef integral_constant true_type; +typedef integral_constant false_type; + +template using bool_constant = integral_constant; +typedef bool_constant true_type; +typedef bool_constant false_type; + +template struct enable_if {}; +template struct enable_if { + typedef __T type; +}; + +template struct true_or_false_type : public false_type {}; +template <> struct true_or_false_type : public true_type {}; + +template struct is_integral : public false_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; + +template struct is_arithmetic : public false_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; + +template struct is_floating_point : public false_type {}; +template <> struct is_floating_point : public true_type {}; +template <> struct is_floating_point : public true_type {}; +template <> struct is_floating_point : public true_type {}; + +template struct is_same : public false_type {}; +template struct is_same<__T, __T> : public true_type {}; + +template ::value> +struct is_signed : public false_type {}; +template +struct is_signed<_Tp, true> : public true_or_false_type<_Tp(-1) < _Tp(0)> {}; + +template +struct is_convertible + : public true_or_false_type<__is_convertible_to(_T1, _T2)> {}; + +template struct char_traits; +template > +class basic_istream; +template > +class basic_ostream; +typedef basic_istream istream; +typedef basic_ostream ostream; + +template +struct is_scalar : public integral_constant {}; +} // Namespace std. +#endif // defined(__HIPCC_RTC__) + +namespace hip_impl { +inline constexpr unsigned int next_pot(unsigned int x) { + // Precondition: x > 1. + return 1u << (32u - __builtin_clz(x - 1u)); +} +} // Namespace hip_impl. + +template struct HIP_vector_base; + +template struct HIP_vector_base { + using Native_vec_ = __NATIVE_VECTOR__(1, T); + + union { + Native_vec_ data; + struct { + T x; + }; + }; + + using value_type = T; + + __HOST_DEVICE__ + HIP_vector_base() = default; + __HOST_DEVICE__ + explicit constexpr HIP_vector_base(T x_) noexcept : data{x_} {} + __HOST_DEVICE__ + constexpr HIP_vector_base(const HIP_vector_base &) = default; + __HOST_DEVICE__ + constexpr HIP_vector_base(HIP_vector_base &&) = default; + __HOST_DEVICE__ + ~HIP_vector_base() = default; + __HOST_DEVICE__ + HIP_vector_base &operator=(const HIP_vector_base &) = default; +}; + +template struct HIP_vector_base { + using Native_vec_ = __NATIVE_VECTOR__(2, T); + + union +#if !__has_attribute(ext_vector_type) + alignas(hip_impl::next_pot(2 * sizeof(T))) +#endif + { + Native_vec_ data; + struct { + T x; + T y; + }; + }; + + using value_type = T; + + __HOST_DEVICE__ + HIP_vector_base() = default; + __HOST_DEVICE__ + explicit constexpr HIP_vector_base(T x_) noexcept : data{x_, x_} {} + __HOST_DEVICE__ + constexpr HIP_vector_base(T x_, T y_) noexcept : data{x_, y_} {} + __HOST_DEVICE__ + constexpr HIP_vector_base(const HIP_vector_base &) = default; + __HOST_DEVICE__ + constexpr HIP_vector_base(HIP_vector_base &&) = default; + __HOST_DEVICE__ + ~HIP_vector_base() = default; + __HOST_DEVICE__ + HIP_vector_base &operator=(const HIP_vector_base &) = default; +}; + +template struct HIP_vector_base { + struct Native_vec_ { + T d[3]; + + __HOST_DEVICE__ + Native_vec_() = default; + + __HOST_DEVICE__ + explicit constexpr Native_vec_(T x_) noexcept : d{x_, x_, x_} {} + __HOST_DEVICE__ + constexpr Native_vec_(T x_, T y_, T z_) noexcept : d{x_, y_, z_} {} + __HOST_DEVICE__ + constexpr Native_vec_(const Native_vec_ &) = default; + __HOST_DEVICE__ + constexpr Native_vec_(Native_vec_ &&) = default; + __HOST_DEVICE__ + ~Native_vec_() = default; + + __HOST_DEVICE__ + Native_vec_ &operator=(const Native_vec_ &) = default; + __HOST_DEVICE__ + Native_vec_ &operator=(Native_vec_ &&) = default; + + __HOST_DEVICE__ + T &operator[](unsigned int idx) noexcept { return d[idx]; } + __HOST_DEVICE__ + T operator[](unsigned int idx) const noexcept { return d[idx]; } + + __HOST_DEVICE__ + Native_vec_ &operator+=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] += x_.d[i]; + return *this; + } + __HOST_DEVICE__ + Native_vec_ &operator-=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] -= x_.d[i]; + return *this; + } + + __HOST_DEVICE__ + Native_vec_ &operator*=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] *= x_.d[i]; + return *this; + } + __HOST_DEVICE__ + Native_vec_ &operator/=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] /= x_.d[i]; + return *this; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ Native_vec_ operator-() const noexcept { + auto r{*this}; + for (auto &&x : r.d) + x = -x; + return r; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ Native_vec_ operator~() const noexcept { + auto r{*this}; + for (auto &&x : r.d) + x = ~x; + return r; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ Native_vec_ &operator%=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] %= x_.d[i]; + return *this; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ Native_vec_ &operator^=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] ^= x_.d[i]; + return *this; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ Native_vec_ &operator|=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] |= x_.d[i]; + return *this; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ Native_vec_ &operator&=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] &= x_.d[i]; + return *this; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ Native_vec_ &operator>>=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] >>= x_.d[i]; + return *this; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ Native_vec_ &operator<<=(const Native_vec_ &x_) noexcept { + for (auto i = 0u; i != 3u; ++i) + d[i] <<= x_.d[i]; + return *this; + } +#if defined(__INTEL_COMPILER) + typedef struct { + int values[4]; + } _Vec3_cmp; + using Vec3_cmp = _Vec3_cmp; +#else + using Vec3_cmp = int __attribute__((vector_size(4 * sizeof(int)))); +#endif // INTEL + __HOST_DEVICE__ + Vec3_cmp operator==(const Native_vec_ &x_) const noexcept { + return Vec3_cmp{d[0] == x_.d[0], d[1] == x_.d[1], d[2] == x_.d[2]}; + } + }; + + union { + Native_vec_ data; + struct { + T x; + T y; + T z; + }; + }; + + using value_type = T; + + __HOST_DEVICE__ + HIP_vector_base() = default; + __HOST_DEVICE__ + explicit constexpr HIP_vector_base(T x_) noexcept : data{x_, x_, x_} {} + __HOST_DEVICE__ + constexpr HIP_vector_base(T x_, T y_, T z_) noexcept : data{x_, y_, z_} {} + __HOST_DEVICE__ + constexpr HIP_vector_base(const HIP_vector_base &) = default; + __HOST_DEVICE__ + constexpr HIP_vector_base(HIP_vector_base &&) = default; + __HOST_DEVICE__ + ~HIP_vector_base() = default; + + __HOST_DEVICE__ + HIP_vector_base &operator=(const HIP_vector_base &) = default; + __HOST_DEVICE__ + HIP_vector_base &operator=(HIP_vector_base &&) = default; +}; + +template struct HIP_vector_base { + using Native_vec_ = __NATIVE_VECTOR__(4, T); + + union +#if !__has_attribute(ext_vector_type) + alignas(hip_impl::next_pot(4 * sizeof(T))) +#endif + { + Native_vec_ data; + struct { + T x; + T y; + T z; + T w; + }; + }; + + using value_type = T; + + __HOST_DEVICE__ + HIP_vector_base() = default; + __HOST_DEVICE__ + explicit constexpr HIP_vector_base(T x_) noexcept : data{x_, x_, x_, x_} {} + __HOST_DEVICE__ + constexpr HIP_vector_base(T x_, T y_, T z_, T w_) noexcept + : data{x_, y_, z_, w_} {} + __HOST_DEVICE__ + constexpr HIP_vector_base(const HIP_vector_base &) = default; + __HOST_DEVICE__ + constexpr HIP_vector_base(HIP_vector_base &&) = default; + __HOST_DEVICE__ + ~HIP_vector_base() = default; + __HOST_DEVICE__ + HIP_vector_base &operator=(const HIP_vector_base &) = default; +}; + +template +struct HIP_vector_type : public HIP_vector_base { + using HIP_vector_base::data; + using typename HIP_vector_base::Native_vec_; + + __HOST_DEVICE__ + HIP_vector_type() = default; + template ::value>::type * = nullptr> + __HOST_DEVICE__ explicit constexpr HIP_vector_type(U x_) noexcept + : HIP_vector_base{static_cast(x_)} {} + template < // TODO: constrain based on type as well. + typename... Us, + typename std::enable_if<(rank > 1) && sizeof...(Us) == rank>::type * = + nullptr> + __HOST_DEVICE__ constexpr HIP_vector_type(Us... xs) noexcept + : HIP_vector_base{static_cast(xs)...} {} + __HOST_DEVICE__ + constexpr HIP_vector_type(const HIP_vector_type &) = default; + __HOST_DEVICE__ + constexpr HIP_vector_type(HIP_vector_type &&) = default; + __HOST_DEVICE__ + ~HIP_vector_type() = default; + + __HOST_DEVICE__ + HIP_vector_type &operator=(const HIP_vector_type &) = default; + __HOST_DEVICE__ + HIP_vector_type &operator=(HIP_vector_type &&) = default; + + // Operators + __HOST_DEVICE__ + HIP_vector_type &operator++() noexcept { return *this += HIP_vector_type{1}; } + __HOST_DEVICE__ + HIP_vector_type operator++(int) noexcept { + auto tmp(*this); + ++*this; + return tmp; + } + + __HOST_DEVICE__ + HIP_vector_type &operator--() noexcept { return *this -= HIP_vector_type{1}; } + __HOST_DEVICE__ + HIP_vector_type operator--(int) noexcept { + auto tmp(*this); + --*this; + return tmp; + } + + __HOST_DEVICE__ + HIP_vector_type &operator+=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data += x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] += x.data[i]; +#endif + return *this; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type &operator+=(U x) noexcept { + return *this += HIP_vector_type{x}; + } + + __HOST_DEVICE__ + HIP_vector_type &operator-=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data -= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] -= x.data[i]; +#endif + return *this; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type &operator-=(U x) noexcept { + return *this -= HIP_vector_type{x}; + } + + __HOST_DEVICE__ + HIP_vector_type &operator*=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data *= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] *= x.data[i]; +#endif + return *this; + } + + friend __HOST_DEVICE__ inline constexpr HIP_vector_type + operator*(HIP_vector_type x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} *= y; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type &operator*=(U x) noexcept { + return *this *= HIP_vector_type{x}; + } + + friend __HOST_DEVICE__ inline constexpr HIP_vector_type + operator/(HIP_vector_type x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} /= y; + } + + __HOST_DEVICE__ + HIP_vector_type &operator/=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data /= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] /= x.data[i]; +#endif + return *this; + } + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type &operator/=(U x) noexcept { + return *this /= HIP_vector_type{x}; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type operator-() const noexcept { + auto tmp(*this); +#if __HIP_USE_NATIVE_VECTOR__ + tmp.data = -tmp.data; +#else + for (auto i = 0u; i != rank; ++i) + tmp.data[i] = -tmp.data[i]; +#endif + return tmp; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type operator~() const noexcept { + HIP_vector_type r{*this}; +#if __HIP_USE_NATIVE_VECTOR__ + r.data = ~r.data; +#else + for (auto i = 0u; i != rank; ++i) + r.data[i] = ~r.data[i]; +#endif + return r; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type & + operator%=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data %= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] %= x.data[i]; +#endif + return *this; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type & + operator^=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data ^= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] ^= x.data[i]; +#endif + return *this; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type & + operator|=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data |= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] |= x.data[i]; +#endif + return *this; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type & + operator&=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data &= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] &= x.data[i]; +#endif + return *this; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type & + operator>>=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data >>= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] >>= x.data[i]; +#endif + return *this; + } + + template {}>::type * = nullptr> + __HOST_DEVICE__ HIP_vector_type & + operator<<=(const HIP_vector_type &x) noexcept { +#if __HIP_USE_NATIVE_VECTOR__ + data <<= x.data; +#else + for (auto i = 0u; i != rank; ++i) + data[i] <<= x.data[i]; +#endif + return *this; + } +}; + +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator+(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} += y; +} +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator+(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} += HIP_vector_type{y}; +} +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator+(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} += y; +} + +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator-(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} -= y; +} +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator-(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} -= HIP_vector_type{y}; +} +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator-(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} -= y; +} + +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator*(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} *= HIP_vector_type{y}; +} +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator*(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} *= y; +} + +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator/(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} /= HIP_vector_type{y}; +} +template +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator/(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} /= y; +} + +template +__HOST_DEVICE__ inline constexpr bool _hip_compare(const V &x, const V &y, + int n) noexcept { + return (n == -1) ? true + : ((x[n] != y[n]) ? false : _hip_compare(x, y, n - 1)); +} + +template +__HOST_DEVICE__ inline constexpr bool +operator==(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return _hip_compare(x.data, y.data, n - 1); +} +template +__HOST_DEVICE__ inline constexpr bool operator==(const HIP_vector_type &x, + U y) noexcept { + return x == HIP_vector_type{y}; +} +template +__HOST_DEVICE__ inline constexpr bool +operator==(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} == y; +} + +template +__HOST_DEVICE__ inline constexpr bool +operator!=(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return !(x == y); +} +template +__HOST_DEVICE__ inline constexpr bool operator!=(const HIP_vector_type &x, + U y) noexcept { + return !(x == y); +} +template +__HOST_DEVICE__ inline constexpr bool +operator!=(U x, const HIP_vector_type &y) noexcept { + return !(x == y); +} + +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator%(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} %= y; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator%(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} %= HIP_vector_type{y}; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator%(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} %= y; +} + +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator^(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} ^= y; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator^(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} ^= HIP_vector_type{y}; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator^(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} ^= y; +} + +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator|(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} |= y; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator|(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} |= HIP_vector_type{y}; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator|(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} |= y; +} + +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator&(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} &= y; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator&(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} &= HIP_vector_type{y}; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator&(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} &= y; +} + +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator>>(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} >>= y; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator>>(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} >>= HIP_vector_type{y}; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator>>(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} >>= y; +} + +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator<<(const HIP_vector_type &x, + const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} <<= y; +} +template {}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator<<(const HIP_vector_type &x, U y) noexcept { + return HIP_vector_type{x} <<= HIP_vector_type{y}; +} +template ::value>::type, + typename std::enable_if{}> * = nullptr> +__HOST_DEVICE__ inline constexpr HIP_vector_type +operator<<(U x, const HIP_vector_type &y) noexcept { + return HIP_vector_type{x} <<= y; +} + +/* + * Map HIP_vector_type to HIP_vector_type + */ +template +__forceinline__ __HOST_DEVICE__ + typename std::enable_if<(rankT == 1 && rankU >= 1), + const HIP_vector_type>::type + __hipMapVector(const HIP_vector_type &u) { + return HIP_vector_type(static_cast(u.x)); +}; + +template +__forceinline__ __HOST_DEVICE__ + typename std::enable_if<(rankT == 2 && rankU == 1), + const HIP_vector_type>::type + __hipMapVector(const HIP_vector_type &u) { + return HIP_vector_type(static_cast(u.x), static_cast(0)); +}; + +template +__forceinline__ __HOST_DEVICE__ + typename std::enable_if<(rankT == 2 && rankU >= 2), + const HIP_vector_type>::type + __hipMapVector(const HIP_vector_type &u) { + return HIP_vector_type(static_cast(u.x), static_cast(u.y)); +}; + +template +__forceinline__ __HOST_DEVICE__ + typename std::enable_if<(rankT == 4 && rankU == 1), + const HIP_vector_type>::type + __hipMapVector(const HIP_vector_type &u) { + return HIP_vector_type(static_cast(u.x), static_cast(0), + static_cast(0), static_cast(0)); +}; + +template +__forceinline__ __HOST_DEVICE__ + typename std::enable_if<(rankT == 4 && rankU == 2), + const HIP_vector_type>::type + __hipMapVector(const HIP_vector_type &u) { + return HIP_vector_type(static_cast(u.x), static_cast(u.y), + static_cast(0), static_cast(0)); +}; + +template +__forceinline__ __HOST_DEVICE__ + typename std::enable_if<(rankT == 4 && rankU == 4), + const HIP_vector_type>::type + __hipMapVector(const HIP_vector_type &u) { + return HIP_vector_type(static_cast(u.x), static_cast(u.y), + static_cast(u.z), static_cast(u.w)); +}; + +#define __MAKE_VECTOR_TYPE__(CUDA_name, T) \ + using CUDA_name##1 = HIP_vector_type; \ + using CUDA_name##2 = HIP_vector_type; \ + using CUDA_name##3 = HIP_vector_type; \ + using CUDA_name##4 = HIP_vector_type; +#else +#define __MAKE_VECTOR_TYPE__(CUDA_name, T) \ + typedef struct { \ + T x; \ + } CUDA_name##1; \ + typedef struct { \ + T x; \ + T y; \ + } CUDA_name##2; \ + typedef struct { \ + T x; \ + T y; \ + T z; \ + } CUDA_name##3; \ + typedef struct { \ + T x; \ + T y; \ + T z; \ + T w; \ + } CUDA_name##4; +#endif + +__MAKE_VECTOR_TYPE__(uchar, unsigned char); +__MAKE_VECTOR_TYPE__(char, char); +__MAKE_VECTOR_TYPE__(ushort, unsigned short); +__MAKE_VECTOR_TYPE__(short, short); +__MAKE_VECTOR_TYPE__(uint, unsigned int); +__MAKE_VECTOR_TYPE__(int, int); +__MAKE_VECTOR_TYPE__(ulong, unsigned long); +__MAKE_VECTOR_TYPE__(long, long); +__MAKE_VECTOR_TYPE__(ulonglong, unsigned long long); +__MAKE_VECTOR_TYPE__(longlong, long long); +__MAKE_VECTOR_TYPE__(float, float); +__MAKE_VECTOR_TYPE__(double, double); + +#else // !defined(__has_attribute) + +#if defined(_MSC_VER) +#include +#include +#include +#include + +/* +this is for compatibility with CUDA as CUDA allows accessing vector components +in C++ program with MSVC +*/ +typedef union { + struct { + char x; + }; + char data; +} char1; +typedef union { + struct { + char x; + char y; + }; + char data[2]; +} char2; +typedef union { + struct { + char x; + char y; + char z; + char w; + }; + char data[4]; +} char4; +typedef union { + struct { + char x; + char y; + char z; + }; + char data[3]; +} char3; +typedef union { + __m64 data; +} char8; +typedef union { + __m128i data; +} char16; + +typedef union { + struct { + unsigned char x; + }; + unsigned char data; +} uchar1; +typedef union { + struct { + unsigned char x; + unsigned char y; + }; + unsigned char data[2]; +} uchar2; +typedef union { + struct { + unsigned char x; + unsigned char y; + unsigned char z; + unsigned char w; + }; + unsigned char data[4]; +} uchar4; +typedef union { + struct { + unsigned char x; + unsigned char y; + unsigned char z; + }; + unsigned char data[3]; +} uchar3; +typedef union { + __m64 data; +} uchar8; +typedef union { + __m128i data; +} uchar16; + +typedef union { + struct { + short x; + }; + short data; +} short1; +typedef union { + struct { + short x; + short y; + }; + short data[2]; +} short2; +typedef union { + struct { + short x; + short y; + short z; + short w; + }; + __m64 data; +} short4; +typedef union { + struct { + short x; + short y; + short z; + }; + short data[3]; +} short3; +typedef union { + __m128i data; +} short8; +typedef union { + __m128i data[2]; +} short16; + +typedef union { + struct { + unsigned short x; + }; + unsigned short data; +} ushort1; +typedef union { + struct { + unsigned short x; + unsigned short y; + }; + unsigned short data[2]; +} ushort2; +typedef union { + struct { + unsigned short x; + unsigned short y; + unsigned short z; + unsigned short w; + }; + __m64 data; +} ushort4; +typedef union { + struct { + unsigned short x; + unsigned short y; + unsigned short z; + }; + unsigned short data[3]; +} ushort3; +typedef union { + __m128i data; +} ushort8; +typedef union { + __m128i data[2]; +} ushort16; + +typedef union { + struct { + int x; + }; + int data; +} int1; +typedef union { + struct { + int x; + int y; + }; + __m64 data; +} int2; +typedef union { + struct { + int x; + int y; + int z; + int w; + }; + __m128i data; +} int4; +typedef union { + struct { + int x; + int y; + int z; + }; + int data[3]; +} int3; +typedef union { + __m128i data[2]; +} int8; +typedef union { + __m128i data[4]; +} int16; + +typedef union { + struct { + unsigned int x; + }; + unsigned int data; +} uint1; +typedef union { + struct { + unsigned int x; + unsigned int y; + }; + __m64 data; +} uint2; +typedef union { + struct { + unsigned int x; + unsigned int y; + unsigned int z; + unsigned int w; + }; + __m128i data; +} uint4; +typedef union { + struct { + unsigned int x; + unsigned int y; + unsigned int z; + }; + unsigned int data[3]; +} uint3; +typedef union { + __m128i data[2]; +} uint8; +typedef union { + __m128i data[4]; +} uint16; + +typedef union { + struct { + int x; + }; + int data; +} long1; +typedef union { + struct { + int x; + int y; + }; + __m64 data; +} long2; +typedef union { + struct { + int x; + int y; + int z; + int w; + }; + __m128i data; +} long4; +typedef union { + struct { + int x; + int y; + int z; + }; + int data[3]; +} long3; +typedef union { + __m128i data[2]; +} long8; +typedef union { + __m128i data[4]; +} long16; + +typedef union { + struct { + unsigned int x; + }; + unsigned int data; +} ulong1; +typedef union { + struct { + unsigned int x; + unsigned int y; + }; + __m64 data; +} ulong2; +typedef union { + struct { + unsigned int x; + unsigned int y; + unsigned int z; + unsigned int w; + }; + __m128i data; +} ulong4; +typedef union { + struct { + unsigned int x; + unsigned int y; + unsigned int z; + }; + unsigned int data[3]; +} ulong3; +typedef union { + __m128i data[2]; +} ulong8; +typedef union { + __m128i data[4]; +} ulong16; + +typedef union { + struct { + long long x; + }; + __m64 data; +} longlong1; +typedef union { + struct { + long long x; + long long y; + }; + __m128i data; +} longlong2; +typedef union { + struct { + long long x; + long long y; + long long z; + long long w; + }; + __m128i data[2]; +} longlong4; +typedef union { + struct { + long long x; + long long y; + long long z; + }; + __m64 data[3]; +} longlong3; +typedef union { + __m128i data[4]; +} longlong8; +typedef union { + __m128i data[8]; +} longlong16; + +typedef union { + struct { + __m64 x; + }; + __m64 data; +} ulonglong1; +typedef union { + struct { + __m64 x; + __m64 y; + }; + __m128i data; +} ulonglong2; +typedef union { + struct { + __m64 x; + __m64 y; + __m64 z; + __m64 w; + }; + __m128i data[2]; +} ulonglong4; +typedef union { + struct { + __m64 x; + __m64 y; + __m64 z; + }; + __m64 data[3]; +} ulonglong3; +typedef union { + __m128i data[4]; +} ulonglong8; +typedef union { + __m128i data[8]; +} ulonglong16; + +typedef union { + struct { + float x; + }; + float data; +} float1; +typedef union { + struct { + float x; + float y; + }; + __m64 data; +} float2; +typedef union { + struct { + float x; + float y; + float z; + float w; + }; + __m128 data; +} float4; +typedef union { + struct { + float x; + float y; + float z; + }; + float data[3]; +} float3; +typedef union { + __m256 data; +} float8; +typedef union { + __m256 data[2]; +} float16; + +typedef union { + struct { + double x; + }; + double data; +} double1; +typedef union { + struct { + double x; + double y; + }; + __m128d data; +} double2; +typedef union { + struct { + double x; + double y; + double z; + double w; + }; + __m256d data; +} double4; +typedef union { + struct { + double x; + double y; + double z; + }; + double data[3]; +} double3; +typedef union { + __m256d data[2]; +} double8; +typedef union { + __m256d data[4]; +} double16; + +#else // !defined(_MSC_VER) + +/* +this is for compatibility with CUDA as CUDA allows accessing vector components +in C++ program with MSVC +*/ +typedef union { + struct { + char x; + }; + char data; +} char1; +typedef union { + struct { + char x; + char y; + }; + char data[2]; +} char2; +typedef union { + struct { + char x; + char y; + char z; + char w; + }; + char data[4]; +} char4; +typedef union { + char data[8]; +} char8; +typedef union { + char data[16]; +} char16; +typedef union { + struct { + char x; + char y; + char z; + }; + char data[3]; +} char3; + +typedef union { + struct { + unsigned char x; + }; + unsigned char data; +} uchar1; +typedef union { + struct { + unsigned char x; + unsigned char y; + }; + unsigned char data[2]; +} uchar2; +typedef union { + struct { + unsigned char x; + unsigned char y; + unsigned char z; + unsigned char w; + }; + unsigned char data[4]; +} uchar4; +typedef union { + unsigned char data[8]; +} uchar8; +typedef union { + unsigned char data[16]; +} uchar16; +typedef union { + struct { + unsigned char x; + unsigned char y; + unsigned char z; + }; + unsigned char data[3]; +} uchar3; + +typedef union { + struct { + short x; + }; + short data; +} short1; +typedef union { + struct { + short x; + short y; + }; + short data[2]; +} short2; +typedef union { + struct { + short x; + short y; + short z; + short w; + }; + short data[4]; +} short4; +typedef union { + short data[8]; +} short8; +typedef union { + short data[16]; +} short16; +typedef union { + struct { + short x; + short y; + short z; + }; + short data[3]; +} short3; + +typedef union { + struct { + unsigned short x; + }; + unsigned short data; +} ushort1; +typedef union { + struct { + unsigned short x; + unsigned short y; + }; + unsigned short data[2]; +} ushort2; +typedef union { + struct { + unsigned short x; + unsigned short y; + unsigned short z; + unsigned short w; + }; + unsigned short data[4]; +} ushort4; +typedef union { + unsigned short data[8]; +} ushort8; +typedef union { + unsigned short data[16]; +} ushort16; +typedef union { + struct { + unsigned short x; + unsigned short y; + unsigned short z; + }; + unsigned short data[3]; +} ushort3; + +typedef union { + struct { + int x; + }; + int data; +} int1; +typedef union { + struct { + int x; + int y; + }; + int data[2]; +} int2; +typedef union { + struct { + int x; + int y; + int z; + int w; + }; + int data[4]; +} int4; +typedef union { + int data[8]; +} int8; +typedef union { + int data[16]; +} int16; +typedef union { + struct { + int x; + int y; + int z; + }; + int data[3]; +} int3; + +typedef union { + struct { + unsigned int x; + }; + unsigned int data; +} uint1; +typedef union { + struct { + unsigned int x; + unsigned int y; + }; + unsigned int data[2]; +} uint2; +typedef union { + struct { + unsigned int x; + unsigned int y; + unsigned int z; + unsigned int w; + }; + unsigned int data[4]; +} uint4; +typedef union { + unsigned int data[8]; +} uint8; +typedef union { + unsigned int data[16]; +} uint16; +typedef union { + struct { + unsigned int x; + unsigned int y; + unsigned int z; + }; + unsigned int data[3]; +} uint3; + +typedef union { + struct { + long x; + }; + long data; +} long1; +typedef union { + struct { + long x; + long y; + }; + long data[2]; +} long2; +typedef union { + struct { + long x; + long y; + long z; + long w; + }; + long data[4]; +} long4; +typedef union { + long data[8]; +} long8; +typedef union { + long data[16]; +} long16; +typedef union { + struct { + long x; + long y; + long z; + }; + long data[3]; +} long3; + +typedef union { + struct { + unsigned long x; + }; + unsigned long data; +} ulong1; +typedef union { + struct { + unsigned long x; + unsigned long y; + }; + unsigned long data[2]; +} ulong2; +typedef union { + struct { + unsigned long x; + unsigned long y; + unsigned long z; + unsigned long w; + }; + unsigned long data[4]; +} ulong4; +typedef union { + unsigned long data[8]; +} ulong8; +typedef union { + unsigned long data[16]; +} ulong16; +typedef union { + struct { + unsigned long x; + unsigned long y; + unsigned long z; + }; + unsigned long data[3]; +} ulong3; + +typedef union { + struct { + long long x; + }; + long long data; +} longlong1; +typedef union { + struct { + long long x; + long long y; + }; + long long data[2]; +} longlong2; +typedef union { + struct { + long long x; + long long y; + long long z; + long long w; + }; + long long data[4]; +} longlong4; +typedef union { + long long data[8]; +} longlong8; +typedef union { + long long data[16]; +} longlong16; +typedef union { + struct { + long long x; + long long y; + long long z; + }; + long long data[3]; +} longlong3; + +typedef union { + struct { + unsigned long long x; + }; + unsigned long long data; +} ulonglong1; +typedef union { + struct { + unsigned long long x; + unsigned long long y; + }; + unsigned long long data[2]; +} ulonglong2; +typedef union { + struct { + unsigned long long x; + unsigned long long y; + unsigned long long z; + unsigned long long w; + }; + unsigned long long data[4]; +} ulonglong4; +typedef union { + unsigned long long data[8]; +} ulonglong8; +typedef union { + unsigned long long data[16]; +} ulonglong16; +typedef union { + struct { + unsigned long long x; + unsigned long long y; + unsigned long long z; + }; + unsigned long long data[3]; +} ulonglong3; + +typedef union { + struct { + float x; + }; + float data; +} float1; +typedef union { + struct { + float x; + float y; + }; + float data[2]; +} float2; +typedef union { + struct { + float x; + float y; + float z; + float w; + }; + float data[4]; +} float4; +typedef union { + float data[8]; +} float8; +typedef union { + float data[16]; +} float16; +typedef union { + struct { + float x; + float y; + float z; + }; + float data[3]; +} float3; + +typedef union { + struct { + double x; + }; + double data; +} double1; +typedef union { + struct { + double x; + double y; + }; + double data[2]; +} double2; +typedef union { + struct { + double x; + double y; + double z; + double w; + }; + double data[4]; +} double4; +typedef union { + double data[8]; +} double8; +typedef union { + double data[16]; +} double16; +typedef union { + struct { + double x; + double y; + double z; + }; + double data[3]; +} double3; + +#endif // defined(_MSC_VER) +#endif // defined(__has_attribute) + +#ifdef __cplusplus +#define DECLOP_MAKE_ONE_COMPONENT(comp, type) \ + static inline __HOST_DEVICE__ type make_##type(comp x) { \ + type r{x}; \ + return r; \ + } + +#define DECLOP_MAKE_TWO_COMPONENT(comp, type) \ + static inline __HOST_DEVICE__ type make_##type(comp x, comp y) { \ + type r{x, y}; \ + return r; \ + } + +#define DECLOP_MAKE_THREE_COMPONENT(comp, type) \ + static inline __HOST_DEVICE__ type make_##type(comp x, comp y, comp z) { \ + type r{x, y, z}; \ + return r; \ + } + +#define DECLOP_MAKE_FOUR_COMPONENT(comp, type) \ + static inline __HOST_DEVICE__ type make_##type(comp x, comp y, comp z, \ + comp w) { \ + type r{x, y, z, w}; \ + return r; \ + } +#else +#define DECLOP_MAKE_ONE_COMPONENT(comp, type) \ + static inline __HOST_DEVICE__ type make_##type(comp x) { \ + type r; \ + r.x = x; \ + return r; \ + } + +#define DECLOP_MAKE_TWO_COMPONENT(comp, type) \ + static inline __HOST_DEVICE__ type make_##type(comp x, comp y) { \ + type r; \ + r.x = x; \ + r.y = y; \ + return r; \ + } + +#define DECLOP_MAKE_THREE_COMPONENT(comp, type) \ + static inline __HOST_DEVICE__ type make_##type(comp x, comp y, comp z) { \ + type r; \ + r.x = x; \ + r.y = y; \ + r.z = z; \ + return r; \ + } + +#define DECLOP_MAKE_FOUR_COMPONENT(comp, type) \ + static inline __HOST_DEVICE__ type make_##type(comp x, comp y, comp z, \ + comp w) { \ + type r; \ + r.x = x; \ + r.y = y; \ + r.z = z; \ + r.w = w; \ + return r; \ + } +#endif + +DECLOP_MAKE_ONE_COMPONENT(unsigned char, uchar1); +DECLOP_MAKE_TWO_COMPONENT(unsigned char, uchar2); +DECLOP_MAKE_THREE_COMPONENT(unsigned char, uchar3); +DECLOP_MAKE_FOUR_COMPONENT(unsigned char, uchar4); + +DECLOP_MAKE_ONE_COMPONENT(signed char, char1); +DECLOP_MAKE_TWO_COMPONENT(signed char, char2); +DECLOP_MAKE_THREE_COMPONENT(signed char, char3); +DECLOP_MAKE_FOUR_COMPONENT(signed char, char4); + +DECLOP_MAKE_ONE_COMPONENT(unsigned short, ushort1); +DECLOP_MAKE_TWO_COMPONENT(unsigned short, ushort2); +DECLOP_MAKE_THREE_COMPONENT(unsigned short, ushort3); +DECLOP_MAKE_FOUR_COMPONENT(unsigned short, ushort4); + +DECLOP_MAKE_ONE_COMPONENT(signed short, short1); +DECLOP_MAKE_TWO_COMPONENT(signed short, short2); +DECLOP_MAKE_THREE_COMPONENT(signed short, short3); +DECLOP_MAKE_FOUR_COMPONENT(signed short, short4); + +DECLOP_MAKE_ONE_COMPONENT(unsigned int, uint1); +DECLOP_MAKE_TWO_COMPONENT(unsigned int, uint2); +DECLOP_MAKE_THREE_COMPONENT(unsigned int, uint3); +DECLOP_MAKE_FOUR_COMPONENT(unsigned int, uint4); + +DECLOP_MAKE_ONE_COMPONENT(signed int, int1); +DECLOP_MAKE_TWO_COMPONENT(signed int, int2); +DECLOP_MAKE_THREE_COMPONENT(signed int, int3); +DECLOP_MAKE_FOUR_COMPONENT(signed int, int4); + +DECLOP_MAKE_ONE_COMPONENT(float, float1); +DECLOP_MAKE_TWO_COMPONENT(float, float2); +DECLOP_MAKE_THREE_COMPONENT(float, float3); +DECLOP_MAKE_FOUR_COMPONENT(float, float4); + +DECLOP_MAKE_ONE_COMPONENT(double, double1); +DECLOP_MAKE_TWO_COMPONENT(double, double2); +DECLOP_MAKE_THREE_COMPONENT(double, double3); +DECLOP_MAKE_FOUR_COMPONENT(double, double4); + +DECLOP_MAKE_ONE_COMPONENT(unsigned long, ulong1); +DECLOP_MAKE_TWO_COMPONENT(unsigned long, ulong2); +DECLOP_MAKE_THREE_COMPONENT(unsigned long, ulong3); +DECLOP_MAKE_FOUR_COMPONENT(unsigned long, ulong4); + +DECLOP_MAKE_ONE_COMPONENT(signed long, long1); +DECLOP_MAKE_TWO_COMPONENT(signed long, long2); +DECLOP_MAKE_THREE_COMPONENT(signed long, long3); +DECLOP_MAKE_FOUR_COMPONENT(signed long, long4); + +DECLOP_MAKE_ONE_COMPONENT(unsigned long long, ulonglong1); +DECLOP_MAKE_TWO_COMPONENT(unsigned long long, ulonglong2); +DECLOP_MAKE_THREE_COMPONENT(unsigned long long, ulonglong3); +DECLOP_MAKE_FOUR_COMPONENT(unsigned long long, ulonglong4); + +DECLOP_MAKE_ONE_COMPONENT(signed long long, longlong1); +DECLOP_MAKE_TWO_COMPONENT(signed long long, longlong2); +DECLOP_MAKE_THREE_COMPONENT(signed long long, longlong3); +DECLOP_MAKE_FOUR_COMPONENT(signed long long, longlong4); + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_math_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_math_functions.h new file mode 100644 index 000000000..97830b288 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_math_functions.h @@ -0,0 +1,96 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#if !defined(__HIPCC_RTC__) +#include "amd_hip_vector_types.h" +#include "hip_fp16_math_fwd.h" +#include "math_fwd.h" + +#include + +#include +// assert.h is only for the host version of assert. +// The device version of assert is implemented in hip/amd_detail/hip_runtime.h. +// Users should include hip_runtime.h for the device version of assert. +#if !__HIP_DEVICE_COMPILE__ +#include +#endif +#include +#include +#include +#endif // !defined(__HIPCC_RTC__) + +#if _LIBCPP_VERSION && __HIP__ +namespace std { +template <> struct __numeric_type<_Float16> { + static _Float16 __test(_Float16); + + typedef _Float16 type; + static const bool value = true; +}; +} // namespace std +#endif // _LIBCPP_VERSION + +#pragma push_macro("__DEVICE__") +#pragma push_macro("__RETURN_TYPE") + +#define __DEVICE__ static __device__ +#define __RETURN_TYPE bool + +// DOT FUNCTIONS +#if defined(__clang__) && defined(__HIP__) +__DEVICE__ +inline int amd_mixed_dot(short2 a, short2 b, int c, bool saturate) { + return __ockl_sdot2(a.data, b.data, c, saturate); +} +__DEVICE__ +inline uint amd_mixed_dot(ushort2 a, ushort2 b, uint c, bool saturate) { + return __ockl_udot2(a.data, b.data, c, saturate); +} +__DEVICE__ +inline int amd_mixed_dot(char4 a, char4 b, int c, bool saturate) { + return __ockl_sdot4(a.data, b.data, c, saturate); +} +__DEVICE__ +inline uint amd_mixed_dot(uchar4 a, uchar4 b, uint c, bool saturate) { + return __ockl_udot4(a.data, b.data, c, saturate); +} +__DEVICE__ +inline int amd_mixed_dot(int a, int b, int c, bool saturate) { + return __ockl_sdot8(a, b, c, saturate); +} +__DEVICE__ +inline uint amd_mixed_dot(uint a, uint b, uint c, bool saturate) { + return __ockl_udot8(a, b, c, saturate); +} +#endif + +#pragma pop_macro("__DEVICE__") +#pragma pop_macro("__RETURN_TYPE") +// For backward compatibility. +// There are HIP applications e.g. TensorFlow, expecting __HIP_ARCH_* macros +// defined after including math_functions.h. +#if !defined(__HIPCC_RTC__) +#include +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_surface_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_surface_functions.h new file mode 100644 index 000000000..5ca955171 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_surface_functions.h @@ -0,0 +1,264 @@ +/* +Copyright (c) 2018 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_SURFACE_FUNCTIONS_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_SURFACE_FUNCTIONS_H + +#if defined(__cplusplus) + +#if !defined(__HIPCC_RTC__) +#include +#include +#include +#include +#endif + +#if defined(__HIPCC_RTC__) +#define __HOST_DEVICE__ __device__ +#else +#define __HOST_DEVICE__ __host__ __device__ +#endif + +#define __HIP_SURFACE_OBJECT_PARAMETERS_INIT \ + unsigned int ADDRESS_SPACE_CONSTANT *i = \ + (unsigned int ADDRESS_SPACE_CONSTANT *)surfObj; + +// CUDA is using byte address, need map to pixel address for HIP +static __HOST_DEVICE__ __forceinline__ int __hipGetPixelAddr(int x, int format, + int order) { + /* + * use below format index to generate format LUT + typedef enum { + HSA_EXT_IMAGE_CHANNEL_TYPE_SNORM_INT8 = 0, + HSA_EXT_IMAGE_CHANNEL_TYPE_SNORM_INT16 = 1, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_INT8 = 2, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_INT16 = 3, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_INT24 = 4, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_SHORT_555 = 5, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_SHORT_565 = 6, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_SHORT_101010 = 7, + HSA_EXT_IMAGE_CHANNEL_TYPE_SIGNED_INT8 = 8, + HSA_EXT_IMAGE_CHANNEL_TYPE_SIGNED_INT16 = 9, + HSA_EXT_IMAGE_CHANNEL_TYPE_SIGNED_INT32 = 10, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8 = 11, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16 = 12, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32 = 13, + HSA_EXT_IMAGE_CHANNEL_TYPE_HALF_FLOAT = 14, + HSA_EXT_IMAGE_CHANNEL_TYPE_FLOAT = 15 + } hsa_ext_image_channel_type_t; + */ + static const int FormatLUT[] = {0, 1, 0, 1, 3, 1, 1, 1, + 0, 1, 2, 0, 1, 2, 1, 2}; + x = FormatLUT[format] == 3 ? x / FormatLUT[format] : x >> FormatLUT[format]; + + /* + * use below order index to generate order LUT + typedef enum { + HSA_EXT_IMAGE_CHANNEL_ORDER_A = 0, + HSA_EXT_IMAGE_CHANNEL_ORDER_R = 1, + HSA_EXT_IMAGE_CHANNEL_ORDER_RX = 2, + HSA_EXT_IMAGE_CHANNEL_ORDER_RG = 3, + HSA_EXT_IMAGE_CHANNEL_ORDER_RGX = 4, + HSA_EXT_IMAGE_CHANNEL_ORDER_RA = 5, + HSA_EXT_IMAGE_CHANNEL_ORDER_RGB = 6, + HSA_EXT_IMAGE_CHANNEL_ORDER_RGBX = 7, + HSA_EXT_IMAGE_CHANNEL_ORDER_RGBA = 8, + HSA_EXT_IMAGE_CHANNEL_ORDER_BGRA = 9, + HSA_EXT_IMAGE_CHANNEL_ORDER_ARGB = 10, + HSA_EXT_IMAGE_CHANNEL_ORDER_ABGR = 11, + HSA_EXT_IMAGE_CHANNEL_ORDER_SRGB = 12, + HSA_EXT_IMAGE_CHANNEL_ORDER_SRGBX = 13, + HSA_EXT_IMAGE_CHANNEL_ORDER_SRGBA = 14, + HSA_EXT_IMAGE_CHANNEL_ORDER_SBGRA = 15, + HSA_EXT_IMAGE_CHANNEL_ORDER_INTENSITY = 16, + HSA_EXT_IMAGE_CHANNEL_ORDER_LUMINANCE = 17, + HSA_EXT_IMAGE_CHANNEL_ORDER_DEPTH = 18, + HSA_EXT_IMAGE_CHANNEL_ORDER_DEPTH_STENCIL = 19 + } hsa_ext_image_channel_order_t; + */ + static const int OrderLUT[] = {0, 0, 1, 1, 3, 1, 3, 2, 2, 2, + 2, 2, 3, 2, 2, 2, 0, 0, 0, 0}; + return x = OrderLUT[order] == 3 ? x / OrderLUT[order] : x >> OrderLUT[order]; +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf1Dread(T *data, hipSurfaceObject_t surfObj, int x, + int boundaryMode = hipBoundaryModeZero) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_1D(i), + __ockl_image_channel_order_1D(i)); + auto tmp = __ockl_image_load_1D(i, x); + *data = __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf1Dwrite(T data, hipSurfaceObject_t surfObj, int x) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_1D(i), + __ockl_image_channel_order_1D(i)); + auto tmp = __hipMapTo(data); + __ockl_image_store_1D(i, x, tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf2Dread(T *data, hipSurfaceObject_t surfObj, int x, int y) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_2D(i), + __ockl_image_channel_order_2D(i)); + auto tmp = __ockl_image_load_2D(i, int2(x, y).data); + *data = __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf2Dwrite(T data, hipSurfaceObject_t surfObj, int x, int y) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_2D(i), + __ockl_image_channel_order_2D(i)); + auto tmp = __hipMapTo(data); + __ockl_image_store_2D(i, int2(x, y).data, tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf3Dread(T *data, hipSurfaceObject_t surfObj, int x, int y, int z) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_3D(i), + __ockl_image_channel_order_3D(i)); + auto tmp = __ockl_image_load_3D(i, int4(x, y, z, 0).data); + *data = __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf3Dwrite(T data, hipSurfaceObject_t surfObj, int x, int y, int z) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_3D(i), + __ockl_image_channel_order_3D(i)); + auto tmp = __hipMapTo(data); + __ockl_image_store_3D(i, int4(x, y, z, 0).data, tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf1DLayeredread(T *data, hipSurfaceObject_t surfObj, int x, int layer) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_1D(i), + __ockl_image_channel_order_1D(i)); + auto tmp = __ockl_image_load_lod_1D(i, x, layer); + *data = __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf1DLayeredwrite(T data, hipSurfaceObject_t surfObj, int x, int layer) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_1D(i), + __ockl_image_channel_order_1D(i)); + auto tmp = __hipMapTo(data); + __ockl_image_store_lod_1D(i, x, layer, tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf2DLayeredread(T *data, hipSurfaceObject_t surfObj, int x, int y, + int layer) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_2D(i), + __ockl_image_channel_order_2D(i)); + auto tmp = __ockl_image_load_lod_2D(i, int2(x, y).data, layer); + *data = __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surf2DLayeredwrite(T data, hipSurfaceObject_t surfObj, int x, int y, + int layer) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_2D(i), + __ockl_image_channel_order_2D(i)); + auto tmp = __hipMapTo(data); + __ockl_image_store_lod_2D(i, int2(x, y).data, layer, tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surfCubemapread(T *data, hipSurfaceObject_t surfObj, int x, int y, int face) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_2D(i), + __ockl_image_channel_order_2D(i)); + auto tmp = __ockl_image_load_CM(i, int2(x, y).data, face); + *data = __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surfCubemapwrite(T data, hipSurfaceObject_t surfObj, int x, int y, int face) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_2D(i), + __ockl_image_channel_order_2D(i)); + auto tmp = __hipMapTo(data); + __ockl_image_store_CM(i, int2(x, y).data, face, tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surfCubemapLayeredread(T *data, hipSurfaceObject_t surfObj, int x, int y, + int face, int layer) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_2D(i), + __ockl_image_channel_order_2D(i)); + auto tmp = __ockl_image_load_lod_CM(i, int2(x, y).data, face, layer); + *data = __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +surfCubemapLayeredwrite(T *data, hipSurfaceObject_t surfObj, int x, int y, + int face, int layer) { + __HIP_SURFACE_OBJECT_PARAMETERS_INIT + x = __hipGetPixelAddr(x, __ockl_image_channel_data_type_2D(i), + __ockl_image_channel_order_2D(i)); + auto tmp = __hipMapTo(data); + __ockl_image_store_lod_CM(i, int2(x, y).data, face, layer, tmp); +} + +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_warp_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_warp_functions.h new file mode 100644 index 000000000..c065ae92b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_warp_functions.h @@ -0,0 +1,612 @@ +/* +Copyright (c) 2022 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_WARP_FUNCTIONS_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_WARP_FUNCTIONS_H + +__device__ static inline unsigned __hip_ds_bpermute(int index, unsigned src) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.u = src; + tmp.i = __builtin_amdgcn_ds_bpermute(index, tmp.i); + return tmp.u; +} + +__device__ static inline float __hip_ds_bpermutef(int index, float src) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.f = src; + tmp.i = __builtin_amdgcn_ds_bpermute(index, tmp.i); + return tmp.f; +} + +__device__ static inline unsigned __hip_ds_permute(int index, unsigned src) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.u = src; + tmp.i = __builtin_amdgcn_ds_permute(index, tmp.i); + return tmp.u; +} + +__device__ static inline float __hip_ds_permutef(int index, float src) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.f = src; + tmp.i = __builtin_amdgcn_ds_permute(index, tmp.i); + return tmp.f; +} + +#define __hip_ds_swizzle(src, pattern) __hip_ds_swizzle_N<(pattern)>((src)) +#define __hip_ds_swizzlef(src, pattern) __hip_ds_swizzlef_N<(pattern)>((src)) + +template +__device__ static inline unsigned __hip_ds_swizzle_N(unsigned int src) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.u = src; + tmp.i = __builtin_amdgcn_ds_swizzle(tmp.i, pattern); + return tmp.u; +} + +template +__device__ static inline float __hip_ds_swizzlef_N(float src) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.f = src; + tmp.i = __builtin_amdgcn_ds_swizzle(tmp.i, pattern); + return tmp.f; +} + +#define __hip_move_dpp(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \ + __hip_move_dpp_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src)) + +template +__device__ static inline int __hip_move_dpp_N(int src) { + return __builtin_amdgcn_mov_dpp(src, dpp_ctrl, row_mask, bank_mask, + bound_ctrl); +} + +static constexpr int warpSize = __AMDGCN_WAVEFRONT_SIZE; + +// warp vote function __all __any __ballot +__device__ inline int __all(int predicate) { + return __ockl_wfall_i32(predicate); +} + +__device__ inline int __any(int predicate) { + return __ockl_wfany_i32(predicate); +} + +// XXX from llvm/include/llvm/IR/InstrTypes.h +#define ICMP_NE 33 + +__device__ inline unsigned long long int __ballot(int predicate) { + return __builtin_amdgcn_uicmp(predicate, 0, ICMP_NE); +} + +__device__ inline unsigned long long int __ballot64(int predicate) { + return __builtin_amdgcn_uicmp(predicate, 0, ICMP_NE); +} + +// See amd_warp_sync_functions.h for an explanation of this preprocessor flag. +#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS +// Since threads in a wave do not make independent progress, __activemask() +// always returns the exact active mask, i.e, all active threads in the wave. +__device__ inline unsigned long long __activemask() { return __ballot(true); } +#endif // HIP_ENABLE_WARP_SYNC_BUILTINS + +__device__ static inline unsigned int __lane_id() { + return __builtin_amdgcn_mbcnt_hi(-1, __builtin_amdgcn_mbcnt_lo(-1, 0)); +} + +__device__ inline int __shfl(int var, int src_lane, int width = warpSize) { + int self = __lane_id(); + int index = (src_lane & (width - 1)) + (self & ~(width - 1)); + return __builtin_amdgcn_ds_bpermute(index << 2, var); +} +__device__ inline unsigned int __shfl(unsigned int var, int src_lane, + int width = warpSize) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.u = var; + tmp.i = __shfl(tmp.i, src_lane, width); + return tmp.u; +} +__device__ inline float __shfl(float var, int src_lane, int width = warpSize) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.f = var; + tmp.i = __shfl(tmp.i, src_lane, width); + return tmp.f; +} +__device__ inline double __shfl(double var, int src_lane, + int width = warpSize) { + static_assert(sizeof(double) == 2 * sizeof(int), ""); + static_assert(sizeof(double) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl(tmp[0], src_lane, width); + tmp[1] = __shfl(tmp[1], src_lane, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + double tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} +__device__ inline long __shfl(long var, int src_lane, int width = warpSize) { +#ifndef _MSC_VER + static_assert(sizeof(long) == 2 * sizeof(int), ""); + static_assert(sizeof(long) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl(tmp[0], src_lane, width); + tmp[1] = __shfl(tmp[1], src_lane, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +#else + static_assert(sizeof(long) == sizeof(int), ""); + return static_cast(__shfl(static_cast(var), src_lane, width)); +#endif +} +__device__ inline unsigned long __shfl(unsigned long var, int src_lane, + int width = warpSize) { +#ifndef _MSC_VER + static_assert(sizeof(unsigned long) == 2 * sizeof(unsigned int), ""); + static_assert(sizeof(unsigned long) == sizeof(uint64_t), ""); + + unsigned int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl(tmp[0], src_lane, width); + tmp[1] = __shfl(tmp[1], src_lane, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + unsigned long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +#else + static_assert(sizeof(unsigned long) == sizeof(unsigned int), ""); + return static_cast( + __shfl(static_cast(var), src_lane, width)); +#endif +} +__device__ inline long long __shfl(long long var, int src_lane, + int width = warpSize) { + static_assert(sizeof(long long) == 2 * sizeof(int), ""); + static_assert(sizeof(long long) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl(tmp[0], src_lane, width); + tmp[1] = __shfl(tmp[1], src_lane, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + long long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} +__device__ inline unsigned long long +__shfl(unsigned long long var, int src_lane, int width = warpSize) { + static_assert(sizeof(unsigned long long) == 2 * sizeof(unsigned int), ""); + static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); + + unsigned int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl(tmp[0], src_lane, width); + tmp[1] = __shfl(tmp[1], src_lane, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + unsigned long long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} + +__device__ inline int __shfl_up(int var, unsigned int lane_delta, + int width = warpSize) { + int self = __lane_id(); + int index = self - lane_delta; + index = (index < (self & ~(width - 1))) ? self : index; + return __builtin_amdgcn_ds_bpermute(index << 2, var); +} +__device__ inline unsigned int +__shfl_up(unsigned int var, unsigned int lane_delta, int width = warpSize) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.u = var; + tmp.i = __shfl_up(tmp.i, lane_delta, width); + return tmp.u; +} +__device__ inline float __shfl_up(float var, unsigned int lane_delta, + int width = warpSize) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.f = var; + tmp.i = __shfl_up(tmp.i, lane_delta, width); + return tmp.f; +} +__device__ inline double __shfl_up(double var, unsigned int lane_delta, + int width = warpSize) { + static_assert(sizeof(double) == 2 * sizeof(int), ""); + static_assert(sizeof(double) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_up(tmp[0], lane_delta, width); + tmp[1] = __shfl_up(tmp[1], lane_delta, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + double tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} +__device__ inline long __shfl_up(long var, unsigned int lane_delta, + int width = warpSize) { +#ifndef _MSC_VER + static_assert(sizeof(long) == 2 * sizeof(int), ""); + static_assert(sizeof(long) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_up(tmp[0], lane_delta, width); + tmp[1] = __shfl_up(tmp[1], lane_delta, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +#else + static_assert(sizeof(long) == sizeof(int), ""); + return static_cast(__shfl_up(static_cast(var), lane_delta, width)); +#endif +} + +__device__ inline unsigned long +__shfl_up(unsigned long var, unsigned int lane_delta, int width = warpSize) { +#ifndef _MSC_VER + static_assert(sizeof(unsigned long) == 2 * sizeof(unsigned int), ""); + static_assert(sizeof(unsigned long) == sizeof(uint64_t), ""); + + unsigned int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_up(tmp[0], lane_delta, width); + tmp[1] = __shfl_up(tmp[1], lane_delta, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + unsigned long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +#else + static_assert(sizeof(unsigned long) == sizeof(unsigned int), ""); + return static_cast( + __shfl_up(static_cast(var), lane_delta, width)); +#endif +} + +__device__ inline long long __shfl_up(long long var, unsigned int lane_delta, + int width = warpSize) { + static_assert(sizeof(long long) == 2 * sizeof(int), ""); + static_assert(sizeof(long long) == sizeof(uint64_t), ""); + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_up(tmp[0], lane_delta, width); + tmp[1] = __shfl_up(tmp[1], lane_delta, width); + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + long long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} + +__device__ inline unsigned long long __shfl_up(unsigned long long var, + unsigned int lane_delta, + int width = warpSize) { + static_assert(sizeof(unsigned long long) == 2 * sizeof(unsigned int), ""); + static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); + unsigned int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_up(tmp[0], lane_delta, width); + tmp[1] = __shfl_up(tmp[1], lane_delta, width); + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + unsigned long long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} + +__device__ inline int __shfl_down(int var, unsigned int lane_delta, + int width = warpSize) { + int self = __lane_id(); + int index = self + lane_delta; + index = (int)((self & (width - 1)) + lane_delta) >= width ? self : index; + return __builtin_amdgcn_ds_bpermute(index << 2, var); +} +__device__ inline unsigned int +__shfl_down(unsigned int var, unsigned int lane_delta, int width = warpSize) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.u = var; + tmp.i = __shfl_down(tmp.i, lane_delta, width); + return tmp.u; +} +__device__ inline float __shfl_down(float var, unsigned int lane_delta, + int width = warpSize) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.f = var; + tmp.i = __shfl_down(tmp.i, lane_delta, width); + return tmp.f; +} +__device__ inline double __shfl_down(double var, unsigned int lane_delta, + int width = warpSize) { + static_assert(sizeof(double) == 2 * sizeof(int), ""); + static_assert(sizeof(double) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_down(tmp[0], lane_delta, width); + tmp[1] = __shfl_down(tmp[1], lane_delta, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + double tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} +__device__ inline long __shfl_down(long var, unsigned int lane_delta, + int width = warpSize) { +#ifndef _MSC_VER + static_assert(sizeof(long) == 2 * sizeof(int), ""); + static_assert(sizeof(long) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_down(tmp[0], lane_delta, width); + tmp[1] = __shfl_down(tmp[1], lane_delta, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +#else + static_assert(sizeof(long) == sizeof(int), ""); + return static_cast( + __shfl_down(static_cast(var), lane_delta, width)); +#endif +} +__device__ inline unsigned long +__shfl_down(unsigned long var, unsigned int lane_delta, int width = warpSize) { +#ifndef _MSC_VER + static_assert(sizeof(unsigned long) == 2 * sizeof(unsigned int), ""); + static_assert(sizeof(unsigned long) == sizeof(uint64_t), ""); + + unsigned int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_down(tmp[0], lane_delta, width); + tmp[1] = __shfl_down(tmp[1], lane_delta, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + unsigned long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +#else + static_assert(sizeof(unsigned long) == sizeof(unsigned int), ""); + return static_cast( + __shfl_down(static_cast(var), lane_delta, width)); +#endif +} +__device__ inline long long __shfl_down(long long var, unsigned int lane_delta, + int width = warpSize) { + static_assert(sizeof(long long) == 2 * sizeof(int), ""); + static_assert(sizeof(long long) == sizeof(uint64_t), ""); + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_down(tmp[0], lane_delta, width); + tmp[1] = __shfl_down(tmp[1], lane_delta, width); + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + long long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} +__device__ inline unsigned long long __shfl_down(unsigned long long var, + unsigned int lane_delta, + int width = warpSize) { + static_assert(sizeof(unsigned long long) == 2 * sizeof(unsigned int), ""); + static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); + unsigned int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_down(tmp[0], lane_delta, width); + tmp[1] = __shfl_down(tmp[1], lane_delta, width); + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + unsigned long long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} + +__device__ inline int __shfl_xor(int var, int lane_mask, int width = warpSize) { + int self = __lane_id(); + int index = self ^ lane_mask; + index = index >= ((self + width) & ~(width - 1)) ? self : index; + return __builtin_amdgcn_ds_bpermute(index << 2, var); +} +__device__ inline unsigned int __shfl_xor(unsigned int var, int lane_mask, + int width = warpSize) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.u = var; + tmp.i = __shfl_xor(tmp.i, lane_mask, width); + return tmp.u; +} +__device__ inline float __shfl_xor(float var, int lane_mask, + int width = warpSize) { + union { + int i; + unsigned u; + float f; + } tmp; + tmp.f = var; + tmp.i = __shfl_xor(tmp.i, lane_mask, width); + return tmp.f; +} +__device__ inline double __shfl_xor(double var, int lane_mask, + int width = warpSize) { + static_assert(sizeof(double) == 2 * sizeof(int), ""); + static_assert(sizeof(double) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_xor(tmp[0], lane_mask, width); + tmp[1] = __shfl_xor(tmp[1], lane_mask, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + double tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} +__device__ inline long __shfl_xor(long var, int lane_mask, + int width = warpSize) { +#ifndef _MSC_VER + static_assert(sizeof(long) == 2 * sizeof(int), ""); + static_assert(sizeof(long) == sizeof(uint64_t), ""); + + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_xor(tmp[0], lane_mask, width); + tmp[1] = __shfl_xor(tmp[1], lane_mask, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +#else + static_assert(sizeof(long) == sizeof(int), ""); + return static_cast(__shfl_xor(static_cast(var), lane_mask, width)); +#endif +} +__device__ inline unsigned long __shfl_xor(unsigned long var, int lane_mask, + int width = warpSize) { +#ifndef _MSC_VER + static_assert(sizeof(unsigned long) == 2 * sizeof(unsigned int), ""); + static_assert(sizeof(unsigned long) == sizeof(uint64_t), ""); + + unsigned int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_xor(tmp[0], lane_mask, width); + tmp[1] = __shfl_xor(tmp[1], lane_mask, width); + + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + unsigned long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +#else + static_assert(sizeof(unsigned long) == sizeof(unsigned int), ""); + return static_cast( + __shfl_xor(static_cast(var), lane_mask, width)); +#endif +} +__device__ inline long long __shfl_xor(long long var, int lane_mask, + int width = warpSize) { + static_assert(sizeof(long long) == 2 * sizeof(int), ""); + static_assert(sizeof(long long) == sizeof(uint64_t), ""); + int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_xor(tmp[0], lane_mask, width); + tmp[1] = __shfl_xor(tmp[1], lane_mask, width); + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + long long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} +__device__ inline unsigned long long +__shfl_xor(unsigned long long var, int lane_mask, int width = warpSize) { + static_assert(sizeof(unsigned long long) == 2 * sizeof(unsigned int), ""); + static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); + unsigned int tmp[2]; + __builtin_memcpy(tmp, &var, sizeof(tmp)); + tmp[0] = __shfl_xor(tmp[0], lane_mask, width); + tmp[1] = __shfl_xor(tmp[1], lane_mask, width); + uint64_t tmp0 = + (static_cast(tmp[1]) << 32ull) | static_cast(tmp[0]); + unsigned long long tmp1; + __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0)); + return tmp1; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_warp_sync_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_warp_sync_functions.h new file mode 100644 index 000000000..4032bf64f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/amd_warp_sync_functions.h @@ -0,0 +1,279 @@ +/* +Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +// Warp sync builtins (with explicit mask argument) introduced in ROCm 6.2 as a +// preview to allow end-users to adapt to the new interface involving 64-bit +// masks. These are disabled by default, and can be enabled by setting the macro +// below. The builtins will be enabled unconditionally in ROCm 6.3. +// +// This arrangement also applies to the __activemask() builtin defined in +// amd_warp_functions.h. +#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS + +#if !defined(__HIPCC_RTC__) +#include "amd_warp_functions.h" +#include "hip_assert.h" +#endif + +template __device__ inline T __hip_readfirstlane(T val) { + // In theory, behaviour is undefined when reading from a union member other + // than the member that was last assigned to, but it works in practice because + // we rely on the compiler to do the reasonable thing. + union { + unsigned long long l; + T d; + } u; + u.d = val; + // NOTE: The builtin returns int, so we first cast it to unsigned int and only + // then extend it to 64 bits. + unsigned long long lower = (unsigned)__builtin_amdgcn_readfirstlane(u.l); + unsigned long long upper = + (unsigned)__builtin_amdgcn_readfirstlane(u.l >> 32); + u.l = (upper << 32) | lower; + return u.d; +} + +// When compiling for wave32 mode, ignore the upper half of the 64-bit mask. +#define __hip_adjust_mask_for_wave32(MASK) \ + do { \ + if (warpSize == 32) \ + MASK &= 0xFFFFFFFF; \ + } while (0) + +// We use a macro to expand each builtin into a waterfall that implements the +// mask semantics: +// +// 1. The mask argument may be divergent. +// 2. Each active thread must have its own bit set in its own mask value. +// 3. For a given mask value, all threads that are mentioned in the mask must +// execute the same static instance of the builtin with the same mask. +// 4. The union of all mask values supplied at a static instance must be equal +// to the activemask at the program point. +// +// Thus, the mask argument partitions the set of currently active threads in the +// wave into disjoint subsets that cover all active threads. +// +// Implementation notes: +// --------------------- +// +// We implement this as a waterfall loop that executes the builtin for each +// subset separately. The return value is a divergent value across the active +// threads. The value for inactive threads is defined by each builtin +// separately. +// +// As long as every mask value is non-zero, we don't need to check if a lane +// specifies itself in the mask; that is done by the later assertion where all +// chosen lanes must be in the chosen mask. + +#define __hip_check_mask(MASK) \ + do { \ + __hip_assert(MASK && "mask must be non-zero"); \ + bool done = false; \ + while (__any(!done)) { \ + if (!done) { \ + auto chosen_mask = __hip_readfirstlane(MASK); \ + if (MASK == chosen_mask) { \ + __hip_assert(MASK == __ballot(true) && \ + "all threads specified in the mask" \ + " must execute the same operation with the same mask"); \ + done = true; \ + } \ + } \ + } \ + } while (0) + +#define __hip_do_sync(RETVAL, FUNC, MASK, ...) \ + do { \ + __hip_assert(MASK && "mask must be non-zero"); \ + bool done = false; \ + while (__any(!done)) { \ + if (!done) { \ + auto chosen_mask = __hip_readfirstlane(MASK); \ + if (MASK == chosen_mask) { \ + __hip_assert(MASK == __ballot(true) && \ + "all threads specified in the mask" \ + " must execute the same operation with the same mask"); \ + RETVAL = FUNC(__VA_ARGS__); \ + done = true; \ + } \ + } \ + } \ + } while (0) + +// __all_sync, __any_sync, __ballot_sync + +template +__device__ inline unsigned long long __ballot_sync(MaskT mask, int predicate) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __ballot(predicate) & mask; +} + +template +__device__ inline int __all_sync(MaskT mask, int predicate) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + return __ballot_sync(mask, predicate) == mask; +} + +template +__device__ inline int __any_sync(MaskT mask, int predicate) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + return __ballot_sync(mask, predicate) != 0; +} + +// __match_any, __match_all and sync variants + +template +__device__ inline unsigned long long __match_any(T value) { + static_assert( + (__hip_internal::is_integral::value || + __hip_internal::is_floating_point::value) && + (sizeof(T) == 4 || sizeof(T) == 8), + "T can be int, unsigned int, long, unsigned long, long long, unsigned " + "long long, float or double."); + bool done = false; + unsigned long long retval = 0; + + while (__any(!done)) { + if (!done) { + T chosen = __hip_readfirstlane(value); + if (chosen == value) { + retval = __activemask(); + done = true; + } + } + } + + return retval; +} + +template +__device__ inline unsigned long long __match_any_sync(MaskT mask, T value) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __match_any(value) & mask; +} + +template +__device__ inline unsigned long long __match_all(T value, int *pred) { + static_assert( + (__hip_internal::is_integral::value || + __hip_internal::is_floating_point::value) && + (sizeof(T) == 4 || sizeof(T) == 8), + "T can be int, unsigned int, long, unsigned long, long long, unsigned " + "long long, float or double."); + T first = __hip_readfirstlane(value); + if (__all(first == value)) { + *pred = true; + return __activemask(); + } else { + *pred = false; + return 0; + } +} + +template +__device__ inline unsigned long long __match_all_sync(MaskT mask, T value, + int *pred) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + MaskT retval = 0; + __hip_adjust_mask_for_wave32(mask); + __hip_do_sync(retval, __match_all, mask, value, pred); + return retval; +} + +// various variants of shfl + +template +__device__ inline T __shfl_sync(MaskT mask, T var, int srcLane, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl(var, srcLane, width); +} + +template +__device__ inline T __shfl_up_sync(MaskT mask, T var, unsigned int delta, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl_up(var, delta, width); +} + +template +__device__ inline T __shfl_down_sync(MaskT mask, T var, unsigned int delta, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl_down(var, delta, width); +} + +template +__device__ inline T __shfl_xor_sync(MaskT mask, T var, int laneMask, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl_xor(var, laneMask, width); +} + +#undef __hip_do_sync +#undef __hip_check_mask +#undef __hip_adjust_mask_for_wave32 + +#endif // HIP_ENABLE_WARP_SYNC_BUILTINS diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/concepts.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/concepts.hpp new file mode 100644 index 000000000..8d6488b0d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/concepts.hpp @@ -0,0 +1,30 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +namespace hip_impl // Documentation only. +{ +#define requires(...) + +#define FunctionalProcedure typename +} // namespace hip_impl diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/device_library_decls.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/device_library_decls.h new file mode 100644 index 000000000..208a9c425 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/device_library_decls.h @@ -0,0 +1,170 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file amd_detail/device_library_decls.h + * @brief Contains declarations for types and functions in device library. + * Uses int64_t and uint64_t instead of long, long long, unsigned + * long and unsigned long long types for device library API + * declarations. + */ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_DEVICE_LIBRARY_DECLS_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_DEVICE_LIBRARY_DECLS_H + +#if !defined(__HIPCC_RTC__) +#include "hip/amd_detail/host_defines.h" +#endif + +typedef unsigned char uchar; +typedef unsigned short ushort; +typedef unsigned int uint; +typedef unsigned long ulong; +typedef unsigned long long ullong; + +extern "C" __device__ __attribute__((const)) bool __ockl_wfany_i32(int); +extern "C" __device__ __attribute__((const)) bool __ockl_wfall_i32(int); +extern "C" __device__ uint __ockl_activelane_u32(void); + +extern "C" __device__ __attribute__((const)) uint __ockl_mul24_u32(uint, uint); +extern "C" __device__ __attribute__((const)) int __ockl_mul24_i32(int, int); +extern "C" __device__ __attribute__((const)) uint __ockl_mul_hi_u32(uint, uint); +extern "C" __device__ __attribute__((const)) int __ockl_mul_hi_i32(int, int); +extern "C" __device__ + __attribute__((const)) uint __ockl_sadd_u32(uint, uint, uint); + +extern "C" __device__ __attribute__((const)) uchar __ockl_clz_u8(uchar); +extern "C" __device__ __attribute__((const)) ushort __ockl_clz_u16(ushort); +extern "C" __device__ __attribute__((const)) uint __ockl_clz_u32(uint); +extern "C" __device__ __attribute__((const)) uint64_t __ockl_clz_u64(uint64_t); + +extern "C" __device__ __attribute__((const)) float __ocml_floor_f32(float); +extern "C" __device__ __attribute__((const)) float __ocml_rint_f32(float); +extern "C" __device__ __attribute__((const)) float __ocml_ceil_f32(float); +extern "C" __device__ __attribute__((const)) float __ocml_trunc_f32(float); + +extern "C" __device__ __attribute__((const)) float __ocml_fmin_f32(float, + float); +extern "C" __device__ __attribute__((const)) float __ocml_fmax_f32(float, + float); + +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtn_f32_f64(double); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtp_f32_f64(double); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtz_f32_f64(double); + +extern "C" __device__ __attribute__((const)) _Float16 +__ocml_cvtrtn_f16_f32(float); +extern "C" __device__ __attribute__((const)) _Float16 +__ocml_cvtrtp_f16_f32(float); +extern "C" __device__ __attribute__((const)) _Float16 +__ocml_cvtrtz_f16_f32(float); + +extern "C" __device__ __attribute__((const)) float __ocml_cvtrtn_f32_s32(int); +extern "C" __device__ __attribute__((const)) float __ocml_cvtrtp_f32_s32(int); +extern "C" __device__ __attribute__((const)) float __ocml_cvtrtz_f32_s32(int); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtn_f32_u32(uint32_t); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtp_f32_u32(uint32_t); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtz_f32_u32(uint32_t); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtn_f32_s64(int64_t); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtp_f32_s64(int64_t); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtz_f32_s64(int64_t); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtn_f32_u64(uint64_t); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtp_f32_u64(uint64_t); +extern "C" __device__ __attribute__((const)) float +__ocml_cvtrtz_f32_u64(uint64_t); +extern "C" __device__ __attribute__((const)) double +__ocml_cvtrtn_f64_s64(int64_t); +extern "C" __device__ __attribute__((const)) double +__ocml_cvtrtp_f64_s64(int64_t); +extern "C" __device__ __attribute__((const)) double +__ocml_cvtrtz_f64_s64(int64_t); +extern "C" __device__ __attribute__((const)) double +__ocml_cvtrtn_f64_u64(uint64_t); +extern "C" __device__ __attribute__((const)) double +__ocml_cvtrtp_f64_u64(uint64_t); +extern "C" __device__ __attribute__((const)) double +__ocml_cvtrtz_f64_u64(uint64_t); + +extern "C" __device__ __attribute__((convergent)) void +__ockl_gws_init(uint nwm1, uint rid); +extern "C" __device__ __attribute__((convergent)) void +__ockl_gws_barrier(uint nwm1, uint rid); + +extern "C" __device__ __attribute__((const)) uint32_t __ockl_lane_u32(); +extern "C" __device__ __attribute__((const)) int __ockl_grid_is_valid(void); +extern "C" __device__ __attribute__((convergent)) void __ockl_grid_sync(void); +extern "C" __device__ __attribute__((const)) uint +__ockl_multi_grid_num_grids(void); +extern "C" __device__ __attribute__((const)) uint +__ockl_multi_grid_grid_rank(void); +extern "C" __device__ __attribute__((const)) uint __ockl_multi_grid_size(void); +extern "C" __device__ __attribute__((const)) uint +__ockl_multi_grid_thread_rank(void); +extern "C" __device__ __attribute__((const)) int +__ockl_multi_grid_is_valid(void); +extern "C" __device__ __attribute__((convergent)) void +__ockl_multi_grid_sync(void); + +extern "C" __device__ void __ockl_atomic_add_noret_f32(float *, float); + +extern "C" __device__ __attribute__((convergent)) int +__ockl_wgred_add_i32(int a); +extern "C" __device__ __attribute__((convergent)) int +__ockl_wgred_and_i32(int a); +extern "C" __device__ __attribute__((convergent)) int +__ockl_wgred_or_i32(int a); + +extern "C" __device__ uint64_t __ockl_fprintf_stderr_begin(); +extern "C" __device__ uint64_t __ockl_fprintf_append_args( + uint64_t msg_desc, uint32_t num_args, uint64_t value0, uint64_t value1, + uint64_t value2, uint64_t value3, uint64_t value4, uint64_t value5, + uint64_t value6, uint32_t is_last); +extern "C" __device__ uint64_t __ockl_fprintf_append_string_n(uint64_t msg_desc, + const char *data, + uint64_t length, + uint32_t is_last); + +// Introduce local address space +#define __local __attribute__((address_space(3))) + +#ifdef __HIP_DEVICE_COMPILE__ +__device__ inline static __local void *__to_local(unsigned x) { + return (__local void *)x; +} +#endif //__HIP_DEVICE_COMPILE__ + +// Using hip.amdgcn.bc - sync threads +#define __CLK_LOCAL_MEM_FENCE 0x01 +typedef unsigned __cl_mem_fence_flags; + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/functional_grid_launch.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/functional_grid_launch.hpp new file mode 100644 index 000000000..92165eedd --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/functional_grid_launch.hpp @@ -0,0 +1,204 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#include "concepts.hpp" +#include "helpers.hpp" +#include "hip_runtime_api.h" +#include "program_state.hpp" + +#include +#include +#include +#include +#include +#include + +hipError_t +ihipExtLaunchMultiKernelMultiDevice(hipLaunchParams *launchParamsList, + int numDevices, unsigned int flags, + hip_impl::program_state &ps); + +hipError_t hipLaunchCooperativeKernel(const void *f, dim3 gridDim, + dim3 blockDim, void **args, + size_t sharedMem, hipStream_t stream, + hip_impl::program_state &ps); + +hipError_t +hipLaunchCooperativeKernelMultiDevice(hipLaunchParams *launchParamsList, + int numDevices, unsigned int flags, + hip_impl::program_state &ps); + +#pragma GCC visibility push(hidden) + +namespace hip_impl { +template {}>::type * = nullptr> +inline T round_up_to_next_multiple_nonnegative(T x, T y) { + T tmp = x + y - 1; + return tmp - tmp % y; +} + +template ::type * = nullptr> +inline hip_impl::kernarg make_kernarg(const std::tuple &, + const kernargs_size_align &, + hip_impl::kernarg kernarg) { + return kernarg; +} + +template ::type * = nullptr> +inline hip_impl::kernarg make_kernarg(const std::tuple &formals, + const kernargs_size_align &size_align, + hip_impl::kernarg kernarg) { + using T = typename std::tuple_element>::type; + + static_assert(!std::is_reference{}, + "A __global__ function cannot have a reference as one of its " + "arguments."); +#if defined(HIP_STRICT) + static_assert(std::is_trivially_copyable{}, + "Only TriviallyCopyable types can be arguments to a __global__ " + "function"); +#endif + + kernarg.resize(round_up_to_next_multiple_nonnegative( + kernarg.size(), size_align.alignment(n)) + + size_align.size(n)); + + std::memcpy(kernarg.data() + kernarg.size() - size_align.size(n), + &std::get(formals), size_align.size(n)); + return make_kernarg(formals, size_align, std::move(kernarg)); +} + +template +inline hip_impl::kernarg make_kernarg(void (*kernel)(Formals...), + std::tuple actuals) { + static_assert( + sizeof...(Formals) == sizeof...(Actuals), + "The count of formal arguments must match the count of actuals."); + + if (sizeof...(Formals) == 0) + return {}; + + std::tuple to_formals{std::move(actuals)}; + hip_impl::kernarg kernarg; + kernarg.reserve(sizeof(to_formals)); + + auto &ps = hip_impl::get_program_state(); + return make_kernarg<0>( + to_formals, + ps.get_kernargs_size_align(reinterpret_cast(kernel)), + std::move(kernarg)); +} + +HIP_INTERNAL_EXPORTED_API hsa_agent_t target_agent(hipStream_t stream); + +inline __attribute__((visibility("hidden"))) void +hipLaunchKernelGGLImpl(std::uintptr_t function_address, const dim3 &numBlocks, + const dim3 &dimBlocks, std::uint32_t sharedMemBytes, + hipStream_t stream, void **kernarg) { + + const auto &kd = hip_impl::get_program_state().kernel_descriptor( + function_address, target_agent(stream)); + + hipModuleLaunchKernel(kd, numBlocks.x, numBlocks.y, numBlocks.z, dimBlocks.x, + dimBlocks.y, dimBlocks.z, sharedMemBytes, stream, + nullptr, kernarg); +} +} // Namespace hip_impl. + +template +inline hipError_t +hipOccupancyMaxPotentialBlockSize(int *gridSize, int *blockSize, T kernel, + size_t dynSharedMemPerBlk = 0, + int blockSizeLimit = 0) { + + using namespace hip_impl; + + hip_impl::hip_init(); + auto f = get_program_state().kernel_descriptor( + reinterpret_cast(kernel), target_agent(0)); + + return hipModuleOccupancyMaxPotentialBlockSize( + gridSize, blockSize, f, dynSharedMemPerBlk, blockSizeLimit); +} + +template +inline hipError_t hipOccupancyMaxPotentialBlockSizeWithFlags( + int *gridSize, int *blockSize, T kernel, size_t dynSharedMemPerBlk = 0, + int blockSizeLimit = 0, unsigned int flags = 0) { + + using namespace hip_impl; + + hip_impl::hip_init(); + if (flags != hipOccupancyDefault) + return hipErrorNotSupported; + auto f = get_program_state().kernel_descriptor( + reinterpret_cast(kernel), target_agent(0)); + + return hipModuleOccupancyMaxPotentialBlockSize( + gridSize, blockSize, f, dynSharedMemPerBlk, blockSizeLimit); +} + +template +inline void hipLaunchKernelGGL(F kernel, const dim3 &numBlocks, + const dim3 &dimBlocks, + std::uint32_t sharedMemBytes, hipStream_t stream, + Args... args) { + hip_impl::hip_init(); + auto kernarg = + hip_impl::make_kernarg(kernel, std::tuple{std::move(args)...}); + std::size_t kernarg_size = kernarg.size(); + + void *config[]{HIP_LAUNCH_PARAM_BUFFER_POINTER, kernarg.data(), + HIP_LAUNCH_PARAM_BUFFER_SIZE, &kernarg_size, + HIP_LAUNCH_PARAM_END}; + + hip_impl::hipLaunchKernelGGLImpl(reinterpret_cast(kernel), + numBlocks, dimBlocks, sharedMemBytes, stream, + &config[0]); +} + +template +inline __attribute__((visibility("hidden"))) hipError_t +hipLaunchCooperativeKernel(F f, dim3 gridDim, dim3 blockDim, void **args, + size_t sharedMem, hipStream_t stream) { + hip_impl::hip_init(); + auto &ps = hip_impl::get_program_state(); + return hipLaunchCooperativeKernel(reinterpret_cast(f), gridDim, + blockDim, args, sharedMem, stream, ps); +} + +inline __attribute__((visibility("hidden"))) hipError_t +hipLaunchCooperativeKernelMultiDevice(hipLaunchParams *launchParamsList, + int numDevices, unsigned int flags) { + + hip_impl::hip_init(); + auto &ps = hip_impl::get_program_state(); + return hipLaunchCooperativeKernelMultiDevice(launchParamsList, numDevices, + flags, ps); +} + +#pragma GCC visibility pop diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch.h new file mode 100644 index 000000000..9e896485a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch.h @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include + +#define GRID_LAUNCH_VERSION 20 + +// Extern definitions +namespace hc { +class completion_future; +class accelerator_view; +} // namespace hc + +// 3 dim structure for groups and grids. +typedef struct gl_dim3 { + int x, y, z; + gl_dim3(uint32_t _x = 1, uint32_t _y = 1, uint32_t _z = 1) + : x(_x), y(_y), z(_z) {}; +} gl_dim3; + +typedef enum gl_barrier_bit { + barrier_bit_queue_default, + barrier_bit_none, + barrier_bit_wait, +} gl_barrier_bit; + +// grid_launch_parm contains information used to launch the kernel. +typedef struct grid_launch_parm { + //! Grid dimensions + gl_dim3 grid_dim; + + //! Group dimensions + gl_dim3 group_dim; + + //! Amount of dynamic group memory to use with the kernel launch. + //! This memory is in addition to the amount used statically in the kernel. + unsigned int dynamic_group_mem_bytes; + + //! Control setting of barrier bit on per-packet basis: + //! See gl_barrier_bit description. + //! Placeholder, is not used to control packet dispatch yet + enum gl_barrier_bit barrier_bit; + + //! Value of packet fences to apply to launch. + //! The correspond to the value of bits 9:14 in the AQL packet, + //! see HSA_PACKET_HEADER_ACQUIRE_FENCE_SCOPE and hsa_fence_scope_t. + unsigned int launch_fence; + + //! Pointer to the accelerator_view where the kernel should execute. + //! If NULL, the default view on the default accelerator is used. + hc::accelerator_view *av; + + //! Pointer to the completion_future used to track the status of the command. + //! If NULL, the command does not write status. In this case, + //! synchronization can be enforced with queue-level waits or + //! waiting on younger commands. + hc::completion_future *cf; + + grid_launch_parm() = default; +} grid_launch_parm; + +extern void init_grid_launch(grid_launch_parm *gl); diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch.hpp new file mode 100644 index 000000000..29e1f6838 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include "grid_launch.h" +#include "hc.hpp" + +class grid_launch_parm_cxx : public grid_launch_parm { +public: + grid_launch_parm_cxx() = default; + + // customized serialization: don't need av and cf in kernel + __attribute__((annotate("serialize"))) void + __cxxamp_serialize(Kalmar::Serialize &s) const { + s.Append(sizeof(int), &grid_dim.x); + s.Append(sizeof(int), &grid_dim.y); + s.Append(sizeof(int), &grid_dim.z); + s.Append(sizeof(int), &group_dim.x); + s.Append(sizeof(int), &group_dim.y); + s.Append(sizeof(int), &group_dim.z); + } + + __attribute__((annotate("user_deserialize"))) + grid_launch_parm_cxx(int grid_dim_x, int grid_dim_y, int grid_dim_z, + int group_dim_x, int group_dim_y, int group_dim_z) { + grid_dim.x = grid_dim_x; + grid_dim.y = grid_dim_y; + grid_dim.z = grid_dim_z; + group_dim.x = group_dim_x; + group_dim.y = group_dim_y; + group_dim.z = group_dim_z; + } +}; + +extern inline void grid_launch_init(grid_launch_parm *lp) { + lp->grid_dim.x = lp->grid_dim.y = lp->grid_dim.z = 1; + + lp->group_dim.x = lp->group_dim.y = lp->group_dim.z = 1; + + lp->dynamic_group_mem_bytes = 0; + + lp->barrier_bit = barrier_bit_queue_default; + lp->launch_fence = -1; + + // TODO - set to NULL? + static hc::accelerator_view av = hc::accelerator().get_default_view(); + lp->av = &av; + lp->cf = NULL; +} diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch_GGL.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch_GGL.hpp new file mode 100644 index 000000000..42188e907 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/grid_launch_GGL.hpp @@ -0,0 +1,26 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#pragma once + +#if GENERIC_GRID_LAUNCH == 1 +#include "macro_based_grid_launch.hpp" +#endif // GENERIC_GRID_LAUNCH diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/helpers.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/helpers.hpp new file mode 100644 index 000000000..c4c36a9cf --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/helpers.hpp @@ -0,0 +1,142 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once +#include "concepts.hpp" + +#include // For std::conditional, std::decay, std::enable_if, + // std::false_type, std result_of and std::true_type. +#include // For std::declval. + +#ifdef __has_include // Check if __has_include is present +#if __has_include() // Check for version header +#include +#if defined(__cpp_lib_is_invocable) && !defined(HIP_HAS_INVOCABLE) +#define HIP_HAS_INVOCABLE __cpp_lib_is_invocable +#endif +#if defined(__cpp_lib_result_of_sfinae) && !defined(HIP_HAS_RESULT_OF_SFINAE) +#define HIP_HAS_RESULT_OF_SFINAE __cpp_lib_result_of_sfinae +#endif +#endif +#endif + +#ifndef HIP_HAS_INVOCABLE +#define HIP_HAS_INVOCABLE 0 +#endif + +#ifndef HIP_HAS_RESULT_OF_SFINAE +#define HIP_HAS_RESULT_OF_SFINAE 0 +#endif + +namespace std { // TODO: these should be removed as soon as possible. +#if (__cplusplus < 201406L) +#if (__cplusplus < 201402L) +template +using enable_if_t = typename enable_if::type; +template +using conditional_t = typename conditional::type; +template using decay_t = typename decay::type; +template +using result_of_t = typename result_of::type; +template +using remove_reference_t = typename remove_reference::type; +#endif +#endif +} // namespace std + +namespace hip_impl { +template using void_t_ = void; + +#if HIP_HAS_INVOCABLE +template struct is_callable_impl; + +template +struct is_callable_impl : std::is_invocable {}; +#elif HIP_HAS_RESULT_OF_SFINAE +template +struct is_callable_impl : std::false_type {}; + +template +struct is_callable_impl::type>> + : std::true_type {}; +#else +template +auto simple_invoke(T Base::*pmd, Derived &&ref) + -> decltype(static_cast(ref).*pmd); + +template +auto simple_invoke(PMD &&pmd, Pointer &&ptr) + -> decltype((*static_cast(ptr)).*static_cast(pmd)); + +template +auto simple_invoke(T Base::*pmd, const std::reference_wrapper &ref) + -> decltype(ref.get().*pmd); + +template +auto simple_invoke(T Base::*pmf, Derived &&ref, Args &&...args) + -> decltype((static_cast(ref).* + pmf)(static_cast(args)...)); + +template +auto simple_invoke(PMF &&pmf, Pointer &&ptr, Args &&...args) + -> decltype(((*static_cast(ptr)).* + static_cast(pmf))(static_cast(args)...)); + +template +auto simple_invoke(T Base::*pmf, const std::reference_wrapper &ref, + Args &&...args) + -> decltype((ref.get().*pmf)(static_cast(args)...)); + +template +auto simple_invoke(F &&f, Ts &&...xs) -> decltype(f(static_cast(xs)...)); + +template +struct is_callable_impl : std::false_type {}; + +template +struct is_callable_impl< + F(Ts...), + void_t_(), std::declval()...))>> + : std::true_type {}; + +#endif + +template struct is_callable : is_callable_impl {}; + +#define count_macro_args_impl_hip_(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, \ + _10, _11, _12, _13, _14, _15, _16, _17, \ + _18, _19, _20, _21, _22, _23, _24, _25, \ + _26, _27, _28, _29, _30, _31, _n, ...) \ + _n +#define count_macro_args_hip_(...) \ + count_macro_args_impl_hip_(, ##__VA_ARGS__, 31, 30, 29, 28, 27, 26, 25, 24, \ + 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, \ + 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define overloaded_macro_expand_hip_(macro, arg_cnt) macro##arg_cnt +#define overload_macro_impl_hip_(macro, arg_cnt) \ + overloaded_macro_expand_hip_(macro, arg_cnt) +#define overload_macro_hip_(macro, ...) \ + overload_macro_impl_hip_(macro, \ + count_macro_args_hip_(__VA_ARGS__))(__VA_ARGS__) +} // namespace hip_impl diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_api_trace.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_api_trace.hpp new file mode 100644 index 000000000..e84c67b9c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_api_trace.hpp @@ -0,0 +1,1723 @@ +/* + Copyright (c) 2023 - 2024 Advanced Micro Devices, Inc. All rights reserved. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to + deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + sell copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + IN THE SOFTWARE. + */ +#pragma once + +#include + +// Define some version macros for the API table. Use similar naming conventions +// to HSA-runtime (MAJOR and STEP versions). Three groups at this time: +// +// (A) HIP_API_TABLE_* defines for versioning for API table structure +// (B) HIP_RUNTIME_API_TABLE_* defines for versioning the HipDispatchTable +// struct (C) HIP_COMPILER_API_TABLE_* defines for versioning the +// HipCompilerDispatchTable struct +// +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! IMPORTANT +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// +// 1. When new functions are added to the API table, always add the new +// function pointer to the +// end of the table and increment the dispatch table's step version +// number. NEVER re-arrange the order of the member variables in a +// dispatch table. This will break the ABI. +// 2. In dire circumstances, if the type of an existing member variable in a +// dispatch +// table has be changed because a data type has been changed/removed, +// increment the dispatch table's major version number. If the function +// pointer type can no longer be declared, DO NOT REMOVE IT! Make the +// function pointer type void* and have it always be set to a nullptr. +// +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// +// The major version number should (ideally) never need to be incremented. +// - Increment the HIP_API_TABLE_MAJOR_VERSION for fundamental changes to the +// API table structs. +// - Increment the HIP_RUNTIME_API_TABLE_MAJOR_VERSION for fundamental changes +// to the +// HipDispatchTable struct, such as a *change* to type/name an existing member +// variable. DO NOT REMOVE IT. +// - Increment the HIP_COMPILER_API_TABLE_MAJOR_VERSION for fundamental changes +// to the +// HipCompilerDispatchTable struct, such as a *change* to type/name an +// existing member variable. DO NOT REMOVE IT. +#define HIP_API_TABLE_MAJOR_VERSION 0 +#define HIP_COMPILER_API_TABLE_MAJOR_VERSION 0 +#define HIP_RUNTIME_API_TABLE_MAJOR_VERSION 0 + +// The step version number should be changed whenever the size of the API table +// struct(s) change. +// - Increment the HIP_API_TABLE_STEP_VERSION when/if new API table structs are +// added +// - Increment the HIP_RUNTIME_API_TABLE_STEP_VERSION when new runtime API +// functions are added +// - Increment the HIP_COMPILER_API_TABLE_STEP_VERSION when new compiler API +// functions are added +// - Reset any of the *_STEP_VERSION defines to zero if the corresponding +// *_MAJOR_VERSION increases +#define HIP_API_TABLE_STEP_VERSION 0 +#define HIP_COMPILER_API_TABLE_STEP_VERSION 0 +#define HIP_RUNTIME_API_TABLE_STEP_VERSION 3 + +// HIP API interface +typedef hipError_t (*t___hipPopCallConfiguration)(dim3 *gridDim, dim3 *blockDim, + size_t *sharedMem, + hipStream_t *stream); +typedef hipError_t (*t___hipPushCallConfiguration)(dim3 gridDim, dim3 blockDim, + size_t sharedMem, + hipStream_t stream); +typedef void **(*t___hipRegisterFatBinary)(const void *data); +typedef void (*t___hipRegisterFunction)( + void **modules, const void *hostFunction, char *deviceFunction, + const char *deviceName, unsigned int threadLimit, uint3 *tid, uint3 *bid, + dim3 *blockDim, dim3 *gridDim, int *wSize); +typedef void (*t___hipRegisterManagedVar)(void *hipModule, void **pointer, + void *init_value, const char *name, + size_t size, unsigned align); +typedef void (*t___hipRegisterSurface)(void **modules, void *var, char *hostVar, + char *deviceVar, int type, int ext); +typedef void (*t___hipRegisterTexture)(void **modules, void *var, char *hostVar, + char *deviceVar, int type, int norm, + int ext); +typedef void (*t___hipRegisterVar)(void **modules, void *var, char *hostVar, + char *deviceVar, int ext, size_t size, + int constant, int global); +typedef void (*t___hipUnregisterFatBinary)(void **modules); + +typedef const char *(*t_hipApiName)(uint32_t id); +typedef hipError_t (*t_hipArray3DCreate)( + hipArray_t *array, const HIP_ARRAY3D_DESCRIPTOR *pAllocateArray); +typedef hipError_t (*t_hipArray3DGetDescriptor)( + HIP_ARRAY3D_DESCRIPTOR *pArrayDescriptor, hipArray_t array); +typedef hipError_t (*t_hipArrayCreate)( + hipArray_t *pHandle, const HIP_ARRAY_DESCRIPTOR *pAllocateArray); +typedef hipError_t (*t_hipArrayDestroy)(hipArray_t array); +typedef hipError_t (*t_hipArrayGetDescriptor)( + HIP_ARRAY_DESCRIPTOR *pArrayDescriptor, hipArray_t array); +typedef hipError_t (*t_hipArrayGetInfo)(hipChannelFormatDesc *desc, + hipExtent *extent, unsigned int *flags, + hipArray_t array); +typedef hipError_t (*t_hipBindTexture)(size_t *offset, + const textureReference *tex, + const void *devPtr, + const hipChannelFormatDesc *desc, + size_t size); +typedef hipError_t (*t_hipBindTexture2D)(size_t *offset, + const textureReference *tex, + const void *devPtr, + const hipChannelFormatDesc *desc, + size_t width, size_t height, + size_t pitch); +typedef hipError_t (*t_hipBindTextureToArray)(const textureReference *tex, + hipArray_const_t array, + const hipChannelFormatDesc *desc); +typedef hipError_t (*t_hipBindTextureToMipmappedArray)( + const textureReference *tex, hipMipmappedArray_const_t mipmappedArray, + const hipChannelFormatDesc *desc); +typedef hipError_t (*t_hipChooseDevice)(int *device, + const hipDeviceProp_t *prop); +typedef hipError_t (*t_hipChooseDeviceR0000)( + int *device, const hipDeviceProp_tR0000 *properties); +typedef hipError_t (*t_hipConfigureCall)(dim3 gridDim, dim3 blockDim, + size_t sharedMem, hipStream_t stream); +typedef hipError_t (*t_hipCreateSurfaceObject)(hipSurfaceObject_t *pSurfObject, + const hipResourceDesc *pResDesc); +typedef hipError_t (*t_hipCreateTextureObject)( + hipTextureObject_t *pTexObject, const hipResourceDesc *pResDesc, + const hipTextureDesc *pTexDesc, + const struct hipResourceViewDesc *pResViewDesc); +typedef hipError_t (*t_hipCtxCreate)(hipCtx_t *ctx, unsigned int flags, + hipDevice_t device); +typedef hipError_t (*t_hipCtxDestroy)(hipCtx_t ctx); +typedef hipError_t (*t_hipCtxDisablePeerAccess)(hipCtx_t peerCtx); +typedef hipError_t (*t_hipCtxEnablePeerAccess)(hipCtx_t peerCtx, + unsigned int flags); +typedef hipError_t (*t_hipCtxGetApiVersion)(hipCtx_t ctx, int *apiVersion); +typedef hipError_t (*t_hipCtxGetCacheConfig)(hipFuncCache_t *cacheConfig); +typedef hipError_t (*t_hipCtxGetCurrent)(hipCtx_t *ctx); +typedef hipError_t (*t_hipCtxGetDevice)(hipDevice_t *device); +typedef hipError_t (*t_hipCtxGetFlags)(unsigned int *flags); +typedef hipError_t (*t_hipCtxGetSharedMemConfig)(hipSharedMemConfig *pConfig); +typedef hipError_t (*t_hipCtxPopCurrent)(hipCtx_t *ctx); +typedef hipError_t (*t_hipCtxPushCurrent)(hipCtx_t ctx); +typedef hipError_t (*t_hipCtxSetCacheConfig)(hipFuncCache_t cacheConfig); +typedef hipError_t (*t_hipCtxSetCurrent)(hipCtx_t ctx); +typedef hipError_t (*t_hipCtxSetSharedMemConfig)(hipSharedMemConfig config); +typedef hipError_t (*t_hipCtxSynchronize)(void); +typedef hipError_t (*t_hipDestroyExternalMemory)(hipExternalMemory_t extMem); +typedef hipError_t (*t_hipDestroyExternalSemaphore)( + hipExternalSemaphore_t extSem); +typedef hipError_t (*t_hipDestroySurfaceObject)( + hipSurfaceObject_t surfaceObject); +typedef hipError_t (*t_hipDestroyTextureObject)( + hipTextureObject_t textureObject); +typedef hipError_t (*t_hipDeviceCanAccessPeer)(int *canAccessPeer, int deviceId, + int peerDeviceId); +typedef hipError_t (*t_hipDeviceComputeCapability)(int *major, int *minor, + hipDevice_t device); +typedef hipError_t (*t_hipDeviceDisablePeerAccess)(int peerDeviceId); +typedef hipError_t (*t_hipDeviceEnablePeerAccess)(int peerDeviceId, + unsigned int flags); +typedef hipError_t (*t_hipDeviceGet)(hipDevice_t *device, int ordinal); +typedef hipError_t (*t_hipDeviceGetAttribute)(int *pi, + hipDeviceAttribute_t attr, + int deviceId); +typedef hipError_t (*t_hipDeviceGetByPCIBusId)(int *device, + const char *pciBusId); +typedef hipError_t (*t_hipDeviceGetCacheConfig)(hipFuncCache_t *cacheConfig); +typedef hipError_t (*t_hipDeviceGetDefaultMemPool)(hipMemPool_t *mem_pool, + int device); +typedef hipError_t (*t_hipDeviceGetGraphMemAttribute)( + int device, hipGraphMemAttributeType attr, void *value); +typedef hipError_t (*t_hipDeviceGetLimit)(size_t *pValue, + enum hipLimit_t limit); +typedef hipError_t (*t_hipDeviceGetMemPool)(hipMemPool_t *mem_pool, int device); +typedef hipError_t (*t_hipDeviceGetName)(char *name, int len, + hipDevice_t device); +typedef hipError_t (*t_hipDeviceGetP2PAttribute)(int *value, + hipDeviceP2PAttr attr, + int srcDevice, int dstDevice); +typedef hipError_t (*t_hipDeviceGetPCIBusId)(char *pciBusId, int len, + int device); +typedef hipError_t (*t_hipDeviceGetSharedMemConfig)( + hipSharedMemConfig *pConfig); +typedef hipError_t (*t_hipDeviceGetStreamPriorityRange)(int *leastPriority, + int *greatestPriority); +typedef hipError_t (*t_hipDeviceGetUuid)(hipUUID *uuid, hipDevice_t device); +typedef hipError_t (*t_hipDeviceGraphMemTrim)(int device); +typedef hipError_t (*t_hipDevicePrimaryCtxGetState)(hipDevice_t dev, + unsigned int *flags, + int *active); +typedef hipError_t (*t_hipDevicePrimaryCtxRelease)(hipDevice_t dev); +typedef hipError_t (*t_hipDevicePrimaryCtxReset)(hipDevice_t dev); +typedef hipError_t (*t_hipDevicePrimaryCtxRetain)(hipCtx_t *pctx, + hipDevice_t dev); +typedef hipError_t (*t_hipDevicePrimaryCtxSetFlags)(hipDevice_t dev, + unsigned int flags); +typedef hipError_t (*t_hipDeviceReset)(void); +typedef hipError_t (*t_hipDeviceSetCacheConfig)(hipFuncCache_t cacheConfig); +typedef hipError_t (*t_hipDeviceSetGraphMemAttribute)( + int device, hipGraphMemAttributeType attr, void *value); +typedef hipError_t (*t_hipDeviceSetLimit)(enum hipLimit_t limit, size_t value); +typedef hipError_t (*t_hipDeviceSetMemPool)(int device, hipMemPool_t mem_pool); +typedef hipError_t (*t_hipDeviceSetSharedMemConfig)(hipSharedMemConfig config); +typedef hipError_t (*t_hipDeviceSynchronize)(void); +typedef hipError_t (*t_hipDeviceTotalMem)(size_t *bytes, hipDevice_t device); +typedef hipError_t (*t_hipDriverGetVersion)(int *driverVersion); +typedef hipError_t (*t_hipDrvGetErrorName)(hipError_t hipError, + const char **errorString); +typedef hipError_t (*t_hipDrvGetErrorString)(hipError_t hipError, + const char **errorString); +typedef hipError_t (*t_hipDrvGraphAddMemcpyNode)( + hipGraphNode_t *phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t *dependencies, size_t numDependencies, + const HIP_MEMCPY3D *copyParams, hipCtx_t ctx); +typedef hipError_t (*t_hipDrvMemcpy2DUnaligned)(const hip_Memcpy2D *pCopy); +typedef hipError_t (*t_hipDrvMemcpy3D)(const HIP_MEMCPY3D *pCopy); +typedef hipError_t (*t_hipDrvMemcpy3DAsync)(const HIP_MEMCPY3D *pCopy, + hipStream_t stream); +typedef hipError_t (*t_hipDrvPointerGetAttributes)( + unsigned int numAttributes, hipPointer_attribute *attributes, void **data, + hipDeviceptr_t ptr); +typedef hipError_t (*t_hipEventCreate)(hipEvent_t *event); +typedef hipError_t (*t_hipEventCreateWithFlags)(hipEvent_t *event, + unsigned flags); +typedef hipError_t (*t_hipEventDestroy)(hipEvent_t event); +typedef hipError_t (*t_hipEventElapsedTime)(float *ms, hipEvent_t start, + hipEvent_t stop); +typedef hipError_t (*t_hipEventQuery)(hipEvent_t event); +typedef hipError_t (*t_hipEventRecord)(hipEvent_t event, hipStream_t stream); +typedef hipError_t (*t_hipEventSynchronize)(hipEvent_t event); +typedef hipError_t (*t_hipExtGetLinkTypeAndHopCount)(int device1, int device2, + uint32_t *linktype, + uint32_t *hopcount); +typedef hipError_t (*t_hipExtLaunchKernel)(const void *function_address, + dim3 numBlocks, dim3 dimBlocks, + void **args, size_t sharedMemBytes, + hipStream_t stream, + hipEvent_t startEvent, + hipEvent_t stopEvent, int flags); +typedef hipError_t (*t_hipExtLaunchMultiKernelMultiDevice)( + hipLaunchParams *launchParamsList, int numDevices, unsigned int flags); +typedef hipError_t (*t_hipExtMallocWithFlags)(void **ptr, size_t sizeBytes, + unsigned int flags); +typedef hipError_t (*t_hipExtStreamCreateWithCUMask)(hipStream_t *stream, + uint32_t cuMaskSize, + const uint32_t *cuMask); +typedef hipError_t (*t_hipExtStreamGetCUMask)(hipStream_t stream, + uint32_t cuMaskSize, + uint32_t *cuMask); +typedef hipError_t (*t_hipExternalMemoryGetMappedBuffer)( + void **devPtr, hipExternalMemory_t extMem, + const hipExternalMemoryBufferDesc *bufferDesc); +typedef hipError_t (*t_hipFree)(void *ptr); +typedef hipError_t (*t_hipFreeArray)(hipArray_t array); +typedef hipError_t (*t_hipFreeAsync)(void *dev_ptr, hipStream_t stream); +typedef hipError_t (*t_hipFreeHost)(void *ptr); +typedef hipError_t (*t_hipFreeMipmappedArray)( + hipMipmappedArray_t mipmappedArray); +typedef hipError_t (*t_hipFuncGetAttribute)(int *value, + hipFunction_attribute attrib, + hipFunction_t hfunc); +typedef hipError_t (*t_hipFuncGetAttributes)(struct hipFuncAttributes *attr, + const void *func); +typedef hipError_t (*t_hipFuncSetAttribute)(const void *func, + hipFuncAttribute attr, int value); +typedef hipError_t (*t_hipFuncSetCacheConfig)(const void *func, + hipFuncCache_t config); +typedef hipError_t (*t_hipFuncSetSharedMemConfig)(const void *func, + hipSharedMemConfig config); +typedef hipError_t (*t_hipGLGetDevices)(unsigned int *pHipDeviceCount, + int *pHipDevices, + unsigned int hipDeviceCount, + hipGLDeviceList deviceList); +typedef hipError_t (*t_hipGetChannelDesc)(hipChannelFormatDesc *desc, + hipArray_const_t array); +typedef hipError_t (*t_hipGetDevice)(int *deviceId); +typedef hipError_t (*t_hipGetDeviceCount)(int *count); +typedef hipError_t (*t_hipGetDeviceFlags)(unsigned int *flags); +typedef hipError_t (*t_hipGetDevicePropertiesR0600)(hipDeviceProp_tR0600 *prop, + int device); +typedef hipError_t (*t_hipGetDevicePropertiesR0000)(hipDeviceProp_tR0000 *prop, + int device); +typedef const char *(*t_hipGetErrorName)(hipError_t hip_error); +typedef const char *(*t_hipGetErrorString)(hipError_t hipError); +typedef hipError_t (*t_hipGetLastError)(void); +typedef hipError_t (*t_hipGetMipmappedArrayLevel)( + hipArray_t *levelArray, hipMipmappedArray_const_t mipmappedArray, + unsigned int level); +typedef hipError_t (*t_hipGetSymbolAddress)(void **devPtr, const void *symbol); +typedef hipError_t (*t_hipGetSymbolSize)(size_t *size, const void *symbol); +typedef hipError_t (*t_hipGetTextureAlignmentOffset)( + size_t *offset, const textureReference *texref); +typedef hipError_t (*t_hipGetTextureObjectResourceDesc)( + hipResourceDesc *pResDesc, hipTextureObject_t textureObject); +typedef hipError_t (*t_hipGetTextureObjectResourceViewDesc)( + struct hipResourceViewDesc *pResViewDesc, hipTextureObject_t textureObject); +typedef hipError_t (*t_hipGetTextureObjectTextureDesc)( + hipTextureDesc *pTexDesc, hipTextureObject_t textureObject); +typedef hipError_t (*t_hipGetTextureReference)(const textureReference **texref, + const void *symbol); +typedef hipError_t (*t_hipGraphAddChildGraphNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + hipGraph_t childGraph); +typedef hipError_t (*t_hipGraphAddDependencies)(hipGraph_t graph, + const hipGraphNode_t *from, + const hipGraphNode_t *to, + size_t numDependencies); +typedef hipError_t (*t_hipGraphAddEmptyNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies); +typedef hipError_t (*t_hipGraphAddEventRecordNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + hipEvent_t event); +typedef hipError_t (*t_hipGraphAddEventWaitNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + hipEvent_t event); +typedef hipError_t (*t_hipGraphAddHostNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const hipHostNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphAddKernelNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const hipKernelNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphAddMemAllocNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + hipMemAllocNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphAddMemFreeNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, void *dev_ptr); +typedef hipError_t (*t_hipGraphAddMemcpyNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const hipMemcpy3DParms *pCopyParams); +typedef hipError_t (*t_hipGraphAddMemcpyNode1D)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, void *dst, + const void *src, size_t count, hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphAddMemcpyNodeFromSymbol)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, void *dst, + const void *symbol, size_t count, size_t offset, hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphAddMemcpyNodeToSymbol)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const void *symbol, const void *src, size_t count, size_t offset, + hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphAddMemsetNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const hipMemsetParams *pMemsetParams); + +typedef hipError_t (*t_hipGraphChildGraphNodeGetGraph)(hipGraphNode_t node, + hipGraph_t *pGraph); +typedef hipError_t (*t_hipGraphClone)(hipGraph_t *pGraphClone, + hipGraph_t originalGraph); +typedef hipError_t (*t_hipGraphCreate)(hipGraph_t *pGraph, unsigned int flags); +typedef hipError_t (*t_hipGraphDebugDotPrint)(hipGraph_t graph, + const char *path, + unsigned int flags); +typedef hipError_t (*t_hipGraphDestroy)(hipGraph_t graph); +typedef hipError_t (*t_hipGraphDestroyNode)(hipGraphNode_t node); +typedef hipError_t (*t_hipGraphEventRecordNodeGetEvent)(hipGraphNode_t node, + hipEvent_t *event_out); +typedef hipError_t (*t_hipGraphEventRecordNodeSetEvent)(hipGraphNode_t node, + hipEvent_t event); +typedef hipError_t (*t_hipGraphEventWaitNodeGetEvent)(hipGraphNode_t node, + hipEvent_t *event_out); +typedef hipError_t (*t_hipGraphEventWaitNodeSetEvent)(hipGraphNode_t node, + hipEvent_t event); +typedef hipError_t (*t_hipGraphExecChildGraphNodeSetParams)( + hipGraphExec_t hGraphExec, hipGraphNode_t node, hipGraph_t childGraph); +typedef hipError_t (*t_hipGraphExecDestroy)(hipGraphExec_t graphExec); +typedef hipError_t (*t_hipGraphExecEventRecordNodeSetEvent)( + hipGraphExec_t hGraphExec, hipGraphNode_t hNode, hipEvent_t event); +typedef hipError_t (*t_hipGraphExecEventWaitNodeSetEvent)( + hipGraphExec_t hGraphExec, hipGraphNode_t hNode, hipEvent_t event); +typedef hipError_t (*t_hipGraphExecHostNodeSetParams)( + hipGraphExec_t hGraphExec, hipGraphNode_t node, + const hipHostNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphExecKernelNodeSetParams)( + hipGraphExec_t hGraphExec, hipGraphNode_t node, + const hipKernelNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphExecMemcpyNodeSetParams)( + hipGraphExec_t hGraphExec, hipGraphNode_t node, + hipMemcpy3DParms *pNodeParams); +typedef hipError_t (*t_hipGraphExecMemcpyNodeSetParams1D)( + hipGraphExec_t hGraphExec, hipGraphNode_t node, void *dst, const void *src, + size_t count, hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphExecMemcpyNodeSetParamsFromSymbol)( + hipGraphExec_t hGraphExec, hipGraphNode_t node, void *dst, + const void *symbol, size_t count, size_t offset, hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphExecMemcpyNodeSetParamsToSymbol)( + hipGraphExec_t hGraphExec, hipGraphNode_t node, const void *symbol, + const void *src, size_t count, size_t offset, hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphExecMemsetNodeSetParams)( + hipGraphExec_t hGraphExec, hipGraphNode_t node, + const hipMemsetParams *pNodeParams); +typedef hipError_t (*t_hipGraphExecUpdate)( + hipGraphExec_t hGraphExec, hipGraph_t hGraph, + hipGraphNode_t *hErrorNode_out, hipGraphExecUpdateResult *updateResult_out); +typedef hipError_t (*t_hipGraphGetEdges)(hipGraph_t graph, hipGraphNode_t *from, + hipGraphNode_t *to, size_t *numEdges); +typedef hipError_t (*t_hipGraphGetNodes)(hipGraph_t graph, + hipGraphNode_t *nodes, + size_t *numNodes); +typedef hipError_t (*t_hipGraphGetRootNodes)(hipGraph_t graph, + hipGraphNode_t *pRootNodes, + size_t *pNumRootNodes); +typedef hipError_t (*t_hipGraphHostNodeGetParams)( + hipGraphNode_t node, hipHostNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphHostNodeSetParams)( + hipGraphNode_t node, const hipHostNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphInstantiate)(hipGraphExec_t *pGraphExec, + hipGraph_t graph, + hipGraphNode_t *pErrorNode, + char *pLogBuffer, + size_t bufferSize); +typedef hipError_t (*t_hipGraphInstantiateWithFlags)(hipGraphExec_t *pGraphExec, + hipGraph_t graph, + unsigned long long flags); +typedef hipError_t (*t_hipGraphKernelNodeCopyAttributes)(hipGraphNode_t hSrc, + hipGraphNode_t hDst); +typedef hipError_t (*t_hipGraphKernelNodeGetAttribute)( + hipGraphNode_t hNode, hipKernelNodeAttrID attr, + hipKernelNodeAttrValue *value); +typedef hipError_t (*t_hipGraphKernelNodeGetParams)( + hipGraphNode_t node, hipKernelNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphKernelNodeSetAttribute)( + hipGraphNode_t hNode, hipKernelNodeAttrID attr, + const hipKernelNodeAttrValue *value); +typedef hipError_t (*t_hipGraphKernelNodeSetParams)( + hipGraphNode_t node, const hipKernelNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphLaunch)(hipGraphExec_t graphExec, + hipStream_t stream); +typedef hipError_t (*t_hipGraphMemAllocNodeGetParams)( + hipGraphNode_t node, hipMemAllocNodeParams *pNodeParams); +typedef hipError_t (*t_hipGraphMemFreeNodeGetParams)(hipGraphNode_t node, + void *dev_ptr); +typedef hipError_t (*t_hipGraphMemcpyNodeGetParams)( + hipGraphNode_t node, hipMemcpy3DParms *pNodeParams); +typedef hipError_t (*t_hipGraphMemcpyNodeSetParams)( + hipGraphNode_t node, const hipMemcpy3DParms *pNodeParams); +typedef hipError_t (*t_hipGraphMemcpyNodeSetParams1D)(hipGraphNode_t node, + void *dst, + const void *src, + size_t count, + hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphMemcpyNodeSetParamsFromSymbol)( + hipGraphNode_t node, void *dst, const void *symbol, size_t count, + size_t offset, hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphMemcpyNodeSetParamsToSymbol)( + hipGraphNode_t node, const void *symbol, const void *src, size_t count, + size_t offset, hipMemcpyKind kind); +typedef hipError_t (*t_hipGraphMemsetNodeGetParams)( + hipGraphNode_t node, hipMemsetParams *pNodeParams); +typedef hipError_t (*t_hipGraphMemsetNodeSetParams)( + hipGraphNode_t node, const hipMemsetParams *pNodeParams); +typedef hipError_t (*t_hipGraphNodeFindInClone)(hipGraphNode_t *pNode, + hipGraphNode_t originalNode, + hipGraph_t clonedGraph); +typedef hipError_t (*t_hipGraphNodeGetDependencies)( + hipGraphNode_t node, hipGraphNode_t *pDependencies, + size_t *pNumDependencies); +typedef hipError_t (*t_hipGraphNodeGetDependentNodes)( + hipGraphNode_t node, hipGraphNode_t *pDependentNodes, + size_t *pNumDependentNodes); +typedef hipError_t (*t_hipGraphNodeGetEnabled)(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + unsigned int *isEnabled); +typedef hipError_t (*t_hipGraphNodeGetType)(hipGraphNode_t node, + hipGraphNodeType *pType); +typedef hipError_t (*t_hipGraphNodeSetEnabled)(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + unsigned int isEnabled); +typedef hipError_t (*t_hipGraphReleaseUserObject)(hipGraph_t graph, + hipUserObject_t object, + unsigned int count); +typedef hipError_t (*t_hipGraphRemoveDependencies)(hipGraph_t graph, + const hipGraphNode_t *from, + const hipGraphNode_t *to, + size_t numDependencies); +typedef hipError_t (*t_hipGraphRetainUserObject)(hipGraph_t graph, + hipUserObject_t object, + unsigned int count, + unsigned int flags); +typedef hipError_t (*t_hipGraphUpload)(hipGraphExec_t graphExec, + hipStream_t stream); +typedef hipError_t (*t_hipGraphicsGLRegisterBuffer)( + hipGraphicsResource **resource, GLuint buffer, unsigned int flags); +typedef hipError_t (*t_hipGraphicsGLRegisterImage)( + hipGraphicsResource **resource, GLuint image, GLenum target, + unsigned int flags); +typedef hipError_t (*t_hipGraphicsMapResources)( + int count, hipGraphicsResource_t *resources, hipStream_t stream); +typedef hipError_t (*t_hipGraphicsResourceGetMappedPointer)( + void **devPtr, size_t *size, hipGraphicsResource_t resource); +typedef hipError_t (*t_hipGraphicsSubResourceGetMappedArray)( + hipArray_t *array, hipGraphicsResource_t resource, unsigned int arrayIndex, + unsigned int mipLevel); +typedef hipError_t (*t_hipGraphicsUnmapResources)( + int count, hipGraphicsResource_t *resources, hipStream_t stream); +typedef hipError_t (*t_hipGraphicsUnregisterResource)( + hipGraphicsResource_t resource); +typedef hipError_t (*t_hipHostAlloc)(void **ptr, size_t size, + unsigned int flags); +typedef hipError_t (*t_hipHostFree)(void *ptr); +typedef hipError_t (*t_hipHostGetDevicePointer)(void **devPtr, void *hstPtr, + unsigned int flags); +typedef hipError_t (*t_hipHostGetFlags)(unsigned int *flagsPtr, void *hostPtr); +typedef hipError_t (*t_hipHostMalloc)(void **ptr, size_t size, + unsigned int flags); +typedef hipError_t (*t_hipHostRegister)(void *hostPtr, size_t sizeBytes, + unsigned int flags); +typedef hipError_t (*t_hipHostUnregister)(void *hostPtr); +typedef hipError_t (*t_hipImportExternalMemory)( + hipExternalMemory_t *extMem_out, + const hipExternalMemoryHandleDesc *memHandleDesc); +typedef hipError_t (*t_hipImportExternalSemaphore)( + hipExternalSemaphore_t *extSem_out, + const hipExternalSemaphoreHandleDesc *semHandleDesc); +typedef hipError_t (*t_hipInit)(unsigned int flags); +typedef hipError_t (*t_hipIpcCloseMemHandle)(void *devPtr); +typedef hipError_t (*t_hipIpcGetEventHandle)(hipIpcEventHandle_t *handle, + hipEvent_t event); +typedef hipError_t (*t_hipIpcGetMemHandle)(hipIpcMemHandle_t *handle, + void *devPtr); +typedef hipError_t (*t_hipIpcOpenEventHandle)(hipEvent_t *event, + hipIpcEventHandle_t handle); +typedef hipError_t (*t_hipIpcOpenMemHandle)(void **devPtr, + hipIpcMemHandle_t handle, + unsigned int flags); +typedef const char *(*t_hipKernelNameRef)(const hipFunction_t f); +typedef const char *(*t_hipKernelNameRefByPtr)(const void *hostFunction, + hipStream_t stream); +typedef hipError_t (*t_hipLaunchByPtr)(const void *func); +typedef hipError_t (*t_hipLaunchCooperativeKernel)(const void *f, dim3 gridDim, + dim3 blockDimX, + void **kernelParams, + unsigned int sharedMemBytes, + hipStream_t stream); +typedef hipError_t (*t_hipLaunchCooperativeKernelMultiDevice)( + hipLaunchParams *launchParamsList, int numDevices, unsigned int flags); +typedef hipError_t (*t_hipLaunchHostFunc)(hipStream_t stream, hipHostFn_t fn, + void *userData); +typedef hipError_t (*t_hipLaunchKernel)(const void *function_address, + dim3 numBlocks, dim3 dimBlocks, + void **args, size_t sharedMemBytes, + hipStream_t stream); +typedef hipError_t (*t_hipMalloc)(void **ptr, size_t size); +typedef hipError_t (*t_hipMalloc3D)(hipPitchedPtr *pitchedDevPtr, + hipExtent extent); +typedef hipError_t (*t_hipMalloc3DArray)( + hipArray_t *array, const struct hipChannelFormatDesc *desc, + struct hipExtent extent, unsigned int flags); +typedef hipError_t (*t_hipMallocArray)(hipArray_t *array, + const hipChannelFormatDesc *desc, + size_t width, size_t height, + unsigned int flags); +typedef hipError_t (*t_hipMallocAsync)(void **dev_ptr, size_t size, + hipStream_t stream); +typedef hipError_t (*t_hipMallocFromPoolAsync)(void **dev_ptr, size_t size, + hipMemPool_t mem_pool, + hipStream_t stream); +typedef hipError_t (*t_hipMallocHost)(void **ptr, size_t size); +typedef hipError_t (*t_hipMallocManaged)(void **dev_ptr, size_t size, + unsigned int flags); +typedef hipError_t (*t_hipMallocMipmappedArray)( + hipMipmappedArray_t *mipmappedArray, + const struct hipChannelFormatDesc *desc, struct hipExtent extent, + unsigned int numLevels, unsigned int flags); +typedef hipError_t (*t_hipMallocPitch)(void **ptr, size_t *pitch, size_t width, + size_t height); +typedef hipError_t (*t_hipMemAddressFree)(void *devPtr, size_t size); +typedef hipError_t (*t_hipMemAddressReserve)(void **ptr, size_t size, + size_t alignment, void *addr, + unsigned long long flags); +typedef hipError_t (*t_hipMemAdvise)(const void *dev_ptr, size_t count, + hipMemoryAdvise advice, int device); +typedef hipError_t (*t_hipMemAllocHost)(void **ptr, size_t size); +typedef hipError_t (*t_hipMemAllocPitch)(hipDeviceptr_t *dptr, size_t *pitch, + size_t widthInBytes, size_t height, + unsigned int elementSizeBytes); +typedef hipError_t (*t_hipMemCreate)(hipMemGenericAllocationHandle_t *handle, + size_t size, + const hipMemAllocationProp *prop, + unsigned long long flags); +typedef hipError_t (*t_hipMemExportToShareableHandle)( + void *shareableHandle, hipMemGenericAllocationHandle_t handle, + hipMemAllocationHandleType handleType, unsigned long long flags); +typedef hipError_t (*t_hipMemGetAccess)(unsigned long long *flags, + const hipMemLocation *location, + void *ptr); +typedef hipError_t (*t_hipMemGetAddressRange)(hipDeviceptr_t *pbase, + size_t *psize, + hipDeviceptr_t dptr); +typedef hipError_t (*t_hipMemGetAllocationGranularity)( + size_t *granularity, const hipMemAllocationProp *prop, + hipMemAllocationGranularity_flags option); +typedef hipError_t (*t_hipMemGetAllocationPropertiesFromHandle)( + hipMemAllocationProp *prop, hipMemGenericAllocationHandle_t handle); +typedef hipError_t (*t_hipMemGetInfo)(size_t *free, size_t *total); +typedef hipError_t (*t_hipMemImportFromShareableHandle)( + hipMemGenericAllocationHandle_t *handle, void *osHandle, + hipMemAllocationHandleType shHandleType); +typedef hipError_t (*t_hipMemMap)(void *ptr, size_t size, size_t offset, + hipMemGenericAllocationHandle_t handle, + unsigned long long flags); +typedef hipError_t (*t_hipMemMapArrayAsync)(hipArrayMapInfo *mapInfoList, + unsigned int count, + hipStream_t stream); +typedef hipError_t (*t_hipMemPoolCreate)(hipMemPool_t *mem_pool, + const hipMemPoolProps *pool_props); +typedef hipError_t (*t_hipMemPoolDestroy)(hipMemPool_t mem_pool); +typedef hipError_t (*t_hipMemPoolExportPointer)( + hipMemPoolPtrExportData *export_data, void *dev_ptr); +typedef hipError_t (*t_hipMemPoolExportToShareableHandle)( + void *shared_handle, hipMemPool_t mem_pool, + hipMemAllocationHandleType handle_type, unsigned int flags); +typedef hipError_t (*t_hipMemPoolGetAccess)(hipMemAccessFlags *flags, + hipMemPool_t mem_pool, + hipMemLocation *location); +typedef hipError_t (*t_hipMemPoolGetAttribute)(hipMemPool_t mem_pool, + hipMemPoolAttr attr, + void *value); +typedef hipError_t (*t_hipMemPoolImportFromShareableHandle)( + hipMemPool_t *mem_pool, void *shared_handle, + hipMemAllocationHandleType handle_type, unsigned int flags); +typedef hipError_t (*t_hipMemPoolImportPointer)( + void **dev_ptr, hipMemPool_t mem_pool, + hipMemPoolPtrExportData *export_data); +typedef hipError_t (*t_hipMemPoolSetAccess)(hipMemPool_t mem_pool, + const hipMemAccessDesc *desc_list, + size_t count); +typedef hipError_t (*t_hipMemPoolSetAttribute)(hipMemPool_t mem_pool, + hipMemPoolAttr attr, + void *value); +typedef hipError_t (*t_hipMemPoolTrimTo)(hipMemPool_t mem_pool, + size_t min_bytes_to_hold); +typedef hipError_t (*t_hipMemPrefetchAsync)(const void *dev_ptr, size_t count, + int device, hipStream_t stream); +typedef hipError_t (*t_hipMemPtrGetInfo)(void *ptr, size_t *size); +typedef hipError_t (*t_hipMemRangeGetAttribute)(void *data, size_t data_size, + hipMemRangeAttribute attribute, + const void *dev_ptr, + size_t count); +typedef hipError_t (*t_hipMemRangeGetAttributes)( + void **data, size_t *data_sizes, hipMemRangeAttribute *attributes, + size_t num_attributes, const void *dev_ptr, size_t count); +typedef hipError_t (*t_hipMemRelease)(hipMemGenericAllocationHandle_t handle); +typedef hipError_t (*t_hipMemRetainAllocationHandle)( + hipMemGenericAllocationHandle_t *handle, void *addr); +typedef hipError_t (*t_hipMemSetAccess)(void *ptr, size_t size, + const hipMemAccessDesc *desc, + size_t count); +typedef hipError_t (*t_hipMemUnmap)(void *ptr, size_t size); +typedef hipError_t (*t_hipMemcpy)(void *dst, const void *src, size_t sizeBytes, + hipMemcpyKind kind); +typedef hipError_t (*t_hipMemcpy2D)(void *dst, size_t dpitch, const void *src, + size_t spitch, size_t width, size_t height, + hipMemcpyKind kind); +typedef hipError_t (*t_hipMemcpy2DAsync)(void *dst, size_t dpitch, + const void *src, size_t spitch, + size_t width, size_t height, + hipMemcpyKind kind, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpy2DFromArray)(void *dst, size_t dpitch, + hipArray_const_t src, + size_t wOffset, size_t hOffset, + size_t width, size_t height, + hipMemcpyKind kind); +typedef hipError_t (*t_hipMemcpy2DFromArrayAsync)( + void *dst, size_t dpitch, hipArray_const_t src, size_t wOffset, + size_t hOffset, size_t width, size_t height, hipMemcpyKind kind, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpy2DToArray)(hipArray_t dst, size_t wOffset, + size_t hOffset, const void *src, + size_t spitch, size_t width, + size_t height, hipMemcpyKind kind); +typedef hipError_t (*t_hipMemcpy2DToArrayAsync)(hipArray_t dst, size_t wOffset, + size_t hOffset, const void *src, + size_t spitch, size_t width, + size_t height, + hipMemcpyKind kind, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpy3D)(const struct hipMemcpy3DParms *p); +typedef hipError_t (*t_hipMemcpy3DAsync)(const struct hipMemcpy3DParms *p, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyAsync)(void *dst, const void *src, + size_t sizeBytes, hipMemcpyKind kind, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyAtoH)(void *dst, hipArray_t srcArray, + size_t srcOffset, size_t count); +typedef hipError_t (*t_hipMemcpyDtoD)(hipDeviceptr_t dst, hipDeviceptr_t src, + size_t sizeBytes); +typedef hipError_t (*t_hipMemcpyDtoDAsync)(hipDeviceptr_t dst, + hipDeviceptr_t src, size_t sizeBytes, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyDtoH)(void *dst, hipDeviceptr_t src, + size_t sizeBytes); +typedef hipError_t (*t_hipMemcpyDtoHAsync)(void *dst, hipDeviceptr_t src, + size_t sizeBytes, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyFromArray)(void *dst, hipArray_const_t srcArray, + size_t wOffset, size_t hOffset, + size_t count, hipMemcpyKind kind); +typedef hipError_t (*t_hipMemcpyFromSymbol)(void *dst, const void *symbol, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind); +typedef hipError_t (*t_hipMemcpyFromSymbolAsync)(void *dst, const void *symbol, + size_t sizeBytes, + size_t offset, + hipMemcpyKind kind, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyHtoA)(hipArray_t dstArray, size_t dstOffset, + const void *srcHost, size_t count); +typedef hipError_t (*t_hipMemcpyHtoD)(hipDeviceptr_t dst, void *src, + size_t sizeBytes); +typedef hipError_t (*t_hipMemcpyHtoDAsync)(hipDeviceptr_t dst, void *src, + size_t sizeBytes, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyParam2D)(const hip_Memcpy2D *pCopy); +typedef hipError_t (*t_hipMemcpyParam2DAsync)(const hip_Memcpy2D *pCopy, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyPeer)(void *dst, int dstDeviceId, + const void *src, int srcDeviceId, + size_t sizeBytes); +typedef hipError_t (*t_hipMemcpyPeerAsync)(void *dst, int dstDeviceId, + const void *src, int srcDevice, + size_t sizeBytes, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyToArray)(hipArray_t dst, size_t wOffset, + size_t hOffset, const void *src, + size_t count, hipMemcpyKind kind); +typedef hipError_t (*t_hipMemcpyToSymbol)(const void *symbol, const void *src, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind); +typedef hipError_t (*t_hipMemcpyToSymbolAsync)(const void *symbol, + const void *src, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyWithStream)(void *dst, const void *src, + size_t sizeBytes, + hipMemcpyKind kind, + hipStream_t stream); +typedef hipError_t (*t_hipMemset)(void *dst, int value, size_t sizeBytes); +typedef hipError_t (*t_hipMemset2D)(void *dst, size_t pitch, int value, + size_t width, size_t height); +typedef hipError_t (*t_hipMemset2DAsync)(void *dst, size_t pitch, int value, + size_t width, size_t height, + hipStream_t stream); +typedef hipError_t (*t_hipMemset3D)(hipPitchedPtr pitchedDevPtr, int value, + hipExtent extent); +typedef hipError_t (*t_hipMemset3DAsync)(hipPitchedPtr pitchedDevPtr, int value, + hipExtent extent, hipStream_t stream); +typedef hipError_t (*t_hipMemsetAsync)(void *dst, int value, size_t sizeBytes, + hipStream_t stream); +typedef hipError_t (*t_hipMemsetD16)(hipDeviceptr_t dest, unsigned short value, + size_t count); +typedef hipError_t (*t_hipMemsetD16Async)(hipDeviceptr_t dest, + unsigned short value, size_t count, + hipStream_t stream); +typedef hipError_t (*t_hipMemsetD32)(hipDeviceptr_t dest, int value, + size_t count); +typedef hipError_t (*t_hipMemsetD32Async)(hipDeviceptr_t dst, int value, + size_t count, hipStream_t stream); +typedef hipError_t (*t_hipMemsetD8)(hipDeviceptr_t dest, unsigned char value, + size_t count); +typedef hipError_t (*t_hipMemsetD8Async)(hipDeviceptr_t dest, + unsigned char value, size_t count, + hipStream_t stream); +typedef hipError_t (*t_hipMipmappedArrayCreate)( + hipMipmappedArray_t *pHandle, HIP_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc, + unsigned int numMipmapLevels); +typedef hipError_t (*t_hipMipmappedArrayDestroy)( + hipMipmappedArray_t hMipmappedArray); +typedef hipError_t (*t_hipMipmappedArrayGetLevel)( + hipArray_t *pLevelArray, hipMipmappedArray_t hMipMappedArray, + unsigned int level); +typedef hipError_t (*t_hipModuleGetFunction)(hipFunction_t *function, + hipModule_t module, + const char *kname); +typedef hipError_t (*t_hipModuleGetGlobal)(hipDeviceptr_t *dptr, size_t *bytes, + hipModule_t hmod, const char *name); +typedef hipError_t (*t_hipModuleGetTexRef)(textureReference **texRef, + hipModule_t hmod, const char *name); +typedef hipError_t (*t_hipModuleLaunchCooperativeKernel)( + hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t stream, + void **kernelParams); +typedef hipError_t (*t_hipModuleLaunchCooperativeKernelMultiDevice)( + hipFunctionLaunchParams *launchParamsList, unsigned int numDevices, + unsigned int flags); +typedef hipError_t (*t_hipModuleLaunchKernel)( + hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t stream, + void **kernelParams, void **extra); +typedef hipError_t (*t_hipModuleLoad)(hipModule_t *module, const char *fname); +typedef hipError_t (*t_hipModuleLoadData)(hipModule_t *module, + const void *image); +typedef hipError_t (*t_hipModuleLoadDataEx)(hipModule_t *module, + const void *image, + unsigned int numOptions, + hipJitOption *options, + void **optionValues); +typedef hipError_t (*t_hipModuleOccupancyMaxActiveBlocksPerMultiprocessor)( + int *numBlocks, hipFunction_t f, int blockSize, size_t dynSharedMemPerBlk); +typedef hipError_t ( + *t_hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags)( + int *numBlocks, hipFunction_t f, int blockSize, size_t dynSharedMemPerBlk, + unsigned int flags); +typedef hipError_t (*t_hipModuleOccupancyMaxPotentialBlockSize)( + int *gridSize, int *blockSize, hipFunction_t f, size_t dynSharedMemPerBlk, + int blockSizeLimit); +typedef hipError_t (*t_hipModuleOccupancyMaxPotentialBlockSizeWithFlags)( + int *gridSize, int *blockSize, hipFunction_t f, size_t dynSharedMemPerBlk, + int blockSizeLimit, unsigned int flags); +typedef hipError_t (*t_hipModuleUnload)(hipModule_t module); +typedef hipError_t (*t_hipOccupancyMaxActiveBlocksPerMultiprocessor)( + int *numBlocks, const void *f, int blockSize, size_t dynSharedMemPerBlk); +typedef hipError_t (*t_hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags)( + int *numBlocks, const void *f, int blockSize, size_t dynSharedMemPerBlk, + unsigned int flags); +typedef hipError_t (*t_hipOccupancyMaxPotentialBlockSize)( + int *gridSize, int *blockSize, const void *f, size_t dynSharedMemPerBlk, + int blockSizeLimit); +typedef hipError_t (*t_hipPeekAtLastError)(void); +typedef hipError_t (*t_hipPointerGetAttribute)(void *data, + hipPointer_attribute attribute, + hipDeviceptr_t ptr); +typedef hipError_t (*t_hipPointerGetAttributes)( + hipPointerAttribute_t *attributes, const void *ptr); +typedef hipError_t (*t_hipPointerSetAttribute)(const void *value, + hipPointer_attribute attribute, + hipDeviceptr_t ptr); +typedef hipError_t (*t_hipProfilerStart)(); +typedef hipError_t (*t_hipProfilerStop)(); +typedef hipError_t (*t_hipRuntimeGetVersion)(int *runtimeVersion); +typedef hipError_t (*t_hipSetDevice)(int deviceId); +typedef hipError_t (*t_hipSetDeviceFlags)(unsigned flags); +typedef hipError_t (*t_hipSetupArgument)(const void *arg, size_t size, + size_t offset); +typedef hipError_t (*t_hipSignalExternalSemaphoresAsync)( + const hipExternalSemaphore_t *extSemArray, + const hipExternalSemaphoreSignalParams *paramsArray, + unsigned int numExtSems, hipStream_t stream); +typedef hipError_t (*t_hipStreamAddCallback)(hipStream_t stream, + hipStreamCallback_t callback, + void *userData, + unsigned int flags); +typedef hipError_t (*t_hipStreamAttachMemAsync)(hipStream_t stream, + void *dev_ptr, size_t length, + unsigned int flags); +typedef hipError_t (*t_hipStreamBeginCapture)(hipStream_t stream, + hipStreamCaptureMode mode); +typedef hipError_t (*t_hipStreamCreate)(hipStream_t *stream); +typedef hipError_t (*t_hipStreamCreateWithFlags)(hipStream_t *stream, + unsigned int flags); +typedef hipError_t (*t_hipStreamCreateWithPriority)(hipStream_t *stream, + unsigned int flags, + int priority); +typedef hipError_t (*t_hipStreamDestroy)(hipStream_t stream); +typedef hipError_t (*t_hipStreamEndCapture)(hipStream_t stream, + hipGraph_t *pGraph); +typedef hipError_t (*t_hipStreamGetCaptureInfo)( + hipStream_t stream, hipStreamCaptureStatus *pCaptureStatus, + unsigned long long *pId); +typedef hipError_t (*t_hipStreamGetCaptureInfo_v2)( + hipStream_t stream, hipStreamCaptureStatus *captureStatus_out, + unsigned long long *id_out, hipGraph_t *graph_out, + const hipGraphNode_t **dependencies_out, size_t *numDependencies_out); +typedef hipError_t (*t_hipStreamGetDevice)(hipStream_t stream, + hipDevice_t *device); +typedef hipError_t (*t_hipStreamGetFlags)(hipStream_t stream, + unsigned int *flags); +typedef hipError_t (*t_hipStreamGetPriority)(hipStream_t stream, int *priority); +typedef hipError_t (*t_hipStreamIsCapturing)( + hipStream_t stream, hipStreamCaptureStatus *pCaptureStatus); +typedef hipError_t (*t_hipStreamQuery)(hipStream_t stream); +typedef hipError_t (*t_hipStreamSynchronize)(hipStream_t stream); +typedef hipError_t (*t_hipStreamUpdateCaptureDependencies)( + hipStream_t stream, hipGraphNode_t *dependencies, size_t numDependencies, + unsigned int flags); +typedef hipError_t (*t_hipStreamWaitEvent)(hipStream_t stream, hipEvent_t event, + unsigned int flags); +typedef hipError_t (*t_hipStreamWaitValue32)(hipStream_t stream, void *ptr, + uint32_t value, unsigned int flags, + uint32_t mask); +typedef hipError_t (*t_hipStreamWaitValue64)(hipStream_t stream, void *ptr, + uint64_t value, unsigned int flags, + uint64_t mask); +typedef hipError_t (*t_hipStreamWriteValue32)(hipStream_t stream, void *ptr, + uint32_t value, + unsigned int flags); +typedef hipError_t (*t_hipStreamWriteValue64)(hipStream_t stream, void *ptr, + uint64_t value, + unsigned int flags); +typedef hipError_t (*t_hipTexObjectCreate)( + hipTextureObject_t *pTexObject, const HIP_RESOURCE_DESC *pResDesc, + const HIP_TEXTURE_DESC *pTexDesc, + const HIP_RESOURCE_VIEW_DESC *pResViewDesc); +typedef hipError_t (*t_hipTexObjectDestroy)(hipTextureObject_t texObject); +typedef hipError_t (*t_hipTexObjectGetResourceDesc)( + HIP_RESOURCE_DESC *pResDesc, hipTextureObject_t texObject); +typedef hipError_t (*t_hipTexObjectGetResourceViewDesc)( + HIP_RESOURCE_VIEW_DESC *pResViewDesc, hipTextureObject_t texObject); +typedef hipError_t (*t_hipTexObjectGetTextureDesc)( + HIP_TEXTURE_DESC *pTexDesc, hipTextureObject_t texObject); +typedef hipError_t (*t_hipTexRefGetAddress)(hipDeviceptr_t *dev_ptr, + const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetAddressMode)(enum hipTextureAddressMode *pam, + const textureReference *texRef, + int dim); +typedef hipError_t (*t_hipTexRefGetFilterMode)(enum hipTextureFilterMode *pfm, + const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetFlags)(unsigned int *pFlags, + const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetFormat)(hipArray_Format *pFormat, + int *pNumChannels, + const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetMaxAnisotropy)( + int *pmaxAnsio, const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetMipMappedArray)( + hipMipmappedArray_t *pArray, const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetMipmapFilterMode)( + enum hipTextureFilterMode *pfm, const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetMipmapLevelBias)( + float *pbias, const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetMipmapLevelClamp)( + float *pminMipmapLevelClamp, float *pmaxMipmapLevelClamp, + const textureReference *texRef); +typedef hipError_t (*t_hipTexRefSetAddress)(size_t *ByteOffset, + textureReference *texRef, + hipDeviceptr_t dptr, size_t bytes); +typedef hipError_t (*t_hipTexRefSetAddress2D)(textureReference *texRef, + const HIP_ARRAY_DESCRIPTOR *desc, + hipDeviceptr_t dptr, + size_t Pitch); +typedef hipError_t (*t_hipTexRefSetAddressMode)(textureReference *texRef, + int dim, + enum hipTextureAddressMode am); +typedef hipError_t (*t_hipTexRefSetArray)(textureReference *tex, + hipArray_const_t array, + unsigned int flags); +typedef hipError_t (*t_hipTexRefSetBorderColor)(textureReference *texRef, + float *pBorderColor); +typedef hipError_t (*t_hipTexRefSetFilterMode)(textureReference *texRef, + enum hipTextureFilterMode fm); +typedef hipError_t (*t_hipTexRefSetFlags)(textureReference *texRef, + unsigned int Flags); +typedef hipError_t (*t_hipTexRefSetFormat)(textureReference *texRef, + hipArray_Format fmt, + int NumPackedComponents); +typedef hipError_t (*t_hipTexRefSetMaxAnisotropy)(textureReference *texRef, + unsigned int maxAniso); +typedef hipError_t (*t_hipTexRefSetMipmapFilterMode)( + textureReference *texRef, enum hipTextureFilterMode fm); +typedef hipError_t (*t_hipTexRefSetMipmapLevelBias)(textureReference *texRef, + float bias); +typedef hipError_t (*t_hipTexRefSetMipmapLevelClamp)(textureReference *texRef, + float minMipMapLevelClamp, + float maxMipMapLevelClamp); +typedef hipError_t (*t_hipTexRefSetMipmappedArray)( + textureReference *texRef, struct hipMipmappedArray *mipmappedArray, + unsigned int Flags); +typedef hipError_t (*t_hipThreadExchangeStreamCaptureMode)( + hipStreamCaptureMode *mode); +typedef hipError_t (*t_hipUnbindTexture)(const textureReference *tex); +typedef hipError_t (*t_hipUserObjectCreate)(hipUserObject_t *object_out, + void *ptr, hipHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags); +typedef hipError_t (*t_hipUserObjectRelease)(hipUserObject_t object, + unsigned int count); +typedef hipError_t (*t_hipUserObjectRetain)(hipUserObject_t object, + unsigned int count); +typedef hipError_t (*t_hipWaitExternalSemaphoresAsync)( + const hipExternalSemaphore_t *extSemArray, + const hipExternalSemaphoreWaitParams *paramsArray, unsigned int numExtSems, + hipStream_t stream); + +typedef hipError_t (*t_hipMemcpy_spt)(void *dst, const void *src, + size_t sizeBytes, hipMemcpyKind kind); + +typedef hipError_t (*t_hipMemcpyToSymbol_spt)(const void *symbol, + const void *src, size_t sizeBytes, + size_t offset, + hipMemcpyKind kind); + +typedef hipError_t (*t_hipMemcpyFromSymbol_spt)(void *dst, const void *symbol, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind); + +typedef hipError_t (*t_hipMemcpy2D_spt)(void *dst, size_t dpitch, + const void *src, size_t spitch, + size_t width, size_t height, + hipMemcpyKind kind); + +typedef hipError_t (*t_hipMemcpy2DFromArray_spt)(void *dst, size_t dpitch, + hipArray_const_t src, + size_t wOffset, size_t hOffset, + size_t width, size_t height, + hipMemcpyKind kind); + +typedef hipError_t (*t_hipMemcpy3D_spt)(const struct hipMemcpy3DParms *p); + +typedef hipError_t (*t_hipMemset_spt)(void *dst, int value, size_t sizeBytes); + +typedef hipError_t (*t_hipMemsetAsync_spt)(void *dst, int value, + size_t sizeBytes, + hipStream_t stream); + +typedef hipError_t (*t_hipMemset2D_spt)(void *dst, size_t pitch, int value, + size_t width, size_t height); + +typedef hipError_t (*t_hipMemset2DAsync_spt)(void *dst, size_t pitch, int value, + size_t width, size_t height, + hipStream_t stream); + +typedef hipError_t (*t_hipMemset3DAsync_spt)(hipPitchedPtr pitchedDevPtr, + int value, hipExtent extent, + hipStream_t stream); + +typedef hipError_t (*t_hipMemset3D_spt)(hipPitchedPtr pitchedDevPtr, int value, + hipExtent extent); + +typedef hipError_t (*t_hipMemcpyAsync_spt)(void *dst, const void *src, + size_t sizeBytes, hipMemcpyKind kind, + hipStream_t stream); + +typedef hipError_t (*t_hipMemcpy3DAsync_spt)(const hipMemcpy3DParms *p, + hipStream_t stream); + +typedef hipError_t (*t_hipMemcpy2DAsync_spt)(void *dst, size_t dpitch, + const void *src, size_t spitch, + size_t width, size_t height, + hipMemcpyKind kind, + hipStream_t stream); + +typedef hipError_t (*t_hipMemcpyFromSymbolAsync_spt)( + void *dst, const void *symbol, size_t sizeBytes, size_t offset, + hipMemcpyKind kind, hipStream_t stream); + +typedef hipError_t (*t_hipMemcpyToSymbolAsync_spt)( + const void *symbol, const void *src, size_t sizeBytes, size_t offset, + hipMemcpyKind kind, hipStream_t stream); + +typedef hipError_t (*t_hipMemcpyFromArray_spt)(void *dst, hipArray_const_t src, + size_t wOffsetSrc, + size_t hOffset, size_t count, + hipMemcpyKind kind); + +typedef hipError_t (*t_hipMemcpy2DToArray_spt)(hipArray_t dst, size_t wOffset, + size_t hOffset, const void *src, + size_t spitch, size_t width, + size_t height, + hipMemcpyKind kind); + +typedef hipError_t (*t_hipMemcpy2DFromArrayAsync_spt)( + void *dst, size_t dpitch, hipArray_const_t src, size_t wOffsetSrc, + size_t hOffsetSrc, size_t width, size_t height, hipMemcpyKind kind, + hipStream_t stream); + +typedef hipError_t (*t_hipMemcpy2DToArrayAsync_spt)( + hipArray_t dst, size_t wOffset, size_t hOffset, const void *src, + size_t spitch, size_t width, size_t height, hipMemcpyKind kind, + hipStream_t stream); + +typedef hipError_t (*t_hipStreamQuery_spt)(hipStream_t stream); + +typedef hipError_t (*t_hipStreamSynchronize_spt)(hipStream_t stream); + +typedef hipError_t (*t_hipStreamGetPriority_spt)(hipStream_t stream, + int *priority); + +typedef hipError_t (*t_hipStreamWaitEvent_spt)(hipStream_t stream, + hipEvent_t event, + unsigned int flags); + +typedef hipError_t (*t_hipStreamGetFlags_spt)(hipStream_t stream, + unsigned int *flags); + +typedef hipError_t (*t_hipStreamAddCallback_spt)(hipStream_t stream, + hipStreamCallback_t callback, + void *userData, + unsigned int flags); +typedef hipError_t (*t_hipEventRecord_spt)(hipEvent_t event, + hipStream_t stream); +typedef hipError_t (*t_hipLaunchCooperativeKernel_spt)( + const void *f, dim3 gridDim, dim3 blockDim, void **kernelParams, + uint32_t sharedMemBytes, hipStream_t hStream); + +typedef hipError_t (*t_hipLaunchKernel_spt)(const void *function_address, + dim3 numBlocks, dim3 dimBlocks, + void **args, size_t sharedMemBytes, + hipStream_t stream); + +typedef hipError_t (*t_hipGraphLaunch_spt)(hipGraphExec_t graphExec, + hipStream_t stream); +typedef hipError_t (*t_hipStreamBeginCapture_spt)(hipStream_t stream, + hipStreamCaptureMode mode); +typedef hipError_t (*t_hipStreamEndCapture_spt)(hipStream_t stream, + hipGraph_t *pGraph); +typedef hipError_t (*t_hipStreamIsCapturing_spt)( + hipStream_t stream, hipStreamCaptureStatus *pCaptureStatus); +typedef hipError_t (*t_hipStreamGetCaptureInfo_spt)( + hipStream_t stream, hipStreamCaptureStatus *pCaptureStatus, + unsigned long long *pId); +typedef hipError_t (*t_hipStreamGetCaptureInfo_v2_spt)( + hipStream_t stream, hipStreamCaptureStatus *captureStatus_out, + unsigned long long *id_out, hipGraph_t *graph_out, + const hipGraphNode_t **dependencies_out, size_t *numDependencies_out); +typedef hipError_t (*t_hipLaunchHostFunc_spt)(hipStream_t stream, + hipHostFn_t fn, void *userData); +typedef hipChannelFormatDesc (*t_hipCreateChannelDesc)(int x, int y, int z, + int w, + hipChannelFormatKind f); +typedef hipError_t (*t_hipExtModuleLaunchKernel)( + hipFunction_t f, uint32_t globalWorkSizeX, uint32_t globalWorkSizeY, + uint32_t globalWorkSizeZ, uint32_t localWorkSizeX, uint32_t localWorkSizeY, + uint32_t localWorkSizeZ, size_t sharedMemBytes, hipStream_t hStream, + void **kernelParams, void **extra, hipEvent_t startEvent, + hipEvent_t stopEvent, uint32_t flags); +typedef hipError_t (*t_hipHccModuleLaunchKernel)( + hipFunction_t f, uint32_t globalWorkSizeX, uint32_t globalWorkSizeY, + uint32_t globalWorkSizeZ, uint32_t localWorkSizeX, uint32_t localWorkSizeY, + uint32_t localWorkSizeZ, size_t sharedMemBytes, hipStream_t hStream, + void **kernelParams, void **extra, hipEvent_t startEvent, + hipEvent_t stopEvent); +typedef int (*t_hipGetStreamDeviceId)(hipStream_t stream); +typedef hipError_t (*t_hipDrvGraphAddMemsetNode)( + hipGraphNode_t *phGraphNode, hipGraph_t hGraph, + const hipGraphNode_t *dependencies, size_t numDependencies, + const HIP_MEMSET_NODE_PARAMS *memsetParams, hipCtx_t ctx); +typedef hipError_t (*t_hipGraphAddExternalSemaphoresWaitNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const hipExternalSemaphoreWaitNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphAddExternalSemaphoresSignalNode)( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const hipExternalSemaphoreSignalNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphExternalSemaphoresSignalNodeSetParams)( + hipGraphNode_t hNode, + const hipExternalSemaphoreSignalNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphExternalSemaphoresWaitNodeSetParams)( + hipGraphNode_t hNode, const hipExternalSemaphoreWaitNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphExternalSemaphoresSignalNodeGetParams)( + hipGraphNode_t hNode, hipExternalSemaphoreSignalNodeParams *params_out); +typedef hipError_t (*t_hipGraphExternalSemaphoresWaitNodeGetParams)( + hipGraphNode_t hNode, hipExternalSemaphoreWaitNodeParams *params_out); +typedef hipError_t (*t_hipGraphExecExternalSemaphoresSignalNodeSetParams)( + hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const hipExternalSemaphoreSignalNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphExecExternalSemaphoresWaitNodeSetParams)( + hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const hipExternalSemaphoreWaitNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphAddNode)(hipGraphNode_t *pGraphNode, + hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + hipGraphNodeParams *nodeParams); +typedef hipError_t (*t_hipGraphInstantiateWithParams)( + hipGraphExec_t *pGraphExec, hipGraph_t graph, + hipGraphInstantiateParams *instantiateParams); +typedef hipError_t (*t_hipExtGetLastError)(); +typedef hipError_t (*t_hipTexRefGetBorderColor)(float *pBorderColor, + const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetArray)(hipArray_t *pArray, + const textureReference *texRef); + +typedef hipError_t (*t_hipTexRefGetBorderColor)(float *pBorderColor, + const textureReference *texRef); +typedef hipError_t (*t_hipTexRefGetArray)(hipArray_t *pArray, + const textureReference *texRef); +typedef hipError_t (*t_hipGetProcAddress)( + const char *symbol, void **pfn, int hipVersion, uint64_t flags, + hipDriverProcAddressQueryResult *symbolStatus); +typedef hipError_t (*t_hipStreamBeginCaptureToGraph)( + hipStream_t stream, hipGraph_t graph, const hipGraphNode_t *dependencies, + const hipGraphEdgeData *dependencyData, size_t numDependencies, + hipStreamCaptureMode mode); +typedef hipError_t (*t_hipGetFuncBySymbol)(hipFunction_t *functionPtr, + const void *symbolPtr); +typedef hipError_t (*t_hipSetValidDevices)(int *device_arr, int len); +typedef hipError_t (*t_hipMemcpyAtoD)(hipDeviceptr_t dstDevice, + hipArray_t srcArray, size_t srcOffset, + size_t ByteCount); +typedef hipError_t (*t_hipMemcpyDtoA)(hipArray_t dstArray, size_t dstOffset, + hipDeviceptr_t srcDevice, + size_t ByteCount); +typedef hipError_t (*t_hipMemcpyAtoA)(hipArray_t dstArray, size_t dstOffset, + hipArray_t srcArray, size_t srcOffset, + size_t ByteCount); +typedef hipError_t (*t_hipMemcpyAtoHAsync)(void *dstHost, hipArray_t srcArray, + size_t srcOffset, size_t ByteCount, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpyHtoAAsync)(hipArray_t dstArray, + size_t dstOffset, + const void *srcHost, + size_t ByteCount, + hipStream_t stream); +typedef hipError_t (*t_hipMemcpy2DArrayToArray)( + hipArray_t dst, size_t wOffsetDst, size_t hOffsetDst, hipArray_const_t src, + size_t wOffsetSrc, size_t hOffsetSrc, size_t width, size_t height, + hipMemcpyKind kind); + +// HIP Compiler dispatch table +struct HipCompilerDispatchTable { + size_t size; + t___hipPopCallConfiguration __hipPopCallConfiguration_fn; + t___hipPushCallConfiguration __hipPushCallConfiguration_fn; + t___hipRegisterFatBinary __hipRegisterFatBinary_fn; + t___hipRegisterFunction __hipRegisterFunction_fn; + t___hipRegisterManagedVar __hipRegisterManagedVar_fn; + t___hipRegisterSurface __hipRegisterSurface_fn; + t___hipRegisterTexture __hipRegisterTexture_fn; + t___hipRegisterVar __hipRegisterVar_fn; + t___hipUnregisterFatBinary __hipUnregisterFatBinary_fn; +}; + +// HIP API dispatch table +struct HipDispatchTable { + size_t size; + t_hipApiName hipApiName_fn; + t_hipArray3DCreate hipArray3DCreate_fn; + t_hipArray3DGetDescriptor hipArray3DGetDescriptor_fn; + t_hipArrayCreate hipArrayCreate_fn; + t_hipArrayDestroy hipArrayDestroy_fn; + t_hipArrayGetDescriptor hipArrayGetDescriptor_fn; + t_hipArrayGetInfo hipArrayGetInfo_fn; + t_hipBindTexture hipBindTexture_fn; + t_hipBindTexture2D hipBindTexture2D_fn; + t_hipBindTextureToArray hipBindTextureToArray_fn; + t_hipBindTextureToMipmappedArray hipBindTextureToMipmappedArray_fn; + t_hipChooseDevice hipChooseDevice_fn; + t_hipChooseDeviceR0000 hipChooseDeviceR0000_fn; + t_hipConfigureCall hipConfigureCall_fn; + t_hipCreateSurfaceObject hipCreateSurfaceObject_fn; + t_hipCreateTextureObject hipCreateTextureObject_fn; + t_hipCtxCreate hipCtxCreate_fn; + t_hipCtxDestroy hipCtxDestroy_fn; + t_hipCtxDisablePeerAccess hipCtxDisablePeerAccess_fn; + t_hipCtxEnablePeerAccess hipCtxEnablePeerAccess_fn; + t_hipCtxGetApiVersion hipCtxGetApiVersion_fn; + t_hipCtxGetCacheConfig hipCtxGetCacheConfig_fn; + t_hipCtxGetCurrent hipCtxGetCurrent_fn; + t_hipCtxGetDevice hipCtxGetDevice_fn; + t_hipCtxGetFlags hipCtxGetFlags_fn; + t_hipCtxGetSharedMemConfig hipCtxGetSharedMemConfig_fn; + t_hipCtxPopCurrent hipCtxPopCurrent_fn; + t_hipCtxPushCurrent hipCtxPushCurrent_fn; + t_hipCtxSetCacheConfig hipCtxSetCacheConfig_fn; + t_hipCtxSetCurrent hipCtxSetCurrent_fn; + t_hipCtxSetSharedMemConfig hipCtxSetSharedMemConfig_fn; + t_hipCtxSynchronize hipCtxSynchronize_fn; + t_hipDestroyExternalMemory hipDestroyExternalMemory_fn; + t_hipDestroyExternalSemaphore hipDestroyExternalSemaphore_fn; + t_hipDestroySurfaceObject hipDestroySurfaceObject_fn; + t_hipDestroyTextureObject hipDestroyTextureObject_fn; + t_hipDeviceCanAccessPeer hipDeviceCanAccessPeer_fn; + t_hipDeviceComputeCapability hipDeviceComputeCapability_fn; + t_hipDeviceDisablePeerAccess hipDeviceDisablePeerAccess_fn; + t_hipDeviceEnablePeerAccess hipDeviceEnablePeerAccess_fn; + t_hipDeviceGet hipDeviceGet_fn; + t_hipDeviceGetAttribute hipDeviceGetAttribute_fn; + t_hipDeviceGetByPCIBusId hipDeviceGetByPCIBusId_fn; + t_hipDeviceGetCacheConfig hipDeviceGetCacheConfig_fn; + t_hipDeviceGetDefaultMemPool hipDeviceGetDefaultMemPool_fn; + t_hipDeviceGetGraphMemAttribute hipDeviceGetGraphMemAttribute_fn; + t_hipDeviceGetLimit hipDeviceGetLimit_fn; + t_hipDeviceGetMemPool hipDeviceGetMemPool_fn; + t_hipDeviceGetName hipDeviceGetName_fn; + t_hipDeviceGetP2PAttribute hipDeviceGetP2PAttribute_fn; + t_hipDeviceGetPCIBusId hipDeviceGetPCIBusId_fn; + t_hipDeviceGetSharedMemConfig hipDeviceGetSharedMemConfig_fn; + t_hipDeviceGetStreamPriorityRange hipDeviceGetStreamPriorityRange_fn; + t_hipDeviceGetUuid hipDeviceGetUuid_fn; + t_hipDeviceGraphMemTrim hipDeviceGraphMemTrim_fn; + t_hipDevicePrimaryCtxGetState hipDevicePrimaryCtxGetState_fn; + t_hipDevicePrimaryCtxRelease hipDevicePrimaryCtxRelease_fn; + t_hipDevicePrimaryCtxReset hipDevicePrimaryCtxReset_fn; + t_hipDevicePrimaryCtxRetain hipDevicePrimaryCtxRetain_fn; + t_hipDevicePrimaryCtxSetFlags hipDevicePrimaryCtxSetFlags_fn; + t_hipDeviceReset hipDeviceReset_fn; + t_hipDeviceSetCacheConfig hipDeviceSetCacheConfig_fn; + t_hipDeviceSetGraphMemAttribute hipDeviceSetGraphMemAttribute_fn; + t_hipDeviceSetLimit hipDeviceSetLimit_fn; + t_hipDeviceSetMemPool hipDeviceSetMemPool_fn; + t_hipDeviceSetSharedMemConfig hipDeviceSetSharedMemConfig_fn; + t_hipDeviceSynchronize hipDeviceSynchronize_fn; + t_hipDeviceTotalMem hipDeviceTotalMem_fn; + t_hipDriverGetVersion hipDriverGetVersion_fn; + t_hipDrvGetErrorName hipDrvGetErrorName_fn; + t_hipDrvGetErrorString hipDrvGetErrorString_fn; + t_hipDrvGraphAddMemcpyNode hipDrvGraphAddMemcpyNode_fn; + t_hipDrvMemcpy2DUnaligned hipDrvMemcpy2DUnaligned_fn; + t_hipDrvMemcpy3D hipDrvMemcpy3D_fn; + t_hipDrvMemcpy3DAsync hipDrvMemcpy3DAsync_fn; + t_hipDrvPointerGetAttributes hipDrvPointerGetAttributes_fn; + t_hipEventCreate hipEventCreate_fn; + t_hipEventCreateWithFlags hipEventCreateWithFlags_fn; + t_hipEventDestroy hipEventDestroy_fn; + t_hipEventElapsedTime hipEventElapsedTime_fn; + t_hipEventQuery hipEventQuery_fn; + t_hipEventRecord hipEventRecord_fn; + t_hipEventSynchronize hipEventSynchronize_fn; + t_hipExtGetLinkTypeAndHopCount hipExtGetLinkTypeAndHopCount_fn; + t_hipExtLaunchKernel hipExtLaunchKernel_fn; + t_hipExtLaunchMultiKernelMultiDevice hipExtLaunchMultiKernelMultiDevice_fn; + t_hipExtMallocWithFlags hipExtMallocWithFlags_fn; + t_hipExtStreamCreateWithCUMask hipExtStreamCreateWithCUMask_fn; + t_hipExtStreamGetCUMask hipExtStreamGetCUMask_fn; + t_hipExternalMemoryGetMappedBuffer hipExternalMemoryGetMappedBuffer_fn; + t_hipFree hipFree_fn; + t_hipFreeArray hipFreeArray_fn; + t_hipFreeAsync hipFreeAsync_fn; + t_hipFreeHost hipFreeHost_fn; + t_hipFreeMipmappedArray hipFreeMipmappedArray_fn; + t_hipFuncGetAttribute hipFuncGetAttribute_fn; + t_hipFuncGetAttributes hipFuncGetAttributes_fn; + t_hipFuncSetAttribute hipFuncSetAttribute_fn; + t_hipFuncSetCacheConfig hipFuncSetCacheConfig_fn; + t_hipFuncSetSharedMemConfig hipFuncSetSharedMemConfig_fn; + t_hipGLGetDevices hipGLGetDevices_fn; + t_hipGetChannelDesc hipGetChannelDesc_fn; + t_hipGetDevice hipGetDevice_fn; + t_hipGetDeviceCount hipGetDeviceCount_fn; + t_hipGetDeviceFlags hipGetDeviceFlags_fn; + t_hipGetDevicePropertiesR0600 hipGetDevicePropertiesR0600_fn; + t_hipGetDevicePropertiesR0000 hipGetDevicePropertiesR0000_fn; + t_hipGetErrorName hipGetErrorName_fn; + t_hipGetErrorString hipGetErrorString_fn; + t_hipGetLastError hipGetLastError_fn; + t_hipGetMipmappedArrayLevel hipGetMipmappedArrayLevel_fn; + t_hipGetSymbolAddress hipGetSymbolAddress_fn; + t_hipGetSymbolSize hipGetSymbolSize_fn; + t_hipGetTextureAlignmentOffset hipGetTextureAlignmentOffset_fn; + t_hipGetTextureObjectResourceDesc hipGetTextureObjectResourceDesc_fn; + t_hipGetTextureObjectResourceViewDesc hipGetTextureObjectResourceViewDesc_fn; + t_hipGetTextureObjectTextureDesc hipGetTextureObjectTextureDesc_fn; + t_hipGetTextureReference hipGetTextureReference_fn; + t_hipGraphAddChildGraphNode hipGraphAddChildGraphNode_fn; + t_hipGraphAddDependencies hipGraphAddDependencies_fn; + t_hipGraphAddEmptyNode hipGraphAddEmptyNode_fn; + t_hipGraphAddEventRecordNode hipGraphAddEventRecordNode_fn; + t_hipGraphAddEventWaitNode hipGraphAddEventWaitNode_fn; + t_hipGraphAddHostNode hipGraphAddHostNode_fn; + t_hipGraphAddKernelNode hipGraphAddKernelNode_fn; + t_hipGraphAddMemAllocNode hipGraphAddMemAllocNode_fn; + t_hipGraphAddMemFreeNode hipGraphAddMemFreeNode_fn; + t_hipGraphAddMemcpyNode hipGraphAddMemcpyNode_fn; + t_hipGraphAddMemcpyNode1D hipGraphAddMemcpyNode1D_fn; + t_hipGraphAddMemcpyNodeFromSymbol hipGraphAddMemcpyNodeFromSymbol_fn; + t_hipGraphAddMemcpyNodeToSymbol hipGraphAddMemcpyNodeToSymbol_fn; + t_hipGraphAddMemsetNode hipGraphAddMemsetNode_fn; + t_hipGraphChildGraphNodeGetGraph hipGraphChildGraphNodeGetGraph_fn; + t_hipGraphClone hipGraphClone_fn; + t_hipGraphCreate hipGraphCreate_fn; + t_hipGraphDebugDotPrint hipGraphDebugDotPrint_fn; + t_hipGraphDestroy hipGraphDestroy_fn; + t_hipGraphDestroyNode hipGraphDestroyNode_fn; + t_hipGraphEventRecordNodeGetEvent hipGraphEventRecordNodeGetEvent_fn; + t_hipGraphEventRecordNodeSetEvent hipGraphEventRecordNodeSetEvent_fn; + t_hipGraphEventWaitNodeGetEvent hipGraphEventWaitNodeGetEvent_fn; + t_hipGraphEventWaitNodeSetEvent hipGraphEventWaitNodeSetEvent_fn; + t_hipGraphExecChildGraphNodeSetParams hipGraphExecChildGraphNodeSetParams_fn; + t_hipGraphExecDestroy hipGraphExecDestroy_fn; + t_hipGraphExecEventRecordNodeSetEvent hipGraphExecEventRecordNodeSetEvent_fn; + t_hipGraphExecEventWaitNodeSetEvent hipGraphExecEventWaitNodeSetEvent_fn; + t_hipGraphExecHostNodeSetParams hipGraphExecHostNodeSetParams_fn; + t_hipGraphExecKernelNodeSetParams hipGraphExecKernelNodeSetParams_fn; + t_hipGraphExecMemcpyNodeSetParams hipGraphExecMemcpyNodeSetParams_fn; + t_hipGraphExecMemcpyNodeSetParams1D hipGraphExecMemcpyNodeSetParams1D_fn; + t_hipGraphExecMemcpyNodeSetParamsFromSymbol + hipGraphExecMemcpyNodeSetParamsFromSymbol_fn; + t_hipGraphExecMemcpyNodeSetParamsToSymbol + hipGraphExecMemcpyNodeSetParamsToSymbol_fn; + t_hipGraphExecMemsetNodeSetParams hipGraphExecMemsetNodeSetParams_fn; + t_hipGraphExecUpdate hipGraphExecUpdate_fn; + t_hipGraphGetEdges hipGraphGetEdges_fn; + t_hipGraphGetNodes hipGraphGetNodes_fn; + t_hipGraphGetRootNodes hipGraphGetRootNodes_fn; + t_hipGraphHostNodeGetParams hipGraphHostNodeGetParams_fn; + t_hipGraphHostNodeSetParams hipGraphHostNodeSetParams_fn; + t_hipGraphInstantiate hipGraphInstantiate_fn; + t_hipGraphInstantiateWithFlags hipGraphInstantiateWithFlags_fn; + t_hipGraphKernelNodeCopyAttributes hipGraphKernelNodeCopyAttributes_fn; + t_hipGraphKernelNodeGetAttribute hipGraphKernelNodeGetAttribute_fn; + t_hipGraphKernelNodeGetParams hipGraphKernelNodeGetParams_fn; + t_hipGraphKernelNodeSetAttribute hipGraphKernelNodeSetAttribute_fn; + t_hipGraphKernelNodeSetParams hipGraphKernelNodeSetParams_fn; + t_hipGraphLaunch hipGraphLaunch_fn; + t_hipGraphMemAllocNodeGetParams hipGraphMemAllocNodeGetParams_fn; + t_hipGraphMemFreeNodeGetParams hipGraphMemFreeNodeGetParams_fn; + t_hipGraphMemcpyNodeGetParams hipGraphMemcpyNodeGetParams_fn; + t_hipGraphMemcpyNodeSetParams hipGraphMemcpyNodeSetParams_fn; + t_hipGraphMemcpyNodeSetParams1D hipGraphMemcpyNodeSetParams1D_fn; + t_hipGraphMemcpyNodeSetParamsFromSymbol + hipGraphMemcpyNodeSetParamsFromSymbol_fn; + t_hipGraphMemcpyNodeSetParamsToSymbol hipGraphMemcpyNodeSetParamsToSymbol_fn; + t_hipGraphMemsetNodeGetParams hipGraphMemsetNodeGetParams_fn; + t_hipGraphMemsetNodeSetParams hipGraphMemsetNodeSetParams_fn; + t_hipGraphNodeFindInClone hipGraphNodeFindInClone_fn; + t_hipGraphNodeGetDependencies hipGraphNodeGetDependencies_fn; + t_hipGraphNodeGetDependentNodes hipGraphNodeGetDependentNodes_fn; + t_hipGraphNodeGetEnabled hipGraphNodeGetEnabled_fn; + t_hipGraphNodeGetType hipGraphNodeGetType_fn; + t_hipGraphNodeSetEnabled hipGraphNodeSetEnabled_fn; + t_hipGraphReleaseUserObject hipGraphReleaseUserObject_fn; + t_hipGraphRemoveDependencies hipGraphRemoveDependencies_fn; + t_hipGraphRetainUserObject hipGraphRetainUserObject_fn; + t_hipGraphUpload hipGraphUpload_fn; + t_hipGraphicsGLRegisterBuffer hipGraphicsGLRegisterBuffer_fn; + t_hipGraphicsGLRegisterImage hipGraphicsGLRegisterImage_fn; + t_hipGraphicsMapResources hipGraphicsMapResources_fn; + t_hipGraphicsResourceGetMappedPointer hipGraphicsResourceGetMappedPointer_fn; + t_hipGraphicsSubResourceGetMappedArray + hipGraphicsSubResourceGetMappedArray_fn; + t_hipGraphicsUnmapResources hipGraphicsUnmapResources_fn; + t_hipGraphicsUnregisterResource hipGraphicsUnregisterResource_fn; + t_hipHostAlloc hipHostAlloc_fn; + t_hipHostFree hipHostFree_fn; + t_hipHostGetDevicePointer hipHostGetDevicePointer_fn; + t_hipHostGetFlags hipHostGetFlags_fn; + t_hipHostMalloc hipHostMalloc_fn; + t_hipHostRegister hipHostRegister_fn; + t_hipHostUnregister hipHostUnregister_fn; + t_hipImportExternalMemory hipImportExternalMemory_fn; + t_hipImportExternalSemaphore hipImportExternalSemaphore_fn; + t_hipInit hipInit_fn; + t_hipIpcCloseMemHandle hipIpcCloseMemHandle_fn; + t_hipIpcGetEventHandle hipIpcGetEventHandle_fn; + t_hipIpcGetMemHandle hipIpcGetMemHandle_fn; + t_hipIpcOpenEventHandle hipIpcOpenEventHandle_fn; + t_hipIpcOpenMemHandle hipIpcOpenMemHandle_fn; + t_hipKernelNameRef hipKernelNameRef_fn; + t_hipKernelNameRefByPtr hipKernelNameRefByPtr_fn; + t_hipLaunchByPtr hipLaunchByPtr_fn; + t_hipLaunchCooperativeKernel hipLaunchCooperativeKernel_fn; + t_hipLaunchCooperativeKernelMultiDevice + hipLaunchCooperativeKernelMultiDevice_fn; + t_hipLaunchHostFunc hipLaunchHostFunc_fn; + t_hipLaunchKernel hipLaunchKernel_fn; + t_hipMalloc hipMalloc_fn; + t_hipMalloc3D hipMalloc3D_fn; + t_hipMalloc3DArray hipMalloc3DArray_fn; + t_hipMallocArray hipMallocArray_fn; + t_hipMallocAsync hipMallocAsync_fn; + t_hipMallocFromPoolAsync hipMallocFromPoolAsync_fn; + t_hipMallocHost hipMallocHost_fn; + t_hipMallocManaged hipMallocManaged_fn; + t_hipMallocMipmappedArray hipMallocMipmappedArray_fn; + t_hipMallocPitch hipMallocPitch_fn; + t_hipMemAddressFree hipMemAddressFree_fn; + t_hipMemAddressReserve hipMemAddressReserve_fn; + t_hipMemAdvise hipMemAdvise_fn; + t_hipMemAllocHost hipMemAllocHost_fn; + t_hipMemAllocPitch hipMemAllocPitch_fn; + t_hipMemCreate hipMemCreate_fn; + t_hipMemExportToShareableHandle hipMemExportToShareableHandle_fn; + t_hipMemGetAccess hipMemGetAccess_fn; + t_hipMemGetAddressRange hipMemGetAddressRange_fn; + t_hipMemGetAllocationGranularity hipMemGetAllocationGranularity_fn; + t_hipMemGetAllocationPropertiesFromHandle + hipMemGetAllocationPropertiesFromHandle_fn; + t_hipMemGetInfo hipMemGetInfo_fn; + t_hipMemImportFromShareableHandle hipMemImportFromShareableHandle_fn; + t_hipMemMap hipMemMap_fn; + t_hipMemMapArrayAsync hipMemMapArrayAsync_fn; + t_hipMemPoolCreate hipMemPoolCreate_fn; + t_hipMemPoolDestroy hipMemPoolDestroy_fn; + t_hipMemPoolExportPointer hipMemPoolExportPointer_fn; + t_hipMemPoolExportToShareableHandle hipMemPoolExportToShareableHandle_fn; + t_hipMemPoolGetAccess hipMemPoolGetAccess_fn; + t_hipMemPoolGetAttribute hipMemPoolGetAttribute_fn; + t_hipMemPoolImportFromShareableHandle hipMemPoolImportFromShareableHandle_fn; + t_hipMemPoolImportPointer hipMemPoolImportPointer_fn; + t_hipMemPoolSetAccess hipMemPoolSetAccess_fn; + t_hipMemPoolSetAttribute hipMemPoolSetAttribute_fn; + t_hipMemPoolTrimTo hipMemPoolTrimTo_fn; + t_hipMemPrefetchAsync hipMemPrefetchAsync_fn; + t_hipMemPtrGetInfo hipMemPtrGetInfo_fn; + t_hipMemRangeGetAttribute hipMemRangeGetAttribute_fn; + t_hipMemRangeGetAttributes hipMemRangeGetAttributes_fn; + t_hipMemRelease hipMemRelease_fn; + t_hipMemRetainAllocationHandle hipMemRetainAllocationHandle_fn; + t_hipMemSetAccess hipMemSetAccess_fn; + t_hipMemUnmap hipMemUnmap_fn; + t_hipMemcpy hipMemcpy_fn; + t_hipMemcpy2D hipMemcpy2D_fn; + t_hipMemcpy2DAsync hipMemcpy2DAsync_fn; + t_hipMemcpy2DFromArray hipMemcpy2DFromArray_fn; + t_hipMemcpy2DFromArrayAsync hipMemcpy2DFromArrayAsync_fn; + t_hipMemcpy2DToArray hipMemcpy2DToArray_fn; + t_hipMemcpy2DToArrayAsync hipMemcpy2DToArrayAsync_fn; + t_hipMemcpy3D hipMemcpy3D_fn; + t_hipMemcpy3DAsync hipMemcpy3DAsync_fn; + t_hipMemcpyAsync hipMemcpyAsync_fn; + t_hipMemcpyAtoH hipMemcpyAtoH_fn; + t_hipMemcpyDtoD hipMemcpyDtoD_fn; + t_hipMemcpyDtoDAsync hipMemcpyDtoDAsync_fn; + t_hipMemcpyDtoH hipMemcpyDtoH_fn; + t_hipMemcpyDtoHAsync hipMemcpyDtoHAsync_fn; + t_hipMemcpyFromArray hipMemcpyFromArray_fn; + t_hipMemcpyFromSymbol hipMemcpyFromSymbol_fn; + t_hipMemcpyFromSymbolAsync hipMemcpyFromSymbolAsync_fn; + t_hipMemcpyHtoA hipMemcpyHtoA_fn; + t_hipMemcpyHtoD hipMemcpyHtoD_fn; + t_hipMemcpyHtoDAsync hipMemcpyHtoDAsync_fn; + t_hipMemcpyParam2D hipMemcpyParam2D_fn; + t_hipMemcpyParam2DAsync hipMemcpyParam2DAsync_fn; + t_hipMemcpyPeer hipMemcpyPeer_fn; + t_hipMemcpyPeerAsync hipMemcpyPeerAsync_fn; + t_hipMemcpyToArray hipMemcpyToArray_fn; + t_hipMemcpyToSymbol hipMemcpyToSymbol_fn; + t_hipMemcpyToSymbolAsync hipMemcpyToSymbolAsync_fn; + t_hipMemcpyWithStream hipMemcpyWithStream_fn; + t_hipMemset hipMemset_fn; + t_hipMemset2D hipMemset2D_fn; + t_hipMemset2DAsync hipMemset2DAsync_fn; + t_hipMemset3D hipMemset3D_fn; + t_hipMemset3DAsync hipMemset3DAsync_fn; + t_hipMemsetAsync hipMemsetAsync_fn; + t_hipMemsetD16 hipMemsetD16_fn; + t_hipMemsetD16Async hipMemsetD16Async_fn; + t_hipMemsetD32 hipMemsetD32_fn; + t_hipMemsetD32Async hipMemsetD32Async_fn; + t_hipMemsetD8 hipMemsetD8_fn; + t_hipMemsetD8Async hipMemsetD8Async_fn; + t_hipMipmappedArrayCreate hipMipmappedArrayCreate_fn; + t_hipMipmappedArrayDestroy hipMipmappedArrayDestroy_fn; + t_hipMipmappedArrayGetLevel hipMipmappedArrayGetLevel_fn; + t_hipModuleGetFunction hipModuleGetFunction_fn; + t_hipModuleGetGlobal hipModuleGetGlobal_fn; + t_hipModuleGetTexRef hipModuleGetTexRef_fn; + t_hipModuleLaunchCooperativeKernel hipModuleLaunchCooperativeKernel_fn; + t_hipModuleLaunchCooperativeKernelMultiDevice + hipModuleLaunchCooperativeKernelMultiDevice_fn; + t_hipModuleLaunchKernel hipModuleLaunchKernel_fn; + t_hipModuleLoad hipModuleLoad_fn; + t_hipModuleLoadData hipModuleLoadData_fn; + t_hipModuleLoadDataEx hipModuleLoadDataEx_fn; + t_hipModuleOccupancyMaxActiveBlocksPerMultiprocessor + hipModuleOccupancyMaxActiveBlocksPerMultiprocessor_fn; + t_hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_fn; + t_hipModuleOccupancyMaxPotentialBlockSize + hipModuleOccupancyMaxPotentialBlockSize_fn; + t_hipModuleOccupancyMaxPotentialBlockSizeWithFlags + hipModuleOccupancyMaxPotentialBlockSizeWithFlags_fn; + t_hipModuleUnload hipModuleUnload_fn; + t_hipOccupancyMaxActiveBlocksPerMultiprocessor + hipOccupancyMaxActiveBlocksPerMultiprocessor_fn; + t_hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_fn; + t_hipOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize_fn; + t_hipPeekAtLastError hipPeekAtLastError_fn; + t_hipPointerGetAttribute hipPointerGetAttribute_fn; + t_hipPointerGetAttributes hipPointerGetAttributes_fn; + t_hipPointerSetAttribute hipPointerSetAttribute_fn; + t_hipProfilerStart hipProfilerStart_fn; + t_hipProfilerStop hipProfilerStop_fn; + t_hipRuntimeGetVersion hipRuntimeGetVersion_fn; + t_hipSetDevice hipSetDevice_fn; + t_hipSetDeviceFlags hipSetDeviceFlags_fn; + t_hipSetupArgument hipSetupArgument_fn; + t_hipSignalExternalSemaphoresAsync hipSignalExternalSemaphoresAsync_fn; + t_hipStreamAddCallback hipStreamAddCallback_fn; + t_hipStreamAttachMemAsync hipStreamAttachMemAsync_fn; + t_hipStreamBeginCapture hipStreamBeginCapture_fn; + t_hipStreamCreate hipStreamCreate_fn; + t_hipStreamCreateWithFlags hipStreamCreateWithFlags_fn; + t_hipStreamCreateWithPriority hipStreamCreateWithPriority_fn; + t_hipStreamDestroy hipStreamDestroy_fn; + t_hipStreamEndCapture hipStreamEndCapture_fn; + t_hipStreamGetCaptureInfo hipStreamGetCaptureInfo_fn; + t_hipStreamGetCaptureInfo_v2 hipStreamGetCaptureInfo_v2_fn; + t_hipStreamGetDevice hipStreamGetDevice_fn; + t_hipStreamGetFlags hipStreamGetFlags_fn; + t_hipStreamGetPriority hipStreamGetPriority_fn; + t_hipStreamIsCapturing hipStreamIsCapturing_fn; + t_hipStreamQuery hipStreamQuery_fn; + t_hipStreamSynchronize hipStreamSynchronize_fn; + t_hipStreamUpdateCaptureDependencies hipStreamUpdateCaptureDependencies_fn; + t_hipStreamWaitEvent hipStreamWaitEvent_fn; + t_hipStreamWaitValue32 hipStreamWaitValue32_fn; + t_hipStreamWaitValue64 hipStreamWaitValue64_fn; + t_hipStreamWriteValue32 hipStreamWriteValue32_fn; + t_hipStreamWriteValue64 hipStreamWriteValue64_fn; + t_hipTexObjectCreate hipTexObjectCreate_fn; + t_hipTexObjectDestroy hipTexObjectDestroy_fn; + t_hipTexObjectGetResourceDesc hipTexObjectGetResourceDesc_fn; + t_hipTexObjectGetResourceViewDesc hipTexObjectGetResourceViewDesc_fn; + t_hipTexObjectGetTextureDesc hipTexObjectGetTextureDesc_fn; + t_hipTexRefGetAddress hipTexRefGetAddress_fn; + t_hipTexRefGetAddressMode hipTexRefGetAddressMode_fn; + t_hipTexRefGetFilterMode hipTexRefGetFilterMode_fn; + t_hipTexRefGetFlags hipTexRefGetFlags_fn; + t_hipTexRefGetFormat hipTexRefGetFormat_fn; + t_hipTexRefGetMaxAnisotropy hipTexRefGetMaxAnisotropy_fn; + t_hipTexRefGetMipMappedArray hipTexRefGetMipMappedArray_fn; + t_hipTexRefGetMipmapFilterMode hipTexRefGetMipmapFilterMode_fn; + t_hipTexRefGetMipmapLevelBias hipTexRefGetMipmapLevelBias_fn; + t_hipTexRefGetMipmapLevelClamp hipTexRefGetMipmapLevelClamp_fn; + t_hipTexRefSetAddress hipTexRefSetAddress_fn; + t_hipTexRefSetAddress2D hipTexRefSetAddress2D_fn; + t_hipTexRefSetAddressMode hipTexRefSetAddressMode_fn; + t_hipTexRefSetArray hipTexRefSetArray_fn; + t_hipTexRefSetBorderColor hipTexRefSetBorderColor_fn; + t_hipTexRefSetFilterMode hipTexRefSetFilterMode_fn; + t_hipTexRefSetFlags hipTexRefSetFlags_fn; + t_hipTexRefSetFormat hipTexRefSetFormat_fn; + t_hipTexRefSetMaxAnisotropy hipTexRefSetMaxAnisotropy_fn; + t_hipTexRefSetMipmapFilterMode hipTexRefSetMipmapFilterMode_fn; + t_hipTexRefSetMipmapLevelBias hipTexRefSetMipmapLevelBias_fn; + t_hipTexRefSetMipmapLevelClamp hipTexRefSetMipmapLevelClamp_fn; + t_hipTexRefSetMipmappedArray hipTexRefSetMipmappedArray_fn; + t_hipThreadExchangeStreamCaptureMode hipThreadExchangeStreamCaptureMode_fn; + t_hipUnbindTexture hipUnbindTexture_fn; + t_hipUserObjectCreate hipUserObjectCreate_fn; + t_hipUserObjectRelease hipUserObjectRelease_fn; + t_hipUserObjectRetain hipUserObjectRetain_fn; + t_hipWaitExternalSemaphoresAsync hipWaitExternalSemaphoresAsync_fn; + t_hipCreateChannelDesc hipCreateChannelDesc_fn; + t_hipExtModuleLaunchKernel hipExtModuleLaunchKernel_fn; + t_hipHccModuleLaunchKernel hipHccModuleLaunchKernel_fn; + t_hipMemcpy_spt hipMemcpy_spt_fn; + t_hipMemcpyToSymbol_spt hipMemcpyToSymbol_spt_fn; + t_hipMemcpyFromSymbol_spt hipMemcpyFromSymbol_spt_fn; + t_hipMemcpy2D_spt hipMemcpy2D_spt_fn; + t_hipMemcpy2DFromArray_spt hipMemcpy2DFromArray_spt_fn; + t_hipMemcpy3D_spt hipMemcpy3D_spt_fn; + t_hipMemset_spt hipMemset_spt_fn; + t_hipMemsetAsync_spt hipMemsetAsync_spt_fn; + t_hipMemset2D_spt hipMemset2D_spt_fn; + t_hipMemset2DAsync_spt hipMemset2DAsync_spt_fn; + t_hipMemset3DAsync_spt hipMemset3DAsync_spt_fn; + t_hipMemset3D_spt hipMemset3D_spt_fn; + t_hipMemcpyAsync_spt hipMemcpyAsync_spt_fn; + t_hipMemcpy3DAsync_spt hipMemcpy3DAsync_spt_fn; + t_hipMemcpy2DAsync_spt hipMemcpy2DAsync_spt_fn; + t_hipMemcpyFromSymbolAsync_spt hipMemcpyFromSymbolAsync_spt_fn; + t_hipMemcpyToSymbolAsync_spt hipMemcpyToSymbolAsync_spt_fn; + t_hipMemcpyFromArray_spt hipMemcpyFromArray_spt_fn; + t_hipMemcpy2DToArray_spt hipMemcpy2DToArray_spt_fn; + t_hipMemcpy2DFromArrayAsync_spt hipMemcpy2DFromArrayAsync_spt_fn; + t_hipMemcpy2DToArrayAsync_spt hipMemcpy2DToArrayAsync_spt_fn; + t_hipStreamQuery_spt hipStreamQuery_spt_fn; + t_hipStreamSynchronize_spt hipStreamSynchronize_spt_fn; + t_hipStreamGetPriority_spt hipStreamGetPriority_spt_fn; + t_hipStreamWaitEvent_spt hipStreamWaitEvent_spt_fn; + t_hipStreamGetFlags_spt hipStreamGetFlags_spt_fn; + t_hipStreamAddCallback_spt hipStreamAddCallback_spt_fn; + t_hipEventRecord_spt hipEventRecord_spt_fn; + t_hipLaunchCooperativeKernel_spt hipLaunchCooperativeKernel_spt_fn; + t_hipLaunchKernel_spt hipLaunchKernel_spt_fn; + t_hipGraphLaunch_spt hipGraphLaunch_spt_fn; + t_hipStreamBeginCapture_spt hipStreamBeginCapture_spt_fn; + t_hipStreamEndCapture_spt hipStreamEndCapture_spt_fn; + t_hipStreamIsCapturing_spt hipStreamIsCapturing_spt_fn; + t_hipStreamGetCaptureInfo_spt hipStreamGetCaptureInfo_spt_fn; + t_hipStreamGetCaptureInfo_v2_spt hipStreamGetCaptureInfo_v2_spt_fn; + t_hipLaunchHostFunc_spt hipLaunchHostFunc_spt_fn; + t_hipGetStreamDeviceId hipGetStreamDeviceId_fn; + t_hipDrvGraphAddMemsetNode hipDrvGraphAddMemsetNode_fn; + t_hipGraphAddExternalSemaphoresWaitNode + hipGraphAddExternalSemaphoresWaitNode_fn; + t_hipGraphAddExternalSemaphoresSignalNode + hipGraphAddExternalSemaphoresSignalNode_fn; + t_hipGraphExternalSemaphoresSignalNodeSetParams + hipGraphExternalSemaphoresSignalNodeSetParams_fn; + t_hipGraphExternalSemaphoresWaitNodeSetParams + hipGraphExternalSemaphoresWaitNodeSetParams_fn; + t_hipGraphExternalSemaphoresSignalNodeGetParams + hipGraphExternalSemaphoresSignalNodeGetParams_fn; + t_hipGraphExternalSemaphoresWaitNodeGetParams + hipGraphExternalSemaphoresWaitNodeGetParams_fn; + t_hipGraphExecExternalSemaphoresSignalNodeSetParams + hipGraphExecExternalSemaphoresSignalNodeSetParams_fn; + t_hipGraphExecExternalSemaphoresWaitNodeSetParams + hipGraphExecExternalSemaphoresWaitNodeSetParams_fn; + t_hipGraphAddNode hipGraphAddNode_fn; + t_hipGraphInstantiateWithParams hipGraphInstantiateWithParams_fn; + t_hipExtGetLastError hipExtGetLastError_fn; + t_hipTexRefGetBorderColor hipTexRefGetBorderColor_fn; + t_hipTexRefGetArray hipTexRefGetArray_fn; + t_hipGetProcAddress hipGetProcAddress_fn; + t_hipStreamBeginCaptureToGraph hipStreamBeginCaptureToGraph_fn; + t_hipGetFuncBySymbol hipGetFuncBySymbol_fn; + t_hipSetValidDevices hipSetValidDevices_fn; + t_hipMemcpyAtoD hipMemcpyAtoD_fn; + t_hipMemcpyDtoA hipMemcpyDtoA_fn; + t_hipMemcpyAtoA hipMemcpyAtoA_fn; + t_hipMemcpyAtoHAsync hipMemcpyAtoHAsync_fn; + t_hipMemcpyHtoAAsync hipMemcpyHtoAAsync_fn; + t_hipMemcpy2DArrayToArray hipMemcpy2DArrayToArray_fn; +}; diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_assert.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_assert.h new file mode 100644 index 000000000..f7fc93055 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_assert.h @@ -0,0 +1,97 @@ +/* +Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +// abort +extern "C" __device__ inline __attribute__((weak)) void abort() { + __builtin_trap(); +} + +// The noinline attribute helps encapsulate the printf expansion, +// which otherwise has a performance impact just by increasing the +// size of the calling function. Additionally, the weak attribute +// allows the function to exist as a global although its definition is +// included in every compilation unit. +#if defined(_WIN32) || defined(_WIN64) +extern "C" __device__ __attribute__((noinline)) __attribute__((weak)) void +_wassert(const wchar_t *_msg, const wchar_t *_file, unsigned _line) { + // FIXME: Need `wchar_t` support to generate assertion message. + __builtin_trap(); +} +#else /* defined(_WIN32) || defined(_WIN64) */ +extern "C" __device__ __attribute__((noinline)) __attribute__((weak)) void +__assert_fail(const char *assertion, const char *file, unsigned int line, + const char *function) { + const char fmt[] = "%s:%u: %s: Device-side assertion `%s' failed.\n"; + + // strlen is not available as a built-in yet, so we create our own + // loop in a macro. With a string literal argument, the compiler + // usually manages to replace the loop with a constant. + // + // The macro does not check for null pointer, since all the string + // arguments are defined to be constant literals when called from + // the assert() macro. + // + // NOTE: The loop below includes the null terminator in the length + // as required by append_string_n(). +#define __hip_get_string_length(LEN, STR) \ + do { \ + const char *tmp = STR; \ + while (*tmp++) \ + ; \ + LEN = tmp - STR; \ + } while (0) + + auto msg = __ockl_fprintf_stderr_begin(); + int len = 0; + __hip_get_string_length(len, fmt); + msg = __ockl_fprintf_append_string_n(msg, fmt, len, 0); + __hip_get_string_length(len, file); + msg = __ockl_fprintf_append_string_n(msg, file, len, 0); + msg = __ockl_fprintf_append_args(msg, 1, line, 0, 0, 0, 0, 0, 0, 0); + __hip_get_string_length(len, function); + msg = __ockl_fprintf_append_string_n(msg, function, len, 0); + __hip_get_string_length(len, assertion); + __ockl_fprintf_append_string_n(msg, assertion, len, /* is_last = */ 1); + +#undef __hip_get_string_length + + __builtin_trap(); +} + +extern "C" __device__ __attribute__((noinline)) __attribute__((weak)) void +__assertfail() { + // ignore all the args for now. + __builtin_trap(); +} +#endif /* defined(_WIN32) || defined(_WIN64) */ + +#if defined(NDEBUG) +#define __hip_assert(COND) +#else +#define __hip_assert(COND) \ + do { \ + if (!(COND)) \ + __builtin_trap(); \ + } while (0) +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_cooperative_groups_helper.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_cooperative_groups_helper.h new file mode 100644 index 000000000..18978c2aa --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_cooperative_groups_helper.h @@ -0,0 +1,266 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file amd_detail/hip_cooperative_groups_helper.h + * + * @brief Device side implementation of cooperative group feature. + * + * Defines helper constructs and APIs which aid the types and device API + * wrappers defined within `amd_detail/hip_cooperative_groups.h`. + */ +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_HELPER_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_HELPER_H + +#if __cplusplus +#if !defined(__HIPCC_RTC__) +#include +#include // threadId, blockId +#endif +#if !defined(__align__) +#define __align__(x) __attribute__((aligned(x))) +#endif + +#if !defined(__CG_QUALIFIER__) +#define __CG_QUALIFIER__ __device__ __forceinline__ +#endif + +#if !defined(__CG_STATIC_QUALIFIER__) +#define __CG_STATIC_QUALIFIER__ __device__ static __forceinline__ +#endif + +#if !defined(_CG_STATIC_CONST_DECL_) +#define _CG_STATIC_CONST_DECL_ static constexpr +#endif + +#if __AMDGCN_WAVEFRONT_SIZE == 32 +using lane_mask = unsigned int; +#else +using lane_mask = unsigned long long int; +#endif + +namespace cooperative_groups { + +/* Global scope */ +template +using is_power_of_2 = std::integral_constant; + +template +using is_valid_wavefront = + std::integral_constant; + +template +using is_valid_tile_size = + std::integral_constant::value && + is_valid_wavefront::value>; + +template +using is_valid_type = + std::integral_constant::value || + std::is_floating_point::value>; + +namespace internal { + +/** + * @brief Enums representing different cooperative group types + * @note This enum is only applicable on Linux. + * + */ +typedef enum { + cg_invalid, + cg_multi_grid, + cg_grid, + cg_workgroup, + cg_tiled_group, + cg_coalesced_group +} group_type; +/** + * @ingroup CooperativeG + * @{ + * This section describes the cooperative groups functions of HIP runtime API. + * + * The cooperative groups provides flexible thread parallel programming + * algorithms, threads cooperate and share data to perform collective + * computations. + * + * @note Cooperative groups feature is implemented on Linux, under + * developement on Windows. + * + */ +/** + * + * @brief Functionalities related to multi-grid cooperative group type + * @note The following cooperative groups functions are only applicable on + * Linux. + * + */ +namespace multi_grid { + +__CG_STATIC_QUALIFIER__ uint32_t num_grids() { + return static_cast(__ockl_multi_grid_num_grids()); +} + +__CG_STATIC_QUALIFIER__ uint32_t grid_rank() { + return static_cast(__ockl_multi_grid_grid_rank()); +} + +__CG_STATIC_QUALIFIER__ uint32_t size() { + return static_cast(__ockl_multi_grid_size()); +} + +__CG_STATIC_QUALIFIER__ uint32_t thread_rank() { + return static_cast(__ockl_multi_grid_thread_rank()); +} + +__CG_STATIC_QUALIFIER__ bool is_valid() { + return static_cast(__ockl_multi_grid_is_valid()); +} + +__CG_STATIC_QUALIFIER__ void sync() { __ockl_multi_grid_sync(); } + +} // namespace multi_grid + +/** + * @brief Functionalities related to grid cooperative group type + * @note The following cooperative groups functions are only applicable on + * Linux. + */ +namespace grid { + +__CG_STATIC_QUALIFIER__ uint32_t size() { + return static_cast((blockDim.z * gridDim.z) * + (blockDim.y * gridDim.y) * + (blockDim.x * gridDim.x)); +} + +__CG_STATIC_QUALIFIER__ uint32_t thread_rank() { + // Compute global id of the workgroup to which the current thread belongs to + uint32_t blkIdx = + static_cast((blockIdx.z * gridDim.y * gridDim.x) + + (blockIdx.y * gridDim.x) + (blockIdx.x)); + + // Compute total number of threads being passed to reach current workgroup + // within grid + uint32_t num_threads_till_current_workgroup = + static_cast(blkIdx * (blockDim.x * blockDim.y * blockDim.z)); + + // Compute thread local rank within current workgroup + uint32_t local_thread_rank = + static_cast((threadIdx.z * blockDim.y * blockDim.x) + + (threadIdx.y * blockDim.x) + (threadIdx.x)); + + return (num_threads_till_current_workgroup + local_thread_rank); +} + +__CG_STATIC_QUALIFIER__ bool is_valid() { + return static_cast(__ockl_grid_is_valid()); +} + +__CG_STATIC_QUALIFIER__ void sync() { __ockl_grid_sync(); } + +} // namespace grid + +/** + * @brief Functionalities related to `workgroup` (thread_block in CUDA + * terminology) cooperative group type + * @note The following cooperative groups functions are only applicable on + * Linux. + */ +namespace workgroup { + +__CG_STATIC_QUALIFIER__ dim3 group_index() { + return (dim3(static_cast(blockIdx.x), + static_cast(blockIdx.y), + static_cast(blockIdx.z))); +} + +__CG_STATIC_QUALIFIER__ dim3 thread_index() { + return (dim3(static_cast(threadIdx.x), + static_cast(threadIdx.y), + static_cast(threadIdx.z))); +} + +__CG_STATIC_QUALIFIER__ uint32_t size() { + return (static_cast(blockDim.x * blockDim.y * blockDim.z)); +} + +__CG_STATIC_QUALIFIER__ uint32_t thread_rank() { + return (static_cast((threadIdx.z * blockDim.y * blockDim.x) + + (threadIdx.y * blockDim.x) + (threadIdx.x))); +} + +__CG_STATIC_QUALIFIER__ bool is_valid() { return true; } + +__CG_STATIC_QUALIFIER__ void sync() { __syncthreads(); } + +__CG_STATIC_QUALIFIER__ dim3 block_dim() { + return (dim3(static_cast(blockDim.x), + static_cast(blockDim.y), + static_cast(blockDim.z))); +} + +} // namespace workgroup + +namespace tiled_group { + +// enforce ordering for memory intructions +__CG_STATIC_QUALIFIER__ void sync() { + __builtin_amdgcn_fence(__ATOMIC_ACQ_REL, "agent"); +} + +} // namespace tiled_group + +namespace coalesced_group { + +// enforce ordering for memory intructions +__CG_STATIC_QUALIFIER__ void sync() { + __builtin_amdgcn_fence(__ATOMIC_ACQ_REL, "agent"); +} + +// Masked bit count +// +// For each thread, this function returns the number of active threads which +// have i-th bit of x set and come before the current thread. +__CG_STATIC_QUALIFIER__ unsigned int masked_bit_count(lane_mask x, + unsigned int add = 0) { + unsigned int counter = 0; +#if __AMDGCN_WAVEFRONT_SIZE == 32 + counter = __builtin_amdgcn_mbcnt_lo(x, add); +#else + counter = __builtin_amdgcn_mbcnt_lo(static_cast(x), add); + counter = __builtin_amdgcn_mbcnt_hi(static_cast(x >> 32), counter); +#endif + + return counter; +} + +} // namespace coalesced_group + +} // namespace internal + +} // namespace cooperative_groups +/** + * @} + */ + +#endif // __cplusplus +#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_HELPER_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_fp16_gcc.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_fp16_gcc.h new file mode 100644 index 000000000..adc0f99d3 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_fp16_gcc.h @@ -0,0 +1,229 @@ +#pragma once + +#if defined(__cplusplus) +#include +#endif + +struct __half_raw { + unsigned short x; +}; + +struct __half2_raw { + unsigned short x; + unsigned short y; +}; + +#if defined(__cplusplus) +struct __half; + +__half __float2half(float); +float __half2float(__half); + +// BEGIN STRUCT __HALF +struct __half { +protected: + unsigned short __x; + +public: + // CREATORS + __half() = default; + __half(const __half_raw &x) : __x{x.x} {} +#if !defined(__HIP_NO_HALF_CONVERSIONS__) + __half(float x) : __x{__float2half(x).__x} {} + __half(double x) : __x{__float2half(x).__x} {} +#endif + __half(const __half &) = default; + __half(__half &&) = default; + ~__half() = default; + + // MANIPULATORS + __half &operator=(const __half &) = default; + __half &operator=(__half &&) = default; + __half &operator=(const __half_raw &x) { + __x = x.x; + return *this; + } +#if !defined(__HIP_NO_HALF_CONVERSIONS__) + __half &operator=(float x) { + __x = __float2half(x).__x; + return *this; + } + __half &operator=(double x) { return *this = static_cast(x); } +#endif + + // ACCESSORS + operator float() const { return __half2float(*this); } + operator __half_raw() const { return __half_raw{__x}; } +}; +// END STRUCT __HALF + +// BEGIN STRUCT __HALF2 +struct __half2 { +public: + __half x; + __half y; + + // CREATORS + __half2() = default; + __half2(const __half2_raw &ix) + : x{reinterpret_cast(ix.x)}, + y{reinterpret_cast(ix.y)} {} + __half2(const __half &ix, const __half &iy) : x{ix}, y{iy} {} + __half2(const __half2 &) = default; + __half2(__half2 &&) = default; + ~__half2() = default; + + // MANIPULATORS + __half2 &operator=(const __half2 &) = default; + __half2 &operator=(__half2 &&) = default; + __half2 &operator=(const __half2_raw &ix) { + x = reinterpret_cast(ix.x); + y = reinterpret_cast(ix.y); + return *this; + } + + // ACCESSORS + operator __half2_raw() const { + return __half2_raw{reinterpret_cast(x), + reinterpret_cast(y)}; + } +}; +// END STRUCT __HALF2 + +inline unsigned short __internal_float2half(float flt, unsigned int &sgn, + unsigned int &rem) { + unsigned int x{}; + std::memcpy(&x, &flt, sizeof(flt)); + + unsigned int u = (x & 0x7fffffffU); + sgn = ((x >> 16) & 0x8000U); + + // NaN/+Inf/-Inf + if (u >= 0x7f800000U) { + rem = 0; + return static_cast((u == 0x7f800000U) ? (sgn | 0x7c00U) + : 0x7fffU); + } + // Overflows + if (u > 0x477fefffU) { + rem = 0x80000000U; + return static_cast(sgn | 0x7bffU); + } + // Normal numbers + if (u >= 0x38800000U) { + rem = u << 19; + u -= 0x38000000U; + return static_cast(sgn | (u >> 13)); + } + // +0/-0 + if (u < 0x33000001U) { + rem = u; + return static_cast(sgn); + } + // Denormal numbers + unsigned int exponent = u >> 23; + unsigned int mantissa = (u & 0x7fffffU); + unsigned int shift = 0x7eU - exponent; + mantissa |= 0x800000U; + rem = mantissa << (32 - shift); + return static_cast(sgn | (mantissa >> shift)); +} + +inline __half __float2half(float x) { + __half_raw r; + unsigned int sgn{}; + unsigned int rem{}; + r.x = __internal_float2half(x, sgn, rem); + if (rem > 0x80000000U || (rem == 0x80000000U && (r.x & 0x1))) + ++r.x; + + return r; +} + +inline __half __float2half_rn(float x) { return __float2half(x); } + +inline __half __float2half_rz(float x) { + __half_raw r; + unsigned int sgn{}; + unsigned int rem{}; + r.x = __internal_float2half(x, sgn, rem); + + return r; +} + +inline __half __float2half_rd(float x) { + __half_raw r; + unsigned int sgn{}; + unsigned int rem{}; + r.x = __internal_float2half(x, sgn, rem); + if (rem && sgn) + ++r.x; + + return r; +} + +inline __half __float2half_ru(float x) { + __half_raw r; + unsigned int sgn{}; + unsigned int rem{}; + r.x = __internal_float2half(x, sgn, rem); + if (rem && !sgn) + ++r.x; + + return r; +} + +inline __half2 __float2half2_rn(float x) { + return __half2{__float2half_rn(x), __float2half_rn(x)}; +} + +inline __half2 __floats2half2_rn(float x, float y) { + return __half2{__float2half_rn(x), __float2half_rn(y)}; +} + +inline float __internal_half2float(unsigned short x) { + unsigned int sign = ((x >> 15) & 1); + unsigned int exponent = ((x >> 10) & 0x1f); + unsigned int mantissa = ((x & 0x3ff) << 13); + + if (exponent == 0x1fU) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffffU) : 0); + exponent = 0xffU; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71U; + do { + msb = (mantissa & 0x400000U); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffffU; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70U; + } + unsigned int u = ((sign << 31) | (exponent << 23) | mantissa); + float f; + memcpy(&f, &u, sizeof(u)); + + return f; +} + +inline float __half2float(__half x) { + return __internal_half2float(static_cast<__half_raw>(x).x); +} + +inline float __low2float(__half2 x) { + return __internal_half2float(static_cast<__half2_raw>(x).x); +} + +inline float __high2float(__half2 x) { + return __internal_half2float(static_cast<__half2_raw>(x).y); +} + +#if !defined(HIP_NO_HALF) +using half = __half; +using half2 = __half2; +#endif +#endif // defined(__cplusplus) diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_fp16_math_fwd.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_fp16_math_fwd.h new file mode 100644 index 000000000..80597d0ec --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_fp16_math_fwd.h @@ -0,0 +1,97 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +// /* +// Half Math Functions +// */ +#if !defined(__HIPCC_RTC__) +#include "host_defines.h" +#endif +#ifndef __CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__ +extern "C" { +__device__ __attribute__((const)) _Float16 __ocml_ceil_f16(_Float16); +__device__ _Float16 __ocml_cos_f16(_Float16); +__device__ __attribute__((pure)) _Float16 __ocml_exp_f16(_Float16); +__device__ __attribute__((pure)) _Float16 __ocml_exp10_f16(_Float16); +__device__ __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16); +__device__ __attribute__((const)) _Float16 __ocml_floor_f16(_Float16); +__device__ __attribute__((const)) _Float16 __ocml_fma_f16(_Float16, _Float16, + _Float16); +__device__ __attribute__((const)) _Float16 __ocml_fabs_f16(_Float16); +__device__ __attribute__((const)) int __ocml_isinf_f16(_Float16); +__device__ __attribute__((const)) int __ocml_isnan_f16(_Float16); +__device__ __attribute__((pure)) _Float16 __ocml_log_f16(_Float16); +__device__ __attribute__((pure)) _Float16 __ocml_log10_f16(_Float16); +__device__ __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16); +__device__ __attribute__((pure)) _Float16 __ocml_pown_f16(_Float16, int); +__device__ __attribute__((const)) _Float16 __ocml_rint_f16(_Float16); +__device__ __attribute__((const)) _Float16 __ocml_rsqrt_f16(_Float16); +__device__ _Float16 __ocml_sin_f16(_Float16); +__device__ __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16); +__device__ __attribute__((const)) _Float16 __ocml_trunc_f16(_Float16); +__device__ __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16); +__device__ __attribute__((const)) _Float16 __ocml_fmin_f16(_Float16, _Float16); + +typedef _Float16 __2f16 __attribute__((ext_vector_type(2))); +typedef short __2i16 __attribute__((ext_vector_type(2))); + +#if defined(__clang__) && defined(__HIP__) +__device__ __attribute__((const)) float __ockl_fdot2(__2f16 a, __2f16 b, + float c, bool s); +#endif + +__device__ __attribute__((const)) __2f16 __ocml_ceil_2f16(__2f16); +__device__ __attribute__((const)) __2f16 __ocml_fabs_2f16(__2f16); +__device__ __2f16 __ocml_cos_2f16(__2f16); +__device__ __attribute__((pure)) __2f16 __ocml_exp_2f16(__2f16); +__device__ __attribute__((pure)) __2f16 __ocml_exp10_2f16(__2f16); +__device__ __attribute__((pure)) __2f16 __ocml_exp2_2f16(__2f16); +__device__ __attribute__((const)) __2f16 __ocml_floor_2f16(__2f16); +__device__ + __attribute__((const)) __2f16 __ocml_fma_2f16(__2f16, __2f16, __2f16); +__device__ __attribute__((const)) __2i16 __ocml_isinf_2f16(__2f16); +__device__ __attribute__((const)) __2i16 __ocml_isnan_2f16(__2f16); +__device__ __attribute__((pure)) __2f16 __ocml_log_2f16(__2f16); +__device__ __attribute__((pure)) __2f16 __ocml_log10_2f16(__2f16); +__device__ __attribute__((pure)) __2f16 __ocml_log2_2f16(__2f16); +__device__ __attribute__((const)) __2f16 __ocml_rint_2f16(__2f16); +__device__ __attribute__((const)) __2f16 __ocml_rsqrt_2f16(__2f16); +__device__ __2f16 __ocml_sin_2f16(__2f16); +__device__ __attribute__((const)) __2f16 __ocml_sqrt_2f16(__2f16); +__device__ __attribute__((const)) __2f16 __ocml_trunc_2f16(__2f16); + +__device__ __attribute__((const)) _Float16 __ocml_cvtrtn_f16_f32(float); +__device__ __attribute__((const)) _Float16 __ocml_cvtrtp_f16_f32(float); +__device__ __attribute__((const)) _Float16 __ocml_cvtrtz_f16_f32(float); +} +#endif // !__CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__ +// TODO: remove these after they get into clang header +// __clang_hip_libdevice_declares.h' +extern "C" { +__device__ __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16); +__device__ __attribute__((const)) _Float16 __ocml_fmin_f16(_Float16, _Float16); +__device__ __attribute__((const)) _Float16 __ocml_cvtrtn_f16_f32(float); +__device__ __attribute__((const)) _Float16 __ocml_cvtrtp_f16_f32(float); +__device__ __attribute__((const)) _Float16 __ocml_cvtrtz_f16_f32(float); +} diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_ldg.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_ldg.h new file mode 100644 index 000000000..ec9e6242c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_ldg.h @@ -0,0 +1,109 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_LDG_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_LDG_H + +#if __HIP_CLANG_ONLY__ +#include "amd_hip_vector_types.h" +#include "host_defines.h" + +__device__ inline static char __ldg(const char *ptr) { return *ptr; } + +__device__ inline static char2 __ldg(const char2 *ptr) { return *ptr; } + +__device__ inline static char4 __ldg(const char4 *ptr) { return *ptr; } + +__device__ inline static signed char __ldg(const signed char *ptr) { + return ptr[0]; +} + +__device__ inline static unsigned char __ldg(const unsigned char *ptr) { + return ptr[0]; +} + +__device__ inline static short __ldg(const short *ptr) { return ptr[0]; } + +__device__ inline static short2 __ldg(const short2 *ptr) { return ptr[0]; } + +__device__ inline static short4 __ldg(const short4 *ptr) { return ptr[0]; } + +__device__ inline static unsigned short __ldg(const unsigned short *ptr) { + return ptr[0]; +} + +__device__ inline static int __ldg(const int *ptr) { return ptr[0]; } + +__device__ inline static int2 __ldg(const int2 *ptr) { return ptr[0]; } + +__device__ inline static int4 __ldg(const int4 *ptr) { return ptr[0]; } + +__device__ inline static unsigned int __ldg(const unsigned int *ptr) { + return ptr[0]; +} + +__device__ inline static long __ldg(const long *ptr) { return ptr[0]; } + +__device__ inline static unsigned long __ldg(const unsigned long *ptr) { + return ptr[0]; +} + +__device__ inline static long long __ldg(const long long *ptr) { + return ptr[0]; +} + +__device__ inline static longlong2 __ldg(const longlong2 *ptr) { + return ptr[0]; +} + +__device__ inline static unsigned long long +__ldg(const unsigned long long *ptr) { + return ptr[0]; +} + +__device__ inline static uchar2 __ldg(const uchar2 *ptr) { return ptr[0]; } + +__device__ inline static uchar4 __ldg(const uchar4 *ptr) { return ptr[0]; } + +__device__ inline static ushort2 __ldg(const ushort2 *ptr) { return ptr[0]; } + +__device__ inline static uint2 __ldg(const uint2 *ptr) { return ptr[0]; } + +__device__ inline static uint4 __ldg(const uint4 *ptr) { return ptr[0]; } + +__device__ inline static ulonglong2 __ldg(const ulonglong2 *ptr) { + return ptr[0]; +} + +__device__ inline static float __ldg(const float *ptr) { return ptr[0]; } + +__device__ inline static float2 __ldg(const float2 *ptr) { return ptr[0]; } + +__device__ inline static float4 __ldg(const float4 *ptr) { return ptr[0]; } + +__device__ inline static double __ldg(const double *ptr) { return ptr[0]; } + +__device__ inline static double2 __ldg(const double2 *ptr) { return ptr[0]; } + +#endif // __HIP_CLANG_ONLY__ + +#endif // HIP_LDG_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_prof_str.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_prof_str.h new file mode 100644 index 000000000..06639f466 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_prof_str.h @@ -0,0 +1,17068 @@ +// Generated file. DO NOT EDIT. +// +// This file is automatically generated by the hip_prof_gen.py script. +// If changes are required, run the script and commit the updated file. + +#ifndef _HIP_PROF_STR_H +#define _HIP_PROF_STR_H +#define HIP_PROF_VER 1 + +#include "amd_hip_gl_interop.h" +#include +#include + +#define HIP_API_ID_CONCAT_HELPER(a, b) a##b +#define HIP_API_ID_CONCAT(a, b) HIP_API_ID_CONCAT_HELPER(a, b) + +// HIP API callbacks ID enumeration +enum hip_api_id_t { + HIP_API_ID_NONE = 0, + HIP_API_ID_FIRST = 1, + HIP_API_ID___hipPopCallConfiguration = 1, + HIP_API_ID___hipPushCallConfiguration = 2, + HIP_API_ID_hipArray3DCreate = 3, + HIP_API_ID_hipArrayCreate = 4, + HIP_API_ID_hipArrayDestroy = 5, + HIP_API_ID_hipChooseDeviceR0000 = 6, + HIP_API_ID_hipConfigureCall = 7, + HIP_API_ID_hipCtxCreate = 8, + HIP_API_ID_hipCtxDestroy = 9, + HIP_API_ID_hipCtxDisablePeerAccess = 10, + HIP_API_ID_hipCtxEnablePeerAccess = 11, + HIP_API_ID_hipCtxGetApiVersion = 12, + HIP_API_ID_hipCtxGetCacheConfig = 13, + HIP_API_ID_hipCtxGetCurrent = 14, + HIP_API_ID_hipCtxGetDevice = 15, + HIP_API_ID_hipCtxGetFlags = 16, + HIP_API_ID_hipCtxGetSharedMemConfig = 17, + HIP_API_ID_hipCtxPopCurrent = 18, + HIP_API_ID_hipCtxPushCurrent = 19, + HIP_API_ID_hipCtxSetCacheConfig = 20, + HIP_API_ID_hipCtxSetCurrent = 21, + HIP_API_ID_hipCtxSetSharedMemConfig = 22, + HIP_API_ID_hipCtxSynchronize = 23, + HIP_API_ID_hipDestroyExternalMemory = 24, + HIP_API_ID_hipDestroyExternalSemaphore = 25, + HIP_API_ID_hipDeviceCanAccessPeer = 26, + HIP_API_ID_hipDeviceComputeCapability = 27, + HIP_API_ID_hipDeviceDisablePeerAccess = 28, + HIP_API_ID_hipDeviceEnablePeerAccess = 29, + HIP_API_ID_hipDeviceGet = 30, + HIP_API_ID_hipDeviceGetAttribute = 31, + HIP_API_ID_hipDeviceGetByPCIBusId = 32, + HIP_API_ID_hipDeviceGetCacheConfig = 33, + HIP_API_ID_hipDeviceGetLimit = 34, + HIP_API_ID_hipDeviceGetName = 35, + HIP_API_ID_hipDeviceGetP2PAttribute = 36, + HIP_API_ID_hipDeviceGetPCIBusId = 37, + HIP_API_ID_hipDeviceGetSharedMemConfig = 38, + HIP_API_ID_hipDeviceGetStreamPriorityRange = 39, + HIP_API_ID_hipDevicePrimaryCtxGetState = 40, + HIP_API_ID_hipDevicePrimaryCtxRelease = 41, + HIP_API_ID_hipDevicePrimaryCtxReset = 42, + HIP_API_ID_hipDevicePrimaryCtxRetain = 43, + HIP_API_ID_hipDevicePrimaryCtxSetFlags = 44, + HIP_API_ID_hipDeviceReset = 45, + HIP_API_ID_hipDeviceSetCacheConfig = 46, + HIP_API_ID_hipDeviceSetSharedMemConfig = 47, + HIP_API_ID_hipDeviceSynchronize = 48, + HIP_API_ID_hipDeviceTotalMem = 49, + HIP_API_ID_RESERVED_50 = 50, + HIP_API_ID_hipDrvMemcpy2DUnaligned = 51, + HIP_API_ID_hipDrvMemcpy3D = 52, + HIP_API_ID_hipDrvMemcpy3DAsync = 53, + HIP_API_ID_hipEventCreate = 54, + HIP_API_ID_hipEventCreateWithFlags = 55, + HIP_API_ID_hipEventDestroy = 56, + HIP_API_ID_hipEventElapsedTime = 57, + HIP_API_ID_hipEventQuery = 58, + HIP_API_ID_hipEventRecord = 59, + HIP_API_ID_hipEventSynchronize = 60, + HIP_API_ID_hipExtGetLinkTypeAndHopCount = 61, + HIP_API_ID_hipExtLaunchKernel = 62, + HIP_API_ID_hipExtLaunchMultiKernelMultiDevice = 63, + HIP_API_ID_hipExtMallocWithFlags = 64, + HIP_API_ID_hipExtModuleLaunchKernel = 65, + HIP_API_ID_hipExtStreamCreateWithCUMask = 66, + HIP_API_ID_hipExtStreamGetCUMask = 67, + HIP_API_ID_hipExternalMemoryGetMappedBuffer = 68, + HIP_API_ID_hipFree = 69, + HIP_API_ID_hipFreeArray = 70, + HIP_API_ID_hipFreeHost = 71, + HIP_API_ID_hipFreeMipmappedArray = 72, + HIP_API_ID_hipFuncGetAttribute = 73, + HIP_API_ID_hipFuncGetAttributes = 74, + HIP_API_ID_hipFuncSetAttribute = 75, + HIP_API_ID_hipFuncSetCacheConfig = 76, + HIP_API_ID_hipFuncSetSharedMemConfig = 77, + HIP_API_ID_hipGetDevice = 78, + HIP_API_ID_hipGetDeviceCount = 79, + HIP_API_ID_hipGetDeviceFlags = 80, + HIP_API_ID_hipGetDevicePropertiesR0000 = 81, + HIP_API_ID_RESERVED_82 = 82, + HIP_API_ID_hipGetErrorString = 83, + HIP_API_ID_hipGetLastError = 84, + HIP_API_ID_hipGetMipmappedArrayLevel = 85, + HIP_API_ID_hipGetSymbolAddress = 86, + HIP_API_ID_hipGetSymbolSize = 87, + HIP_API_ID_hipHccModuleLaunchKernel = 88, + HIP_API_ID_hipHostAlloc = 89, + HIP_API_ID_hipHostFree = 90, + HIP_API_ID_hipHostGetDevicePointer = 91, + HIP_API_ID_hipHostGetFlags = 92, + HIP_API_ID_hipHostMalloc = 93, + HIP_API_ID_hipHostRegister = 94, + HIP_API_ID_hipHostUnregister = 95, + HIP_API_ID_hipImportExternalMemory = 96, + HIP_API_ID_hipImportExternalSemaphore = 97, + HIP_API_ID_hipInit = 98, + HIP_API_ID_hipIpcCloseMemHandle = 99, + HIP_API_ID_hipIpcGetEventHandle = 100, + HIP_API_ID_hipIpcGetMemHandle = 101, + HIP_API_ID_hipIpcOpenEventHandle = 102, + HIP_API_ID_hipIpcOpenMemHandle = 103, + HIP_API_ID_hipLaunchByPtr = 104, + HIP_API_ID_hipLaunchCooperativeKernel = 105, + HIP_API_ID_hipLaunchCooperativeKernelMultiDevice = 106, + HIP_API_ID_hipLaunchKernel = 107, + HIP_API_ID_hipMalloc = 108, + HIP_API_ID_hipMalloc3D = 109, + HIP_API_ID_hipMalloc3DArray = 110, + HIP_API_ID_hipMallocArray = 111, + HIP_API_ID_hipMallocHost = 112, + HIP_API_ID_hipMallocManaged = 113, + HIP_API_ID_hipMallocMipmappedArray = 114, + HIP_API_ID_hipMallocPitch = 115, + HIP_API_ID_hipMemAdvise = 116, + HIP_API_ID_hipMemAllocHost = 117, + HIP_API_ID_hipMemAllocPitch = 118, + HIP_API_ID_hipMemGetAddressRange = 119, + HIP_API_ID_hipMemGetInfo = 120, + HIP_API_ID_hipMemPrefetchAsync = 121, + HIP_API_ID_hipMemPtrGetInfo = 122, + HIP_API_ID_hipMemRangeGetAttribute = 123, + HIP_API_ID_hipMemRangeGetAttributes = 124, + HIP_API_ID_hipMemcpy = 125, + HIP_API_ID_hipMemcpy2D = 126, + HIP_API_ID_hipMemcpy2DAsync = 127, + HIP_API_ID_hipMemcpy2DFromArray = 128, + HIP_API_ID_hipMemcpy2DFromArrayAsync = 129, + HIP_API_ID_hipMemcpy2DToArray = 130, + HIP_API_ID_hipMemcpy2DToArrayAsync = 131, + HIP_API_ID_hipMemcpy3D = 132, + HIP_API_ID_hipMemcpy3DAsync = 133, + HIP_API_ID_hipMemcpyAsync = 134, + HIP_API_ID_hipMemcpyAtoH = 135, + HIP_API_ID_hipMemcpyDtoD = 136, + HIP_API_ID_hipMemcpyDtoDAsync = 137, + HIP_API_ID_hipMemcpyDtoH = 138, + HIP_API_ID_hipMemcpyDtoHAsync = 139, + HIP_API_ID_hipMemcpyFromArray = 140, + HIP_API_ID_hipMemcpyFromSymbol = 141, + HIP_API_ID_hipMemcpyFromSymbolAsync = 142, + HIP_API_ID_hipMemcpyHtoA = 143, + HIP_API_ID_hipMemcpyHtoD = 144, + HIP_API_ID_hipMemcpyHtoDAsync = 145, + HIP_API_ID_hipMemcpyParam2D = 146, + HIP_API_ID_hipMemcpyParam2DAsync = 147, + HIP_API_ID_hipMemcpyPeer = 148, + HIP_API_ID_hipMemcpyPeerAsync = 149, + HIP_API_ID_hipMemcpyToArray = 150, + HIP_API_ID_hipMemcpyToSymbol = 151, + HIP_API_ID_hipMemcpyToSymbolAsync = 152, + HIP_API_ID_hipMemcpyWithStream = 153, + HIP_API_ID_hipMemset = 154, + HIP_API_ID_hipMemset2D = 155, + HIP_API_ID_hipMemset2DAsync = 156, + HIP_API_ID_hipMemset3D = 157, + HIP_API_ID_hipMemset3DAsync = 158, + HIP_API_ID_hipMemsetAsync = 159, + HIP_API_ID_hipMemsetD16 = 160, + HIP_API_ID_hipMemsetD16Async = 161, + HIP_API_ID_hipMemsetD32 = 162, + HIP_API_ID_hipMemsetD32Async = 163, + HIP_API_ID_hipMemsetD8 = 164, + HIP_API_ID_hipMemsetD8Async = 165, + HIP_API_ID_hipModuleGetFunction = 166, + HIP_API_ID_hipModuleGetGlobal = 167, + HIP_API_ID_hipModuleGetTexRef = 168, + HIP_API_ID_hipModuleLaunchKernel = 169, + HIP_API_ID_hipModuleLoad = 170, + HIP_API_ID_hipModuleLoadData = 171, + HIP_API_ID_hipModuleLoadDataEx = 172, + HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessor = 173, + HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags = 174, + HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSize = 175, + HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSizeWithFlags = 176, + HIP_API_ID_hipModuleUnload = 177, + HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessor = 178, + HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags = 179, + HIP_API_ID_hipOccupancyMaxPotentialBlockSize = 180, + HIP_API_ID_hipPeekAtLastError = 181, + HIP_API_ID_hipPointerGetAttributes = 182, + HIP_API_ID_hipProfilerStart = 183, + HIP_API_ID_hipProfilerStop = 184, + HIP_API_ID_RESERVED_185 = 185, + HIP_API_ID_hipSetDevice = 186, + HIP_API_ID_hipSetDeviceFlags = 187, + HIP_API_ID_hipSetupArgument = 188, + HIP_API_ID_hipSignalExternalSemaphoresAsync = 189, + HIP_API_ID_hipStreamAddCallback = 190, + HIP_API_ID_hipStreamAttachMemAsync = 191, + HIP_API_ID_hipStreamCreate = 192, + HIP_API_ID_hipStreamCreateWithFlags = 193, + HIP_API_ID_hipStreamCreateWithPriority = 194, + HIP_API_ID_hipStreamDestroy = 195, + HIP_API_ID_hipStreamGetFlags = 196, + HIP_API_ID_hipStreamGetPriority = 197, + HIP_API_ID_hipStreamQuery = 198, + HIP_API_ID_hipStreamSynchronize = 199, + HIP_API_ID_hipStreamWaitEvent = 200, + HIP_API_ID_hipStreamWaitValue32 = 201, + HIP_API_ID_hipStreamWaitValue64 = 202, + HIP_API_ID_hipStreamWriteValue32 = 203, + HIP_API_ID_hipStreamWriteValue64 = 204, + HIP_API_ID_hipWaitExternalSemaphoresAsync = 205, + HIP_API_ID_hipCreateSurfaceObject = 206, + HIP_API_ID_hipDestroySurfaceObject = 207, + HIP_API_ID_hipGraphAddKernelNode = 208, + HIP_API_ID_hipGraphAddMemcpyNode = 209, + HIP_API_ID_hipGraphAddMemsetNode = 210, + HIP_API_ID_hipGraphCreate = 211, + HIP_API_ID_hipGraphDestroy = 212, + HIP_API_ID_hipGraphExecDestroy = 213, + HIP_API_ID_hipGraphInstantiate = 214, + HIP_API_ID_hipGraphLaunch = 215, + HIP_API_ID_hipMipmappedArrayCreate = 216, + HIP_API_ID_hipMipmappedArrayDestroy = 217, + HIP_API_ID_hipMipmappedArrayGetLevel = 218, + HIP_API_ID_hipStreamBeginCapture = 219, + HIP_API_ID_hipStreamEndCapture = 220, + HIP_API_ID_hipTexRefGetAddress = 221, + HIP_API_ID_hipTexRefGetFlags = 222, + HIP_API_ID_hipTexRefGetFormat = 223, + HIP_API_ID_hipTexRefGetMaxAnisotropy = 224, + HIP_API_ID_hipTexRefGetMipMappedArray = 225, + HIP_API_ID_hipTexRefGetMipmapLevelBias = 226, + HIP_API_ID_hipTexRefGetMipmapLevelClamp = 227, + HIP_API_ID_hipTexRefSetAddress = 228, + HIP_API_ID_hipTexRefSetAddress2D = 229, + HIP_API_ID_hipTexRefSetBorderColor = 230, + HIP_API_ID_hipTexRefSetFormat = 231, + HIP_API_ID_hipTexRefSetMaxAnisotropy = 232, + HIP_API_ID_hipTexRefSetMipmapLevelClamp = 233, + HIP_API_ID_hipTexRefSetMipmappedArray = 234, + HIP_API_ID_hipGLGetDevices = 235, + HIP_API_ID_hipGraphAddDependencies = 236, + HIP_API_ID_hipGraphAddEmptyNode = 237, + HIP_API_ID_hipGraphExecKernelNodeSetParams = 238, + HIP_API_ID_hipGraphGetNodes = 239, + HIP_API_ID_hipGraphGetRootNodes = 240, + HIP_API_ID_hipGraphKernelNodeGetParams = 241, + HIP_API_ID_hipGraphKernelNodeSetParams = 242, + HIP_API_ID_hipGraphMemcpyNodeGetParams = 243, + HIP_API_ID_hipGraphMemcpyNodeSetParams = 244, + HIP_API_ID_hipGraphMemsetNodeGetParams = 245, + HIP_API_ID_hipGraphMemsetNodeSetParams = 246, + HIP_API_ID_hipGraphicsGLRegisterBuffer = 247, + HIP_API_ID_hipGraphicsMapResources = 248, + HIP_API_ID_hipGraphicsResourceGetMappedPointer = 249, + HIP_API_ID_hipGraphicsUnmapResources = 250, + HIP_API_ID_hipGraphicsUnregisterResource = 251, + HIP_API_ID_hipGraphAddChildGraphNode = 252, + HIP_API_ID_hipGraphAddEventRecordNode = 253, + HIP_API_ID_hipGraphAddEventWaitNode = 254, + HIP_API_ID_hipGraphAddHostNode = 255, + HIP_API_ID_hipGraphAddMemcpyNode1D = 256, + HIP_API_ID_hipGraphAddMemcpyNodeFromSymbol = 257, + HIP_API_ID_hipGraphAddMemcpyNodeToSymbol = 258, + HIP_API_ID_hipGraphChildGraphNodeGetGraph = 259, + HIP_API_ID_hipGraphClone = 260, + HIP_API_ID_hipGraphDestroyNode = 261, + HIP_API_ID_hipGraphEventRecordNodeGetEvent = 262, + HIP_API_ID_hipGraphEventRecordNodeSetEvent = 263, + HIP_API_ID_hipGraphEventWaitNodeGetEvent = 264, + HIP_API_ID_hipGraphEventWaitNodeSetEvent = 265, + HIP_API_ID_hipGraphExecChildGraphNodeSetParams = 266, + HIP_API_ID_hipGraphExecEventRecordNodeSetEvent = 267, + HIP_API_ID_hipGraphExecEventWaitNodeSetEvent = 268, + HIP_API_ID_hipGraphExecHostNodeSetParams = 269, + HIP_API_ID_hipGraphExecMemcpyNodeSetParams = 270, + HIP_API_ID_hipGraphExecMemcpyNodeSetParams1D = 271, + HIP_API_ID_hipGraphExecMemcpyNodeSetParamsFromSymbol = 272, + HIP_API_ID_hipGraphExecMemcpyNodeSetParamsToSymbol = 273, + HIP_API_ID_hipGraphExecMemsetNodeSetParams = 274, + HIP_API_ID_hipGraphExecUpdate = 275, + HIP_API_ID_hipGraphGetEdges = 276, + HIP_API_ID_hipGraphHostNodeGetParams = 277, + HIP_API_ID_hipGraphHostNodeSetParams = 278, + HIP_API_ID_hipGraphInstantiateWithFlags = 279, + HIP_API_ID_hipGraphMemcpyNodeSetParams1D = 280, + HIP_API_ID_hipGraphMemcpyNodeSetParamsFromSymbol = 281, + HIP_API_ID_hipGraphMemcpyNodeSetParamsToSymbol = 282, + HIP_API_ID_hipGraphNodeFindInClone = 283, + HIP_API_ID_hipGraphNodeGetDependencies = 284, + HIP_API_ID_hipGraphNodeGetDependentNodes = 285, + HIP_API_ID_hipGraphNodeGetType = 286, + HIP_API_ID_hipGraphRemoveDependencies = 287, + HIP_API_ID_hipStreamGetCaptureInfo = 288, + HIP_API_ID_hipStreamGetCaptureInfo_v2 = 289, + HIP_API_ID_hipStreamIsCapturing = 290, + HIP_API_ID_hipStreamUpdateCaptureDependencies = 291, + HIP_API_ID_hipDrvPointerGetAttributes = 292, + HIP_API_ID_hipGraphicsGLRegisterImage = 293, + HIP_API_ID_hipGraphicsSubResourceGetMappedArray = 294, + HIP_API_ID_hipPointerGetAttribute = 295, + HIP_API_ID_RESERVED_296 = 296, + HIP_API_ID_hipThreadExchangeStreamCaptureMode = 297, + HIP_API_ID_hipDeviceGetUuid = 298, + HIP_API_ID_hipGetChannelDesc = 299, + HIP_API_ID_hipGraphKernelNodeGetAttribute = 300, + HIP_API_ID_hipGraphKernelNodeSetAttribute = 301, + HIP_API_ID_hipLaunchHostFunc = 302, + HIP_API_ID_hipDeviceGetDefaultMemPool = 303, + HIP_API_ID_hipDeviceGetMemPool = 304, + HIP_API_ID_hipDeviceSetMemPool = 305, + HIP_API_ID_hipFreeAsync = 306, + HIP_API_ID_hipMallocAsync = 307, + HIP_API_ID_hipMallocFromPoolAsync = 308, + HIP_API_ID_hipMemPoolCreate = 309, + HIP_API_ID_hipMemPoolDestroy = 310, + HIP_API_ID_hipMemPoolExportPointer = 311, + HIP_API_ID_hipMemPoolExportToShareableHandle = 312, + HIP_API_ID_hipMemPoolGetAccess = 313, + HIP_API_ID_hipMemPoolGetAttribute = 314, + HIP_API_ID_hipMemPoolImportFromShareableHandle = 315, + HIP_API_ID_hipMemPoolImportPointer = 316, + HIP_API_ID_hipMemPoolSetAccess = 317, + HIP_API_ID_hipMemPoolSetAttribute = 318, + HIP_API_ID_hipMemPoolTrimTo = 319, + HIP_API_ID_hipMemAddressFree = 320, + HIP_API_ID_hipMemAddressReserve = 321, + HIP_API_ID_hipMemCreate = 322, + HIP_API_ID_hipMemExportToShareableHandle = 323, + HIP_API_ID_hipMemGetAccess = 324, + HIP_API_ID_hipMemGetAllocationGranularity = 325, + HIP_API_ID_hipMemGetAllocationPropertiesFromHandle = 326, + HIP_API_ID_hipMemImportFromShareableHandle = 327, + HIP_API_ID_hipMemMap = 328, + HIP_API_ID_hipMemMapArrayAsync = 329, + HIP_API_ID_hipMemRelease = 330, + HIP_API_ID_hipMemRetainAllocationHandle = 331, + HIP_API_ID_hipMemSetAccess = 332, + HIP_API_ID_hipMemUnmap = 333, + HIP_API_ID_hipDeviceSetGraphMemAttribute = 334, + HIP_API_ID_hipDeviceGetGraphMemAttribute = 335, + HIP_API_ID_hipDeviceGraphMemTrim = 336, + HIP_API_ID_hipDeviceSetLimit = 337, + HIP_API_ID_hipTexRefSetArray = 338, + HIP_API_ID_hipTexRefSetFlags = 339, + HIP_API_ID_hipTexRefSetMipmapLevelBias = 340, + HIP_API_ID_hipDriverGetVersion = 341, + HIP_API_ID_hipGraphUpload = 342, + HIP_API_ID_hipRuntimeGetVersion = 343, + HIP_API_ID_hipUserObjectCreate = 344, + HIP_API_ID_hipUserObjectRelease = 345, + HIP_API_ID_hipUserObjectRetain = 346, + HIP_API_ID_hipGraphRetainUserObject = 347, + HIP_API_ID_hipGraphReleaseUserObject = 348, + HIP_API_ID_hipGraphDebugDotPrint = 349, + HIP_API_ID_hipGraphKernelNodeCopyAttributes = 350, + HIP_API_ID_hipGraphNodeGetEnabled = 351, + HIP_API_ID_hipGraphNodeSetEnabled = 352, + HIP_API_ID_hipPointerSetAttribute = 353, + HIP_API_ID_hipGraphAddMemAllocNode = 354, + HIP_API_ID_hipGraphAddMemFreeNode = 355, + HIP_API_ID_hipGraphMemAllocNodeGetParams = 356, + HIP_API_ID_hipGraphMemFreeNodeGetParams = 357, + HIP_API_ID_hipModuleLaunchCooperativeKernel = 358, + HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice = 359, + HIP_API_ID_hipArray3DGetDescriptor = 360, + HIP_API_ID_hipArrayGetDescriptor = 361, + HIP_API_ID_hipArrayGetInfo = 362, + HIP_API_ID_hipStreamGetDevice = 363, + HIP_API_ID_hipExternalMemoryGetMappedMipmappedArray = 364, + HIP_API_ID_hipChooseDeviceR0600 = 365, + HIP_API_ID_hipDrvGraphAddMemcpyNode = 366, + HIP_API_ID_hipDrvGraphAddMemsetNode = 367, + HIP_API_ID_RESERVED_368 = 368, + HIP_API_ID_RESERVED_369 = 369, + HIP_API_ID_hipGetDevicePropertiesR0600 = 370, + HIP_API_ID_hipGraphAddExternalSemaphoresSignalNode = 371, + HIP_API_ID_hipGraphAddExternalSemaphoresWaitNode = 372, + HIP_API_ID_hipGraphExecExternalSemaphoresSignalNodeSetParams = 373, + HIP_API_ID_hipGraphExecExternalSemaphoresWaitNodeSetParams = 374, + HIP_API_ID_hipGraphExternalSemaphoresSignalNodeGetParams = 375, + HIP_API_ID_hipGraphExternalSemaphoresSignalNodeSetParams = 376, + HIP_API_ID_hipGraphExternalSemaphoresWaitNodeGetParams = 377, + HIP_API_ID_hipGraphExternalSemaphoresWaitNodeSetParams = 378, + HIP_API_ID_hipExtGetLastError = 379, + HIP_API_ID_hipGraphAddNode = 380, + HIP_API_ID_hipGetProcAddress = 381, + HIP_API_ID_RESERVED_382 = 382, + HIP_API_ID_RESERVED_383 = 383, + HIP_API_ID_hipGraphInstantiateWithParams = 384, + HIP_API_ID_RESERVED_385 = 385, + HIP_API_ID_RESERVED_386 = 386, + HIP_API_ID_RESERVED_387 = 387, + HIP_API_ID_RESERVED_388 = 388, + HIP_API_ID_hipTexRefGetArray = 389, + HIP_API_ID_hipTexRefGetBorderColor = 390, + HIP_API_ID_hipStreamBeginCaptureToGraph = 391, + HIP_API_ID_hipGetFuncBySymbol = 392, + HIP_API_ID_hipMemcpy2DArrayToArray = 393, + HIP_API_ID_hipMemcpyAtoA = 394, + HIP_API_ID_hipMemcpyAtoD = 395, + HIP_API_ID_hipMemcpyAtoHAsync = 396, + HIP_API_ID_hipMemcpyDtoA = 397, + HIP_API_ID_hipMemcpyHtoAAsync = 398, + HIP_API_ID_hipSetValidDevices = 399, + HIP_API_ID_LAST = 399, + + HIP_API_ID_hipChooseDevice = HIP_API_ID_CONCAT(HIP_API_ID_, hipChooseDevice), + HIP_API_ID_hipGetDeviceProperties = + HIP_API_ID_CONCAT(HIP_API_ID_, hipGetDeviceProperties), + + HIP_API_ID_hipBindTexture = HIP_API_ID_NONE, + HIP_API_ID_hipBindTexture2D = HIP_API_ID_NONE, + HIP_API_ID_hipBindTextureToArray = HIP_API_ID_NONE, + HIP_API_ID_hipBindTextureToMipmappedArray = HIP_API_ID_NONE, + HIP_API_ID_hipCreateTextureObject = HIP_API_ID_NONE, + HIP_API_ID_hipDestroyTextureObject = HIP_API_ID_NONE, + HIP_API_ID_hipDeviceGetCount = HIP_API_ID_NONE, + HIP_API_ID_hipGetTextureAlignmentOffset = HIP_API_ID_NONE, + HIP_API_ID_hipGetTextureObjectResourceDesc = HIP_API_ID_NONE, + HIP_API_ID_hipGetTextureObjectResourceViewDesc = HIP_API_ID_NONE, + HIP_API_ID_hipGetTextureObjectTextureDesc = HIP_API_ID_NONE, + HIP_API_ID_hipGetTextureReference = HIP_API_ID_NONE, + HIP_API_ID_hipTexObjectCreate = HIP_API_ID_NONE, + HIP_API_ID_hipTexObjectDestroy = HIP_API_ID_NONE, + HIP_API_ID_hipTexObjectGetResourceDesc = HIP_API_ID_NONE, + HIP_API_ID_hipTexObjectGetResourceViewDesc = HIP_API_ID_NONE, + HIP_API_ID_hipTexObjectGetTextureDesc = HIP_API_ID_NONE, + HIP_API_ID_hipTexRefGetAddressMode = HIP_API_ID_NONE, + HIP_API_ID_hipTexRefGetFilterMode = HIP_API_ID_NONE, + HIP_API_ID_hipTexRefGetMipmapFilterMode = HIP_API_ID_NONE, + HIP_API_ID_hipTexRefSetAddressMode = HIP_API_ID_NONE, + HIP_API_ID_hipTexRefSetFilterMode = HIP_API_ID_NONE, + HIP_API_ID_hipTexRefSetMipmapFilterMode = HIP_API_ID_NONE, + HIP_API_ID_hipUnbindTexture = HIP_API_ID_NONE, +}; + +#undef HIP_API_ID_CONCAT_HELPER +#undef HIP_API_ID_CONCAT + +// Return the HIP API string for a given callback ID +static inline const char *hip_api_name(const uint32_t id) { + switch (id) { + case HIP_API_ID___hipPopCallConfiguration: + return "__hipPopCallConfiguration"; + case HIP_API_ID___hipPushCallConfiguration: + return "__hipPushCallConfiguration"; + case HIP_API_ID_hipArray3DCreate: + return "hipArray3DCreate"; + case HIP_API_ID_hipArray3DGetDescriptor: + return "hipArray3DGetDescriptor"; + case HIP_API_ID_hipArrayCreate: + return "hipArrayCreate"; + case HIP_API_ID_hipArrayDestroy: + return "hipArrayDestroy"; + case HIP_API_ID_hipArrayGetDescriptor: + return "hipArrayGetDescriptor"; + case HIP_API_ID_hipArrayGetInfo: + return "hipArrayGetInfo"; + case HIP_API_ID_hipChooseDeviceR0000: + return "hipChooseDeviceR0000"; + case HIP_API_ID_hipChooseDeviceR0600: + return "hipChooseDeviceR0600"; + case HIP_API_ID_hipConfigureCall: + return "hipConfigureCall"; + case HIP_API_ID_hipCreateSurfaceObject: + return "hipCreateSurfaceObject"; + case HIP_API_ID_hipCtxCreate: + return "hipCtxCreate"; + case HIP_API_ID_hipCtxDestroy: + return "hipCtxDestroy"; + case HIP_API_ID_hipCtxDisablePeerAccess: + return "hipCtxDisablePeerAccess"; + case HIP_API_ID_hipCtxEnablePeerAccess: + return "hipCtxEnablePeerAccess"; + case HIP_API_ID_hipCtxGetApiVersion: + return "hipCtxGetApiVersion"; + case HIP_API_ID_hipCtxGetCacheConfig: + return "hipCtxGetCacheConfig"; + case HIP_API_ID_hipCtxGetCurrent: + return "hipCtxGetCurrent"; + case HIP_API_ID_hipCtxGetDevice: + return "hipCtxGetDevice"; + case HIP_API_ID_hipCtxGetFlags: + return "hipCtxGetFlags"; + case HIP_API_ID_hipCtxGetSharedMemConfig: + return "hipCtxGetSharedMemConfig"; + case HIP_API_ID_hipCtxPopCurrent: + return "hipCtxPopCurrent"; + case HIP_API_ID_hipCtxPushCurrent: + return "hipCtxPushCurrent"; + case HIP_API_ID_hipCtxSetCacheConfig: + return "hipCtxSetCacheConfig"; + case HIP_API_ID_hipCtxSetCurrent: + return "hipCtxSetCurrent"; + case HIP_API_ID_hipCtxSetSharedMemConfig: + return "hipCtxSetSharedMemConfig"; + case HIP_API_ID_hipCtxSynchronize: + return "hipCtxSynchronize"; + case HIP_API_ID_hipDestroyExternalMemory: + return "hipDestroyExternalMemory"; + case HIP_API_ID_hipDestroyExternalSemaphore: + return "hipDestroyExternalSemaphore"; + case HIP_API_ID_hipDestroySurfaceObject: + return "hipDestroySurfaceObject"; + case HIP_API_ID_hipDeviceCanAccessPeer: + return "hipDeviceCanAccessPeer"; + case HIP_API_ID_hipDeviceComputeCapability: + return "hipDeviceComputeCapability"; + case HIP_API_ID_hipDeviceDisablePeerAccess: + return "hipDeviceDisablePeerAccess"; + case HIP_API_ID_hipDeviceEnablePeerAccess: + return "hipDeviceEnablePeerAccess"; + case HIP_API_ID_hipDeviceGet: + return "hipDeviceGet"; + case HIP_API_ID_hipDeviceGetAttribute: + return "hipDeviceGetAttribute"; + case HIP_API_ID_hipDeviceGetByPCIBusId: + return "hipDeviceGetByPCIBusId"; + case HIP_API_ID_hipDeviceGetCacheConfig: + return "hipDeviceGetCacheConfig"; + case HIP_API_ID_hipDeviceGetDefaultMemPool: + return "hipDeviceGetDefaultMemPool"; + case HIP_API_ID_hipDeviceGetGraphMemAttribute: + return "hipDeviceGetGraphMemAttribute"; + case HIP_API_ID_hipDeviceGetLimit: + return "hipDeviceGetLimit"; + case HIP_API_ID_hipDeviceGetMemPool: + return "hipDeviceGetMemPool"; + case HIP_API_ID_hipDeviceGetName: + return "hipDeviceGetName"; + case HIP_API_ID_hipDeviceGetP2PAttribute: + return "hipDeviceGetP2PAttribute"; + case HIP_API_ID_hipDeviceGetPCIBusId: + return "hipDeviceGetPCIBusId"; + case HIP_API_ID_hipDeviceGetSharedMemConfig: + return "hipDeviceGetSharedMemConfig"; + case HIP_API_ID_hipDeviceGetStreamPriorityRange: + return "hipDeviceGetStreamPriorityRange"; + case HIP_API_ID_hipDeviceGetUuid: + return "hipDeviceGetUuid"; + case HIP_API_ID_hipDeviceGraphMemTrim: + return "hipDeviceGraphMemTrim"; + case HIP_API_ID_hipDevicePrimaryCtxGetState: + return "hipDevicePrimaryCtxGetState"; + case HIP_API_ID_hipDevicePrimaryCtxRelease: + return "hipDevicePrimaryCtxRelease"; + case HIP_API_ID_hipDevicePrimaryCtxReset: + return "hipDevicePrimaryCtxReset"; + case HIP_API_ID_hipDevicePrimaryCtxRetain: + return "hipDevicePrimaryCtxRetain"; + case HIP_API_ID_hipDevicePrimaryCtxSetFlags: + return "hipDevicePrimaryCtxSetFlags"; + case HIP_API_ID_hipDeviceReset: + return "hipDeviceReset"; + case HIP_API_ID_hipDeviceSetCacheConfig: + return "hipDeviceSetCacheConfig"; + case HIP_API_ID_hipDeviceSetGraphMemAttribute: + return "hipDeviceSetGraphMemAttribute"; + case HIP_API_ID_hipDeviceSetLimit: + return "hipDeviceSetLimit"; + case HIP_API_ID_hipDeviceSetMemPool: + return "hipDeviceSetMemPool"; + case HIP_API_ID_hipDeviceSetSharedMemConfig: + return "hipDeviceSetSharedMemConfig"; + case HIP_API_ID_hipDeviceSynchronize: + return "hipDeviceSynchronize"; + case HIP_API_ID_hipDeviceTotalMem: + return "hipDeviceTotalMem"; + case HIP_API_ID_hipDriverGetVersion: + return "hipDriverGetVersion"; + case HIP_API_ID_hipDrvGraphAddMemcpyNode: + return "hipDrvGraphAddMemcpyNode"; + case HIP_API_ID_hipDrvGraphAddMemsetNode: + return "hipDrvGraphAddMemsetNode"; + case HIP_API_ID_hipDrvMemcpy2DUnaligned: + return "hipDrvMemcpy2DUnaligned"; + case HIP_API_ID_hipDrvMemcpy3D: + return "hipDrvMemcpy3D"; + case HIP_API_ID_hipDrvMemcpy3DAsync: + return "hipDrvMemcpy3DAsync"; + case HIP_API_ID_hipDrvPointerGetAttributes: + return "hipDrvPointerGetAttributes"; + case HIP_API_ID_hipEventCreate: + return "hipEventCreate"; + case HIP_API_ID_hipEventCreateWithFlags: + return "hipEventCreateWithFlags"; + case HIP_API_ID_hipEventDestroy: + return "hipEventDestroy"; + case HIP_API_ID_hipEventElapsedTime: + return "hipEventElapsedTime"; + case HIP_API_ID_hipEventQuery: + return "hipEventQuery"; + case HIP_API_ID_hipEventRecord: + return "hipEventRecord"; + case HIP_API_ID_hipEventSynchronize: + return "hipEventSynchronize"; + case HIP_API_ID_hipExtGetLastError: + return "hipExtGetLastError"; + case HIP_API_ID_hipExtGetLinkTypeAndHopCount: + return "hipExtGetLinkTypeAndHopCount"; + case HIP_API_ID_hipExtLaunchKernel: + return "hipExtLaunchKernel"; + case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: + return "hipExtLaunchMultiKernelMultiDevice"; + case HIP_API_ID_hipExtMallocWithFlags: + return "hipExtMallocWithFlags"; + case HIP_API_ID_hipExtModuleLaunchKernel: + return "hipExtModuleLaunchKernel"; + case HIP_API_ID_hipExtStreamCreateWithCUMask: + return "hipExtStreamCreateWithCUMask"; + case HIP_API_ID_hipExtStreamGetCUMask: + return "hipExtStreamGetCUMask"; + case HIP_API_ID_hipExternalMemoryGetMappedBuffer: + return "hipExternalMemoryGetMappedBuffer"; + case HIP_API_ID_hipExternalMemoryGetMappedMipmappedArray: + return "hipExternalMemoryGetMappedMipmappedArray"; + case HIP_API_ID_hipFree: + return "hipFree"; + case HIP_API_ID_hipFreeArray: + return "hipFreeArray"; + case HIP_API_ID_hipFreeAsync: + return "hipFreeAsync"; + case HIP_API_ID_hipFreeHost: + return "hipFreeHost"; + case HIP_API_ID_hipFreeMipmappedArray: + return "hipFreeMipmappedArray"; + case HIP_API_ID_hipFuncGetAttribute: + return "hipFuncGetAttribute"; + case HIP_API_ID_hipFuncGetAttributes: + return "hipFuncGetAttributes"; + case HIP_API_ID_hipFuncSetAttribute: + return "hipFuncSetAttribute"; + case HIP_API_ID_hipFuncSetCacheConfig: + return "hipFuncSetCacheConfig"; + case HIP_API_ID_hipFuncSetSharedMemConfig: + return "hipFuncSetSharedMemConfig"; + case HIP_API_ID_hipGLGetDevices: + return "hipGLGetDevices"; + case HIP_API_ID_hipGetChannelDesc: + return "hipGetChannelDesc"; + case HIP_API_ID_hipGetDevice: + return "hipGetDevice"; + case HIP_API_ID_hipGetDeviceCount: + return "hipGetDeviceCount"; + case HIP_API_ID_hipGetDeviceFlags: + return "hipGetDeviceFlags"; + case HIP_API_ID_hipGetDevicePropertiesR0000: + return "hipGetDevicePropertiesR0000"; + case HIP_API_ID_hipGetDevicePropertiesR0600: + return "hipGetDevicePropertiesR0600"; + case HIP_API_ID_hipGetErrorString: + return "hipGetErrorString"; + case HIP_API_ID_hipGetFuncBySymbol: + return "hipGetFuncBySymbol"; + case HIP_API_ID_hipGetLastError: + return "hipGetLastError"; + case HIP_API_ID_hipGetMipmappedArrayLevel: + return "hipGetMipmappedArrayLevel"; + case HIP_API_ID_hipGetProcAddress: + return "hipGetProcAddress"; + case HIP_API_ID_hipGetSymbolAddress: + return "hipGetSymbolAddress"; + case HIP_API_ID_hipGetSymbolSize: + return "hipGetSymbolSize"; + case HIP_API_ID_hipGraphAddChildGraphNode: + return "hipGraphAddChildGraphNode"; + case HIP_API_ID_hipGraphAddDependencies: + return "hipGraphAddDependencies"; + case HIP_API_ID_hipGraphAddEmptyNode: + return "hipGraphAddEmptyNode"; + case HIP_API_ID_hipGraphAddEventRecordNode: + return "hipGraphAddEventRecordNode"; + case HIP_API_ID_hipGraphAddEventWaitNode: + return "hipGraphAddEventWaitNode"; + case HIP_API_ID_hipGraphAddExternalSemaphoresSignalNode: + return "hipGraphAddExternalSemaphoresSignalNode"; + case HIP_API_ID_hipGraphAddExternalSemaphoresWaitNode: + return "hipGraphAddExternalSemaphoresWaitNode"; + case HIP_API_ID_hipGraphAddHostNode: + return "hipGraphAddHostNode"; + case HIP_API_ID_hipGraphAddKernelNode: + return "hipGraphAddKernelNode"; + case HIP_API_ID_hipGraphAddMemAllocNode: + return "hipGraphAddMemAllocNode"; + case HIP_API_ID_hipGraphAddMemFreeNode: + return "hipGraphAddMemFreeNode"; + case HIP_API_ID_hipGraphAddMemcpyNode: + return "hipGraphAddMemcpyNode"; + case HIP_API_ID_hipGraphAddMemcpyNode1D: + return "hipGraphAddMemcpyNode1D"; + case HIP_API_ID_hipGraphAddMemcpyNodeFromSymbol: + return "hipGraphAddMemcpyNodeFromSymbol"; + case HIP_API_ID_hipGraphAddMemcpyNodeToSymbol: + return "hipGraphAddMemcpyNodeToSymbol"; + case HIP_API_ID_hipGraphAddMemsetNode: + return "hipGraphAddMemsetNode"; + case HIP_API_ID_hipGraphAddNode: + return "hipGraphAddNode"; + case HIP_API_ID_hipGraphChildGraphNodeGetGraph: + return "hipGraphChildGraphNodeGetGraph"; + case HIP_API_ID_hipGraphClone: + return "hipGraphClone"; + case HIP_API_ID_hipGraphCreate: + return "hipGraphCreate"; + case HIP_API_ID_hipGraphDebugDotPrint: + return "hipGraphDebugDotPrint"; + case HIP_API_ID_hipGraphDestroy: + return "hipGraphDestroy"; + case HIP_API_ID_hipGraphDestroyNode: + return "hipGraphDestroyNode"; + case HIP_API_ID_hipGraphEventRecordNodeGetEvent: + return "hipGraphEventRecordNodeGetEvent"; + case HIP_API_ID_hipGraphEventRecordNodeSetEvent: + return "hipGraphEventRecordNodeSetEvent"; + case HIP_API_ID_hipGraphEventWaitNodeGetEvent: + return "hipGraphEventWaitNodeGetEvent"; + case HIP_API_ID_hipGraphEventWaitNodeSetEvent: + return "hipGraphEventWaitNodeSetEvent"; + case HIP_API_ID_hipGraphExecChildGraphNodeSetParams: + return "hipGraphExecChildGraphNodeSetParams"; + case HIP_API_ID_hipGraphExecDestroy: + return "hipGraphExecDestroy"; + case HIP_API_ID_hipGraphExecEventRecordNodeSetEvent: + return "hipGraphExecEventRecordNodeSetEvent"; + case HIP_API_ID_hipGraphExecEventWaitNodeSetEvent: + return "hipGraphExecEventWaitNodeSetEvent"; + case HIP_API_ID_hipGraphExecExternalSemaphoresSignalNodeSetParams: + return "hipGraphExecExternalSemaphoresSignalNodeSetParams"; + case HIP_API_ID_hipGraphExecExternalSemaphoresWaitNodeSetParams: + return "hipGraphExecExternalSemaphoresWaitNodeSetParams"; + case HIP_API_ID_hipGraphExecHostNodeSetParams: + return "hipGraphExecHostNodeSetParams"; + case HIP_API_ID_hipGraphExecKernelNodeSetParams: + return "hipGraphExecKernelNodeSetParams"; + case HIP_API_ID_hipGraphExecMemcpyNodeSetParams: + return "hipGraphExecMemcpyNodeSetParams"; + case HIP_API_ID_hipGraphExecMemcpyNodeSetParams1D: + return "hipGraphExecMemcpyNodeSetParams1D"; + case HIP_API_ID_hipGraphExecMemcpyNodeSetParamsFromSymbol: + return "hipGraphExecMemcpyNodeSetParamsFromSymbol"; + case HIP_API_ID_hipGraphExecMemcpyNodeSetParamsToSymbol: + return "hipGraphExecMemcpyNodeSetParamsToSymbol"; + case HIP_API_ID_hipGraphExecMemsetNodeSetParams: + return "hipGraphExecMemsetNodeSetParams"; + case HIP_API_ID_hipGraphExecUpdate: + return "hipGraphExecUpdate"; + case HIP_API_ID_hipGraphExternalSemaphoresSignalNodeGetParams: + return "hipGraphExternalSemaphoresSignalNodeGetParams"; + case HIP_API_ID_hipGraphExternalSemaphoresSignalNodeSetParams: + return "hipGraphExternalSemaphoresSignalNodeSetParams"; + case HIP_API_ID_hipGraphExternalSemaphoresWaitNodeGetParams: + return "hipGraphExternalSemaphoresWaitNodeGetParams"; + case HIP_API_ID_hipGraphExternalSemaphoresWaitNodeSetParams: + return "hipGraphExternalSemaphoresWaitNodeSetParams"; + case HIP_API_ID_hipGraphGetEdges: + return "hipGraphGetEdges"; + case HIP_API_ID_hipGraphGetNodes: + return "hipGraphGetNodes"; + case HIP_API_ID_hipGraphGetRootNodes: + return "hipGraphGetRootNodes"; + case HIP_API_ID_hipGraphHostNodeGetParams: + return "hipGraphHostNodeGetParams"; + case HIP_API_ID_hipGraphHostNodeSetParams: + return "hipGraphHostNodeSetParams"; + case HIP_API_ID_hipGraphInstantiate: + return "hipGraphInstantiate"; + case HIP_API_ID_hipGraphInstantiateWithFlags: + return "hipGraphInstantiateWithFlags"; + case HIP_API_ID_hipGraphInstantiateWithParams: + return "hipGraphInstantiateWithParams"; + case HIP_API_ID_hipGraphKernelNodeCopyAttributes: + return "hipGraphKernelNodeCopyAttributes"; + case HIP_API_ID_hipGraphKernelNodeGetAttribute: + return "hipGraphKernelNodeGetAttribute"; + case HIP_API_ID_hipGraphKernelNodeGetParams: + return "hipGraphKernelNodeGetParams"; + case HIP_API_ID_hipGraphKernelNodeSetAttribute: + return "hipGraphKernelNodeSetAttribute"; + case HIP_API_ID_hipGraphKernelNodeSetParams: + return "hipGraphKernelNodeSetParams"; + case HIP_API_ID_hipGraphLaunch: + return "hipGraphLaunch"; + case HIP_API_ID_hipGraphMemAllocNodeGetParams: + return "hipGraphMemAllocNodeGetParams"; + case HIP_API_ID_hipGraphMemFreeNodeGetParams: + return "hipGraphMemFreeNodeGetParams"; + case HIP_API_ID_hipGraphMemcpyNodeGetParams: + return "hipGraphMemcpyNodeGetParams"; + case HIP_API_ID_hipGraphMemcpyNodeSetParams: + return "hipGraphMemcpyNodeSetParams"; + case HIP_API_ID_hipGraphMemcpyNodeSetParams1D: + return "hipGraphMemcpyNodeSetParams1D"; + case HIP_API_ID_hipGraphMemcpyNodeSetParamsFromSymbol: + return "hipGraphMemcpyNodeSetParamsFromSymbol"; + case HIP_API_ID_hipGraphMemcpyNodeSetParamsToSymbol: + return "hipGraphMemcpyNodeSetParamsToSymbol"; + case HIP_API_ID_hipGraphMemsetNodeGetParams: + return "hipGraphMemsetNodeGetParams"; + case HIP_API_ID_hipGraphMemsetNodeSetParams: + return "hipGraphMemsetNodeSetParams"; + case HIP_API_ID_hipGraphNodeFindInClone: + return "hipGraphNodeFindInClone"; + case HIP_API_ID_hipGraphNodeGetDependencies: + return "hipGraphNodeGetDependencies"; + case HIP_API_ID_hipGraphNodeGetDependentNodes: + return "hipGraphNodeGetDependentNodes"; + case HIP_API_ID_hipGraphNodeGetEnabled: + return "hipGraphNodeGetEnabled"; + case HIP_API_ID_hipGraphNodeGetType: + return "hipGraphNodeGetType"; + case HIP_API_ID_hipGraphNodeSetEnabled: + return "hipGraphNodeSetEnabled"; + case HIP_API_ID_hipGraphReleaseUserObject: + return "hipGraphReleaseUserObject"; + case HIP_API_ID_hipGraphRemoveDependencies: + return "hipGraphRemoveDependencies"; + case HIP_API_ID_hipGraphRetainUserObject: + return "hipGraphRetainUserObject"; + case HIP_API_ID_hipGraphUpload: + return "hipGraphUpload"; + case HIP_API_ID_hipGraphicsGLRegisterBuffer: + return "hipGraphicsGLRegisterBuffer"; + case HIP_API_ID_hipGraphicsGLRegisterImage: + return "hipGraphicsGLRegisterImage"; + case HIP_API_ID_hipGraphicsMapResources: + return "hipGraphicsMapResources"; + case HIP_API_ID_hipGraphicsResourceGetMappedPointer: + return "hipGraphicsResourceGetMappedPointer"; + case HIP_API_ID_hipGraphicsSubResourceGetMappedArray: + return "hipGraphicsSubResourceGetMappedArray"; + case HIP_API_ID_hipGraphicsUnmapResources: + return "hipGraphicsUnmapResources"; + case HIP_API_ID_hipGraphicsUnregisterResource: + return "hipGraphicsUnregisterResource"; + case HIP_API_ID_hipHccModuleLaunchKernel: + return "hipHccModuleLaunchKernel"; + case HIP_API_ID_hipHostAlloc: + return "hipHostAlloc"; + case HIP_API_ID_hipHostFree: + return "hipHostFree"; + case HIP_API_ID_hipHostGetDevicePointer: + return "hipHostGetDevicePointer"; + case HIP_API_ID_hipHostGetFlags: + return "hipHostGetFlags"; + case HIP_API_ID_hipHostMalloc: + return "hipHostMalloc"; + case HIP_API_ID_hipHostRegister: + return "hipHostRegister"; + case HIP_API_ID_hipHostUnregister: + return "hipHostUnregister"; + case HIP_API_ID_hipImportExternalMemory: + return "hipImportExternalMemory"; + case HIP_API_ID_hipImportExternalSemaphore: + return "hipImportExternalSemaphore"; + case HIP_API_ID_hipInit: + return "hipInit"; + case HIP_API_ID_hipIpcCloseMemHandle: + return "hipIpcCloseMemHandle"; + case HIP_API_ID_hipIpcGetEventHandle: + return "hipIpcGetEventHandle"; + case HIP_API_ID_hipIpcGetMemHandle: + return "hipIpcGetMemHandle"; + case HIP_API_ID_hipIpcOpenEventHandle: + return "hipIpcOpenEventHandle"; + case HIP_API_ID_hipIpcOpenMemHandle: + return "hipIpcOpenMemHandle"; + case HIP_API_ID_hipLaunchByPtr: + return "hipLaunchByPtr"; + case HIP_API_ID_hipLaunchCooperativeKernel: + return "hipLaunchCooperativeKernel"; + case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: + return "hipLaunchCooperativeKernelMultiDevice"; + case HIP_API_ID_hipLaunchHostFunc: + return "hipLaunchHostFunc"; + case HIP_API_ID_hipLaunchKernel: + return "hipLaunchKernel"; + case HIP_API_ID_hipMalloc: + return "hipMalloc"; + case HIP_API_ID_hipMalloc3D: + return "hipMalloc3D"; + case HIP_API_ID_hipMalloc3DArray: + return "hipMalloc3DArray"; + case HIP_API_ID_hipMallocArray: + return "hipMallocArray"; + case HIP_API_ID_hipMallocAsync: + return "hipMallocAsync"; + case HIP_API_ID_hipMallocFromPoolAsync: + return "hipMallocFromPoolAsync"; + case HIP_API_ID_hipMallocHost: + return "hipMallocHost"; + case HIP_API_ID_hipMallocManaged: + return "hipMallocManaged"; + case HIP_API_ID_hipMallocMipmappedArray: + return "hipMallocMipmappedArray"; + case HIP_API_ID_hipMallocPitch: + return "hipMallocPitch"; + case HIP_API_ID_hipMemAddressFree: + return "hipMemAddressFree"; + case HIP_API_ID_hipMemAddressReserve: + return "hipMemAddressReserve"; + case HIP_API_ID_hipMemAdvise: + return "hipMemAdvise"; + case HIP_API_ID_hipMemAllocHost: + return "hipMemAllocHost"; + case HIP_API_ID_hipMemAllocPitch: + return "hipMemAllocPitch"; + case HIP_API_ID_hipMemCreate: + return "hipMemCreate"; + case HIP_API_ID_hipMemExportToShareableHandle: + return "hipMemExportToShareableHandle"; + case HIP_API_ID_hipMemGetAccess: + return "hipMemGetAccess"; + case HIP_API_ID_hipMemGetAddressRange: + return "hipMemGetAddressRange"; + case HIP_API_ID_hipMemGetAllocationGranularity: + return "hipMemGetAllocationGranularity"; + case HIP_API_ID_hipMemGetAllocationPropertiesFromHandle: + return "hipMemGetAllocationPropertiesFromHandle"; + case HIP_API_ID_hipMemGetInfo: + return "hipMemGetInfo"; + case HIP_API_ID_hipMemImportFromShareableHandle: + return "hipMemImportFromShareableHandle"; + case HIP_API_ID_hipMemMap: + return "hipMemMap"; + case HIP_API_ID_hipMemMapArrayAsync: + return "hipMemMapArrayAsync"; + case HIP_API_ID_hipMemPoolCreate: + return "hipMemPoolCreate"; + case HIP_API_ID_hipMemPoolDestroy: + return "hipMemPoolDestroy"; + case HIP_API_ID_hipMemPoolExportPointer: + return "hipMemPoolExportPointer"; + case HIP_API_ID_hipMemPoolExportToShareableHandle: + return "hipMemPoolExportToShareableHandle"; + case HIP_API_ID_hipMemPoolGetAccess: + return "hipMemPoolGetAccess"; + case HIP_API_ID_hipMemPoolGetAttribute: + return "hipMemPoolGetAttribute"; + case HIP_API_ID_hipMemPoolImportFromShareableHandle: + return "hipMemPoolImportFromShareableHandle"; + case HIP_API_ID_hipMemPoolImportPointer: + return "hipMemPoolImportPointer"; + case HIP_API_ID_hipMemPoolSetAccess: + return "hipMemPoolSetAccess"; + case HIP_API_ID_hipMemPoolSetAttribute: + return "hipMemPoolSetAttribute"; + case HIP_API_ID_hipMemPoolTrimTo: + return "hipMemPoolTrimTo"; + case HIP_API_ID_hipMemPrefetchAsync: + return "hipMemPrefetchAsync"; + case HIP_API_ID_hipMemPtrGetInfo: + return "hipMemPtrGetInfo"; + case HIP_API_ID_hipMemRangeGetAttribute: + return "hipMemRangeGetAttribute"; + case HIP_API_ID_hipMemRangeGetAttributes: + return "hipMemRangeGetAttributes"; + case HIP_API_ID_hipMemRelease: + return "hipMemRelease"; + case HIP_API_ID_hipMemRetainAllocationHandle: + return "hipMemRetainAllocationHandle"; + case HIP_API_ID_hipMemSetAccess: + return "hipMemSetAccess"; + case HIP_API_ID_hipMemUnmap: + return "hipMemUnmap"; + case HIP_API_ID_hipMemcpy: + return "hipMemcpy"; + case HIP_API_ID_hipMemcpy2D: + return "hipMemcpy2D"; + case HIP_API_ID_hipMemcpy2DArrayToArray: + return "hipMemcpy2DArrayToArray"; + case HIP_API_ID_hipMemcpy2DAsync: + return "hipMemcpy2DAsync"; + case HIP_API_ID_hipMemcpy2DFromArray: + return "hipMemcpy2DFromArray"; + case HIP_API_ID_hipMemcpy2DFromArrayAsync: + return "hipMemcpy2DFromArrayAsync"; + case HIP_API_ID_hipMemcpy2DToArray: + return "hipMemcpy2DToArray"; + case HIP_API_ID_hipMemcpy2DToArrayAsync: + return "hipMemcpy2DToArrayAsync"; + case HIP_API_ID_hipMemcpy3D: + return "hipMemcpy3D"; + case HIP_API_ID_hipMemcpy3DAsync: + return "hipMemcpy3DAsync"; + case HIP_API_ID_hipMemcpyAsync: + return "hipMemcpyAsync"; + case HIP_API_ID_hipMemcpyAtoA: + return "hipMemcpyAtoA"; + case HIP_API_ID_hipMemcpyAtoD: + return "hipMemcpyAtoD"; + case HIP_API_ID_hipMemcpyAtoH: + return "hipMemcpyAtoH"; + case HIP_API_ID_hipMemcpyAtoHAsync: + return "hipMemcpyAtoHAsync"; + case HIP_API_ID_hipMemcpyDtoA: + return "hipMemcpyDtoA"; + case HIP_API_ID_hipMemcpyDtoD: + return "hipMemcpyDtoD"; + case HIP_API_ID_hipMemcpyDtoDAsync: + return "hipMemcpyDtoDAsync"; + case HIP_API_ID_hipMemcpyDtoH: + return "hipMemcpyDtoH"; + case HIP_API_ID_hipMemcpyDtoHAsync: + return "hipMemcpyDtoHAsync"; + case HIP_API_ID_hipMemcpyFromArray: + return "hipMemcpyFromArray"; + case HIP_API_ID_hipMemcpyFromSymbol: + return "hipMemcpyFromSymbol"; + case HIP_API_ID_hipMemcpyFromSymbolAsync: + return "hipMemcpyFromSymbolAsync"; + case HIP_API_ID_hipMemcpyHtoA: + return "hipMemcpyHtoA"; + case HIP_API_ID_hipMemcpyHtoAAsync: + return "hipMemcpyHtoAAsync"; + case HIP_API_ID_hipMemcpyHtoD: + return "hipMemcpyHtoD"; + case HIP_API_ID_hipMemcpyHtoDAsync: + return "hipMemcpyHtoDAsync"; + case HIP_API_ID_hipMemcpyParam2D: + return "hipMemcpyParam2D"; + case HIP_API_ID_hipMemcpyParam2DAsync: + return "hipMemcpyParam2DAsync"; + case HIP_API_ID_hipMemcpyPeer: + return "hipMemcpyPeer"; + case HIP_API_ID_hipMemcpyPeerAsync: + return "hipMemcpyPeerAsync"; + case HIP_API_ID_hipMemcpyToArray: + return "hipMemcpyToArray"; + case HIP_API_ID_hipMemcpyToSymbol: + return "hipMemcpyToSymbol"; + case HIP_API_ID_hipMemcpyToSymbolAsync: + return "hipMemcpyToSymbolAsync"; + case HIP_API_ID_hipMemcpyWithStream: + return "hipMemcpyWithStream"; + case HIP_API_ID_hipMemset: + return "hipMemset"; + case HIP_API_ID_hipMemset2D: + return "hipMemset2D"; + case HIP_API_ID_hipMemset2DAsync: + return "hipMemset2DAsync"; + case HIP_API_ID_hipMemset3D: + return "hipMemset3D"; + case HIP_API_ID_hipMemset3DAsync: + return "hipMemset3DAsync"; + case HIP_API_ID_hipMemsetAsync: + return "hipMemsetAsync"; + case HIP_API_ID_hipMemsetD16: + return "hipMemsetD16"; + case HIP_API_ID_hipMemsetD16Async: + return "hipMemsetD16Async"; + case HIP_API_ID_hipMemsetD32: + return "hipMemsetD32"; + case HIP_API_ID_hipMemsetD32Async: + return "hipMemsetD32Async"; + case HIP_API_ID_hipMemsetD8: + return "hipMemsetD8"; + case HIP_API_ID_hipMemsetD8Async: + return "hipMemsetD8Async"; + case HIP_API_ID_hipMipmappedArrayCreate: + return "hipMipmappedArrayCreate"; + case HIP_API_ID_hipMipmappedArrayDestroy: + return "hipMipmappedArrayDestroy"; + case HIP_API_ID_hipMipmappedArrayGetLevel: + return "hipMipmappedArrayGetLevel"; + case HIP_API_ID_hipModuleGetFunction: + return "hipModuleGetFunction"; + case HIP_API_ID_hipModuleGetGlobal: + return "hipModuleGetGlobal"; + case HIP_API_ID_hipModuleGetTexRef: + return "hipModuleGetTexRef"; + case HIP_API_ID_hipModuleLaunchCooperativeKernel: + return "hipModuleLaunchCooperativeKernel"; + case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: + return "hipModuleLaunchCooperativeKernelMultiDevice"; + case HIP_API_ID_hipModuleLaunchKernel: + return "hipModuleLaunchKernel"; + case HIP_API_ID_hipModuleLoad: + return "hipModuleLoad"; + case HIP_API_ID_hipModuleLoadData: + return "hipModuleLoadData"; + case HIP_API_ID_hipModuleLoadDataEx: + return "hipModuleLoadDataEx"; + case HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessor: + return "hipModuleOccupancyMaxActiveBlocksPerMultiprocessor"; + case HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: + return "hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"; + case HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSize: + return "hipModuleOccupancyMaxPotentialBlockSize"; + case HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSizeWithFlags: + return "hipModuleOccupancyMaxPotentialBlockSizeWithFlags"; + case HIP_API_ID_hipModuleUnload: + return "hipModuleUnload"; + case HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessor: + return "hipOccupancyMaxActiveBlocksPerMultiprocessor"; + case HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: + return "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"; + case HIP_API_ID_hipOccupancyMaxPotentialBlockSize: + return "hipOccupancyMaxPotentialBlockSize"; + case HIP_API_ID_hipPeekAtLastError: + return "hipPeekAtLastError"; + case HIP_API_ID_hipPointerGetAttribute: + return "hipPointerGetAttribute"; + case HIP_API_ID_hipPointerGetAttributes: + return "hipPointerGetAttributes"; + case HIP_API_ID_hipPointerSetAttribute: + return "hipPointerSetAttribute"; + case HIP_API_ID_hipProfilerStart: + return "hipProfilerStart"; + case HIP_API_ID_hipProfilerStop: + return "hipProfilerStop"; + case HIP_API_ID_hipRuntimeGetVersion: + return "hipRuntimeGetVersion"; + case HIP_API_ID_hipSetDevice: + return "hipSetDevice"; + case HIP_API_ID_hipSetDeviceFlags: + return "hipSetDeviceFlags"; + case HIP_API_ID_hipSetValidDevices: + return "hipSetValidDevices"; + case HIP_API_ID_hipSetupArgument: + return "hipSetupArgument"; + case HIP_API_ID_hipSignalExternalSemaphoresAsync: + return "hipSignalExternalSemaphoresAsync"; + case HIP_API_ID_hipStreamAddCallback: + return "hipStreamAddCallback"; + case HIP_API_ID_hipStreamAttachMemAsync: + return "hipStreamAttachMemAsync"; + case HIP_API_ID_hipStreamBeginCapture: + return "hipStreamBeginCapture"; + case HIP_API_ID_hipStreamBeginCaptureToGraph: + return "hipStreamBeginCaptureToGraph"; + case HIP_API_ID_hipStreamCreate: + return "hipStreamCreate"; + case HIP_API_ID_hipStreamCreateWithFlags: + return "hipStreamCreateWithFlags"; + case HIP_API_ID_hipStreamCreateWithPriority: + return "hipStreamCreateWithPriority"; + case HIP_API_ID_hipStreamDestroy: + return "hipStreamDestroy"; + case HIP_API_ID_hipStreamEndCapture: + return "hipStreamEndCapture"; + case HIP_API_ID_hipStreamGetCaptureInfo: + return "hipStreamGetCaptureInfo"; + case HIP_API_ID_hipStreamGetCaptureInfo_v2: + return "hipStreamGetCaptureInfo_v2"; + case HIP_API_ID_hipStreamGetDevice: + return "hipStreamGetDevice"; + case HIP_API_ID_hipStreamGetFlags: + return "hipStreamGetFlags"; + case HIP_API_ID_hipStreamGetPriority: + return "hipStreamGetPriority"; + case HIP_API_ID_hipStreamIsCapturing: + return "hipStreamIsCapturing"; + case HIP_API_ID_hipStreamQuery: + return "hipStreamQuery"; + case HIP_API_ID_hipStreamSynchronize: + return "hipStreamSynchronize"; + case HIP_API_ID_hipStreamUpdateCaptureDependencies: + return "hipStreamUpdateCaptureDependencies"; + case HIP_API_ID_hipStreamWaitEvent: + return "hipStreamWaitEvent"; + case HIP_API_ID_hipStreamWaitValue32: + return "hipStreamWaitValue32"; + case HIP_API_ID_hipStreamWaitValue64: + return "hipStreamWaitValue64"; + case HIP_API_ID_hipStreamWriteValue32: + return "hipStreamWriteValue32"; + case HIP_API_ID_hipStreamWriteValue64: + return "hipStreamWriteValue64"; + case HIP_API_ID_hipTexRefGetAddress: + return "hipTexRefGetAddress"; + case HIP_API_ID_hipTexRefGetArray: + return "hipTexRefGetArray"; + case HIP_API_ID_hipTexRefGetBorderColor: + return "hipTexRefGetBorderColor"; + case HIP_API_ID_hipTexRefGetFlags: + return "hipTexRefGetFlags"; + case HIP_API_ID_hipTexRefGetFormat: + return "hipTexRefGetFormat"; + case HIP_API_ID_hipTexRefGetMaxAnisotropy: + return "hipTexRefGetMaxAnisotropy"; + case HIP_API_ID_hipTexRefGetMipMappedArray: + return "hipTexRefGetMipMappedArray"; + case HIP_API_ID_hipTexRefGetMipmapLevelBias: + return "hipTexRefGetMipmapLevelBias"; + case HIP_API_ID_hipTexRefGetMipmapLevelClamp: + return "hipTexRefGetMipmapLevelClamp"; + case HIP_API_ID_hipTexRefSetAddress: + return "hipTexRefSetAddress"; + case HIP_API_ID_hipTexRefSetAddress2D: + return "hipTexRefSetAddress2D"; + case HIP_API_ID_hipTexRefSetArray: + return "hipTexRefSetArray"; + case HIP_API_ID_hipTexRefSetBorderColor: + return "hipTexRefSetBorderColor"; + case HIP_API_ID_hipTexRefSetFlags: + return "hipTexRefSetFlags"; + case HIP_API_ID_hipTexRefSetFormat: + return "hipTexRefSetFormat"; + case HIP_API_ID_hipTexRefSetMaxAnisotropy: + return "hipTexRefSetMaxAnisotropy"; + case HIP_API_ID_hipTexRefSetMipmapLevelBias: + return "hipTexRefSetMipmapLevelBias"; + case HIP_API_ID_hipTexRefSetMipmapLevelClamp: + return "hipTexRefSetMipmapLevelClamp"; + case HIP_API_ID_hipTexRefSetMipmappedArray: + return "hipTexRefSetMipmappedArray"; + case HIP_API_ID_hipThreadExchangeStreamCaptureMode: + return "hipThreadExchangeStreamCaptureMode"; + case HIP_API_ID_hipUserObjectCreate: + return "hipUserObjectCreate"; + case HIP_API_ID_hipUserObjectRelease: + return "hipUserObjectRelease"; + case HIP_API_ID_hipUserObjectRetain: + return "hipUserObjectRetain"; + case HIP_API_ID_hipWaitExternalSemaphoresAsync: + return "hipWaitExternalSemaphoresAsync"; + }; + return "unknown"; +}; + +#include +// Return the HIP API callback ID for a given name +static inline uint32_t hipApiIdByName(const char *name) { + if (strcmp("__hipPopCallConfiguration", name) == 0) + return HIP_API_ID___hipPopCallConfiguration; + if (strcmp("__hipPushCallConfiguration", name) == 0) + return HIP_API_ID___hipPushCallConfiguration; + if (strcmp("hipArray3DCreate", name) == 0) + return HIP_API_ID_hipArray3DCreate; + if (strcmp("hipArray3DGetDescriptor", name) == 0) + return HIP_API_ID_hipArray3DGetDescriptor; + if (strcmp("hipArrayCreate", name) == 0) + return HIP_API_ID_hipArrayCreate; + if (strcmp("hipArrayDestroy", name) == 0) + return HIP_API_ID_hipArrayDestroy; + if (strcmp("hipArrayGetDescriptor", name) == 0) + return HIP_API_ID_hipArrayGetDescriptor; + if (strcmp("hipArrayGetInfo", name) == 0) + return HIP_API_ID_hipArrayGetInfo; + if (strcmp("hipChooseDeviceR0000", name) == 0) + return HIP_API_ID_hipChooseDeviceR0000; + if (strcmp("hipChooseDeviceR0600", name) == 0) + return HIP_API_ID_hipChooseDeviceR0600; + if (strcmp("hipConfigureCall", name) == 0) + return HIP_API_ID_hipConfigureCall; + if (strcmp("hipCreateSurfaceObject", name) == 0) + return HIP_API_ID_hipCreateSurfaceObject; + if (strcmp("hipCtxCreate", name) == 0) + return HIP_API_ID_hipCtxCreate; + if (strcmp("hipCtxDestroy", name) == 0) + return HIP_API_ID_hipCtxDestroy; + if (strcmp("hipCtxDisablePeerAccess", name) == 0) + return HIP_API_ID_hipCtxDisablePeerAccess; + if (strcmp("hipCtxEnablePeerAccess", name) == 0) + return HIP_API_ID_hipCtxEnablePeerAccess; + if (strcmp("hipCtxGetApiVersion", name) == 0) + return HIP_API_ID_hipCtxGetApiVersion; + if (strcmp("hipCtxGetCacheConfig", name) == 0) + return HIP_API_ID_hipCtxGetCacheConfig; + if (strcmp("hipCtxGetCurrent", name) == 0) + return HIP_API_ID_hipCtxGetCurrent; + if (strcmp("hipCtxGetDevice", name) == 0) + return HIP_API_ID_hipCtxGetDevice; + if (strcmp("hipCtxGetFlags", name) == 0) + return HIP_API_ID_hipCtxGetFlags; + if (strcmp("hipCtxGetSharedMemConfig", name) == 0) + return HIP_API_ID_hipCtxGetSharedMemConfig; + if (strcmp("hipCtxPopCurrent", name) == 0) + return HIP_API_ID_hipCtxPopCurrent; + if (strcmp("hipCtxPushCurrent", name) == 0) + return HIP_API_ID_hipCtxPushCurrent; + if (strcmp("hipCtxSetCacheConfig", name) == 0) + return HIP_API_ID_hipCtxSetCacheConfig; + if (strcmp("hipCtxSetCurrent", name) == 0) + return HIP_API_ID_hipCtxSetCurrent; + if (strcmp("hipCtxSetSharedMemConfig", name) == 0) + return HIP_API_ID_hipCtxSetSharedMemConfig; + if (strcmp("hipCtxSynchronize", name) == 0) + return HIP_API_ID_hipCtxSynchronize; + if (strcmp("hipDestroyExternalMemory", name) == 0) + return HIP_API_ID_hipDestroyExternalMemory; + if (strcmp("hipDestroyExternalSemaphore", name) == 0) + return HIP_API_ID_hipDestroyExternalSemaphore; + if (strcmp("hipDestroySurfaceObject", name) == 0) + return HIP_API_ID_hipDestroySurfaceObject; + if (strcmp("hipDeviceCanAccessPeer", name) == 0) + return HIP_API_ID_hipDeviceCanAccessPeer; + if (strcmp("hipDeviceComputeCapability", name) == 0) + return HIP_API_ID_hipDeviceComputeCapability; + if (strcmp("hipDeviceDisablePeerAccess", name) == 0) + return HIP_API_ID_hipDeviceDisablePeerAccess; + if (strcmp("hipDeviceEnablePeerAccess", name) == 0) + return HIP_API_ID_hipDeviceEnablePeerAccess; + if (strcmp("hipDeviceGet", name) == 0) + return HIP_API_ID_hipDeviceGet; + if (strcmp("hipDeviceGetAttribute", name) == 0) + return HIP_API_ID_hipDeviceGetAttribute; + if (strcmp("hipDeviceGetByPCIBusId", name) == 0) + return HIP_API_ID_hipDeviceGetByPCIBusId; + if (strcmp("hipDeviceGetCacheConfig", name) == 0) + return HIP_API_ID_hipDeviceGetCacheConfig; + if (strcmp("hipDeviceGetDefaultMemPool", name) == 0) + return HIP_API_ID_hipDeviceGetDefaultMemPool; + if (strcmp("hipDeviceGetGraphMemAttribute", name) == 0) + return HIP_API_ID_hipDeviceGetGraphMemAttribute; + if (strcmp("hipDeviceGetLimit", name) == 0) + return HIP_API_ID_hipDeviceGetLimit; + if (strcmp("hipDeviceGetMemPool", name) == 0) + return HIP_API_ID_hipDeviceGetMemPool; + if (strcmp("hipDeviceGetName", name) == 0) + return HIP_API_ID_hipDeviceGetName; + if (strcmp("hipDeviceGetP2PAttribute", name) == 0) + return HIP_API_ID_hipDeviceGetP2PAttribute; + if (strcmp("hipDeviceGetPCIBusId", name) == 0) + return HIP_API_ID_hipDeviceGetPCIBusId; + if (strcmp("hipDeviceGetSharedMemConfig", name) == 0) + return HIP_API_ID_hipDeviceGetSharedMemConfig; + if (strcmp("hipDeviceGetStreamPriorityRange", name) == 0) + return HIP_API_ID_hipDeviceGetStreamPriorityRange; + if (strcmp("hipDeviceGetUuid", name) == 0) + return HIP_API_ID_hipDeviceGetUuid; + if (strcmp("hipDeviceGraphMemTrim", name) == 0) + return HIP_API_ID_hipDeviceGraphMemTrim; + if (strcmp("hipDevicePrimaryCtxGetState", name) == 0) + return HIP_API_ID_hipDevicePrimaryCtxGetState; + if (strcmp("hipDevicePrimaryCtxRelease", name) == 0) + return HIP_API_ID_hipDevicePrimaryCtxRelease; + if (strcmp("hipDevicePrimaryCtxReset", name) == 0) + return HIP_API_ID_hipDevicePrimaryCtxReset; + if (strcmp("hipDevicePrimaryCtxRetain", name) == 0) + return HIP_API_ID_hipDevicePrimaryCtxRetain; + if (strcmp("hipDevicePrimaryCtxSetFlags", name) == 0) + return HIP_API_ID_hipDevicePrimaryCtxSetFlags; + if (strcmp("hipDeviceReset", name) == 0) + return HIP_API_ID_hipDeviceReset; + if (strcmp("hipDeviceSetCacheConfig", name) == 0) + return HIP_API_ID_hipDeviceSetCacheConfig; + if (strcmp("hipDeviceSetGraphMemAttribute", name) == 0) + return HIP_API_ID_hipDeviceSetGraphMemAttribute; + if (strcmp("hipDeviceSetLimit", name) == 0) + return HIP_API_ID_hipDeviceSetLimit; + if (strcmp("hipDeviceSetMemPool", name) == 0) + return HIP_API_ID_hipDeviceSetMemPool; + if (strcmp("hipDeviceSetSharedMemConfig", name) == 0) + return HIP_API_ID_hipDeviceSetSharedMemConfig; + if (strcmp("hipDeviceSynchronize", name) == 0) + return HIP_API_ID_hipDeviceSynchronize; + if (strcmp("hipDeviceTotalMem", name) == 0) + return HIP_API_ID_hipDeviceTotalMem; + if (strcmp("hipDriverGetVersion", name) == 0) + return HIP_API_ID_hipDriverGetVersion; + if (strcmp("hipDrvGraphAddMemcpyNode", name) == 0) + return HIP_API_ID_hipDrvGraphAddMemcpyNode; + if (strcmp("hipDrvGraphAddMemsetNode", name) == 0) + return HIP_API_ID_hipDrvGraphAddMemsetNode; + if (strcmp("hipDrvMemcpy2DUnaligned", name) == 0) + return HIP_API_ID_hipDrvMemcpy2DUnaligned; + if (strcmp("hipDrvMemcpy3D", name) == 0) + return HIP_API_ID_hipDrvMemcpy3D; + if (strcmp("hipDrvMemcpy3DAsync", name) == 0) + return HIP_API_ID_hipDrvMemcpy3DAsync; + if (strcmp("hipDrvPointerGetAttributes", name) == 0) + return HIP_API_ID_hipDrvPointerGetAttributes; + if (strcmp("hipEventCreate", name) == 0) + return HIP_API_ID_hipEventCreate; + if (strcmp("hipEventCreateWithFlags", name) == 0) + return HIP_API_ID_hipEventCreateWithFlags; + if (strcmp("hipEventDestroy", name) == 0) + return HIP_API_ID_hipEventDestroy; + if (strcmp("hipEventElapsedTime", name) == 0) + return HIP_API_ID_hipEventElapsedTime; + if (strcmp("hipEventQuery", name) == 0) + return HIP_API_ID_hipEventQuery; + if (strcmp("hipEventRecord", name) == 0) + return HIP_API_ID_hipEventRecord; + if (strcmp("hipEventSynchronize", name) == 0) + return HIP_API_ID_hipEventSynchronize; + if (strcmp("hipExtGetLastError", name) == 0) + return HIP_API_ID_hipExtGetLastError; + if (strcmp("hipExtGetLinkTypeAndHopCount", name) == 0) + return HIP_API_ID_hipExtGetLinkTypeAndHopCount; + if (strcmp("hipExtLaunchKernel", name) == 0) + return HIP_API_ID_hipExtLaunchKernel; + if (strcmp("hipExtLaunchMultiKernelMultiDevice", name) == 0) + return HIP_API_ID_hipExtLaunchMultiKernelMultiDevice; + if (strcmp("hipExtMallocWithFlags", name) == 0) + return HIP_API_ID_hipExtMallocWithFlags; + if (strcmp("hipExtModuleLaunchKernel", name) == 0) + return HIP_API_ID_hipExtModuleLaunchKernel; + if (strcmp("hipExtStreamCreateWithCUMask", name) == 0) + return HIP_API_ID_hipExtStreamCreateWithCUMask; + if (strcmp("hipExtStreamGetCUMask", name) == 0) + return HIP_API_ID_hipExtStreamGetCUMask; + if (strcmp("hipExternalMemoryGetMappedBuffer", name) == 0) + return HIP_API_ID_hipExternalMemoryGetMappedBuffer; + if (strcmp("hipExternalMemoryGetMappedMipmappedArray", name) == 0) + return HIP_API_ID_hipExternalMemoryGetMappedMipmappedArray; + if (strcmp("hipFree", name) == 0) + return HIP_API_ID_hipFree; + if (strcmp("hipFreeArray", name) == 0) + return HIP_API_ID_hipFreeArray; + if (strcmp("hipFreeAsync", name) == 0) + return HIP_API_ID_hipFreeAsync; + if (strcmp("hipFreeHost", name) == 0) + return HIP_API_ID_hipFreeHost; + if (strcmp("hipFreeMipmappedArray", name) == 0) + return HIP_API_ID_hipFreeMipmappedArray; + if (strcmp("hipFuncGetAttribute", name) == 0) + return HIP_API_ID_hipFuncGetAttribute; + if (strcmp("hipFuncGetAttributes", name) == 0) + return HIP_API_ID_hipFuncGetAttributes; + if (strcmp("hipFuncSetAttribute", name) == 0) + return HIP_API_ID_hipFuncSetAttribute; + if (strcmp("hipFuncSetCacheConfig", name) == 0) + return HIP_API_ID_hipFuncSetCacheConfig; + if (strcmp("hipFuncSetSharedMemConfig", name) == 0) + return HIP_API_ID_hipFuncSetSharedMemConfig; + if (strcmp("hipGLGetDevices", name) == 0) + return HIP_API_ID_hipGLGetDevices; + if (strcmp("hipGetChannelDesc", name) == 0) + return HIP_API_ID_hipGetChannelDesc; + if (strcmp("hipGetDevice", name) == 0) + return HIP_API_ID_hipGetDevice; + if (strcmp("hipGetDeviceCount", name) == 0) + return HIP_API_ID_hipGetDeviceCount; + if (strcmp("hipGetDeviceFlags", name) == 0) + return HIP_API_ID_hipGetDeviceFlags; + if (strcmp("hipGetDevicePropertiesR0000", name) == 0) + return HIP_API_ID_hipGetDevicePropertiesR0000; + if (strcmp("hipGetDevicePropertiesR0600", name) == 0) + return HIP_API_ID_hipGetDevicePropertiesR0600; + if (strcmp("hipGetErrorString", name) == 0) + return HIP_API_ID_hipGetErrorString; + if (strcmp("hipGetFuncBySymbol", name) == 0) + return HIP_API_ID_hipGetFuncBySymbol; + if (strcmp("hipGetLastError", name) == 0) + return HIP_API_ID_hipGetLastError; + if (strcmp("hipGetMipmappedArrayLevel", name) == 0) + return HIP_API_ID_hipGetMipmappedArrayLevel; + if (strcmp("hipGetProcAddress", name) == 0) + return HIP_API_ID_hipGetProcAddress; + if (strcmp("hipGetSymbolAddress", name) == 0) + return HIP_API_ID_hipGetSymbolAddress; + if (strcmp("hipGetSymbolSize", name) == 0) + return HIP_API_ID_hipGetSymbolSize; + if (strcmp("hipGraphAddChildGraphNode", name) == 0) + return HIP_API_ID_hipGraphAddChildGraphNode; + if (strcmp("hipGraphAddDependencies", name) == 0) + return HIP_API_ID_hipGraphAddDependencies; + if (strcmp("hipGraphAddEmptyNode", name) == 0) + return HIP_API_ID_hipGraphAddEmptyNode; + if (strcmp("hipGraphAddEventRecordNode", name) == 0) + return HIP_API_ID_hipGraphAddEventRecordNode; + if (strcmp("hipGraphAddEventWaitNode", name) == 0) + return HIP_API_ID_hipGraphAddEventWaitNode; + if (strcmp("hipGraphAddExternalSemaphoresSignalNode", name) == 0) + return HIP_API_ID_hipGraphAddExternalSemaphoresSignalNode; + if (strcmp("hipGraphAddExternalSemaphoresWaitNode", name) == 0) + return HIP_API_ID_hipGraphAddExternalSemaphoresWaitNode; + if (strcmp("hipGraphAddHostNode", name) == 0) + return HIP_API_ID_hipGraphAddHostNode; + if (strcmp("hipGraphAddKernelNode", name) == 0) + return HIP_API_ID_hipGraphAddKernelNode; + if (strcmp("hipGraphAddMemAllocNode", name) == 0) + return HIP_API_ID_hipGraphAddMemAllocNode; + if (strcmp("hipGraphAddMemFreeNode", name) == 0) + return HIP_API_ID_hipGraphAddMemFreeNode; + if (strcmp("hipGraphAddMemcpyNode", name) == 0) + return HIP_API_ID_hipGraphAddMemcpyNode; + if (strcmp("hipGraphAddMemcpyNode1D", name) == 0) + return HIP_API_ID_hipGraphAddMemcpyNode1D; + if (strcmp("hipGraphAddMemcpyNodeFromSymbol", name) == 0) + return HIP_API_ID_hipGraphAddMemcpyNodeFromSymbol; + if (strcmp("hipGraphAddMemcpyNodeToSymbol", name) == 0) + return HIP_API_ID_hipGraphAddMemcpyNodeToSymbol; + if (strcmp("hipGraphAddMemsetNode", name) == 0) + return HIP_API_ID_hipGraphAddMemsetNode; + if (strcmp("hipGraphAddNode", name) == 0) + return HIP_API_ID_hipGraphAddNode; + if (strcmp("hipGraphChildGraphNodeGetGraph", name) == 0) + return HIP_API_ID_hipGraphChildGraphNodeGetGraph; + if (strcmp("hipGraphClone", name) == 0) + return HIP_API_ID_hipGraphClone; + if (strcmp("hipGraphCreate", name) == 0) + return HIP_API_ID_hipGraphCreate; + if (strcmp("hipGraphDebugDotPrint", name) == 0) + return HIP_API_ID_hipGraphDebugDotPrint; + if (strcmp("hipGraphDestroy", name) == 0) + return HIP_API_ID_hipGraphDestroy; + if (strcmp("hipGraphDestroyNode", name) == 0) + return HIP_API_ID_hipGraphDestroyNode; + if (strcmp("hipGraphEventRecordNodeGetEvent", name) == 0) + return HIP_API_ID_hipGraphEventRecordNodeGetEvent; + if (strcmp("hipGraphEventRecordNodeSetEvent", name) == 0) + return HIP_API_ID_hipGraphEventRecordNodeSetEvent; + if (strcmp("hipGraphEventWaitNodeGetEvent", name) == 0) + return HIP_API_ID_hipGraphEventWaitNodeGetEvent; + if (strcmp("hipGraphEventWaitNodeSetEvent", name) == 0) + return HIP_API_ID_hipGraphEventWaitNodeSetEvent; + if (strcmp("hipGraphExecChildGraphNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExecChildGraphNodeSetParams; + if (strcmp("hipGraphExecDestroy", name) == 0) + return HIP_API_ID_hipGraphExecDestroy; + if (strcmp("hipGraphExecEventRecordNodeSetEvent", name) == 0) + return HIP_API_ID_hipGraphExecEventRecordNodeSetEvent; + if (strcmp("hipGraphExecEventWaitNodeSetEvent", name) == 0) + return HIP_API_ID_hipGraphExecEventWaitNodeSetEvent; + if (strcmp("hipGraphExecExternalSemaphoresSignalNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExecExternalSemaphoresSignalNodeSetParams; + if (strcmp("hipGraphExecExternalSemaphoresWaitNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExecExternalSemaphoresWaitNodeSetParams; + if (strcmp("hipGraphExecHostNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExecHostNodeSetParams; + if (strcmp("hipGraphExecKernelNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExecKernelNodeSetParams; + if (strcmp("hipGraphExecMemcpyNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExecMemcpyNodeSetParams; + if (strcmp("hipGraphExecMemcpyNodeSetParams1D", name) == 0) + return HIP_API_ID_hipGraphExecMemcpyNodeSetParams1D; + if (strcmp("hipGraphExecMemcpyNodeSetParamsFromSymbol", name) == 0) + return HIP_API_ID_hipGraphExecMemcpyNodeSetParamsFromSymbol; + if (strcmp("hipGraphExecMemcpyNodeSetParamsToSymbol", name) == 0) + return HIP_API_ID_hipGraphExecMemcpyNodeSetParamsToSymbol; + if (strcmp("hipGraphExecMemsetNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExecMemsetNodeSetParams; + if (strcmp("hipGraphExecUpdate", name) == 0) + return HIP_API_ID_hipGraphExecUpdate; + if (strcmp("hipGraphExternalSemaphoresSignalNodeGetParams", name) == 0) + return HIP_API_ID_hipGraphExternalSemaphoresSignalNodeGetParams; + if (strcmp("hipGraphExternalSemaphoresSignalNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExternalSemaphoresSignalNodeSetParams; + if (strcmp("hipGraphExternalSemaphoresWaitNodeGetParams", name) == 0) + return HIP_API_ID_hipGraphExternalSemaphoresWaitNodeGetParams; + if (strcmp("hipGraphExternalSemaphoresWaitNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphExternalSemaphoresWaitNodeSetParams; + if (strcmp("hipGraphGetEdges", name) == 0) + return HIP_API_ID_hipGraphGetEdges; + if (strcmp("hipGraphGetNodes", name) == 0) + return HIP_API_ID_hipGraphGetNodes; + if (strcmp("hipGraphGetRootNodes", name) == 0) + return HIP_API_ID_hipGraphGetRootNodes; + if (strcmp("hipGraphHostNodeGetParams", name) == 0) + return HIP_API_ID_hipGraphHostNodeGetParams; + if (strcmp("hipGraphHostNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphHostNodeSetParams; + if (strcmp("hipGraphInstantiate", name) == 0) + return HIP_API_ID_hipGraphInstantiate; + if (strcmp("hipGraphInstantiateWithFlags", name) == 0) + return HIP_API_ID_hipGraphInstantiateWithFlags; + if (strcmp("hipGraphInstantiateWithParams", name) == 0) + return HIP_API_ID_hipGraphInstantiateWithParams; + if (strcmp("hipGraphKernelNodeCopyAttributes", name) == 0) + return HIP_API_ID_hipGraphKernelNodeCopyAttributes; + if (strcmp("hipGraphKernelNodeGetAttribute", name) == 0) + return HIP_API_ID_hipGraphKernelNodeGetAttribute; + if (strcmp("hipGraphKernelNodeGetParams", name) == 0) + return HIP_API_ID_hipGraphKernelNodeGetParams; + if (strcmp("hipGraphKernelNodeSetAttribute", name) == 0) + return HIP_API_ID_hipGraphKernelNodeSetAttribute; + if (strcmp("hipGraphKernelNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphKernelNodeSetParams; + if (strcmp("hipGraphLaunch", name) == 0) + return HIP_API_ID_hipGraphLaunch; + if (strcmp("hipGraphMemAllocNodeGetParams", name) == 0) + return HIP_API_ID_hipGraphMemAllocNodeGetParams; + if (strcmp("hipGraphMemFreeNodeGetParams", name) == 0) + return HIP_API_ID_hipGraphMemFreeNodeGetParams; + if (strcmp("hipGraphMemcpyNodeGetParams", name) == 0) + return HIP_API_ID_hipGraphMemcpyNodeGetParams; + if (strcmp("hipGraphMemcpyNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphMemcpyNodeSetParams; + if (strcmp("hipGraphMemcpyNodeSetParams1D", name) == 0) + return HIP_API_ID_hipGraphMemcpyNodeSetParams1D; + if (strcmp("hipGraphMemcpyNodeSetParamsFromSymbol", name) == 0) + return HIP_API_ID_hipGraphMemcpyNodeSetParamsFromSymbol; + if (strcmp("hipGraphMemcpyNodeSetParamsToSymbol", name) == 0) + return HIP_API_ID_hipGraphMemcpyNodeSetParamsToSymbol; + if (strcmp("hipGraphMemsetNodeGetParams", name) == 0) + return HIP_API_ID_hipGraphMemsetNodeGetParams; + if (strcmp("hipGraphMemsetNodeSetParams", name) == 0) + return HIP_API_ID_hipGraphMemsetNodeSetParams; + if (strcmp("hipGraphNodeFindInClone", name) == 0) + return HIP_API_ID_hipGraphNodeFindInClone; + if (strcmp("hipGraphNodeGetDependencies", name) == 0) + return HIP_API_ID_hipGraphNodeGetDependencies; + if (strcmp("hipGraphNodeGetDependentNodes", name) == 0) + return HIP_API_ID_hipGraphNodeGetDependentNodes; + if (strcmp("hipGraphNodeGetEnabled", name) == 0) + return HIP_API_ID_hipGraphNodeGetEnabled; + if (strcmp("hipGraphNodeGetType", name) == 0) + return HIP_API_ID_hipGraphNodeGetType; + if (strcmp("hipGraphNodeSetEnabled", name) == 0) + return HIP_API_ID_hipGraphNodeSetEnabled; + if (strcmp("hipGraphReleaseUserObject", name) == 0) + return HIP_API_ID_hipGraphReleaseUserObject; + if (strcmp("hipGraphRemoveDependencies", name) == 0) + return HIP_API_ID_hipGraphRemoveDependencies; + if (strcmp("hipGraphRetainUserObject", name) == 0) + return HIP_API_ID_hipGraphRetainUserObject; + if (strcmp("hipGraphUpload", name) == 0) + return HIP_API_ID_hipGraphUpload; + if (strcmp("hipGraphicsGLRegisterBuffer", name) == 0) + return HIP_API_ID_hipGraphicsGLRegisterBuffer; + if (strcmp("hipGraphicsGLRegisterImage", name) == 0) + return HIP_API_ID_hipGraphicsGLRegisterImage; + if (strcmp("hipGraphicsMapResources", name) == 0) + return HIP_API_ID_hipGraphicsMapResources; + if (strcmp("hipGraphicsResourceGetMappedPointer", name) == 0) + return HIP_API_ID_hipGraphicsResourceGetMappedPointer; + if (strcmp("hipGraphicsSubResourceGetMappedArray", name) == 0) + return HIP_API_ID_hipGraphicsSubResourceGetMappedArray; + if (strcmp("hipGraphicsUnmapResources", name) == 0) + return HIP_API_ID_hipGraphicsUnmapResources; + if (strcmp("hipGraphicsUnregisterResource", name) == 0) + return HIP_API_ID_hipGraphicsUnregisterResource; + if (strcmp("hipHccModuleLaunchKernel", name) == 0) + return HIP_API_ID_hipHccModuleLaunchKernel; + if (strcmp("hipHostAlloc", name) == 0) + return HIP_API_ID_hipHostAlloc; + if (strcmp("hipHostFree", name) == 0) + return HIP_API_ID_hipHostFree; + if (strcmp("hipHostGetDevicePointer", name) == 0) + return HIP_API_ID_hipHostGetDevicePointer; + if (strcmp("hipHostGetFlags", name) == 0) + return HIP_API_ID_hipHostGetFlags; + if (strcmp("hipHostMalloc", name) == 0) + return HIP_API_ID_hipHostMalloc; + if (strcmp("hipHostRegister", name) == 0) + return HIP_API_ID_hipHostRegister; + if (strcmp("hipHostUnregister", name) == 0) + return HIP_API_ID_hipHostUnregister; + if (strcmp("hipImportExternalMemory", name) == 0) + return HIP_API_ID_hipImportExternalMemory; + if (strcmp("hipImportExternalSemaphore", name) == 0) + return HIP_API_ID_hipImportExternalSemaphore; + if (strcmp("hipInit", name) == 0) + return HIP_API_ID_hipInit; + if (strcmp("hipIpcCloseMemHandle", name) == 0) + return HIP_API_ID_hipIpcCloseMemHandle; + if (strcmp("hipIpcGetEventHandle", name) == 0) + return HIP_API_ID_hipIpcGetEventHandle; + if (strcmp("hipIpcGetMemHandle", name) == 0) + return HIP_API_ID_hipIpcGetMemHandle; + if (strcmp("hipIpcOpenEventHandle", name) == 0) + return HIP_API_ID_hipIpcOpenEventHandle; + if (strcmp("hipIpcOpenMemHandle", name) == 0) + return HIP_API_ID_hipIpcOpenMemHandle; + if (strcmp("hipLaunchByPtr", name) == 0) + return HIP_API_ID_hipLaunchByPtr; + if (strcmp("hipLaunchCooperativeKernel", name) == 0) + return HIP_API_ID_hipLaunchCooperativeKernel; + if (strcmp("hipLaunchCooperativeKernelMultiDevice", name) == 0) + return HIP_API_ID_hipLaunchCooperativeKernelMultiDevice; + if (strcmp("hipLaunchHostFunc", name) == 0) + return HIP_API_ID_hipLaunchHostFunc; + if (strcmp("hipLaunchKernel", name) == 0) + return HIP_API_ID_hipLaunchKernel; + if (strcmp("hipMalloc", name) == 0) + return HIP_API_ID_hipMalloc; + if (strcmp("hipMalloc3D", name) == 0) + return HIP_API_ID_hipMalloc3D; + if (strcmp("hipMalloc3DArray", name) == 0) + return HIP_API_ID_hipMalloc3DArray; + if (strcmp("hipMallocArray", name) == 0) + return HIP_API_ID_hipMallocArray; + if (strcmp("hipMallocAsync", name) == 0) + return HIP_API_ID_hipMallocAsync; + if (strcmp("hipMallocFromPoolAsync", name) == 0) + return HIP_API_ID_hipMallocFromPoolAsync; + if (strcmp("hipMallocHost", name) == 0) + return HIP_API_ID_hipMallocHost; + if (strcmp("hipMallocManaged", name) == 0) + return HIP_API_ID_hipMallocManaged; + if (strcmp("hipMallocMipmappedArray", name) == 0) + return HIP_API_ID_hipMallocMipmappedArray; + if (strcmp("hipMallocPitch", name) == 0) + return HIP_API_ID_hipMallocPitch; + if (strcmp("hipMemAddressFree", name) == 0) + return HIP_API_ID_hipMemAddressFree; + if (strcmp("hipMemAddressReserve", name) == 0) + return HIP_API_ID_hipMemAddressReserve; + if (strcmp("hipMemAdvise", name) == 0) + return HIP_API_ID_hipMemAdvise; + if (strcmp("hipMemAllocHost", name) == 0) + return HIP_API_ID_hipMemAllocHost; + if (strcmp("hipMemAllocPitch", name) == 0) + return HIP_API_ID_hipMemAllocPitch; + if (strcmp("hipMemCreate", name) == 0) + return HIP_API_ID_hipMemCreate; + if (strcmp("hipMemExportToShareableHandle", name) == 0) + return HIP_API_ID_hipMemExportToShareableHandle; + if (strcmp("hipMemGetAccess", name) == 0) + return HIP_API_ID_hipMemGetAccess; + if (strcmp("hipMemGetAddressRange", name) == 0) + return HIP_API_ID_hipMemGetAddressRange; + if (strcmp("hipMemGetAllocationGranularity", name) == 0) + return HIP_API_ID_hipMemGetAllocationGranularity; + if (strcmp("hipMemGetAllocationPropertiesFromHandle", name) == 0) + return HIP_API_ID_hipMemGetAllocationPropertiesFromHandle; + if (strcmp("hipMemGetInfo", name) == 0) + return HIP_API_ID_hipMemGetInfo; + if (strcmp("hipMemImportFromShareableHandle", name) == 0) + return HIP_API_ID_hipMemImportFromShareableHandle; + if (strcmp("hipMemMap", name) == 0) + return HIP_API_ID_hipMemMap; + if (strcmp("hipMemMapArrayAsync", name) == 0) + return HIP_API_ID_hipMemMapArrayAsync; + if (strcmp("hipMemPoolCreate", name) == 0) + return HIP_API_ID_hipMemPoolCreate; + if (strcmp("hipMemPoolDestroy", name) == 0) + return HIP_API_ID_hipMemPoolDestroy; + if (strcmp("hipMemPoolExportPointer", name) == 0) + return HIP_API_ID_hipMemPoolExportPointer; + if (strcmp("hipMemPoolExportToShareableHandle", name) == 0) + return HIP_API_ID_hipMemPoolExportToShareableHandle; + if (strcmp("hipMemPoolGetAccess", name) == 0) + return HIP_API_ID_hipMemPoolGetAccess; + if (strcmp("hipMemPoolGetAttribute", name) == 0) + return HIP_API_ID_hipMemPoolGetAttribute; + if (strcmp("hipMemPoolImportFromShareableHandle", name) == 0) + return HIP_API_ID_hipMemPoolImportFromShareableHandle; + if (strcmp("hipMemPoolImportPointer", name) == 0) + return HIP_API_ID_hipMemPoolImportPointer; + if (strcmp("hipMemPoolSetAccess", name) == 0) + return HIP_API_ID_hipMemPoolSetAccess; + if (strcmp("hipMemPoolSetAttribute", name) == 0) + return HIP_API_ID_hipMemPoolSetAttribute; + if (strcmp("hipMemPoolTrimTo", name) == 0) + return HIP_API_ID_hipMemPoolTrimTo; + if (strcmp("hipMemPrefetchAsync", name) == 0) + return HIP_API_ID_hipMemPrefetchAsync; + if (strcmp("hipMemPtrGetInfo", name) == 0) + return HIP_API_ID_hipMemPtrGetInfo; + if (strcmp("hipMemRangeGetAttribute", name) == 0) + return HIP_API_ID_hipMemRangeGetAttribute; + if (strcmp("hipMemRangeGetAttributes", name) == 0) + return HIP_API_ID_hipMemRangeGetAttributes; + if (strcmp("hipMemRelease", name) == 0) + return HIP_API_ID_hipMemRelease; + if (strcmp("hipMemRetainAllocationHandle", name) == 0) + return HIP_API_ID_hipMemRetainAllocationHandle; + if (strcmp("hipMemSetAccess", name) == 0) + return HIP_API_ID_hipMemSetAccess; + if (strcmp("hipMemUnmap", name) == 0) + return HIP_API_ID_hipMemUnmap; + if (strcmp("hipMemcpy", name) == 0) + return HIP_API_ID_hipMemcpy; + if (strcmp("hipMemcpy2D", name) == 0) + return HIP_API_ID_hipMemcpy2D; + if (strcmp("hipMemcpy2DArrayToArray", name) == 0) + return HIP_API_ID_hipMemcpy2DArrayToArray; + if (strcmp("hipMemcpy2DAsync", name) == 0) + return HIP_API_ID_hipMemcpy2DAsync; + if (strcmp("hipMemcpy2DFromArray", name) == 0) + return HIP_API_ID_hipMemcpy2DFromArray; + if (strcmp("hipMemcpy2DFromArrayAsync", name) == 0) + return HIP_API_ID_hipMemcpy2DFromArrayAsync; + if (strcmp("hipMemcpy2DToArray", name) == 0) + return HIP_API_ID_hipMemcpy2DToArray; + if (strcmp("hipMemcpy2DToArrayAsync", name) == 0) + return HIP_API_ID_hipMemcpy2DToArrayAsync; + if (strcmp("hipMemcpy3D", name) == 0) + return HIP_API_ID_hipMemcpy3D; + if (strcmp("hipMemcpy3DAsync", name) == 0) + return HIP_API_ID_hipMemcpy3DAsync; + if (strcmp("hipMemcpyAsync", name) == 0) + return HIP_API_ID_hipMemcpyAsync; + if (strcmp("hipMemcpyAtoA", name) == 0) + return HIP_API_ID_hipMemcpyAtoA; + if (strcmp("hipMemcpyAtoD", name) == 0) + return HIP_API_ID_hipMemcpyAtoD; + if (strcmp("hipMemcpyAtoH", name) == 0) + return HIP_API_ID_hipMemcpyAtoH; + if (strcmp("hipMemcpyAtoHAsync", name) == 0) + return HIP_API_ID_hipMemcpyAtoHAsync; + if (strcmp("hipMemcpyDtoA", name) == 0) + return HIP_API_ID_hipMemcpyDtoA; + if (strcmp("hipMemcpyDtoD", name) == 0) + return HIP_API_ID_hipMemcpyDtoD; + if (strcmp("hipMemcpyDtoDAsync", name) == 0) + return HIP_API_ID_hipMemcpyDtoDAsync; + if (strcmp("hipMemcpyDtoH", name) == 0) + return HIP_API_ID_hipMemcpyDtoH; + if (strcmp("hipMemcpyDtoHAsync", name) == 0) + return HIP_API_ID_hipMemcpyDtoHAsync; + if (strcmp("hipMemcpyFromArray", name) == 0) + return HIP_API_ID_hipMemcpyFromArray; + if (strcmp("hipMemcpyFromSymbol", name) == 0) + return HIP_API_ID_hipMemcpyFromSymbol; + if (strcmp("hipMemcpyFromSymbolAsync", name) == 0) + return HIP_API_ID_hipMemcpyFromSymbolAsync; + if (strcmp("hipMemcpyHtoA", name) == 0) + return HIP_API_ID_hipMemcpyHtoA; + if (strcmp("hipMemcpyHtoAAsync", name) == 0) + return HIP_API_ID_hipMemcpyHtoAAsync; + if (strcmp("hipMemcpyHtoD", name) == 0) + return HIP_API_ID_hipMemcpyHtoD; + if (strcmp("hipMemcpyHtoDAsync", name) == 0) + return HIP_API_ID_hipMemcpyHtoDAsync; + if (strcmp("hipMemcpyParam2D", name) == 0) + return HIP_API_ID_hipMemcpyParam2D; + if (strcmp("hipMemcpyParam2DAsync", name) == 0) + return HIP_API_ID_hipMemcpyParam2DAsync; + if (strcmp("hipMemcpyPeer", name) == 0) + return HIP_API_ID_hipMemcpyPeer; + if (strcmp("hipMemcpyPeerAsync", name) == 0) + return HIP_API_ID_hipMemcpyPeerAsync; + if (strcmp("hipMemcpyToArray", name) == 0) + return HIP_API_ID_hipMemcpyToArray; + if (strcmp("hipMemcpyToSymbol", name) == 0) + return HIP_API_ID_hipMemcpyToSymbol; + if (strcmp("hipMemcpyToSymbolAsync", name) == 0) + return HIP_API_ID_hipMemcpyToSymbolAsync; + if (strcmp("hipMemcpyWithStream", name) == 0) + return HIP_API_ID_hipMemcpyWithStream; + if (strcmp("hipMemset", name) == 0) + return HIP_API_ID_hipMemset; + if (strcmp("hipMemset2D", name) == 0) + return HIP_API_ID_hipMemset2D; + if (strcmp("hipMemset2DAsync", name) == 0) + return HIP_API_ID_hipMemset2DAsync; + if (strcmp("hipMemset3D", name) == 0) + return HIP_API_ID_hipMemset3D; + if (strcmp("hipMemset3DAsync", name) == 0) + return HIP_API_ID_hipMemset3DAsync; + if (strcmp("hipMemsetAsync", name) == 0) + return HIP_API_ID_hipMemsetAsync; + if (strcmp("hipMemsetD16", name) == 0) + return HIP_API_ID_hipMemsetD16; + if (strcmp("hipMemsetD16Async", name) == 0) + return HIP_API_ID_hipMemsetD16Async; + if (strcmp("hipMemsetD32", name) == 0) + return HIP_API_ID_hipMemsetD32; + if (strcmp("hipMemsetD32Async", name) == 0) + return HIP_API_ID_hipMemsetD32Async; + if (strcmp("hipMemsetD8", name) == 0) + return HIP_API_ID_hipMemsetD8; + if (strcmp("hipMemsetD8Async", name) == 0) + return HIP_API_ID_hipMemsetD8Async; + if (strcmp("hipMipmappedArrayCreate", name) == 0) + return HIP_API_ID_hipMipmappedArrayCreate; + if (strcmp("hipMipmappedArrayDestroy", name) == 0) + return HIP_API_ID_hipMipmappedArrayDestroy; + if (strcmp("hipMipmappedArrayGetLevel", name) == 0) + return HIP_API_ID_hipMipmappedArrayGetLevel; + if (strcmp("hipModuleGetFunction", name) == 0) + return HIP_API_ID_hipModuleGetFunction; + if (strcmp("hipModuleGetGlobal", name) == 0) + return HIP_API_ID_hipModuleGetGlobal; + if (strcmp("hipModuleGetTexRef", name) == 0) + return HIP_API_ID_hipModuleGetTexRef; + if (strcmp("hipModuleLaunchCooperativeKernel", name) == 0) + return HIP_API_ID_hipModuleLaunchCooperativeKernel; + if (strcmp("hipModuleLaunchCooperativeKernelMultiDevice", name) == 0) + return HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice; + if (strcmp("hipModuleLaunchKernel", name) == 0) + return HIP_API_ID_hipModuleLaunchKernel; + if (strcmp("hipModuleLoad", name) == 0) + return HIP_API_ID_hipModuleLoad; + if (strcmp("hipModuleLoadData", name) == 0) + return HIP_API_ID_hipModuleLoadData; + if (strcmp("hipModuleLoadDataEx", name) == 0) + return HIP_API_ID_hipModuleLoadDataEx; + if (strcmp("hipModuleOccupancyMaxActiveBlocksPerMultiprocessor", name) == 0) + return HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessor; + if (strcmp("hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + name) == 0) + return HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags; + if (strcmp("hipModuleOccupancyMaxPotentialBlockSize", name) == 0) + return HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSize; + if (strcmp("hipModuleOccupancyMaxPotentialBlockSizeWithFlags", name) == 0) + return HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSizeWithFlags; + if (strcmp("hipModuleUnload", name) == 0) + return HIP_API_ID_hipModuleUnload; + if (strcmp("hipOccupancyMaxActiveBlocksPerMultiprocessor", name) == 0) + return HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessor; + if (strcmp("hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", name) == + 0) + return HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags; + if (strcmp("hipOccupancyMaxPotentialBlockSize", name) == 0) + return HIP_API_ID_hipOccupancyMaxPotentialBlockSize; + if (strcmp("hipPeekAtLastError", name) == 0) + return HIP_API_ID_hipPeekAtLastError; + if (strcmp("hipPointerGetAttribute", name) == 0) + return HIP_API_ID_hipPointerGetAttribute; + if (strcmp("hipPointerGetAttributes", name) == 0) + return HIP_API_ID_hipPointerGetAttributes; + if (strcmp("hipPointerSetAttribute", name) == 0) + return HIP_API_ID_hipPointerSetAttribute; + if (strcmp("hipProfilerStart", name) == 0) + return HIP_API_ID_hipProfilerStart; + if (strcmp("hipProfilerStop", name) == 0) + return HIP_API_ID_hipProfilerStop; + if (strcmp("hipRuntimeGetVersion", name) == 0) + return HIP_API_ID_hipRuntimeGetVersion; + if (strcmp("hipSetDevice", name) == 0) + return HIP_API_ID_hipSetDevice; + if (strcmp("hipSetDeviceFlags", name) == 0) + return HIP_API_ID_hipSetDeviceFlags; + if (strcmp("hipSetValidDevices", name) == 0) + return HIP_API_ID_hipSetValidDevices; + if (strcmp("hipSetupArgument", name) == 0) + return HIP_API_ID_hipSetupArgument; + if (strcmp("hipSignalExternalSemaphoresAsync", name) == 0) + return HIP_API_ID_hipSignalExternalSemaphoresAsync; + if (strcmp("hipStreamAddCallback", name) == 0) + return HIP_API_ID_hipStreamAddCallback; + if (strcmp("hipStreamAttachMemAsync", name) == 0) + return HIP_API_ID_hipStreamAttachMemAsync; + if (strcmp("hipStreamBeginCapture", name) == 0) + return HIP_API_ID_hipStreamBeginCapture; + if (strcmp("hipStreamBeginCaptureToGraph", name) == 0) + return HIP_API_ID_hipStreamBeginCaptureToGraph; + if (strcmp("hipStreamCreate", name) == 0) + return HIP_API_ID_hipStreamCreate; + if (strcmp("hipStreamCreateWithFlags", name) == 0) + return HIP_API_ID_hipStreamCreateWithFlags; + if (strcmp("hipStreamCreateWithPriority", name) == 0) + return HIP_API_ID_hipStreamCreateWithPriority; + if (strcmp("hipStreamDestroy", name) == 0) + return HIP_API_ID_hipStreamDestroy; + if (strcmp("hipStreamEndCapture", name) == 0) + return HIP_API_ID_hipStreamEndCapture; + if (strcmp("hipStreamGetCaptureInfo", name) == 0) + return HIP_API_ID_hipStreamGetCaptureInfo; + if (strcmp("hipStreamGetCaptureInfo_v2", name) == 0) + return HIP_API_ID_hipStreamGetCaptureInfo_v2; + if (strcmp("hipStreamGetDevice", name) == 0) + return HIP_API_ID_hipStreamGetDevice; + if (strcmp("hipStreamGetFlags", name) == 0) + return HIP_API_ID_hipStreamGetFlags; + if (strcmp("hipStreamGetPriority", name) == 0) + return HIP_API_ID_hipStreamGetPriority; + if (strcmp("hipStreamIsCapturing", name) == 0) + return HIP_API_ID_hipStreamIsCapturing; + if (strcmp("hipStreamQuery", name) == 0) + return HIP_API_ID_hipStreamQuery; + if (strcmp("hipStreamSynchronize", name) == 0) + return HIP_API_ID_hipStreamSynchronize; + if (strcmp("hipStreamUpdateCaptureDependencies", name) == 0) + return HIP_API_ID_hipStreamUpdateCaptureDependencies; + if (strcmp("hipStreamWaitEvent", name) == 0) + return HIP_API_ID_hipStreamWaitEvent; + if (strcmp("hipStreamWaitValue32", name) == 0) + return HIP_API_ID_hipStreamWaitValue32; + if (strcmp("hipStreamWaitValue64", name) == 0) + return HIP_API_ID_hipStreamWaitValue64; + if (strcmp("hipStreamWriteValue32", name) == 0) + return HIP_API_ID_hipStreamWriteValue32; + if (strcmp("hipStreamWriteValue64", name) == 0) + return HIP_API_ID_hipStreamWriteValue64; + if (strcmp("hipTexRefGetAddress", name) == 0) + return HIP_API_ID_hipTexRefGetAddress; + if (strcmp("hipTexRefGetArray", name) == 0) + return HIP_API_ID_hipTexRefGetArray; + if (strcmp("hipTexRefGetBorderColor", name) == 0) + return HIP_API_ID_hipTexRefGetBorderColor; + if (strcmp("hipTexRefGetFlags", name) == 0) + return HIP_API_ID_hipTexRefGetFlags; + if (strcmp("hipTexRefGetFormat", name) == 0) + return HIP_API_ID_hipTexRefGetFormat; + if (strcmp("hipTexRefGetMaxAnisotropy", name) == 0) + return HIP_API_ID_hipTexRefGetMaxAnisotropy; + if (strcmp("hipTexRefGetMipMappedArray", name) == 0) + return HIP_API_ID_hipTexRefGetMipMappedArray; + if (strcmp("hipTexRefGetMipmapLevelBias", name) == 0) + return HIP_API_ID_hipTexRefGetMipmapLevelBias; + if (strcmp("hipTexRefGetMipmapLevelClamp", name) == 0) + return HIP_API_ID_hipTexRefGetMipmapLevelClamp; + if (strcmp("hipTexRefSetAddress", name) == 0) + return HIP_API_ID_hipTexRefSetAddress; + if (strcmp("hipTexRefSetAddress2D", name) == 0) + return HIP_API_ID_hipTexRefSetAddress2D; + if (strcmp("hipTexRefSetArray", name) == 0) + return HIP_API_ID_hipTexRefSetArray; + if (strcmp("hipTexRefSetBorderColor", name) == 0) + return HIP_API_ID_hipTexRefSetBorderColor; + if (strcmp("hipTexRefSetFlags", name) == 0) + return HIP_API_ID_hipTexRefSetFlags; + if (strcmp("hipTexRefSetFormat", name) == 0) + return HIP_API_ID_hipTexRefSetFormat; + if (strcmp("hipTexRefSetMaxAnisotropy", name) == 0) + return HIP_API_ID_hipTexRefSetMaxAnisotropy; + if (strcmp("hipTexRefSetMipmapLevelBias", name) == 0) + return HIP_API_ID_hipTexRefSetMipmapLevelBias; + if (strcmp("hipTexRefSetMipmapLevelClamp", name) == 0) + return HIP_API_ID_hipTexRefSetMipmapLevelClamp; + if (strcmp("hipTexRefSetMipmappedArray", name) == 0) + return HIP_API_ID_hipTexRefSetMipmappedArray; + if (strcmp("hipThreadExchangeStreamCaptureMode", name) == 0) + return HIP_API_ID_hipThreadExchangeStreamCaptureMode; + if (strcmp("hipUserObjectCreate", name) == 0) + return HIP_API_ID_hipUserObjectCreate; + if (strcmp("hipUserObjectRelease", name) == 0) + return HIP_API_ID_hipUserObjectRelease; + if (strcmp("hipUserObjectRetain", name) == 0) + return HIP_API_ID_hipUserObjectRetain; + if (strcmp("hipWaitExternalSemaphoresAsync", name) == 0) + return HIP_API_ID_hipWaitExternalSemaphoresAsync; + return HIP_API_ID_NONE; +} + +// HIP API callbacks data structures +typedef struct hip_api_data_s { + uint64_t correlation_id; + uint32_t phase; + union { + struct { + dim3 *gridDim; + dim3 gridDim__val; + dim3 *blockDim; + dim3 blockDim__val; + size_t *sharedMem; + size_t sharedMem__val; + hipStream_t *stream; + hipStream_t stream__val; + } __hipPopCallConfiguration; + struct { + dim3 gridDim; + dim3 blockDim; + size_t sharedMem; + hipStream_t stream; + } __hipPushCallConfiguration; + struct { + hipArray_t *array; + hipArray_t array__val; + const HIP_ARRAY3D_DESCRIPTOR *pAllocateArray; + HIP_ARRAY3D_DESCRIPTOR pAllocateArray__val; + } hipArray3DCreate; + struct { + HIP_ARRAY3D_DESCRIPTOR *pArrayDescriptor; + HIP_ARRAY3D_DESCRIPTOR pArrayDescriptor__val; + hipArray_t array; + } hipArray3DGetDescriptor; + struct { + hipArray_t *pHandle; + hipArray_t pHandle__val; + const HIP_ARRAY_DESCRIPTOR *pAllocateArray; + HIP_ARRAY_DESCRIPTOR pAllocateArray__val; + } hipArrayCreate; + struct { + hipArray_t array; + } hipArrayDestroy; + struct { + HIP_ARRAY_DESCRIPTOR *pArrayDescriptor; + HIP_ARRAY_DESCRIPTOR pArrayDescriptor__val; + hipArray_t array; + } hipArrayGetDescriptor; + struct { + hipChannelFormatDesc *desc; + hipChannelFormatDesc desc__val; + hipExtent *extent; + hipExtent extent__val; + unsigned int *flags; + unsigned int flags__val; + hipArray_t array; + } hipArrayGetInfo; + struct { + int *device; + int device__val; + const hipDeviceProp_tR0000 *prop; + hipDeviceProp_tR0000 prop__val; + } hipChooseDeviceR0000; + struct { + int *device; + int device__val; + const hipDeviceProp_tR0600 *prop; + hipDeviceProp_tR0600 prop__val; + } hipChooseDeviceR0600; + struct { + dim3 gridDim; + dim3 blockDim; + size_t sharedMem; + hipStream_t stream; + } hipConfigureCall; + struct { + hipSurfaceObject_t *pSurfObject; + hipSurfaceObject_t pSurfObject__val; + const hipResourceDesc *pResDesc; + hipResourceDesc pResDesc__val; + } hipCreateSurfaceObject; + struct { + hipCtx_t *ctx; + hipCtx_t ctx__val; + unsigned int flags; + hipDevice_t device; + } hipCtxCreate; + struct { + hipCtx_t ctx; + } hipCtxDestroy; + struct { + hipCtx_t peerCtx; + } hipCtxDisablePeerAccess; + struct { + hipCtx_t peerCtx; + unsigned int flags; + } hipCtxEnablePeerAccess; + struct { + hipCtx_t ctx; + int *apiVersion; + int apiVersion__val; + } hipCtxGetApiVersion; + struct { + hipFuncCache_t *cacheConfig; + hipFuncCache_t cacheConfig__val; + } hipCtxGetCacheConfig; + struct { + hipCtx_t *ctx; + hipCtx_t ctx__val; + } hipCtxGetCurrent; + struct { + hipDevice_t *device; + hipDevice_t device__val; + } hipCtxGetDevice; + struct { + unsigned int *flags; + unsigned int flags__val; + } hipCtxGetFlags; + struct { + hipSharedMemConfig *pConfig; + hipSharedMemConfig pConfig__val; + } hipCtxGetSharedMemConfig; + struct { + hipCtx_t *ctx; + hipCtx_t ctx__val; + } hipCtxPopCurrent; + struct { + hipCtx_t ctx; + } hipCtxPushCurrent; + struct { + hipFuncCache_t cacheConfig; + } hipCtxSetCacheConfig; + struct { + hipCtx_t ctx; + } hipCtxSetCurrent; + struct { + hipSharedMemConfig config; + } hipCtxSetSharedMemConfig; + struct { + hipExternalMemory_t extMem; + } hipDestroyExternalMemory; + struct { + hipExternalSemaphore_t extSem; + } hipDestroyExternalSemaphore; + struct { + hipSurfaceObject_t surfaceObject; + } hipDestroySurfaceObject; + struct { + int *canAccessPeer; + int canAccessPeer__val; + int deviceId; + int peerDeviceId; + } hipDeviceCanAccessPeer; + struct { + int *major; + int major__val; + int *minor; + int minor__val; + hipDevice_t device; + } hipDeviceComputeCapability; + struct { + int peerDeviceId; + } hipDeviceDisablePeerAccess; + struct { + int peerDeviceId; + unsigned int flags; + } hipDeviceEnablePeerAccess; + struct { + hipDevice_t *device; + hipDevice_t device__val; + int ordinal; + } hipDeviceGet; + struct { + int *pi; + int pi__val; + hipDeviceAttribute_t attr; + int deviceId; + } hipDeviceGetAttribute; + struct { + int *device; + int device__val; + const char *pciBusId; + char pciBusId__val; + } hipDeviceGetByPCIBusId; + struct { + hipFuncCache_t *cacheConfig; + hipFuncCache_t cacheConfig__val; + } hipDeviceGetCacheConfig; + struct { + hipMemPool_t *mem_pool; + hipMemPool_t mem_pool__val; + int device; + } hipDeviceGetDefaultMemPool; + struct { + int device; + hipGraphMemAttributeType attr; + void *value; + } hipDeviceGetGraphMemAttribute; + struct { + size_t *pValue; + size_t pValue__val; + enum hipLimit_t limit; + } hipDeviceGetLimit; + struct { + hipMemPool_t *mem_pool; + hipMemPool_t mem_pool__val; + int device; + } hipDeviceGetMemPool; + struct { + char *name; + char name__val; + int len; + hipDevice_t device; + } hipDeviceGetName; + struct { + int *value; + int value__val; + hipDeviceP2PAttr attr; + int srcDevice; + int dstDevice; + } hipDeviceGetP2PAttribute; + struct { + char *pciBusId; + char pciBusId__val; + int len; + int device; + } hipDeviceGetPCIBusId; + struct { + hipSharedMemConfig *pConfig; + hipSharedMemConfig pConfig__val; + } hipDeviceGetSharedMemConfig; + struct { + int *leastPriority; + int leastPriority__val; + int *greatestPriority; + int greatestPriority__val; + } hipDeviceGetStreamPriorityRange; + struct { + hipUUID *uuid; + hipUUID uuid__val; + hipDevice_t device; + } hipDeviceGetUuid; + struct { + int device; + } hipDeviceGraphMemTrim; + struct { + hipDevice_t dev; + unsigned int *flags; + unsigned int flags__val; + int *active; + int active__val; + } hipDevicePrimaryCtxGetState; + struct { + hipDevice_t dev; + } hipDevicePrimaryCtxRelease; + struct { + hipDevice_t dev; + } hipDevicePrimaryCtxReset; + struct { + hipCtx_t *pctx; + hipCtx_t pctx__val; + hipDevice_t dev; + } hipDevicePrimaryCtxRetain; + struct { + hipDevice_t dev; + unsigned int flags; + } hipDevicePrimaryCtxSetFlags; + struct { + hipFuncCache_t cacheConfig; + } hipDeviceSetCacheConfig; + struct { + int device; + hipGraphMemAttributeType attr; + void *value; + } hipDeviceSetGraphMemAttribute; + struct { + enum hipLimit_t limit; + size_t value; + } hipDeviceSetLimit; + struct { + int device; + hipMemPool_t mem_pool; + } hipDeviceSetMemPool; + struct { + hipSharedMemConfig config; + } hipDeviceSetSharedMemConfig; + struct { + size_t *bytes; + size_t bytes__val; + hipDevice_t device; + } hipDeviceTotalMem; + struct { + int *driverVersion; + int driverVersion__val; + } hipDriverGetVersion; + struct { + hipGraphNode_t *phGraphNode; + hipGraphNode_t phGraphNode__val; + hipGraph_t hGraph; + const hipGraphNode_t *dependencies; + hipGraphNode_t dependencies__val; + size_t numDependencies; + const HIP_MEMCPY3D *copyParams; + HIP_MEMCPY3D copyParams__val; + hipCtx_t ctx; + } hipDrvGraphAddMemcpyNode; + struct { + hipGraphNode_t *phGraphNode; + hipGraphNode_t phGraphNode__val; + hipGraph_t hGraph; + const hipGraphNode_t *dependencies; + hipGraphNode_t dependencies__val; + size_t numDependencies; + const HIP_MEMSET_NODE_PARAMS *memsetParams; + HIP_MEMSET_NODE_PARAMS memsetParams__val; + hipCtx_t ctx; + } hipDrvGraphAddMemsetNode; + struct { + const hip_Memcpy2D *pCopy; + hip_Memcpy2D pCopy__val; + } hipDrvMemcpy2DUnaligned; + struct { + const HIP_MEMCPY3D *pCopy; + HIP_MEMCPY3D pCopy__val; + } hipDrvMemcpy3D; + struct { + const HIP_MEMCPY3D *pCopy; + HIP_MEMCPY3D pCopy__val; + hipStream_t stream; + } hipDrvMemcpy3DAsync; + struct { + unsigned int numAttributes; + hipPointer_attribute *attributes; + hipPointer_attribute attributes__val; + void **data; + void *data__val; + hipDeviceptr_t ptr; + } hipDrvPointerGetAttributes; + struct { + hipEvent_t *event; + hipEvent_t event__val; + } hipEventCreate; + struct { + hipEvent_t *event; + hipEvent_t event__val; + unsigned int flags; + } hipEventCreateWithFlags; + struct { + hipEvent_t event; + } hipEventDestroy; + struct { + float *ms; + float ms__val; + hipEvent_t start; + hipEvent_t stop; + } hipEventElapsedTime; + struct { + hipEvent_t event; + } hipEventQuery; + struct { + hipEvent_t event; + hipStream_t stream; + } hipEventRecord; + struct { + hipEvent_t event; + } hipEventSynchronize; + struct { + int device1; + int device2; + unsigned int *linktype; + unsigned int linktype__val; + unsigned int *hopcount; + unsigned int hopcount__val; + } hipExtGetLinkTypeAndHopCount; + struct { + const void *function_address; + dim3 numBlocks; + dim3 dimBlocks; + void **args; + void *args__val; + size_t sharedMemBytes; + hipStream_t stream; + hipEvent_t startEvent; + hipEvent_t stopEvent; + int flags; + } hipExtLaunchKernel; + struct { + hipLaunchParams *launchParamsList; + hipLaunchParams launchParamsList__val; + int numDevices; + unsigned int flags; + } hipExtLaunchMultiKernelMultiDevice; + struct { + void **ptr; + void *ptr__val; + size_t sizeBytes; + unsigned int flags; + } hipExtMallocWithFlags; + struct { + hipFunction_t f; + unsigned int globalWorkSizeX; + unsigned int globalWorkSizeY; + unsigned int globalWorkSizeZ; + unsigned int localWorkSizeX; + unsigned int localWorkSizeY; + unsigned int localWorkSizeZ; + size_t sharedMemBytes; + hipStream_t hStream; + void **kernelParams; + void *kernelParams__val; + void **extra; + void *extra__val; + hipEvent_t startEvent; + hipEvent_t stopEvent; + unsigned int flags; + } hipExtModuleLaunchKernel; + struct { + hipStream_t *stream; + hipStream_t stream__val; + unsigned int cuMaskSize; + const unsigned int *cuMask; + unsigned int cuMask__val; + } hipExtStreamCreateWithCUMask; + struct { + hipStream_t stream; + unsigned int cuMaskSize; + unsigned int *cuMask; + unsigned int cuMask__val; + } hipExtStreamGetCUMask; + struct { + void **devPtr; + void *devPtr__val; + hipExternalMemory_t extMem; + const hipExternalMemoryBufferDesc *bufferDesc; + hipExternalMemoryBufferDesc bufferDesc__val; + } hipExternalMemoryGetMappedBuffer; + struct { + hipMipmappedArray_t *mipmap; + hipMipmappedArray_t mipmap__val; + hipExternalMemory_t extMem; + const hipExternalMemoryMipmappedArrayDesc *mipmapDesc; + hipExternalMemoryMipmappedArrayDesc mipmapDesc__val; + } hipExternalMemoryGetMappedMipmappedArray; + struct { + void *ptr; + } hipFree; + struct { + hipArray_t array; + } hipFreeArray; + struct { + void *dev_ptr; + hipStream_t stream; + } hipFreeAsync; + struct { + void *ptr; + } hipFreeHost; + struct { + hipMipmappedArray_t mipmappedArray; + } hipFreeMipmappedArray; + struct { + int *value; + int value__val; + hipFunction_attribute attrib; + hipFunction_t hfunc; + } hipFuncGetAttribute; + struct { + hipFuncAttributes *attr; + hipFuncAttributes attr__val; + const void *func; + } hipFuncGetAttributes; + struct { + const void *func; + hipFuncAttribute attr; + int value; + } hipFuncSetAttribute; + struct { + const void *func; + hipFuncCache_t config; + } hipFuncSetCacheConfig; + struct { + const void *func; + hipSharedMemConfig config; + } hipFuncSetSharedMemConfig; + struct { + unsigned int *pHipDeviceCount; + unsigned int pHipDeviceCount__val; + int *pHipDevices; + int pHipDevices__val; + unsigned int hipDeviceCount; + hipGLDeviceList deviceList; + } hipGLGetDevices; + struct { + hipChannelFormatDesc *desc; + hipChannelFormatDesc desc__val; + hipArray_const_t array; + } hipGetChannelDesc; + struct { + int *deviceId; + int deviceId__val; + } hipGetDevice; + struct { + int *count; + int count__val; + } hipGetDeviceCount; + struct { + unsigned int *flags; + unsigned int flags__val; + } hipGetDeviceFlags; + struct { + hipDeviceProp_tR0000 *prop; + hipDeviceProp_tR0000 prop__val; + int device; + } hipGetDevicePropertiesR0000; + struct { + hipDeviceProp_tR0600 *prop; + hipDeviceProp_tR0600 prop__val; + int deviceId; + } hipGetDevicePropertiesR0600; + struct { + hipFunction_t *functionPtr; + hipFunction_t functionPtr__val; + const void *symbolPtr; + } hipGetFuncBySymbol; + struct { + hipArray_t *levelArray; + hipArray_t levelArray__val; + hipMipmappedArray_const_t mipmappedArray; + unsigned int level; + } hipGetMipmappedArrayLevel; + struct { + const char *symbol; + char symbol__val; + void **pfn; + void *pfn__val; + int hipVersion; + uint64_t flags; + hipDriverProcAddressQueryResult *symbolStatus; + hipDriverProcAddressQueryResult symbolStatus__val; + } hipGetProcAddress; + struct { + void **devPtr; + void *devPtr__val; + const void *symbol; + } hipGetSymbolAddress; + struct { + size_t *size; + size_t size__val; + const void *symbol; + } hipGetSymbolSize; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + hipGraph_t childGraph; + } hipGraphAddChildGraphNode; + struct { + hipGraph_t graph; + const hipGraphNode_t *from; + hipGraphNode_t from__val; + const hipGraphNode_t *to; + hipGraphNode_t to__val; + size_t numDependencies; + } hipGraphAddDependencies; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + } hipGraphAddEmptyNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + hipEvent_t event; + } hipGraphAddEventRecordNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + hipEvent_t event; + } hipGraphAddEventWaitNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + const hipExternalSemaphoreSignalNodeParams *nodeParams; + hipExternalSemaphoreSignalNodeParams nodeParams__val; + } hipGraphAddExternalSemaphoresSignalNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + const hipExternalSemaphoreWaitNodeParams *nodeParams; + hipExternalSemaphoreWaitNodeParams nodeParams__val; + } hipGraphAddExternalSemaphoresWaitNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + const hipHostNodeParams *pNodeParams; + hipHostNodeParams pNodeParams__val; + } hipGraphAddHostNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + const hipKernelNodeParams *pNodeParams; + hipKernelNodeParams pNodeParams__val; + } hipGraphAddKernelNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + hipMemAllocNodeParams *pNodeParams; + hipMemAllocNodeParams pNodeParams__val; + } hipGraphAddMemAllocNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + void *dev_ptr; + } hipGraphAddMemFreeNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + const hipMemcpy3DParms *pCopyParams; + hipMemcpy3DParms pCopyParams__val; + } hipGraphAddMemcpyNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + void *dst; + const void *src; + size_t count; + hipMemcpyKind kind; + } hipGraphAddMemcpyNode1D; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + void *dst; + const void *symbol; + size_t count; + size_t offset; + hipMemcpyKind kind; + } hipGraphAddMemcpyNodeFromSymbol; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + const void *symbol; + const void *src; + size_t count; + size_t offset; + hipMemcpyKind kind; + } hipGraphAddMemcpyNodeToSymbol; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + const hipMemsetParams *pMemsetParams; + hipMemsetParams pMemsetParams__val; + } hipGraphAddMemsetNode; + struct { + hipGraphNode_t *pGraphNode; + hipGraphNode_t pGraphNode__val; + hipGraph_t graph; + const hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t numDependencies; + hipGraphNodeParams *nodeParams; + hipGraphNodeParams nodeParams__val; + } hipGraphAddNode; + struct { + hipGraphNode_t node; + hipGraph_t *pGraph; + hipGraph_t pGraph__val; + } hipGraphChildGraphNodeGetGraph; + struct { + hipGraph_t *pGraphClone; + hipGraph_t pGraphClone__val; + hipGraph_t originalGraph; + } hipGraphClone; + struct { + hipGraph_t *pGraph; + hipGraph_t pGraph__val; + unsigned int flags; + } hipGraphCreate; + struct { + hipGraph_t graph; + const char *path; + char path__val; + unsigned int flags; + } hipGraphDebugDotPrint; + struct { + hipGraph_t graph; + } hipGraphDestroy; + struct { + hipGraphNode_t node; + } hipGraphDestroyNode; + struct { + hipGraphNode_t node; + hipEvent_t *event_out; + hipEvent_t event_out__val; + } hipGraphEventRecordNodeGetEvent; + struct { + hipGraphNode_t node; + hipEvent_t event; + } hipGraphEventRecordNodeSetEvent; + struct { + hipGraphNode_t node; + hipEvent_t *event_out; + hipEvent_t event_out__val; + } hipGraphEventWaitNodeGetEvent; + struct { + hipGraphNode_t node; + hipEvent_t event; + } hipGraphEventWaitNodeSetEvent; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t node; + hipGraph_t childGraph; + } hipGraphExecChildGraphNodeSetParams; + struct { + hipGraphExec_t graphExec; + } hipGraphExecDestroy; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t hNode; + hipEvent_t event; + } hipGraphExecEventRecordNodeSetEvent; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t hNode; + hipEvent_t event; + } hipGraphExecEventWaitNodeSetEvent; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t hNode; + const hipExternalSemaphoreSignalNodeParams *nodeParams; + hipExternalSemaphoreSignalNodeParams nodeParams__val; + } hipGraphExecExternalSemaphoresSignalNodeSetParams; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t hNode; + const hipExternalSemaphoreWaitNodeParams *nodeParams; + hipExternalSemaphoreWaitNodeParams nodeParams__val; + } hipGraphExecExternalSemaphoresWaitNodeSetParams; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t node; + const hipHostNodeParams *pNodeParams; + hipHostNodeParams pNodeParams__val; + } hipGraphExecHostNodeSetParams; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t node; + const hipKernelNodeParams *pNodeParams; + hipKernelNodeParams pNodeParams__val; + } hipGraphExecKernelNodeSetParams; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t node; + hipMemcpy3DParms *pNodeParams; + hipMemcpy3DParms pNodeParams__val; + } hipGraphExecMemcpyNodeSetParams; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t node; + void *dst; + const void *src; + size_t count; + hipMemcpyKind kind; + } hipGraphExecMemcpyNodeSetParams1D; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t node; + void *dst; + const void *symbol; + size_t count; + size_t offset; + hipMemcpyKind kind; + } hipGraphExecMemcpyNodeSetParamsFromSymbol; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t node; + const void *symbol; + const void *src; + size_t count; + size_t offset; + hipMemcpyKind kind; + } hipGraphExecMemcpyNodeSetParamsToSymbol; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t node; + const hipMemsetParams *pNodeParams; + hipMemsetParams pNodeParams__val; + } hipGraphExecMemsetNodeSetParams; + struct { + hipGraphExec_t hGraphExec; + hipGraph_t hGraph; + hipGraphNode_t *hErrorNode_out; + hipGraphNode_t hErrorNode_out__val; + hipGraphExecUpdateResult *updateResult_out; + hipGraphExecUpdateResult updateResult_out__val; + } hipGraphExecUpdate; + struct { + hipGraphNode_t hNode; + hipExternalSemaphoreSignalNodeParams *params_out; + hipExternalSemaphoreSignalNodeParams params_out__val; + } hipGraphExternalSemaphoresSignalNodeGetParams; + struct { + hipGraphNode_t hNode; + const hipExternalSemaphoreSignalNodeParams *nodeParams; + hipExternalSemaphoreSignalNodeParams nodeParams__val; + } hipGraphExternalSemaphoresSignalNodeSetParams; + struct { + hipGraphNode_t hNode; + hipExternalSemaphoreWaitNodeParams *params_out; + hipExternalSemaphoreWaitNodeParams params_out__val; + } hipGraphExternalSemaphoresWaitNodeGetParams; + struct { + hipGraphNode_t hNode; + const hipExternalSemaphoreWaitNodeParams *nodeParams; + hipExternalSemaphoreWaitNodeParams nodeParams__val; + } hipGraphExternalSemaphoresWaitNodeSetParams; + struct { + hipGraph_t graph; + hipGraphNode_t *from; + hipGraphNode_t from__val; + hipGraphNode_t *to; + hipGraphNode_t to__val; + size_t *numEdges; + size_t numEdges__val; + } hipGraphGetEdges; + struct { + hipGraph_t graph; + hipGraphNode_t *nodes; + hipGraphNode_t nodes__val; + size_t *numNodes; + size_t numNodes__val; + } hipGraphGetNodes; + struct { + hipGraph_t graph; + hipGraphNode_t *pRootNodes; + hipGraphNode_t pRootNodes__val; + size_t *pNumRootNodes; + size_t pNumRootNodes__val; + } hipGraphGetRootNodes; + struct { + hipGraphNode_t node; + hipHostNodeParams *pNodeParams; + hipHostNodeParams pNodeParams__val; + } hipGraphHostNodeGetParams; + struct { + hipGraphNode_t node; + const hipHostNodeParams *pNodeParams; + hipHostNodeParams pNodeParams__val; + } hipGraphHostNodeSetParams; + struct { + hipGraphExec_t *pGraphExec; + hipGraphExec_t pGraphExec__val; + hipGraph_t graph; + hipGraphNode_t *pErrorNode; + hipGraphNode_t pErrorNode__val; + char *pLogBuffer; + char pLogBuffer__val; + size_t bufferSize; + } hipGraphInstantiate; + struct { + hipGraphExec_t *pGraphExec; + hipGraphExec_t pGraphExec__val; + hipGraph_t graph; + unsigned long long flags; + } hipGraphInstantiateWithFlags; + struct { + hipGraphExec_t *pGraphExec; + hipGraphExec_t pGraphExec__val; + hipGraph_t graph; + hipGraphInstantiateParams *instantiateParams; + hipGraphInstantiateParams instantiateParams__val; + } hipGraphInstantiateWithParams; + struct { + hipGraphNode_t hSrc; + hipGraphNode_t hDst; + } hipGraphKernelNodeCopyAttributes; + struct { + hipGraphNode_t hNode; + hipLaunchAttributeID attr; + hipLaunchAttributeValue *value; + hipLaunchAttributeValue value__val; + } hipGraphKernelNodeGetAttribute; + struct { + hipGraphNode_t node; + hipKernelNodeParams *pNodeParams; + hipKernelNodeParams pNodeParams__val; + } hipGraphKernelNodeGetParams; + struct { + hipGraphNode_t hNode; + hipLaunchAttributeID attr; + const hipLaunchAttributeValue *value; + hipLaunchAttributeValue value__val; + } hipGraphKernelNodeSetAttribute; + struct { + hipGraphNode_t node; + const hipKernelNodeParams *pNodeParams; + hipKernelNodeParams pNodeParams__val; + } hipGraphKernelNodeSetParams; + struct { + hipGraphExec_t graphExec; + hipStream_t stream; + } hipGraphLaunch; + struct { + hipGraphNode_t node; + hipMemAllocNodeParams *pNodeParams; + hipMemAllocNodeParams pNodeParams__val; + } hipGraphMemAllocNodeGetParams; + struct { + hipGraphNode_t node; + void *dev_ptr; + } hipGraphMemFreeNodeGetParams; + struct { + hipGraphNode_t node; + hipMemcpy3DParms *pNodeParams; + hipMemcpy3DParms pNodeParams__val; + } hipGraphMemcpyNodeGetParams; + struct { + hipGraphNode_t node; + const hipMemcpy3DParms *pNodeParams; + hipMemcpy3DParms pNodeParams__val; + } hipGraphMemcpyNodeSetParams; + struct { + hipGraphNode_t node; + void *dst; + const void *src; + size_t count; + hipMemcpyKind kind; + } hipGraphMemcpyNodeSetParams1D; + struct { + hipGraphNode_t node; + void *dst; + const void *symbol; + size_t count; + size_t offset; + hipMemcpyKind kind; + } hipGraphMemcpyNodeSetParamsFromSymbol; + struct { + hipGraphNode_t node; + const void *symbol; + const void *src; + size_t count; + size_t offset; + hipMemcpyKind kind; + } hipGraphMemcpyNodeSetParamsToSymbol; + struct { + hipGraphNode_t node; + hipMemsetParams *pNodeParams; + hipMemsetParams pNodeParams__val; + } hipGraphMemsetNodeGetParams; + struct { + hipGraphNode_t node; + const hipMemsetParams *pNodeParams; + hipMemsetParams pNodeParams__val; + } hipGraphMemsetNodeSetParams; + struct { + hipGraphNode_t *pNode; + hipGraphNode_t pNode__val; + hipGraphNode_t originalNode; + hipGraph_t clonedGraph; + } hipGraphNodeFindInClone; + struct { + hipGraphNode_t node; + hipGraphNode_t *pDependencies; + hipGraphNode_t pDependencies__val; + size_t *pNumDependencies; + size_t pNumDependencies__val; + } hipGraphNodeGetDependencies; + struct { + hipGraphNode_t node; + hipGraphNode_t *pDependentNodes; + hipGraphNode_t pDependentNodes__val; + size_t *pNumDependentNodes; + size_t pNumDependentNodes__val; + } hipGraphNodeGetDependentNodes; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t hNode; + unsigned int *isEnabled; + unsigned int isEnabled__val; + } hipGraphNodeGetEnabled; + struct { + hipGraphNode_t node; + hipGraphNodeType *pType; + hipGraphNodeType pType__val; + } hipGraphNodeGetType; + struct { + hipGraphExec_t hGraphExec; + hipGraphNode_t hNode; + unsigned int isEnabled; + } hipGraphNodeSetEnabled; + struct { + hipGraph_t graph; + hipUserObject_t object; + unsigned int count; + } hipGraphReleaseUserObject; + struct { + hipGraph_t graph; + const hipGraphNode_t *from; + hipGraphNode_t from__val; + const hipGraphNode_t *to; + hipGraphNode_t to__val; + size_t numDependencies; + } hipGraphRemoveDependencies; + struct { + hipGraph_t graph; + hipUserObject_t object; + unsigned int count; + unsigned int flags; + } hipGraphRetainUserObject; + struct { + hipGraphExec_t graphExec; + hipStream_t stream; + } hipGraphUpload; + struct { + hipGraphicsResource **resource; + hipGraphicsResource *resource__val; + GLuint buffer; + unsigned int flags; + } hipGraphicsGLRegisterBuffer; + struct { + hipGraphicsResource **resource; + hipGraphicsResource *resource__val; + GLuint image; + GLenum target; + unsigned int flags; + } hipGraphicsGLRegisterImage; + struct { + int count; + hipGraphicsResource_t *resources; + hipGraphicsResource_t resources__val; + hipStream_t stream; + } hipGraphicsMapResources; + struct { + void **devPtr; + void *devPtr__val; + size_t *size; + size_t size__val; + hipGraphicsResource_t resource; + } hipGraphicsResourceGetMappedPointer; + struct { + hipArray_t *array; + hipArray_t array__val; + hipGraphicsResource_t resource; + unsigned int arrayIndex; + unsigned int mipLevel; + } hipGraphicsSubResourceGetMappedArray; + struct { + int count; + hipGraphicsResource_t *resources; + hipGraphicsResource_t resources__val; + hipStream_t stream; + } hipGraphicsUnmapResources; + struct { + hipGraphicsResource_t resource; + } hipGraphicsUnregisterResource; + struct { + hipFunction_t f; + unsigned int globalWorkSizeX; + unsigned int globalWorkSizeY; + unsigned int globalWorkSizeZ; + unsigned int blockDimX; + unsigned int blockDimY; + unsigned int blockDimZ; + size_t sharedMemBytes; + hipStream_t hStream; + void **kernelParams; + void *kernelParams__val; + void **extra; + void *extra__val; + hipEvent_t startEvent; + hipEvent_t stopEvent; + } hipHccModuleLaunchKernel; + struct { + void **ptr; + void *ptr__val; + size_t size; + unsigned int flags; + } hipHostAlloc; + struct { + void *ptr; + } hipHostFree; + struct { + void **devPtr; + void *devPtr__val; + void *hstPtr; + unsigned int flags; + } hipHostGetDevicePointer; + struct { + unsigned int *flagsPtr; + unsigned int flagsPtr__val; + void *hostPtr; + } hipHostGetFlags; + struct { + void **ptr; + void *ptr__val; + size_t size; + unsigned int flags; + } hipHostMalloc; + struct { + void *hostPtr; + size_t sizeBytes; + unsigned int flags; + } hipHostRegister; + struct { + void *hostPtr; + } hipHostUnregister; + struct { + hipExternalMemory_t *extMem_out; + hipExternalMemory_t extMem_out__val; + const hipExternalMemoryHandleDesc *memHandleDesc; + hipExternalMemoryHandleDesc memHandleDesc__val; + } hipImportExternalMemory; + struct { + hipExternalSemaphore_t *extSem_out; + hipExternalSemaphore_t extSem_out__val; + const hipExternalSemaphoreHandleDesc *semHandleDesc; + hipExternalSemaphoreHandleDesc semHandleDesc__val; + } hipImportExternalSemaphore; + struct { + unsigned int flags; + } hipInit; + struct { + void *devPtr; + } hipIpcCloseMemHandle; + struct { + hipIpcEventHandle_t *handle; + hipIpcEventHandle_t handle__val; + hipEvent_t event; + } hipIpcGetEventHandle; + struct { + hipIpcMemHandle_t *handle; + hipIpcMemHandle_t handle__val; + void *devPtr; + } hipIpcGetMemHandle; + struct { + hipEvent_t *event; + hipEvent_t event__val; + hipIpcEventHandle_t handle; + } hipIpcOpenEventHandle; + struct { + void **devPtr; + void *devPtr__val; + hipIpcMemHandle_t handle; + unsigned int flags; + } hipIpcOpenMemHandle; + struct { + const void *hostFunction; + } hipLaunchByPtr; + struct { + const void *f; + dim3 gridDim; + dim3 blockDimX; + void **kernelParams; + void *kernelParams__val; + unsigned int sharedMemBytes; + hipStream_t stream; + } hipLaunchCooperativeKernel; + struct { + hipLaunchParams *launchParamsList; + hipLaunchParams launchParamsList__val; + int numDevices; + unsigned int flags; + } hipLaunchCooperativeKernelMultiDevice; + struct { + hipStream_t stream; + hipHostFn_t fn; + void *userData; + } hipLaunchHostFunc; + struct { + const void *function_address; + dim3 numBlocks; + dim3 dimBlocks; + void **args; + void *args__val; + size_t sharedMemBytes; + hipStream_t stream; + } hipLaunchKernel; + struct { + void **ptr; + void *ptr__val; + size_t size; + } hipMalloc; + struct { + hipPitchedPtr *pitchedDevPtr; + hipPitchedPtr pitchedDevPtr__val; + hipExtent extent; + } hipMalloc3D; + struct { + hipArray_t *array; + hipArray_t array__val; + const hipChannelFormatDesc *desc; + hipChannelFormatDesc desc__val; + hipExtent extent; + unsigned int flags; + } hipMalloc3DArray; + struct { + hipArray_t *array; + hipArray_t array__val; + const hipChannelFormatDesc *desc; + hipChannelFormatDesc desc__val; + size_t width; + size_t height; + unsigned int flags; + } hipMallocArray; + struct { + void **dev_ptr; + void *dev_ptr__val; + size_t size; + hipStream_t stream; + } hipMallocAsync; + struct { + void **dev_ptr; + void *dev_ptr__val; + size_t size; + hipMemPool_t mem_pool; + hipStream_t stream; + } hipMallocFromPoolAsync; + struct { + void **ptr; + void *ptr__val; + size_t size; + } hipMallocHost; + struct { + void **dev_ptr; + void *dev_ptr__val; + size_t size; + unsigned int flags; + } hipMallocManaged; + struct { + hipMipmappedArray_t *mipmappedArray; + hipMipmappedArray_t mipmappedArray__val; + const hipChannelFormatDesc *desc; + hipChannelFormatDesc desc__val; + hipExtent extent; + unsigned int numLevels; + unsigned int flags; + } hipMallocMipmappedArray; + struct { + void **ptr; + void *ptr__val; + size_t *pitch; + size_t pitch__val; + size_t width; + size_t height; + } hipMallocPitch; + struct { + void *devPtr; + size_t size; + } hipMemAddressFree; + struct { + void **ptr; + void *ptr__val; + size_t size; + size_t alignment; + void *addr; + unsigned long long flags; + } hipMemAddressReserve; + struct { + const void *dev_ptr; + size_t count; + hipMemoryAdvise advice; + int device; + } hipMemAdvise; + struct { + void **ptr; + void *ptr__val; + size_t size; + } hipMemAllocHost; + struct { + hipDeviceptr_t *dptr; + hipDeviceptr_t dptr__val; + size_t *pitch; + size_t pitch__val; + size_t widthInBytes; + size_t height; + unsigned int elementSizeBytes; + } hipMemAllocPitch; + struct { + hipMemGenericAllocationHandle_t *handle; + hipMemGenericAllocationHandle_t handle__val; + size_t size; + const hipMemAllocationProp *prop; + hipMemAllocationProp prop__val; + unsigned long long flags; + } hipMemCreate; + struct { + void *shareableHandle; + hipMemGenericAllocationHandle_t handle; + hipMemAllocationHandleType handleType; + unsigned long long flags; + } hipMemExportToShareableHandle; + struct { + unsigned long long *flags; + unsigned long long flags__val; + const hipMemLocation *location; + hipMemLocation location__val; + void *ptr; + } hipMemGetAccess; + struct { + hipDeviceptr_t *pbase; + hipDeviceptr_t pbase__val; + size_t *psize; + size_t psize__val; + hipDeviceptr_t dptr; + } hipMemGetAddressRange; + struct { + size_t *granularity; + size_t granularity__val; + const hipMemAllocationProp *prop; + hipMemAllocationProp prop__val; + hipMemAllocationGranularity_flags option; + } hipMemGetAllocationGranularity; + struct { + hipMemAllocationProp *prop; + hipMemAllocationProp prop__val; + hipMemGenericAllocationHandle_t handle; + } hipMemGetAllocationPropertiesFromHandle; + struct { + size_t *free; + size_t free__val; + size_t *total; + size_t total__val; + } hipMemGetInfo; + struct { + hipMemGenericAllocationHandle_t *handle; + hipMemGenericAllocationHandle_t handle__val; + void *osHandle; + hipMemAllocationHandleType shHandleType; + } hipMemImportFromShareableHandle; + struct { + void *ptr; + size_t size; + size_t offset; + hipMemGenericAllocationHandle_t handle; + unsigned long long flags; + } hipMemMap; + struct { + hipArrayMapInfo *mapInfoList; + hipArrayMapInfo mapInfoList__val; + unsigned int count; + hipStream_t stream; + } hipMemMapArrayAsync; + struct { + hipMemPool_t *mem_pool; + hipMemPool_t mem_pool__val; + const hipMemPoolProps *pool_props; + hipMemPoolProps pool_props__val; + } hipMemPoolCreate; + struct { + hipMemPool_t mem_pool; + } hipMemPoolDestroy; + struct { + hipMemPoolPtrExportData *export_data; + hipMemPoolPtrExportData export_data__val; + void *dev_ptr; + } hipMemPoolExportPointer; + struct { + void *shared_handle; + hipMemPool_t mem_pool; + hipMemAllocationHandleType handle_type; + unsigned int flags; + } hipMemPoolExportToShareableHandle; + struct { + hipMemAccessFlags *flags; + hipMemAccessFlags flags__val; + hipMemPool_t mem_pool; + hipMemLocation *location; + hipMemLocation location__val; + } hipMemPoolGetAccess; + struct { + hipMemPool_t mem_pool; + hipMemPoolAttr attr; + void *value; + } hipMemPoolGetAttribute; + struct { + hipMemPool_t *mem_pool; + hipMemPool_t mem_pool__val; + void *shared_handle; + hipMemAllocationHandleType handle_type; + unsigned int flags; + } hipMemPoolImportFromShareableHandle; + struct { + void **dev_ptr; + void *dev_ptr__val; + hipMemPool_t mem_pool; + hipMemPoolPtrExportData *export_data; + hipMemPoolPtrExportData export_data__val; + } hipMemPoolImportPointer; + struct { + hipMemPool_t mem_pool; + const hipMemAccessDesc *desc_list; + hipMemAccessDesc desc_list__val; + size_t count; + } hipMemPoolSetAccess; + struct { + hipMemPool_t mem_pool; + hipMemPoolAttr attr; + void *value; + } hipMemPoolSetAttribute; + struct { + hipMemPool_t mem_pool; + size_t min_bytes_to_hold; + } hipMemPoolTrimTo; + struct { + const void *dev_ptr; + size_t count; + int device; + hipStream_t stream; + } hipMemPrefetchAsync; + struct { + void *ptr; + size_t *size; + size_t size__val; + } hipMemPtrGetInfo; + struct { + void *data; + size_t data_size; + hipMemRangeAttribute attribute; + const void *dev_ptr; + size_t count; + } hipMemRangeGetAttribute; + struct { + void **data; + void *data__val; + size_t *data_sizes; + size_t data_sizes__val; + hipMemRangeAttribute *attributes; + hipMemRangeAttribute attributes__val; + size_t num_attributes; + const void *dev_ptr; + size_t count; + } hipMemRangeGetAttributes; + struct { + hipMemGenericAllocationHandle_t handle; + } hipMemRelease; + struct { + hipMemGenericAllocationHandle_t *handle; + hipMemGenericAllocationHandle_t handle__val; + void *addr; + } hipMemRetainAllocationHandle; + struct { + void *ptr; + size_t size; + const hipMemAccessDesc *desc; + hipMemAccessDesc desc__val; + size_t count; + } hipMemSetAccess; + struct { + void *ptr; + size_t size; + } hipMemUnmap; + struct { + void *dst; + const void *src; + size_t sizeBytes; + hipMemcpyKind kind; + } hipMemcpy; + struct { + void *dst; + size_t dpitch; + const void *src; + size_t spitch; + size_t width; + size_t height; + hipMemcpyKind kind; + } hipMemcpy2D; + struct { + hipArray_t dst; + size_t wOffsetDst; + size_t hOffsetDst; + hipArray_const_t src; + size_t wOffsetSrc; + size_t hOffsetSrc; + size_t width; + size_t height; + hipMemcpyKind kind; + } hipMemcpy2DArrayToArray; + struct { + void *dst; + size_t dpitch; + const void *src; + size_t spitch; + size_t width; + size_t height; + hipMemcpyKind kind; + hipStream_t stream; + } hipMemcpy2DAsync; + struct { + void *dst; + size_t dpitch; + hipArray_const_t src; + size_t wOffset; + size_t hOffset; + size_t width; + size_t height; + hipMemcpyKind kind; + } hipMemcpy2DFromArray; + struct { + void *dst; + size_t dpitch; + hipArray_const_t src; + size_t wOffset; + size_t hOffset; + size_t width; + size_t height; + hipMemcpyKind kind; + hipStream_t stream; + } hipMemcpy2DFromArrayAsync; + struct { + hipArray_t dst; + size_t wOffset; + size_t hOffset; + const void *src; + size_t spitch; + size_t width; + size_t height; + hipMemcpyKind kind; + } hipMemcpy2DToArray; + struct { + hipArray_t dst; + size_t wOffset; + size_t hOffset; + const void *src; + size_t spitch; + size_t width; + size_t height; + hipMemcpyKind kind; + hipStream_t stream; + } hipMemcpy2DToArrayAsync; + struct { + const hipMemcpy3DParms *p; + hipMemcpy3DParms p__val; + } hipMemcpy3D; + struct { + const hipMemcpy3DParms *p; + hipMemcpy3DParms p__val; + hipStream_t stream; + } hipMemcpy3DAsync; + struct { + void *dst; + const void *src; + size_t sizeBytes; + hipMemcpyKind kind; + hipStream_t stream; + } hipMemcpyAsync; + struct { + hipArray_t dstArray; + size_t dstOffset; + hipArray_t srcArray; + size_t srcOffset; + size_t ByteCount; + } hipMemcpyAtoA; + struct { + hipDeviceptr_t dstDevice; + hipArray_t srcArray; + size_t srcOffset; + size_t ByteCount; + } hipMemcpyAtoD; + struct { + void *dst; + hipArray_t srcArray; + size_t srcOffset; + size_t count; + } hipMemcpyAtoH; + struct { + void *dstHost; + hipArray_t srcArray; + size_t srcOffset; + size_t ByteCount; + hipStream_t stream; + } hipMemcpyAtoHAsync; + struct { + hipArray_t dstArray; + size_t dstOffset; + hipDeviceptr_t srcDevice; + size_t ByteCount; + } hipMemcpyDtoA; + struct { + hipDeviceptr_t dst; + hipDeviceptr_t src; + size_t sizeBytes; + } hipMemcpyDtoD; + struct { + hipDeviceptr_t dst; + hipDeviceptr_t src; + size_t sizeBytes; + hipStream_t stream; + } hipMemcpyDtoDAsync; + struct { + void *dst; + hipDeviceptr_t src; + size_t sizeBytes; + } hipMemcpyDtoH; + struct { + void *dst; + hipDeviceptr_t src; + size_t sizeBytes; + hipStream_t stream; + } hipMemcpyDtoHAsync; + struct { + void *dst; + hipArray_const_t srcArray; + size_t wOffset; + size_t hOffset; + size_t count; + hipMemcpyKind kind; + } hipMemcpyFromArray; + struct { + void *dst; + const void *symbol; + size_t sizeBytes; + size_t offset; + hipMemcpyKind kind; + } hipMemcpyFromSymbol; + struct { + void *dst; + const void *symbol; + size_t sizeBytes; + size_t offset; + hipMemcpyKind kind; + hipStream_t stream; + } hipMemcpyFromSymbolAsync; + struct { + hipArray_t dstArray; + size_t dstOffset; + const void *srcHost; + size_t count; + } hipMemcpyHtoA; + struct { + hipArray_t dstArray; + size_t dstOffset; + const void *srcHost; + size_t ByteCount; + hipStream_t stream; + } hipMemcpyHtoAAsync; + struct { + hipDeviceptr_t dst; + void *src; + size_t sizeBytes; + } hipMemcpyHtoD; + struct { + hipDeviceptr_t dst; + void *src; + size_t sizeBytes; + hipStream_t stream; + } hipMemcpyHtoDAsync; + struct { + const hip_Memcpy2D *pCopy; + hip_Memcpy2D pCopy__val; + } hipMemcpyParam2D; + struct { + const hip_Memcpy2D *pCopy; + hip_Memcpy2D pCopy__val; + hipStream_t stream; + } hipMemcpyParam2DAsync; + struct { + void *dst; + int dstDeviceId; + const void *src; + int srcDeviceId; + size_t sizeBytes; + } hipMemcpyPeer; + struct { + void *dst; + int dstDeviceId; + const void *src; + int srcDevice; + size_t sizeBytes; + hipStream_t stream; + } hipMemcpyPeerAsync; + struct { + hipArray_t dst; + size_t wOffset; + size_t hOffset; + const void *src; + size_t count; + hipMemcpyKind kind; + } hipMemcpyToArray; + struct { + const void *symbol; + const void *src; + size_t sizeBytes; + size_t offset; + hipMemcpyKind kind; + } hipMemcpyToSymbol; + struct { + const void *symbol; + const void *src; + size_t sizeBytes; + size_t offset; + hipMemcpyKind kind; + hipStream_t stream; + } hipMemcpyToSymbolAsync; + struct { + void *dst; + const void *src; + size_t sizeBytes; + hipMemcpyKind kind; + hipStream_t stream; + } hipMemcpyWithStream; + struct { + void *dst; + int value; + size_t sizeBytes; + } hipMemset; + struct { + void *dst; + size_t pitch; + int value; + size_t width; + size_t height; + } hipMemset2D; + struct { + void *dst; + size_t pitch; + int value; + size_t width; + size_t height; + hipStream_t stream; + } hipMemset2DAsync; + struct { + hipPitchedPtr pitchedDevPtr; + int value; + hipExtent extent; + } hipMemset3D; + struct { + hipPitchedPtr pitchedDevPtr; + int value; + hipExtent extent; + hipStream_t stream; + } hipMemset3DAsync; + struct { + void *dst; + int value; + size_t sizeBytes; + hipStream_t stream; + } hipMemsetAsync; + struct { + hipDeviceptr_t dest; + unsigned short value; + size_t count; + } hipMemsetD16; + struct { + hipDeviceptr_t dest; + unsigned short value; + size_t count; + hipStream_t stream; + } hipMemsetD16Async; + struct { + hipDeviceptr_t dest; + int value; + size_t count; + } hipMemsetD32; + struct { + hipDeviceptr_t dst; + int value; + size_t count; + hipStream_t stream; + } hipMemsetD32Async; + struct { + hipDeviceptr_t dest; + unsigned char value; + size_t count; + } hipMemsetD8; + struct { + hipDeviceptr_t dest; + unsigned char value; + size_t count; + hipStream_t stream; + } hipMemsetD8Async; + struct { + hipMipmappedArray_t *pHandle; + hipMipmappedArray_t pHandle__val; + HIP_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc; + HIP_ARRAY3D_DESCRIPTOR pMipmappedArrayDesc__val; + unsigned int numMipmapLevels; + } hipMipmappedArrayCreate; + struct { + hipMipmappedArray_t hMipmappedArray; + } hipMipmappedArrayDestroy; + struct { + hipArray_t *pLevelArray; + hipArray_t pLevelArray__val; + hipMipmappedArray_t hMipMappedArray; + unsigned int level; + } hipMipmappedArrayGetLevel; + struct { + hipFunction_t *function; + hipFunction_t function__val; + hipModule_t module; + const char *kname; + char kname__val; + } hipModuleGetFunction; + struct { + hipDeviceptr_t *dptr; + hipDeviceptr_t dptr__val; + size_t *bytes; + size_t bytes__val; + hipModule_t hmod; + const char *name; + char name__val; + } hipModuleGetGlobal; + struct { + textureReference **texRef; + textureReference *texRef__val; + hipModule_t hmod; + const char *name; + char name__val; + } hipModuleGetTexRef; + struct { + hipFunction_t f; + unsigned int gridDimX; + unsigned int gridDimY; + unsigned int gridDimZ; + unsigned int blockDimX; + unsigned int blockDimY; + unsigned int blockDimZ; + unsigned int sharedMemBytes; + hipStream_t stream; + void **kernelParams; + void *kernelParams__val; + } hipModuleLaunchCooperativeKernel; + struct { + hipFunctionLaunchParams *launchParamsList; + hipFunctionLaunchParams launchParamsList__val; + unsigned int numDevices; + unsigned int flags; + } hipModuleLaunchCooperativeKernelMultiDevice; + struct { + hipFunction_t f; + unsigned int gridDimX; + unsigned int gridDimY; + unsigned int gridDimZ; + unsigned int blockDimX; + unsigned int blockDimY; + unsigned int blockDimZ; + unsigned int sharedMemBytes; + hipStream_t stream; + void **kernelParams; + void *kernelParams__val; + void **extra; + void *extra__val; + } hipModuleLaunchKernel; + struct { + hipModule_t *module; + hipModule_t module__val; + const char *fname; + char fname__val; + } hipModuleLoad; + struct { + hipModule_t *module; + hipModule_t module__val; + const void *image; + } hipModuleLoadData; + struct { + hipModule_t *module; + hipModule_t module__val; + const void *image; + unsigned int numOptions; + hipJitOption *options; + hipJitOption options__val; + void **optionsValues; + void *optionsValues__val; + } hipModuleLoadDataEx; + struct { + int *numBlocks; + int numBlocks__val; + hipFunction_t f; + int blockSize; + size_t dynSharedMemPerBlk; + } hipModuleOccupancyMaxActiveBlocksPerMultiprocessor; + struct { + int *numBlocks; + int numBlocks__val; + hipFunction_t f; + int blockSize; + size_t dynSharedMemPerBlk; + unsigned int flags; + } hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags; + struct { + int *gridSize; + int gridSize__val; + int *blockSize; + int blockSize__val; + hipFunction_t f; + size_t dynSharedMemPerBlk; + int blockSizeLimit; + } hipModuleOccupancyMaxPotentialBlockSize; + struct { + int *gridSize; + int gridSize__val; + int *blockSize; + int blockSize__val; + hipFunction_t f; + size_t dynSharedMemPerBlk; + int blockSizeLimit; + unsigned int flags; + } hipModuleOccupancyMaxPotentialBlockSizeWithFlags; + struct { + hipModule_t module; + } hipModuleUnload; + struct { + int *numBlocks; + int numBlocks__val; + const void *f; + int blockSize; + size_t dynamicSMemSize; + } hipOccupancyMaxActiveBlocksPerMultiprocessor; + struct { + int *numBlocks; + int numBlocks__val; + const void *f; + int blockSize; + size_t dynamicSMemSize; + unsigned int flags; + } hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags; + struct { + int *gridSize; + int gridSize__val; + int *blockSize; + int blockSize__val; + const void *f; + size_t dynSharedMemPerBlk; + int blockSizeLimit; + } hipOccupancyMaxPotentialBlockSize; + struct { + void *data; + hipPointer_attribute attribute; + hipDeviceptr_t ptr; + } hipPointerGetAttribute; + struct { + hipPointerAttribute_t *attributes; + hipPointerAttribute_t attributes__val; + const void *ptr; + } hipPointerGetAttributes; + struct { + const void *value; + hipPointer_attribute attribute; + hipDeviceptr_t ptr; + } hipPointerSetAttribute; + struct { + int *runtimeVersion; + int runtimeVersion__val; + } hipRuntimeGetVersion; + struct { + int deviceId; + } hipSetDevice; + struct { + unsigned int flags; + } hipSetDeviceFlags; + struct { + int *device_arr; + int device_arr__val; + int len; + } hipSetValidDevices; + struct { + const void *arg; + size_t size; + size_t offset; + } hipSetupArgument; + struct { + const hipExternalSemaphore_t *extSemArray; + hipExternalSemaphore_t extSemArray__val; + const hipExternalSemaphoreSignalParams *paramsArray; + hipExternalSemaphoreSignalParams paramsArray__val; + unsigned int numExtSems; + hipStream_t stream; + } hipSignalExternalSemaphoresAsync; + struct { + hipStream_t stream; + hipStreamCallback_t callback; + void *userData; + unsigned int flags; + } hipStreamAddCallback; + struct { + hipStream_t stream; + void *dev_ptr; + size_t length; + unsigned int flags; + } hipStreamAttachMemAsync; + struct { + hipStream_t stream; + hipStreamCaptureMode mode; + } hipStreamBeginCapture; + struct { + hipStream_t stream; + hipGraph_t graph; + const hipGraphNode_t *dependencies; + hipGraphNode_t dependencies__val; + const hipGraphEdgeData *dependencyData; + hipGraphEdgeData dependencyData__val; + size_t numDependencies; + hipStreamCaptureMode mode; + } hipStreamBeginCaptureToGraph; + struct { + hipStream_t *stream; + hipStream_t stream__val; + } hipStreamCreate; + struct { + hipStream_t *stream; + hipStream_t stream__val; + unsigned int flags; + } hipStreamCreateWithFlags; + struct { + hipStream_t *stream; + hipStream_t stream__val; + unsigned int flags; + int priority; + } hipStreamCreateWithPriority; + struct { + hipStream_t stream; + } hipStreamDestroy; + struct { + hipStream_t stream; + hipGraph_t *pGraph; + hipGraph_t pGraph__val; + } hipStreamEndCapture; + struct { + hipStream_t stream; + hipStreamCaptureStatus *pCaptureStatus; + hipStreamCaptureStatus pCaptureStatus__val; + unsigned long long *pId; + unsigned long long pId__val; + } hipStreamGetCaptureInfo; + struct { + hipStream_t stream; + hipStreamCaptureStatus *captureStatus_out; + hipStreamCaptureStatus captureStatus_out__val; + unsigned long long *id_out; + unsigned long long id_out__val; + hipGraph_t *graph_out; + hipGraph_t graph_out__val; + const hipGraphNode_t **dependencies_out; + const hipGraphNode_t *dependencies_out__val; + size_t *numDependencies_out; + size_t numDependencies_out__val; + } hipStreamGetCaptureInfo_v2; + struct { + hipStream_t stream; + hipDevice_t *device; + hipDevice_t device__val; + } hipStreamGetDevice; + struct { + hipStream_t stream; + unsigned int *flags; + unsigned int flags__val; + } hipStreamGetFlags; + struct { + hipStream_t stream; + int *priority; + int priority__val; + } hipStreamGetPriority; + struct { + hipStream_t stream; + hipStreamCaptureStatus *pCaptureStatus; + hipStreamCaptureStatus pCaptureStatus__val; + } hipStreamIsCapturing; + struct { + hipStream_t stream; + } hipStreamQuery; + struct { + hipStream_t stream; + } hipStreamSynchronize; + struct { + hipStream_t stream; + hipGraphNode_t *dependencies; + hipGraphNode_t dependencies__val; + size_t numDependencies; + unsigned int flags; + } hipStreamUpdateCaptureDependencies; + struct { + hipStream_t stream; + hipEvent_t event; + unsigned int flags; + } hipStreamWaitEvent; + struct { + hipStream_t stream; + void *ptr; + unsigned int value; + unsigned int flags; + unsigned int mask; + } hipStreamWaitValue32; + struct { + hipStream_t stream; + void *ptr; + uint64_t value; + unsigned int flags; + uint64_t mask; + } hipStreamWaitValue64; + struct { + hipStream_t stream; + void *ptr; + unsigned int value; + unsigned int flags; + } hipStreamWriteValue32; + struct { + hipStream_t stream; + void *ptr; + uint64_t value; + unsigned int flags; + } hipStreamWriteValue64; + struct { + hipDeviceptr_t *dev_ptr; + hipDeviceptr_t dev_ptr__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetAddress; + struct { + hipArray_t *pArray; + hipArray_t pArray__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetArray; + struct { + float *pBorderColor; + float pBorderColor__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetBorderColor; + struct { + unsigned int *pFlags; + unsigned int pFlags__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetFlags; + struct { + hipArray_Format *pFormat; + hipArray_Format pFormat__val; + int *pNumChannels; + int pNumChannels__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetFormat; + struct { + int *pmaxAnsio; + int pmaxAnsio__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetMaxAnisotropy; + struct { + hipMipmappedArray_t *pArray; + hipMipmappedArray_t pArray__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetMipMappedArray; + struct { + float *pbias; + float pbias__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetMipmapLevelBias; + struct { + float *pminMipmapLevelClamp; + float pminMipmapLevelClamp__val; + float *pmaxMipmapLevelClamp; + float pmaxMipmapLevelClamp__val; + const textureReference *texRef; + textureReference texRef__val; + } hipTexRefGetMipmapLevelClamp; + struct { + size_t *ByteOffset; + size_t ByteOffset__val; + textureReference *texRef; + textureReference texRef__val; + hipDeviceptr_t dptr; + size_t bytes; + } hipTexRefSetAddress; + struct { + textureReference *texRef; + textureReference texRef__val; + const HIP_ARRAY_DESCRIPTOR *desc; + HIP_ARRAY_DESCRIPTOR desc__val; + hipDeviceptr_t dptr; + size_t Pitch; + } hipTexRefSetAddress2D; + struct { + textureReference *tex; + textureReference tex__val; + hipArray_const_t array; + unsigned int flags; + } hipTexRefSetArray; + struct { + textureReference *texRef; + textureReference texRef__val; + float *pBorderColor; + float pBorderColor__val; + } hipTexRefSetBorderColor; + struct { + textureReference *texRef; + textureReference texRef__val; + unsigned int Flags; + } hipTexRefSetFlags; + struct { + textureReference *texRef; + textureReference texRef__val; + hipArray_Format fmt; + int NumPackedComponents; + } hipTexRefSetFormat; + struct { + textureReference *texRef; + textureReference texRef__val; + unsigned int maxAniso; + } hipTexRefSetMaxAnisotropy; + struct { + textureReference *texRef; + textureReference texRef__val; + float bias; + } hipTexRefSetMipmapLevelBias; + struct { + textureReference *texRef; + textureReference texRef__val; + float minMipMapLevelClamp; + float maxMipMapLevelClamp; + } hipTexRefSetMipmapLevelClamp; + struct { + textureReference *texRef; + textureReference texRef__val; + hipMipmappedArray *mipmappedArray; + hipMipmappedArray mipmappedArray__val; + unsigned int Flags; + } hipTexRefSetMipmappedArray; + struct { + hipStreamCaptureMode *mode; + hipStreamCaptureMode mode__val; + } hipThreadExchangeStreamCaptureMode; + struct { + hipUserObject_t *object_out; + hipUserObject_t object_out__val; + void *ptr; + hipHostFn_t destroy; + unsigned int initialRefcount; + unsigned int flags; + } hipUserObjectCreate; + struct { + hipUserObject_t object; + unsigned int count; + } hipUserObjectRelease; + struct { + hipUserObject_t object; + unsigned int count; + } hipUserObjectRetain; + struct { + const hipExternalSemaphore_t *extSemArray; + hipExternalSemaphore_t extSemArray__val; + const hipExternalSemaphoreWaitParams *paramsArray; + hipExternalSemaphoreWaitParams paramsArray__val; + unsigned int numExtSems; + hipStream_t stream; + } hipWaitExternalSemaphoresAsync; + } args; + uint64_t *phase_data; +} hip_api_data_t; + +// HIP API callbacks args data filling macros +// __hipPopCallConfiguration[('dim3*', 'gridDim'), ('dim3*', 'blockDim'), +// ('size_t*', 'sharedMem'), ('hipStream_t*', 'stream')] +#define INIT___hipPopCallConfiguration_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.__hipPopCallConfiguration.gridDim = (dim3 *)gridDim; \ + cb_data.args.__hipPopCallConfiguration.blockDim = (dim3 *)blockDim; \ + cb_data.args.__hipPopCallConfiguration.sharedMem = (size_t *)sharedMem; \ + cb_data.args.__hipPopCallConfiguration.stream = (hipStream_t *)stream; \ + }; +// __hipPushCallConfiguration[('dim3', 'gridDim'), ('dim3', 'blockDim'), +// ('size_t', 'sharedMem'), ('hipStream_t', 'stream')] +#define INIT___hipPushCallConfiguration_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.__hipPushCallConfiguration.gridDim = (dim3)gridDim; \ + cb_data.args.__hipPushCallConfiguration.blockDim = (dim3)blockDim; \ + cb_data.args.__hipPushCallConfiguration.sharedMem = (size_t)sharedMem; \ + cb_data.args.__hipPushCallConfiguration.stream = (hipStream_t)stream; \ + }; +// hipArray3DCreate[('hipArray_t*', 'array'), ('const HIP_ARRAY3D_DESCRIPTOR*', +// 'pAllocateArray')] +#define INIT_hipArray3DCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipArray3DCreate.array = (hipArray_t *)array; \ + cb_data.args.hipArray3DCreate.pAllocateArray = \ + (const HIP_ARRAY3D_DESCRIPTOR *)pAllocateArray; \ + }; +// hipArray3DGetDescriptor[('HIP_ARRAY3D_DESCRIPTOR*', 'pArrayDescriptor'), +// ('hipArray_t', 'array')] +#define INIT_hipArray3DGetDescriptor_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipArray3DGetDescriptor.pArrayDescriptor = \ + (HIP_ARRAY3D_DESCRIPTOR *)pArrayDescriptor; \ + cb_data.args.hipArray3DGetDescriptor.array = (hipArray_t)array; \ + }; +// hipArrayCreate[('hipArray_t*', 'pHandle'), ('const HIP_ARRAY_DESCRIPTOR*', +// 'pAllocateArray')] +#define INIT_hipArrayCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipArrayCreate.pHandle = (hipArray_t *)array; \ + cb_data.args.hipArrayCreate.pAllocateArray = \ + (const HIP_ARRAY_DESCRIPTOR *)pAllocateArray; \ + }; +// hipArrayDestroy[('hipArray_t', 'array')] +#define INIT_hipArrayDestroy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipArrayDestroy.array = (hipArray_t)array; \ + }; +// hipArrayGetDescriptor[('HIP_ARRAY_DESCRIPTOR*', 'pArrayDescriptor'), +// ('hipArray_t', 'array')] +#define INIT_hipArrayGetDescriptor_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipArrayGetDescriptor.pArrayDescriptor = \ + (HIP_ARRAY_DESCRIPTOR *)pArrayDescriptor; \ + cb_data.args.hipArrayGetDescriptor.array = (hipArray_t)array; \ + }; +// hipArrayGetInfo[('hipChannelFormatDesc*', 'desc'), ('hipExtent*', 'extent'), +// ('unsigned int*', 'flags'), ('hipArray_t', 'array')] +#define INIT_hipArrayGetInfo_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipArrayGetInfo.desc = (hipChannelFormatDesc *)desc; \ + cb_data.args.hipArrayGetInfo.extent = (hipExtent *)extent; \ + cb_data.args.hipArrayGetInfo.flags = (unsigned int *)flags; \ + cb_data.args.hipArrayGetInfo.array = (hipArray_t)array; \ + }; +// hipChooseDeviceR0000[('int*', 'device'), ('const hipDeviceProp_tR0000*', +// 'prop')] +#define INIT_hipChooseDeviceR0000_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipChooseDeviceR0000.device = (int *)device; \ + cb_data.args.hipChooseDeviceR0000.prop = \ + (const hipDeviceProp_tR0000 *)properties; \ + }; +// hipChooseDeviceR0600[('int*', 'device'), ('const hipDeviceProp_tR0600*', +// 'prop')] +#define INIT_hipChooseDeviceR0600_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipChooseDeviceR0600.device = (int *)device; \ + cb_data.args.hipChooseDeviceR0600.prop = \ + (const hipDeviceProp_tR0600 *)properties; \ + }; +// hipConfigureCall[('dim3', 'gridDim'), ('dim3', 'blockDim'), ('size_t', +// 'sharedMem'), ('hipStream_t', 'stream')] +#define INIT_hipConfigureCall_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipConfigureCall.gridDim = (dim3)gridDim; \ + cb_data.args.hipConfigureCall.blockDim = (dim3)blockDim; \ + cb_data.args.hipConfigureCall.sharedMem = (size_t)sharedMem; \ + cb_data.args.hipConfigureCall.stream = (hipStream_t)stream; \ + }; +// hipCreateSurfaceObject[('hipSurfaceObject_t*', 'pSurfObject'), ('const +// hipResourceDesc*', 'pResDesc')] +#define INIT_hipCreateSurfaceObject_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCreateSurfaceObject.pSurfObject = \ + (hipSurfaceObject_t *)pSurfObject; \ + cb_data.args.hipCreateSurfaceObject.pResDesc = \ + (const hipResourceDesc *)pResDesc; \ + }; +// hipCtxCreate[('hipCtx_t*', 'ctx'), ('unsigned int', 'flags'), ('hipDevice_t', +// 'device')] +#define INIT_hipCtxCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxCreate.ctx = (hipCtx_t *)ctx; \ + cb_data.args.hipCtxCreate.flags = (unsigned int)flags; \ + cb_data.args.hipCtxCreate.device = (hipDevice_t)device; \ + }; +// hipCtxDestroy[('hipCtx_t', 'ctx')] +#define INIT_hipCtxDestroy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxDestroy.ctx = (hipCtx_t)ctx; \ + }; +// hipCtxDisablePeerAccess[('hipCtx_t', 'peerCtx')] +#define INIT_hipCtxDisablePeerAccess_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxDisablePeerAccess.peerCtx = (hipCtx_t)peerCtx; \ + }; +// hipCtxEnablePeerAccess[('hipCtx_t', 'peerCtx'), ('unsigned int', 'flags')] +#define INIT_hipCtxEnablePeerAccess_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxEnablePeerAccess.peerCtx = (hipCtx_t)peerCtx; \ + cb_data.args.hipCtxEnablePeerAccess.flags = (unsigned int)flags; \ + }; +// hipCtxGetApiVersion[('hipCtx_t', 'ctx'), ('int*', 'apiVersion')] +#define INIT_hipCtxGetApiVersion_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxGetApiVersion.ctx = (hipCtx_t)ctx; \ + cb_data.args.hipCtxGetApiVersion.apiVersion = (int *)apiVersion; \ + }; +// hipCtxGetCacheConfig[('hipFuncCache_t*', 'cacheConfig')] +#define INIT_hipCtxGetCacheConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxGetCacheConfig.cacheConfig = \ + (hipFuncCache_t *)cacheConfig; \ + }; +// hipCtxGetCurrent[('hipCtx_t*', 'ctx')] +#define INIT_hipCtxGetCurrent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxGetCurrent.ctx = (hipCtx_t *)ctx; \ + }; +// hipCtxGetDevice[('hipDevice_t*', 'device')] +#define INIT_hipCtxGetDevice_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxGetDevice.device = (hipDevice_t *)device; \ + }; +// hipCtxGetFlags[('unsigned int*', 'flags')] +#define INIT_hipCtxGetFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxGetFlags.flags = (unsigned int *)flags; \ + }; +// hipCtxGetSharedMemConfig[('hipSharedMemConfig*', 'pConfig')] +#define INIT_hipCtxGetSharedMemConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxGetSharedMemConfig.pConfig = \ + (hipSharedMemConfig *)pConfig; \ + }; +// hipCtxPopCurrent[('hipCtx_t*', 'ctx')] +#define INIT_hipCtxPopCurrent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxPopCurrent.ctx = (hipCtx_t *)ctx; \ + }; +// hipCtxPushCurrent[('hipCtx_t', 'ctx')] +#define INIT_hipCtxPushCurrent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxPushCurrent.ctx = (hipCtx_t)ctx; \ + }; +// hipCtxSetCacheConfig[('hipFuncCache_t', 'cacheConfig')] +#define INIT_hipCtxSetCacheConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxSetCacheConfig.cacheConfig = \ + (hipFuncCache_t)cacheConfig; \ + }; +// hipCtxSetCurrent[('hipCtx_t', 'ctx')] +#define INIT_hipCtxSetCurrent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxSetCurrent.ctx = (hipCtx_t)ctx; \ + }; +// hipCtxSetSharedMemConfig[('hipSharedMemConfig', 'config')] +#define INIT_hipCtxSetSharedMemConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipCtxSetSharedMemConfig.config = (hipSharedMemConfig)config; \ + }; +// hipCtxSynchronize[] +#define INIT_hipCtxSynchronize_CB_ARGS_DATA(cb_data) {}; +// hipDestroyExternalMemory[('hipExternalMemory_t', 'extMem')] +#define INIT_hipDestroyExternalMemory_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDestroyExternalMemory.extMem = \ + (hipExternalMemory_t)extMem; \ + }; +// hipDestroyExternalSemaphore[('hipExternalSemaphore_t', 'extSem')] +#define INIT_hipDestroyExternalSemaphore_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDestroyExternalSemaphore.extSem = \ + (hipExternalSemaphore_t)extSem; \ + }; +// hipDestroySurfaceObject[('hipSurfaceObject_t', 'surfaceObject')] +#define INIT_hipDestroySurfaceObject_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDestroySurfaceObject.surfaceObject = \ + (hipSurfaceObject_t)surfaceObject; \ + }; +// hipDeviceCanAccessPeer[('int*', 'canAccessPeer'), ('int', 'deviceId'), +// ('int', 'peerDeviceId')] +#define INIT_hipDeviceCanAccessPeer_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceCanAccessPeer.canAccessPeer = (int *)canAccess; \ + cb_data.args.hipDeviceCanAccessPeer.deviceId = (int)deviceId; \ + cb_data.args.hipDeviceCanAccessPeer.peerDeviceId = (int)peerDeviceId; \ + }; +// hipDeviceComputeCapability[('int*', 'major'), ('int*', 'minor'), +// ('hipDevice_t', 'device')] +#define INIT_hipDeviceComputeCapability_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceComputeCapability.major = (int *)major; \ + cb_data.args.hipDeviceComputeCapability.minor = (int *)minor; \ + cb_data.args.hipDeviceComputeCapability.device = (hipDevice_t)device; \ + }; +// hipDeviceDisablePeerAccess[('int', 'peerDeviceId')] +#define INIT_hipDeviceDisablePeerAccess_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceDisablePeerAccess.peerDeviceId = (int)peerDeviceId; \ + }; +// hipDeviceEnablePeerAccess[('int', 'peerDeviceId'), ('unsigned int', 'flags')] +#define INIT_hipDeviceEnablePeerAccess_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceEnablePeerAccess.peerDeviceId = (int)peerDeviceId; \ + cb_data.args.hipDeviceEnablePeerAccess.flags = (unsigned int)flags; \ + }; +// hipDeviceGet[('hipDevice_t*', 'device'), ('int', 'ordinal')] +#define INIT_hipDeviceGet_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGet.device = (hipDevice_t *)device; \ + cb_data.args.hipDeviceGet.ordinal = (int)deviceId; \ + }; +// hipDeviceGetAttribute[('int*', 'pi'), ('hipDeviceAttribute_t', 'attr'), +// ('int', 'deviceId')] +#define INIT_hipDeviceGetAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetAttribute.pi = (int *)pi; \ + cb_data.args.hipDeviceGetAttribute.attr = (hipDeviceAttribute_t)attr; \ + cb_data.args.hipDeviceGetAttribute.deviceId = (int)device; \ + }; +// hipDeviceGetByPCIBusId[('int*', 'device'), ('const char*', 'pciBusId')] +#define INIT_hipDeviceGetByPCIBusId_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetByPCIBusId.device = (int *)device; \ + cb_data.args.hipDeviceGetByPCIBusId.pciBusId = \ + (pciBusIdstr) ? strdup(pciBusIdstr) : NULL; \ + }; +// hipDeviceGetCacheConfig[('hipFuncCache_t*', 'cacheConfig')] +#define INIT_hipDeviceGetCacheConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetCacheConfig.cacheConfig = \ + (hipFuncCache_t *)cacheConfig; \ + }; +// hipDeviceGetDefaultMemPool[('hipMemPool_t*', 'mem_pool'), ('int', 'device')] +#define INIT_hipDeviceGetDefaultMemPool_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetDefaultMemPool.mem_pool = \ + (hipMemPool_t *)mem_pool; \ + cb_data.args.hipDeviceGetDefaultMemPool.device = (int)device; \ + }; +// hipDeviceGetGraphMemAttribute[('int', 'device'), ('hipGraphMemAttributeType', +// 'attr'), ('void*', 'value')] +#define INIT_hipDeviceGetGraphMemAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetGraphMemAttribute.device = (int)device; \ + cb_data.args.hipDeviceGetGraphMemAttribute.attr = \ + (hipGraphMemAttributeType)attr; \ + cb_data.args.hipDeviceGetGraphMemAttribute.value = (void *)value; \ + }; +// hipDeviceGetLimit[('size_t*', 'pValue'), ('hipLimit_t', 'limit')] +#define INIT_hipDeviceGetLimit_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetLimit.pValue = (size_t *)pValue; \ + cb_data.args.hipDeviceGetLimit.limit = (hipLimit_t)limit; \ + }; +// hipDeviceGetMemPool[('hipMemPool_t*', 'mem_pool'), ('int', 'device')] +#define INIT_hipDeviceGetMemPool_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetMemPool.mem_pool = (hipMemPool_t *)mem_pool; \ + cb_data.args.hipDeviceGetMemPool.device = (int)device; \ + }; +// hipDeviceGetName[('char*', 'name'), ('int', 'len'), ('hipDevice_t', +// 'device')] +#define INIT_hipDeviceGetName_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetName.name = (char *)name; \ + cb_data.args.hipDeviceGetName.len = (int)len; \ + cb_data.args.hipDeviceGetName.device = (hipDevice_t)device; \ + }; +// hipDeviceGetP2PAttribute[('int*', 'value'), ('hipDeviceP2PAttr', 'attr'), +// ('int', 'srcDevice'), ('int', 'dstDevice')] +#define INIT_hipDeviceGetP2PAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetP2PAttribute.value = (int *)value; \ + cb_data.args.hipDeviceGetP2PAttribute.attr = (hipDeviceP2PAttr)attr; \ + cb_data.args.hipDeviceGetP2PAttribute.srcDevice = (int)srcDevice; \ + cb_data.args.hipDeviceGetP2PAttribute.dstDevice = (int)dstDevice; \ + }; +// hipDeviceGetPCIBusId[('char*', 'pciBusId'), ('int', 'len'), ('int', +// 'device')] +#define INIT_hipDeviceGetPCIBusId_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetPCIBusId.pciBusId = (char *)pciBusId; \ + cb_data.args.hipDeviceGetPCIBusId.len = (int)len; \ + cb_data.args.hipDeviceGetPCIBusId.device = (int)device; \ + }; +// hipDeviceGetSharedMemConfig[('hipSharedMemConfig*', 'pConfig')] +#define INIT_hipDeviceGetSharedMemConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetSharedMemConfig.pConfig = \ + (hipSharedMemConfig *)pConfig; \ + }; +// hipDeviceGetStreamPriorityRange[('int*', 'leastPriority'), ('int*', +// 'greatestPriority')] +#define INIT_hipDeviceGetStreamPriorityRange_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetStreamPriorityRange.leastPriority = \ + (int *)leastPriority; \ + cb_data.args.hipDeviceGetStreamPriorityRange.greatestPriority = \ + (int *)greatestPriority; \ + }; +// hipDeviceGetUuid[('hipUUID*', 'uuid'), ('hipDevice_t', 'device')] +#define INIT_hipDeviceGetUuid_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGetUuid.uuid = (hipUUID *)uuid; \ + cb_data.args.hipDeviceGetUuid.device = (hipDevice_t)device; \ + }; +// hipDeviceGraphMemTrim[('int', 'device')] +#define INIT_hipDeviceGraphMemTrim_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceGraphMemTrim.device = (int)device; \ + }; +// hipDevicePrimaryCtxGetState[('hipDevice_t', 'dev'), ('unsigned int*', +// 'flags'), ('int*', 'active')] +#define INIT_hipDevicePrimaryCtxGetState_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDevicePrimaryCtxGetState.dev = (hipDevice_t)dev; \ + cb_data.args.hipDevicePrimaryCtxGetState.flags = (unsigned int *)flags; \ + cb_data.args.hipDevicePrimaryCtxGetState.active = (int *)active; \ + }; +// hipDevicePrimaryCtxRelease[('hipDevice_t', 'dev')] +#define INIT_hipDevicePrimaryCtxRelease_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDevicePrimaryCtxRelease.dev = (hipDevice_t)dev; \ + }; +// hipDevicePrimaryCtxReset[('hipDevice_t', 'dev')] +#define INIT_hipDevicePrimaryCtxReset_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDevicePrimaryCtxReset.dev = (hipDevice_t)dev; \ + }; +// hipDevicePrimaryCtxRetain[('hipCtx_t*', 'pctx'), ('hipDevice_t', 'dev')] +#define INIT_hipDevicePrimaryCtxRetain_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDevicePrimaryCtxRetain.pctx = (hipCtx_t *)pctx; \ + cb_data.args.hipDevicePrimaryCtxRetain.dev = (hipDevice_t)dev; \ + }; +// hipDevicePrimaryCtxSetFlags[('hipDevice_t', 'dev'), ('unsigned int', +// 'flags')] +#define INIT_hipDevicePrimaryCtxSetFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDevicePrimaryCtxSetFlags.dev = (hipDevice_t)dev; \ + cb_data.args.hipDevicePrimaryCtxSetFlags.flags = (unsigned int)flags; \ + }; +// hipDeviceReset[] +#define INIT_hipDeviceReset_CB_ARGS_DATA(cb_data) {}; +// hipDeviceSetCacheConfig[('hipFuncCache_t', 'cacheConfig')] +#define INIT_hipDeviceSetCacheConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceSetCacheConfig.cacheConfig = \ + (hipFuncCache_t)cacheConfig; \ + }; +// hipDeviceSetGraphMemAttribute[('int', 'device'), ('hipGraphMemAttributeType', +// 'attr'), ('void*', 'value')] +#define INIT_hipDeviceSetGraphMemAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceSetGraphMemAttribute.device = (int)device; \ + cb_data.args.hipDeviceSetGraphMemAttribute.attr = \ + (hipGraphMemAttributeType)attr; \ + cb_data.args.hipDeviceSetGraphMemAttribute.value = (void *)value; \ + }; +// hipDeviceSetLimit[('hipLimit_t', 'limit'), ('size_t', 'value')] +#define INIT_hipDeviceSetLimit_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceSetLimit.limit = (hipLimit_t)limit; \ + cb_data.args.hipDeviceSetLimit.value = (size_t)value; \ + }; +// hipDeviceSetMemPool[('int', 'device'), ('hipMemPool_t', 'mem_pool')] +#define INIT_hipDeviceSetMemPool_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceSetMemPool.device = (int)device; \ + cb_data.args.hipDeviceSetMemPool.mem_pool = (hipMemPool_t)mem_pool; \ + }; +// hipDeviceSetSharedMemConfig[('hipSharedMemConfig', 'config')] +#define INIT_hipDeviceSetSharedMemConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceSetSharedMemConfig.config = \ + (hipSharedMemConfig)config; \ + }; +// hipDeviceSynchronize[] +#define INIT_hipDeviceSynchronize_CB_ARGS_DATA(cb_data) {}; +// hipDeviceTotalMem[('size_t*', 'bytes'), ('hipDevice_t', 'device')] +#define INIT_hipDeviceTotalMem_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDeviceTotalMem.bytes = (size_t *)bytes; \ + cb_data.args.hipDeviceTotalMem.device = (hipDevice_t)device; \ + }; +// hipDriverGetVersion[('int*', 'driverVersion')] +#define INIT_hipDriverGetVersion_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDriverGetVersion.driverVersion = (int *)driverVersion; \ + }; +// hipDrvGraphAddMemcpyNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', +// 'hGraph'), ('const hipGraphNode_t*', 'dependencies'), ('size_t', +// 'numDependencies'), ('const HIP_MEMCPY3D*', 'copyParams'), ('hipCtx_t', +// 'ctx')] +#define INIT_hipDrvGraphAddMemcpyNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDrvGraphAddMemcpyNode.phGraphNode = \ + (hipGraphNode_t *)phGraphNode; \ + cb_data.args.hipDrvGraphAddMemcpyNode.hGraph = (hipGraph_t)hGraph; \ + cb_data.args.hipDrvGraphAddMemcpyNode.dependencies = \ + (const hipGraphNode_t *)dependencies; \ + cb_data.args.hipDrvGraphAddMemcpyNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipDrvGraphAddMemcpyNode.copyParams = \ + (const HIP_MEMCPY3D *)copyParams; \ + cb_data.args.hipDrvGraphAddMemcpyNode.ctx = (hipCtx_t)ctx; \ + }; +// hipDrvGraphAddMemsetNode[('hipGraphNode_t*', 'phGraphNode'), ('hipGraph_t', +// 'hGraph'), ('const hipGraphNode_t*', 'dependencies'), ('size_t', +// 'numDependencies'), ('const HIP_MEMSET_NODE_PARAMS*', 'memsetParams'), +// ('hipCtx_t', 'ctx')] +#define INIT_hipDrvGraphAddMemsetNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDrvGraphAddMemsetNode.phGraphNode = \ + (hipGraphNode_t *)phGraphNode; \ + cb_data.args.hipDrvGraphAddMemsetNode.hGraph = (hipGraph_t)hGraph; \ + cb_data.args.hipDrvGraphAddMemsetNode.dependencies = \ + (const hipGraphNode_t *)dependencies; \ + cb_data.args.hipDrvGraphAddMemsetNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipDrvGraphAddMemsetNode.memsetParams = \ + (const HIP_MEMSET_NODE_PARAMS *)memsetParams; \ + cb_data.args.hipDrvGraphAddMemsetNode.ctx = (hipCtx_t)ctx; \ + }; +// hipDrvMemcpy2DUnaligned[('const hip_Memcpy2D*', 'pCopy')] +#define INIT_hipDrvMemcpy2DUnaligned_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDrvMemcpy2DUnaligned.pCopy = (const hip_Memcpy2D *)pCopy; \ + }; +// hipDrvMemcpy3D[('const HIP_MEMCPY3D*', 'pCopy')] +#define INIT_hipDrvMemcpy3D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDrvMemcpy3D.pCopy = (const HIP_MEMCPY3D *)pCopy; \ + }; +// hipDrvMemcpy3DAsync[('const HIP_MEMCPY3D*', 'pCopy'), ('hipStream_t', +// 'stream')] +#define INIT_hipDrvMemcpy3DAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDrvMemcpy3DAsync.pCopy = (const HIP_MEMCPY3D *)pCopy; \ + cb_data.args.hipDrvMemcpy3DAsync.stream = (hipStream_t)stream; \ + }; +// hipDrvPointerGetAttributes[('unsigned int', 'numAttributes'), +// ('hipPointer_attribute*', 'attributes'), ('void**', 'data'), +// ('hipDeviceptr_t', 'ptr')] +#define INIT_hipDrvPointerGetAttributes_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipDrvPointerGetAttributes.numAttributes = \ + (unsigned int)numAttributes; \ + cb_data.args.hipDrvPointerGetAttributes.attributes = \ + (hipPointer_attribute *)attributes; \ + cb_data.args.hipDrvPointerGetAttributes.data = (void **)data; \ + cb_data.args.hipDrvPointerGetAttributes.ptr = (hipDeviceptr_t)ptr; \ + }; +// hipEventCreate[('hipEvent_t*', 'event')] +#define INIT_hipEventCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipEventCreate.event = (hipEvent_t *)event; \ + }; +// hipEventCreateWithFlags[('hipEvent_t*', 'event'), ('unsigned int', 'flags')] +#define INIT_hipEventCreateWithFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipEventCreateWithFlags.event = (hipEvent_t *)event; \ + cb_data.args.hipEventCreateWithFlags.flags = (unsigned int)flags; \ + }; +// hipEventDestroy[('hipEvent_t', 'event')] +#define INIT_hipEventDestroy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipEventDestroy.event = (hipEvent_t)event; \ + }; +// hipEventElapsedTime[('float*', 'ms'), ('hipEvent_t', 'start'), ('hipEvent_t', +// 'stop')] +#define INIT_hipEventElapsedTime_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipEventElapsedTime.ms = (float *)ms; \ + cb_data.args.hipEventElapsedTime.start = (hipEvent_t)start; \ + cb_data.args.hipEventElapsedTime.stop = (hipEvent_t)stop; \ + }; +// hipEventQuery[('hipEvent_t', 'event')] +#define INIT_hipEventQuery_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipEventQuery.event = (hipEvent_t)event; \ + }; +// hipEventRecord[('hipEvent_t', 'event'), ('hipStream_t', 'stream')] +#define INIT_hipEventRecord_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipEventRecord.event = (hipEvent_t)event; \ + cb_data.args.hipEventRecord.stream = (hipStream_t)stream; \ + }; +// hipEventSynchronize[('hipEvent_t', 'event')] +#define INIT_hipEventSynchronize_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipEventSynchronize.event = (hipEvent_t)event; \ + }; +// hipExtGetLastError[] +#define INIT_hipExtGetLastError_CB_ARGS_DATA(cb_data) {}; +// hipExtGetLinkTypeAndHopCount[('int', 'device1'), ('int', 'device2'), +// ('unsigned int*', 'linktype'), ('unsigned int*', 'hopcount')] +#define INIT_hipExtGetLinkTypeAndHopCount_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExtGetLinkTypeAndHopCount.device1 = (int)device1; \ + cb_data.args.hipExtGetLinkTypeAndHopCount.device2 = (int)device2; \ + cb_data.args.hipExtGetLinkTypeAndHopCount.linktype = \ + (unsigned int *)linktype; \ + cb_data.args.hipExtGetLinkTypeAndHopCount.hopcount = \ + (unsigned int *)hopcount; \ + }; +// hipExtLaunchKernel[('const void*', 'function_address'), ('dim3', +// 'numBlocks'), ('dim3', 'dimBlocks'), ('void**', 'args'), ('size_t', +// 'sharedMemBytes'), ('hipStream_t', 'stream'), ('hipEvent_t', 'startEvent'), +// ('hipEvent_t', 'stopEvent'), ('int', 'flags')] +#define INIT_hipExtLaunchKernel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExtLaunchKernel.function_address = \ + (const void *)hostFunction; \ + cb_data.args.hipExtLaunchKernel.numBlocks = (dim3)gridDim; \ + cb_data.args.hipExtLaunchKernel.dimBlocks = (dim3)blockDim; \ + cb_data.args.hipExtLaunchKernel.args = (void **)args; \ + cb_data.args.hipExtLaunchKernel.sharedMemBytes = (size_t)sharedMemBytes; \ + cb_data.args.hipExtLaunchKernel.stream = (hipStream_t)stream; \ + cb_data.args.hipExtLaunchKernel.startEvent = (hipEvent_t)startEvent; \ + cb_data.args.hipExtLaunchKernel.stopEvent = (hipEvent_t)stopEvent; \ + cb_data.args.hipExtLaunchKernel.flags = (int)flags; \ + }; +// hipExtLaunchMultiKernelMultiDevice[('hipLaunchParams*', 'launchParamsList'), +// ('int', 'numDevices'), ('unsigned int', 'flags')] +#define INIT_hipExtLaunchMultiKernelMultiDevice_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExtLaunchMultiKernelMultiDevice.launchParamsList = \ + (hipLaunchParams *)launchParamsList; \ + cb_data.args.hipExtLaunchMultiKernelMultiDevice.numDevices = \ + (int)numDevices; \ + cb_data.args.hipExtLaunchMultiKernelMultiDevice.flags = \ + (unsigned int)flags; \ + }; +// hipExtMallocWithFlags[('void**', 'ptr'), ('size_t', 'sizeBytes'), ('unsigned +// int', 'flags')] +#define INIT_hipExtMallocWithFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExtMallocWithFlags.ptr = (void **)ptr; \ + cb_data.args.hipExtMallocWithFlags.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipExtMallocWithFlags.flags = (unsigned int)flags; \ + }; +// hipExtModuleLaunchKernel[('hipFunction_t', 'f'), ('unsigned int', +// 'globalWorkSizeX'), ('unsigned int', 'globalWorkSizeY'), ('unsigned int', +// 'globalWorkSizeZ'), ('unsigned int', 'localWorkSizeX'), ('unsigned int', +// 'localWorkSizeY'), ('unsigned int', 'localWorkSizeZ'), ('size_t', +// 'sharedMemBytes'), ('hipStream_t', 'hStream'), ('void**', 'kernelParams'), +// ('void**', 'extra'), ('hipEvent_t', 'startEvent'), ('hipEvent_t', +// 'stopEvent'), ('unsigned int', 'flags')] +#define INIT_hipExtModuleLaunchKernel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExtModuleLaunchKernel.f = (hipFunction_t)f; \ + cb_data.args.hipExtModuleLaunchKernel.globalWorkSizeX = \ + (unsigned int)globalWorkSizeX; \ + cb_data.args.hipExtModuleLaunchKernel.globalWorkSizeY = \ + (unsigned int)globalWorkSizeY; \ + cb_data.args.hipExtModuleLaunchKernel.globalWorkSizeZ = \ + (unsigned int)globalWorkSizeZ; \ + cb_data.args.hipExtModuleLaunchKernel.localWorkSizeX = \ + (unsigned int)localWorkSizeX; \ + cb_data.args.hipExtModuleLaunchKernel.localWorkSizeY = \ + (unsigned int)localWorkSizeY; \ + cb_data.args.hipExtModuleLaunchKernel.localWorkSizeZ = \ + (unsigned int)localWorkSizeZ; \ + cb_data.args.hipExtModuleLaunchKernel.sharedMemBytes = \ + (size_t)sharedMemBytes; \ + cb_data.args.hipExtModuleLaunchKernel.hStream = (hipStream_t)hStream; \ + cb_data.args.hipExtModuleLaunchKernel.kernelParams = \ + (void **)kernelParams; \ + cb_data.args.hipExtModuleLaunchKernel.extra = (void **)extra; \ + cb_data.args.hipExtModuleLaunchKernel.startEvent = (hipEvent_t)startEvent; \ + cb_data.args.hipExtModuleLaunchKernel.stopEvent = (hipEvent_t)stopEvent; \ + cb_data.args.hipExtModuleLaunchKernel.flags = (unsigned int)flags; \ + }; +// hipExtStreamCreateWithCUMask[('hipStream_t*', 'stream'), ('unsigned int', +// 'cuMaskSize'), ('const unsigned int*', 'cuMask')] +#define INIT_hipExtStreamCreateWithCUMask_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExtStreamCreateWithCUMask.stream = (hipStream_t *)stream; \ + cb_data.args.hipExtStreamCreateWithCUMask.cuMaskSize = \ + (unsigned int)cuMaskSize; \ + cb_data.args.hipExtStreamCreateWithCUMask.cuMask = \ + (const unsigned int *)cuMask; \ + }; +// hipExtStreamGetCUMask[('hipStream_t', 'stream'), ('unsigned int', +// 'cuMaskSize'), ('unsigned int*', 'cuMask')] +#define INIT_hipExtStreamGetCUMask_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExtStreamGetCUMask.stream = (hipStream_t)stream; \ + cb_data.args.hipExtStreamGetCUMask.cuMaskSize = (unsigned int)cuMaskSize; \ + cb_data.args.hipExtStreamGetCUMask.cuMask = (unsigned int *)cuMask; \ + }; +// hipExternalMemoryGetMappedBuffer[('void**', 'devPtr'), +// ('hipExternalMemory_t', 'extMem'), ('const hipExternalMemoryBufferDesc*', +// 'bufferDesc')] +#define INIT_hipExternalMemoryGetMappedBuffer_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExternalMemoryGetMappedBuffer.devPtr = (void **)devPtr; \ + cb_data.args.hipExternalMemoryGetMappedBuffer.extMem = \ + (hipExternalMemory_t)extMem; \ + cb_data.args.hipExternalMemoryGetMappedBuffer.bufferDesc = \ + (const hipExternalMemoryBufferDesc *)bufferDesc; \ + }; +// hipExternalMemoryGetMappedMipmappedArray[('hipMipmappedArray_t*', 'mipmap'), +// ('hipExternalMemory_t', 'extMem'), ('const +// hipExternalMemoryMipmappedArrayDesc*', 'mipmapDesc')] +#define INIT_hipExternalMemoryGetMappedMipmappedArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipExternalMemoryGetMappedMipmappedArray.mipmap = \ + (hipMipmappedArray_t *)mipmap; \ + cb_data.args.hipExternalMemoryGetMappedMipmappedArray.extMem = \ + (hipExternalMemory_t)extMem; \ + cb_data.args.hipExternalMemoryGetMappedMipmappedArray.mipmapDesc = \ + (const hipExternalMemoryMipmappedArrayDesc *)mipmapDesc; \ + }; +// hipFree[('void*', 'ptr')] +#define INIT_hipFree_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFree.ptr = (void *)ptr; \ + }; +// hipFreeArray[('hipArray_t', 'array')] +#define INIT_hipFreeArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFreeArray.array = (hipArray_t)array; \ + }; +// hipFreeAsync[('void*', 'dev_ptr'), ('hipStream_t', 'stream')] +#define INIT_hipFreeAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFreeAsync.dev_ptr = (void *)dev_ptr; \ + cb_data.args.hipFreeAsync.stream = (hipStream_t)stream; \ + }; +// hipFreeHost[('void*', 'ptr')] +#define INIT_hipFreeHost_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFreeHost.ptr = (void *)ptr; \ + }; +// hipFreeMipmappedArray[('hipMipmappedArray_t', 'mipmappedArray')] +#define INIT_hipFreeMipmappedArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFreeMipmappedArray.mipmappedArray = \ + (hipMipmappedArray_t)mipmappedArray; \ + }; +// hipFuncGetAttribute[('int*', 'value'), ('hipFunction_attribute', 'attrib'), +// ('hipFunction_t', 'hfunc')] +#define INIT_hipFuncGetAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFuncGetAttribute.value = (int *)value; \ + cb_data.args.hipFuncGetAttribute.attrib = (hipFunction_attribute)attrib; \ + cb_data.args.hipFuncGetAttribute.hfunc = (hipFunction_t)hfunc; \ + }; +// hipFuncGetAttributes[('hipFuncAttributes*', 'attr'), ('const void*', 'func')] +#define INIT_hipFuncGetAttributes_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFuncGetAttributes.attr = (hipFuncAttributes *)attr; \ + cb_data.args.hipFuncGetAttributes.func = (const void *)func; \ + }; +// hipFuncSetAttribute[('const void*', 'func'), ('hipFuncAttribute', 'attr'), +// ('int', 'value')] +#define INIT_hipFuncSetAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFuncSetAttribute.func = (const void *)func; \ + cb_data.args.hipFuncSetAttribute.attr = (hipFuncAttribute)attr; \ + cb_data.args.hipFuncSetAttribute.value = (int)value; \ + }; +// hipFuncSetCacheConfig[('const void*', 'func'), ('hipFuncCache_t', 'config')] +#define INIT_hipFuncSetCacheConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFuncSetCacheConfig.func = (const void *)func; \ + cb_data.args.hipFuncSetCacheConfig.config = (hipFuncCache_t)cacheConfig; \ + }; +// hipFuncSetSharedMemConfig[('const void*', 'func'), ('hipSharedMemConfig', +// 'config')] +#define INIT_hipFuncSetSharedMemConfig_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipFuncSetSharedMemConfig.func = (const void *)func; \ + cb_data.args.hipFuncSetSharedMemConfig.config = \ + (hipSharedMemConfig)config; \ + }; +// hipGLGetDevices[('unsigned int*', 'pHipDeviceCount'), ('int*', +// 'pHipDevices'), ('unsigned int', 'hipDeviceCount'), ('hipGLDeviceList', +// 'deviceList')] +#define INIT_hipGLGetDevices_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGLGetDevices.pHipDeviceCount = \ + (unsigned int *)pHipDeviceCount; \ + cb_data.args.hipGLGetDevices.pHipDevices = (int *)pHipDevices; \ + cb_data.args.hipGLGetDevices.hipDeviceCount = \ + (unsigned int)hipDeviceCount; \ + cb_data.args.hipGLGetDevices.deviceList = (hipGLDeviceList)deviceList; \ + }; +// hipGetChannelDesc[('hipChannelFormatDesc*', 'desc'), ('hipArray_const_t', +// 'array')] +#define INIT_hipGetChannelDesc_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetChannelDesc.desc = (hipChannelFormatDesc *)desc; \ + cb_data.args.hipGetChannelDesc.array = (hipArray_const_t)array; \ + }; +// hipGetDevice[('int*', 'deviceId')] +#define INIT_hipGetDevice_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetDevice.deviceId = (int *)deviceId; \ + }; +// hipGetDeviceCount[('int*', 'count')] +#define INIT_hipGetDeviceCount_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetDeviceCount.count = (int *)count; \ + }; +// hipGetDeviceFlags[('unsigned int*', 'flags')] +#define INIT_hipGetDeviceFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetDeviceFlags.flags = (unsigned int *)flags; \ + }; +// hipGetDevicePropertiesR0000[('hipDeviceProp_tR0000*', 'prop'), ('int', +// 'device')] +#define INIT_hipGetDevicePropertiesR0000_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetDevicePropertiesR0000.prop = \ + (hipDeviceProp_tR0000 *)prop; \ + cb_data.args.hipGetDevicePropertiesR0000.device = (int)device; \ + }; +// hipGetDevicePropertiesR0600[('hipDeviceProp_tR0600*', 'prop'), ('int', +// 'deviceId')] +#define INIT_hipGetDevicePropertiesR0600_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetDevicePropertiesR0600.prop = \ + (hipDeviceProp_tR0600 *)prop; \ + cb_data.args.hipGetDevicePropertiesR0600.deviceId = (int)device; \ + }; +// hipGetErrorString[] +#define INIT_hipGetErrorString_CB_ARGS_DATA(cb_data) {}; +// hipGetFuncBySymbol[('hipFunction_t*', 'functionPtr'), ('const void*', +// 'symbolPtr')] +#define INIT_hipGetFuncBySymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetFuncBySymbol.functionPtr = \ + (hipFunction_t *)functionPtr; \ + cb_data.args.hipGetFuncBySymbol.symbolPtr = (const void *)symbolPtr; \ + }; +// hipGetLastError[] +#define INIT_hipGetLastError_CB_ARGS_DATA(cb_data) {}; +// hipGetMipmappedArrayLevel[('hipArray_t*', 'levelArray'), +// ('hipMipmappedArray_const_t', 'mipmappedArray'), ('unsigned int', 'level')] +#define INIT_hipGetMipmappedArrayLevel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetMipmappedArrayLevel.levelArray = \ + (hipArray_t *)levelArray; \ + cb_data.args.hipGetMipmappedArrayLevel.mipmappedArray = \ + (hipMipmappedArray_const_t)mipmappedArray; \ + cb_data.args.hipGetMipmappedArrayLevel.level = (unsigned int)level; \ + }; +// hipGetProcAddress[('const char*', 'symbol'), ('void**', 'pfn'), ('int', +// 'hipVersion'), ('uint64_t', 'flags'), ('hipDriverProcAddressQueryResult*', +// 'symbolStatus')] +#define INIT_hipGetProcAddress_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetProcAddress.symbol = (symbol) ? strdup(symbol) : NULL; \ + cb_data.args.hipGetProcAddress.pfn = (void **)pfn; \ + cb_data.args.hipGetProcAddress.hipVersion = (int)hipVersion; \ + cb_data.args.hipGetProcAddress.flags = (uint64_t)flags; \ + cb_data.args.hipGetProcAddress.symbolStatus = \ + (hipDriverProcAddressQueryResult *)symbolStatus; \ + }; +// hipGetSymbolAddress[('void**', 'devPtr'), ('const void*', 'symbol')] +#define INIT_hipGetSymbolAddress_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetSymbolAddress.devPtr = (void **)devPtr; \ + cb_data.args.hipGetSymbolAddress.symbol = (const void *)symbol; \ + }; +// hipGetSymbolSize[('size_t*', 'size'), ('const void*', 'symbol')] +#define INIT_hipGetSymbolSize_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGetSymbolSize.size = (size_t *)sizePtr; \ + cb_data.args.hipGetSymbolSize.symbol = (const void *)symbol; \ + }; +// hipGraphAddChildGraphNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('hipGraph_t', 'childGraph')] +#define INIT_hipGraphAddChildGraphNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddChildGraphNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddChildGraphNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddChildGraphNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddChildGraphNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddChildGraphNode.childGraph = \ + (hipGraph_t)childGraph; \ + }; +// hipGraphAddDependencies[('hipGraph_t', 'graph'), ('const hipGraphNode_t*', +// 'from'), ('const hipGraphNode_t*', 'to'), ('size_t', 'numDependencies')] +#define INIT_hipGraphAddDependencies_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddDependencies.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddDependencies.from = (const hipGraphNode_t *)from; \ + cb_data.args.hipGraphAddDependencies.to = (const hipGraphNode_t *)to; \ + cb_data.args.hipGraphAddDependencies.numDependencies = \ + (size_t)numDependencies; \ + }; +// hipGraphAddEmptyNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies')] +#define INIT_hipGraphAddEmptyNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddEmptyNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddEmptyNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddEmptyNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddEmptyNode.numDependencies = \ + (size_t)numDependencies; \ + }; +// hipGraphAddEventRecordNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('hipEvent_t', 'event')] +#define INIT_hipGraphAddEventRecordNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddEventRecordNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddEventRecordNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddEventRecordNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddEventRecordNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddEventRecordNode.event = (hipEvent_t)event; \ + }; +// hipGraphAddEventWaitNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('hipEvent_t', 'event')] +#define INIT_hipGraphAddEventWaitNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddEventWaitNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddEventWaitNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddEventWaitNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddEventWaitNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddEventWaitNode.event = (hipEvent_t)event; \ + }; +// hipGraphAddExternalSemaphoresSignalNode[('hipGraphNode_t*', 'pGraphNode'), +// ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), +// ('size_t', 'numDependencies'), ('const +// hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] +#define INIT_hipGraphAddExternalSemaphoresSignalNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.graph = \ + (hipGraph_t)graph; \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddExternalSemaphoresSignalNode.nodeParams = \ + (const hipExternalSemaphoreSignalNodeParams *)nodeParams; \ + }; +// hipGraphAddExternalSemaphoresWaitNode[('hipGraphNode_t*', 'pGraphNode'), +// ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), +// ('size_t', 'numDependencies'), ('const hipExternalSemaphoreWaitNodeParams*', +// 'nodeParams')] +#define INIT_hipGraphAddExternalSemaphoresWaitNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.graph = \ + (hipGraph_t)graph; \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddExternalSemaphoresWaitNode.nodeParams = \ + (const hipExternalSemaphoreWaitNodeParams *)nodeParams; \ + }; +// hipGraphAddHostNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('const hipHostNodeParams*', 'pNodeParams')] +#define INIT_hipGraphAddHostNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddHostNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddHostNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddHostNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddHostNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddHostNode.pNodeParams = \ + (const hipHostNodeParams *)pNodeParams; \ + }; +// hipGraphAddKernelNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('const hipKernelNodeParams*', 'pNodeParams')] +#define INIT_hipGraphAddKernelNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddKernelNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddKernelNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddKernelNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddKernelNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddKernelNode.pNodeParams = \ + (const hipKernelNodeParams *)pNodeParams; \ + }; +// hipGraphAddMemAllocNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('hipMemAllocNodeParams*', 'pNodeParams')] +#define INIT_hipGraphAddMemAllocNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddMemAllocNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddMemAllocNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddMemAllocNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddMemAllocNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddMemAllocNode.pNodeParams = \ + (hipMemAllocNodeParams *)pNodeParams; \ + }; +// hipGraphAddMemFreeNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('void*', 'dev_ptr')] +#define INIT_hipGraphAddMemFreeNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddMemFreeNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddMemFreeNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddMemFreeNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddMemFreeNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddMemFreeNode.dev_ptr = (void *)dev_ptr; \ + }; +// hipGraphAddMemcpyNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('const hipMemcpy3DParms*', 'pCopyParams')] +#define INIT_hipGraphAddMemcpyNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddMemcpyNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddMemcpyNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddMemcpyNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddMemcpyNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddMemcpyNode.pCopyParams = \ + (const hipMemcpy3DParms *)pCopyParams; \ + }; +// hipGraphAddMemcpyNode1D[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('void*', 'dst'), ('const void*', 'src'), ('size_t', +// 'count'), ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphAddMemcpyNode1D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddMemcpyNode1D.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddMemcpyNode1D.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddMemcpyNode1D.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddMemcpyNode1D.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddMemcpyNode1D.dst = (void *)dst; \ + cb_data.args.hipGraphAddMemcpyNode1D.src = (const void *)src; \ + cb_data.args.hipGraphAddMemcpyNode1D.count = (size_t)count; \ + cb_data.args.hipGraphAddMemcpyNode1D.kind = (hipMemcpyKind)kind; \ + }; +// hipGraphAddMemcpyNodeFromSymbol[('hipGraphNode_t*', 'pGraphNode'), +// ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), +// ('size_t', 'numDependencies'), ('void*', 'dst'), ('const void*', 'symbol'), +// ('size_t', 'count'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphAddMemcpyNodeFromSymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.dst = (void *)dst; \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.symbol = \ + (const void *)symbol; \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.count = (size_t)count; \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.offset = (size_t)offset; \ + cb_data.args.hipGraphAddMemcpyNodeFromSymbol.kind = (hipMemcpyKind)kind; \ + }; +// hipGraphAddMemcpyNodeToSymbol[('hipGraphNode_t*', 'pGraphNode'), +// ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), +// ('size_t', 'numDependencies'), ('const void*', 'symbol'), ('const void*', +// 'src'), ('size_t', 'count'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphAddMemcpyNodeToSymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.symbol = (const void *)symbol; \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.src = (const void *)src; \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.count = (size_t)count; \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.offset = (size_t)offset; \ + cb_data.args.hipGraphAddMemcpyNodeToSymbol.kind = (hipMemcpyKind)kind; \ + }; +// hipGraphAddMemsetNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', +// 'numDependencies'), ('const hipMemsetParams*', 'pMemsetParams')] +#define INIT_hipGraphAddMemsetNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddMemsetNode.pGraphNode = \ + (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddMemsetNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddMemsetNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddMemsetNode.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipGraphAddMemsetNode.pMemsetParams = \ + (const hipMemsetParams *)pMemsetParams; \ + }; +// hipGraphAddNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', 'graph'), +// ('const hipGraphNode_t*', 'pDependencies'), ('size_t', 'numDependencies'), +// ('hipGraphNodeParams*', 'nodeParams')] +#define INIT_hipGraphAddNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphAddNode.pGraphNode = (hipGraphNode_t *)pGraphNode; \ + cb_data.args.hipGraphAddNode.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphAddNode.pDependencies = \ + (const hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphAddNode.numDependencies = (size_t)numDependencies; \ + cb_data.args.hipGraphAddNode.nodeParams = \ + (hipGraphNodeParams *)nodeParams; \ + }; +// hipGraphChildGraphNodeGetGraph[('hipGraphNode_t', 'node'), ('hipGraph_t*', +// 'pGraph')] +#define INIT_hipGraphChildGraphNodeGetGraph_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphChildGraphNodeGetGraph.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphChildGraphNodeGetGraph.pGraph = (hipGraph_t *)pGraph; \ + }; +// hipGraphClone[('hipGraph_t*', 'pGraphClone'), ('hipGraph_t', +// 'originalGraph')] +#define INIT_hipGraphClone_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphClone.pGraphClone = (hipGraph_t *)pGraphClone; \ + cb_data.args.hipGraphClone.originalGraph = (hipGraph_t)originalGraph; \ + }; +// hipGraphCreate[('hipGraph_t*', 'pGraph'), ('unsigned int', 'flags')] +#define INIT_hipGraphCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphCreate.pGraph = (hipGraph_t *)pGraph; \ + cb_data.args.hipGraphCreate.flags = (unsigned int)flags; \ + }; +// hipGraphDebugDotPrint[('hipGraph_t', 'graph'), ('const char*', 'path'), +// ('unsigned int', 'flags')] +#define INIT_hipGraphDebugDotPrint_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphDebugDotPrint.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphDebugDotPrint.path = (path) ? strdup(path) : NULL; \ + cb_data.args.hipGraphDebugDotPrint.flags = (unsigned int)flags; \ + }; +// hipGraphDestroy[('hipGraph_t', 'graph')] +#define INIT_hipGraphDestroy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphDestroy.graph = (hipGraph_t)graph; \ + }; +// hipGraphDestroyNode[('hipGraphNode_t', 'node')] +#define INIT_hipGraphDestroyNode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphDestroyNode.node = (hipGraphNode_t)node; \ + }; +// hipGraphEventRecordNodeGetEvent[('hipGraphNode_t', 'node'), ('hipEvent_t*', +// 'event_out')] +#define INIT_hipGraphEventRecordNodeGetEvent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphEventRecordNodeGetEvent.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphEventRecordNodeGetEvent.event_out = \ + (hipEvent_t *)event_out; \ + }; +// hipGraphEventRecordNodeSetEvent[('hipGraphNode_t', 'node'), ('hipEvent_t', +// 'event')] +#define INIT_hipGraphEventRecordNodeSetEvent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphEventRecordNodeSetEvent.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphEventRecordNodeSetEvent.event = (hipEvent_t)event; \ + }; +// hipGraphEventWaitNodeGetEvent[('hipGraphNode_t', 'node'), ('hipEvent_t*', +// 'event_out')] +#define INIT_hipGraphEventWaitNodeGetEvent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphEventWaitNodeGetEvent.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphEventWaitNodeGetEvent.event_out = \ + (hipEvent_t *)event_out; \ + }; +// hipGraphEventWaitNodeSetEvent[('hipGraphNode_t', 'node'), ('hipEvent_t', +// 'event')] +#define INIT_hipGraphEventWaitNodeSetEvent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphEventWaitNodeSetEvent.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphEventWaitNodeSetEvent.event = (hipEvent_t)event; \ + }; +// hipGraphExecChildGraphNodeSetParams[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'node'), ('hipGraph_t', 'childGraph')] +#define INIT_hipGraphExecChildGraphNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecChildGraphNodeSetParams.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecChildGraphNodeSetParams.node = \ + (hipGraphNode_t)node; \ + cb_data.args.hipGraphExecChildGraphNodeSetParams.childGraph = \ + (hipGraph_t)childGraph; \ + }; +// hipGraphExecDestroy[('hipGraphExec_t', 'graphExec')] +#define INIT_hipGraphExecDestroy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecDestroy.graphExec = (hipGraphExec_t)pGraphExec; \ + }; +// hipGraphExecEventRecordNodeSetEvent[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'hNode'), ('hipEvent_t', 'event')] +#define INIT_hipGraphExecEventRecordNodeSetEvent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecEventRecordNodeSetEvent.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecEventRecordNodeSetEvent.hNode = \ + (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExecEventRecordNodeSetEvent.event = \ + (hipEvent_t)event; \ + }; +// hipGraphExecEventWaitNodeSetEvent[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'hNode'), ('hipEvent_t', 'event')] +#define INIT_hipGraphExecEventWaitNodeSetEvent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecEventWaitNodeSetEvent.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecEventWaitNodeSetEvent.hNode = \ + (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExecEventWaitNodeSetEvent.event = (hipEvent_t)event; \ + }; +// hipGraphExecExternalSemaphoresSignalNodeSetParams[('hipGraphExec_t', +// 'hGraphExec'), ('hipGraphNode_t', 'hNode'), ('const +// hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] +#define INIT_hipGraphExecExternalSemaphoresSignalNodeSetParams_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipGraphExecExternalSemaphoresSignalNodeSetParams \ + .hGraphExec = (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecExternalSemaphoresSignalNodeSetParams.hNode = \ + (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExecExternalSemaphoresSignalNodeSetParams \ + .nodeParams = \ + (const hipExternalSemaphoreSignalNodeParams *)nodeParams; \ + }; +// hipGraphExecExternalSemaphoresWaitNodeSetParams[('hipGraphExec_t', +// 'hGraphExec'), ('hipGraphNode_t', 'hNode'), ('const +// hipExternalSemaphoreWaitNodeParams*', 'nodeParams')] +#define INIT_hipGraphExecExternalSemaphoresWaitNodeSetParams_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipGraphExecExternalSemaphoresWaitNodeSetParams.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecExternalSemaphoresWaitNodeSetParams.hNode = \ + (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExecExternalSemaphoresWaitNodeSetParams.nodeParams = \ + (const hipExternalSemaphoreWaitNodeParams *)nodeParams; \ + }; +// hipGraphExecHostNodeSetParams[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'node'), ('const hipHostNodeParams*', 'pNodeParams')] +#define INIT_hipGraphExecHostNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecHostNodeSetParams.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecHostNodeSetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphExecHostNodeSetParams.pNodeParams = \ + (const hipHostNodeParams *)pNodeParams; \ + }; +// hipGraphExecKernelNodeSetParams[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'node'), ('const hipKernelNodeParams*', 'pNodeParams')] +#define INIT_hipGraphExecKernelNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecKernelNodeSetParams.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecKernelNodeSetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphExecKernelNodeSetParams.pNodeParams = \ + (const hipKernelNodeParams *)pNodeParams; \ + }; +// hipGraphExecMemcpyNodeSetParams[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'node'), ('hipMemcpy3DParms*', 'pNodeParams')] +#define INIT_hipGraphExecMemcpyNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecMemcpyNodeSetParams.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecMemcpyNodeSetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphExecMemcpyNodeSetParams.pNodeParams = \ + (hipMemcpy3DParms *)pNodeParams; \ + }; +// hipGraphExecMemcpyNodeSetParams1D[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'node'), ('void*', 'dst'), ('const void*', 'src'), +// ('size_t', 'count'), ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphExecMemcpyNodeSetParams1D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecMemcpyNodeSetParams1D.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecMemcpyNodeSetParams1D.node = \ + (hipGraphNode_t)node; \ + cb_data.args.hipGraphExecMemcpyNodeSetParams1D.dst = (void *)dst; \ + cb_data.args.hipGraphExecMemcpyNodeSetParams1D.src = (const void *)src; \ + cb_data.args.hipGraphExecMemcpyNodeSetParams1D.count = (size_t)count; \ + cb_data.args.hipGraphExecMemcpyNodeSetParams1D.kind = (hipMemcpyKind)kind; \ + }; +// hipGraphExecMemcpyNodeSetParamsFromSymbol[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'node'), ('void*', 'dst'), ('const void*', 'symbol'), +// ('size_t', 'count'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphExecMemcpyNodeSetParamsFromSymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsFromSymbol.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsFromSymbol.node = \ + (hipGraphNode_t)node; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsFromSymbol.dst = (void *)dst; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsFromSymbol.symbol = \ + (const void *)symbol; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsFromSymbol.count = \ + (size_t)count; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsFromSymbol.offset = \ + (size_t)offset; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsFromSymbol.kind = \ + (hipMemcpyKind)kind; \ + }; +// hipGraphExecMemcpyNodeSetParamsToSymbol[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'node'), ('const void*', 'symbol'), ('const void*', +// 'src'), ('size_t', 'count'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphExecMemcpyNodeSetParamsToSymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsToSymbol.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsToSymbol.node = \ + (hipGraphNode_t)node; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsToSymbol.symbol = \ + (const void *)symbol; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsToSymbol.src = \ + (const void *)src; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsToSymbol.count = \ + (size_t)count; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsToSymbol.offset = \ + (size_t)offset; \ + cb_data.args.hipGraphExecMemcpyNodeSetParamsToSymbol.kind = \ + (hipMemcpyKind)kind; \ + }; +// hipGraphExecMemsetNodeSetParams[('hipGraphExec_t', 'hGraphExec'), +// ('hipGraphNode_t', 'node'), ('const hipMemsetParams*', 'pNodeParams')] +#define INIT_hipGraphExecMemsetNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecMemsetNodeSetParams.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecMemsetNodeSetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphExecMemsetNodeSetParams.pNodeParams = \ + (const hipMemsetParams *)pNodeParams; \ + }; +// hipGraphExecUpdate[('hipGraphExec_t', 'hGraphExec'), ('hipGraph_t', +// 'hGraph'), ('hipGraphNode_t*', 'hErrorNode_out'), +// ('hipGraphExecUpdateResult*', 'updateResult_out')] +#define INIT_hipGraphExecUpdate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExecUpdate.hGraphExec = (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphExecUpdate.hGraph = (hipGraph_t)hGraph; \ + cb_data.args.hipGraphExecUpdate.hErrorNode_out = \ + (hipGraphNode_t *)hErrorNode_out; \ + cb_data.args.hipGraphExecUpdate.updateResult_out = \ + (hipGraphExecUpdateResult *)updateResult_out; \ + }; +// hipGraphExternalSemaphoresSignalNodeGetParams[('hipGraphNode_t', 'hNode'), +// ('hipExternalSemaphoreSignalNodeParams*', 'params_out')] +#define INIT_hipGraphExternalSemaphoresSignalNodeGetParams_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipGraphExternalSemaphoresSignalNodeGetParams.hNode = \ + (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExternalSemaphoresSignalNodeGetParams.params_out = \ + (hipExternalSemaphoreSignalNodeParams *)params_out; \ + }; +// hipGraphExternalSemaphoresSignalNodeSetParams[('hipGraphNode_t', 'hNode'), +// ('const hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] +#define INIT_hipGraphExternalSemaphoresSignalNodeSetParams_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipGraphExternalSemaphoresSignalNodeSetParams.hNode = \ + (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExternalSemaphoresSignalNodeSetParams.nodeParams = \ + (const hipExternalSemaphoreSignalNodeParams *)nodeParams; \ + }; +// hipGraphExternalSemaphoresWaitNodeGetParams[('hipGraphNode_t', 'hNode'), +// ('hipExternalSemaphoreWaitNodeParams*', 'params_out')] +#define INIT_hipGraphExternalSemaphoresWaitNodeGetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExternalSemaphoresWaitNodeGetParams.hNode = \ + (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExternalSemaphoresWaitNodeGetParams.params_out = \ + (hipExternalSemaphoreWaitNodeParams *)params_out; \ + }; +// hipGraphExternalSemaphoresWaitNodeSetParams[('hipGraphNode_t', 'hNode'), +// ('const hipExternalSemaphoreWaitNodeParams*', 'nodeParams')] +#define INIT_hipGraphExternalSemaphoresWaitNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphExternalSemaphoresWaitNodeSetParams.hNode = \ + (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphExternalSemaphoresWaitNodeSetParams.nodeParams = \ + (const hipExternalSemaphoreWaitNodeParams *)nodeParams; \ + }; +// hipGraphGetEdges[('hipGraph_t', 'graph'), ('hipGraphNode_t*', 'from'), +// ('hipGraphNode_t*', 'to'), ('size_t*', 'numEdges')] +#define INIT_hipGraphGetEdges_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphGetEdges.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphGetEdges.from = (hipGraphNode_t *)from; \ + cb_data.args.hipGraphGetEdges.to = (hipGraphNode_t *)to; \ + cb_data.args.hipGraphGetEdges.numEdges = (size_t *)numEdges; \ + }; +// hipGraphGetNodes[('hipGraph_t', 'graph'), ('hipGraphNode_t*', 'nodes'), +// ('size_t*', 'numNodes')] +#define INIT_hipGraphGetNodes_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphGetNodes.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphGetNodes.nodes = (hipGraphNode_t *)nodes; \ + cb_data.args.hipGraphGetNodes.numNodes = (size_t *)numNodes; \ + }; +// hipGraphGetRootNodes[('hipGraph_t', 'graph'), ('hipGraphNode_t*', +// 'pRootNodes'), ('size_t*', 'pNumRootNodes')] +#define INIT_hipGraphGetRootNodes_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphGetRootNodes.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphGetRootNodes.pRootNodes = \ + (hipGraphNode_t *)pRootNodes; \ + cb_data.args.hipGraphGetRootNodes.pNumRootNodes = (size_t *)pNumRootNodes; \ + }; +// hipGraphHostNodeGetParams[('hipGraphNode_t', 'node'), ('hipHostNodeParams*', +// 'pNodeParams')] +#define INIT_hipGraphHostNodeGetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphHostNodeGetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphHostNodeGetParams.pNodeParams = \ + (hipHostNodeParams *)pNodeParams; \ + }; +// hipGraphHostNodeSetParams[('hipGraphNode_t', 'node'), ('const +// hipHostNodeParams*', 'pNodeParams')] +#define INIT_hipGraphHostNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphHostNodeSetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphHostNodeSetParams.pNodeParams = \ + (const hipHostNodeParams *)pNodeParams; \ + }; +// hipGraphInstantiate[('hipGraphExec_t*', 'pGraphExec'), ('hipGraph_t', +// 'graph'), ('hipGraphNode_t*', 'pErrorNode'), ('char*', 'pLogBuffer'), +// ('size_t', 'bufferSize')] +#define INIT_hipGraphInstantiate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphInstantiate.pGraphExec = \ + (hipGraphExec_t *)pGraphExec; \ + cb_data.args.hipGraphInstantiate.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphInstantiate.pErrorNode = \ + (hipGraphNode_t *)pErrorNode; \ + cb_data.args.hipGraphInstantiate.pLogBuffer = (char *)pLogBuffer; \ + cb_data.args.hipGraphInstantiate.bufferSize = (size_t)bufferSize; \ + }; +// hipGraphInstantiateWithFlags[('hipGraphExec_t*', 'pGraphExec'), +// ('hipGraph_t', 'graph'), ('unsigned long long', 'flags')] +#define INIT_hipGraphInstantiateWithFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphInstantiateWithFlags.pGraphExec = \ + (hipGraphExec_t *)pGraphExec; \ + cb_data.args.hipGraphInstantiateWithFlags.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphInstantiateWithFlags.flags = \ + (unsigned long long)flags; \ + }; +// hipGraphInstantiateWithParams[('hipGraphExec_t*', 'pGraphExec'), +// ('hipGraph_t', 'graph'), ('hipGraphInstantiateParams*', 'instantiateParams')] +#define INIT_hipGraphInstantiateWithParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphInstantiateWithParams.pGraphExec = \ + (hipGraphExec_t *)pGraphExec; \ + cb_data.args.hipGraphInstantiateWithParams.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphInstantiateWithParams.instantiateParams = \ + (hipGraphInstantiateParams *)instantiateParams; \ + }; +// hipGraphKernelNodeCopyAttributes[('hipGraphNode_t', 'hSrc'), +// ('hipGraphNode_t', 'hDst')] +#define INIT_hipGraphKernelNodeCopyAttributes_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphKernelNodeCopyAttributes.hSrc = (hipGraphNode_t)hSrc; \ + cb_data.args.hipGraphKernelNodeCopyAttributes.hDst = (hipGraphNode_t)hDst; \ + }; +// hipGraphKernelNodeGetAttribute[('hipGraphNode_t', 'hNode'), +// ('hipLaunchAttributeID', 'attr'), ('hipLaunchAttributeValue*', 'value')] +#define INIT_hipGraphKernelNodeGetAttribute_CB_ARGS_DATA(cb_data) {}; +// hipGraphKernelNodeGetParams[('hipGraphNode_t', 'node'), +// ('hipKernelNodeParams*', 'pNodeParams')] +#define INIT_hipGraphKernelNodeGetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphKernelNodeGetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphKernelNodeGetParams.pNodeParams = \ + (hipKernelNodeParams *)pNodeParams; \ + }; +// hipGraphKernelNodeSetAttribute[('hipGraphNode_t', 'hNode'), +// ('hipLaunchAttributeID', 'attr'), ('const hipLaunchAttributeValue*', +// 'value')] +#define INIT_hipGraphKernelNodeSetAttribute_CB_ARGS_DATA(cb_data) {}; +// hipGraphKernelNodeSetParams[('hipGraphNode_t', 'node'), ('const +// hipKernelNodeParams*', 'pNodeParams')] +#define INIT_hipGraphKernelNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphKernelNodeSetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphKernelNodeSetParams.pNodeParams = \ + (const hipKernelNodeParams *)pNodeParams; \ + }; +// hipGraphLaunch[('hipGraphExec_t', 'graphExec'), ('hipStream_t', 'stream')] +#define INIT_hipGraphLaunch_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphLaunch.graphExec = (hipGraphExec_t)graphExec; \ + cb_data.args.hipGraphLaunch.stream = (hipStream_t)stream; \ + }; +// hipGraphMemAllocNodeGetParams[('hipGraphNode_t', 'node'), +// ('hipMemAllocNodeParams*', 'pNodeParams')] +#define INIT_hipGraphMemAllocNodeGetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemAllocNodeGetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemAllocNodeGetParams.pNodeParams = \ + (hipMemAllocNodeParams *)pNodeParams; \ + }; +// hipGraphMemFreeNodeGetParams[('hipGraphNode_t', 'node'), ('void*', +// 'dev_ptr')] +#define INIT_hipGraphMemFreeNodeGetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemFreeNodeGetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemFreeNodeGetParams.dev_ptr = (void *)dev_ptr; \ + }; +// hipGraphMemcpyNodeGetParams[('hipGraphNode_t', 'node'), ('hipMemcpy3DParms*', +// 'pNodeParams')] +#define INIT_hipGraphMemcpyNodeGetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemcpyNodeGetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemcpyNodeGetParams.pNodeParams = \ + (hipMemcpy3DParms *)pNodeParams; \ + }; +// hipGraphMemcpyNodeSetParams[('hipGraphNode_t', 'node'), ('const +// hipMemcpy3DParms*', 'pNodeParams')] +#define INIT_hipGraphMemcpyNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemcpyNodeSetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemcpyNodeSetParams.pNodeParams = \ + (const hipMemcpy3DParms *)pNodeParams; \ + }; +// hipGraphMemcpyNodeSetParams1D[('hipGraphNode_t', 'node'), ('void*', 'dst'), +// ('const void*', 'src'), ('size_t', 'count'), ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphMemcpyNodeSetParams1D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemcpyNodeSetParams1D.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemcpyNodeSetParams1D.dst = (void *)dst; \ + cb_data.args.hipGraphMemcpyNodeSetParams1D.src = (const void *)src; \ + cb_data.args.hipGraphMemcpyNodeSetParams1D.count = (size_t)count; \ + cb_data.args.hipGraphMemcpyNodeSetParams1D.kind = (hipMemcpyKind)kind; \ + }; +// hipGraphMemcpyNodeSetParamsFromSymbol[('hipGraphNode_t', 'node'), ('void*', +// 'dst'), ('const void*', 'symbol'), ('size_t', 'count'), ('size_t', 'offset'), +// ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphMemcpyNodeSetParamsFromSymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemcpyNodeSetParamsFromSymbol.node = \ + (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemcpyNodeSetParamsFromSymbol.dst = (void *)dst; \ + cb_data.args.hipGraphMemcpyNodeSetParamsFromSymbol.symbol = \ + (const void *)symbol; \ + cb_data.args.hipGraphMemcpyNodeSetParamsFromSymbol.count = (size_t)count; \ + cb_data.args.hipGraphMemcpyNodeSetParamsFromSymbol.offset = \ + (size_t)offset; \ + cb_data.args.hipGraphMemcpyNodeSetParamsFromSymbol.kind = \ + (hipMemcpyKind)kind; \ + }; +// hipGraphMemcpyNodeSetParamsToSymbol[('hipGraphNode_t', 'node'), ('const +// void*', 'symbol'), ('const void*', 'src'), ('size_t', 'count'), ('size_t', +// 'offset'), ('hipMemcpyKind', 'kind')] +#define INIT_hipGraphMemcpyNodeSetParamsToSymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemcpyNodeSetParamsToSymbol.node = \ + (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemcpyNodeSetParamsToSymbol.symbol = \ + (const void *)symbol; \ + cb_data.args.hipGraphMemcpyNodeSetParamsToSymbol.src = (const void *)src; \ + cb_data.args.hipGraphMemcpyNodeSetParamsToSymbol.count = (size_t)count; \ + cb_data.args.hipGraphMemcpyNodeSetParamsToSymbol.offset = (size_t)offset; \ + cb_data.args.hipGraphMemcpyNodeSetParamsToSymbol.kind = \ + (hipMemcpyKind)kind; \ + }; +// hipGraphMemsetNodeGetParams[('hipGraphNode_t', 'node'), ('hipMemsetParams*', +// 'pNodeParams')] +#define INIT_hipGraphMemsetNodeGetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemsetNodeGetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemsetNodeGetParams.pNodeParams = \ + (hipMemsetParams *)pNodeParams; \ + }; +// hipGraphMemsetNodeSetParams[('hipGraphNode_t', 'node'), ('const +// hipMemsetParams*', 'pNodeParams')] +#define INIT_hipGraphMemsetNodeSetParams_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphMemsetNodeSetParams.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphMemsetNodeSetParams.pNodeParams = \ + (const hipMemsetParams *)pNodeParams; \ + }; +// hipGraphNodeFindInClone[('hipGraphNode_t*', 'pNode'), ('hipGraphNode_t', +// 'originalNode'), ('hipGraph_t', 'clonedGraph')] +#define INIT_hipGraphNodeFindInClone_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphNodeFindInClone.pNode = (hipGraphNode_t *)pNode; \ + cb_data.args.hipGraphNodeFindInClone.originalNode = \ + (hipGraphNode_t)originalNode; \ + cb_data.args.hipGraphNodeFindInClone.clonedGraph = \ + (hipGraph_t)clonedGraph; \ + }; +// hipGraphNodeGetDependencies[('hipGraphNode_t', 'node'), ('hipGraphNode_t*', +// 'pDependencies'), ('size_t*', 'pNumDependencies')] +#define INIT_hipGraphNodeGetDependencies_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphNodeGetDependencies.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphNodeGetDependencies.pDependencies = \ + (hipGraphNode_t *)pDependencies; \ + cb_data.args.hipGraphNodeGetDependencies.pNumDependencies = \ + (size_t *)pNumDependencies; \ + }; +// hipGraphNodeGetDependentNodes[('hipGraphNode_t', 'node'), ('hipGraphNode_t*', +// 'pDependentNodes'), ('size_t*', 'pNumDependentNodes')] +#define INIT_hipGraphNodeGetDependentNodes_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphNodeGetDependentNodes.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphNodeGetDependentNodes.pDependentNodes = \ + (hipGraphNode_t *)pDependentNodes; \ + cb_data.args.hipGraphNodeGetDependentNodes.pNumDependentNodes = \ + (size_t *)pNumDependentNodes; \ + }; +// hipGraphNodeGetEnabled[('hipGraphExec_t', 'hGraphExec'), ('hipGraphNode_t', +// 'hNode'), ('unsigned int*', 'isEnabled')] +#define INIT_hipGraphNodeGetEnabled_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphNodeGetEnabled.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphNodeGetEnabled.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphNodeGetEnabled.isEnabled = (unsigned int *)isEnabled; \ + }; +// hipGraphNodeGetType[('hipGraphNode_t', 'node'), ('hipGraphNodeType*', +// 'pType')] +#define INIT_hipGraphNodeGetType_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphNodeGetType.node = (hipGraphNode_t)node; \ + cb_data.args.hipGraphNodeGetType.pType = (hipGraphNodeType *)pType; \ + }; +// hipGraphNodeSetEnabled[('hipGraphExec_t', 'hGraphExec'), ('hipGraphNode_t', +// 'hNode'), ('unsigned int', 'isEnabled')] +#define INIT_hipGraphNodeSetEnabled_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphNodeSetEnabled.hGraphExec = \ + (hipGraphExec_t)hGraphExec; \ + cb_data.args.hipGraphNodeSetEnabled.hNode = (hipGraphNode_t)hNode; \ + cb_data.args.hipGraphNodeSetEnabled.isEnabled = (unsigned int)isEnabled; \ + }; +// hipGraphReleaseUserObject[('hipGraph_t', 'graph'), ('hipUserObject_t', +// 'object'), ('unsigned int', 'count')] +#define INIT_hipGraphReleaseUserObject_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphReleaseUserObject.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphReleaseUserObject.object = (hipUserObject_t)object; \ + cb_data.args.hipGraphReleaseUserObject.count = (unsigned int)count; \ + }; +// hipGraphRemoveDependencies[('hipGraph_t', 'graph'), ('const hipGraphNode_t*', +// 'from'), ('const hipGraphNode_t*', 'to'), ('size_t', 'numDependencies')] +#define INIT_hipGraphRemoveDependencies_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphRemoveDependencies.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphRemoveDependencies.from = \ + (const hipGraphNode_t *)from; \ + cb_data.args.hipGraphRemoveDependencies.to = (const hipGraphNode_t *)to; \ + cb_data.args.hipGraphRemoveDependencies.numDependencies = \ + (size_t)numDependencies; \ + }; +// hipGraphRetainUserObject[('hipGraph_t', 'graph'), ('hipUserObject_t', +// 'object'), ('unsigned int', 'count'), ('unsigned int', 'flags')] +#define INIT_hipGraphRetainUserObject_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphRetainUserObject.graph = (hipGraph_t)graph; \ + cb_data.args.hipGraphRetainUserObject.object = (hipUserObject_t)object; \ + cb_data.args.hipGraphRetainUserObject.count = (unsigned int)count; \ + cb_data.args.hipGraphRetainUserObject.flags = (unsigned int)flags; \ + }; +// hipGraphUpload[('hipGraphExec_t', 'graphExec'), ('hipStream_t', 'stream')] +#define INIT_hipGraphUpload_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphUpload.graphExec = (hipGraphExec_t)graphExec; \ + cb_data.args.hipGraphUpload.stream = (hipStream_t)stream; \ + }; +// hipGraphicsGLRegisterBuffer[('hipGraphicsResource**', 'resource'), ('GLuint', +// 'buffer'), ('unsigned int', 'flags')] +#define INIT_hipGraphicsGLRegisterBuffer_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphicsGLRegisterBuffer.resource = \ + (hipGraphicsResource **)resource; \ + cb_data.args.hipGraphicsGLRegisterBuffer.buffer = (GLuint)buffer; \ + cb_data.args.hipGraphicsGLRegisterBuffer.flags = (unsigned int)flags; \ + }; +// hipGraphicsGLRegisterImage[('hipGraphicsResource**', 'resource'), ('GLuint', +// 'image'), ('GLenum', 'target'), ('unsigned int', 'flags')] +#define INIT_hipGraphicsGLRegisterImage_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphicsGLRegisterImage.resource = \ + (hipGraphicsResource **)resource; \ + cb_data.args.hipGraphicsGLRegisterImage.image = (GLuint)image; \ + cb_data.args.hipGraphicsGLRegisterImage.target = (GLenum)target; \ + cb_data.args.hipGraphicsGLRegisterImage.flags = (unsigned int)flags; \ + }; +// hipGraphicsMapResources[('int', 'count'), ('hipGraphicsResource_t*', +// 'resources'), ('hipStream_t', 'stream')] +#define INIT_hipGraphicsMapResources_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphicsMapResources.count = (int)count; \ + cb_data.args.hipGraphicsMapResources.resources = \ + (hipGraphicsResource_t *)resources; \ + cb_data.args.hipGraphicsMapResources.stream = (hipStream_t)stream; \ + }; +// hipGraphicsResourceGetMappedPointer[('void**', 'devPtr'), ('size_t*', +// 'size'), ('hipGraphicsResource_t', 'resource')] +#define INIT_hipGraphicsResourceGetMappedPointer_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphicsResourceGetMappedPointer.devPtr = (void **)devPtr; \ + cb_data.args.hipGraphicsResourceGetMappedPointer.size = (size_t *)size; \ + cb_data.args.hipGraphicsResourceGetMappedPointer.resource = \ + (hipGraphicsResource_t)resource; \ + }; +// hipGraphicsSubResourceGetMappedArray[('hipArray_t*', 'array'), +// ('hipGraphicsResource_t', 'resource'), ('unsigned int', 'arrayIndex'), +// ('unsigned int', 'mipLevel')] +#define INIT_hipGraphicsSubResourceGetMappedArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphicsSubResourceGetMappedArray.array = \ + (hipArray_t *)array; \ + cb_data.args.hipGraphicsSubResourceGetMappedArray.resource = \ + (hipGraphicsResource_t)resource; \ + cb_data.args.hipGraphicsSubResourceGetMappedArray.arrayIndex = \ + (unsigned int)arrayIndex; \ + cb_data.args.hipGraphicsSubResourceGetMappedArray.mipLevel = \ + (unsigned int)mipLevel; \ + }; +// hipGraphicsUnmapResources[('int', 'count'), ('hipGraphicsResource_t*', +// 'resources'), ('hipStream_t', 'stream')] +#define INIT_hipGraphicsUnmapResources_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphicsUnmapResources.count = (int)count; \ + cb_data.args.hipGraphicsUnmapResources.resources = \ + (hipGraphicsResource_t *)resources; \ + cb_data.args.hipGraphicsUnmapResources.stream = (hipStream_t)stream; \ + }; +// hipGraphicsUnregisterResource[('hipGraphicsResource_t', 'resource')] +#define INIT_hipGraphicsUnregisterResource_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipGraphicsUnregisterResource.resource = \ + (hipGraphicsResource_t)resource; \ + }; +// hipHccModuleLaunchKernel[('hipFunction_t', 'f'), ('unsigned int', +// 'globalWorkSizeX'), ('unsigned int', 'globalWorkSizeY'), ('unsigned int', +// 'globalWorkSizeZ'), ('unsigned int', 'blockDimX'), ('unsigned int', +// 'blockDimY'), ('unsigned int', 'blockDimZ'), ('size_t', 'sharedMemBytes'), +// ('hipStream_t', 'hStream'), ('void**', 'kernelParams'), ('void**', 'extra'), +// ('hipEvent_t', 'startEvent'), ('hipEvent_t', 'stopEvent')] +#define INIT_hipHccModuleLaunchKernel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipHccModuleLaunchKernel.f = (hipFunction_t)f; \ + cb_data.args.hipHccModuleLaunchKernel.globalWorkSizeX = \ + (unsigned int)globalWorkSizeX; \ + cb_data.args.hipHccModuleLaunchKernel.globalWorkSizeY = \ + (unsigned int)globalWorkSizeY; \ + cb_data.args.hipHccModuleLaunchKernel.globalWorkSizeZ = \ + (unsigned int)globalWorkSizeZ; \ + cb_data.args.hipHccModuleLaunchKernel.blockDimX = (unsigned int)blockDimX; \ + cb_data.args.hipHccModuleLaunchKernel.blockDimY = (unsigned int)blockDimY; \ + cb_data.args.hipHccModuleLaunchKernel.blockDimZ = (unsigned int)blockDimZ; \ + cb_data.args.hipHccModuleLaunchKernel.sharedMemBytes = \ + (size_t)sharedMemBytes; \ + cb_data.args.hipHccModuleLaunchKernel.hStream = (hipStream_t)hStream; \ + cb_data.args.hipHccModuleLaunchKernel.kernelParams = \ + (void **)kernelParams; \ + cb_data.args.hipHccModuleLaunchKernel.extra = (void **)extra; \ + cb_data.args.hipHccModuleLaunchKernel.startEvent = (hipEvent_t)startEvent; \ + cb_data.args.hipHccModuleLaunchKernel.stopEvent = (hipEvent_t)stopEvent; \ + }; +// hipHostAlloc[('void**', 'ptr'), ('size_t', 'size'), ('unsigned int', +// 'flags')] +#define INIT_hipHostAlloc_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipHostAlloc.ptr = (void **)ptr; \ + cb_data.args.hipHostAlloc.size = (size_t)sizeBytes; \ + cb_data.args.hipHostAlloc.flags = (unsigned int)flags; \ + }; +// hipHostFree[('void*', 'ptr')] +#define INIT_hipHostFree_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipHostFree.ptr = (void *)ptr; \ + }; +// hipHostGetDevicePointer[('void**', 'devPtr'), ('void*', 'hstPtr'), ('unsigned +// int', 'flags')] +#define INIT_hipHostGetDevicePointer_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipHostGetDevicePointer.devPtr = (void **)devicePointer; \ + cb_data.args.hipHostGetDevicePointer.hstPtr = (void *)hostPointer; \ + cb_data.args.hipHostGetDevicePointer.flags = (unsigned int)flags; \ + }; +// hipHostGetFlags[('unsigned int*', 'flagsPtr'), ('void*', 'hostPtr')] +#define INIT_hipHostGetFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipHostGetFlags.flagsPtr = (unsigned int *)flagsPtr; \ + cb_data.args.hipHostGetFlags.hostPtr = (void *)hostPtr; \ + }; +// hipHostMalloc[('void**', 'ptr'), ('size_t', 'size'), ('unsigned int', +// 'flags')] +#define INIT_hipHostMalloc_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipHostMalloc.ptr = (void **)ptr; \ + cb_data.args.hipHostMalloc.size = (size_t)sizeBytes; \ + cb_data.args.hipHostMalloc.flags = (unsigned int)flags; \ + }; +// hipHostRegister[('void*', 'hostPtr'), ('size_t', 'sizeBytes'), ('unsigned +// int', 'flags')] +#define INIT_hipHostRegister_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipHostRegister.hostPtr = (void *)hostPtr; \ + cb_data.args.hipHostRegister.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipHostRegister.flags = (unsigned int)flags; \ + }; +// hipHostUnregister[('void*', 'hostPtr')] +#define INIT_hipHostUnregister_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipHostUnregister.hostPtr = (void *)hostPtr; \ + }; +// hipImportExternalMemory[('hipExternalMemory_t*', 'extMem_out'), ('const +// hipExternalMemoryHandleDesc*', 'memHandleDesc')] +#define INIT_hipImportExternalMemory_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipImportExternalMemory.extMem_out = \ + (hipExternalMemory_t *)extMem_out; \ + cb_data.args.hipImportExternalMemory.memHandleDesc = \ + (const hipExternalMemoryHandleDesc *)memHandleDesc; \ + }; +// hipImportExternalSemaphore[('hipExternalSemaphore_t*', 'extSem_out'), ('const +// hipExternalSemaphoreHandleDesc*', 'semHandleDesc')] +#define INIT_hipImportExternalSemaphore_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipImportExternalSemaphore.extSem_out = \ + (hipExternalSemaphore_t *)extSem_out; \ + cb_data.args.hipImportExternalSemaphore.semHandleDesc = \ + (const hipExternalSemaphoreHandleDesc *)semHandleDesc; \ + }; +// hipInit[('unsigned int', 'flags')] +#define INIT_hipInit_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipInit.flags = (unsigned int)flags; \ + }; +// hipIpcCloseMemHandle[('void*', 'devPtr')] +#define INIT_hipIpcCloseMemHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipIpcCloseMemHandle.devPtr = (void *)dev_ptr; \ + }; +// hipIpcGetEventHandle[('hipIpcEventHandle_t*', 'handle'), ('hipEvent_t', +// 'event')] +#define INIT_hipIpcGetEventHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipIpcGetEventHandle.handle = (hipIpcEventHandle_t *)handle; \ + cb_data.args.hipIpcGetEventHandle.event = (hipEvent_t)event; \ + }; +// hipIpcGetMemHandle[('hipIpcMemHandle_t*', 'handle'), ('void*', 'devPtr')] +#define INIT_hipIpcGetMemHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipIpcGetMemHandle.handle = (hipIpcMemHandle_t *)handle; \ + cb_data.args.hipIpcGetMemHandle.devPtr = (void *)dev_ptr; \ + }; +// hipIpcOpenEventHandle[('hipEvent_t*', 'event'), ('hipIpcEventHandle_t', +// 'handle')] +#define INIT_hipIpcOpenEventHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipIpcOpenEventHandle.event = (hipEvent_t *)event; \ + cb_data.args.hipIpcOpenEventHandle.handle = (hipIpcEventHandle_t)handle; \ + }; +// hipIpcOpenMemHandle[('void**', 'devPtr'), ('hipIpcMemHandle_t', 'handle'), +// ('unsigned int', 'flags')] +#define INIT_hipIpcOpenMemHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipIpcOpenMemHandle.devPtr = (void **)dev_ptr; \ + cb_data.args.hipIpcOpenMemHandle.handle = (hipIpcMemHandle_t)handle; \ + cb_data.args.hipIpcOpenMemHandle.flags = (unsigned int)flags; \ + }; +// hipLaunchByPtr[('const void*', 'hostFunction')] +#define INIT_hipLaunchByPtr_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipLaunchByPtr.hostFunction = (const void *)hostFunction; \ + }; +// hipLaunchCooperativeKernel[('const void*', 'f'), ('dim3', 'gridDim'), +// ('dim3', 'blockDimX'), ('void**', 'kernelParams'), ('unsigned int', +// 'sharedMemBytes'), ('hipStream_t', 'stream')] +#define INIT_hipLaunchCooperativeKernel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipLaunchCooperativeKernel.f = (const void *)f; \ + cb_data.args.hipLaunchCooperativeKernel.gridDim = (dim3)gridDim; \ + cb_data.args.hipLaunchCooperativeKernel.blockDimX = (dim3)blockDim; \ + cb_data.args.hipLaunchCooperativeKernel.kernelParams = \ + (void **)kernelParams; \ + cb_data.args.hipLaunchCooperativeKernel.sharedMemBytes = \ + (unsigned int)sharedMemBytes; \ + cb_data.args.hipLaunchCooperativeKernel.stream = (hipStream_t)hStream; \ + }; +// hipLaunchCooperativeKernelMultiDevice[('hipLaunchParams*', +// 'launchParamsList'), ('int', 'numDevices'), ('unsigned int', 'flags')] +#define INIT_hipLaunchCooperativeKernelMultiDevice_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipLaunchCooperativeKernelMultiDevice.launchParamsList = \ + (hipLaunchParams *)launchParamsList; \ + cb_data.args.hipLaunchCooperativeKernelMultiDevice.numDevices = \ + (int)numDevices; \ + cb_data.args.hipLaunchCooperativeKernelMultiDevice.flags = \ + (unsigned int)flags; \ + }; +// hipLaunchHostFunc[('hipStream_t', 'stream'), ('hipHostFn_t', 'fn'), ('void*', +// 'userData')] +#define INIT_hipLaunchHostFunc_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipLaunchHostFunc.stream = (hipStream_t)stream; \ + cb_data.args.hipLaunchHostFunc.fn = (hipHostFn_t)fn; \ + cb_data.args.hipLaunchHostFunc.userData = (void *)userData; \ + }; +// hipLaunchKernel[('const void*', 'function_address'), ('dim3', 'numBlocks'), +// ('dim3', 'dimBlocks'), ('void**', 'args'), ('size_t', 'sharedMemBytes'), +// ('hipStream_t', 'stream')] +#define INIT_hipLaunchKernel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipLaunchKernel.function_address = \ + (const void *)hostFunction; \ + cb_data.args.hipLaunchKernel.numBlocks = (dim3)gridDim; \ + cb_data.args.hipLaunchKernel.dimBlocks = (dim3)blockDim; \ + cb_data.args.hipLaunchKernel.args = (void **)args; \ + cb_data.args.hipLaunchKernel.sharedMemBytes = (size_t)sharedMemBytes; \ + cb_data.args.hipLaunchKernel.stream = (hipStream_t)stream; \ + }; +// hipMalloc[('void**', 'ptr'), ('size_t', 'size')] +#define INIT_hipMalloc_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMalloc.ptr = (void **)ptr; \ + cb_data.args.hipMalloc.size = (size_t)sizeBytes; \ + }; +// hipMalloc3D[('hipPitchedPtr*', 'pitchedDevPtr'), ('hipExtent', 'extent')] +#define INIT_hipMalloc3D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMalloc3D.pitchedDevPtr = (hipPitchedPtr *)pitchedDevPtr; \ + cb_data.args.hipMalloc3D.extent = (hipExtent)extent; \ + }; +// hipMalloc3DArray[('hipArray_t*', 'array'), ('const hipChannelFormatDesc*', +// 'desc'), ('hipExtent', 'extent'), ('unsigned int', 'flags')] +#define INIT_hipMalloc3DArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMalloc3DArray.array = (hipArray_t *)array; \ + cb_data.args.hipMalloc3DArray.desc = (const hipChannelFormatDesc *)desc; \ + cb_data.args.hipMalloc3DArray.extent = (hipExtent)extent; \ + cb_data.args.hipMalloc3DArray.flags = (unsigned int)flags; \ + }; +// hipMallocArray[('hipArray_t*', 'array'), ('const hipChannelFormatDesc*', +// 'desc'), ('size_t', 'width'), ('size_t', 'height'), ('unsigned int', +// 'flags')] +#define INIT_hipMallocArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMallocArray.array = (hipArray_t *)array; \ + cb_data.args.hipMallocArray.desc = (const hipChannelFormatDesc *)desc; \ + cb_data.args.hipMallocArray.width = (size_t)width; \ + cb_data.args.hipMallocArray.height = (size_t)height; \ + cb_data.args.hipMallocArray.flags = (unsigned int)flags; \ + }; +// hipMallocAsync[('void**', 'dev_ptr'), ('size_t', 'size'), ('hipStream_t', +// 'stream')] +#define INIT_hipMallocAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMallocAsync.dev_ptr = (void **)dev_ptr; \ + cb_data.args.hipMallocAsync.size = (size_t)size; \ + cb_data.args.hipMallocAsync.stream = (hipStream_t)stream; \ + }; +// hipMallocFromPoolAsync[('void**', 'dev_ptr'), ('size_t', 'size'), +// ('hipMemPool_t', 'mem_pool'), ('hipStream_t', 'stream')] +#define INIT_hipMallocFromPoolAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMallocFromPoolAsync.dev_ptr = (void **)dev_ptr; \ + cb_data.args.hipMallocFromPoolAsync.size = (size_t)size; \ + cb_data.args.hipMallocFromPoolAsync.mem_pool = (hipMemPool_t)mem_pool; \ + cb_data.args.hipMallocFromPoolAsync.stream = (hipStream_t)stream; \ + }; +// hipMallocHost[('void**', 'ptr'), ('size_t', 'size')] +#define INIT_hipMallocHost_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMallocHost.ptr = (void **)ptr; \ + cb_data.args.hipMallocHost.size = (size_t)size; \ + }; +// hipMallocManaged[('void**', 'dev_ptr'), ('size_t', 'size'), ('unsigned int', +// 'flags')] +#define INIT_hipMallocManaged_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMallocManaged.dev_ptr = (void **)dev_ptr; \ + cb_data.args.hipMallocManaged.size = (size_t)size; \ + cb_data.args.hipMallocManaged.flags = (unsigned int)flags; \ + }; +// hipMallocMipmappedArray[('hipMipmappedArray_t*', 'mipmappedArray'), ('const +// hipChannelFormatDesc*', 'desc'), ('hipExtent', 'extent'), ('unsigned int', +// 'numLevels'), ('unsigned int', 'flags')] +#define INIT_hipMallocMipmappedArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMallocMipmappedArray.mipmappedArray = \ + (hipMipmappedArray_t *)mipmappedArray; \ + cb_data.args.hipMallocMipmappedArray.desc = \ + (const hipChannelFormatDesc *)desc; \ + cb_data.args.hipMallocMipmappedArray.extent = (hipExtent)extent; \ + cb_data.args.hipMallocMipmappedArray.numLevels = (unsigned int)numLevels; \ + cb_data.args.hipMallocMipmappedArray.flags = (unsigned int)flags; \ + }; +// hipMallocPitch[('void**', 'ptr'), ('size_t*', 'pitch'), ('size_t', 'width'), +// ('size_t', 'height')] +#define INIT_hipMallocPitch_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMallocPitch.ptr = (void **)ptr; \ + cb_data.args.hipMallocPitch.pitch = (size_t *)pitch; \ + cb_data.args.hipMallocPitch.width = (size_t)width; \ + cb_data.args.hipMallocPitch.height = (size_t)height; \ + }; +// hipMemAddressFree[('void*', 'devPtr'), ('size_t', 'size')] +#define INIT_hipMemAddressFree_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemAddressFree.devPtr = (void *)devPtr; \ + cb_data.args.hipMemAddressFree.size = (size_t)size; \ + }; +// hipMemAddressReserve[('void**', 'ptr'), ('size_t', 'size'), ('size_t', +// 'alignment'), ('void*', 'addr'), ('unsigned long long', 'flags')] +#define INIT_hipMemAddressReserve_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemAddressReserve.ptr = (void **)ptr; \ + cb_data.args.hipMemAddressReserve.size = (size_t)size; \ + cb_data.args.hipMemAddressReserve.alignment = (size_t)alignment; \ + cb_data.args.hipMemAddressReserve.addr = (void *)addr; \ + cb_data.args.hipMemAddressReserve.flags = (unsigned long long)flags; \ + }; +// hipMemAdvise[('const void*', 'dev_ptr'), ('size_t', 'count'), +// ('hipMemoryAdvise', 'advice'), ('int', 'device')] +#define INIT_hipMemAdvise_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemAdvise.dev_ptr = (const void *)dev_ptr; \ + cb_data.args.hipMemAdvise.count = (size_t)count; \ + cb_data.args.hipMemAdvise.advice = (hipMemoryAdvise)advice; \ + cb_data.args.hipMemAdvise.device = (int)device; \ + }; +// hipMemAllocHost[('void**', 'ptr'), ('size_t', 'size')] +#define INIT_hipMemAllocHost_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemAllocHost.ptr = (void **)ptr; \ + cb_data.args.hipMemAllocHost.size = (size_t)size; \ + }; +// hipMemAllocPitch[('hipDeviceptr_t*', 'dptr'), ('size_t*', 'pitch'), +// ('size_t', 'widthInBytes'), ('size_t', 'height'), ('unsigned int', +// 'elementSizeBytes')] +#define INIT_hipMemAllocPitch_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemAllocPitch.dptr = (hipDeviceptr_t *)dptr; \ + cb_data.args.hipMemAllocPitch.pitch = (size_t *)pitch; \ + cb_data.args.hipMemAllocPitch.widthInBytes = (size_t)widthInBytes; \ + cb_data.args.hipMemAllocPitch.height = (size_t)height; \ + cb_data.args.hipMemAllocPitch.elementSizeBytes = \ + (unsigned int)elementSizeBytes; \ + }; +// hipMemCreate[('hipMemGenericAllocationHandle_t*', 'handle'), ('size_t', +// 'size'), ('const hipMemAllocationProp*', 'prop'), ('unsigned long long', +// 'flags')] +#define INIT_hipMemCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemCreate.handle = \ + (hipMemGenericAllocationHandle_t *)handle; \ + cb_data.args.hipMemCreate.size = (size_t)size; \ + cb_data.args.hipMemCreate.prop = (const hipMemAllocationProp *)prop; \ + cb_data.args.hipMemCreate.flags = (unsigned long long)flags; \ + }; +// hipMemExportToShareableHandle[('void*', 'shareableHandle'), +// ('hipMemGenericAllocationHandle_t', 'handle'), ('hipMemAllocationHandleType', +// 'handleType'), ('unsigned long long', 'flags')] +#define INIT_hipMemExportToShareableHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemExportToShareableHandle.shareableHandle = \ + (void *)shareableHandle; \ + cb_data.args.hipMemExportToShareableHandle.handle = \ + (hipMemGenericAllocationHandle_t)handle; \ + cb_data.args.hipMemExportToShareableHandle.handleType = \ + (hipMemAllocationHandleType)handleType; \ + cb_data.args.hipMemExportToShareableHandle.flags = \ + (unsigned long long)flags; \ + }; +// hipMemGetAccess[('unsigned long long*', 'flags'), ('const hipMemLocation*', +// 'location'), ('void*', 'ptr')] +#define INIT_hipMemGetAccess_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemGetAccess.flags = (unsigned long long *)flags; \ + cb_data.args.hipMemGetAccess.location = (const hipMemLocation *)location; \ + cb_data.args.hipMemGetAccess.ptr = (void *)ptr; \ + }; +// hipMemGetAddressRange[('hipDeviceptr_t*', 'pbase'), ('size_t*', 'psize'), +// ('hipDeviceptr_t', 'dptr')] +#define INIT_hipMemGetAddressRange_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemGetAddressRange.pbase = (hipDeviceptr_t *)pbase; \ + cb_data.args.hipMemGetAddressRange.psize = (size_t *)psize; \ + cb_data.args.hipMemGetAddressRange.dptr = (hipDeviceptr_t)dptr; \ + }; +// hipMemGetAllocationGranularity[('size_t*', 'granularity'), ('const +// hipMemAllocationProp*', 'prop'), ('hipMemAllocationGranularity_flags', +// 'option')] +#define INIT_hipMemGetAllocationGranularity_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemGetAllocationGranularity.granularity = \ + (size_t *)granularity; \ + cb_data.args.hipMemGetAllocationGranularity.prop = \ + (const hipMemAllocationProp *)prop; \ + cb_data.args.hipMemGetAllocationGranularity.option = \ + (hipMemAllocationGranularity_flags)option; \ + }; +// hipMemGetAllocationPropertiesFromHandle[('hipMemAllocationProp*', 'prop'), +// ('hipMemGenericAllocationHandle_t', 'handle')] +#define INIT_hipMemGetAllocationPropertiesFromHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemGetAllocationPropertiesFromHandle.prop = \ + (hipMemAllocationProp *)prop; \ + cb_data.args.hipMemGetAllocationPropertiesFromHandle.handle = \ + (hipMemGenericAllocationHandle_t)handle; \ + }; +// hipMemGetInfo[('size_t*', 'free'), ('size_t*', 'total')] +#define INIT_hipMemGetInfo_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemGetInfo.free = (size_t *)free; \ + cb_data.args.hipMemGetInfo.total = (size_t *)total; \ + }; +// hipMemImportFromShareableHandle[('hipMemGenericAllocationHandle_t*', +// 'handle'), ('void*', 'osHandle'), ('hipMemAllocationHandleType', +// 'shHandleType')] +#define INIT_hipMemImportFromShareableHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemImportFromShareableHandle.handle = \ + (hipMemGenericAllocationHandle_t *)handle; \ + cb_data.args.hipMemImportFromShareableHandle.osHandle = (void *)osHandle; \ + cb_data.args.hipMemImportFromShareableHandle.shHandleType = \ + (hipMemAllocationHandleType)shHandleType; \ + }; +// hipMemMap[('void*', 'ptr'), ('size_t', 'size'), ('size_t', 'offset'), +// ('hipMemGenericAllocationHandle_t', 'handle'), ('unsigned long long', +// 'flags')] +#define INIT_hipMemMap_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemMap.ptr = (void *)ptr; \ + cb_data.args.hipMemMap.size = (size_t)size; \ + cb_data.args.hipMemMap.offset = (size_t)offset; \ + cb_data.args.hipMemMap.handle = (hipMemGenericAllocationHandle_t)handle; \ + cb_data.args.hipMemMap.flags = (unsigned long long)flags; \ + }; +// hipMemMapArrayAsync[('hipArrayMapInfo*', 'mapInfoList'), ('unsigned int', +// 'count'), ('hipStream_t', 'stream')] +#define INIT_hipMemMapArrayAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemMapArrayAsync.mapInfoList = \ + (hipArrayMapInfo *)mapInfoList; \ + cb_data.args.hipMemMapArrayAsync.count = (unsigned int)count; \ + cb_data.args.hipMemMapArrayAsync.stream = (hipStream_t)stream; \ + }; +// hipMemPoolCreate[('hipMemPool_t*', 'mem_pool'), ('const hipMemPoolProps*', +// 'pool_props')] +#define INIT_hipMemPoolCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolCreate.mem_pool = (hipMemPool_t *)mem_pool; \ + cb_data.args.hipMemPoolCreate.pool_props = \ + (const hipMemPoolProps *)pool_props; \ + }; +// hipMemPoolDestroy[('hipMemPool_t', 'mem_pool')] +#define INIT_hipMemPoolDestroy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolDestroy.mem_pool = (hipMemPool_t)mem_pool; \ + }; +// hipMemPoolExportPointer[('hipMemPoolPtrExportData*', 'export_data'), +// ('void*', 'dev_ptr')] +#define INIT_hipMemPoolExportPointer_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolExportPointer.export_data = \ + (hipMemPoolPtrExportData *)export_data; \ + cb_data.args.hipMemPoolExportPointer.dev_ptr = (void *)ptr; \ + }; +// hipMemPoolExportToShareableHandle[('void*', 'shared_handle'), +// ('hipMemPool_t', 'mem_pool'), ('hipMemAllocationHandleType', 'handle_type'), +// ('unsigned int', 'flags')] +#define INIT_hipMemPoolExportToShareableHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolExportToShareableHandle.shared_handle = \ + (void *)shared_handle; \ + cb_data.args.hipMemPoolExportToShareableHandle.mem_pool = \ + (hipMemPool_t)mem_pool; \ + cb_data.args.hipMemPoolExportToShareableHandle.handle_type = \ + (hipMemAllocationHandleType)handle_type; \ + cb_data.args.hipMemPoolExportToShareableHandle.flags = \ + (unsigned int)flags; \ + }; +// hipMemPoolGetAccess[('hipMemAccessFlags*', 'flags'), ('hipMemPool_t', +// 'mem_pool'), ('hipMemLocation*', 'location')] +#define INIT_hipMemPoolGetAccess_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolGetAccess.flags = (hipMemAccessFlags *)flags; \ + cb_data.args.hipMemPoolGetAccess.mem_pool = (hipMemPool_t)mem_pool; \ + cb_data.args.hipMemPoolGetAccess.location = (hipMemLocation *)location; \ + }; +// hipMemPoolGetAttribute[('hipMemPool_t', 'mem_pool'), ('hipMemPoolAttr', +// 'attr'), ('void*', 'value')] +#define INIT_hipMemPoolGetAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolGetAttribute.mem_pool = (hipMemPool_t)mem_pool; \ + cb_data.args.hipMemPoolGetAttribute.attr = (hipMemPoolAttr)attr; \ + cb_data.args.hipMemPoolGetAttribute.value = (void *)value; \ + }; +// hipMemPoolImportFromShareableHandle[('hipMemPool_t*', 'mem_pool'), ('void*', +// 'shared_handle'), ('hipMemAllocationHandleType', 'handle_type'), ('unsigned +// int', 'flags')] +#define INIT_hipMemPoolImportFromShareableHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolImportFromShareableHandle.mem_pool = \ + (hipMemPool_t *)mem_pool; \ + cb_data.args.hipMemPoolImportFromShareableHandle.shared_handle = \ + (void *)shared_handle; \ + cb_data.args.hipMemPoolImportFromShareableHandle.handle_type = \ + (hipMemAllocationHandleType)handle_type; \ + cb_data.args.hipMemPoolImportFromShareableHandle.flags = \ + (unsigned int)flags; \ + }; +// hipMemPoolImportPointer[('void**', 'dev_ptr'), ('hipMemPool_t', 'mem_pool'), +// ('hipMemPoolPtrExportData*', 'export_data')] +#define INIT_hipMemPoolImportPointer_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolImportPointer.dev_ptr = (void **)ptr; \ + cb_data.args.hipMemPoolImportPointer.mem_pool = (hipMemPool_t)mem_pool; \ + cb_data.args.hipMemPoolImportPointer.export_data = \ + (hipMemPoolPtrExportData *)export_data; \ + }; +// hipMemPoolSetAccess[('hipMemPool_t', 'mem_pool'), ('const hipMemAccessDesc*', +// 'desc_list'), ('size_t', 'count')] +#define INIT_hipMemPoolSetAccess_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolSetAccess.mem_pool = (hipMemPool_t)mem_pool; \ + cb_data.args.hipMemPoolSetAccess.desc_list = \ + (const hipMemAccessDesc *)desc_list; \ + cb_data.args.hipMemPoolSetAccess.count = (size_t)count; \ + }; +// hipMemPoolSetAttribute[('hipMemPool_t', 'mem_pool'), ('hipMemPoolAttr', +// 'attr'), ('void*', 'value')] +#define INIT_hipMemPoolSetAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolSetAttribute.mem_pool = (hipMemPool_t)mem_pool; \ + cb_data.args.hipMemPoolSetAttribute.attr = (hipMemPoolAttr)attr; \ + cb_data.args.hipMemPoolSetAttribute.value = (void *)value; \ + }; +// hipMemPoolTrimTo[('hipMemPool_t', 'mem_pool'), ('size_t', +// 'min_bytes_to_hold')] +#define INIT_hipMemPoolTrimTo_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPoolTrimTo.mem_pool = (hipMemPool_t)mem_pool; \ + cb_data.args.hipMemPoolTrimTo.min_bytes_to_hold = \ + (size_t)min_bytes_to_hold; \ + }; +// hipMemPrefetchAsync[('const void*', 'dev_ptr'), ('size_t', 'count'), ('int', +// 'device'), ('hipStream_t', 'stream')] +#define INIT_hipMemPrefetchAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPrefetchAsync.dev_ptr = (const void *)dev_ptr; \ + cb_data.args.hipMemPrefetchAsync.count = (size_t)count; \ + cb_data.args.hipMemPrefetchAsync.device = (int)device; \ + cb_data.args.hipMemPrefetchAsync.stream = (hipStream_t)stream; \ + }; +// hipMemPtrGetInfo[('void*', 'ptr'), ('size_t*', 'size')] +#define INIT_hipMemPtrGetInfo_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemPtrGetInfo.ptr = (void *)ptr; \ + cb_data.args.hipMemPtrGetInfo.size = (size_t *)size; \ + }; +// hipMemRangeGetAttribute[('void*', 'data'), ('size_t', 'data_size'), +// ('hipMemRangeAttribute', 'attribute'), ('const void*', 'dev_ptr'), ('size_t', +// 'count')] +#define INIT_hipMemRangeGetAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemRangeGetAttribute.data = (void *)data; \ + cb_data.args.hipMemRangeGetAttribute.data_size = (size_t)data_size; \ + cb_data.args.hipMemRangeGetAttribute.attribute = \ + (hipMemRangeAttribute)attribute; \ + cb_data.args.hipMemRangeGetAttribute.dev_ptr = (const void *)dev_ptr; \ + cb_data.args.hipMemRangeGetAttribute.count = (size_t)count; \ + }; +// hipMemRangeGetAttributes[('void**', 'data'), ('size_t*', 'data_sizes'), +// ('hipMemRangeAttribute*', 'attributes'), ('size_t', 'num_attributes'), +// ('const void*', 'dev_ptr'), ('size_t', 'count')] +#define INIT_hipMemRangeGetAttributes_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemRangeGetAttributes.data = (void **)data; \ + cb_data.args.hipMemRangeGetAttributes.data_sizes = (size_t *)data_sizes; \ + cb_data.args.hipMemRangeGetAttributes.attributes = \ + (hipMemRangeAttribute *)attributes; \ + cb_data.args.hipMemRangeGetAttributes.num_attributes = \ + (size_t)num_attributes; \ + cb_data.args.hipMemRangeGetAttributes.dev_ptr = (const void *)dev_ptr; \ + cb_data.args.hipMemRangeGetAttributes.count = (size_t)count; \ + }; +// hipMemRelease[('hipMemGenericAllocationHandle_t', 'handle')] +#define INIT_hipMemRelease_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemRelease.handle = \ + (hipMemGenericAllocationHandle_t)handle; \ + }; +// hipMemRetainAllocationHandle[('hipMemGenericAllocationHandle_t*', 'handle'), +// ('void*', 'addr')] +#define INIT_hipMemRetainAllocationHandle_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemRetainAllocationHandle.handle = \ + (hipMemGenericAllocationHandle_t *)handle; \ + cb_data.args.hipMemRetainAllocationHandle.addr = (void *)addr; \ + }; +// hipMemSetAccess[('void*', 'ptr'), ('size_t', 'size'), ('const +// hipMemAccessDesc*', 'desc'), ('size_t', 'count')] +#define INIT_hipMemSetAccess_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemSetAccess.ptr = (void *)ptr; \ + cb_data.args.hipMemSetAccess.size = (size_t)size; \ + cb_data.args.hipMemSetAccess.desc = (const hipMemAccessDesc *)desc; \ + cb_data.args.hipMemSetAccess.count = (size_t)count; \ + }; +// hipMemUnmap[('void*', 'ptr'), ('size_t', 'size')] +#define INIT_hipMemUnmap_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemUnmap.ptr = (void *)ptr; \ + cb_data.args.hipMemUnmap.size = (size_t)size; \ + }; +// hipMemcpy[('void*', 'dst'), ('const void*', 'src'), ('size_t', 'sizeBytes'), +// ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy.dst = (void *)dst; \ + cb_data.args.hipMemcpy.src = (const void *)src; \ + cb_data.args.hipMemcpy.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemcpy.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpy2D[('void*', 'dst'), ('size_t', 'dpitch'), ('const void*', 'src'), +// ('size_t', 'spitch'), ('size_t', 'width'), ('size_t', 'height'), +// ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpy2D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy2D.dst = (void *)dst; \ + cb_data.args.hipMemcpy2D.dpitch = (size_t)dpitch; \ + cb_data.args.hipMemcpy2D.src = (const void *)src; \ + cb_data.args.hipMemcpy2D.spitch = (size_t)spitch; \ + cb_data.args.hipMemcpy2D.width = (size_t)width; \ + cb_data.args.hipMemcpy2D.height = (size_t)height; \ + cb_data.args.hipMemcpy2D.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpy2DArrayToArray[('hipArray_t', 'dst'), ('size_t', 'wOffsetDst'), +// ('size_t', 'hOffsetDst'), ('hipArray_const_t', 'src'), ('size_t', +// 'wOffsetSrc'), ('size_t', 'hOffsetSrc'), ('size_t', 'width'), ('size_t', +// 'height'), ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpy2DArrayToArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy2DArrayToArray.dst = (hipArray_t)dst; \ + cb_data.args.hipMemcpy2DArrayToArray.wOffsetDst = (size_t)wOffsetDst; \ + cb_data.args.hipMemcpy2DArrayToArray.hOffsetDst = (size_t)hOffsetDst; \ + cb_data.args.hipMemcpy2DArrayToArray.src = (hipArray_const_t)src; \ + cb_data.args.hipMemcpy2DArrayToArray.wOffsetSrc = (size_t)wOffsetSrc; \ + cb_data.args.hipMemcpy2DArrayToArray.hOffsetSrc = (size_t)hOffsetSrc; \ + cb_data.args.hipMemcpy2DArrayToArray.width = (size_t)width; \ + cb_data.args.hipMemcpy2DArrayToArray.height = (size_t)height; \ + cb_data.args.hipMemcpy2DArrayToArray.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpy2DAsync[('void*', 'dst'), ('size_t', 'dpitch'), ('const void*', +// 'src'), ('size_t', 'spitch'), ('size_t', 'width'), ('size_t', 'height'), +// ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpy2DAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy2DAsync.dst = (void *)dst; \ + cb_data.args.hipMemcpy2DAsync.dpitch = (size_t)dpitch; \ + cb_data.args.hipMemcpy2DAsync.src = (const void *)src; \ + cb_data.args.hipMemcpy2DAsync.spitch = (size_t)spitch; \ + cb_data.args.hipMemcpy2DAsync.width = (size_t)width; \ + cb_data.args.hipMemcpy2DAsync.height = (size_t)height; \ + cb_data.args.hipMemcpy2DAsync.kind = (hipMemcpyKind)kind; \ + cb_data.args.hipMemcpy2DAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpy2DFromArray[('void*', 'dst'), ('size_t', 'dpitch'), +// ('hipArray_const_t', 'src'), ('size_t', 'wOffset'), ('size_t', 'hOffset'), +// ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpy2DFromArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy2DFromArray.dst = (void *)dst; \ + cb_data.args.hipMemcpy2DFromArray.dpitch = (size_t)dpitch; \ + cb_data.args.hipMemcpy2DFromArray.src = (hipArray_const_t)src; \ + cb_data.args.hipMemcpy2DFromArray.wOffset = (size_t)wOffsetSrc; \ + cb_data.args.hipMemcpy2DFromArray.hOffset = (size_t)hOffset; \ + cb_data.args.hipMemcpy2DFromArray.width = (size_t)width; \ + cb_data.args.hipMemcpy2DFromArray.height = (size_t)height; \ + cb_data.args.hipMemcpy2DFromArray.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpy2DFromArrayAsync[('void*', 'dst'), ('size_t', 'dpitch'), +// ('hipArray_const_t', 'src'), ('size_t', 'wOffset'), ('size_t', 'hOffset'), +// ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind'), +// ('hipStream_t', 'stream')] +#define INIT_hipMemcpy2DFromArrayAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy2DFromArrayAsync.dst = (void *)dst; \ + cb_data.args.hipMemcpy2DFromArrayAsync.dpitch = (size_t)dpitch; \ + cb_data.args.hipMemcpy2DFromArrayAsync.src = (hipArray_const_t)src; \ + cb_data.args.hipMemcpy2DFromArrayAsync.wOffset = (size_t)wOffsetSrc; \ + cb_data.args.hipMemcpy2DFromArrayAsync.hOffset = (size_t)hOffsetSrc; \ + cb_data.args.hipMemcpy2DFromArrayAsync.width = (size_t)width; \ + cb_data.args.hipMemcpy2DFromArrayAsync.height = (size_t)height; \ + cb_data.args.hipMemcpy2DFromArrayAsync.kind = (hipMemcpyKind)kind; \ + cb_data.args.hipMemcpy2DFromArrayAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpy2DToArray[('hipArray_t', 'dst'), ('size_t', 'wOffset'), ('size_t', +// 'hOffset'), ('const void*', 'src'), ('size_t', 'spitch'), ('size_t', +// 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpy2DToArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy2DToArray.dst = (hipArray_t)dst; \ + cb_data.args.hipMemcpy2DToArray.wOffset = (size_t)wOffset; \ + cb_data.args.hipMemcpy2DToArray.hOffset = (size_t)hOffset; \ + cb_data.args.hipMemcpy2DToArray.src = (const void *)src; \ + cb_data.args.hipMemcpy2DToArray.spitch = (size_t)spitch; \ + cb_data.args.hipMemcpy2DToArray.width = (size_t)width; \ + cb_data.args.hipMemcpy2DToArray.height = (size_t)height; \ + cb_data.args.hipMemcpy2DToArray.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpy2DToArrayAsync[('hipArray_t', 'dst'), ('size_t', 'wOffset'), +// ('size_t', 'hOffset'), ('const void*', 'src'), ('size_t', 'spitch'), +// ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind'), +// ('hipStream_t', 'stream')] +#define INIT_hipMemcpy2DToArrayAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy2DToArrayAsync.dst = (hipArray_t)dst; \ + cb_data.args.hipMemcpy2DToArrayAsync.wOffset = (size_t)wOffset; \ + cb_data.args.hipMemcpy2DToArrayAsync.hOffset = (size_t)hOffset; \ + cb_data.args.hipMemcpy2DToArrayAsync.src = (const void *)src; \ + cb_data.args.hipMemcpy2DToArrayAsync.spitch = (size_t)spitch; \ + cb_data.args.hipMemcpy2DToArrayAsync.width = (size_t)width; \ + cb_data.args.hipMemcpy2DToArrayAsync.height = (size_t)height; \ + cb_data.args.hipMemcpy2DToArrayAsync.kind = (hipMemcpyKind)kind; \ + cb_data.args.hipMemcpy2DToArrayAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpy3D[('const hipMemcpy3DParms*', 'p')] +#define INIT_hipMemcpy3D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy3D.p = (const hipMemcpy3DParms *)p; \ + }; +// hipMemcpy3DAsync[('const hipMemcpy3DParms*', 'p'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpy3DAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpy3DAsync.p = (const hipMemcpy3DParms *)p; \ + cb_data.args.hipMemcpy3DAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyAsync[('void*', 'dst'), ('const void*', 'src'), ('size_t', +// 'sizeBytes'), ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpyAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyAsync.dst = (void *)dst; \ + cb_data.args.hipMemcpyAsync.src = (const void *)src; \ + cb_data.args.hipMemcpyAsync.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemcpyAsync.kind = (hipMemcpyKind)kind; \ + cb_data.args.hipMemcpyAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyAtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), +// ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'ByteCount')] +#define INIT_hipMemcpyAtoA_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyAtoA.dstArray = (hipArray_t)dstArray; \ + cb_data.args.hipMemcpyAtoA.dstOffset = (size_t)dstOffset; \ + cb_data.args.hipMemcpyAtoA.srcArray = (hipArray_t)srcArray; \ + cb_data.args.hipMemcpyAtoA.srcOffset = (size_t)srcOffset; \ + cb_data.args.hipMemcpyAtoA.ByteCount = (size_t)ByteCount; \ + }; +// hipMemcpyAtoD[('hipDeviceptr_t', 'dstDevice'), ('hipArray_t', 'srcArray'), +// ('size_t', 'srcOffset'), ('size_t', 'ByteCount')] +#define INIT_hipMemcpyAtoD_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyAtoD.dstDevice = (hipDeviceptr_t)dstDevice; \ + cb_data.args.hipMemcpyAtoD.srcArray = (hipArray_t)srcArray; \ + cb_data.args.hipMemcpyAtoD.srcOffset = (size_t)srcOffset; \ + cb_data.args.hipMemcpyAtoD.ByteCount = (size_t)ByteCount; \ + }; +// hipMemcpyAtoH[('void*', 'dst'), ('hipArray_t', 'srcArray'), ('size_t', +// 'srcOffset'), ('size_t', 'count')] +#define INIT_hipMemcpyAtoH_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyAtoH.dst = (void *)dstHost; \ + cb_data.args.hipMemcpyAtoH.srcArray = (hipArray_t)srcArray; \ + cb_data.args.hipMemcpyAtoH.srcOffset = (size_t)srcOffset; \ + cb_data.args.hipMemcpyAtoH.count = (size_t)ByteCount; \ + }; +// hipMemcpyAtoHAsync[('void*', 'dstHost'), ('hipArray_t', 'srcArray'), +// ('size_t', 'srcOffset'), ('size_t', 'ByteCount'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpyAtoHAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyAtoHAsync.dstHost = (void *)dstHost; \ + cb_data.args.hipMemcpyAtoHAsync.srcArray = (hipArray_t)srcArray; \ + cb_data.args.hipMemcpyAtoHAsync.srcOffset = (size_t)srcOffset; \ + cb_data.args.hipMemcpyAtoHAsync.ByteCount = (size_t)ByteCount; \ + cb_data.args.hipMemcpyAtoHAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyDtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), +// ('hipDeviceptr_t', 'srcDevice'), ('size_t', 'ByteCount')] +#define INIT_hipMemcpyDtoA_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyDtoA.dstArray = (hipArray_t)dstArray; \ + cb_data.args.hipMemcpyDtoA.dstOffset = (size_t)dstOffset; \ + cb_data.args.hipMemcpyDtoA.srcDevice = (hipDeviceptr_t)srcDevice; \ + cb_data.args.hipMemcpyDtoA.ByteCount = (size_t)ByteCount; \ + }; +// hipMemcpyDtoD[('hipDeviceptr_t', 'dst'), ('hipDeviceptr_t', 'src'), +// ('size_t', 'sizeBytes')] +#define INIT_hipMemcpyDtoD_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyDtoD.dst = (hipDeviceptr_t)dstDevice; \ + cb_data.args.hipMemcpyDtoD.src = (hipDeviceptr_t)srcDevice; \ + cb_data.args.hipMemcpyDtoD.sizeBytes = (size_t)ByteCount; \ + }; +// hipMemcpyDtoDAsync[('hipDeviceptr_t', 'dst'), ('hipDeviceptr_t', 'src'), +// ('size_t', 'sizeBytes'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpyDtoDAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyDtoDAsync.dst = (hipDeviceptr_t)dstDevice; \ + cb_data.args.hipMemcpyDtoDAsync.src = (hipDeviceptr_t)srcDevice; \ + cb_data.args.hipMemcpyDtoDAsync.sizeBytes = (size_t)ByteCount; \ + cb_data.args.hipMemcpyDtoDAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyDtoH[('void*', 'dst'), ('hipDeviceptr_t', 'src'), ('size_t', +// 'sizeBytes')] +#define INIT_hipMemcpyDtoH_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyDtoH.dst = (void *)dstHost; \ + cb_data.args.hipMemcpyDtoH.src = (hipDeviceptr_t)srcDevice; \ + cb_data.args.hipMemcpyDtoH.sizeBytes = (size_t)ByteCount; \ + }; +// hipMemcpyDtoHAsync[('void*', 'dst'), ('hipDeviceptr_t', 'src'), ('size_t', +// 'sizeBytes'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpyDtoHAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyDtoHAsync.dst = (void *)dstHost; \ + cb_data.args.hipMemcpyDtoHAsync.src = (hipDeviceptr_t)srcDevice; \ + cb_data.args.hipMemcpyDtoHAsync.sizeBytes = (size_t)ByteCount; \ + cb_data.args.hipMemcpyDtoHAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyFromArray[('void*', 'dst'), ('hipArray_const_t', 'srcArray'), +// ('size_t', 'wOffset'), ('size_t', 'hOffset'), ('size_t', 'count'), +// ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpyFromArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyFromArray.dst = (void *)dst; \ + cb_data.args.hipMemcpyFromArray.srcArray = (hipArray_const_t)src; \ + cb_data.args.hipMemcpyFromArray.wOffset = (size_t)wOffsetSrc; \ + cb_data.args.hipMemcpyFromArray.hOffset = (size_t)hOffset; \ + cb_data.args.hipMemcpyFromArray.count = (size_t)count; \ + cb_data.args.hipMemcpyFromArray.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpyFromSymbol[('void*', 'dst'), ('const void*', 'symbol'), ('size_t', +// 'sizeBytes'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpyFromSymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyFromSymbol.dst = (void *)dst; \ + cb_data.args.hipMemcpyFromSymbol.symbol = (const void *)symbol; \ + cb_data.args.hipMemcpyFromSymbol.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemcpyFromSymbol.offset = (size_t)offset; \ + cb_data.args.hipMemcpyFromSymbol.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpyFromSymbolAsync[('void*', 'dst'), ('const void*', 'symbol'), +// ('size_t', 'sizeBytes'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind'), +// ('hipStream_t', 'stream')] +#define INIT_hipMemcpyFromSymbolAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyFromSymbolAsync.dst = (void *)dst; \ + cb_data.args.hipMemcpyFromSymbolAsync.symbol = (const void *)symbol; \ + cb_data.args.hipMemcpyFromSymbolAsync.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemcpyFromSymbolAsync.offset = (size_t)offset; \ + cb_data.args.hipMemcpyFromSymbolAsync.kind = (hipMemcpyKind)kind; \ + cb_data.args.hipMemcpyFromSymbolAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyHtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), ('const +// void*', 'srcHost'), ('size_t', 'count')] +#define INIT_hipMemcpyHtoA_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyHtoA.dstArray = (hipArray_t)dstArray; \ + cb_data.args.hipMemcpyHtoA.dstOffset = (size_t)dstOffset; \ + cb_data.args.hipMemcpyHtoA.srcHost = (const void *)srcHost; \ + cb_data.args.hipMemcpyHtoA.count = (size_t)ByteCount; \ + }; +// hipMemcpyHtoAAsync[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), +// ('const void*', 'srcHost'), ('size_t', 'ByteCount'), ('hipStream_t', +// 'stream')] +#define INIT_hipMemcpyHtoAAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyHtoAAsync.dstArray = (hipArray_t)dstArray; \ + cb_data.args.hipMemcpyHtoAAsync.dstOffset = (size_t)dstOffset; \ + cb_data.args.hipMemcpyHtoAAsync.srcHost = (const void *)srcHost; \ + cb_data.args.hipMemcpyHtoAAsync.ByteCount = (size_t)ByteCount; \ + cb_data.args.hipMemcpyHtoAAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyHtoD[('hipDeviceptr_t', 'dst'), ('void*', 'src'), ('size_t', +// 'sizeBytes')] +#define INIT_hipMemcpyHtoD_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyHtoD.dst = (hipDeviceptr_t)dstDevice; \ + cb_data.args.hipMemcpyHtoD.src = (void *)srcHost; \ + cb_data.args.hipMemcpyHtoD.sizeBytes = (size_t)ByteCount; \ + }; +// hipMemcpyHtoDAsync[('hipDeviceptr_t', 'dst'), ('void*', 'src'), ('size_t', +// 'sizeBytes'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpyHtoDAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyHtoDAsync.dst = (hipDeviceptr_t)dstDevice; \ + cb_data.args.hipMemcpyHtoDAsync.src = (void *)srcHost; \ + cb_data.args.hipMemcpyHtoDAsync.sizeBytes = (size_t)ByteCount; \ + cb_data.args.hipMemcpyHtoDAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyParam2D[('const hip_Memcpy2D*', 'pCopy')] +#define INIT_hipMemcpyParam2D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyParam2D.pCopy = (const hip_Memcpy2D *)pCopy; \ + }; +// hipMemcpyParam2DAsync[('const hip_Memcpy2D*', 'pCopy'), ('hipStream_t', +// 'stream')] +#define INIT_hipMemcpyParam2DAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyParam2DAsync.pCopy = (const hip_Memcpy2D *)pCopy; \ + cb_data.args.hipMemcpyParam2DAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyPeer[('void*', 'dst'), ('int', 'dstDeviceId'), ('const void*', +// 'src'), ('int', 'srcDeviceId'), ('size_t', 'sizeBytes')] +#define INIT_hipMemcpyPeer_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyPeer.dst = (void *)dst; \ + cb_data.args.hipMemcpyPeer.dstDeviceId = (int)dstDevice; \ + cb_data.args.hipMemcpyPeer.src = (const void *)src; \ + cb_data.args.hipMemcpyPeer.srcDeviceId = (int)srcDevice; \ + cb_data.args.hipMemcpyPeer.sizeBytes = (size_t)sizeBytes; \ + }; +// hipMemcpyPeerAsync[('void*', 'dst'), ('int', 'dstDeviceId'), ('const void*', +// 'src'), ('int', 'srcDevice'), ('size_t', 'sizeBytes'), ('hipStream_t', +// 'stream')] +#define INIT_hipMemcpyPeerAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyPeerAsync.dst = (void *)dst; \ + cb_data.args.hipMemcpyPeerAsync.dstDeviceId = (int)dstDevice; \ + cb_data.args.hipMemcpyPeerAsync.src = (const void *)src; \ + cb_data.args.hipMemcpyPeerAsync.srcDevice = (int)srcDevice; \ + cb_data.args.hipMemcpyPeerAsync.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemcpyPeerAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyToArray[('hipArray_t', 'dst'), ('size_t', 'wOffset'), ('size_t', +// 'hOffset'), ('const void*', 'src'), ('size_t', 'count'), ('hipMemcpyKind', +// 'kind')] +#define INIT_hipMemcpyToArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyToArray.dst = (hipArray_t)dst; \ + cb_data.args.hipMemcpyToArray.wOffset = (size_t)wOffset; \ + cb_data.args.hipMemcpyToArray.hOffset = (size_t)hOffset; \ + cb_data.args.hipMemcpyToArray.src = (const void *)src; \ + cb_data.args.hipMemcpyToArray.count = (size_t)count; \ + cb_data.args.hipMemcpyToArray.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpyToSymbol[('const void*', 'symbol'), ('const void*', 'src'), +// ('size_t', 'sizeBytes'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] +#define INIT_hipMemcpyToSymbol_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyToSymbol.symbol = (const void *)symbol; \ + cb_data.args.hipMemcpyToSymbol.src = (const void *)src; \ + cb_data.args.hipMemcpyToSymbol.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemcpyToSymbol.offset = (size_t)offset; \ + cb_data.args.hipMemcpyToSymbol.kind = (hipMemcpyKind)kind; \ + }; +// hipMemcpyToSymbolAsync[('const void*', 'symbol'), ('const void*', 'src'), +// ('size_t', 'sizeBytes'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind'), +// ('hipStream_t', 'stream')] +#define INIT_hipMemcpyToSymbolAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyToSymbolAsync.symbol = (const void *)symbol; \ + cb_data.args.hipMemcpyToSymbolAsync.src = (const void *)src; \ + cb_data.args.hipMemcpyToSymbolAsync.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemcpyToSymbolAsync.offset = (size_t)offset; \ + cb_data.args.hipMemcpyToSymbolAsync.kind = (hipMemcpyKind)kind; \ + cb_data.args.hipMemcpyToSymbolAsync.stream = (hipStream_t)stream; \ + }; +// hipMemcpyWithStream[('void*', 'dst'), ('const void*', 'src'), ('size_t', +// 'sizeBytes'), ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] +#define INIT_hipMemcpyWithStream_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemcpyWithStream.dst = (void *)dst; \ + cb_data.args.hipMemcpyWithStream.src = (const void *)src; \ + cb_data.args.hipMemcpyWithStream.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemcpyWithStream.kind = (hipMemcpyKind)kind; \ + cb_data.args.hipMemcpyWithStream.stream = (hipStream_t)stream; \ + }; +// hipMemset[('void*', 'dst'), ('int', 'value'), ('size_t', 'sizeBytes')] +#define INIT_hipMemset_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemset.dst = (void *)dst; \ + cb_data.args.hipMemset.value = (int)value; \ + cb_data.args.hipMemset.sizeBytes = (size_t)sizeBytes; \ + }; +// hipMemset2D[('void*', 'dst'), ('size_t', 'pitch'), ('int', 'value'), +// ('size_t', 'width'), ('size_t', 'height')] +#define INIT_hipMemset2D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemset2D.dst = (void *)dst; \ + cb_data.args.hipMemset2D.pitch = (size_t)pitch; \ + cb_data.args.hipMemset2D.value = (int)value; \ + cb_data.args.hipMemset2D.width = (size_t)width; \ + cb_data.args.hipMemset2D.height = (size_t)height; \ + }; +// hipMemset2DAsync[('void*', 'dst'), ('size_t', 'pitch'), ('int', 'value'), +// ('size_t', 'width'), ('size_t', 'height'), ('hipStream_t', 'stream')] +#define INIT_hipMemset2DAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemset2DAsync.dst = (void *)dst; \ + cb_data.args.hipMemset2DAsync.pitch = (size_t)pitch; \ + cb_data.args.hipMemset2DAsync.value = (int)value; \ + cb_data.args.hipMemset2DAsync.width = (size_t)width; \ + cb_data.args.hipMemset2DAsync.height = (size_t)height; \ + cb_data.args.hipMemset2DAsync.stream = (hipStream_t)stream; \ + }; +// hipMemset3D[('hipPitchedPtr', 'pitchedDevPtr'), ('int', 'value'), +// ('hipExtent', 'extent')] +#define INIT_hipMemset3D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemset3D.pitchedDevPtr = (hipPitchedPtr)pitchedDevPtr; \ + cb_data.args.hipMemset3D.value = (int)value; \ + cb_data.args.hipMemset3D.extent = (hipExtent)extent; \ + }; +// hipMemset3DAsync[('hipPitchedPtr', 'pitchedDevPtr'), ('int', 'value'), +// ('hipExtent', 'extent'), ('hipStream_t', 'stream')] +#define INIT_hipMemset3DAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemset3DAsync.pitchedDevPtr = \ + (hipPitchedPtr)pitchedDevPtr; \ + cb_data.args.hipMemset3DAsync.value = (int)value; \ + cb_data.args.hipMemset3DAsync.extent = (hipExtent)extent; \ + cb_data.args.hipMemset3DAsync.stream = (hipStream_t)stream; \ + }; +// hipMemsetAsync[('void*', 'dst'), ('int', 'value'), ('size_t', 'sizeBytes'), +// ('hipStream_t', 'stream')] +#define INIT_hipMemsetAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemsetAsync.dst = (void *)dst; \ + cb_data.args.hipMemsetAsync.value = (int)value; \ + cb_data.args.hipMemsetAsync.sizeBytes = (size_t)sizeBytes; \ + cb_data.args.hipMemsetAsync.stream = (hipStream_t)stream; \ + }; +// hipMemsetD16[('hipDeviceptr_t', 'dest'), ('unsigned short', 'value'), +// ('size_t', 'count')] +#define INIT_hipMemsetD16_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemsetD16.dest = (hipDeviceptr_t)dst; \ + cb_data.args.hipMemsetD16.value = (unsigned short)value; \ + cb_data.args.hipMemsetD16.count = (size_t)count; \ + }; +// hipMemsetD16Async[('hipDeviceptr_t', 'dest'), ('unsigned short', 'value'), +// ('size_t', 'count'), ('hipStream_t', 'stream')] +#define INIT_hipMemsetD16Async_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemsetD16Async.dest = (hipDeviceptr_t)dst; \ + cb_data.args.hipMemsetD16Async.value = (unsigned short)value; \ + cb_data.args.hipMemsetD16Async.count = (size_t)count; \ + cb_data.args.hipMemsetD16Async.stream = (hipStream_t)stream; \ + }; +// hipMemsetD32[('hipDeviceptr_t', 'dest'), ('int', 'value'), ('size_t', +// 'count')] +#define INIT_hipMemsetD32_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemsetD32.dest = (hipDeviceptr_t)dst; \ + cb_data.args.hipMemsetD32.value = (int)value; \ + cb_data.args.hipMemsetD32.count = (size_t)count; \ + }; +// hipMemsetD32Async[('hipDeviceptr_t', 'dst'), ('int', 'value'), ('size_t', +// 'count'), ('hipStream_t', 'stream')] +#define INIT_hipMemsetD32Async_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemsetD32Async.dst = (hipDeviceptr_t)dst; \ + cb_data.args.hipMemsetD32Async.value = (int)value; \ + cb_data.args.hipMemsetD32Async.count = (size_t)count; \ + cb_data.args.hipMemsetD32Async.stream = (hipStream_t)stream; \ + }; +// hipMemsetD8[('hipDeviceptr_t', 'dest'), ('unsigned char', 'value'), +// ('size_t', 'count')] +#define INIT_hipMemsetD8_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemsetD8.dest = (hipDeviceptr_t)dst; \ + cb_data.args.hipMemsetD8.value = (unsigned char)value; \ + cb_data.args.hipMemsetD8.count = (size_t)count; \ + }; +// hipMemsetD8Async[('hipDeviceptr_t', 'dest'), ('unsigned char', 'value'), +// ('size_t', 'count'), ('hipStream_t', 'stream')] +#define INIT_hipMemsetD8Async_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMemsetD8Async.dest = (hipDeviceptr_t)dst; \ + cb_data.args.hipMemsetD8Async.value = (unsigned char)value; \ + cb_data.args.hipMemsetD8Async.count = (size_t)count; \ + cb_data.args.hipMemsetD8Async.stream = (hipStream_t)stream; \ + }; +// hipMipmappedArrayCreate[('hipMipmappedArray_t*', 'pHandle'), +// ('HIP_ARRAY3D_DESCRIPTOR*', 'pMipmappedArrayDesc'), ('unsigned int', +// 'numMipmapLevels')] +#define INIT_hipMipmappedArrayCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMipmappedArrayCreate.pHandle = \ + (hipMipmappedArray_t *)mipmapped_array_pptr; \ + cb_data.args.hipMipmappedArrayCreate.pMipmappedArrayDesc = \ + (HIP_ARRAY3D_DESCRIPTOR *)mipmapped_array_desc_ptr; \ + cb_data.args.hipMipmappedArrayCreate.numMipmapLevels = \ + (unsigned int)num_mipmap_levels; \ + }; +// hipMipmappedArrayDestroy[('hipMipmappedArray_t', 'hMipmappedArray')] +#define INIT_hipMipmappedArrayDestroy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMipmappedArrayDestroy.hMipmappedArray = \ + (hipMipmappedArray_t)mipmapped_array_ptr; \ + }; +// hipMipmappedArrayGetLevel[('hipArray_t*', 'pLevelArray'), +// ('hipMipmappedArray_t', 'hMipMappedArray'), ('unsigned int', 'level')] +#define INIT_hipMipmappedArrayGetLevel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipMipmappedArrayGetLevel.pLevelArray = \ + (hipArray_t *)level_array_pptr; \ + cb_data.args.hipMipmappedArrayGetLevel.hMipMappedArray = \ + (hipMipmappedArray_t)mipmapped_array_ptr; \ + cb_data.args.hipMipmappedArrayGetLevel.level = (unsigned int)mip_level; \ + }; +// hipModuleGetFunction[('hipFunction_t*', 'function'), ('hipModule_t', +// 'module'), ('const char*', 'kname')] +#define INIT_hipModuleGetFunction_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleGetFunction.function = (hipFunction_t *)hfunc; \ + cb_data.args.hipModuleGetFunction.module = (hipModule_t)hmod; \ + cb_data.args.hipModuleGetFunction.kname = (name) ? strdup(name) : NULL; \ + }; +// hipModuleGetGlobal[('hipDeviceptr_t*', 'dptr'), ('size_t*', 'bytes'), +// ('hipModule_t', 'hmod'), ('const char*', 'name')] +#define INIT_hipModuleGetGlobal_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleGetGlobal.dptr = (hipDeviceptr_t *)dptr; \ + cb_data.args.hipModuleGetGlobal.bytes = (size_t *)bytes; \ + cb_data.args.hipModuleGetGlobal.hmod = (hipModule_t)hmod; \ + cb_data.args.hipModuleGetGlobal.name = (name) ? strdup(name) : NULL; \ + }; +// hipModuleGetTexRef[('textureReference**', 'texRef'), ('hipModule_t', 'hmod'), +// ('const char*', 'name')] +#define INIT_hipModuleGetTexRef_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleGetTexRef.texRef = (textureReference **)texRef; \ + cb_data.args.hipModuleGetTexRef.hmod = (hipModule_t)hmod; \ + cb_data.args.hipModuleGetTexRef.name = (name) ? strdup(name) : NULL; \ + }; +// hipModuleLaunchCooperativeKernel[('hipFunction_t', 'f'), ('unsigned int', +// 'gridDimX'), ('unsigned int', 'gridDimY'), ('unsigned int', 'gridDimZ'), +// ('unsigned int', 'blockDimX'), ('unsigned int', 'blockDimY'), ('unsigned +// int', 'blockDimZ'), ('unsigned int', 'sharedMemBytes'), ('hipStream_t', +// 'stream'), ('void**', 'kernelParams')] +#define INIT_hipModuleLaunchCooperativeKernel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleLaunchCooperativeKernel.f = (hipFunction_t)f; \ + cb_data.args.hipModuleLaunchCooperativeKernel.gridDimX = \ + (unsigned int)gridDimX; \ + cb_data.args.hipModuleLaunchCooperativeKernel.gridDimY = \ + (unsigned int)gridDimY; \ + cb_data.args.hipModuleLaunchCooperativeKernel.gridDimZ = \ + (unsigned int)gridDimZ; \ + cb_data.args.hipModuleLaunchCooperativeKernel.blockDimX = \ + (unsigned int)blockDimX; \ + cb_data.args.hipModuleLaunchCooperativeKernel.blockDimY = \ + (unsigned int)blockDimY; \ + cb_data.args.hipModuleLaunchCooperativeKernel.blockDimZ = \ + (unsigned int)blockDimZ; \ + cb_data.args.hipModuleLaunchCooperativeKernel.sharedMemBytes = \ + (unsigned int)sharedMemBytes; \ + cb_data.args.hipModuleLaunchCooperativeKernel.stream = \ + (hipStream_t)stream; \ + cb_data.args.hipModuleLaunchCooperativeKernel.kernelParams = \ + (void **)kernelParams; \ + }; +// hipModuleLaunchCooperativeKernelMultiDevice[('hipFunctionLaunchParams*', +// 'launchParamsList'), ('unsigned int', 'numDevices'), ('unsigned int', +// 'flags')] +#define INIT_hipModuleLaunchCooperativeKernelMultiDevice_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleLaunchCooperativeKernelMultiDevice \ + .launchParamsList = (hipFunctionLaunchParams *)launchParamsList; \ + cb_data.args.hipModuleLaunchCooperativeKernelMultiDevice.numDevices = \ + (unsigned int)numDevices; \ + cb_data.args.hipModuleLaunchCooperativeKernelMultiDevice.flags = \ + (unsigned int)flags; \ + }; +// hipModuleLaunchKernel[('hipFunction_t', 'f'), ('unsigned int', 'gridDimX'), +// ('unsigned int', 'gridDimY'), ('unsigned int', 'gridDimZ'), ('unsigned int', +// 'blockDimX'), ('unsigned int', 'blockDimY'), ('unsigned int', 'blockDimZ'), +// ('unsigned int', 'sharedMemBytes'), ('hipStream_t', 'stream'), ('void**', +// 'kernelParams'), ('void**', 'extra')] +#define INIT_hipModuleLaunchKernel_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleLaunchKernel.f = (hipFunction_t)f; \ + cb_data.args.hipModuleLaunchKernel.gridDimX = (unsigned int)gridDimX; \ + cb_data.args.hipModuleLaunchKernel.gridDimY = (unsigned int)gridDimY; \ + cb_data.args.hipModuleLaunchKernel.gridDimZ = (unsigned int)gridDimZ; \ + cb_data.args.hipModuleLaunchKernel.blockDimX = (unsigned int)blockDimX; \ + cb_data.args.hipModuleLaunchKernel.blockDimY = (unsigned int)blockDimY; \ + cb_data.args.hipModuleLaunchKernel.blockDimZ = (unsigned int)blockDimZ; \ + cb_data.args.hipModuleLaunchKernel.sharedMemBytes = \ + (unsigned int)sharedMemBytes; \ + cb_data.args.hipModuleLaunchKernel.stream = (hipStream_t)hStream; \ + cb_data.args.hipModuleLaunchKernel.kernelParams = (void **)kernelParams; \ + cb_data.args.hipModuleLaunchKernel.extra = (void **)extra; \ + }; +// hipModuleLoad[('hipModule_t*', 'module'), ('const char*', 'fname')] +#define INIT_hipModuleLoad_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleLoad.module = (hipModule_t *)module; \ + cb_data.args.hipModuleLoad.fname = (fname) ? strdup(fname) : NULL; \ + }; +// hipModuleLoadData[('hipModule_t*', 'module'), ('const void*', 'image')] +#define INIT_hipModuleLoadData_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleLoadData.module = (hipModule_t *)module; \ + cb_data.args.hipModuleLoadData.image = (const void *)image; \ + }; +// hipModuleLoadDataEx[('hipModule_t*', 'module'), ('const void*', 'image'), +// ('unsigned int', 'numOptions'), ('hipJitOption*', 'options'), ('void**', +// 'optionsValues')] +#define INIT_hipModuleLoadDataEx_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleLoadDataEx.module = (hipModule_t *)module; \ + cb_data.args.hipModuleLoadDataEx.image = (const void *)image; \ + cb_data.args.hipModuleLoadDataEx.numOptions = (unsigned int)numOptions; \ + cb_data.args.hipModuleLoadDataEx.options = (hipJitOption *)options; \ + cb_data.args.hipModuleLoadDataEx.optionsValues = (void **)optionsValues; \ + }; +// hipModuleOccupancyMaxActiveBlocksPerMultiprocessor[('int*', 'numBlocks'), +// ('hipFunction_t', 'f'), ('int', 'blockSize'), ('size_t', +// 'dynSharedMemPerBlk')] +#define INIT_hipModuleOccupancyMaxActiveBlocksPerMultiprocessor_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor \ + .numBlocks = (int *)numBlocks; \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor.f = \ + (hipFunction_t)f; \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor \ + .blockSize = (int)blockSize; \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor \ + .dynSharedMemPerBlk = (size_t)dynSharedMemPerBlk; \ + }; +// hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags[('int*', +// 'numBlocks'), ('hipFunction_t', 'f'), ('int', 'blockSize'), ('size_t', +// 'dynSharedMemPerBlk'), ('unsigned int', 'flags')] +#define INIT_hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags \ + .numBlocks = (int *)numBlocks; \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags \ + .f = (hipFunction_t)f; \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags \ + .blockSize = (int)blockSize; \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags \ + .dynSharedMemPerBlk = (size_t)dynSharedMemPerBlk; \ + cb_data.args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags \ + .flags = (unsigned int)flags; \ + }; +// hipModuleOccupancyMaxPotentialBlockSize[('int*', 'gridSize'), ('int*', +// 'blockSize'), ('hipFunction_t', 'f'), ('size_t', 'dynSharedMemPerBlk'), +// ('int', 'blockSizeLimit')] +#define INIT_hipModuleOccupancyMaxPotentialBlockSize_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSize.gridSize = \ + (int *)gridSize; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSize.blockSize = \ + (int *)blockSize; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSize.f = (hipFunction_t)f; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSize.dynSharedMemPerBlk = \ + (size_t)dynSharedMemPerBlk; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSize.blockSizeLimit = \ + (int)blockSizeLimit; \ + }; +// hipModuleOccupancyMaxPotentialBlockSizeWithFlags[('int*', 'gridSize'), +// ('int*', 'blockSize'), ('hipFunction_t', 'f'), ('size_t', +// 'dynSharedMemPerBlk'), ('int', 'blockSizeLimit'), ('unsigned int', 'flags')] +#define INIT_hipModuleOccupancyMaxPotentialBlockSizeWithFlags_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.gridSize = \ + (int *)gridSize; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.blockSize = \ + (int *)blockSize; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.f = \ + (hipFunction_t)f; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags \ + .dynSharedMemPerBlk = (size_t)dynSharedMemPerBlk; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags \ + .blockSizeLimit = (int)blockSizeLimit; \ + cb_data.args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.flags = \ + (unsigned int)flags; \ + }; +// hipModuleUnload[('hipModule_t', 'module')] +#define INIT_hipModuleUnload_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipModuleUnload.module = (hipModule_t)hmod; \ + }; +// hipOccupancyMaxActiveBlocksPerMultiprocessor[('int*', 'numBlocks'), ('const +// void*', 'f'), ('int', 'blockSize'), ('size_t', 'dynamicSMemSize')] +#define INIT_hipOccupancyMaxActiveBlocksPerMultiprocessor_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessor.numBlocks = \ + (int *)numBlocks; \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessor.f = \ + (const void *)f; \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessor.blockSize = \ + (int)blockSize; \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessor \ + .dynamicSMemSize = (size_t)dynamicSMemSize; \ + }; +// hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags[('int*', 'numBlocks'), +// ('const void*', 'f'), ('int', 'blockSize'), ('size_t', 'dynamicSMemSize'), +// ('unsigned int', 'flags')] +#define INIT_hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags_CB_ARGS_DATA( \ + cb_data) \ + { \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags \ + .numBlocks = (int *)numBlocks; \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags.f = \ + (const void *)f; \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags \ + .blockSize = (int)blockSize; \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags \ + .dynamicSMemSize = (size_t)dynamicSMemSize; \ + cb_data.args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags.flags = \ + (unsigned int)flags; \ + }; +// hipOccupancyMaxPotentialBlockSize[('int*', 'gridSize'), ('int*', +// 'blockSize'), ('const void*', 'f'), ('size_t', 'dynSharedMemPerBlk'), ('int', +// 'blockSizeLimit')] +#define INIT_hipOccupancyMaxPotentialBlockSize_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipOccupancyMaxPotentialBlockSize.gridSize = (int *)gridSize; \ + cb_data.args.hipOccupancyMaxPotentialBlockSize.blockSize = \ + (int *)blockSize; \ + cb_data.args.hipOccupancyMaxPotentialBlockSize.f = (const void *)f; \ + cb_data.args.hipOccupancyMaxPotentialBlockSize.dynSharedMemPerBlk = \ + (size_t)dynSharedMemPerBlk; \ + cb_data.args.hipOccupancyMaxPotentialBlockSize.blockSizeLimit = \ + (int)blockSizeLimit; \ + }; +// hipPeekAtLastError[] +#define INIT_hipPeekAtLastError_CB_ARGS_DATA(cb_data) {}; +// hipPointerGetAttribute[('void*', 'data'), ('hipPointer_attribute', +// 'attribute'), ('hipDeviceptr_t', 'ptr')] +#define INIT_hipPointerGetAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipPointerGetAttribute.data = (void *)data; \ + cb_data.args.hipPointerGetAttribute.attribute = \ + (hipPointer_attribute)attribute; \ + cb_data.args.hipPointerGetAttribute.ptr = (hipDeviceptr_t)ptr; \ + }; +// hipPointerGetAttributes[('hipPointerAttribute_t*', 'attributes'), ('const +// void*', 'ptr')] +#define INIT_hipPointerGetAttributes_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipPointerGetAttributes.attributes = \ + (hipPointerAttribute_t *)attributes; \ + cb_data.args.hipPointerGetAttributes.ptr = (const void *)ptr; \ + }; +// hipPointerSetAttribute[('const void*', 'value'), ('hipPointer_attribute', +// 'attribute'), ('hipDeviceptr_t', 'ptr')] +#define INIT_hipPointerSetAttribute_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipPointerSetAttribute.value = (const void *)value; \ + cb_data.args.hipPointerSetAttribute.attribute = \ + (hipPointer_attribute)attribute; \ + cb_data.args.hipPointerSetAttribute.ptr = (hipDeviceptr_t)ptr; \ + }; +// hipProfilerStart[] +#define INIT_hipProfilerStart_CB_ARGS_DATA(cb_data) {}; +// hipProfilerStop[] +#define INIT_hipProfilerStop_CB_ARGS_DATA(cb_data) {}; +// hipRuntimeGetVersion[('int*', 'runtimeVersion')] +#define INIT_hipRuntimeGetVersion_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipRuntimeGetVersion.runtimeVersion = (int *)runtimeVersion; \ + }; +// hipSetDevice[('int', 'deviceId')] +#define INIT_hipSetDevice_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipSetDevice.deviceId = (int)device; \ + }; +// hipSetDeviceFlags[('unsigned int', 'flags')] +#define INIT_hipSetDeviceFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipSetDeviceFlags.flags = (unsigned int)flags; \ + }; +// hipSetValidDevices[('int*', 'device_arr'), ('int', 'len')] +#define INIT_hipSetValidDevices_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipSetValidDevices.device_arr = (int *)device_arr; \ + cb_data.args.hipSetValidDevices.len = (int)len; \ + }; +// hipSetupArgument[('const void*', 'arg'), ('size_t', 'size'), ('size_t', +// 'offset')] +#define INIT_hipSetupArgument_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipSetupArgument.arg = (const void *)arg; \ + cb_data.args.hipSetupArgument.size = (size_t)size; \ + cb_data.args.hipSetupArgument.offset = (size_t)offset; \ + }; +// hipSignalExternalSemaphoresAsync[('const hipExternalSemaphore_t*', +// 'extSemArray'), ('const hipExternalSemaphoreSignalParams*', 'paramsArray'), +// ('unsigned int', 'numExtSems'), ('hipStream_t', 'stream')] +#define INIT_hipSignalExternalSemaphoresAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipSignalExternalSemaphoresAsync.extSemArray = \ + (const hipExternalSemaphore_t *)extSemArray; \ + cb_data.args.hipSignalExternalSemaphoresAsync.paramsArray = \ + (const hipExternalSemaphoreSignalParams *)paramsArray; \ + cb_data.args.hipSignalExternalSemaphoresAsync.numExtSems = \ + (unsigned int)numExtSems; \ + cb_data.args.hipSignalExternalSemaphoresAsync.stream = \ + (hipStream_t)stream; \ + }; +// hipStreamAddCallback[('hipStream_t', 'stream'), ('hipStreamCallback_t', +// 'callback'), ('void*', 'userData'), ('unsigned int', 'flags')] +#define INIT_hipStreamAddCallback_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamAddCallback.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamAddCallback.callback = \ + (hipStreamCallback_t)callback; \ + cb_data.args.hipStreamAddCallback.userData = (void *)userData; \ + cb_data.args.hipStreamAddCallback.flags = (unsigned int)flags; \ + }; +// hipStreamAttachMemAsync[('hipStream_t', 'stream'), ('void*', 'dev_ptr'), +// ('size_t', 'length'), ('unsigned int', 'flags')] +#define INIT_hipStreamAttachMemAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamAttachMemAsync.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamAttachMemAsync.dev_ptr = (void *)dev_ptr; \ + cb_data.args.hipStreamAttachMemAsync.length = (size_t)length; \ + cb_data.args.hipStreamAttachMemAsync.flags = (unsigned int)flags; \ + }; +// hipStreamBeginCapture[('hipStream_t', 'stream'), ('hipStreamCaptureMode', +// 'mode')] +#define INIT_hipStreamBeginCapture_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamBeginCapture.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamBeginCapture.mode = (hipStreamCaptureMode)mode; \ + }; +// hipStreamBeginCaptureToGraph[('hipStream_t', 'stream'), ('hipGraph_t', +// 'graph'), ('const hipGraphNode_t*', 'dependencies'), ('const +// hipGraphEdgeData*', 'dependencyData'), ('size_t', 'numDependencies'), +// ('hipStreamCaptureMode', 'mode')] +#define INIT_hipStreamBeginCaptureToGraph_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamBeginCaptureToGraph.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamBeginCaptureToGraph.graph = (hipGraph_t)graph; \ + cb_data.args.hipStreamBeginCaptureToGraph.dependencies = \ + (const hipGraphNode_t *)dependencies; \ + cb_data.args.hipStreamBeginCaptureToGraph.dependencyData = \ + (const hipGraphEdgeData *)dependencyData; \ + cb_data.args.hipStreamBeginCaptureToGraph.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipStreamBeginCaptureToGraph.mode = \ + (hipStreamCaptureMode)mode; \ + }; +// hipStreamCreate[('hipStream_t*', 'stream')] +#define INIT_hipStreamCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamCreate.stream = (hipStream_t *)stream; \ + }; +// hipStreamCreateWithFlags[('hipStream_t*', 'stream'), ('unsigned int', +// 'flags')] +#define INIT_hipStreamCreateWithFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamCreateWithFlags.stream = (hipStream_t *)stream; \ + cb_data.args.hipStreamCreateWithFlags.flags = (unsigned int)flags; \ + }; +// hipStreamCreateWithPriority[('hipStream_t*', 'stream'), ('unsigned int', +// 'flags'), ('int', 'priority')] +#define INIT_hipStreamCreateWithPriority_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamCreateWithPriority.stream = (hipStream_t *)stream; \ + cb_data.args.hipStreamCreateWithPriority.flags = (unsigned int)flags; \ + cb_data.args.hipStreamCreateWithPriority.priority = (int)priority; \ + }; +// hipStreamDestroy[('hipStream_t', 'stream')] +#define INIT_hipStreamDestroy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamDestroy.stream = (hipStream_t)stream; \ + }; +// hipStreamEndCapture[('hipStream_t', 'stream'), ('hipGraph_t*', 'pGraph')] +#define INIT_hipStreamEndCapture_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamEndCapture.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamEndCapture.pGraph = (hipGraph_t *)pGraph; \ + }; +// hipStreamGetCaptureInfo[('hipStream_t', 'stream'), +// ('hipStreamCaptureStatus*', 'pCaptureStatus'), ('unsigned long long*', +// 'pId')] +#define INIT_hipStreamGetCaptureInfo_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamGetCaptureInfo.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamGetCaptureInfo.pCaptureStatus = \ + (hipStreamCaptureStatus *)pCaptureStatus; \ + cb_data.args.hipStreamGetCaptureInfo.pId = (unsigned long long *)pId; \ + }; +// hipStreamGetCaptureInfo_v2[('hipStream_t', 'stream'), +// ('hipStreamCaptureStatus*', 'captureStatus_out'), ('unsigned long long*', +// 'id_out'), ('hipGraph_t*', 'graph_out'), ('const hipGraphNode_t**', +// 'dependencies_out'), ('size_t*', 'numDependencies_out')] +#define INIT_hipStreamGetCaptureInfo_v2_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamGetCaptureInfo_v2.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamGetCaptureInfo_v2.captureStatus_out = \ + (hipStreamCaptureStatus *)captureStatus_out; \ + cb_data.args.hipStreamGetCaptureInfo_v2.id_out = \ + (unsigned long long *)id_out; \ + cb_data.args.hipStreamGetCaptureInfo_v2.graph_out = \ + (hipGraph_t *)graph_out; \ + cb_data.args.hipStreamGetCaptureInfo_v2.dependencies_out = \ + (const hipGraphNode_t **)dependencies_out; \ + cb_data.args.hipStreamGetCaptureInfo_v2.numDependencies_out = \ + (size_t *)numDependencies_out; \ + }; +// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')] +#define INIT_hipStreamGetDevice_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamGetDevice.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamGetDevice.device = (hipDevice_t *)device; \ + }; +// hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')] +#define INIT_hipStreamGetFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamGetFlags.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamGetFlags.flags = (unsigned int *)flags; \ + }; +// hipStreamGetPriority[('hipStream_t', 'stream'), ('int*', 'priority')] +#define INIT_hipStreamGetPriority_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamGetPriority.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamGetPriority.priority = (int *)priority; \ + }; +// hipStreamIsCapturing[('hipStream_t', 'stream'), ('hipStreamCaptureStatus*', +// 'pCaptureStatus')] +#define INIT_hipStreamIsCapturing_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamIsCapturing.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamIsCapturing.pCaptureStatus = \ + (hipStreamCaptureStatus *)pCaptureStatus; \ + }; +// hipStreamQuery[('hipStream_t', 'stream')] +#define INIT_hipStreamQuery_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamQuery.stream = (hipStream_t)stream; \ + }; +// hipStreamSynchronize[('hipStream_t', 'stream')] +#define INIT_hipStreamSynchronize_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamSynchronize.stream = (hipStream_t)stream; \ + }; +// hipStreamUpdateCaptureDependencies[('hipStream_t', 'stream'), +// ('hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'), +// ('unsigned int', 'flags')] +#define INIT_hipStreamUpdateCaptureDependencies_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamUpdateCaptureDependencies.stream = \ + (hipStream_t)stream; \ + cb_data.args.hipStreamUpdateCaptureDependencies.dependencies = \ + (hipGraphNode_t *)dependencies; \ + cb_data.args.hipStreamUpdateCaptureDependencies.numDependencies = \ + (size_t)numDependencies; \ + cb_data.args.hipStreamUpdateCaptureDependencies.flags = \ + (unsigned int)flags; \ + }; +// hipStreamWaitEvent[('hipStream_t', 'stream'), ('hipEvent_t', 'event'), +// ('unsigned int', 'flags')] +#define INIT_hipStreamWaitEvent_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamWaitEvent.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamWaitEvent.event = (hipEvent_t)event; \ + cb_data.args.hipStreamWaitEvent.flags = (unsigned int)flags; \ + }; +// hipStreamWaitValue32[('hipStream_t', 'stream'), ('void*', 'ptr'), ('unsigned +// int', 'value'), ('unsigned int', 'flags'), ('unsigned int', 'mask')] +#define INIT_hipStreamWaitValue32_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamWaitValue32.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamWaitValue32.ptr = (void *)ptr; \ + cb_data.args.hipStreamWaitValue32.value = (unsigned int)value; \ + cb_data.args.hipStreamWaitValue32.flags = (unsigned int)flags; \ + cb_data.args.hipStreamWaitValue32.mask = (unsigned int)mask; \ + }; +// hipStreamWaitValue64[('hipStream_t', 'stream'), ('void*', 'ptr'), +// ('uint64_t', 'value'), ('unsigned int', 'flags'), ('uint64_t', 'mask')] +#define INIT_hipStreamWaitValue64_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamWaitValue64.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamWaitValue64.ptr = (void *)ptr; \ + cb_data.args.hipStreamWaitValue64.value = (uint64_t)value; \ + cb_data.args.hipStreamWaitValue64.flags = (unsigned int)flags; \ + cb_data.args.hipStreamWaitValue64.mask = (uint64_t)mask; \ + }; +// hipStreamWriteValue32[('hipStream_t', 'stream'), ('void*', 'ptr'), ('unsigned +// int', 'value'), ('unsigned int', 'flags')] +#define INIT_hipStreamWriteValue32_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamWriteValue32.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamWriteValue32.ptr = (void *)ptr; \ + cb_data.args.hipStreamWriteValue32.value = (unsigned int)value; \ + cb_data.args.hipStreamWriteValue32.flags = (unsigned int)flags; \ + }; +// hipStreamWriteValue64[('hipStream_t', 'stream'), ('void*', 'ptr'), +// ('uint64_t', 'value'), ('unsigned int', 'flags')] +#define INIT_hipStreamWriteValue64_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipStreamWriteValue64.stream = (hipStream_t)stream; \ + cb_data.args.hipStreamWriteValue64.ptr = (void *)ptr; \ + cb_data.args.hipStreamWriteValue64.value = (uint64_t)value; \ + cb_data.args.hipStreamWriteValue64.flags = (unsigned int)flags; \ + }; +// hipTexRefGetAddress[('hipDeviceptr_t*', 'dev_ptr'), ('const +// textureReference*', 'texRef')] +#define INIT_hipTexRefGetAddress_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetAddress.dev_ptr = (hipDeviceptr_t *)dptr; \ + cb_data.args.hipTexRefGetAddress.texRef = \ + (const textureReference *)texRef; \ + }; +// hipTexRefGetArray[('hipArray_t*', 'pArray'), ('const textureReference*', +// 'texRef')] +#define INIT_hipTexRefGetArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetArray.pArray = (hipArray_t *)pArray; \ + cb_data.args.hipTexRefGetArray.texRef = (const textureReference *)texRef; \ + }; +// hipTexRefGetBorderColor[('float*', 'pBorderColor'), ('const +// textureReference*', 'texRef')] +#define INIT_hipTexRefGetBorderColor_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetBorderColor.pBorderColor = (float *)pBorderColor; \ + cb_data.args.hipTexRefGetBorderColor.texRef = \ + (const textureReference *)texRef; \ + }; +// hipTexRefGetFlags[('unsigned int*', 'pFlags'), ('const textureReference*', +// 'texRef')] +#define INIT_hipTexRefGetFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetFlags.pFlags = (unsigned int *)pFlags; \ + cb_data.args.hipTexRefGetFlags.texRef = (const textureReference *)texRef; \ + }; +// hipTexRefGetFormat[('hipArray_Format*', 'pFormat'), ('int*', 'pNumChannels'), +// ('const textureReference*', 'texRef')] +#define INIT_hipTexRefGetFormat_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetFormat.pFormat = (hipArray_Format *)pFormat; \ + cb_data.args.hipTexRefGetFormat.pNumChannels = (int *)pNumChannels; \ + cb_data.args.hipTexRefGetFormat.texRef = (const textureReference *)texRef; \ + }; +// hipTexRefGetMaxAnisotropy[('int*', 'pmaxAnsio'), ('const textureReference*', +// 'texRef')] +#define INIT_hipTexRefGetMaxAnisotropy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetMaxAnisotropy.pmaxAnsio = (int *)pmaxAnsio; \ + cb_data.args.hipTexRefGetMaxAnisotropy.texRef = \ + (const textureReference *)texRef; \ + }; +// hipTexRefGetMipMappedArray[('hipMipmappedArray_t*', 'pArray'), ('const +// textureReference*', 'texRef')] +#define INIT_hipTexRefGetMipMappedArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetMipMappedArray.pArray = \ + (hipMipmappedArray_t *)pArray; \ + cb_data.args.hipTexRefGetMipMappedArray.texRef = \ + (const textureReference *)texRef; \ + }; +// hipTexRefGetMipmapLevelBias[('float*', 'pbias'), ('const textureReference*', +// 'texRef')] +#define INIT_hipTexRefGetMipmapLevelBias_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetMipmapLevelBias.pbias = (float *)pbias; \ + cb_data.args.hipTexRefGetMipmapLevelBias.texRef = \ + (const textureReference *)texRef; \ + }; +// hipTexRefGetMipmapLevelClamp[('float*', 'pminMipmapLevelClamp'), ('float*', +// 'pmaxMipmapLevelClamp'), ('const textureReference*', 'texRef')] +#define INIT_hipTexRefGetMipmapLevelClamp_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefGetMipmapLevelClamp.pminMipmapLevelClamp = \ + (float *)pminMipmapLevelClamp; \ + cb_data.args.hipTexRefGetMipmapLevelClamp.pmaxMipmapLevelClamp = \ + (float *)pmaxMipmapLevelClamp; \ + cb_data.args.hipTexRefGetMipmapLevelClamp.texRef = \ + (const textureReference *)texRef; \ + }; +// hipTexRefSetAddress[('size_t*', 'ByteOffset'), ('textureReference*', +// 'texRef'), ('hipDeviceptr_t', 'dptr'), ('size_t', 'bytes')] +#define INIT_hipTexRefSetAddress_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetAddress.ByteOffset = (size_t *)ByteOffset; \ + cb_data.args.hipTexRefSetAddress.texRef = (textureReference *)texRef; \ + cb_data.args.hipTexRefSetAddress.dptr = (hipDeviceptr_t)dptr; \ + cb_data.args.hipTexRefSetAddress.bytes = (size_t)bytes; \ + }; +// hipTexRefSetAddress2D[('textureReference*', 'texRef'), ('const +// HIP_ARRAY_DESCRIPTOR*', 'desc'), ('hipDeviceptr_t', 'dptr'), ('size_t', +// 'Pitch')] +#define INIT_hipTexRefSetAddress2D_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetAddress2D.texRef = (textureReference *)texRef; \ + cb_data.args.hipTexRefSetAddress2D.desc = \ + (const HIP_ARRAY_DESCRIPTOR *)desc; \ + cb_data.args.hipTexRefSetAddress2D.dptr = (hipDeviceptr_t)dptr; \ + cb_data.args.hipTexRefSetAddress2D.Pitch = (size_t)Pitch; \ + }; +// hipTexRefSetArray[('textureReference*', 'tex'), ('hipArray_const_t', +// 'array'), ('unsigned int', 'flags')] +#define INIT_hipTexRefSetArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetArray.tex = (textureReference *)texRef; \ + cb_data.args.hipTexRefSetArray.array = (hipArray_const_t)array; \ + cb_data.args.hipTexRefSetArray.flags = (unsigned int)flags; \ + }; +// hipTexRefSetBorderColor[('textureReference*', 'texRef'), ('float*', +// 'pBorderColor')] +#define INIT_hipTexRefSetBorderColor_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetBorderColor.texRef = (textureReference *)texRef; \ + cb_data.args.hipTexRefSetBorderColor.pBorderColor = (float *)pBorderColor; \ + }; +// hipTexRefSetFlags[('textureReference*', 'texRef'), ('unsigned int', 'Flags')] +#define INIT_hipTexRefSetFlags_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetFlags.texRef = (textureReference *)texRef; \ + cb_data.args.hipTexRefSetFlags.Flags = (unsigned int)Flags; \ + }; +// hipTexRefSetFormat[('textureReference*', 'texRef'), ('hipArray_Format', +// 'fmt'), ('int', 'NumPackedComponents')] +#define INIT_hipTexRefSetFormat_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetFormat.texRef = (textureReference *)texRef; \ + cb_data.args.hipTexRefSetFormat.fmt = (hipArray_Format)fmt; \ + cb_data.args.hipTexRefSetFormat.NumPackedComponents = \ + (int)NumPackedComponents; \ + }; +// hipTexRefSetMaxAnisotropy[('textureReference*', 'texRef'), ('unsigned int', +// 'maxAniso')] +#define INIT_hipTexRefSetMaxAnisotropy_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetMaxAnisotropy.texRef = \ + (textureReference *)texRef; \ + cb_data.args.hipTexRefSetMaxAnisotropy.maxAniso = (unsigned int)maxAniso; \ + }; +// hipTexRefSetMipmapLevelBias[('textureReference*', 'texRef'), ('float', +// 'bias')] +#define INIT_hipTexRefSetMipmapLevelBias_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetMipmapLevelBias.texRef = \ + (textureReference *)texRef; \ + cb_data.args.hipTexRefSetMipmapLevelBias.bias = (float)bias; \ + }; +// hipTexRefSetMipmapLevelClamp[('textureReference*', 'texRef'), ('float', +// 'minMipMapLevelClamp'), ('float', 'maxMipMapLevelClamp')] +#define INIT_hipTexRefSetMipmapLevelClamp_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetMipmapLevelClamp.texRef = \ + (textureReference *)texRef; \ + cb_data.args.hipTexRefSetMipmapLevelClamp.minMipMapLevelClamp = \ + (float)minMipMapLevelClamp; \ + cb_data.args.hipTexRefSetMipmapLevelClamp.maxMipMapLevelClamp = \ + (float)maxMipMapLevelClamp; \ + }; +// hipTexRefSetMipmappedArray[('textureReference*', 'texRef'), +// ('hipMipmappedArray*', 'mipmappedArray'), ('unsigned int', 'Flags')] +#define INIT_hipTexRefSetMipmappedArray_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipTexRefSetMipmappedArray.texRef = \ + (textureReference *)texRef; \ + cb_data.args.hipTexRefSetMipmappedArray.mipmappedArray = \ + (hipMipmappedArray *)mipmappedArray; \ + cb_data.args.hipTexRefSetMipmappedArray.Flags = (unsigned int)Flags; \ + }; +// hipThreadExchangeStreamCaptureMode[('hipStreamCaptureMode*', 'mode')] +#define INIT_hipThreadExchangeStreamCaptureMode_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipThreadExchangeStreamCaptureMode.mode = \ + (hipStreamCaptureMode *)mode; \ + }; +// hipUserObjectCreate[('hipUserObject_t*', 'object_out'), ('void*', 'ptr'), +// ('hipHostFn_t', 'destroy'), ('unsigned int', 'initialRefcount'), ('unsigned +// int', 'flags')] +#define INIT_hipUserObjectCreate_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipUserObjectCreate.object_out = \ + (hipUserObject_t *)object_out; \ + cb_data.args.hipUserObjectCreate.ptr = (void *)ptr; \ + cb_data.args.hipUserObjectCreate.destroy = (hipHostFn_t)destroy; \ + cb_data.args.hipUserObjectCreate.initialRefcount = \ + (unsigned int)initialRefcount; \ + cb_data.args.hipUserObjectCreate.flags = (unsigned int)flags; \ + }; +// hipUserObjectRelease[('hipUserObject_t', 'object'), ('unsigned int', +// 'count')] +#define INIT_hipUserObjectRelease_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipUserObjectRelease.object = (hipUserObject_t)object; \ + cb_data.args.hipUserObjectRelease.count = (unsigned int)count; \ + }; +// hipUserObjectRetain[('hipUserObject_t', 'object'), ('unsigned int', 'count')] +#define INIT_hipUserObjectRetain_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipUserObjectRetain.object = (hipUserObject_t)object; \ + cb_data.args.hipUserObjectRetain.count = (unsigned int)count; \ + }; +// hipWaitExternalSemaphoresAsync[('const hipExternalSemaphore_t*', +// 'extSemArray'), ('const hipExternalSemaphoreWaitParams*', 'paramsArray'), +// ('unsigned int', 'numExtSems'), ('hipStream_t', 'stream')] +#define INIT_hipWaitExternalSemaphoresAsync_CB_ARGS_DATA(cb_data) \ + { \ + cb_data.args.hipWaitExternalSemaphoresAsync.extSemArray = \ + (const hipExternalSemaphore_t *)extSemArray; \ + cb_data.args.hipWaitExternalSemaphoresAsync.paramsArray = \ + (const hipExternalSemaphoreWaitParams *)paramsArray; \ + cb_data.args.hipWaitExternalSemaphoresAsync.numExtSems = \ + (unsigned int)numExtSems; \ + cb_data.args.hipWaitExternalSemaphoresAsync.stream = (hipStream_t)stream; \ + }; +#define INIT_CB_ARGS_DATA(cb_id, cb_data) INIT_##cb_id##_CB_ARGS_DATA(cb_data) + +// Macros for non-public API primitives +// hipBindTexture() +#define INIT_hipBindTexture_CB_ARGS_DATA(cb_data) {}; +// hipBindTexture2D() +#define INIT_hipBindTexture2D_CB_ARGS_DATA(cb_data) {}; +// hipBindTextureToArray() +#define INIT_hipBindTextureToArray_CB_ARGS_DATA(cb_data) {}; +// hipBindTextureToMipmappedArray() +#define INIT_hipBindTextureToMipmappedArray_CB_ARGS_DATA(cb_data) {}; +// hipCreateTextureObject() +#define INIT_hipCreateTextureObject_CB_ARGS_DATA(cb_data) {}; +// hipDestroyTextureObject() +#define INIT_hipDestroyTextureObject_CB_ARGS_DATA(cb_data) {}; +// hipDeviceGetCount() +#define INIT_hipDeviceGetCount_CB_ARGS_DATA(cb_data) {}; +// hipGetTextureAlignmentOffset() +#define INIT_hipGetTextureAlignmentOffset_CB_ARGS_DATA(cb_data) {}; +// hipGetTextureObjectResourceDesc() +#define INIT_hipGetTextureObjectResourceDesc_CB_ARGS_DATA(cb_data) {}; +// hipGetTextureObjectResourceViewDesc() +#define INIT_hipGetTextureObjectResourceViewDesc_CB_ARGS_DATA(cb_data) {}; +// hipGetTextureObjectTextureDesc() +#define INIT_hipGetTextureObjectTextureDesc_CB_ARGS_DATA(cb_data) {}; +// hipGetTextureReference() +#define INIT_hipGetTextureReference_CB_ARGS_DATA(cb_data) {}; +// hipTexObjectCreate() +#define INIT_hipTexObjectCreate_CB_ARGS_DATA(cb_data) {}; +// hipTexObjectDestroy() +#define INIT_hipTexObjectDestroy_CB_ARGS_DATA(cb_data) {}; +// hipTexObjectGetResourceDesc() +#define INIT_hipTexObjectGetResourceDesc_CB_ARGS_DATA(cb_data) {}; +// hipTexObjectGetResourceViewDesc() +#define INIT_hipTexObjectGetResourceViewDesc_CB_ARGS_DATA(cb_data) {}; +// hipTexObjectGetTextureDesc() +#define INIT_hipTexObjectGetTextureDesc_CB_ARGS_DATA(cb_data) {}; +// hipTexRefGetAddressMode() +#define INIT_hipTexRefGetAddressMode_CB_ARGS_DATA(cb_data) {}; +// hipTexRefGetFilterMode() +#define INIT_hipTexRefGetFilterMode_CB_ARGS_DATA(cb_data) {}; +// hipTexRefGetMipmapFilterMode() +#define INIT_hipTexRefGetMipmapFilterMode_CB_ARGS_DATA(cb_data) {}; +// hipTexRefSetAddressMode() +#define INIT_hipTexRefSetAddressMode_CB_ARGS_DATA(cb_data) {}; +// hipTexRefSetFilterMode() +#define INIT_hipTexRefSetFilterMode_CB_ARGS_DATA(cb_data) {}; +// hipTexRefSetMipmapFilterMode() +#define INIT_hipTexRefSetMipmapFilterMode_CB_ARGS_DATA(cb_data) {}; +// hipUnbindTexture() +#define INIT_hipUnbindTexture_CB_ARGS_DATA(cb_data) {}; + +#define INIT_NONE_CB_ARGS_DATA(cb_data) {}; + +#if HIP_PROF_HIP_API_STRING +// HIP API args filling helper +static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t *data) { + switch (id) { + // __hipPopCallConfiguration[('dim3*', 'gridDim'), ('dim3*', 'blockDim'), + // ('size_t*', 'sharedMem'), ('hipStream_t*', 'stream')] + case HIP_API_ID___hipPopCallConfiguration: + if (data->args.__hipPopCallConfiguration.gridDim) + data->args.__hipPopCallConfiguration.gridDim__val = + *(data->args.__hipPopCallConfiguration.gridDim); + if (data->args.__hipPopCallConfiguration.blockDim) + data->args.__hipPopCallConfiguration.blockDim__val = + *(data->args.__hipPopCallConfiguration.blockDim); + if (data->args.__hipPopCallConfiguration.sharedMem) + data->args.__hipPopCallConfiguration.sharedMem__val = + *(data->args.__hipPopCallConfiguration.sharedMem); + if (data->args.__hipPopCallConfiguration.stream) + data->args.__hipPopCallConfiguration.stream__val = + *(data->args.__hipPopCallConfiguration.stream); + break; + // __hipPushCallConfiguration[('dim3', 'gridDim'), ('dim3', 'blockDim'), + // ('size_t', 'sharedMem'), ('hipStream_t', 'stream')] + case HIP_API_ID___hipPushCallConfiguration: + break; + // hipArray3DCreate[('hipArray_t*', 'array'), ('const + // HIP_ARRAY3D_DESCRIPTOR*', 'pAllocateArray')] + case HIP_API_ID_hipArray3DCreate: + if (data->args.hipArray3DCreate.array) + data->args.hipArray3DCreate.array__val = + *(data->args.hipArray3DCreate.array); + if (data->args.hipArray3DCreate.pAllocateArray) + data->args.hipArray3DCreate.pAllocateArray__val = + *(data->args.hipArray3DCreate.pAllocateArray); + break; + // hipArray3DGetDescriptor[('HIP_ARRAY3D_DESCRIPTOR*', 'pArrayDescriptor'), + // ('hipArray_t', 'array')] + case HIP_API_ID_hipArray3DGetDescriptor: + if (data->args.hipArray3DGetDescriptor.pArrayDescriptor) + data->args.hipArray3DGetDescriptor.pArrayDescriptor__val = + *(data->args.hipArray3DGetDescriptor.pArrayDescriptor); + break; + // hipArrayCreate[('hipArray_t*', 'pHandle'), ('const + // HIP_ARRAY_DESCRIPTOR*', 'pAllocateArray')] + case HIP_API_ID_hipArrayCreate: + if (data->args.hipArrayCreate.pHandle) + data->args.hipArrayCreate.pHandle__val = + *(data->args.hipArrayCreate.pHandle); + if (data->args.hipArrayCreate.pAllocateArray) + data->args.hipArrayCreate.pAllocateArray__val = + *(data->args.hipArrayCreate.pAllocateArray); + break; + // hipArrayDestroy[('hipArray_t', 'array')] + case HIP_API_ID_hipArrayDestroy: + break; + // hipArrayGetDescriptor[('HIP_ARRAY_DESCRIPTOR*', 'pArrayDescriptor'), + // ('hipArray_t', 'array')] + case HIP_API_ID_hipArrayGetDescriptor: + if (data->args.hipArrayGetDescriptor.pArrayDescriptor) + data->args.hipArrayGetDescriptor.pArrayDescriptor__val = + *(data->args.hipArrayGetDescriptor.pArrayDescriptor); + break; + // hipArrayGetInfo[('hipChannelFormatDesc*', 'desc'), ('hipExtent*', + // 'extent'), ('unsigned int*', 'flags'), ('hipArray_t', 'array')] + case HIP_API_ID_hipArrayGetInfo: + if (data->args.hipArrayGetInfo.desc) + data->args.hipArrayGetInfo.desc__val = *(data->args.hipArrayGetInfo.desc); + if (data->args.hipArrayGetInfo.extent) + data->args.hipArrayGetInfo.extent__val = + *(data->args.hipArrayGetInfo.extent); + if (data->args.hipArrayGetInfo.flags) + data->args.hipArrayGetInfo.flags__val = + *(data->args.hipArrayGetInfo.flags); + break; + // hipChooseDeviceR0000[('int*', 'device'), ('const hipDeviceProp_tR0000*', + // 'prop')] + case HIP_API_ID_hipChooseDeviceR0000: + if (data->args.hipChooseDeviceR0000.device) + data->args.hipChooseDeviceR0000.device__val = + *(data->args.hipChooseDeviceR0000.device); + if (data->args.hipChooseDeviceR0000.prop) + data->args.hipChooseDeviceR0000.prop__val = + *(data->args.hipChooseDeviceR0000.prop); + break; + // hipChooseDeviceR0600[('int*', 'device'), ('const hipDeviceProp_tR0600*', + // 'prop')] + case HIP_API_ID_hipChooseDeviceR0600: + if (data->args.hipChooseDeviceR0600.device) + data->args.hipChooseDeviceR0600.device__val = + *(data->args.hipChooseDeviceR0600.device); + if (data->args.hipChooseDeviceR0600.prop) + data->args.hipChooseDeviceR0600.prop__val = + *(data->args.hipChooseDeviceR0600.prop); + break; + // hipConfigureCall[('dim3', 'gridDim'), ('dim3', 'blockDim'), ('size_t', + // 'sharedMem'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipConfigureCall: + break; + // hipCreateSurfaceObject[('hipSurfaceObject_t*', 'pSurfObject'), ('const + // hipResourceDesc*', 'pResDesc')] + case HIP_API_ID_hipCreateSurfaceObject: + if (data->args.hipCreateSurfaceObject.pSurfObject) + data->args.hipCreateSurfaceObject.pSurfObject__val = + *(data->args.hipCreateSurfaceObject.pSurfObject); + if (data->args.hipCreateSurfaceObject.pResDesc) + data->args.hipCreateSurfaceObject.pResDesc__val = + *(data->args.hipCreateSurfaceObject.pResDesc); + break; + // hipCtxCreate[('hipCtx_t*', 'ctx'), ('unsigned int', 'flags'), + // ('hipDevice_t', 'device')] + case HIP_API_ID_hipCtxCreate: + if (data->args.hipCtxCreate.ctx) + data->args.hipCtxCreate.ctx__val = *(data->args.hipCtxCreate.ctx); + break; + // hipCtxDestroy[('hipCtx_t', 'ctx')] + case HIP_API_ID_hipCtxDestroy: + break; + // hipCtxDisablePeerAccess[('hipCtx_t', 'peerCtx')] + case HIP_API_ID_hipCtxDisablePeerAccess: + break; + // hipCtxEnablePeerAccess[('hipCtx_t', 'peerCtx'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipCtxEnablePeerAccess: + break; + // hipCtxGetApiVersion[('hipCtx_t', 'ctx'), ('int*', 'apiVersion')] + case HIP_API_ID_hipCtxGetApiVersion: + if (data->args.hipCtxGetApiVersion.apiVersion) + data->args.hipCtxGetApiVersion.apiVersion__val = + *(data->args.hipCtxGetApiVersion.apiVersion); + break; + // hipCtxGetCacheConfig[('hipFuncCache_t*', 'cacheConfig')] + case HIP_API_ID_hipCtxGetCacheConfig: + if (data->args.hipCtxGetCacheConfig.cacheConfig) + data->args.hipCtxGetCacheConfig.cacheConfig__val = + *(data->args.hipCtxGetCacheConfig.cacheConfig); + break; + // hipCtxGetCurrent[('hipCtx_t*', 'ctx')] + case HIP_API_ID_hipCtxGetCurrent: + if (data->args.hipCtxGetCurrent.ctx) + data->args.hipCtxGetCurrent.ctx__val = *(data->args.hipCtxGetCurrent.ctx); + break; + // hipCtxGetDevice[('hipDevice_t*', 'device')] + case HIP_API_ID_hipCtxGetDevice: + if (data->args.hipCtxGetDevice.device) + data->args.hipCtxGetDevice.device__val = + *(data->args.hipCtxGetDevice.device); + break; + // hipCtxGetFlags[('unsigned int*', 'flags')] + case HIP_API_ID_hipCtxGetFlags: + if (data->args.hipCtxGetFlags.flags) + data->args.hipCtxGetFlags.flags__val = *(data->args.hipCtxGetFlags.flags); + break; + // hipCtxGetSharedMemConfig[('hipSharedMemConfig*', 'pConfig')] + case HIP_API_ID_hipCtxGetSharedMemConfig: + if (data->args.hipCtxGetSharedMemConfig.pConfig) + data->args.hipCtxGetSharedMemConfig.pConfig__val = + *(data->args.hipCtxGetSharedMemConfig.pConfig); + break; + // hipCtxPopCurrent[('hipCtx_t*', 'ctx')] + case HIP_API_ID_hipCtxPopCurrent: + if (data->args.hipCtxPopCurrent.ctx) + data->args.hipCtxPopCurrent.ctx__val = *(data->args.hipCtxPopCurrent.ctx); + break; + // hipCtxPushCurrent[('hipCtx_t', 'ctx')] + case HIP_API_ID_hipCtxPushCurrent: + break; + // hipCtxSetCacheConfig[('hipFuncCache_t', 'cacheConfig')] + case HIP_API_ID_hipCtxSetCacheConfig: + break; + // hipCtxSetCurrent[('hipCtx_t', 'ctx')] + case HIP_API_ID_hipCtxSetCurrent: + break; + // hipCtxSetSharedMemConfig[('hipSharedMemConfig', 'config')] + case HIP_API_ID_hipCtxSetSharedMemConfig: + break; + // hipCtxSynchronize[] + case HIP_API_ID_hipCtxSynchronize: + break; + // hipDestroyExternalMemory[('hipExternalMemory_t', 'extMem')] + case HIP_API_ID_hipDestroyExternalMemory: + break; + // hipDestroyExternalSemaphore[('hipExternalSemaphore_t', 'extSem')] + case HIP_API_ID_hipDestroyExternalSemaphore: + break; + // hipDestroySurfaceObject[('hipSurfaceObject_t', 'surfaceObject')] + case HIP_API_ID_hipDestroySurfaceObject: + break; + // hipDeviceCanAccessPeer[('int*', 'canAccessPeer'), ('int', 'deviceId'), + // ('int', 'peerDeviceId')] + case HIP_API_ID_hipDeviceCanAccessPeer: + if (data->args.hipDeviceCanAccessPeer.canAccessPeer) + data->args.hipDeviceCanAccessPeer.canAccessPeer__val = + *(data->args.hipDeviceCanAccessPeer.canAccessPeer); + break; + // hipDeviceComputeCapability[('int*', 'major'), ('int*', 'minor'), + // ('hipDevice_t', 'device')] + case HIP_API_ID_hipDeviceComputeCapability: + if (data->args.hipDeviceComputeCapability.major) + data->args.hipDeviceComputeCapability.major__val = + *(data->args.hipDeviceComputeCapability.major); + if (data->args.hipDeviceComputeCapability.minor) + data->args.hipDeviceComputeCapability.minor__val = + *(data->args.hipDeviceComputeCapability.minor); + break; + // hipDeviceDisablePeerAccess[('int', 'peerDeviceId')] + case HIP_API_ID_hipDeviceDisablePeerAccess: + break; + // hipDeviceEnablePeerAccess[('int', 'peerDeviceId'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipDeviceEnablePeerAccess: + break; + // hipDeviceGet[('hipDevice_t*', 'device'), ('int', 'ordinal')] + case HIP_API_ID_hipDeviceGet: + if (data->args.hipDeviceGet.device) + data->args.hipDeviceGet.device__val = *(data->args.hipDeviceGet.device); + break; + // hipDeviceGetAttribute[('int*', 'pi'), ('hipDeviceAttribute_t', 'attr'), + // ('int', 'deviceId')] + case HIP_API_ID_hipDeviceGetAttribute: + if (data->args.hipDeviceGetAttribute.pi) + data->args.hipDeviceGetAttribute.pi__val = + *(data->args.hipDeviceGetAttribute.pi); + break; + // hipDeviceGetByPCIBusId[('int*', 'device'), ('const char*', 'pciBusId')] + case HIP_API_ID_hipDeviceGetByPCIBusId: + if (data->args.hipDeviceGetByPCIBusId.device) + data->args.hipDeviceGetByPCIBusId.device__val = + *(data->args.hipDeviceGetByPCIBusId.device); + if (data->args.hipDeviceGetByPCIBusId.pciBusId) + data->args.hipDeviceGetByPCIBusId.pciBusId__val = + *(data->args.hipDeviceGetByPCIBusId.pciBusId); + break; + // hipDeviceGetCacheConfig[('hipFuncCache_t*', 'cacheConfig')] + case HIP_API_ID_hipDeviceGetCacheConfig: + if (data->args.hipDeviceGetCacheConfig.cacheConfig) + data->args.hipDeviceGetCacheConfig.cacheConfig__val = + *(data->args.hipDeviceGetCacheConfig.cacheConfig); + break; + // hipDeviceGetDefaultMemPool[('hipMemPool_t*', 'mem_pool'), ('int', + // 'device')] + case HIP_API_ID_hipDeviceGetDefaultMemPool: + if (data->args.hipDeviceGetDefaultMemPool.mem_pool) + data->args.hipDeviceGetDefaultMemPool.mem_pool__val = + *(data->args.hipDeviceGetDefaultMemPool.mem_pool); + break; + // hipDeviceGetGraphMemAttribute[('int', 'device'), + // ('hipGraphMemAttributeType', 'attr'), ('void*', 'value')] + case HIP_API_ID_hipDeviceGetGraphMemAttribute: + break; + // hipDeviceGetLimit[('size_t*', 'pValue'), ('hipLimit_t', 'limit')] + case HIP_API_ID_hipDeviceGetLimit: + if (data->args.hipDeviceGetLimit.pValue) + data->args.hipDeviceGetLimit.pValue__val = + *(data->args.hipDeviceGetLimit.pValue); + break; + // hipDeviceGetMemPool[('hipMemPool_t*', 'mem_pool'), ('int', 'device')] + case HIP_API_ID_hipDeviceGetMemPool: + if (data->args.hipDeviceGetMemPool.mem_pool) + data->args.hipDeviceGetMemPool.mem_pool__val = + *(data->args.hipDeviceGetMemPool.mem_pool); + break; + // hipDeviceGetName[('char*', 'name'), ('int', 'len'), ('hipDevice_t', + // 'device')] + case HIP_API_ID_hipDeviceGetName: + data->args.hipDeviceGetName.name = + (data->args.hipDeviceGetName.name) + ? strdup(data->args.hipDeviceGetName.name) + : NULL; + break; + // hipDeviceGetP2PAttribute[('int*', 'value'), ('hipDeviceP2PAttr', 'attr'), + // ('int', 'srcDevice'), ('int', 'dstDevice')] + case HIP_API_ID_hipDeviceGetP2PAttribute: + if (data->args.hipDeviceGetP2PAttribute.value) + data->args.hipDeviceGetP2PAttribute.value__val = + *(data->args.hipDeviceGetP2PAttribute.value); + break; + // hipDeviceGetPCIBusId[('char*', 'pciBusId'), ('int', 'len'), ('int', + // 'device')] + case HIP_API_ID_hipDeviceGetPCIBusId: + data->args.hipDeviceGetPCIBusId.pciBusId = + (data->args.hipDeviceGetPCIBusId.pciBusId) + ? strdup(data->args.hipDeviceGetPCIBusId.pciBusId) + : NULL; + break; + // hipDeviceGetSharedMemConfig[('hipSharedMemConfig*', 'pConfig')] + case HIP_API_ID_hipDeviceGetSharedMemConfig: + if (data->args.hipDeviceGetSharedMemConfig.pConfig) + data->args.hipDeviceGetSharedMemConfig.pConfig__val = + *(data->args.hipDeviceGetSharedMemConfig.pConfig); + break; + // hipDeviceGetStreamPriorityRange[('int*', 'leastPriority'), ('int*', + // 'greatestPriority')] + case HIP_API_ID_hipDeviceGetStreamPriorityRange: + if (data->args.hipDeviceGetStreamPriorityRange.leastPriority) + data->args.hipDeviceGetStreamPriorityRange.leastPriority__val = + *(data->args.hipDeviceGetStreamPriorityRange.leastPriority); + if (data->args.hipDeviceGetStreamPriorityRange.greatestPriority) + data->args.hipDeviceGetStreamPriorityRange.greatestPriority__val = + *(data->args.hipDeviceGetStreamPriorityRange.greatestPriority); + break; + // hipDeviceGetUuid[('hipUUID*', 'uuid'), ('hipDevice_t', 'device')] + case HIP_API_ID_hipDeviceGetUuid: + if (data->args.hipDeviceGetUuid.uuid) + data->args.hipDeviceGetUuid.uuid__val = + *(data->args.hipDeviceGetUuid.uuid); + break; + // hipDeviceGraphMemTrim[('int', 'device')] + case HIP_API_ID_hipDeviceGraphMemTrim: + break; + // hipDevicePrimaryCtxGetState[('hipDevice_t', 'dev'), ('unsigned int*', + // 'flags'), ('int*', 'active')] + case HIP_API_ID_hipDevicePrimaryCtxGetState: + if (data->args.hipDevicePrimaryCtxGetState.flags) + data->args.hipDevicePrimaryCtxGetState.flags__val = + *(data->args.hipDevicePrimaryCtxGetState.flags); + if (data->args.hipDevicePrimaryCtxGetState.active) + data->args.hipDevicePrimaryCtxGetState.active__val = + *(data->args.hipDevicePrimaryCtxGetState.active); + break; + // hipDevicePrimaryCtxRelease[('hipDevice_t', 'dev')] + case HIP_API_ID_hipDevicePrimaryCtxRelease: + break; + // hipDevicePrimaryCtxReset[('hipDevice_t', 'dev')] + case HIP_API_ID_hipDevicePrimaryCtxReset: + break; + // hipDevicePrimaryCtxRetain[('hipCtx_t*', 'pctx'), ('hipDevice_t', 'dev')] + case HIP_API_ID_hipDevicePrimaryCtxRetain: + if (data->args.hipDevicePrimaryCtxRetain.pctx) + data->args.hipDevicePrimaryCtxRetain.pctx__val = + *(data->args.hipDevicePrimaryCtxRetain.pctx); + break; + // hipDevicePrimaryCtxSetFlags[('hipDevice_t', 'dev'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipDevicePrimaryCtxSetFlags: + break; + // hipDeviceReset[] + case HIP_API_ID_hipDeviceReset: + break; + // hipDeviceSetCacheConfig[('hipFuncCache_t', 'cacheConfig')] + case HIP_API_ID_hipDeviceSetCacheConfig: + break; + // hipDeviceSetGraphMemAttribute[('int', 'device'), + // ('hipGraphMemAttributeType', 'attr'), ('void*', 'value')] + case HIP_API_ID_hipDeviceSetGraphMemAttribute: + break; + // hipDeviceSetLimit[('hipLimit_t', 'limit'), ('size_t', 'value')] + case HIP_API_ID_hipDeviceSetLimit: + break; + // hipDeviceSetMemPool[('int', 'device'), ('hipMemPool_t', 'mem_pool')] + case HIP_API_ID_hipDeviceSetMemPool: + break; + // hipDeviceSetSharedMemConfig[('hipSharedMemConfig', 'config')] + case HIP_API_ID_hipDeviceSetSharedMemConfig: + break; + // hipDeviceSynchronize[] + case HIP_API_ID_hipDeviceSynchronize: + break; + // hipDeviceTotalMem[('size_t*', 'bytes'), ('hipDevice_t', 'device')] + case HIP_API_ID_hipDeviceTotalMem: + if (data->args.hipDeviceTotalMem.bytes) + data->args.hipDeviceTotalMem.bytes__val = + *(data->args.hipDeviceTotalMem.bytes); + break; + // hipDriverGetVersion[('int*', 'driverVersion')] + case HIP_API_ID_hipDriverGetVersion: + if (data->args.hipDriverGetVersion.driverVersion) + data->args.hipDriverGetVersion.driverVersion__val = + *(data->args.hipDriverGetVersion.driverVersion); + break; + // hipDrvGraphAddMemcpyNode[('hipGraphNode_t*', 'phGraphNode'), + // ('hipGraph_t', 'hGraph'), ('const hipGraphNode_t*', 'dependencies'), + // ('size_t', 'numDependencies'), ('const HIP_MEMCPY3D*', 'copyParams'), + // ('hipCtx_t', 'ctx')] + case HIP_API_ID_hipDrvGraphAddMemcpyNode: + if (data->args.hipDrvGraphAddMemcpyNode.phGraphNode) + data->args.hipDrvGraphAddMemcpyNode.phGraphNode__val = + *(data->args.hipDrvGraphAddMemcpyNode.phGraphNode); + if (data->args.hipDrvGraphAddMemcpyNode.dependencies) + data->args.hipDrvGraphAddMemcpyNode.dependencies__val = + *(data->args.hipDrvGraphAddMemcpyNode.dependencies); + if (data->args.hipDrvGraphAddMemcpyNode.copyParams) + data->args.hipDrvGraphAddMemcpyNode.copyParams__val = + *(data->args.hipDrvGraphAddMemcpyNode.copyParams); + break; + // hipDrvGraphAddMemsetNode[('hipGraphNode_t*', 'phGraphNode'), + // ('hipGraph_t', 'hGraph'), ('const hipGraphNode_t*', 'dependencies'), + // ('size_t', 'numDependencies'), ('const HIP_MEMSET_NODE_PARAMS*', + // 'memsetParams'), ('hipCtx_t', 'ctx')] + case HIP_API_ID_hipDrvGraphAddMemsetNode: + if (data->args.hipDrvGraphAddMemsetNode.phGraphNode) + data->args.hipDrvGraphAddMemsetNode.phGraphNode__val = + *(data->args.hipDrvGraphAddMemsetNode.phGraphNode); + if (data->args.hipDrvGraphAddMemsetNode.dependencies) + data->args.hipDrvGraphAddMemsetNode.dependencies__val = + *(data->args.hipDrvGraphAddMemsetNode.dependencies); + if (data->args.hipDrvGraphAddMemsetNode.memsetParams) + data->args.hipDrvGraphAddMemsetNode.memsetParams__val = + *(data->args.hipDrvGraphAddMemsetNode.memsetParams); + break; + // hipDrvMemcpy2DUnaligned[('const hip_Memcpy2D*', 'pCopy')] + case HIP_API_ID_hipDrvMemcpy2DUnaligned: + if (data->args.hipDrvMemcpy2DUnaligned.pCopy) + data->args.hipDrvMemcpy2DUnaligned.pCopy__val = + *(data->args.hipDrvMemcpy2DUnaligned.pCopy); + break; + // hipDrvMemcpy3D[('const HIP_MEMCPY3D*', 'pCopy')] + case HIP_API_ID_hipDrvMemcpy3D: + if (data->args.hipDrvMemcpy3D.pCopy) + data->args.hipDrvMemcpy3D.pCopy__val = *(data->args.hipDrvMemcpy3D.pCopy); + break; + // hipDrvMemcpy3DAsync[('const HIP_MEMCPY3D*', 'pCopy'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipDrvMemcpy3DAsync: + if (data->args.hipDrvMemcpy3DAsync.pCopy) + data->args.hipDrvMemcpy3DAsync.pCopy__val = + *(data->args.hipDrvMemcpy3DAsync.pCopy); + break; + // hipDrvPointerGetAttributes[('unsigned int', 'numAttributes'), + // ('hipPointer_attribute*', 'attributes'), ('void**', 'data'), + // ('hipDeviceptr_t', 'ptr')] + case HIP_API_ID_hipDrvPointerGetAttributes: + if (data->args.hipDrvPointerGetAttributes.attributes) + data->args.hipDrvPointerGetAttributes.attributes__val = + *(data->args.hipDrvPointerGetAttributes.attributes); + if (data->args.hipDrvPointerGetAttributes.data) + data->args.hipDrvPointerGetAttributes.data__val = + *(data->args.hipDrvPointerGetAttributes.data); + break; + // hipEventCreate[('hipEvent_t*', 'event')] + case HIP_API_ID_hipEventCreate: + if (data->args.hipEventCreate.event) + data->args.hipEventCreate.event__val = *(data->args.hipEventCreate.event); + break; + // hipEventCreateWithFlags[('hipEvent_t*', 'event'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipEventCreateWithFlags: + if (data->args.hipEventCreateWithFlags.event) + data->args.hipEventCreateWithFlags.event__val = + *(data->args.hipEventCreateWithFlags.event); + break; + // hipEventDestroy[('hipEvent_t', 'event')] + case HIP_API_ID_hipEventDestroy: + break; + // hipEventElapsedTime[('float*', 'ms'), ('hipEvent_t', 'start'), + // ('hipEvent_t', 'stop')] + case HIP_API_ID_hipEventElapsedTime: + if (data->args.hipEventElapsedTime.ms) + data->args.hipEventElapsedTime.ms__val = + *(data->args.hipEventElapsedTime.ms); + break; + // hipEventQuery[('hipEvent_t', 'event')] + case HIP_API_ID_hipEventQuery: + break; + // hipEventRecord[('hipEvent_t', 'event'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipEventRecord: + break; + // hipEventSynchronize[('hipEvent_t', 'event')] + case HIP_API_ID_hipEventSynchronize: + break; + // hipExtGetLastError[] + case HIP_API_ID_hipExtGetLastError: + break; + // hipExtGetLinkTypeAndHopCount[('int', 'device1'), ('int', 'device2'), + // ('unsigned int*', 'linktype'), ('unsigned int*', 'hopcount')] + case HIP_API_ID_hipExtGetLinkTypeAndHopCount: + if (data->args.hipExtGetLinkTypeAndHopCount.linktype) + data->args.hipExtGetLinkTypeAndHopCount.linktype__val = + *(data->args.hipExtGetLinkTypeAndHopCount.linktype); + if (data->args.hipExtGetLinkTypeAndHopCount.hopcount) + data->args.hipExtGetLinkTypeAndHopCount.hopcount__val = + *(data->args.hipExtGetLinkTypeAndHopCount.hopcount); + break; + // hipExtLaunchKernel[('const void*', 'function_address'), ('dim3', + // 'numBlocks'), ('dim3', 'dimBlocks'), ('void**', 'args'), ('size_t', + // 'sharedMemBytes'), ('hipStream_t', 'stream'), ('hipEvent_t', + // 'startEvent'), ('hipEvent_t', 'stopEvent'), ('int', 'flags')] + case HIP_API_ID_hipExtLaunchKernel: + if (data->args.hipExtLaunchKernel.args) + data->args.hipExtLaunchKernel.args__val = + *(data->args.hipExtLaunchKernel.args); + break; + // hipExtLaunchMultiKernelMultiDevice[('hipLaunchParams*', + // 'launchParamsList'), ('int', 'numDevices'), ('unsigned int', 'flags')] + case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: + if (data->args.hipExtLaunchMultiKernelMultiDevice.launchParamsList) + data->args.hipExtLaunchMultiKernelMultiDevice.launchParamsList__val = + *(data->args.hipExtLaunchMultiKernelMultiDevice.launchParamsList); + break; + // hipExtMallocWithFlags[('void**', 'ptr'), ('size_t', 'sizeBytes'), + // ('unsigned int', 'flags')] + case HIP_API_ID_hipExtMallocWithFlags: + if (data->args.hipExtMallocWithFlags.ptr) + data->args.hipExtMallocWithFlags.ptr__val = + *(data->args.hipExtMallocWithFlags.ptr); + break; + // hipExtModuleLaunchKernel[('hipFunction_t', 'f'), ('unsigned int', + // 'globalWorkSizeX'), ('unsigned int', 'globalWorkSizeY'), ('unsigned int', + // 'globalWorkSizeZ'), ('unsigned int', 'localWorkSizeX'), ('unsigned int', + // 'localWorkSizeY'), ('unsigned int', 'localWorkSizeZ'), ('size_t', + // 'sharedMemBytes'), ('hipStream_t', 'hStream'), ('void**', + // 'kernelParams'), ('void**', 'extra'), ('hipEvent_t', 'startEvent'), + // ('hipEvent_t', 'stopEvent'), ('unsigned int', 'flags')] + case HIP_API_ID_hipExtModuleLaunchKernel: + if (data->args.hipExtModuleLaunchKernel.kernelParams) + data->args.hipExtModuleLaunchKernel.kernelParams__val = + *(data->args.hipExtModuleLaunchKernel.kernelParams); + if (data->args.hipExtModuleLaunchKernel.extra) + data->args.hipExtModuleLaunchKernel.extra__val = + *(data->args.hipExtModuleLaunchKernel.extra); + break; + // hipExtStreamCreateWithCUMask[('hipStream_t*', 'stream'), ('unsigned int', + // 'cuMaskSize'), ('const unsigned int*', 'cuMask')] + case HIP_API_ID_hipExtStreamCreateWithCUMask: + if (data->args.hipExtStreamCreateWithCUMask.stream) + data->args.hipExtStreamCreateWithCUMask.stream__val = + *(data->args.hipExtStreamCreateWithCUMask.stream); + if (data->args.hipExtStreamCreateWithCUMask.cuMask) + data->args.hipExtStreamCreateWithCUMask.cuMask__val = + *(data->args.hipExtStreamCreateWithCUMask.cuMask); + break; + // hipExtStreamGetCUMask[('hipStream_t', 'stream'), ('unsigned int', + // 'cuMaskSize'), ('unsigned int*', 'cuMask')] + case HIP_API_ID_hipExtStreamGetCUMask: + if (data->args.hipExtStreamGetCUMask.cuMask) + data->args.hipExtStreamGetCUMask.cuMask__val = + *(data->args.hipExtStreamGetCUMask.cuMask); + break; + // hipExternalMemoryGetMappedBuffer[('void**', 'devPtr'), + // ('hipExternalMemory_t', 'extMem'), ('const hipExternalMemoryBufferDesc*', + // 'bufferDesc')] + case HIP_API_ID_hipExternalMemoryGetMappedBuffer: + if (data->args.hipExternalMemoryGetMappedBuffer.devPtr) + data->args.hipExternalMemoryGetMappedBuffer.devPtr__val = + *(data->args.hipExternalMemoryGetMappedBuffer.devPtr); + if (data->args.hipExternalMemoryGetMappedBuffer.bufferDesc) + data->args.hipExternalMemoryGetMappedBuffer.bufferDesc__val = + *(data->args.hipExternalMemoryGetMappedBuffer.bufferDesc); + break; + // hipExternalMemoryGetMappedMipmappedArray[('hipMipmappedArray_t*', + // 'mipmap'), ('hipExternalMemory_t', 'extMem'), ('const + // hipExternalMemoryMipmappedArrayDesc*', 'mipmapDesc')] + case HIP_API_ID_hipExternalMemoryGetMappedMipmappedArray: + if (data->args.hipExternalMemoryGetMappedMipmappedArray.mipmap) + data->args.hipExternalMemoryGetMappedMipmappedArray.mipmap__val = + *(data->args.hipExternalMemoryGetMappedMipmappedArray.mipmap); + if (data->args.hipExternalMemoryGetMappedMipmappedArray.mipmapDesc) + data->args.hipExternalMemoryGetMappedMipmappedArray.mipmapDesc__val = + *(data->args.hipExternalMemoryGetMappedMipmappedArray.mipmapDesc); + break; + // hipFree[('void*', 'ptr')] + case HIP_API_ID_hipFree: + break; + // hipFreeArray[('hipArray_t', 'array')] + case HIP_API_ID_hipFreeArray: + break; + // hipFreeAsync[('void*', 'dev_ptr'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipFreeAsync: + break; + // hipFreeHost[('void*', 'ptr')] + case HIP_API_ID_hipFreeHost: + break; + // hipFreeMipmappedArray[('hipMipmappedArray_t', 'mipmappedArray')] + case HIP_API_ID_hipFreeMipmappedArray: + break; + // hipFuncGetAttribute[('int*', 'value'), ('hipFunction_attribute', + // 'attrib'), ('hipFunction_t', 'hfunc')] + case HIP_API_ID_hipFuncGetAttribute: + if (data->args.hipFuncGetAttribute.value) + data->args.hipFuncGetAttribute.value__val = + *(data->args.hipFuncGetAttribute.value); + break; + // hipFuncGetAttributes[('hipFuncAttributes*', 'attr'), ('const void*', + // 'func')] + case HIP_API_ID_hipFuncGetAttributes: + if (data->args.hipFuncGetAttributes.attr) + data->args.hipFuncGetAttributes.attr__val = + *(data->args.hipFuncGetAttributes.attr); + break; + // hipFuncSetAttribute[('const void*', 'func'), ('hipFuncAttribute', + // 'attr'), ('int', 'value')] + case HIP_API_ID_hipFuncSetAttribute: + break; + // hipFuncSetCacheConfig[('const void*', 'func'), ('hipFuncCache_t', + // 'config')] + case HIP_API_ID_hipFuncSetCacheConfig: + break; + // hipFuncSetSharedMemConfig[('const void*', 'func'), ('hipSharedMemConfig', + // 'config')] + case HIP_API_ID_hipFuncSetSharedMemConfig: + break; + // hipGLGetDevices[('unsigned int*', 'pHipDeviceCount'), ('int*', + // 'pHipDevices'), ('unsigned int', 'hipDeviceCount'), ('hipGLDeviceList', + // 'deviceList')] + case HIP_API_ID_hipGLGetDevices: + if (data->args.hipGLGetDevices.pHipDeviceCount) + data->args.hipGLGetDevices.pHipDeviceCount__val = + *(data->args.hipGLGetDevices.pHipDeviceCount); + if (data->args.hipGLGetDevices.pHipDevices) + data->args.hipGLGetDevices.pHipDevices__val = + *(data->args.hipGLGetDevices.pHipDevices); + break; + // hipGetChannelDesc[('hipChannelFormatDesc*', 'desc'), ('hipArray_const_t', + // 'array')] + case HIP_API_ID_hipGetChannelDesc: + if (data->args.hipGetChannelDesc.desc) + data->args.hipGetChannelDesc.desc__val = + *(data->args.hipGetChannelDesc.desc); + break; + // hipGetDevice[('int*', 'deviceId')] + case HIP_API_ID_hipGetDevice: + if (data->args.hipGetDevice.deviceId) + data->args.hipGetDevice.deviceId__val = + *(data->args.hipGetDevice.deviceId); + break; + // hipGetDeviceCount[('int*', 'count')] + case HIP_API_ID_hipGetDeviceCount: + if (data->args.hipGetDeviceCount.count) + data->args.hipGetDeviceCount.count__val = + *(data->args.hipGetDeviceCount.count); + break; + // hipGetDeviceFlags[('unsigned int*', 'flags')] + case HIP_API_ID_hipGetDeviceFlags: + if (data->args.hipGetDeviceFlags.flags) + data->args.hipGetDeviceFlags.flags__val = + *(data->args.hipGetDeviceFlags.flags); + break; + // hipGetDevicePropertiesR0000[('hipDeviceProp_tR0000*', 'prop'), ('int', + // 'device')] + case HIP_API_ID_hipGetDevicePropertiesR0000: + if (data->args.hipGetDevicePropertiesR0000.prop) + data->args.hipGetDevicePropertiesR0000.prop__val = + *(data->args.hipGetDevicePropertiesR0000.prop); + break; + // hipGetDevicePropertiesR0600[('hipDeviceProp_tR0600*', 'prop'), ('int', + // 'deviceId')] + case HIP_API_ID_hipGetDevicePropertiesR0600: + if (data->args.hipGetDevicePropertiesR0600.prop) + data->args.hipGetDevicePropertiesR0600.prop__val = + *(data->args.hipGetDevicePropertiesR0600.prop); + break; + // hipGetErrorString[] + case HIP_API_ID_hipGetErrorString: + break; + // hipGetFuncBySymbol[('hipFunction_t*', 'functionPtr'), ('const void*', + // 'symbolPtr')] + case HIP_API_ID_hipGetFuncBySymbol: + if (data->args.hipGetFuncBySymbol.functionPtr) + data->args.hipGetFuncBySymbol.functionPtr__val = + *(data->args.hipGetFuncBySymbol.functionPtr); + break; + // hipGetLastError[] + case HIP_API_ID_hipGetLastError: + break; + // hipGetMipmappedArrayLevel[('hipArray_t*', 'levelArray'), + // ('hipMipmappedArray_const_t', 'mipmappedArray'), ('unsigned int', + // 'level')] + case HIP_API_ID_hipGetMipmappedArrayLevel: + if (data->args.hipGetMipmappedArrayLevel.levelArray) + data->args.hipGetMipmappedArrayLevel.levelArray__val = + *(data->args.hipGetMipmappedArrayLevel.levelArray); + break; + // hipGetProcAddress[('const char*', 'symbol'), ('void**', 'pfn'), ('int', + // 'hipVersion'), ('uint64_t', 'flags'), + // ('hipDriverProcAddressQueryResult*', 'symbolStatus')] + case HIP_API_ID_hipGetProcAddress: + if (data->args.hipGetProcAddress.symbol) + data->args.hipGetProcAddress.symbol__val = + *(data->args.hipGetProcAddress.symbol); + if (data->args.hipGetProcAddress.pfn) + data->args.hipGetProcAddress.pfn__val = + *(data->args.hipGetProcAddress.pfn); + if (data->args.hipGetProcAddress.symbolStatus) + data->args.hipGetProcAddress.symbolStatus__val = + *(data->args.hipGetProcAddress.symbolStatus); + break; + // hipGetSymbolAddress[('void**', 'devPtr'), ('const void*', 'symbol')] + case HIP_API_ID_hipGetSymbolAddress: + if (data->args.hipGetSymbolAddress.devPtr) + data->args.hipGetSymbolAddress.devPtr__val = + *(data->args.hipGetSymbolAddress.devPtr); + break; + // hipGetSymbolSize[('size_t*', 'size'), ('const void*', 'symbol')] + case HIP_API_ID_hipGetSymbolSize: + if (data->args.hipGetSymbolSize.size) + data->args.hipGetSymbolSize.size__val = + *(data->args.hipGetSymbolSize.size); + break; + // hipGraphAddChildGraphNode[('hipGraphNode_t*', 'pGraphNode'), + // ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), + // ('size_t', 'numDependencies'), ('hipGraph_t', 'childGraph')] + case HIP_API_ID_hipGraphAddChildGraphNode: + if (data->args.hipGraphAddChildGraphNode.pGraphNode) + data->args.hipGraphAddChildGraphNode.pGraphNode__val = + *(data->args.hipGraphAddChildGraphNode.pGraphNode); + if (data->args.hipGraphAddChildGraphNode.pDependencies) + data->args.hipGraphAddChildGraphNode.pDependencies__val = + *(data->args.hipGraphAddChildGraphNode.pDependencies); + break; + // hipGraphAddDependencies[('hipGraph_t', 'graph'), ('const + // hipGraphNode_t*', 'from'), ('const hipGraphNode_t*', 'to'), ('size_t', + // 'numDependencies')] + case HIP_API_ID_hipGraphAddDependencies: + if (data->args.hipGraphAddDependencies.from) + data->args.hipGraphAddDependencies.from__val = + *(data->args.hipGraphAddDependencies.from); + if (data->args.hipGraphAddDependencies.to) + data->args.hipGraphAddDependencies.to__val = + *(data->args.hipGraphAddDependencies.to); + break; + // hipGraphAddEmptyNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies')] + case HIP_API_ID_hipGraphAddEmptyNode: + if (data->args.hipGraphAddEmptyNode.pGraphNode) + data->args.hipGraphAddEmptyNode.pGraphNode__val = + *(data->args.hipGraphAddEmptyNode.pGraphNode); + if (data->args.hipGraphAddEmptyNode.pDependencies) + data->args.hipGraphAddEmptyNode.pDependencies__val = + *(data->args.hipGraphAddEmptyNode.pDependencies); + break; + // hipGraphAddEventRecordNode[('hipGraphNode_t*', 'pGraphNode'), + // ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), + // ('size_t', 'numDependencies'), ('hipEvent_t', 'event')] + case HIP_API_ID_hipGraphAddEventRecordNode: + if (data->args.hipGraphAddEventRecordNode.pGraphNode) + data->args.hipGraphAddEventRecordNode.pGraphNode__val = + *(data->args.hipGraphAddEventRecordNode.pGraphNode); + if (data->args.hipGraphAddEventRecordNode.pDependencies) + data->args.hipGraphAddEventRecordNode.pDependencies__val = + *(data->args.hipGraphAddEventRecordNode.pDependencies); + break; + // hipGraphAddEventWaitNode[('hipGraphNode_t*', 'pGraphNode'), + // ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), + // ('size_t', 'numDependencies'), ('hipEvent_t', 'event')] + case HIP_API_ID_hipGraphAddEventWaitNode: + if (data->args.hipGraphAddEventWaitNode.pGraphNode) + data->args.hipGraphAddEventWaitNode.pGraphNode__val = + *(data->args.hipGraphAddEventWaitNode.pGraphNode); + if (data->args.hipGraphAddEventWaitNode.pDependencies) + data->args.hipGraphAddEventWaitNode.pDependencies__val = + *(data->args.hipGraphAddEventWaitNode.pDependencies); + break; + // hipGraphAddExternalSemaphoresSignalNode[('hipGraphNode_t*', + // 'pGraphNode'), ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', + // 'pDependencies'), ('size_t', 'numDependencies'), ('const + // hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphAddExternalSemaphoresSignalNode: + if (data->args.hipGraphAddExternalSemaphoresSignalNode.pGraphNode) + data->args.hipGraphAddExternalSemaphoresSignalNode.pGraphNode__val = + *(data->args.hipGraphAddExternalSemaphoresSignalNode.pGraphNode); + if (data->args.hipGraphAddExternalSemaphoresSignalNode.pDependencies) + data->args.hipGraphAddExternalSemaphoresSignalNode.pDependencies__val = + *(data->args.hipGraphAddExternalSemaphoresSignalNode.pDependencies); + if (data->args.hipGraphAddExternalSemaphoresSignalNode.nodeParams) + data->args.hipGraphAddExternalSemaphoresSignalNode.nodeParams__val = + *(data->args.hipGraphAddExternalSemaphoresSignalNode.nodeParams); + break; + // hipGraphAddExternalSemaphoresWaitNode[('hipGraphNode_t*', 'pGraphNode'), + // ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), + // ('size_t', 'numDependencies'), ('const + // hipExternalSemaphoreWaitNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphAddExternalSemaphoresWaitNode: + if (data->args.hipGraphAddExternalSemaphoresWaitNode.pGraphNode) + data->args.hipGraphAddExternalSemaphoresWaitNode.pGraphNode__val = + *(data->args.hipGraphAddExternalSemaphoresWaitNode.pGraphNode); + if (data->args.hipGraphAddExternalSemaphoresWaitNode.pDependencies) + data->args.hipGraphAddExternalSemaphoresWaitNode.pDependencies__val = + *(data->args.hipGraphAddExternalSemaphoresWaitNode.pDependencies); + if (data->args.hipGraphAddExternalSemaphoresWaitNode.nodeParams) + data->args.hipGraphAddExternalSemaphoresWaitNode.nodeParams__val = + *(data->args.hipGraphAddExternalSemaphoresWaitNode.nodeParams); + break; + // hipGraphAddHostNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies'), ('const hipHostNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphAddHostNode: + if (data->args.hipGraphAddHostNode.pGraphNode) + data->args.hipGraphAddHostNode.pGraphNode__val = + *(data->args.hipGraphAddHostNode.pGraphNode); + if (data->args.hipGraphAddHostNode.pDependencies) + data->args.hipGraphAddHostNode.pDependencies__val = + *(data->args.hipGraphAddHostNode.pDependencies); + if (data->args.hipGraphAddHostNode.pNodeParams) + data->args.hipGraphAddHostNode.pNodeParams__val = + *(data->args.hipGraphAddHostNode.pNodeParams); + break; + // hipGraphAddKernelNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies'), ('const hipKernelNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphAddKernelNode: + if (data->args.hipGraphAddKernelNode.pGraphNode) + data->args.hipGraphAddKernelNode.pGraphNode__val = + *(data->args.hipGraphAddKernelNode.pGraphNode); + if (data->args.hipGraphAddKernelNode.pDependencies) + data->args.hipGraphAddKernelNode.pDependencies__val = + *(data->args.hipGraphAddKernelNode.pDependencies); + if (data->args.hipGraphAddKernelNode.pNodeParams) + data->args.hipGraphAddKernelNode.pNodeParams__val = + *(data->args.hipGraphAddKernelNode.pNodeParams); + break; + // hipGraphAddMemAllocNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies'), ('hipMemAllocNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphAddMemAllocNode: + if (data->args.hipGraphAddMemAllocNode.pGraphNode) + data->args.hipGraphAddMemAllocNode.pGraphNode__val = + *(data->args.hipGraphAddMemAllocNode.pGraphNode); + if (data->args.hipGraphAddMemAllocNode.pDependencies) + data->args.hipGraphAddMemAllocNode.pDependencies__val = + *(data->args.hipGraphAddMemAllocNode.pDependencies); + if (data->args.hipGraphAddMemAllocNode.pNodeParams) + data->args.hipGraphAddMemAllocNode.pNodeParams__val = + *(data->args.hipGraphAddMemAllocNode.pNodeParams); + break; + // hipGraphAddMemFreeNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies'), ('void*', 'dev_ptr')] + case HIP_API_ID_hipGraphAddMemFreeNode: + if (data->args.hipGraphAddMemFreeNode.pGraphNode) + data->args.hipGraphAddMemFreeNode.pGraphNode__val = + *(data->args.hipGraphAddMemFreeNode.pGraphNode); + if (data->args.hipGraphAddMemFreeNode.pDependencies) + data->args.hipGraphAddMemFreeNode.pDependencies__val = + *(data->args.hipGraphAddMemFreeNode.pDependencies); + break; + // hipGraphAddMemcpyNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies'), ('const hipMemcpy3DParms*', 'pCopyParams')] + case HIP_API_ID_hipGraphAddMemcpyNode: + if (data->args.hipGraphAddMemcpyNode.pGraphNode) + data->args.hipGraphAddMemcpyNode.pGraphNode__val = + *(data->args.hipGraphAddMemcpyNode.pGraphNode); + if (data->args.hipGraphAddMemcpyNode.pDependencies) + data->args.hipGraphAddMemcpyNode.pDependencies__val = + *(data->args.hipGraphAddMemcpyNode.pDependencies); + if (data->args.hipGraphAddMemcpyNode.pCopyParams) + data->args.hipGraphAddMemcpyNode.pCopyParams__val = + *(data->args.hipGraphAddMemcpyNode.pCopyParams); + break; + // hipGraphAddMemcpyNode1D[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies'), ('void*', 'dst'), ('const void*', 'src'), ('size_t', + // 'count'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipGraphAddMemcpyNode1D: + if (data->args.hipGraphAddMemcpyNode1D.pGraphNode) + data->args.hipGraphAddMemcpyNode1D.pGraphNode__val = + *(data->args.hipGraphAddMemcpyNode1D.pGraphNode); + if (data->args.hipGraphAddMemcpyNode1D.pDependencies) + data->args.hipGraphAddMemcpyNode1D.pDependencies__val = + *(data->args.hipGraphAddMemcpyNode1D.pDependencies); + break; + // hipGraphAddMemcpyNodeFromSymbol[('hipGraphNode_t*', 'pGraphNode'), + // ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), + // ('size_t', 'numDependencies'), ('void*', 'dst'), ('const void*', + // 'symbol'), ('size_t', 'count'), ('size_t', 'offset'), ('hipMemcpyKind', + // 'kind')] + case HIP_API_ID_hipGraphAddMemcpyNodeFromSymbol: + if (data->args.hipGraphAddMemcpyNodeFromSymbol.pGraphNode) + data->args.hipGraphAddMemcpyNodeFromSymbol.pGraphNode__val = + *(data->args.hipGraphAddMemcpyNodeFromSymbol.pGraphNode); + if (data->args.hipGraphAddMemcpyNodeFromSymbol.pDependencies) + data->args.hipGraphAddMemcpyNodeFromSymbol.pDependencies__val = + *(data->args.hipGraphAddMemcpyNodeFromSymbol.pDependencies); + break; + // hipGraphAddMemcpyNodeToSymbol[('hipGraphNode_t*', 'pGraphNode'), + // ('hipGraph_t', 'graph'), ('const hipGraphNode_t*', 'pDependencies'), + // ('size_t', 'numDependencies'), ('const void*', 'symbol'), ('const void*', + // 'src'), ('size_t', 'count'), ('size_t', 'offset'), ('hipMemcpyKind', + // 'kind')] + case HIP_API_ID_hipGraphAddMemcpyNodeToSymbol: + if (data->args.hipGraphAddMemcpyNodeToSymbol.pGraphNode) + data->args.hipGraphAddMemcpyNodeToSymbol.pGraphNode__val = + *(data->args.hipGraphAddMemcpyNodeToSymbol.pGraphNode); + if (data->args.hipGraphAddMemcpyNodeToSymbol.pDependencies) + data->args.hipGraphAddMemcpyNodeToSymbol.pDependencies__val = + *(data->args.hipGraphAddMemcpyNodeToSymbol.pDependencies); + break; + // hipGraphAddMemsetNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies'), ('const hipMemsetParams*', 'pMemsetParams')] + case HIP_API_ID_hipGraphAddMemsetNode: + if (data->args.hipGraphAddMemsetNode.pGraphNode) + data->args.hipGraphAddMemsetNode.pGraphNode__val = + *(data->args.hipGraphAddMemsetNode.pGraphNode); + if (data->args.hipGraphAddMemsetNode.pDependencies) + data->args.hipGraphAddMemsetNode.pDependencies__val = + *(data->args.hipGraphAddMemsetNode.pDependencies); + if (data->args.hipGraphAddMemsetNode.pMemsetParams) + data->args.hipGraphAddMemsetNode.pMemsetParams__val = + *(data->args.hipGraphAddMemsetNode.pMemsetParams); + break; + // hipGraphAddNode[('hipGraphNode_t*', 'pGraphNode'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'pDependencies'), ('size_t', + // 'numDependencies'), ('hipGraphNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphAddNode: + if (data->args.hipGraphAddNode.pGraphNode) + data->args.hipGraphAddNode.pGraphNode__val = + *(data->args.hipGraphAddNode.pGraphNode); + if (data->args.hipGraphAddNode.pDependencies) + data->args.hipGraphAddNode.pDependencies__val = + *(data->args.hipGraphAddNode.pDependencies); + if (data->args.hipGraphAddNode.nodeParams) + data->args.hipGraphAddNode.nodeParams__val = + *(data->args.hipGraphAddNode.nodeParams); + break; + // hipGraphChildGraphNodeGetGraph[('hipGraphNode_t', 'node'), + // ('hipGraph_t*', 'pGraph')] + case HIP_API_ID_hipGraphChildGraphNodeGetGraph: + if (data->args.hipGraphChildGraphNodeGetGraph.pGraph) + data->args.hipGraphChildGraphNodeGetGraph.pGraph__val = + *(data->args.hipGraphChildGraphNodeGetGraph.pGraph); + break; + // hipGraphClone[('hipGraph_t*', 'pGraphClone'), ('hipGraph_t', + // 'originalGraph')] + case HIP_API_ID_hipGraphClone: + if (data->args.hipGraphClone.pGraphClone) + data->args.hipGraphClone.pGraphClone__val = + *(data->args.hipGraphClone.pGraphClone); + break; + // hipGraphCreate[('hipGraph_t*', 'pGraph'), ('unsigned int', 'flags')] + case HIP_API_ID_hipGraphCreate: + if (data->args.hipGraphCreate.pGraph) + data->args.hipGraphCreate.pGraph__val = + *(data->args.hipGraphCreate.pGraph); + break; + // hipGraphDebugDotPrint[('hipGraph_t', 'graph'), ('const char*', 'path'), + // ('unsigned int', 'flags')] + case HIP_API_ID_hipGraphDebugDotPrint: + if (data->args.hipGraphDebugDotPrint.path) + data->args.hipGraphDebugDotPrint.path__val = + *(data->args.hipGraphDebugDotPrint.path); + break; + // hipGraphDestroy[('hipGraph_t', 'graph')] + case HIP_API_ID_hipGraphDestroy: + break; + // hipGraphDestroyNode[('hipGraphNode_t', 'node')] + case HIP_API_ID_hipGraphDestroyNode: + break; + // hipGraphEventRecordNodeGetEvent[('hipGraphNode_t', 'node'), + // ('hipEvent_t*', 'event_out')] + case HIP_API_ID_hipGraphEventRecordNodeGetEvent: + if (data->args.hipGraphEventRecordNodeGetEvent.event_out) + data->args.hipGraphEventRecordNodeGetEvent.event_out__val = + *(data->args.hipGraphEventRecordNodeGetEvent.event_out); + break; + // hipGraphEventRecordNodeSetEvent[('hipGraphNode_t', 'node'), + // ('hipEvent_t', 'event')] + case HIP_API_ID_hipGraphEventRecordNodeSetEvent: + break; + // hipGraphEventWaitNodeGetEvent[('hipGraphNode_t', 'node'), ('hipEvent_t*', + // 'event_out')] + case HIP_API_ID_hipGraphEventWaitNodeGetEvent: + if (data->args.hipGraphEventWaitNodeGetEvent.event_out) + data->args.hipGraphEventWaitNodeGetEvent.event_out__val = + *(data->args.hipGraphEventWaitNodeGetEvent.event_out); + break; + // hipGraphEventWaitNodeSetEvent[('hipGraphNode_t', 'node'), ('hipEvent_t', + // 'event')] + case HIP_API_ID_hipGraphEventWaitNodeSetEvent: + break; + // hipGraphExecChildGraphNodeSetParams[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'node'), ('hipGraph_t', 'childGraph')] + case HIP_API_ID_hipGraphExecChildGraphNodeSetParams: + break; + // hipGraphExecDestroy[('hipGraphExec_t', 'graphExec')] + case HIP_API_ID_hipGraphExecDestroy: + break; + // hipGraphExecEventRecordNodeSetEvent[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'hNode'), ('hipEvent_t', 'event')] + case HIP_API_ID_hipGraphExecEventRecordNodeSetEvent: + break; + // hipGraphExecEventWaitNodeSetEvent[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'hNode'), ('hipEvent_t', 'event')] + case HIP_API_ID_hipGraphExecEventWaitNodeSetEvent: + break; + // hipGraphExecExternalSemaphoresSignalNodeSetParams[('hipGraphExec_t', + // 'hGraphExec'), ('hipGraphNode_t', 'hNode'), ('const + // hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphExecExternalSemaphoresSignalNodeSetParams: + if (data->args.hipGraphExecExternalSemaphoresSignalNodeSetParams.nodeParams) + data->args.hipGraphExecExternalSemaphoresSignalNodeSetParams + .nodeParams__val = + *(data->args.hipGraphExecExternalSemaphoresSignalNodeSetParams + .nodeParams); + break; + // hipGraphExecExternalSemaphoresWaitNodeSetParams[('hipGraphExec_t', + // 'hGraphExec'), ('hipGraphNode_t', 'hNode'), ('const + // hipExternalSemaphoreWaitNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphExecExternalSemaphoresWaitNodeSetParams: + if (data->args.hipGraphExecExternalSemaphoresWaitNodeSetParams.nodeParams) + data->args.hipGraphExecExternalSemaphoresWaitNodeSetParams + .nodeParams__val = + *(data->args.hipGraphExecExternalSemaphoresWaitNodeSetParams + .nodeParams); + break; + // hipGraphExecHostNodeSetParams[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'node'), ('const hipHostNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphExecHostNodeSetParams: + if (data->args.hipGraphExecHostNodeSetParams.pNodeParams) + data->args.hipGraphExecHostNodeSetParams.pNodeParams__val = + *(data->args.hipGraphExecHostNodeSetParams.pNodeParams); + break; + // hipGraphExecKernelNodeSetParams[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'node'), ('const hipKernelNodeParams*', + // 'pNodeParams')] + case HIP_API_ID_hipGraphExecKernelNodeSetParams: + if (data->args.hipGraphExecKernelNodeSetParams.pNodeParams) + data->args.hipGraphExecKernelNodeSetParams.pNodeParams__val = + *(data->args.hipGraphExecKernelNodeSetParams.pNodeParams); + break; + // hipGraphExecMemcpyNodeSetParams[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'node'), ('hipMemcpy3DParms*', 'pNodeParams')] + case HIP_API_ID_hipGraphExecMemcpyNodeSetParams: + if (data->args.hipGraphExecMemcpyNodeSetParams.pNodeParams) + data->args.hipGraphExecMemcpyNodeSetParams.pNodeParams__val = + *(data->args.hipGraphExecMemcpyNodeSetParams.pNodeParams); + break; + // hipGraphExecMemcpyNodeSetParams1D[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'node'), ('void*', 'dst'), ('const void*', 'src'), + // ('size_t', 'count'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipGraphExecMemcpyNodeSetParams1D: + break; + // hipGraphExecMemcpyNodeSetParamsFromSymbol[('hipGraphExec_t', + // 'hGraphExec'), ('hipGraphNode_t', 'node'), ('void*', 'dst'), ('const + // void*', 'symbol'), ('size_t', 'count'), ('size_t', 'offset'), + // ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipGraphExecMemcpyNodeSetParamsFromSymbol: + break; + // hipGraphExecMemcpyNodeSetParamsToSymbol[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'node'), ('const void*', 'symbol'), ('const void*', + // 'src'), ('size_t', 'count'), ('size_t', 'offset'), ('hipMemcpyKind', + // 'kind')] + case HIP_API_ID_hipGraphExecMemcpyNodeSetParamsToSymbol: + break; + // hipGraphExecMemsetNodeSetParams[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'node'), ('const hipMemsetParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphExecMemsetNodeSetParams: + if (data->args.hipGraphExecMemsetNodeSetParams.pNodeParams) + data->args.hipGraphExecMemsetNodeSetParams.pNodeParams__val = + *(data->args.hipGraphExecMemsetNodeSetParams.pNodeParams); + break; + // hipGraphExecUpdate[('hipGraphExec_t', 'hGraphExec'), ('hipGraph_t', + // 'hGraph'), ('hipGraphNode_t*', 'hErrorNode_out'), + // ('hipGraphExecUpdateResult*', 'updateResult_out')] + case HIP_API_ID_hipGraphExecUpdate: + if (data->args.hipGraphExecUpdate.hErrorNode_out) + data->args.hipGraphExecUpdate.hErrorNode_out__val = + *(data->args.hipGraphExecUpdate.hErrorNode_out); + if (data->args.hipGraphExecUpdate.updateResult_out) + data->args.hipGraphExecUpdate.updateResult_out__val = + *(data->args.hipGraphExecUpdate.updateResult_out); + break; + // hipGraphExternalSemaphoresSignalNodeGetParams[('hipGraphNode_t', + // 'hNode'), ('hipExternalSemaphoreSignalNodeParams*', 'params_out')] + case HIP_API_ID_hipGraphExternalSemaphoresSignalNodeGetParams: + if (data->args.hipGraphExternalSemaphoresSignalNodeGetParams.params_out) + data->args.hipGraphExternalSemaphoresSignalNodeGetParams.params_out__val = + *(data->args.hipGraphExternalSemaphoresSignalNodeGetParams + .params_out); + break; + // hipGraphExternalSemaphoresSignalNodeSetParams[('hipGraphNode_t', + // 'hNode'), ('const hipExternalSemaphoreSignalNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphExternalSemaphoresSignalNodeSetParams: + if (data->args.hipGraphExternalSemaphoresSignalNodeSetParams.nodeParams) + data->args.hipGraphExternalSemaphoresSignalNodeSetParams.nodeParams__val = + *(data->args.hipGraphExternalSemaphoresSignalNodeSetParams + .nodeParams); + break; + // hipGraphExternalSemaphoresWaitNodeGetParams[('hipGraphNode_t', 'hNode'), + // ('hipExternalSemaphoreWaitNodeParams*', 'params_out')] + case HIP_API_ID_hipGraphExternalSemaphoresWaitNodeGetParams: + if (data->args.hipGraphExternalSemaphoresWaitNodeGetParams.params_out) + data->args.hipGraphExternalSemaphoresWaitNodeGetParams.params_out__val = + *(data->args.hipGraphExternalSemaphoresWaitNodeGetParams.params_out); + break; + // hipGraphExternalSemaphoresWaitNodeSetParams[('hipGraphNode_t', 'hNode'), + // ('const hipExternalSemaphoreWaitNodeParams*', 'nodeParams')] + case HIP_API_ID_hipGraphExternalSemaphoresWaitNodeSetParams: + if (data->args.hipGraphExternalSemaphoresWaitNodeSetParams.nodeParams) + data->args.hipGraphExternalSemaphoresWaitNodeSetParams.nodeParams__val = + *(data->args.hipGraphExternalSemaphoresWaitNodeSetParams.nodeParams); + break; + // hipGraphGetEdges[('hipGraph_t', 'graph'), ('hipGraphNode_t*', 'from'), + // ('hipGraphNode_t*', 'to'), ('size_t*', 'numEdges')] + case HIP_API_ID_hipGraphGetEdges: + if (data->args.hipGraphGetEdges.from) + data->args.hipGraphGetEdges.from__val = + *(data->args.hipGraphGetEdges.from); + if (data->args.hipGraphGetEdges.to) + data->args.hipGraphGetEdges.to__val = *(data->args.hipGraphGetEdges.to); + if (data->args.hipGraphGetEdges.numEdges) + data->args.hipGraphGetEdges.numEdges__val = + *(data->args.hipGraphGetEdges.numEdges); + break; + // hipGraphGetNodes[('hipGraph_t', 'graph'), ('hipGraphNode_t*', 'nodes'), + // ('size_t*', 'numNodes')] + case HIP_API_ID_hipGraphGetNodes: + if (data->args.hipGraphGetNodes.nodes) + data->args.hipGraphGetNodes.nodes__val = + *(data->args.hipGraphGetNodes.nodes); + if (data->args.hipGraphGetNodes.numNodes) + data->args.hipGraphGetNodes.numNodes__val = + *(data->args.hipGraphGetNodes.numNodes); + break; + // hipGraphGetRootNodes[('hipGraph_t', 'graph'), ('hipGraphNode_t*', + // 'pRootNodes'), ('size_t*', 'pNumRootNodes')] + case HIP_API_ID_hipGraphGetRootNodes: + if (data->args.hipGraphGetRootNodes.pRootNodes) + data->args.hipGraphGetRootNodes.pRootNodes__val = + *(data->args.hipGraphGetRootNodes.pRootNodes); + if (data->args.hipGraphGetRootNodes.pNumRootNodes) + data->args.hipGraphGetRootNodes.pNumRootNodes__val = + *(data->args.hipGraphGetRootNodes.pNumRootNodes); + break; + // hipGraphHostNodeGetParams[('hipGraphNode_t', 'node'), + // ('hipHostNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphHostNodeGetParams: + if (data->args.hipGraphHostNodeGetParams.pNodeParams) + data->args.hipGraphHostNodeGetParams.pNodeParams__val = + *(data->args.hipGraphHostNodeGetParams.pNodeParams); + break; + // hipGraphHostNodeSetParams[('hipGraphNode_t', 'node'), ('const + // hipHostNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphHostNodeSetParams: + if (data->args.hipGraphHostNodeSetParams.pNodeParams) + data->args.hipGraphHostNodeSetParams.pNodeParams__val = + *(data->args.hipGraphHostNodeSetParams.pNodeParams); + break; + // hipGraphInstantiate[('hipGraphExec_t*', 'pGraphExec'), ('hipGraph_t', + // 'graph'), ('hipGraphNode_t*', 'pErrorNode'), ('char*', 'pLogBuffer'), + // ('size_t', 'bufferSize')] + case HIP_API_ID_hipGraphInstantiate: + if (data->args.hipGraphInstantiate.pGraphExec) + data->args.hipGraphInstantiate.pGraphExec__val = + *(data->args.hipGraphInstantiate.pGraphExec); + if (data->args.hipGraphInstantiate.pErrorNode) + data->args.hipGraphInstantiate.pErrorNode__val = + *(data->args.hipGraphInstantiate.pErrorNode); + data->args.hipGraphInstantiate.pLogBuffer = + (data->args.hipGraphInstantiate.pLogBuffer) + ? strdup(data->args.hipGraphInstantiate.pLogBuffer) + : NULL; + break; + // hipGraphInstantiateWithFlags[('hipGraphExec_t*', 'pGraphExec'), + // ('hipGraph_t', 'graph'), ('unsigned long long', 'flags')] + case HIP_API_ID_hipGraphInstantiateWithFlags: + if (data->args.hipGraphInstantiateWithFlags.pGraphExec) + data->args.hipGraphInstantiateWithFlags.pGraphExec__val = + *(data->args.hipGraphInstantiateWithFlags.pGraphExec); + break; + // hipGraphInstantiateWithParams[('hipGraphExec_t*', 'pGraphExec'), + // ('hipGraph_t', 'graph'), ('hipGraphInstantiateParams*', + // 'instantiateParams')] + case HIP_API_ID_hipGraphInstantiateWithParams: + if (data->args.hipGraphInstantiateWithParams.pGraphExec) + data->args.hipGraphInstantiateWithParams.pGraphExec__val = + *(data->args.hipGraphInstantiateWithParams.pGraphExec); + if (data->args.hipGraphInstantiateWithParams.instantiateParams) + data->args.hipGraphInstantiateWithParams.instantiateParams__val = + *(data->args.hipGraphInstantiateWithParams.instantiateParams); + break; + // hipGraphKernelNodeCopyAttributes[('hipGraphNode_t', 'hSrc'), + // ('hipGraphNode_t', 'hDst')] + case HIP_API_ID_hipGraphKernelNodeCopyAttributes: + break; + // hipGraphKernelNodeGetAttribute[('hipGraphNode_t', 'hNode'), + // ('hipLaunchAttributeID', 'attr'), ('hipLaunchAttributeValue*', 'value')] + case HIP_API_ID_hipGraphKernelNodeGetAttribute: + if (data->args.hipGraphKernelNodeGetAttribute.value) + data->args.hipGraphKernelNodeGetAttribute.value__val = + *(data->args.hipGraphKernelNodeGetAttribute.value); + break; + // hipGraphKernelNodeGetParams[('hipGraphNode_t', 'node'), + // ('hipKernelNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphKernelNodeGetParams: + if (data->args.hipGraphKernelNodeGetParams.pNodeParams) + data->args.hipGraphKernelNodeGetParams.pNodeParams__val = + *(data->args.hipGraphKernelNodeGetParams.pNodeParams); + break; + // hipGraphKernelNodeSetAttribute[('hipGraphNode_t', 'hNode'), + // ('hipLaunchAttributeID', 'attr'), ('const hipLaunchAttributeValue*', + // 'value')] + case HIP_API_ID_hipGraphKernelNodeSetAttribute: + if (data->args.hipGraphKernelNodeSetAttribute.value) + data->args.hipGraphKernelNodeSetAttribute.value__val = + *(data->args.hipGraphKernelNodeSetAttribute.value); + break; + // hipGraphKernelNodeSetParams[('hipGraphNode_t', 'node'), ('const + // hipKernelNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphKernelNodeSetParams: + if (data->args.hipGraphKernelNodeSetParams.pNodeParams) + data->args.hipGraphKernelNodeSetParams.pNodeParams__val = + *(data->args.hipGraphKernelNodeSetParams.pNodeParams); + break; + // hipGraphLaunch[('hipGraphExec_t', 'graphExec'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipGraphLaunch: + break; + // hipGraphMemAllocNodeGetParams[('hipGraphNode_t', 'node'), + // ('hipMemAllocNodeParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphMemAllocNodeGetParams: + if (data->args.hipGraphMemAllocNodeGetParams.pNodeParams) + data->args.hipGraphMemAllocNodeGetParams.pNodeParams__val = + *(data->args.hipGraphMemAllocNodeGetParams.pNodeParams); + break; + // hipGraphMemFreeNodeGetParams[('hipGraphNode_t', 'node'), ('void*', + // 'dev_ptr')] + case HIP_API_ID_hipGraphMemFreeNodeGetParams: + break; + // hipGraphMemcpyNodeGetParams[('hipGraphNode_t', 'node'), + // ('hipMemcpy3DParms*', 'pNodeParams')] + case HIP_API_ID_hipGraphMemcpyNodeGetParams: + if (data->args.hipGraphMemcpyNodeGetParams.pNodeParams) + data->args.hipGraphMemcpyNodeGetParams.pNodeParams__val = + *(data->args.hipGraphMemcpyNodeGetParams.pNodeParams); + break; + // hipGraphMemcpyNodeSetParams[('hipGraphNode_t', 'node'), ('const + // hipMemcpy3DParms*', 'pNodeParams')] + case HIP_API_ID_hipGraphMemcpyNodeSetParams: + if (data->args.hipGraphMemcpyNodeSetParams.pNodeParams) + data->args.hipGraphMemcpyNodeSetParams.pNodeParams__val = + *(data->args.hipGraphMemcpyNodeSetParams.pNodeParams); + break; + // hipGraphMemcpyNodeSetParams1D[('hipGraphNode_t', 'node'), ('void*', + // 'dst'), ('const void*', 'src'), ('size_t', 'count'), ('hipMemcpyKind', + // 'kind')] + case HIP_API_ID_hipGraphMemcpyNodeSetParams1D: + break; + // hipGraphMemcpyNodeSetParamsFromSymbol[('hipGraphNode_t', 'node'), + // ('void*', 'dst'), ('const void*', 'symbol'), ('size_t', 'count'), + // ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipGraphMemcpyNodeSetParamsFromSymbol: + break; + // hipGraphMemcpyNodeSetParamsToSymbol[('hipGraphNode_t', 'node'), ('const + // void*', 'symbol'), ('const void*', 'src'), ('size_t', 'count'), + // ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipGraphMemcpyNodeSetParamsToSymbol: + break; + // hipGraphMemsetNodeGetParams[('hipGraphNode_t', 'node'), + // ('hipMemsetParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphMemsetNodeGetParams: + if (data->args.hipGraphMemsetNodeGetParams.pNodeParams) + data->args.hipGraphMemsetNodeGetParams.pNodeParams__val = + *(data->args.hipGraphMemsetNodeGetParams.pNodeParams); + break; + // hipGraphMemsetNodeSetParams[('hipGraphNode_t', 'node'), ('const + // hipMemsetParams*', 'pNodeParams')] + case HIP_API_ID_hipGraphMemsetNodeSetParams: + if (data->args.hipGraphMemsetNodeSetParams.pNodeParams) + data->args.hipGraphMemsetNodeSetParams.pNodeParams__val = + *(data->args.hipGraphMemsetNodeSetParams.pNodeParams); + break; + // hipGraphNodeFindInClone[('hipGraphNode_t*', 'pNode'), ('hipGraphNode_t', + // 'originalNode'), ('hipGraph_t', 'clonedGraph')] + case HIP_API_ID_hipGraphNodeFindInClone: + if (data->args.hipGraphNodeFindInClone.pNode) + data->args.hipGraphNodeFindInClone.pNode__val = + *(data->args.hipGraphNodeFindInClone.pNode); + break; + // hipGraphNodeGetDependencies[('hipGraphNode_t', 'node'), + // ('hipGraphNode_t*', 'pDependencies'), ('size_t*', 'pNumDependencies')] + case HIP_API_ID_hipGraphNodeGetDependencies: + if (data->args.hipGraphNodeGetDependencies.pDependencies) + data->args.hipGraphNodeGetDependencies.pDependencies__val = + *(data->args.hipGraphNodeGetDependencies.pDependencies); + if (data->args.hipGraphNodeGetDependencies.pNumDependencies) + data->args.hipGraphNodeGetDependencies.pNumDependencies__val = + *(data->args.hipGraphNodeGetDependencies.pNumDependencies); + break; + // hipGraphNodeGetDependentNodes[('hipGraphNode_t', 'node'), + // ('hipGraphNode_t*', 'pDependentNodes'), ('size_t*', + // 'pNumDependentNodes')] + case HIP_API_ID_hipGraphNodeGetDependentNodes: + if (data->args.hipGraphNodeGetDependentNodes.pDependentNodes) + data->args.hipGraphNodeGetDependentNodes.pDependentNodes__val = + *(data->args.hipGraphNodeGetDependentNodes.pDependentNodes); + if (data->args.hipGraphNodeGetDependentNodes.pNumDependentNodes) + data->args.hipGraphNodeGetDependentNodes.pNumDependentNodes__val = + *(data->args.hipGraphNodeGetDependentNodes.pNumDependentNodes); + break; + // hipGraphNodeGetEnabled[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'hNode'), ('unsigned int*', 'isEnabled')] + case HIP_API_ID_hipGraphNodeGetEnabled: + if (data->args.hipGraphNodeGetEnabled.isEnabled) + data->args.hipGraphNodeGetEnabled.isEnabled__val = + *(data->args.hipGraphNodeGetEnabled.isEnabled); + break; + // hipGraphNodeGetType[('hipGraphNode_t', 'node'), ('hipGraphNodeType*', + // 'pType')] + case HIP_API_ID_hipGraphNodeGetType: + if (data->args.hipGraphNodeGetType.pType) + data->args.hipGraphNodeGetType.pType__val = + *(data->args.hipGraphNodeGetType.pType); + break; + // hipGraphNodeSetEnabled[('hipGraphExec_t', 'hGraphExec'), + // ('hipGraphNode_t', 'hNode'), ('unsigned int', 'isEnabled')] + case HIP_API_ID_hipGraphNodeSetEnabled: + break; + // hipGraphReleaseUserObject[('hipGraph_t', 'graph'), ('hipUserObject_t', + // 'object'), ('unsigned int', 'count')] + case HIP_API_ID_hipGraphReleaseUserObject: + break; + // hipGraphRemoveDependencies[('hipGraph_t', 'graph'), ('const + // hipGraphNode_t*', 'from'), ('const hipGraphNode_t*', 'to'), ('size_t', + // 'numDependencies')] + case HIP_API_ID_hipGraphRemoveDependencies: + if (data->args.hipGraphRemoveDependencies.from) + data->args.hipGraphRemoveDependencies.from__val = + *(data->args.hipGraphRemoveDependencies.from); + if (data->args.hipGraphRemoveDependencies.to) + data->args.hipGraphRemoveDependencies.to__val = + *(data->args.hipGraphRemoveDependencies.to); + break; + // hipGraphRetainUserObject[('hipGraph_t', 'graph'), ('hipUserObject_t', + // 'object'), ('unsigned int', 'count'), ('unsigned int', 'flags')] + case HIP_API_ID_hipGraphRetainUserObject: + break; + // hipGraphUpload[('hipGraphExec_t', 'graphExec'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipGraphUpload: + break; + // hipGraphicsGLRegisterBuffer[('hipGraphicsResource**', 'resource'), + // ('GLuint', 'buffer'), ('unsigned int', 'flags')] + case HIP_API_ID_hipGraphicsGLRegisterBuffer: + if (data->args.hipGraphicsGLRegisterBuffer.resource) + data->args.hipGraphicsGLRegisterBuffer.resource__val = + *(data->args.hipGraphicsGLRegisterBuffer.resource); + break; + // hipGraphicsGLRegisterImage[('hipGraphicsResource**', 'resource'), + // ('GLuint', 'image'), ('GLenum', 'target'), ('unsigned int', 'flags')] + case HIP_API_ID_hipGraphicsGLRegisterImage: + if (data->args.hipGraphicsGLRegisterImage.resource) + data->args.hipGraphicsGLRegisterImage.resource__val = + *(data->args.hipGraphicsGLRegisterImage.resource); + break; + // hipGraphicsMapResources[('int', 'count'), ('hipGraphicsResource_t*', + // 'resources'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipGraphicsMapResources: + if (data->args.hipGraphicsMapResources.resources) + data->args.hipGraphicsMapResources.resources__val = + *(data->args.hipGraphicsMapResources.resources); + break; + // hipGraphicsResourceGetMappedPointer[('void**', 'devPtr'), ('size_t*', + // 'size'), ('hipGraphicsResource_t', 'resource')] + case HIP_API_ID_hipGraphicsResourceGetMappedPointer: + if (data->args.hipGraphicsResourceGetMappedPointer.devPtr) + data->args.hipGraphicsResourceGetMappedPointer.devPtr__val = + *(data->args.hipGraphicsResourceGetMappedPointer.devPtr); + if (data->args.hipGraphicsResourceGetMappedPointer.size) + data->args.hipGraphicsResourceGetMappedPointer.size__val = + *(data->args.hipGraphicsResourceGetMappedPointer.size); + break; + // hipGraphicsSubResourceGetMappedArray[('hipArray_t*', 'array'), + // ('hipGraphicsResource_t', 'resource'), ('unsigned int', 'arrayIndex'), + // ('unsigned int', 'mipLevel')] + case HIP_API_ID_hipGraphicsSubResourceGetMappedArray: + if (data->args.hipGraphicsSubResourceGetMappedArray.array) + data->args.hipGraphicsSubResourceGetMappedArray.array__val = + *(data->args.hipGraphicsSubResourceGetMappedArray.array); + break; + // hipGraphicsUnmapResources[('int', 'count'), ('hipGraphicsResource_t*', + // 'resources'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipGraphicsUnmapResources: + if (data->args.hipGraphicsUnmapResources.resources) + data->args.hipGraphicsUnmapResources.resources__val = + *(data->args.hipGraphicsUnmapResources.resources); + break; + // hipGraphicsUnregisterResource[('hipGraphicsResource_t', 'resource')] + case HIP_API_ID_hipGraphicsUnregisterResource: + break; + // hipHccModuleLaunchKernel[('hipFunction_t', 'f'), ('unsigned int', + // 'globalWorkSizeX'), ('unsigned int', 'globalWorkSizeY'), ('unsigned int', + // 'globalWorkSizeZ'), ('unsigned int', 'blockDimX'), ('unsigned int', + // 'blockDimY'), ('unsigned int', 'blockDimZ'), ('size_t', + // 'sharedMemBytes'), ('hipStream_t', 'hStream'), ('void**', + // 'kernelParams'), ('void**', 'extra'), ('hipEvent_t', 'startEvent'), + // ('hipEvent_t', 'stopEvent')] + case HIP_API_ID_hipHccModuleLaunchKernel: + if (data->args.hipHccModuleLaunchKernel.kernelParams) + data->args.hipHccModuleLaunchKernel.kernelParams__val = + *(data->args.hipHccModuleLaunchKernel.kernelParams); + if (data->args.hipHccModuleLaunchKernel.extra) + data->args.hipHccModuleLaunchKernel.extra__val = + *(data->args.hipHccModuleLaunchKernel.extra); + break; + // hipHostAlloc[('void**', 'ptr'), ('size_t', 'size'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipHostAlloc: + if (data->args.hipHostAlloc.ptr) + data->args.hipHostAlloc.ptr__val = *(data->args.hipHostAlloc.ptr); + break; + // hipHostFree[('void*', 'ptr')] + case HIP_API_ID_hipHostFree: + break; + // hipHostGetDevicePointer[('void**', 'devPtr'), ('void*', 'hstPtr'), + // ('unsigned int', 'flags')] + case HIP_API_ID_hipHostGetDevicePointer: + if (data->args.hipHostGetDevicePointer.devPtr) + data->args.hipHostGetDevicePointer.devPtr__val = + *(data->args.hipHostGetDevicePointer.devPtr); + break; + // hipHostGetFlags[('unsigned int*', 'flagsPtr'), ('void*', 'hostPtr')] + case HIP_API_ID_hipHostGetFlags: + if (data->args.hipHostGetFlags.flagsPtr) + data->args.hipHostGetFlags.flagsPtr__val = + *(data->args.hipHostGetFlags.flagsPtr); + break; + // hipHostMalloc[('void**', 'ptr'), ('size_t', 'size'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipHostMalloc: + if (data->args.hipHostMalloc.ptr) + data->args.hipHostMalloc.ptr__val = *(data->args.hipHostMalloc.ptr); + break; + // hipHostRegister[('void*', 'hostPtr'), ('size_t', 'sizeBytes'), ('unsigned + // int', 'flags')] + case HIP_API_ID_hipHostRegister: + break; + // hipHostUnregister[('void*', 'hostPtr')] + case HIP_API_ID_hipHostUnregister: + break; + // hipImportExternalMemory[('hipExternalMemory_t*', 'extMem_out'), ('const + // hipExternalMemoryHandleDesc*', 'memHandleDesc')] + case HIP_API_ID_hipImportExternalMemory: + if (data->args.hipImportExternalMemory.extMem_out) + data->args.hipImportExternalMemory.extMem_out__val = + *(data->args.hipImportExternalMemory.extMem_out); + if (data->args.hipImportExternalMemory.memHandleDesc) + data->args.hipImportExternalMemory.memHandleDesc__val = + *(data->args.hipImportExternalMemory.memHandleDesc); + break; + // hipImportExternalSemaphore[('hipExternalSemaphore_t*', 'extSem_out'), + // ('const hipExternalSemaphoreHandleDesc*', 'semHandleDesc')] + case HIP_API_ID_hipImportExternalSemaphore: + if (data->args.hipImportExternalSemaphore.extSem_out) + data->args.hipImportExternalSemaphore.extSem_out__val = + *(data->args.hipImportExternalSemaphore.extSem_out); + if (data->args.hipImportExternalSemaphore.semHandleDesc) + data->args.hipImportExternalSemaphore.semHandleDesc__val = + *(data->args.hipImportExternalSemaphore.semHandleDesc); + break; + // hipInit[('unsigned int', 'flags')] + case HIP_API_ID_hipInit: + break; + // hipIpcCloseMemHandle[('void*', 'devPtr')] + case HIP_API_ID_hipIpcCloseMemHandle: + break; + // hipIpcGetEventHandle[('hipIpcEventHandle_t*', 'handle'), ('hipEvent_t', + // 'event')] + case HIP_API_ID_hipIpcGetEventHandle: + if (data->args.hipIpcGetEventHandle.handle) + data->args.hipIpcGetEventHandle.handle__val = + *(data->args.hipIpcGetEventHandle.handle); + break; + // hipIpcGetMemHandle[('hipIpcMemHandle_t*', 'handle'), ('void*', 'devPtr')] + case HIP_API_ID_hipIpcGetMemHandle: + if (data->args.hipIpcGetMemHandle.handle) + data->args.hipIpcGetMemHandle.handle__val = + *(data->args.hipIpcGetMemHandle.handle); + break; + // hipIpcOpenEventHandle[('hipEvent_t*', 'event'), ('hipIpcEventHandle_t', + // 'handle')] + case HIP_API_ID_hipIpcOpenEventHandle: + if (data->args.hipIpcOpenEventHandle.event) + data->args.hipIpcOpenEventHandle.event__val = + *(data->args.hipIpcOpenEventHandle.event); + break; + // hipIpcOpenMemHandle[('void**', 'devPtr'), ('hipIpcMemHandle_t', + // 'handle'), ('unsigned int', 'flags')] + case HIP_API_ID_hipIpcOpenMemHandle: + if (data->args.hipIpcOpenMemHandle.devPtr) + data->args.hipIpcOpenMemHandle.devPtr__val = + *(data->args.hipIpcOpenMemHandle.devPtr); + break; + // hipLaunchByPtr[('const void*', 'hostFunction')] + case HIP_API_ID_hipLaunchByPtr: + break; + // hipLaunchCooperativeKernel[('const void*', 'f'), ('dim3', 'gridDim'), + // ('dim3', 'blockDimX'), ('void**', 'kernelParams'), ('unsigned int', + // 'sharedMemBytes'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipLaunchCooperativeKernel: + if (data->args.hipLaunchCooperativeKernel.kernelParams) + data->args.hipLaunchCooperativeKernel.kernelParams__val = + *(data->args.hipLaunchCooperativeKernel.kernelParams); + break; + // hipLaunchCooperativeKernelMultiDevice[('hipLaunchParams*', + // 'launchParamsList'), ('int', 'numDevices'), ('unsigned int', 'flags')] + case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: + if (data->args.hipLaunchCooperativeKernelMultiDevice.launchParamsList) + data->args.hipLaunchCooperativeKernelMultiDevice.launchParamsList__val = + *(data->args.hipLaunchCooperativeKernelMultiDevice.launchParamsList); + break; + // hipLaunchHostFunc[('hipStream_t', 'stream'), ('hipHostFn_t', 'fn'), + // ('void*', 'userData')] + case HIP_API_ID_hipLaunchHostFunc: + break; + // hipLaunchKernel[('const void*', 'function_address'), ('dim3', + // 'numBlocks'), ('dim3', 'dimBlocks'), ('void**', 'args'), ('size_t', + // 'sharedMemBytes'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipLaunchKernel: + if (data->args.hipLaunchKernel.args) + data->args.hipLaunchKernel.args__val = *(data->args.hipLaunchKernel.args); + break; + // hipMalloc[('void**', 'ptr'), ('size_t', 'size')] + case HIP_API_ID_hipMalloc: + if (data->args.hipMalloc.ptr) + data->args.hipMalloc.ptr__val = *(data->args.hipMalloc.ptr); + break; + // hipMalloc3D[('hipPitchedPtr*', 'pitchedDevPtr'), ('hipExtent', 'extent')] + case HIP_API_ID_hipMalloc3D: + if (data->args.hipMalloc3D.pitchedDevPtr) + data->args.hipMalloc3D.pitchedDevPtr__val = + *(data->args.hipMalloc3D.pitchedDevPtr); + break; + // hipMalloc3DArray[('hipArray_t*', 'array'), ('const + // hipChannelFormatDesc*', 'desc'), ('hipExtent', 'extent'), ('unsigned + // int', 'flags')] + case HIP_API_ID_hipMalloc3DArray: + if (data->args.hipMalloc3DArray.array) + data->args.hipMalloc3DArray.array__val = + *(data->args.hipMalloc3DArray.array); + if (data->args.hipMalloc3DArray.desc) + data->args.hipMalloc3DArray.desc__val = + *(data->args.hipMalloc3DArray.desc); + break; + // hipMallocArray[('hipArray_t*', 'array'), ('const hipChannelFormatDesc*', + // 'desc'), ('size_t', 'width'), ('size_t', 'height'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipMallocArray: + if (data->args.hipMallocArray.array) + data->args.hipMallocArray.array__val = *(data->args.hipMallocArray.array); + if (data->args.hipMallocArray.desc) + data->args.hipMallocArray.desc__val = *(data->args.hipMallocArray.desc); + break; + // hipMallocAsync[('void**', 'dev_ptr'), ('size_t', 'size'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipMallocAsync: + if (data->args.hipMallocAsync.dev_ptr) + data->args.hipMallocAsync.dev_ptr__val = + *(data->args.hipMallocAsync.dev_ptr); + break; + // hipMallocFromPoolAsync[('void**', 'dev_ptr'), ('size_t', 'size'), + // ('hipMemPool_t', 'mem_pool'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMallocFromPoolAsync: + if (data->args.hipMallocFromPoolAsync.dev_ptr) + data->args.hipMallocFromPoolAsync.dev_ptr__val = + *(data->args.hipMallocFromPoolAsync.dev_ptr); + break; + // hipMallocHost[('void**', 'ptr'), ('size_t', 'size')] + case HIP_API_ID_hipMallocHost: + if (data->args.hipMallocHost.ptr) + data->args.hipMallocHost.ptr__val = *(data->args.hipMallocHost.ptr); + break; + // hipMallocManaged[('void**', 'dev_ptr'), ('size_t', 'size'), ('unsigned + // int', 'flags')] + case HIP_API_ID_hipMallocManaged: + if (data->args.hipMallocManaged.dev_ptr) + data->args.hipMallocManaged.dev_ptr__val = + *(data->args.hipMallocManaged.dev_ptr); + break; + // hipMallocMipmappedArray[('hipMipmappedArray_t*', 'mipmappedArray'), + // ('const hipChannelFormatDesc*', 'desc'), ('hipExtent', 'extent'), + // ('unsigned int', 'numLevels'), ('unsigned int', 'flags')] + case HIP_API_ID_hipMallocMipmappedArray: + if (data->args.hipMallocMipmappedArray.mipmappedArray) + data->args.hipMallocMipmappedArray.mipmappedArray__val = + *(data->args.hipMallocMipmappedArray.mipmappedArray); + if (data->args.hipMallocMipmappedArray.desc) + data->args.hipMallocMipmappedArray.desc__val = + *(data->args.hipMallocMipmappedArray.desc); + break; + // hipMallocPitch[('void**', 'ptr'), ('size_t*', 'pitch'), ('size_t', + // 'width'), ('size_t', 'height')] + case HIP_API_ID_hipMallocPitch: + if (data->args.hipMallocPitch.ptr) + data->args.hipMallocPitch.ptr__val = *(data->args.hipMallocPitch.ptr); + if (data->args.hipMallocPitch.pitch) + data->args.hipMallocPitch.pitch__val = *(data->args.hipMallocPitch.pitch); + break; + // hipMemAddressFree[('void*', 'devPtr'), ('size_t', 'size')] + case HIP_API_ID_hipMemAddressFree: + break; + // hipMemAddressReserve[('void**', 'ptr'), ('size_t', 'size'), ('size_t', + // 'alignment'), ('void*', 'addr'), ('unsigned long long', 'flags')] + case HIP_API_ID_hipMemAddressReserve: + if (data->args.hipMemAddressReserve.ptr) + data->args.hipMemAddressReserve.ptr__val = + *(data->args.hipMemAddressReserve.ptr); + break; + // hipMemAdvise[('const void*', 'dev_ptr'), ('size_t', 'count'), + // ('hipMemoryAdvise', 'advice'), ('int', 'device')] + case HIP_API_ID_hipMemAdvise: + break; + // hipMemAllocHost[('void**', 'ptr'), ('size_t', 'size')] + case HIP_API_ID_hipMemAllocHost: + if (data->args.hipMemAllocHost.ptr) + data->args.hipMemAllocHost.ptr__val = *(data->args.hipMemAllocHost.ptr); + break; + // hipMemAllocPitch[('hipDeviceptr_t*', 'dptr'), ('size_t*', 'pitch'), + // ('size_t', 'widthInBytes'), ('size_t', 'height'), ('unsigned int', + // 'elementSizeBytes')] + case HIP_API_ID_hipMemAllocPitch: + if (data->args.hipMemAllocPitch.dptr) + data->args.hipMemAllocPitch.dptr__val = + *(data->args.hipMemAllocPitch.dptr); + if (data->args.hipMemAllocPitch.pitch) + data->args.hipMemAllocPitch.pitch__val = + *(data->args.hipMemAllocPitch.pitch); + break; + // hipMemCreate[('hipMemGenericAllocationHandle_t*', 'handle'), ('size_t', + // 'size'), ('const hipMemAllocationProp*', 'prop'), ('unsigned long long', + // 'flags')] + case HIP_API_ID_hipMemCreate: + if (data->args.hipMemCreate.handle) + data->args.hipMemCreate.handle__val = *(data->args.hipMemCreate.handle); + if (data->args.hipMemCreate.prop) + data->args.hipMemCreate.prop__val = *(data->args.hipMemCreate.prop); + break; + // hipMemExportToShareableHandle[('void*', 'shareableHandle'), + // ('hipMemGenericAllocationHandle_t', 'handle'), + // ('hipMemAllocationHandleType', 'handleType'), ('unsigned long long', + // 'flags')] + case HIP_API_ID_hipMemExportToShareableHandle: + break; + // hipMemGetAccess[('unsigned long long*', 'flags'), ('const + // hipMemLocation*', 'location'), ('void*', 'ptr')] + case HIP_API_ID_hipMemGetAccess: + if (data->args.hipMemGetAccess.flags) + data->args.hipMemGetAccess.flags__val = + *(data->args.hipMemGetAccess.flags); + if (data->args.hipMemGetAccess.location) + data->args.hipMemGetAccess.location__val = + *(data->args.hipMemGetAccess.location); + break; + // hipMemGetAddressRange[('hipDeviceptr_t*', 'pbase'), ('size_t*', 'psize'), + // ('hipDeviceptr_t', 'dptr')] + case HIP_API_ID_hipMemGetAddressRange: + if (data->args.hipMemGetAddressRange.pbase) + data->args.hipMemGetAddressRange.pbase__val = + *(data->args.hipMemGetAddressRange.pbase); + if (data->args.hipMemGetAddressRange.psize) + data->args.hipMemGetAddressRange.psize__val = + *(data->args.hipMemGetAddressRange.psize); + break; + // hipMemGetAllocationGranularity[('size_t*', 'granularity'), ('const + // hipMemAllocationProp*', 'prop'), ('hipMemAllocationGranularity_flags', + // 'option')] + case HIP_API_ID_hipMemGetAllocationGranularity: + if (data->args.hipMemGetAllocationGranularity.granularity) + data->args.hipMemGetAllocationGranularity.granularity__val = + *(data->args.hipMemGetAllocationGranularity.granularity); + if (data->args.hipMemGetAllocationGranularity.prop) + data->args.hipMemGetAllocationGranularity.prop__val = + *(data->args.hipMemGetAllocationGranularity.prop); + break; + // hipMemGetAllocationPropertiesFromHandle[('hipMemAllocationProp*', + // 'prop'), ('hipMemGenericAllocationHandle_t', 'handle')] + case HIP_API_ID_hipMemGetAllocationPropertiesFromHandle: + if (data->args.hipMemGetAllocationPropertiesFromHandle.prop) + data->args.hipMemGetAllocationPropertiesFromHandle.prop__val = + *(data->args.hipMemGetAllocationPropertiesFromHandle.prop); + break; + // hipMemGetInfo[('size_t*', 'free'), ('size_t*', 'total')] + case HIP_API_ID_hipMemGetInfo: + if (data->args.hipMemGetInfo.free) + data->args.hipMemGetInfo.free__val = *(data->args.hipMemGetInfo.free); + if (data->args.hipMemGetInfo.total) + data->args.hipMemGetInfo.total__val = *(data->args.hipMemGetInfo.total); + break; + // hipMemImportFromShareableHandle[('hipMemGenericAllocationHandle_t*', + // 'handle'), ('void*', 'osHandle'), ('hipMemAllocationHandleType', + // 'shHandleType')] + case HIP_API_ID_hipMemImportFromShareableHandle: + if (data->args.hipMemImportFromShareableHandle.handle) + data->args.hipMemImportFromShareableHandle.handle__val = + *(data->args.hipMemImportFromShareableHandle.handle); + break; + // hipMemMap[('void*', 'ptr'), ('size_t', 'size'), ('size_t', 'offset'), + // ('hipMemGenericAllocationHandle_t', 'handle'), ('unsigned long long', + // 'flags')] + case HIP_API_ID_hipMemMap: + break; + // hipMemMapArrayAsync[('hipArrayMapInfo*', 'mapInfoList'), ('unsigned int', + // 'count'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemMapArrayAsync: + if (data->args.hipMemMapArrayAsync.mapInfoList) + data->args.hipMemMapArrayAsync.mapInfoList__val = + *(data->args.hipMemMapArrayAsync.mapInfoList); + break; + // hipMemPoolCreate[('hipMemPool_t*', 'mem_pool'), ('const + // hipMemPoolProps*', 'pool_props')] + case HIP_API_ID_hipMemPoolCreate: + if (data->args.hipMemPoolCreate.mem_pool) + data->args.hipMemPoolCreate.mem_pool__val = + *(data->args.hipMemPoolCreate.mem_pool); + if (data->args.hipMemPoolCreate.pool_props) + data->args.hipMemPoolCreate.pool_props__val = + *(data->args.hipMemPoolCreate.pool_props); + break; + // hipMemPoolDestroy[('hipMemPool_t', 'mem_pool')] + case HIP_API_ID_hipMemPoolDestroy: + break; + // hipMemPoolExportPointer[('hipMemPoolPtrExportData*', 'export_data'), + // ('void*', 'dev_ptr')] + case HIP_API_ID_hipMemPoolExportPointer: + if (data->args.hipMemPoolExportPointer.export_data) + data->args.hipMemPoolExportPointer.export_data__val = + *(data->args.hipMemPoolExportPointer.export_data); + break; + // hipMemPoolExportToShareableHandle[('void*', 'shared_handle'), + // ('hipMemPool_t', 'mem_pool'), ('hipMemAllocationHandleType', + // 'handle_type'), ('unsigned int', 'flags')] + case HIP_API_ID_hipMemPoolExportToShareableHandle: + break; + // hipMemPoolGetAccess[('hipMemAccessFlags*', 'flags'), ('hipMemPool_t', + // 'mem_pool'), ('hipMemLocation*', 'location')] + case HIP_API_ID_hipMemPoolGetAccess: + if (data->args.hipMemPoolGetAccess.flags) + data->args.hipMemPoolGetAccess.flags__val = + *(data->args.hipMemPoolGetAccess.flags); + if (data->args.hipMemPoolGetAccess.location) + data->args.hipMemPoolGetAccess.location__val = + *(data->args.hipMemPoolGetAccess.location); + break; + // hipMemPoolGetAttribute[('hipMemPool_t', 'mem_pool'), ('hipMemPoolAttr', + // 'attr'), ('void*', 'value')] + case HIP_API_ID_hipMemPoolGetAttribute: + break; + // hipMemPoolImportFromShareableHandle[('hipMemPool_t*', 'mem_pool'), + // ('void*', 'shared_handle'), ('hipMemAllocationHandleType', + // 'handle_type'), ('unsigned int', 'flags')] + case HIP_API_ID_hipMemPoolImportFromShareableHandle: + if (data->args.hipMemPoolImportFromShareableHandle.mem_pool) + data->args.hipMemPoolImportFromShareableHandle.mem_pool__val = + *(data->args.hipMemPoolImportFromShareableHandle.mem_pool); + break; + // hipMemPoolImportPointer[('void**', 'dev_ptr'), ('hipMemPool_t', + // 'mem_pool'), ('hipMemPoolPtrExportData*', 'export_data')] + case HIP_API_ID_hipMemPoolImportPointer: + if (data->args.hipMemPoolImportPointer.dev_ptr) + data->args.hipMemPoolImportPointer.dev_ptr__val = + *(data->args.hipMemPoolImportPointer.dev_ptr); + if (data->args.hipMemPoolImportPointer.export_data) + data->args.hipMemPoolImportPointer.export_data__val = + *(data->args.hipMemPoolImportPointer.export_data); + break; + // hipMemPoolSetAccess[('hipMemPool_t', 'mem_pool'), ('const + // hipMemAccessDesc*', 'desc_list'), ('size_t', 'count')] + case HIP_API_ID_hipMemPoolSetAccess: + if (data->args.hipMemPoolSetAccess.desc_list) + data->args.hipMemPoolSetAccess.desc_list__val = + *(data->args.hipMemPoolSetAccess.desc_list); + break; + // hipMemPoolSetAttribute[('hipMemPool_t', 'mem_pool'), ('hipMemPoolAttr', + // 'attr'), ('void*', 'value')] + case HIP_API_ID_hipMemPoolSetAttribute: + break; + // hipMemPoolTrimTo[('hipMemPool_t', 'mem_pool'), ('size_t', + // 'min_bytes_to_hold')] + case HIP_API_ID_hipMemPoolTrimTo: + break; + // hipMemPrefetchAsync[('const void*', 'dev_ptr'), ('size_t', 'count'), + // ('int', 'device'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemPrefetchAsync: + break; + // hipMemPtrGetInfo[('void*', 'ptr'), ('size_t*', 'size')] + case HIP_API_ID_hipMemPtrGetInfo: + if (data->args.hipMemPtrGetInfo.size) + data->args.hipMemPtrGetInfo.size__val = + *(data->args.hipMemPtrGetInfo.size); + break; + // hipMemRangeGetAttribute[('void*', 'data'), ('size_t', 'data_size'), + // ('hipMemRangeAttribute', 'attribute'), ('const void*', 'dev_ptr'), + // ('size_t', 'count')] + case HIP_API_ID_hipMemRangeGetAttribute: + break; + // hipMemRangeGetAttributes[('void**', 'data'), ('size_t*', 'data_sizes'), + // ('hipMemRangeAttribute*', 'attributes'), ('size_t', 'num_attributes'), + // ('const void*', 'dev_ptr'), ('size_t', 'count')] + case HIP_API_ID_hipMemRangeGetAttributes: + if (data->args.hipMemRangeGetAttributes.data) + data->args.hipMemRangeGetAttributes.data__val = + *(data->args.hipMemRangeGetAttributes.data); + if (data->args.hipMemRangeGetAttributes.data_sizes) + data->args.hipMemRangeGetAttributes.data_sizes__val = + *(data->args.hipMemRangeGetAttributes.data_sizes); + if (data->args.hipMemRangeGetAttributes.attributes) + data->args.hipMemRangeGetAttributes.attributes__val = + *(data->args.hipMemRangeGetAttributes.attributes); + break; + // hipMemRelease[('hipMemGenericAllocationHandle_t', 'handle')] + case HIP_API_ID_hipMemRelease: + break; + // hipMemRetainAllocationHandle[('hipMemGenericAllocationHandle_t*', + // 'handle'), ('void*', 'addr')] + case HIP_API_ID_hipMemRetainAllocationHandle: + if (data->args.hipMemRetainAllocationHandle.handle) + data->args.hipMemRetainAllocationHandle.handle__val = + *(data->args.hipMemRetainAllocationHandle.handle); + break; + // hipMemSetAccess[('void*', 'ptr'), ('size_t', 'size'), ('const + // hipMemAccessDesc*', 'desc'), ('size_t', 'count')] + case HIP_API_ID_hipMemSetAccess: + if (data->args.hipMemSetAccess.desc) + data->args.hipMemSetAccess.desc__val = *(data->args.hipMemSetAccess.desc); + break; + // hipMemUnmap[('void*', 'ptr'), ('size_t', 'size')] + case HIP_API_ID_hipMemUnmap: + break; + // hipMemcpy[('void*', 'dst'), ('const void*', 'src'), ('size_t', + // 'sizeBytes'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpy: + break; + // hipMemcpy2D[('void*', 'dst'), ('size_t', 'dpitch'), ('const void*', + // 'src'), ('size_t', 'spitch'), ('size_t', 'width'), ('size_t', 'height'), + // ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpy2D: + break; + // hipMemcpy2DArrayToArray[('hipArray_t', 'dst'), ('size_t', 'wOffsetDst'), + // ('size_t', 'hOffsetDst'), ('hipArray_const_t', 'src'), ('size_t', + // 'wOffsetSrc'), ('size_t', 'hOffsetSrc'), ('size_t', 'width'), ('size_t', + // 'height'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpy2DArrayToArray: + break; + // hipMemcpy2DAsync[('void*', 'dst'), ('size_t', 'dpitch'), ('const void*', + // 'src'), ('size_t', 'spitch'), ('size_t', 'width'), ('size_t', 'height'), + // ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpy2DAsync: + break; + // hipMemcpy2DFromArray[('void*', 'dst'), ('size_t', 'dpitch'), + // ('hipArray_const_t', 'src'), ('size_t', 'wOffset'), ('size_t', + // 'hOffset'), ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', + // 'kind')] + case HIP_API_ID_hipMemcpy2DFromArray: + break; + // hipMemcpy2DFromArrayAsync[('void*', 'dst'), ('size_t', 'dpitch'), + // ('hipArray_const_t', 'src'), ('size_t', 'wOffset'), ('size_t', + // 'hOffset'), ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', + // 'kind'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpy2DFromArrayAsync: + break; + // hipMemcpy2DToArray[('hipArray_t', 'dst'), ('size_t', 'wOffset'), + // ('size_t', 'hOffset'), ('const void*', 'src'), ('size_t', 'spitch'), + // ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpy2DToArray: + break; + // hipMemcpy2DToArrayAsync[('hipArray_t', 'dst'), ('size_t', 'wOffset'), + // ('size_t', 'hOffset'), ('const void*', 'src'), ('size_t', 'spitch'), + // ('size_t', 'width'), ('size_t', 'height'), ('hipMemcpyKind', 'kind'), + // ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpy2DToArrayAsync: + break; + // hipMemcpy3D[('const hipMemcpy3DParms*', 'p')] + case HIP_API_ID_hipMemcpy3D: + if (data->args.hipMemcpy3D.p) + data->args.hipMemcpy3D.p__val = *(data->args.hipMemcpy3D.p); + break; + // hipMemcpy3DAsync[('const hipMemcpy3DParms*', 'p'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipMemcpy3DAsync: + if (data->args.hipMemcpy3DAsync.p) + data->args.hipMemcpy3DAsync.p__val = *(data->args.hipMemcpy3DAsync.p); + break; + // hipMemcpyAsync[('void*', 'dst'), ('const void*', 'src'), ('size_t', + // 'sizeBytes'), ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyAsync: + break; + // hipMemcpyAtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), + // ('hipArray_t', 'srcArray'), ('size_t', 'srcOffset'), ('size_t', + // 'ByteCount')] + case HIP_API_ID_hipMemcpyAtoA: + break; + // hipMemcpyAtoD[('hipDeviceptr_t', 'dstDevice'), ('hipArray_t', + // 'srcArray'), ('size_t', 'srcOffset'), ('size_t', 'ByteCount')] + case HIP_API_ID_hipMemcpyAtoD: + break; + // hipMemcpyAtoH[('void*', 'dst'), ('hipArray_t', 'srcArray'), ('size_t', + // 'srcOffset'), ('size_t', 'count')] + case HIP_API_ID_hipMemcpyAtoH: + break; + // hipMemcpyAtoHAsync[('void*', 'dstHost'), ('hipArray_t', 'srcArray'), + // ('size_t', 'srcOffset'), ('size_t', 'ByteCount'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipMemcpyAtoHAsync: + break; + // hipMemcpyDtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), + // ('hipDeviceptr_t', 'srcDevice'), ('size_t', 'ByteCount')] + case HIP_API_ID_hipMemcpyDtoA: + break; + // hipMemcpyDtoD[('hipDeviceptr_t', 'dst'), ('hipDeviceptr_t', 'src'), + // ('size_t', 'sizeBytes')] + case HIP_API_ID_hipMemcpyDtoD: + break; + // hipMemcpyDtoDAsync[('hipDeviceptr_t', 'dst'), ('hipDeviceptr_t', 'src'), + // ('size_t', 'sizeBytes'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyDtoDAsync: + break; + // hipMemcpyDtoH[('void*', 'dst'), ('hipDeviceptr_t', 'src'), ('size_t', + // 'sizeBytes')] + case HIP_API_ID_hipMemcpyDtoH: + break; + // hipMemcpyDtoHAsync[('void*', 'dst'), ('hipDeviceptr_t', 'src'), + // ('size_t', 'sizeBytes'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyDtoHAsync: + break; + // hipMemcpyFromArray[('void*', 'dst'), ('hipArray_const_t', 'srcArray'), + // ('size_t', 'wOffset'), ('size_t', 'hOffset'), ('size_t', 'count'), + // ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpyFromArray: + break; + // hipMemcpyFromSymbol[('void*', 'dst'), ('const void*', 'symbol'), + // ('size_t', 'sizeBytes'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpyFromSymbol: + break; + // hipMemcpyFromSymbolAsync[('void*', 'dst'), ('const void*', 'symbol'), + // ('size_t', 'sizeBytes'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind'), + // ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyFromSymbolAsync: + break; + // hipMemcpyHtoA[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), + // ('const void*', 'srcHost'), ('size_t', 'count')] + case HIP_API_ID_hipMemcpyHtoA: + break; + // hipMemcpyHtoAAsync[('hipArray_t', 'dstArray'), ('size_t', 'dstOffset'), + // ('const void*', 'srcHost'), ('size_t', 'ByteCount'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipMemcpyHtoAAsync: + break; + // hipMemcpyHtoD[('hipDeviceptr_t', 'dst'), ('void*', 'src'), ('size_t', + // 'sizeBytes')] + case HIP_API_ID_hipMemcpyHtoD: + break; + // hipMemcpyHtoDAsync[('hipDeviceptr_t', 'dst'), ('void*', 'src'), + // ('size_t', 'sizeBytes'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyHtoDAsync: + break; + // hipMemcpyParam2D[('const hip_Memcpy2D*', 'pCopy')] + case HIP_API_ID_hipMemcpyParam2D: + if (data->args.hipMemcpyParam2D.pCopy) + data->args.hipMemcpyParam2D.pCopy__val = + *(data->args.hipMemcpyParam2D.pCopy); + break; + // hipMemcpyParam2DAsync[('const hip_Memcpy2D*', 'pCopy'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipMemcpyParam2DAsync: + if (data->args.hipMemcpyParam2DAsync.pCopy) + data->args.hipMemcpyParam2DAsync.pCopy__val = + *(data->args.hipMemcpyParam2DAsync.pCopy); + break; + // hipMemcpyPeer[('void*', 'dst'), ('int', 'dstDeviceId'), ('const void*', + // 'src'), ('int', 'srcDeviceId'), ('size_t', 'sizeBytes')] + case HIP_API_ID_hipMemcpyPeer: + break; + // hipMemcpyPeerAsync[('void*', 'dst'), ('int', 'dstDeviceId'), ('const + // void*', 'src'), ('int', 'srcDevice'), ('size_t', 'sizeBytes'), + // ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyPeerAsync: + break; + // hipMemcpyToArray[('hipArray_t', 'dst'), ('size_t', 'wOffset'), ('size_t', + // 'hOffset'), ('const void*', 'src'), ('size_t', 'count'), + // ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpyToArray: + break; + // hipMemcpyToSymbol[('const void*', 'symbol'), ('const void*', 'src'), + // ('size_t', 'sizeBytes'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind')] + case HIP_API_ID_hipMemcpyToSymbol: + break; + // hipMemcpyToSymbolAsync[('const void*', 'symbol'), ('const void*', 'src'), + // ('size_t', 'sizeBytes'), ('size_t', 'offset'), ('hipMemcpyKind', 'kind'), + // ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyToSymbolAsync: + break; + // hipMemcpyWithStream[('void*', 'dst'), ('const void*', 'src'), ('size_t', + // 'sizeBytes'), ('hipMemcpyKind', 'kind'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemcpyWithStream: + break; + // hipMemset[('void*', 'dst'), ('int', 'value'), ('size_t', 'sizeBytes')] + case HIP_API_ID_hipMemset: + break; + // hipMemset2D[('void*', 'dst'), ('size_t', 'pitch'), ('int', 'value'), + // ('size_t', 'width'), ('size_t', 'height')] + case HIP_API_ID_hipMemset2D: + break; + // hipMemset2DAsync[('void*', 'dst'), ('size_t', 'pitch'), ('int', 'value'), + // ('size_t', 'width'), ('size_t', 'height'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemset2DAsync: + break; + // hipMemset3D[('hipPitchedPtr', 'pitchedDevPtr'), ('int', 'value'), + // ('hipExtent', 'extent')] + case HIP_API_ID_hipMemset3D: + break; + // hipMemset3DAsync[('hipPitchedPtr', 'pitchedDevPtr'), ('int', 'value'), + // ('hipExtent', 'extent'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemset3DAsync: + break; + // hipMemsetAsync[('void*', 'dst'), ('int', 'value'), ('size_t', + // 'sizeBytes'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemsetAsync: + break; + // hipMemsetD16[('hipDeviceptr_t', 'dest'), ('unsigned short', 'value'), + // ('size_t', 'count')] + case HIP_API_ID_hipMemsetD16: + break; + // hipMemsetD16Async[('hipDeviceptr_t', 'dest'), ('unsigned short', + // 'value'), ('size_t', 'count'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemsetD16Async: + break; + // hipMemsetD32[('hipDeviceptr_t', 'dest'), ('int', 'value'), ('size_t', + // 'count')] + case HIP_API_ID_hipMemsetD32: + break; + // hipMemsetD32Async[('hipDeviceptr_t', 'dst'), ('int', 'value'), ('size_t', + // 'count'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemsetD32Async: + break; + // hipMemsetD8[('hipDeviceptr_t', 'dest'), ('unsigned char', 'value'), + // ('size_t', 'count')] + case HIP_API_ID_hipMemsetD8: + break; + // hipMemsetD8Async[('hipDeviceptr_t', 'dest'), ('unsigned char', 'value'), + // ('size_t', 'count'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipMemsetD8Async: + break; + // hipMipmappedArrayCreate[('hipMipmappedArray_t*', 'pHandle'), + // ('HIP_ARRAY3D_DESCRIPTOR*', 'pMipmappedArrayDesc'), ('unsigned int', + // 'numMipmapLevels')] + case HIP_API_ID_hipMipmappedArrayCreate: + if (data->args.hipMipmappedArrayCreate.pHandle) + data->args.hipMipmappedArrayCreate.pHandle__val = + *(data->args.hipMipmappedArrayCreate.pHandle); + if (data->args.hipMipmappedArrayCreate.pMipmappedArrayDesc) + data->args.hipMipmappedArrayCreate.pMipmappedArrayDesc__val = + *(data->args.hipMipmappedArrayCreate.pMipmappedArrayDesc); + break; + // hipMipmappedArrayDestroy[('hipMipmappedArray_t', 'hMipmappedArray')] + case HIP_API_ID_hipMipmappedArrayDestroy: + break; + // hipMipmappedArrayGetLevel[('hipArray_t*', 'pLevelArray'), + // ('hipMipmappedArray_t', 'hMipMappedArray'), ('unsigned int', 'level')] + case HIP_API_ID_hipMipmappedArrayGetLevel: + if (data->args.hipMipmappedArrayGetLevel.pLevelArray) + data->args.hipMipmappedArrayGetLevel.pLevelArray__val = + *(data->args.hipMipmappedArrayGetLevel.pLevelArray); + break; + // hipModuleGetFunction[('hipFunction_t*', 'function'), ('hipModule_t', + // 'module'), ('const char*', 'kname')] + case HIP_API_ID_hipModuleGetFunction: + if (data->args.hipModuleGetFunction.function) + data->args.hipModuleGetFunction.function__val = + *(data->args.hipModuleGetFunction.function); + if (data->args.hipModuleGetFunction.kname) + data->args.hipModuleGetFunction.kname__val = + *(data->args.hipModuleGetFunction.kname); + break; + // hipModuleGetGlobal[('hipDeviceptr_t*', 'dptr'), ('size_t*', 'bytes'), + // ('hipModule_t', 'hmod'), ('const char*', 'name')] + case HIP_API_ID_hipModuleGetGlobal: + if (data->args.hipModuleGetGlobal.dptr) + data->args.hipModuleGetGlobal.dptr__val = + *(data->args.hipModuleGetGlobal.dptr); + if (data->args.hipModuleGetGlobal.bytes) + data->args.hipModuleGetGlobal.bytes__val = + *(data->args.hipModuleGetGlobal.bytes); + if (data->args.hipModuleGetGlobal.name) + data->args.hipModuleGetGlobal.name__val = + *(data->args.hipModuleGetGlobal.name); + break; + // hipModuleGetTexRef[('textureReference**', 'texRef'), ('hipModule_t', + // 'hmod'), ('const char*', 'name')] + case HIP_API_ID_hipModuleGetTexRef: + if (data->args.hipModuleGetTexRef.texRef) + data->args.hipModuleGetTexRef.texRef__val = + *(data->args.hipModuleGetTexRef.texRef); + if (data->args.hipModuleGetTexRef.name) + data->args.hipModuleGetTexRef.name__val = + *(data->args.hipModuleGetTexRef.name); + break; + // hipModuleLaunchCooperativeKernel[('hipFunction_t', 'f'), ('unsigned int', + // 'gridDimX'), ('unsigned int', 'gridDimY'), ('unsigned int', 'gridDimZ'), + // ('unsigned int', 'blockDimX'), ('unsigned int', 'blockDimY'), ('unsigned + // int', 'blockDimZ'), ('unsigned int', 'sharedMemBytes'), ('hipStream_t', + // 'stream'), ('void**', 'kernelParams')] + case HIP_API_ID_hipModuleLaunchCooperativeKernel: + if (data->args.hipModuleLaunchCooperativeKernel.kernelParams) + data->args.hipModuleLaunchCooperativeKernel.kernelParams__val = + *(data->args.hipModuleLaunchCooperativeKernel.kernelParams); + break; + // hipModuleLaunchCooperativeKernelMultiDevice[('hipFunctionLaunchParams*', + // 'launchParamsList'), ('unsigned int', 'numDevices'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: + if (data->args.hipModuleLaunchCooperativeKernelMultiDevice.launchParamsList) + data->args.hipModuleLaunchCooperativeKernelMultiDevice + .launchParamsList__val = + *(data->args.hipModuleLaunchCooperativeKernelMultiDevice + .launchParamsList); + break; + // hipModuleLaunchKernel[('hipFunction_t', 'f'), ('unsigned int', + // 'gridDimX'), ('unsigned int', 'gridDimY'), ('unsigned int', 'gridDimZ'), + // ('unsigned int', 'blockDimX'), ('unsigned int', 'blockDimY'), ('unsigned + // int', 'blockDimZ'), ('unsigned int', 'sharedMemBytes'), ('hipStream_t', + // 'stream'), ('void**', 'kernelParams'), ('void**', 'extra')] + case HIP_API_ID_hipModuleLaunchKernel: + if (data->args.hipModuleLaunchKernel.kernelParams) + data->args.hipModuleLaunchKernel.kernelParams__val = + *(data->args.hipModuleLaunchKernel.kernelParams); + if (data->args.hipModuleLaunchKernel.extra) + data->args.hipModuleLaunchKernel.extra__val = + *(data->args.hipModuleLaunchKernel.extra); + break; + // hipModuleLoad[('hipModule_t*', 'module'), ('const char*', 'fname')] + case HIP_API_ID_hipModuleLoad: + if (data->args.hipModuleLoad.module) + data->args.hipModuleLoad.module__val = *(data->args.hipModuleLoad.module); + if (data->args.hipModuleLoad.fname) + data->args.hipModuleLoad.fname__val = *(data->args.hipModuleLoad.fname); + break; + // hipModuleLoadData[('hipModule_t*', 'module'), ('const void*', 'image')] + case HIP_API_ID_hipModuleLoadData: + if (data->args.hipModuleLoadData.module) + data->args.hipModuleLoadData.module__val = + *(data->args.hipModuleLoadData.module); + break; + // hipModuleLoadDataEx[('hipModule_t*', 'module'), ('const void*', 'image'), + // ('unsigned int', 'numOptions'), ('hipJitOption*', 'options'), ('void**', + // 'optionsValues')] + case HIP_API_ID_hipModuleLoadDataEx: + if (data->args.hipModuleLoadDataEx.module) + data->args.hipModuleLoadDataEx.module__val = + *(data->args.hipModuleLoadDataEx.module); + if (data->args.hipModuleLoadDataEx.options) + data->args.hipModuleLoadDataEx.options__val = + *(data->args.hipModuleLoadDataEx.options); + if (data->args.hipModuleLoadDataEx.optionsValues) + data->args.hipModuleLoadDataEx.optionsValues__val = + *(data->args.hipModuleLoadDataEx.optionsValues); + break; + // hipModuleOccupancyMaxActiveBlocksPerMultiprocessor[('int*', 'numBlocks'), + // ('hipFunction_t', 'f'), ('int', 'blockSize'), ('size_t', + // 'dynSharedMemPerBlk')] + case HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessor: + if (data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor.numBlocks) + data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor + .numBlocks__val = + *(data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor + .numBlocks); + break; + // hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags[('int*', + // 'numBlocks'), ('hipFunction_t', 'f'), ('int', 'blockSize'), ('size_t', + // 'dynSharedMemPerBlk'), ('unsigned int', 'flags')] + case HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: + if (data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks) + data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks__val = *( + data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks); + break; + // hipModuleOccupancyMaxPotentialBlockSize[('int*', 'gridSize'), ('int*', + // 'blockSize'), ('hipFunction_t', 'f'), ('size_t', 'dynSharedMemPerBlk'), + // ('int', 'blockSizeLimit')] + case HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSize: + if (data->args.hipModuleOccupancyMaxPotentialBlockSize.gridSize) + data->args.hipModuleOccupancyMaxPotentialBlockSize.gridSize__val = + *(data->args.hipModuleOccupancyMaxPotentialBlockSize.gridSize); + if (data->args.hipModuleOccupancyMaxPotentialBlockSize.blockSize) + data->args.hipModuleOccupancyMaxPotentialBlockSize.blockSize__val = + *(data->args.hipModuleOccupancyMaxPotentialBlockSize.blockSize); + break; + // hipModuleOccupancyMaxPotentialBlockSizeWithFlags[('int*', 'gridSize'), + // ('int*', 'blockSize'), ('hipFunction_t', 'f'), ('size_t', + // 'dynSharedMemPerBlk'), ('int', 'blockSizeLimit'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSizeWithFlags: + if (data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.gridSize) + data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags + .gridSize__val = *( + data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.gridSize); + if (data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.blockSize) + data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags + .blockSize__val = + *(data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags + .blockSize); + break; + // hipModuleUnload[('hipModule_t', 'module')] + case HIP_API_ID_hipModuleUnload: + break; + // hipOccupancyMaxActiveBlocksPerMultiprocessor[('int*', 'numBlocks'), + // ('const void*', 'f'), ('int', 'blockSize'), ('size_t', + // 'dynamicSMemSize')] + case HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessor: + if (data->args.hipOccupancyMaxActiveBlocksPerMultiprocessor.numBlocks) + data->args.hipOccupancyMaxActiveBlocksPerMultiprocessor.numBlocks__val = + *(data->args.hipOccupancyMaxActiveBlocksPerMultiprocessor.numBlocks); + break; + // hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags[('int*', + // 'numBlocks'), ('const void*', 'f'), ('int', 'blockSize'), ('size_t', + // 'dynamicSMemSize'), ('unsigned int', 'flags')] + case HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: + if (data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks) + data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks__val = + *(data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks); + break; + // hipOccupancyMaxPotentialBlockSize[('int*', 'gridSize'), ('int*', + // 'blockSize'), ('const void*', 'f'), ('size_t', 'dynSharedMemPerBlk'), + // ('int', 'blockSizeLimit')] + case HIP_API_ID_hipOccupancyMaxPotentialBlockSize: + if (data->args.hipOccupancyMaxPotentialBlockSize.gridSize) + data->args.hipOccupancyMaxPotentialBlockSize.gridSize__val = + *(data->args.hipOccupancyMaxPotentialBlockSize.gridSize); + if (data->args.hipOccupancyMaxPotentialBlockSize.blockSize) + data->args.hipOccupancyMaxPotentialBlockSize.blockSize__val = + *(data->args.hipOccupancyMaxPotentialBlockSize.blockSize); + break; + // hipPeekAtLastError[] + case HIP_API_ID_hipPeekAtLastError: + break; + // hipPointerGetAttribute[('void*', 'data'), ('hipPointer_attribute', + // 'attribute'), ('hipDeviceptr_t', 'ptr')] + case HIP_API_ID_hipPointerGetAttribute: + break; + // hipPointerGetAttributes[('hipPointerAttribute_t*', 'attributes'), ('const + // void*', 'ptr')] + case HIP_API_ID_hipPointerGetAttributes: + if (data->args.hipPointerGetAttributes.attributes) + data->args.hipPointerGetAttributes.attributes__val = + *(data->args.hipPointerGetAttributes.attributes); + break; + // hipPointerSetAttribute[('const void*', 'value'), ('hipPointer_attribute', + // 'attribute'), ('hipDeviceptr_t', 'ptr')] + case HIP_API_ID_hipPointerSetAttribute: + break; + // hipProfilerStart[] + case HIP_API_ID_hipProfilerStart: + break; + // hipProfilerStop[] + case HIP_API_ID_hipProfilerStop: + break; + // hipRuntimeGetVersion[('int*', 'runtimeVersion')] + case HIP_API_ID_hipRuntimeGetVersion: + if (data->args.hipRuntimeGetVersion.runtimeVersion) + data->args.hipRuntimeGetVersion.runtimeVersion__val = + *(data->args.hipRuntimeGetVersion.runtimeVersion); + break; + // hipSetDevice[('int', 'deviceId')] + case HIP_API_ID_hipSetDevice: + break; + // hipSetDeviceFlags[('unsigned int', 'flags')] + case HIP_API_ID_hipSetDeviceFlags: + break; + // hipSetValidDevices[('int*', 'device_arr'), ('int', 'len')] + case HIP_API_ID_hipSetValidDevices: + if (data->args.hipSetValidDevices.device_arr) + data->args.hipSetValidDevices.device_arr__val = + *(data->args.hipSetValidDevices.device_arr); + break; + // hipSetupArgument[('const void*', 'arg'), ('size_t', 'size'), ('size_t', + // 'offset')] + case HIP_API_ID_hipSetupArgument: + break; + // hipSignalExternalSemaphoresAsync[('const hipExternalSemaphore_t*', + // 'extSemArray'), ('const hipExternalSemaphoreSignalParams*', + // 'paramsArray'), ('unsigned int', 'numExtSems'), ('hipStream_t', + // 'stream')] + case HIP_API_ID_hipSignalExternalSemaphoresAsync: + if (data->args.hipSignalExternalSemaphoresAsync.extSemArray) + data->args.hipSignalExternalSemaphoresAsync.extSemArray__val = + *(data->args.hipSignalExternalSemaphoresAsync.extSemArray); + if (data->args.hipSignalExternalSemaphoresAsync.paramsArray) + data->args.hipSignalExternalSemaphoresAsync.paramsArray__val = + *(data->args.hipSignalExternalSemaphoresAsync.paramsArray); + break; + // hipStreamAddCallback[('hipStream_t', 'stream'), ('hipStreamCallback_t', + // 'callback'), ('void*', 'userData'), ('unsigned int', 'flags')] + case HIP_API_ID_hipStreamAddCallback: + break; + // hipStreamAttachMemAsync[('hipStream_t', 'stream'), ('void*', 'dev_ptr'), + // ('size_t', 'length'), ('unsigned int', 'flags')] + case HIP_API_ID_hipStreamAttachMemAsync: + break; + // hipStreamBeginCapture[('hipStream_t', 'stream'), ('hipStreamCaptureMode', + // 'mode')] + case HIP_API_ID_hipStreamBeginCapture: + break; + // hipStreamBeginCaptureToGraph[('hipStream_t', 'stream'), ('hipGraph_t', + // 'graph'), ('const hipGraphNode_t*', 'dependencies'), ('const + // hipGraphEdgeData*', 'dependencyData'), ('size_t', 'numDependencies'), + // ('hipStreamCaptureMode', 'mode')] + case HIP_API_ID_hipStreamBeginCaptureToGraph: + if (data->args.hipStreamBeginCaptureToGraph.dependencies) + data->args.hipStreamBeginCaptureToGraph.dependencies__val = + *(data->args.hipStreamBeginCaptureToGraph.dependencies); + if (data->args.hipStreamBeginCaptureToGraph.dependencyData) + data->args.hipStreamBeginCaptureToGraph.dependencyData__val = + *(data->args.hipStreamBeginCaptureToGraph.dependencyData); + break; + // hipStreamCreate[('hipStream_t*', 'stream')] + case HIP_API_ID_hipStreamCreate: + if (data->args.hipStreamCreate.stream) + data->args.hipStreamCreate.stream__val = + *(data->args.hipStreamCreate.stream); + break; + // hipStreamCreateWithFlags[('hipStream_t*', 'stream'), ('unsigned int', + // 'flags')] + case HIP_API_ID_hipStreamCreateWithFlags: + if (data->args.hipStreamCreateWithFlags.stream) + data->args.hipStreamCreateWithFlags.stream__val = + *(data->args.hipStreamCreateWithFlags.stream); + break; + // hipStreamCreateWithPriority[('hipStream_t*', 'stream'), ('unsigned int', + // 'flags'), ('int', 'priority')] + case HIP_API_ID_hipStreamCreateWithPriority: + if (data->args.hipStreamCreateWithPriority.stream) + data->args.hipStreamCreateWithPriority.stream__val = + *(data->args.hipStreamCreateWithPriority.stream); + break; + // hipStreamDestroy[('hipStream_t', 'stream')] + case HIP_API_ID_hipStreamDestroy: + break; + // hipStreamEndCapture[('hipStream_t', 'stream'), ('hipGraph_t*', 'pGraph')] + case HIP_API_ID_hipStreamEndCapture: + if (data->args.hipStreamEndCapture.pGraph) + data->args.hipStreamEndCapture.pGraph__val = + *(data->args.hipStreamEndCapture.pGraph); + break; + // hipStreamGetCaptureInfo[('hipStream_t', 'stream'), + // ('hipStreamCaptureStatus*', 'pCaptureStatus'), ('unsigned long long*', + // 'pId')] + case HIP_API_ID_hipStreamGetCaptureInfo: + if (data->args.hipStreamGetCaptureInfo.pCaptureStatus) + data->args.hipStreamGetCaptureInfo.pCaptureStatus__val = + *(data->args.hipStreamGetCaptureInfo.pCaptureStatus); + if (data->args.hipStreamGetCaptureInfo.pId) + data->args.hipStreamGetCaptureInfo.pId__val = + *(data->args.hipStreamGetCaptureInfo.pId); + break; + // hipStreamGetCaptureInfo_v2[('hipStream_t', 'stream'), + // ('hipStreamCaptureStatus*', 'captureStatus_out'), ('unsigned long long*', + // 'id_out'), ('hipGraph_t*', 'graph_out'), ('const hipGraphNode_t**', + // 'dependencies_out'), ('size_t*', 'numDependencies_out')] + case HIP_API_ID_hipStreamGetCaptureInfo_v2: + if (data->args.hipStreamGetCaptureInfo_v2.captureStatus_out) + data->args.hipStreamGetCaptureInfo_v2.captureStatus_out__val = + *(data->args.hipStreamGetCaptureInfo_v2.captureStatus_out); + if (data->args.hipStreamGetCaptureInfo_v2.id_out) + data->args.hipStreamGetCaptureInfo_v2.id_out__val = + *(data->args.hipStreamGetCaptureInfo_v2.id_out); + if (data->args.hipStreamGetCaptureInfo_v2.graph_out) + data->args.hipStreamGetCaptureInfo_v2.graph_out__val = + *(data->args.hipStreamGetCaptureInfo_v2.graph_out); + if (data->args.hipStreamGetCaptureInfo_v2.dependencies_out) + data->args.hipStreamGetCaptureInfo_v2.dependencies_out__val = + *(data->args.hipStreamGetCaptureInfo_v2.dependencies_out); + if (data->args.hipStreamGetCaptureInfo_v2.numDependencies_out) + data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val = + *(data->args.hipStreamGetCaptureInfo_v2.numDependencies_out); + break; + // hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')] + case HIP_API_ID_hipStreamGetDevice: + if (data->args.hipStreamGetDevice.device) + data->args.hipStreamGetDevice.device__val = + *(data->args.hipStreamGetDevice.device); + break; + // hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')] + case HIP_API_ID_hipStreamGetFlags: + if (data->args.hipStreamGetFlags.flags) + data->args.hipStreamGetFlags.flags__val = + *(data->args.hipStreamGetFlags.flags); + break; + // hipStreamGetPriority[('hipStream_t', 'stream'), ('int*', 'priority')] + case HIP_API_ID_hipStreamGetPriority: + if (data->args.hipStreamGetPriority.priority) + data->args.hipStreamGetPriority.priority__val = + *(data->args.hipStreamGetPriority.priority); + break; + // hipStreamIsCapturing[('hipStream_t', 'stream'), + // ('hipStreamCaptureStatus*', 'pCaptureStatus')] + case HIP_API_ID_hipStreamIsCapturing: + if (data->args.hipStreamIsCapturing.pCaptureStatus) + data->args.hipStreamIsCapturing.pCaptureStatus__val = + *(data->args.hipStreamIsCapturing.pCaptureStatus); + break; + // hipStreamQuery[('hipStream_t', 'stream')] + case HIP_API_ID_hipStreamQuery: + break; + // hipStreamSynchronize[('hipStream_t', 'stream')] + case HIP_API_ID_hipStreamSynchronize: + break; + // hipStreamUpdateCaptureDependencies[('hipStream_t', 'stream'), + // ('hipGraphNode_t*', 'dependencies'), ('size_t', 'numDependencies'), + // ('unsigned int', 'flags')] + case HIP_API_ID_hipStreamUpdateCaptureDependencies: + if (data->args.hipStreamUpdateCaptureDependencies.dependencies) + data->args.hipStreamUpdateCaptureDependencies.dependencies__val = + *(data->args.hipStreamUpdateCaptureDependencies.dependencies); + break; + // hipStreamWaitEvent[('hipStream_t', 'stream'), ('hipEvent_t', 'event'), + // ('unsigned int', 'flags')] + case HIP_API_ID_hipStreamWaitEvent: + break; + // hipStreamWaitValue32[('hipStream_t', 'stream'), ('void*', 'ptr'), + // ('unsigned int', 'value'), ('unsigned int', 'flags'), ('unsigned int', + // 'mask')] + case HIP_API_ID_hipStreamWaitValue32: + break; + // hipStreamWaitValue64[('hipStream_t', 'stream'), ('void*', 'ptr'), + // ('uint64_t', 'value'), ('unsigned int', 'flags'), ('uint64_t', 'mask')] + case HIP_API_ID_hipStreamWaitValue64: + break; + // hipStreamWriteValue32[('hipStream_t', 'stream'), ('void*', 'ptr'), + // ('unsigned int', 'value'), ('unsigned int', 'flags')] + case HIP_API_ID_hipStreamWriteValue32: + break; + // hipStreamWriteValue64[('hipStream_t', 'stream'), ('void*', 'ptr'), + // ('uint64_t', 'value'), ('unsigned int', 'flags')] + case HIP_API_ID_hipStreamWriteValue64: + break; + // hipTexRefGetAddress[('hipDeviceptr_t*', 'dev_ptr'), ('const + // textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetAddress: + if (data->args.hipTexRefGetAddress.dev_ptr) + data->args.hipTexRefGetAddress.dev_ptr__val = + *(data->args.hipTexRefGetAddress.dev_ptr); + if (data->args.hipTexRefGetAddress.texRef) + data->args.hipTexRefGetAddress.texRef__val = + *(data->args.hipTexRefGetAddress.texRef); + break; + // hipTexRefGetArray[('hipArray_t*', 'pArray'), ('const textureReference*', + // 'texRef')] + case HIP_API_ID_hipTexRefGetArray: + if (data->args.hipTexRefGetArray.pArray) + data->args.hipTexRefGetArray.pArray__val = + *(data->args.hipTexRefGetArray.pArray); + if (data->args.hipTexRefGetArray.texRef) + data->args.hipTexRefGetArray.texRef__val = + *(data->args.hipTexRefGetArray.texRef); + break; + // hipTexRefGetBorderColor[('float*', 'pBorderColor'), ('const + // textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetBorderColor: + if (data->args.hipTexRefGetBorderColor.pBorderColor) + data->args.hipTexRefGetBorderColor.pBorderColor__val = + *(data->args.hipTexRefGetBorderColor.pBorderColor); + if (data->args.hipTexRefGetBorderColor.texRef) + data->args.hipTexRefGetBorderColor.texRef__val = + *(data->args.hipTexRefGetBorderColor.texRef); + break; + // hipTexRefGetFlags[('unsigned int*', 'pFlags'), ('const + // textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetFlags: + if (data->args.hipTexRefGetFlags.pFlags) + data->args.hipTexRefGetFlags.pFlags__val = + *(data->args.hipTexRefGetFlags.pFlags); + if (data->args.hipTexRefGetFlags.texRef) + data->args.hipTexRefGetFlags.texRef__val = + *(data->args.hipTexRefGetFlags.texRef); + break; + // hipTexRefGetFormat[('hipArray_Format*', 'pFormat'), ('int*', + // 'pNumChannels'), ('const textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetFormat: + if (data->args.hipTexRefGetFormat.pFormat) + data->args.hipTexRefGetFormat.pFormat__val = + *(data->args.hipTexRefGetFormat.pFormat); + if (data->args.hipTexRefGetFormat.pNumChannels) + data->args.hipTexRefGetFormat.pNumChannels__val = + *(data->args.hipTexRefGetFormat.pNumChannels); + if (data->args.hipTexRefGetFormat.texRef) + data->args.hipTexRefGetFormat.texRef__val = + *(data->args.hipTexRefGetFormat.texRef); + break; + // hipTexRefGetMaxAnisotropy[('int*', 'pmaxAnsio'), ('const + // textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetMaxAnisotropy: + if (data->args.hipTexRefGetMaxAnisotropy.pmaxAnsio) + data->args.hipTexRefGetMaxAnisotropy.pmaxAnsio__val = + *(data->args.hipTexRefGetMaxAnisotropy.pmaxAnsio); + if (data->args.hipTexRefGetMaxAnisotropy.texRef) + data->args.hipTexRefGetMaxAnisotropy.texRef__val = + *(data->args.hipTexRefGetMaxAnisotropy.texRef); + break; + // hipTexRefGetMipMappedArray[('hipMipmappedArray_t*', 'pArray'), ('const + // textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetMipMappedArray: + if (data->args.hipTexRefGetMipMappedArray.pArray) + data->args.hipTexRefGetMipMappedArray.pArray__val = + *(data->args.hipTexRefGetMipMappedArray.pArray); + if (data->args.hipTexRefGetMipMappedArray.texRef) + data->args.hipTexRefGetMipMappedArray.texRef__val = + *(data->args.hipTexRefGetMipMappedArray.texRef); + break; + // hipTexRefGetMipmapLevelBias[('float*', 'pbias'), ('const + // textureReference*', 'texRef')] + case HIP_API_ID_hipTexRefGetMipmapLevelBias: + if (data->args.hipTexRefGetMipmapLevelBias.pbias) + data->args.hipTexRefGetMipmapLevelBias.pbias__val = + *(data->args.hipTexRefGetMipmapLevelBias.pbias); + if (data->args.hipTexRefGetMipmapLevelBias.texRef) + data->args.hipTexRefGetMipmapLevelBias.texRef__val = + *(data->args.hipTexRefGetMipmapLevelBias.texRef); + break; + // hipTexRefGetMipmapLevelClamp[('float*', 'pminMipmapLevelClamp'), + // ('float*', 'pmaxMipmapLevelClamp'), ('const textureReference*', + // 'texRef')] + case HIP_API_ID_hipTexRefGetMipmapLevelClamp: + if (data->args.hipTexRefGetMipmapLevelClamp.pminMipmapLevelClamp) + data->args.hipTexRefGetMipmapLevelClamp.pminMipmapLevelClamp__val = + *(data->args.hipTexRefGetMipmapLevelClamp.pminMipmapLevelClamp); + if (data->args.hipTexRefGetMipmapLevelClamp.pmaxMipmapLevelClamp) + data->args.hipTexRefGetMipmapLevelClamp.pmaxMipmapLevelClamp__val = + *(data->args.hipTexRefGetMipmapLevelClamp.pmaxMipmapLevelClamp); + if (data->args.hipTexRefGetMipmapLevelClamp.texRef) + data->args.hipTexRefGetMipmapLevelClamp.texRef__val = + *(data->args.hipTexRefGetMipmapLevelClamp.texRef); + break; + // hipTexRefSetAddress[('size_t*', 'ByteOffset'), ('textureReference*', + // 'texRef'), ('hipDeviceptr_t', 'dptr'), ('size_t', 'bytes')] + case HIP_API_ID_hipTexRefSetAddress: + if (data->args.hipTexRefSetAddress.ByteOffset) + data->args.hipTexRefSetAddress.ByteOffset__val = + *(data->args.hipTexRefSetAddress.ByteOffset); + if (data->args.hipTexRefSetAddress.texRef) + data->args.hipTexRefSetAddress.texRef__val = + *(data->args.hipTexRefSetAddress.texRef); + break; + // hipTexRefSetAddress2D[('textureReference*', 'texRef'), ('const + // HIP_ARRAY_DESCRIPTOR*', 'desc'), ('hipDeviceptr_t', 'dptr'), ('size_t', + // 'Pitch')] + case HIP_API_ID_hipTexRefSetAddress2D: + if (data->args.hipTexRefSetAddress2D.texRef) + data->args.hipTexRefSetAddress2D.texRef__val = + *(data->args.hipTexRefSetAddress2D.texRef); + if (data->args.hipTexRefSetAddress2D.desc) + data->args.hipTexRefSetAddress2D.desc__val = + *(data->args.hipTexRefSetAddress2D.desc); + break; + // hipTexRefSetArray[('textureReference*', 'tex'), ('hipArray_const_t', + // 'array'), ('unsigned int', 'flags')] + case HIP_API_ID_hipTexRefSetArray: + if (data->args.hipTexRefSetArray.tex) + data->args.hipTexRefSetArray.tex__val = + *(data->args.hipTexRefSetArray.tex); + break; + // hipTexRefSetBorderColor[('textureReference*', 'texRef'), ('float*', + // 'pBorderColor')] + case HIP_API_ID_hipTexRefSetBorderColor: + if (data->args.hipTexRefSetBorderColor.texRef) + data->args.hipTexRefSetBorderColor.texRef__val = + *(data->args.hipTexRefSetBorderColor.texRef); + if (data->args.hipTexRefSetBorderColor.pBorderColor) + data->args.hipTexRefSetBorderColor.pBorderColor__val = + *(data->args.hipTexRefSetBorderColor.pBorderColor); + break; + // hipTexRefSetFlags[('textureReference*', 'texRef'), ('unsigned int', + // 'Flags')] + case HIP_API_ID_hipTexRefSetFlags: + if (data->args.hipTexRefSetFlags.texRef) + data->args.hipTexRefSetFlags.texRef__val = + *(data->args.hipTexRefSetFlags.texRef); + break; + // hipTexRefSetFormat[('textureReference*', 'texRef'), ('hipArray_Format', + // 'fmt'), ('int', 'NumPackedComponents')] + case HIP_API_ID_hipTexRefSetFormat: + if (data->args.hipTexRefSetFormat.texRef) + data->args.hipTexRefSetFormat.texRef__val = + *(data->args.hipTexRefSetFormat.texRef); + break; + // hipTexRefSetMaxAnisotropy[('textureReference*', 'texRef'), ('unsigned + // int', 'maxAniso')] + case HIP_API_ID_hipTexRefSetMaxAnisotropy: + if (data->args.hipTexRefSetMaxAnisotropy.texRef) + data->args.hipTexRefSetMaxAnisotropy.texRef__val = + *(data->args.hipTexRefSetMaxAnisotropy.texRef); + break; + // hipTexRefSetMipmapLevelBias[('textureReference*', 'texRef'), ('float', + // 'bias')] + case HIP_API_ID_hipTexRefSetMipmapLevelBias: + if (data->args.hipTexRefSetMipmapLevelBias.texRef) + data->args.hipTexRefSetMipmapLevelBias.texRef__val = + *(data->args.hipTexRefSetMipmapLevelBias.texRef); + break; + // hipTexRefSetMipmapLevelClamp[('textureReference*', 'texRef'), ('float', + // 'minMipMapLevelClamp'), ('float', 'maxMipMapLevelClamp')] + case HIP_API_ID_hipTexRefSetMipmapLevelClamp: + if (data->args.hipTexRefSetMipmapLevelClamp.texRef) + data->args.hipTexRefSetMipmapLevelClamp.texRef__val = + *(data->args.hipTexRefSetMipmapLevelClamp.texRef); + break; + // hipTexRefSetMipmappedArray[('textureReference*', 'texRef'), + // ('hipMipmappedArray*', 'mipmappedArray'), ('unsigned int', 'Flags')] + case HIP_API_ID_hipTexRefSetMipmappedArray: + if (data->args.hipTexRefSetMipmappedArray.texRef) + data->args.hipTexRefSetMipmappedArray.texRef__val = + *(data->args.hipTexRefSetMipmappedArray.texRef); + if (data->args.hipTexRefSetMipmappedArray.mipmappedArray) + data->args.hipTexRefSetMipmappedArray.mipmappedArray__val = + *(data->args.hipTexRefSetMipmappedArray.mipmappedArray); + break; + // hipThreadExchangeStreamCaptureMode[('hipStreamCaptureMode*', 'mode')] + case HIP_API_ID_hipThreadExchangeStreamCaptureMode: + if (data->args.hipThreadExchangeStreamCaptureMode.mode) + data->args.hipThreadExchangeStreamCaptureMode.mode__val = + *(data->args.hipThreadExchangeStreamCaptureMode.mode); + break; + // hipUserObjectCreate[('hipUserObject_t*', 'object_out'), ('void*', 'ptr'), + // ('hipHostFn_t', 'destroy'), ('unsigned int', 'initialRefcount'), + // ('unsigned int', 'flags')] + case HIP_API_ID_hipUserObjectCreate: + if (data->args.hipUserObjectCreate.object_out) + data->args.hipUserObjectCreate.object_out__val = + *(data->args.hipUserObjectCreate.object_out); + break; + // hipUserObjectRelease[('hipUserObject_t', 'object'), ('unsigned int', + // 'count')] + case HIP_API_ID_hipUserObjectRelease: + break; + // hipUserObjectRetain[('hipUserObject_t', 'object'), ('unsigned int', + // 'count')] + case HIP_API_ID_hipUserObjectRetain: + break; + // hipWaitExternalSemaphoresAsync[('const hipExternalSemaphore_t*', + // 'extSemArray'), ('const hipExternalSemaphoreWaitParams*', 'paramsArray'), + // ('unsigned int', 'numExtSems'), ('hipStream_t', 'stream')] + case HIP_API_ID_hipWaitExternalSemaphoresAsync: + if (data->args.hipWaitExternalSemaphoresAsync.extSemArray) + data->args.hipWaitExternalSemaphoresAsync.extSemArray__val = + *(data->args.hipWaitExternalSemaphoresAsync.extSemArray); + if (data->args.hipWaitExternalSemaphoresAsync.paramsArray) + data->args.hipWaitExternalSemaphoresAsync.paramsArray__val = + *(data->args.hipWaitExternalSemaphoresAsync.paramsArray); + break; + default: + break; + }; +} + +#include +#include +// HIP API string method, method name and parameters +static inline const char *hipApiString(hip_api_id_t id, + const hip_api_data_t *data) { + std::ostringstream oss; + switch (id) { + case HIP_API_ID___hipPopCallConfiguration: + oss << "__hipPopCallConfiguration("; + if (data->args.__hipPopCallConfiguration.gridDim == NULL) + oss << "gridDim=NULL"; + else { + oss << "gridDim="; + roctracer::hip_support::detail::operator<<( + oss, data->args.__hipPopCallConfiguration.gridDim__val); + } + if (data->args.__hipPopCallConfiguration.blockDim == NULL) + oss << ", blockDim=NULL"; + else { + oss << ", blockDim="; + roctracer::hip_support::detail::operator<<( + oss, data->args.__hipPopCallConfiguration.blockDim__val); + } + if (data->args.__hipPopCallConfiguration.sharedMem == NULL) + oss << ", sharedMem=NULL"; + else { + oss << ", sharedMem="; + roctracer::hip_support::detail::operator<<( + oss, data->args.__hipPopCallConfiguration.sharedMem__val); + } + if (data->args.__hipPopCallConfiguration.stream == NULL) + oss << ", stream=NULL"; + else { + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.__hipPopCallConfiguration.stream__val); + } + oss << ")"; + break; + case HIP_API_ID___hipPushCallConfiguration: + oss << "__hipPushCallConfiguration("; + oss << "gridDim="; + roctracer::hip_support::detail::operator<<( + oss, data->args.__hipPushCallConfiguration.gridDim); + oss << ", blockDim="; + roctracer::hip_support::detail::operator<<( + oss, data->args.__hipPushCallConfiguration.blockDim); + oss << ", sharedMem="; + roctracer::hip_support::detail::operator<<( + oss, data->args.__hipPushCallConfiguration.sharedMem); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.__hipPushCallConfiguration.stream); + oss << ")"; + break; + case HIP_API_ID_hipArray3DCreate: + oss << "hipArray3DCreate("; + if (data->args.hipArray3DCreate.array == NULL) + oss << "array=NULL"; + else { + oss << "array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArray3DCreate.array__val); + } + if (data->args.hipArray3DCreate.pAllocateArray == NULL) + oss << ", pAllocateArray=NULL"; + else { + oss << ", pAllocateArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArray3DCreate.pAllocateArray__val); + } + oss << ")"; + break; + case HIP_API_ID_hipArray3DGetDescriptor: + oss << "hipArray3DGetDescriptor("; + if (data->args.hipArray3DGetDescriptor.pArrayDescriptor == NULL) + oss << "pArrayDescriptor=NULL"; + else { + oss << "pArrayDescriptor="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArray3DGetDescriptor.pArrayDescriptor__val); + } + oss << ", array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArray3DGetDescriptor.array); + oss << ")"; + break; + case HIP_API_ID_hipArrayCreate: + oss << "hipArrayCreate("; + if (data->args.hipArrayCreate.pHandle == NULL) + oss << "pHandle=NULL"; + else { + oss << "pHandle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayCreate.pHandle__val); + } + if (data->args.hipArrayCreate.pAllocateArray == NULL) + oss << ", pAllocateArray=NULL"; + else { + oss << ", pAllocateArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayCreate.pAllocateArray__val); + } + oss << ")"; + break; + case HIP_API_ID_hipArrayDestroy: + oss << "hipArrayDestroy("; + oss << "array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayDestroy.array); + oss << ")"; + break; + case HIP_API_ID_hipArrayGetDescriptor: + oss << "hipArrayGetDescriptor("; + if (data->args.hipArrayGetDescriptor.pArrayDescriptor == NULL) + oss << "pArrayDescriptor=NULL"; + else { + oss << "pArrayDescriptor="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayGetDescriptor.pArrayDescriptor__val); + } + oss << ", array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayGetDescriptor.array); + oss << ")"; + break; + case HIP_API_ID_hipArrayGetInfo: + oss << "hipArrayGetInfo("; + if (data->args.hipArrayGetInfo.desc == NULL) + oss << "desc=NULL"; + else { + oss << "desc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayGetInfo.desc__val); + } + if (data->args.hipArrayGetInfo.extent == NULL) + oss << ", extent=NULL"; + else { + oss << ", extent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayGetInfo.extent__val); + } + if (data->args.hipArrayGetInfo.flags == NULL) + oss << ", flags=NULL"; + else { + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayGetInfo.flags__val); + } + oss << ", array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipArrayGetInfo.array); + oss << ")"; + break; + case HIP_API_ID_hipChooseDeviceR0000: + oss << "hipChooseDeviceR0000("; + if (data->args.hipChooseDeviceR0000.device == NULL) + oss << "device=NULL"; + else { + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipChooseDeviceR0000.device__val); + } + if (data->args.hipChooseDeviceR0000.prop == NULL) + oss << ", prop=NULL"; + else { + oss << ", prop="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipChooseDeviceR0000.prop__val); + } + oss << ")"; + break; + case HIP_API_ID_hipChooseDeviceR0600: + oss << "hipChooseDeviceR0600("; + if (data->args.hipChooseDeviceR0600.device == NULL) + oss << "device=NULL"; + else { + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipChooseDeviceR0600.device__val); + } + if (data->args.hipChooseDeviceR0600.prop == NULL) + oss << ", prop=NULL"; + else { + oss << ", prop="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipChooseDeviceR0600.prop__val); + } + oss << ")"; + break; + case HIP_API_ID_hipConfigureCall: + oss << "hipConfigureCall("; + oss << "gridDim="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipConfigureCall.gridDim); + oss << ", blockDim="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipConfigureCall.blockDim); + oss << ", sharedMem="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipConfigureCall.sharedMem); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipConfigureCall.stream); + oss << ")"; + break; + case HIP_API_ID_hipCreateSurfaceObject: + oss << "hipCreateSurfaceObject("; + if (data->args.hipCreateSurfaceObject.pSurfObject == NULL) + oss << "pSurfObject=NULL"; + else { + oss << "pSurfObject="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCreateSurfaceObject.pSurfObject__val); + } + if (data->args.hipCreateSurfaceObject.pResDesc == NULL) + oss << ", pResDesc=NULL"; + else { + oss << ", pResDesc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCreateSurfaceObject.pResDesc__val); + } + oss << ")"; + break; + case HIP_API_ID_hipCtxCreate: + oss << "hipCtxCreate("; + if (data->args.hipCtxCreate.ctx == NULL) + oss << "ctx=NULL"; + else { + oss << "ctx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxCreate.ctx__val); + } + oss << ", flags="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipCtxCreate.flags); + oss << ", device="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipCtxCreate.device); + oss << ")"; + break; + case HIP_API_ID_hipCtxDestroy: + oss << "hipCtxDestroy("; + oss << "ctx="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipCtxDestroy.ctx); + oss << ")"; + break; + case HIP_API_ID_hipCtxDisablePeerAccess: + oss << "hipCtxDisablePeerAccess("; + oss << "peerCtx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxDisablePeerAccess.peerCtx); + oss << ")"; + break; + case HIP_API_ID_hipCtxEnablePeerAccess: + oss << "hipCtxEnablePeerAccess("; + oss << "peerCtx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxEnablePeerAccess.peerCtx); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxEnablePeerAccess.flags); + oss << ")"; + break; + case HIP_API_ID_hipCtxGetApiVersion: + oss << "hipCtxGetApiVersion("; + oss << "ctx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxGetApiVersion.ctx); + if (data->args.hipCtxGetApiVersion.apiVersion == NULL) + oss << ", apiVersion=NULL"; + else { + oss << ", apiVersion="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxGetApiVersion.apiVersion__val); + } + oss << ")"; + break; + case HIP_API_ID_hipCtxGetCacheConfig: + oss << "hipCtxGetCacheConfig("; + if (data->args.hipCtxGetCacheConfig.cacheConfig == NULL) + oss << "cacheConfig=NULL"; + else { + oss << "cacheConfig="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxGetCacheConfig.cacheConfig__val); + } + oss << ")"; + break; + case HIP_API_ID_hipCtxGetCurrent: + oss << "hipCtxGetCurrent("; + if (data->args.hipCtxGetCurrent.ctx == NULL) + oss << "ctx=NULL"; + else { + oss << "ctx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxGetCurrent.ctx__val); + } + oss << ")"; + break; + case HIP_API_ID_hipCtxGetDevice: + oss << "hipCtxGetDevice("; + if (data->args.hipCtxGetDevice.device == NULL) + oss << "device=NULL"; + else { + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxGetDevice.device__val); + } + oss << ")"; + break; + case HIP_API_ID_hipCtxGetFlags: + oss << "hipCtxGetFlags("; + if (data->args.hipCtxGetFlags.flags == NULL) + oss << "flags=NULL"; + else { + oss << "flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxGetFlags.flags__val); + } + oss << ")"; + break; + case HIP_API_ID_hipCtxGetSharedMemConfig: + oss << "hipCtxGetSharedMemConfig("; + if (data->args.hipCtxGetSharedMemConfig.pConfig == NULL) + oss << "pConfig=NULL"; + else { + oss << "pConfig="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxGetSharedMemConfig.pConfig__val); + } + oss << ")"; + break; + case HIP_API_ID_hipCtxPopCurrent: + oss << "hipCtxPopCurrent("; + if (data->args.hipCtxPopCurrent.ctx == NULL) + oss << "ctx=NULL"; + else { + oss << "ctx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxPopCurrent.ctx__val); + } + oss << ")"; + break; + case HIP_API_ID_hipCtxPushCurrent: + oss << "hipCtxPushCurrent("; + oss << "ctx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxPushCurrent.ctx); + oss << ")"; + break; + case HIP_API_ID_hipCtxSetCacheConfig: + oss << "hipCtxSetCacheConfig("; + oss << "cacheConfig="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxSetCacheConfig.cacheConfig); + oss << ")"; + break; + case HIP_API_ID_hipCtxSetCurrent: + oss << "hipCtxSetCurrent("; + oss << "ctx="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipCtxSetCurrent.ctx); + oss << ")"; + break; + case HIP_API_ID_hipCtxSetSharedMemConfig: + oss << "hipCtxSetSharedMemConfig("; + oss << "config="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipCtxSetSharedMemConfig.config); + oss << ")"; + break; + case HIP_API_ID_hipCtxSynchronize: + oss << "hipCtxSynchronize("; + oss << ")"; + break; + case HIP_API_ID_hipDestroyExternalMemory: + oss << "hipDestroyExternalMemory("; + oss << "extMem="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDestroyExternalMemory.extMem); + oss << ")"; + break; + case HIP_API_ID_hipDestroyExternalSemaphore: + oss << "hipDestroyExternalSemaphore("; + oss << "extSem="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDestroyExternalSemaphore.extSem); + oss << ")"; + break; + case HIP_API_ID_hipDestroySurfaceObject: + oss << "hipDestroySurfaceObject("; + oss << "surfaceObject="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDestroySurfaceObject.surfaceObject); + oss << ")"; + break; + case HIP_API_ID_hipDeviceCanAccessPeer: + oss << "hipDeviceCanAccessPeer("; + if (data->args.hipDeviceCanAccessPeer.canAccessPeer == NULL) + oss << "canAccessPeer=NULL"; + else { + oss << "canAccessPeer="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceCanAccessPeer.canAccessPeer__val); + } + oss << ", deviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceCanAccessPeer.deviceId); + oss << ", peerDeviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceCanAccessPeer.peerDeviceId); + oss << ")"; + break; + case HIP_API_ID_hipDeviceComputeCapability: + oss << "hipDeviceComputeCapability("; + if (data->args.hipDeviceComputeCapability.major == NULL) + oss << "major=NULL"; + else { + oss << "major="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceComputeCapability.major__val); + } + if (data->args.hipDeviceComputeCapability.minor == NULL) + oss << ", minor=NULL"; + else { + oss << ", minor="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceComputeCapability.minor__val); + } + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceComputeCapability.device); + oss << ")"; + break; + case HIP_API_ID_hipDeviceDisablePeerAccess: + oss << "hipDeviceDisablePeerAccess("; + oss << "peerDeviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceDisablePeerAccess.peerDeviceId); + oss << ")"; + break; + case HIP_API_ID_hipDeviceEnablePeerAccess: + oss << "hipDeviceEnablePeerAccess("; + oss << "peerDeviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceEnablePeerAccess.peerDeviceId); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceEnablePeerAccess.flags); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGet: + oss << "hipDeviceGet("; + if (data->args.hipDeviceGet.device == NULL) + oss << "device=NULL"; + else { + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGet.device__val); + } + oss << ", ordinal="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipDeviceGet.ordinal); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetAttribute: + oss << "hipDeviceGetAttribute("; + if (data->args.hipDeviceGetAttribute.pi == NULL) + oss << "pi=NULL"; + else { + oss << "pi="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetAttribute.pi__val); + } + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetAttribute.attr); + oss << ", deviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetAttribute.deviceId); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetByPCIBusId: + oss << "hipDeviceGetByPCIBusId("; + if (data->args.hipDeviceGetByPCIBusId.device == NULL) + oss << "device=NULL"; + else { + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetByPCIBusId.device__val); + } + if (data->args.hipDeviceGetByPCIBusId.pciBusId == NULL) + oss << ", pciBusId=NULL"; + else { + oss << ", pciBusId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetByPCIBusId.pciBusId__val); + } + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetCacheConfig: + oss << "hipDeviceGetCacheConfig("; + if (data->args.hipDeviceGetCacheConfig.cacheConfig == NULL) + oss << "cacheConfig=NULL"; + else { + oss << "cacheConfig="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetCacheConfig.cacheConfig__val); + } + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetDefaultMemPool: + oss << "hipDeviceGetDefaultMemPool("; + if (data->args.hipDeviceGetDefaultMemPool.mem_pool == NULL) + oss << "mem_pool=NULL"; + else { + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetDefaultMemPool.mem_pool__val); + } + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetDefaultMemPool.device); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetGraphMemAttribute: + oss << "hipDeviceGetGraphMemAttribute("; + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetGraphMemAttribute.device); + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetGraphMemAttribute.attr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetGraphMemAttribute.value); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetLimit: + oss << "hipDeviceGetLimit("; + if (data->args.hipDeviceGetLimit.pValue == NULL) + oss << "pValue=NULL"; + else { + oss << "pValue="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetLimit.pValue__val); + } + oss << ", limit="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetLimit.limit); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetMemPool: + oss << "hipDeviceGetMemPool("; + if (data->args.hipDeviceGetMemPool.mem_pool == NULL) + oss << "mem_pool=NULL"; + else { + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetMemPool.mem_pool__val); + } + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetMemPool.device); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetName: + oss << "hipDeviceGetName("; + if (data->args.hipDeviceGetName.name == NULL) + oss << "name=NULL"; + else { + oss << "name="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetName.name__val); + } + oss << ", len="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipDeviceGetName.len); + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetName.device); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetP2PAttribute: + oss << "hipDeviceGetP2PAttribute("; + if (data->args.hipDeviceGetP2PAttribute.value == NULL) + oss << "value=NULL"; + else { + oss << "value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetP2PAttribute.value__val); + } + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetP2PAttribute.attr); + oss << ", srcDevice="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetP2PAttribute.srcDevice); + oss << ", dstDevice="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetP2PAttribute.dstDevice); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetPCIBusId: + oss << "hipDeviceGetPCIBusId("; + if (data->args.hipDeviceGetPCIBusId.pciBusId == NULL) + oss << "pciBusId=NULL"; + else { + oss << "pciBusId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetPCIBusId.pciBusId__val); + } + oss << ", len="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetPCIBusId.len); + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetPCIBusId.device); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetSharedMemConfig: + oss << "hipDeviceGetSharedMemConfig("; + if (data->args.hipDeviceGetSharedMemConfig.pConfig == NULL) + oss << "pConfig=NULL"; + else { + oss << "pConfig="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetSharedMemConfig.pConfig__val); + } + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetStreamPriorityRange: + oss << "hipDeviceGetStreamPriorityRange("; + if (data->args.hipDeviceGetStreamPriorityRange.leastPriority == NULL) + oss << "leastPriority=NULL"; + else { + oss << "leastPriority="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetStreamPriorityRange.leastPriority__val); + } + if (data->args.hipDeviceGetStreamPriorityRange.greatestPriority == NULL) + oss << ", greatestPriority=NULL"; + else { + oss << ", greatestPriority="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipDeviceGetStreamPriorityRange.greatestPriority__val); + } + oss << ")"; + break; + case HIP_API_ID_hipDeviceGetUuid: + oss << "hipDeviceGetUuid("; + if (data->args.hipDeviceGetUuid.uuid == NULL) + oss << "uuid=NULL"; + else { + oss << "uuid="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetUuid.uuid__val); + } + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGetUuid.device); + oss << ")"; + break; + case HIP_API_ID_hipDeviceGraphMemTrim: + oss << "hipDeviceGraphMemTrim("; + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceGraphMemTrim.device); + oss << ")"; + break; + case HIP_API_ID_hipDevicePrimaryCtxGetState: + oss << "hipDevicePrimaryCtxGetState("; + oss << "dev="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxGetState.dev); + if (data->args.hipDevicePrimaryCtxGetState.flags == NULL) + oss << ", flags=NULL"; + else { + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxGetState.flags__val); + } + if (data->args.hipDevicePrimaryCtxGetState.active == NULL) + oss << ", active=NULL"; + else { + oss << ", active="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxGetState.active__val); + } + oss << ")"; + break; + case HIP_API_ID_hipDevicePrimaryCtxRelease: + oss << "hipDevicePrimaryCtxRelease("; + oss << "dev="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxRelease.dev); + oss << ")"; + break; + case HIP_API_ID_hipDevicePrimaryCtxReset: + oss << "hipDevicePrimaryCtxReset("; + oss << "dev="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxReset.dev); + oss << ")"; + break; + case HIP_API_ID_hipDevicePrimaryCtxRetain: + oss << "hipDevicePrimaryCtxRetain("; + if (data->args.hipDevicePrimaryCtxRetain.pctx == NULL) + oss << "pctx=NULL"; + else { + oss << "pctx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxRetain.pctx__val); + } + oss << ", dev="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxRetain.dev); + oss << ")"; + break; + case HIP_API_ID_hipDevicePrimaryCtxSetFlags: + oss << "hipDevicePrimaryCtxSetFlags("; + oss << "dev="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxSetFlags.dev); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDevicePrimaryCtxSetFlags.flags); + oss << ")"; + break; + case HIP_API_ID_hipDeviceReset: + oss << "hipDeviceReset("; + oss << ")"; + break; + case HIP_API_ID_hipDeviceSetCacheConfig: + oss << "hipDeviceSetCacheConfig("; + oss << "cacheConfig="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetCacheConfig.cacheConfig); + oss << ")"; + break; + case HIP_API_ID_hipDeviceSetGraphMemAttribute: + oss << "hipDeviceSetGraphMemAttribute("; + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetGraphMemAttribute.device); + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetGraphMemAttribute.attr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetGraphMemAttribute.value); + oss << ")"; + break; + case HIP_API_ID_hipDeviceSetLimit: + oss << "hipDeviceSetLimit("; + oss << "limit="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetLimit.limit); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetLimit.value); + oss << ")"; + break; + case HIP_API_ID_hipDeviceSetMemPool: + oss << "hipDeviceSetMemPool("; + oss << "device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetMemPool.device); + oss << ", mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetMemPool.mem_pool); + oss << ")"; + break; + case HIP_API_ID_hipDeviceSetSharedMemConfig: + oss << "hipDeviceSetSharedMemConfig("; + oss << "config="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceSetSharedMemConfig.config); + oss << ")"; + break; + case HIP_API_ID_hipDeviceSynchronize: + oss << "hipDeviceSynchronize("; + oss << ")"; + break; + case HIP_API_ID_hipDeviceTotalMem: + oss << "hipDeviceTotalMem("; + if (data->args.hipDeviceTotalMem.bytes == NULL) + oss << "bytes=NULL"; + else { + oss << "bytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceTotalMem.bytes__val); + } + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDeviceTotalMem.device); + oss << ")"; + break; + case HIP_API_ID_hipDriverGetVersion: + oss << "hipDriverGetVersion("; + if (data->args.hipDriverGetVersion.driverVersion == NULL) + oss << "driverVersion=NULL"; + else { + oss << "driverVersion="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDriverGetVersion.driverVersion__val); + } + oss << ")"; + break; + case HIP_API_ID_hipDrvGraphAddMemcpyNode: + oss << "hipDrvGraphAddMemcpyNode("; + if (data->args.hipDrvGraphAddMemcpyNode.phGraphNode == NULL) + oss << "phGraphNode=NULL"; + else { + oss << "phGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemcpyNode.phGraphNode__val); + } + oss << ", hGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemcpyNode.hGraph); + if (data->args.hipDrvGraphAddMemcpyNode.dependencies == NULL) + oss << ", dependencies=NULL"; + else { + oss << ", dependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemcpyNode.dependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemcpyNode.numDependencies); + if (data->args.hipDrvGraphAddMemcpyNode.copyParams == NULL) + oss << ", copyParams=NULL"; + else { + oss << ", copyParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemcpyNode.copyParams__val); + } + oss << ", ctx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemcpyNode.ctx); + oss << ")"; + break; + case HIP_API_ID_hipDrvGraphAddMemsetNode: + oss << "hipDrvGraphAddMemsetNode("; + if (data->args.hipDrvGraphAddMemsetNode.phGraphNode == NULL) + oss << "phGraphNode=NULL"; + else { + oss << "phGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemsetNode.phGraphNode__val); + } + oss << ", hGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemsetNode.hGraph); + if (data->args.hipDrvGraphAddMemsetNode.dependencies == NULL) + oss << ", dependencies=NULL"; + else { + oss << ", dependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemsetNode.dependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemsetNode.numDependencies); + if (data->args.hipDrvGraphAddMemsetNode.memsetParams == NULL) + oss << ", memsetParams=NULL"; + else { + oss << ", memsetParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemsetNode.memsetParams__val); + } + oss << ", ctx="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvGraphAddMemsetNode.ctx); + oss << ")"; + break; + case HIP_API_ID_hipDrvMemcpy2DUnaligned: + oss << "hipDrvMemcpy2DUnaligned("; + if (data->args.hipDrvMemcpy2DUnaligned.pCopy == NULL) + oss << "pCopy=NULL"; + else { + oss << "pCopy="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvMemcpy2DUnaligned.pCopy__val); + } + oss << ")"; + break; + case HIP_API_ID_hipDrvMemcpy3D: + oss << "hipDrvMemcpy3D("; + if (data->args.hipDrvMemcpy3D.pCopy == NULL) + oss << "pCopy=NULL"; + else { + oss << "pCopy="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvMemcpy3D.pCopy__val); + } + oss << ")"; + break; + case HIP_API_ID_hipDrvMemcpy3DAsync: + oss << "hipDrvMemcpy3DAsync("; + if (data->args.hipDrvMemcpy3DAsync.pCopy == NULL) + oss << "pCopy=NULL"; + else { + oss << "pCopy="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvMemcpy3DAsync.pCopy__val); + } + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvMemcpy3DAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipDrvPointerGetAttributes: + oss << "hipDrvPointerGetAttributes("; + oss << "numAttributes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvPointerGetAttributes.numAttributes); + if (data->args.hipDrvPointerGetAttributes.attributes == NULL) + oss << ", attributes=NULL"; + else { + oss << ", attributes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvPointerGetAttributes.attributes__val); + } + if (data->args.hipDrvPointerGetAttributes.data == NULL) + oss << ", data=NULL"; + else { + oss << ", data="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvPointerGetAttributes.data__val); + } + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipDrvPointerGetAttributes.ptr); + oss << ")"; + break; + case HIP_API_ID_hipEventCreate: + oss << "hipEventCreate("; + if (data->args.hipEventCreate.event == NULL) + oss << "event=NULL"; + else { + oss << "event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventCreate.event__val); + } + oss << ")"; + break; + case HIP_API_ID_hipEventCreateWithFlags: + oss << "hipEventCreateWithFlags("; + if (data->args.hipEventCreateWithFlags.event == NULL) + oss << "event=NULL"; + else { + oss << "event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventCreateWithFlags.event__val); + } + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventCreateWithFlags.flags); + oss << ")"; + break; + case HIP_API_ID_hipEventDestroy: + oss << "hipEventDestroy("; + oss << "event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventDestroy.event); + oss << ")"; + break; + case HIP_API_ID_hipEventElapsedTime: + oss << "hipEventElapsedTime("; + if (data->args.hipEventElapsedTime.ms == NULL) + oss << "ms=NULL"; + else { + oss << "ms="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventElapsedTime.ms__val); + } + oss << ", start="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventElapsedTime.start); + oss << ", stop="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventElapsedTime.stop); + oss << ")"; + break; + case HIP_API_ID_hipEventQuery: + oss << "hipEventQuery("; + oss << "event="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipEventQuery.event); + oss << ")"; + break; + case HIP_API_ID_hipEventRecord: + oss << "hipEventRecord("; + oss << "event="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipEventRecord.event); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventRecord.stream); + oss << ")"; + break; + case HIP_API_ID_hipEventSynchronize: + oss << "hipEventSynchronize("; + oss << "event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipEventSynchronize.event); + oss << ")"; + break; + case HIP_API_ID_hipExtGetLastError: + oss << "hipExtGetLastError("; + oss << ")"; + break; + case HIP_API_ID_hipExtGetLinkTypeAndHopCount: + oss << "hipExtGetLinkTypeAndHopCount("; + oss << "device1="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtGetLinkTypeAndHopCount.device1); + oss << ", device2="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtGetLinkTypeAndHopCount.device2); + if (data->args.hipExtGetLinkTypeAndHopCount.linktype == NULL) + oss << ", linktype=NULL"; + else { + oss << ", linktype="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtGetLinkTypeAndHopCount.linktype__val); + } + if (data->args.hipExtGetLinkTypeAndHopCount.hopcount == NULL) + oss << ", hopcount=NULL"; + else { + oss << ", hopcount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtGetLinkTypeAndHopCount.hopcount__val); + } + oss << ")"; + break; + case HIP_API_ID_hipExtLaunchKernel: + oss << "hipExtLaunchKernel("; + oss << "function_address="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.function_address); + oss << ", numBlocks="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.numBlocks); + oss << ", dimBlocks="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.dimBlocks); + if (data->args.hipExtLaunchKernel.args == NULL) + oss << ", args=NULL"; + else { + oss << ", args="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.args__val); + } + oss << ", sharedMemBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.sharedMemBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.stream); + oss << ", startEvent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.startEvent); + oss << ", stopEvent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.stopEvent); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchKernel.flags); + oss << ")"; + break; + case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: + oss << "hipExtLaunchMultiKernelMultiDevice("; + if (data->args.hipExtLaunchMultiKernelMultiDevice.launchParamsList == NULL) + oss << "launchParamsList=NULL"; + else { + oss << "launchParamsList="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipExtLaunchMultiKernelMultiDevice.launchParamsList__val); + } + oss << ", numDevices="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchMultiKernelMultiDevice.numDevices); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtLaunchMultiKernelMultiDevice.flags); + oss << ")"; + break; + case HIP_API_ID_hipExtMallocWithFlags: + oss << "hipExtMallocWithFlags("; + if (data->args.hipExtMallocWithFlags.ptr == NULL) + oss << "ptr=NULL"; + else { + oss << "ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtMallocWithFlags.ptr__val); + } + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtMallocWithFlags.sizeBytes); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtMallocWithFlags.flags); + oss << ")"; + break; + case HIP_API_ID_hipExtModuleLaunchKernel: + oss << "hipExtModuleLaunchKernel("; + oss << "f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.f); + oss << ", globalWorkSizeX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.globalWorkSizeX); + oss << ", globalWorkSizeY="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.globalWorkSizeY); + oss << ", globalWorkSizeZ="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.globalWorkSizeZ); + oss << ", localWorkSizeX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.localWorkSizeX); + oss << ", localWorkSizeY="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.localWorkSizeY); + oss << ", localWorkSizeZ="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.localWorkSizeZ); + oss << ", sharedMemBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.sharedMemBytes); + oss << ", hStream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.hStream); + if (data->args.hipExtModuleLaunchKernel.kernelParams == NULL) + oss << ", kernelParams=NULL"; + else { + oss << ", kernelParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.kernelParams__val); + } + if (data->args.hipExtModuleLaunchKernel.extra == NULL) + oss << ", extra=NULL"; + else { + oss << ", extra="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.extra__val); + } + oss << ", startEvent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.startEvent); + oss << ", stopEvent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.stopEvent); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtModuleLaunchKernel.flags); + oss << ")"; + break; + case HIP_API_ID_hipExtStreamCreateWithCUMask: + oss << "hipExtStreamCreateWithCUMask("; + if (data->args.hipExtStreamCreateWithCUMask.stream == NULL) + oss << "stream=NULL"; + else { + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtStreamCreateWithCUMask.stream__val); + } + oss << ", cuMaskSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtStreamCreateWithCUMask.cuMaskSize); + if (data->args.hipExtStreamCreateWithCUMask.cuMask == NULL) + oss << ", cuMask=NULL"; + else { + oss << ", cuMask="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtStreamCreateWithCUMask.cuMask__val); + } + oss << ")"; + break; + case HIP_API_ID_hipExtStreamGetCUMask: + oss << "hipExtStreamGetCUMask("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtStreamGetCUMask.stream); + oss << ", cuMaskSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtStreamGetCUMask.cuMaskSize); + if (data->args.hipExtStreamGetCUMask.cuMask == NULL) + oss << ", cuMask=NULL"; + else { + oss << ", cuMask="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExtStreamGetCUMask.cuMask__val); + } + oss << ")"; + break; + case HIP_API_ID_hipExternalMemoryGetMappedBuffer: + oss << "hipExternalMemoryGetMappedBuffer("; + if (data->args.hipExternalMemoryGetMappedBuffer.devPtr == NULL) + oss << "devPtr=NULL"; + else { + oss << "devPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExternalMemoryGetMappedBuffer.devPtr__val); + } + oss << ", extMem="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExternalMemoryGetMappedBuffer.extMem); + if (data->args.hipExternalMemoryGetMappedBuffer.bufferDesc == NULL) + oss << ", bufferDesc=NULL"; + else { + oss << ", bufferDesc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExternalMemoryGetMappedBuffer.bufferDesc__val); + } + oss << ")"; + break; + case HIP_API_ID_hipExternalMemoryGetMappedMipmappedArray: + oss << "hipExternalMemoryGetMappedMipmappedArray("; + if (data->args.hipExternalMemoryGetMappedMipmappedArray.mipmap == NULL) + oss << "mipmap=NULL"; + else { + oss << "mipmap="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExternalMemoryGetMappedMipmappedArray.mipmap__val); + } + oss << ", extMem="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipExternalMemoryGetMappedMipmappedArray.extMem); + if (data->args.hipExternalMemoryGetMappedMipmappedArray.mipmapDesc == NULL) + oss << ", mipmapDesc=NULL"; + else { + oss << ", mipmapDesc="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipExternalMemoryGetMappedMipmappedArray.mipmapDesc__val); + } + oss << ")"; + break; + case HIP_API_ID_hipFree: + oss << "hipFree("; + oss << "ptr="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipFree.ptr); + oss << ")"; + break; + case HIP_API_ID_hipFreeArray: + oss << "hipFreeArray("; + oss << "array="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipFreeArray.array); + oss << ")"; + break; + case HIP_API_ID_hipFreeAsync: + oss << "hipFreeAsync("; + oss << "dev_ptr="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipFreeAsync.dev_ptr); + oss << ", stream="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipFreeAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipFreeHost: + oss << "hipFreeHost("; + oss << "ptr="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipFreeHost.ptr); + oss << ")"; + break; + case HIP_API_ID_hipFreeMipmappedArray: + oss << "hipFreeMipmappedArray("; + oss << "mipmappedArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFreeMipmappedArray.mipmappedArray); + oss << ")"; + break; + case HIP_API_ID_hipFuncGetAttribute: + oss << "hipFuncGetAttribute("; + if (data->args.hipFuncGetAttribute.value == NULL) + oss << "value=NULL"; + else { + oss << "value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncGetAttribute.value__val); + } + oss << ", attrib="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncGetAttribute.attrib); + oss << ", hfunc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncGetAttribute.hfunc); + oss << ")"; + break; + case HIP_API_ID_hipFuncGetAttributes: + oss << "hipFuncGetAttributes("; + if (data->args.hipFuncGetAttributes.attr == NULL) + oss << "attr=NULL"; + else { + oss << "attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncGetAttributes.attr__val); + } + oss << ", func="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncGetAttributes.func); + oss << ")"; + break; + case HIP_API_ID_hipFuncSetAttribute: + oss << "hipFuncSetAttribute("; + oss << "func="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncSetAttribute.func); + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncSetAttribute.attr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncSetAttribute.value); + oss << ")"; + break; + case HIP_API_ID_hipFuncSetCacheConfig: + oss << "hipFuncSetCacheConfig("; + oss << "func="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncSetCacheConfig.func); + oss << ", config="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncSetCacheConfig.config); + oss << ")"; + break; + case HIP_API_ID_hipFuncSetSharedMemConfig: + oss << "hipFuncSetSharedMemConfig("; + oss << "func="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncSetSharedMemConfig.func); + oss << ", config="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipFuncSetSharedMemConfig.config); + oss << ")"; + break; + case HIP_API_ID_hipGLGetDevices: + oss << "hipGLGetDevices("; + if (data->args.hipGLGetDevices.pHipDeviceCount == NULL) + oss << "pHipDeviceCount=NULL"; + else { + oss << "pHipDeviceCount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGLGetDevices.pHipDeviceCount__val); + } + if (data->args.hipGLGetDevices.pHipDevices == NULL) + oss << ", pHipDevices=NULL"; + else { + oss << ", pHipDevices="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGLGetDevices.pHipDevices__val); + } + oss << ", hipDeviceCount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGLGetDevices.hipDeviceCount); + oss << ", deviceList="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGLGetDevices.deviceList); + oss << ")"; + break; + case HIP_API_ID_hipGetChannelDesc: + oss << "hipGetChannelDesc("; + if (data->args.hipGetChannelDesc.desc == NULL) + oss << "desc=NULL"; + else { + oss << "desc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetChannelDesc.desc__val); + } + oss << ", array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetChannelDesc.array); + oss << ")"; + break; + case HIP_API_ID_hipGetDevice: + oss << "hipGetDevice("; + if (data->args.hipGetDevice.deviceId == NULL) + oss << "deviceId=NULL"; + else { + oss << "deviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetDevice.deviceId__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGetDeviceCount: + oss << "hipGetDeviceCount("; + if (data->args.hipGetDeviceCount.count == NULL) + oss << "count=NULL"; + else { + oss << "count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetDeviceCount.count__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGetDeviceFlags: + oss << "hipGetDeviceFlags("; + if (data->args.hipGetDeviceFlags.flags == NULL) + oss << "flags=NULL"; + else { + oss << "flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetDeviceFlags.flags__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGetDevicePropertiesR0000: + oss << "hipGetDevicePropertiesR0000("; + if (data->args.hipGetDevicePropertiesR0000.prop == NULL) + oss << "prop=NULL"; + else { + oss << "prop="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetDevicePropertiesR0000.prop__val); + } + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetDevicePropertiesR0000.device); + oss << ")"; + break; + case HIP_API_ID_hipGetDevicePropertiesR0600: + oss << "hipGetDevicePropertiesR0600("; + if (data->args.hipGetDevicePropertiesR0600.prop == NULL) + oss << "prop=NULL"; + else { + oss << "prop="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetDevicePropertiesR0600.prop__val); + } + oss << ", deviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetDevicePropertiesR0600.deviceId); + oss << ")"; + break; + case HIP_API_ID_hipGetErrorString: + oss << "hipGetErrorString("; + oss << ")"; + break; + case HIP_API_ID_hipGetFuncBySymbol: + oss << "hipGetFuncBySymbol("; + if (data->args.hipGetFuncBySymbol.functionPtr == NULL) + oss << "functionPtr=NULL"; + else { + oss << "functionPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetFuncBySymbol.functionPtr__val); + } + oss << ", symbolPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetFuncBySymbol.symbolPtr); + oss << ")"; + break; + case HIP_API_ID_hipGetLastError: + oss << "hipGetLastError("; + oss << ")"; + break; + case HIP_API_ID_hipGetMipmappedArrayLevel: + oss << "hipGetMipmappedArrayLevel("; + if (data->args.hipGetMipmappedArrayLevel.levelArray == NULL) + oss << "levelArray=NULL"; + else { + oss << "levelArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetMipmappedArrayLevel.levelArray__val); + } + oss << ", mipmappedArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetMipmappedArrayLevel.mipmappedArray); + oss << ", level="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetMipmappedArrayLevel.level); + oss << ")"; + break; + case HIP_API_ID_hipGetProcAddress: + oss << "hipGetProcAddress("; + if (data->args.hipGetProcAddress.symbol == NULL) + oss << "symbol=NULL"; + else { + oss << "symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetProcAddress.symbol__val); + } + if (data->args.hipGetProcAddress.pfn == NULL) + oss << ", pfn=NULL"; + else { + oss << ", pfn="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetProcAddress.pfn__val); + } + oss << ", hipVersion="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetProcAddress.hipVersion); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetProcAddress.flags); + if (data->args.hipGetProcAddress.symbolStatus == NULL) + oss << ", symbolStatus=NULL"; + else { + oss << ", symbolStatus="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetProcAddress.symbolStatus__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGetSymbolAddress: + oss << "hipGetSymbolAddress("; + if (data->args.hipGetSymbolAddress.devPtr == NULL) + oss << "devPtr=NULL"; + else { + oss << "devPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetSymbolAddress.devPtr__val); + } + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetSymbolAddress.symbol); + oss << ")"; + break; + case HIP_API_ID_hipGetSymbolSize: + oss << "hipGetSymbolSize("; + if (data->args.hipGetSymbolSize.size == NULL) + oss << "size=NULL"; + else { + oss << "size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetSymbolSize.size__val); + } + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGetSymbolSize.symbol); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddChildGraphNode: + oss << "hipGraphAddChildGraphNode("; + if (data->args.hipGraphAddChildGraphNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddChildGraphNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddChildGraphNode.graph); + if (data->args.hipGraphAddChildGraphNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddChildGraphNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddChildGraphNode.numDependencies); + oss << ", childGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddChildGraphNode.childGraph); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddDependencies: + oss << "hipGraphAddDependencies("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddDependencies.graph); + if (data->args.hipGraphAddDependencies.from == NULL) + oss << ", from=NULL"; + else { + oss << ", from="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddDependencies.from__val); + } + if (data->args.hipGraphAddDependencies.to == NULL) + oss << ", to=NULL"; + else { + oss << ", to="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddDependencies.to__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddDependencies.numDependencies); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddEmptyNode: + oss << "hipGraphAddEmptyNode("; + if (data->args.hipGraphAddEmptyNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEmptyNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEmptyNode.graph); + if (data->args.hipGraphAddEmptyNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEmptyNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEmptyNode.numDependencies); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddEventRecordNode: + oss << "hipGraphAddEventRecordNode("; + if (data->args.hipGraphAddEventRecordNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventRecordNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventRecordNode.graph); + if (data->args.hipGraphAddEventRecordNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventRecordNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventRecordNode.numDependencies); + oss << ", event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventRecordNode.event); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddEventWaitNode: + oss << "hipGraphAddEventWaitNode("; + if (data->args.hipGraphAddEventWaitNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventWaitNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventWaitNode.graph); + if (data->args.hipGraphAddEventWaitNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventWaitNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventWaitNode.numDependencies); + oss << ", event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddEventWaitNode.event); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddExternalSemaphoresSignalNode: + oss << "hipGraphAddExternalSemaphoresSignalNode("; + if (data->args.hipGraphAddExternalSemaphoresSignalNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphAddExternalSemaphoresSignalNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddExternalSemaphoresSignalNode.graph); + if (data->args.hipGraphAddExternalSemaphoresSignalNode.pDependencies == + NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddExternalSemaphoresSignalNode + .pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphAddExternalSemaphoresSignalNode.numDependencies); + if (data->args.hipGraphAddExternalSemaphoresSignalNode.nodeParams == NULL) + oss << ", nodeParams=NULL"; + else { + oss << ", nodeParams="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphAddExternalSemaphoresSignalNode.nodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphAddExternalSemaphoresWaitNode: + oss << "hipGraphAddExternalSemaphoresWaitNode("; + if (data->args.hipGraphAddExternalSemaphoresWaitNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphAddExternalSemaphoresWaitNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddExternalSemaphoresWaitNode.graph); + if (data->args.hipGraphAddExternalSemaphoresWaitNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphAddExternalSemaphoresWaitNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddExternalSemaphoresWaitNode.numDependencies); + if (data->args.hipGraphAddExternalSemaphoresWaitNode.nodeParams == NULL) + oss << ", nodeParams=NULL"; + else { + oss << ", nodeParams="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphAddExternalSemaphoresWaitNode.nodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphAddHostNode: + oss << "hipGraphAddHostNode("; + if (data->args.hipGraphAddHostNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddHostNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddHostNode.graph); + if (data->args.hipGraphAddHostNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddHostNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddHostNode.numDependencies); + if (data->args.hipGraphAddHostNode.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddHostNode.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphAddKernelNode: + oss << "hipGraphAddKernelNode("; + if (data->args.hipGraphAddKernelNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddKernelNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddKernelNode.graph); + if (data->args.hipGraphAddKernelNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddKernelNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddKernelNode.numDependencies); + if (data->args.hipGraphAddKernelNode.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddKernelNode.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphAddMemAllocNode: + oss << "hipGraphAddMemAllocNode("; + if (data->args.hipGraphAddMemAllocNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemAllocNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemAllocNode.graph); + if (data->args.hipGraphAddMemAllocNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemAllocNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemAllocNode.numDependencies); + if (data->args.hipGraphAddMemAllocNode.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemAllocNode.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphAddMemFreeNode: + oss << "hipGraphAddMemFreeNode("; + if (data->args.hipGraphAddMemFreeNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemFreeNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemFreeNode.graph); + if (data->args.hipGraphAddMemFreeNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemFreeNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemFreeNode.numDependencies); + oss << ", dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemFreeNode.dev_ptr); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddMemcpyNode: + oss << "hipGraphAddMemcpyNode("; + if (data->args.hipGraphAddMemcpyNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode.graph); + if (data->args.hipGraphAddMemcpyNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode.numDependencies); + if (data->args.hipGraphAddMemcpyNode.pCopyParams == NULL) + oss << ", pCopyParams=NULL"; + else { + oss << ", pCopyParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode.pCopyParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphAddMemcpyNode1D: + oss << "hipGraphAddMemcpyNode1D("; + if (data->args.hipGraphAddMemcpyNode1D.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode1D.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode1D.graph); + if (data->args.hipGraphAddMemcpyNode1D.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode1D.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode1D.numDependencies); + oss << ", dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode1D.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode1D.src); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode1D.count); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNode1D.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddMemcpyNodeFromSymbol: + oss << "hipGraphAddMemcpyNodeFromSymbol("; + if (data->args.hipGraphAddMemcpyNodeFromSymbol.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.graph); + if (data->args.hipGraphAddMemcpyNodeFromSymbol.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.numDependencies); + oss << ", dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.dst); + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.symbol); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.count); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeFromSymbol.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddMemcpyNodeToSymbol: + oss << "hipGraphAddMemcpyNodeToSymbol("; + if (data->args.hipGraphAddMemcpyNodeToSymbol.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.graph); + if (data->args.hipGraphAddMemcpyNodeToSymbol.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.numDependencies); + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.symbol); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.src); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.count); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemcpyNodeToSymbol.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphAddMemsetNode: + oss << "hipGraphAddMemsetNode("; + if (data->args.hipGraphAddMemsetNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemsetNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemsetNode.graph); + if (data->args.hipGraphAddMemsetNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemsetNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemsetNode.numDependencies); + if (data->args.hipGraphAddMemsetNode.pMemsetParams == NULL) + oss << ", pMemsetParams=NULL"; + else { + oss << ", pMemsetParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddMemsetNode.pMemsetParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphAddNode: + oss << "hipGraphAddNode("; + if (data->args.hipGraphAddNode.pGraphNode == NULL) + oss << "pGraphNode=NULL"; + else { + oss << "pGraphNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddNode.pGraphNode__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddNode.graph); + if (data->args.hipGraphAddNode.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddNode.pDependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddNode.numDependencies); + if (data->args.hipGraphAddNode.nodeParams == NULL) + oss << ", nodeParams=NULL"; + else { + oss << ", nodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphAddNode.nodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphChildGraphNodeGetGraph: + oss << "hipGraphChildGraphNodeGetGraph("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphChildGraphNodeGetGraph.node); + if (data->args.hipGraphChildGraphNodeGetGraph.pGraph == NULL) + oss << ", pGraph=NULL"; + else { + oss << ", pGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphChildGraphNodeGetGraph.pGraph__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphClone: + oss << "hipGraphClone("; + if (data->args.hipGraphClone.pGraphClone == NULL) + oss << "pGraphClone=NULL"; + else { + oss << "pGraphClone="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphClone.pGraphClone__val); + } + oss << ", originalGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphClone.originalGraph); + oss << ")"; + break; + case HIP_API_ID_hipGraphCreate: + oss << "hipGraphCreate("; + if (data->args.hipGraphCreate.pGraph == NULL) + oss << "pGraph=NULL"; + else { + oss << "pGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphCreate.pGraph__val); + } + oss << ", flags="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipGraphCreate.flags); + oss << ")"; + break; + case HIP_API_ID_hipGraphDebugDotPrint: + oss << "hipGraphDebugDotPrint("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphDebugDotPrint.graph); + if (data->args.hipGraphDebugDotPrint.path == NULL) + oss << ", path=NULL"; + else { + oss << ", path="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphDebugDotPrint.path__val); + } + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphDebugDotPrint.flags); + oss << ")"; + break; + case HIP_API_ID_hipGraphDestroy: + oss << "hipGraphDestroy("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphDestroy.graph); + oss << ")"; + break; + case HIP_API_ID_hipGraphDestroyNode: + oss << "hipGraphDestroyNode("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphDestroyNode.node); + oss << ")"; + break; + case HIP_API_ID_hipGraphEventRecordNodeGetEvent: + oss << "hipGraphEventRecordNodeGetEvent("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphEventRecordNodeGetEvent.node); + if (data->args.hipGraphEventRecordNodeGetEvent.event_out == NULL) + oss << ", event_out=NULL"; + else { + oss << ", event_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphEventRecordNodeGetEvent.event_out__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphEventRecordNodeSetEvent: + oss << "hipGraphEventRecordNodeSetEvent("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphEventRecordNodeSetEvent.node); + oss << ", event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphEventRecordNodeSetEvent.event); + oss << ")"; + break; + case HIP_API_ID_hipGraphEventWaitNodeGetEvent: + oss << "hipGraphEventWaitNodeGetEvent("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphEventWaitNodeGetEvent.node); + if (data->args.hipGraphEventWaitNodeGetEvent.event_out == NULL) + oss << ", event_out=NULL"; + else { + oss << ", event_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphEventWaitNodeGetEvent.event_out__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphEventWaitNodeSetEvent: + oss << "hipGraphEventWaitNodeSetEvent("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphEventWaitNodeSetEvent.node); + oss << ", event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphEventWaitNodeSetEvent.event); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecChildGraphNodeSetParams: + oss << "hipGraphExecChildGraphNodeSetParams("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecChildGraphNodeSetParams.hGraphExec); + oss << ", node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecChildGraphNodeSetParams.node); + oss << ", childGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecChildGraphNodeSetParams.childGraph); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecDestroy: + oss << "hipGraphExecDestroy("; + oss << "graphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecDestroy.graphExec); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecEventRecordNodeSetEvent: + oss << "hipGraphExecEventRecordNodeSetEvent("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecEventRecordNodeSetEvent.hGraphExec); + oss << ", hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecEventRecordNodeSetEvent.hNode); + oss << ", event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecEventRecordNodeSetEvent.event); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecEventWaitNodeSetEvent: + oss << "hipGraphExecEventWaitNodeSetEvent("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecEventWaitNodeSetEvent.hGraphExec); + oss << ", hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecEventWaitNodeSetEvent.hNode); + oss << ", event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecEventWaitNodeSetEvent.event); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecExternalSemaphoresSignalNodeSetParams: + oss << "hipGraphExecExternalSemaphoresSignalNodeSetParams("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecExternalSemaphoresSignalNodeSetParams + .hGraphExec); + oss << ", hNode="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphExecExternalSemaphoresSignalNodeSetParams.hNode); + if (data->args.hipGraphExecExternalSemaphoresSignalNodeSetParams + .nodeParams == NULL) + oss << ", nodeParams=NULL"; + else { + oss << ", nodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecExternalSemaphoresSignalNodeSetParams + .nodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExecExternalSemaphoresWaitNodeSetParams: + oss << "hipGraphExecExternalSemaphoresWaitNodeSetParams("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphExecExternalSemaphoresWaitNodeSetParams.hGraphExec); + oss << ", hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecExternalSemaphoresWaitNodeSetParams.hNode); + if (data->args.hipGraphExecExternalSemaphoresWaitNodeSetParams.nodeParams == + NULL) + oss << ", nodeParams=NULL"; + else { + oss << ", nodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecExternalSemaphoresWaitNodeSetParams + .nodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExecHostNodeSetParams: + oss << "hipGraphExecHostNodeSetParams("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecHostNodeSetParams.hGraphExec); + oss << ", node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecHostNodeSetParams.node); + if (data->args.hipGraphExecHostNodeSetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecHostNodeSetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExecKernelNodeSetParams: + oss << "hipGraphExecKernelNodeSetParams("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecKernelNodeSetParams.hGraphExec); + oss << ", node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecKernelNodeSetParams.node); + if (data->args.hipGraphExecKernelNodeSetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecKernelNodeSetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExecMemcpyNodeSetParams: + oss << "hipGraphExecMemcpyNodeSetParams("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams.hGraphExec); + oss << ", node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams.node); + if (data->args.hipGraphExecMemcpyNodeSetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExecMemcpyNodeSetParams1D: + oss << "hipGraphExecMemcpyNodeSetParams1D("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams1D.hGraphExec); + oss << ", node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams1D.node); + oss << ", dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams1D.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams1D.src); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams1D.count); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParams1D.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecMemcpyNodeSetParamsFromSymbol: + oss << "hipGraphExecMemcpyNodeSetParamsFromSymbol("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsFromSymbol.hGraphExec); + oss << ", node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsFromSymbol.node); + oss << ", dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsFromSymbol.dst); + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsFromSymbol.symbol); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsFromSymbol.count); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsFromSymbol.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsFromSymbol.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecMemcpyNodeSetParamsToSymbol: + oss << "hipGraphExecMemcpyNodeSetParamsToSymbol("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsToSymbol.hGraphExec); + oss << ", node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsToSymbol.node); + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsToSymbol.symbol); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsToSymbol.src); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsToSymbol.count); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsToSymbol.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemcpyNodeSetParamsToSymbol.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphExecMemsetNodeSetParams: + oss << "hipGraphExecMemsetNodeSetParams("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemsetNodeSetParams.hGraphExec); + oss << ", node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemsetNodeSetParams.node); + if (data->args.hipGraphExecMemsetNodeSetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecMemsetNodeSetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExecUpdate: + oss << "hipGraphExecUpdate("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecUpdate.hGraphExec); + oss << ", hGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecUpdate.hGraph); + if (data->args.hipGraphExecUpdate.hErrorNode_out == NULL) + oss << ", hErrorNode_out=NULL"; + else { + oss << ", hErrorNode_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecUpdate.hErrorNode_out__val); + } + if (data->args.hipGraphExecUpdate.updateResult_out == NULL) + oss << ", updateResult_out=NULL"; + else { + oss << ", updateResult_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExecUpdate.updateResult_out__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExternalSemaphoresSignalNodeGetParams: + oss << "hipGraphExternalSemaphoresSignalNodeGetParams("; + oss << "hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExternalSemaphoresSignalNodeGetParams.hNode); + if (data->args.hipGraphExternalSemaphoresSignalNodeGetParams.params_out == + NULL) + oss << ", params_out=NULL"; + else { + oss << ", params_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExternalSemaphoresSignalNodeGetParams + .params_out__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExternalSemaphoresSignalNodeSetParams: + oss << "hipGraphExternalSemaphoresSignalNodeSetParams("; + oss << "hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExternalSemaphoresSignalNodeSetParams.hNode); + if (data->args.hipGraphExternalSemaphoresSignalNodeSetParams.nodeParams == + NULL) + oss << ", nodeParams=NULL"; + else { + oss << ", nodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExternalSemaphoresSignalNodeSetParams + .nodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExternalSemaphoresWaitNodeGetParams: + oss << "hipGraphExternalSemaphoresWaitNodeGetParams("; + oss << "hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExternalSemaphoresWaitNodeGetParams.hNode); + if (data->args.hipGraphExternalSemaphoresWaitNodeGetParams.params_out == + NULL) + oss << ", params_out=NULL"; + else { + oss << ", params_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExternalSemaphoresWaitNodeGetParams + .params_out__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphExternalSemaphoresWaitNodeSetParams: + oss << "hipGraphExternalSemaphoresWaitNodeSetParams("; + oss << "hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExternalSemaphoresWaitNodeSetParams.hNode); + if (data->args.hipGraphExternalSemaphoresWaitNodeSetParams.nodeParams == + NULL) + oss << ", nodeParams=NULL"; + else { + oss << ", nodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphExternalSemaphoresWaitNodeSetParams + .nodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphGetEdges: + oss << "hipGraphGetEdges("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetEdges.graph); + if (data->args.hipGraphGetEdges.from == NULL) + oss << ", from=NULL"; + else { + oss << ", from="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetEdges.from__val); + } + if (data->args.hipGraphGetEdges.to == NULL) + oss << ", to=NULL"; + else { + oss << ", to="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetEdges.to__val); + } + if (data->args.hipGraphGetEdges.numEdges == NULL) + oss << ", numEdges=NULL"; + else { + oss << ", numEdges="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetEdges.numEdges__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphGetNodes: + oss << "hipGraphGetNodes("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetNodes.graph); + if (data->args.hipGraphGetNodes.nodes == NULL) + oss << ", nodes=NULL"; + else { + oss << ", nodes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetNodes.nodes__val); + } + if (data->args.hipGraphGetNodes.numNodes == NULL) + oss << ", numNodes=NULL"; + else { + oss << ", numNodes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetNodes.numNodes__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphGetRootNodes: + oss << "hipGraphGetRootNodes("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetRootNodes.graph); + if (data->args.hipGraphGetRootNodes.pRootNodes == NULL) + oss << ", pRootNodes=NULL"; + else { + oss << ", pRootNodes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetRootNodes.pRootNodes__val); + } + if (data->args.hipGraphGetRootNodes.pNumRootNodes == NULL) + oss << ", pNumRootNodes=NULL"; + else { + oss << ", pNumRootNodes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphGetRootNodes.pNumRootNodes__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphHostNodeGetParams: + oss << "hipGraphHostNodeGetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphHostNodeGetParams.node); + if (data->args.hipGraphHostNodeGetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphHostNodeGetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphHostNodeSetParams: + oss << "hipGraphHostNodeSetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphHostNodeSetParams.node); + if (data->args.hipGraphHostNodeSetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphHostNodeSetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphInstantiate: + oss << "hipGraphInstantiate("; + if (data->args.hipGraphInstantiate.pGraphExec == NULL) + oss << "pGraphExec=NULL"; + else { + oss << "pGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiate.pGraphExec__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiate.graph); + if (data->args.hipGraphInstantiate.pErrorNode == NULL) + oss << ", pErrorNode=NULL"; + else { + oss << ", pErrorNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiate.pErrorNode__val); + } + if (data->args.hipGraphInstantiate.pLogBuffer == NULL) + oss << ", pLogBuffer=NULL"; + else { + oss << ", pLogBuffer="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiate.pLogBuffer__val); + } + oss << ", bufferSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiate.bufferSize); + oss << ")"; + break; + case HIP_API_ID_hipGraphInstantiateWithFlags: + oss << "hipGraphInstantiateWithFlags("; + if (data->args.hipGraphInstantiateWithFlags.pGraphExec == NULL) + oss << "pGraphExec=NULL"; + else { + oss << "pGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiateWithFlags.pGraphExec__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiateWithFlags.graph); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiateWithFlags.flags); + oss << ")"; + break; + case HIP_API_ID_hipGraphInstantiateWithParams: + oss << "hipGraphInstantiateWithParams("; + if (data->args.hipGraphInstantiateWithParams.pGraphExec == NULL) + oss << "pGraphExec=NULL"; + else { + oss << "pGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiateWithParams.pGraphExec__val); + } + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiateWithParams.graph); + if (data->args.hipGraphInstantiateWithParams.instantiateParams == NULL) + oss << ", instantiateParams=NULL"; + else { + oss << ", instantiateParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphInstantiateWithParams.instantiateParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphKernelNodeCopyAttributes: + oss << "hipGraphKernelNodeCopyAttributes("; + oss << "hSrc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeCopyAttributes.hSrc); + oss << ", hDst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeCopyAttributes.hDst); + oss << ")"; + break; + case HIP_API_ID_hipGraphKernelNodeGetAttribute: + oss << "hipGraphKernelNodeGetAttribute("; + oss << "hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeGetAttribute.hNode); + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeGetAttribute.attr); + if (data->args.hipGraphKernelNodeGetAttribute.value == NULL) + oss << ", value=NULL"; + else { + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeGetAttribute.value__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphKernelNodeGetParams: + oss << "hipGraphKernelNodeGetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeGetParams.node); + if (data->args.hipGraphKernelNodeGetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeGetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphKernelNodeSetAttribute: + oss << "hipGraphKernelNodeSetAttribute("; + oss << "hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeSetAttribute.hNode); + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeSetAttribute.attr); + if (data->args.hipGraphKernelNodeSetAttribute.value == NULL) + oss << ", value=NULL"; + else { + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeSetAttribute.value__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphKernelNodeSetParams: + oss << "hipGraphKernelNodeSetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeSetParams.node); + if (data->args.hipGraphKernelNodeSetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphKernelNodeSetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphLaunch: + oss << "hipGraphLaunch("; + oss << "graphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphLaunch.graphExec); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphLaunch.stream); + oss << ")"; + break; + case HIP_API_ID_hipGraphMemAllocNodeGetParams: + oss << "hipGraphMemAllocNodeGetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemAllocNodeGetParams.node); + if (data->args.hipGraphMemAllocNodeGetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemAllocNodeGetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphMemFreeNodeGetParams: + oss << "hipGraphMemFreeNodeGetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemFreeNodeGetParams.node); + oss << ", dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemFreeNodeGetParams.dev_ptr); + oss << ")"; + break; + case HIP_API_ID_hipGraphMemcpyNodeGetParams: + oss << "hipGraphMemcpyNodeGetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeGetParams.node); + if (data->args.hipGraphMemcpyNodeGetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeGetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphMemcpyNodeSetParams: + oss << "hipGraphMemcpyNodeSetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParams.node); + if (data->args.hipGraphMemcpyNodeSetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphMemcpyNodeSetParams1D: + oss << "hipGraphMemcpyNodeSetParams1D("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParams1D.node); + oss << ", dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParams1D.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParams1D.src); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParams1D.count); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParams1D.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphMemcpyNodeSetParamsFromSymbol: + oss << "hipGraphMemcpyNodeSetParamsFromSymbol("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsFromSymbol.node); + oss << ", dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsFromSymbol.dst); + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsFromSymbol.symbol); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsFromSymbol.count); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsFromSymbol.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsFromSymbol.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphMemcpyNodeSetParamsToSymbol: + oss << "hipGraphMemcpyNodeSetParamsToSymbol("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsToSymbol.node); + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsToSymbol.symbol); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsToSymbol.src); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsToSymbol.count); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsToSymbol.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemcpyNodeSetParamsToSymbol.kind); + oss << ")"; + break; + case HIP_API_ID_hipGraphMemsetNodeGetParams: + oss << "hipGraphMemsetNodeGetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemsetNodeGetParams.node); + if (data->args.hipGraphMemsetNodeGetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemsetNodeGetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphMemsetNodeSetParams: + oss << "hipGraphMemsetNodeSetParams("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemsetNodeSetParams.node); + if (data->args.hipGraphMemsetNodeSetParams.pNodeParams == NULL) + oss << ", pNodeParams=NULL"; + else { + oss << ", pNodeParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphMemsetNodeSetParams.pNodeParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphNodeFindInClone: + oss << "hipGraphNodeFindInClone("; + if (data->args.hipGraphNodeFindInClone.pNode == NULL) + oss << "pNode=NULL"; + else { + oss << "pNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeFindInClone.pNode__val); + } + oss << ", originalNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeFindInClone.originalNode); + oss << ", clonedGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeFindInClone.clonedGraph); + oss << ")"; + break; + case HIP_API_ID_hipGraphNodeGetDependencies: + oss << "hipGraphNodeGetDependencies("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetDependencies.node); + if (data->args.hipGraphNodeGetDependencies.pDependencies == NULL) + oss << ", pDependencies=NULL"; + else { + oss << ", pDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetDependencies.pDependencies__val); + } + if (data->args.hipGraphNodeGetDependencies.pNumDependencies == NULL) + oss << ", pNumDependencies=NULL"; + else { + oss << ", pNumDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetDependencies.pNumDependencies__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphNodeGetDependentNodes: + oss << "hipGraphNodeGetDependentNodes("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetDependentNodes.node); + if (data->args.hipGraphNodeGetDependentNodes.pDependentNodes == NULL) + oss << ", pDependentNodes=NULL"; + else { + oss << ", pDependentNodes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetDependentNodes.pDependentNodes__val); + } + if (data->args.hipGraphNodeGetDependentNodes.pNumDependentNodes == NULL) + oss << ", pNumDependentNodes=NULL"; + else { + oss << ", pNumDependentNodes="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipGraphNodeGetDependentNodes.pNumDependentNodes__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphNodeGetEnabled: + oss << "hipGraphNodeGetEnabled("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetEnabled.hGraphExec); + oss << ", hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetEnabled.hNode); + if (data->args.hipGraphNodeGetEnabled.isEnabled == NULL) + oss << ", isEnabled=NULL"; + else { + oss << ", isEnabled="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetEnabled.isEnabled__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphNodeGetType: + oss << "hipGraphNodeGetType("; + oss << "node="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetType.node); + if (data->args.hipGraphNodeGetType.pType == NULL) + oss << ", pType=NULL"; + else { + oss << ", pType="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeGetType.pType__val); + } + oss << ")"; + break; + case HIP_API_ID_hipGraphNodeSetEnabled: + oss << "hipGraphNodeSetEnabled("; + oss << "hGraphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeSetEnabled.hGraphExec); + oss << ", hNode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeSetEnabled.hNode); + oss << ", isEnabled="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphNodeSetEnabled.isEnabled); + oss << ")"; + break; + case HIP_API_ID_hipGraphReleaseUserObject: + oss << "hipGraphReleaseUserObject("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphReleaseUserObject.graph); + oss << ", object="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphReleaseUserObject.object); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphReleaseUserObject.count); + oss << ")"; + break; + case HIP_API_ID_hipGraphRemoveDependencies: + oss << "hipGraphRemoveDependencies("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphRemoveDependencies.graph); + if (data->args.hipGraphRemoveDependencies.from == NULL) + oss << ", from=NULL"; + else { + oss << ", from="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphRemoveDependencies.from__val); + } + if (data->args.hipGraphRemoveDependencies.to == NULL) + oss << ", to=NULL"; + else { + oss << ", to="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphRemoveDependencies.to__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphRemoveDependencies.numDependencies); + oss << ")"; + break; + case HIP_API_ID_hipGraphRetainUserObject: + oss << "hipGraphRetainUserObject("; + oss << "graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphRetainUserObject.graph); + oss << ", object="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphRetainUserObject.object); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphRetainUserObject.count); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphRetainUserObject.flags); + oss << ")"; + break; + case HIP_API_ID_hipGraphUpload: + oss << "hipGraphUpload("; + oss << "graphExec="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphUpload.graphExec); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphUpload.stream); + oss << ")"; + break; + case HIP_API_ID_hipGraphicsGLRegisterBuffer: + oss << "hipGraphicsGLRegisterBuffer("; + if (data->args.hipGraphicsGLRegisterBuffer.resource == NULL) + oss << "resource=NULL"; + else { + oss << "resource="; + roctracer::hip_support::detail::operator<<( + oss, (void *)data->args.hipGraphicsGLRegisterBuffer.resource__val); + } + oss << ", buffer="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsGLRegisterBuffer.buffer); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsGLRegisterBuffer.flags); + oss << ")"; + break; + case HIP_API_ID_hipGraphicsGLRegisterImage: + oss << "hipGraphicsGLRegisterImage("; + if (data->args.hipGraphicsGLRegisterImage.resource == NULL) + oss << "resource=NULL"; + else { + oss << "resource="; + roctracer::hip_support::detail::operator<<( + oss, (void *)data->args.hipGraphicsGLRegisterImage.resource__val); + } + oss << ", image="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsGLRegisterImage.image); + oss << ", target="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsGLRegisterImage.target); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsGLRegisterImage.flags); + oss << ")"; + break; + case HIP_API_ID_hipGraphicsMapResources: + oss << "hipGraphicsMapResources("; + oss << "count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsMapResources.count); + if (data->args.hipGraphicsMapResources.resources == NULL) + oss << ", resources=NULL"; + else { + oss << ", resources="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsMapResources.resources__val); + } + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsMapResources.stream); + oss << ")"; + break; + case HIP_API_ID_hipGraphicsResourceGetMappedPointer: + oss << "hipGraphicsResourceGetMappedPointer("; + if (data->args.hipGraphicsResourceGetMappedPointer.devPtr == NULL) + oss << "devPtr=NULL"; + else { + oss << "devPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsResourceGetMappedPointer.devPtr__val); + } + if (data->args.hipGraphicsResourceGetMappedPointer.size == NULL) + oss << ", size=NULL"; + else { + oss << ", size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsResourceGetMappedPointer.size__val); + } + oss << ", resource="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsResourceGetMappedPointer.resource); + oss << ")"; + break; + case HIP_API_ID_hipGraphicsSubResourceGetMappedArray: + oss << "hipGraphicsSubResourceGetMappedArray("; + if (data->args.hipGraphicsSubResourceGetMappedArray.array == NULL) + oss << "array=NULL"; + else { + oss << "array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsSubResourceGetMappedArray.array__val); + } + oss << ", resource="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsSubResourceGetMappedArray.resource); + oss << ", arrayIndex="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsSubResourceGetMappedArray.arrayIndex); + oss << ", mipLevel="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsSubResourceGetMappedArray.mipLevel); + oss << ")"; + break; + case HIP_API_ID_hipGraphicsUnmapResources: + oss << "hipGraphicsUnmapResources("; + oss << "count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsUnmapResources.count); + if (data->args.hipGraphicsUnmapResources.resources == NULL) + oss << ", resources=NULL"; + else { + oss << ", resources="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsUnmapResources.resources__val); + } + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsUnmapResources.stream); + oss << ")"; + break; + case HIP_API_ID_hipGraphicsUnregisterResource: + oss << "hipGraphicsUnregisterResource("; + oss << "resource="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipGraphicsUnregisterResource.resource); + oss << ")"; + break; + case HIP_API_ID_hipHccModuleLaunchKernel: + oss << "hipHccModuleLaunchKernel("; + oss << "f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.f); + oss << ", globalWorkSizeX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.globalWorkSizeX); + oss << ", globalWorkSizeY="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.globalWorkSizeY); + oss << ", globalWorkSizeZ="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.globalWorkSizeZ); + oss << ", blockDimX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.blockDimX); + oss << ", blockDimY="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.blockDimY); + oss << ", blockDimZ="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.blockDimZ); + oss << ", sharedMemBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.sharedMemBytes); + oss << ", hStream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.hStream); + if (data->args.hipHccModuleLaunchKernel.kernelParams == NULL) + oss << ", kernelParams=NULL"; + else { + oss << ", kernelParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.kernelParams__val); + } + if (data->args.hipHccModuleLaunchKernel.extra == NULL) + oss << ", extra=NULL"; + else { + oss << ", extra="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.extra__val); + } + oss << ", startEvent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.startEvent); + oss << ", stopEvent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHccModuleLaunchKernel.stopEvent); + oss << ")"; + break; + case HIP_API_ID_hipHostAlloc: + oss << "hipHostAlloc("; + if (data->args.hipHostAlloc.ptr == NULL) + oss << "ptr=NULL"; + else { + oss << "ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostAlloc.ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipHostAlloc.size); + oss << ", flags="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipHostAlloc.flags); + oss << ")"; + break; + case HIP_API_ID_hipHostFree: + oss << "hipHostFree("; + oss << "ptr="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipHostFree.ptr); + oss << ")"; + break; + case HIP_API_ID_hipHostGetDevicePointer: + oss << "hipHostGetDevicePointer("; + if (data->args.hipHostGetDevicePointer.devPtr == NULL) + oss << "devPtr=NULL"; + else { + oss << "devPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostGetDevicePointer.devPtr__val); + } + oss << ", hstPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostGetDevicePointer.hstPtr); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostGetDevicePointer.flags); + oss << ")"; + break; + case HIP_API_ID_hipHostGetFlags: + oss << "hipHostGetFlags("; + if (data->args.hipHostGetFlags.flagsPtr == NULL) + oss << "flagsPtr=NULL"; + else { + oss << "flagsPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostGetFlags.flagsPtr__val); + } + oss << ", hostPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostGetFlags.hostPtr); + oss << ")"; + break; + case HIP_API_ID_hipHostMalloc: + oss << "hipHostMalloc("; + if (data->args.hipHostMalloc.ptr == NULL) + oss << "ptr=NULL"; + else { + oss << "ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostMalloc.ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipHostMalloc.size); + oss << ", flags="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipHostMalloc.flags); + oss << ")"; + break; + case HIP_API_ID_hipHostRegister: + oss << "hipHostRegister("; + oss << "hostPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostRegister.hostPtr); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostRegister.sizeBytes); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostRegister.flags); + oss << ")"; + break; + case HIP_API_ID_hipHostUnregister: + oss << "hipHostUnregister("; + oss << "hostPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipHostUnregister.hostPtr); + oss << ")"; + break; + case HIP_API_ID_hipImportExternalMemory: + oss << "hipImportExternalMemory("; + if (data->args.hipImportExternalMemory.extMem_out == NULL) + oss << "extMem_out=NULL"; + else { + oss << "extMem_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipImportExternalMemory.extMem_out__val); + } + if (data->args.hipImportExternalMemory.memHandleDesc == NULL) + oss << ", memHandleDesc=NULL"; + else { + oss << ", memHandleDesc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipImportExternalMemory.memHandleDesc__val); + } + oss << ")"; + break; + case HIP_API_ID_hipImportExternalSemaphore: + oss << "hipImportExternalSemaphore("; + if (data->args.hipImportExternalSemaphore.extSem_out == NULL) + oss << "extSem_out=NULL"; + else { + oss << "extSem_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipImportExternalSemaphore.extSem_out__val); + } + if (data->args.hipImportExternalSemaphore.semHandleDesc == NULL) + oss << ", semHandleDesc=NULL"; + else { + oss << ", semHandleDesc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipImportExternalSemaphore.semHandleDesc__val); + } + oss << ")"; + break; + case HIP_API_ID_hipInit: + oss << "hipInit("; + oss << "flags="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipInit.flags); + oss << ")"; + break; + case HIP_API_ID_hipIpcCloseMemHandle: + oss << "hipIpcCloseMemHandle("; + oss << "devPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcCloseMemHandle.devPtr); + oss << ")"; + break; + case HIP_API_ID_hipIpcGetEventHandle: + oss << "hipIpcGetEventHandle("; + if (data->args.hipIpcGetEventHandle.handle == NULL) + oss << "handle=NULL"; + else { + oss << "handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcGetEventHandle.handle__val); + } + oss << ", event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcGetEventHandle.event); + oss << ")"; + break; + case HIP_API_ID_hipIpcGetMemHandle: + oss << "hipIpcGetMemHandle("; + if (data->args.hipIpcGetMemHandle.handle == NULL) + oss << "handle=NULL"; + else { + oss << "handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcGetMemHandle.handle__val); + } + oss << ", devPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcGetMemHandle.devPtr); + oss << ")"; + break; + case HIP_API_ID_hipIpcOpenEventHandle: + oss << "hipIpcOpenEventHandle("; + if (data->args.hipIpcOpenEventHandle.event == NULL) + oss << "event=NULL"; + else { + oss << "event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcOpenEventHandle.event__val); + } + oss << ", handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcOpenEventHandle.handle); + oss << ")"; + break; + case HIP_API_ID_hipIpcOpenMemHandle: + oss << "hipIpcOpenMemHandle("; + if (data->args.hipIpcOpenMemHandle.devPtr == NULL) + oss << "devPtr=NULL"; + else { + oss << "devPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcOpenMemHandle.devPtr__val); + } + oss << ", handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcOpenMemHandle.handle); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipIpcOpenMemHandle.flags); + oss << ")"; + break; + case HIP_API_ID_hipLaunchByPtr: + oss << "hipLaunchByPtr("; + oss << "hostFunction="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchByPtr.hostFunction); + oss << ")"; + break; + case HIP_API_ID_hipLaunchCooperativeKernel: + oss << "hipLaunchCooperativeKernel("; + oss << "f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernel.f); + oss << ", gridDim="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernel.gridDim); + oss << ", blockDimX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernel.blockDimX); + if (data->args.hipLaunchCooperativeKernel.kernelParams == NULL) + oss << ", kernelParams=NULL"; + else { + oss << ", kernelParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernel.kernelParams__val); + } + oss << ", sharedMemBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernel.sharedMemBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernel.stream); + oss << ")"; + break; + case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: + oss << "hipLaunchCooperativeKernelMultiDevice("; + if (data->args.hipLaunchCooperativeKernelMultiDevice.launchParamsList == + NULL) + oss << "launchParamsList=NULL"; + else { + oss << "launchParamsList="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernelMultiDevice + .launchParamsList__val); + } + oss << ", numDevices="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernelMultiDevice.numDevices); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchCooperativeKernelMultiDevice.flags); + oss << ")"; + break; + case HIP_API_ID_hipLaunchHostFunc: + oss << "hipLaunchHostFunc("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchHostFunc.stream); + oss << ", fn="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipLaunchHostFunc.fn); + oss << ", userData="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchHostFunc.userData); + oss << ")"; + break; + case HIP_API_ID_hipLaunchKernel: + oss << "hipLaunchKernel("; + oss << "function_address="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchKernel.function_address); + oss << ", numBlocks="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchKernel.numBlocks); + oss << ", dimBlocks="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchKernel.dimBlocks); + if (data->args.hipLaunchKernel.args == NULL) + oss << ", args=NULL"; + else { + oss << ", args="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchKernel.args__val); + } + oss << ", sharedMemBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchKernel.sharedMemBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipLaunchKernel.stream); + oss << ")"; + break; + case HIP_API_ID_hipMalloc: + oss << "hipMalloc("; + if (data->args.hipMalloc.ptr == NULL) + oss << "ptr=NULL"; + else { + oss << "ptr="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMalloc.ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMalloc.size); + oss << ")"; + break; + case HIP_API_ID_hipMalloc3D: + oss << "hipMalloc3D("; + if (data->args.hipMalloc3D.pitchedDevPtr == NULL) + oss << "pitchedDevPtr=NULL"; + else { + oss << "pitchedDevPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMalloc3D.pitchedDevPtr__val); + } + oss << ", extent="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMalloc3D.extent); + oss << ")"; + break; + case HIP_API_ID_hipMalloc3DArray: + oss << "hipMalloc3DArray("; + if (data->args.hipMalloc3DArray.array == NULL) + oss << "array=NULL"; + else { + oss << "array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMalloc3DArray.array__val); + } + if (data->args.hipMalloc3DArray.desc == NULL) + oss << ", desc=NULL"; + else { + oss << ", desc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMalloc3DArray.desc__val); + } + oss << ", extent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMalloc3DArray.extent); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMalloc3DArray.flags); + oss << ")"; + break; + case HIP_API_ID_hipMallocArray: + oss << "hipMallocArray("; + if (data->args.hipMallocArray.array == NULL) + oss << "array=NULL"; + else { + oss << "array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocArray.array__val); + } + if (data->args.hipMallocArray.desc == NULL) + oss << ", desc=NULL"; + else { + oss << ", desc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocArray.desc__val); + } + oss << ", width="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMallocArray.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocArray.height); + oss << ", flags="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMallocArray.flags); + oss << ")"; + break; + case HIP_API_ID_hipMallocAsync: + oss << "hipMallocAsync("; + if (data->args.hipMallocAsync.dev_ptr == NULL) + oss << "dev_ptr=NULL"; + else { + oss << "dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocAsync.dev_ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMallocAsync.size); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMallocFromPoolAsync: + oss << "hipMallocFromPoolAsync("; + if (data->args.hipMallocFromPoolAsync.dev_ptr == NULL) + oss << "dev_ptr=NULL"; + else { + oss << "dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocFromPoolAsync.dev_ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocFromPoolAsync.size); + oss << ", mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocFromPoolAsync.mem_pool); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocFromPoolAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMallocHost: + oss << "hipMallocHost("; + if (data->args.hipMallocHost.ptr == NULL) + oss << "ptr=NULL"; + else { + oss << "ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocHost.ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMallocHost.size); + oss << ")"; + break; + case HIP_API_ID_hipMallocManaged: + oss << "hipMallocManaged("; + if (data->args.hipMallocManaged.dev_ptr == NULL) + oss << "dev_ptr=NULL"; + else { + oss << "dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocManaged.dev_ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocManaged.size); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocManaged.flags); + oss << ")"; + break; + case HIP_API_ID_hipMallocMipmappedArray: + oss << "hipMallocMipmappedArray("; + if (data->args.hipMallocMipmappedArray.mipmappedArray == NULL) + oss << "mipmappedArray=NULL"; + else { + oss << "mipmappedArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocMipmappedArray.mipmappedArray__val); + } + if (data->args.hipMallocMipmappedArray.desc == NULL) + oss << ", desc=NULL"; + else { + oss << ", desc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocMipmappedArray.desc__val); + } + oss << ", extent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocMipmappedArray.extent); + oss << ", numLevels="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocMipmappedArray.numLevels); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocMipmappedArray.flags); + oss << ")"; + break; + case HIP_API_ID_hipMallocPitch: + oss << "hipMallocPitch("; + if (data->args.hipMallocPitch.ptr == NULL) + oss << "ptr=NULL"; + else { + oss << "ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocPitch.ptr__val); + } + if (data->args.hipMallocPitch.pitch == NULL) + oss << ", pitch=NULL"; + else { + oss << ", pitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocPitch.pitch__val); + } + oss << ", width="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMallocPitch.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMallocPitch.height); + oss << ")"; + break; + case HIP_API_ID_hipMemAddressFree: + oss << "hipMemAddressFree("; + oss << "devPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAddressFree.devPtr); + oss << ", size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAddressFree.size); + oss << ")"; + break; + case HIP_API_ID_hipMemAddressReserve: + oss << "hipMemAddressReserve("; + if (data->args.hipMemAddressReserve.ptr == NULL) + oss << "ptr=NULL"; + else { + oss << "ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAddressReserve.ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAddressReserve.size); + oss << ", alignment="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAddressReserve.alignment); + oss << ", addr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAddressReserve.addr); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAddressReserve.flags); + oss << ")"; + break; + case HIP_API_ID_hipMemAdvise: + oss << "hipMemAdvise("; + oss << "dev_ptr="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemAdvise.dev_ptr); + oss << ", count="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemAdvise.count); + oss << ", advice="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemAdvise.advice); + oss << ", device="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemAdvise.device); + oss << ")"; + break; + case HIP_API_ID_hipMemAllocHost: + oss << "hipMemAllocHost("; + if (data->args.hipMemAllocHost.ptr == NULL) + oss << "ptr=NULL"; + else { + oss << "ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAllocHost.ptr__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemAllocHost.size); + oss << ")"; + break; + case HIP_API_ID_hipMemAllocPitch: + oss << "hipMemAllocPitch("; + if (data->args.hipMemAllocPitch.dptr == NULL) + oss << "dptr=NULL"; + else { + oss << "dptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAllocPitch.dptr__val); + } + if (data->args.hipMemAllocPitch.pitch == NULL) + oss << ", pitch=NULL"; + else { + oss << ", pitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAllocPitch.pitch__val); + } + oss << ", widthInBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAllocPitch.widthInBytes); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAllocPitch.height); + oss << ", elementSizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemAllocPitch.elementSizeBytes); + oss << ")"; + break; + case HIP_API_ID_hipMemCreate: + oss << "hipMemCreate("; + if (data->args.hipMemCreate.handle == NULL) + oss << "handle=NULL"; + else { + oss << "handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemCreate.handle__val); + } + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemCreate.size); + if (data->args.hipMemCreate.prop == NULL) + oss << ", prop=NULL"; + else { + oss << ", prop="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemCreate.prop__val); + } + oss << ", flags="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemCreate.flags); + oss << ")"; + break; + case HIP_API_ID_hipMemExportToShareableHandle: + oss << "hipMemExportToShareableHandle("; + oss << "shareableHandle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemExportToShareableHandle.shareableHandle); + oss << ", handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemExportToShareableHandle.handle); + oss << ", handleType="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemExportToShareableHandle.handleType); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemExportToShareableHandle.flags); + oss << ")"; + break; + case HIP_API_ID_hipMemGetAccess: + oss << "hipMemGetAccess("; + if (data->args.hipMemGetAccess.flags == NULL) + oss << "flags=NULL"; + else { + oss << "flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAccess.flags__val); + } + if (data->args.hipMemGetAccess.location == NULL) + oss << ", location=NULL"; + else { + oss << ", location="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAccess.location__val); + } + oss << ", ptr="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemGetAccess.ptr); + oss << ")"; + break; + case HIP_API_ID_hipMemGetAddressRange: + oss << "hipMemGetAddressRange("; + if (data->args.hipMemGetAddressRange.pbase == NULL) + oss << "pbase=NULL"; + else { + oss << "pbase="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAddressRange.pbase__val); + } + if (data->args.hipMemGetAddressRange.psize == NULL) + oss << ", psize=NULL"; + else { + oss << ", psize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAddressRange.psize__val); + } + oss << ", dptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAddressRange.dptr); + oss << ")"; + break; + case HIP_API_ID_hipMemGetAllocationGranularity: + oss << "hipMemGetAllocationGranularity("; + if (data->args.hipMemGetAllocationGranularity.granularity == NULL) + oss << "granularity=NULL"; + else { + oss << "granularity="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAllocationGranularity.granularity__val); + } + if (data->args.hipMemGetAllocationGranularity.prop == NULL) + oss << ", prop=NULL"; + else { + oss << ", prop="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAllocationGranularity.prop__val); + } + oss << ", option="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAllocationGranularity.option); + oss << ")"; + break; + case HIP_API_ID_hipMemGetAllocationPropertiesFromHandle: + oss << "hipMemGetAllocationPropertiesFromHandle("; + if (data->args.hipMemGetAllocationPropertiesFromHandle.prop == NULL) + oss << "prop=NULL"; + else { + oss << "prop="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAllocationPropertiesFromHandle.prop__val); + } + oss << ", handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetAllocationPropertiesFromHandle.handle); + oss << ")"; + break; + case HIP_API_ID_hipMemGetInfo: + oss << "hipMemGetInfo("; + if (data->args.hipMemGetInfo.free == NULL) + oss << "free=NULL"; + else { + oss << "free="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetInfo.free__val); + } + if (data->args.hipMemGetInfo.total == NULL) + oss << ", total=NULL"; + else { + oss << ", total="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemGetInfo.total__val); + } + oss << ")"; + break; + case HIP_API_ID_hipMemImportFromShareableHandle: + oss << "hipMemImportFromShareableHandle("; + if (data->args.hipMemImportFromShareableHandle.handle == NULL) + oss << "handle=NULL"; + else { + oss << "handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemImportFromShareableHandle.handle__val); + } + oss << ", osHandle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemImportFromShareableHandle.osHandle); + oss << ", shHandleType="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemImportFromShareableHandle.shHandleType); + oss << ")"; + break; + case HIP_API_ID_hipMemMap: + oss << "hipMemMap("; + oss << "ptr="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemMap.ptr); + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemMap.size); + oss << ", offset="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemMap.offset); + oss << ", handle="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemMap.handle); + oss << ", flags="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemMap.flags); + oss << ")"; + break; + case HIP_API_ID_hipMemMapArrayAsync: + oss << "hipMemMapArrayAsync("; + if (data->args.hipMemMapArrayAsync.mapInfoList == NULL) + oss << "mapInfoList=NULL"; + else { + oss << "mapInfoList="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemMapArrayAsync.mapInfoList__val); + } + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemMapArrayAsync.count); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemMapArrayAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemPoolCreate: + oss << "hipMemPoolCreate("; + if (data->args.hipMemPoolCreate.mem_pool == NULL) + oss << "mem_pool=NULL"; + else { + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolCreate.mem_pool__val); + } + if (data->args.hipMemPoolCreate.pool_props == NULL) + oss << ", pool_props=NULL"; + else { + oss << ", pool_props="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolCreate.pool_props__val); + } + oss << ")"; + break; + case HIP_API_ID_hipMemPoolDestroy: + oss << "hipMemPoolDestroy("; + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolDestroy.mem_pool); + oss << ")"; + break; + case HIP_API_ID_hipMemPoolExportPointer: + oss << "hipMemPoolExportPointer("; + if (data->args.hipMemPoolExportPointer.export_data == NULL) + oss << "export_data=NULL"; + else { + oss << "export_data="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolExportPointer.export_data__val); + } + oss << ", dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolExportPointer.dev_ptr); + oss << ")"; + break; + case HIP_API_ID_hipMemPoolExportToShareableHandle: + oss << "hipMemPoolExportToShareableHandle("; + oss << "shared_handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolExportToShareableHandle.shared_handle); + oss << ", mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolExportToShareableHandle.mem_pool); + oss << ", handle_type="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolExportToShareableHandle.handle_type); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolExportToShareableHandle.flags); + oss << ")"; + break; + case HIP_API_ID_hipMemPoolGetAccess: + oss << "hipMemPoolGetAccess("; + if (data->args.hipMemPoolGetAccess.flags == NULL) + oss << "flags=NULL"; + else { + oss << "flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolGetAccess.flags__val); + } + oss << ", mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolGetAccess.mem_pool); + if (data->args.hipMemPoolGetAccess.location == NULL) + oss << ", location=NULL"; + else { + oss << ", location="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolGetAccess.location__val); + } + oss << ")"; + break; + case HIP_API_ID_hipMemPoolGetAttribute: + oss << "hipMemPoolGetAttribute("; + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolGetAttribute.mem_pool); + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolGetAttribute.attr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolGetAttribute.value); + oss << ")"; + break; + case HIP_API_ID_hipMemPoolImportFromShareableHandle: + oss << "hipMemPoolImportFromShareableHandle("; + if (data->args.hipMemPoolImportFromShareableHandle.mem_pool == NULL) + oss << "mem_pool=NULL"; + else { + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolImportFromShareableHandle.mem_pool__val); + } + oss << ", shared_handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolImportFromShareableHandle.shared_handle); + oss << ", handle_type="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolImportFromShareableHandle.handle_type); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolImportFromShareableHandle.flags); + oss << ")"; + break; + case HIP_API_ID_hipMemPoolImportPointer: + oss << "hipMemPoolImportPointer("; + if (data->args.hipMemPoolImportPointer.dev_ptr == NULL) + oss << "dev_ptr=NULL"; + else { + oss << "dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolImportPointer.dev_ptr__val); + } + oss << ", mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolImportPointer.mem_pool); + if (data->args.hipMemPoolImportPointer.export_data == NULL) + oss << ", export_data=NULL"; + else { + oss << ", export_data="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolImportPointer.export_data__val); + } + oss << ")"; + break; + case HIP_API_ID_hipMemPoolSetAccess: + oss << "hipMemPoolSetAccess("; + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolSetAccess.mem_pool); + if (data->args.hipMemPoolSetAccess.desc_list == NULL) + oss << ", desc_list=NULL"; + else { + oss << ", desc_list="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolSetAccess.desc_list__val); + } + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolSetAccess.count); + oss << ")"; + break; + case HIP_API_ID_hipMemPoolSetAttribute: + oss << "hipMemPoolSetAttribute("; + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolSetAttribute.mem_pool); + oss << ", attr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolSetAttribute.attr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolSetAttribute.value); + oss << ")"; + break; + case HIP_API_ID_hipMemPoolTrimTo: + oss << "hipMemPoolTrimTo("; + oss << "mem_pool="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolTrimTo.mem_pool); + oss << ", min_bytes_to_hold="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPoolTrimTo.min_bytes_to_hold); + oss << ")"; + break; + case HIP_API_ID_hipMemPrefetchAsync: + oss << "hipMemPrefetchAsync("; + oss << "dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPrefetchAsync.dev_ptr); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPrefetchAsync.count); + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPrefetchAsync.device); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPrefetchAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemPtrGetInfo: + oss << "hipMemPtrGetInfo("; + oss << "ptr="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemPtrGetInfo.ptr); + if (data->args.hipMemPtrGetInfo.size == NULL) + oss << ", size=NULL"; + else { + oss << ", size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemPtrGetInfo.size__val); + } + oss << ")"; + break; + case HIP_API_ID_hipMemRangeGetAttribute: + oss << "hipMemRangeGetAttribute("; + oss << "data="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttribute.data); + oss << ", data_size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttribute.data_size); + oss << ", attribute="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttribute.attribute); + oss << ", dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttribute.dev_ptr); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttribute.count); + oss << ")"; + break; + case HIP_API_ID_hipMemRangeGetAttributes: + oss << "hipMemRangeGetAttributes("; + if (data->args.hipMemRangeGetAttributes.data == NULL) + oss << "data=NULL"; + else { + oss << "data="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttributes.data__val); + } + if (data->args.hipMemRangeGetAttributes.data_sizes == NULL) + oss << ", data_sizes=NULL"; + else { + oss << ", data_sizes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttributes.data_sizes__val); + } + if (data->args.hipMemRangeGetAttributes.attributes == NULL) + oss << ", attributes=NULL"; + else { + oss << ", attributes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttributes.attributes__val); + } + oss << ", num_attributes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttributes.num_attributes); + oss << ", dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttributes.dev_ptr); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRangeGetAttributes.count); + oss << ")"; + break; + case HIP_API_ID_hipMemRelease: + oss << "hipMemRelease("; + oss << "handle="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemRelease.handle); + oss << ")"; + break; + case HIP_API_ID_hipMemRetainAllocationHandle: + oss << "hipMemRetainAllocationHandle("; + if (data->args.hipMemRetainAllocationHandle.handle == NULL) + oss << "handle=NULL"; + else { + oss << "handle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRetainAllocationHandle.handle__val); + } + oss << ", addr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemRetainAllocationHandle.addr); + oss << ")"; + break; + case HIP_API_ID_hipMemSetAccess: + oss << "hipMemSetAccess("; + oss << "ptr="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemSetAccess.ptr); + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemSetAccess.size); + if (data->args.hipMemSetAccess.desc == NULL) + oss << ", desc=NULL"; + else { + oss << ", desc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemSetAccess.desc__val); + } + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemSetAccess.count); + oss << ")"; + break; + case HIP_API_ID_hipMemUnmap: + oss << "hipMemUnmap("; + oss << "ptr="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemUnmap.ptr); + oss << ", size="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemUnmap.size); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy: + oss << "hipMemcpy("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy.sizeBytes); + oss << ", kind="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy2D: + oss << "hipMemcpy2D("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2D.dst); + oss << ", dpitch="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy2D.dpitch); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemcpy2D.src); + oss << ", spitch="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy2D.spitch); + oss << ", width="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy2D.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy2D.height); + oss << ", kind="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy2D.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy2DArrayToArray: + oss << "hipMemcpy2DArrayToArray("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.dst); + oss << ", wOffsetDst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.wOffsetDst); + oss << ", hOffsetDst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.hOffsetDst); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.src); + oss << ", wOffsetSrc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.wOffsetSrc); + oss << ", hOffsetSrc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.hOffsetSrc); + oss << ", width="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.height); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DArrayToArray.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy2DAsync: + oss << "hipMemcpy2DAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy2DAsync.dst); + oss << ", dpitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DAsync.dpitch); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy2DAsync.src); + oss << ", spitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DAsync.spitch); + oss << ", width="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DAsync.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DAsync.height); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DAsync.kind); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy2DFromArray: + oss << "hipMemcpy2DFromArray("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArray.dst); + oss << ", dpitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArray.dpitch); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArray.src); + oss << ", wOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArray.wOffset); + oss << ", hOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArray.hOffset); + oss << ", width="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArray.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArray.height); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArray.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy2DFromArrayAsync: + oss << "hipMemcpy2DFromArrayAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.dst); + oss << ", dpitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.dpitch); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.src); + oss << ", wOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.wOffset); + oss << ", hOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.hOffset); + oss << ", width="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.height); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.kind); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DFromArrayAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy2DToArray: + oss << "hipMemcpy2DToArray("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArray.dst); + oss << ", wOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArray.wOffset); + oss << ", hOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArray.hOffset); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArray.src); + oss << ", spitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArray.spitch); + oss << ", width="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArray.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArray.height); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArray.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy2DToArrayAsync: + oss << "hipMemcpy2DToArrayAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.dst); + oss << ", wOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.wOffset); + oss << ", hOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.hOffset); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.src); + oss << ", spitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.spitch); + oss << ", width="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.height); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.kind); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy2DToArrayAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpy3D: + oss << "hipMemcpy3D("; + if (data->args.hipMemcpy3D.p == NULL) + oss << "p=NULL"; + else { + oss << "p="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpy3D.p__val); + } + oss << ")"; + break; + case HIP_API_ID_hipMemcpy3DAsync: + oss << "hipMemcpy3DAsync("; + if (data->args.hipMemcpy3DAsync.p == NULL) + oss << "p=NULL"; + else { + oss << "p="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy3DAsync.p__val); + } + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpy3DAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyAsync: + oss << "hipMemcpyAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyAsync.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyAsync.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAsync.sizeBytes); + oss << ", kind="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyAsync.kind); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyAtoA: + oss << "hipMemcpyAtoA("; + oss << "dstArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoA.dstArray); + oss << ", dstOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoA.dstOffset); + oss << ", srcArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoA.srcArray); + oss << ", srcOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoA.srcOffset); + oss << ", ByteCount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoA.ByteCount); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyAtoD: + oss << "hipMemcpyAtoD("; + oss << "dstDevice="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoD.dstDevice); + oss << ", srcArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoD.srcArray); + oss << ", srcOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoD.srcOffset); + oss << ", ByteCount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoD.ByteCount); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyAtoH: + oss << "hipMemcpyAtoH("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyAtoH.dst); + oss << ", srcArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoH.srcArray); + oss << ", srcOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoH.srcOffset); + oss << ", count="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyAtoH.count); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyAtoHAsync: + oss << "hipMemcpyAtoHAsync("; + oss << "dstHost="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoHAsync.dstHost); + oss << ", srcArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoHAsync.srcArray); + oss << ", srcOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoHAsync.srcOffset); + oss << ", ByteCount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoHAsync.ByteCount); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyAtoHAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyDtoA: + oss << "hipMemcpyDtoA("; + oss << "dstArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoA.dstArray); + oss << ", dstOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoA.dstOffset); + oss << ", srcDevice="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoA.srcDevice); + oss << ", ByteCount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoA.ByteCount); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyDtoD: + oss << "hipMemcpyDtoD("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyDtoD.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyDtoD.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoD.sizeBytes); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyDtoDAsync: + oss << "hipMemcpyDtoDAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoDAsync.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoDAsync.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoDAsync.sizeBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoDAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyDtoH: + oss << "hipMemcpyDtoH("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyDtoH.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyDtoH.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoH.sizeBytes); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyDtoHAsync: + oss << "hipMemcpyDtoHAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoHAsync.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoHAsync.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoHAsync.sizeBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyDtoHAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyFromArray: + oss << "hipMemcpyFromArray("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromArray.dst); + oss << ", srcArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromArray.srcArray); + oss << ", wOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromArray.wOffset); + oss << ", hOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromArray.hOffset); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromArray.count); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromArray.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyFromSymbol: + oss << "hipMemcpyFromSymbol("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbol.dst); + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbol.symbol); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbol.sizeBytes); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbol.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbol.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyFromSymbolAsync: + oss << "hipMemcpyFromSymbolAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbolAsync.dst); + oss << ", symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbolAsync.symbol); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbolAsync.sizeBytes); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbolAsync.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbolAsync.kind); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyFromSymbolAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyHtoA: + oss << "hipMemcpyHtoA("; + oss << "dstArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoA.dstArray); + oss << ", dstOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoA.dstOffset); + oss << ", srcHost="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoA.srcHost); + oss << ", count="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyHtoA.count); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyHtoAAsync: + oss << "hipMemcpyHtoAAsync("; + oss << "dstArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoAAsync.dstArray); + oss << ", dstOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoAAsync.dstOffset); + oss << ", srcHost="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoAAsync.srcHost); + oss << ", ByteCount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoAAsync.ByteCount); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoAAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyHtoD: + oss << "hipMemcpyHtoD("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyHtoD.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyHtoD.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoD.sizeBytes); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyHtoDAsync: + oss << "hipMemcpyHtoDAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoDAsync.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoDAsync.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoDAsync.sizeBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyHtoDAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyParam2D: + oss << "hipMemcpyParam2D("; + if (data->args.hipMemcpyParam2D.pCopy == NULL) + oss << "pCopy=NULL"; + else { + oss << "pCopy="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyParam2D.pCopy__val); + } + oss << ")"; + break; + case HIP_API_ID_hipMemcpyParam2DAsync: + oss << "hipMemcpyParam2DAsync("; + if (data->args.hipMemcpyParam2DAsync.pCopy == NULL) + oss << "pCopy=NULL"; + else { + oss << "pCopy="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyParam2DAsync.pCopy__val); + } + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyParam2DAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyPeer: + oss << "hipMemcpyPeer("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyPeer.dst); + oss << ", dstDeviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeer.dstDeviceId); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyPeer.src); + oss << ", srcDeviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeer.srcDeviceId); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeer.sizeBytes); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyPeerAsync: + oss << "hipMemcpyPeerAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeerAsync.dst); + oss << ", dstDeviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeerAsync.dstDeviceId); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeerAsync.src); + oss << ", srcDevice="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeerAsync.srcDevice); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeerAsync.sizeBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyPeerAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyToArray: + oss << "hipMemcpyToArray("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyToArray.dst); + oss << ", wOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToArray.wOffset); + oss << ", hOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToArray.hOffset); + oss << ", src="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemcpyToArray.src); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToArray.count); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToArray.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyToSymbol: + oss << "hipMemcpyToSymbol("; + oss << "symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbol.symbol); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbol.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbol.sizeBytes); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbol.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbol.kind); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyToSymbolAsync: + oss << "hipMemcpyToSymbolAsync("; + oss << "symbol="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbolAsync.symbol); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbolAsync.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbolAsync.sizeBytes); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbolAsync.offset); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbolAsync.kind); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyToSymbolAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemcpyWithStream: + oss << "hipMemcpyWithStream("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyWithStream.dst); + oss << ", src="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyWithStream.src); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyWithStream.sizeBytes); + oss << ", kind="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyWithStream.kind); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemcpyWithStream.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemset: + oss << "hipMemset("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemset.dst); + oss << ", value="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemset.value); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemset.sizeBytes); + oss << ")"; + break; + case HIP_API_ID_hipMemset2D: + oss << "hipMemset2D("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, data->args.hipMemset2D.dst); + oss << ", pitch="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemset2D.pitch); + oss << ", value="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemset2D.value); + oss << ", width="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemset2D.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemset2D.height); + oss << ")"; + break; + case HIP_API_ID_hipMemset2DAsync: + oss << "hipMemset2DAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemset2DAsync.dst); + oss << ", pitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset2DAsync.pitch); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset2DAsync.value); + oss << ", width="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset2DAsync.width); + oss << ", height="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset2DAsync.height); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset2DAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemset3D: + oss << "hipMemset3D("; + oss << "pitchedDevPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset3D.pitchedDevPtr); + oss << ", value="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemset3D.value); + oss << ", extent="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemset3D.extent); + oss << ")"; + break; + case HIP_API_ID_hipMemset3DAsync: + oss << "hipMemset3DAsync("; + oss << "pitchedDevPtr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset3DAsync.pitchedDevPtr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset3DAsync.value); + oss << ", extent="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset3DAsync.extent); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemset3DAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemsetAsync: + oss << "hipMemsetAsync("; + oss << "dst="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetAsync.dst); + oss << ", value="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetAsync.value); + oss << ", sizeBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetAsync.sizeBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemsetD16: + oss << "hipMemsetD16("; + oss << "dest="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD16.dest); + oss << ", value="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD16.value); + oss << ", count="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD16.count); + oss << ")"; + break; + case HIP_API_ID_hipMemsetD16Async: + oss << "hipMemsetD16Async("; + oss << "dest="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD16Async.dest); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD16Async.value); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD16Async.count); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD16Async.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemsetD32: + oss << "hipMemsetD32("; + oss << "dest="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD32.dest); + oss << ", value="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD32.value); + oss << ", count="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD32.count); + oss << ")"; + break; + case HIP_API_ID_hipMemsetD32Async: + oss << "hipMemsetD32Async("; + oss << "dst="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD32Async.dst); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD32Async.value); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD32Async.count); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD32Async.stream); + oss << ")"; + break; + case HIP_API_ID_hipMemsetD8: + oss << "hipMemsetD8("; + oss << "dest="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD8.dest); + oss << ", value="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD8.value); + oss << ", count="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipMemsetD8.count); + oss << ")"; + break; + case HIP_API_ID_hipMemsetD8Async: + oss << "hipMemsetD8Async("; + oss << "dest="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD8Async.dest); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD8Async.value); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD8Async.count); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMemsetD8Async.stream); + oss << ")"; + break; + case HIP_API_ID_hipMipmappedArrayCreate: + oss << "hipMipmappedArrayCreate("; + if (data->args.hipMipmappedArrayCreate.pHandle == NULL) + oss << "pHandle=NULL"; + else { + oss << "pHandle="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMipmappedArrayCreate.pHandle__val); + } + if (data->args.hipMipmappedArrayCreate.pMipmappedArrayDesc == NULL) + oss << ", pMipmappedArrayDesc=NULL"; + else { + oss << ", pMipmappedArrayDesc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMipmappedArrayCreate.pMipmappedArrayDesc__val); + } + oss << ", numMipmapLevels="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMipmappedArrayCreate.numMipmapLevels); + oss << ")"; + break; + case HIP_API_ID_hipMipmappedArrayDestroy: + oss << "hipMipmappedArrayDestroy("; + oss << "hMipmappedArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMipmappedArrayDestroy.hMipmappedArray); + oss << ")"; + break; + case HIP_API_ID_hipMipmappedArrayGetLevel: + oss << "hipMipmappedArrayGetLevel("; + if (data->args.hipMipmappedArrayGetLevel.pLevelArray == NULL) + oss << "pLevelArray=NULL"; + else { + oss << "pLevelArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMipmappedArrayGetLevel.pLevelArray__val); + } + oss << ", hMipMappedArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMipmappedArrayGetLevel.hMipMappedArray); + oss << ", level="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipMipmappedArrayGetLevel.level); + oss << ")"; + break; + case HIP_API_ID_hipModuleGetFunction: + oss << "hipModuleGetFunction("; + if (data->args.hipModuleGetFunction.function == NULL) + oss << "function=NULL"; + else { + oss << "function="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetFunction.function__val); + } + oss << ", module="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetFunction.module); + if (data->args.hipModuleGetFunction.kname == NULL) + oss << ", kname=NULL"; + else { + oss << ", kname="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetFunction.kname__val); + } + oss << ")"; + break; + case HIP_API_ID_hipModuleGetGlobal: + oss << "hipModuleGetGlobal("; + if (data->args.hipModuleGetGlobal.dptr == NULL) + oss << "dptr=NULL"; + else { + oss << "dptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetGlobal.dptr__val); + } + if (data->args.hipModuleGetGlobal.bytes == NULL) + oss << ", bytes=NULL"; + else { + oss << ", bytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetGlobal.bytes__val); + } + oss << ", hmod="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetGlobal.hmod); + if (data->args.hipModuleGetGlobal.name == NULL) + oss << ", name=NULL"; + else { + oss << ", name="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetGlobal.name__val); + } + oss << ")"; + break; + case HIP_API_ID_hipModuleGetTexRef: + oss << "hipModuleGetTexRef("; + if (data->args.hipModuleGetTexRef.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, (void *)data->args.hipModuleGetTexRef.texRef__val); + } + oss << ", hmod="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetTexRef.hmod); + if (data->args.hipModuleGetTexRef.name == NULL) + oss << ", name=NULL"; + else { + oss << ", name="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleGetTexRef.name__val); + } + oss << ")"; + break; + case HIP_API_ID_hipModuleLaunchCooperativeKernel: + oss << "hipModuleLaunchCooperativeKernel("; + oss << "f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.f); + oss << ", gridDimX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.gridDimX); + oss << ", gridDimY="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.gridDimY); + oss << ", gridDimZ="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.gridDimZ); + oss << ", blockDimX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.blockDimX); + oss << ", blockDimY="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.blockDimY); + oss << ", blockDimZ="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.blockDimZ); + oss << ", sharedMemBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.sharedMemBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.stream); + if (data->args.hipModuleLaunchCooperativeKernel.kernelParams == NULL) + oss << ", kernelParams=NULL"; + else { + oss << ", kernelParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernel.kernelParams__val); + } + oss << ")"; + break; + case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: + oss << "hipModuleLaunchCooperativeKernelMultiDevice("; + if (data->args.hipModuleLaunchCooperativeKernelMultiDevice + .launchParamsList == NULL) + oss << "launchParamsList=NULL"; + else { + oss << "launchParamsList="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernelMultiDevice + .launchParamsList__val); + } + oss << ", numDevices="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernelMultiDevice.numDevices); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchCooperativeKernelMultiDevice.flags); + oss << ")"; + break; + case HIP_API_ID_hipModuleLaunchKernel: + oss << "hipModuleLaunchKernel("; + oss << "f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.f); + oss << ", gridDimX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.gridDimX); + oss << ", gridDimY="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.gridDimY); + oss << ", gridDimZ="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.gridDimZ); + oss << ", blockDimX="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.blockDimX); + oss << ", blockDimY="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.blockDimY); + oss << ", blockDimZ="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.blockDimZ); + oss << ", sharedMemBytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.sharedMemBytes); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.stream); + if (data->args.hipModuleLaunchKernel.kernelParams == NULL) + oss << ", kernelParams=NULL"; + else { + oss << ", kernelParams="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.kernelParams__val); + } + if (data->args.hipModuleLaunchKernel.extra == NULL) + oss << ", extra=NULL"; + else { + oss << ", extra="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLaunchKernel.extra__val); + } + oss << ")"; + break; + case HIP_API_ID_hipModuleLoad: + oss << "hipModuleLoad("; + if (data->args.hipModuleLoad.module == NULL) + oss << "module=NULL"; + else { + oss << "module="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoad.module__val); + } + if (data->args.hipModuleLoad.fname == NULL) + oss << ", fname=NULL"; + else { + oss << ", fname="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoad.fname__val); + } + oss << ")"; + break; + case HIP_API_ID_hipModuleLoadData: + oss << "hipModuleLoadData("; + if (data->args.hipModuleLoadData.module == NULL) + oss << "module=NULL"; + else { + oss << "module="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoadData.module__val); + } + oss << ", image="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoadData.image); + oss << ")"; + break; + case HIP_API_ID_hipModuleLoadDataEx: + oss << "hipModuleLoadDataEx("; + if (data->args.hipModuleLoadDataEx.module == NULL) + oss << "module=NULL"; + else { + oss << "module="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoadDataEx.module__val); + } + oss << ", image="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoadDataEx.image); + oss << ", numOptions="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoadDataEx.numOptions); + if (data->args.hipModuleLoadDataEx.options == NULL) + oss << ", options=NULL"; + else { + oss << ", options="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoadDataEx.options__val); + } + if (data->args.hipModuleLoadDataEx.optionsValues == NULL) + oss << ", optionsValues=NULL"; + else { + oss << ", optionsValues="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleLoadDataEx.optionsValues__val); + } + oss << ")"; + break; + case HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessor: + oss << "hipModuleOccupancyMaxActiveBlocksPerMultiprocessor("; + if (data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor + .numBlocks == NULL) + oss << "numBlocks=NULL"; + else { + oss << "numBlocks="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor + .numBlocks__val); + } + oss << ", f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor.f); + oss << ", blockSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor + .blockSize); + oss << ", dynSharedMemPerBlk="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessor + .dynSharedMemPerBlk); + oss << ")"; + break; + case HIP_API_ID_hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: + oss << "hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags("; + if (data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks == NULL) + oss << "numBlocks=NULL"; + else { + oss << "numBlocks="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks__val); + } + oss << ", f="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .f); + oss << ", blockSize="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .blockSize); + oss << ", dynSharedMemPerBlk="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .dynSharedMemPerBlk); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .flags); + oss << ")"; + break; + case HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSize: + oss << "hipModuleOccupancyMaxPotentialBlockSize("; + if (data->args.hipModuleOccupancyMaxPotentialBlockSize.gridSize == NULL) + oss << "gridSize=NULL"; + else { + oss << "gridSize="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipModuleOccupancyMaxPotentialBlockSize.gridSize__val); + } + if (data->args.hipModuleOccupancyMaxPotentialBlockSize.blockSize == NULL) + oss << ", blockSize=NULL"; + else { + oss << ", blockSize="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipModuleOccupancyMaxPotentialBlockSize.blockSize__val); + } + oss << ", f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxPotentialBlockSize.f); + oss << ", dynSharedMemPerBlk="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipModuleOccupancyMaxPotentialBlockSize.dynSharedMemPerBlk); + oss << ", blockSizeLimit="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxPotentialBlockSize.blockSizeLimit); + oss << ")"; + break; + case HIP_API_ID_hipModuleOccupancyMaxPotentialBlockSizeWithFlags: + oss << "hipModuleOccupancyMaxPotentialBlockSizeWithFlags("; + if (data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.gridSize == + NULL) + oss << "gridSize=NULL"; + else { + oss << "gridSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags + .gridSize__val); + } + if (data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.blockSize == + NULL) + oss << ", blockSize=NULL"; + else { + oss << ", blockSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags + .blockSize__val); + } + oss << ", f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.f); + oss << ", dynSharedMemPerBlk="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags + .dynSharedMemPerBlk); + oss << ", blockSizeLimit="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags + .blockSizeLimit); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleOccupancyMaxPotentialBlockSizeWithFlags.flags); + oss << ")"; + break; + case HIP_API_ID_hipModuleUnload: + oss << "hipModuleUnload("; + oss << "module="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipModuleUnload.module); + oss << ")"; + break; + case HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessor: + oss << "hipOccupancyMaxActiveBlocksPerMultiprocessor("; + if (data->args.hipOccupancyMaxActiveBlocksPerMultiprocessor.numBlocks == + NULL) + oss << "numBlocks=NULL"; + else { + oss << "numBlocks="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxActiveBlocksPerMultiprocessor + .numBlocks__val); + } + oss << ", f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxActiveBlocksPerMultiprocessor.f); + oss << ", blockSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxActiveBlocksPerMultiprocessor.blockSize); + oss << ", dynamicSMemSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxActiveBlocksPerMultiprocessor + .dynamicSMemSize); + oss << ")"; + break; + case HIP_API_ID_hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: + oss << "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags("; + if (data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks == NULL) + oss << "numBlocks=NULL"; + else { + oss << "numBlocks="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .numBlocks__val); + } + oss << ", f="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags.f); + oss << ", blockSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .blockSize); + oss << ", dynamicSMemSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + .dynamicSMemSize); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags.flags); + oss << ")"; + break; + case HIP_API_ID_hipOccupancyMaxPotentialBlockSize: + oss << "hipOccupancyMaxPotentialBlockSize("; + if (data->args.hipOccupancyMaxPotentialBlockSize.gridSize == NULL) + oss << "gridSize=NULL"; + else { + oss << "gridSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxPotentialBlockSize.gridSize__val); + } + if (data->args.hipOccupancyMaxPotentialBlockSize.blockSize == NULL) + oss << ", blockSize=NULL"; + else { + oss << ", blockSize="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxPotentialBlockSize.blockSize__val); + } + oss << ", f="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxPotentialBlockSize.f); + oss << ", dynSharedMemPerBlk="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxPotentialBlockSize.dynSharedMemPerBlk); + oss << ", blockSizeLimit="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipOccupancyMaxPotentialBlockSize.blockSizeLimit); + oss << ")"; + break; + case HIP_API_ID_hipPeekAtLastError: + oss << "hipPeekAtLastError("; + oss << ")"; + break; + case HIP_API_ID_hipPointerGetAttribute: + oss << "hipPointerGetAttribute("; + oss << "data="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipPointerGetAttribute.data); + oss << ", attribute="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipPointerGetAttribute.attribute); + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipPointerGetAttribute.ptr); + oss << ")"; + break; + case HIP_API_ID_hipPointerGetAttributes: + oss << "hipPointerGetAttributes("; + if (data->args.hipPointerGetAttributes.attributes == NULL) + oss << "attributes=NULL"; + else { + oss << "attributes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipPointerGetAttributes.attributes__val); + } + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipPointerGetAttributes.ptr); + oss << ")"; + break; + case HIP_API_ID_hipPointerSetAttribute: + oss << "hipPointerSetAttribute("; + oss << "value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipPointerSetAttribute.value); + oss << ", attribute="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipPointerSetAttribute.attribute); + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipPointerSetAttribute.ptr); + oss << ")"; + break; + case HIP_API_ID_hipProfilerStart: + oss << "hipProfilerStart("; + oss << ")"; + break; + case HIP_API_ID_hipProfilerStop: + oss << "hipProfilerStop("; + oss << ")"; + break; + case HIP_API_ID_hipRuntimeGetVersion: + oss << "hipRuntimeGetVersion("; + if (data->args.hipRuntimeGetVersion.runtimeVersion == NULL) + oss << "runtimeVersion=NULL"; + else { + oss << "runtimeVersion="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipRuntimeGetVersion.runtimeVersion__val); + } + oss << ")"; + break; + case HIP_API_ID_hipSetDevice: + oss << "hipSetDevice("; + oss << "deviceId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSetDevice.deviceId); + oss << ")"; + break; + case HIP_API_ID_hipSetDeviceFlags: + oss << "hipSetDeviceFlags("; + oss << "flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSetDeviceFlags.flags); + oss << ")"; + break; + case HIP_API_ID_hipSetValidDevices: + oss << "hipSetValidDevices("; + if (data->args.hipSetValidDevices.device_arr == NULL) + oss << "device_arr=NULL"; + else { + oss << "device_arr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSetValidDevices.device_arr__val); + } + oss << ", len="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSetValidDevices.len); + oss << ")"; + break; + case HIP_API_ID_hipSetupArgument: + oss << "hipSetupArgument("; + oss << "arg="; + roctracer::hip_support::detail::operator<<(oss, + data->args.hipSetupArgument.arg); + oss << ", size="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSetupArgument.size); + oss << ", offset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSetupArgument.offset); + oss << ")"; + break; + case HIP_API_ID_hipSignalExternalSemaphoresAsync: + oss << "hipSignalExternalSemaphoresAsync("; + if (data->args.hipSignalExternalSemaphoresAsync.extSemArray == NULL) + oss << "extSemArray=NULL"; + else { + oss << "extSemArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSignalExternalSemaphoresAsync.extSemArray__val); + } + if (data->args.hipSignalExternalSemaphoresAsync.paramsArray == NULL) + oss << ", paramsArray=NULL"; + else { + oss << ", paramsArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSignalExternalSemaphoresAsync.paramsArray__val); + } + oss << ", numExtSems="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSignalExternalSemaphoresAsync.numExtSems); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipSignalExternalSemaphoresAsync.stream); + oss << ")"; + break; + case HIP_API_ID_hipStreamAddCallback: + oss << "hipStreamAddCallback("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamAddCallback.stream); + oss << ", callback="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamAddCallback.callback); + oss << ", userData="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamAddCallback.userData); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamAddCallback.flags); + oss << ")"; + break; + case HIP_API_ID_hipStreamAttachMemAsync: + oss << "hipStreamAttachMemAsync("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamAttachMemAsync.stream); + oss << ", dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamAttachMemAsync.dev_ptr); + oss << ", length="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamAttachMemAsync.length); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamAttachMemAsync.flags); + oss << ")"; + break; + case HIP_API_ID_hipStreamBeginCapture: + oss << "hipStreamBeginCapture("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamBeginCapture.stream); + oss << ", mode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamBeginCapture.mode); + oss << ")"; + break; + case HIP_API_ID_hipStreamBeginCaptureToGraph: + oss << "hipStreamBeginCaptureToGraph("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamBeginCaptureToGraph.stream); + oss << ", graph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamBeginCaptureToGraph.graph); + if (data->args.hipStreamBeginCaptureToGraph.dependencies == NULL) + oss << ", dependencies=NULL"; + else { + oss << ", dependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamBeginCaptureToGraph.dependencies__val); + } + if (data->args.hipStreamBeginCaptureToGraph.dependencyData == NULL) + oss << ", dependencyData=NULL"; + else { + oss << ", dependencyData="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamBeginCaptureToGraph.dependencyData__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamBeginCaptureToGraph.numDependencies); + oss << ", mode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamBeginCaptureToGraph.mode); + oss << ")"; + break; + case HIP_API_ID_hipStreamCreate: + oss << "hipStreamCreate("; + if (data->args.hipStreamCreate.stream == NULL) + oss << "stream=NULL"; + else { + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamCreate.stream__val); + } + oss << ")"; + break; + case HIP_API_ID_hipStreamCreateWithFlags: + oss << "hipStreamCreateWithFlags("; + if (data->args.hipStreamCreateWithFlags.stream == NULL) + oss << "stream=NULL"; + else { + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamCreateWithFlags.stream__val); + } + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamCreateWithFlags.flags); + oss << ")"; + break; + case HIP_API_ID_hipStreamCreateWithPriority: + oss << "hipStreamCreateWithPriority("; + if (data->args.hipStreamCreateWithPriority.stream == NULL) + oss << "stream=NULL"; + else { + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamCreateWithPriority.stream__val); + } + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamCreateWithPriority.flags); + oss << ", priority="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamCreateWithPriority.priority); + oss << ")"; + break; + case HIP_API_ID_hipStreamDestroy: + oss << "hipStreamDestroy("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamDestroy.stream); + oss << ")"; + break; + case HIP_API_ID_hipStreamEndCapture: + oss << "hipStreamEndCapture("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamEndCapture.stream); + if (data->args.hipStreamEndCapture.pGraph == NULL) + oss << ", pGraph=NULL"; + else { + oss << ", pGraph="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamEndCapture.pGraph__val); + } + oss << ")"; + break; + case HIP_API_ID_hipStreamGetCaptureInfo: + oss << "hipStreamGetCaptureInfo("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetCaptureInfo.stream); + if (data->args.hipStreamGetCaptureInfo.pCaptureStatus == NULL) + oss << ", pCaptureStatus=NULL"; + else { + oss << ", pCaptureStatus="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetCaptureInfo.pCaptureStatus__val); + } + if (data->args.hipStreamGetCaptureInfo.pId == NULL) + oss << ", pId=NULL"; + else { + oss << ", pId="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetCaptureInfo.pId__val); + } + oss << ")"; + break; + case HIP_API_ID_hipStreamGetCaptureInfo_v2: + oss << "hipStreamGetCaptureInfo_v2("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetCaptureInfo_v2.stream); + if (data->args.hipStreamGetCaptureInfo_v2.captureStatus_out == NULL) + oss << ", captureStatus_out=NULL"; + else { + oss << ", captureStatus_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetCaptureInfo_v2.captureStatus_out__val); + } + if (data->args.hipStreamGetCaptureInfo_v2.id_out == NULL) + oss << ", id_out=NULL"; + else { + oss << ", id_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetCaptureInfo_v2.id_out__val); + } + if (data->args.hipStreamGetCaptureInfo_v2.graph_out == NULL) + oss << ", graph_out=NULL"; + else { + oss << ", graph_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetCaptureInfo_v2.graph_out__val); + } + if (data->args.hipStreamGetCaptureInfo_v2.dependencies_out == NULL) + oss << ", dependencies_out=NULL"; + else { + oss << ", dependencies_out="; + roctracer::hip_support::detail::operator<<( + oss, + (void *)data->args.hipStreamGetCaptureInfo_v2.dependencies_out__val); + } + if (data->args.hipStreamGetCaptureInfo_v2.numDependencies_out == NULL) + oss << ", numDependencies_out=NULL"; + else { + oss << ", numDependencies_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val); + } + oss << ")"; + break; + case HIP_API_ID_hipStreamGetDevice: + oss << "hipStreamGetDevice("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetDevice.stream); + if (data->args.hipStreamGetDevice.device == NULL) + oss << ", device=NULL"; + else { + oss << ", device="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetDevice.device__val); + } + oss << ")"; + break; + case HIP_API_ID_hipStreamGetFlags: + oss << "hipStreamGetFlags("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetFlags.stream); + if (data->args.hipStreamGetFlags.flags == NULL) + oss << ", flags=NULL"; + else { + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetFlags.flags__val); + } + oss << ")"; + break; + case HIP_API_ID_hipStreamGetPriority: + oss << "hipStreamGetPriority("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetPriority.stream); + if (data->args.hipStreamGetPriority.priority == NULL) + oss << ", priority=NULL"; + else { + oss << ", priority="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamGetPriority.priority__val); + } + oss << ")"; + break; + case HIP_API_ID_hipStreamIsCapturing: + oss << "hipStreamIsCapturing("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamIsCapturing.stream); + if (data->args.hipStreamIsCapturing.pCaptureStatus == NULL) + oss << ", pCaptureStatus=NULL"; + else { + oss << ", pCaptureStatus="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamIsCapturing.pCaptureStatus__val); + } + oss << ")"; + break; + case HIP_API_ID_hipStreamQuery: + oss << "hipStreamQuery("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamQuery.stream); + oss << ")"; + break; + case HIP_API_ID_hipStreamSynchronize: + oss << "hipStreamSynchronize("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamSynchronize.stream); + oss << ")"; + break; + case HIP_API_ID_hipStreamUpdateCaptureDependencies: + oss << "hipStreamUpdateCaptureDependencies("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamUpdateCaptureDependencies.stream); + if (data->args.hipStreamUpdateCaptureDependencies.dependencies == NULL) + oss << ", dependencies=NULL"; + else { + oss << ", dependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamUpdateCaptureDependencies.dependencies__val); + } + oss << ", numDependencies="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamUpdateCaptureDependencies.numDependencies); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamUpdateCaptureDependencies.flags); + oss << ")"; + break; + case HIP_API_ID_hipStreamWaitEvent: + oss << "hipStreamWaitEvent("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitEvent.stream); + oss << ", event="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitEvent.event); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitEvent.flags); + oss << ")"; + break; + case HIP_API_ID_hipStreamWaitValue32: + oss << "hipStreamWaitValue32("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue32.stream); + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue32.ptr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue32.value); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue32.flags); + oss << ", mask="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue32.mask); + oss << ")"; + break; + case HIP_API_ID_hipStreamWaitValue64: + oss << "hipStreamWaitValue64("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue64.stream); + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue64.ptr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue64.value); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue64.flags); + oss << ", mask="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWaitValue64.mask); + oss << ")"; + break; + case HIP_API_ID_hipStreamWriteValue32: + oss << "hipStreamWriteValue32("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWriteValue32.stream); + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWriteValue32.ptr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWriteValue32.value); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWriteValue32.flags); + oss << ")"; + break; + case HIP_API_ID_hipStreamWriteValue64: + oss << "hipStreamWriteValue64("; + oss << "stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWriteValue64.stream); + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWriteValue64.ptr); + oss << ", value="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWriteValue64.value); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipStreamWriteValue64.flags); + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetAddress: + oss << "hipTexRefGetAddress("; + if (data->args.hipTexRefGetAddress.dev_ptr == NULL) + oss << "dev_ptr=NULL"; + else { + oss << "dev_ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetAddress.dev_ptr__val); + } + if (data->args.hipTexRefGetAddress.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetAddress.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetArray: + oss << "hipTexRefGetArray("; + if (data->args.hipTexRefGetArray.pArray == NULL) + oss << "pArray=NULL"; + else { + oss << "pArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetArray.pArray__val); + } + if (data->args.hipTexRefGetArray.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetArray.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetBorderColor: + oss << "hipTexRefGetBorderColor("; + if (data->args.hipTexRefGetBorderColor.pBorderColor == NULL) + oss << "pBorderColor=NULL"; + else { + oss << "pBorderColor="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetBorderColor.pBorderColor__val); + } + if (data->args.hipTexRefGetBorderColor.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetBorderColor.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetFlags: + oss << "hipTexRefGetFlags("; + if (data->args.hipTexRefGetFlags.pFlags == NULL) + oss << "pFlags=NULL"; + else { + oss << "pFlags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetFlags.pFlags__val); + } + if (data->args.hipTexRefGetFlags.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetFlags.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetFormat: + oss << "hipTexRefGetFormat("; + if (data->args.hipTexRefGetFormat.pFormat == NULL) + oss << "pFormat=NULL"; + else { + oss << "pFormat="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetFormat.pFormat__val); + } + if (data->args.hipTexRefGetFormat.pNumChannels == NULL) + oss << ", pNumChannels=NULL"; + else { + oss << ", pNumChannels="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetFormat.pNumChannels__val); + } + if (data->args.hipTexRefGetFormat.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetFormat.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetMaxAnisotropy: + oss << "hipTexRefGetMaxAnisotropy("; + if (data->args.hipTexRefGetMaxAnisotropy.pmaxAnsio == NULL) + oss << "pmaxAnsio=NULL"; + else { + oss << "pmaxAnsio="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetMaxAnisotropy.pmaxAnsio__val); + } + if (data->args.hipTexRefGetMaxAnisotropy.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetMaxAnisotropy.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetMipMappedArray: + oss << "hipTexRefGetMipMappedArray("; + if (data->args.hipTexRefGetMipMappedArray.pArray == NULL) + oss << "pArray=NULL"; + else { + oss << "pArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetMipMappedArray.pArray__val); + } + if (data->args.hipTexRefGetMipMappedArray.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetMipMappedArray.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetMipmapLevelBias: + oss << "hipTexRefGetMipmapLevelBias("; + if (data->args.hipTexRefGetMipmapLevelBias.pbias == NULL) + oss << "pbias=NULL"; + else { + oss << "pbias="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetMipmapLevelBias.pbias__val); + } + if (data->args.hipTexRefGetMipmapLevelBias.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetMipmapLevelBias.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefGetMipmapLevelClamp: + oss << "hipTexRefGetMipmapLevelClamp("; + if (data->args.hipTexRefGetMipmapLevelClamp.pminMipmapLevelClamp == NULL) + oss << "pminMipmapLevelClamp=NULL"; + else { + oss << "pminMipmapLevelClamp="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipTexRefGetMipmapLevelClamp.pminMipmapLevelClamp__val); + } + if (data->args.hipTexRefGetMipmapLevelClamp.pmaxMipmapLevelClamp == NULL) + oss << ", pmaxMipmapLevelClamp=NULL"; + else { + oss << ", pmaxMipmapLevelClamp="; + roctracer::hip_support::detail::operator<<( + oss, + data->args.hipTexRefGetMipmapLevelClamp.pmaxMipmapLevelClamp__val); + } + if (data->args.hipTexRefGetMipmapLevelClamp.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefGetMipmapLevelClamp.texRef__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetAddress: + oss << "hipTexRefSetAddress("; + if (data->args.hipTexRefSetAddress.ByteOffset == NULL) + oss << "ByteOffset=NULL"; + else { + oss << "ByteOffset="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetAddress.ByteOffset__val); + } + if (data->args.hipTexRefSetAddress.texRef == NULL) + oss << ", texRef=NULL"; + else { + oss << ", texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetAddress.texRef__val); + } + oss << ", dptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetAddress.dptr); + oss << ", bytes="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetAddress.bytes); + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetAddress2D: + oss << "hipTexRefSetAddress2D("; + if (data->args.hipTexRefSetAddress2D.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetAddress2D.texRef__val); + } + if (data->args.hipTexRefSetAddress2D.desc == NULL) + oss << ", desc=NULL"; + else { + oss << ", desc="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetAddress2D.desc__val); + } + oss << ", dptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetAddress2D.dptr); + oss << ", Pitch="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetAddress2D.Pitch); + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetArray: + oss << "hipTexRefSetArray("; + if (data->args.hipTexRefSetArray.tex == NULL) + oss << "tex=NULL"; + else { + oss << "tex="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetArray.tex__val); + } + oss << ", array="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetArray.array); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetArray.flags); + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetBorderColor: + oss << "hipTexRefSetBorderColor("; + if (data->args.hipTexRefSetBorderColor.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetBorderColor.texRef__val); + } + if (data->args.hipTexRefSetBorderColor.pBorderColor == NULL) + oss << ", pBorderColor=NULL"; + else { + oss << ", pBorderColor="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetBorderColor.pBorderColor__val); + } + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetFlags: + oss << "hipTexRefSetFlags("; + if (data->args.hipTexRefSetFlags.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetFlags.texRef__val); + } + oss << ", Flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetFlags.Flags); + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetFormat: + oss << "hipTexRefSetFormat("; + if (data->args.hipTexRefSetFormat.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetFormat.texRef__val); + } + oss << ", fmt="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetFormat.fmt); + oss << ", NumPackedComponents="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetFormat.NumPackedComponents); + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetMaxAnisotropy: + oss << "hipTexRefSetMaxAnisotropy("; + if (data->args.hipTexRefSetMaxAnisotropy.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMaxAnisotropy.texRef__val); + } + oss << ", maxAniso="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMaxAnisotropy.maxAniso); + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetMipmapLevelBias: + oss << "hipTexRefSetMipmapLevelBias("; + if (data->args.hipTexRefSetMipmapLevelBias.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMipmapLevelBias.texRef__val); + } + oss << ", bias="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMipmapLevelBias.bias); + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetMipmapLevelClamp: + oss << "hipTexRefSetMipmapLevelClamp("; + if (data->args.hipTexRefSetMipmapLevelClamp.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMipmapLevelClamp.texRef__val); + } + oss << ", minMipMapLevelClamp="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMipmapLevelClamp.minMipMapLevelClamp); + oss << ", maxMipMapLevelClamp="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMipmapLevelClamp.maxMipMapLevelClamp); + oss << ")"; + break; + case HIP_API_ID_hipTexRefSetMipmappedArray: + oss << "hipTexRefSetMipmappedArray("; + if (data->args.hipTexRefSetMipmappedArray.texRef == NULL) + oss << "texRef=NULL"; + else { + oss << "texRef="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMipmappedArray.texRef__val); + } + if (data->args.hipTexRefSetMipmappedArray.mipmappedArray == NULL) + oss << ", mipmappedArray=NULL"; + else { + oss << ", mipmappedArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMipmappedArray.mipmappedArray__val); + } + oss << ", Flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipTexRefSetMipmappedArray.Flags); + oss << ")"; + break; + case HIP_API_ID_hipThreadExchangeStreamCaptureMode: + oss << "hipThreadExchangeStreamCaptureMode("; + if (data->args.hipThreadExchangeStreamCaptureMode.mode == NULL) + oss << "mode=NULL"; + else { + oss << "mode="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipThreadExchangeStreamCaptureMode.mode__val); + } + oss << ")"; + break; + case HIP_API_ID_hipUserObjectCreate: + oss << "hipUserObjectCreate("; + if (data->args.hipUserObjectCreate.object_out == NULL) + oss << "object_out=NULL"; + else { + oss << "object_out="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectCreate.object_out__val); + } + oss << ", ptr="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectCreate.ptr); + oss << ", destroy="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectCreate.destroy); + oss << ", initialRefcount="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectCreate.initialRefcount); + oss << ", flags="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectCreate.flags); + oss << ")"; + break; + case HIP_API_ID_hipUserObjectRelease: + oss << "hipUserObjectRelease("; + oss << "object="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectRelease.object); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectRelease.count); + oss << ")"; + break; + case HIP_API_ID_hipUserObjectRetain: + oss << "hipUserObjectRetain("; + oss << "object="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectRetain.object); + oss << ", count="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipUserObjectRetain.count); + oss << ")"; + break; + case HIP_API_ID_hipWaitExternalSemaphoresAsync: + oss << "hipWaitExternalSemaphoresAsync("; + if (data->args.hipWaitExternalSemaphoresAsync.extSemArray == NULL) + oss << "extSemArray=NULL"; + else { + oss << "extSemArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipWaitExternalSemaphoresAsync.extSemArray__val); + } + if (data->args.hipWaitExternalSemaphoresAsync.paramsArray == NULL) + oss << ", paramsArray=NULL"; + else { + oss << ", paramsArray="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipWaitExternalSemaphoresAsync.paramsArray__val); + } + oss << ", numExtSems="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipWaitExternalSemaphoresAsync.numExtSems); + oss << ", stream="; + roctracer::hip_support::detail::operator<<( + oss, data->args.hipWaitExternalSemaphoresAsync.stream); + oss << ")"; + break; + default: + oss << "unknown"; + }; + return strdup(oss.str().c_str()); +} +#endif // HIP_PROF_HIP_API_STRING +#endif // _HIP_PROF_STR_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_runtime_prof.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_runtime_prof.h new file mode 100644 index 000000000..5e7899099 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hip_runtime_prof.h @@ -0,0 +1,78 @@ +/* +Copyright (c) 2019 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_RUNTIME_PROF_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_RUNTIME_PROF_H + +// HIP ROCclr Op IDs enumeration +enum HipVdiOpId { + kHipVdiOpIdDispatch = 0, + kHipVdiOpIdCopy = 1, + kHipVdiOpIdBarrier = 2, + kHipVdiOpIdNumber = 3 +}; + +// Types of ROCclr commands +enum HipVdiCommandKind { + kHipVdiCommandKernel = 0x11F0, + kHipVdiCommandTask = 0x11F1, + kHipVdiMemcpyDeviceToHost = 0x11F3, + kHipHipVdiMemcpyHostToDevice = 0x11F4, + kHipVdiMemcpyDeviceToDevice = 0x11F5, + kHipVidMemcpyDeviceToHostRect = 0x1201, + kHipVdiMemcpyHostToDeviceRect = 0x1202, + kHipVdiMemcpyDeviceToDeviceRect = 0x1203, + kHipVdiFillMemory = 0x1207, +}; + +/** + * @brief Initializes activity callback + * + * @param [input] id_callback Event ID callback function + * @param [input] op_callback Event operation callback function + * @param [input] arg Arguments passed into callback + * + * @returns None + */ +void hipInitActivityCallback(void *id_callback, void *op_callback, void *arg); + +/** + * @brief Enables activity callback + * + * @param [input] op Operation, which will trigger a callback (@see + * HipVdiOpId) + * @param [input] enable Enable state for the callback + * + * @returns True if successful + */ +bool hipEnableActivityCallback(uint32_t op, bool enable); + +/** + * @brief Returns the description string for the operation kind + * + * @param [input] id Command kind id (@see HipVdiCommandKind) + * + * @returns A pointer to a const string with the command description + */ +const char *hipGetCmdName(uint32_t id); + +#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_RUNTIME_PROF_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/host_defines.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/host_defines.h new file mode 100644 index 000000000..5b6ec68f5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/host_defines.h @@ -0,0 +1,191 @@ +/* +Copyright (c) 2015 - 2022 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file amd_detail/host_defines.h + * @brief TODO-doc + */ + +#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HOST_DEFINES_H +#define HIP_INCLUDE_HIP_AMD_DETAIL_HOST_DEFINES_H + +// Add guard to Generic Grid Launch method +#ifndef GENERIC_GRID_LAUNCH +#define GENERIC_GRID_LAUNCH 1 +#endif + +#if defined(__clang__) && defined(__HIP__) + +namespace __hip_internal { +typedef unsigned char uint8_t; +typedef unsigned short uint16_t; +typedef unsigned int uint32_t; +typedef unsigned long long uint64_t; +typedef signed char int8_t; +typedef signed short int16_t; +typedef signed int int32_t; +typedef signed long long int64_t; + +template struct integral_constant { + static constexpr const _Tp value = __v; + typedef _Tp value_type; + typedef integral_constant type; + constexpr operator value_type() const { return value; } + constexpr value_type operator()() const { return value; } +}; +template +constexpr const _Tp integral_constant<_Tp, __v>::value; + +typedef integral_constant true_type; +typedef integral_constant false_type; + +template using bool_constant = integral_constant; +typedef bool_constant true_type; +typedef bool_constant false_type; + +template struct enable_if {}; +template struct enable_if { + typedef __T type; +}; + +template struct true_or_false_type : public false_type {}; +template <> struct true_or_false_type : public true_type {}; + +template struct is_integral : public false_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; +template <> struct is_integral : public true_type {}; + +template struct is_arithmetic : public false_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; +template <> struct is_arithmetic : public true_type {}; + +template struct is_floating_point : public false_type {}; +template <> struct is_floating_point : public true_type {}; +template <> struct is_floating_point : public true_type {}; +template <> struct is_floating_point : public true_type {}; + +template struct is_same : public false_type {}; +template struct is_same<__T, __T> : public true_type {}; + +template ::value> +struct is_signed : public false_type {}; +template +struct is_signed<_Tp, true> : public true_or_false_type<_Tp(-1) < _Tp(0)> {}; + +template struct char_traits; +template > +class basic_istream; +template > +class basic_ostream; +typedef basic_istream istream; +typedef basic_ostream ostream; + +template +struct is_standard_layout + : public integral_constant {}; + +template +struct is_trivial : public integral_constant {}; + +template struct conditional { + using type = T; +}; +template struct conditional { + using type = F; +}; +} // namespace __hip_internal +typedef __hip_internal::uint8_t __hip_uint8_t; +typedef __hip_internal::uint16_t __hip_uint16_t; +typedef __hip_internal::uint32_t __hip_uint32_t; +typedef __hip_internal::uint64_t __hip_uint64_t; +typedef __hip_internal::int8_t __hip_int8_t; +typedef __hip_internal::int16_t __hip_int16_t; +typedef __hip_internal::int32_t __hip_int32_t; +typedef __hip_internal::int64_t __hip_int64_t; + +#if !__CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__ +#define __host__ __attribute__((host)) +#define __device__ __attribute__((device)) +#define __global__ __attribute__((global)) +#define __shared__ __attribute__((shared)) +#define __constant__ __attribute__((constant)) +#endif // !__CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__ + +#if !defined(__has_feature) || !__has_feature(cuda_noinline_keyword) +#define __noinline__ __attribute__((noinline)) +#endif + +#define __forceinline__ inline __attribute__((always_inline)) + +#if __HIP_NO_IMAGE_SUPPORT +#define __hip_img_chk__ \ + __attribute__(( \ + unavailable("The image/texture API not supported on the device"))) +#else +#define __hip_img_chk__ +#endif + +#else + +// Non-HCC compiler +/** + * Function and kernel markers + */ +#define __host__ +#define __device__ + +#define __global__ + +#define __noinline__ +#define __forceinline__ inline + +#define __shared__ +#define __constant__ + +#define __hip_img_chk__ +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hsa_helpers.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hsa_helpers.hpp new file mode 100644 index 000000000..52d385d5c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/hsa_helpers.hpp @@ -0,0 +1,110 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#pragma once + +#include + +#include +#include +#include + +namespace hip_impl { +inline void *address(hsa_executable_symbol_t x) { + void *r = nullptr; + hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ADDRESS, + &r); + + return r; +} + +inline hsa_agent_t agent(hsa_executable_symbol_t x) { + hsa_agent_t r = {}; + hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_AGENT, &r); + + return r; +} + +inline std::uint32_t group_size(hsa_executable_symbol_t x) { + std::uint32_t r = 0u; + hsa_executable_symbol_get_info( + x, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, &r); + + return r; +} + +inline hsa_isa_t isa(hsa_agent_t x) { + hsa_isa_t r = {}; + hsa_agent_iterate_isas( + x, + [](hsa_isa_t i, void *o) { + *static_cast(o) = i; // Pick the first. + + return HSA_STATUS_INFO_BREAK; + }, + &r); + + return r; +} + +inline std::uint64_t kernel_object(hsa_executable_symbol_t x) { + std::uint64_t r = 0u; + hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, + &r); + + return r; +} + +inline std::string name(hsa_executable_symbol_t x) { + std::uint32_t sz = 0u; + hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_NAME_LENGTH, + &sz); + + std::string r(sz, '\0'); + hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_NAME, + &r.front()); + + return r; +} + +inline std::uint32_t private_size(hsa_executable_symbol_t x) { + std::uint32_t r = 0u; + hsa_executable_symbol_get_info( + x, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, &r); + + return r; +} + +inline std::uint32_t size(hsa_executable_symbol_t x) { + std::uint32_t r = 0; + hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_SIZE, + &r); + + return r; +} + +inline hsa_symbol_kind_t type(hsa_executable_symbol_t x) { + hsa_symbol_kind_t r = {}; + hsa_executable_symbol_get_info(x, HSA_EXECUTABLE_SYMBOL_INFO_TYPE, &r); + + return r; +} +} // namespace hip_impl diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/macro_based_grid_launch.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/macro_based_grid_launch.hpp new file mode 100644 index 000000000..58f38e204 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/macro_based_grid_launch.hpp @@ -0,0 +1,850 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#include "concepts.hpp" +#include "helpers.hpp" + +#include "hc.hpp" +#include "hip/hip_ext.h" +#include "hip_runtime.h" + +#include +#include +#include +#include +#include + +namespace hip_impl { +namespace { +struct New_grid_launch_tag {}; +struct Old_grid_launch_tag {}; + +template class RAII_guard { + D dtor_; + +public: + RAII_guard() = default; + + RAII_guard(const C &ctor, D dtor) : dtor_{std::move(dtor)} { ctor(); } + + RAII_guard(const RAII_guard &) = default; + RAII_guard(RAII_guard &&) = default; + + RAII_guard &operator=(const RAII_guard &) = default; + RAII_guard &operator=(RAII_guard &&) = default; + + ~RAII_guard() { dtor_(); } +}; + +template +RAII_guard make_RAII_guard(const C &ctor, D dtor) { + return RAII_guard{ctor, std::move(dtor)}; +} + +template +using is_new_grid_launch_t = + typename std::conditional{}, New_grid_launch_tag, + Old_grid_launch_tag>::type; +} // namespace + +// TODO: - dispatch rank should be derived from the domain dimensions passed +// in, and not always assumed to be 3; + +template + requires(Domain == {Ts...}) +inline void grid_launch_hip_impl_(New_grid_launch_tag, dim3 num_blocks, + dim3 dim_blocks, int group_mem_bytes, + const hc::accelerator_view &acc_v, K k) { + const auto d = + hc::extent<3>{num_blocks.z * dim_blocks.z, num_blocks.y * dim_blocks.y, + num_blocks.x * dim_blocks.x} + .tile_with_dynamic(dim_blocks.z, dim_blocks.y, dim_blocks.x, + group_mem_bytes); + + try { + hc::parallel_for_each(acc_v, d, k); + } catch (std::exception &ex) { + std::cerr << "Failed in " << __func__ << ", with exception: " << ex.what() + << std::endl; + hip_throw(ex); + } +} + +// TODO: these are workarounds, they should be removed. + +hc::accelerator_view lock_stream_hip_(hipStream_t &, void *&); +void print_prelaunch_trace_(const char *, dim3, dim3, int, hipStream_t); +void unlock_stream_hip_(hipStream_t, void *, const char *, + hc::accelerator_view *); + +template + requires(Domain == {Ts...}) +inline void grid_launch_hip_impl_(New_grid_launch_tag, dim3 num_blocks, + dim3 dim_blocks, int group_mem_bytes, + hipStream_t stream, const char *kernel_name, + K k) { + void *lck_stream = nullptr; + auto acc_v = lock_stream_hip_(stream, lck_stream); + auto stream_guard = make_RAII_guard( + std::bind(print_prelaunch_trace_, kernel_name, num_blocks, dim_blocks, + group_mem_bytes, stream), + std::bind(unlock_stream_hip_, stream, lck_stream, kernel_name, &acc_v)); + + try { + grid_launch_hip_impl_(New_grid_launch_tag{}, std::move(num_blocks), + std::move(dim_blocks), group_mem_bytes, acc_v, + std::move(k)); + } catch (std::exception &ex) { + std::cerr << "Failed in " << __func__ << ", with exception: " << ex.what() + << std::endl; + hip_throw(ex); + } +} + +template + requires(Domain == {hipLaunchParm, Ts...}) +inline void grid_launch_hip_impl_(Old_grid_launch_tag, dim3 num_blocks, + dim3 dim_blocks, int group_mem_bytes, + hipStream_t stream, K k) { + grid_launch_hip_impl_(New_grid_launch_tag{}, std::move(num_blocks), + std::move(dim_blocks), group_mem_bytes, + std::move(stream), std::move(k)); +} + +template + requires(Domain == {hipLaunchParm, Ts...}) +inline void grid_launch_hip_impl_(Old_grid_launch_tag, dim3 num_blocks, + dim3 dim_blocks, int group_mem_bytes, + hipStream_t stream, const char *kernel_name, + K k) { + grid_launch_hip_impl_(New_grid_launch_tag{}, std::move(num_blocks), + std::move(dim_blocks), group_mem_bytes, + std::move(stream), kernel_name, std::move(k)); +} + +template + requires(Domain == {Ts...}) +inline std::enable_if_t::value> +grid_launch_hip_(dim3 num_blocks, dim3 dim_blocks, int group_mem_bytes, + hipStream_t stream, const char *kernel_name, K k) { + grid_launch_hip_impl_(is_new_grid_launch_t{}, std::move(num_blocks), + std::move(dim_blocks), group_mem_bytes, + std::move(stream), kernel_name, std::move(k)); +} + +template + requires(Domain == {Ts...}) +inline std::enable_if_t::value> +grid_launch_hip_(dim3 num_blocks, dim3 dim_blocks, int group_mem_bytes, + hipStream_t stream, K k) { + grid_launch_hip_impl_(is_new_grid_launch_t{}, std::move(num_blocks), + std::move(dim_blocks), group_mem_bytes, + std::move(stream), std::move(k)); +} + +// TODO: these are temporary and purposefully noisy and disruptive. +#define make_kernel_name_hip(k, n) \ + HIP_kernel_functor_name_begin##_##k##_##HIP_kernel_functor_name_end##_##n + +#define make_kernel_functor_hip_30(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15, p16, p17, p18, p19, p20, p21, \ + p22, p23, p24, p25, p26, p27) \ + struct make_kernel_name_hip(function_name, 28) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + std::decay_t _p20_; \ + std::decay_t _p21_; \ + std::decay_t _p22_; \ + std::decay_t _p23_; \ + std::decay_t _p24_; \ + std::decay_t _p25_; \ + std::decay_t _p26_; \ + std::decay_t _p27_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_, _p20_, _p21_, _p22_, _p23_, _p24_, _p25_, \ + _p26_, _p27_); \ + } \ + } +#define make_kernel_functor_hip_29(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15, p16, p17, p18, p19, p20, p21, \ + p22, p23, p24, p25, p26) \ + struct make_kernel_name_hip(function_name, 27) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + std::decay_t _p20_; \ + std::decay_t _p21_; \ + std::decay_t _p22_; \ + std::decay_t _p23_; \ + std::decay_t _p24_; \ + std::decay_t _p25_; \ + std::decay_t _p26_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_, _p20_, _p21_, _p22_, _p23_, _p24_, _p25_, \ + _p26_); \ + } \ + } +#define make_kernel_functor_hip_28( \ + function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, \ + p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25) \ + struct make_kernel_name_hip(function_name, 26) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + std::decay_t _p20_; \ + std::decay_t _p21_; \ + std::decay_t _p22_; \ + std::decay_t _p23_; \ + std::decay_t _p24_; \ + std::decay_t _p25_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_, _p20_, _p21_, _p22_, _p23_, _p24_, _p25_); \ + } \ + } +#define make_kernel_functor_hip_27( \ + function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, \ + p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24) \ + struct make_kernel_name_hip(function_name, 25) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + std::decay_t _p20_; \ + std::decay_t _p21_; \ + std::decay_t _p22_; \ + std::decay_t _p23_; \ + std::decay_t _p24_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_, _p20_, _p21_, _p22_, _p23_, _p24_); \ + } \ + } +#define make_kernel_functor_hip_26( \ + function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, \ + p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23) \ + struct make_kernel_name_hip(function_name, 24) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + std::decay_t _p20_; \ + std::decay_t _p21_; \ + std::decay_t _p22_; \ + std::decay_t _p23_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_, _p20_, _p21_, _p22_, _p23_); \ + } \ + } +#define make_kernel_functor_hip_25( \ + function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, \ + p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22) \ + struct make_kernel_name_hip(function_name, 23) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + std::decay_t _p20_; \ + std::decay_t _p21_; \ + std::decay_t _p22_; \ + __attribute__((used, flatten)) void \ + operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_, _p20_, _p21_, _p22_); \ + } \ + } +#define make_kernel_functor_hip_24(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15, p16, p17, p18, p19, p20, p21) \ + struct make_kernel_name_hip(function_name, 22) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + std::decay_t _p20_; \ + std::decay_t _p21_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_, _p20_, _p21_); \ + } \ + } +#define make_kernel_functor_hip_23(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15, p16, p17, p18, p19, p20) \ + struct make_kernel_name_hip(function_name, 21) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + std::decay_t _p20_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_, _p20_); \ + } \ + } +#define make_kernel_functor_hip_22(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15, p16, p17, p18, p19) \ + struct make_kernel_name_hip(function_name, 20) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + std::decay_t _p19_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_, _p19_); \ + } \ + } +#define make_kernel_functor_hip_21(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15, p16, p17, p18) \ + struct make_kernel_name_hip(function_name, 19) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + std::decay_t _p18_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, \ + _p18_); \ + } \ + } +#define make_kernel_functor_hip_20(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15, p16, p17) \ + struct make_kernel_name_hip(function_name, 18) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + std::decay_t _p17_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_, _p17_); \ + } \ + } +#define make_kernel_functor_hip_19(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15, p16) \ + struct make_kernel_name_hip(function_name, 17) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + std::decay_t _p16_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_, _p16_); \ + } \ + } +#define make_kernel_functor_hip_18(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14, p15) \ + struct make_kernel_name_hip(function_name, 16) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + std::decay_t _p15_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_, _p15_); \ + } \ + } +#define make_kernel_functor_hip_17(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, \ + p14) \ + struct make_kernel_name_hip(function_name, 15) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + std::decay_t _p14_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_, _p14_); \ + } \ + } +#define make_kernel_functor_hip_16(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12, p13) \ + struct make_kernel_name_hip(function_name, 14) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + std::decay_t _p13_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_, _p13_); \ + } \ + } +#define make_kernel_functor_hip_15(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11, p12) \ + struct make_kernel_name_hip(function_name, 13) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + std::decay_t _p12_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_, _p12_); \ + } \ + } +#define make_kernel_functor_hip_14(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10, p11) \ + struct make_kernel_name_hip(function_name, 12) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + std::decay_t _p11_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_, _p11_); \ + } \ + } +#define make_kernel_functor_hip_13(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9, p10) \ + struct make_kernel_name_hip(function_name, 11) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + std::decay_t _p10_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, \ + _p10_); \ + } \ + } +#define make_kernel_functor_hip_12(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8, p9) \ + struct make_kernel_name_hip(function_name, 10) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + std::decay_t _p9_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_); \ + } \ + } +#define make_kernel_functor_hip_11(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7, p8) \ + struct make_kernel_name_hip(function_name, 9) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + std::decay_t _p8_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_); \ + } \ + } +#define make_kernel_functor_hip_10(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6, p7) \ + struct make_kernel_name_hip(function_name, 8) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + std::decay_t _p7_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_); \ + } \ + } +#define make_kernel_functor_hip_9(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5, p6) \ + struct make_kernel_name_hip(function_name, 7) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + std::decay_t _p6_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_); \ + } \ + } +#define make_kernel_functor_hip_8(function_name, kernel_name, p0, p1, p2, p3, \ + p4, p5) \ + struct make_kernel_name_hip(function_name, 6) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + std::decay_t _p5_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_); \ + } \ + } +#define make_kernel_functor_hip_7(function_name, kernel_name, p0, p1, p2, p3, \ + p4) \ + struct make_kernel_name_hip(function_name, 5) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + std::decay_t _p4_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_); \ + } \ + } +#define make_kernel_functor_hip_6(function_name, kernel_name, p0, p1, p2, p3) \ + struct make_kernel_name_hip(function_name, 4) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + std::decay_t _p3_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_, _p3_); \ + } \ + } +#define make_kernel_functor_hip_5(function_name, kernel_name, p0, p1, p2) \ + struct make_kernel_name_hip(function_name, 3) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + std::decay_t _p2_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_, _p2_); \ + } \ + } +#define make_kernel_functor_hip_4(function_name, kernel_name, p0, p1) \ + struct make_kernel_name_hip(function_name, 2) { \ + std::decay_t _p0_; \ + std::decay_t _p1_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_, _p1_); \ + } \ + } +#define fofo(f, n) kernel_prefix_hip##f##kernel_suffix_hip##n +#define make_kernel_functor_hip_3(function_name, kernel_name, p0) \ + struct make_kernel_name_hip(function_name, 1) { \ + std::decay_t _p0_; \ + void operator()(const hc::tiled_index<3> &) const [[hc]] { \ + kernel_name(_p0_); \ + } \ + } +#define make_kernel_functor_hip_2(function_name, kernel_name) \ + struct make_kernel_name_hip(function_name, 0) { \ + void operator()(const hc::tiled_index<3> &) [[hc]] { \ + return kernel_name(hipLaunchParm{}); \ + } \ + } +#define make_kernel_functor_hip_1(...) +#define make_kernel_functor_hip_0(...) +#define make_kernel_functor_hip_(...) \ + overload_macro_hip_(make_kernel_functor_hip_, __VA_ARGS__) + +#define hipLaunchNamedKernelGGL(function_name, kernel_name, num_blocks, \ + dim_blocks, group_mem_bytes, stream, ...) \ + do { \ + make_kernel_functor_hip_(function_name, kernel_name, __VA_ARGS__) \ + hip_kernel_functor_impl_{__VA_ARGS__}; \ + hip_impl::grid_launch_hip_(num_blocks, dim_blocks, group_mem_bytes, \ + stream, #kernel_name, \ + hip_kernel_functor_impl_); \ + } while (0) + +#define hipLaunchKernelGGL(kernel_name, num_blocks, dim_blocks, \ + group_mem_bytes, stream, ...) \ + do { \ + hipLaunchNamedKernelGGL(unnamed, kernel_name, num_blocks, dim_blocks, \ + group_mem_bytes, stream, ##__VA_ARGS__); \ + } while (0) + +#define hipLaunchKernel(kernel_name, num_blocks, dim_blocks, group_mem_bytes, \ + stream, ...) \ + do { \ + hipLaunchKernelGGL(kernel_name, num_blocks, dim_blocks, group_mem_bytes, \ + stream, hipLaunchParm{}, ##__VA_ARGS__); \ + } while (0) +} // namespace hip_impl diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/math_fwd.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/math_fwd.h new file mode 100644 index 000000000..4b201d469 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/math_fwd.h @@ -0,0 +1,313 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#if !defined(__HIPCC_RTC__) +#include "amd_hip_vector_types.h" // For Native_vec_ +#include "host_defines.h" +#endif + +#if defined(__cplusplus) +extern "C" { +#endif + +// DOT FUNCTIONS +#if defined(__clang__) && defined(__HIP__) +__device__ __attribute__((const)) int +__ockl_sdot2(HIP_vector_base::Native_vec_, + HIP_vector_base::Native_vec_, int, bool); + +__device__ __attribute__((const)) unsigned int +__ockl_udot2(HIP_vector_base::Native_vec_, + HIP_vector_base::Native_vec_, unsigned int, + bool); + +__device__ __attribute__((const)) int +__ockl_sdot4(HIP_vector_base::Native_vec_, + HIP_vector_base::Native_vec_, int, bool); + +__device__ __attribute__((const)) unsigned int +__ockl_udot4(HIP_vector_base::Native_vec_, + HIP_vector_base::Native_vec_, unsigned int, + bool); + +__device__ __attribute__((const)) int __ockl_sdot8(int, int, int, bool); + +__device__ __attribute__((const)) unsigned int +__ockl_udot8(unsigned int, unsigned int, unsigned int, bool); +#endif + +#if !__CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__ +// BEGIN FLOAT +__device__ __attribute__((const)) float __ocml_acos_f32(float); +__device__ __attribute__((pure)) float __ocml_acosh_f32(float); +__device__ __attribute__((const)) float __ocml_asin_f32(float); +__device__ __attribute__((pure)) float __ocml_asinh_f32(float); +__device__ __attribute__((const)) float __ocml_atan2_f32(float, float); +__device__ __attribute__((const)) float __ocml_atan_f32(float); +__device__ __attribute__((pure)) float __ocml_atanh_f32(float); +__device__ __attribute__((pure)) float __ocml_cbrt_f32(float); +__device__ __attribute__((const)) float __ocml_ceil_f32(float); +__device__ __attribute__((const)) __device__ float __ocml_copysign_f32(float, + float); +__device__ float __ocml_cos_f32(float); +__device__ float __ocml_native_cos_f32(float); +__device__ __attribute__((pure)) __device__ float __ocml_cosh_f32(float); +__device__ float __ocml_cospi_f32(float); +__device__ float __ocml_i0_f32(float); +__device__ float __ocml_i1_f32(float); +__device__ __attribute__((pure)) float __ocml_erfc_f32(float); +__device__ __attribute__((pure)) float __ocml_erfcinv_f32(float); +__device__ __attribute__((pure)) float __ocml_erfcx_f32(float); +__device__ __attribute__((pure)) float __ocml_erf_f32(float); +__device__ __attribute__((pure)) float __ocml_erfinv_f32(float); +__device__ __attribute__((pure)) float __ocml_exp10_f32(float); +__device__ __attribute__((pure)) float __ocml_native_exp10_f32(float); +__device__ __attribute__((pure)) float __ocml_exp2_f32(float); +__device__ __attribute__((pure)) float __ocml_exp_f32(float); +__device__ __attribute__((pure)) float __ocml_native_exp_f32(float); +__device__ __attribute__((pure)) float __ocml_expm1_f32(float); +__device__ __attribute__((const)) float __ocml_fabs_f32(float); +__device__ __attribute__((const)) float __ocml_fdim_f32(float, float); +__device__ __attribute__((const)) float __ocml_floor_f32(float); +__device__ __attribute__((const)) float __ocml_fma_f32(float, float, float); +__device__ __attribute__((const)) float __ocml_fmax_f32(float, float); +__device__ __attribute__((const)) float __ocml_fmin_f32(float, float); +__device__ __attribute__((const)) __device__ float __ocml_fmod_f32(float, + float); +__device__ float __ocml_frexp_f32(float, + __attribute__((address_space(5))) int *); +__device__ __attribute__((const)) float __ocml_hypot_f32(float, float); +__device__ __attribute__((const)) int __ocml_ilogb_f32(float); +__device__ __attribute__((const)) int __ocml_isfinite_f32(float); +__device__ __attribute__((const)) int __ocml_isinf_f32(float); +__device__ __attribute__((const)) int __ocml_isnan_f32(float); +__device__ float __ocml_j0_f32(float); +__device__ float __ocml_j1_f32(float); +__device__ __attribute__((const)) float __ocml_ldexp_f32(float, int); +__device__ float __ocml_lgamma_f32(float); +__device__ __attribute__((pure)) float __ocml_log10_f32(float); +__device__ __attribute__((pure)) float __ocml_native_log10_f32(float); +__device__ __attribute__((pure)) float __ocml_log1p_f32(float); +__device__ __attribute__((pure)) float __ocml_log2_f32(float); +__device__ __attribute__((pure)) float __ocml_native_log2_f32(float); +__device__ __attribute__((const)) float __ocml_logb_f32(float); +__device__ __attribute__((pure)) float __ocml_log_f32(float); +__device__ __attribute__((pure)) float __ocml_native_log_f32(float); +__device__ float __ocml_modf_f32(float, + __attribute__((address_space(5))) float *); +__device__ __attribute__((const)) float __ocml_nearbyint_f32(float); +__device__ __attribute__((const)) float __ocml_nextafter_f32(float, float); +__device__ __attribute__((const)) float __ocml_len3_f32(float, float, float); +__device__ __attribute__((const)) float __ocml_len4_f32(float, float, float, + float); +__device__ __attribute__((pure)) float __ocml_ncdf_f32(float); +__device__ __attribute__((pure)) float __ocml_ncdfinv_f32(float); +__device__ __attribute__((pure)) float __ocml_pow_f32(float, float); +__device__ __attribute__((pure)) float __ocml_pown_f32(float, int); +__device__ __attribute__((pure)) float __ocml_rcbrt_f32(float); +__device__ __attribute__((const)) float __ocml_remainder_f32(float, float); +__device__ float __ocml_remquo_f32(float, float, + __attribute__((address_space(5))) int *); +__device__ __attribute__((const)) float __ocml_rhypot_f32(float, float); +__device__ __attribute__((const)) float __ocml_rint_f32(float); +__device__ __attribute__((const)) float __ocml_rlen3_f32(float, float, float); +__device__ __attribute__((const)) float __ocml_rlen4_f32(float, float, float, + float); +__device__ __attribute__((const)) float __ocml_round_f32(float); +__device__ __attribute__((pure)) float __ocml_rsqrt_f32(float); +__device__ __attribute__((const)) float __ocml_scalb_f32(float, float); +__device__ __attribute__((const)) float __ocml_scalbn_f32(float, int); +__device__ __attribute__((const)) int __ocml_signbit_f32(float); +__device__ float __ocml_sincos_f32(float, + __attribute__((address_space(5))) float *); +__device__ float __ocml_sincospi_f32(float, + __attribute__((address_space(5))) float *); +__device__ float __ocml_sin_f32(float); +__device__ float __ocml_native_sin_f32(float); +__device__ __attribute__((pure)) float __ocml_sinh_f32(float); +__device__ float __ocml_sinpi_f32(float); +__device__ __attribute__((const)) float __ocml_sqrt_f32(float); +__device__ __attribute__((const)) float __ocml_native_sqrt_f32(float); +__device__ float __ocml_tan_f32(float); +__device__ __attribute__((pure)) float __ocml_tanh_f32(float); +__device__ float __ocml_tgamma_f32(float); +__device__ __attribute__((const)) float __ocml_trunc_f32(float); +__device__ float __ocml_y0_f32(float); +__device__ float __ocml_y1_f32(float); + +// BEGIN INTRINSICS +__device__ __attribute__((const)) float __ocml_add_rte_f32(float, float); +__device__ __attribute__((const)) float __ocml_add_rtn_f32(float, float); +__device__ __attribute__((const)) float __ocml_add_rtp_f32(float, float); +__device__ __attribute__((const)) float __ocml_add_rtz_f32(float, float); +__device__ __attribute__((const)) float __ocml_sub_rte_f32(float, float); +__device__ __attribute__((const)) float __ocml_sub_rtn_f32(float, float); +__device__ __attribute__((const)) float __ocml_sub_rtp_f32(float, float); +__device__ __attribute__((const)) float __ocml_sub_rtz_f32(float, float); +__device__ __attribute__((const)) float __ocml_mul_rte_f32(float, float); +__device__ __attribute__((const)) float __ocml_mul_rtn_f32(float, float); +__device__ __attribute__((const)) float __ocml_mul_rtp_f32(float, float); +__device__ __attribute__((const)) float __ocml_mul_rtz_f32(float, float); +__device__ __attribute__((const)) float __ocml_div_rte_f32(float, float); +__device__ __attribute__((const)) float __ocml_div_rtn_f32(float, float); +__device__ __attribute__((const)) float __ocml_div_rtp_f32(float, float); +__device__ __attribute__((const)) float __ocml_div_rtz_f32(float, float); +__device__ __attribute__((const)) float __ocml_sqrt_rte_f32(float); +__device__ __attribute__((const)) float __ocml_sqrt_rtn_f32(float); +__device__ __attribute__((const)) float __ocml_sqrt_rtp_f32(float); +__device__ __attribute__((const)) float __ocml_sqrt_rtz_f32(float); +__device__ __attribute__((const)) float __ocml_fma_rte_f32(float, float, float); +__device__ __attribute__((const)) float __ocml_fma_rtn_f32(float, float, float); +__device__ __attribute__((const)) float __ocml_fma_rtp_f32(float, float, float); +__device__ __attribute__((const)) float __ocml_fma_rtz_f32(float, float, float); +// END INTRINSICS +// END FLOAT + +// BEGIN DOUBLE +__device__ __attribute__((const)) double __ocml_acos_f64(double); +__device__ __attribute__((pure)) double __ocml_acosh_f64(double); +__device__ __attribute__((const)) double __ocml_asin_f64(double); +__device__ __attribute__((pure)) double __ocml_asinh_f64(double); +__device__ __attribute__((const)) double __ocml_atan2_f64(double, double); +__device__ __attribute__((const)) double __ocml_atan_f64(double); +__device__ __attribute__((pure)) double __ocml_atanh_f64(double); +__device__ __attribute__((pure)) double __ocml_cbrt_f64(double); +__device__ __attribute__((const)) double __ocml_ceil_f64(double); +__device__ __attribute__((const)) double __ocml_copysign_f64(double, double); +__device__ double __ocml_cos_f64(double); +__device__ __attribute__((pure)) double __ocml_cosh_f64(double); +__device__ double __ocml_cospi_f64(double); +__device__ double __ocml_i0_f64(double); +__device__ double __ocml_i1_f64(double); +__device__ __attribute__((pure)) double __ocml_erfc_f64(double); +__device__ __attribute__((pure)) double __ocml_erfcinv_f64(double); +__device__ __attribute__((pure)) double __ocml_erfcx_f64(double); +__device__ __attribute__((pure)) double __ocml_erf_f64(double); +__device__ __attribute__((pure)) double __ocml_erfinv_f64(double); +__device__ __attribute__((pure)) double __ocml_exp10_f64(double); +__device__ __attribute__((pure)) double __ocml_exp2_f64(double); +__device__ __attribute__((pure)) double __ocml_exp_f64(double); +__device__ __attribute__((pure)) double __ocml_expm1_f64(double); +__device__ __attribute__((const)) double __ocml_fabs_f64(double); +__device__ __attribute__((const)) double __ocml_fdim_f64(double, double); +__device__ __attribute__((const)) double __ocml_floor_f64(double); +__device__ __attribute__((const)) double __ocml_fma_f64(double, double, double); +__device__ __attribute__((const)) double __ocml_fmax_f64(double, double); +__device__ __attribute__((const)) double __ocml_fmin_f64(double, double); +__device__ __attribute__((const)) double __ocml_fmod_f64(double, double); +__device__ double __ocml_frexp_f64(double, + __attribute__((address_space(5))) int *); +__device__ __attribute__((const)) double __ocml_hypot_f64(double, double); +__device__ __attribute__((const)) int __ocml_ilogb_f64(double); +__device__ __attribute__((const)) int __ocml_isfinite_f64(double); +__device__ __attribute__((const)) int __ocml_isinf_f64(double); +__device__ __attribute__((const)) int __ocml_isnan_f64(double); +__device__ double __ocml_j0_f64(double); +__device__ double __ocml_j1_f64(double); +__device__ __attribute__((const)) double __ocml_ldexp_f64(double, int); +__device__ double __ocml_lgamma_f64(double); +__device__ __attribute__((pure)) double __ocml_log10_f64(double); +__device__ __attribute__((pure)) double __ocml_log1p_f64(double); +__device__ __attribute__((pure)) double __ocml_log2_f64(double); +__device__ __attribute__((const)) double __ocml_logb_f64(double); +__device__ __attribute__((pure)) double __ocml_log_f64(double); +__device__ double __ocml_modf_f64(double, + __attribute__((address_space(5))) double *); +__device__ __attribute__((const)) double __ocml_nearbyint_f64(double); +__device__ __attribute__((const)) double __ocml_nextafter_f64(double, double); +__device__ __attribute__((const)) double __ocml_len3_f64(double, double, + double); +__device__ __attribute__((const)) double __ocml_len4_f64(double, double, double, + double); +__device__ __attribute__((pure)) double __ocml_ncdf_f64(double); +__device__ __attribute__((pure)) double __ocml_ncdfinv_f64(double); +__device__ __attribute__((pure)) double __ocml_pow_f64(double, double); +__device__ __attribute__((pure)) double __ocml_pown_f64(double, int); +__device__ __attribute__((pure)) double __ocml_rcbrt_f64(double); +__device__ __attribute__((const)) double __ocml_remainder_f64(double, double); +__device__ double __ocml_remquo_f64(double, double, + __attribute__((address_space(5))) int *); +__device__ __attribute__((const)) double __ocml_rhypot_f64(double, double); +__device__ __attribute__((const)) double __ocml_rint_f64(double); +__device__ __attribute__((const)) double __ocml_rlen3_f64(double, double, + double); +__device__ __attribute__((const)) double __ocml_rlen4_f64(double, double, + double, double); +__device__ __attribute__((const)) double __ocml_round_f64(double); +__device__ __attribute__((pure)) double __ocml_rsqrt_f64(double); +__device__ __attribute__((const)) double __ocml_scalb_f64(double, double); +__device__ __attribute__((const)) double __ocml_scalbn_f64(double, int); +__device__ __attribute__((const)) int __ocml_signbit_f64(double); +__device__ double __ocml_sincos_f64(double, + __attribute__((address_space(5))) double *); +__device__ double +__ocml_sincospi_f64(double, __attribute__((address_space(5))) double *); +__device__ double __ocml_sin_f64(double); +__device__ __attribute__((pure)) double __ocml_sinh_f64(double); +__device__ double __ocml_sinpi_f64(double); +__device__ __attribute__((const)) double __ocml_sqrt_f64(double); +__device__ double __ocml_tan_f64(double); +__device__ __attribute__((pure)) double __ocml_tanh_f64(double); +__device__ double __ocml_tgamma_f64(double); +__device__ __attribute__((const)) double __ocml_trunc_f64(double); +__device__ double __ocml_y0_f64(double); +__device__ double __ocml_y1_f64(double); + +// BEGIN INTRINSICS +__device__ __attribute__((const)) double __ocml_add_rte_f64(double, double); +__device__ __attribute__((const)) double __ocml_add_rtn_f64(double, double); +__device__ __attribute__((const)) double __ocml_add_rtp_f64(double, double); +__device__ __attribute__((const)) double __ocml_add_rtz_f64(double, double); +__device__ __attribute__((const)) double __ocml_sub_rte_f64(double, double); +__device__ __attribute__((const)) double __ocml_sub_rtn_f64(double, double); +__device__ __attribute__((const)) double __ocml_sub_rtp_f64(double, double); +__device__ __attribute__((const)) double __ocml_sub_rtz_f64(double, double); +__device__ __attribute__((const)) double __ocml_mul_rte_f64(double, double); +__device__ __attribute__((const)) double __ocml_mul_rtn_f64(double, double); +__device__ __attribute__((const)) double __ocml_mul_rtp_f64(double, double); +__device__ __attribute__((const)) double __ocml_mul_rtz_f64(double, double); +__device__ __attribute__((const)) double __ocml_div_rte_f64(double, double); +__device__ __attribute__((const)) double __ocml_div_rtn_f64(double, double); +__device__ __attribute__((const)) double __ocml_div_rtp_f64(double, double); +__device__ __attribute__((const)) double __ocml_div_rtz_f64(double, double); +__device__ __attribute__((const)) double __ocml_sqrt_rte_f64(double); +__device__ __attribute__((const)) double __ocml_sqrt_rtn_f64(double); +__device__ __attribute__((const)) double __ocml_sqrt_rtp_f64(double); +__device__ __attribute__((const)) double __ocml_sqrt_rtz_f64(double); +__device__ __attribute__((const)) double __ocml_fma_rte_f64(double, double, + double); +__device__ __attribute__((const)) double __ocml_fma_rtn_f64(double, double, + double); +__device__ __attribute__((const)) double __ocml_fma_rtp_f64(double, double, + double); +__device__ __attribute__((const)) double __ocml_fma_rtz_f64(double, double, + double); +// END INTRINSICS +// END DOUBLE + +#endif // !__CLANG_HIP_RUNTIME_WRAPPER_INCLUDED__ + +#if defined(__cplusplus) +} // extern "C" +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/ockl_image.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/ockl_image.h new file mode 100644 index 000000000..d332968a5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/ockl_image.h @@ -0,0 +1,323 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#if !defined(__HIPCC_RTC__) +#include +#endif + +extern "C" { + +#define ADDRESS_SPACE_CONSTANT __attribute__((address_space(4))) + +__device__ float4::Native_vec_ +__ockl_image_load_1D(unsigned int ADDRESS_SPACE_CONSTANT *i, int c); + +__device__ float4::Native_vec_ +__ockl_image_load_1Db(unsigned int ADDRESS_SPACE_CONSTANT *i, int c); + +__device__ float4::Native_vec_ +__ockl_image_load_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_load_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_load_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_load_3D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_load_CM(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, int f); + +__device__ float4::Native_vec_ +__ockl_image_load_CMa(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, int f); + +__device__ float4::Native_vec_ +__ockl_image_load_lod_1D(unsigned int ADDRESS_SPACE_CONSTANT *i, int c, int l); + +__device__ float4::Native_vec_ +__ockl_image_load_lod_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, int l); + +__device__ float4::Native_vec_ +__ockl_image_load_lod_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, int l); + +__device__ float4::Native_vec_ +__ockl_image_load_lod_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, int l); + +__device__ float4::Native_vec_ +__ockl_image_load_lod_3D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, int l); + +__device__ float4::Native_vec_ +__ockl_image_load_lod_CM(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, int f, int l); + +__device__ float4::Native_vec_ +__ockl_image_load_lod_CMa(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, int f, int l); + +__device__ void __ockl_image_store_1D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int c, float4::Native_vec_ p); + +__device__ void __ockl_image_store_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, + float4::Native_vec_ p); + +__device__ void __ockl_image_store_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, + float4::Native_vec_ p); + +__device__ void __ockl_image_store_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, + float4::Native_vec_ p); + +__device__ void __ockl_image_store_3D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, + float4::Native_vec_ p); + +__device__ void __ockl_image_store_CM(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, int f, + float4::Native_vec_ p); + +__device__ void __ockl_image_store_CMa(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, int f, + float4::Native_vec_ p); + +__device__ void +__ockl_image_store_lod_1D(unsigned int ADDRESS_SPACE_CONSTANT *i, int c, int l, + float4::Native_vec_ p); + +__device__ void +__ockl_image_store_lod_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, int l, float4::Native_vec_ p); + +__device__ void +__ockl_image_store_lod_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, int l, float4::Native_vec_ p); + +__device__ void +__ockl_image_store_lod_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, int l, float4::Native_vec_ p); + +__device__ void +__ockl_image_store_lod_3D(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, int l, float4::Native_vec_ p); + +__device__ void +__ockl_image_store_lod_CM(unsigned int ADDRESS_SPACE_CONSTANT *i, + int2::Native_vec_ c, int f, int l, + float4::Native_vec_ p); + +__device__ void +__ockl_image_store_lod_CMa(unsigned int ADDRESS_SPACE_CONSTANT *i, + int4::Native_vec_ c, int f, int l, + float4::Native_vec_ p); + +__device__ float4::Native_vec_ +__ockl_image_sample_1D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, float c); + +__device__ float4::Native_vec_ +__ockl_image_sample_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_sample_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_sample_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_sample_3D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_sample_CM(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_sample_CMa(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_sample_grad_1D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, float c, + float dx, float dy); + +__device__ float4::Native_vec_ +__ockl_image_sample_grad_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c, float dx, float dy); + +__device__ float4::Native_vec_ +__ockl_image_sample_grad_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c, float2::Native_vec_ dx, + float2::Native_vec_ dy); + +__device__ float4::Native_vec_ +__ockl_image_sample_grad_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c, float2::Native_vec_ dx, + float2::Native_vec_ dy); + +__device__ float4::Native_vec_ +__ockl_image_sample_grad_3D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c, float4::Native_vec_ dx, + float4::Native_vec_ dy); + +__device__ float4::Native_vec_ +__ockl_image_sample_lod_1D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, float c, + float l); + +__device__ float4::Native_vec_ +__ockl_image_sample_lod_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c, float l); + +__device__ float4::Native_vec_ +__ockl_image_sample_lod_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c, float l); + +__device__ float4::Native_vec_ +__ockl_image_sample_lod_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c, float l); + +__device__ float4::Native_vec_ +__ockl_image_sample_lod_3D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c, float l); + +__device__ float4::Native_vec_ +__ockl_image_sample_lod_CM(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c, float l); + +__device__ float4::Native_vec_ +__ockl_image_sample_lod_CMa(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float4::Native_vec_ c, float l); + +__device__ float4::Native_vec_ +__ockl_image_gather4r_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_gather4g_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_gather4b_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c); + +__device__ float4::Native_vec_ +__ockl_image_gather4a_2D(unsigned int ADDRESS_SPACE_CONSTANT *i, + unsigned int ADDRESS_SPACE_CONSTANT *s, + float2::Native_vec_ c); + +__device__ int +__ockl_image_channel_data_type_1D(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_1Db(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_2D(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_2Dad(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_2Dd(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_3D(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_CM(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_data_type_CMa(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_1D(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_1Da(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_1Db(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_2D(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_2Da(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_2Dad(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_2Dd(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_3D(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_CM(unsigned int ADDRESS_SPACE_CONSTANT *i); + +__device__ int +__ockl_image_channel_order_CMa(unsigned int ADDRESS_SPACE_CONSTANT *i); +} diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/program_state.hpp b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/program_state.hpp new file mode 100644 index 000000000..31b0ab614 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/program_state.hpp @@ -0,0 +1,106 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include + +struct ihipModuleSymbol_t; +using hipFunction_t = ihipModuleSymbol_t *; + +namespace hip_impl { + +// This section contains internal APIs that +// needs to be exported +#ifdef __GNUC__ +#pragma GCC visibility push(default) +#endif + +struct kernarg_impl; +class kernarg { +public: + kernarg(); + kernarg(kernarg &&); + ~kernarg(); + std::uint8_t *data(); + std::size_t size(); + void reserve(std::size_t); + void resize(std::size_t); + +private: + kernarg_impl *impl; +}; + +class kernargs_size_align; +class program_state_impl; +class program_state { +public: + program_state(); + ~program_state(); + program_state(const program_state &) = delete; + + hipFunction_t kernel_descriptor(std::uintptr_t, hsa_agent_t); + + kernargs_size_align get_kernargs_size_align(std::uintptr_t); + hsa_executable_t load_executable(const char *, const size_t, hsa_executable_t, + hsa_agent_t); + hsa_executable_t load_executable_no_copy(const char *, const size_t, + hsa_executable_t, hsa_agent_t); + + void *global_addr_by_name(const char *name); + +private: + friend class agent_globals_impl; + program_state_impl *impl; +}; + +class kernargs_size_align { +public: + std::size_t size(std::size_t n) const; + std::size_t alignment(std::size_t n) const; + const void *getHandle() const { return handle; }; + +private: + const void *handle; + friend kernargs_size_align + program_state::get_kernargs_size_align(std::uintptr_t); +}; + +#ifdef __GNUC__ +#pragma GCC visibility pop +#endif + +inline __attribute__((visibility("hidden"))) program_state & +get_program_state() { + static program_state ps; + return ps; +} +} // Namespace hip_impl. diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/texture_fetch_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/texture_fetch_functions.h new file mode 100644 index 000000000..220c119c7 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/texture_fetch_functions.h @@ -0,0 +1,485 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#if defined(__cplusplus) + +#if !defined(__HIPCC_RTC__) +#include +#include +#include +#include +#endif // !defined(__HIPCC_RTC__) + +#define TEXTURE_PARAMETERS_INIT \ + unsigned int ADDRESS_SPACE_CONSTANT *i = \ + (unsigned int ADDRESS_SPACE_CONSTANT *)t.textureObject; \ + unsigned int ADDRESS_SPACE_CONSTANT *s = i + HIP_SAMPLER_OBJECT_OFFSET_DWORD; + +template struct __hip_is_tex_surf_scalar_channel_type { + static constexpr bool value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value; +}; + +template struct __hip_is_tex_surf_channel_type { + static constexpr bool value = __hip_is_tex_surf_scalar_channel_type::value; +}; + +template +struct __hip_is_tex_surf_channel_type> { + static constexpr bool value = + __hip_is_tex_surf_scalar_channel_type::value && + ((rank == 1) || (rank == 2) || (rank == 4)); +}; + +template struct __hip_is_tex_normalized_channel_type { + static constexpr bool value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; +}; + +template +struct __hip_is_tex_normalized_channel_type> { + static constexpr bool value = + __hip_is_tex_normalized_channel_type::value && + ((rank == 1) || (rank == 2) || (rank == 4)); +}; + +template +struct __hip_tex_ret { + static_assert(std::is_same::value, "Invalid channel type!"); +}; + +/* + * Map from device function return U to scalar texture type T + */ +template +__forceinline__ __device__ + typename std::enable_if<__hip_is_tex_surf_scalar_channel_type::value, + const T>::type + __hipMapFrom(const U &u) { + if constexpr (sizeof(T) < sizeof(float)) { + union { + U u; + int i; + } d = {u}; + return static_cast(d.i); + } else { // sizeof(T) == sizeof(float) + union { + U u; + T t; + } d = {u}; + return d.t; + } +} + +/* + * Map from device function return U to vector texture type T + */ +template +__forceinline__ __device__ typename std::enable_if< + __hip_is_tex_surf_scalar_channel_type::value, + const T>::type +__hipMapFrom(const U &u) { + if constexpr (sizeof(typename T::value_type) < sizeof(float)) { + union { + U u; + int4 i4; + } d = {u}; + return __hipMapVector(d.i4); + } else { // sizeof(typename T::value_type) == sizeof(float) + union { + U u; + T t; + } d = {u}; + return d.t; + } +} + +/* + * Map from scalar texture type T to device function input U + */ +template +__forceinline__ __device__ + typename std::enable_if<__hip_is_tex_surf_scalar_channel_type::value, + const U>::type + __hipMapTo(const T &t) { + if constexpr (sizeof(T) < sizeof(float)) { + union { + U u; + int i; + } d = {0}; + d.i = static_cast(t); + return d.u; + } else { // sizeof(T) == sizeof(float) + union { + U u; + T t; + } d = {0}; + d.t = t; + return d.u; + } +} + +/* + * Map from vector texture type T to device function input U + */ +template +__forceinline__ __device__ typename std::enable_if< + __hip_is_tex_surf_scalar_channel_type::value, + const U>::type +__hipMapTo(const T &t) { + if constexpr (sizeof(typename T::value_type) < sizeof(float)) { + union { + U u; + int4 i4; + } d = {0}; + d.i4 = __hipMapVector(t); + return d.u; + } else { // sizeof(typename T::value_type) == sizeof(float) + union { + U u; + T t; + } d = {0}; + d.t = t; + return d.u; + } +} + +template +using __hip_tex_ret_t = typename __hip_tex_ret::type; + +template +struct __hip_tex_ret< + T, hipReadModeElementType, + typename std::enable_if<__hip_is_tex_surf_channel_type::value, + bool>::type> { + using type = T; +}; + +template +struct __hip_tex_ret< + HIP_vector_type, hipReadModeElementType, + typename std::enable_if< + __hip_is_tex_surf_channel_type>::value, + bool>::type> { + using type = + HIP_vector_type<__hip_tex_ret_t, rank>; +}; + +template +struct __hip_tex_ret< + T, hipReadModeNormalizedFloat, + typename std::enable_if<__hip_is_tex_normalized_channel_type::value, + bool>::type> { + using type = float; +}; + +template +struct __hip_tex_ret< + HIP_vector_type, hipReadModeNormalizedFloat, + typename std::enable_if< + __hip_is_tex_normalized_channel_type>::value, + bool>::type> { + using type = + HIP_vector_type<__hip_tex_ret_t, rank>; +}; + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex1Dfetch(texture t, int x) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_load_1Db(i, x); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex1D(texture t, float x) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_1D(i, s, x); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex2D(texture t, float x, float y) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_2D(i, s, float2(x, y).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex1DLayered(texture t, float x, + int layer) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_1Da(i, s, float2(x, layer).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex2DLayered(texture t, float x, float y, + int layer) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_2Da(i, s, float4(x, y, layer, 0.0f).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex3D(texture t, float x, float y, float z) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_3D(i, s, float4(x, y, z, 0.0f).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +texCubemap(texture t, float x, float y, + float z) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_CM(i, s, float4(x, y, z, 0.0f).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex1DLod(texture t, float x, float level) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_lod_1D(i, s, x, level); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex2DLod(texture t, float x, float y, + float level) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_lod_2D(i, s, float2(x, y).data, level); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex1DLayeredLod(texture t, float x, + int layer, float level) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_lod_1Da(i, s, float2(x, layer).data, level); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex2DLayeredLod(texture t, float x, + float y, int layer, float level) { + TEXTURE_PARAMETERS_INIT; + auto tmp = + __ockl_image_sample_lod_2Da(i, s, float4(x, y, layer, 0.0f).data, level); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex3DLod(texture t, float x, float y, float z, + float level) { + TEXTURE_PARAMETERS_INIT; + auto tmp = + __ockl_image_sample_lod_3D(i, s, float4(x, y, z, 0.0f).data, level); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +texCubemapLod(texture t, float x, float y, + float z, float level) { + TEXTURE_PARAMETERS_INIT; + auto tmp = + __ockl_image_sample_lod_CM(i, s, float4(x, y, z, 0.0f).data, level); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +texCubemapLayered(texture t, float x, + float y, float z, int layer) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_CMa(i, s, float4(x, y, z, layer).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +texCubemapLayeredLod(texture t, + float x, float y, float z, int layer, float level) { + TEXTURE_PARAMETERS_INIT; + auto tmp = + __ockl_image_sample_lod_CMa(i, s, float4(x, y, z, layer).data, level); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +texCubemapGrad(texture t, float x, float y, + float z, float4 dPdx, float4 dPdy) { + TEXTURE_PARAMETERS_INIT; + // TODO missing in device libs. + // auto tmp = __ockl_image_sample_grad_CM(i, s, float4(x, y, z, 0.0f).data, + // float4(dPdx.x, dPdx.y, dPdx.z, 0.0f).data, float4(dPdy.x, dPdy.y, dPdy.z, + // 0.0f).data); return __hipMapFrom<__hip_tex_ret_t>(tmp); + return {}; +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +texCubemapLayeredGrad(texture t, + float x, float y, float z, int layer, float4 dPdx, + float4 dPdy) { + TEXTURE_PARAMETERS_INIT; + // TODO missing in device libs. + // auto tmp = __ockl_image_sample_grad_CMa(i, s, float4(x, y, z, layer).data, + // float4(dPdx.x, dPdx.y, dPdx.z, 0.0f).data, float4(dPdy.x, dPdy.y, dPdy.z, + // 0.0f).data); return __hipMapFrom<__hip_tex_ret_t>(tmp); + return {}; +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex1DGrad(texture t, float x, float dPdx, + float dPdy) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_grad_1D(i, s, x, dPdx, dPdy); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex2DGrad(texture t, float x, float y, + float2 dPdx, float2 dPdy) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_grad_2D(i, s, float2(x, y).data, + float2(dPdx.x, dPdx.y).data, + float2(dPdy.x, dPdy.y).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex1DLayeredGrad(texture t, float x, + int layer, float dPdx, float dPdy) { + TEXTURE_PARAMETERS_INIT; + auto tmp = + __ockl_image_sample_grad_1Da(i, s, float2(x, layer).data, dPdx, dPdy); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex2DLayeredGrad(texture t, float x, + float y, int layer, float2 dPdx, float2 dPdy) { + TEXTURE_PARAMETERS_INIT; + auto tmp = __ockl_image_sample_grad_2Da(i, s, float4(x, y, layer, 0.0f).data, + float2(dPdx.x, dPdx.y).data, + float2(dPdy.x, dPdy.y).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +static __forceinline__ __device__ __hip_img_chk__ __hip_tex_ret_t +tex3DGrad(texture t, float x, float y, float z, + float4 dPdx, float4 dPdy) { + TEXTURE_PARAMETERS_INIT; + auto tmp = + __ockl_image_sample_grad_3D(i, s, float4(x, y, z, 0.0f).data, + float4(dPdx.x, dPdx.y, dPdx.z, 0.0f).data, + float4(dPdy.x, dPdy.y, dPdy.z, 0.0f).data); + return __hipMapFrom<__hip_tex_ret_t>(tmp); +} + +template +struct __hip_tex2dgather_ret { + static_assert(std::is_same::value, "Invalid channel type!"); +}; + +template +using __hip_tex2dgather_ret_t = + typename __hip_tex2dgather_ret::type; + +template +struct __hip_tex2dgather_ret< + T, hipReadModeElementType, + typename std::enable_if<__hip_is_tex_surf_channel_type::value, + bool>::type> { + using type = HIP_vector_type; +}; + +template +struct __hip_tex2dgather_ret< + HIP_vector_type, hipReadModeElementType, + typename std::enable_if< + __hip_is_tex_surf_channel_type>::value, + bool>::type> { + using type = HIP_vector_type; +}; + +template +struct __hip_tex2dgather_ret< + T, hipReadModeNormalizedFloat, + typename std::enable_if<__hip_is_tex_normalized_channel_type::value, + bool>::type> { + using type = float4; +}; + +template +static __forceinline__ + __device__ __hip_img_chk__ __hip_tex2dgather_ret_t + tex2Dgather(texture t, float x, float y, + int comp = 0) { + TEXTURE_PARAMETERS_INIT; + switch (comp) { + case 1: { + auto tmp = __ockl_image_gather4g_2D(i, s, float2(x, y).data); + return __hipMapFrom<__hip_tex2dgather_ret_t>(tmp); + } + case 2: { + auto tmp = __ockl_image_gather4b_2D(i, s, float2(x, y).data); + return __hipMapFrom<__hip_tex2dgather_ret_t>(tmp); + } + case 3: { + auto tmp = __ockl_image_gather4a_2D(i, s, float2(x, y).data); + return __hipMapFrom<__hip_tex2dgather_ret_t>(tmp); + } + default: { + auto tmp = __ockl_image_gather4r_2D(i, s, float2(x, y).data); + return __hipMapFrom<__hip_tex2dgather_ret_t>(tmp); + } + } + return {}; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/texture_indirect_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/texture_indirect_functions.h new file mode 100644 index 000000000..af32e6849 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/amd_detail/texture_indirect_functions.h @@ -0,0 +1,472 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +#if defined(__cplusplus) + +#if !defined(__HIPCC_RTC__) +#include +#include +#include +#include +#include +#endif // !defined(__HIPCC_RTC__) + +#define TEXTURE_OBJECT_PARAMETERS_INIT \ + unsigned int ADDRESS_SPACE_CONSTANT *i = \ + (unsigned int ADDRESS_SPACE_CONSTANT *)textureObject; \ + unsigned int ADDRESS_SPACE_CONSTANT *s = i + HIP_SAMPLER_OBJECT_OFFSET_DWORD; + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex1Dfetch(hipTextureObject_t textureObject, + int x) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_load_1Db(i, x); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex1Dfetch(T *ptr, hipTextureObject_t textureObject, int x) { + *ptr = tex1Dfetch(textureObject, x); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex1D(hipTextureObject_t textureObject, + float x) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_1D(i, s, x); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex1D(T *ptr, hipTextureObject_t textureObject, float x) { + *ptr = tex1D(textureObject, x); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex2D(hipTextureObject_t textureObject, + float x, float y) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_2D(i, s, float2(x, y).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex2D(T *ptr, hipTextureObject_t textureObject, float x, float y) { + *ptr = tex2D(textureObject, x, y); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex3D(hipTextureObject_t textureObject, + float x, float y, float z) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_3D(i, s, float4(x, y, z, 0.0f).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex3D(T *ptr, hipTextureObject_t textureObject, float x, float y, float z) { + *ptr = tex3D(textureObject, x, y, z); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +tex1DLayered(hipTextureObject_t textureObject, float x, int layer) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_1Da(i, s, float2(x, layer).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex1DLayered(T *ptr, hipTextureObject_t textureObject, float x, int layer) { + *ptr = tex1DLayered(textureObject, x, layer); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +tex2DLayered(hipTextureObject_t textureObject, float x, float y, int layer) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_2Da(i, s, float4(x, y, layer, 0.0f).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex2DLayered(T *ptr, hipTextureObject_t textureObject, float x, float y, + int layer) { + *ptr = tex1DLayered(textureObject, x, y, layer); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T texCubemap(hipTextureObject_t textureObject, + float x, float y, float z) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_CM(i, s, float4(x, y, z, 0.0f).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +texCubemap(T *ptr, hipTextureObject_t textureObject, float x, float y, + float z) { + *ptr = texCubemap(textureObject, x, y, z); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T texCubemapLayered( + hipTextureObject_t textureObject, float x, float y, float z, int layer) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_CMa(i, s, float4(x, y, z, layer).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +texCubemapLayered(T *ptr, hipTextureObject_t textureObject, float x, float y, + float z, int layer) { + *ptr = texCubemapLayered(textureObject, x, y, z, layer); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +tex2Dgather(hipTextureObject_t textureObject, float x, float y, int comp = 0) { + TEXTURE_OBJECT_PARAMETERS_INIT + switch (comp) { + case 1: { + auto tmp = __ockl_image_gather4r_2D(i, s, float2(x, y).data); + return __hipMapFrom(tmp); + break; + } + case 2: { + auto tmp = __ockl_image_gather4g_2D(i, s, float2(x, y).data); + return __hipMapFrom(tmp); + break; + } + case 3: { + auto tmp = __ockl_image_gather4b_2D(i, s, float2(x, y).data); + return __hipMapFrom(tmp); + break; + } + default: { + auto tmp = __ockl_image_gather4a_2D(i, s, float2(x, y).data); + return __hipMapFrom(tmp); + break; + } + } + return {}; +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex2Dgather(T *ptr, hipTextureObject_t textureObject, float x, float y, + int comp = 0) { + *ptr = texCubemapLayered(textureObject, x, y, comp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex1DLod(hipTextureObject_t textureObject, + float x, float level) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_lod_1D(i, s, x, level); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex1DLod(T *ptr, hipTextureObject_t textureObject, float x, float level) { + *ptr = tex1DLod(textureObject, x, level); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex2DLod(hipTextureObject_t textureObject, + float x, float y, float level) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_lod_2D(i, s, float2(x, y).data, level); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex2DLod(T *ptr, hipTextureObject_t textureObject, float x, float y, + float level) { + *ptr = tex2DLod(textureObject, x, y, level); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex3DLod(hipTextureObject_t textureObject, + float x, float y, float z, + float level) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = + __ockl_image_sample_lod_3D(i, s, float4(x, y, z, 0.0f).data, level); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex3DLod(T *ptr, hipTextureObject_t textureObject, float x, float y, float z, + float level) { + *ptr = tex3DLod(textureObject, x, y, z, level); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex1DLayeredLod( + hipTextureObject_t textureObject, float x, int layer, float level) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_1Da(i, s, float2(x, layer).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex1DLayeredLod(T *ptr, hipTextureObject_t textureObject, float x, int layer, + float level) { + *ptr = tex1DLayeredLod(textureObject, x, layer, level); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +tex2DLayeredLod(hipTextureObject_t textureObject, float x, float y, int layer, + float level) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_2Da(i, s, float4(x, y, layer, 0.0f).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex2DLayeredLod(T *ptr, hipTextureObject_t textureObject, float x, float y, + int layer, float level) { + *ptr = tex2DLayeredLod(textureObject, x, y, layer, level); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T texCubemapLod( + hipTextureObject_t textureObject, float x, float y, float z, float level) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = + __ockl_image_sample_lod_CM(i, s, float4(x, y, z, 0.0f).data, level); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +texCubemapLod(T *ptr, hipTextureObject_t textureObject, float x, float y, + float z, float level) { + *ptr = texCubemapLod(textureObject, x, y, z, level); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +texCubemapGrad(hipTextureObject_t textureObject, float x, float y, float z, + float4 dPdx, float4 dPdy) { + TEXTURE_OBJECT_PARAMETERS_INIT + // TODO missing in device libs. + // auto tmp = __ockl_image_sample_grad_CM(i, s, float4(x, y, z, 0.0f).data, + // float4(dPdx.x, dPdx.y, dPdx.z, 0.0f).data, float4(dPdy.x, dPdy.y, dPdy.z, + // 0.0f).data); return __hipMapFrom(tmp); + return {}; +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +texCubemapGrad(T *ptr, hipTextureObject_t textureObject, float x, float y, + float z, float4 dPdx, float4 dPdy) { + *ptr = texCubemapGrad(textureObject, x, y, z, dPdx, dPdy); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +texCubemapLayeredLod(hipTextureObject_t textureObject, float x, float y, + float z, int layer, float level) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = + __ockl_image_sample_lod_CMa(i, s, float4(x, y, z, layer).data, level); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +texCubemapLayeredLod(T *ptr, hipTextureObject_t textureObject, float x, float y, + float z, int layer, float level) { + *ptr = texCubemapLayeredLod(textureObject, x, y, z, layer, level); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex1DGrad(hipTextureObject_t textureObject, + float x, float dPdx, float dPdy) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_grad_1D(i, s, x, dPdx, dPdy); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex1DGrad(T *ptr, hipTextureObject_t textureObject, float x, float dPdx, + float dPdy) { + *ptr = tex1DGrad(textureObject, x, dPdx, dPdy); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex2DGrad(hipTextureObject_t textureObject, + float x, float y, float2 dPdx, + float2 dPdy) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_grad_2D(i, s, float2(x, y).data, + float2(dPdx.x, dPdx.y).data, + float2(dPdy.x, dPdy.y).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex2DGrad(T *ptr, hipTextureObject_t textureObject, float x, float y, + float2 dPdx, float2 dPdy) { + *ptr = tex2DGrad(textureObject, x, y, dPdx, dPdy); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T tex3DGrad(hipTextureObject_t textureObject, + float x, float y, float z, + float4 dPdx, float4 dPdy) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = + __ockl_image_sample_grad_3D(i, s, float4(x, y, z, 0.0f).data, + float4(dPdx.x, dPdx.y, dPdx.z, 0.0f).data, + float4(dPdy.x, dPdy.y, dPdy.z, 0.0f).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex3DGrad(T *ptr, hipTextureObject_t textureObject, float x, float y, float z, + float4 dPdx, float4 dPdy) { + *ptr = tex3DGrad(textureObject, x, y, z, dPdx, dPdy); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +tex1DLayeredGrad(hipTextureObject_t textureObject, float x, int layer, + float dPdx, float dPdy) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = + __ockl_image_sample_grad_1Da(i, s, float2(x, layer).data, dPdx, dPdy); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex1DLayeredGrad(T *ptr, hipTextureObject_t textureObject, float x, int layer, + float dPdx, float dPdy) { + *ptr = tex1DLayeredGrad(textureObject, x, layer, dPdx, dPdy); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +tex2DLayeredGrad(hipTextureObject_t textureObject, float x, float y, int layer, + float2 dPdx, float2 dPdy) { + TEXTURE_OBJECT_PARAMETERS_INIT + auto tmp = __ockl_image_sample_grad_2Da(i, s, float4(x, y, layer, 0.0f).data, + float2(dPdx.x, dPdx.y).data, + float2(dPdy.x, dPdy.y).data); + return __hipMapFrom(tmp); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +tex2DLayeredGrad(T *ptr, hipTextureObject_t textureObject, float x, float y, + int layer, float2 dPdx, float2 dPdy) { + *ptr = tex2DLayeredGrad(textureObject, x, y, layer, dPdx, dPdy); +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ T +texCubemapLayeredGrad(hipTextureObject_t textureObject, float x, float y, + float z, int layer, float4 dPdx, float4 dPdy) { + TEXTURE_OBJECT_PARAMETERS_INIT + // TODO missing in device libs. + // auto tmp = __ockl_image_sample_grad_CMa(i, s, float4(x, y, z, layer).data, + // float4(dPdx.x, dPdx.y, dPdx.z, 0.0f).data, float4(dPdy.x, dPdy.y, dPdy.z, + // 0.0f).data); return __hipMapFrom(tmp); + return {}; +} + +template ::value>::type * = nullptr> +static __device__ __hip_img_chk__ void +texCubemapLayeredGrad(T *ptr, hipTextureObject_t textureObject, float x, + float y, float z, int layer, float4 dPdx, float4 dPdy) { + *ptr = texCubemapLayeredGrad(textureObject, x, y, z, layer, dPdx, dPdy); +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/channel_descriptor.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/channel_descriptor.h new file mode 100644 index 000000000..02fad2b0f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/channel_descriptor.h @@ -0,0 +1,38 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_CHANNEL_DESCRIPTOR_H +#define HIP_INCLUDE_HIP_CHANNEL_DESCRIPTOR_H + +// Some standard header files, these are included by hc.hpp and so want to make +// them avail on both paths to provide a consistent include env and avoid +// "missing symbol" errors that only appears on NVCC path: + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/device_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/device_functions.h new file mode 100644 index 000000000..9a86171e4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/device_functions.h @@ -0,0 +1,38 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_DEVICE_FUNCTIONS_H +#define HIP_INCLUDE_HIP_DEVICE_FUNCTIONS_H + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/driver_types.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/driver_types.h new file mode 100644 index 000000000..88b735d05 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/driver_types.h @@ -0,0 +1,497 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_DRIVER_TYPES_H +#define HIP_INCLUDE_HIP_DRIVER_TYPES_H + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include "driver_types.h" +#elif defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) + +#if !defined(__HIPCC_RTC__) +#ifndef __cplusplus +#include +#endif +#endif // !defined(__HIPCC_RTC__) +typedef void *hipDeviceptr_t; +typedef enum hipChannelFormatKind { + hipChannelFormatKindSigned = 0, + hipChannelFormatKindUnsigned = 1, + hipChannelFormatKindFloat = 2, + hipChannelFormatKindNone = 3 +} hipChannelFormatKind; +typedef struct hipChannelFormatDesc { + int x; + int y; + int z; + int w; + enum hipChannelFormatKind f; +} hipChannelFormatDesc; +#define HIP_TRSA_OVERRIDE_FORMAT 0x01 +#define HIP_TRSF_READ_AS_INTEGER 0x01 +#define HIP_TRSF_NORMALIZED_COORDINATES 0x02 +#define HIP_TRSF_SRGB 0x10 + +typedef struct hipArray *hipArray_t; +typedef const struct hipArray *hipArray_const_t; +typedef enum hipArray_Format { + HIP_AD_FORMAT_UNSIGNED_INT8 = 0x01, + HIP_AD_FORMAT_UNSIGNED_INT16 = 0x02, + HIP_AD_FORMAT_UNSIGNED_INT32 = 0x03, + HIP_AD_FORMAT_SIGNED_INT8 = 0x08, + HIP_AD_FORMAT_SIGNED_INT16 = 0x09, + HIP_AD_FORMAT_SIGNED_INT32 = 0x0a, + HIP_AD_FORMAT_HALF = 0x10, + HIP_AD_FORMAT_FLOAT = 0x20 +} hipArray_Format; +typedef struct HIP_ARRAY_DESCRIPTOR { + size_t Width; + size_t Height; + enum hipArray_Format Format; + unsigned int NumChannels; +} HIP_ARRAY_DESCRIPTOR; +typedef struct HIP_ARRAY3D_DESCRIPTOR { + size_t Width; + size_t Height; + size_t Depth; + enum hipArray_Format Format; + unsigned int NumChannels; + unsigned int Flags; +} HIP_ARRAY3D_DESCRIPTOR; +#if !defined(__HIPCC_RTC__) +typedef struct hip_Memcpy2D { + size_t srcXInBytes; + size_t srcY; + hipMemoryType srcMemoryType; + const void *srcHost; + hipDeviceptr_t srcDevice; + hipArray_t srcArray; + size_t srcPitch; + size_t dstXInBytes; + size_t dstY; + hipMemoryType dstMemoryType; + void *dstHost; + hipDeviceptr_t dstDevice; + hipArray_t dstArray; + size_t dstPitch; + size_t WidthInBytes; + size_t Height; +} hip_Memcpy2D; +#endif // !defined(__HIPCC_RTC__) +typedef struct hipMipmappedArray { + void *data; + struct hipChannelFormatDesc desc; + unsigned int type; + unsigned int width; + unsigned int height; + unsigned int depth; + unsigned int min_mipmap_level; + unsigned int max_mipmap_level; + unsigned int flags; + enum hipArray_Format format; + unsigned int num_channels; +} hipMipmappedArray; +typedef struct hipMipmappedArray *hipMipmappedArray_t; +typedef hipMipmappedArray_t hipmipmappedArray; +typedef const struct hipMipmappedArray *hipMipmappedArray_const_t; +/** + * hip resource types + */ +typedef enum hipResourceType { + hipResourceTypeArray = 0x00, + hipResourceTypeMipmappedArray = 0x01, + hipResourceTypeLinear = 0x02, + hipResourceTypePitch2D = 0x03 +} hipResourceType; +typedef enum HIPresourcetype_enum { + HIP_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resoure */ + HIP_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, /**< Mipmapped array resource */ + HIP_RESOURCE_TYPE_LINEAR = 0x02, /**< Linear resource */ + HIP_RESOURCE_TYPE_PITCH2D = 0x03 /**< Pitch 2D resource */ +} HIPresourcetype, + hipResourcetype; +/** + * hip address modes + */ +typedef enum HIPaddress_mode_enum { + HIP_TR_ADDRESS_MODE_WRAP = 0, + HIP_TR_ADDRESS_MODE_CLAMP = 1, + HIP_TR_ADDRESS_MODE_MIRROR = 2, + HIP_TR_ADDRESS_MODE_BORDER = 3 +} HIPaddress_mode; +/** + * hip filter modes + */ +typedef enum HIPfilter_mode_enum { + HIP_TR_FILTER_MODE_POINT = 0, + HIP_TR_FILTER_MODE_LINEAR = 1 +} HIPfilter_mode; +/** + * Texture descriptor + */ +typedef struct HIP_TEXTURE_DESC_st { + HIPaddress_mode addressMode[3]; /**< Address modes */ + HIPfilter_mode filterMode; /**< Filter mode */ + unsigned int flags; /**< Flags */ + unsigned int maxAnisotropy; /**< Maximum anisotropy ratio */ + HIPfilter_mode mipmapFilterMode; /**< Mipmap filter mode */ + float mipmapLevelBias; /**< Mipmap level bias */ + float minMipmapLevelClamp; /**< Mipmap minimum level clamp */ + float maxMipmapLevelClamp; /**< Mipmap maximum level clamp */ + float borderColor[4]; /**< Border Color */ + int reserved[12]; +} HIP_TEXTURE_DESC; +/** + * hip texture resource view formats + */ +typedef enum hipResourceViewFormat { + hipResViewFormatNone = 0x00, + hipResViewFormatUnsignedChar1 = 0x01, + hipResViewFormatUnsignedChar2 = 0x02, + hipResViewFormatUnsignedChar4 = 0x03, + hipResViewFormatSignedChar1 = 0x04, + hipResViewFormatSignedChar2 = 0x05, + hipResViewFormatSignedChar4 = 0x06, + hipResViewFormatUnsignedShort1 = 0x07, + hipResViewFormatUnsignedShort2 = 0x08, + hipResViewFormatUnsignedShort4 = 0x09, + hipResViewFormatSignedShort1 = 0x0a, + hipResViewFormatSignedShort2 = 0x0b, + hipResViewFormatSignedShort4 = 0x0c, + hipResViewFormatUnsignedInt1 = 0x0d, + hipResViewFormatUnsignedInt2 = 0x0e, + hipResViewFormatUnsignedInt4 = 0x0f, + hipResViewFormatSignedInt1 = 0x10, + hipResViewFormatSignedInt2 = 0x11, + hipResViewFormatSignedInt4 = 0x12, + hipResViewFormatHalf1 = 0x13, + hipResViewFormatHalf2 = 0x14, + hipResViewFormatHalf4 = 0x15, + hipResViewFormatFloat1 = 0x16, + hipResViewFormatFloat2 = 0x17, + hipResViewFormatFloat4 = 0x18, + hipResViewFormatUnsignedBlockCompressed1 = 0x19, + hipResViewFormatUnsignedBlockCompressed2 = 0x1a, + hipResViewFormatUnsignedBlockCompressed3 = 0x1b, + hipResViewFormatUnsignedBlockCompressed4 = 0x1c, + hipResViewFormatSignedBlockCompressed4 = 0x1d, + hipResViewFormatUnsignedBlockCompressed5 = 0x1e, + hipResViewFormatSignedBlockCompressed5 = 0x1f, + hipResViewFormatUnsignedBlockCompressed6H = 0x20, + hipResViewFormatSignedBlockCompressed6H = 0x21, + hipResViewFormatUnsignedBlockCompressed7 = 0x22 +} hipResourceViewFormat; +typedef enum HIPresourceViewFormat_enum { + HIP_RES_VIEW_FORMAT_NONE = + 0x00, /**< No resource view format (use underlying resource format) */ + HIP_RES_VIEW_FORMAT_UINT_1X8 = 0x01, /**< 1 channel unsigned 8-bit integers */ + HIP_RES_VIEW_FORMAT_UINT_2X8 = 0x02, /**< 2 channel unsigned 8-bit integers */ + HIP_RES_VIEW_FORMAT_UINT_4X8 = 0x03, /**< 4 channel unsigned 8-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_1X8 = 0x04, /**< 1 channel signed 8-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_2X8 = 0x05, /**< 2 channel signed 8-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_4X8 = 0x06, /**< 4 channel signed 8-bit integers */ + HIP_RES_VIEW_FORMAT_UINT_1X16 = + 0x07, /**< 1 channel unsigned 16-bit integers */ + HIP_RES_VIEW_FORMAT_UINT_2X16 = + 0x08, /**< 2 channel unsigned 16-bit integers */ + HIP_RES_VIEW_FORMAT_UINT_4X16 = + 0x09, /**< 4 channel unsigned 16-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_1X16 = 0x0a, /**< 1 channel signed 16-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_2X16 = 0x0b, /**< 2 channel signed 16-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_4X16 = 0x0c, /**< 4 channel signed 16-bit integers */ + HIP_RES_VIEW_FORMAT_UINT_1X32 = + 0x0d, /**< 1 channel unsigned 32-bit integers */ + HIP_RES_VIEW_FORMAT_UINT_2X32 = + 0x0e, /**< 2 channel unsigned 32-bit integers */ + HIP_RES_VIEW_FORMAT_UINT_4X32 = + 0x0f, /**< 4 channel unsigned 32-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_1X32 = 0x10, /**< 1 channel signed 32-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_2X32 = 0x11, /**< 2 channel signed 32-bit integers */ + HIP_RES_VIEW_FORMAT_SINT_4X32 = 0x12, /**< 4 channel signed 32-bit integers */ + HIP_RES_VIEW_FORMAT_FLOAT_1X16 = 0x13, /**< 1 channel 16-bit floating point */ + HIP_RES_VIEW_FORMAT_FLOAT_2X16 = 0x14, /**< 2 channel 16-bit floating point */ + HIP_RES_VIEW_FORMAT_FLOAT_4X16 = 0x15, /**< 4 channel 16-bit floating point */ + HIP_RES_VIEW_FORMAT_FLOAT_1X32 = 0x16, /**< 1 channel 32-bit floating point */ + HIP_RES_VIEW_FORMAT_FLOAT_2X32 = 0x17, /**< 2 channel 32-bit floating point */ + HIP_RES_VIEW_FORMAT_FLOAT_4X32 = 0x18, /**< 4 channel 32-bit floating point */ + HIP_RES_VIEW_FORMAT_UNSIGNED_BC1 = 0x19, /**< Block compressed 1 */ + HIP_RES_VIEW_FORMAT_UNSIGNED_BC2 = 0x1a, /**< Block compressed 2 */ + HIP_RES_VIEW_FORMAT_UNSIGNED_BC3 = 0x1b, /**< Block compressed 3 */ + HIP_RES_VIEW_FORMAT_UNSIGNED_BC4 = 0x1c, /**< Block compressed 4 unsigned */ + HIP_RES_VIEW_FORMAT_SIGNED_BC4 = 0x1d, /**< Block compressed 4 signed */ + HIP_RES_VIEW_FORMAT_UNSIGNED_BC5 = 0x1e, /**< Block compressed 5 unsigned */ + HIP_RES_VIEW_FORMAT_SIGNED_BC5 = 0x1f, /**< Block compressed 5 signed */ + HIP_RES_VIEW_FORMAT_UNSIGNED_BC6H = + 0x20, /**< Block compressed 6 unsigned half-float */ + HIP_RES_VIEW_FORMAT_SIGNED_BC6H = + 0x21, /**< Block compressed 6 signed half-float */ + HIP_RES_VIEW_FORMAT_UNSIGNED_BC7 = 0x22 /**< Block compressed 7 */ +} HIPresourceViewFormat; +/** + * HIP resource descriptor + */ +typedef struct hipResourceDesc { + enum hipResourceType resType; + union { + struct { + hipArray_t array; + } array; + struct { + hipMipmappedArray_t mipmap; + } mipmap; + struct { + void *devPtr; + struct hipChannelFormatDesc desc; + size_t sizeInBytes; + } linear; + struct { + void *devPtr; + struct hipChannelFormatDesc desc; + size_t width; + size_t height; + size_t pitchInBytes; + } pitch2D; + } res; +} hipResourceDesc; +typedef struct HIP_RESOURCE_DESC_st { + HIPresourcetype resType; /**< Resource type */ + union { + struct { + hipArray_t hArray; /**< HIP array */ + } array; + struct { + hipMipmappedArray_t hMipmappedArray; /**< HIP mipmapped array */ + } mipmap; + struct { + hipDeviceptr_t devPtr; /**< Device pointer */ + hipArray_Format format; /**< Array format */ + unsigned int numChannels; /**< Channels per array element */ + size_t sizeInBytes; /**< Size in bytes */ + } linear; + struct { + hipDeviceptr_t devPtr; /**< Device pointer */ + hipArray_Format format; /**< Array format */ + unsigned int numChannels; /**< Channels per array element */ + size_t width; /**< Width of the array in elements */ + size_t height; /**< Height of the array in elements */ + size_t pitchInBytes; /**< Pitch between two rows in bytes */ + } pitch2D; + struct { + int reserved[32]; + } reserved; + } res; + unsigned int flags; /**< Flags (must be zero) */ +} HIP_RESOURCE_DESC; +/** + * hip resource view descriptor + */ +struct hipResourceViewDesc { + enum hipResourceViewFormat format; + size_t width; + size_t height; + size_t depth; + unsigned int firstMipmapLevel; + unsigned int lastMipmapLevel; + unsigned int firstLayer; + unsigned int lastLayer; +}; +/** + * Resource view descriptor + */ +typedef struct HIP_RESOURCE_VIEW_DESC_st { + HIPresourceViewFormat format; /**< Resource view format */ + size_t width; /**< Width of the resource view */ + size_t height; /**< Height of the resource view */ + size_t depth; /**< Depth of the resource view */ + unsigned int firstMipmapLevel; /**< First defined mipmap level */ + unsigned int lastMipmapLevel; /**< Last defined mipmap level */ + unsigned int firstLayer; /**< First layer index */ + unsigned int lastLayer; /**< Last layer index */ + unsigned int reserved[16]; +} HIP_RESOURCE_VIEW_DESC; +/** + * Memory copy types + * + */ +#if !defined(__HIPCC_RTC__) +typedef enum hipMemcpyKind { + hipMemcpyHostToHost = 0, ///< Host-to-Host Copy + hipMemcpyHostToDevice = 1, ///< Host-to-Device Copy + hipMemcpyDeviceToHost = 2, ///< Device-to-Host Copy + hipMemcpyDeviceToDevice = 3, ///< Device-to-Device Copy + hipMemcpyDefault = 4, ///< Runtime will automatically determine + ///< copy-kind based on virtual addresses. + hipMemcpyDeviceToDeviceNoCU = + 1024 ///< Device-to-Device Copy without using compute units +} hipMemcpyKind; +typedef struct hipPitchedPtr { + void *ptr; + size_t pitch; + size_t xsize; + size_t ysize; +} hipPitchedPtr; +typedef struct hipExtent { + size_t width; // Width in elements when referring to array memory, in bytes + // when referring to linear memory + size_t height; + size_t depth; +} hipExtent; +typedef struct hipPos { + size_t x; + size_t y; + size_t z; +} hipPos; +typedef struct hipMemcpy3DParms { + hipArray_t srcArray; + struct hipPos srcPos; + struct hipPitchedPtr srcPtr; + hipArray_t dstArray; + struct hipPos dstPos; + struct hipPitchedPtr dstPtr; + struct hipExtent extent; + enum hipMemcpyKind kind; +} hipMemcpy3DParms; +typedef struct HIP_MEMCPY3D { + size_t srcXInBytes; + size_t srcY; + size_t srcZ; + size_t srcLOD; + hipMemoryType srcMemoryType; + const void *srcHost; + hipDeviceptr_t srcDevice; + hipArray_t srcArray; + size_t srcPitch; + size_t srcHeight; + size_t dstXInBytes; + size_t dstY; + size_t dstZ; + size_t dstLOD; + hipMemoryType dstMemoryType; + void *dstHost; + hipDeviceptr_t dstDevice; + hipArray_t dstArray; + size_t dstPitch; + size_t dstHeight; + size_t WidthInBytes; + size_t Height; + size_t Depth; +} HIP_MEMCPY3D; +static inline struct hipPitchedPtr make_hipPitchedPtr(void *d, size_t p, + size_t xsz, size_t ysz) { + struct hipPitchedPtr s; + s.ptr = d; + s.pitch = p; + s.xsize = xsz; + s.ysize = ysz; + return s; +} +static inline struct hipPos make_hipPos(size_t x, size_t y, size_t z) { + struct hipPos p; + p.x = x; + p.y = y; + p.z = z; + return p; +} +static inline struct hipExtent make_hipExtent(size_t w, size_t h, size_t d) { + struct hipExtent e; + e.width = w; + e.height = h; + e.depth = d; + return e; +} +typedef enum hipFunction_attribute { + HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, + HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, + HIP_FUNC_ATTRIBUTE_CONST_SIZE_BYTES, + HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, + HIP_FUNC_ATTRIBUTE_NUM_REGS, + HIP_FUNC_ATTRIBUTE_PTX_VERSION, + HIP_FUNC_ATTRIBUTE_BINARY_VERSION, + HIP_FUNC_ATTRIBUTE_CACHE_MODE_CA, + HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + HIP_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT, + HIP_FUNC_ATTRIBUTE_MAX +} hipFunction_attribute; + +typedef enum hipPointer_attribute { + HIP_POINTER_ATTRIBUTE_CONTEXT = + 1, ///< The context on which a pointer was allocated + ///< @warning - not supported in HIP + HIP_POINTER_ATTRIBUTE_MEMORY_TYPE, ///< memory type describing location of a + ///< pointer + HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ///< address at which the pointer is + ///< allocated on device + HIP_POINTER_ATTRIBUTE_HOST_POINTER, ///< address at which the pointer is + ///< allocated on host + HIP_POINTER_ATTRIBUTE_P2P_TOKENS, ///< A pair of tokens for use with linux + ///< kernel interface + ///< @warning - not supported in HIP + HIP_POINTER_ATTRIBUTE_SYNC_MEMOPS, ///< Synchronize every synchronous memory + ///< operation initiated on this region + HIP_POINTER_ATTRIBUTE_BUFFER_ID, ///< Unique ID for an allocated memory region + HIP_POINTER_ATTRIBUTE_IS_MANAGED, ///< Indicates if the pointer points to + ///< managed memory + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, ///< device ordinal of a device on which + ///< a pointer was allocated or + ///< registered + HIP_POINTER_ATTRIBUTE_IS_LEGACY_HIP_IPC_CAPABLE, ///< if this pointer maps to + ///< an allocation that is + ///< suitable for + ///< hipIpcGetMemHandle + ///< @warning - not supported + ///< in HIP + HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, ///< Starting address for this + ///< requested pointer + HIP_POINTER_ATTRIBUTE_RANGE_SIZE, ///< Size of the address range for this + ///< requested pointer + HIP_POINTER_ATTRIBUTE_MAPPED, ///< tells if this pointer is in a valid address + ///< range that is mapped to a backing + ///< allocation + HIP_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES, ///< Bitmask of allowed + ///< hipmemAllocationHandleType + ///< for this allocation @warning + ///< - not supported in HIP + HIP_POINTER_ATTRIBUTE_IS_GPU_DIRECT_RDMA_CAPABLE, ///< returns if the memory + ///< referenced by this + ///< pointer can be used + ///< with the GPUDirect RDMA + ///< API + ///< @warning - not + ///< supported in HIP + HIP_POINTER_ATTRIBUTE_ACCESS_FLAGS, ///< Returns the access flags the device + ///< associated with for the corresponding + ///< memory referenced by the ptr + HIP_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ///< Returns the mempool handle for the + ///< allocation if it was allocated from + ///< a mempool + ///< @warning - not supported in HIP +} hipPointer_attribute; + +#endif // !defined(__HIPCC_RTC__) +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_bf16.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_bf16.h new file mode 100644 index 000000000..12351710b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_bf16.h @@ -0,0 +1,36 @@ +/* +Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_BF16_H +#define HIP_INCLUDE_HIP_HIP_BF16_H + +#include + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif // HIP_INCLUDE_HIP_HIP_BF16_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_bfloat16.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_bfloat16.h new file mode 100644 index 000000000..53aed89aa --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_bfloat16.h @@ -0,0 +1,44 @@ +/** + * MIT License + * + * Copyright (c) 2019 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/*!\file + * \brief hip_bfloat16.h provides struct for hip_bfloat16 typedef + */ + +#ifndef _HIP_BFLOAT16_H_ +#define _HIP_BFLOAT16_H_ + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#warning "hip_bfloat16.h is not supported on nvidia platform" +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif // _HIP_BFLOAT16_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_common.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_common.h new file mode 100644 index 000000000..ae9d9d7fc --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_common.h @@ -0,0 +1,101 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_COMMON_H +#define HIP_INCLUDE_HIP_HIP_COMMON_H + +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-macro-identifier" +#endif +// Common code included at start of every hip file. +// Auto enable __HIP_PLATFORM_AMD__ if compiling on AMD platform +// Other compiler (GCC,ICC,etc) need to set one of these macros explicitly +#if defined(__clang__) && defined(__HIP__) +#ifndef __HIP_PLATFORM_AMD__ +#define __HIP_PLATFORM_AMD__ +#endif +#endif // defined(__clang__) && defined(__HIP__) + +// Auto enable __HIP_PLATFORM_NVIDIA__ if compiling with NVIDIA platform +#if defined(__NVCC__) || \ + (defined(__clang__) && defined(__CUDA__) && !defined(__HIP__)) +#ifndef __HIP_PLATFORM_NVIDIA__ +#define __HIP_PLATFORM_NVIDIA__ +#endif + +#ifdef __CUDACC__ +#define __HIPCC__ +#endif + +#endif //__NVCC__ + +// Auto enable __HIP_DEVICE_COMPILE__ if compiled in HCC or NVCC device path +#if (defined(__HCC_ACCELERATOR__) && __HCC_ACCELERATOR__ != 0) || \ + (defined(__CUDA_ARCH__) && __CUDA_ARCH__ != 0) +#define __HIP_DEVICE_COMPILE__ 1 +#endif + +#ifdef __GNUC__ +#define HIP_PUBLIC_API __attribute__((visibility("default"))) +#define HIP_INTERNAL_EXPORTED_API __attribute__((visibility("default"))) +#else +#define HIP_PUBLIC_API +#define HIP_INTERNAL_EXPORTED_API +#endif + +#if __HIP_DEVICE_COMPILE__ == 0 +// 32-bit Atomics +#define __HIP_ARCH_HAS_GLOBAL_INT32_ATOMICS__ (0) +#define __HIP_ARCH_HAS_GLOBAL_FLOAT_ATOMIC_EXCH__ (0) +#define __HIP_ARCH_HAS_SHARED_INT32_ATOMICS__ (0) +#define __HIP_ARCH_HAS_SHARED_FLOAT_ATOMIC_EXCH__ (0) +#define __HIP_ARCH_HAS_FLOAT_ATOMIC_ADD__ (0) + +// 64-bit Atomics +#define __HIP_ARCH_HAS_GLOBAL_INT64_ATOMICS__ (0) +#define __HIP_ARCH_HAS_SHARED_INT64_ATOMICS__ (0) + +// Doubles +#define __HIP_ARCH_HAS_DOUBLES__ (0) + +// Warp cross-lane operations +#define __HIP_ARCH_HAS_WARP_VOTE__ (0) +#define __HIP_ARCH_HAS_WARP_BALLOT__ (0) +#define __HIP_ARCH_HAS_WARP_SHUFFLE__ (0) +#define __HIP_ARCH_HAS_WARP_FUNNEL_SHIFT__ (0) + +// Sync +#define __HIP_ARCH_HAS_THREAD_FENCE_SYSTEM__ (0) +#define __HIP_ARCH_HAS_SYNC_THREAD_EXT__ (0) + +// Misc +#define __HIP_ARCH_HAS_SURFACE_FUNCS__ (0) +#define __HIP_ARCH_HAS_3DGRID__ (0) +#define __HIP_ARCH_HAS_DYNAMIC_PARALLEL__ (0) +#endif + +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_complex.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_complex.h new file mode 100644 index 000000000..a8d1b6ee6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_complex.h @@ -0,0 +1,38 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_COMPLEX_H +#define HIP_INCLUDE_HIP_HIP_COMPLEX_H + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_cooperative_groups.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_cooperative_groups.h new file mode 100644 index 000000000..ca6621a2c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_cooperative_groups.h @@ -0,0 +1,46 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file hip_cooperative_groups.h + * + * @brief Defines new types and device API wrappers for `Cooperative Group` + * feature. + */ + +#ifndef HIP_INCLUDE_HIP_HIP_COOPERATIVE_GROUP_H +#define HIP_INCLUDE_HIP_HIP_COOPERATIVE_GROUP_H + +#include +#include + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#if __cplusplus && defined(__clang__) && defined(__HIP__) +#include +#endif +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif // HIP_INCLUDE_HIP_HIP_COOPERATIVE_GROUP_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_deprecated.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_deprecated.h new file mode 100644 index 000000000..eafa668d2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_deprecated.h @@ -0,0 +1,114 @@ +#pragma once + +// This file will add older hip functions used in the versioning system +// Find the deprecated functions and structs in hip_device.cpp + +// This struct is also kept in hip_device.cpp +typedef struct hipDeviceProp_tR0000 { + char name[256]; ///< Device name. + size_t totalGlobalMem; ///< Size of global memory region (in bytes). + size_t sharedMemPerBlock; ///< Size of shared memory region (in bytes). + int regsPerBlock; ///< Registers per block. + int warpSize; ///< Warp size. + int maxThreadsPerBlock; ///< Max work items per work group or workgroup max + ///< size. + int maxThreadsDim[3]; ///< Max number of threads in each dimension (XYZ) of a + ///< block. + int maxGridSize[3]; ///< Max grid dimensions (XYZ). + int clockRate; ///< Max clock frequency of the multiProcessors in khz. + int memoryClockRate; ///< Max global memory clock frequency in khz. + int memoryBusWidth; ///< Global memory bus width in bits. + size_t totalConstMem; ///< Size of shared memory region (in bytes). + int major; ///< Major compute capability. On HCC, this is an approximation + ///< and features may differ from CUDA CC. See the arch feature + ///< flags for portable ways to query feature caps. + int minor; ///< Minor compute capability. On HCC, this is an approximation + ///< and features may differ from CUDA CC. See the arch feature + ///< flags for portable ways to query feature caps. + int multiProcessorCount; ///< Number of multi-processors (compute units). + int l2CacheSize; ///< L2 cache size. + int maxThreadsPerMultiProcessor; ///< Maximum resident threads per + ///< multi-processor. + int computeMode; ///< Compute mode. + int clockInstructionRate; ///< Frequency in khz of the timer used by the + ///< device-side "clock*" instructions. New for + ///< HIP. + hipDeviceArch_t arch; ///< Architectural feature flags. New for HIP. + int concurrentKernels; ///< Device can possibly execute multiple kernels + ///< concurrently. + int pciDomainID; ///< PCI Domain ID + int pciBusID; ///< PCI Bus ID. + int pciDeviceID; ///< PCI Device ID. + size_t maxSharedMemoryPerMultiProcessor; ///< Maximum Shared Memory Per + ///< Multiprocessor. + int isMultiGpuBoard; ///< 1 if device is on a multi-GPU board, 0 if not. + int canMapHostMemory; ///< Check whether HIP can map host memory + int gcnArch; ///< DEPRECATED: use gcnArchName instead + char gcnArchName[256]; ///< AMD GCN Arch Name. + int integrated; ///< APU vs dGPU + int cooperativeLaunch; ///< HIP device supports cooperative launch + int cooperativeMultiDeviceLaunch; ///< HIP device supports cooperative launch + ///< on multiple devices + int maxTexture1DLinear; ///< Maximum size for 1D textures bound to linear + ///< memory + int maxTexture1D; ///< Maximum number of elements in 1D images + int maxTexture2D[2]; ///< Maximum dimensions (width, height) of 2D images, in + ///< image elements + int maxTexture3D[3]; ///< Maximum dimensions (width, height, depth) of 3D + ///< images, in image elements + unsigned int + *hdpMemFlushCntl; ///< Addres of HDP_MEM_COHERENCY_FLUSH_CNTL register + unsigned int + *hdpRegFlushCntl; ///< Addres of HDP_REG_COHERENCY_FLUSH_CNTL register + size_t memPitch; ///< Maximum pitch in bytes allowed by memory copies + size_t textureAlignment; ///< Alignment requirement for textures + size_t texturePitchAlignment; ///< Pitch alignment requirement for texture + ///< references bound to pitched memory + int kernelExecTimeoutEnabled; ///< Run time limit for kernels executed on the + ///< device + int ECCEnabled; ///< Device has ECC support enabled + int tccDriver; ///< 1:If device is Tesla device using TCC driver, else 0 + int cooperativeMultiDeviceUnmatchedFunc; ///< HIP device supports cooperative + ///< launch on multiple + /// devices with unmatched functions + int cooperativeMultiDeviceUnmatchedGridDim; ///< HIP device supports + ///< cooperative launch on + ///< multiple + /// devices with unmatched grid + /// dimensions + int cooperativeMultiDeviceUnmatchedBlockDim; ///< HIP device supports + ///< cooperative launch on + ///< multiple + /// devices with unmatched block + /// dimensions + int cooperativeMultiDeviceUnmatchedSharedMem; ///< HIP device supports + ///< cooperative launch on + ///< multiple + /// devices with unmatched + /// shared memories + int isLargeBar; ///< 1: if it is a large PCI bar device, else 0 + int asicRevision; ///< Revision of the GPU in this device + int managedMemory; ///< Device supports allocating managed memory on this + ///< system + int directManagedMemAccessFromHost; ///< Host can directly access managed + ///< memory on the device without + ///< migration + int concurrentManagedAccess; ///< Device can coherently access managed memory + ///< concurrently with the CPU + int pageableMemoryAccess; ///< Device supports coherently accessing pageable + ///< memory without calling hipHostRegister on it + int pageableMemoryAccessUsesHostPageTables; ///< Device accesses pageable + ///< memory via the host's page + ///< tables +} hipDeviceProp_tR0000; + +#ifdef __cplusplus +extern "C" { +#endif + +hipError_t hipGetDevicePropertiesR0000(hipDeviceProp_tR0000 *prop, int device); +hipError_t hipChooseDeviceR0000(int *device, const hipDeviceProp_tR0000 *prop); + +#ifdef __cplusplus +} +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_ext.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_ext.h new file mode 100644 index 000000000..045745011 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_ext.h @@ -0,0 +1,178 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_EXT_H +#define HIP_INCLUDE_HIP_HIP_EXT_H +#include "hip/hip_runtime.h" +#if defined(__cplusplus) +#include +#include +#endif +/** @addtogroup Module Module Management + * @{ + */ + +/** + * @brief Launches kernel with parameters and shared memory on stream with + * arguments passed to kernel params or extra arguments. + * + * @param [in] f Kernel to launch. + * @param [in] globalWorkSizeX X grid dimension specified in work-items. + * @param [in] globalWorkSizeY Y grid dimension specified in work-items. + * @param [in] globalWorkSizeZ Z grid dimension specified in work-items. + * @param [in] localWorkSizeX X block dimension specified in work-items. + * @param [in] localWorkSizeY Y block dimension specified in work-items. + * @param [in] localWorkSizeZ Z block dimension specified in work-items. + * @param [in] sharedMemBytes Amount of dynamic shared memory to allocate for + * this kernel. HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] hStream Stream where the kernel should be dispatched. + * May be 0, in which case the default stream is used with associated + * synchronization rules. + * @param [in] kernelParams pointer to kernel parameters. + * @param [in] extra Pointer to kernel arguments. These are passed directly to + * the kernel and must be in the memory layout and alignment expected by the + * kernel. All passed arguments must be naturally aligned according to their + * type. The memory address of each argument should be a multiple of its size in + * bytes. Please refer to hip_porting_driver_api.md for sample usage. + * @param [in] startEvent If non-null, specified event will be updated to track + * the start time of the kernel launch. The event must be created before calling + * this API. + * @param [in] stopEvent If non-null, specified event will be updated to track + * the stop time of the kernel launch. The event must be created before calling + * this API. + * @param [in] flags The value of hipExtAnyOrderLaunch, signifies if kernel can + * be launched in any order. + * @returns #hipSuccess, #hipInvalidDeviceId, #hipErrorNotInitialized, + * #hipErrorInvalidValue. + * + * HIP/ROCm actually updates the start event when the associated kernel + * completes. Currently, timing between startEvent and stopEvent does not + * include the time it takes to perform a system scope release/cache flush - + * only the time it takes to issues writes to cache. + * + * @note For this HIP API, the flag 'hipExtAnyOrderLaunch' is not supported on + * AMD GFX9xx boards. + * + */ +HIP_PUBLIC_API +extern "C" hipError_t hipExtModuleLaunchKernel( + hipFunction_t f, uint32_t globalWorkSizeX, uint32_t globalWorkSizeY, + uint32_t globalWorkSizeZ, uint32_t localWorkSizeX, uint32_t localWorkSizeY, + uint32_t localWorkSizeZ, size_t sharedMemBytes, hipStream_t hStream, + void **kernelParams, void **extra, hipEvent_t startEvent __dparm(NULL), + hipEvent_t stopEvent __dparm(NULL), uint32_t flags __dparm(0)); +/** + * @brief This HIP API is deprecated, please use hipExtModuleLaunchKernel() + * instead. + * + */ +DEPRECATED("use hipExtModuleLaunchKernel instead") +HIP_PUBLIC_API +extern "C" hipError_t hipHccModuleLaunchKernel( + hipFunction_t f, uint32_t globalWorkSizeX, uint32_t globalWorkSizeY, + uint32_t globalWorkSizeZ, uint32_t localWorkSizeX, uint32_t localWorkSizeY, + uint32_t localWorkSizeZ, size_t sharedMemBytes, hipStream_t hStream, + void **kernelParams, void **extra, hipEvent_t startEvent __dparm(NULL), + hipEvent_t stopEvent __dparm(NULL)); + +#if defined(__cplusplus) + +/** + * @brief Launches kernel from the pointer address, with arguments and shared + * memory on stream. + * + * @param [in] function_address pointer to the Kernel to launch. + * @param [in] numBlocks number of blocks. + * @param [in] dimBlocks dimension of a block. + * @param [in] args pointer to kernel arguments. + * @param [in] sharedMemBytes Amount of dynamic shared memory to allocate for + * this kernel. HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream Stream where the kernel should be dispatched. + * May be 0, in which case the default stream is used with associated + * synchronization rules. + * @param [in] startEvent If non-null, specified event will be updated to track + * the start time of the kernel launch. The event must be created before calling + * this API. + * @param [in] stopEvent If non-null, specified event will be updated to track + * the stop time of the kernel launch. The event must be created before calling + * this API. + * @param [in] flags The value of hipExtAnyOrderLaunch, signifies if kernel can + * be launched in any order. + * @returns #hipSuccess, #hipInvalidDeviceId, #hipErrorNotInitialized, + * #hipErrorInvalidValue. + * + */ +extern "C" hipError_t +hipExtLaunchKernel(const void *function_address, dim3 numBlocks, dim3 dimBlocks, + void **args, size_t sharedMemBytes, hipStream_t stream, + hipEvent_t startEvent, hipEvent_t stopEvent, int flags); + +/** + * @brief Launches kernel with dimention parameters and shared memory on stream + * with templated kernel and arguments. + * + * @param [in] kernel Kernel to launch. + * @param [in] numBlocks const number of blocks. + * @param [in] dimBlocks const dimension of a block. + * @param [in] sharedMemBytes Amount of dynamic shared memory to allocate for + * this kernel. HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream Stream where the kernel should be dispatched. + * May be 0, in which case the default stream is used with associated + * synchronization rules. + * @param [in] startEvent If non-null, specified event will be updated to track + * the start time of the kernel launch. The event must be created before calling + * this API. + * @param [in] stopEvent If non-null, specified event will be updated to track + * the stop time of the kernel launch. The event must be created before calling + * this API. + * @param [in] flags The value of hipExtAnyOrderLaunch, signifies if kernel can + * be launched in any order. + * @param [in] args templated kernel arguments. + * + */ +template +inline void +hipExtLaunchKernelGGL(F kernel, const dim3 &numBlocks, const dim3 &dimBlocks, + std::uint32_t sharedMemBytes, hipStream_t stream, + hipEvent_t startEvent, hipEvent_t stopEvent, + std::uint32_t flags, Args... args) { + constexpr size_t count = sizeof...(Args); + auto tup_ = std::tuple{args...}; + auto tup = validateArgsCountType(kernel, tup_); + void *_Args[count]; + pArgs<0>(tup, _Args); + + auto k = reinterpret_cast(kernel); + hipExtLaunchKernel(k, numBlocks, dimBlocks, _Args, sharedMemBytes, stream, + startEvent, stopEvent, (int)flags); +} + +#endif // defined(__cplusplus) + +// doxygen end AMD-specific features +/** + * @} + */ +#endif // #iidef HIP_INCLUDE_HIP_HIP_EXT_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_fp16.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_fp16.h new file mode 100644 index 000000000..67def4ee1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_fp16.h @@ -0,0 +1,36 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_FP16_H +#define HIP_INCLUDE_HIP_HIP_FP16_H + +#include + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include "cuda_fp16.h" +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_fp8.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_fp8.h new file mode 100644 index 000000000..c300d3326 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_fp8.h @@ -0,0 +1,33 @@ +/* +Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_FP8_H +#define HIP_INCLUDE_HIP_HIP_FP8_H + +#include + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +// We only have fnuz defs for now, which are not supported by other platforms +#include +#endif + +#endif // HIP_INCLUDE_HIP_HIP_FP8_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_gl_interop.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_gl_interop.h new file mode 100644 index 000000000..8af6ec32d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_gl_interop.h @@ -0,0 +1,32 @@ +/* +Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef HIP_GL_INTEROP_H +#define HIP_GL_INTEROP_H + +#include + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include "hip/amd_detail/amd_hip_gl_interop.h" +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include "hip/nvidia_detail/nvidia_hip_gl_interop.h" +#endif +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_hcc.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_hcc.h new file mode 100644 index 000000000..97b76fc30 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_hcc.h @@ -0,0 +1,24 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_HCC_H +#define HIP_INCLUDE_HIP_HIP_HCC_H +#warning "hip/hip_hcc.h is deprecated, please use hip/hip_ext.h" +#include "hip/hip_ext.h" +#endif // #ifdef HIP_INCLUDE_HIP_HIP_HCC_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_math_constants.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_math_constants.h new file mode 100644 index 000000000..4cce93c0f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_math_constants.h @@ -0,0 +1,36 @@ +/* +Copyright (c) 2015 - 2022 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef HIP_MATH_CONSTANTS_H +#define HIP_MATH_CONSTANTS_H + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include "hip/amd_detail/amd_hip_math_constants.h" +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include "hip/nvidia_detail/nvidia_hip_math_constants.h" +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_profile.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_profile.h new file mode 100644 index 000000000..4fef521d1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_profile.h @@ -0,0 +1,27 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_PROFILE_H +#define HIP_INCLUDE_HIP_HIP_PROFILE_H + +#define HIP_SCOPED_MARKER(markerName, group) +#define HIP_BEGIN_MARKER(markerName, group) +#define HIP_END_MARKER() + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_runtime.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_runtime.h new file mode 100644 index 000000000..1fa994a89 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_runtime.h @@ -0,0 +1,77 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +//! HIP = Heterogeneous-compute Interface for Portability +//! +//! Define a extremely thin runtime layer that allows source code to be compiled +//! unmodified through either AMD CLANG or NVCC. Key features tend to be in +//! the spirit and terminology of CUDA, but with a portable path to other +//! accelerators as well: +// +//! Both paths support rich C++ features including classes, templates, lambdas, +//! etc. Runtime API is C Memory management is based on pure pointers and +//! resembles malloc/free/copy. +// +//! hip_runtime.h : includes everything in hip_api.h, plus math builtins and +//! kernel launch macros. hip_runtime_api.h : Defines HIP API. This is a C +//! header file and does not use any C++ features. + +#ifndef HIP_INCLUDE_HIP_HIP_RUNTIME_H +#define HIP_INCLUDE_HIP_HIP_RUNTIME_H + +#if __HIP_DEVICE_COMPILE__ && !__GFX7__ && !__GFX8__ && !__GFX9__ && \ + __AMDGCN_WAVEFRONT_SIZE == 64 +#error HIP is not supported on the specified GPU ARCH with wavefront size 64 +#endif + +#if !defined(__HIPCC_RTC__) +// Some standard header files, these are included by hc.hpp and so want to make +// them avail on both paths to provide a consistent include env and avoid +// "missing symbol" errors that only appears on NVCC path: +#include +#include +#include +#include + +#if __cplusplus > 199711L +#include +#endif +#endif // !defined(__HIPCC_RTC__) + +#include +#include + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#if !defined(__HIPCC_RTC__) +#include +#include +#endif // !defined(__HIPCC_RTC__) +#include + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_runtime_api.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_runtime_api.h new file mode 100644 index 000000000..b0d242178 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_runtime_api.h @@ -0,0 +1,10099 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file hip_runtime_api.h + * + * @brief Defines the API signatures for HIP runtime. + * This file can be compiled with a standard compiler. + */ + +#ifndef HIP_INCLUDE_HIP_HIP_RUNTIME_API_H +#define HIP_INCLUDE_HIP_HIP_RUNTIME_API_H + +#include +#include +#include // for getDeviceProp + +enum { + HIP_SUCCESS = 0, + HIP_ERROR_INVALID_VALUE, + HIP_ERROR_NOT_INITIALIZED, + HIP_ERROR_LAUNCH_OUT_OF_RESOURCES +}; +// hack to get these to show up in Doxygen: +/** + * @defgroup GlobalDefs Global enum and defines + * @{ + * + */ +/** + * hipDeviceArch_t + * + */ +typedef struct { + // 32-bit Atomics + unsigned + hasGlobalInt32Atomics : 1; ///< 32-bit integer atomics for global memory. + unsigned hasGlobalFloatAtomicExch : 1; ///< 32-bit float atomic exch for + ///< global memory. + unsigned + hasSharedInt32Atomics : 1; ///< 32-bit integer atomics for shared memory. + unsigned hasSharedFloatAtomicExch : 1; ///< 32-bit float atomic exch for + ///< shared memory. + unsigned hasFloatAtomicAdd : 1; ///< 32-bit float atomic add in global and + ///< shared memory. + + // 64-bit Atomics + unsigned + hasGlobalInt64Atomics : 1; ///< 64-bit integer atomics for global memory. + unsigned + hasSharedInt64Atomics : 1; ///< 64-bit integer atomics for shared memory. + + // Doubles + unsigned hasDoubles : 1; ///< Double-precision floating point. + + // Warp cross-lane operations + unsigned hasWarpVote : 1; ///< Warp vote instructions (__any, __all). + unsigned hasWarpBallot : 1; ///< Warp ballot instructions (__ballot). + unsigned hasWarpShuffle : 1; ///< Warp shuffle operations. (__shfl_*). + unsigned + hasFunnelShift : 1; ///< Funnel two words into one with shift&mask caps. + + // Sync + unsigned hasThreadFenceSystem : 1; ///< __threadfence_system. + unsigned hasSyncThreadsExt : 1; ///< __syncthreads_count, syncthreads_and, + ///< syncthreads_or. + + // Misc + unsigned hasSurfaceFuncs : 1; ///< Surface functions. + unsigned has3dGrid : 1; ///< Grid and group dims are 3D (rather than 2D). + unsigned hasDynamicParallelism : 1; ///< Dynamic parallelism. +} hipDeviceArch_t; + +typedef struct hipUUID_t { + char bytes[16]; +} hipUUID; + +//--- +// Common headers for both NVCC and HCC paths: + +#define hipGetDeviceProperties hipGetDevicePropertiesR0600 +#define hipDeviceProp_t hipDeviceProp_tR0600 +#define hipChooseDevice hipChooseDeviceR0600 + +/** + * hipDeviceProp + * + */ +typedef struct hipDeviceProp_t { + char name[256]; ///< Device name. + hipUUID uuid; ///< UUID of a device + char luid[8]; ///< 8-byte unique identifier. Only valid on windows + unsigned int luidDeviceNodeMask; ///< LUID node mask + size_t totalGlobalMem; ///< Size of global memory region (in bytes). + size_t sharedMemPerBlock; ///< Size of shared memory per block (in bytes). + int regsPerBlock; ///< Registers per block. + int warpSize; ///< Warp size. + size_t memPitch; ///< Maximum pitch in bytes allowed by memory copies + ///< pitched memory + int maxThreadsPerBlock; ///< Max work items per work group or workgroup max + ///< size. + int maxThreadsDim[3]; ///< Max number of threads in each dimension (XYZ) of a + ///< block. + int maxGridSize[3]; ///< Max grid dimensions (XYZ). + int clockRate; ///< Max clock frequency of the multiProcessors in khz. + size_t totalConstMem; ///< Size of shared constant memory region on the device + ///< (in bytes). + int major; ///< Major compute capability. On HCC, this is an approximation + ///< and features may differ from CUDA CC. See the arch feature + ///< flags for portable ways to query feature caps. + int minor; ///< Minor compute capability. On HCC, this is an approximation + ///< and features may differ from CUDA CC. See the arch feature + ///< flags for portable ways to query feature caps. + size_t textureAlignment; ///< Alignment requirement for textures + size_t texturePitchAlignment; ///< Pitch alignment requirement for texture + ///< references bound to + int deviceOverlap; ///< Deprecated. Use asyncEngineCount instead + int multiProcessorCount; ///< Number of multi-processors (compute units). + int kernelExecTimeoutEnabled; ///< Run time limit for kernels executed on the + ///< device + int integrated; ///< APU vs dGPU + int canMapHostMemory; ///< Check whether HIP can map host memory + int computeMode; ///< Compute mode. + int maxTexture1D; ///< Maximum number of elements in 1D images + int maxTexture1DMipmap; ///< Maximum 1D mipmap texture size + int maxTexture1DLinear; ///< Maximum size for 1D textures bound to linear + ///< memory + int maxTexture2D[2]; ///< Maximum dimensions (width, height) of 2D images, in + ///< image elements + int maxTexture2DMipmap[2]; ///< Maximum number of elements in 2D array mipmap + ///< of images + int maxTexture2DLinear[3]; ///< Maximum 2D tex dimensions if tex are bound to + ///< pitched memory + int maxTexture2DGather[2]; ///< Maximum 2D tex dimensions if gather has to be + ///< performed + int maxTexture3D[3]; ///< Maximum dimensions (width, height, depth) of 3D + ///< images, in image elements + int maxTexture3DAlt[3]; ///< Maximum alternate 3D texture dims + int maxTextureCubemap; ///< Maximum cubemap texture dims + int maxTexture1DLayered[2]; ///< Maximum number of elements in 1D array images + int maxTexture2DLayered[3]; ///< Maximum number of elements in 2D array images + int maxTextureCubemapLayered[2]; ///< Maximum cubemaps layered texture dims + int maxSurface1D; ///< Maximum 1D surface size + int maxSurface2D[2]; ///< Maximum 2D surface size + int maxSurface3D[3]; ///< Maximum 3D surface size + int maxSurface1DLayered[2]; ///< Maximum 1D layered surface size + int maxSurface2DLayered[3]; ///< Maximum 2D layared surface size + int maxSurfaceCubemap; ///< Maximum cubemap surface size + int maxSurfaceCubemapLayered[2]; ///< Maximum cubemap layered surface size + size_t surfaceAlignment; ///< Alignment requirement for surface + int concurrentKernels; ///< Device can possibly execute multiple kernels + ///< concurrently. + int ECCEnabled; ///< Device has ECC support enabled + int pciBusID; ///< PCI Bus ID. + int pciDeviceID; ///< PCI Device ID. + int pciDomainID; ///< PCI Domain ID + int tccDriver; ///< 1:If device is Tesla device using TCC driver, else 0 + int asyncEngineCount; ///< Number of async engines + int unifiedAddressing; ///< Does device and host share unified address space + int memoryClockRate; ///< Max global memory clock frequency in khz. + int memoryBusWidth; ///< Global memory bus width in bits. + int l2CacheSize; ///< L2 cache size. + int persistingL2CacheMaxSize; ///< Device's max L2 persisting lines in bytes + int maxThreadsPerMultiProcessor; ///< Maximum resident threads per + ///< multi-processor. + int streamPrioritiesSupported; ///< Device supports stream priority + int globalL1CacheSupported; ///< Indicates globals are cached in L1 + int localL1CacheSupported; ///< Locals are cahced in L1 + size_t sharedMemPerMultiprocessor; ///< Amount of shared memory available per + ///< multiprocessor. + int regsPerMultiprocessor; ///< registers available per multiprocessor + int managedMemory; ///< Device supports allocating managed memory on this + ///< system + int isMultiGpuBoard; ///< 1 if device is on a multi-GPU board, 0 if not. + int multiGpuBoardGroupID; ///< Unique identifier for a group of devices on + ///< same multiboard GPU + int hostNativeAtomicSupported; ///< Link between host and device supports + ///< native atomics + int singleToDoublePrecisionPerfRatio; ///< Deprecated. CUDA only. + int pageableMemoryAccess; ///< Device supports coherently accessing pageable + ///< memory without calling hipHostRegister on it + int concurrentManagedAccess; ///< Device can coherently access managed memory + ///< concurrently with the CPU + int computePreemptionSupported; ///< Is compute preemption supported on the + ///< device + int canUseHostPointerForRegisteredMem; ///< Device can access host registered + ///< memory with same address as the + ///< host + int cooperativeLaunch; ///< HIP device supports cooperative launch + int cooperativeMultiDeviceLaunch; ///< HIP device supports cooperative launch + ///< on multiple devices + size_t sharedMemPerBlockOptin; ///< Per device m ax shared mem per block + ///< usable by special opt in + int pageableMemoryAccessUsesHostPageTables; ///< Device accesses pageable + ///< memory via the host's page + ///< tables + int directManagedMemAccessFromHost; ///< Host can directly access managed + ///< memory on the device without + ///< migration + int maxBlocksPerMultiProcessor; ///< Max number of blocks on CU + int accessPolicyMaxWindowSize; ///< Max value of access policy window + size_t + reservedSharedMemPerBlock; ///< Shared memory reserved by driver per block + int hostRegisterSupported; ///< Device supports hipHostRegister + int sparseHipArraySupported; ///< Indicates if device supports sparse hip + ///< arrays + int hostRegisterReadOnlySupported; ///< Device supports using the + ///< hipHostRegisterReadOnly flag with + ///< hipHostRegistger + int timelineSemaphoreInteropSupported; ///< Indicates external timeline + ///< semaphore support + int memoryPoolsSupported; ///< Indicates if device supports hipMallocAsync and + ///< hipMemPool APIs + int gpuDirectRDMASupported; ///< Indicates device support of RDMA APIs + unsigned int + gpuDirectRDMAFlushWritesOptions; ///< Bitmask to be interpreted according + ///< to + ///< hipFlushGPUDirectRDMAWritesOptions + int gpuDirectRDMAWritesOrdering; ///< value of hipGPUDirectRDMAWritesOrdering + unsigned int + memoryPoolSupportedHandleTypes; ///< Bitmask of handle types support with + ///< mempool based IPC + int deferredMappingHipArraySupported; ///< Device supports deferred mapping + ///< HIP arrays and HIP mipmapped arrays + int ipcEventSupported; ///< Device supports IPC events + int clusterLaunch; ///< Device supports cluster launch + int unifiedFunctionPointers; ///< Indicates device supports unified function + ///< pointers + int reserved[63]; ///< CUDA Reserved. + + int hipReserved[32]; ///< Reserved for adding new entries for HIP/CUDA. + + /* HIP Only struct members */ + char gcnArchName[256]; ///< AMD GCN Arch Name. HIP Only. + size_t maxSharedMemoryPerMultiProcessor; ///< Maximum Shared Memory Per CU. + ///< HIP Only. + int clockInstructionRate; ///< Frequency in khz of the timer used by the + ///< device-side "clock*" instructions. New for + ///< HIP. + hipDeviceArch_t arch; ///< Architectural feature flags. New for HIP. + unsigned int + *hdpMemFlushCntl; ///< Addres of HDP_MEM_COHERENCY_FLUSH_CNTL register + unsigned int + *hdpRegFlushCntl; ///< Addres of HDP_REG_COHERENCY_FLUSH_CNTL register + int cooperativeMultiDeviceUnmatchedFunc; ///< HIP device supports cooperative + ///< launch on multiple + /// devices with unmatched functions + int cooperativeMultiDeviceUnmatchedGridDim; ///< HIP device supports + ///< cooperative launch on + ///< multiple + /// devices with unmatched grid + /// dimensions + int cooperativeMultiDeviceUnmatchedBlockDim; ///< HIP device supports + ///< cooperative launch on + ///< multiple + /// devices with unmatched block + /// dimensions + int cooperativeMultiDeviceUnmatchedSharedMem; ///< HIP device supports + ///< cooperative launch on + ///< multiple + /// devices with unmatched + /// shared memories + int isLargeBar; ///< 1: if it is a large PCI bar device, else 0 + int asicRevision; ///< Revision of the GPU in this device +} hipDeviceProp_t; + +/** + * hipMemoryType (for pointer attributes) + * + * @note hipMemoryType enum values are combination of cudaMemoryType and + * cuMemoryType and AMD specific enum values. + * + */ +typedef enum hipMemoryType { + hipMemoryTypeUnregistered = 0, ///< Unregistered memory + hipMemoryTypeHost = 1, ///< Memory is physically located on host + hipMemoryTypeDevice = 2, ///< Memory is physically located on device. (see + ///< deviceId for specific device) + hipMemoryTypeManaged = + 3, ///< Managed memory, automaticallly managed by the unified + ///< memory system + ///< place holder for new values. + hipMemoryTypeArray = 10, ///< Array memory, physically located on device. (see + ///< deviceId for specific device) + hipMemoryTypeUnified = 11 ///< unified address space + +} hipMemoryType; + +/** + * Pointer attributes + */ +typedef struct hipPointerAttribute_t { + enum hipMemoryType type; + int device; + void *devicePointer; + void *hostPointer; + int isManaged; + unsigned allocationFlags; /* flags specified when memory was allocated*/ + /* peers? */ +} hipPointerAttribute_t; + +// Ignoring error-code return values from hip APIs is discouraged. On C++17, +// we can make that yield a warning +#if __cplusplus >= 201703L +#define __HIP_NODISCARD [[nodiscard]] +#else +#define __HIP_NODISCARD +#endif + +/** + * HIP error type + * + */ +// Developer note - when updating these, update the hipErrorName and +// hipErrorString functions in NVCC and HCC paths Also update the +// hipCUDAErrorTohipError function in NVCC path. + +typedef enum __HIP_NODISCARD hipError_t { + hipSuccess = 0, ///< Successful completion. + hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API + ///< call is NULL or not in an acceptable range. + hipErrorOutOfMemory = 2, ///< out of memory range. + // Deprecated + hipErrorMemoryAllocation = 2, ///< Memory allocation error. + hipErrorNotInitialized = 3, ///< Invalid not initialized + // Deprecated + hipErrorInitializationError = 3, + hipErrorDeinitialized = 4, ///< Deinitialized + hipErrorProfilerDisabled = 5, + hipErrorProfilerNotInitialized = 6, + hipErrorProfilerAlreadyStarted = 7, + hipErrorProfilerAlreadyStopped = 8, + hipErrorInvalidConfiguration = 9, ///< Invalide configuration + hipErrorInvalidPitchValue = 12, ///< Invalid pitch value + hipErrorInvalidSymbol = 13, ///< Invalid symbol + hipErrorInvalidDevicePointer = 17, ///< Invalid Device Pointer + hipErrorInvalidMemcpyDirection = 21, ///< Invalid memory copy direction + hipErrorInsufficientDriver = 35, + hipErrorMissingConfiguration = 52, + hipErrorPriorLaunchFailure = 53, + hipErrorInvalidDeviceFunction = 98, ///< Invalid device function + hipErrorNoDevice = 100, ///< Call to hipGetDeviceCount returned 0 devices + hipErrorInvalidDevice = + 101, ///< DeviceID must be in range from 0 to compute-devices. + hipErrorInvalidImage = 200, ///< Invalid image + hipErrorInvalidContext = 201, ///< Produced when input context is invalid. + hipErrorContextAlreadyCurrent = 202, + hipErrorMapFailed = 205, + // Deprecated + hipErrorMapBufferObjectFailed = + 205, ///< Produced when the IPC memory attach failed from ROCr. + hipErrorUnmapFailed = 206, + hipErrorArrayIsMapped = 207, + hipErrorAlreadyMapped = 208, + hipErrorNoBinaryForGpu = 209, + hipErrorAlreadyAcquired = 210, + hipErrorNotMapped = 211, + hipErrorNotMappedAsArray = 212, + hipErrorNotMappedAsPointer = 213, + hipErrorECCNotCorrectable = 214, + hipErrorUnsupportedLimit = 215, ///< Unsupported limit + hipErrorContextAlreadyInUse = 216, ///< The context is already in use + hipErrorPeerAccessUnsupported = 217, + hipErrorInvalidKernelFile = + 218, ///< In CUDA DRV, it is CUDA_ERROR_INVALID_PTX + hipErrorInvalidGraphicsContext = 219, + hipErrorInvalidSource = 300, ///< Invalid source. + hipErrorFileNotFound = 301, ///< the file is not found. + hipErrorSharedObjectSymbolNotFound = 302, + hipErrorSharedObjectInitFailed = 303, ///< Failed to initialize shared object. + hipErrorOperatingSystem = 304, ///< Not the correct operating system + hipErrorInvalidHandle = 400, ///< Invalide handle + // Deprecated + hipErrorInvalidResourceHandle = + 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid. + hipErrorIllegalState = + 401, ///< Resource required is not in a valid state to perform operation. + hipErrorNotFound = 500, ///< Not found + hipErrorNotReady = + 600, ///< Indicates that asynchronous operations enqueued earlier are not + ///< ready. This is not actually an error, but is used to + ///< distinguish from hipSuccess (which indicates completion). APIs + ///< that return this error include hipEventQuery and hipStreamQuery. + hipErrorIllegalAddress = 700, + hipErrorLaunchOutOfResources = 701, ///< Out of resources error. + hipErrorLaunchTimeOut = 702, ///< Timeout for the launch. + hipErrorPeerAccessAlreadyEnabled = 704, ///< Peer access was already enabled + ///< from the current device. + hipErrorPeerAccessNotEnabled = + 705, ///< Peer access was never enabled from the current device. + hipErrorSetOnActiveProcess = 708, ///< The process is active. + hipErrorContextIsDestroyed = 709, ///< The context is already destroyed + hipErrorAssert = 710, ///< Produced when the kernel calls assert. + hipErrorHostMemoryAlreadyRegistered = 712, ///< Produced when trying to lock a + ///< page-locked memory. + hipErrorHostMemoryNotRegistered = 713, ///< Produced when trying to unlock a + ///< non-page-locked memory. + hipErrorLaunchFailure = + 719, ///< An exception occurred on the device while executing a kernel. + hipErrorCooperativeLaunchTooLarge = + 720, ///< This error indicates that the number of blocks + ///< launched per grid for a kernel that was launched + ///< via cooperative launch APIs exceeds the maximum + ///< number of allowed blocks for the current device. + hipErrorNotSupported = + 801, ///< Produced when the hip API is not supported/implemented + hipErrorStreamCaptureUnsupported = 900, ///< The operation is not permitted + ///< when the stream is capturing. + hipErrorStreamCaptureInvalidated = + 901, ///< The current capture sequence on the stream + ///< has been invalidated due to a previous error. + hipErrorStreamCaptureMerge = + 902, ///< The operation would have resulted in a merge of + ///< two independent capture sequences. + hipErrorStreamCaptureUnmatched = + 903, ///< The capture was not initiated in this stream. + hipErrorStreamCaptureUnjoined = + 904, ///< The capture sequence contains a fork that was not + ///< joined to the primary stream. + hipErrorStreamCaptureIsolation = + 905, ///< A dependency would have been created which crosses + ///< the capture sequence boundary. Only implicit + ///< in-stream ordering dependencies are allowed + ///< to cross the boundary + hipErrorStreamCaptureImplicit = + 906, ///< The operation would have resulted in a disallowed + ///< implicit dependency on a current capture sequence + ///< from hipStreamLegacy. + hipErrorCapturedEvent = + 907, ///< The operation is not permitted on an event which was last + ///< recorded in a capturing stream. + hipErrorStreamCaptureWrongThread = + 908, ///< A stream capture sequence not initiated with + ///< the hipStreamCaptureModeRelaxed argument to + ///< hipStreamBeginCapture was passed to + ///< hipStreamEndCapture in a different thread. + hipErrorGraphExecUpdateFailure = + 910, ///< This error indicates that the graph update + ///< not performed because it included changes which + ///< violated constraintsspecific to instantiated graph + ///< update. + hipErrorUnknown = 999, ///< Unknown error. + // HSA Runtime Error Codes start here. + hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. + ///< Typically not seen in production systems. + hipErrorRuntimeOther = + 1053, ///< HSA runtime call other than memory returned error. Typically + ///< not seen in production systems. + hipErrorTbd ///< Marker that more error codes are needed. +} hipError_t; + +#undef __HIP_NODISCARD + +/** + * hipDeviceAttribute_t + * hipDeviceAttributeUnused number: 5 + */ +typedef enum hipDeviceAttribute_t { + hipDeviceAttributeCudaCompatibleBegin = 0, + + hipDeviceAttributeEccEnabled = + hipDeviceAttributeCudaCompatibleBegin, ///< Whether ECC support is + ///< enabled. + hipDeviceAttributeAccessPolicyMaxWindowSize, ///< Cuda only. The maximum size + ///< of the window policy in + ///< bytes. + hipDeviceAttributeAsyncEngineCount, ///< Asynchronous engines number. + hipDeviceAttributeCanMapHostMemory, ///< Whether host memory can be mapped + ///< into device address space + hipDeviceAttributeCanUseHostPointerForRegisteredMem, ///< Device can access + ///< host registered + ///< memory at the same + ///< virtual address as + ///< the CPU + hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz. + hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in. + hipDeviceAttributeComputePreemptionSupported, ///< Device supports Compute + ///< Preemption. + hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple + ///< kernels concurrently. + hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access + ///< managed memory concurrently + ///< with the CPU + hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch + hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative + ///< launch on multiple + ///< devices + hipDeviceAttributeDeviceOverlap, ///< Device can concurrently copy memory and + ///< execute a kernel. Deprecated. Use + ///< instead asyncEngineCount. + hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly + ///< access managed memory + ///< on the device without + ///< migration + hipDeviceAttributeGlobalL1CacheSupported, ///< Device supports caching globals + ///< in L1 + hipDeviceAttributeHostNativeAtomicSupported, ///< Link between the device and + ///< the host supports native + ///< atomic operations + hipDeviceAttributeIntegrated, ///< Device is integrated GPU + hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices. + hipDeviceAttributeKernelExecTimeout, ///< Run time limit for kernels executed + ///< on the device + hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device + ///< doesn't have L2 cache. + hipDeviceAttributeLocalL1CacheSupported, ///< caching locals in L1 is + ///< supported + hipDeviceAttributeLuid, ///< 8-byte locally unique identifier in 8 bytes. + ///< Undefined on TCC and non-Windows platforms + hipDeviceAttributeLuidDeviceNodeMask, ///< Luid device node mask. Undefined on + ///< TCC and non-Windows platforms + hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability + ///< version number. + hipDeviceAttributeManagedMemory, ///< Device supports allocating managed + ///< memory on this system + hipDeviceAttributeMaxBlocksPerMultiProcessor, ///< Max block size per + ///< multiprocessor + hipDeviceAttributeMaxBlockDimX, ///< Max block size in width. + hipDeviceAttributeMaxBlockDimY, ///< Max block size in height. + hipDeviceAttributeMaxBlockDimZ, ///< Max block size in depth. + hipDeviceAttributeMaxGridDimX, ///< Max grid size in width. + hipDeviceAttributeMaxGridDimY, ///< Max grid size in height. + hipDeviceAttributeMaxGridDimZ, ///< Max grid size in depth. + hipDeviceAttributeMaxSurface1D, ///< Maximum size of 1D surface. + hipDeviceAttributeMaxSurface1DLayered, ///< Cuda only. Maximum dimensions of + ///< 1D layered surface. + hipDeviceAttributeMaxSurface2D, ///< Maximum dimension (width, height) of 2D + ///< surface. + hipDeviceAttributeMaxSurface2DLayered, ///< Cuda only. Maximum dimensions of + ///< 2D layered surface. + hipDeviceAttributeMaxSurface3D, ///< Maximum dimension (width, height, depth) + ///< of 3D surface. + hipDeviceAttributeMaxSurfaceCubemap, ///< Cuda only. Maximum dimensions of + ///< Cubemap surface. + hipDeviceAttributeMaxSurfaceCubemapLayered, ///< Cuda only. Maximum dimension + ///< of Cubemap layered surface. + hipDeviceAttributeMaxTexture1DWidth, ///< Maximum size of 1D texture. + hipDeviceAttributeMaxTexture1DLayered, ///< Maximum dimensions of 1D layered + ///< texture. + hipDeviceAttributeMaxTexture1DLinear, ///< Maximum number of elements + ///< allocatable in a 1D linear texture. + ///< Use + ///< cudaDeviceGetTexture1DLinearMaxWidth() + ///< instead on Cuda. + hipDeviceAttributeMaxTexture1DMipmap, ///< Maximum size of 1D mipmapped + ///< texture. + hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D + ///< texture. + hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension hight of 2D + ///< texture. + hipDeviceAttributeMaxTexture2DGather, ///< Maximum dimensions of 2D texture if + ///< gather operations performed. + hipDeviceAttributeMaxTexture2DLayered, ///< Maximum dimensions of 2D layered + ///< texture. + hipDeviceAttributeMaxTexture2DLinear, ///< Maximum dimensions (width, height, + ///< pitch) of 2D textures bound to + ///< pitched memory. + hipDeviceAttributeMaxTexture2DMipmap, ///< Maximum dimensions of 2D mipmapped + ///< texture. + hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D + ///< texture. + hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimension height of 3D + ///< texture. + hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimension depth of 3D + ///< texture. + hipDeviceAttributeMaxTexture3DAlt, ///< Maximum dimensions of alternate 3D + ///< texture. + hipDeviceAttributeMaxTextureCubemap, ///< Maximum dimensions of Cubemap + ///< texture + hipDeviceAttributeMaxTextureCubemapLayered, ///< Maximum dimensions of Cubemap + ///< layered texture. + hipDeviceAttributeMaxThreadsDim, ///< Maximum dimension of a block + hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per + ///< block. + hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads + ///< per multiprocessor. + hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory + ///< copies + hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits. + hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in + ///< kilohertz. + hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability + ///< version number. + hipDeviceAttributeMultiGpuBoardGroupID, ///< Unique ID of device group on the + ///< same multi-GPU board + hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the + ///< device. + hipDeviceAttributeUnused1, ///< Previously hipDeviceAttributeName + hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently + ///< accessing pageable memory without + ///< calling hipHostRegister on it + hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses + ///< pageable memory + ///< via the host's + ///< page tables + hipDeviceAttributePciBusId, ///< PCI Bus ID. + hipDeviceAttributePciDeviceId, ///< PCI Device ID. + hipDeviceAttributePciDomainID, ///< PCI Domain ID. + hipDeviceAttributePersistingL2CacheMaxSize, ///< Maximum l2 persisting lines + ///< capacity in bytes + hipDeviceAttributeMaxRegistersPerBlock, ///< 32-bit registers available to a + ///< thread block. This number is + ///< shared by all thread blocks + ///< simultaneously resident on a + ///< multiprocessor. + hipDeviceAttributeMaxRegistersPerMultiprocessor, ///< 32-bit registers + ///< available per block. + hipDeviceAttributeReservedSharedMemPerBlock, ///< Shared memory reserved by + ///< CUDA driver per block. + hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory + ///< available per block in bytes. + hipDeviceAttributeSharedMemPerBlockOptin, ///< Maximum shared memory per block + ///< usable by special opt in. + hipDeviceAttributeSharedMemPerMultiprocessor, ///< Shared memory available per + ///< multiprocessor. + hipDeviceAttributeSingleToDoublePrecisionPerfRatio, ///< Cuda only. + ///< Performance ratio of + ///< single precision to + ///< double precision. + hipDeviceAttributeStreamPrioritiesSupported, ///< Whether to support stream + ///< priorities. + hipDeviceAttributeSurfaceAlignment, ///< Alignment requirement for surfaces + hipDeviceAttributeTccDriver, ///< Cuda only. Whether device is a Tesla device + ///< using TCC driver + hipDeviceAttributeTextureAlignment, ///< Alignment requirement for textures + hipDeviceAttributeTexturePitchAlignment, ///< Pitch alignment requirement for + ///< 2D texture references bound to + ///< pitched memory; + hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes. + hipDeviceAttributeTotalGlobalMem, ///< Global memory available on devicice. + hipDeviceAttributeUnifiedAddressing, ///< Cuda only. An unified address space + ///< shared with the host. + hipDeviceAttributeUnused2, ///< Previously hipDeviceAttributeUuid + hipDeviceAttributeWarpSize, ///< Warp size in threads. + hipDeviceAttributeMemoryPoolsSupported, ///< Device supports HIP Stream + ///< Ordered Memory Allocator + hipDeviceAttributeVirtualMemoryManagementSupported, ///< Device supports HIP + ///< virtual memory + ///< management + hipDeviceAttributeHostRegisterSupported, ///< Can device support host memory + ///< registration via hipHostRegister + hipDeviceAttributeMemoryPoolSupportedHandleTypes, ///< Supported handle mask + ///< for HIP Stream Ordered + ///< Memory Allocator + + hipDeviceAttributeCudaCompatibleEnd = 9999, + hipDeviceAttributeAmdSpecificBegin = 10000, + + hipDeviceAttributeClockInstructionRate = + hipDeviceAttributeAmdSpecificBegin, ///< Frequency in khz of the timer + ///< used by the device-side "clock*" + hipDeviceAttributeUnused3, ///< Previously hipDeviceAttributeArch + hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory + ///< PerMultiprocessor. + hipDeviceAttributeUnused4, ///< Previously hipDeviceAttributeGcnArch + hipDeviceAttributeUnused5, ///< Previously hipDeviceAttributeGcnArchName + hipDeviceAttributeHdpMemFlushCntl, ///< Address of the + ///< HDP_MEM_COHERENCY_FLUSH_CNTL register + hipDeviceAttributeHdpRegFlushCntl, ///< Address of the + ///< HDP_REG_COHERENCY_FLUSH_CNTL register + hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports + ///< cooperative launch + ///< on multiple + ///< devices with + ///< unmatched + ///< functions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports + ///< cooperative + ///< launch on + ///< multiple + ///< devices with + ///< unmatched grid + ///< dimensions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports + ///< cooperative + ///< launch on + ///< multiple + ///< devices with + ///< unmatched + ///< block + ///< dimensions + hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports + ///< cooperative + ///< launch on + ///< multiple + ///< devices with + ///< unmatched + ///< shared + ///< memories + hipDeviceAttributeIsLargeBar, ///< Whether it is LargeBar + hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device + hipDeviceAttributeCanUseStreamWaitValue, ///< '1' if Device supports + ///< hipStreamWaitValue32() and + ///< hipStreamWaitValue64(), '0' + ///< otherwise. + hipDeviceAttributeImageSupport, ///< '1' if Device supports image, '0' + ///< otherwise. + hipDeviceAttributePhysicalMultiProcessorCount, ///< All available physical + ///< compute units for the + ///< device + hipDeviceAttributeFineGrainSupport, ///< '1' if Device supports fine grain, + ///< '0' otherwise + hipDeviceAttributeWallClockRate, ///< Constant frequency of wall clock in + ///< kilohertz. + + hipDeviceAttributeAmdSpecificEnd = 19999, + hipDeviceAttributeVendorSpecificBegin = 20000, + // Extended attributes for vendors +} hipDeviceAttribute_t; + +typedef enum hipDriverProcAddressQueryResult { + HIP_GET_PROC_ADDRESS_SUCCESS = 0, + HIP_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND = 1, + HIP_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT = 2 +} hipDriverProcAddressQueryResult; + +enum hipComputeMode { + hipComputeModeDefault = 0, + hipComputeModeExclusive = 1, + hipComputeModeProhibited = 2, + hipComputeModeExclusiveProcess = 3 +}; + +enum hipFlushGPUDirectRDMAWritesOptions { + hipFlushGPUDirectRDMAWritesOptionHost = 1 << 0, + hipFlushGPUDirectRDMAWritesOptionMemOps = 1 << 1 +}; + +enum hipGPUDirectRDMAWritesOrdering { + hipGPUDirectRDMAWritesOrderingNone = 0, + hipGPUDirectRDMAWritesOrderingOwner = 100, + hipGPUDirectRDMAWritesOrderingAllDevices = 200 +}; + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) + +#include +#include +#ifndef GENERIC_GRID_LAUNCH +#define GENERIC_GRID_LAUNCH 1 +#endif +#include +#include +#include +#include +#if defined(_MSC_VER) +#define DEPRECATED(msg) __declspec(deprecated(msg)) +#else // !defined(_MSC_VER) +#define DEPRECATED(msg) __attribute__((deprecated(msg))) +#endif // !defined(_MSC_VER) +#define DEPRECATED_MSG \ + "This API is marked as deprecated and may not be supported in future " \ + "releases. For more details please refer " \ + "https://github.com/ROCm/HIP/blob/develop/docs/reference/" \ + "deprecated_api_list.md" +#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void *)0x01) +#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void *)0x02) +#define HIP_LAUNCH_PARAM_END ((void *)0x03) +#ifdef __cplusplus +#define __dparm(x) = x +#else +#define __dparm(x) +#endif +#ifdef __GNUC__ +#pragma GCC visibility push(default) +#endif +#ifdef __cplusplus +namespace hip_impl { +hipError_t hip_init(); +} // namespace hip_impl +#endif +// Structure definitions: +#ifdef __cplusplus +extern "C" { +#endif +//--- +// API-visible structures +typedef struct ihipCtx_t *hipCtx_t; +// Note many APIs also use integer deviceIds as an alternative to the device +// pointer: +typedef int hipDevice_t; +typedef enum hipDeviceP2PAttr { + hipDevP2PAttrPerformanceRank = 0, + hipDevP2PAttrAccessSupported, + hipDevP2PAttrNativeAtomicSupported, + hipDevP2PAttrHipArrayAccessSupported +} hipDeviceP2PAttr; +typedef struct ihipStream_t *hipStream_t; +#define hipIpcMemLazyEnablePeerAccess 0x01 +#define HIP_IPC_HANDLE_SIZE 64 +typedef struct hipIpcMemHandle_st { + char reserved[HIP_IPC_HANDLE_SIZE]; +} hipIpcMemHandle_t; +typedef struct hipIpcEventHandle_st { + char reserved[HIP_IPC_HANDLE_SIZE]; +} hipIpcEventHandle_t; +typedef struct ihipModule_t *hipModule_t; +typedef struct ihipModuleSymbol_t *hipFunction_t; +/** + * HIP memory pool + */ +typedef struct ihipMemPoolHandle_t *hipMemPool_t; + +typedef struct hipFuncAttributes { + int binaryVersion; + int cacheModeCA; + size_t constSizeBytes; + size_t localSizeBytes; + int maxDynamicSharedSizeBytes; + int maxThreadsPerBlock; + int numRegs; + int preferredShmemCarveout; + int ptxVersion; + size_t sharedSizeBytes; +} hipFuncAttributes; +typedef struct ihipEvent_t *hipEvent_t; + +/** + * hipLimit + * + * @note In HIP device limit-related APIs, any input limit value other than + * those defined in the enum is treated as "UnsupportedLimit" by default. + */ +enum hipLimit_t { + hipLimitStackSize = 0x0, ///< Limit of stack size in bytes on the current + ///< device, per thread. The size is in units of 256 + ///< dwords, up to the limit of (128K - 16) + hipLimitPrintfFifoSize = + 0x01, ///< Size limit in bytes of fifo used by printf call on the + ///< device. Currently not supported + hipLimitMallocHeapSize = + 0x02, ///< Limit of heap size in bytes on the current device, should + ///< be less than the global memory size on the device + hipLimitRange ///< Supported limit range +}; +/** + * Flags that can be used with hipStreamCreateWithFlags. + */ +// Flags that can be used with hipStreamCreateWithFlags. +/** Default stream creation flags. These are used with hipStreamCreate().*/ +#define hipStreamDefault 0x00 + +/** Stream does not implicitly synchronize with null stream.*/ +#define hipStreamNonBlocking 0x01 + +// Flags that can be used with hipEventCreateWithFlags. +/** Default flags.*/ +#define hipEventDefault 0x0 + +/** Waiting will yield CPU. Power-friendly and usage-friendly but may increase + * latency.*/ +#define hipEventBlockingSync 0x1 + +/** Disable event's capability to record timing information. May improve + * performance.*/ +#define hipEventDisableTiming 0x2 + +/** Event can support IPC. hipEventDisableTiming also must be set.*/ +#define hipEventInterprocess 0x4 + +/** Disable performing a system scope sequentially consistent memory fence when + * the event transitions from recording to recorded. This can be used for + * events that are only being used to measure timing, and do not require the + * event inspection operations (see ::hipEventSynchronize, ::hipEventQuery, and + * ::hipEventElapsedTime) to synchronize-with the work on which the recorded + * event (see ::hipEventRecord) is waiting. On some AMD GPU devices this can + * improve the accuracy of timing measurements by avoiding the cost of cache + * writeback and invalidation, and the performance impact of those actions on + * the execution of following work. */ +#define hipEventDisableSystemFence 0x20000000 + +/** Use a device-scope release when recording this event. This flag is useful to + * obtain more precise timings of commands between events. The flag is a no-op + * on CUDA platforms.*/ +#define hipEventReleaseToDevice 0x40000000 + +/** Use a system-scope release when recording this event. This flag is useful to + * make non-coherent host memory visible to the host. The flag is a no-op on + * CUDA platforms.*/ +#define hipEventReleaseToSystem 0x80000000 + +// Flags that can be used with hipHostMalloc. +/** Default pinned memory allocation on the host.*/ +#define hipHostMallocDefault 0x0 + +/** Memory is considered allocated by all contexts.*/ +#define hipHostMallocPortable 0x1 + +/** Map the allocation into the address space for the current device. The device + * pointer can be obtained with #hipHostGetDevicePointer.*/ +#define hipHostMallocMapped 0x2 + +/** Allocates the memory as write-combined. On some system configurations, + * write-combined allocation may be transferred faster across the PCI Express + * bus, however, could have low read efficiency by most CPUs. It's a good option + * for data tranfer from host to device via mapped pinned memory.*/ +#define hipHostMallocWriteCombined 0x4 + +/** + * Host memory allocation will follow numa policy set by user. + * @note This numa allocation flag is applicable on Linux, under development on + * Windows. + */ +#define hipHostMallocNumaUser 0x20000000 + +/** Allocate coherent memory. Overrides HIP_COHERENT_HOST_ALLOC for specific + * allocation.*/ +#define hipHostMallocCoherent 0x40000000 + +/** Allocate non-coherent memory. Overrides HIP_COHERENT_HOST_ALLOC for specific + * allocation.*/ +#define hipHostMallocNonCoherent 0x80000000 + +/** Memory can be accessed by any stream on any device*/ +#define hipMemAttachGlobal 0x01 + +/** Memory cannot be accessed by any stream on any device.*/ +#define hipMemAttachHost 0x02 + +/** Memory can only be accessed by a single stream on the associated device.*/ +#define hipMemAttachSingle 0x04 + +#define hipDeviceMallocDefault 0x0 + +/** Memory is allocated in fine grained region of device.*/ +#define hipDeviceMallocFinegrained 0x1 + +/** Memory represents a HSA signal.*/ +#define hipMallocSignalMemory 0x2 + +/** Memory allocated will be uncached. */ +#define hipDeviceMallocUncached 0x3 + +/** Memory allocated will be contiguous. */ +#define hipDeviceMallocContiguous 0x4 + +// Flags that can be used with hipHostRegister. +/** Memory is Mapped and Portable.*/ +#define hipHostRegisterDefault 0x0 + +/** Memory is considered registered by all contexts.*/ +#define hipHostRegisterPortable 0x1 + +/** Map the allocation into the address space for the current device. The device + * pointer can be obtained with #hipHostGetDevicePointer.*/ +#define hipHostRegisterMapped 0x2 + +/** Not supported.*/ +#define hipHostRegisterIoMemory 0x4 + +/** This flag is ignored On AMD devices.*/ +#define hipHostRegisterReadOnly 0x08 + +/** Coarse Grained host memory lock.*/ +#define hipExtHostRegisterCoarseGrained 0x8 + +/** Automatically select between Spin and Yield.*/ +#define hipDeviceScheduleAuto 0x0 + +/** Dedicate a CPU core to spin-wait. Provides lowest latency, but burns a CPU + * core and may consume more power.*/ +#define hipDeviceScheduleSpin 0x1 + +/** Yield the CPU to the operating system when waiting. May increase latency, + * but lowers power and is friendlier to other threads in the system.*/ +#define hipDeviceScheduleYield 0x2 +#define hipDeviceScheduleBlockingSync 0x4 +#define hipDeviceScheduleMask 0x7 +#define hipDeviceMapHost 0x8 +#define hipDeviceLmemResizeToMax 0x10 +/** Default HIP array allocation flag.*/ +#define hipArrayDefault 0x00 +#define hipArrayLayered 0x01 +#define hipArraySurfaceLoadStore 0x02 +#define hipArrayCubemap 0x04 +#define hipArrayTextureGather 0x08 +#define hipOccupancyDefault 0x00 +#define hipOccupancyDisableCachingOverride 0x01 +#define hipCooperativeLaunchMultiDeviceNoPreSync 0x01 +#define hipCooperativeLaunchMultiDeviceNoPostSync 0x02 +#define hipCpuDeviceId ((int)-1) +#define hipInvalidDeviceId ((int)-2) +// Flags that can be used with hipExtLaunch Set of APIs. +/** AnyOrderLaunch of kernels.*/ +#define hipExtAnyOrderLaunch 0x01 +// Flags to be used with hipStreamWaitValue32 and hipStreamWaitValue64. +#define hipStreamWaitValueGte 0x0 +#define hipStreamWaitValueEq 0x1 +#define hipStreamWaitValueAnd 0x2 +#define hipStreamWaitValueNor 0x3 +// Stream per thread +/** Implicit stream per application thread.*/ +#define hipStreamPerThread ((hipStream_t)2) + +#define hipStreamLegacy ((hipStream_t)1) + +// Indicates that the external memory object is a dedicated resource +#define hipExternalMemoryDedicated 0x1 +/** + * HIP Memory Advise values + * + * @note This memory advise enumeration is used on Linux, not Windows. + */ +typedef enum hipMemoryAdvise { + hipMemAdviseSetReadMostly = 1, ///< Data will mostly be read and only + ///< occassionally be written to + hipMemAdviseUnsetReadMostly = + 2, ///< Undo the effect of hipMemAdviseSetReadMostly + hipMemAdviseSetPreferredLocation = 3, ///< Set the preferred location for the + ///< data as the specified device + hipMemAdviseUnsetPreferredLocation = + 4, ///< Clear the preferred location for the data + hipMemAdviseSetAccessedBy = + 5, ///< Data will be accessed by the specified device + ///< so prevent page faults as much as possible + hipMemAdviseUnsetAccessedBy = 6, ///< Let HIP to decide on the page faulting + ///< policy for the specified device + hipMemAdviseSetCoarseGrain = + 100, ///< The default memory model is fine-grain. That allows + ///< coherent operations between host and device, while + ///< executing kernels. The coarse-grain can be used + ///< for data that only needs to be coherent at dispatch + ///< boundaries for better performance + hipMemAdviseUnsetCoarseGrain = + 101 ///< Restores cache coherency policy back to fine-grain +} hipMemoryAdvise; +/** + * HIP Coherency Mode + */ +typedef enum hipMemRangeCoherencyMode { + hipMemRangeCoherencyModeFineGrain = + 0, ///< Updates to memory with this attribute can be + ///< done coherently from all devices + hipMemRangeCoherencyModeCoarseGrain = + 1, ///< Writes to memory with this attribute can be + ///< performed by a single device at a time + hipMemRangeCoherencyModeIndeterminate = + 2 ///< Memory region queried contains subregions with + ///< both hipMemRangeCoherencyModeFineGrain and + ///< hipMemRangeCoherencyModeCoarseGrain attributes +} hipMemRangeCoherencyMode; +/** + * HIP range attributes + */ +typedef enum hipMemRangeAttribute { + hipMemRangeAttributeReadMostly = 1, ///< Whether the range will mostly be read + ///< and only occassionally be written to + hipMemRangeAttributePreferredLocation = + 2, ///< The preferred location of the range + hipMemRangeAttributeAccessedBy = + 3, ///< Memory range has hipMemAdviseSetAccessedBy + ///< set for the specified device + hipMemRangeAttributeLastPrefetchLocation = 4, ///< The last location to where + ///< the range was prefetched + hipMemRangeAttributeCoherencyMode = + 100, ///< Returns coherency mode + ///< @ref hipMemRangeCoherencyMode for the range +} hipMemRangeAttribute; + +/** + * HIP memory pool attributes + */ +typedef enum hipMemPoolAttr { + /** + * (value type = int) + * Allow @p hipMemAllocAsync to use memory asynchronously freed + * in another streams as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * hip events and null stream interactions can create the required + * stream ordered dependencies. (default enabled) + */ + hipMemPoolReuseFollowEventDependencies = 0x1, + /** + * (value type = int) + * Allow reuse of already completed frees when there is no dependency + * between the free and allocation. (default enabled) + */ + hipMemPoolReuseAllowOpportunistic = 0x2, + /** + * (value type = int) + * Allow @p hipMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to reuse + * a piece of memory released by cuFreeAsync (default enabled). + */ + hipMemPoolReuseAllowInternalDependencies = 0x3, + /** + * (value type = uint64_t) + * Amount of reserved memory in bytes to hold onto before trying + * to release memory back to the OS. When more than the release + * threshold bytes of memory are held by the memory pool, the + * allocator will try to release memory back to the OS on the + * next call to stream, event or context synchronize. (default 0) + */ + hipMemPoolAttrReleaseThreshold = 0x4, + /** + * (value type = uint64_t) + * Amount of backing memory currently allocated for the mempool. + */ + hipMemPoolAttrReservedMemCurrent = 0x5, + /** + * (value type = uint64_t) + * High watermark of backing memory allocated for the mempool since the + * last time it was reset. High watermark can only be reset to zero. + */ + hipMemPoolAttrReservedMemHigh = 0x6, + /** + * (value type = uint64_t) + * Amount of memory from the pool that is currently in use by the application. + */ + hipMemPoolAttrUsedMemCurrent = 0x7, + /** + * (value type = uint64_t) + * High watermark of the amount of memory from the pool that was in use by the + * application since the last time it was reset. High watermark can only be + * reset to zero. + */ + hipMemPoolAttrUsedMemHigh = 0x8 +} hipMemPoolAttr; +/** + * Specifies the type of location + */ +typedef enum hipMemLocationType { + hipMemLocationTypeInvalid = 0, + hipMemLocationTypeDevice = 1 ///< Device location, thus it's HIP device ID +} hipMemLocationType; +/** + * Specifies a memory location. + * + * To specify a gpu, set type = @p hipMemLocationTypeDevice and set id = the + * gpu's device ID + */ +typedef struct hipMemLocation { + hipMemLocationType + type; ///< Specifies the location type, which describes the meaning of id + int id; ///< Identifier for the provided location type @p hipMemLocationType +} hipMemLocation; +/** + * Specifies the memory protection flags for mapping + * + */ +typedef enum hipMemAccessFlags { + hipMemAccessFlagsProtNone = + 0, ///< Default, make the address range not accessible + hipMemAccessFlagsProtRead = 1, ///< Set the address range read accessible + hipMemAccessFlagsProtReadWrite = + 3 ///< Set the address range read-write accessible +} hipMemAccessFlags; +/** + * Memory access descriptor + */ +typedef struct hipMemAccessDesc { + hipMemLocation + location; ///< Location on which the accessibility has to change + hipMemAccessFlags flags; ///< Accessibility flags to set +} hipMemAccessDesc; +/** + * Defines the allocation types + */ +typedef enum hipMemAllocationType { + hipMemAllocationTypeInvalid = 0x0, + /** This allocation type is 'pinned', i.e. cannot migrate from its current + * location while the application is actively using it + */ + hipMemAllocationTypePinned = 0x1, + hipMemAllocationTypeMax = 0x7FFFFFFF +} hipMemAllocationType; +/** + * Flags for specifying handle types for memory pool allocations + * + */ +typedef enum hipMemAllocationHandleType { + hipMemHandleTypeNone = 0x0, ///< Does not allow any export mechanism + hipMemHandleTypePosixFileDescriptor = + 0x1, ///< Allows a file descriptor for exporting. Permitted only on POSIX + ///< systems + hipMemHandleTypeWin32 = + 0x2, ///< Allows a Win32 NT handle for exporting. (HANDLE) + hipMemHandleTypeWin32Kmt = + 0x4 ///< Allows a Win32 KMT handle for exporting. (D3DKMT_HANDLE) +} hipMemAllocationHandleType; +/** + * Specifies the properties of allocations made from the pool. + */ +typedef struct hipMemPoolProps { + hipMemAllocationType + allocType; ///< Allocation type. Currently must be specified as @p + ///< hipMemAllocationTypePinned + hipMemAllocationHandleType + handleTypes; ///< Handle types that will be supported by allocations from + ///< the pool + hipMemLocation location; ///< Location where allocations should reside + /** + * Windows-specific LPSECURITYATTRIBUTES required when @p + * hipMemHandleTypeWin32 is specified + */ + void *win32SecurityAttributes; + size_t maxSize; ///< Maximum pool size. When set to 0, defaults to a system + ///< dependent value + unsigned char reserved[56]; ///< Reserved for future use, must be 0 +} hipMemPoolProps; +/** + * Opaque data structure for exporting a pool allocation + */ +typedef struct hipMemPoolPtrExportData { + unsigned char reserved[64]; +} hipMemPoolPtrExportData; + +/** + * hipJitOption + */ +typedef enum hipJitOption { + hipJitOptionMaxRegisters = 0, + hipJitOptionThreadsPerBlock, + hipJitOptionWallTime, + hipJitOptionInfoLogBuffer, + hipJitOptionInfoLogBufferSizeBytes, + hipJitOptionErrorLogBuffer, + hipJitOptionErrorLogBufferSizeBytes, + hipJitOptionOptimizationLevel, + hipJitOptionTargetFromContext, + hipJitOptionTarget, + hipJitOptionFallbackStrategy, + hipJitOptionGenerateDebugInfo, + hipJitOptionLogVerbose, + hipJitOptionGenerateLineInfo, + hipJitOptionCacheMode, + hipJitOptionSm3xOpt, + hipJitOptionFastCompile, + hipJitOptionNumOptions +} hipJitOption; +/** + * @warning On AMD devices and some Nvidia devices, these hints and controls are + * ignored. + */ +typedef enum hipFuncAttribute { + hipFuncAttributeMaxDynamicSharedMemorySize = 8, + hipFuncAttributePreferredSharedMemoryCarveout = 9, + hipFuncAttributeMax +} hipFuncAttribute; +/** + * @warning On AMD devices and some Nvidia devices, these hints and controls are + * ignored. + */ +typedef enum hipFuncCache_t { + hipFuncCachePreferNone, ///< no preference for shared memory or L1 (default) + hipFuncCachePreferShared, ///< prefer larger shared memory and smaller L1 + ///< cache + hipFuncCachePreferL1, ///< prefer larger L1 cache and smaller shared memory + hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory +} hipFuncCache_t; +/** + * @warning On AMD devices and some Nvidia devices, these hints and controls are + * ignored. + */ +typedef enum hipSharedMemConfig { + hipSharedMemBankSizeDefault, ///< The compiler selects a device-specific value + ///< for the banking. + hipSharedMemBankSizeFourByte, ///< Shared mem is banked at 4-bytes intervals + ///< and performs best when adjacent threads + ///< access data 4 bytes apart. + hipSharedMemBankSizeEightByte ///< Shared mem is banked at 8-byte intervals + ///< and performs best when adjacent threads + ///< access data 4 bytes apart. +} hipSharedMemConfig; +/** + * Struct for data in 3D + */ +typedef struct dim3 { + uint32_t x; ///< x + uint32_t y; ///< y + uint32_t z; ///< z +#ifdef __cplusplus + constexpr __host__ __device__ dim3(uint32_t _x = 1, uint32_t _y = 1, + uint32_t _z = 1) + : x(_x), y(_y), z(_z) {}; +#endif +} dim3; +/** + * struct hipLaunchParams_t + */ +typedef struct hipLaunchParams_t { + void *func; ///< Device function symbol + dim3 gridDim; ///< Grid dimentions + dim3 blockDim; ///< Block dimentions + void **args; ///< Arguments + size_t sharedMem; ///< Shared memory + hipStream_t stream; ///< Stream identifier +} hipLaunchParams; +/** + * struct hipFunctionLaunchParams_t + */ +typedef struct hipFunctionLaunchParams_t { + hipFunction_t function; ///< Kernel to launch + unsigned int gridDimX; ///< Width(X) of grid in blocks + unsigned int gridDimY; ///< Height(Y) of grid in blocks + unsigned int gridDimZ; ///< Depth(Z) of grid in blocks + unsigned int blockDimX; ///< X dimension of each thread block + unsigned int blockDimY; ///< Y dimension of each thread block + unsigned int blockDimZ; ///< Z dimension of each thread block + unsigned int sharedMemBytes; ///< Shared memory + hipStream_t hStream; ///< Stream identifier + void **kernelParams; ///< Kernel parameters +} hipFunctionLaunchParams; +typedef enum hipExternalMemoryHandleType_enum { + hipExternalMemoryHandleTypeOpaqueFd = 1, + hipExternalMemoryHandleTypeOpaqueWin32 = 2, + hipExternalMemoryHandleTypeOpaqueWin32Kmt = 3, + hipExternalMemoryHandleTypeD3D12Heap = 4, + hipExternalMemoryHandleTypeD3D12Resource = 5, + hipExternalMemoryHandleTypeD3D11Resource = 6, + hipExternalMemoryHandleTypeD3D11ResourceKmt = 7, + hipExternalMemoryHandleTypeNvSciBuf = 8 +} hipExternalMemoryHandleType; +typedef struct hipExternalMemoryHandleDesc_st { + hipExternalMemoryHandleType type; + union { + int fd; + struct { + void *handle; + const void *name; + } win32; + const void *nvSciBufObject; + } handle; + unsigned long long size; + unsigned int flags; + unsigned int reserved[16]; +} hipExternalMemoryHandleDesc; +typedef struct hipExternalMemoryBufferDesc_st { + unsigned long long offset; + unsigned long long size; + unsigned int flags; + unsigned int reserved[16]; +} hipExternalMemoryBufferDesc; +typedef struct hipExternalMemoryMipmappedArrayDesc_st { + unsigned long long offset; + hipChannelFormatDesc formatDesc; + hipExtent extent; + unsigned int flags; + unsigned int numLevels; +} hipExternalMemoryMipmappedArrayDesc; +typedef void *hipExternalMemory_t; +typedef enum hipExternalSemaphoreHandleType_enum { + hipExternalSemaphoreHandleTypeOpaqueFd = 1, + hipExternalSemaphoreHandleTypeOpaqueWin32 = 2, + hipExternalSemaphoreHandleTypeOpaqueWin32Kmt = 3, + hipExternalSemaphoreHandleTypeD3D12Fence = 4, + hipExternalSemaphoreHandleTypeD3D11Fence = 5, + hipExternalSemaphoreHandleTypeNvSciSync = 6, + hipExternalSemaphoreHandleTypeKeyedMutex = 7, + hipExternalSemaphoreHandleTypeKeyedMutexKmt = 8, + hipExternalSemaphoreHandleTypeTimelineSemaphoreFd = 9, + hipExternalSemaphoreHandleTypeTimelineSemaphoreWin32 = 10 +} hipExternalSemaphoreHandleType; +typedef struct hipExternalSemaphoreHandleDesc_st { + hipExternalSemaphoreHandleType type; + union { + int fd; + struct { + void *handle; + const void *name; + } win32; + const void *NvSciSyncObj; + } handle; + unsigned int flags; + unsigned int reserved[16]; +} hipExternalSemaphoreHandleDesc; +typedef void *hipExternalSemaphore_t; +typedef struct hipExternalSemaphoreSignalParams_st { + struct { + struct { + unsigned long long value; + } fence; + union { + void *fence; + unsigned long long reserved; + } nvSciSync; + struct { + unsigned long long key; + } keyedMutex; + unsigned int reserved[12]; + } params; + unsigned int flags; + unsigned int reserved[16]; +} hipExternalSemaphoreSignalParams; +/** + * External semaphore wait parameters, compatible with driver type + */ +typedef struct hipExternalSemaphoreWaitParams_st { + struct { + struct { + unsigned long long value; + } fence; + union { + void *fence; + unsigned long long reserved; + } nvSciSync; + struct { + unsigned long long key; + unsigned int timeoutMs; + } keyedMutex; + unsigned int reserved[10]; + } params; + unsigned int flags; + unsigned int reserved[16]; +} hipExternalSemaphoreWaitParams; + +#if __HIP_HAS_GET_PCH +/** + * Internal use only. This API may change in the future + * Pre-Compiled header for online compilation + */ +void __hipGetPCH(const char **pch, unsigned int *size); +#endif + +/** + * HIP Access falgs for Interop resources. + */ +typedef enum hipGraphicsRegisterFlags { + hipGraphicsRegisterFlagsNone = 0, + hipGraphicsRegisterFlagsReadOnly = + 1, ///< HIP will not write to this registered resource + hipGraphicsRegisterFlagsWriteDiscard = + 2, ///< HIP will only write and will not read from this registered + ///< resource + hipGraphicsRegisterFlagsSurfaceLoadStore = + 4, ///< HIP will bind this resource to a surface + hipGraphicsRegisterFlagsTextureGather = + 8 ///< HIP will perform texture gather operations on this registered + ///< resource +} hipGraphicsRegisterFlags; + +typedef struct _hipGraphicsResource hipGraphicsResource; + +typedef hipGraphicsResource *hipGraphicsResource_t; + +/** + * An opaque value that represents a hip graph + */ +typedef struct ihipGraph *hipGraph_t; +/** + * An opaque value that represents a hip graph node + */ +typedef struct hipGraphNode *hipGraphNode_t; +/** + * An opaque value that represents a hip graph Exec + */ +typedef struct hipGraphExec *hipGraphExec_t; + +/** + * An opaque value that represents a user obj + */ +typedef struct hipUserObject *hipUserObject_t; + +/** + * hipGraphNodeType + */ +typedef enum hipGraphNodeType { + hipGraphNodeTypeKernel = 0, ///< GPU kernel node + hipGraphNodeTypeMemcpy = 1, ///< Memcpy node + hipGraphNodeTypeMemset = 2, ///< Memset node + hipGraphNodeTypeHost = 3, ///< Host (executable) node + hipGraphNodeTypeGraph = 4, ///< Node which executes an embedded graph + hipGraphNodeTypeEmpty = 5, ///< Empty (no-op) node + hipGraphNodeTypeWaitEvent = 6, ///< External event wait node + hipGraphNodeTypeEventRecord = 7, ///< External event record node + hipGraphNodeTypeExtSemaphoreSignal = 8, ///< External Semaphore signal node + hipGraphNodeTypeExtSemaphoreWait = 9, ///< External Semaphore wait node + hipGraphNodeTypeMemAlloc = 10, ///< Memory alloc node + hipGraphNodeTypeMemFree = 11, ///< Memory free node + hipGraphNodeTypeMemcpyFromSymbol = 12, ///< MemcpyFromSymbol node + hipGraphNodeTypeMemcpyToSymbol = 13, ///< MemcpyToSymbol node + hipGraphNodeTypeCount +} hipGraphNodeType; + +typedef void (*hipHostFn_t)(void *userData); +typedef struct hipHostNodeParams { + hipHostFn_t fn; + void *userData; +} hipHostNodeParams; +typedef struct hipKernelNodeParams { + dim3 blockDim; + void **extra; + void *func; + dim3 gridDim; + void **kernelParams; + unsigned int sharedMemBytes; +} hipKernelNodeParams; +typedef struct hipMemsetParams { + void *dst; + unsigned int elementSize; + size_t height; + size_t pitch; + unsigned int value; + size_t width; +} hipMemsetParams; + +typedef struct hipMemAllocNodeParams { + hipMemPoolProps poolProps; ///< Pool properties, which contain where + ///< the location should reside + const hipMemAccessDesc + *accessDescs; ///< The number of memory access descriptors. + ///< Must not be bigger than the number of GPUs + size_t accessDescCount; ///< The number of access descriptors + size_t bytesize; ///< The size of the requested allocation in bytes + void *dptr; ///< Returned device address of the allocation +} hipMemAllocNodeParams; + +typedef enum hipAccessProperty { + hipAccessPropertyNormal = 0, + hipAccessPropertyStreaming = 1, + hipAccessPropertyPersisting = 2, +} hipAccessProperty; +typedef struct hipAccessPolicyWindow { + void *base_ptr; + hipAccessProperty hitProp; + float hitRatio; + hipAccessProperty missProp; + size_t num_bytes; +} hipAccessPolicyWindow; + +/** + * Launch Attribute ID + */ +typedef enum hipLaunchAttributeID { + hipLaunchAttributeAccessPolicyWindow = + 1, /**< Valid for Streams, graph nodes, launches*/ + hipLaunchAttributeCooperative = 2, /**< Valid for graph nodes, launches */ + hipLaunchAttributePriority = + 8, /**< Valid for graph node, streams, launches */ +} hipLaunchAttributeID; + +/** + * Launch Attribute Value + */ +typedef union hipLaunchAttributeValue { + hipAccessPolicyWindow accessPolicyWindow; /**< Value of launch attribute:: + hipLaunchAttributePolicyWindow. */ + int cooperative; /**< Value of launch attribute + ::hipLaunchAttributeCooperative */ + int priority; /**< Value of launch attribute :: hipLaunchAttributePriority. + Execution priority of kernel. */ +} hipLaunchAttributeValue; + +/** + * Kernel node attributeID + */ +#define hipKernelNodeAttrID hipLaunchAttributeID +#define hipKernelNodeAttributeAccessPolicyWindow \ + hipLaunchAttributeAccessPolicyWindow +#define hipKernelNodeAttributeCooperative hipLaunchAttributeCooperative +#define hipKernelNodeAttributePriority hipLaunchAttributePriority + +/** + * Kernel node attribute value + */ +#define hipKernelNodeAttrValue hipLaunchAttributeValue + +/** + * Memset node params + */ +typedef struct HIP_MEMSET_NODE_PARAMS { + hipDeviceptr_t dst; ///< Destination pointer on device + size_t pitch; ///< Destination device pointer pitch. Unused if height equals 1 + unsigned int value; ///< Value of memset to be set + unsigned int elementSize; ///< Element in bytes. Must be 1, 2, or 4. + size_t width; ///< Width of a row + size_t height; ///< Number of rows +} HIP_MEMSET_NODE_PARAMS; + +/** + * Graph execution update result + */ +typedef enum hipGraphExecUpdateResult { + hipGraphExecUpdateSuccess = 0x0, ///< The update succeeded + hipGraphExecUpdateError = + 0x1, ///< The update failed for an unexpected reason which is described + ///< in the return value of the function + hipGraphExecUpdateErrorTopologyChanged = + 0x2, ///< The update failed because the topology changed + hipGraphExecUpdateErrorNodeTypeChanged = + 0x3, ///< The update failed because a node type changed + hipGraphExecUpdateErrorFunctionChanged = + 0x4, ///< The update failed because the function of a kernel node changed + hipGraphExecUpdateErrorParametersChanged = + 0x5, ///< The update failed because the parameters changed in a way that + ///< is not supported + hipGraphExecUpdateErrorNotSupported = + 0x6, ///< The update failed because something about the node is not + ///< supported + hipGraphExecUpdateErrorUnsupportedFunctionChange = 0x7 +} hipGraphExecUpdateResult; + +typedef enum hipStreamCaptureMode { + hipStreamCaptureModeGlobal = 0, + hipStreamCaptureModeThreadLocal, + hipStreamCaptureModeRelaxed +} hipStreamCaptureMode; +typedef enum hipStreamCaptureStatus { + hipStreamCaptureStatusNone = 0, ///< Stream is not capturing + hipStreamCaptureStatusActive, ///< Stream is actively capturing + hipStreamCaptureStatusInvalidated ///< Stream is part of a capture sequence + ///< that has been invalidated, but not + ///< terminated +} hipStreamCaptureStatus; + +typedef enum hipStreamUpdateCaptureDependenciesFlags { + hipStreamAddCaptureDependencies = 0, ///< Add new nodes to the dependency set + hipStreamSetCaptureDependencies, ///< Replace the dependency set with the new + ///< nodes +} hipStreamUpdateCaptureDependenciesFlags; + +typedef enum hipGraphMemAttributeType { + hipGraphMemAttrUsedMemCurrent = + 0, ///< Amount of memory, in bytes, currently associated with graphs + hipGraphMemAttrUsedMemHigh, ///< High watermark of memory, in bytes, + ///< associated with graphs since the last time. + hipGraphMemAttrReservedMemCurrent, ///< Amount of memory, in bytes, currently + ///< allocated for graphs. + hipGraphMemAttrReservedMemHigh, ///< High watermark of memory, in bytes, + ///< currently allocated for graphs +} hipGraphMemAttributeType; +typedef enum hipUserObjectFlags { + hipUserObjectNoDestructorSync = + 0x1, ///< Destructor execution is not synchronized. +} hipUserObjectFlags; + +typedef enum hipUserObjectRetainFlags { + hipGraphUserObjectMove = 0x1, ///< Add new reference or retain. +} hipUserObjectRetainFlags; + +typedef enum hipGraphInstantiateFlags { + hipGraphInstantiateFlagAutoFreeOnLaunch = + 1, ///< Automatically free memory allocated in a graph before relaunching. + hipGraphInstantiateFlagUpload = + 2, ///< Automatically upload the graph after instantiaton. + hipGraphInstantiateFlagDeviceLaunch = + 4, ///< Instantiate the graph to be launchable from the device. + hipGraphInstantiateFlagUseNodePriority = + 8, ///< Run the graph using the per-node priority attributes rather than + ///< the priority of the stream it is launched into. +} hipGraphInstantiateFlags; + +enum hipGraphDebugDotFlags { + hipGraphDebugDotFlagsVerbose = + 1 << 0, /**< Output all debug data as if every debug flag is enabled */ + hipGraphDebugDotFlagsKernelNodeParams = + 1 << 2, /**< Adds hipKernelNodeParams to output */ + hipGraphDebugDotFlagsMemcpyNodeParams = + 1 << 3, /**< Adds hipMemcpy3DParms to output */ + hipGraphDebugDotFlagsMemsetNodeParams = + 1 << 4, /**< Adds hipMemsetParams to output */ + hipGraphDebugDotFlagsHostNodeParams = + 1 << 5, /**< Adds hipHostNodeParams to output */ + hipGraphDebugDotFlagsEventNodeParams = + 1 + << 6, /**< Adds hipEvent_t handle from record and wait nodes to output */ + hipGraphDebugDotFlagsExtSemasSignalNodeParams = + 1 << 7, /**< Adds hipExternalSemaphoreSignalNodeParams values to output */ + hipGraphDebugDotFlagsExtSemasWaitNodeParams = + 1 << 8, /**< Adds hipExternalSemaphoreWaitNodeParams to output */ + hipGraphDebugDotFlagsKernelNodeAttributes = + 1 << 9, /**< Adds hipKernelNodeAttrID values to output */ + hipGraphDebugDotFlagsHandles = + 1 + << 10 /**< Adds node handles and every kernel function handle to output */ +}; + +/** + * hipGraphInstantiateWithParams results + */ +typedef enum hipGraphInstantiateResult { + hipGraphInstantiateSuccess = 0, /**< Instantiation Success */ + hipGraphInstantiateError = 1, /**< Instantiation failed for an + unexpected reason which is described in the return value of the function */ + hipGraphInstantiateInvalidStructure = 2, /**< Instantiation failed due + to invalid structure, such as cycles */ + hipGraphInstantiateNodeOperationNotSupported = 3, /**< Instantiation for + device launch failed because the graph contained an unsupported operation */ + hipGraphInstantiateMultipleDevicesNotSupported = 4, /**< Instantiation for + device launch failed due to the nodes belonging to different contexts */ +} hipGraphInstantiateResult; + +/** + * Graph Instantiation parameters + */ +typedef struct hipGraphInstantiateParams { + hipGraphNode_t + errNode_out; /**< The node which caused instantiation to fail, if any*/ + unsigned long long flags; /**< Instantiation flags */ + hipGraphInstantiateResult result_out; /**< Whether instantiation was + successful. If it failed, the reason why */ + hipStream_t uploadStream; /**< Upload stream */ +} hipGraphInstantiateParams; + +/** + * Memory allocation properties + */ +typedef struct hipMemAllocationProp { + hipMemAllocationType type; ///< Memory allocation type + hipMemAllocationHandleType requestedHandleType; ///< Requested handle type + hipMemLocation location; ///< Memory location + void *win32HandleMetaData; ///< Metadata for Win32 handles + struct { + unsigned char compressionType; ///< Compression type + unsigned char gpuDirectRDMACapable; ///< RDMA capable + unsigned short usage; ///< Usage + } allocFlags; +} hipMemAllocationProp; + +/** + * External semaphore signal node parameters + */ +typedef struct hipExternalSemaphoreSignalNodeParams { + ///< Array containing external semaphore handles. + hipExternalSemaphore_t *extSemArray; + ///< Array containing parameters of external signal semaphore. + const hipExternalSemaphoreSignalParams *paramsArray; + ///< Total number of handles and parameters contained in extSemArray and + ///< paramsArray. + unsigned int numExtSems; +} hipExternalSemaphoreSignalNodeParams; + +/** + * External semaphore wait node parameters + */ +typedef struct hipExternalSemaphoreWaitNodeParams { + ///< Array containing external semaphore handles. + hipExternalSemaphore_t *extSemArray; + ///< Array containing parameters of external wait semaphore. + const hipExternalSemaphoreWaitParams *paramsArray; + ///< Total number of handles and parameters contained in extSemArray and + ///< paramsArray. + unsigned int numExtSems; +} hipExternalSemaphoreWaitNodeParams; + +/** + * Generic handle for memory allocation + */ +typedef struct ihipMemGenericAllocationHandle *hipMemGenericAllocationHandle_t; + +/** + * Flags for granularity + */ +typedef enum hipMemAllocationGranularity_flags { + hipMemAllocationGranularityMinimum = 0x0, ///< Minimum granularity + hipMemAllocationGranularityRecommended = + 0x1 ///< Recommended granularity for performance +} hipMemAllocationGranularity_flags; + +/** + * Memory handle type + */ +typedef enum hipMemHandleType { + hipMemHandleTypeGeneric = 0x0 ///< Generic handle type +} hipMemHandleType; + +/** + * Memory operation types + */ +typedef enum hipMemOperationType { + hipMemOperationTypeMap = 0x1, ///< Map operation + hipMemOperationTypeUnmap = 0x2 ///< Unmap operation +} hipMemOperationType; + +/** + * Subresource types for sparse arrays + */ +typedef enum hipArraySparseSubresourceType { + hipArraySparseSubresourceTypeSparseLevel = 0x0, ///< Sparse level + hipArraySparseSubresourceTypeMiptail = 0x1 ///< Miptail +} hipArraySparseSubresourceType; + +/** + * Map info for arrays + */ +typedef struct hipArrayMapInfo { + hipResourceType resourceType; ///< Resource type + union { + hipMipmappedArray mipmap; + hipArray_t array; + } resource; + hipArraySparseSubresourceType subresourceType; ///< Sparse subresource type + union { + struct { + unsigned int level; ///< For mipmapped arrays must be a valid mipmap + ///< level. For arrays must be zero + unsigned int layer; ///< For layered arrays must be a valid layer index. + ///< Otherwise, must be zero + unsigned int offsetX; ///< X offset in elements + unsigned int offsetY; ///< Y offset in elements + unsigned int offsetZ; ///< Z offset in elements + unsigned int extentWidth; ///< Width in elements + unsigned int extentHeight; ///< Height in elements + unsigned int extentDepth; ///< Depth in elements + } sparseLevel; + struct { + unsigned int layer; ///< For layered arrays must be a valid layer index. + ///< Otherwise, must be zero + unsigned long long offset; ///< Offset within mip tail + unsigned long long size; ///< Extent in bytes + } miptail; + } subresource; + hipMemOperationType memOperationType; ///< Memory operation type + hipMemHandleType memHandleType; ///< Memory handle type + union { + hipMemGenericAllocationHandle_t memHandle; + } memHandle; + unsigned long long offset; ///< Offset within the memory + unsigned int deviceBitMask; ///< Device ordinal bit mask + unsigned int flags; ///< flags for future use, must be zero now. + unsigned int reserved[2]; ///< Reserved for future use, must be zero now. +} hipArrayMapInfo; + +/** + * Memcpy node params + */ +typedef struct hipMemcpyNodeParams { + int flags; ///< Must be zero. + int reserved[3]; ///< Must be zero. + hipMemcpy3DParms copyParams; ///< Params set for the memory copy. +} hipMemcpyNodeParams; + +/** + * Child graph node params + */ +typedef struct hipChildGraphNodeParams { + hipGraph_t + graph; ///< Either the child graph to clone into the node, or + ///< a handle to the graph possesed by the node used during query +} hipChildGraphNodeParams; + +/** + * Event record node params + */ +typedef struct hipEventWaitNodeParams { + hipEvent_t event; ///< Event to wait on +} hipEventWaitNodeParams; + +/** + * Event record node params + */ +typedef struct hipEventRecordNodeParams { + hipEvent_t event; ///< The event to be recorded when node executes +} hipEventRecordNodeParams; + +/** + * Memory free node params + */ +typedef struct hipMemFreeNodeParams { + void *dptr; ///< the pointer to be freed +} hipMemFreeNodeParams; + +/** + * Params for different graph nodes + */ +typedef struct hipGraphNodeParams { + hipGraphNodeType type; + int reserved0[3]; + union { + long long reserved1[29]; + hipKernelNodeParams kernel; + hipMemcpyNodeParams memcpy; + hipMemsetParams memset; + hipHostNodeParams host; + hipChildGraphNodeParams graph; + hipEventWaitNodeParams eventWait; + hipEventRecordNodeParams eventRecord; + hipExternalSemaphoreSignalNodeParams extSemSignal; + hipExternalSemaphoreWaitNodeParams extSemWait; + hipMemAllocNodeParams alloc; + hipMemFreeNodeParams free; + }; + + long long reserved2; +} hipGraphNodeParams; + +/** + * This port activates when the kernel has finished executing. + */ +#define hipGraphKernelNodePortDefault 0 + +/** + * This port activates when all blocks of the kernel have begun execution. + */ +#define hipGraphKernelNodePortLaunchCompletion 2 + +/** + * This port activates when all blocks of the kernel have performed + * hipTriggerProgrammaticLaunchCompletion() or have terminated. + * It must be used with edge type hipGraphDependencyTypeProgrammatic. + */ +#define hipGraphKernelNodePortProgrammatic 1 + +typedef enum hipGraphDependencyType { + hipGraphDependencyTypeDefault = 0, + hipGraphDependencyTypeProgrammatic = 1 +} hipGraphDependencyType; + +typedef struct hipGraphEdgeData { + unsigned char + from_port; ///< This indicates when the dependency is triggered from the + ///< upstream node on the edge. The meaning is specfic to the + ///< node type. A value of 0 in all cases means full completion + ///< of the upstream node, with memory visibility to the + ///< downstream node or portion thereof (indicated by to_port). + ///< Only kernel nodes define non-zero ports. A kernel node can + ///< use the following output port types: + ///< hipGraphKernelNodePortDefault, + ///< hipGraphKernelNodePortProgrammatic, or + ///< hipGraphKernelNodePortLaunchCompletion. + unsigned char reserved[5]; ///< These bytes are unused and must be zeroed + unsigned char to_port; ///< Currently no node types define non-zero ports. + ///< This field must be set to zero. + unsigned char type; ///< This should be populated with a value from + ///< hipGraphDependencyType +} hipGraphEdgeData; + +// Doxygen end group GlobalDefs +/** + * @} + */ +/** + * @defgroup API HIP API + * @{ + * + * Defines the HIP API. See the individual sections for more information. + */ +/** + * @defgroup Driver Initialization and Version + * @{ + * This section describes the initializtion and version functions of HIP + * runtime API. + * + */ +/** + * @brief Explicitly initializes the HIP runtime. + * + * @param [in] flags Initialization flag, should be zero. + * + * Most HIP APIs implicitly initialize the HIP runtime. + * This API provides control over the timing of the initialization. + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +// TODO-ctx - more description on error codes. +hipError_t hipInit(unsigned int flags); + +/** + * @brief Returns the approximate HIP driver version. + * + * @param [out] driverVersion driver version + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning The HIP feature set does not correspond to an exact CUDA SDK driver + * revision. This function always set *driverVersion to 4 as an approximation + * though HIP supports some features which were introduced in later CUDA SDK + * revisions. HIP apps code should not rely on the driver revision number here + * and should use arch feature flags to test device capabilities or conditional + * compilation. + * + * @see hipRuntimeGetVersion + */ +hipError_t hipDriverGetVersion(int *driverVersion); +/** + * @brief Returns the approximate HIP Runtime version. + * + * @param [out] runtimeVersion HIP runtime version + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning The version definition of HIP runtime is different from CUDA. + * On AMD platform, the function returns HIP runtime version, + * while on NVIDIA platform, it returns CUDA runtime version. + * And there is no mapping/correlation between HIP version and CUDA version. + * + * @see hipDriverGetVersion + */ +hipError_t hipRuntimeGetVersion(int *runtimeVersion); +/** + * @brief Returns a handle to a compute device + * @param [out] device Handle of device + * @param [in] ordinal Device ordinal + * + * @returns #hipSuccess, #hipErrorInvalidDevice + */ +hipError_t hipDeviceGet(hipDevice_t *device, int ordinal); + +/** + * @brief Returns the compute capability of the device + * @param [out] major Major compute capability version number + * @param [out] minor Minor compute capability version number + * @param [in] device Device ordinal + * + * @returns #hipSuccess, #hipErrorInvalidDevice + */ +hipError_t hipDeviceComputeCapability(int *major, int *minor, + hipDevice_t device); +/** + * @brief Returns an identifer string for the device. + * @param [out] name String of the device name + * @param [in] len Maximum length of string to store in device name + * @param [in] device Device ordinal + * + * @returns #hipSuccess, #hipErrorInvalidDevice + */ +hipError_t hipDeviceGetName(char *name, int len, hipDevice_t device); +/** + * @brief Returns an UUID for the device.[BETA] + * @param [out] uuid UUID for the device + * @param [in] device device ordinal + * + * @warning This API is marked as beta, meaning, while this is feature complete, + * it is still open to changes and may have outstanding issues. + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue, + * #hipErrorNotInitialized, #hipErrorDeinitialized + */ +hipError_t hipDeviceGetUuid(hipUUID *uuid, hipDevice_t device); +/** + * @brief Returns a value for attribute of link between two devices + * @param [out] value Pointer of the value for the attrubute + * @param [in] attr enum of hipDeviceP2PAttr to query + * @param [in] srcDevice The source device of the link + * @param [in] dstDevice The destination device of the link + * + * @returns #hipSuccess, #hipErrorInvalidDevice + */ +hipError_t hipDeviceGetP2PAttribute(int *value, hipDeviceP2PAttr attr, + int srcDevice, int dstDevice); +/** + * @brief Returns a PCI Bus Id string for the device, overloaded to take int + * device ID. + * @param [out] pciBusId The string of PCI Bus Id format for the device + * @param [in] len Maximum length of string + * @param [in] device The device ordinal + * + * @returns #hipSuccess, #hipErrorInvalidDevice + */ +hipError_t hipDeviceGetPCIBusId(char *pciBusId, int len, int device); +/** + * @brief Returns a handle to a compute device. + * @param [out] device The handle of the device + * @param [in] pciBusId The string of PCI Bus Id for the device + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + */ +hipError_t hipDeviceGetByPCIBusId(int *device, const char *pciBusId); +/** + * @brief Returns the total amount of memory on the device. + * @param [out] bytes The size of memory in bytes, on the device + * @param [in] device The ordinal of the device + * + * @returns #hipSuccess, #hipErrorInvalidDevice + */ +hipError_t hipDeviceTotalMem(size_t *bytes, hipDevice_t device); +// doxygen end initialization +/** + * @} + */ +/** + * @defgroup Device Device Management + * @{ + * This section describes the device management functions of HIP runtime API. + */ +/** + * @brief Waits on all active streams on current device + * + * When this command is invoked, the host thread gets blocked until all the + * commands associated with streams associated with the device. HIP does not + * support multiple blocking modes (yet!). + * + * @returns #hipSuccess + * + * @see hipSetDevice, hipDeviceReset + */ +hipError_t hipDeviceSynchronize(void); +/** + * @brief The state of current device is discarded and updated to a fresh state. + * + * Calling this function deletes all streams created, memory allocated, kernels + * running, events created. Make sure that no other thread is using the device + * or streams, memory, kernels, events associated with the current device. + * + * @returns #hipSuccess + * + * @see hipDeviceSynchronize + */ +hipError_t hipDeviceReset(void); +/** + * @brief Set default device to be used for subsequent hip API calls from this + * thread. + * + * @param[in] deviceId Valid device in range 0...hipGetDeviceCount(). + * + * Sets @p device as the default device for the calling host thread. Valid + * device id's are 0... (hipGetDeviceCount()-1). + * + * Many HIP APIs implicitly use the "default device" : + * + * - Any device memory subsequently allocated from this host thread (using + * hipMalloc) will be allocated on device. + * - Any streams or events created from this host thread will be associated with + * device. + * - Any kernels launched from this host thread (using hipLaunchKernel) will be + * executed on device (unless a specific stream is specified, in which case the + * device associated with that stream will be used). + * + * This function may be called from any host thread. Multiple host threads may + * use the same device. This function does no synchronization with the previous + * or new device, and has very little runtime overhead. Applications can use + * hipSetDevice to quickly switch the default device before making a HIP runtime + * call which uses the default device. + * + * The default device is stored in thread-local-storage for each thread. + * Thread-pool implementations may inherit the default device of the previous + * thread. A good practice is to always call hipSetDevice at the start of HIP + * coding sequency to establish a known standard device. + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorNoDevice + * + * @see #hipGetDevice, #hipGetDeviceCount + */ +hipError_t hipSetDevice(int deviceId); +/** + * @brief Set a list of devices that can be used. + * + * @param[in] device_arr List of devices to try + * @param[in] len Number of devices in specified list + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see #hipGetDevice, #hipGetDeviceCount. #hipSetDevice. + * #hipGetDeviceProperties. #hipSetDeviceFlags. #hipChooseDevice + * + * */ +hipError_t hipSetValidDevices(int *device_arr, int len); +/** + * @brief Return the default device id for the calling host thread. + * + * @param [out] deviceId *device is written with the default device + * + * HIP maintains an default device for each thread using thread-local-storage. + * This device is used implicitly for HIP runtime APIs called by this thread. + * hipGetDevice returns in * @p device the default device for the calling host + * thread. + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see hipSetDevice, hipGetDevicesizeBytes + */ +hipError_t hipGetDevice(int *deviceId); +/** + * @brief Return number of compute-capable devices. + * + * @param [out] count Returns number of compute-capable devices. + * + * @returns #hipSuccess, #hipErrorNoDevice + * + * + * Returns in @p *count the number of devices that have ability to run compute + * commands. If there are no such devices, then @ref hipGetDeviceCount will + * return #hipErrorNoDevice. If 1 or more devices can be found, then + * hipGetDeviceCount returns #hipSuccess. + */ +hipError_t hipGetDeviceCount(int *count); +/** + * @brief Query for a specific device attribute. + * + * @param [out] pi pointer to value to return + * @param [in] attr attribute to query + * @param [in] deviceId which device to query for information + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + */ +hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attr, + int deviceId); +/** + * @brief Returns the default memory pool of the specified device + * + * @param [out] mem_pool Default memory pool to return + * @param [in] device Device index for query the default memory pool + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue, + * #hipErrorNotSupported + * + * @see hipDeviceGetDefaultMemPool, hipMallocAsync, hipMemPoolTrimTo, + * hipMemPoolGetAttribute, hipDeviceSetMemPool, hipMemPoolSetAttribute, + * hipMemPoolSetAccess, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipDeviceGetDefaultMemPool(hipMemPool_t *mem_pool, int device); +/** + * @brief Sets the current memory pool of a device + * + * The memory pool must be local to the specified device. + * @p hipMallocAsync allocates from the current mempool of the provided stream's + * device. By default, a device's current memory pool is its default memory + * pool. + * + * @note Use @p hipMallocFromPoolAsync for asynchronous memory allocations from + * a device different than the one the stream runs on. + * + * @param [in] device Device index for the update + * @param [in] mem_pool Memory pool for update as the current on the specified + * device + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidDevice, + * #hipErrorNotSupported + * + * @see hipDeviceGetDefaultMemPool, hipMallocAsync, hipMemPoolTrimTo, + * hipMemPoolGetAttribute, hipDeviceSetMemPool, hipMemPoolSetAttribute, + * hipMemPoolSetAccess, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipDeviceSetMemPool(int device, hipMemPool_t mem_pool); +/** + * @brief Gets the current memory pool for the specified device + * + * Returns the last pool provided to @p hipDeviceSetMemPool for this device + * or the device's default memory pool if @p hipDeviceSetMemPool has never been + * called. By default the current mempool is the default mempool for a device, + * otherwise the returned pool must have been set with @p hipDeviceSetMemPool. + * + * @param [out] mem_pool Current memory pool on the specified device + * @param [in] device Device index to query the current memory pool + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @see hipDeviceGetDefaultMemPool, hipMallocAsync, hipMemPoolTrimTo, + * hipMemPoolGetAttribute, hipDeviceSetMemPool, hipMemPoolSetAttribute, + * hipMemPoolSetAccess, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipDeviceGetMemPool(hipMemPool_t *mem_pool, int device); +/** + * @brief Returns device properties. + * + * @param [out] prop written with device properties + * @param [in] deviceId which device to query for information + * + * @return #hipSuccess, #hipErrorInvalidDevice + * @bug HCC always returns 0 for maxThreadsPerMultiProcessor + * @bug HCC always returns 0 for regsPerBlock + * @bug HCC always returns 0 for l2CacheSize + * + * Populates hipGetDeviceProperties with information for the specified device. + */ +hipError_t hipGetDeviceProperties(hipDeviceProp_t *prop, int deviceId); +/** + * @brief Set L1/Shared cache partition. + * + * @param [in] cacheConfig Cache configuration + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorNotSupported + * + * Note: AMD devices do not support reconfigurable cache. This API is not + * implemented on AMD platform. If the function is called, it will return + * hipErrorNotSupported. + * + */ +hipError_t hipDeviceSetCacheConfig(hipFuncCache_t cacheConfig); +/** + * @brief Get Cache configuration for a specific Device + * + * @param [out] cacheConfig Pointer of cache configuration + * + * @returns #hipSuccess, #hipErrorNotInitialized + * Note: AMD devices do not support reconfigurable cache. This hint is ignored + * on these architectures. + * + */ +hipError_t hipDeviceGetCacheConfig(hipFuncCache_t *cacheConfig); +/** + * @brief Gets resource limits of current device + * + * The function queries the size of limit value, as required by the input enum + * value hipLimit_t, which can be either #hipLimitStackSize, or + * #hipLimitMallocHeapSize. Any other input as default, the function will return + * #hipErrorUnsupportedLimit. + * + * @param [out] pValue Returns the size of the limit in bytes + * @param [in] limit The limit to query + * + * @returns #hipSuccess, #hipErrorUnsupportedLimit, #hipErrorInvalidValue + * + */ +hipError_t hipDeviceGetLimit(size_t *pValue, enum hipLimit_t limit); +/** + * @brief Sets resource limits of current device. + * + * As the input enum limit, + * #hipLimitStackSize sets the limit value of the stack size on the current GPU + * device, per thread. The limit size can get via hipDeviceGetLimit. The size is + * in units of 256 dwords, up to the limit (128K - 16). + * + * #hipLimitMallocHeapSize sets the limit value of the heap used by the + * malloc()/free() calls. For limit size, use the #hipDeviceGetLimit API. + * + * Any other input as default, the funtion will return hipErrorUnsupportedLimit. + * + * @param [in] limit Enum of hipLimit_t to set + * @param [in] value The size of limit value in bytes + * + * @returns #hipSuccess, #hipErrorUnsupportedLimit, #hipErrorInvalidValue + * + */ +hipError_t hipDeviceSetLimit(enum hipLimit_t limit, size_t value); +/** + * @brief Returns bank width of shared memory for current device + * + * @param [out] pConfig The pointer of the bank width for shared memory + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + * + * Note: AMD devices and some Nvidia GPUS do not support shared cache banking, + * and the hint is ignored on those architectures. + * + */ +hipError_t hipDeviceGetSharedMemConfig(hipSharedMemConfig *pConfig); +/** + * @brief Gets the flags set for current device + * + * @param [out] flags Pointer of the flags + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + */ +hipError_t hipGetDeviceFlags(unsigned int *flags); +/** + * @brief The bank width of shared memory on current device is set + * + * @param [in] config Configuration for the bank width of shared memory + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + * + * Note: AMD devices and some Nvidia GPUS do not support shared cache banking, + * and the hint is ignored on those architectures. + * + */ +hipError_t hipDeviceSetSharedMemConfig(hipSharedMemConfig config); +/** + * @brief The current device behavior is changed according the flags passed. + * + * @param [in] flags Flag to set on the current device + * + * The schedule flags impact how HIP waits for the completion of a command + * running on a device. hipDeviceScheduleSpin : HIP runtime will + * actively spin in the thread which submitted the work until the command + * completes. This offers the lowest latency, but will consume a CPU core and + * may increase power. hipDeviceScheduleYield : The HIP runtime will + * yield the CPU to system so that other tasks can use it. This may increase + * latency to detect the completion but will consume less power and is + * friendlier to other tasks in the system. hipDeviceScheduleBlockingSync : On + * ROCm platform, this is a synonym for hipDeviceScheduleYield. + * hipDeviceScheduleAuto : Use a hueristic to select between Spin and + * Yield modes. If the number of HIP contexts is greater than the number of + * logical processors in the system, use Spin scheduling. Else use Yield + * scheduling. + * + * + * hipDeviceMapHost : Allow mapping host memory. On ROCM, this is + * always allowed and the flag is ignored. hipDeviceLmemResizeToMax : + * @warning ROCm silently ignores this flag. + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorSetOnActiveProcess + * + * + */ +hipError_t hipSetDeviceFlags(unsigned flags); +/** + * @brief Device which matches hipDeviceProp_t is returned + * + * @param [out] device Pointer of the device + * @param [in] prop Pointer of the properties + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipChooseDevice(int *device, const hipDeviceProp_t *prop); +/** + * @brief Returns the link type and hop count between two devices + * + * @param [in] device1 Ordinal for device1 + * @param [in] device2 Ordinal for device2 + * @param [out] linktype Returns the link type (See hsa_amd_link_info_type_t) + * between the two devices + * @param [out] hopcount Returns the hop count between the two devices + * + * Queries and returns the HSA link type and the hop count between the two + * specified devices. + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipExtGetLinkTypeAndHopCount(int device1, int device2, + uint32_t *linktype, uint32_t *hopcount); +// TODO: implement IPC apis +/** + * @brief Gets an interprocess memory handle for an existing device memory + * allocation + * + * Takes a pointer to the base of an existing device memory allocation created + * with hipMalloc and exports it for use in another process. This is a + * lightweight operation and may be called multiple times on an allocation + * without adverse effects. + * + * If a region of memory is freed with hipFree and a subsequent call + * to hipMalloc returns memory with the same device address, + * hipIpcGetMemHandle will return a unique handle for the + * new memory. + * + * @param handle - Pointer to user allocated hipIpcMemHandle to return + * the handle in. + * @param devPtr - Base pointer to previously allocated device memory + * + * @returns #hipSuccess, #hipErrorInvalidHandle, #hipErrorOutOfMemory, + * #hipErrorMapFailed + * + * @note This IPC memory related feature API on Windows may behave differently + * from Linux. + * + */ +hipError_t hipIpcGetMemHandle(hipIpcMemHandle_t *handle, void *devPtr); +/** + * @brief Opens an interprocess memory handle exported from another process + * and returns a device pointer usable in the local process. + * + * Maps memory exported from another process with hipIpcGetMemHandle into + * the current device address space. For contexts on different devices + * hipIpcOpenMemHandle can attempt to enable peer access between the + * devices as if the user called hipDeviceEnablePeerAccess. This behavior is + * controlled by the hipIpcMemLazyEnablePeerAccess flag. + * hipDeviceCanAccessPeer can determine if a mapping is possible. + * + * Contexts that may open hipIpcMemHandles are restricted in the following way. + * hipIpcMemHandles from each device in a given process may only be opened + * by one context per device per other process. + * + * Memory returned from hipIpcOpenMemHandle must be freed with + * hipIpcCloseMemHandle. + * + * Calling hipFree on an exported memory region before calling + * hipIpcCloseMemHandle in the importing context will result in undefined + * behavior. + * + * @param devPtr - Returned device pointer + * @param handle - hipIpcMemHandle to open + * @param flags - Flags for this operation. Must be specified as + * hipIpcMemLazyEnablePeerAccess + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidContext, + * #hipErrorInvalidDevicePointer + * + * @note During multiple processes, using the same memory handle opened by the + * current context, there is no guarantee that the same device poiter will be + * returned in @p *devPtr. This is diffrent from CUDA. + * @note This IPC memory related feature API on Windows may behave differently + * from Linux. + * + */ +hipError_t hipIpcOpenMemHandle(void **devPtr, hipIpcMemHandle_t handle, + unsigned int flags); +/** + * @brief Close memory mapped with hipIpcOpenMemHandle + * + * Unmaps memory returnd by hipIpcOpenMemHandle. The original allocation + * in the exporting process as well as imported mappings in other processes + * will be unaffected. + * + * Any resources used to enable peer access will be freed if this is the + * last mapping using them. + * + * @param devPtr - Device pointer returned by hipIpcOpenMemHandle + * + * @returns #hipSuccess, #hipErrorMapFailed, #hipErrorInvalidHandle + * + * @note This IPC memory related feature API on Windows may behave differently + * from Linux. + * + */ +hipError_t hipIpcCloseMemHandle(void *devPtr); + +/** + * @brief Gets an opaque interprocess handle for an event. + * + * This opaque handle may be copied into other processes and opened with + * hipIpcOpenEventHandle. Then hipEventRecord, hipEventSynchronize, + * hipStreamWaitEvent and hipEventQuery may be used in either process. + * Operations on the imported event after the exported event has been freed with + * hipEventDestroy will result in undefined behavior. + * + * @param[out] handle Pointer to hipIpcEventHandle to return the opaque event + * handle + * @param[in] event Event allocated with hipEventInterprocess and + * hipEventDisableTiming flags + * + * @returns #hipSuccess, #hipErrorInvalidConfiguration, #hipErrorInvalidValue + * + * @note This IPC event related feature API is currently applicable on Linux. + * + */ +hipError_t hipIpcGetEventHandle(hipIpcEventHandle_t *handle, hipEvent_t event); + +/** + * @brief Opens an interprocess event handles. + * + * Opens an interprocess event handle exported from another process with + * hipIpcGetEventHandle. The returned hipEvent_t behaves like a locally created + * event with the hipEventDisableTiming flag specified. This event need be freed + * with hipEventDestroy. Operations on the imported event after the exported + * event has been freed with hipEventDestroy will result in undefined behavior. + * If the function is called within the same process where handle is returned by + * hipIpcGetEventHandle, it will return hipErrorInvalidContext. + * + * @param[out] event Pointer to hipEvent_t to return the event + * @param[in] handle The opaque interprocess handle to open + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidContext + * + * @note This IPC event related feature API is currently applicable on Linux. + * + */ +hipError_t hipIpcOpenEventHandle(hipEvent_t *event, hipIpcEventHandle_t handle); + +// end doxygen Device +/** + * @} + */ +/** + * + * @defgroup Execution Execution Control + * @{ + * This section describes the execution control functions of HIP runtime API. + * + */ +/** + * @brief Set attribute for a specific function + * + * @param [in] func Pointer of the function + * @param [in] attr Attribute to set + * @param [in] value Value to set + * + * @returns #hipSuccess, #hipErrorInvalidDeviceFunction, #hipErrorInvalidValue + * + * Note: AMD devices and some Nvidia GPUS do not support shared cache banking, + * and the hint is ignored on those architectures. + * + */ +hipError_t hipFuncSetAttribute(const void *func, hipFuncAttribute attr, + int value); +/** + * @brief Set Cache configuration for a specific function + * + * @param [in] func Pointer of the function. + * @param [in] config Configuration to set. + * + * @returns #hipSuccess, #hipErrorNotInitialized + * Note: AMD devices and some Nvidia GPUS do not support reconfigurable cache. + * This hint is ignored on those architectures. + * + */ +hipError_t hipFuncSetCacheConfig(const void *func, hipFuncCache_t config); +/** + * @brief Set shared memory configuation for a specific function + * + * @param [in] func Pointer of the function + * @param [in] config Configuration + * + * @returns #hipSuccess, #hipErrorInvalidDeviceFunction, #hipErrorInvalidValue + * + * Note: AMD devices and some Nvidia GPUS do not support shared cache banking, + * and the hint is ignored on those architectures. + * + */ +hipError_t hipFuncSetSharedMemConfig(const void *func, + hipSharedMemConfig config); +// doxygen end execution +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Error Error Handling + * @{ + * This section describes the error handling functions of HIP runtime API. + */ +/** + * @brief Return last error returned by any HIP runtime API call and resets the + * stored error code to #hipSuccess + * + * @returns return code from last HIP called from the active host thread + * + * Returns the last error that has been returned by any of the runtime calls in + * the same host thread, and then resets the saved error to #hipSuccess. + * + * @see hipGetErrorString, hipGetLastError, hipPeakAtLastError, hipError_t + */ +hipError_t hipGetLastError(void); + +/** + * @brief Return last error returned by any HIP runtime API call and resets the + * stored error code to #hipSuccess + * + * @returns return code from last HIP called from the active host thread + * + * Returns the last error that has been returned by any of the runtime calls in + * the same host thread, and then resets the saved error to #hipSuccess. + * + * @see hipGetErrorString, hipGetLastError, hipPeakAtLastError, hipError_t + */ +hipError_t hipExtGetLastError(void); + +/** + * @brief Return last error returned by any HIP runtime API call. + * + * @return #hipSuccess + * + * Returns the last error that has been returned by any of the runtime calls in + * the same host thread. Unlike hipGetLastError, this function does not reset + * the saved error code. + * + * @see hipGetErrorString, hipGetLastError, hipPeakAtLastError, hipError_t + */ +hipError_t hipPeekAtLastError(void); +/** + * @brief Return hip error as text string form. + * + * @param hip_error Error code to convert to name. + * @return const char pointer to the NULL-terminated error name + * + * @see hipGetErrorString, hipGetLastError, hipPeakAtLastError, hipError_t + */ +const char *hipGetErrorName(hipError_t hip_error); +/** + * @brief Return handy text string message to explain the error which occurred + * + * @param hipError Error code to convert to string. + * @return const char pointer to the NULL-terminated error string + * + * @see hipGetErrorName, hipGetLastError, hipPeakAtLastError, hipError_t + */ +const char *hipGetErrorString(hipError_t hipError); +/** + * @brief Return hip error as text string form. + * + * @param [in] hipError Error code to convert to string. + * @param [out] errorString char pointer to the NULL-terminated error string + * @return #hipSuccess, #hipErrorInvalidValue + * + * @see hipGetErrorName, hipGetLastError, hipPeakAtLastError, hipError_t + */ +hipError_t hipDrvGetErrorName(hipError_t hipError, const char **errorString); +/** + * @brief Return handy text string message to explain the error which occurred + * + * @param [in] hipError Error code to convert to string. + * @param [out] errorString char pointer to the NULL-terminated error string + * @return #hipSuccess, #hipErrorInvalidValue + * + * @see hipGetErrorName, hipGetLastError, hipPeakAtLastError, hipError_t + */ +hipError_t hipDrvGetErrorString(hipError_t hipError, const char **errorString); +// end doxygen Error +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Stream Stream Management + * @{ + * This section describes the stream management functions of HIP runtime API. + * The following Stream APIs are not (yet) supported in HIP: + * - hipStreamAttachMemAsync is a nop + */ + +/** + * @brief Create an asynchronous stream. + * + * @param[in, out] stream Valid pointer to hipStream_t. This function writes + * the memory with the newly created stream. + * @return #hipSuccess, #hipErrorInvalidValue + * + * Create a new asynchronous stream. @p stream returns an opaque handle that + * can be used to reference the newly created stream in subsequent hipStream* + * commands. The stream is allocated on the heap and will remain allocated even + * if the handle goes out-of-scope. To release the memory used by the stream, + * application must call hipStreamDestroy. + * + * @return #hipSuccess, #hipErrorInvalidValue + * + * @see hipStreamCreateWithFlags, hipStreamCreateWithPriority, + * hipStreamSynchronize, hipStreamWaitEvent, hipStreamDestroy + */ +hipError_t hipStreamCreate(hipStream_t *stream); +/** + * @brief Create an asynchronous stream. + * + * @param[in, out] stream Pointer to new stream + * @param[in ] flags to control stream creation. + * @return #hipSuccess, #hipErrorInvalidValue + * + * Create a new asynchronous stream. @p stream returns an opaque handle that + * can be used to reference the newly created stream in subsequent hipStream* + * commands. The stream is allocated on the heap and will remain allocated even + * if the handle goes out-of-scope. To release the memory used by the stream, + * application must call hipStreamDestroy. Flags controls behavior of the + * stream. See #hipStreamDefault, #hipStreamNonBlocking. + * + * + * @see hipStreamCreate, hipStreamCreateWithPriority, hipStreamSynchronize, + * hipStreamWaitEvent, hipStreamDestroy + */ +hipError_t hipStreamCreateWithFlags(hipStream_t *stream, unsigned int flags); +/** + * @brief Create an asynchronous stream with the specified priority. + * + * @param[in, out] stream Pointer to new stream + * @param[in ] flags to control stream creation. + * @param[in ] priority of the stream. Lower numbers represent higher + * priorities. + * @return #hipSuccess, #hipErrorInvalidValue + * + * Create a new asynchronous stream with the specified priority. @p stream + * returns an opaque handle that can be used to reference the newly created + * stream in subsequent hipStream* commands. The stream is allocated on the + * heap and will remain allocated even if the handle goes out-of-scope. To + * release the memory used by the stream, application must call + * hipStreamDestroy. Flags controls behavior of the stream. See + * #hipStreamDefault, #hipStreamNonBlocking. + * + * + * @see hipStreamCreate, hipStreamSynchronize, hipStreamWaitEvent, + * hipStreamDestroy + */ +hipError_t hipStreamCreateWithPriority(hipStream_t *stream, unsigned int flags, + int priority); +/** + * @brief Returns numerical values that correspond to the least and greatest + * stream priority. + * + * @param[in, out] leastPriority pointer in which value corresponding to least + * priority is returned. + * @param[in, out] greatestPriority pointer in which value corresponding to + * greatest priority is returned. + * @returns #hipSuccess + * + * Returns in *leastPriority and *greatestPriority the numerical values that + * correspond to the least and greatest stream priority respectively. Stream + * priorities follow a convention where lower numbers imply greater priorities. + * The range of meaningful stream priorities is given by + * [*greatestPriority, *leastPriority]. If the user attempts to create a stream + * with a priority value that is outside the meaningful range as specified by + * this API, the priority is automatically clamped to within the valid range. + */ +hipError_t hipDeviceGetStreamPriorityRange(int *leastPriority, + int *greatestPriority); +/** + * @brief Destroys the specified stream. + * + * @param[in] stream stream identifier. + * @return #hipSuccess #hipErrorInvalidHandle + * + * Destroys the specified stream. + * + * If commands are still executing on the specified stream, some may complete + * execution before the queue is deleted. + * + * The queue may be destroyed while some commands are still inflight, or may + * wait for all commands queued to the stream before destroying it. + * + * @see hipStreamCreate, hipStreamCreateWithFlags, hipStreamCreateWithPriority, + * hipStreamQuery, hipStreamWaitEvent, hipStreamSynchronize + */ +hipError_t hipStreamDestroy(hipStream_t stream); +/** + * @brief Return #hipSuccess if all of the operations in the specified @p stream + * have completed, or #hipErrorNotReady if not. + * + * @param[in] stream stream to query + * + * @return #hipSuccess, #hipErrorNotReady, #hipErrorInvalidHandle + * + * This is thread-safe and returns a snapshot of the current state of the queue. + * However, if other host threads are sending work to the stream, the status may + * change immediately after the function is called. It is typically used for + * debug. + * + * @see hipStreamCreate, hipStreamCreateWithFlags, hipStreamCreateWithPriority, + * hipStreamWaitEvent, hipStreamSynchronize, hipStreamDestroy + */ +hipError_t hipStreamQuery(hipStream_t stream); +/** + * @brief Wait for all commands in stream to complete. + * + * @param[in] stream stream identifier. + * + * @return #hipSuccess, #hipErrorInvalidHandle + * + * This command is host-synchronous : the host will block until the specified + * stream is empty. + * + * This command follows standard null-stream semantics. Specifically, + * specifying the null stream will cause the command to wait for other streams + * on the same device to complete all pending operations. + * + * This command honors the hipDeviceLaunchBlocking flag, which controls whether + * the wait is active or blocking. + * + * @see hipStreamCreate, hipStreamCreateWithFlags, hipStreamCreateWithPriority, + * hipStreamWaitEvent, hipStreamDestroy + * + */ +hipError_t hipStreamSynchronize(hipStream_t stream); +/** + * @brief Make the specified compute stream wait for an event + * + * @param[in] stream stream to make wait. + * @param[in] event event to wait on + * @param[in] flags control operation [must be 0] + * + * @return #hipSuccess, #hipErrorInvalidHandle + * + * This function inserts a wait operation into the specified stream. + * All future work submitted to @p stream will wait until @p event reports + * completion before beginning execution. + * + * This function only waits for commands in the current stream to complete. + * Notably, this function does not implicitly wait for commands in the default + * stream to complete, even if the specified stream is created with + * hipStreamNonBlocking = 0. + * + * @see hipStreamCreate, hipStreamCreateWithFlags, hipStreamCreateWithPriority, + * hipStreamSynchronize, hipStreamDestroy + */ +hipError_t hipStreamWaitEvent(hipStream_t stream, hipEvent_t event, + unsigned int flags __dparm(0)); +/** + * @brief Return flags associated with this stream. + * + * @param[in] stream stream to be queried + * @param[in,out] flags Pointer to an unsigned integer in which the stream's + * flags are returned + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidHandle + * + * @returns #hipSuccess #hipErrorInvalidValue #hipErrorInvalidHandle + * + * Return flags associated with this stream in *@p flags. + * + * @see hipStreamCreateWithFlags + */ +hipError_t hipStreamGetFlags(hipStream_t stream, unsigned int *flags); +/** + * @brief Query the priority of a stream. + * + * @param[in] stream stream to be queried + * @param[in,out] priority Pointer to an unsigned integer in which the stream's + * priority is returned + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidHandle + * + * @returns #hipSuccess #hipErrorInvalidValue #hipErrorInvalidHandle + * + * Query the priority of a stream. The priority is returned in in priority. + * + * @see hipStreamCreateWithFlags + */ +hipError_t hipStreamGetPriority(hipStream_t stream, int *priority); +/** + * @brief Get the device assocaited with the stream + * + * @param[in] stream stream to be queried + * @param[out] device device associated with the stream + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorContextIsDestroyed, + * #hipErrorInvalidHandle, #hipErrorNotInitialized, #hipErrorDeinitialized, + * #hipErrorInvalidContext + * + * @see hipStreamCreate, hipStreamDestroy, hipDeviceGetStreamPriorityRange + */ +hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t *device); +/** + * @brief Create an asynchronous stream with the specified CU mask. + * + * @param[in, out] stream Pointer to new stream + * @param[in ] cuMaskSize Size of CU mask bit array passed in. + * @param[in ] cuMask Bit-vector representing the CU mask. Each active bit + * represents using one CU. The first 32 bits represent the first 32 CUs, and so + * on. If its size is greater than physical CU number (i.e., multiProcessorCount + * member of hipDeviceProp_t), the extra elements are ignored. It is user's + * responsibility to make sure the input is meaningful. + * @return #hipSuccess, #hipErrorInvalidHandle, #hipErrorInvalidValue + * + * Create a new asynchronous stream with the specified CU mask. @p stream + * returns an opaque handle that can be used to reference the newly created + * stream in subsequent hipStream* commands. The stream is allocated on the + * heap and will remain allocated even if the handle goes out-of-scope. To + * release the memory used by the stream, application must call + * hipStreamDestroy. + * + * + * @see hipStreamCreate, hipStreamSynchronize, hipStreamWaitEvent, + * hipStreamDestroy + */ +hipError_t hipExtStreamCreateWithCUMask(hipStream_t *stream, + uint32_t cuMaskSize, + const uint32_t *cuMask); +/** + * @brief Get CU mask associated with an asynchronous stream + * + * @param[in] stream stream to be queried + * @param[in] cuMaskSize number of the block of memories (uint32_t *) allocated + * by user + * @param[out] cuMask Pointer to a pre-allocated block of memories (uint32_t *) + * in which the stream's CU mask is returned. The CU mask is returned in a + * chunck of 32 bits where each active bit represents one active CU + * @return #hipSuccess, #hipErrorInvalidHandle, #hipErrorInvalidValue + * + * @see hipStreamCreate, hipStreamSynchronize, hipStreamWaitEvent, + * hipStreamDestroy + */ +hipError_t hipExtStreamGetCUMask(hipStream_t stream, uint32_t cuMaskSize, + uint32_t *cuMask); +/** + * Stream CallBack struct + */ +typedef void (*hipStreamCallback_t)(hipStream_t stream, hipError_t status, + void *userData); +/** + * @brief Adds a callback to be called on the host after all currently enqueued + * items in the stream have completed. For each + * hipStreamAddCallback call, a callback will be executed exactly once. + * The callback will block later work in the stream until it is finished. + * @param[in] stream - Stream to add callback to + * @param[in] callback - The function to call once preceding stream operations + * are complete + * @param[in] userData - User specified data to be passed to the callback + * function + * @param[in] flags - Reserved for future use, must be 0 + * @return #hipSuccess, #hipErrorInvalidHandle, #hipErrorNotSupported + * + * @see hipStreamCreate, hipStreamCreateWithFlags, hipStreamQuery, + * hipStreamSynchronize, hipStreamWaitEvent, hipStreamDestroy, + * hipStreamCreateWithPriority + * + */ +hipError_t hipStreamAddCallback(hipStream_t stream, + hipStreamCallback_t callback, void *userData, + unsigned int flags); +// end doxygen Stream +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup StreamM Stream Memory Operations + * @{ + * This section describes Stream Memory Wait and Write functions of HIP runtime + *API. + */ +/** + * @brief Enqueues a wait command to the stream.[BETA] + * + * @param [in] stream - Stream identifier + * @param [in] ptr - Pointer to memory object allocated using + * 'hipMallocSignalMemory' flag + * @param [in] value - Value to be used in compare operation + * @param [in] flags - Defines the compare operation, supported values are + * hipStreamWaitValueGte hipStreamWaitValueEq, hipStreamWaitValueAnd and + * hipStreamWaitValueNor + * @param [in] mask - Mask to be applied on value at memory before it is + * compared with value, default value is set to enable every bit + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * Enqueues a wait command to the stream, all operations enqueued on this + * stream after this, will not execute until the defined wait condition is true. + * + * hipStreamWaitValueGte: waits until *ptr&mask >= value + * hipStreamWaitValueEq : waits until *ptr&mask == value + * hipStreamWaitValueAnd: waits until ((*ptr&mask) & value) != 0 + * hipStreamWaitValueNor: waits until ~((*ptr&mask) | (value&mask)) != 0 + * + * @note when using 'hipStreamWaitValueNor', mask is applied on both 'value' and + * '*ptr'. + * + * @note Support for hipStreamWaitValue32 can be queried using + * 'hipDeviceGetAttribute()' and 'hipDeviceAttributeCanUseStreamWaitValue' flag. + * + * @warning This API is marked as beta, meaning, while this is feature complete, + * it is still open to changes and may have outstanding issues. + * + * @see hipExtMallocWithFlags, hipFree, hipStreamWaitValue64, + * hipStreamWriteValue64, hipStreamWriteValue32, hipDeviceGetAttribute + */ +hipError_t hipStreamWaitValue32(hipStream_t stream, void *ptr, uint32_t value, + unsigned int flags, + uint32_t mask __dparm(0xFFFFFFFF)); +/** + * @brief Enqueues a wait command to the stream.[BETA] + * + * @param [in] stream - Stream identifier + * @param [in] ptr - Pointer to memory object allocated using + * 'hipMallocSignalMemory' flag + * @param [in] value - Value to be used in compare operation + * @param [in] flags - Defines the compare operation, supported values are + * hipStreamWaitValueGte hipStreamWaitValueEq, hipStreamWaitValueAnd and + * hipStreamWaitValueNor. + * @param [in] mask - Mask to be applied on value at memory before it is + * compared with value default value is set to enable every bit + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * Enqueues a wait command to the stream, all operations enqueued on this + * stream after this, will not execute until the defined wait condition is true. + * + * hipStreamWaitValueGte: waits until *ptr&mask >= value + * hipStreamWaitValueEq : waits until *ptr&mask == value + * hipStreamWaitValueAnd: waits until ((*ptr&mask) & value) != 0 + * hipStreamWaitValueNor: waits until ~((*ptr&mask) | (value&mask)) != 0 + * + * @note when using 'hipStreamWaitValueNor', mask is applied on both 'value' and + * '*ptr'. + * + * @note Support for hipStreamWaitValue64 can be queried using + * 'hipDeviceGetAttribute()' and 'hipDeviceAttributeCanUseStreamWaitValue' flag. + * + * @warning This API is marked as beta, meaning, while this is feature complete, + * it is still open to changes and may have outstanding issues. + * + * @see hipExtMallocWithFlags, hipFree, hipStreamWaitValue32, + * hipStreamWriteValue64, hipStreamWriteValue32, hipDeviceGetAttribute + */ +hipError_t hipStreamWaitValue64(hipStream_t stream, void *ptr, uint64_t value, + unsigned int flags, + uint64_t mask __dparm(0xFFFFFFFFFFFFFFFF)); +/** + * @brief Enqueues a write command to the stream.[BETA] + * + * @param [in] stream - Stream identifier + * @param [in] ptr - Pointer to a GPU accessible memory object + * @param [in] value - Value to be written + * @param [in] flags - reserved, ignored for now, will be used in future + * releases + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * Enqueues a write command to the stream, write operation is performed after + * all earlier commands on this stream have completed the execution. + * + * @warning This API is marked as beta, meaning, while this is feature complete, + * it is still open to changes and may have outstanding issues. + * + * @see hipExtMallocWithFlags, hipFree, hipStreamWriteValue32, + * hipStreamWaitValue32, hipStreamWaitValue64 + */ +hipError_t hipStreamWriteValue32(hipStream_t stream, void *ptr, uint32_t value, + unsigned int flags); +/** + * @brief Enqueues a write command to the stream.[BETA] + * + * @param [in] stream - Stream identifier + * @param [in] ptr - Pointer to a GPU accessible memory object + * @param [in] value - Value to be written + * @param [in] flags - reserved, ignored for now, will be used in future + * releases + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * Enqueues a write command to the stream, write operation is performed after + * all earlier commands on this stream have completed the execution. + * + * @warning This API is marked as beta, meaning, while this is feature complete, + * it is still open to changes and may have outstanding issues. + * + * @see hipExtMallocWithFlags, hipFree, hipStreamWriteValue32, + * hipStreamWaitValue32, hipStreamWaitValue64 + */ +hipError_t hipStreamWriteValue64(hipStream_t stream, void *ptr, uint64_t value, + unsigned int flags); +// end doxygen Stream Memory Operations +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Event Event Management + * @{ + * This section describes the event management functions of HIP runtime API. + */ +/** + * @brief Create an event with the specified flags + * + * @param[in,out] event Returns the newly created event. + * @param[in] flags Flags to control event behavior. Valid values are + #hipEventDefault, #hipEventBlockingSync, #hipEventDisableTiming, + #hipEventInterprocess + * #hipEventDefault : Default flag. The event will use active synchronization + and will support timing. Blocking synchronization provides lowest possible + latency at the expense of dedicating a CPU to poll on the event. + * #hipEventBlockingSync : The event will use blocking synchronization : if + hipEventSynchronize is called on this event, the thread will block until the + event completes. This can increase latency for the synchroniation but can + result in lower power and more resources for other CPU threads. + * #hipEventDisableTiming : Disable recording of timing information. Events + created with this flag would not record profiling data and provide best + performance if used for synchronization. + * #hipEventInterprocess : The event can be used as an interprocess event. + hipEventDisableTiming flag also must be set when hipEventInterprocess flag is + set. + * #hipEventDisableSystemFence : Disable acquire and release system scope fence. + This may improve performance but device memory may not be visible to the host + and other devices if this flag is set. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue, + #hipErrorLaunchFailure, #hipErrorOutOfMemory + * + * @see hipEventCreate, hipEventSynchronize, hipEventDestroy, + hipEventElapsedTime + */ +hipError_t hipEventCreateWithFlags(hipEvent_t *event, unsigned flags); +/** + * Create an event + * + * @param[in,out] event Returns the newly created event. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue, + * #hipErrorLaunchFailure, #hipErrorOutOfMemory + * + * @see hipEventCreateWithFlags, hipEventRecord, hipEventQuery, + * hipEventSynchronize, hipEventDestroy, hipEventElapsedTime + */ +hipError_t hipEventCreate(hipEvent_t *event); +/** + * @brief Record an event in the specified stream. + * + * @param[in] event event to record. + * @param[in] stream stream in which to record event. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized, + * #hipErrorInvalidHandle, #hipErrorLaunchFailure + * + * hipEventQuery() or hipEventSynchronize() must be used to determine when the + * event transitions from "recording" (after hipEventRecord() is called) to + * "recorded" (when timestamps are set, if requested). + * + * Events which are recorded in a non-NULL stream will transition to + * from recording to "recorded" state when they reach the head of + * the specified stream, after all previous + * commands in that stream have completed executing. + * + * If hipEventRecord() has been previously called on this event, then this call + * will overwrite any existing state in event. + * + * If this function is called on an event that is currently being recorded, + * results are undefined + * - either outstanding recording may save state into the event, and the order + * is not guaranteed. + * + * @note: If this function is not called before use hipEventQuery() or + * hipEventSynchronize(), #hipSuccess is returned, meaning no pending event in + * the stream. + * + * @see hipEventCreate, hipEventCreateWithFlags, hipEventQuery, + * hipEventSynchronize, hipEventDestroy, hipEventElapsedTime + * + */ +#ifdef __cplusplus +hipError_t hipEventRecord(hipEvent_t event, hipStream_t stream = NULL); +#else +hipError_t hipEventRecord(hipEvent_t event, hipStream_t stream); +#endif +/** + * @brief Destroy the specified event. + * + * @param[in] event Event to destroy. + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue, + * #hipErrorLaunchFailure + * + * Releases memory associated with the event. If the event is recording but + * has not completed recording when hipEventDestroy() is called, the function + * will return immediately and the completion_future resources will be released + * later, when the hipDevice is synchronized. + * + * @see hipEventCreate, hipEventCreateWithFlags, hipEventQuery, + * hipEventSynchronize, hipEventRecord, hipEventElapsedTime + * + * @returns #hipSuccess + */ +hipError_t hipEventDestroy(hipEvent_t event); +/** + * @brief Wait for an event to complete. + * + * This function will block until the event is ready, waiting for all previous + * work in the stream specified when event was recorded with hipEventRecord(). + * + * If hipEventRecord() has not been called on @p event, this function returns + * #hipSuccess when no event is captured. + * + * + * @param[in] event Event on which to wait. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized, + * #hipErrorInvalidHandle, #hipErrorLaunchFailure + * + * @see hipEventCreate, hipEventCreateWithFlags, hipEventQuery, + * hipEventDestroy, hipEventRecord, hipEventElapsedTime + */ +hipError_t hipEventSynchronize(hipEvent_t event); +/** + * @brief Return the elapsed time between two events. + * + * @param[out] ms : Return time between start and stop in ms. + * @param[in] start : Start event. + * @param[in] stop : Stop event. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotReady, + * #hipErrorInvalidHandle, #hipErrorNotInitialized, #hipErrorLaunchFailure + * + * Computes the elapsed time between two events. Time is computed in ms, with + * a resolution of approximately 1 us. + * + * Events which are recorded in a NULL stream will block until all commands + * on all other streams complete execution, and then record the timestamp. + * + * Events which are recorded in a non-NULL stream will record their timestamp + * when they reach the head of the specified stream, after all previous + * commands in that stream have completed executing. Thus the time that + * the event recorded may be significantly after the host calls + * hipEventRecord(). + * + * If hipEventRecord() has not been called on either event, then + * #hipErrorInvalidHandle is returned. If hipEventRecord() has been called on + * both events, but the timestamp has not yet been recorded on one or both + * events (that is, hipEventQuery() would return #hipErrorNotReady on at least + * one of the events), then #hipErrorNotReady is returned. + * + * @see hipEventCreate, hipEventCreateWithFlags, hipEventQuery, hipEventDestroy, + * hipEventRecord, hipEventSynchronize + */ +hipError_t hipEventElapsedTime(float *ms, hipEvent_t start, hipEvent_t stop); +/** + * @brief Query event status + * + * @param[in] event Event to query. + * @returns #hipSuccess, #hipErrorNotReady, #hipErrorInvalidHandle, + * #hipErrorInvalidValue, #hipErrorNotInitialized, #hipErrorLaunchFailure + * + * Query the status of the specified event. This function will return + * #hipSuccess if all commands in the appropriate stream (specified to + * hipEventRecord()) have completed. If any execution has not completed, then + * #hipErrorNotReady is returned. + * + * @note: This API returns #hipSuccess, if hipEventRecord() is not called before + * this API. + * + * @see hipEventCreate, hipEventCreateWithFlags, hipEventRecord, + * hipEventDestroy, hipEventSynchronize, hipEventElapsedTime + */ +hipError_t hipEventQuery(hipEvent_t event); +// end doxygen Events +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Memory Memory Management + * @{ + * This section describes the memory management functions of HIP runtime API. + * The following CUDA APIs are not currently supported: + * - cudaMalloc3D + * - cudaMalloc3DArray + * - TODO - more 2D, 3D, array APIs here. + * + * + */ + +/** + * @brief Sets information on the specified pointer.[BETA] + * + * @param [in] value Sets pointer attribute value + * @param [in] attribute Attribute to set + * @param [in] ptr Pointer to set attributes for + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @warning This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipPointerSetAttribute(const void *value, + hipPointer_attribute attribute, + hipDeviceptr_t ptr); + +/** + * @brief Returns attributes for the specified pointer + * + * @param [out] attributes attributes for the specified pointer + * @param [in] ptr pointer to get attributes for + * + * The output parameter 'attributes' has a member named 'type' that describes + * what memory the pointer is associated with, such as device memory, host + * memory, managed memory, and others. Otherwise, the API cannot handle the + * pointer and returns #hipErrorInvalidValue. + * + * @note The unrecognized memory type is unsupported to keep the HIP + * functionality backward compatibility due to #hipMemoryType enum values. + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @note The current behavior of this HIP API corresponds to the CUDA API + * before version 11.0. + * + * @see hipPointerGetAttribute + */ +hipError_t hipPointerGetAttributes(hipPointerAttribute_t *attributes, + const void *ptr); +/** + * @brief Returns information about the specified pointer.[BETA] + * + * @param [in, out] data Returned pointer attribute value + * @param [in] attribute Attribute to query for + * @param [in] ptr Pointer to get attributes for + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @warning This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @see hipPointerGetAttributes + */ +hipError_t hipPointerGetAttribute(void *data, hipPointer_attribute attribute, + hipDeviceptr_t ptr); +/** + * @brief Returns information about the specified pointer.[BETA] + * + * @param [in] numAttributes number of attributes to query for + * @param [in] attributes attributes to query for + * @param [in, out] data a two-dimensional containing pointers to memory + * locations where the result of each attribute query will be written to + * @param [in] ptr pointer to get attributes for + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @warning This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @see hipPointerGetAttribute + */ +hipError_t hipDrvPointerGetAttributes(unsigned int numAttributes, + hipPointer_attribute *attributes, + void **data, hipDeviceptr_t ptr); +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup External External Resource Interoperability + * @{ + * @ingroup API + * + * This section describes the external resource interoperability functions of + *HIP runtime API. + * + */ +/** + * @brief Imports an external semaphore. + * + * @param[out] extSem_out External semaphores to be waited on + * @param[in] semHandleDesc Semaphore import handle descriptor + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see + */ +hipError_t +hipImportExternalSemaphore(hipExternalSemaphore_t *extSem_out, + const hipExternalSemaphoreHandleDesc *semHandleDesc); +/** + * @brief Signals a set of external semaphore objects. + * + * @param[in] extSemArray External semaphores to be waited on + * @param[in] paramsArray Array of semaphore parameters + * @param[in] numExtSems Number of semaphores to wait on + * @param[in] stream Stream to enqueue the wait operations in + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see + */ +hipError_t hipSignalExternalSemaphoresAsync( + const hipExternalSemaphore_t *extSemArray, + const hipExternalSemaphoreSignalParams *paramsArray, + unsigned int numExtSems, hipStream_t stream); +/** + * @brief Waits on a set of external semaphore objects + * + * @param[in] extSemArray External semaphores to be waited on + * @param[in] paramsArray Array of semaphore parameters + * @param[in] numExtSems Number of semaphores to wait on + * @param[in] stream Stream to enqueue the wait operations in + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see + */ +hipError_t hipWaitExternalSemaphoresAsync( + const hipExternalSemaphore_t *extSemArray, + const hipExternalSemaphoreWaitParams *paramsArray, unsigned int numExtSems, + hipStream_t stream); +/** + * @brief Destroys an external semaphore object and releases any references to + * the underlying resource. Any outstanding signals or waits must have completed + * before the semaphore is destroyed. + * + * @param[in] extSem handle to an external memory object + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see + */ +hipError_t hipDestroyExternalSemaphore(hipExternalSemaphore_t extSem); + +/** + * @brief Imports an external memory object. + * + * @param[out] extMem_out Returned handle to an external memory object + * @param[in] memHandleDesc Memory import handle descriptor + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see + */ +hipError_t +hipImportExternalMemory(hipExternalMemory_t *extMem_out, + const hipExternalMemoryHandleDesc *memHandleDesc); +/** + * @brief Maps a buffer onto an imported memory object. + * + * @param[out] devPtr Returned device pointer to buffer + * @param[in] extMem Handle to external memory object + * @param[in] bufferDesc Buffer descriptor + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see + */ +hipError_t +hipExternalMemoryGetMappedBuffer(void **devPtr, hipExternalMemory_t extMem, + const hipExternalMemoryBufferDesc *bufferDesc); +/** + * @brief Destroys an external memory object. + * + * @param[in] extMem External memory object to be destroyed + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see + */ +hipError_t hipDestroyExternalMemory(hipExternalMemory_t extMem); +/** + * @brief Maps a mipmapped array onto an external memory object. + * + * @param[out] mipmap mipmapped array to return + * @param[in] extMem external memory object handle + * @param[in] mipmapDesc external mipmapped array descriptor + * + * Returned mipmapped array must be freed using hipFreeMipmappedArray. + * + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidResourceHandle + * + * @see hipImportExternalMemory, hipDestroyExternalMemory, + * hipExternalMemoryGetMappedBuffer, hipFreeMipmappedArray + */ +hipError_t hipExternalMemoryGetMappedMipmappedArray( + hipMipmappedArray_t *mipmap, hipExternalMemory_t extMem, + const hipExternalMemoryMipmappedArrayDesc *mipmapDesc); +// end of external resource +/** + * @} + */ +/** + * @brief Allocate memory on the default accelerator + * + * @param[out] ptr Pointer to the allocated memory + * @param[in] size Requested memory size + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and hipSuccess + * is returned. + * + * @return #hipSuccess, #hipErrorOutOfMemory, #hipErrorInvalidValue (bad + * context, null *ptr) + * + * @see hipMallocPitch, hipFree, hipMallocArray, hipFreeArray, hipMalloc3D, + * hipMalloc3DArray, hipHostFree, hipHostMalloc + */ +hipError_t hipMalloc(void **ptr, size_t size); +/** + * @brief Allocate memory on the default accelerator + * + * @param[out] ptr Pointer to the allocated memory + * @param[in] sizeBytes Requested memory size + * @param[in] flags Type of memory allocation + * + * If requested memory size is 0, no memory is allocated, *ptr returns nullptr, + * and #hipSuccess is returned. + * + * The memory allocation flag should be either #hipDeviceMallocDefault, + * #hipDeviceMallocFinegrained, #hipDeviceMallocUncached, or + * #hipMallocSignalMemory. If the flag is any other value, the API returns + * #hipErrorInvalidValue. + * + * @return #hipSuccess, #hipErrorOutOfMemory, #hipErrorInvalidValue (bad + * context, null *ptr) + * + * @see hipMallocPitch, hipFree, hipMallocArray, hipFreeArray, hipMalloc3D, + * hipMalloc3DArray, hipHostFree, hipHostMalloc + */ +hipError_t hipExtMallocWithFlags(void **ptr, size_t sizeBytes, + unsigned int flags); +/** + * @brief Allocate pinned host memory [Deprecated] + * + * @param[out] ptr Pointer to the allocated host pinned memory + * @param[in] size Requested memory size + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and hipSuccess + * is returned. + * + * @return #hipSuccess, #hipErrorOutOfMemory + * + * @warning This API is deprecated, use hipHostMalloc() instead + */ +DEPRECATED("use hipHostMalloc instead") +hipError_t hipMallocHost(void **ptr, size_t size); +/** + * @brief Allocate pinned host memory [Deprecated] + * + * @param[out] ptr Pointer to the allocated host pinned memory + * @param[in] size Requested memory size + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and hipSuccess + * is returned. + * + * @return #hipSuccess, #hipErrorOutOfMemory + * + * @warning This API is deprecated, use hipHostMalloc() instead + */ +DEPRECATED("use hipHostMalloc instead") +hipError_t hipMemAllocHost(void **ptr, size_t size); +/** + * @brief Allocates device accessible page locked (pinned) host memory + * + * This API allocates pinned host memory which is mapped into the address space + * of all GPUs in the system, the memory can be accessed directly by the GPU + * device, and can be read or written with much higher bandwidth than pageable + * memory obtained with functions such as malloc(). + * + * Using the pinned host memory, applications can implement faster data + * transfers for HostToDevice and DeviceToHost. The runtime tracks the + * hipHostMalloc allocations and can avoid some of the setup required for + * regular unpinned memory. + * + * When the memory accesses are infrequent, zero-copy memory can be a good + * choice, for coherent allocation. GPU can directly access the host memory over + * the CPU/GPU interconnect, without need to copy the data. + * + * Currently the allocation granularity is 4KB for the API. + * + * Developers need to choose proper allocation flag with consideration of + * synchronization. + * + * @param[out] ptr Pointer to the allocated host pinned memory + * @param[in] size Requested memory size in bytes + * If size is 0, no memory is allocated, *ptr returns nullptr, and hipSuccess + * is returned. + * @param[in] flags Type of host memory allocation + * + * If no input for flags, it will be the default pinned memory allocation on + * the host. + * + * @return #hipSuccess, #hipErrorOutOfMemory + * + * @see hipSetDeviceFlags, hipHostFree + */ +hipError_t hipHostMalloc(void **ptr, size_t size, unsigned int flags); +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup MemoryM Managed Memory + * + * @ingroup Memory + * @{ + * This section describes the managed memory management functions of HIP + *runtime API. + * + * @note The managed memory management APIs are implemented on Linux, under + *developement on Windows. + * + */ +/** + * @brief Allocates memory that will be automatically managed by HIP. + * + * This API is used for managed memory, allows data be shared and accessible to + * both CPU and GPU using a single pointer. + * + * The API returns the allocation pointer, managed by HMM, can be used further + * to execute kernels on device and fetch data between the host and device as + * needed. + * + * @note It is recommend to do the capability check before call this API. + * + * @param [out] dev_ptr - pointer to allocated device memory + * @param [in] size - requested allocation size in bytes, it should be + * granularity of 4KB + * @param [in] flags - must be either hipMemAttachGlobal or hipMemAttachHost + * (defaults to hipMemAttachGlobal) + * + * @returns #hipSuccess, #hipErrorMemoryAllocation, #hipErrorNotSupported, + * #hipErrorInvalidValue + * + */ +hipError_t hipMallocManaged(void **dev_ptr, size_t size, + unsigned int flags __dparm(hipMemAttachGlobal)); +/** + * @brief Prefetches memory to the specified destination device using HIP. + * + * @param [in] dev_ptr pointer to be prefetched + * @param [in] count size in bytes for prefetching + * @param [in] device destination device to prefetch to + * @param [in] stream stream to enqueue prefetch operation + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPrefetchAsync(const void *dev_ptr, size_t count, int device, + hipStream_t stream __dparm(0)); +/** + * @brief Advise about the usage of a given memory range to HIP. + * + * @param [in] dev_ptr pointer to memory to set the advice for + * @param [in] count size in bytes of the memory range, it should be CPU page + * size alligned. + * @param [in] advice advice to be applied for the specified memory range + * @param [in] device device to apply the advice for + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * This HIP API advises about the usage to be applied on unified memory + * allocation in the range starting from the pointer address devPtr, with the + * size of count bytes. The memory range must refer to managed memory allocated + * via the API hipMallocManaged, and the range will be handled with proper round + * down and round up respectively in the driver to be aligned to CPU page size, + * the same way as corresponding CUDA API behaves in CUDA version 8.0 and + * afterwards. + * + * @note This API is implemented on Linux and is under development on Windows. + */ +hipError_t hipMemAdvise(const void *dev_ptr, size_t count, + hipMemoryAdvise advice, int device); +/** + * @brief Query an attribute of a given memory range in HIP. + * + * @param [in,out] data a pointer to a memory location where the result of + * each attribute query will be written to + * @param [in] data_size the size of data + * @param [in] attribute the attribute to query + * @param [in] dev_ptr start of the range to query + * @param [in] count size of the range to query + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemRangeGetAttribute(void *data, size_t data_size, + hipMemRangeAttribute attribute, + const void *dev_ptr, size_t count); +/** + * @brief Query attributes of a given memory range in HIP. + * + * @param [in,out] data a two-dimensional array containing pointers to + * memory locations where the result of each attribute query will be written to + * @param [in] data_sizes an array, containing the sizes of each result + * @param [in] attributes the attribute to query + * @param [in] num_attributes an array of attributes to query (numAttributes + * and the number of attributes in this array should match) + * @param [in] dev_ptr start of the range to query + * @param [in] count size of the range to query + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemRangeGetAttributes(void **data, size_t *data_sizes, + hipMemRangeAttribute *attributes, + size_t num_attributes, const void *dev_ptr, + size_t count); +/** + * @brief Attach memory to a stream asynchronously in HIP. + * + * @param [in] stream - stream in which to enqueue the attach operation + * @param [in] dev_ptr - pointer to memory (must be a pointer to managed + * memory or to a valid host-accessible region of system-allocated memory) + * @param [in] length - length of memory (defaults to zero) + * @param [in] flags - must be one of hipMemAttachGlobal, hipMemAttachHost + * or hipMemAttachSingle (defaults to hipMemAttachSingle) + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t +hipStreamAttachMemAsync(hipStream_t stream, void *dev_ptr, + size_t length __dparm(0), + unsigned int flags __dparm(hipMemAttachSingle)); +// end doxygen Managed Memory +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup StreamO Stream Ordered Memory Allocator + * @{ + * @ingroup Memory + * This section describes Stream Ordered Memory Allocator functions of HIP + *runtime API. + * + * The asynchronous allocator allows the user to allocate and free in stream + *order. All asynchronous accesses of the allocation must happen between the + *stream executions of the allocation and the free. If the memory is accessed + *outside of the promised stream order, a use before allocation / use after free + *error will cause undefined behavior. + * + * The allocator is free to reallocate the memory as long as it can guarantee + *that compliant memory accesses will not overlap temporally. The allocator may + *refer to internal stream ordering as well as inter-stream dependencies (such + *as HIP events and null stream dependencies) when establishing the temporal + *guarantee. The allocator may also insert inter-stream dependencies to + *establish the temporal guarantee. Whether or not a device supports the + *integrated stream ordered memory allocator may be queried by calling @p + *hipDeviceGetAttribute with the device attribute + * @p hipDeviceAttributeMemoryPoolsSupported + * + * @note APIs in this section are implemented on Linux, under development on + *Windows. + */ + +/** + * @brief Allocates memory with stream ordered semantics + * + * Inserts a memory allocation operation into @p stream. + * A pointer to the allocated memory is returned immediately in *dptr. + * The allocation must not be accessed until the allocation operation completes. + * The allocation comes from the memory pool associated with the stream's + * device. + * + * @note The default memory pool of a device contains device memory from that + * device. + * @note Basic stream ordering allows future work submitted into the same stream + * to use the allocation. Stream query, stream synchronize, and HIP events can + * be used to guarantee that the allocation operation completes before work + * submitted in a separate stream runs. + * @note During stream capture, this function results in the creation of an + * allocation node. In this case, the allocation is owned by the graph instead + * of the memory pool. The memory pool's properties are used to set the node's + * creation parameters. + * + * @param [out] dev_ptr Returned device pointer of memory allocation + * @param [in] size Number of bytes to allocate + * @param [in] stream The stream establishing the stream ordering contract + * and the memory pool to allocate from + * + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported, + * #hipErrorOutOfMemory + * + * @see hipMallocFromPoolAsync, hipFreeAsync, hipMemPoolTrimTo, + * hipMemPoolGetAttribute, hipDeviceSetMemPool, hipMemPoolSetAttribute, + * hipMemPoolSetAccess, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMallocAsync(void **dev_ptr, size_t size, hipStream_t stream); +/** + * @brief Frees memory with stream ordered semantics + * + * Inserts a free operation into @p stream. + * The allocation must not be used after stream execution reaches the free. + * After this API returns, accessing the memory from any subsequent work + * launched on the GPU or querying its pointer attributes results in undefined + * behavior. + * + * @note During stream capture, this function results in the creation of a free + * node and must therefore be passed the address of a graph allocation. + * + * @param [in] dev_ptr Pointer to device memory to free + * @param [in] stream The stream, where the destruciton will occur according to + * the execution order + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @see hipMallocFromPoolAsync, hipMallocAsync, hipMemPoolTrimTo, + * hipMemPoolGetAttribute, hipDeviceSetMemPool, hipMemPoolSetAttribute, + * hipMemPoolSetAccess, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipFreeAsync(void *dev_ptr, hipStream_t stream); +/** + * @brief Releases freed memory back to the OS + * + * Releases memory back to the OS until the pool contains fewer than @p + * min_bytes_to_keep reserved bytes, or there is no more memory that the + * allocator can safely release. The allocator cannot release OS allocations + * that back outstanding asynchronous allocations. The OS allocations may happen + * at different granularity from the user allocations. + * + * @note: Allocations that have not been freed count as outstanding. + * @note: Allocations that have been asynchronously freed but whose completion + * has not been observed on the host (eg. by a synchronize) can count as + * outstanding. + * + * @param[in] mem_pool The memory pool to trim allocations + * @param[in] min_bytes_to_hold If the pool has less than min_bytes_to_hold + * reserved, then the TrimTo operation is a no-op. Otherwise the memory pool + * will contain at least min_bytes_to_hold bytes reserved after the operation. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @see hipMallocFromPoolAsync, hipMallocAsync, hipFreeAsync, + * hipMemPoolGetAttribute, hipDeviceSetMemPool, hipMemPoolSetAttribute, + * hipMemPoolSetAccess, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolTrimTo(hipMemPool_t mem_pool, size_t min_bytes_to_hold); +/** + * @brief Sets attributes of a memory pool + * + * Supported attributes are: + * - @p hipMemPoolAttrReleaseThreshold: (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold + * onto before trying to release memory back to the OS. When more than the + * release threshold bytes of memory are held by the memory pool, the allocator + * will try to release memory back to the OS on the next call to stream, event + * or context synchronize. (default 0) + * - @p hipMemPoolReuseFollowEventDependencies: (value type = int) + * Allow @p hipMallocAsync to use memory + * asynchronously freed in another stream as long as a stream ordering + * dependency of the allocating stream on the free action exists. HIP events and + * null stream interactions can create the required stream ordered dependencies. + * (default enabled) + * - @p hipMemPoolReuseAllowOpportunistic: (value type = int) + * Allow reuse of already completed frees when + * there is no dependency between the free and allocation. (default enabled) + * - @p hipMemPoolReuseAllowInternalDependencies: (value type = int) + * Allow @p hipMallocAsync to insert new stream + * dependencies in order to establish the stream ordering required to reuse a + * piece of memory released by @p hipFreeAsync (default enabled). + * + * @param [in] mem_pool The memory pool to modify + * @param [in] attr The attribute to modify + * @param [in] value Pointer to the value to assign + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @see hipMallocFromPoolAsync, hipMallocAsync, hipFreeAsync, + * hipMemPoolGetAttribute, hipMemPoolTrimTo, hipDeviceSetMemPool, + * hipMemPoolSetAccess, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolSetAttribute(hipMemPool_t mem_pool, hipMemPoolAttr attr, + void *value); +/** + * @brief Gets attributes of a memory pool + * + * Supported attributes are: + * - @p hipMemPoolAttrReleaseThreshold: (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold + * onto before trying to release memory back to the OS. When more than the + * release threshold bytes of memory are held by the memory pool, the allocator + * will try to release memory back to the OS on the next call to stream, event + * or context synchronize. (default 0) + * - @p hipMemPoolReuseFollowEventDependencies: (value type = int) + * Allow @p hipMallocAsync to use memory + * asynchronously freed in another stream as long as a stream ordering + * dependency of the allocating stream on the free action exists. HIP events and + * null stream interactions can create the required stream ordered dependencies. + * (default enabled) + * - @p hipMemPoolReuseAllowOpportunistic: (value type = int) + * Allow reuse of already completed frees when + * there is no dependency between the free and allocation. (default enabled) + * - @p hipMemPoolReuseAllowInternalDependencies: (value type = int) + * Allow @p hipMallocAsync to insert new stream + * dependencies in order to establish the stream ordering required to reuse a + * piece of memory released by @p hipFreeAsync (default enabled). + * + * @param [in] mem_pool The memory pool to get attributes of + * @param [in] attr The attribute to get + * @param [in] value Retrieved value + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @see hipMallocFromPoolAsync, hipMallocAsync, hipFreeAsync, + * hipMemPoolTrimTo, hipDeviceSetMemPool, hipMemPoolSetAttribute, + * hipMemPoolSetAccess, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolGetAttribute(hipMemPool_t mem_pool, hipMemPoolAttr attr, + void *value); +/** + * @brief Controls visibility of the specified pool between devices + * + * @param [in] mem_pool Memory pool for acccess change + * @param [in] desc_list Array of access descriptors. Each descriptor instructs + * the access to enable for a single gpu + * @param [in] count Number of descriptors in the map array. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @see hipMallocFromPoolAsync, hipMallocAsync, hipFreeAsync, + * hipMemPoolGetAttribute, hipMemPoolTrimTo, hipDeviceSetMemPool, + * hipMemPoolSetAttribute, hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolSetAccess(hipMemPool_t mem_pool, + const hipMemAccessDesc *desc_list, size_t count); +/** + * @brief Returns the accessibility of a pool from a device + * + * Returns the accessibility of the pool's memory from the specified location. + * + * @param [out] flags Accessibility of the memory pool from the specified + * location/device + * @param [in] mem_pool Memory pool being queried + * @param [in] location Location/device for memory pool access + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @see hipMallocFromPoolAsync, hipMallocAsync, hipFreeAsync, + * hipMemPoolGetAttribute, hipMemPoolTrimTo, hipDeviceSetMemPool, + * hipMemPoolSetAttribute, hipMemPoolSetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolGetAccess(hipMemAccessFlags *flags, hipMemPool_t mem_pool, + hipMemLocation *location); +/** + * @brief Creates a memory pool + * + * Creates a HIP memory pool and returns the handle in @p mem_pool. The @p + * pool_props determines the properties of the pool such as the backing device + * and IPC capabilities. + * + * By default, the memory pool will be accessible from the device it is + * allocated on. + * + * @param [out] mem_pool Contains createed memory pool + * @param [in] pool_props Memory pool properties + * + * @note Specifying hipMemHandleTypeNone creates a memory pool that will not + * support IPC. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @see hipMallocFromPoolAsync, hipMallocAsync, hipFreeAsync, + * hipMemPoolGetAttribute, hipMemPoolDestroy, hipMemPoolTrimTo, + * hipDeviceSetMemPool, hipMemPoolSetAttribute, hipMemPoolSetAccess, + * hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolCreate(hipMemPool_t *mem_pool, + const hipMemPoolProps *pool_props); +/** + * @brief Destroys the specified memory pool + * + * If any pointers obtained from this pool haven't been freed or + * the pool has free operations that haven't completed + * when @p hipMemPoolDestroy is invoked, the function will return immediately + * and the resources associated with the pool will be released automatically + * once there are no more outstanding allocations. + * + * Destroying the current mempool of a device sets the default mempool of + * that device as the current mempool for that device. + * + * @param [in] mem_pool Memory pool for destruction + * + * @note A device's default memory pool cannot be destroyed. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @see hipMallocFromPoolAsync, hipMallocAsync, hipFreeAsync, + * hipMemPoolGetAttribute, hipMemPoolCreate hipMemPoolTrimTo, + * hipDeviceSetMemPool, hipMemPoolSetAttribute, hipMemPoolSetAccess, + * hipMemPoolGetAccess + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolDestroy(hipMemPool_t mem_pool); +/** + * @brief Allocates memory from a specified pool with stream ordered semantics. + * + * Inserts an allocation operation into @p stream. + * A pointer to the allocated memory is returned immediately in @p dev_ptr. + * The allocation must not be accessed until the allocation operation completes. + * The allocation comes from the specified memory pool. + * + * @note The specified memory pool may be from a device different than that of + * the specified @p stream. + * + * Basic stream ordering allows future work submitted into the same stream to + * use the allocation. Stream query, stream synchronize, and HIP events can be + * used to guarantee that the allocation operation completes before work + * submitted in a separate stream runs. + * + * @note During stream capture, this function results in the creation of an + * allocation node. In this case, the allocation is owned by the graph instead + * of the memory pool. The memory pool's properties are used to set the node's + * creation parameters. + * + * @param [out] dev_ptr Returned device pointer + * @param [in] size Number of bytes to allocate + * @param [in] mem_pool The pool to allocate from + * @param [in] stream The stream establishing the stream ordering semantic + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported, + * #hipErrorOutOfMemory + * + * @see hipMallocAsync, hipFreeAsync, hipMemPoolGetAttribute, hipMemPoolCreate + * hipMemPoolTrimTo, hipDeviceSetMemPool, hipMemPoolSetAttribute, + * hipMemPoolSetAccess, hipMemPoolGetAccess, + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMallocFromPoolAsync(void **dev_ptr, size_t size, + hipMemPool_t mem_pool, hipStream_t stream); +/** + * @brief Exports a memory pool to the requested handle type. + * + * Given an IPC capable mempool, create an OS handle to share the pool with + * another process. A recipient process can convert the shareable handle into a + * mempool with @p hipMemPoolImportFromShareableHandle. Individual pointers can + * then be shared with the @p hipMemPoolExportPointer and @p + * hipMemPoolImportPointer APIs. The implementation of what the shareable handle + * is and how it can be transferred is defined by the requested handle type. + * + * @note: To create an IPC capable mempool, create a mempool with a @p + * hipMemAllocationHandleType other than @p hipMemHandleTypeNone. + * + * @param [out] shared_handle Pointer to the location in which to store the + * requested handle + * @param [in] mem_pool Pool to export + * @param [in] handle_type The type of handle to create + * @param [in] flags Must be 0 + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorOutOfMemory + * + * @see hipMemPoolImportFromShareableHandle + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t +hipMemPoolExportToShareableHandle(void *shared_handle, hipMemPool_t mem_pool, + hipMemAllocationHandleType handle_type, + unsigned int flags); +/** + * @brief Imports a memory pool from a shared handle. + * + * Specific allocations can be imported from the imported pool with @p + * hipMemPoolImportPointer. + * + * @note Imported memory pools do not support creating new allocations. + * As such imported memory pools may not be used in @p hipDeviceSetMemPool + * or @p hipMallocFromPoolAsync calls. + * + * @param [out] mem_pool Returned memory pool + * @param [in] shared_handle OS handle of the pool to open + * @param [in] handle_type The type of handle being imported + * @param [in] flags Must be 0 + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorOutOfMemory + * + * @see hipMemPoolExportToShareableHandle + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t +hipMemPoolImportFromShareableHandle(hipMemPool_t *mem_pool, void *shared_handle, + hipMemAllocationHandleType handle_type, + unsigned int flags); +/** + * @brief Export data to share a memory pool allocation between processes. + * + * Constructs @p export_data for sharing a specific allocation from an already + * shared memory pool. The recipient process can import the allocation with the + * @p hipMemPoolImportPointer api. The data is not a handle and may be shared + * through any IPC mechanism. + * + * @param[out] export_data Returned export data + * @param[in] dev_ptr Pointer to memory being exported + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorOutOfMemory + * + * @see hipMemPoolImportPointer + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolExportPointer(hipMemPoolPtrExportData *export_data, + void *dev_ptr); +/** + * @brief Import a memory pool allocation from another process. + * + * Returns in @p dev_ptr a pointer to the imported memory. + * The imported memory must not be accessed before the allocation operation + * completes in the exporting process. The imported memory must be freed from + * all importing processes before being freed in the exporting process. The + * pointer may be freed with @p hipFree or @p hipFreeAsync. If @p hipFreeAsync + * is used, the free must be completed on the importing process before the free + * operation on the exporting process. + * + * @note The @p hipFreeAsync api may be used in the exporting process before + * the @p hipFreeAsync operation completes in its stream as long as the + * @p hipFreeAsync in the exporting process specifies a stream with + * a stream dependency on the importing process's @p hipFreeAsync. + * + * @param [out] dev_ptr Pointer to imported memory + * @param [in] mem_pool Memory pool from which to import a pointer + * @param [in] export_data Data specifying the memory to import + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized, + * #hipErrorOutOfMemory + * + * @see hipMemPoolExportPointer + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemPoolImportPointer(void **dev_ptr, hipMemPool_t mem_pool, + hipMemPoolPtrExportData *export_data); +// Doxygen end of ordered memory allocator +/** + * @} + */ + +/** + * @brief Allocate device accessible page locked host memory [Deprecated] + * + * @param[out] ptr Pointer to the allocated host pinned memory + * @param[in] size Requested memory size in bytes + * @param[in] flags Type of host memory allocation + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and hipSuccess + * is returned. + * + * @return #hipSuccess, #hipErrorOutOfMemory + * + * @warning This API is deprecated, use hipHostMalloc() instead + */ +DEPRECATED("use hipHostMalloc instead") +hipError_t hipHostAlloc(void **ptr, size_t size, unsigned int flags); +/** + * @brief Get Device pointer from Host Pointer allocated through hipHostMalloc + * + * @param[out] devPtr Device Pointer mapped to passed host pointer + * @param[in] hstPtr Host Pointer allocated through hipHostMalloc + * @param[in] flags Flags to be passed for extension + * + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorOutOfMemory + * + * @see hipSetDeviceFlags, hipHostMalloc + */ +hipError_t hipHostGetDevicePointer(void **devPtr, void *hstPtr, + unsigned int flags); +/** + * @brief Return flags associated with host pointer + * + * @param[out] flagsPtr Memory location to store flags + * @param[in] hostPtr Host Pointer allocated through hipHostMalloc + * @return #hipSuccess, #hipErrorInvalidValue + * + * @see hipHostMalloc + */ +hipError_t hipHostGetFlags(unsigned int *flagsPtr, void *hostPtr); +/** + * @brief Register host memory so it can be accessed from the current device. + * + * @param[out] hostPtr Pointer to host memory to be registered. + * @param[in] sizeBytes Size of the host memory + * @param[in] flags See below. + * + * Flags: + * - #hipHostRegisterDefault Memory is Mapped and Portable + * - #hipHostRegisterPortable Memory is considered registered by all contexts. + * HIP only supports one context so this is always assumed true. + * - #hipHostRegisterMapped Map the allocation into the address space for + * the current device. The device pointer can be obtained with + * #hipHostGetDevicePointer. + * + * + * After registering the memory, use #hipHostGetDevicePointer to obtain the + * mapped device pointer. On many systems, the mapped device pointer will have a + * different value than the mapped host pointer. Applications must use the + * device pointer in device code, and the host pointer in device code. + * + * On some systems, registered memory is pinned. On some systems, registered + * memory may not be actually be pinned but uses OS or hardware facilities to + * all GPU access to the host memory. + * + * Developers are strongly encouraged to register memory blocks which are + * aligned to the host cache-line size. (typically 64-bytes but can be obtains + * from the CPUID instruction). + * + * If registering non-aligned pointers, the application must take care when + * register pointers from the same cache line on different devices. HIP's + * coarse-grained synchronization model does not guarantee correct results if + * different devices write to different parts of the same cache block - + * typically one of the writes will "win" and overwrite data from the other + * registered memory region. + * + * @return #hipSuccess, #hipErrorOutOfMemory + * + * @see hipHostUnregister, hipHostGetFlags, hipHostGetDevicePointer + */ +hipError_t hipHostRegister(void *hostPtr, size_t sizeBytes, unsigned int flags); +/** + * @brief Un-register host pointer + * + * @param[in] hostPtr Host pointer previously registered with #hipHostRegister + * @return Error code + * + * @see hipHostRegister + */ +hipError_t hipHostUnregister(void *hostPtr); +/** + * Allocates at least width (in bytes) * height bytes of linear memory + * Padding may occur to ensure alighnment requirements are met for the given + * row The change in width size due to padding will be returned in *pitch. + * Currently the alignment is set to 128 bytes + * + * @param[out] ptr Pointer to the allocated device memory + * @param[out] pitch Pitch for allocation (in bytes) + * @param[in] width Requested pitched allocation width (in bytes) + * @param[in] height Requested pitched allocation height + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and hipSuccess + * is returned. + * + * @return Error code + * + * @see hipMalloc, hipFree, hipMallocArray, hipFreeArray, hipHostFree, + * hipMalloc3D, hipMalloc3DArray, hipHostMalloc + */ +hipError_t hipMallocPitch(void **ptr, size_t *pitch, size_t width, + size_t height); +/** + * Allocates at least width (in bytes) * height bytes of linear memory + * Padding may occur to ensure alighnment requirements are met for the given + * row The change in width size due to padding will be returned in *pitch. + * Currently the alignment is set to 128 bytes + * + * @param[out] dptr Pointer to the allocated device memory + * @param[out] pitch Pitch for allocation (in bytes) + * @param[in] widthInBytes Requested pitched allocation width (in bytes) + * @param[in] height Requested pitched allocation height + * @param[in] elementSizeBytes The size of element bytes, should be 4, 8 or + * 16 + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and hipSuccess + * is returned. The intended usage of pitch is as a separate parameter of the + * allocation, used to compute addresses within the 2D array. Given the row and + * column of an array element of type T, the address is computed as: T* pElement + * = (T*)((char*)BaseAddress + Row * Pitch) + Column; + * + * @return Error code + * + * @see hipMalloc, hipFree, hipMallocArray, hipFreeArray, hipHostFree, + * hipMalloc3D, hipMalloc3DArray, hipHostMalloc + */ +hipError_t hipMemAllocPitch(hipDeviceptr_t *dptr, size_t *pitch, + size_t widthInBytes, size_t height, + unsigned int elementSizeBytes); +/** + * @brief Free memory allocated by the hcc hip memory allocation API. + * This API performs an implicit hipDeviceSynchronize() call. + * If pointer is NULL, the hip runtime is initialized and hipSuccess is + * returned. + * + * @param[in] ptr Pointer to memory to be freed + * @return #hipSuccess + * @return #hipErrorInvalidDevicePointer (if pointer is invalid, including host + * pointers allocated with hipHostMalloc) + * + * @see hipMalloc, hipMallocPitch, hipMallocArray, hipFreeArray, hipHostFree, + * hipMalloc3D, hipMalloc3DArray, hipHostMalloc + */ +hipError_t hipFree(void *ptr); +/** + * @brief Free memory allocated by the hcc hip host memory allocation API + * [Deprecated] + * + * @param[in] ptr Pointer to memory to be freed + * @return #hipSuccess, + * #hipErrorInvalidValue (if pointer is invalid, including device + * pointers allocated with hipMalloc) + * + * @warning This API is deprecated, use hipHostFree() instead + */ +DEPRECATED("use hipHostFree instead") +hipError_t hipFreeHost(void *ptr); +/** + * @brief Free memory allocated by the hcc hip host memory allocation API + * This API performs an implicit hipDeviceSynchronize() call. + * If pointer is NULL, the hip runtime is initialized and hipSuccess is + * returned. + * + * @param[in] ptr Pointer to memory to be freed + * @return #hipSuccess, + * #hipErrorInvalidValue (if pointer is invalid, including device + * pointers allocated with hipMalloc) + * + * @see hipMalloc, hipMallocPitch, hipFree, hipMallocArray, hipFreeArray, + * hipMalloc3D, hipMalloc3DArray, hipHostMalloc + */ +hipError_t hipHostFree(void *ptr); +/** + * @brief Copy data from src to dst. + * + * It supports memory from host to device, + * device to host, device to device and host to host + * The src and dst must not overlap. + * + * For hipMemcpy, the copy is always performed by the current device (set by + * hipSetDevice). For multi-gpu or peer-to-peer configurations, it is + * recommended to set the current device to the device where the src data is + * physically located. For optimal peer-to-peer copies, the copy device must be + * able to access the src and dst pointers (by calling hipDeviceEnablePeerAccess + * with copy agent as the current device and src/dest as the peerDevice + * argument. if this is not done, the hipMemcpy will still work, but will + * perform the copy using a staging buffer on the host. Calling hipMemcpy with + * dst and src pointers that do not match the hipMemcpyKind results in undefined + * behavior. + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * @param[in] kind Kind of transfer + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorUnknown + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpy(void *dst, const void *src, size_t sizeBytes, + hipMemcpyKind kind); +/** + * @brief Memory copy on the stream. + * It allows single or multiple devices to do memory copy on single or multiple + * streams. + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * @param[in] kind Kind of transfer + * @param[in] stream Valid stream + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorUnknown, + * #hipErrorContextIsDestroyed + * + * @see hipMemcpy, hipStreamCreate, hipStreamSynchronize, hipStreamDestroy, + * hipSetDevice, hipLaunchKernelGGL + * + */ +hipError_t hipMemcpyWithStream(void *dst, const void *src, size_t sizeBytes, + hipMemcpyKind kind, hipStream_t stream); +/** + * @brief Copy data from Host to Device + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyHtoD(hipDeviceptr_t dst, void *src, size_t sizeBytes); +/** + * @brief Copy data from Device to Host + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyDtoH(void *dst, hipDeviceptr_t src, size_t sizeBytes); +/** + * @brief Copy data from Device to Device + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyDtoD(hipDeviceptr_t dst, hipDeviceptr_t src, + size_t sizeBytes); +/** + * @brief Copies from one 1D array to device memory. + * + * @param[out] dstDevice Destination device pointer + * @param[in] srcArray Source array + * @param[in] srcOffset Offset in bytes of source array + * @param[in] ByteCount Size of memory copy in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyAtoD(hipDeviceptr_t dstDevice, hipArray_t srcArray, + size_t srcOffset, size_t ByteCount); +/** + * @brief Copies from device memory to a 1D array. + * + * @param[out] dstArray Destination array + * @param[in] dstOffset Offset in bytes of destination array + * @param[in] srcDevice Source device pointer + * @param[in] ByteCount Size of memory copy in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyDtoA(hipArray_t dstArray, size_t dstOffset, + hipDeviceptr_t srcDevice, size_t ByteCount); + +/** + * @brief Copies from one 1D array to another. + * + * @param[out] dstArray Destination array + * @param[in] dstOffset Offset in bytes of destination array + * @param[in] srcArray Source array + * @param[in] srcOffset Offset in bytes of source array + * @param[in] ByteCount Size of memory copy in bytes + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyAtoA(hipArray_t dstArray, size_t dstOffset, + hipArray_t srcArray, size_t srcOffset, + size_t ByteCount); +/** + * @brief Copy data from Host to Device asynchronously + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * @param[in] stream Stream identifier + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dst, void *src, size_t sizeBytes, + hipStream_t stream); +/** + * @brief Copy data from Device to Host asynchronously + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * @param[in] stream Stream identifier + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyDtoHAsync(void *dst, hipDeviceptr_t src, size_t sizeBytes, + hipStream_t stream); +/** + * @brief Copy data from Device to Device asynchronously + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * @param[in] stream Stream identifier + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyDtoDAsync(hipDeviceptr_t dst, hipDeviceptr_t src, + size_t sizeBytes, hipStream_t stream); +/** + * @brief Copies from one 1D array to host memory. + * + * @param[out] dstHost Destination pointer + * @param[in] srcArray Source array + * @param[in] srcOffset Offset in bytes of source array + * @param[in] ByteCount Size of memory copy in bytes + * @param[in] stream Stream identifier + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyAtoHAsync(void *dstHost, hipArray_t srcArray, + size_t srcOffset, size_t ByteCount, + hipStream_t stream); +/** + * @brief Copies from host memory to a 1D array. + * + * @param[out] dstArray Destination array + * @param[in] dstOffset Offset in bytes of destination array + * @param[in] srcHost Source host pointer + * @param[in] ByteCount Size of memory copy in bytes + * @param[in] stream Stream identifier + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipArrayGetDescriptor, hipMemAlloc, + * hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, hipMemcpy2DAsync, + * hipMemcpy2DUnaligned, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer + */ +hipError_t hipMemcpyHtoAAsync(hipArray_t dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount, + hipStream_t stream); +/** + * @brief Returns a global pointer from a module. + * Returns in *dptr and *bytes the pointer and size of the global of name name + * located in module hmod. If no variable of that name exists, it returns + * hipErrorNotFound. Both parameters dptr and bytes are optional. If one of them + * is NULL, it is ignored and hipSuccess is returned. + * + * @param[out] dptr Returns global device pointer + * @param[out] bytes Returns global size in bytes + * @param[in] hmod Module to retrieve global from + * @param[in] name Name of global to retrieve + * + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotFound, + * #hipErrorInvalidContext + * + */ +hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t *bytes, + hipModule_t hmod, const char *name); + +/** + * @brief Gets device pointer associated with symbol on the device. + * + * @param[out] devPtr pointer to the device associated the symbole + * @param[in] symbol pointer to the symbole of the device + * + * @return #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipGetSymbolAddress(void **devPtr, const void *symbol); + +/** + * @brief Gets the size of the given symbol on the device. + * + * @param[in] symbol pointer to the device symbole + * @param[out] size pointer to the size + * + * @return #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipGetSymbolSize(size_t *size, const void *symbol); + +/** + * @brief Gets the pointer of requested HIP driver function. + * + * @param[in] symbol The Symbol name of the driver function to request. + * @param[out] pfn Output pointer to the requested driver function. + * @param[in] hipVersion The HIP version for the requested driver function + * symbol. HIP version is defined as 100*version_major + version_minor. For + * example, in HIP 6.1, the hipversion is 601, for the symbol function + * "hipGetDeviceProperties", the specified hipVersion 601 is greater or equal to + * the version 600, the symbol function will be handle properly as backend + * compatible function. + * + * @param[in] flags Currently only default flag is suppported. + * @param[out] symbolStatus Optional enumeration for returned status of + * searching for symbol driver function based on the input hipVersion. + * + * Returns hipSuccess if the returned pfn is addressed to the pointer of found + * driver function. + * + * @return #hipSuccess, #hipErrorInvalidValue. + */ +hipError_t hipGetProcAddress(const char *symbol, void **pfn, int hipVersion, + uint64_t flags, + hipDriverProcAddressQueryResult *symbolStatus); + +/** + * @brief Copies data to the given symbol on the device. + * Symbol HIP APIs allow a kernel to define a device-side data symbol which can + * be accessed on the host side. The symbol can be in __constant or device + * space. Note that the symbol name needs to be encased in the HIP_SYMBOL macro. + * This also applies to hipMemcpyFromSymbol, hipGetSymbolAddress, and + * hipGetSymbolSize. For detailed usage, see the memcpyToSymbol + * example in the HIP Porting Guide. + * + * + * @param[out] symbol pointer to the device symbole + * @param[in] src pointer to the source address + * @param[in] sizeBytes size in bytes to copy + * @param[in] offset offset in bytes from start of symbole + * @param[in] kind type of memory transfer + * + * @return #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipMemcpyToSymbol(const void *symbol, const void *src, + size_t sizeBytes, size_t offset __dparm(0), + hipMemcpyKind kind __dparm(hipMemcpyHostToDevice)); + +/** + * @brief Copies data to the given symbol on the device asynchronously. + * + * @param[out] symbol pointer to the device symbole + * @param[in] src pointer to the source address + * @param[in] sizeBytes size in bytes to copy + * @param[in] offset offset in bytes from start of symbole + * @param[in] kind type of memory transfer + * @param[in] stream stream identifier + * + * @return #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipMemcpyToSymbolAsync(const void *symbol, const void *src, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind, + hipStream_t stream __dparm(0)); + +/** + * @brief Copies data from the given symbol on the device. + * + * @param[out] dst Returns pointer to destinition memory address + * @param[in] symbol Pointer to the symbole address on the device + * @param[in] sizeBytes Size in bytes to copy + * @param[in] offset Offset in bytes from the start of symbole + * @param[in] kind Type of memory transfer + * + * @return #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t +hipMemcpyFromSymbol(void *dst, const void *symbol, size_t sizeBytes, + size_t offset __dparm(0), + hipMemcpyKind kind __dparm(hipMemcpyDeviceToHost)); + +/** + * @brief Copies data from the given symbol on the device asynchronously. + * + * @param[out] dst Returns pointer to destinition memory address + * @param[in] symbol pointer to the symbole address on the device + * @param[in] sizeBytes size in bytes to copy + * @param[in] offset offset in bytes from the start of symbole + * @param[in] kind type of memory transfer + * @param[in] stream stream identifier + * + * @return #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipMemcpyFromSymbolAsync(void *dst, const void *symbol, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind, + hipStream_t stream __dparm(0)); +/** + * @brief Copy data from src to dst asynchronously. + * + * @warning If host or dest are not pinned, the memory copy will be performed + * synchronously. For best performance, use hipHostMalloc to allocate host + * memory that is transferred asynchronously. + * + * @warning on HCC hipMemcpyAsync does not support overlapped H2D and D2H + * copies. For hipMemcpy, the copy is always performed by the device associated + * with the specified stream. + * + * For multi-gpu or peer-to-peer configurations, it is recommended to use a + * stream which is a attached to the device where the src data is physically + * located. For optimal peer-to-peer copies, the copy device must be able to + * access the src and dst pointers (by calling hipDeviceEnablePeerAccess with + * copy agent as the current device and src/dest as the peerDevice argument. if + * this is not done, the hipMemcpy will still work, but will perform the copy + * using a staging buffer on the host. + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * @param[in] kind Type of memory transfer + * @param[in] stream Stream identifier + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorUnknown + * + * @see hipMemcpy, hipMemcpy2D, hipMemcpyToArray, hipMemcpy2DToArray, + * hipMemcpyFromArray, hipMemcpy2DFromArray, hipMemcpyArrayToArray, + * hipMemcpy2DArrayToArray, hipMemcpyToSymbol, hipMemcpyFromSymbol, + * hipMemcpy2DAsync, hipMemcpyToArrayAsync, hipMemcpy2DToArrayAsync, + * hipMemcpyFromArrayAsync, hipMemcpy2DFromArrayAsync, hipMemcpyToSymbolAsync, + * hipMemcpyFromSymbolAsync + */ +hipError_t hipMemcpyAsync(void *dst, const void *src, size_t sizeBytes, + hipMemcpyKind kind, hipStream_t stream __dparm(0)); +/** + * @brief Fills the first sizeBytes bytes of the memory area pointed to by dest + * with the constant byte value value. + * + * @param[out] dst Data being filled + * @param[in] value Value to be set + * @param[in] sizeBytes Data size in bytes + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + */ +hipError_t hipMemset(void *dst, int value, size_t sizeBytes); +/** + * @brief Fills the first sizeBytes bytes of the memory area pointed to by dest + * with the constant byte value value. + * + * @param[out] dest Data ptr to be filled + * @param[in] value Value to be set + * @param[in] count Number of values to be set + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + */ +hipError_t hipMemsetD8(hipDeviceptr_t dest, unsigned char value, size_t count); +/** + * @brief Fills the first sizeBytes bytes of the memory area pointed to by dest + * with the constant byte value value. + * + * hipMemsetD8Async() is asynchronous with respect to the host, so the call may + * return before the memset is complete. The operation can optionally be + * associated to a stream by passing a non-zero stream argument. If stream is + * non-zero, the operation may overlap with operations in other streams. + * + * @param[out] dest Data ptr to be filled + * @param[in] value Constant value to be set + * @param[in] count Number of values to be set + * @param[in] stream Stream identifier + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + */ +hipError_t hipMemsetD8Async(hipDeviceptr_t dest, unsigned char value, + size_t count, hipStream_t stream __dparm(0)); +/** + * @brief Fills the first sizeBytes bytes of the memory area pointed to by dest + * with the constant short value value. + * + * @param[out] dest Data ptr to be filled + * @param[in] value Constant value to be set + * @param[in] count Number of values to be set + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + */ +hipError_t hipMemsetD16(hipDeviceptr_t dest, unsigned short value, + size_t count); +/** + * @brief Fills the first sizeBytes bytes of the memory area pointed to by dest + * with the constant short value value. + * + * hipMemsetD16Async() is asynchronous with respect to the host, so the call may + * return before the memset is complete. The operation can optionally be + * associated to a stream by passing a non-zero stream argument. If stream is + * non-zero, the operation may overlap with operations in other streams. + * + * @param[out] dest Data ptr to be filled + * @param[in] value Constant value to be set + * @param[in] count Number of values to be set + * @param[in] stream Stream identifier + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + */ +hipError_t hipMemsetD16Async(hipDeviceptr_t dest, unsigned short value, + size_t count, hipStream_t stream __dparm(0)); +/** + * @brief Fills the memory area pointed to by dest with the constant integer + * value for specified number of times. + * + * @param[out] dest Data being filled + * @param[in] value Constant value to be set + * @param[in] count Number of values to be set + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + */ +hipError_t hipMemsetD32(hipDeviceptr_t dest, int value, size_t count); +/** + * @brief Fills the first sizeBytes bytes of the memory area pointed to by dev + * with the constant byte value value. + * + * hipMemsetAsync() is asynchronous with respect to the host, so the call may + * return before the memset is complete. The operation can optionally be + * associated to a stream by passing a non-zero stream argument. If stream is + * non-zero, the operation may overlap with operations in other streams. + * + * @param[out] dst Pointer to device memory + * @param[in] value Value to set for each byte of specified memory + * @param[in] sizeBytes Size in bytes to set + * @param[in] stream Stream identifier + * @return #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipMemsetAsync(void *dst, int value, size_t sizeBytes, + hipStream_t stream __dparm(0)); +/** + * @brief Fills the memory area pointed to by dev with the constant integer + * value for specified number of times. + * + * hipMemsetD32Async() is asynchronous with respect to the host, so the call + * may return before the memset is complete. The operation can optionally be + * associated to a stream by passing a non-zero stream argument. If stream is + * non-zero, the operation may overlap with operations in other streams. + * + * @param[out] dst Pointer to device memory + * @param[in] value Value to set for each byte of specified memory + * @param[in] count Number of values to be set + * @param[in] stream Stream identifier + * @return #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipMemsetD32Async(hipDeviceptr_t dst, int value, size_t count, + hipStream_t stream __dparm(0)); +/** + * @brief Fills the memory area pointed to by dst with the constant value. + * + * @param[out] dst Pointer to device memory + * @param[in] pitch Data size in bytes + * @param[in] value Constant value to be set + * @param[in] width + * @param[in] height + * @return #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipMemset2D(void *dst, size_t pitch, int value, size_t width, + size_t height); +/** + * @brief Fills asynchronously the memory area pointed to by dst with the + * constant value. + * + * @param[in] dst Pointer to 2D device memory + * @param[in] pitch Pitch size in bytes + * @param[in] value Value to be set for each byte of specified memory + * @param[in] width Width of matrix set columns in bytes + * @param[in] height Height of matrix set rows in bytes + * @param[in] stream Stream identifier + * @return #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipMemset2DAsync(void *dst, size_t pitch, int value, size_t width, + size_t height, hipStream_t stream __dparm(0)); +/** + * @brief Fills synchronously the memory area pointed to by pitchedDevPtr with + * the constant value. + * + * @param[in] pitchedDevPtr Pointer to pitched device memory + * @param[in] value Value to set for each byte of specified memory + * @param[in] extent Size parameters for width field in bytes in device + * memory + * @return #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipMemset3D(hipPitchedPtr pitchedDevPtr, int value, + hipExtent extent); +/** + * @brief Fills asynchronously the memory area pointed to by pitchedDevPtr with + * the constant value. + * + * @param[in] pitchedDevPtr Pointer to pitched device memory + * @param[in] value Value to set for each byte of specified memory + * @param[in] extent Size parameters for width field in bytes in device + * memory + * @param[in] stream Stream identifier + * @return #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipMemset3DAsync(hipPitchedPtr pitchedDevPtr, int value, + hipExtent extent, hipStream_t stream __dparm(0)); +/** + * @brief Query memory info. + * + * On ROCM, this function gets the actual free memory left on the current + *device, so supports the cases while running multi-workload (such as multiple + *processes, multiple threads, and multiple GPUs). + * + * @warning On Windows, the free memory only accounts for memory allocated by + *this process and may be optimistic. + * + * @param[out] free Returns free memory on the current device in bytes + * @param[out] total Returns total allocatable memory on the current device in + *bytes + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + **/ +hipError_t hipMemGetInfo(size_t *free, size_t *total); + +/** + * @brief Get allocated memory size via memory pointer. + * + * This function gets the allocated shared virtual memory size from memory + *pointer. + * + * @param[in] ptr Pointer to allocated memory + * @param[out] size Returns the allocated memory size in bytes + * + * @return #hipSuccess, #hipErrorInvalidValue + * + **/ +hipError_t hipMemPtrGetInfo(void *ptr, size_t *size); +/** + * @brief Allocate an array on the device. + * + * @param[out] array Pointer to allocated array in device memory + * @param[in] desc Requested channel format + * @param[in] width Requested array allocation width + * @param[in] height Requested array allocation height + * @param[in] flags Requested properties of allocated array + * @return #hipSuccess, #hipErrorOutOfMemory + * + * @see hipMalloc, hipMallocPitch, hipFree, hipFreeArray, hipHostMalloc, + * hipHostFree + */ +hipError_t hipMallocArray(hipArray_t *array, const hipChannelFormatDesc *desc, + size_t width, size_t height __dparm(0), + unsigned int flags __dparm(hipArrayDefault)); +/** + * @brief Create an array memory pointer on the device. + * + * @param[out] pHandle Pointer to the array memory + * @param[in] pAllocateArray Requested array desciptor + * + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @see hipMallocArray, hipArrayDestroy, hipFreeArray + */ +hipError_t hipArrayCreate(hipArray_t *pHandle, + const HIP_ARRAY_DESCRIPTOR *pAllocateArray); +/** + * @brief Destroy an array memory pointer on the device. + * + * @param[in] array Pointer to the array memory + * + * @return #hipSuccess, #hipErrorInvalidValue + * + * @see hipArrayCreate, hipArrayDestroy, hipFreeArray + */ +hipError_t hipArrayDestroy(hipArray_t array); +/** + * @brief Create a 3D array memory pointer on the device. + * + * @param[out] array Pointer to the 3D array memory + * @param[in] pAllocateArray Requested array desciptor + * + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @see hipMallocArray, hipArrayDestroy, hipFreeArray + */ +hipError_t hipArray3DCreate(hipArray_t *array, + const HIP_ARRAY3D_DESCRIPTOR *pAllocateArray); +/** + * @brief Create a 3D memory pointer on the device. + * + * @param[out] pitchedDevPtr Pointer to the 3D memory + * @param[in] extent Requested extent + * + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @see hipMallocPitch, hipMemGetInfo, hipFree + */ +hipError_t hipMalloc3D(hipPitchedPtr *pitchedDevPtr, hipExtent extent); +/** + * @brief Frees an array on the device. + * + * @param[in] array Pointer to array to free + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorNotInitialized + * + * @see hipMalloc, hipMallocPitch, hipFree, hipMallocArray, hipHostMalloc, + * hipHostFree + */ +hipError_t hipFreeArray(hipArray_t array); +/** + * @brief Allocate an array on the device. + * + * @param[out] array Pointer to allocated array in device memory + * @param[in] desc Requested channel format + * @param[in] extent Requested array allocation width, height and depth + * @param[in] flags Requested properties of allocated array + * @return #hipSuccess, #hipErrorOutOfMemory + * + * @see hipMalloc, hipMallocPitch, hipFree, hipFreeArray, hipHostMalloc, + * hipHostFree + */ +hipError_t hipMalloc3DArray(hipArray_t *array, + const struct hipChannelFormatDesc *desc, + struct hipExtent extent, unsigned int flags); +/** + * @brief Gets info about the specified array + * + * @param[out] desc - Returned array type + * @param[out] extent - Returned array shape. 2D arrays will have depth of zero + * @param[out] flags - Returned array flags + * @param[in] array - The HIP array to get info for + * + * @return #hipSuccess, #hipErrorInvalidValue #hipErrorInvalidHandle + * + * @see hipArrayGetDescriptor, hipArray3DGetDescriptor + */ +hipError_t hipArrayGetInfo(hipChannelFormatDesc *desc, hipExtent *extent, + unsigned int *flags, hipArray_t array); +/** + * @brief Gets a 1D or 2D array descriptor + * + * @param[out] pArrayDescriptor - Returned array descriptor + * @param[in] array - Array to get descriptor of + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue #hipErrorInvalidHandle + * + * @see hipArray3DCreate, hipArray3DGetDescriptor, hipArrayCreate, + * hipArrayDestroy, hipMemAlloc, hipMemAllocHost, hipMemAllocPitch, hipMemcpy2D, + * hipMemcpy2DAsync, hipMemcpy2DUnaligned, hipMemcpy3D, hipMemcpy3DAsync, + * hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, hipMemcpyAtoHAsync, + * hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, hipMemcpyDtoH, + * hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, hipMemcpyHtoD, + * hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, hipMemGetAddressRange, + * hipMemGetInfo, hipMemHostAlloc, hipMemHostGetDevicePointer, hipMemsetD8, + * hipMemsetD16, hipMemsetD32, hipArrayGetInfo + */ +hipError_t hipArrayGetDescriptor(HIP_ARRAY_DESCRIPTOR *pArrayDescriptor, + hipArray_t array); +/** + * @brief Gets a 3D array descriptor + * + * @param[out] pArrayDescriptor - Returned 3D array descriptor + * @param[in] array - 3D array to get descriptor of + * + * @return #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidValue #hipErrorInvalidHandle, + * #hipErrorContextIsDestroyed + * + * @see hipArray3DCreate, hipArrayCreate, hipArrayDestroy, + * hipArrayGetDescriptor, hipMemAlloc, hipMemAllocHost, hipMemAllocPitch, + * hipMemcpy2D, hipMemcpy2DAsync, hipMemcpy2DUnaligned, hipMemcpy3D, + * hipMemcpy3DAsync, hipMemcpyAtoA, hipMemcpyAtoD, hipMemcpyAtoH, + * hipMemcpyAtoHAsync, hipMemcpyDtoA, hipMemcpyDtoD, hipMemcpyDtoDAsync, + * hipMemcpyDtoH, hipMemcpyDtoHAsync, hipMemcpyHtoA, hipMemcpyHtoAAsync, + * hipMemcpyHtoD, hipMemcpyHtoDAsync, hipMemFree, hipMemFreeHost, + * hipMemGetAddressRange, hipMemGetInfo, hipMemHostAlloc, + * hipMemHostGetDevicePointer, hipMemsetD8, hipMemsetD16, hipMemsetD32, + * hipArrayGetInfo + */ +hipError_t hipArray3DGetDescriptor(HIP_ARRAY3D_DESCRIPTOR *pArrayDescriptor, + hipArray_t array); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] dpitch Pitch of destination memory + * @param[in] src Source memory address + * @param[in] spitch Pitch of source memory + * @param[in] width Width of matrix transfer (columns in bytes) + * @param[in] height Height of matrix transfer (rows) + * @param[in] kind Type of transfer + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpyToArray, hipMemcpy2DToArray, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy2D(void *dst, size_t dpitch, const void *src, size_t spitch, + size_t width, size_t height, hipMemcpyKind kind); +/** + * @brief Copies memory for 2D arrays. + * @param[in] pCopy Parameters for the memory copy + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2D, hipMemcpyToArray, hipMemcpy2DToArray, + * hipMemcpyFromArray, hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpyParam2D(const hip_Memcpy2D *pCopy); +/** + * @brief Copies memory for 2D arrays. + * @param[in] pCopy Parameters for the memory copy + * @param[in] stream Stream to use + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2D, hipMemcpyToArray, hipMemcpy2DToArray, + * hipMemcpyFromArray, hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpyParam2DAsync(const hip_Memcpy2D *pCopy, + hipStream_t stream __dparm(0)); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] dpitch Pitch of destination memory + * @param[in] src Source memory address + * @param[in] spitch Pitch of source memory + * @param[in] width Width of matrix transfer (columns in bytes) + * @param[in] height Height of matrix transfer (rows) + * @param[in] kind Type of transfer + * @param[in] stream Stream to use + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpyToArray, hipMemcpy2DToArray, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy2DAsync(void *dst, size_t dpitch, const void *src, + size_t spitch, size_t width, size_t height, + hipMemcpyKind kind, hipStream_t stream __dparm(0)); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] wOffset Destination starting X offset + * @param[in] hOffset Destination starting Y offset + * @param[in] src Source memory address + * @param[in] spitch Pitch of source memory + * @param[in] width Width of matrix transfer (columns in bytes) + * @param[in] height Height of matrix transfer (rows) + * @param[in] kind Type of transfer + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpyToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy2DToArray(hipArray_t dst, size_t wOffset, size_t hOffset, + const void *src, size_t spitch, size_t width, + size_t height, hipMemcpyKind kind); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] wOffset Destination starting X offset + * @param[in] hOffset Destination starting Y offset + * @param[in] src Source memory address + * @param[in] spitch Pitch of source memory + * @param[in] width Width of matrix transfer (columns in bytes) + * @param[in] height Height of matrix transfer (rows) + * @param[in] kind Type of transfer + * @param[in] stream Accelerator view which the copy is being enqueued + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpyToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy2DToArrayAsync(hipArray_t dst, size_t wOffset, + size_t hOffset, const void *src, + size_t spitch, size_t width, size_t height, + hipMemcpyKind kind, + hipStream_t stream __dparm(0)); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] wOffsetDst Destination starting X offset + * @param[in] hOffsetDst Destination starting Y offset + * @param[in] src Source memory address + * @param[in] wOffsetSrc Source starting X offset + * @param[in] hOffsetSrc Source starting Y offset (columns in bytes) + * @param[in] width Width of matrix transfer (columns in bytes) + * @param[in] height Height of matrix transfer (rows) + * @param[in] kind Type of transfer + * + * @returns #hipSuccess, #hipErrorInvalidValue, + * #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpyToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy2DArrayToArray(hipArray_t dst, size_t wOffsetDst, + size_t hOffsetDst, hipArray_const_t src, + size_t wOffsetSrc, size_t hOffsetSrc, + size_t width, size_t height, + hipMemcpyKind kind); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] wOffset Destination starting X offset + * @param[in] hOffset Destination starting Y offset + * @param[in] src Source memory address + * @param[in] count size in bytes to copy + * @param[in] kind Type of transfer + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + * @warning This API is deprecated. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipMemcpyToArray(hipArray_t dst, size_t wOffset, size_t hOffset, + const void *src, size_t count, hipMemcpyKind kind); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] srcArray Source memory address + * @param[in] wOffset Source starting X offset + * @param[in] hOffset Source starting Y offset + * @param[in] count Size in bytes to copy + * @param[in] kind Type of transfer + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + * @warning This API is deprecated. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipMemcpyFromArray(void *dst, hipArray_const_t srcArray, + size_t wOffset, size_t hOffset, size_t count, + hipMemcpyKind kind); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] dpitch Pitch of destination memory + * @param[in] src Source memory address + * @param[in] wOffset Source starting X offset + * @param[in] hOffset Source starting Y offset + * @param[in] width Width of matrix transfer (columns in bytes) + * @param[in] height Height of matrix transfer (rows) + * @param[in] kind Type of transfer + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy2DFromArray(void *dst, size_t dpitch, hipArray_const_t src, + size_t wOffset, size_t hOffset, size_t width, + size_t height, hipMemcpyKind kind); +/** + * @brief Copies data between host and device asynchronously. + * + * @param[in] dst Destination memory address + * @param[in] dpitch Pitch of destination memory + * @param[in] src Source memory address + * @param[in] wOffset Source starting X offset + * @param[in] hOffset Source starting Y offset + * @param[in] width Width of matrix transfer (columns in bytes) + * @param[in] height Height of matrix transfer (rows) + * @param[in] kind Type of transfer + * @param[in] stream Accelerator view which the copy is being enqueued + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy2DFromArrayAsync(void *dst, size_t dpitch, + hipArray_const_t src, size_t wOffset, + size_t hOffset, size_t width, + size_t height, hipMemcpyKind kind, + hipStream_t stream __dparm(0)); +/** + * @brief Copies data between host and device. + * + * @param[in] dst Destination memory address + * @param[in] srcArray Source array + * @param[in] srcOffset Offset in bytes of source array + * @param[in] count Size of memory copy in bytes + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpyAtoH(void *dst, hipArray_t srcArray, size_t srcOffset, + size_t count); +/** + * @brief Copies data between host and device. + * + * @param[in] dstArray Destination memory address + * @param[in] dstOffset Offset in bytes of destination array + * @param[in] srcHost Source host pointer + * @param[in] count Size of memory copy in bytes + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpyHtoA(hipArray_t dstArray, size_t dstOffset, + const void *srcHost, size_t count); +/** + * @brief Copies data between host and device. + * + * @param[in] p 3D memory copy parameters + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy3D(const struct hipMemcpy3DParms *p); +/** + * @brief Copies data between host and device asynchronously. + * + * @param[in] p 3D memory copy parameters + * @param[in] stream Stream to use + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipMemcpy3DAsync(const struct hipMemcpy3DParms *p, + hipStream_t stream __dparm(0)); +/** + * @brief Copies data between host and device. + * + * @param[in] pCopy 3D memory copy parameters + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipDrvMemcpy3D(const HIP_MEMCPY3D *pCopy); +/** + * @brief Copies data between host and device asynchronously. + * + * @param[in] pCopy 3D memory copy parameters + * @param[in] stream Stream to use + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidPitchValue, + * #hipErrorInvalidDevicePointer, #hipErrorInvalidMemcpyDirection + * + * @see hipMemcpy, hipMemcpy2DToArray, hipMemcpy2D, hipMemcpyFromArray, + * hipMemcpyToSymbol, hipMemcpyAsync + */ +hipError_t hipDrvMemcpy3DAsync(const HIP_MEMCPY3D *pCopy, hipStream_t stream); +// doxygen end Memory +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup PeerToPeer PeerToPeer Device Memory Access + * @{ + * @warning PeerToPeer support is experimental. + * This section describes the PeerToPeer device memory access functions of HIP + *runtime API. + */ +/** + * @brief Determine if a device can access a peer's memory. + * + * @param [out] canAccessPeer Returns the peer access capability (0 or 1) + * @param [in] deviceId - device from where memory may be accessed. + * @param [in] peerDeviceId - device where memory is physically located + * + * Returns "1" in @p canAccessPeer if the specified @p device is capable + * of directly accessing memory physically located on peerDevice , or "0" if + * not. + * + * Returns "0" in @p canAccessPeer if deviceId == peerDeviceId, and both are + * valid devices : a device is not a peer of itself. + * + * @returns #hipSuccess, + * @returns #hipErrorInvalidDevice if deviceId or peerDeviceId are not valid + * devices + */ +hipError_t hipDeviceCanAccessPeer(int *canAccessPeer, int deviceId, + int peerDeviceId); +/** + * @brief Enable direct access from current device's virtual address space to + * memory allocations physically located on a peer device. + * + * Memory which already allocated on peer device will be mapped into the address + * space of the current device. In addition, all future memory allocations on + * peerDeviceId will be mapped into the address space of the current device when + * the memory is allocated. The peer memory remains accessible from the current + * device until a call to hipDeviceDisablePeerAccess or hipDeviceReset. + * + * + * @param [in] peerDeviceId Peer device to enable direct access to from the + * current device + * @param [in] flags Reserved for future use, must be zero + * + * Returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue, + * @returns #hipErrorPeerAccessAlreadyEnabled if peer access is already enabled + * for this device. + */ +hipError_t hipDeviceEnablePeerAccess(int peerDeviceId, unsigned int flags); +/** + * @brief Disable direct access from current device's virtual address space to + * memory allocations physically located on a peer device. + * + * Returns hipErrorPeerAccessNotEnabled if direct access to memory on peerDevice + * has not yet been enabled from the current device. + * + * @param [in] peerDeviceId Peer device to disable direct access to + * + * @returns #hipSuccess, #hipErrorPeerAccessNotEnabled + */ +hipError_t hipDeviceDisablePeerAccess(int peerDeviceId); +/** + * @brief Get information on memory allocations. + * + * @param [out] pbase - BAse pointer address + * @param [out] psize - Size of allocation + * @param [in] dptr- Device Pointer + * + * @returns #hipSuccess, #hipErrorNotFound + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + */ +hipError_t hipMemGetAddressRange(hipDeviceptr_t *pbase, size_t *psize, + hipDeviceptr_t dptr); +#ifndef USE_PEER_NON_UNIFIED +#define USE_PEER_NON_UNIFIED 1 +#endif +#if USE_PEER_NON_UNIFIED == 1 +/** + * @brief Copies memory from one device to memory on another device. + * + * @param [out] dst - Destination device pointer. + * @param [in] dstDeviceId - Destination device + * @param [in] src - Source device pointer + * @param [in] srcDeviceId - Source device + * @param [in] sizeBytes - Size of memory copy in bytes + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidDevice + */ +hipError_t hipMemcpyPeer(void *dst, int dstDeviceId, const void *src, + int srcDeviceId, size_t sizeBytes); +/** + * @brief Copies memory from one device to memory on another device. + * + * @param [out] dst - Destination device pointer. + * @param [in] dstDeviceId - Destination device + * @param [in] src - Source device pointer + * @param [in] srcDevice - Source device + * @param [in] sizeBytes - Size of memory copy in bytes + * @param [in] stream - Stream identifier + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidDevice + */ +hipError_t hipMemcpyPeerAsync(void *dst, int dstDeviceId, const void *src, + int srcDevice, size_t sizeBytes, + hipStream_t stream __dparm(0)); +#endif +// doxygen end PeerToPeer +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Context Context Management [Deprecated] + * @{ + * This section describes the context management functions of HIP runtime API. + * + * @warning + * + * On the AMD platform, context management APIs are deprecated as there are + *better alternate interfaces, such as using hipSetDevice and stream APIs to + *achieve the required functionality. + * + * On the NVIDIA platform, CUDA supports the driver API that defines "Context" + *and "Devices" as separate entities. Each context contains a single device, + *which can theoretically have multiple contexts. HIP initially added limited + *support for these APIs to facilitate easy porting from existing driver codes. + * + * These APIs are only for equivalent driver APIs on the NVIDIA platform. + * + */ + +/** + * @brief Create a context and set it as current/default context + * + * @param [out] ctx Context to create + * @param [in] flags Context creation flags + * @param [in] device device handle + * + * @return #hipSuccess + * + * @see hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, hipCtxGetCurrent, + * hipCtxPushCurrent, hipCtxSetCacheConfig, hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device); +/** + * @brief Destroy a HIP context. + * + * @param [in] ctx Context to destroy + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @see hipCtxCreate, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent,hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize , hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxDestroy(hipCtx_t ctx); +/** + * @brief Pop the current/default context and return the popped context. + * + * @param [out] ctx The current context to pop + * + * @returns #hipSuccess, #hipErrorInvalidContext + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxSetCurrent, + * hipCtxGetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, hipCtxSynchronize, + * hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxPopCurrent(hipCtx_t *ctx); +/** + * @brief Push the context to be set as current/ default context + * + * @param [in] ctx The current context to push + * + * @returns #hipSuccess, #hipErrorInvalidContext + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, hipCtxSynchronize + * , hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxPushCurrent(hipCtx_t ctx); +/** + * @brief Set the passed context as current/default + * + * @param [in] ctx The context to set as current + * + * @returns #hipSuccess, #hipErrorInvalidContext + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, hipCtxSynchronize + * , hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxSetCurrent(hipCtx_t ctx); +/** + * @brief Get the handle of the current/ default context + * + * @param [out] ctx The context to get as current + * + * @returns #hipSuccess, #hipErrorInvalidContext + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetDevice, hipCtxGetFlags, + * hipCtxPopCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, hipCtxSynchronize, + * hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxGetCurrent(hipCtx_t *ctx); +/** + * @brief Get the handle of the device associated with current/default context + * + * @param [out] device The device from the current context + * + * @returns #hipSuccess, #hipErrorInvalidContext + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, hipCtxSynchronize + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxGetDevice(hipDevice_t *device); +/** + * @brief Returns the approximate HIP api version. + * + * @param [in] ctx Context to check + * @param [out] apiVersion API version to get + * + * @return #hipSuccess + * + * @warning The HIP feature set does not correspond to an exact CUDA SDK api + * revision. This function always set *apiVersion to 4 as an approximation + * though HIP supports some features which were introduced in later CUDA SDK + * revisions. HIP apps code should not rely on the api revision number here and + * should use arch feature flags to test device capabilities or conditional + * compilation. + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetDevice, hipCtxGetFlags, + * hipCtxPopCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, hipCtxSynchronize, + * hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxGetApiVersion(hipCtx_t ctx, int *apiVersion); +/** + * @brief Get Cache configuration for a specific function + * + * @param [out] cacheConfig Cache configuration + * + * @return #hipSuccess + * + * @warning AMD devices and some Nvidia GPUS do not support reconfigurable + * cache. This hint is ignored on those architectures. + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxGetCacheConfig(hipFuncCache_t *cacheConfig); +/** + * @brief Set L1/Shared cache partition. + * + * @param [in] cacheConfig Cache configuration to set + * + * @return #hipSuccess + * + * @warning AMD devices and some Nvidia GPUS do not support reconfigurable + * cache. This hint is ignored on those architectures. + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxSetCacheConfig(hipFuncCache_t cacheConfig); +/** + * @brief Set Shared memory bank configuration. + * + * @param [in] config Shared memory configuration to set + * + * @return #hipSuccess + * + * @warning AMD devices and some Nvidia GPUS do not support shared cache + * banking, and the hint is ignored on those architectures. + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxSetSharedMemConfig(hipSharedMemConfig config); +/** + * @brief Get Shared memory bank configuration. + * + * @param [out] pConfig Pointer of shared memory configuration + * + * @return #hipSuccess + * + * @warning AMD devices and some Nvidia GPUS do not support shared cache + * banking, and the hint is ignored on those architectures. + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxGetSharedMemConfig(hipSharedMemConfig *pConfig); +/** + * @brief Blocks until the default context has completed all preceding requested + * tasks. + * + * @return #hipSuccess + * + * @warning This function waits for all streams on the default context to + * complete execution, and then returns. + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxSynchronize(void); +/** + * @brief Return flags used for creating default context. + * + * @param [out] flags Pointer of flags + * + * @returns #hipSuccess + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxPopCurrent, hipCtxGetCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxGetFlags(unsigned int *flags); +/** + * @brief Enables direct access to memory allocations in a peer context. + * + * Memory which already allocated on peer device will be mapped into the address + * space of the current device. In addition, all future memory allocations on + * peerDeviceId will be mapped into the address space of the current device when + * the memory is allocated. The peer memory remains accessible from the current + * device until a call to hipDeviceDisablePeerAccess or hipDeviceReset. + * + * + * @param [in] peerCtx Peer context + * @param [in] flags flags, need to set as 0 + * + * @returns #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue, + * #hipErrorPeerAccessAlreadyEnabled + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * @warning PeerToPeer support is experimental. + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxEnablePeerAccess(hipCtx_t peerCtx, unsigned int flags); +/** + * @brief Disable direct access from current context's virtual address space to + * memory allocations physically located on a peer context.Disables direct + * access to memory allocations in a peer context and unregisters any registered + * allocations. + * + * Returns #hipErrorPeerAccessNotEnabled if direct access to memory on + * peerDevice has not yet been enabled from the current device. + * + * @param [in] peerCtx Peer context to be disabled + * + * @returns #hipSuccess, #hipErrorPeerAccessNotEnabled + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * @warning PeerToPeer support is experimental. + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * cuCtx driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipCtxDisablePeerAccess(hipCtx_t peerCtx); + +/** + * @brief Get the state of the primary context. + * + * @param [in] dev Device to get primary context flags for + * @param [out] flags Pointer to store flags + * @param [out] active Pointer to store context state; 0 = inactive, 1 = active + * + * @returns #hipSuccess + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipDevicePrimaryCtxGetState(hipDevice_t dev, unsigned int *flags, + int *active); +/** + * @brief Release the primary context on the GPU. + * + * @param [in] dev Device which primary context is released + * + * @returns #hipSuccess + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * @warning This function return #hipSuccess though doesn't release the + * primaryCtx by design on HIP/HCC path. + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipDevicePrimaryCtxRelease(hipDevice_t dev); +/** + * @brief Retain the primary context on the GPU. + * + * @param [out] pctx Returned context handle of the new context + * @param [in] dev Device which primary context is released + * + * @returns #hipSuccess + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipDevicePrimaryCtxRetain(hipCtx_t *pctx, hipDevice_t dev); +/** + * @brief Resets the primary context on the GPU. + * + * @param [in] dev Device which primary context is reset + * + * @returns #hipSuccess + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipDevicePrimaryCtxReset(hipDevice_t dev); +/** + * @brief Set flags for the primary context. + * + * @param [in] dev Device for which the primary context flags are set + * @param [in] flags New flags for the device + * + * @returns #hipSuccess, #hipErrorContextAlreadyInUse + * + * @see hipCtxCreate, hipCtxDestroy, hipCtxGetFlags, hipCtxPopCurrent, + * hipCtxGetCurrent, hipCtxSetCurrent, hipCtxPushCurrent, hipCtxSetCacheConfig, + * hipCtxSynchronize, hipCtxGetDevice + * + * @warning This API is deprecated on the AMD platform, only for equivalent + * driver API on the NVIDIA platform. + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipDevicePrimaryCtxSetFlags(hipDevice_t dev, unsigned int flags); +// doxygen end Context Management +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * + * @defgroup Module Module Management + * @{ + * @ingroup API + * This section describes the module management functions of HIP runtime API. + * + */ +/** + * @brief Loads code object from file into a module the currrent context. + * + * @param [in] fname Filename of code object to load + + * @param [out] module Module + * + * @warning File/memory resources allocated in this function are released only + in hipModuleUnload. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidContext, + #hipErrorFileNotFound, + * #hipErrorOutOfMemory, #hipErrorSharedObjectInitFailed, + #hipErrorNotInitialized + * + */ +hipError_t hipModuleLoad(hipModule_t *module, const char *fname); +/** + * @brief Frees the module + * + * @param [in] module Module to free + * + * @returns #hipSuccess, #hipErrorInvalidResourceHandle + * + * The module is freed, and the code objects associated with it are destroyed. + */ +hipError_t hipModuleUnload(hipModule_t module); +/** + * @brief Function with kname will be extracted if present in module + * + * @param [in] module Module to get function from + * @param [in] kname Pointer to the name of function + * @param [out] function Pointer to function handle + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidContext, + * #hipErrorNotInitialized, #hipErrorNotFound, + */ +hipError_t hipModuleGetFunction(hipFunction_t *function, hipModule_t module, + const char *kname); +/** + * @brief Find out attributes for a given function. + * + * @param [out] attr Attributes of funtion + * @param [in] func Pointer to the function handle + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidDeviceFunction + */ +hipError_t hipFuncGetAttributes(struct hipFuncAttributes *attr, + const void *func); +/** + * @brief Find out a specific attribute for a given function. + * + * @param [out] value Pointer to the value + * @param [in] attrib Attributes of the given funtion + * @param [in] hfunc Function to get attributes from + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidDeviceFunction + */ +hipError_t hipFuncGetAttribute(int *value, hipFunction_attribute attrib, + hipFunction_t hfunc); +/** + * @brief Gets pointer to device entry function that matches entry function + * symbolPtr. + * + * @param [out] functionPtr Device entry function + * @param [in] symbolPtr Pointer to device entry function to search for + * + * @returns #hipSuccess, #hipErrorInvalidDeviceFunction + * + */ +hipError_t hipGetFuncBySymbol(hipFunction_t *functionPtr, + const void *symbolPtr); +/** + * @brief returns the handle of the texture reference with the name from the + * module. + * + * @param [in] hmod Module + * @param [in] name Pointer of name of texture reference + * @param [out] texRef Pointer of texture reference + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorNotFound, + * #hipErrorInvalidValue + */ +hipError_t hipModuleGetTexRef(textureReference **texRef, hipModule_t hmod, + const char *name); +/** + * @brief builds module from code object which resides in host memory. Image is + * pointer to that location. + * + * @param [in] image The pointer to the location of data + * @param [out] module Retuned module + * + * @returns hipSuccess, hipErrorNotInitialized, hipErrorOutOfMemory, + * hipErrorNotInitialized + */ +hipError_t hipModuleLoadData(hipModule_t *module, const void *image); +/** + * @brief builds module from code object which resides in host memory. Image is + * pointer to that location. Options are not used. hipModuleLoadData is called. + * + * @param [in] image The pointer to the location of data + * @param [out] module Retuned module + * @param [in] numOptions Number of options + * @param [in] options Options for JIT + * @param [in] optionValues Option values for JIT + * + * @returns hipSuccess, hipErrorNotInitialized, hipErrorOutOfMemory, + * hipErrorNotInitialized + */ +hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, + unsigned int numOptions, hipJitOption *options, + void **optionValues); +/** + * @brief launches kernel f with launch parameters and shared memory on stream + * with arguments passed to kernelparams or extra + * + * @param [in] f Kernel to launch. + * @param [in] gridDimX X grid dimension specified as multiple of blockDimX. + * @param [in] gridDimY Y grid dimension specified as multiple of blockDimY. + * @param [in] gridDimZ Z grid dimension specified as multiple of blockDimZ. + * @param [in] blockDimX X block dimensions specified in work-items + * @param [in] blockDimY Y grid dimension specified in work-items + * @param [in] blockDimZ Z grid dimension specified in work-items + * @param [in] sharedMemBytes Amount of dynamic shared memory to allocate for + * this kernel. The HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream Stream where the kernel should be dispatched. May be + * 0, in which case th default stream is used with associated synchronization + * rules. + * @param [in] kernelParams Kernel parameters to launch + * @param [in] extra Pointer to kernel arguments. These are passed + * directly to the kernel and must be in the memory layout and alignment + * expected by the kernel. All passed arguments must be naturally aligned + * according to their type. The memory address of each argument should be a + * multiple of its size in bytes. Please refer to hip_porting_driver_api.md for + * sample usage. + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. So gridDim.x * blockDim.x, + * gridDim.y * blockDim.y and gridDim.z * blockDim.z are always less than 2^32. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue + */ +hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, + hipStream_t stream, void **kernelParams, + void **extra); +/** + * @brief launches kernel f with launch parameters and shared memory on stream + * with arguments passed to kernelParams, where thread blocks can cooperate and + * synchronize as they execute + * + * @param [in] f Kernel to launch. + * @param [in] gridDimX X grid dimension specified as multiple of + * blockDimX. + * @param [in] gridDimY Y grid dimension specified as multiple of + * blockDimY. + * @param [in] gridDimZ Z grid dimension specified as multiple of + * blockDimZ. + * @param [in] blockDimX X block dimension specified in work-items. + * @param [in] blockDimY Y block dimension specified in work-items. + * @param [in] blockDimZ Z block dimension specified in work-items. + * @param [in] sharedMemBytes Amount of dynamic shared memory to allocate for + * this kernel. The HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream Stream where the kernel should be dispatched. May + * be 0, in which case the default stream is used with associated + * synchronization rules. + * @param [in] kernelParams A list of kernel arguments. + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * @returns #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidHandle, #hipErrorInvalidImage, + * #hipErrorInvalidValue, #hipErrorInvalidConfiguration, #hipErrorLaunchFailure, + * #hipErrorLaunchOutOfResources, #hipErrorLaunchTimeOut, + * #hipErrorCooperativeLaunchTooLarge, #hipErrorSharedObjectInitFailed + */ +hipError_t hipModuleLaunchCooperativeKernel( + hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t stream, + void **kernelParams); +/** + * @brief Launches kernels on multiple devices where thread blocks can cooperate + * and synchronize as they execute. + * + * @param [in] launchParamsList List of launch parameters, one per + * device. + * @param [in] numDevices Size of the launchParamsList array. + * @param [in] flags Flags to control launch behavior. + * + * @returns #hipSuccess, #hipErrorDeinitialized, #hipErrorNotInitialized, + * #hipErrorInvalidContext, #hipErrorInvalidHandle, #hipErrorInvalidImage, + * #hipErrorInvalidValue, #hipErrorInvalidConfiguration, + * #hipErrorInvalidResourceHandle, #hipErrorLaunchFailure, + * #hipErrorLaunchOutOfResources, #hipErrorLaunchTimeOut, + * #hipErrorCooperativeLaunchTooLarge, #hipErrorSharedObjectInitFailed + */ +hipError_t hipModuleLaunchCooperativeKernelMultiDevice( + hipFunctionLaunchParams *launchParamsList, unsigned int numDevices, + unsigned int flags); +/** + * @brief launches kernel f with launch parameters and shared memory on stream + * with arguments passed to kernelparams or extra, where thread blocks can + * cooperate and synchronize as they execute + * + * @param [in] f Kernel to launch. + * @param [in] gridDim Grid dimensions specified as multiple of blockDim. + * @param [in] blockDimX Block dimensions specified in work-items + * @param [in] kernelParams A list of kernel arguments + * @param [in] sharedMemBytes Amount of dynamic shared memory to allocate for + * this kernel. The HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream Stream where the kernel should be dispatched. May be + * 0, in which case th default stream is used with associated synchronization + * rules. + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue, + * #hipErrorCooperativeLaunchTooLarge + */ +hipError_t hipLaunchCooperativeKernel(const void *f, dim3 gridDim, + dim3 blockDimX, void **kernelParams, + unsigned int sharedMemBytes, + hipStream_t stream); +/** + * @brief Launches kernels on multiple devices where thread blocks can cooperate + * and synchronize as they execute. + * + * @param [in] launchParamsList List of launch parameters, one per + * device. + * @param [in] numDevices Size of the launchParamsList array. + * @param [in] flags Flags to control launch behavior. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue, + * #hipErrorCooperativeLaunchTooLarge + */ +hipError_t +hipLaunchCooperativeKernelMultiDevice(hipLaunchParams *launchParamsList, + int numDevices, unsigned int flags); +/** + * @brief Launches kernels on multiple devices and guarantees all specified + * kernels are dispatched on respective streams before enqueuing any other work + * on the specified streams from any other threads + * + * + * @param [in] launchParamsList List of launch parameters, one per + * device. + * @param [in] numDevices Size of the launchParamsList array. + * @param [in] flags Flags to control launch behavior. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue + */ +hipError_t hipExtLaunchMultiKernelMultiDevice(hipLaunchParams *launchParamsList, + int numDevices, + unsigned int flags); +// doxygen end Module +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Occupancy Occupancy + * @{ + * This section describes the occupancy functions of HIP runtime API. + * + */ +/** + * @brief determine the grid and block sizes to achieves maximum occupancy for a + * kernel + * + * @param [out] gridSize minimum grid size for maximum potential + * occupancy + * @param [out] blockSize block size for maximum potential occupancy + * @param [in] f kernel function for which occupancy is + * calulated + * @param [in] dynSharedMemPerBlk dynamic shared memory usage (in bytes) + * intended for each block + * @param [in] blockSizeLimit the maximum block size for the kernel, use 0 + * for no limit + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +// TODO - Match CUoccupancyB2DSize +hipError_t hipModuleOccupancyMaxPotentialBlockSize(int *gridSize, + int *blockSize, + hipFunction_t f, + size_t dynSharedMemPerBlk, + int blockSizeLimit); +/** + * @brief determine the grid and block sizes to achieves maximum occupancy for a + * kernel + * + * @param [out] gridSize minimum grid size for maximum potential + * occupancy + * @param [out] blockSize block size for maximum potential occupancy + * @param [in] f kernel function for which occupancy is + * calulated + * @param [in] dynSharedMemPerBlk dynamic shared memory usage (in bytes) + * intended for each block + * @param [in] blockSizeLimit the maximum block size for the kernel, use 0 + * for no limit + * @param [in] flags Extra flags for occupancy calculation (only + * default supported) + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +// TODO - Match CUoccupancyB2DSize +hipError_t hipModuleOccupancyMaxPotentialBlockSizeWithFlags( + int *gridSize, int *blockSize, hipFunction_t f, size_t dynSharedMemPerBlk, + int blockSizeLimit, unsigned int flags); +/** + * @brief Returns occupancy for a device function. + * + * @param [out] numBlocks Returned occupancy + * @param [in] f Kernel function (hipFunction) for which + * occupancy is calulated + * @param [in] blockSize Block size the kernel is intended to be + * launched with + * @param [in] dynSharedMemPerBlk Dynamic shared memory usage (in bytes) + * intended for each block + * @returns #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipModuleOccupancyMaxActiveBlocksPerMultiprocessor( + int *numBlocks, hipFunction_t f, int blockSize, size_t dynSharedMemPerBlk); +/** + * @brief Returns occupancy for a device function. + * + * @param [out] numBlocks Returned occupancy + * @param [in] f Kernel function(hipFunction_t) for which + * occupancy is calulated + * @param [in] blockSize Block size the kernel is intended to be + * launched with + * @param [in] dynSharedMemPerBlk Dynamic shared memory usage (in bytes) + * intended for each block + * @param [in] flags Extra flags for occupancy calculation (only + * default supported) + * @returns #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + int *numBlocks, hipFunction_t f, int blockSize, size_t dynSharedMemPerBlk, + unsigned int flags); +/** + * @brief Returns occupancy for a device function. + * + * @param [out] numBlocks Returned occupancy + * @param [in] f Kernel function for which occupancy is + * calulated + * @param [in] blockSize Block size the kernel is intended to be + * launched with + * @param [in] dynSharedMemPerBlk Dynamic shared memory usage (in bytes) + * intended for each block + * @returns #hipSuccess, #hipErrorInvalidDeviceFunction, #hipErrorInvalidValue + */ +hipError_t hipOccupancyMaxActiveBlocksPerMultiprocessor( + int *numBlocks, const void *f, int blockSize, size_t dynSharedMemPerBlk); +/** + * @brief Returns occupancy for a device function. + * + * @param [out] numBlocks Returned occupancy + * @param [in] f Kernel function for which occupancy is + * calulated + * @param [in] blockSize Block size the kernel is intended to be + * launched with + * @param [in] dynSharedMemPerBlk Dynamic shared memory usage (in bytes) + * intended for each block + * @param [in] flags Extra flags for occupancy calculation + * (currently ignored) + * @returns #hipSuccess, #hipErrorInvalidDeviceFunction, #hipErrorInvalidValue + */ +hipError_t hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + int *numBlocks, const void *f, int blockSize, size_t dynSharedMemPerBlk, + unsigned int flags __dparm(hipOccupancyDefault)); +/** + * @brief determine the grid and block sizes to achieves maximum occupancy for a + * kernel + * + * @param [out] gridSize minimum grid size for maximum potential + * occupancy + * @param [out] blockSize block size for maximum potential occupancy + * @param [in] f kernel function for which occupancy is + * calulated + * @param [in] dynSharedMemPerBlk dynamic shared memory usage (in bytes) + * intended for each block + * @param [in] blockSizeLimit the maximum block size for the kernel, use 0 + * for no limit + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipOccupancyMaxPotentialBlockSize(int *gridSize, int *blockSize, + const void *f, + size_t dynSharedMemPerBlk, + int blockSizeLimit); +// doxygen end Occupancy +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Profiler Profiler Control[Deprecated] + * @{ + * This section describes the profiler control functions of HIP runtime API. + * + * @warning The cudaProfilerInitialize API format for "configFile" is not + *supported. + * + */ +// TODO - expand descriptions: +/** + * @brief Start recording of profiling information + * When using this API, start the profiler with profiling disabled. + * (--startdisabled) + * @returns #hipErrorNotSupported + * @warning : hipProfilerStart API is deprecated, use roctracer/rocTX instead. + */ +DEPRECATED("use roctracer/rocTX instead") +hipError_t hipProfilerStart(); +/** + * @brief Stop recording of profiling information. + * When using this API, start the profiler with profiling disabled. + * (--startdisabled) + * @returns #hipErrorNotSupported + * @warning hipProfilerStart API is deprecated, use roctracer/rocTX instead. + */ +DEPRECATED("use roctracer/rocTX instead") +hipError_t hipProfilerStop(); +// doxygen end profiler +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Clang Launch API to support the triple-chevron syntax + * @{ + * This section describes the API to support the triple-chevron syntax. + */ +/** + * @brief Configure a kernel launch. + * + * @param [in] gridDim grid dimension specified as multiple of blockDim. + * @param [in] blockDim block dimensions specified in work-items + * @param [in] sharedMem Amount of dynamic shared memory to allocate for this + * kernel. The HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream Stream where the kernel should be dispatched. May be + * 0, in which case the default stream is used with associated synchronization + * rules. + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue + * + */ +hipError_t hipConfigureCall(dim3 gridDim, dim3 blockDim, + size_t sharedMem __dparm(0), + hipStream_t stream __dparm(0)); +/** + * @brief Set a kernel argument. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue + * + * @param [in] arg Pointer the argument in host memory. + * @param [in] size Size of the argument. + * @param [in] offset Offset of the argument on the argument stack. + * + */ +hipError_t hipSetupArgument(const void *arg, size_t size, size_t offset); +/** + * @brief Launch a kernel. + * + * @param [in] func Kernel to launch. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue + * + */ +hipError_t hipLaunchByPtr(const void *func); +/** + * @brief Push configuration of a kernel launch. + * + * @param [in] gridDim grid dimension specified as multiple of blockDim. + * @param [in] blockDim block dimensions specified in work-items + * @param [in] sharedMem Amount of dynamic shared memory to allocate for this + * kernel. The HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream Stream where the kernel should be dispatched. May be + * 0, in which case the default stream is used with associated synchronization + * rules. + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue + * + */ +hipError_t __hipPushCallConfiguration(dim3 gridDim, dim3 blockDim, + size_t sharedMem __dparm(0), + hipStream_t stream __dparm(0)); +/** + * @brief Pop configuration of a kernel launch. + * + * @param [out] gridDim grid dimension specified as multiple of blockDim. + * @param [out] blockDim block dimensions specified in work-items + * @param [out] sharedMem Amount of dynamic shared memory to allocate for this + * kernel. The HIP-Clang compiler provides support for extern shared + * declarations. + * @param [out] stream Stream where the kernel should be dispatched. May be + * 0, in which case the default stream is used with associated synchronization + * rules. + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * Please note, HIP does not support kernel launch with total work items defined + * in dimension with size gridDim x blockDim >= 2^32. + * + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue + * + */ +hipError_t __hipPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, + size_t *sharedMem, hipStream_t *stream); +/** + * @brief C compliant kernel launch API + * + * @param [in] function_address - kernel stub function pointer. + * @param [in] numBlocks - number of blocks + * @param [in] dimBlocks - dimension of a block + * @param [in] args - kernel arguments + * @param [in] sharedMemBytes - Amount of dynamic shared memory to allocate for + * this kernel. The HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream - Stream where the kernel should be dispatched. May be 0, + * in which case th default stream is used with associated synchronization + * rules. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipLaunchKernel(const void *function_address, dim3 numBlocks, + dim3 dimBlocks, void **args, + size_t sharedMemBytes __dparm(0), + hipStream_t stream __dparm(0)); + +/** + * @brief Enqueues a host function call in a stream. + * + * @param [in] stream - The stream to enqueue work in. + * @param [in] fn - The function to call once enqueued preceeding operations are + * complete. + * @param [in] userData - User-specified data to be passed to the function. + * + * @returns #hipSuccess, #hipErrorInvalidResourceHandle, #hipErrorInvalidValue, + * #hipErrorNotSupported + * + * The host function to call in this API will be executed after the preceding + * operations in the stream are complete. The function is a blocking operation + * that blocks operations in the stream that follow it, until the function is + * returned. Event synchronization and internal callback functions make sure + * enqueued operations will execute in order, in the stream. + * + * The host function must not make any HIP API calls. The host function is + * non-reentrant. It must not perform sychronization with any operation that may + * depend on other processing execution but is not enqueued to run earlier in + * the stream. + * + * Host functions that are enqueued respectively in different non-blocking + * streams can run concurrently. + * + * @warning This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipLaunchHostFunc(hipStream_t stream, hipHostFn_t fn, + void *userData); + +/** + * Copies memory for 2D arrays. + * + * @param pCopy - Parameters for the memory copy + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipDrvMemcpy2DUnaligned(const hip_Memcpy2D *pCopy); +// TODO: Move this to hip_ext.h +/** + * @brief Launches kernel from the pointer address, with arguments and shared + * memory on stream. + * + * @param [in] function_address pointer to the Kernel to launch. + * @param [in] numBlocks number of blocks. + * @param [in] dimBlocks dimension of a block. + * @param [in] args pointer to kernel arguments. + * @param [in] sharedMemBytes Amount of dynamic shared memory to allocate for + * this kernel. HIP-Clang compiler provides support for extern shared + * declarations. + * @param [in] stream Stream where the kernel should be dispatched. + * May be 0, in which case the default stream is used with associated + * synchronization rules. + * @param [in] startEvent If non-null, specified event will be updated to track + * the start time of the kernel launch. The event must be created before calling + * this API. + * @param [in] stopEvent If non-null, specified event will be updated to track + * the stop time of the kernel launch. The event must be created before calling + * this API. + * @param [in] flags The value of hipExtAnyOrderLaunch, signifies if kernel can + * be launched in any order. + * @returns #hipSuccess, #hipErrorNotInitialized, #hipErrorInvalidValue. + * + */ +hipError_t hipExtLaunchKernel(const void *function_address, dim3 numBlocks, + dim3 dimBlocks, void **args, + size_t sharedMemBytes, hipStream_t stream, + hipEvent_t startEvent, hipEvent_t stopEvent, + int flags); +// doxygen end Clang launch +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Texture Texture Management + * @{ + * This section describes the texture management functions of HIP runtime API. + */ + +/** + * @brief Creates a texture object. + * + * @param [out] pTexObject pointer to the texture object to create + * @param [in] pResDesc pointer to resource descriptor + * @param [in] pTexDesc pointer to texture descriptor + * @param [in] pResViewDesc pointer to resource view descriptor + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported, + * #hipErrorOutOfMemory + * + * @note 3D liner filter isn't supported on GFX90A boards, on which the API @p + * hipCreateTextureObject will return hipErrorNotSupported. + * + */ +hipError_t +hipCreateTextureObject(hipTextureObject_t *pTexObject, + const hipResourceDesc *pResDesc, + const hipTextureDesc *pTexDesc, + const struct hipResourceViewDesc *pResViewDesc); + +/** + * @brief Destroys a texture object. + * + * @param [in] textureObject texture object to destroy + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipDestroyTextureObject(hipTextureObject_t textureObject); + +/** + * @brief Gets the channel descriptor in an array. + * + * @param [in] desc pointer to channel format descriptor + * @param [out] array memory array on the device + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipGetChannelDesc(hipChannelFormatDesc *desc, + hipArray_const_t array); + +/** + * @brief Gets resource descriptor for the texture object. + * + * @param [out] pResDesc pointer to resource descriptor + * @param [in] textureObject texture object + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipGetTextureObjectResourceDesc(hipResourceDesc *pResDesc, + hipTextureObject_t textureObject); + +/** + * @brief Gets resource view descriptor for the texture object. + * + * @param [out] pResViewDesc pointer to resource view descriptor + * @param [in] textureObject texture object + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t +hipGetTextureObjectResourceViewDesc(struct hipResourceViewDesc *pResViewDesc, + hipTextureObject_t textureObject); + +/** + * @brief Gets texture descriptor for the texture object. + * + * @param [out] pTexDesc pointer to texture descriptor + * @param [in] textureObject texture object + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipGetTextureObjectTextureDesc(hipTextureDesc *pTexDesc, + hipTextureObject_t textureObject); + +/** + * @brief Creates a texture object. + * + * @param [out] pTexObject pointer to texture object to create + * @param [in] pResDesc pointer to resource descriptor + * @param [in] pTexDesc pointer to texture descriptor + * @param [in] pResViewDesc pointer to resource view descriptor + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipTexObjectCreate(hipTextureObject_t *pTexObject, + const HIP_RESOURCE_DESC *pResDesc, + const HIP_TEXTURE_DESC *pTexDesc, + const HIP_RESOURCE_VIEW_DESC *pResViewDesc); + +/** + * @brief Destroys a texture object. + * + * @param [in] texObject texture object to destroy + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipTexObjectDestroy(hipTextureObject_t texObject); + +/** + * @brief Gets resource descriptor of a texture object. + * + * @param [out] pResDesc pointer to resource descriptor + * @param [in] texObject texture object + * + * @returns #hipSuccess, #hipErrorNotSupported, #hipErrorInvalidValue + * + */ +hipError_t hipTexObjectGetResourceDesc(HIP_RESOURCE_DESC *pResDesc, + hipTextureObject_t texObject); + +/** + * @brief Gets resource view descriptor of a texture object. + * + * @param [out] pResViewDesc pointer to resource view descriptor + * @param [in] texObject texture object + * + * @returns #hipSuccess, #hipErrorNotSupported, #hipErrorInvalidValue + * + */ +hipError_t hipTexObjectGetResourceViewDesc(HIP_RESOURCE_VIEW_DESC *pResViewDesc, + hipTextureObject_t texObject); + +/** + * @brief Gets texture descriptor of a texture object. + * + * @param [out] pTexDesc pointer to texture descriptor + * @param [in] texObject texture object + * + * @returns #hipSuccess, #hipErrorNotSupported, #hipErrorInvalidValue + * + */ +hipError_t hipTexObjectGetTextureDesc(HIP_TEXTURE_DESC *pTexDesc, + hipTextureObject_t texObject); + +/** + * @brief Allocate a mipmapped array on the device. + * + * @param[out] mipmappedArray - Pointer to allocated mipmapped array in device + * memory + * @param[in] desc - Requested channel format + * @param[in] extent - Requested allocation size (width field in + * elements) + * @param[in] numLevels - Number of mipmap levels to allocate + * @param[in] flags - Flags for extensions + * + * @return #hipSuccess, #hipErrorInvalidValue, #hipErrorMemoryAllocation + * + * @note This API is implemented on Windows, under development on Linux. + * + */ +hipError_t hipMallocMipmappedArray(hipMipmappedArray_t *mipmappedArray, + const struct hipChannelFormatDesc *desc, + struct hipExtent extent, + unsigned int numLevels, + unsigned int flags __dparm(0)); + +/** + * @brief Frees a mipmapped array on the device. + * + * @param[in] mipmappedArray - Pointer to mipmapped array to free + * + * @return #hipSuccess, #hipErrorInvalidValue + * + * @note This API is implemented on Windows, under development on Linux. + * + */ +hipError_t hipFreeMipmappedArray(hipMipmappedArray_t mipmappedArray); + +/** + * @brief Gets a mipmap level of a HIP mipmapped array. + * + * @param[out] levelArray - Returned mipmap level HIP array + * @param[in] mipmappedArray - HIP mipmapped array + * @param[in] level - Mipmap level + * + * @return #hipSuccess, #hipErrorInvalidValue + * + * @note This API is implemented on Windows, under development on Linux. + * + */ +hipError_t hipGetMipmappedArrayLevel(hipArray_t *levelArray, + hipMipmappedArray_const_t mipmappedArray, + unsigned int level); + +/** + * @brief Create a mipmapped array. + * + * @param [out] pHandle pointer to mipmapped array + * @param [in] pMipmappedArrayDesc mipmapped array descriptor + * @param [in] numMipmapLevels mipmap level + * + * @returns #hipSuccess, #hipErrorNotSupported, #hipErrorInvalidValue + * + * @note This API is implemented on Windows, under development on Linux. + */ +hipError_t hipMipmappedArrayCreate(hipMipmappedArray_t *pHandle, + HIP_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc, + unsigned int numMipmapLevels); + +/** + * @brief Destroy a mipmapped array. + * + * @param [out] hMipmappedArray pointer to mipmapped array to destroy + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @note This API is implemented on Windows, under development on Linux. + * + */ +hipError_t hipMipmappedArrayDestroy(hipMipmappedArray_t hMipmappedArray); + +/** + * @brief Get a mipmapped array on a mipmapped level. + * + * @param [in] pLevelArray Pointer of array + * @param [out] hMipMappedArray Pointer of mipmapped array on the requested + * mipmap level + * @param [out] level Mipmap level + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @note This API is implemented on Windows, under development on Linux. + * + */ +hipError_t hipMipmappedArrayGetLevel(hipArray_t *pLevelArray, + hipMipmappedArray_t hMipMappedArray, + unsigned int level); + +/** + * + * @addtogroup TextureD Texture Management [Deprecated] + * @{ + * @ingroup Texture + * This section describes the deprecated texture management functions of HIP + * runtime API. + */ + +/** + * @brief Binds a mipmapped array to a texture. + * + * @param [in] tex pointer to the texture reference to bind + * @param [in] mipmappedArray memory mipmapped array on the device + * @param [in] desc opointer to the channel format + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t +hipBindTextureToMipmappedArray(const textureReference *tex, + hipMipmappedArray_const_t mipmappedArray, + const hipChannelFormatDesc *desc); + +/** + * @brief Gets the texture reference related with the symbol. + * + * @param [out] texref texture reference + * @param [in] symbol pointer to the symbol related with the texture for the + * reference + * + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipGetTextureReference(const textureReference **texref, + const void *symbol); + +/** + * @brief Gets the border color used by a texture reference. + * + * @param [out] pBorderColor Returned Type and Value of RGBA color. + * @param [in] texRef Texture reference. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetBorderColor(float *pBorderColor, + const textureReference *texRef); + +/** + * @brief Gets the array bound to a texture reference. + + * + * @param [in] pArray Returned array. + * @param [in] texRef texture reference. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetArray(hipArray_t *pArray, + const textureReference *texRef); + +/** + * @brief Sets address mode for a texture reference. + * + * @param [in] texRef texture reference. + * @param [in] dim Dimension of the texture. + * @param [in] am Value of the texture address mode. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetAddressMode(textureReference *texRef, int dim, + enum hipTextureAddressMode am); +/** + * @brief Binds an array as a texture reference. + * + * @param [in] tex Pointer texture reference. + * @param [in] array Array to bind. + * @param [in] flags Flags should be set as HIP_TRSA_OVERRIDE_FORMAT, as a + * valid value. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetArray(textureReference *tex, hipArray_const_t array, + unsigned int flags); +/** + * @brief Set filter mode for a texture reference. + * + * @param [in] texRef Pointer texture reference. + * @param [in] fm Value of texture filter mode. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetFilterMode(textureReference *texRef, + enum hipTextureFilterMode fm); +/** + * @brief Set flags for a texture reference. + * + * @param [in] texRef Pointer texture reference. + * @param [in] Flags Value of flags. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetFlags(textureReference *texRef, unsigned int Flags); +/** + * @brief Set format for a texture reference. + * + * @param [in] texRef Pointer texture reference. + * @param [in] fmt Value of format. + * @param [in] NumPackedComponents Number of components per array. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetFormat(textureReference *texRef, hipArray_Format fmt, + int NumPackedComponents); +/** + * @brief Binds a memory area to a texture. + * + * @param [in] offset Offset in bytes. + * @param [in] tex Texture to bind. + * @param [in] devPtr Pointer of memory on the device. + * @param [in] desc Pointer of channel format descriptor. + * @param [in] size Size of memory in bites. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipBindTexture(size_t *offset, const textureReference *tex, + const void *devPtr, const hipChannelFormatDesc *desc, + size_t size __dparm(UINT_MAX)); +/** + * @brief Binds a 2D memory area to a texture. + * + * @param [in] offset Offset in bytes. + * @param [in] tex Texture to bind. + * @param [in] devPtr Pointer of 2D memory area on the device. + * @param [in] desc Pointer of channel format descriptor. + * @param [in] width Width in texel units. + * @param [in] height Height in texel units. + * @param [in] pitch Pitch in bytes. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipBindTexture2D(size_t *offset, const textureReference *tex, + const void *devPtr, + const hipChannelFormatDesc *desc, size_t width, + size_t height, size_t pitch); +/** + * @brief Binds a memory area to a texture. + * + * @param [in] tex Pointer of texture reference. + * @param [in] array Array to bind. + * @param [in] desc Pointer of channel format descriptor. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipBindTextureToArray(const textureReference *tex, + hipArray_const_t array, + const hipChannelFormatDesc *desc); +/** + * @brief Get the offset of the alignment in a texture. + * + * @param [in] offset Offset in bytes. + * @param [in] texref Pointer of texture reference. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipGetTextureAlignmentOffset(size_t *offset, + const textureReference *texref); +/** + * @brief Unbinds a texture. + * + * @param [in] tex Texture to unbind. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipUnbindTexture(const textureReference *tex); +/** + * @brief Gets the address for a texture reference. + * + * @param [out] dev_ptr Pointer of device address. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetAddress(hipDeviceptr_t *dev_ptr, + const textureReference *texRef); +/** + * @brief Gets the address mode for a texture reference. + * + * @param [out] pam Pointer of address mode. + * @param [in] texRef Pointer of texture reference. + * @param [in] dim Dimension. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetAddressMode(enum hipTextureAddressMode *pam, + const textureReference *texRef, int dim); +/** + * @brief Gets filter mode for a texture reference. + * + * @param [out] pfm Pointer of filter mode. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetFilterMode(enum hipTextureFilterMode *pfm, + const textureReference *texRef); +/** + * @brief Gets flags for a texture reference. + * + * @param [out] pFlags Pointer of flags. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetFlags(unsigned int *pFlags, + const textureReference *texRef); +/** + * @brief Gets texture format for a texture reference. + * + * @param [out] pFormat Pointer of the format. + * @param [out] pNumChannels Pointer of number of channels. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetFormat(hipArray_Format *pFormat, int *pNumChannels, + const textureReference *texRef); +/** + * @brief Gets the maximum anisotropy for a texture reference. + * + * @param [out] pmaxAnsio Pointer of the maximum anisotropy. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetMaxAnisotropy(int *pmaxAnsio, + const textureReference *texRef); +/** + * @brief Gets the mipmap filter mode for a texture reference. + * + * @param [out] pfm Pointer of the mipmap filter mode. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetMipmapFilterMode(enum hipTextureFilterMode *pfm, + const textureReference *texRef); +/** + * @brief Gets the mipmap level bias for a texture reference. + * + * @param [out] pbias Pointer of the mipmap level bias. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetMipmapLevelBias(float *pbias, + const textureReference *texRef); +/** + * @brief Gets the minimum and maximum mipmap level clamps for a texture + * reference. + * + * @param [out] pminMipmapLevelClamp Pointer of the minimum mipmap level clamp. + * @param [out] pmaxMipmapLevelClamp Pointer of the maximum mipmap level clamp. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, + float *pmaxMipmapLevelClamp, + const textureReference *texRef); +/** + * @brief Gets the mipmapped array bound to a texture reference. + * + * @param [out] pArray Pointer of the mipmapped array. + * @param [in] texRef Pointer of texture reference. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefGetMipMappedArray(hipMipmappedArray_t *pArray, + const textureReference *texRef); +/** + * @brief Sets an bound address for a texture reference. + * + * @param [out] ByteOffset Pointer of the offset in bytes. + * @param [in] texRef Pointer of texture reference. + * @param [in] dptr Pointer of device address to bind. + * @param [in] bytes Size in bytes. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetAddress(size_t *ByteOffset, textureReference *texRef, + hipDeviceptr_t dptr, size_t bytes); +/** + * @brief Set a bind an address as a 2D texture reference. + * + * @param [in] texRef Pointer of texture reference. + * @param [in] desc Pointer of array descriptor. + * @param [in] dptr Pointer of device address to bind. + * @param [in] Pitch Pitch in bytes. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetAddress2D(textureReference *texRef, + const HIP_ARRAY_DESCRIPTOR *desc, + hipDeviceptr_t dptr, size_t Pitch); +/** + * @brief Sets the maximum anisotropy for a texture reference. + * + * @param [in] texRef Pointer of texture reference. + * @param [out] maxAniso Value of the maximum anisotropy. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetMaxAnisotropy(textureReference *texRef, + unsigned int maxAniso); +/** + * @brief Sets border color for a texture reference. + * + * @param [in] texRef Pointer of texture reference. + * @param [in] pBorderColor Pointer of border color. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetBorderColor(textureReference *texRef, + float *pBorderColor); +/** + * @brief Sets mipmap filter mode for a texture reference. + * + * @param [in] texRef Pointer of texture reference. + * @param [in] fm Value of filter mode. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetMipmapFilterMode(textureReference *texRef, + enum hipTextureFilterMode fm); +/** + * @brief Sets mipmap level bias for a texture reference. + * + * @param [in] texRef Pointer of texture reference. + * @param [in] bias Value of mipmap bias. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetMipmapLevelBias(textureReference *texRef, float bias); +/** + * @brief Sets mipmap level clamp for a texture reference. + * + * @param [in] texRef Pointer of texture reference. + * @param [in] minMipMapLevelClamp Value of minimum mipmap level clamp. + * @param [in] maxMipMapLevelClamp Value of maximum mipmap level clamp. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetMipmapLevelClamp(textureReference *texRef, + float minMipMapLevelClamp, + float maxMipMapLevelClamp); +/** + * @brief Binds mipmapped array to a texture reference. + * + * @param [in] texRef Pointer of texture reference to bind. + * @param [in] mipmappedArray Pointer of mipmapped array to bind. + * @param [in] Flags Flags should be set as HIP_TRSA_OVERRIDE_FORMAT, as a + * valid value. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning This API is deprecated. + * + */ +DEPRECATED(DEPRECATED_MSG) +hipError_t hipTexRefSetMipmappedArray(textureReference *texRef, + struct hipMipmappedArray *mipmappedArray, + unsigned int Flags); + +// doxygen end deprecated texture management +/** + * @} + */ + +// doxygen end Texture management +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Runtime Runtime Compilation + * @{ + * This section describes the runtime compilation functions of HIP runtime API. + * + */ +// This group is for HIPrtc + +// doxygen end Runtime +/** + * @} + */ + +/** + * + * @defgroup Callback Callback Activity APIs + * @{ + * This section describes the callback/Activity of HIP runtime API. + */ +/** + * @brief Returns HIP API name by ID. + * + * @param [in] id ID of HIP API + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +const char *hipApiName(uint32_t id); +/** + * @brief Returns kernel name reference by function name. + * + * @param [in] f Name of function + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +const char *hipKernelNameRef(const hipFunction_t f); +/** + * @brief Retrives kernel for a given host pointer, unless stated otherwise. + * + * @param [in] hostFunction Pointer of host function. + * @param [in] stream Stream the kernel is executed on. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +const char *hipKernelNameRefByPtr(const void *hostFunction, hipStream_t stream); +/** + * @brief Returns device ID on the stream. + * + * @param [in] stream Stream of device executed on. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +int hipGetStreamDeviceId(hipStream_t stream); + +// doxygen end Callback +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Graph Graph Management + * @{ + * This section describes the graph management types & functions of HIP runtime + *API. + */ + +/** + * @brief Begins graph capture on a stream. + * + * @param [in] stream - Stream to initiate capture. + * @param [in] mode - Controls the interaction of this capture sequence with + * other API calls that are not safe. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipStreamBeginCapture(hipStream_t stream, hipStreamCaptureMode mode); + +/** +* @brief Begins graph capture on a stream to an existing graph. +* +* @param [in] stream - Stream to initiate capture. +* @param [in] graph - Graph to capture into. +* @param [in] dependencies - Dependencies of the first node captured in the +stream. Can be NULL if +* numDependencies is 0. +* @param [in] dependencyData - Optional array of data associated with each +dependency. +* @param [in] numDependencies - Number of dependencies. +* @param [in] mode - Controls the interaction of this capture sequence with +other API calls that are not safe. +* +* @returns #hipSuccess, #hipErrorInvalidValue +* +* @warning : param "const hipGraphEdgeData* dependencyData" is currently not +supported and has to passed as nullptr. This API is marked as beta, meaning, +while this is feature complete, it is still open to changes and may have +outstanding issues. +* +*/ +hipError_t hipStreamBeginCaptureToGraph(hipStream_t stream, hipGraph_t graph, + const hipGraphNode_t *dependencies, + const hipGraphEdgeData *dependencyData, + size_t numDependencies, + hipStreamCaptureMode mode); + +/** + * @brief Ends capture on a stream, returning the captured graph. + * + * @param [in] stream - Stream to end capture. + * @param [out] pGraph - returns the graph captured. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipStreamEndCapture(hipStream_t stream, hipGraph_t *pGraph); + +/** + * @brief Get capture status of a stream. + * + * @param [in] stream - Stream under capture. + * @param [out] pCaptureStatus - returns current status of the capture. + * @param [out] pId - unique ID of the capture. + * + * @returns #hipSuccess, #hipErrorStreamCaptureImplicit + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipStreamGetCaptureInfo(hipStream_t stream, + hipStreamCaptureStatus *pCaptureStatus, + unsigned long long *pId); + +/** + * @brief Get stream's capture state + * + * @param [in] stream - Stream under capture. + * @param [out] captureStatus_out - returns current status of the capture. + * @param [out] id_out - unique ID of the capture. + * @param [in] graph_out - returns the graph being captured into. + * @param [out] dependencies_out - returns pointer to an array of nodes. + * @param [out] numDependencies_out - returns size of the array returned in + * dependencies_out. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorStreamCaptureImplicit + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipStreamGetCaptureInfo_v2( + hipStream_t stream, hipStreamCaptureStatus *captureStatus_out, + unsigned long long *id_out __dparm(0), hipGraph_t *graph_out __dparm(0), + const hipGraphNode_t **dependencies_out __dparm(0), + size_t *numDependencies_out __dparm(0)); + +/** + * @brief Get stream's capture state + * + * @param [in] stream - Stream under capture. + * @param [out] pCaptureStatus - returns current status of the capture. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorStreamCaptureImplicit + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipStreamIsCapturing(hipStream_t stream, + hipStreamCaptureStatus *pCaptureStatus); + +/** + * @brief Update the set of dependencies in a capturing stream + * + * @param [in] stream Stream under capture. + * @param [in] dependencies pointer to an array of nodes to Add/Replace. + * @param [in] numDependencies size of the array in dependencies. + * @param [in] flags Flag how to update dependency set. Should be one of value + * in enum #hipStreamUpdateCaptureDependenciesFlags + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorIllegalState + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipStreamUpdateCaptureDependencies(hipStream_t stream, + hipGraphNode_t *dependencies, + size_t numDependencies, + unsigned int flags __dparm(0)); + +/** + * @brief Swaps the stream capture mode of a thread. + * + * @param [in] mode - Pointer to mode value to swap with the current mode + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipThreadExchangeStreamCaptureMode(hipStreamCaptureMode *mode); + +/** + * @brief Creates a graph + * + * @param [out] pGraph - pointer to graph to create. + * @param [in] flags - flags for graph creation, must be 0. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorMemoryAllocation + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphCreate(hipGraph_t *pGraph, unsigned int flags); + +/** + * @brief Destroys a graph + * + * @param [in] graph - instance of graph to destroy. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphDestroy(hipGraph_t graph); + +/** + * @brief Adds dependency edges to a graph. + * + * @param [in] graph - instance of the graph to add dependencies. + * @param [in] from - pointer to the graph nodes with dependenties to add from. + * @param [in] to - pointer to the graph nodes to add dependenties to. + * @param [in] numDependencies - the number of dependencies to add. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphAddDependencies(hipGraph_t graph, const hipGraphNode_t *from, + const hipGraphNode_t *to, + size_t numDependencies); + +/** + * @brief Removes dependency edges from a graph. + * + * @param [in] graph - instance of the graph to remove dependencies. + * @param [in] from - Array of nodes that provide the dependencies. + * @param [in] to - Array of dependent nodes. + * @param [in] numDependencies - the number of dependencies to remove. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphRemoveDependencies(hipGraph_t graph, + const hipGraphNode_t *from, + const hipGraphNode_t *to, + size_t numDependencies); + +/** + * @brief Returns a graph's dependency edges. + * + * @param [in] graph - instance of the graph to get the edges from. + * @param [out] from - pointer to the graph nodes to return edge endpoints. + * @param [out] to - pointer to the graph nodes to return edge endpoints. + * @param [out] numEdges - returns number of edges. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * from and to may both be NULL, in which case this function only returns the + * number of edges in numEdges. Otherwise, numEdges entries will be filled in. + * If numEdges is higher than the actual number of edges, the remaining entries + * in from and to will be set to NULL, and the number of edges actually returned + * will be written to numEdges + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphGetEdges(hipGraph_t graph, hipGraphNode_t *from, + hipGraphNode_t *to, size_t *numEdges); + +/** + * @brief Returns graph nodes. + * + * @param [in] graph - instance of graph to get the nodes. + * @param [out] nodes - pointer to return the graph nodes. + * @param [out] numNodes - returns number of graph nodes. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * nodes may be NULL, in which case this function will return the number of + * nodes in numNodes. Otherwise, numNodes entries will be filled in. If numNodes + * is higher than the actual number of nodes, the remaining entries in nodes + * will be set to NULL, and the number of nodes actually obtained will be + * returned in numNodes. + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphGetNodes(hipGraph_t graph, hipGraphNode_t *nodes, + size_t *numNodes); + +/** + * @brief Returns graph's root nodes. + * + * @param [in] graph - instance of the graph to get the nodes. + * @param [out] pRootNodes - pointer to return the graph's root nodes. + * @param [out] pNumRootNodes - returns the number of graph's root nodes. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * pRootNodes may be NULL, in which case this function will return the number of + * root nodes in pNumRootNodes. Otherwise, pNumRootNodes entries will be filled + * in. If pNumRootNodes is higher than the actual number of root nodes, the + * remaining entries in pRootNodes will be set to NULL, and the number of nodes + * actually obtained will be returned in pNumRootNodes. + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphGetRootNodes(hipGraph_t graph, hipGraphNode_t *pRootNodes, + size_t *pNumRootNodes); + +/** + * @brief Returns a node's dependencies. + * + * @param [in] node - graph node to get the dependencies from. + * @param [out] pDependencies - pointer to to return the dependencies. + * @param [out] pNumDependencies - returns the number of graph node + * dependencies. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * pDependencies may be NULL, in which case this function will return the number + * of dependencies in pNumDependencies. Otherwise, pNumDependencies entries will + * be filled in. If pNumDependencies is higher than the actual number of + * dependencies, the remaining entries in pDependencies will be set to NULL, and + * the number of nodes actually obtained will be returned in pNumDependencies. + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphNodeGetDependencies(hipGraphNode_t node, + hipGraphNode_t *pDependencies, + size_t *pNumDependencies); + +/** + * @brief Returns a node's dependent nodes. + * + * @param [in] node - graph node to get the Dependent nodes from. + * @param [out] pDependentNodes - pointer to return the graph dependent nodes. + * @param [out] pNumDependentNodes - returns the number of graph node dependent + * nodes. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * DependentNodes may be NULL, in which case this function will return the + * number of dependent nodes in pNumDependentNodes. Otherwise, + * pNumDependentNodes entries will be filled in. If pNumDependentNodes is higher + * than the actual number of dependent nodes, the remaining entries in + * pDependentNodes will be set to NULL, and the number of nodes actually + * obtained will be returned in pNumDependentNodes. + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphNodeGetDependentNodes(hipGraphNode_t node, + hipGraphNode_t *pDependentNodes, + size_t *pNumDependentNodes); + +/** + * @brief Returns a node's type. + * + * @param [in] node - instance of the graph to add dependencies. + * @param [out] pType - pointer to the return the type + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphNodeGetType(hipGraphNode_t node, hipGraphNodeType *pType); + +/** + * @brief Remove a node from the graph. + * + * @param [in] node - graph node to remove + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphDestroyNode(hipGraphNode_t node); + +/** + * @brief Clones a graph. + * + * @param [out] pGraphClone - Returns newly created cloned graph. + * @param [in] originalGraph - original graph to clone from. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorMemoryAllocation + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphClone(hipGraph_t *pGraphClone, hipGraph_t originalGraph); + +/** + * @brief Finds a cloned version of a node. + * + * @param [out] pNode - Returns the cloned node. + * @param [in] originalNode - original node handle. + * @param [in] clonedGraph - Cloned graph to query. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphNodeFindInClone(hipGraphNode_t *pNode, + hipGraphNode_t originalNode, + hipGraph_t clonedGraph); + +/** + * @brief Creates an executable graph from a graph + * + * @param [out] pGraphExec - pointer to instantiated executable graph that is + * created. + * @param [in] graph - instance of graph to instantiate. + * @param [out] pErrorNode - pointer to error node in case error occured in + * graph instantiation, it could modify the correponding node. + * @param [out] pLogBuffer - pointer to log buffer. + * @param [out] bufferSize - the size of log buffer. + * + * @returns #hipSuccess, #hipErrorOutOfMemory + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + */ +hipError_t hipGraphInstantiate(hipGraphExec_t *pGraphExec, hipGraph_t graph, + hipGraphNode_t *pErrorNode, char *pLogBuffer, + size_t bufferSize); + +/** + * @brief Creates an executable graph from a graph. + * + * @param [out] pGraphExec - pointer to instantiated executable graph that is + * created. + * @param [in] graph - instance of graph to instantiate. + * @param [in] flags - Flags to control instantiation. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues.It does + * not support any of flag and is behaving as hipGraphInstantiate. + */ +hipError_t hipGraphInstantiateWithFlags(hipGraphExec_t *pGraphExec, + hipGraph_t graph, + unsigned long long flags); + +/** + * @brief Creates an executable graph from a graph. + * + * @param [out] pGraphExec - pointer to instantiated executable graph that is + * created. + * @param [in] graph - instance of graph to instantiate. + * @param [in] instantiateParams - Graph Instantiate Params + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t +hipGraphInstantiateWithParams(hipGraphExec_t *pGraphExec, hipGraph_t graph, + hipGraphInstantiateParams *instantiateParams); +/** + * @brief launches an executable graph in a stream + * + * @param [in] graphExec - instance of executable graph to launch. + * @param [in] stream - instance of stream in which to launch executable graph. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphLaunch(hipGraphExec_t graphExec, hipStream_t stream); + +/** + * @brief uploads an executable graph in a stream + * + * @param [in] graphExec - instance of executable graph to launch. + * @param [in] stream - instance of stream in which to launch executable graph. + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphUpload(hipGraphExec_t graphExec, hipStream_t stream); + +/** + * @brief Creates a kernel execution node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to graph node to create. + * @param [in] graph - instance of graph to add the created node. + * @param [in] pDependencies - pointer to the dependencies on the kernel + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] nodeParams - pointer to the parameters for the node. + * @returns #hipSuccess, #hipErrorInvalidValue. + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + hipGraphNodeParams *nodeParams); + +/** + * @brief Destroys an executable graph + * + * @param [in] graphExec - instance of executable graph to destry. + * + * @returns #hipSuccess. + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecDestroy(hipGraphExec_t graphExec); + +// Check whether an executable graph can be updated with a graph and perform the +// update if possible. +/** + * @brief Check whether an executable graph can be updated with a graph and + * perform the update if * possible. + * + * @param [in] hGraphExec - instance of executable graph to update. + * @param [in] hGraph - graph that contains the updated parameters. + * @param [in] hErrorNode_out - node which caused the permissibility check to + * forbid the update. + * @param [in] updateResult_out - Whether the graph update was permitted. + * @returns #hipSuccess, #hipErrorGraphExecUpdateFailure + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecUpdate(hipGraphExec_t hGraphExec, hipGraph_t hGraph, + hipGraphNode_t *hErrorNode_out, + hipGraphExecUpdateResult *updateResult_out); + +/** + * @brief Creates a kernel execution node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to graph node to create. + * @param [in] graph - instance of graph to add the created node. + * @param [in] pDependencies - pointer to the dependencies on the kernel + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] pNodeParams - pointer to the parameters to the kernel execution + * node on the GPU. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorInvalidDeviceFunction + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddKernelNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + const hipKernelNodeParams *pNodeParams); + +/** + * @brief Gets kernel node's parameters. + * + * @param [in] node - instance of the node to get parameters from. + * @param [out] pNodeParams - pointer to the parameters + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphKernelNodeGetParams(hipGraphNode_t node, + hipKernelNodeParams *pNodeParams); + +/** + * @brief Sets a kernel node's parameters. + * + * @param [in] node - instance of the node to set parameters to. + * @param [in] pNodeParams - const pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphKernelNodeSetParams(hipGraphNode_t node, + const hipKernelNodeParams *pNodeParams); + +/** + * @brief Sets the parameters for a kernel node in the given graphExec. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] node - instance of the node to set parameters to. + * @param [in] pNodeParams - const pointer to the kernel node parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t +hipGraphExecKernelNodeSetParams(hipGraphExec_t hGraphExec, hipGraphNode_t node, + const hipKernelNodeParams *pNodeParams); + +/** + * @brief Creates a memcpy node and adds it to a graph. + * + * @param [out] phGraphNode - pointer to graph node to create. + * @param [in] hGraph - instance of graph to add the created node. + * @param [in] dependencies - const pointer to the dependencies on the memcpy + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] copyParams - const pointer to the parameters for the memory copy. + * @param [in] ctx - cotext related to current device. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipDrvGraphAddMemcpyNode(hipGraphNode_t *phGraphNode, + hipGraph_t hGraph, + const hipGraphNode_t *dependencies, + size_t numDependencies, + const HIP_MEMCPY3D *copyParams, + hipCtx_t ctx); +/** + * @brief Creates a memcpy node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to graph node to create. + * @param [in] graph - instance of graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memcpy + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] pCopyParams - const pointer to the parameters for the memory + * copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddMemcpyNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + const hipMemcpy3DParms *pCopyParams); +/** + * @brief Gets a memcpy node's parameters. + * + * @param [in] node - instance of the node to get parameters from. + * @param [out] pNodeParams - pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemcpyNodeGetParams(hipGraphNode_t node, + hipMemcpy3DParms *pNodeParams); + +/** + * @brief Sets a memcpy node's parameters. + * + * @param [in] node - instance of the node to set parameters to. + * @param [in] pNodeParams - const pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemcpyNodeSetParams(hipGraphNode_t node, + const hipMemcpy3DParms *pNodeParams); + +/** + * @brief Sets a node attribute. + * + * @param [in] hNode - instance of the node to set parameters to. + * @param [in] attr - the attribute node is set to. + * @param [in] value - const pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphKernelNodeSetAttribute(hipGraphNode_t hNode, + hipKernelNodeAttrID attr, + const hipKernelNodeAttrValue *value); +/** + * @brief Gets a node attribute. + * + * @param [in] hNode - instance of the node to set parameters to. + * @param [in] attr - the attribute node is set to. + * @param [in] value - const pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphKernelNodeGetAttribute(hipGraphNode_t hNode, + hipKernelNodeAttrID attr, + hipKernelNodeAttrValue *value); +/** + * @brief Sets the parameters for a memcpy node in the given graphExec. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] node - instance of the node to set parameters to. + * @param [in] pNodeParams - const pointer to the kernel node parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecMemcpyNodeSetParams(hipGraphExec_t hGraphExec, + hipGraphNode_t node, + hipMemcpy3DParms *pNodeParams); + +/** + * @brief Creates a 1D memcpy node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to graph node to create. + * @param [in] graph - instance of graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memcpy + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] dst - pointer to memory address to the destination. + * @param [in] src - pointer to memory address to the source. + * @param [in] count - the size of the memory to copy. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddMemcpyNode1D(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, void *dst, + const void *src, size_t count, + hipMemcpyKind kind); + +/** + * @brief Sets a memcpy node's parameters to perform a 1-dimensional copy. + * + * @param [in] node - instance of the node to set parameters to. + * @param [in] dst - pointer to memory address to the destination. + * @param [in] src - pointer to memory address to the source. + * @param [in] count - the size of the memory to copy. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemcpyNodeSetParams1D(hipGraphNode_t node, void *dst, + const void *src, size_t count, + hipMemcpyKind kind); + +/** + * @brief Sets the parameters for a memcpy node in the given graphExec to + * perform a 1-dimensional copy. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] node - instance of the node to set parameters to. + * @param [in] dst - pointer to memory address to the destination. + * @param [in] src - pointer to memory address to the source. + * @param [in] count - the size of the memory to copy. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecMemcpyNodeSetParams1D(hipGraphExec_t hGraphExec, + hipGraphNode_t node, void *dst, + const void *src, size_t count, + hipMemcpyKind kind); + +/** + * @brief Creates a memcpy node to copy from a symbol on the device and adds it + * to a graph. + * + * @param [out] pGraphNode - pointer to graph node to create. + * @param [in] graph - instance of graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memcpy + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] dst - pointer to memory address to the destination. + * @param [in] symbol - Device symbol address. + * @param [in] count - the size of the memory to copy. + * @param [in] offset - Offset from start of symbol in bytes. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddMemcpyNodeFromSymbol(hipGraphNode_t *pGraphNode, + hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, void *dst, + const void *symbol, size_t count, + size_t offset, hipMemcpyKind kind); + +/** + * @brief Sets a memcpy node's parameters to copy from a symbol on the device. + * + * @param [in] node - instance of the node to set parameters to. + * @param [in] dst - pointer to memory address to the destination. + * @param [in] symbol - Device symbol address. + * @param [in] count - the size of the memory to copy. + * @param [in] offset - Offset from start of symbol in bytes. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemcpyNodeSetParamsFromSymbol(hipGraphNode_t node, void *dst, + const void *symbol, + size_t count, size_t offset, + hipMemcpyKind kind); + +/** + * @brief Sets the parameters for a memcpy node in the given graphExec to copy + * from a symbol on the + * * device. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] node - instance of the node to set parameters to. + * @param [in] dst - pointer to memory address to the destination. + * @param [in] symbol - Device symbol address. + * @param [in] count - the size of the memory to copy. + * @param [in] offset - Offset from start of symbol in bytes. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecMemcpyNodeSetParamsFromSymbol( + hipGraphExec_t hGraphExec, hipGraphNode_t node, void *dst, + const void *symbol, size_t count, size_t offset, hipMemcpyKind kind); + +/** + * @brief Creates a memcpy node to copy to a symbol on the device and adds it to + * a graph. + * + * @param [out] pGraphNode - pointer to graph node to create. + * @param [in] graph - instance of graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memcpy + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] symbol - Device symbol address. + * @param [in] src - pointer to memory address of the src. + * @param [in] count - the size of the memory to copy. + * @param [in] offset - Offset from start of symbol in bytes. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddMemcpyNodeToSymbol(hipGraphNode_t *pGraphNode, + hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + const void *symbol, const void *src, + size_t count, size_t offset, + hipMemcpyKind kind); + +/** + * @brief Sets a memcpy node's parameters to copy to a symbol on the device. + * + * @param [in] node - instance of the node to set parameters to. + * @param [in] symbol - Device symbol address. + * @param [in] src - pointer to memory address of the src. + * @param [in] count - the size of the memory to copy. + * @param [in] offset - Offset from start of symbol in bytes. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemcpyNodeSetParamsToSymbol(hipGraphNode_t node, + const void *symbol, + const void *src, size_t count, + size_t offset, + hipMemcpyKind kind); + +/** + * @brief Sets the parameters for a memcpy node in the given graphExec to copy + * to a symbol on the device. + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] node - instance of the node to set parameters to. + * @param [in] symbol - Device symbol address. + * @param [in] src - pointer to memory address of the src. + * @param [in] count - the size of the memory to copy. + * @param [in] offset - Offset from start of symbol in bytes. + * @param [in] kind - the type of memory copy. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecMemcpyNodeSetParamsToSymbol( + hipGraphExec_t hGraphExec, hipGraphNode_t node, const void *symbol, + const void *src, size_t count, size_t offset, hipMemcpyKind kind); + +/** + * @brief Creates a memset node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to the graph node to create. + * @param [in] graph - instance of the graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memset + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] pMemsetParams - const pointer to the parameters for the memory + * set. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddMemsetNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + const hipMemsetParams *pMemsetParams); + +/** + * @brief Gets a memset node's parameters. + * + * @param [in] node - instane of the node to get parameters from. + * @param [out] pNodeParams - pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemsetNodeGetParams(hipGraphNode_t node, + hipMemsetParams *pNodeParams); + +/** + * @brief Sets a memset node's parameters. + * + * @param [in] node - instance of the node to set parameters to. + * @param [in] pNodeParams - pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemsetNodeSetParams(hipGraphNode_t node, + const hipMemsetParams *pNodeParams); + +/** + * @brief Sets the parameters for a memset node in the given graphExec. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] node - instance of the node to set parameters to. + * @param [in] pNodeParams - pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecMemsetNodeSetParams(hipGraphExec_t hGraphExec, + hipGraphNode_t node, + const hipMemsetParams *pNodeParams); + +/** + * @brief Creates a host execution node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to the graph node to create. + * @param [in] graph - instance of the graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memset + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] pNodeParams -pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddHostNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + const hipHostNodeParams *pNodeParams); + +/** + * @brief Returns a host node's parameters. + * + * @param [in] node - instane of the node to get parameters from. + * @param [out] pNodeParams - pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphHostNodeGetParams(hipGraphNode_t node, + hipHostNodeParams *pNodeParams); + +/** + * @brief Sets a host node's parameters. + * + * @param [in] node - instance of the node to set parameters to. + * @param [in] pNodeParams - pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphHostNodeSetParams(hipGraphNode_t node, + const hipHostNodeParams *pNodeParams); + +/** + * @brief Sets the parameters for a host node in the given graphExec. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] node - instance of the node to set parameters to. + * @param [in] pNodeParams - pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecHostNodeSetParams(hipGraphExec_t hGraphExec, + hipGraphNode_t node, + const hipHostNodeParams *pNodeParams); + +/** + * @brief Creates a child graph node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to the graph node to create. + * @param [in] graph - instance of the graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memset + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] childGraph - the graph to clone into this node + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddChildGraphNode(hipGraphNode_t *pGraphNode, + hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + hipGraph_t childGraph); + +/** + * @brief Gets a handle to the embedded graph of a child graph node. + * + * @param [in] node - instane of the node to get child graph. + * @param [out] pGraph - pointer to get the graph. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphChildGraphNodeGetGraph(hipGraphNode_t node, + hipGraph_t *pGraph); + +/** + * @brief Updates node parameters in the child graph node in the given + * graphExec. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] node - node from the graph which was used to instantiate + * graphExec. + * @param [in] childGraph - child graph with updated parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecChildGraphNodeSetParams(hipGraphExec_t hGraphExec, + hipGraphNode_t node, + hipGraph_t childGraph); + +/** + * @brief Creates an empty node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to the graph node to create and add to the + * graph. + * @param [in] graph - instane of the graph the node is add to. + * @param [in] pDependencies - const pointer to the node dependenties. + * @param [in] numDependencies - the number of dependencies. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddEmptyNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies); + +/** + * @brief Creates an event record node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to the graph node to create and add to the + * graph. + * @param [in] graph - instane of the graph the node to be added. + * @param [in] pDependencies - const pointer to the node dependenties. + * @param [in] numDependencies - the number of dependencies. + * @param [in] event - Event for the node. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddEventRecordNode(hipGraphNode_t *pGraphNode, + hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, hipEvent_t event); + +/** + * @brief Returns the event associated with an event record node. + * + * @param [in] node - instane of the node to get event from. + * @param [out] event_out - Pointer to return the event. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphEventRecordNodeGetEvent(hipGraphNode_t node, + hipEvent_t *event_out); + +/** + * @brief Sets an event record node's event. + * + * @param [in] node - instane of the node to set event to. + * @param [in] event - pointer to the event. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphEventRecordNodeSetEvent(hipGraphNode_t node, + hipEvent_t event); + +/** + * @brief Sets the event for an event record node in the given graphExec. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] hNode - node from the graph which was used to instantiate + * graphExec. + * @param [in] event - pointer to the event. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecEventRecordNodeSetEvent(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + hipEvent_t event); + +/** + * @brief Creates an event wait node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to the graph node to create and add to the + * graph. + * @param [in] graph - instane of the graph the node to be added. + * @param [in] pDependencies - const pointer to the node dependenties. + * @param [in] numDependencies - the number of dependencies. + * @param [in] event - Event for the node. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddEventWaitNode(hipGraphNode_t *pGraphNode, + hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, hipEvent_t event); + +/** + * @brief Returns the event associated with an event wait node. + * + * @param [in] node - instane of the node to get event from. + * @param [out] event_out - Pointer to return the event. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphEventWaitNodeGetEvent(hipGraphNode_t node, + hipEvent_t *event_out); + +/** + * @brief Sets an event wait node's event. + * + * @param [in] node - instane of the node to set event to. + * @param [in] event - pointer to the event. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphEventWaitNodeSetEvent(hipGraphNode_t node, hipEvent_t event); + +/** + * @brief Sets the event for an event record node in the given graphExec. + * + * @param [in] hGraphExec - instance of the executable graph with the node. + * @param [in] hNode - node from the graph which was used to instantiate + * graphExec. + * @param [in] event - pointer to the event. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecEventWaitNodeSetEvent(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + hipEvent_t event); + +/** + * @brief Creates a memory allocation node and adds it to a graph + * + * @param [out] pGraphNode - Pointer to the graph node to create and add to + * the graph + * @param [in] graph - Instane of the graph the node to be added + * @param [in] pDependencies - Const pointer to the node dependenties + * @param [in] numDependencies - The number of dependencies + * @param [in] pNodeParams - Node parameters for memory allocation + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddMemAllocNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, + hipMemAllocNodeParams *pNodeParams); + +/** + * @brief Returns parameters for memory allocation node + * + * @param [in] node - Memory allocation node for a query + * @param [out] pNodeParams - Parameters for the specified memory allocation + * node + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemAllocNodeGetParams(hipGraphNode_t node, + hipMemAllocNodeParams *pNodeParams); + +/** + * @brief Creates a memory free node and adds it to a graph + * + * @param [out] pGraphNode - Pointer to the graph node to create and add to + * the graph + * @param [in] graph - Instane of the graph the node to be added + * @param [in] pDependencies - Const pointer to the node dependenties + * @param [in] numDependencies - The number of dependencies + * @param [in] dev_ptr - Pointer to the memory to be freed + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddMemFreeNode(hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, + size_t numDependencies, void *dev_ptr); + +/** + * @brief Returns parameters for memory free node + * + * @param [in] node - Memory free node for a query + * @param [out] dev_ptr - Device pointer for the specified memory free node + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphMemFreeNodeGetParams(hipGraphNode_t node, void *dev_ptr); + +/** + * @brief Get the mem attribute for graphs. + * + * @param [in] device - device the attr is get for. + * @param [in] attr - attr to get. + * @param [out] value - value for specific attr. + * @returns #hipSuccess, #hipErrorInvalidDevice + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipDeviceGetGraphMemAttribute(int device, + hipGraphMemAttributeType attr, + void *value); + +/** + * @brief Set the mem attribute for graphs. + * + * @param [in] device - device the attr is set for. + * @param [in] attr - attr to set. + * @param [in] value - value for specific attr. + * @returns #hipSuccess, #hipErrorInvalidDevice + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipDeviceSetGraphMemAttribute(int device, + hipGraphMemAttributeType attr, + void *value); + +/** + * @brief Free unused memory on specific device used for graph back to OS. + * + * @param [in] device - device the memory is used for graphs + * @returns #hipSuccess, #hipErrorInvalidDevice + * + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipDeviceGraphMemTrim(int device); + +/** + * @brief Create an instance of userObject to manage lifetime of a resource. + * + * @param [out] object_out - pointer to instace of userobj. + * @param [in] ptr - pointer to pass to destroy function. + * @param [in] destroy - destroy callback to remove resource. + * @param [in] initialRefcount - reference to resource. + * @param [in] flags - flags passed to API. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipUserObjectCreate(hipUserObject_t *object_out, void *ptr, + hipHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags); + +/** + * @brief Release number of references to resource. + * + * @param [in] object - pointer to instace of userobj. + * @param [in] count - reference to resource to be retained. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipUserObjectRelease(hipUserObject_t object, + unsigned int count __dparm(1)); + +/** + * @brief Retain number of references to resource. + * + * @param [in] object - pointer to instace of userobj. + * @param [in] count - reference to resource to be retained. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipUserObjectRetain(hipUserObject_t object, + unsigned int count __dparm(1)); + +/** + * @brief Retain user object for graphs. + * + * @param [in] graph - pointer to graph to retain the user object for. + * @param [in] object - pointer to instace of userobj. + * @param [in] count - reference to resource to be retained. + * @param [in] flags - flags passed to API. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphRetainUserObject(hipGraph_t graph, hipUserObject_t object, + unsigned int count __dparm(1), + unsigned int flags __dparm(0)); + +/** + * @brief Release user object from graphs. + * + * @param [in] graph - pointer to graph to retain the user object for. + * @param [in] object - pointer to instace of userobj. + * @param [in] count - reference to resource to be retained. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphReleaseUserObject(hipGraph_t graph, hipUserObject_t object, + unsigned int count __dparm(1)); + +/** + * @brief Write a DOT file describing graph structure. + * + * @param [in] graph - graph object for which DOT file has to be generated. + * @param [in] path - path to write the DOT file. + * @param [in] flags - Flags from hipGraphDebugDotFlags to get additional node + * information. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorOperatingSystem + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphDebugDotPrint(hipGraph_t graph, const char *path, + unsigned int flags); + +/** + * @brief Copies attributes from source node to destination node. + * + * Copies attributes from source node to destination node. + * Both node must have the same context. + * + * @param [out] hDst - Destination node. + * @param [in] hSrc - Source node. + * For list of attributes see ::hipKernelNodeAttrID. + * + * @returns #hipSuccess, #hipErrorInvalidContext + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphKernelNodeCopyAttributes(hipGraphNode_t hSrc, + hipGraphNode_t hDst); + +/** + * @brief Enables or disables the specified node in the given graphExec + * + * Sets hNode to be either enabled or disabled. Disabled nodes are functionally + * equivalent to empty nodes until they are reenabled. Existing node parameters + * are not affected by disabling/enabling the node. + * + * The node is identified by the corresponding hNode in the non-executable + * graph, from which the executable graph was instantiated. + * + * hNode must not have been removed from the original graph. + * + * @note Currently only kernel, memset and memcpy nodes are supported. + * + * @param [in] hGraphExec - The executable graph in which to set the specified + * node. + * @param [in] hNode - Node from the graph from which graphExec was + * instantiated. + * @param [in] isEnabled - Node is enabled if != 0, otherwise the node is + * disabled. + * + * @returns #hipSuccess, #hipErrorInvalidValue, + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphNodeSetEnabled(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, unsigned int isEnabled); +/** + * @brief Query whether a node in the given graphExec is enabled + * + * Sets isEnabled to 1 if hNode is enabled, or 0 if it is disabled. + * + * The node is identified by the corresponding node in the non-executable graph, + * from which the executable graph was instantiated. + * + * hNode must not have been removed from the original graph. + * + * @note Currently only kernel, memset and memcpy nodes are supported. + * + * @param [in] hGraphExec - The executable graph in which to set the specified + * node. + * @param [in] hNode - Node from the graph from which graphExec was + * instantiated. + * @param [out] isEnabled - Location to return the enabled status of the node. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphNodeGetEnabled(hipGraphExec_t hGraphExec, + hipGraphNode_t hNode, + unsigned int *isEnabled); + +/** + * @brief Creates a external semaphor wait node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to the graph node to create. + * @param [in] graph - instance of the graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memset + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] nodeParams -pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddExternalSemaphoresWaitNode( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const hipExternalSemaphoreWaitNodeParams *nodeParams); + +/** + * @brief Creates a external semaphor signal node and adds it to a graph. + * + * @param [out] pGraphNode - pointer to the graph node to create. + * @param [in] graph - instance of the graph to add the created node. + * @param [in] pDependencies - const pointer to the dependencies on the memset + * execution node. + * @param [in] numDependencies - the number of the dependencies. + * @param [in] nodeParams -pointer to the parameters. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphAddExternalSemaphoresSignalNode( + hipGraphNode_t *pGraphNode, hipGraph_t graph, + const hipGraphNode_t *pDependencies, size_t numDependencies, + const hipExternalSemaphoreSignalNodeParams *nodeParams); +/** + * @brief Updates node parameters in the external semaphore signal node. + * + * @param [in] hNode - Node from the graph from which graphExec was + * instantiated. + * @param [in] nodeParams - Pointer to the params to be set. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExternalSemaphoresSignalNodeSetParams( + hipGraphNode_t hNode, + const hipExternalSemaphoreSignalNodeParams *nodeParams); +/** + * @brief Updates node parameters in the external semaphore wait node. + * + * @param [in] hNode - Node from the graph from which graphExec was + * instantiated. + * @param [in] nodeParams - Pointer to the params to be set. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExternalSemaphoresWaitNodeSetParams( + hipGraphNode_t hNode, const hipExternalSemaphoreWaitNodeParams *nodeParams); +/** + * @brief Returns external semaphore signal node params. + * + * @param [in] hNode - Node from the graph from which graphExec was + * instantiated. + * @param [out] params_out - Pointer to params. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExternalSemaphoresSignalNodeGetParams( + hipGraphNode_t hNode, hipExternalSemaphoreSignalNodeParams *params_out); +/** + * @brief Returns external semaphore wait node params. + * + * @param [in] hNode - Node from the graph from which graphExec was + * instantiated. + * @param [out] params_out - Pointer to params. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExternalSemaphoresWaitNodeGetParams( + hipGraphNode_t hNode, hipExternalSemaphoreWaitNodeParams *params_out); +/** + * @brief Updates node parameters in the external semaphore signal node in the + * given graphExec. + * + * @param [in] hGraphExec - The executable graph in which to set the specified + * node. + * @param [in] hNode - Node from the graph from which graphExec was + * instantiated. + * @param [in] nodeParams - Pointer to the params to be set. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecExternalSemaphoresSignalNodeSetParams( + hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const hipExternalSemaphoreSignalNodeParams *nodeParams); +/** + * @brief Updates node parameters in the external semaphore wait node in the + * given graphExec. + * + * @param [in] hGraphExec - The executable graph in which to set the specified + * node. + * @param [in] hNode - Node from the graph from which graphExec was + * instantiated. + * @param [in] nodeParams - Pointer to the params to be set. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipGraphExecExternalSemaphoresWaitNodeSetParams( + hipGraphExec_t hGraphExec, hipGraphNode_t hNode, + const hipExternalSemaphoreWaitNodeParams *nodeParams); + +/** + * @brief Creates a memset node and adds it to a graph. + * + * @param [out] phGraphNode - pointer to graph node to create. + * @param [in] hGraph - instance of graph to add the created node to. + * @param [in] dependencies - const pointer to the dependencies on the memset + * execution node. + * @param [in] numDependencies - number of the dependencies. + * @param [in] memsetParams - const pointer to the parameters for the memory + * set. + * @param [in] ctx - cotext related to current device. + * @returns #hipSuccess, #hipErrorInvalidValue + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + */ +hipError_t hipDrvGraphAddMemsetNode(hipGraphNode_t *phGraphNode, + hipGraph_t hGraph, + const hipGraphNode_t *dependencies, + size_t numDependencies, + const HIP_MEMSET_NODE_PARAMS *memsetParams, + hipCtx_t ctx); + +// doxygen end graph API +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Virtual Virtual Memory Management + * @{ + * This section describes the virtual memory management functions of HIP + *runtime API. + * + * @note Please note, the virtual memory management functions of HIP runtime + *API are implemented on Linux, under development on Windows. + */ + +/** + * @brief Frees an address range reservation made via hipMemAddressReserve + * + * @param [in] devPtr - starting address of the range. + * @param [in] size - size of the range. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemAddressFree(void *devPtr, size_t size); + +/** + * @brief Reserves an address range + * + * @param [out] ptr - starting address of the reserved range. + * @param [in] size - size of the reservation. + * @param [in] alignment - alignment of the address. + * @param [in] addr - requested starting address of the range. + * @param [in] flags - currently unused, must be zero. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemAddressReserve(void **ptr, size_t size, size_t alignment, + void *addr, unsigned long long flags); + +/** + * @brief Creates a memory allocation described by the properties and size + * + * @param [out] handle - value of the returned handle. + * @param [in] size - size of the allocation. + * @param [in] prop - properties of the allocation. + * @param [in] flags - currently unused, must be zero. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemCreate(hipMemGenericAllocationHandle_t *handle, size_t size, + const hipMemAllocationProp *prop, + unsigned long long flags); + +/** + * @brief Exports an allocation to a requested shareable handle type. + * + * @param [out] shareableHandle - value of the returned handle. + * @param [in] handle - handle to share. + * @param [in] handleType - type of the shareable handle. + * @param [in] flags - currently unused, must be zero. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemExportToShareableHandle(void *shareableHandle, + hipMemGenericAllocationHandle_t handle, + hipMemAllocationHandleType handleType, + unsigned long long flags); + +/** + * @brief Get the access flags set for the given location and ptr. + * + * @param [out] flags - flags for this location. + * @param [in] location - target location. + * @param [in] ptr - address to check the access flags. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemGetAccess(unsigned long long *flags, + const hipMemLocation *location, void *ptr); + +/** + * @brief Calculates either the minimal or recommended granularity. + * + * @param [out] granularity - returned granularity. + * @param [in] prop - location properties. + * @param [in] option - determines which granularity to return. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + * + */ +hipError_t +hipMemGetAllocationGranularity(size_t *granularity, + const hipMemAllocationProp *prop, + hipMemAllocationGranularity_flags option); + +/** + * @brief Retrieve the property structure of the given handle. + * + * @param [out] prop - properties of the given handle. + * @param [in] handle - handle to perform the query on. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux under development on Windows. + */ +hipError_t +hipMemGetAllocationPropertiesFromHandle(hipMemAllocationProp *prop, + hipMemGenericAllocationHandle_t handle); + +/** + * @brief Imports an allocation from a requested shareable handle type. + * + * @param [out] handle - returned value. + * @param [in] osHandle - shareable handle representing the memory allocation. + * @param [in] shHandleType - handle type. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t +hipMemImportFromShareableHandle(hipMemGenericAllocationHandle_t *handle, + void *osHandle, + hipMemAllocationHandleType shHandleType); + +/** + * @brief Maps an allocation handle to a reserved virtual address range. + * + * @param [in] ptr - address where the memory will be mapped. + * @param [in] size - size of the mapping. + * @param [in] offset - offset into the memory, currently must be zero. + * @param [in] handle - memory allocation to be mapped. + * @param [in] flags - currently unused, must be zero. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemMap(void *ptr, size_t size, size_t offset, + hipMemGenericAllocationHandle_t handle, + unsigned long long flags); + +/** + * @brief Maps or unmaps subregions of sparse HIP arrays and sparse HIP + * mipmapped arrays. + * + * @param [in] mapInfoList - list of hipArrayMapInfo. + * @param [in] count - number of hipArrayMapInfo in mapInfoList. + * @param [in] stream - stream identifier for the stream to use for map or unmap + * operations. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemMapArrayAsync(hipArrayMapInfo *mapInfoList, unsigned int count, + hipStream_t stream); + +/** + * @brief Release a memory handle representing a memory allocation which was + * previously allocated through hipMemCreate. + * + * @param [in] handle - handle of the memory allocation. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemRelease(hipMemGenericAllocationHandle_t handle); + +/** + * @brief Returns the allocation handle of the backing memory allocation given + * the address. + * + * @param [out] handle - handle representing addr. + * @param [in] addr - address to look up. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemRetainAllocationHandle(hipMemGenericAllocationHandle_t *handle, + void *addr); + +/** + * @brief Set the access flags for each location specified in desc for the given + * virtual address range. + * + * @param [in] ptr - starting address of the virtual address range. + * @param [in] size - size of the range. + * @param [in] desc - array of hipMemAccessDesc. + * @param [in] count - number of hipMemAccessDesc in desc. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemSetAccess(void *ptr, size_t size, const hipMemAccessDesc *desc, + size_t count); + +/** + * @brief Unmap memory allocation of a given address range. + * + * @param [in] ptr - starting address of the range to unmap. + * @param [in] size - size of the virtual address range. + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorNotSupported + * @warning : This API is marked as beta, meaning, while this is feature + * complete, it is still open to changes and may have outstanding issues. + * + * @note This API is implemented on Linux, under development on Windows. + */ +hipError_t hipMemUnmap(void *ptr, size_t size); + +// doxygen end virtual memory management API +/** + * @} + */ +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup GL OpenGL Interop + * @{ + * This section describes the OpenGL and graphics interoperability functions of + *HIP runtime API. + */ + +/** + * @brief Maps a graphics resource for access. + * + * @param [in] count - Number of resources to map. + * @param [in] resources - Pointer of resources to map. + * @param [in] stream - Stream for synchronization. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorUnknown, + * #hipErrorInvalidResourceHandle + * + */ +hipError_t hipGraphicsMapResources(int count, hipGraphicsResource_t *resources, + hipStream_t stream __dparm(0)); +/** + * @brief Get an array through which to access a subresource of a mapped + * graphics resource. + * + * @param [out] array - Pointer of array through which a subresource of resource + * may be accessed. + * @param [in] resource - Mapped resource to access. + * @param [in] arrayIndex - Array index for the subresource to access. + * @param [in] mipLevel - Mipmap level for the subresource to access. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + * @note In this API, the value of arrayIndex higher than zero is currently not + * supported. + * + */ +hipError_t hipGraphicsSubResourceGetMappedArray(hipArray_t *array, + hipGraphicsResource_t resource, + unsigned int arrayIndex, + unsigned int mipLevel); +/** + * @brief Gets device accessible address of a graphics resource. + * + * @param [out] devPtr - Pointer of device through which graphic resource may be + * accessed. + * @param [out] size - Size of the buffer accessible from devPtr. + * @param [in] resource - Mapped resource to access. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipGraphicsResourceGetMappedPointer(void **devPtr, size_t *size, + hipGraphicsResource_t resource); +/** + * @brief Unmaps graphics resources. + * + * @param [in] count - Number of resources to unmap. + * @param [in] resources - Pointer of resources to unmap. + * @param [in] stream - Stream for synchronization. + * + * @returns #hipSuccess, #hipErrorInvalidValue, #hipErrorUnknown, + * #hipErrorContextIsDestroyed + * + */ +hipError_t hipGraphicsUnmapResources(int count, + hipGraphicsResource_t *resources, + hipStream_t stream __dparm(0)); +/** + * @brief Unregisters a graphics resource. + * + * @param [in] resource - Graphics resources to unregister. + * + * @returns #hipSuccess + * + */ +hipError_t hipGraphicsUnregisterResource(hipGraphicsResource_t resource); +// doxygen end GL Interop +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Surface Surface Object + * @{ + * + * This section describes surface object functions of HIP runtime API. + * + * @note APIs in this section are under development. + * + */ + +/** + * @brief Create a surface object. + * + * @param [out] pSurfObject Pointer of surface object to be created. + * @param [in] pResDesc Pointer of suface object descriptor. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +hipError_t hipCreateSurfaceObject(hipSurfaceObject_t *pSurfObject, + const hipResourceDesc *pResDesc); +/** + * @brief Destroy a surface object. + * + * @param [in] surfaceObject Surface object to be destroyed. + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +hipError_t hipDestroySurfaceObject(hipSurfaceObject_t surfaceObject); +// end of surface +/** + * @} + */ +#ifdef __cplusplus +} /* extern "c" */ +#endif +#ifdef __cplusplus +#if defined(__clang__) && defined(__HIP__) +template +static hipError_t __host__ inline hipOccupancyMaxPotentialBlockSize( + int *gridSize, int *blockSize, T f, size_t dynSharedMemPerBlk = 0, + int blockSizeLimit = 0) { + return hipOccupancyMaxPotentialBlockSize(gridSize, blockSize, + reinterpret_cast(f), + dynSharedMemPerBlk, blockSizeLimit); +} +template +static hipError_t __host__ inline hipOccupancyMaxPotentialBlockSizeWithFlags( + int *gridSize, int *blockSize, T f, size_t dynSharedMemPerBlk = 0, + int blockSizeLimit = 0, unsigned int flags = 0) { + return hipOccupancyMaxPotentialBlockSize(gridSize, blockSize, + reinterpret_cast(f), + dynSharedMemPerBlk, blockSizeLimit); +} +#endif // defined(__clang__) && defined(__HIP__) + +/** + * @brief Gets the address of a symbol. + * @ingroup Memory + * @param [out] devPtr - Returns device pointer associated with symbol. + * @param [in] symbol - Device symbol. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +template +hipError_t hipGetSymbolAddress(void **devPtr, const T &symbol) { + return ::hipGetSymbolAddress(devPtr, (const void *)&symbol); +} +/** + * @ingroup Memory + * @brief Gets the size of a symbol. + * + * @param [out] size - Returns the size of a symbol. + * @param [in] symbol - Device symbol address. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +template +hipError_t hipGetSymbolSize(size_t *size, const T &symbol) { + return ::hipGetSymbolSize(size, (const void *)&symbol); +} + +/** + * @ingroup Memory + * @brief Copies data to the given symbol on the device. + * + * @returns #hipSuccess, #hipErrorInvalidMemcpyDirection, #hipErrorInvalidValue + * + * @see hipMemcpyToSymbol + */ +template +hipError_t +hipMemcpyToSymbol(const T &symbol, const void *src, size_t sizeBytes, + size_t offset __dparm(0), + hipMemcpyKind kind __dparm(hipMemcpyHostToDevice)) { + return ::hipMemcpyToSymbol((const void *)&symbol, src, sizeBytes, offset, + kind); +} +/** + * @ingroup Memory + * @brief Copies data to the given symbol on the device asynchronously on the + * stream. + * + * @returns #hipSuccess, #hipErrorInvalidMemcpyDirection, #hipErrorInvalidValue + * + * @see hipMemcpyToSymbolAsync + */ +template +hipError_t hipMemcpyToSymbolAsync(const T &symbol, const void *src, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind, + hipStream_t stream __dparm(0)) { + return ::hipMemcpyToSymbolAsync((const void *)&symbol, src, sizeBytes, offset, + kind, stream); +} +/** + * @brief Copies data from the given symbol on the device. + * @ingroup Memory + * @returns #hipSuccess, #hipErrorInvalidMemcpyDirection, #hipErrorInvalidValue + * + * @see hipMemcpyFromSymbol + */ +template +hipError_t +hipMemcpyFromSymbol(void *dst, const T &symbol, size_t sizeBytes, + size_t offset __dparm(0), + hipMemcpyKind kind __dparm(hipMemcpyDeviceToHost)) { + return ::hipMemcpyFromSymbol(dst, (const void *)&symbol, sizeBytes, offset, + kind); +} +/** + * @brief Copies data from the given symbol on the device asynchronously on the + * stream. + * @ingroup Memory + * @returns #hipSuccess, #hipErrorInvalidMemcpyDirection, #hipErrorInvalidValue + * + * @see hipMemcpyFromSymbolAsync + */ +template +hipError_t hipMemcpyFromSymbolAsync(void *dst, const T &symbol, + size_t sizeBytes, size_t offset, + hipMemcpyKind kind, + hipStream_t stream __dparm(0)) { + return ::hipMemcpyFromSymbolAsync(dst, (const void *)&symbol, sizeBytes, + offset, kind, stream); +} + +/** + * @brief Returns occupancy for a kernel function. + * @ingroup Occupancy + * @param [out] numBlocks - Pointer of occupancy in number of blocks. + * @param [in] f - The kernel function to launch on the device. + * @param [in] blockSize - The block size as kernel launched. + * @param [in] dynSharedMemPerBlk - Dynamic shared memory in bytes per block. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +template +inline hipError_t +hipOccupancyMaxActiveBlocksPerMultiprocessor(int *numBlocks, T f, int blockSize, + size_t dynSharedMemPerBlk) { + return hipOccupancyMaxActiveBlocksPerMultiprocessor( + numBlocks, reinterpret_cast(f), blockSize, + dynSharedMemPerBlk); +} +/** + * @brief Returns occupancy for a device function with the specified flags. + * + * @ingroup Occupancy + * @param [out] numBlocks - Pointer of occupancy in number of blocks. + * @param [in] f - The kernel function to launch on the device. + * @param [in] blockSize - The block size as kernel launched. + * @param [in] dynSharedMemPerBlk - Dynamic shared memory in bytes per block. + * @param [in] flags - Flag to handle the behavior for the occupancy calculator. + * + * @returns #hipSuccess, #hipErrorInvalidValue + * + */ +template +inline hipError_t hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + int *numBlocks, T f, int blockSize, size_t dynSharedMemPerBlk, + unsigned int flags) { + return hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + numBlocks, reinterpret_cast(f), blockSize, + dynSharedMemPerBlk, flags); +} +/** + * @brief Returns grid and block size that achieves maximum potential occupancy + * for a device function + * + * @ingroup Occupancy + * Returns in \p *min_grid_size and \p *block_size a suggested grid / + * block size pair that achieves the best potential occupancy + * (i.e. the maximum number of active warps on the current device with the + * smallest number of blocks for a particular function). + * + * @param [out] min_grid_size minimum grid size needed to achieve the best + * potential occupancy + * @param [out] block_size block size required for the best potential + * occupancy + * @param [in] func device function symbol + * @param [in] block_size_to_dynamic_smem_size - a unary function/functor that + * takes block size, and returns the size, in bytes, of dynamic shared memory + * needed for a block + * @param [in] block_size_limit the maximum block size \p func is designed to + * work with. 0 means no limit. + * @param [in] flags reserved + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidDeviceFunction, + * #hipErrorInvalidValue, #hipErrorUnknown + */ +template +static hipError_t + __host__ inline hipOccupancyMaxPotentialBlockSizeVariableSMemWithFlags( + int *min_grid_size, int *block_size, T func, + UnaryFunction block_size_to_dynamic_smem_size, int block_size_limit = 0, + unsigned int flags = 0) { + if (min_grid_size == nullptr || block_size == nullptr || + reinterpret_cast(func) == nullptr) { + return hipErrorInvalidValue; + } + + int dev; + hipError_t status; + if ((status = hipGetDevice(&dev)) != hipSuccess) { + return status; + } + + int max_threads_per_cu; + if ((status = hipDeviceGetAttribute( + &max_threads_per_cu, hipDeviceAttributeMaxThreadsPerMultiProcessor, + dev)) != hipSuccess) { + return status; + } + + int warp_size; + if ((status = hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, + dev)) != hipSuccess) { + return status; + } + + int max_cu_count; + if ((status = hipDeviceGetAttribute(&max_cu_count, + hipDeviceAttributeMultiprocessorCount, + dev)) != hipSuccess) { + return status; + } + + struct hipFuncAttributes attr; + if ((status = hipFuncGetAttributes( + &attr, reinterpret_cast(func))) != hipSuccess) { + return status; + } + + // Initial limits for the execution + const int func_max_threads_per_block = attr.maxThreadsPerBlock; + if (block_size_limit == 0) { + block_size_limit = func_max_threads_per_block; + } + + if (func_max_threads_per_block < block_size_limit) { + block_size_limit = func_max_threads_per_block; + } + + const int block_size_limit_aligned = + ((block_size_limit + (warp_size - 1)) / warp_size) * warp_size; + + // For maximum search + int max_threads = 0; + int max_block_size{}; + int max_num_blocks{}; + for (int block_size_check_aligned = block_size_limit_aligned; + block_size_check_aligned > 0; block_size_check_aligned -= warp_size) { + // Make sure the logic uses the requested limit and not aligned + int block_size_check = (block_size_limit < block_size_check_aligned) + ? block_size_limit + : block_size_check_aligned; + + size_t dyn_smem_size = block_size_to_dynamic_smem_size(block_size_check); + int optimal_blocks; + if ((status = hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + &optimal_blocks, func, block_size_check, dyn_smem_size, flags)) != + hipSuccess) { + return status; + } + + int total_threads = block_size_check * optimal_blocks; + if (total_threads > max_threads) { + max_block_size = block_size_check; + max_num_blocks = optimal_blocks; + max_threads = total_threads; + } + + // Break if the logic reached possible maximum + if (max_threads_per_cu == max_threads) { + break; + } + } + + // Grid size is the number of blocks per CU * CU count + *min_grid_size = max_num_blocks * max_cu_count; + *block_size = max_block_size; + + return status; +} + +/** + * @brief Returns grid and block size that achieves maximum potential occupancy + * for a device function + * + * @ingroup Occupancy + * Returns in \p *min_grid_size and \p *block_size a suggested grid / + * block size pair that achieves the best potential occupancy + * (i.e. the maximum number of active warps on the current device with the + * smallest number of blocks for a particular function). + * + * @param [out] min_grid_size minimum grid size needed to achieve the best + * potential occupancy + * @param [out] block_size block size required for the best potential + * occupancy + * @param [in] func device function symbol + * @param [in] block_size_to_dynamic_smem_size - a unary function/functor that + * takes block size, and returns the size, in bytes, of dynamic shared memory + * needed for a block + * @param [in] block_size_limit the maximum block size \p func is designed to + * work with. 0 means no limit. + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidDeviceFunction, + * #hipErrorInvalidValue, #hipErrorUnknown + */ +template +static hipError_t __host__ inline hipOccupancyMaxPotentialBlockSizeVariableSMem( + int *min_grid_size, int *block_size, T func, + UnaryFunction block_size_to_dynamic_smem_size, int block_size_limit = 0) { + return hipOccupancyMaxPotentialBlockSizeVariableSMemWithFlags( + min_grid_size, block_size, func, block_size_to_dynamic_smem_size, + block_size_limit); +} +/** + * @brief Returns grid and block size that achieves maximum potential occupancy + * for a device function + * + * @ingroup Occupancy + * + * Returns in \p *min_grid_size and \p *block_size a suggested grid / + * block size pair that achieves the best potential occupancy + * (i.e. the maximum number of active warps on the current device with the + * smallest number of blocks for a particular function). + * + * @return #hipSuccess, #hipErrorInvalidDevice, #hipErrorInvalidValue + * + * @see hipOccupancyMaxPotentialBlockSize + */ +template +inline hipError_t hipOccupancyMaxPotentialBlockSize(int *gridSize, + int *blockSize, F kernel, + size_t dynSharedMemPerBlk, + uint32_t blockSizeLimit) { + return hipOccupancyMaxPotentialBlockSize(gridSize, blockSize, + (hipFunction_t)kernel, + dynSharedMemPerBlk, blockSizeLimit); +} +/** + * @brief Launches a device function + * + * @ingroup Execution + * + * @param [in] f device function symbol + * @param [in] gridDim grid dimentions + * @param [in] blockDim block dimentions + * @param [in] kernelParams kernel parameters + * @param [in] sharedMemBytes shared memory in bytes + * @param [in] stream stream on which kernel launched + * + * @return #hipSuccess, #hipErrorLaunchFailure, #hipErrorInvalidValue, + * #hipErrorInvalidResourceHandle + * + */ +template +inline hipError_t hipLaunchCooperativeKernel(T f, dim3 gridDim, dim3 blockDim, + void **kernelParams, + unsigned int sharedMemBytes, + hipStream_t stream) { + return hipLaunchCooperativeKernel(reinterpret_cast(f), gridDim, + blockDim, kernelParams, sharedMemBytes, + stream); +} +/** + * @brief Launches device function on multiple devices where thread blocks can + * cooperate and synchronize on execution. + * + * @ingroup Execution + * + * @param [in] launchParamsList list of kernel launch parameters, one per + * device + * @param [in] numDevices size of launchParamsList array + * @param [in] flags flag to handle launch behavior + * + * @return #hipSuccess, #hipErrorLaunchFailure, #hipErrorInvalidValue, + * #hipErrorInvalidResourceHandle + * + */ +template +inline hipError_t +hipLaunchCooperativeKernelMultiDevice(hipLaunchParams *launchParamsList, + unsigned int numDevices, + unsigned int flags = 0) { + return hipLaunchCooperativeKernelMultiDevice(launchParamsList, numDevices, + flags); +} +/** + * + * @ingroup Module + * + * @brief Launches kernels on multiple devices and guarantees all specified + * kernels are dispatched on respective streams before enqueuing any other work + * on the specified streams from any other threads + * + * + * @param [in] launchParamsList List of launch parameters, one per + * device. + * @param [in] numDevices Size of the launchParamsList array. + * @param [in] flags Flags to control launch behavior. + * + * @returns #hipSuccess, #hipErrorInvalidValue + */ +template +inline hipError_t +hipExtLaunchMultiKernelMultiDevice(hipLaunchParams *launchParamsList, + unsigned int numDevices, + unsigned int flags = 0) { + return hipExtLaunchMultiKernelMultiDevice(launchParamsList, numDevices, + flags); +} +/** + * @brief Binds a memory area to a texture. + * + * @ingroup TextureD + * + * @param [in] offset Offset in bytes. + * @param [in] tex Texture to bind. + * @param [in] devPtr Pointer of memory on the device. + * @param [in] size Size of memory in bites. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipBindTexture(size_t *offset, const struct texture &tex, + const void *devPtr, size_t size = UINT_MAX) { + return hipBindTexture(offset, &tex, devPtr, &tex.channelDesc, size); +} +/** + * @brief Binds a memory area to a texture. + * + * @ingroup TextureD + * + * @param [in] offset Offset in bytes. + * @param [in] tex Texture to bind. + * @param [in] devPtr Pointer of memory on the device. + * @param [in] desc Texture channel format. + * @param [in] size Size of memory in bites. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipBindTexture(size_t *offset, const struct texture &tex, + const void *devPtr, const struct hipChannelFormatDesc &desc, + size_t size = UINT_MAX) { + return hipBindTexture(offset, &tex, devPtr, &desc, size); +} +/** + * @brief Binds a 2D memory area to a texture. + * + * @ingroup TextureD + * + * @param [in] offset Offset in bytes. + * @param [in] tex Texture to bind. + * @param [in] devPtr Pointer of 2D memory area on the device. + * @param [in] width Width in texel units. + * @param [in] height Height in texel units. + * @param [in] pitch Pitch in bytes. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipBindTexture2D(size_t *offset, + const struct texture &tex, + const void *devPtr, size_t width, size_t height, + size_t pitch) { + return hipBindTexture2D(offset, &tex, devPtr, &tex.channelDesc, width, height, + pitch); +} +/** + * @brief Binds a 2D memory area to a texture. + * + * @ingroup TextureD + * + * @param [in] offset Offset in bytes. + * @param [in] tex Texture to bind. + * @param [in] devPtr Pointer of 2D memory area on the device. + * @param [in] desc Texture channel format. + * @param [in] width Width in texel units. + * @param [in] height Height in texel units. + * @param [in] pitch Pitch in bytes. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipBindTexture2D(size_t *offset, + const struct texture &tex, + const void *devPtr, + const struct hipChannelFormatDesc &desc, size_t width, + size_t height, size_t pitch) { + return hipBindTexture2D(offset, &tex, devPtr, &desc, width, height, pitch); +} +/** + * @brief Binds an array to a texture. + * + * @ingroup TextureD + * + * @param [in] tex Texture to bind. + * @param [in] array Array of memory on the device. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipBindTextureToArray(const struct texture &tex, + hipArray_const_t array) { + struct hipChannelFormatDesc desc; + hipError_t err = hipGetChannelDesc(&desc, array); + return (err == hipSuccess) ? hipBindTextureToArray(&tex, array, &desc) : err; +} +/** + * @brief Binds an array to a texture. + * + * @ingroup TextureD + * + * @param [in] tex Texture to bind. + * @param [in] array Array of memory on the device. + * @param [in] desc Texture channel format. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipBindTextureToArray(const struct texture &tex, + hipArray_const_t array, + const struct hipChannelFormatDesc &desc) { + return hipBindTextureToArray(&tex, array, &desc); +} +/** + * @brief Binds a mipmapped array to a texture. + * + * @ingroup TextureD + * + * @param [in] tex Texture to bind. + * @param [in] mipmappedArray Mipmapped Array of memory on the device. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipBindTextureToMipmappedArray(const struct texture &tex, + hipMipmappedArray_const_t mipmappedArray) { + struct hipChannelFormatDesc desc; + hipArray_t levelArray; + hipError_t err = hipGetMipmappedArrayLevel(&levelArray, mipmappedArray, 0); + if (err != hipSuccess) { + return err; + } + err = hipGetChannelDesc(&desc, levelArray); + return (err == hipSuccess) + ? hipBindTextureToMipmappedArray(&tex, mipmappedArray, &desc) + : err; +} +/** + * @brief Binds a mipmapped array to a texture. + * + * @ingroup TextureD + * + * @param [in] tex Texture to bind. + * @param [in] mipmappedArray Mipmapped Array of memory on the device. + * @param [in] desc Texture channel format. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipBindTextureToMipmappedArray(const struct texture &tex, + hipMipmappedArray_const_t mipmappedArray, + const struct hipChannelFormatDesc &desc) { + return hipBindTextureToMipmappedArray(&tex, mipmappedArray, &desc); +} +/** + * @brief Unbinds a texture. + * + * @ingroup TextureD + * + * @param [in] tex Texture to unbind. + * + * @warning This API is deprecated. + * + */ +template +DEPRECATED(DEPRECATED_MSG) +static inline hipError_t + hipUnbindTexture(const struct texture &tex) { + return hipUnbindTexture(&tex); +} +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @ingroup StreamO + * @{ + * + * This section describes wrappers for stream Ordered allocation from memory + *pool functions of HIP runtime API. + * + * @note APIs in this section are implemented on Linux, under development on + *Windows. + * + */ + +/** + * @brief C++ wrappers for allocations from a memory pool + * + * This is an alternate C++ calls for @p hipMallocFromPoolAsync made available + * through function overloading. + * + * @see hipMallocFromPoolAsync + * + * @note This API is implemented on Linux, under development on Windows. + */ +static inline hipError_t hipMallocAsync(void **dev_ptr, size_t size, + hipMemPool_t mem_pool, + hipStream_t stream) { + return hipMallocFromPoolAsync(dev_ptr, size, mem_pool, stream); +} +/** + * @brief C++ wrappers for allocations from a memory pool on the stream + * + * This is an alternate C++ calls for @p hipMallocFromPoolAsync made available + * through function overloading. + * + * @see hipMallocFromPoolAsync + * + * @note This API is implemented on Linux, under development on Windows. + */ +template +static inline hipError_t hipMallocAsync(T **dev_ptr, size_t size, + hipMemPool_t mem_pool, + hipStream_t stream) { + return hipMallocFromPoolAsync(reinterpret_cast(dev_ptr), size, + mem_pool, stream); +} +/** + * @brief C++ wrappers for allocations from a memory pool + * + * This is an alternate C++ calls for @p hipMallocFromPoolAsync made available + * through function overloading. + * + * @see hipMallocFromPoolAsync + * + * @note This API is implemented on Linux, under development on Windows. + */ +template +static inline hipError_t hipMallocAsync(T **dev_ptr, size_t size, + hipStream_t stream) { + return hipMallocAsync(reinterpret_cast(dev_ptr), size, stream); +} +/** + * @brief C++ wrappers for allocations from a memory pool + * + * This is an alternate C++ calls for @p hipMallocFromPoolAsync made available + * through function overloading. + * + * @see hipMallocFromPoolAsync + * + * @note This API is implemented on Linux, under development on Windows. + */ +template +static inline hipError_t hipMallocFromPoolAsync(T **dev_ptr, size_t size, + hipMemPool_t mem_pool, + hipStream_t stream) { + return hipMallocFromPoolAsync(reinterpret_cast(dev_ptr), size, + mem_pool, stream); +} +/** + * @} + */ + +#endif // __cplusplus + +#ifdef __GNUC__ +#pragma GCC visibility pop +#endif + +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include "hip/nvidia_detail/nvidia_hip_runtime_api.h" +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +/** + * @brief: C++ wrapper for hipMalloc + * @ingroup Memory + * Perform automatic type conversion to eliminate need for excessive typecasting + * (ie void**) + * + * __HIP_DISABLE_CPP_FUNCTIONS__ macro can be defined to suppress these + * wrappers. It is useful for applications which need to obtain decltypes of + * HIP runtime APIs. + * + * @see hipMalloc + */ +#if defined(__cplusplus) && !defined(__HIP_DISABLE_CPP_FUNCTIONS__) +template static inline hipError_t hipMalloc(T **devPtr, size_t size) { + return hipMalloc((void **)devPtr, size); +} +/** + * @brief: C++ wrapper for hipHostMalloc + * @ingroup Memory + * Provide an override to automatically typecast the pointer type from void**, + * and also provide a default for the flags. + * + * __HIP_DISABLE_CPP_FUNCTIONS__ macro can be defined to suppress these + * wrappers. It is useful for applications which need to obtain decltypes of + * HIP runtime APIs. + * + * @see hipHostMalloc + */ +template +static inline hipError_t +hipHostMalloc(T **ptr, size_t size, unsigned int flags = hipHostMallocDefault) { + return hipHostMalloc((void **)ptr, size, flags); +} +/** + * @brief: C++ wrapper for hipMallocManaged + * + * @ingroup MemoryM + * Provide an override to automatically typecast the pointer type from void**, + * and also provide a default for the flags. + * + * __HIP_DISABLE_CPP_FUNCTIONS__ macro can be defined to suppress these + * wrappers. It is useful for applications which need to obtain decltypes of + * HIP runtime APIs. + * + * @see hipMallocManaged + * + */ +template +static inline hipError_t +hipMallocManaged(T **devPtr, size_t size, + unsigned int flags = hipMemAttachGlobal) { + return hipMallocManaged((void **)devPtr, size, flags); +} + +#endif +#endif +// doxygen end HIP API +/** + * @} + */ +#include + +#if USE_PROF_API +#include +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_texture_types.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_texture_types.h new file mode 100644 index 000000000..6a8044086 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_texture_types.h @@ -0,0 +1,28 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_HIP_TEXTURE_TYPES_H +#define HIP_INCLUDE_HIP_HIP_TEXTURE_TYPES_H + +#include + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_vector_types.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_vector_types.h new file mode 100644 index 000000000..b11d04064 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_vector_types.h @@ -0,0 +1,40 @@ +/* +Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +//! hip_vector_types.h : Defines the HIP vector types. + +#ifndef HIP_INCLUDE_HIP_HIP_VECTOR_TYPES_H +#define HIP_INCLUDE_HIP_HIP_VECTOR_TYPES_H + +#include + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#if __cplusplus +#include +#endif +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_version.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_version.h new file mode 100644 index 000000000..22c488e77 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hip_version.h @@ -0,0 +1,18 @@ +// Auto-generated by cmake + +#ifndef HIP_VERSION_H +#define HIP_VERSION_H + +#define HIP_VERSION_MAJOR 6 +#define HIP_VERSION_MINOR 2 +#define HIP_VERSION_PATCH 41134 +#define HIP_VERSION_GITHASH "65d174c3e" +#define HIP_VERSION_BUILD_ID 0 +#define HIP_VERSION_BUILD_NAME "" +#define HIP_VERSION \ + (HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR * 100000 + \ + HIP_VERSION_PATCH) + +#define __HIP_HAS_GET_PCH 1 + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hiprtc.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hiprtc.h new file mode 100644 index 000000000..a5dc0ce28 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/hiprtc.h @@ -0,0 +1,454 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#pragma once + +#include + +#if !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#include + +#if !defined(_WIN32) +#pragma GCC visibility push(default) +#endif + +/** + * + * @addtogroup GlobalDefs + * @{ + * + */ +/** + * hiprtc error code + */ +typedef enum hiprtcResult { + HIPRTC_SUCCESS = 0, ///< Success + HIPRTC_ERROR_OUT_OF_MEMORY = 1, ///< Out of memory + HIPRTC_ERROR_PROGRAM_CREATION_FAILURE = 2, ///< Failed to create program + HIPRTC_ERROR_INVALID_INPUT = 3, ///< Invalid input + HIPRTC_ERROR_INVALID_PROGRAM = 4, ///< Invalid program + HIPRTC_ERROR_INVALID_OPTION = 5, ///< Invalid option + HIPRTC_ERROR_COMPILATION = 6, ///< Compilation error + HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7, ///< Failed in builtin operation + HIPRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = + 8, ///< No name expression after compilation + HIPRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = + 9, ///< No lowered names before compilation + HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10, ///< Invalid name expression + HIPRTC_ERROR_INTERNAL_ERROR = 11, ///< Internal error + HIPRTC_ERROR_LINKING = 100 ///< Error in linking +} hiprtcResult; + +/** + * hiprtc JIT option + */ + +typedef enum hiprtcJIT_option { + HIPRTC_JIT_MAX_REGISTERS = 0, ///< CUDA Only Maximum registers may be used in + ///< a thread, passed to compiler + HIPRTC_JIT_THREADS_PER_BLOCK, ///< CUDA Only Number of thread per block + HIPRTC_JIT_WALL_TIME, ///< CUDA Only Value for total wall clock time + HIPRTC_JIT_INFO_LOG_BUFFER, ///< CUDA Only Pointer to the buffer with logged + ///< information + HIPRTC_JIT_INFO_LOG_BUFFER_SIZE_BYTES, ///< CUDA Only Size of the buffer in + ///< bytes for logged info + HIPRTC_JIT_ERROR_LOG_BUFFER, ///< CUDA Only Pointer to the buffer with logged + ///< error(s) + HIPRTC_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, ///< CUDA Only Size of the buffer in + ///< bytes for logged error(s) + HIPRTC_JIT_OPTIMIZATION_LEVEL, ///< Value of optimization level for generated + ///< codes, acceptable options -O0, -O1, -O2, + ///< -O3 + HIPRTC_JIT_TARGET_FROM_HIPCONTEXT, ///< CUDA Only The target context, which is + ///< the default + HIPRTC_JIT_TARGET, ///< CUDA Only JIT target + HIPRTC_JIT_FALLBACK_STRATEGY, ///< CUDA Only Fallback strategy + HIPRTC_JIT_GENERATE_DEBUG_INFO, ///< CUDA Only Generate debug information + HIPRTC_JIT_LOG_VERBOSE, ///< CUDA Only Generate log verbose + HIPRTC_JIT_GENERATE_LINE_INFO, ///< CUDA Only Generate line number information + HIPRTC_JIT_CACHE_MODE, ///< CUDA Only Set cache mode + HIPRTC_JIT_NEW_SM3X_OPT, ///< @deprecated CUDA Only New SM3X option. + HIPRTC_JIT_FAST_COMPILE, ///< CUDA Only Set fast compile + HIPRTC_JIT_GLOBAL_SYMBOL_NAMES, ///< CUDA Only Array of device symbol names to + ///< be relocated to the host + HIPRTC_JIT_GLOBAL_SYMBOL_ADDRESS, ///< CUDA Only Array of host addresses to be + ///< relocated to the device + HIPRTC_JIT_GLOBAL_SYMBOL_COUNT, ///< CUDA Only Number of symbol count. + HIPRTC_JIT_LTO, ///< @deprecated CUDA Only Enable link-time optimization for + ///< device code + HIPRTC_JIT_FTZ, ///< @deprecated CUDA Only Set single-precision denormals. + HIPRTC_JIT_PREC_DIV, ///< @deprecated CUDA Only Set single-precision + ///< floating-point division and reciprocals + HIPRTC_JIT_PREC_SQRT, ///< @deprecated CUDA Only Set single-precision + ///< floating-point square root + HIPRTC_JIT_FMA, ///< @deprecated CUDA Only Enable floating-point multiplies + ///< and adds/subtracts operations + HIPRTC_JIT_NUM_OPTIONS, ///< Number of options + HIPRTC_JIT_IR_TO_ISA_OPT_EXT = + 10000, ///< Linker options to be passed on to compiler + /// @note Only supported for the AMD platform. + HIPRTC_JIT_IR_TO_ISA_OPT_COUNT_EXT, ///< Count of linker options to be passed + ///< on to compiler @note Only supported + ///< for the AMD platform +} hiprtcJIT_option; + +/** + * hiprtc JIT input type + */ +typedef enum hiprtcJITInputType { + HIPRTC_JIT_INPUT_CUBIN = 0, ///< Input cubin + HIPRTC_JIT_INPUT_PTX, ///< Input PTX + HIPRTC_JIT_INPUT_FATBINARY, ///< Input fat binary + HIPRTC_JIT_INPUT_OBJECT, ///< Input object + HIPRTC_JIT_INPUT_LIBRARY, ///< Input library + HIPRTC_JIT_INPUT_NVVM, ///< Input NVVM + HIPRTC_JIT_NUM_LEGACY_INPUT_TYPES, ///< Number of legacy input type + HIPRTC_JIT_INPUT_LLVM_BITCODE = 100, ///< LLVM bitcode or IR assembly + HIPRTC_JIT_INPUT_LLVM_BUNDLED_BITCODE = 101, ///< LLVM bundled bitcode + HIPRTC_JIT_INPUT_LLVM_ARCHIVES_OF_BUNDLED_BITCODE = + 102, ///< LLVM archives of boundled bitcode + HIPRTC_JIT_NUM_INPUT_TYPES = (HIPRTC_JIT_NUM_LEGACY_INPUT_TYPES + 3) +} hiprtcJITInputType; +/** + * @} + */ + +/** + * hiprtc link state + * + */ +typedef struct ihiprtcLinkState *hiprtcLinkState; +/** + * @ingroup Runtime + * + * @brief Returns text string message to explain the error which occurred + * + * @param [in] result code to convert to string. + * @returns const char pointer to the NULL-terminated error string + * + * @warning In HIP, this function returns the name of the error, + * if the hiprtc result is defined, it will return "Invalid HIPRTC error code" + * + * @see hiprtcResult + */ +const char *hiprtcGetErrorString(hiprtcResult result); + +/** + * @ingroup Runtime + * @brief Sets the parameters as major and minor version. + * + * @param [out] major HIP Runtime Compilation major version. + * @param [out] minor HIP Runtime Compilation minor version. + * + * @returns #HIPRTC_ERROR_INVALID_INPUT, #HIPRTC_SUCCESS + * + */ +hiprtcResult hiprtcVersion(int *major, int *minor); + +/** + * hiprtc program + * + */ +typedef struct _hiprtcProgram *hiprtcProgram; + +/** + * @ingroup Runtime + * @brief Adds the given name exprssion to the runtime compilation program. + * + * @param [in] prog runtime compilation program instance. + * @param [in] name_expression const char pointer to the name expression. + * @returns #HIPRTC_SUCCESS + * + * If const char pointer is NULL, it will return #HIPRTC_ERROR_INVALID_INPUT. + * + * @see hiprtcResult + */ +hiprtcResult hiprtcAddNameExpression(hiprtcProgram prog, + const char *name_expression); + +/** + * @ingroup Runtime + * @brief Compiles the given runtime compilation program. + * + * @param [in] prog runtime compilation program instance. + * @param [in] numOptions number of compiler options. + * @param [in] options compiler options as const array of strins. + * @returns #HIPRTC_SUCCESS + * + * If the compiler failed to build the runtime compilation program, + * it will return #HIPRTC_ERROR_COMPILATION. + * + * @see hiprtcResult + */ +hiprtcResult hiprtcCompileProgram(hiprtcProgram prog, int numOptions, + const char **options); + +/** + * @ingroup Runtime + * @brief Creates an instance of hiprtcProgram with the given input parameters, + * and sets the output hiprtcProgram prog with it. + * + * @param [in, out] prog runtime compilation program instance. + * @param [in] src const char pointer to the program source. + * @param [in] name const char pointer to the program name. + * @param [in] numHeaders number of headers. + * @param [in] headers array of strings pointing to headers. + * @param [in] includeNames array of strings pointing to names included in + * program source. + * @returns #HIPRTC_SUCCESS + * + * Any invalide input parameter, it will return #HIPRTC_ERROR_INVALID_INPUT + * or #HIPRTC_ERROR_INVALID_PROGRAM. + * + * If failed to create the program, it will return + * #HIPRTC_ERROR_PROGRAM_CREATION_FAILURE. + * + * @see hiprtcResult + */ +hiprtcResult hiprtcCreateProgram(hiprtcProgram *prog, const char *src, + const char *name, int numHeaders, + const char **headers, + const char **includeNames); + +/** + * @brief Destroys an instance of given hiprtcProgram. + * @ingroup Runtime + * @param [in] prog runtime compilation program instance. + * @returns #HIPRTC_SUCCESS + * + * If prog is NULL, it will return #HIPRTC_ERROR_INVALID_INPUT. + * + * @see hiprtcResult + */ +hiprtcResult hiprtcDestroyProgram(hiprtcProgram *prog); + +/** + * @brief Gets the lowered (mangled) name from an instance of hiprtcProgram with + * the given input parameters, and sets the output lowered_name with it. + * @ingroup Runtime + * @param [in] prog runtime compilation program instance. + * @param [in] name_expression const char pointer to the name expression. + * @param [in, out] lowered_name const char array to the lowered (mangled) + * name. + * @returns #HIPRTC_SUCCESS + * + * If any invalide nullptr input parameters, it will return + * #HIPRTC_ERROR_INVALID_INPUT + * + * If name_expression is not found, it will return + * #HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID + * + * If failed to get lowered_name from the program, it will return + * #HIPRTC_ERROR_COMPILATION. + * + * @see hiprtcResult + */ +hiprtcResult hiprtcGetLoweredName(hiprtcProgram prog, + const char *name_expression, + const char **lowered_name); + +/** + * @brief Gets the log generated by the runtime compilation program instance. + * @ingroup Runtime + * @param [in] prog runtime compilation program instance. + * @param [out] log memory pointer to the generated log. + * @returns #HIPRTC_SUCCESS + * + * @see hiprtcResult + */ +hiprtcResult hiprtcGetProgramLog(hiprtcProgram prog, char *log); + +/** + * @brief Gets the size of log generated by the runtime compilation program + * instance. + * + * @param [in] prog runtime compilation program instance. + * @param [out] logSizeRet size of generated log. + * @returns #HIPRTC_SUCCESS + * + * @see hiprtcResult + */ +hiprtcResult hiprtcGetProgramLogSize(hiprtcProgram prog, size_t *logSizeRet); + +/** + * @brief Gets the pointer of compilation binary by the runtime compilation + * program instance. + * @ingroup Runtime + * @param [in] prog runtime compilation program instance. + * @param [out] code char pointer to binary. + * @returns #HIPRTC_SUCCESS + * + * @see hiprtcResult + */ +hiprtcResult hiprtcGetCode(hiprtcProgram prog, char *code); + +/** + * @brief Gets the size of compilation binary by the runtime compilation program + * instance. + * @ingroup Runtime + * @param [in] prog runtime compilation program instance. + * @param [out] codeSizeRet the size of binary. + * @returns #HIPRTC_SUCCESS + * + * @see hiprtcResult + */ +hiprtcResult hiprtcGetCodeSize(hiprtcProgram prog, size_t *codeSizeRet); + +/** + * @brief Gets the pointer of compiled bitcode by the runtime compilation + * program instance. + * + * @param [in] prog runtime compilation program instance. + * @param [out] bitcode char pointer to bitcode. + * @return HIPRTC_SUCCESS + * + * @see hiprtcResult + */ +hiprtcResult hiprtcGetBitcode(hiprtcProgram prog, char *bitcode); + +/** + * @brief Gets the size of compiled bitcode by the runtime compilation program + * instance. + * @ingroup Runtime + * + * @param [in] prog runtime compilation program instance. + * @param [out] bitcode_size the size of bitcode. + * @returns #HIPRTC_SUCCESS + * + * @see hiprtcResult + */ +hiprtcResult hiprtcGetBitcodeSize(hiprtcProgram prog, size_t *bitcode_size); + +/** + * @brief Creates the link instance via hiprtc APIs. + * @ingroup Runtime + * @param [in] num_options Number of options + * @param [in] option_ptr Array of options + * @param [in] option_vals_pptr Array of option values cast to void* + * @param [out] hip_link_state_ptr hiprtc link state created upon success + * + * @returns #HIPRTC_SUCCESS, #HIPRTC_ERROR_INVALID_INPUT, + * #HIPRTC_ERROR_INVALID_OPTION + * + * @see hiprtcResult + */ +hiprtcResult hiprtcLinkCreate(unsigned int num_options, + hiprtcJIT_option *option_ptr, + void **option_vals_pptr, + hiprtcLinkState *hip_link_state_ptr); + +/** + * @brief Adds a file with bit code to be linked with options + * @ingroup Runtime + * @param [in] hip_link_state hiprtc link state + * @param [in] input_type Type of the input data or bitcode + * @param [in] file_path Path to the input file where bitcode is present + * @param [in] num_options Size of the options + * @param [in] options_ptr Array of options applied to this input + * @param [in] option_values Array of option values cast to void* + * + * @returns #HIPRTC_SUCCESS + * + * If input values are invalid, it will + * @return #HIPRTC_ERROR_INVALID_INPUT + * + * @see hiprtcResult + */ + +hiprtcResult hiprtcLinkAddFile(hiprtcLinkState hip_link_state, + hiprtcJITInputType input_type, + const char *file_path, unsigned int num_options, + hiprtcJIT_option *options_ptr, + void **option_values); + +/** + * @brief Completes the linking of the given program. + * @ingroup Runtime + * @param [in] hip_link_state hiprtc link state + * @param [in] input_type Type of the input data or bitcode + * @param [in] image Input data which is null terminated + * @param [in] image_size Size of the input data + * @param [in] name Optional name for this input + * @param [in] num_options Size of the options + * @param [in] options_ptr Array of options applied to this input + * @param [in] option_values Array of option values cast to void* + * + * @returns #HIPRTC_SUCCESS, #HIPRTC_ERROR_INVALID_INPUT + * + * If adding the file fails, it will + * @return #HIPRTC_ERROR_PROGRAM_CREATION_FAILURE + * + * @see hiprtcResult + */ + +hiprtcResult hiprtcLinkAddData(hiprtcLinkState hip_link_state, + hiprtcJITInputType input_type, void *image, + size_t image_size, const char *name, + unsigned int num_options, + hiprtcJIT_option *options_ptr, + void **option_values); + +/** + * @brief Completes the linking of the given program. + * @ingroup Runtime + * @param [in] hip_link_state hiprtc link state + * @param [out] bin_out Upon success, points to the output binary + * @param [out] size_out Size of the binary is stored (optional) + * + * @returns #HIPRTC_SUCCESS + * + * If adding the data fails, it will + * @return #HIPRTC_ERROR_LINKING + * + * @see hiprtcResult + */ +hiprtcResult hiprtcLinkComplete(hiprtcLinkState hip_link_state, void **bin_out, + size_t *size_out); + +/** + * @brief Deletes the link instance via hiprtc APIs. + * @ingroup Runtime + * @param [in] hip_link_state link state instance + * + * @returns #HIPRTC_SUCCESS + * + * @see hiprtcResult + */ +hiprtcResult hiprtcLinkDestroy(hiprtcLinkState hip_link_state); + +#if !defined(_WIN32) +#pragma GCC visibility pop +#endif + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/library_types.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/library_types.h new file mode 100644 index 000000000..fb41014c6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/library_types.h @@ -0,0 +1,78 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_LIBRARY_TYPES_H +#define HIP_INCLUDE_HIP_LIBRARY_TYPES_H + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) + +typedef enum hipDataType { + HIP_R_32F = 0, + HIP_R_64F = 1, + HIP_R_16F = 2, + HIP_R_8I = 3, + HIP_C_32F = 4, + HIP_C_64F = 5, + HIP_C_16F = 6, + HIP_C_8I = 7, + HIP_R_8U = 8, + HIP_C_8U = 9, + HIP_R_32I = 10, + HIP_C_32I = 11, + HIP_R_32U = 12, + HIP_C_32U = 13, + HIP_R_16BF = 14, + HIP_C_16BF = 15, + HIP_R_4I = 16, + HIP_C_4I = 17, + HIP_R_4U = 18, + HIP_C_4U = 19, + HIP_R_16I = 20, + HIP_C_16I = 21, + HIP_R_16U = 22, + HIP_C_16U = 23, + HIP_R_64I = 24, + HIP_C_64I = 25, + HIP_R_64U = 26, + HIP_C_64U = 27, + // HIP specific Data Types + HIP_R_8F_E4M3_FNUZ = 1000, + HIP_R_8F_E5M2_FNUZ = 1001 +} hipDataType; + +typedef enum hipLibraryPropertyType { + HIP_LIBRARY_MAJOR_VERSION, + HIP_LIBRARY_MINOR_VERSION, + HIP_LIBRARY_PATCH_LEVEL +} hipLibraryPropertyType; + +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include "library_types.h" +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/math_functions.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/math_functions.h new file mode 100644 index 000000000..81d0c11d2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/math_functions.h @@ -0,0 +1,42 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_MATH_FUNCTIONS_H +#define HIP_INCLUDE_HIP_MATH_FUNCTIONS_H + +// Some standard header files, these are included by hc.hpp and so want to make +// them avail on both paths to provide a consistent include env and avoid +// "missing symbol" errors that only appears on NVCC path: + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +#include +#elif !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +// #include +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/surface_types.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/surface_types.h new file mode 100644 index 000000000..232e6ce5d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/surface_types.h @@ -0,0 +1,63 @@ +/* +Copyright (c) 2022 - 2023 Advanced Micro Devices, Inc. All rights reserved. +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/** + * @file surface_types.h + * @brief Defines surface types for HIP runtime. + */ + +#ifndef HIP_INCLUDE_HIP_SURFACE_TYPES_H +#define HIP_INCLUDE_HIP_SURFACE_TYPES_H + +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +#if !defined(__HIPCC_RTC__) +#include +#endif + +/** + * An opaque value that represents a hip surface object + */ +struct __hip_surface; +typedef struct __hip_surface *hipSurfaceObject_t; + +/** + * hip surface reference + */ +struct surfaceReference { + hipSurfaceObject_t surfaceObject; +}; + +/** + * hip surface boundary modes + */ +enum hipSurfaceBoundaryMode { + hipBoundaryModeZero = 0, + hipBoundaryModeTrap = 1, + hipBoundaryModeClamp = 2 +}; + +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + +#endif /* !HIP_INCLUDE_HIP_SURFACE_TYPES_H */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hip/texture_types.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/texture_types.h new file mode 100644 index 000000000..0ec389496 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hip/texture_types.h @@ -0,0 +1,200 @@ +/* +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef HIP_INCLUDE_HIP_TEXTURE_TYPES_H +#define HIP_INCLUDE_HIP_TEXTURE_TYPES_H + +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#pragma clang diagnostic ignored "-Wreserved-macro-identifier" +#pragma clang diagnostic ignored "-Wc++98-compat" +#endif + +#if !defined(__HIPCC_RTC__) +#include +#endif + +#if !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#include "texture_types.h" +#elif defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_PLATFORM_NVIDIA__) +/******************************************************************************* + * * + * * + * * + *******************************************************************************/ +#if !defined(__HIPCC_RTC__) +#include +#include +#include +#endif // !defined(__HIPCC_RTC__) + +#define hipTextureType1D 0x01 +#define hipTextureType2D 0x02 +#define hipTextureType3D 0x03 +#define hipTextureTypeCubemap 0x0C +#define hipTextureType1DLayered 0xF1 +#define hipTextureType2DLayered 0xF2 +#define hipTextureTypeCubemapLayered 0xFC + +/** + * Should be same as HSA_IMAGE_OBJECT_SIZE_DWORD/HSA_SAMPLER_OBJECT_SIZE_DWORD + */ +#define HIP_IMAGE_OBJECT_SIZE_DWORD 12 +#define HIP_SAMPLER_OBJECT_SIZE_DWORD 8 +#define HIP_SAMPLER_OBJECT_OFFSET_DWORD HIP_IMAGE_OBJECT_SIZE_DWORD +#define HIP_TEXTURE_OBJECT_SIZE_DWORD \ + (HIP_IMAGE_OBJECT_SIZE_DWORD + HIP_SAMPLER_OBJECT_SIZE_DWORD) + +/** + * An opaque value that represents a hip texture object + */ +struct __hip_texture; +typedef struct __hip_texture *hipTextureObject_t; + +/** + * hip texture address modes + */ +enum hipTextureAddressMode { + hipAddressModeWrap = 0, + hipAddressModeClamp = 1, + hipAddressModeMirror = 2, + hipAddressModeBorder = 3 +}; + +/** + * hip texture filter modes + */ +enum hipTextureFilterMode { hipFilterModePoint = 0, hipFilterModeLinear = 1 }; + +/** + * hip texture read modes + */ +enum hipTextureReadMode { + hipReadModeElementType = 0, + hipReadModeNormalizedFloat = 1 +}; + +/** + * hip texture reference + */ +typedef struct textureReference { + int normalized; + enum hipTextureReadMode readMode; // used only for driver API's + enum hipTextureFilterMode filterMode; + enum hipTextureAddressMode + addressMode[3]; // Texture address mode for up to 3 dimensions + struct hipChannelFormatDesc channelDesc; + int sRGB; // Perform sRGB->linear conversion during texture read + unsigned int maxAnisotropy; // Limit to the anisotropy ratio + enum hipTextureFilterMode mipmapFilterMode; + float mipmapLevelBias; + float minMipmapLevelClamp; + float maxMipmapLevelClamp; + + hipTextureObject_t textureObject; + int numChannels; + enum hipArray_Format format; +} textureReference; + +/** + * hip texture descriptor + */ +typedef struct hipTextureDesc { + enum hipTextureAddressMode + addressMode[3]; // Texture address mode for up to 3 dimensions + enum hipTextureFilterMode filterMode; + enum hipTextureReadMode readMode; + int sRGB; // Perform sRGB->linear conversion during texture read + float borderColor[4]; + int normalizedCoords; + unsigned int maxAnisotropy; + enum hipTextureFilterMode mipmapFilterMode; + float mipmapLevelBias; + float minMipmapLevelClamp; + float maxMipmapLevelClamp; +} hipTextureDesc; + +#if __cplusplus + +/******************************************************************************* + * * + * * + * * + *******************************************************************************/ +#if __HIP__ +#define __HIP_TEXTURE_ATTRIB __attribute__((device_builtin_texture_type)) +#else +#define __HIP_TEXTURE_ATTRIB +#endif + +typedef textureReference *hipTexRef; + +template +struct __HIP_TEXTURE_ATTRIB texture : public textureReference { + texture(int norm = 0, enum hipTextureFilterMode fMode = hipFilterModePoint, + enum hipTextureAddressMode aMode = hipAddressModeClamp) { + normalized = norm; + readMode = mode; + filterMode = fMode; + addressMode[0] = aMode; + addressMode[1] = aMode; + addressMode[2] = aMode; + channelDesc = hipCreateChannelDesc(); + sRGB = 0; + textureObject = nullptr; + maxAnisotropy = 0; + mipmapLevelBias = 0; + minMipmapLevelClamp = 0; + maxMipmapLevelClamp = 0; + } + + texture(int norm, enum hipTextureFilterMode fMode, + enum hipTextureAddressMode aMode, struct hipChannelFormatDesc desc) { + normalized = norm; + readMode = mode; + filterMode = fMode; + addressMode[0] = aMode; + addressMode[1] = aMode; + addressMode[2] = aMode; + channelDesc = desc; + sRGB = 0; + textureObject = nullptr; + maxAnisotropy = 0; + mipmapLevelBias = 0; + minMipmapLevelClamp = 0; + maxMipmapLevelClamp = 0; + } +}; + +#endif /* __cplusplus */ + +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/Brig.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/Brig.h new file mode 100644 index 000000000..848b5f38c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/Brig.h @@ -0,0 +1,1127 @@ +// University of Illinois/NCSA +// Open Source License +// +// Copyright (c) 2013-2015, Advanced Micro Devices, Inc. +// All rights reserved. +// +// Developed by: +// +// HSA Team +// +// Advanced Micro Devices, Inc +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// with the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// +// * Redistributions in binary form must reproduce the above copyright +// notice, +// this list of conditions and the following disclaimers in the +// documentation and/or other materials provided with the distribution. +// +// * Neither the names of the LLVM Team, University of Illinois at +// Urbana-Champaign, nor the names of its contributors may be used to +// endorse or promote products derived from this Software without specific +// prior written permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH +// THE SOFTWARE. + +#ifndef INCLUDED_BRIG_H +#define INCLUDED_BRIG_H + +#include /* size_t */ +#include /* uintXX_t */ + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/*========================================================================================*/ +/* =======================================================================================*/ +/* =======================================================================================*/ +/* =======================================================================================*/ + +typedef uint32_t BrigCodeOffset32_t; +typedef uint32_t BrigOperandOffset32_t; +typedef uint32_t BrigDataOffset32_t; + +typedef BrigDataOffset32_t BrigDataOffsetCodeList32_t; +typedef BrigDataOffset32_t BrigDataOffsetOperandList32_t; +typedef BrigDataOffset32_t BrigDataOffsetString32_t; + +typedef uint32_t BrigVersion32_t; +enum BrigVersion { + BRIG_VERSION_HSAIL_MAJOR = 1, + BRIG_VERSION_HSAIL_MINOR = 0, + BRIG_VERSION_BRIG_MAJOR = 1, + BRIG_VERSION_BRIG_MINOR = 0 +}; + +typedef uint16_t BrigKind16_t; +enum BrigKind { + BRIG_KIND_NONE = 0x0000, + + BRIG_KIND_DIRECTIVE_BEGIN = 0x1000, + BRIG_KIND_DIRECTIVE_ARG_BLOCK_END = 0x1000, + BRIG_KIND_DIRECTIVE_ARG_BLOCK_START = 0x1001, + BRIG_KIND_DIRECTIVE_COMMENT = 0x1002, + BRIG_KIND_DIRECTIVE_CONTROL = 0x1003, + BRIG_KIND_DIRECTIVE_EXTENSION = 0x1004, + BRIG_KIND_DIRECTIVE_FBARRIER = 0x1005, + BRIG_KIND_DIRECTIVE_FUNCTION = 0x1006, + BRIG_KIND_DIRECTIVE_INDIRECT_FUNCTION = 0x1007, + BRIG_KIND_DIRECTIVE_KERNEL = 0x1008, + BRIG_KIND_DIRECTIVE_LABEL = 0x1009, + BRIG_KIND_DIRECTIVE_LOC = 0x100a, + BRIG_KIND_DIRECTIVE_MODULE = 0x100b, + BRIG_KIND_DIRECTIVE_PRAGMA = 0x100c, + BRIG_KIND_DIRECTIVE_SIGNATURE = 0x100d, + BRIG_KIND_DIRECTIVE_VARIABLE = 0x100e, + BRIG_KIND_DIRECTIVE_END = 0x100f, + + BRIG_KIND_INST_BEGIN = 0x2000, + BRIG_KIND_INST_ADDR = 0x2000, + BRIG_KIND_INST_ATOMIC = 0x2001, + BRIG_KIND_INST_BASIC = 0x2002, + BRIG_KIND_INST_BR = 0x2003, + BRIG_KIND_INST_CMP = 0x2004, + BRIG_KIND_INST_CVT = 0x2005, + BRIG_KIND_INST_IMAGE = 0x2006, + BRIG_KIND_INST_LANE = 0x2007, + BRIG_KIND_INST_MEM = 0x2008, + BRIG_KIND_INST_MEM_FENCE = 0x2009, + BRIG_KIND_INST_MOD = 0x200a, + BRIG_KIND_INST_QUERY_IMAGE = 0x200b, + BRIG_KIND_INST_QUERY_SAMPLER = 0x200c, + BRIG_KIND_INST_QUEUE = 0x200d, + BRIG_KIND_INST_SEG = 0x200e, + BRIG_KIND_INST_SEG_CVT = 0x200f, + BRIG_KIND_INST_SIGNAL = 0x2010, + BRIG_KIND_INST_SOURCE_TYPE = 0x2011, + BRIG_KIND_INST_END = 0x2012, + + BRIG_KIND_OPERAND_BEGIN = 0x3000, + BRIG_KIND_OPERAND_ADDRESS = 0x3000, + BRIG_KIND_OPERAND_ALIGN = 0x3001, + BRIG_KIND_OPERAND_CODE_LIST = 0x3002, + BRIG_KIND_OPERAND_CODE_REF = 0x3003, + BRIG_KIND_OPERAND_CONSTANT_BYTES = 0x3004, + BRIG_KIND_OPERAND_RESERVED = 0x3005, + BRIG_KIND_OPERAND_CONSTANT_IMAGE = 0x3006, + BRIG_KIND_OPERAND_CONSTANT_OPERAND_LIST = 0x3007, + BRIG_KIND_OPERAND_CONSTANT_SAMPLER = 0x3008, + BRIG_KIND_OPERAND_OPERAND_LIST = 0x3009, + BRIG_KIND_OPERAND_REGISTER = 0x300a, + BRIG_KIND_OPERAND_STRING = 0x300b, + BRIG_KIND_OPERAND_WAVESIZE = 0x300c, + BRIG_KIND_OPERAND_END = 0x300d +}; + +typedef uint8_t BrigAlignment8_t; +enum BrigAlignment { + BRIG_ALIGNMENT_NONE = 0, + BRIG_ALIGNMENT_1 = 1, + BRIG_ALIGNMENT_2 = 2, + BRIG_ALIGNMENT_4 = 3, + BRIG_ALIGNMENT_8 = 4, + BRIG_ALIGNMENT_16 = 5, + BRIG_ALIGNMENT_32 = 6, + BRIG_ALIGNMENT_64 = 7, + BRIG_ALIGNMENT_128 = 8, + BRIG_ALIGNMENT_256 = 9, + BRIG_ALIGNMENT_MAX = BRIG_ALIGNMENT_256 +}; + +typedef uint8_t BrigAllocation8_t; +enum BrigAllocation { + BRIG_ALLOCATION_NONE = 0, + BRIG_ALLOCATION_PROGRAM = 1, + BRIG_ALLOCATION_AGENT = 2, + BRIG_ALLOCATION_AUTOMATIC = 3 +}; + +typedef uint8_t BrigAluModifier8_t; +enum BrigAluModifierMask { BRIG_ALU_FTZ = 1 }; + +typedef uint8_t BrigAtomicOperation8_t; +enum BrigAtomicOperation { + BRIG_ATOMIC_ADD = 0, + BRIG_ATOMIC_AND = 1, + BRIG_ATOMIC_CAS = 2, + BRIG_ATOMIC_EXCH = 3, + BRIG_ATOMIC_LD = 4, + BRIG_ATOMIC_MAX = 5, + BRIG_ATOMIC_MIN = 6, + BRIG_ATOMIC_OR = 7, + BRIG_ATOMIC_ST = 8, + BRIG_ATOMIC_SUB = 9, + BRIG_ATOMIC_WRAPDEC = 10, + BRIG_ATOMIC_WRAPINC = 11, + BRIG_ATOMIC_XOR = 12, + BRIG_ATOMIC_WAIT_EQ = 13, + BRIG_ATOMIC_WAIT_NE = 14, + BRIG_ATOMIC_WAIT_LT = 15, + BRIG_ATOMIC_WAIT_GTE = 16, + BRIG_ATOMIC_WAITTIMEOUT_EQ = 17, + BRIG_ATOMIC_WAITTIMEOUT_NE = 18, + BRIG_ATOMIC_WAITTIMEOUT_LT = 19, + BRIG_ATOMIC_WAITTIMEOUT_GTE = 20 +}; + +typedef uint8_t BrigCompareOperation8_t; +enum BrigCompareOperation { + BRIG_COMPARE_EQ = 0, + BRIG_COMPARE_NE = 1, + BRIG_COMPARE_LT = 2, + BRIG_COMPARE_LE = 3, + BRIG_COMPARE_GT = 4, + BRIG_COMPARE_GE = 5, + BRIG_COMPARE_EQU = 6, + BRIG_COMPARE_NEU = 7, + BRIG_COMPARE_LTU = 8, + BRIG_COMPARE_LEU = 9, + BRIG_COMPARE_GTU = 10, + BRIG_COMPARE_GEU = 11, + BRIG_COMPARE_NUM = 12, + BRIG_COMPARE_NAN = 13, + BRIG_COMPARE_SEQ = 14, + BRIG_COMPARE_SNE = 15, + BRIG_COMPARE_SLT = 16, + BRIG_COMPARE_SLE = 17, + BRIG_COMPARE_SGT = 18, + BRIG_COMPARE_SGE = 19, + BRIG_COMPARE_SGEU = 20, + BRIG_COMPARE_SEQU = 21, + BRIG_COMPARE_SNEU = 22, + BRIG_COMPARE_SLTU = 23, + BRIG_COMPARE_SLEU = 24, + BRIG_COMPARE_SNUM = 25, + BRIG_COMPARE_SNAN = 26, + BRIG_COMPARE_SGTU = 27 +}; + +typedef uint16_t BrigControlDirective16_t; +enum BrigControlDirective { + BRIG_CONTROL_NONE = 0, + BRIG_CONTROL_ENABLEBREAKEXCEPTIONS = 1, + BRIG_CONTROL_ENABLEDETECTEXCEPTIONS = 2, + BRIG_CONTROL_MAXDYNAMICGROUPSIZE = 3, + BRIG_CONTROL_MAXFLATGRIDSIZE = 4, + BRIG_CONTROL_MAXFLATWORKGROUPSIZE = 5, + BRIG_CONTROL_REQUIREDDIM = 6, + BRIG_CONTROL_REQUIREDGRIDSIZE = 7, + BRIG_CONTROL_REQUIREDWORKGROUPSIZE = 8, + BRIG_CONTROL_REQUIRENOPARTIALWORKGROUPS = 9 +}; + +typedef uint8_t BrigExecutableModifier8_t; +enum BrigExecutableModifierMask { BRIG_EXECUTABLE_DEFINITION = 1 }; + +typedef uint8_t BrigImageChannelOrder8_t; +enum BrigImageChannelOrder { + BRIG_CHANNEL_ORDER_A = 0, + BRIG_CHANNEL_ORDER_R = 1, + BRIG_CHANNEL_ORDER_RX = 2, + BRIG_CHANNEL_ORDER_RG = 3, + BRIG_CHANNEL_ORDER_RGX = 4, + BRIG_CHANNEL_ORDER_RA = 5, + BRIG_CHANNEL_ORDER_RGB = 6, + BRIG_CHANNEL_ORDER_RGBX = 7, + BRIG_CHANNEL_ORDER_RGBA = 8, + BRIG_CHANNEL_ORDER_BGRA = 9, + BRIG_CHANNEL_ORDER_ARGB = 10, + BRIG_CHANNEL_ORDER_ABGR = 11, + BRIG_CHANNEL_ORDER_SRGB = 12, + BRIG_CHANNEL_ORDER_SRGBX = 13, + BRIG_CHANNEL_ORDER_SRGBA = 14, + BRIG_CHANNEL_ORDER_SBGRA = 15, + BRIG_CHANNEL_ORDER_INTENSITY = 16, + BRIG_CHANNEL_ORDER_LUMINANCE = 17, + BRIG_CHANNEL_ORDER_DEPTH = 18, + BRIG_CHANNEL_ORDER_DEPTH_STENCIL = 19, + + BRIG_CHANNEL_ORDER_FIRST_USER_DEFINED = 128 +}; + +typedef uint8_t BrigImageChannelType8_t; +enum BrigImageChannelType { + BRIG_CHANNEL_TYPE_SNORM_INT8 = 0, + BRIG_CHANNEL_TYPE_SNORM_INT16 = 1, + BRIG_CHANNEL_TYPE_UNORM_INT8 = 2, + BRIG_CHANNEL_TYPE_UNORM_INT16 = 3, + BRIG_CHANNEL_TYPE_UNORM_INT24 = 4, + BRIG_CHANNEL_TYPE_UNORM_SHORT_555 = 5, + BRIG_CHANNEL_TYPE_UNORM_SHORT_565 = 6, + BRIG_CHANNEL_TYPE_UNORM_INT_101010 = 7, + BRIG_CHANNEL_TYPE_SIGNED_INT8 = 8, + BRIG_CHANNEL_TYPE_SIGNED_INT16 = 9, + BRIG_CHANNEL_TYPE_SIGNED_INT32 = 10, + BRIG_CHANNEL_TYPE_UNSIGNED_INT8 = 11, + BRIG_CHANNEL_TYPE_UNSIGNED_INT16 = 12, + BRIG_CHANNEL_TYPE_UNSIGNED_INT32 = 13, + BRIG_CHANNEL_TYPE_HALF_FLOAT = 14, + BRIG_CHANNEL_TYPE_FLOAT = 15, + + BRIG_CHANNEL_TYPE_FIRST_USER_DEFINED = 128 +}; + +typedef uint8_t BrigImageGeometry8_t; +enum BrigImageGeometry { + BRIG_GEOMETRY_1D = 0, + BRIG_GEOMETRY_2D = 1, + BRIG_GEOMETRY_3D = 2, + BRIG_GEOMETRY_1DA = 3, + BRIG_GEOMETRY_2DA = 4, + BRIG_GEOMETRY_1DB = 5, + BRIG_GEOMETRY_2DDEPTH = 6, + BRIG_GEOMETRY_2DADEPTH = 7, + + BRIG_GEOMETRY_FIRST_USER_DEFINED = 128 +}; + +typedef uint8_t BrigImageQuery8_t; +enum BrigImageQuery { + BRIG_IMAGE_QUERY_WIDTH = 0, + BRIG_IMAGE_QUERY_HEIGHT = 1, + BRIG_IMAGE_QUERY_DEPTH = 2, + BRIG_IMAGE_QUERY_ARRAY = 3, + BRIG_IMAGE_QUERY_CHANNELORDER = 4, + BRIG_IMAGE_QUERY_CHANNELTYPE = 5, + + BRIG_IMAGE_QUERY_FIRST_USER_DEFINED = 6 +}; + +typedef uint8_t BrigLinkage8_t; +enum BrigLinkage { + BRIG_LINKAGE_NONE = 0, + BRIG_LINKAGE_PROGRAM = 1, + BRIG_LINKAGE_MODULE = 2, + BRIG_LINKAGE_FUNCTION = 3, + BRIG_LINKAGE_ARG = 4 +}; + +typedef uint8_t BrigMachineModel8_t; +enum BrigMachineModel { + BRIG_MACHINE_SMALL = 0, + BRIG_MACHINE_LARGE = 1, +}; + +typedef uint8_t BrigMemoryModifier8_t; +enum BrigMemoryModifierMask { BRIG_MEMORY_CONST = 1 }; + +typedef uint8_t BrigMemoryOrder8_t; +enum BrigMemoryOrder { + BRIG_MEMORY_ORDER_NONE = 0, + BRIG_MEMORY_ORDER_RELAXED = 1, + BRIG_MEMORY_ORDER_SC_ACQUIRE = 2, + BRIG_MEMORY_ORDER_SC_RELEASE = 3, + BRIG_MEMORY_ORDER_SC_ACQUIRE_RELEASE = 4, +}; + +typedef uint8_t BrigMemoryScope8_t; +enum BrigMemoryScope { + BRIG_MEMORY_SCOPE_NONE = 0, + BRIG_MEMORY_SCOPE_WORKITEM = 1, + BRIG_MEMORY_SCOPE_WAVEFRONT = 2, + BRIG_MEMORY_SCOPE_WORKGROUP = 3, + BRIG_MEMORY_SCOPE_AGENT = 4, + BRIG_MEMORY_SCOPE_SYSTEM = 5, +}; + +typedef uint16_t BrigOpcode16_t; +enum BrigOpcode { + BRIG_OPCODE_NOP = 0, + BRIG_OPCODE_ABS = 1, + BRIG_OPCODE_ADD = 2, + BRIG_OPCODE_BORROW = 3, + BRIG_OPCODE_CARRY = 4, + BRIG_OPCODE_CEIL = 5, + BRIG_OPCODE_COPYSIGN = 6, + BRIG_OPCODE_DIV = 7, + BRIG_OPCODE_FLOOR = 8, + BRIG_OPCODE_FMA = 9, + BRIG_OPCODE_FRACT = 10, + BRIG_OPCODE_MAD = 11, + BRIG_OPCODE_MAX = 12, + BRIG_OPCODE_MIN = 13, + BRIG_OPCODE_MUL = 14, + BRIG_OPCODE_MULHI = 15, + BRIG_OPCODE_NEG = 16, + BRIG_OPCODE_REM = 17, + BRIG_OPCODE_RINT = 18, + BRIG_OPCODE_SQRT = 19, + BRIG_OPCODE_SUB = 20, + BRIG_OPCODE_TRUNC = 21, + BRIG_OPCODE_MAD24 = 22, + BRIG_OPCODE_MAD24HI = 23, + BRIG_OPCODE_MUL24 = 24, + BRIG_OPCODE_MUL24HI = 25, + BRIG_OPCODE_SHL = 26, + BRIG_OPCODE_SHR = 27, + BRIG_OPCODE_AND = 28, + BRIG_OPCODE_NOT = 29, + BRIG_OPCODE_OR = 30, + BRIG_OPCODE_POPCOUNT = 31, + BRIG_OPCODE_XOR = 32, + BRIG_OPCODE_BITEXTRACT = 33, + BRIG_OPCODE_BITINSERT = 34, + BRIG_OPCODE_BITMASK = 35, + BRIG_OPCODE_BITREV = 36, + BRIG_OPCODE_BITSELECT = 37, + BRIG_OPCODE_FIRSTBIT = 38, + BRIG_OPCODE_LASTBIT = 39, + BRIG_OPCODE_COMBINE = 40, + BRIG_OPCODE_EXPAND = 41, + BRIG_OPCODE_LDA = 42, + BRIG_OPCODE_MOV = 43, + BRIG_OPCODE_SHUFFLE = 44, + BRIG_OPCODE_UNPACKHI = 45, + BRIG_OPCODE_UNPACKLO = 46, + BRIG_OPCODE_PACK = 47, + BRIG_OPCODE_UNPACK = 48, + BRIG_OPCODE_CMOV = 49, + BRIG_OPCODE_CLASS = 50, + BRIG_OPCODE_NCOS = 51, + BRIG_OPCODE_NEXP2 = 52, + BRIG_OPCODE_NFMA = 53, + BRIG_OPCODE_NLOG2 = 54, + BRIG_OPCODE_NRCP = 55, + BRIG_OPCODE_NRSQRT = 56, + BRIG_OPCODE_NSIN = 57, + BRIG_OPCODE_NSQRT = 58, + BRIG_OPCODE_BITALIGN = 59, + BRIG_OPCODE_BYTEALIGN = 60, + BRIG_OPCODE_PACKCVT = 61, + BRIG_OPCODE_UNPACKCVT = 62, + BRIG_OPCODE_LERP = 63, + BRIG_OPCODE_SAD = 64, + BRIG_OPCODE_SADHI = 65, + BRIG_OPCODE_SEGMENTP = 66, + BRIG_OPCODE_FTOS = 67, + BRIG_OPCODE_STOF = 68, + BRIG_OPCODE_CMP = 69, + BRIG_OPCODE_CVT = 70, + BRIG_OPCODE_LD = 71, + BRIG_OPCODE_ST = 72, + BRIG_OPCODE_ATOMIC = 73, + BRIG_OPCODE_ATOMICNORET = 74, + BRIG_OPCODE_SIGNAL = 75, + BRIG_OPCODE_SIGNALNORET = 76, + BRIG_OPCODE_MEMFENCE = 77, + BRIG_OPCODE_RDIMAGE = 78, + BRIG_OPCODE_LDIMAGE = 79, + BRIG_OPCODE_STIMAGE = 80, + BRIG_OPCODE_IMAGEFENCE = 81, + BRIG_OPCODE_QUERYIMAGE = 82, + BRIG_OPCODE_QUERYSAMPLER = 83, + BRIG_OPCODE_CBR = 84, + BRIG_OPCODE_BR = 85, + BRIG_OPCODE_SBR = 86, + BRIG_OPCODE_BARRIER = 87, + BRIG_OPCODE_WAVEBARRIER = 88, + BRIG_OPCODE_ARRIVEFBAR = 89, + BRIG_OPCODE_INITFBAR = 90, + BRIG_OPCODE_JOINFBAR = 91, + BRIG_OPCODE_LEAVEFBAR = 92, + BRIG_OPCODE_RELEASEFBAR = 93, + BRIG_OPCODE_WAITFBAR = 94, + BRIG_OPCODE_LDF = 95, + BRIG_OPCODE_ACTIVELANECOUNT = 96, + BRIG_OPCODE_ACTIVELANEID = 97, + BRIG_OPCODE_ACTIVELANEMASK = 98, + BRIG_OPCODE_ACTIVELANEPERMUTE = 99, + BRIG_OPCODE_CALL = 100, + BRIG_OPCODE_SCALL = 101, + BRIG_OPCODE_ICALL = 102, + BRIG_OPCODE_RET = 103, + BRIG_OPCODE_ALLOCA = 104, + BRIG_OPCODE_CURRENTWORKGROUPSIZE = 105, + BRIG_OPCODE_CURRENTWORKITEMFLATID = 106, + BRIG_OPCODE_DIM = 107, + BRIG_OPCODE_GRIDGROUPS = 108, + BRIG_OPCODE_GRIDSIZE = 109, + BRIG_OPCODE_PACKETCOMPLETIONSIG = 110, + BRIG_OPCODE_PACKETID = 111, + BRIG_OPCODE_WORKGROUPID = 112, + BRIG_OPCODE_WORKGROUPSIZE = 113, + BRIG_OPCODE_WORKITEMABSID = 114, + BRIG_OPCODE_WORKITEMFLATABSID = 115, + BRIG_OPCODE_WORKITEMFLATID = 116, + BRIG_OPCODE_WORKITEMID = 117, + BRIG_OPCODE_CLEARDETECTEXCEPT = 118, + BRIG_OPCODE_GETDETECTEXCEPT = 119, + BRIG_OPCODE_SETDETECTEXCEPT = 120, + BRIG_OPCODE_ADDQUEUEWRITEINDEX = 121, + BRIG_OPCODE_CASQUEUEWRITEINDEX = 122, + BRIG_OPCODE_LDQUEUEREADINDEX = 123, + BRIG_OPCODE_LDQUEUEWRITEINDEX = 124, + BRIG_OPCODE_STQUEUEREADINDEX = 125, + BRIG_OPCODE_STQUEUEWRITEINDEX = 126, + BRIG_OPCODE_CLOCK = 127, + BRIG_OPCODE_CUID = 128, + BRIG_OPCODE_DEBUGTRAP = 129, + BRIG_OPCODE_GROUPBASEPTR = 130, + BRIG_OPCODE_KERNARGBASEPTR = 131, + BRIG_OPCODE_LANEID = 132, + BRIG_OPCODE_MAXCUID = 133, + BRIG_OPCODE_MAXWAVEID = 134, + BRIG_OPCODE_NULLPTR = 135, + BRIG_OPCODE_WAVEID = 136, + + BRIG_OPCODE_FIRST_USER_DEFINED = 32768, +}; + +typedef uint8_t BrigPack8_t; +enum BrigPack { + BRIG_PACK_NONE = 0, + BRIG_PACK_PP = 1, + BRIG_PACK_PS = 2, + BRIG_PACK_SP = 3, + BRIG_PACK_SS = 4, + BRIG_PACK_S = 5, + BRIG_PACK_P = 6, + BRIG_PACK_PPSAT = 7, + BRIG_PACK_PSSAT = 8, + BRIG_PACK_SPSAT = 9, + BRIG_PACK_SSSAT = 10, + BRIG_PACK_SSAT = 11, + BRIG_PACK_PSAT = 12 +}; + +typedef uint8_t BrigProfile8_t; +enum BrigProfile { + BRIG_PROFILE_BASE = 0, + BRIG_PROFILE_FULL = 1, +}; + +typedef uint16_t BrigRegisterKind16_t; +enum BrigRegisterKind { + BRIG_REGISTER_KIND_CONTROL = 0, + BRIG_REGISTER_KIND_SINGLE = 1, + BRIG_REGISTER_KIND_DOUBLE = 2, + BRIG_REGISTER_KIND_QUAD = 3 +}; + +typedef uint8_t BrigRound8_t; +enum BrigRound { + BRIG_ROUND_NONE = 0, + BRIG_ROUND_FLOAT_DEFAULT = 1, + BRIG_ROUND_FLOAT_NEAR_EVEN = 2, + BRIG_ROUND_FLOAT_ZERO = 3, + BRIG_ROUND_FLOAT_PLUS_INFINITY = 4, + BRIG_ROUND_FLOAT_MINUS_INFINITY = 5, + BRIG_ROUND_INTEGER_NEAR_EVEN = 6, + BRIG_ROUND_INTEGER_ZERO = 7, + BRIG_ROUND_INTEGER_PLUS_INFINITY = 8, + BRIG_ROUND_INTEGER_MINUS_INFINITY = 9, + BRIG_ROUND_INTEGER_NEAR_EVEN_SAT = 10, + BRIG_ROUND_INTEGER_ZERO_SAT = 11, + BRIG_ROUND_INTEGER_PLUS_INFINITY_SAT = 12, + BRIG_ROUND_INTEGER_MINUS_INFINITY_SAT = 13, + BRIG_ROUND_INTEGER_SIGNALING_NEAR_EVEN = 14, + BRIG_ROUND_INTEGER_SIGNALING_ZERO = 15, + BRIG_ROUND_INTEGER_SIGNALING_PLUS_INFINITY = 16, + BRIG_ROUND_INTEGER_SIGNALING_MINUS_INFINITY = 17, + BRIG_ROUND_INTEGER_SIGNALING_NEAR_EVEN_SAT = 18, + BRIG_ROUND_INTEGER_SIGNALING_ZERO_SAT = 19, + BRIG_ROUND_INTEGER_SIGNALING_PLUS_INFINITY_SAT = 20, + BRIG_ROUND_INTEGER_SIGNALING_MINUS_INFINITY_SAT = 21 +}; + +typedef uint8_t BrigSamplerAddressing8_t; +enum BrigSamplerAddressing { + BRIG_ADDRESSING_UNDEFINED = 0, + BRIG_ADDRESSING_CLAMP_TO_EDGE = 1, + BRIG_ADDRESSING_CLAMP_TO_BORDER = 2, + BRIG_ADDRESSING_REPEAT = 3, + BRIG_ADDRESSING_MIRRORED_REPEAT = 4, + + BRIG_ADDRESSING_FIRST_USER_DEFINED = 128 +}; + +typedef uint8_t BrigSamplerCoordNormalization8_t; +enum BrigSamplerCoordNormalization { + BRIG_COORD_UNNORMALIZED = 0, + BRIG_COORD_NORMALIZED = 1 +}; + +typedef uint8_t BrigSamplerFilter8_t; +enum BrigSamplerFilter { + BRIG_FILTER_NEAREST = 0, + BRIG_FILTER_LINEAR = 1, + + BRIG_FILTER_FIRST_USER_DEFINED = 128 +}; + +typedef uint8_t BrigSamplerQuery8_t; +enum BrigSamplerQuery { + BRIG_SAMPLER_QUERY_ADDRESSING = 0, + BRIG_SAMPLER_QUERY_COORD = 1, + BRIG_SAMPLER_QUERY_FILTER = 2 +}; + +typedef uint32_t BrigSectionIndex32_t; +enum BrigSectionIndex { + BRIG_SECTION_INDEX_DATA = 0, + BRIG_SECTION_INDEX_CODE = 1, + BRIG_SECTION_INDEX_OPERAND = 2, + + BRIG_SECTION_INDEX_BEGIN_IMPLEMENTATION_DEFINED = 3, +}; + +typedef uint8_t BrigSegCvtModifier8_t; +enum BrigSegCvtModifierMask { BRIG_SEG_CVT_NONULL = 1 }; + +typedef uint8_t BrigSegment8_t; +enum BrigSegment { + BRIG_SEGMENT_NONE = 0, + BRIG_SEGMENT_FLAT = 1, + BRIG_SEGMENT_GLOBAL = 2, + BRIG_SEGMENT_READONLY = 3, + BRIG_SEGMENT_KERNARG = 4, + BRIG_SEGMENT_GROUP = 5, + BRIG_SEGMENT_PRIVATE = 6, + BRIG_SEGMENT_SPILL = 7, + BRIG_SEGMENT_ARG = 8, + + BRIG_SEGMENT_FIRST_USER_DEFINED = 128 +}; + +enum { + BRIG_TYPE_BASE_SIZE = 5, + BRIG_TYPE_PACK_SIZE = 2, + BRIG_TYPE_ARRAY_SIZE = 1, + + BRIG_TYPE_BASE_SHIFT = 0, + BRIG_TYPE_PACK_SHIFT = BRIG_TYPE_BASE_SHIFT + BRIG_TYPE_BASE_SIZE, + BRIG_TYPE_ARRAY_SHIFT = BRIG_TYPE_PACK_SHIFT + BRIG_TYPE_PACK_SIZE, + + BRIG_TYPE_BASE_MASK = ((1 << BRIG_TYPE_BASE_SIZE) - 1) + << BRIG_TYPE_BASE_SHIFT, + BRIG_TYPE_PACK_MASK = ((1 << BRIG_TYPE_PACK_SIZE) - 1) + << BRIG_TYPE_PACK_SHIFT, + BRIG_TYPE_ARRAY_MASK = ((1 << BRIG_TYPE_ARRAY_SIZE) - 1) + << BRIG_TYPE_ARRAY_SHIFT, + + BRIG_TYPE_PACK_NONE = 0 << BRIG_TYPE_PACK_SHIFT, + BRIG_TYPE_PACK_32 = 1 << BRIG_TYPE_PACK_SHIFT, + BRIG_TYPE_PACK_64 = 2 << BRIG_TYPE_PACK_SHIFT, + BRIG_TYPE_PACK_128 = 3 << BRIG_TYPE_PACK_SHIFT, + + BRIG_TYPE_ARRAY = 1 << BRIG_TYPE_ARRAY_SHIFT +}; + +typedef uint16_t BrigType16_t; +enum BrigType { + BRIG_TYPE_NONE = 0, + BRIG_TYPE_U8 = 1, + BRIG_TYPE_U16 = 2, + BRIG_TYPE_U32 = 3, + BRIG_TYPE_U64 = 4, + BRIG_TYPE_S8 = 5, + BRIG_TYPE_S16 = 6, + BRIG_TYPE_S32 = 7, + BRIG_TYPE_S64 = 8, + BRIG_TYPE_F16 = 9, + BRIG_TYPE_F32 = 10, + BRIG_TYPE_F64 = 11, + BRIG_TYPE_B1 = 12, + BRIG_TYPE_B8 = 13, + BRIG_TYPE_B16 = 14, + BRIG_TYPE_B32 = 15, + BRIG_TYPE_B64 = 16, + BRIG_TYPE_B128 = 17, + BRIG_TYPE_SAMP = 18, + BRIG_TYPE_ROIMG = 19, + BRIG_TYPE_WOIMG = 20, + BRIG_TYPE_RWIMG = 21, + BRIG_TYPE_SIG32 = 22, + BRIG_TYPE_SIG64 = 23, + + BRIG_TYPE_U8X4 = BRIG_TYPE_U8 | BRIG_TYPE_PACK_32, + BRIG_TYPE_U8X8 = BRIG_TYPE_U8 | BRIG_TYPE_PACK_64, + BRIG_TYPE_U8X16 = BRIG_TYPE_U8 | BRIG_TYPE_PACK_128, + BRIG_TYPE_U16X2 = BRIG_TYPE_U16 | BRIG_TYPE_PACK_32, + BRIG_TYPE_U16X4 = BRIG_TYPE_U16 | BRIG_TYPE_PACK_64, + BRIG_TYPE_U16X8 = BRIG_TYPE_U16 | BRIG_TYPE_PACK_128, + BRIG_TYPE_U32X2 = BRIG_TYPE_U32 | BRIG_TYPE_PACK_64, + BRIG_TYPE_U32X4 = BRIG_TYPE_U32 | BRIG_TYPE_PACK_128, + BRIG_TYPE_U64X2 = BRIG_TYPE_U64 | BRIG_TYPE_PACK_128, + BRIG_TYPE_S8X4 = BRIG_TYPE_S8 | BRIG_TYPE_PACK_32, + BRIG_TYPE_S8X8 = BRIG_TYPE_S8 | BRIG_TYPE_PACK_64, + BRIG_TYPE_S8X16 = BRIG_TYPE_S8 | BRIG_TYPE_PACK_128, + BRIG_TYPE_S16X2 = BRIG_TYPE_S16 | BRIG_TYPE_PACK_32, + BRIG_TYPE_S16X4 = BRIG_TYPE_S16 | BRIG_TYPE_PACK_64, + BRIG_TYPE_S16X8 = BRIG_TYPE_S16 | BRIG_TYPE_PACK_128, + BRIG_TYPE_S32X2 = BRIG_TYPE_S32 | BRIG_TYPE_PACK_64, + BRIG_TYPE_S32X4 = BRIG_TYPE_S32 | BRIG_TYPE_PACK_128, + BRIG_TYPE_S64X2 = BRIG_TYPE_S64 | BRIG_TYPE_PACK_128, + BRIG_TYPE_F16X2 = BRIG_TYPE_F16 | BRIG_TYPE_PACK_32, + BRIG_TYPE_F16X4 = BRIG_TYPE_F16 | BRIG_TYPE_PACK_64, + BRIG_TYPE_F16X8 = BRIG_TYPE_F16 | BRIG_TYPE_PACK_128, + BRIG_TYPE_F32X2 = BRIG_TYPE_F32 | BRIG_TYPE_PACK_64, + BRIG_TYPE_F32X4 = BRIG_TYPE_F32 | BRIG_TYPE_PACK_128, + BRIG_TYPE_F64X2 = BRIG_TYPE_F64 | BRIG_TYPE_PACK_128, + + BRIG_TYPE_U8_ARRAY = BRIG_TYPE_U8 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U16_ARRAY = BRIG_TYPE_U16 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U32_ARRAY = BRIG_TYPE_U32 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U64_ARRAY = BRIG_TYPE_U64 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S8_ARRAY = BRIG_TYPE_S8 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S16_ARRAY = BRIG_TYPE_S16 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S32_ARRAY = BRIG_TYPE_S32 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S64_ARRAY = BRIG_TYPE_S64 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F16_ARRAY = BRIG_TYPE_F16 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F32_ARRAY = BRIG_TYPE_F32 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F64_ARRAY = BRIG_TYPE_F64 | BRIG_TYPE_ARRAY, + BRIG_TYPE_B8_ARRAY = BRIG_TYPE_B8 | BRIG_TYPE_ARRAY, + BRIG_TYPE_B16_ARRAY = BRIG_TYPE_B16 | BRIG_TYPE_ARRAY, + BRIG_TYPE_B32_ARRAY = BRIG_TYPE_B32 | BRIG_TYPE_ARRAY, + BRIG_TYPE_B64_ARRAY = BRIG_TYPE_B64 | BRIG_TYPE_ARRAY, + BRIG_TYPE_B128_ARRAY = BRIG_TYPE_B128 | BRIG_TYPE_ARRAY, + BRIG_TYPE_SAMP_ARRAY = BRIG_TYPE_SAMP | BRIG_TYPE_ARRAY, + BRIG_TYPE_ROIMG_ARRAY = BRIG_TYPE_ROIMG | BRIG_TYPE_ARRAY, + BRIG_TYPE_WOIMG_ARRAY = BRIG_TYPE_WOIMG | BRIG_TYPE_ARRAY, + BRIG_TYPE_RWIMG_ARRAY = BRIG_TYPE_RWIMG | BRIG_TYPE_ARRAY, + BRIG_TYPE_SIG32_ARRAY = BRIG_TYPE_SIG32 | BRIG_TYPE_ARRAY, + BRIG_TYPE_SIG64_ARRAY = BRIG_TYPE_SIG64 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U8X4_ARRAY = BRIG_TYPE_U8X4 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U8X8_ARRAY = BRIG_TYPE_U8X8 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U8X16_ARRAY = BRIG_TYPE_U8X16 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U16X2_ARRAY = BRIG_TYPE_U16X2 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U16X4_ARRAY = BRIG_TYPE_U16X4 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U16X8_ARRAY = BRIG_TYPE_U16X8 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U32X2_ARRAY = BRIG_TYPE_U32X2 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U32X4_ARRAY = BRIG_TYPE_U32X4 | BRIG_TYPE_ARRAY, + BRIG_TYPE_U64X2_ARRAY = BRIG_TYPE_U64X2 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S8X4_ARRAY = BRIG_TYPE_S8X4 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S8X8_ARRAY = BRIG_TYPE_S8X8 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S8X16_ARRAY = BRIG_TYPE_S8X16 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S16X2_ARRAY = BRIG_TYPE_S16X2 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S16X4_ARRAY = BRIG_TYPE_S16X4 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S16X8_ARRAY = BRIG_TYPE_S16X8 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S32X2_ARRAY = BRIG_TYPE_S32X2 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S32X4_ARRAY = BRIG_TYPE_S32X4 | BRIG_TYPE_ARRAY, + BRIG_TYPE_S64X2_ARRAY = BRIG_TYPE_S64X2 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F16X2_ARRAY = BRIG_TYPE_F16X2 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F16X4_ARRAY = BRIG_TYPE_F16X4 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F16X8_ARRAY = BRIG_TYPE_F16X8 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F32X2_ARRAY = BRIG_TYPE_F32X2 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F32X4_ARRAY = BRIG_TYPE_F32X4 | BRIG_TYPE_ARRAY, + BRIG_TYPE_F64X2_ARRAY = BRIG_TYPE_F64X2 | BRIG_TYPE_ARRAY, +}; + +typedef uint8_t BrigVariableModifier8_t; +enum BrigVariableModifierMask { + BRIG_VARIABLE_DEFINITION = 1, + BRIG_VARIABLE_CONST = 2 +}; + +typedef uint8_t BrigWidth8_t; +enum BrigWidth { + BRIG_WIDTH_NONE = 0, + BRIG_WIDTH_1 = 1, + BRIG_WIDTH_2 = 2, + BRIG_WIDTH_4 = 3, + BRIG_WIDTH_8 = 4, + BRIG_WIDTH_16 = 5, + BRIG_WIDTH_32 = 6, + BRIG_WIDTH_64 = 7, + BRIG_WIDTH_128 = 8, + BRIG_WIDTH_256 = 9, + BRIG_WIDTH_512 = 10, + BRIG_WIDTH_1024 = 11, + BRIG_WIDTH_2048 = 12, + BRIG_WIDTH_4096 = 13, + BRIG_WIDTH_8192 = 14, + BRIG_WIDTH_16384 = 15, + BRIG_WIDTH_32768 = 16, + BRIG_WIDTH_65536 = 17, + BRIG_WIDTH_131072 = 18, + BRIG_WIDTH_262144 = 19, + BRIG_WIDTH_524288 = 20, + BRIG_WIDTH_1048576 = 21, + BRIG_WIDTH_2097152 = 22, + BRIG_WIDTH_4194304 = 23, + BRIG_WIDTH_8388608 = 24, + BRIG_WIDTH_16777216 = 25, + BRIG_WIDTH_33554432 = 26, + BRIG_WIDTH_67108864 = 27, + BRIG_WIDTH_134217728 = 28, + BRIG_WIDTH_268435456 = 29, + BRIG_WIDTH_536870912 = 30, + BRIG_WIDTH_1073741824 = 31, + BRIG_WIDTH_2147483648 = 32, + BRIG_WIDTH_WAVESIZE = 33, + BRIG_WIDTH_ALL = 34, +}; + +struct BrigUInt64 { + uint32_t lo; + uint32_t hi; +}; + +struct BrigBase { + uint16_t byteCount; + BrigKind16_t kind; +}; + +struct BrigData { + uint32_t byteCount; + uint8_t bytes[1]; +}; + +struct BrigDirectiveArgBlock { + BrigBase base; +}; + +struct BrigDirectiveComment { + BrigBase base; + BrigDataOffsetString32_t name; +}; + +struct BrigDirectiveControl { + BrigBase base; + BrigControlDirective16_t control; + uint16_t reserved; + BrigDataOffsetOperandList32_t operands; +}; + +struct BrigDirectiveExecutable { + BrigBase base; + BrigDataOffsetString32_t name; + uint16_t outArgCount; + uint16_t inArgCount; + BrigCodeOffset32_t firstInArg; + BrigCodeOffset32_t firstCodeBlockEntry; + BrigCodeOffset32_t nextModuleEntry; + BrigExecutableModifier8_t modifier; + BrigLinkage8_t linkage; + uint16_t reserved; +}; + +struct BrigDirectiveExtension { + BrigBase base; + BrigDataOffsetString32_t name; +}; + +struct BrigDirectiveFbarrier { + BrigBase base; + BrigDataOffsetString32_t name; + BrigVariableModifier8_t modifier; + BrigLinkage8_t linkage; + uint16_t reserved; +}; + +struct BrigDirectiveLabel { + BrigBase base; + BrigDataOffsetString32_t name; +}; + +struct BrigDirectiveLoc { + BrigBase base; + BrigDataOffsetString32_t filename; + uint32_t line; + uint32_t column; +}; + +struct BrigDirectiveNone { + BrigBase base; +}; + +struct BrigDirectivePragma { + BrigBase base; + BrigDataOffsetOperandList32_t operands; +}; + +struct BrigDirectiveVariable { + BrigBase base; + BrigDataOffsetString32_t name; + BrigOperandOffset32_t init; + BrigType16_t type; + BrigSegment8_t segment; + BrigAlignment8_t align; + BrigUInt64 dim; + BrigVariableModifier8_t modifier; + BrigLinkage8_t linkage; + BrigAllocation8_t allocation; + uint8_t reserved; +}; + +struct BrigDirectiveModule { + BrigBase base; + BrigDataOffsetString32_t name; + BrigVersion32_t hsailMajor; + BrigVersion32_t hsailMinor; + BrigProfile8_t profile; + BrigMachineModel8_t machineModel; + BrigRound8_t defaultFloatRound; + uint8_t reserved; +}; + +struct BrigInstBase { + BrigBase base; + BrigOpcode16_t opcode; + BrigType16_t type; + BrigDataOffsetOperandList32_t operands; +}; + +struct BrigInstAddr { + BrigInstBase base; + BrigSegment8_t segment; + uint8_t reserved[3]; +}; + +struct BrigInstAtomic { + BrigInstBase base; + BrigSegment8_t segment; + BrigMemoryOrder8_t memoryOrder; + BrigMemoryScope8_t memoryScope; + BrigAtomicOperation8_t atomicOperation; + uint8_t equivClass; + uint8_t reserved[3]; +}; + +struct BrigInstBasic { + BrigInstBase base; +}; + +struct BrigInstBr { + BrigInstBase base; + BrigWidth8_t width; + uint8_t reserved[3]; +}; + +struct BrigInstCmp { + BrigInstBase base; + BrigType16_t sourceType; + BrigAluModifier8_t modifier; + BrigCompareOperation8_t compare; + BrigPack8_t pack; + uint8_t reserved[3]; +}; + +struct BrigInstCvt { + BrigInstBase base; + BrigType16_t sourceType; + BrigAluModifier8_t modifier; + BrigRound8_t round; +}; + +struct BrigInstImage { + BrigInstBase base; + BrigType16_t imageType; + BrigType16_t coordType; + BrigImageGeometry8_t geometry; + uint8_t equivClass; + uint16_t reserved; +}; + +struct BrigInstLane { + BrigInstBase base; + BrigType16_t sourceType; + BrigWidth8_t width; + uint8_t reserved; +}; + +struct BrigInstMem { + BrigInstBase base; + BrigSegment8_t segment; + BrigAlignment8_t align; + uint8_t equivClass; + BrigWidth8_t width; + BrigMemoryModifier8_t modifier; + uint8_t reserved[3]; +}; + +struct BrigInstMemFence { + BrigInstBase base; + BrigMemoryOrder8_t memoryOrder; + BrigMemoryScope8_t globalSegmentMemoryScope; + BrigMemoryScope8_t groupSegmentMemoryScope; + BrigMemoryScope8_t imageSegmentMemoryScope; +}; + +struct BrigInstMod { + BrigInstBase base; + BrigAluModifier8_t modifier; + BrigRound8_t round; + BrigPack8_t pack; + uint8_t reserved; +}; + +struct BrigInstQueryImage { + BrigInstBase base; + BrigType16_t imageType; + BrigImageGeometry8_t geometry; + BrigImageQuery8_t query; +}; + +struct BrigInstQuerySampler { + BrigInstBase base; + BrigSamplerQuery8_t query; + uint8_t reserved[3]; +}; + +struct BrigInstQueue { + BrigInstBase base; + BrigSegment8_t segment; + BrigMemoryOrder8_t memoryOrder; + uint16_t reserved; +}; + +struct BrigInstSeg { + BrigInstBase base; + BrigSegment8_t segment; + uint8_t reserved[3]; +}; + +struct BrigInstSegCvt { + BrigInstBase base; + BrigType16_t sourceType; + BrigSegment8_t segment; + BrigSegCvtModifier8_t modifier; +}; + +struct BrigInstSignal { + BrigInstBase base; + BrigType16_t signalType; + BrigMemoryOrder8_t memoryOrder; + BrigAtomicOperation8_t signalOperation; +}; + +struct BrigInstSourceType { + BrigInstBase base; + BrigType16_t sourceType; + uint16_t reserved; +}; + +struct BrigOperandAddress { + BrigBase base; + BrigCodeOffset32_t symbol; + BrigOperandOffset32_t reg; + BrigUInt64 offset; +}; + +struct BrigOperandAlign { + BrigBase base; + BrigAlignment8_t align; + uint8_t reserved[3]; +}; + +struct BrigOperandCodeList { + BrigBase base; + BrigDataOffsetCodeList32_t elements; +}; + +struct BrigOperandCodeRef { + BrigBase base; + BrigCodeOffset32_t ref; +}; + +struct BrigOperandConstantBytes { + BrigBase base; + BrigType16_t type; + uint16_t reserved; + BrigDataOffsetString32_t bytes; +}; + +struct BrigOperandConstantOperandList { + BrigBase base; + BrigType16_t type; + uint16_t reserved; + BrigDataOffsetOperandList32_t elements; +}; + +struct BrigOperandConstantImage { + BrigBase base; + BrigType16_t type; + BrigImageGeometry8_t geometry; + BrigImageChannelOrder8_t channelOrder; + BrigImageChannelType8_t channelType; + uint8_t reserved[3]; + BrigUInt64 width; + BrigUInt64 height; + BrigUInt64 depth; + BrigUInt64 array; +}; + +struct BrigOperandOperandList { + BrigBase base; + BrigDataOffsetOperandList32_t elements; +}; + +struct BrigOperandRegister { + BrigBase base; + BrigRegisterKind16_t regKind; + uint16_t regNum; +}; + +struct BrigOperandConstantSampler { + BrigBase base; + BrigType16_t type; + BrigSamplerCoordNormalization8_t coord; + BrigSamplerFilter8_t filter; + BrigSamplerAddressing8_t addressing; + uint8_t reserved[3]; +}; + +struct BrigOperandString { + BrigBase base; + BrigDataOffsetString32_t string; +}; + +struct BrigOperandWavesize { + BrigBase base; +}; + +typedef uint32_t BrigExceptions32_t; +enum BrigExceptionsMask { + BRIG_EXCEPTIONS_INVALID_OPERATION = 1 << 0, + BRIG_EXCEPTIONS_DIVIDE_BY_ZERO = 1 << 1, + BRIG_EXCEPTIONS_OVERFLOW = 1 << 2, + BRIG_EXCEPTIONS_UNDERFLOW = 1 << 3, + BRIG_EXCEPTIONS_INEXACT = 1 << 4, + + BRIG_EXCEPTIONS_FIRST_USER_DEFINED = 1 << 16 +}; + +struct BrigSectionHeader { + uint64_t byteCount; + uint32_t headerByteCount; + uint32_t nameLength; + uint8_t name[1]; +}; + +struct BrigModuleHeader { + char identification[8]; + BrigVersion32_t brigMajor; + BrigVersion32_t brigMinor; + uint64_t byteCount; + uint8_t hash[64]; + uint32_t reserved; + uint32_t sectionCount; + uint64_t sectionIndex; +}; + +typedef BrigModuleHeader *BrigModule_t; + +#ifdef __cplusplus +} +#endif /*__cplusplus*/ + +#endif // defined(INCLUDED_BRIG_H) diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_common.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_common.h new file mode 100644 index 000000000..f3fb7b8da --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_common.h @@ -0,0 +1,89 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +// The following set of header files provides definitions for AMD GPU +// Architecture: +// - amd_hsa_common.h +// - amd_hsa_elf.h +// - amd_hsa_kernel_code.h +// - amd_hsa_queue.h +// - amd_hsa_signal.h +// +// Refer to "HSA Application Binary Interface: AMD GPU Architecture" for more +// information. + +#ifndef AMD_HSA_COMMON_H +#define AMD_HSA_COMMON_H + +#include +#include + +// Descriptive version of the HSA Application Binary Interface. +#define AMD_HSA_ABI_VERSION "AMD GPU Architecture v0.35 (June 25, 2015)" + +// Alignment attribute that specifies a minimum alignment (in bytes) for +// variables of the specified type. +#if defined(__GNUC__) +#define __ALIGNED__(x) __attribute__((aligned(x))) +#elif defined(_MSC_VER) +#define __ALIGNED__(x) __declspec(align(x)) +#elif defined(RC_INVOKED) +#define __ALIGNED__(x) +#else +#error +#endif + +// Creates enumeration entries for packed types. Enumeration entries include +// bit shift amount, bit width, and bit mask. +#define AMD_HSA_BITS_CREATE_ENUM_ENTRIES(name, shift, width) \ + name##_SHIFT = (shift), name##_WIDTH = (width), \ + name = (((1 << (width)) - 1) << (shift)) + +// Gets bits for specified mask from specified src packed instance. +#define AMD_HSA_BITS_GET(src, mask) ((src & mask) >> mask##_SHIFT) + +// Sets val bits for specified mask in specified dst packed instance. +#define AMD_HSA_BITS_SET(dst, mask, val) \ + dst &= (~(1 << mask##_SHIFT) & ~mask); \ + dst |= (((val) << mask##_SHIFT) & mask) + +#endif // AMD_HSA_COMMON_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_elf.h new file mode 100644 index 000000000..22db60ceb --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -0,0 +1,468 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +// Undefine the macro in case it is defined in the system elf.h. +#undef EM_AMDGPU + +#ifndef AMD_HSA_ELF_H +#define AMD_HSA_ELF_H + +// AMD GPU Specific ELF Header Enumeration Values. +// +// Values are copied from LLVM BinaryFormat/ELF.h . This file also contains +// code object V1 defintions which are not part of the LLVM header. Code object +// V1 was only supported by the Finalizer which is now deprecated and removed. +// +// TODO: Deprecate and remove V1 support and replace this header with using the +// LLVM header. +namespace ELF { + +// Machine architectures +// See current registered ELF machine architectures at: +// http://www.uxsglobal.com/developers/gabi/latest/ch4.eheader.html +enum { + EM_AMDGPU = 224, // AMD GPU architecture +}; + +// OS ABI identification. +enum { + ELFOSABI_AMDGPU_HSA = 64, // AMD HSA runtime +}; + +// AMDGPU OS ABI Version identification. +enum { + // ELFABIVERSION_AMDGPU_HSA_V1 does not exist because OS ABI identification + // was never defined for V1. + ELFABIVERSION_AMDGPU_HSA_V2 = 0, + ELFABIVERSION_AMDGPU_HSA_V3 = 1, + ELFABIVERSION_AMDGPU_HSA_V4 = 2, + ELFABIVERSION_AMDGPU_HSA_V5 = 3, + ELFABIVERSION_AMDGPU_HSA_V6 = 4, +}; + +// AMDGPU specific e_flags. +enum : unsigned { + // Processor selection mask for EF_AMDGPU_MACH_* values. + EF_AMDGPU_MACH = 0x0ff, + + // Not specified processor. + EF_AMDGPU_MACH_NONE = 0x000, + + // AMDGCN-based processors. + // clang-format off + EF_AMDGPU_MACH_AMDGCN_GFX600 = 0x020, + EF_AMDGPU_MACH_AMDGCN_GFX601 = 0x021, + EF_AMDGPU_MACH_AMDGCN_GFX700 = 0x022, + EF_AMDGPU_MACH_AMDGCN_GFX701 = 0x023, + EF_AMDGPU_MACH_AMDGCN_GFX702 = 0x024, + EF_AMDGPU_MACH_AMDGCN_GFX703 = 0x025, + EF_AMDGPU_MACH_AMDGCN_GFX704 = 0x026, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X27 = 0x027, + EF_AMDGPU_MACH_AMDGCN_GFX801 = 0x028, + EF_AMDGPU_MACH_AMDGCN_GFX802 = 0x029, + EF_AMDGPU_MACH_AMDGCN_GFX803 = 0x02a, + EF_AMDGPU_MACH_AMDGCN_GFX810 = 0x02b, + EF_AMDGPU_MACH_AMDGCN_GFX900 = 0x02c, + EF_AMDGPU_MACH_AMDGCN_GFX902 = 0x02d, + EF_AMDGPU_MACH_AMDGCN_GFX904 = 0x02e, + EF_AMDGPU_MACH_AMDGCN_GFX906 = 0x02f, + EF_AMDGPU_MACH_AMDGCN_GFX908 = 0x030, + EF_AMDGPU_MACH_AMDGCN_GFX909 = 0x031, + EF_AMDGPU_MACH_AMDGCN_GFX90C = 0x032, + EF_AMDGPU_MACH_AMDGCN_GFX1010 = 0x033, + EF_AMDGPU_MACH_AMDGCN_GFX1011 = 0x034, + EF_AMDGPU_MACH_AMDGCN_GFX1012 = 0x035, + EF_AMDGPU_MACH_AMDGCN_GFX1030 = 0x036, + EF_AMDGPU_MACH_AMDGCN_GFX1031 = 0x037, + EF_AMDGPU_MACH_AMDGCN_GFX1032 = 0x038, + EF_AMDGPU_MACH_AMDGCN_GFX1033 = 0x039, + EF_AMDGPU_MACH_AMDGCN_GFX602 = 0x03a, + EF_AMDGPU_MACH_AMDGCN_GFX705 = 0x03b, + EF_AMDGPU_MACH_AMDGCN_GFX805 = 0x03c, + EF_AMDGPU_MACH_AMDGCN_GFX1035 = 0x03d, + EF_AMDGPU_MACH_AMDGCN_GFX1034 = 0x03e, + EF_AMDGPU_MACH_AMDGCN_GFX90A = 0x03f, + EF_AMDGPU_MACH_AMDGCN_GFX940 = 0x040, + EF_AMDGPU_MACH_AMDGCN_GFX1100 = 0x041, + EF_AMDGPU_MACH_AMDGCN_GFX1013 = 0x042, + EF_AMDGPU_MACH_AMDGCN_GFX1150 = 0x043, + EF_AMDGPU_MACH_AMDGCN_GFX1103 = 0x044, + EF_AMDGPU_MACH_AMDGCN_GFX1036 = 0x045, + EF_AMDGPU_MACH_AMDGCN_GFX1101 = 0x046, + EF_AMDGPU_MACH_AMDGCN_GFX1102 = 0x047, + EF_AMDGPU_MACH_AMDGCN_GFX1200 = 0x048, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X49 = 0x049, + EF_AMDGPU_MACH_AMDGCN_GFX1151 = 0x04a, + EF_AMDGPU_MACH_AMDGCN_GFX941 = 0x04b, + EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d, + EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e, + EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050, + EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051, + EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052, + EF_AMDGPU_MACH_AMDGCN_GFX10_3_GENERIC = 0x053, + EF_AMDGPU_MACH_AMDGCN_GFX11_GENERIC = 0x054, + EF_AMDGPU_MACH_AMDGCN_RESERVED_0X55 = 0x055, + // clang-format on + + // First/last AMDGCN-based processors. + EF_AMDGPU_MACH_AMDGCN_FIRST = EF_AMDGPU_MACH_AMDGCN_GFX600, + EF_AMDGPU_MACH_AMDGCN_LAST = EF_AMDGPU_MACH_AMDGCN_GFX11_GENERIC, + + // Indicates if the "xnack" target feature is enabled for all code contained + // in the object. + // + // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V2. + EF_AMDGPU_FEATURE_XNACK_V2 = 0x01, + // Indicates if the trap handler is enabled for all code contained + // in the object. + // + // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V2. + EF_AMDGPU_FEATURE_TRAP_HANDLER_V2 = 0x02, + + // Indicates if the "xnack" target feature is enabled for all code contained + // in the object. + // + // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V3. + EF_AMDGPU_FEATURE_XNACK_V3 = 0x100, + // Indicates if the "sramecc" target feature is enabled for all code + // contained in the object. + // + // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V3. + EF_AMDGPU_FEATURE_SRAMECC_V3 = 0x200, + + // XNACK selection mask for EF_AMDGPU_FEATURE_XNACK_* values. + // + // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V4. + EF_AMDGPU_FEATURE_XNACK_V4 = 0x300, + // XNACK is not supported. + EF_AMDGPU_FEATURE_XNACK_UNSUPPORTED_V4 = 0x000, + // XNACK is any/default/unspecified. + EF_AMDGPU_FEATURE_XNACK_ANY_V4 = 0x100, + // XNACK is off. + EF_AMDGPU_FEATURE_XNACK_OFF_V4 = 0x200, + // XNACK is on. + EF_AMDGPU_FEATURE_XNACK_ON_V4 = 0x300, + + // SRAMECC selection mask for EF_AMDGPU_FEATURE_SRAMECC_* values. + // + // Only valid for ELFOSABI_AMDGPU_HSA and ELFABIVERSION_AMDGPU_HSA_V4. + EF_AMDGPU_FEATURE_SRAMECC_V4 = 0xc00, + // SRAMECC is not supported. + EF_AMDGPU_FEATURE_SRAMECC_UNSUPPORTED_V4 = 0x000, + // SRAMECC is any/default/unspecified. + EF_AMDGPU_FEATURE_SRAMECC_ANY_V4 = 0x400, + // SRAMECC is off. + EF_AMDGPU_FEATURE_SRAMECC_OFF_V4 = 0x800, + // SRAMECC is on. + EF_AMDGPU_FEATURE_SRAMECC_ON_V4 = 0xc00, + + // Generic target versioning. This is contained in the list byte of EFLAGS. + EF_AMDGPU_GENERIC_VERSION = 0xff000000, + EF_AMDGPU_GENERIC_VERSION_OFFSET = 24, + EF_AMDGPU_GENERIC_VERSION_MIN = 1, + EF_AMDGPU_GENERIC_VERSION_MAX = 0xff, +}; + +// ELF Relocation types for AMDGPU. +enum : unsigned { + R_AMDGPU_ABS32_LO = 1, + R_AMDGPU_ABS32_HI = 2, + R_AMDGPU_ABS64 = 3, + R_AMDGPU_ABS32 = 6, + R_AMDGPU_RELATIVE64 = 13, +}; + +} // end namespace ELF + +// ELF Section Header Flag Enumeration Values. +#define SHF_AMDGPU_HSA_GLOBAL (0x00100000 & SHF_MASKOS) +#define SHF_AMDGPU_HSA_READONLY (0x00200000 & SHF_MASKOS) +#define SHF_AMDGPU_HSA_CODE (0x00400000 & SHF_MASKOS) +#define SHF_AMDGPU_HSA_AGENT (0x00800000 & SHF_MASKOS) + +// +typedef enum { + AMDGPU_HSA_SEGMENT_GLOBAL_PROGRAM = 0, + AMDGPU_HSA_SEGMENT_GLOBAL_AGENT = 1, + AMDGPU_HSA_SEGMENT_READONLY_AGENT = 2, + AMDGPU_HSA_SEGMENT_CODE_AGENT = 3, + AMDGPU_HSA_SEGMENT_LAST, +} amdgpu_hsa_elf_segment_t; + +// ELF Program Header Type Enumeration Values. +#define PT_AMDGPU_HSA_LOAD_GLOBAL_PROGRAM \ + (PT_LOOS + AMDGPU_HSA_SEGMENT_GLOBAL_PROGRAM) +#define PT_AMDGPU_HSA_LOAD_GLOBAL_AGENT \ + (PT_LOOS + AMDGPU_HSA_SEGMENT_GLOBAL_AGENT) +#define PT_AMDGPU_HSA_LOAD_READONLY_AGENT \ + (PT_LOOS + AMDGPU_HSA_SEGMENT_READONLY_AGENT) +#define PT_AMDGPU_HSA_LOAD_CODE_AGENT (PT_LOOS + AMDGPU_HSA_SEGMENT_CODE_AGENT) + +// ELF Symbol Type Enumeration Values. +#define STT_AMDGPU_HSA_KERNEL (STT_LOOS + 0) +#define STT_AMDGPU_HSA_INDIRECT_FUNCTION (STT_LOOS + 1) +#define STT_AMDGPU_HSA_METADATA (STT_LOOS + 2) + +// ELF Symbol Binding Enumeration Values. +#define STB_AMDGPU_HSA_EXTERNAL (STB_LOOS + 0) + +// ELF Symbol Other Information Creation/Retrieval. +#define ELF64_ST_AMDGPU_ALLOCATION(o) (((o) >> 2) & 0x3) +#define ELF64_ST_AMDGPU_FLAGS(o) ((o) >> 4) +#define ELF64_ST_AMDGPU_OTHER(f, a, v) \ + (((f) << 4) + (((a) & 0x3) << 2) + ((v) & 0x3)) + +typedef enum { + AMDGPU_HSA_SYMBOL_ALLOCATION_DEFAULT = 0, + AMDGPU_HSA_SYMBOL_ALLOCATION_GLOBAL_PROGRAM = 1, + AMDGPU_HSA_SYMBOL_ALLOCATION_GLOBAL_AGENT = 2, + AMDGPU_HSA_SYMBOL_ALLOCATION_READONLY_AGENT = 3, + AMDGPU_HSA_SYMBOL_ALLOCATION_LAST, +} amdgpu_hsa_symbol_allocation_t; + +// ELF Symbol Allocation Enumeration Values. +#define STA_AMDGPU_HSA_DEFAULT AMDGPU_HSA_SYMBOL_ALLOCATION_DEFAULT +#define STA_AMDGPU_HSA_GLOBAL_PROGRAM \ + AMDGPU_HSA_SYMBOL_ALLOCATION_GLOBAL_PROGRAM +#define STA_AMDGPU_HSA_GLOBAL_AGENT AMDGPU_HSA_SYMBOL_ALLOCATION_GLOBAL_AGENT +#define STA_AMDGPU_HSA_READONLY_AGENT \ + AMDGPU_HSA_SYMBOL_ALLOCATION_READONLY_AGENT + +typedef enum { + AMDGPU_HSA_SYMBOL_FLAG_DEFAULT = 0, + AMDGPU_HSA_SYMBOL_FLAG_CONST = 1, + AMDGPU_HSA_SYMBOL_FLAG_LAST, +} amdgpu_hsa_symbol_flag_t; + +// ELF Symbol Flag Enumeration Values. +#define STF_AMDGPU_HSA_CONST AMDGPU_HSA_SYMBOL_FLAG_CONST + +// Legacy/V1 AMD GPU Relocation Type Enumeration Values. +#define R_AMDGPU_V1_NONE 0 +#define R_AMDGPU_V1_32_LOW 1 +#define R_AMDGPU_V1_32_HIGH 2 +#define R_AMDGPU_V1_64 3 +#define R_AMDGPU_V1_INIT_SAMPLER 4 +#define R_AMDGPU_V1_INIT_IMAGE 5 +#define R_AMDGPU_V1_RELATIVE64 13 + +// AMD GPU Note Type Enumeration Values. +#define NT_AMD_HSA_CODE_OBJECT_VERSION 1 +#define NT_AMD_HSA_HSAIL 2 +#define NT_AMD_HSA_ISA_VERSION 3 +#define NT_AMD_HSA_PRODUCER 4 +#define NT_AMD_HSA_PRODUCER_OPTIONS 5 +#define NT_AMD_HSA_EXTENSION 6 +#define NT_AMD_HSA_ISA_NAME 11 +/* AMDGPU snapshots of runtime, agent and queues state for use in core dump */ +#define NT_AMDGPU_CORE_STATE 33 +#define NT_AMD_HSA_HLDEBUG_DEBUG 101 +#define NT_AMD_HSA_HLDEBUG_TARGET 102 + +// AMD GPU Metadata Kind Enumeration Values. +typedef uint16_t amdgpu_hsa_metadata_kind16_t; +typedef enum { + AMDGPU_HSA_METADATA_KIND_NONE = 0, + AMDGPU_HSA_METADATA_KIND_INIT_SAMP = 1, + AMDGPU_HSA_METADATA_KIND_INIT_ROIMG = 2, + AMDGPU_HSA_METADATA_KIND_INIT_WOIMG = 3, + AMDGPU_HSA_METADATA_KIND_INIT_RWIMG = 4 +} amdgpu_hsa_metadata_kind_t; + +// AMD GPU Sampler Coordinate Normalization Enumeration Values. +typedef uint8_t amdgpu_hsa_sampler_coord8_t; +typedef enum { + AMDGPU_HSA_SAMPLER_COORD_UNNORMALIZED = 0, + AMDGPU_HSA_SAMPLER_COORD_NORMALIZED = 1 +} amdgpu_hsa_sampler_coord_t; + +// AMD GPU Sampler Filter Enumeration Values. +typedef uint8_t amdgpu_hsa_sampler_filter8_t; +typedef enum { + AMDGPU_HSA_SAMPLER_FILTER_NEAREST = 0, + AMDGPU_HSA_SAMPLER_FILTER_LINEAR = 1 +} amdgpu_hsa_sampler_filter_t; + +// AMD GPU Sampler Addressing Enumeration Values. +typedef uint8_t amdgpu_hsa_sampler_addressing8_t; +typedef enum { + AMDGPU_HSA_SAMPLER_ADDRESSING_UNDEFINED = 0, + AMDGPU_HSA_SAMPLER_ADDRESSING_CLAMP_TO_EDGE = 1, + AMDGPU_HSA_SAMPLER_ADDRESSING_CLAMP_TO_BORDER = 2, + AMDGPU_HSA_SAMPLER_ADDRESSING_REPEAT = 3, + AMDGPU_HSA_SAMPLER_ADDRESSING_MIRRORED_REPEAT = 4 +} amdgpu_hsa_sampler_addressing_t; + +// AMD GPU Sampler Descriptor. +typedef struct amdgpu_hsa_sampler_descriptor_s { + uint16_t size; + amdgpu_hsa_metadata_kind16_t kind; + amdgpu_hsa_sampler_coord8_t coord; + amdgpu_hsa_sampler_filter8_t filter; + amdgpu_hsa_sampler_addressing8_t addressing; + uint8_t reserved1; +} amdgpu_hsa_sampler_descriptor_t; + +// AMD GPU Image Geometry Enumeration Values. +typedef uint8_t amdgpu_hsa_image_geometry8_t; +typedef enum { + AMDGPU_HSA_IMAGE_GEOMETRY_1D = 0, + AMDGPU_HSA_IMAGE_GEOMETRY_2D = 1, + AMDGPU_HSA_IMAGE_GEOMETRY_3D = 2, + AMDGPU_HSA_IMAGE_GEOMETRY_1DA = 3, + AMDGPU_HSA_IMAGE_GEOMETRY_2DA = 4, + AMDGPU_HSA_IMAGE_GEOMETRY_1DB = 5, + AMDGPU_HSA_IMAGE_GEOMETRY_2DDEPTH = 6, + AMDGPU_HSA_IMAGE_GEOMETRY_2DADEPTH = 7 +} amdgpu_hsa_image_geometry_t; + +// AMD GPU Image Channel Order Enumeration Values. +typedef uint8_t amdgpu_hsa_image_channel_order8_t; +typedef enum { + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_A = 0, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_R = 1, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_RX = 2, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_RG = 3, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_RGX = 4, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_RA = 5, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_RGB = 6, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_RGBX = 7, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_RGBA = 8, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_BGRA = 9, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_ARGB = 10, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_ABGR = 11, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_SRGB = 12, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_SRGBX = 13, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_SRGBA = 14, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_SBGRA = 15, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_INTENSITY = 16, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_LUMINANCE = 17, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_DEPTH = 18, + AMDGPU_HSA_IMAGE_CHANNEL_ORDER_DEPTH_STENCIL = 19 +} amdgpu_hsa_image_channel_order_t; + +// AMD GPU Image Channel Type Enumeration Values. +typedef uint8_t amdgpu_hsa_image_channel_type8_t; +typedef enum { + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_SNORM_INT8 = 0, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_SNORM_INT16 = 1, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_UNORM_INT8 = 2, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_UNORM_INT16 = 3, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_UNORM_INT24 = 4, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_SHORT_555 = 5, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_SHORT_565 = 6, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_INT_101010 = 7, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_SIGNED_INT8 = 8, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_SIGNED_INT16 = 9, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_SIGNED_INT32 = 10, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8 = 11, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16 = 12, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32 = 13, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_HALF_FLOAT = 14, + AMDGPU_HSA_IMAGE_CHANNEL_TYPE_FLOAT = 15 +} amdgpu_hsa_image_channel_type_t; + +// AMD GPU Image Descriptor. +typedef struct amdgpu_hsa_image_descriptor_s { + uint16_t size; + amdgpu_hsa_metadata_kind16_t kind; + amdgpu_hsa_image_geometry8_t geometry; + amdgpu_hsa_image_channel_order8_t channel_order; + amdgpu_hsa_image_channel_type8_t channel_type; + uint8_t reserved1; + uint64_t width; + uint64_t height; + uint64_t depth; + uint64_t array; +} amdgpu_hsa_image_descriptor_t; + +typedef struct amdgpu_hsa_note_code_object_version_s { + uint32_t major_version; + uint32_t minor_version; +} amdgpu_hsa_note_code_object_version_t; + +typedef struct amdgpu_hsa_note_hsail_s { + uint32_t hsail_major_version; + uint32_t hsail_minor_version; + uint8_t profile; + uint8_t machine_model; + uint8_t default_float_round; +} amdgpu_hsa_note_hsail_t; + +typedef struct amdgpu_hsa_note_isa_s { + uint16_t vendor_name_size; + uint16_t architecture_name_size; + uint32_t major; + uint32_t minor; + uint32_t stepping; + char vendor_and_architecture_name[1]; +} amdgpu_hsa_note_isa_t; + +typedef struct amdgpu_hsa_note_producer_s { + uint16_t producer_name_size; + uint16_t reserved; + uint32_t producer_major_version; + uint32_t producer_minor_version; + char producer_name[1]; +} amdgpu_hsa_note_producer_t; + +typedef struct amdgpu_hsa_note_producer_options_s { + uint16_t producer_options_size; + char producer_options[1]; +} amdgpu_hsa_note_producer_options_t; + +typedef enum { + AMDGPU_HSA_RODATA_GLOBAL_PROGRAM = 0, + AMDGPU_HSA_RODATA_GLOBAL_AGENT, + AMDGPU_HSA_RODATA_READONLY_AGENT, + AMDGPU_HSA_DATA_GLOBAL_PROGRAM, + AMDGPU_HSA_DATA_GLOBAL_AGENT, + AMDGPU_HSA_DATA_READONLY_AGENT, + AMDGPU_HSA_BSS_GLOBAL_PROGRAM, + AMDGPU_HSA_BSS_GLOBAL_AGENT, + AMDGPU_HSA_BSS_READONLY_AGENT, + AMDGPU_HSA_SECTION_LAST, +} amdgpu_hsa_elf_section_t; + +#endif // AMD_HSA_ELF_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_kernel_code.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_kernel_code.h new file mode 100644 index 000000000..15b63def4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_kernel_code.h @@ -0,0 +1,314 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef AMD_HSA_KERNEL_CODE_H +#define AMD_HSA_KERNEL_CODE_H + +#include "amd_hsa_common.h" +#include "hsa.h" + +// AMD Kernel Code Version Enumeration Values. +typedef uint32_t amd_kernel_code_version32_t; +enum amd_kernel_code_version_t { + AMD_KERNEL_CODE_VERSION_MAJOR = 1, + AMD_KERNEL_CODE_VERSION_MINOR = 1 +}; + +// AMD Machine Kind Enumeration Values. +typedef uint16_t amd_machine_kind16_t; +enum amd_machine_kind_t { + AMD_MACHINE_KIND_UNDEFINED = 0, + AMD_MACHINE_KIND_AMDGPU = 1 +}; + +// AMD Machine Version. +typedef uint16_t amd_machine_version16_t; + +// AMD Float Round Mode Enumeration Values. +enum amd_float_round_mode_t { + AMD_FLOAT_ROUND_MODE_NEAREST_EVEN = 0, + AMD_FLOAT_ROUND_MODE_PLUS_INFINITY = 1, + AMD_FLOAT_ROUND_MODE_MINUS_INFINITY = 2, + AMD_FLOAT_ROUND_MODE_ZERO = 3 +}; + +// AMD Float Denorm Mode Enumeration Values. +enum amd_float_denorm_mode_t { + AMD_FLOAT_DENORM_MODE_FLUSH_SOURCE_OUTPUT = 0, + AMD_FLOAT_DENORM_MODE_FLUSH_OUTPUT = 1, + AMD_FLOAT_DENORM_MODE_FLUSH_SOURCE = 2, + AMD_FLOAT_DENORM_MODE_NO_FLUSH = 3 +}; + +// AMD Compute Program Resource Register One. +typedef uint32_t amd_compute_pgm_rsrc_one32_t; +enum amd_compute_pgm_rsrc_one_t { + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_ONE_GRANULATED_WORKITEM_VGPR_COUNT, 0, 6), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_ONE_GRANULATED_WAVEFRONT_SGPR_COUNT, 6, 4), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_PRIORITY, 10, 2), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_FLOAT_ROUND_MODE_32, + 12, 2), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_ONE_FLOAT_ROUND_MODE_16_64, 14, 2), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_ONE_FLOAT_DENORM_MODE_32, 16, 2), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_ONE_FLOAT_DENORM_MODE_16_64, 18, 2), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_PRIV, 20, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_ENABLE_DX10_CLAMP, + 21, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_DEBUG_MODE, 22, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_ENABLE_IEEE_MODE, + 23, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_BULKY, 24, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_CDBG_USER, 25, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_ONE_RESERVED1, 26, 6) +}; + +// AMD System VGPR Workitem ID Enumeration Values. +enum amd_system_vgpr_workitem_id_t { + AMD_SYSTEM_VGPR_WORKITEM_ID_X = 0, + AMD_SYSTEM_VGPR_WORKITEM_ID_X_Y = 1, + AMD_SYSTEM_VGPR_WORKITEM_ID_X_Y_Z = 2, + AMD_SYSTEM_VGPR_WORKITEM_ID_UNDEFINED = 3 +}; + +// AMD Compute Program Resource Register Two. +typedef uint32_t amd_compute_pgm_rsrc_two32_t; +enum amd_compute_pgm_rsrc_two_t { + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_PRIVATE_SEGMENT_WAVE_BYTE_OFFSET, 0, + 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_TWO_USER_SGPR_COUNT, 1, + 5), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_TRAP_HANDLER, + 6, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_X, 7, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Y, 8, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, 9, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_INFO, 10, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_VGPR_WORKITEM_ID, 11, 2), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_ADDRESS_WATCH, 13, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_MEMORY_VIOLATION, 14, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE, + 15, 9), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_IEEE_754_FP_INVALID_OPERATION, + 24, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_FP_DENORMAL_SOURCE, 25, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_IEEE_754_FP_DIVISION_BY_ZERO, + 26, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_IEEE_754_FP_OVERFLOW, 27, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_IEEE_754_FP_UNDERFLOW, 28, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_IEEE_754_FP_INEXACT, 29, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_EXCEPTION_INT_DIVISION_BY_ZERO, 30, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_COMPUTE_PGM_RSRC_TWO_RESERVED1, 31, 1) +}; + +// AMD Element Byte Size Enumeration Values. +enum amd_element_byte_size_t { + AMD_ELEMENT_BYTE_SIZE_2 = 0, + AMD_ELEMENT_BYTE_SIZE_4 = 1, + AMD_ELEMENT_BYTE_SIZE_8 = 2, + AMD_ELEMENT_BYTE_SIZE_16 = 3 +}; + +// AMD Kernel Code Properties. +typedef uint32_t amd_kernel_code_properties32_t; +enum amd_kernel_code_properties_t { + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_PRIVATE_SEGMENT_BUFFER, 0, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_DISPATCH_PTR, 1, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_QUEUE_PTR, 2, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_KERNARG_SEGMENT_PTR, 3, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_DISPATCH_ID, 4, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_FLAT_SCRATCH_INIT, 5, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_PRIVATE_SEGMENT_SIZE, 6, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_GRID_WORKGROUP_COUNT_X, 7, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_GRID_WORKGROUP_COUNT_Y, 8, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_GRID_WORKGROUP_COUNT_Z, 9, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_KERNEL_CODE_PROPERTIES_RESERVED1, 10, 6), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_ENABLE_ORDERED_APPEND_GDS, 16, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_PRIVATE_ELEMENT_SIZE, 17, 2), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_KERNEL_CODE_PROPERTIES_IS_PTR64, 19, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_KERNEL_CODE_PROPERTIES_IS_DYNAMIC_CALLSTACK, 20, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_KERNEL_CODE_PROPERTIES_IS_DEBUG_ENABLED, + 21, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_KERNEL_CODE_PROPERTIES_IS_XNACK_ENABLED, + 22, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_KERNEL_CODE_PROPERTIES_RESERVED2, 23, 9) +}; + +// AMD Power Of Two Enumeration Values. +typedef uint8_t amd_powertwo8_t; +enum amd_powertwo_t { + AMD_POWERTWO_1 = 0, + AMD_POWERTWO_2 = 1, + AMD_POWERTWO_4 = 2, + AMD_POWERTWO_8 = 3, + AMD_POWERTWO_16 = 4, + AMD_POWERTWO_32 = 5, + AMD_POWERTWO_64 = 6, + AMD_POWERTWO_128 = 7, + AMD_POWERTWO_256 = 8 +}; + +// AMD Enabled Control Directive Enumeration Values. +typedef uint64_t amd_enabled_control_directive64_t; +enum amd_enabled_control_directive_t { + AMD_ENABLED_CONTROL_DIRECTIVE_ENABLE_BREAK_EXCEPTIONS = 1, + AMD_ENABLED_CONTROL_DIRECTIVE_ENABLE_DETECT_EXCEPTIONS = 2, + AMD_ENABLED_CONTROL_DIRECTIVE_MAX_DYNAMIC_GROUP_SIZE = 4, + AMD_ENABLED_CONTROL_DIRECTIVE_MAX_FLAT_GRID_SIZE = 8, + AMD_ENABLED_CONTROL_DIRECTIVE_MAX_FLAT_WORKGROUP_SIZE = 16, + AMD_ENABLED_CONTROL_DIRECTIVE_REQUIRED_DIM = 32, + AMD_ENABLED_CONTROL_DIRECTIVE_REQUIRED_GRID_SIZE = 64, + AMD_ENABLED_CONTROL_DIRECTIVE_REQUIRED_WORKGROUP_SIZE = 128, + AMD_ENABLED_CONTROL_DIRECTIVE_REQUIRE_NO_PARTIAL_WORKGROUPS = 256 +}; + +// AMD Exception Kind Enumeration Values. +typedef uint16_t amd_exception_kind16_t; +enum amd_exception_kind_t { + AMD_EXCEPTION_KIND_INVALID_OPERATION = 1, + AMD_EXCEPTION_KIND_DIVISION_BY_ZERO = 2, + AMD_EXCEPTION_KIND_OVERFLOW = 4, + AMD_EXCEPTION_KIND_UNDERFLOW = 8, + AMD_EXCEPTION_KIND_INEXACT = 16 +}; + +// AMD Control Directives. +#define AMD_CONTROL_DIRECTIVES_ALIGN_BYTES 64 +#define AMD_CONTROL_DIRECTIVES_ALIGN \ + __ALIGNED__(AMD_CONTROL_DIRECTIVES_ALIGN_BYTES) +typedef AMD_CONTROL_DIRECTIVES_ALIGN struct amd_control_directives_s { + amd_enabled_control_directive64_t enabled_control_directives; + uint16_t enable_break_exceptions; + uint16_t enable_detect_exceptions; + uint32_t max_dynamic_group_size; + uint64_t max_flat_grid_size; + uint32_t max_flat_workgroup_size; + uint8_t required_dim; + uint8_t reserved1[3]; + uint64_t required_grid_size[3]; + uint32_t required_workgroup_size[3]; + uint8_t reserved2[60]; +} amd_control_directives_t; + +// AMD Kernel Code. +#define AMD_ISA_ALIGN_BYTES 256 +#define AMD_KERNEL_CODE_ALIGN_BYTES 64 +#define AMD_KERNEL_CODE_ALIGN __ALIGNED__(AMD_KERNEL_CODE_ALIGN_BYTES) +typedef AMD_KERNEL_CODE_ALIGN struct amd_kernel_code_s { + amd_kernel_code_version32_t amd_kernel_code_version_major; + amd_kernel_code_version32_t amd_kernel_code_version_minor; + amd_machine_kind16_t amd_machine_kind; + amd_machine_version16_t amd_machine_version_major; + amd_machine_version16_t amd_machine_version_minor; + amd_machine_version16_t amd_machine_version_stepping; + int64_t kernel_code_entry_byte_offset; + int64_t kernel_code_prefetch_byte_offset; + uint64_t kernel_code_prefetch_byte_size; + uint64_t max_scratch_backing_memory_byte_size; + amd_compute_pgm_rsrc_one32_t compute_pgm_rsrc1; + amd_compute_pgm_rsrc_two32_t compute_pgm_rsrc2; + amd_kernel_code_properties32_t kernel_code_properties; + uint32_t workitem_private_segment_byte_size; + uint32_t workgroup_group_segment_byte_size; + uint32_t gds_segment_byte_size; + uint64_t kernarg_segment_byte_size; + uint32_t workgroup_fbarrier_count; + uint16_t wavefront_sgpr_count; + uint16_t workitem_vgpr_count; + uint16_t reserved_vgpr_first; + uint16_t reserved_vgpr_count; + uint16_t reserved_sgpr_first; + uint16_t reserved_sgpr_count; + uint16_t debug_wavefront_private_segment_offset_sgpr; + uint16_t debug_private_segment_buffer_sgpr; + amd_powertwo8_t kernarg_segment_alignment; + amd_powertwo8_t group_segment_alignment; + amd_powertwo8_t private_segment_alignment; + amd_powertwo8_t wavefront_size; + int32_t call_convention; + uint8_t reserved1[12]; + uint64_t runtime_loader_kernel_symbol; + amd_control_directives_t control_directives; +} amd_kernel_code_t; + +// TODO: this struct should be completely gone once debugger designs/implements +// Debugger APIs. +typedef struct amd_runtime_loader_debug_info_s { + const void *elf_raw; + size_t elf_size; + const char *kernel_name; + const void *owning_segment; +} amd_runtime_loader_debug_info_t; + +#endif // AMD_HSA_KERNEL_CODE_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_queue.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_queue.h new file mode 100644 index 000000000..f7760af54 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_queue.h @@ -0,0 +1,111 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef AMD_HSA_QUEUE_H +#define AMD_HSA_QUEUE_H + +#include "amd_hsa_common.h" +#include "hsa.h" + +// AMD Queue Properties. +typedef uint32_t amd_queue_properties32_t; +enum amd_queue_properties_t { + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_QUEUE_PROPERTIES_ENABLE_TRAP_HANDLER, 0, + 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_QUEUE_PROPERTIES_IS_PTR64, 1, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES( + AMD_QUEUE_PROPERTIES_ENABLE_TRAP_HANDLER_DEBUG_SGPRS, 2, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_QUEUE_PROPERTIES_ENABLE_PROFILING, 3, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_QUEUE_PROPERTIES_USE_SCRATCH_ONCE, 4, 1), + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_QUEUE_PROPERTIES_RESERVED1, 5, 27) +}; + +// AMD Queue. +#define AMD_QUEUE_ALIGN_BYTES 64 +#define AMD_QUEUE_ALIGN __ALIGNED__(AMD_QUEUE_ALIGN_BYTES) + +// AMD Queue Capabilities. +typedef uint32_t amd_queue_capabilities32_t; +enum amd_queue_capabilities_t { + /* Whether this CP queue supports dual-scratch and async-reclaim */ + AMD_HSA_BITS_CREATE_ENUM_ENTRIES(AMD_QUEUE_CAPS_ASYNC_RECLAIM, 0, 1), +}; + +// Members tagged with "async-reclaim" are ignored by CP FW's that do not +// support AMD_QUEUE_CAPS_ASYNC_RECLAIM. CP FW's that support async-reclaim also +// support dual-scratch (alternate scratch). + +typedef struct AMD_QUEUE_ALIGN amd_queue_s { + hsa_queue_t hsa_queue; + uint32_t caps; + uint32_t reserved1[3]; + volatile uint64_t write_dispatch_id; + uint32_t group_segment_aperture_base_hi; + uint32_t private_segment_aperture_base_hi; + uint32_t max_cu_id; + uint32_t max_wave_id; + volatile uint64_t max_legacy_doorbell_dispatch_id_plus_1; + volatile uint32_t legacy_doorbell_lock; + uint32_t reserved2[9]; + volatile uint64_t read_dispatch_id; + uint32_t read_dispatch_id_field_base_byte_offset; + uint32_t compute_tmpring_size; + uint32_t scratch_resource_descriptor[4]; + uint64_t scratch_backing_memory_location; + uint64_t scratch_backing_memory_byte_size; + uint32_t scratch_wave64_lane_byte_size; + amd_queue_properties32_t queue_properties; + volatile uint64_t scratch_last_used_index; /* async-reclaim */ + hsa_signal_t queue_inactive_signal; + uint32_t reserved4[2]; + volatile uint64_t alt_scratch_last_used_index; /* async-reclaim */ + uint64_t alt_scratch_backing_memory_location; /* async-reclaim */ + uint64_t alt_scratch_backing_memory_byte_size; /* async-reclaim */ + uint32_t alt_scratch_dispatch_limit_x; /* async-reclaim */ + uint32_t alt_scratch_dispatch_limit_y; /* async-reclaim */ + uint32_t alt_scratch_dispatch_limit_z; /* async-reclaim */ + uint32_t alt_scratch_wave64_lane_byte_size; /* async-reclaim */ + uint32_t alt_compute_tmpring_size; /* async-reclaim */ + uint32_t reserved5; +} amd_queue_t; + +#endif // AMD_HSA_QUEUE_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_signal.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_signal.h new file mode 100644 index 000000000..53be36ff5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/amd_hsa_signal.h @@ -0,0 +1,80 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef AMD_HSA_SIGNAL_H +#define AMD_HSA_SIGNAL_H + +#include "amd_hsa_common.h" +#include "amd_hsa_queue.h" + +// AMD Signal Kind Enumeration Values. +typedef int64_t amd_signal_kind64_t; +enum amd_signal_kind_t { + AMD_SIGNAL_KIND_INVALID = 0, + AMD_SIGNAL_KIND_USER = 1, + AMD_SIGNAL_KIND_DOORBELL = -1, + AMD_SIGNAL_KIND_LEGACY_DOORBELL = -2 +}; + +// AMD Signal. +#define AMD_SIGNAL_ALIGN_BYTES 64 +#define AMD_SIGNAL_ALIGN __ALIGNED__(AMD_SIGNAL_ALIGN_BYTES) +typedef struct AMD_SIGNAL_ALIGN amd_signal_s { + amd_signal_kind64_t kind; + union { + volatile int64_t value; + volatile uint32_t *legacy_hardware_doorbell_ptr; + volatile uint64_t *hardware_doorbell_ptr; + }; + uint64_t event_mailbox_ptr; + uint32_t event_id; + uint32_t reserved1; + uint64_t start_ts; + uint64_t end_ts; + union { + amd_queue_t *queue_ptr; + uint64_t reserved2; + }; + uint32_t reserved3[2]; +} amd_signal_t; + +#endif // AMD_HSA_SIGNAL_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa.h new file mode 100644 index 000000000..7d3142203 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa.h @@ -0,0 +1,5522 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef HSA_RUNTIME_INC_HSA_H_ +#define HSA_RUNTIME_INC_HSA_H_ + +#include /* size_t */ +#include /* uintXX_t */ + +#ifndef __cplusplus +#include /* bool */ +#endif /* __cplusplus */ + +// Placeholder for calling convention and import/export macros +#ifndef HSA_CALL +#define HSA_CALL +#endif + +#ifndef HSA_EXPORT_DECORATOR +#ifdef __GNUC__ +#define HSA_EXPORT_DECORATOR __attribute__((visibility("default"))) +#else +#define HSA_EXPORT_DECORATOR +#endif +#endif +#define HSA_API_EXPORT HSA_EXPORT_DECORATOR HSA_CALL +#define HSA_API_IMPORT HSA_CALL + +#if !defined(HSA_API) && defined(HSA_EXPORT) +#define HSA_API HSA_API_EXPORT +#else +#define HSA_API HSA_API_IMPORT +#endif + +// Detect and set large model builds. +#undef HSA_LARGE_MODEL +#if defined(__LP64__) || defined(_M_X64) +#define HSA_LARGE_MODEL +#endif + +// Try to detect CPU endianness +#if !defined(LITTLEENDIAN_CPU) && !defined(BIGENDIAN_CPU) +#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) +#define LITTLEENDIAN_CPU +#elif defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +#define BIGENDIAN_CPU +#elif defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || \ + defined(_M_X64) || defined(__loongarch64) || defined(__riscv) +#define LITTLEENDIAN_CPU +#endif +#endif + +#undef HSA_LITTLE_ENDIAN +#if defined(LITTLEENDIAN_CPU) +#define HSA_LITTLE_ENDIAN +#elif defined(BIGENDIAN_CPU) +#else +#error "BIGENDIAN_CPU or LITTLEENDIAN_CPU must be defined" +#endif + +#ifndef HSA_DEPRECATED +#define HSA_DEPRECATED +// #ifdef __GNUC__ +// #define HSA_DEPRECATED __attribute__((deprecated)) +// #else +// #define HSA_DEPRECATED __declspec(deprecated) +// #endif +#endif + +#define HSA_VERSION_1_0 1 + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** \defgroup status Runtime Notifications + * @{ + */ + +/** + * @brief Status codes. + */ +typedef enum { + /** + * The function has been executed successfully. + */ + HSA_STATUS_SUCCESS = 0x0, + /** + * A traversal over a list of elements has been interrupted by the + * application before completing. + */ + HSA_STATUS_INFO_BREAK = 0x1, + /** + * A generic error has occurred. + */ + HSA_STATUS_ERROR = 0x1000, + /** + * One of the actual arguments does not meet a precondition stated in the + * documentation of the corresponding formal argument. + */ + HSA_STATUS_ERROR_INVALID_ARGUMENT = 0x1001, + /** + * The requested queue creation is not valid. + */ + HSA_STATUS_ERROR_INVALID_QUEUE_CREATION = 0x1002, + /** + * The requested allocation is not valid. + */ + HSA_STATUS_ERROR_INVALID_ALLOCATION = 0x1003, + /** + * The agent is invalid. + */ + HSA_STATUS_ERROR_INVALID_AGENT = 0x1004, + /** + * The memory region is invalid. + */ + HSA_STATUS_ERROR_INVALID_REGION = 0x1005, + /** + * The signal is invalid. + */ + HSA_STATUS_ERROR_INVALID_SIGNAL = 0x1006, + /** + * The queue is invalid. + */ + HSA_STATUS_ERROR_INVALID_QUEUE = 0x1007, + /** + * The HSA runtime failed to allocate the necessary resources. This error + * may also occur when the HSA runtime needs to spawn threads or create + * internal OS-specific events. + */ + HSA_STATUS_ERROR_OUT_OF_RESOURCES = 0x1008, + /** + * The AQL packet is malformed. + */ + HSA_STATUS_ERROR_INVALID_PACKET_FORMAT = 0x1009, + /** + * An error has been detected while releasing a resource. + */ + HSA_STATUS_ERROR_RESOURCE_FREE = 0x100A, + /** + * An API other than ::hsa_init has been invoked while the reference count + * of the HSA runtime is 0. + */ + HSA_STATUS_ERROR_NOT_INITIALIZED = 0x100B, + /** + * The maximum reference count for the object has been reached. + */ + HSA_STATUS_ERROR_REFCOUNT_OVERFLOW = 0x100C, + /** + * The arguments passed to a functions are not compatible. + */ + HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS = 0x100D, + /** + * The index is invalid. + */ + HSA_STATUS_ERROR_INVALID_INDEX = 0x100E, + /** + * The instruction set architecture is invalid. + */ + HSA_STATUS_ERROR_INVALID_ISA = 0x100F, + /** + * The instruction set architecture name is invalid. + */ + HSA_STATUS_ERROR_INVALID_ISA_NAME = 0x1017, + /** + * The code object is invalid. + */ + HSA_STATUS_ERROR_INVALID_CODE_OBJECT = 0x1010, + /** + * The executable is invalid. + */ + HSA_STATUS_ERROR_INVALID_EXECUTABLE = 0x1011, + /** + * The executable is frozen. + */ + HSA_STATUS_ERROR_FROZEN_EXECUTABLE = 0x1012, + /** + * There is no symbol with the given name. + */ + HSA_STATUS_ERROR_INVALID_SYMBOL_NAME = 0x1013, + /** + * The variable is already defined. + */ + HSA_STATUS_ERROR_VARIABLE_ALREADY_DEFINED = 0x1014, + /** + * The variable is undefined. + */ + HSA_STATUS_ERROR_VARIABLE_UNDEFINED = 0x1015, + /** + * An HSAIL operation resulted in a hardware exception. + */ + HSA_STATUS_ERROR_EXCEPTION = 0x1016, + /** + * The code object symbol is invalid. + */ + HSA_STATUS_ERROR_INVALID_CODE_SYMBOL = 0x1018, + /** + * The executable symbol is invalid. + */ + HSA_STATUS_ERROR_INVALID_EXECUTABLE_SYMBOL = 0x1019, + /** + * The file descriptor is invalid. + */ + HSA_STATUS_ERROR_INVALID_FILE = 0x1020, + /** + * The code object reader is invalid. + */ + HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER = 0x1021, + /** + * The cache is invalid. + */ + HSA_STATUS_ERROR_INVALID_CACHE = 0x1022, + /** + * The wavefront is invalid. + */ + HSA_STATUS_ERROR_INVALID_WAVEFRONT = 0x1023, + /** + * The signal group is invalid. + */ + HSA_STATUS_ERROR_INVALID_SIGNAL_GROUP = 0x1024, + /** + * The HSA runtime is not in the configuration state. + */ + HSA_STATUS_ERROR_INVALID_RUNTIME_STATE = 0x1025, + /** + * The queue received an error that may require process termination. + */ + HSA_STATUS_ERROR_FATAL = 0x1026 +} hsa_status_t; + +/** + * @brief Query additional information about a status code. + * + * @param[in] status Status code. + * + * @param[out] status_string A NUL-terminated string that describes the error + * status. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p status is an invalid + * status code, or @p status_string is NULL. + */ +hsa_status_t HSA_API hsa_status_string(hsa_status_t status, + const char **status_string); + +/** @} */ + +/** \defgroup common Common Definitions + * @{ + */ + +/** + * @brief Three-dimensional coordinate. + */ +typedef struct hsa_dim3_s { + /** + * X dimension. + */ + uint32_t x; + + /** + * Y dimension. + */ + uint32_t y; + + /** + * Z dimension. + */ + uint32_t z; +} hsa_dim3_t; + +/** + * @brief Access permissions. + */ +typedef enum { + /** + * Used to remove existing access + */ + HSA_ACCESS_PERMISSION_NONE = 0, + /** + * Read-only access. + */ + HSA_ACCESS_PERMISSION_RO = 1, + /** + * Write-only access. + */ + HSA_ACCESS_PERMISSION_WO = 2, + /** + * Read and write access. + */ + HSA_ACCESS_PERMISSION_RW = 3 +} hsa_access_permission_t; + +/** + * @brief POSIX file descriptor. + */ +typedef int hsa_file_t; + +/** @} **/ + +/** \defgroup initshutdown Initialization and Shut Down + * @{ + */ + +/** + * @brief Initialize the HSA runtime. + * + * @details Initializes the HSA runtime if it is not already initialized, and + * increases the reference counter associated with the HSA runtime for the + * current process. Invocation of any HSA function other than ::hsa_init results + * in undefined behavior if the current HSA runtime reference counter is less + * than one. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_REFCOUNT_OVERFLOW The HSA runtime reference + * count reaches INT32_MAX. + */ +hsa_status_t HSA_API hsa_init(); + +/** + * @brief Shut down the HSA runtime. + * + * @details Decreases the reference count of the HSA runtime instance. When the + * reference count reaches 0, the HSA runtime is no longer considered valid + * but the application might call ::hsa_init to initialize the HSA runtime + * again. + * + * Once the reference count of the HSA runtime reaches 0, all the resources + * associated with it (queues, signals, agent information, etc.) are + * considered invalid and any attempt to reference them in subsequent API calls + * results in undefined behavior. When the reference count reaches 0, the HSA + * runtime may release resources associated with it. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + */ +hsa_status_t HSA_API hsa_shut_down(); + +/** @} **/ + +/** \defgroup agentinfo System and Agent Information + * @{ + */ + +/** + * @brief Endianness. A convention used to interpret the bytes making up a data + * word. + */ +typedef enum { + /** + * The least significant byte is stored in the smallest address. + */ + HSA_ENDIANNESS_LITTLE = 0, + /** + * The most significant byte is stored in the smallest address. + */ + HSA_ENDIANNESS_BIG = 1 +} hsa_endianness_t; + +/** + * @brief Machine model. A machine model determines the size of certain data + * types in HSA runtime and an agent. + */ +typedef enum { + /** + * Small machine model. Addresses use 32 bits. + */ + HSA_MACHINE_MODEL_SMALL = 0, + /** + * Large machine model. Addresses use 64 bits. + */ + HSA_MACHINE_MODEL_LARGE = 1 +} hsa_machine_model_t; + +/** + * @brief Profile. A profile indicates a particular level of feature + * support. For example, in the base profile the application must use the HSA + * runtime allocator to reserve shared virtual memory, while in the full profile + * any host pointer can be shared across all the agents. + */ +typedef enum { + /** + * Base profile. + */ + HSA_PROFILE_BASE = 0, + /** + * Full profile. + */ + HSA_PROFILE_FULL = 1 +} hsa_profile_t; + +/** + * @brief System attributes. + */ +typedef enum { + /** + * Major version of the HSA runtime specification supported by the + * implementation. The type of this attribute is uint16_t. + */ + HSA_SYSTEM_INFO_VERSION_MAJOR = 0, + /** + * Minor version of the HSA runtime specification supported by the + * implementation. The type of this attribute is uint16_t. + */ + HSA_SYSTEM_INFO_VERSION_MINOR = 1, + /** + * Current timestamp. The value of this attribute monotonically increases at a + * constant rate. The type of this attribute is uint64_t. + */ + HSA_SYSTEM_INFO_TIMESTAMP = 2, + /** + * Timestamp value increase rate, in Hz. The timestamp (clock) frequency is + * in the range 1-400MHz. The type of this attribute is uint64_t. + */ + HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY = 3, + /** + * Maximum duration of a signal wait operation. Expressed as a count based on + * the timestamp frequency. The type of this attribute is uint64_t. + */ + HSA_SYSTEM_INFO_SIGNAL_MAX_WAIT = 4, + /** + * Endianness of the system. The type of this attribute is ::hsa_endianness_t. + */ + HSA_SYSTEM_INFO_ENDIANNESS = 5, + /** + * Machine model supported by the HSA runtime. The type of this attribute is + * ::hsa_machine_model_t. + */ + HSA_SYSTEM_INFO_MACHINE_MODEL = 6, + /** + * Bit-mask indicating which extensions are supported by the + * implementation. An extension with an ID of @p i is supported if the bit at + * position @p i is set. The type of this attribute is uint8_t[128]. + */ + HSA_SYSTEM_INFO_EXTENSIONS = 7, + /** + * String containing the ROCr build identifier. + */ + HSA_AMD_SYSTEM_INFO_BUILD_VERSION = 0x200, + /** + * Returns true if hsa_amd_svm_* APIs are supported by the driver. The type + * of this attribute is bool. + */ + HSA_AMD_SYSTEM_INFO_SVM_SUPPORTED = 0x201, + // TODO: Should this be per Agent? + /** + * Returns true if all Agents have access to system allocated memory (such as + * that allocated by mmap, malloc, or new) by default. + * If false then system allocated memory may only be made SVM accessible to + * an Agent by declaration of accessibility with hsa_amd_svm_set_attributes. + * The type of this attribute is bool. + */ + HSA_AMD_SYSTEM_INFO_SVM_ACCESSIBLE_BY_DEFAULT = 0x202, + /** + * Returns true if mwaitx is enabled on this system + * The type of this attribute is bool. + */ + HSA_AMD_SYSTEM_INFO_MWAITX_ENABLED = 0x203, + /** + * Returns true if DMABUF APIs are supported by the driver. The type of + * this attribute is bool. + */ + HSA_AMD_SYSTEM_INFO_DMABUF_SUPPORTED = 0x204, + /** + * Returns true if Virtual Memory APIs are supported by the driver. The type + * of this attribute is bool. + */ + HSA_AMD_SYSTEM_INFO_VIRTUAL_MEM_API_SUPPORTED = 0x205, + /** + * Returns true if XNACK is enabled on this system. The type of + * this attribute is bool. + */ + HSA_AMD_SYSTEM_INFO_XNACK_ENABLED = 0x206, + /** + * Major version of the HSA runtime extension specification supported by the + * implementation. The type of this attribute is uint16_t. + */ + HSA_AMD_SYSTEM_INFO_EXT_VERSION_MAJOR = 0x207, + /** + * Minor version of the HSA runtime extension specification supported by the + * implementation. The type of this attribute is uint16_t. + */ + HSA_AMD_SYSTEM_INFO_EXT_VERSION_MINOR = 0x208, +} hsa_system_info_t; + +/** + * @brief Get the current value of a system attribute. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * system attribute, or @p value is NULL. + */ +hsa_status_t HSA_API hsa_system_get_info(hsa_system_info_t attribute, + void *value); + +/** + * @brief HSA extensions. + */ +typedef enum { + /** + * Finalizer extension. + */ + HSA_EXTENSION_FINALIZER = 0, + /** + * Images extension. + */ + HSA_EXTENSION_IMAGES = 1, + + /** + * Performance counter extension. + */ + HSA_EXTENSION_PERFORMANCE_COUNTERS = 2, + + /** + * Profiling events extension. + */ + HSA_EXTENSION_PROFILING_EVENTS = 3, + /** + * Extension count. + */ + HSA_EXTENSION_STD_LAST = 3, + /** + * First AMD extension number. + */ + HSA_AMD_FIRST_EXTENSION = 0x200, + /** + * Profiler extension. + */ + HSA_EXTENSION_AMD_PROFILER = 0x200, + /** + * Loader extension. + */ + HSA_EXTENSION_AMD_LOADER = 0x201, + /** + * AqlProfile extension. + */ + HSA_EXTENSION_AMD_AQLPROFILE = 0x202, + /** + * PC Sampling extension. + */ + HSA_EXTENSION_AMD_PC_SAMPLING = 0x203, + /** + * Last AMD extension. + */ + HSA_AMD_LAST_EXTENSION = 0x203 +} hsa_extension_t; + +/** + * @brief Query the name of a given extension. + * + * @param[in] extension Extension identifier. If the extension is not supported + * by the implementation (see ::HSA_SYSTEM_INFO_EXTENSIONS), the behavior + * is undefined. + * + * @param[out] name Pointer to a memory location where the HSA runtime stores + * the extension name. The extension name is a NUL-terminated string. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p extension is not a valid + * extension, or @p name is NULL. + */ +hsa_status_t HSA_API hsa_extension_get_name(uint16_t extension, + const char **name); + +/** + * @deprecated + * + * @brief Query if a given version of an extension is supported by the HSA + * implementation. + * + * @param[in] extension Extension identifier. + * + * @param[in] version_major Major version number. + * + * @param[in] version_minor Minor version number. + * + * @param[out] result Pointer to a memory location where the HSA runtime stores + * the result of the check. The result is true if the specified version of the + * extension is supported, and false otherwise. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p extension is not a valid + * extension, or @p result is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED +hsa_system_extension_supported(uint16_t extension, uint16_t version_major, + uint16_t version_minor, bool *result); + +/** + * @brief Query if a given version of an extension is supported by the HSA + * implementation. All minor versions from 0 up to the returned @p version_minor + * must be supported by the implementation. + * + * @param[in] extension Extension identifier. + * + * @param[in] version_major Major version number. + * + * @param[out] version_minor Minor version number. + * + * @param[out] result Pointer to a memory location where the HSA runtime stores + * the result of the check. The result is true if the specified version of the + * extension is supported, and false otherwise. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p extension is not a valid + * extension, or @p version_minor is NULL, or @p result is NULL. + */ +hsa_status_t HSA_API +hsa_system_major_extension_supported(uint16_t extension, uint16_t version_major, + uint16_t *version_minor, bool *result); + +/** + * @deprecated + * + * @brief Retrieve the function pointers corresponding to a given version of an + * extension. Portable applications are expected to invoke the extension API + * using the returned function pointers + * + * @details The application is responsible for verifying that the given version + * of the extension is supported by the HSA implementation (see + * ::hsa_system_extension_supported). If the given combination of extension, + * major version, and minor version is not supported by the implementation, the + * behavior is undefined. + * + * @param[in] extension Extension identifier. + * + * @param[in] version_major Major version number for which to retrieve the + * function pointer table. + * + * @param[in] version_minor Minor version number for which to retrieve the + * function pointer table. + * + * @param[out] table Pointer to an application-allocated function pointer table + * that is populated by the HSA runtime. Must not be NULL. The memory associated + * with table can be reused or freed after the function returns. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p extension is not a valid + * extension, or @p table is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED +hsa_system_get_extension_table(uint16_t extension, uint16_t version_major, + uint16_t version_minor, void *table); + +/** + * @brief Retrieve the function pointers corresponding to a given major version + * of an extension. Portable applications are expected to invoke the extension + * API using the returned function pointers. + * + * @details The application is responsible for verifying that the given major + * version of the extension is supported by the HSA implementation (see + * ::hsa_system_major_extension_supported). If the given combination of + * extension and major version is not supported by the implementation, the + * behavior is undefined. Additionally if the length doesn't allow space for a + * full minor version, it is implementation defined if only some of the function + * pointers for that minor version get written. + * + * @param[in] extension Extension identifier. + * + * @param[in] version_major Major version number for which to retrieve the + * function pointer table. + * + * @param[in] table_length Size in bytes of the function pointer table to be + * populated. The implementation will not write more than this many bytes to the + * table. + * + * @param[out] table Pointer to an application-allocated function pointer table + * that is populated by the HSA runtime. Must not be NULL. The memory associated + * with table can be reused or freed after the function returns. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p extension is not a valid + * extension, or @p table is NULL. + */ +hsa_status_t HSA_API +hsa_system_get_major_extension_table(uint16_t extension, uint16_t version_major, + size_t table_length, void *table); + +/** + * @brief Struct containing an opaque handle to an agent, a device that + * participates in the HSA memory model. An agent can submit AQL packets for + * execution, and may also accept AQL packets for execution (agent dispatch + * packets or kernel dispatch packets launching HSAIL-derived binaries). + */ +typedef struct hsa_agent_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_agent_t; + +/** + * @brief Agent features. + */ +typedef enum { + /** + * The agent supports AQL packets of kernel dispatch type. If this + * feature is enabled, the agent is also a kernel agent. + */ + HSA_AGENT_FEATURE_KERNEL_DISPATCH = 1, + /** + * The agent supports AQL packets of agent dispatch type. + */ + HSA_AGENT_FEATURE_AGENT_DISPATCH = 2 +} hsa_agent_feature_t; + +/** + * @brief Hardware device type. + */ +typedef enum { + /** + * CPU device. + */ + HSA_DEVICE_TYPE_CPU = 0, + /** + * GPU device. + */ + HSA_DEVICE_TYPE_GPU = 1, + /** + * DSP device. + */ + HSA_DEVICE_TYPE_DSP = 2 +} hsa_device_type_t; + +/** + * @brief Default floating-point rounding mode. + */ +typedef enum { + /** + * Use a default floating-point rounding mode specified elsewhere. + */ + HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT = 0, + /** + * Operations that specify the default floating-point mode are rounded to zero + * by default. + */ + HSA_DEFAULT_FLOAT_ROUNDING_MODE_ZERO = 1, + /** + * Operations that specify the default floating-point mode are rounded to the + * nearest representable number and that ties should be broken by selecting + * the value with an even least significant bit. + */ + HSA_DEFAULT_FLOAT_ROUNDING_MODE_NEAR = 2 +} hsa_default_float_rounding_mode_t; + +/** + * @brief Agent attributes. + */ +typedef enum { + /** + * Agent name. The type of this attribute is a NUL-terminated char[64]. The + * name must be at most 63 characters long (not including the NUL terminator) + * and all array elements not used for the name must be NUL. + */ + HSA_AGENT_INFO_NAME = 0, + /** + * Name of vendor. The type of this attribute is a NUL-terminated char[64]. + * The name must be at most 63 characters long (not including the NUL + * terminator) and all array elements not used for the name must be NUL. + */ + HSA_AGENT_INFO_VENDOR_NAME = 1, + /** + * Agent capability. The type of this attribute is ::hsa_agent_feature_t. + */ + HSA_AGENT_INFO_FEATURE = 2, + /** + * @deprecated Query ::HSA_ISA_INFO_MACHINE_MODELS for a given intruction set + * architecture supported by the agent instead. If more than one ISA is + * supported by the agent, the returned value corresponds to the first ISA + * enumerated by ::hsa_agent_iterate_isas. + * + * Machine model supported by the agent. The type of this attribute is + * ::hsa_machine_model_t. + */ + HSA_AGENT_INFO_MACHINE_MODEL = 3, + /** + * @deprecated Query ::HSA_ISA_INFO_PROFILES for a given intruction set + * architecture supported by the agent instead. If more than one ISA is + * supported by the agent, the returned value corresponds to the first ISA + * enumerated by ::hsa_agent_iterate_isas. + * + * Profile supported by the agent. The type of this attribute is + * ::hsa_profile_t. + */ + HSA_AGENT_INFO_PROFILE = 4, + /** + * @deprecated Query ::HSA_ISA_INFO_DEFAULT_FLOAT_ROUNDING_MODES for a given + * intruction set architecture supported by the agent instead. If more than + * one ISA is supported by the agent, the returned value corresponds to the + * first ISA enumerated by ::hsa_agent_iterate_isas. + * + * Default floating-point rounding mode. The type of this attribute is + * ::hsa_default_float_rounding_mode_t, but the value + * ::HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT is not allowed. + */ + HSA_AGENT_INFO_DEFAULT_FLOAT_ROUNDING_MODE = 5, + /** + * @deprecated Query ::HSA_ISA_INFO_BASE_PROFILE_DEFAULT_FLOAT_ROUNDING_MODES + * for a given intruction set architecture supported by the agent instead. If + * more than one ISA is supported by the agent, the returned value corresponds + * to the first ISA enumerated by ::hsa_agent_iterate_isas. + * + * A bit-mask of ::hsa_default_float_rounding_mode_t values, representing the + * default floating-point rounding modes supported by the agent in the Base + * profile. The type of this attribute is uint32_t. The default floating-point + * rounding mode (::HSA_AGENT_INFO_DEFAULT_FLOAT_ROUNDING_MODE) bit must not + * be set. + */ + HSA_AGENT_INFO_BASE_PROFILE_DEFAULT_FLOAT_ROUNDING_MODES = 23, + /** + * @deprecated Query ::HSA_ISA_INFO_FAST_F16_OPERATION for a given intruction + * set architecture supported by the agent instead. If more than one ISA is + * supported by the agent, the returned value corresponds to the first ISA + * enumerated by ::hsa_agent_iterate_isas. + * + * Flag indicating that the f16 HSAIL operation is at least as fast as the + * f32 operation in the current agent. The value of this attribute is + * undefined if the agent is not a kernel agent. The type of this + * attribute is bool. + */ + HSA_AGENT_INFO_FAST_F16_OPERATION = 24, + /** + * @deprecated Query ::HSA_WAVEFRONT_INFO_SIZE for a given wavefront and + * intruction set architecture supported by the agent instead. If more than + * one ISA is supported by the agent, the returned value corresponds to the + * first ISA enumerated by ::hsa_agent_iterate_isas and the first wavefront + * enumerated by ::hsa_isa_iterate_wavefronts for that ISA. + * + * Number of work-items in a wavefront. Must be a power of 2 in the range + * [1,256]. The value of this attribute is undefined if the agent is not + * a kernel agent. The type of this attribute is uint32_t. + */ + HSA_AGENT_INFO_WAVEFRONT_SIZE = 6, + /** + * @deprecated Query ::HSA_ISA_INFO_WORKGROUP_MAX_DIM for a given intruction + * set architecture supported by the agent instead. If more than one ISA is + * supported by the agent, the returned value corresponds to the first ISA + * enumerated by ::hsa_agent_iterate_isas. + * + * Maximum number of work-items of each dimension of a work-group. Each + * maximum must be greater than 0. No maximum can exceed the value of + * ::HSA_AGENT_INFO_WORKGROUP_MAX_SIZE. The value of this attribute is + * undefined if the agent is not a kernel agent. The type of this + * attribute is uint16_t[3]. + */ + HSA_AGENT_INFO_WORKGROUP_MAX_DIM = 7, + /** + * @deprecated Query ::HSA_ISA_INFO_WORKGROUP_MAX_SIZE for a given intruction + * set architecture supported by the agent instead. If more than one ISA is + * supported by the agent, the returned value corresponds to the first ISA + * enumerated by ::hsa_agent_iterate_isas. + * + * Maximum total number of work-items in a work-group. The value of this + * attribute is undefined if the agent is not a kernel agent. The type + * of this attribute is uint32_t. + */ + HSA_AGENT_INFO_WORKGROUP_MAX_SIZE = 8, + /** + * @deprecated Query ::HSA_ISA_INFO_GRID_MAX_DIM for a given intruction set + * architecture supported by the agent instead. + * + * Maximum number of work-items of each dimension of a grid. Each maximum must + * be greater than 0, and must not be smaller than the corresponding value in + * ::HSA_AGENT_INFO_WORKGROUP_MAX_DIM. No maximum can exceed the value of + * ::HSA_AGENT_INFO_GRID_MAX_SIZE. The value of this attribute is undefined + * if the agent is not a kernel agent. The type of this attribute is + * ::hsa_dim3_t. + */ + HSA_AGENT_INFO_GRID_MAX_DIM = 9, + /** + * @deprecated Query ::HSA_ISA_INFO_GRID_MAX_SIZE for a given intruction set + * architecture supported by the agent instead. If more than one ISA is + * supported by the agent, the returned value corresponds to the first ISA + * enumerated by ::hsa_agent_iterate_isas. + * + * Maximum total number of work-items in a grid. The value of this attribute + * is undefined if the agent is not a kernel agent. The type of this + * attribute is uint32_t. + */ + HSA_AGENT_INFO_GRID_MAX_SIZE = 10, + /** + * @deprecated Query ::HSA_ISA_INFO_FBARRIER_MAX_SIZE for a given intruction + * set architecture supported by the agent instead. If more than one ISA is + * supported by the agent, the returned value corresponds to the first ISA + * enumerated by ::hsa_agent_iterate_isas. + * + * Maximum number of fbarriers per work-group. Must be at least 32. The value + * of this attribute is undefined if the agent is not a kernel agent. The + * type of this attribute is uint32_t. + */ + HSA_AGENT_INFO_FBARRIER_MAX_SIZE = 11, + /** + * @deprecated The maximum number of queues is not statically determined. + * + * Maximum number of queues that can be active (created but not destroyed) at + * one time in the agent. The type of this attribute is uint32_t. + */ + HSA_AGENT_INFO_QUEUES_MAX = 12, + /** + * Minimum number of packets that a queue created in the agent + * can hold. Must be a power of 2 greater than 0. Must not exceed + * the value of ::HSA_AGENT_INFO_QUEUE_MAX_SIZE. The type of this + * attribute is uint32_t. + */ + HSA_AGENT_INFO_QUEUE_MIN_SIZE = 13, + /** + * Maximum number of packets that a queue created in the agent can + * hold. Must be a power of 2 greater than 0. The type of this attribute + * is uint32_t. + */ + HSA_AGENT_INFO_QUEUE_MAX_SIZE = 14, + /** + * Type of a queue created in the agent. The type of this attribute is + * ::hsa_queue_type32_t. + */ + HSA_AGENT_INFO_QUEUE_TYPE = 15, + /** + * @deprecated NUMA information is not exposed anywhere else in the API. + * + * Identifier of the NUMA node associated with the agent. The type of this + * attribute is uint32_t. + */ + HSA_AGENT_INFO_NODE = 16, + /** + * Type of hardware device associated with the agent. The type of this + * attribute is ::hsa_device_type_t. + */ + HSA_AGENT_INFO_DEVICE = 17, + /** + * @deprecated Query ::hsa_agent_iterate_caches to retrieve information about + * the caches present in a given agent. + * + * Array of data cache sizes (L1..L4). Each size is expressed in bytes. A size + * of 0 for a particular level indicates that there is no cache information + * for that level. The type of this attribute is uint32_t[4]. + */ + HSA_AGENT_INFO_CACHE_SIZE = 18, + /** + * @deprecated An agent may support multiple instruction set + * architectures. See ::hsa_agent_iterate_isas. If more than one ISA is + * supported by the agent, the returned value corresponds to the first ISA + * enumerated by ::hsa_agent_iterate_isas. + * + * Instruction set architecture of the agent. The type of this attribute + * is ::hsa_isa_t. + */ + HSA_AGENT_INFO_ISA = 19, + /** + * Bit-mask indicating which extensions are supported by the agent. An + * extension with an ID of @p i is supported if the bit at position @p i is + * set. The type of this attribute is uint8_t[128]. + */ + HSA_AGENT_INFO_EXTENSIONS = 20, + /** + * Major version of the HSA runtime specification supported by the + * agent. The type of this attribute is uint16_t. + */ + HSA_AGENT_INFO_VERSION_MAJOR = 21, + /** + * Minor version of the HSA runtime specification supported by the + * agent. The type of this attribute is uint16_t. + */ + HSA_AGENT_INFO_VERSION_MINOR = 22, + /** + * This enum does not have a fixed underlying type, thus in C++ post D2338: + * If the enumeration type does not have a fixed underlying type, the value is + * unchanged if the original value is within the range of the enumeration + * values (9.7.1 [dcl.enum]), and otherwise, the behavior is + * undefined. + * Thus increase the range of this enum to encompass vendor extensions. + */ + HSA_AGENT_INFO_LAST = INT32_MAX +} hsa_agent_info_t; + +/** + * @brief Get the current value of an attribute for a given agent. + * + * @param[in] agent A valid agent. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * agent attribute, or @p value is NULL. + */ +hsa_status_t HSA_API hsa_agent_get_info(hsa_agent_t agent, + hsa_agent_info_t attribute, + void *value); + +/** + * @brief Iterate over the available agents, and invoke an + * application-defined callback on every iteration. + * + * @param[in] callback Callback to be invoked once per agent. The HSA + * runtime passes two arguments to the callback: the agent and the + * application data. If @p callback returns a status other than + * ::HSA_STATUS_SUCCESS for a particular iteration, the traversal stops and + * ::hsa_iterate_agents returns that status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_iterate_agents( + hsa_status_t (*callback)(hsa_agent_t agent, void *data), void *data); + +/* + +// If we do not know the size of an attribute, we need to query it first +// Note: this API will not be in the spec unless needed +hsa_status_t HSA_API hsa_agent_get_info_size( + hsa_agent_t agent, + hsa_agent_info_t attribute, + size_t* size); + +// Set the value of an agents attribute +// Note: this API will not be in the spec unless needed +hsa_status_t HSA_API hsa_agent_set_info( + hsa_agent_t agent, + hsa_agent_info_t attribute, + void* value); + +*/ + +/** + * @brief Exception policies applied in the presence of hardware exceptions. + */ +typedef enum { + /** + * If a hardware exception is detected, a work-item signals an exception. + */ + HSA_EXCEPTION_POLICY_BREAK = 1, + /** + * If a hardware exception is detected, a hardware status bit is set. + */ + HSA_EXCEPTION_POLICY_DETECT = 2 +} hsa_exception_policy_t; + +/** + * @deprecated Use ::hsa_isa_get_exception_policies for a given intruction set + * architecture supported by the agent instead. If more than one ISA is + * supported by the agent, this function uses the first value returned by + * ::hsa_agent_iterate_isas. + * + * @brief Retrieve the exception policy support for a given combination of + * agent and profile + * + * @param[in] agent Agent. + * + * @param[in] profile Profile. + * + * @param[out] mask Pointer to a memory location where the HSA runtime stores a + * mask of ::hsa_exception_policy_t values. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p profile is not a valid + * profile, or @p mask is NULL. + * + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_agent_get_exception_policies( + hsa_agent_t agent, hsa_profile_t profile, uint16_t *mask); + +/** + * @brief Cache handle. + */ +typedef struct hsa_cache_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_cache_t; + +/** + * @brief Cache attributes. + */ +typedef enum { + /** + * The length of the cache name in bytes, not including the NUL terminator. + * The type of this attribute is uint32_t. + */ + HSA_CACHE_INFO_NAME_LENGTH = 0, + /** + * Human-readable description. The type of this attribute is a NUL-terminated + * character array with the length equal to the value of + * ::HSA_CACHE_INFO_NAME_LENGTH attribute. + */ + HSA_CACHE_INFO_NAME = 1, + /** + * Cache level. A L1 cache must return a value of 1, a L2 must return a value + * of 2, and so on. The type of this attribute is uint8_t. + */ + HSA_CACHE_INFO_LEVEL = 2, + /** + * Cache size, in bytes. A value of 0 indicates that there is no size + * information available. The type of this attribute is uint32_t. + */ + HSA_CACHE_INFO_SIZE = 3 +} hsa_cache_info_t; + +/** + * @brief Get the current value of an attribute for a given cache object. + * + * @param[in] cache Cache. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CACHE The cache is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * instruction set architecture attribute, or @p value is + * NULL. + */ +hsa_status_t HSA_API hsa_cache_get_info(hsa_cache_t cache, + hsa_cache_info_t attribute, + void *value); + +/** + * @brief Iterate over the memory caches of a given agent, and + * invoke an application-defined callback on every iteration. + * + * @details Caches are visited in ascending order according to the value of the + * ::HSA_CACHE_INFO_LEVEL attribute. + * + * @param[in] agent A valid agent. + * + * @param[in] callback Callback to be invoked once per cache that is present in + * the agent. The HSA runtime passes two arguments to the callback: the cache + * and the application data. If @p callback returns a status other than + * ::HSA_STATUS_SUCCESS for a particular iteration, the traversal stops and + * that value is returned. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_agent_iterate_caches( + hsa_agent_t agent, hsa_status_t (*callback)(hsa_cache_t cache, void *data), + void *data); + +/** + * @deprecated + * + * @brief Query if a given version of an extension is supported by an agent + * + * @param[in] extension Extension identifier. + * + * @param[in] agent Agent. + * + * @param[in] version_major Major version number. + * + * @param[in] version_minor Minor version number. + * + * @param[out] result Pointer to a memory location where the HSA runtime stores + * the result of the check. The result is true if the specified version of the + * extension is supported, and false otherwise. The result must be false if + * ::hsa_system_extension_supported returns false for the same extension + * version. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p extension is not a valid + * extension, or @p result is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_agent_extension_supported( + uint16_t extension, hsa_agent_t agent, uint16_t version_major, + uint16_t version_minor, bool *result); + +/** + * @brief Query if a given version of an extension is supported by an agent. All + * minor versions from 0 up to the returned @p version_minor must be supported. + * + * @param[in] extension Extension identifier. + * + * @param[in] agent Agent. + * + * @param[in] version_major Major version number. + * + * @param[out] version_minor Minor version number. + * + * @param[out] result Pointer to a memory location where the HSA runtime stores + * the result of the check. The result is true if the specified version of the + * extension is supported, and false otherwise. The result must be false if + * ::hsa_system_extension_supported returns false for the same extension + * version. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p extension is not a valid + * extension, or @p version_minor is NULL, or @p result is NULL. + */ +hsa_status_t HSA_API hsa_agent_major_extension_supported( + uint16_t extension, hsa_agent_t agent, uint16_t version_major, + uint16_t *version_minor, bool *result); + +/** @} */ + +/** \defgroup signals Signals + * @{ + */ + +/** + * @brief Signal handle. + */ +typedef struct hsa_signal_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. The value 0 is reserved. + */ + uint64_t handle; +} hsa_signal_t; + +/** + * @brief Signal value. The value occupies 32 bits in small machine mode, and 64 + * bits in large machine mode. + */ +#ifdef HSA_LARGE_MODEL +typedef int64_t hsa_signal_value_t; +#else +typedef int32_t hsa_signal_value_t; +#endif + +/** + * @brief Create a signal. + * + * @param[in] initial_value Initial value of the signal. + * + * @param[in] num_consumers Size of @p consumers. A value of 0 indicates that + * any agent might wait on the signal. + * + * @param[in] consumers List of agents that might consume (wait on) the + * signal. If @p num_consumers is 0, this argument is ignored; otherwise, the + * HSA runtime might use the list to optimize the handling of the signal + * object. If an agent not listed in @p consumers waits on the returned + * signal, the behavior is undefined. The memory associated with @p consumers + * can be reused or freed after the function returns. + * + * @param[out] signal Pointer to a memory location where the HSA runtime will + * store the newly created signal handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p signal is NULL, @p + * num_consumers is greater than 0 but @p consumers is NULL, or @p consumers + * contains duplicates. + */ +hsa_status_t HSA_API hsa_signal_create(hsa_signal_value_t initial_value, + uint32_t num_consumers, + const hsa_agent_t *consumers, + hsa_signal_t *signal); + +/** + * @brief Destroy a signal previous created by ::hsa_signal_create. + * + * @param[in] signal Signal. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SIGNAL @p signal is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT The handle in @p signal is 0. + */ +hsa_status_t HSA_API hsa_signal_destroy(hsa_signal_t signal); + +/** + * @brief Atomically read the current value of a signal. + * + * @param[in] signal Signal. + * + * @return Value of the signal. + */ +hsa_signal_value_t HSA_API hsa_signal_load_scacquire(hsa_signal_t signal); + +/** + * @copydoc hsa_signal_load_scacquire + */ +hsa_signal_value_t HSA_API hsa_signal_load_relaxed(hsa_signal_t signal); + +/** + * @deprecated Renamed as ::hsa_signal_load_scacquire. + * + * @copydoc hsa_signal_load_scacquire + */ +hsa_signal_value_t HSA_API HSA_DEPRECATED +hsa_signal_load_acquire(hsa_signal_t signal); + +/** + * @brief Atomically set the value of a signal. + * + * @details If the value of the signal is changed, all the agents waiting + * on @p signal for which @p value satisfies their wait condition are awakened. + * + * @param[in] signal Signal. + * + * @param[in] value New signal value. + */ +void HSA_API hsa_signal_store_relaxed(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_store_relaxed + */ +void HSA_API hsa_signal_store_screlease(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_store_screlease. + * + * @copydoc hsa_signal_store_screlease + */ +void HSA_API HSA_DEPRECATED hsa_signal_store_release(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @brief Atomically set the value of a signal without necessarily notifying the + * the agents waiting on it. + * + * @details The agents waiting on @p signal may not wake up even when the new + * value satisfies their wait condition. If the application wants to update the + * signal and there is no need to notify any agent, invoking this function can + * be more efficient than calling the non-silent counterpart. + * + * @param[in] signal Signal. + * + * @param[in] value New signal value. + */ +void HSA_API hsa_signal_silent_store_relaxed(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_silent_store_relaxed + */ +void HSA_API hsa_signal_silent_store_screlease(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @brief Atomically set the value of a signal and return its previous value. + * + * @details If the value of the signal is changed, all the agents waiting + * on @p signal for which @p value satisfies their wait condition are awakened. + * + * @param[in] signal Signal. If @p signal is a queue doorbell signal, the + * behavior is undefined. + * + * @param[in] value New value. + * + * @return Value of the signal prior to the exchange. + * + */ +hsa_signal_value_t HSA_API +hsa_signal_exchange_scacq_screl(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_exchange_scacq_screl. + * + * @copydoc hsa_signal_exchange_scacq_screl + */ +hsa_signal_value_t HSA_API HSA_DEPRECATED +hsa_signal_exchange_acq_rel(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_exchange_scacq_screl + */ +hsa_signal_value_t HSA_API +hsa_signal_exchange_scacquire(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_exchange_scacquire. + * + * @copydoc hsa_signal_exchange_scacquire + */ +hsa_signal_value_t HSA_API HSA_DEPRECATED +hsa_signal_exchange_acquire(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_exchange_scacq_screl + */ +hsa_signal_value_t HSA_API +hsa_signal_exchange_relaxed(hsa_signal_t signal, hsa_signal_value_t value); +/** + * @copydoc hsa_signal_exchange_scacq_screl + */ +hsa_signal_value_t HSA_API +hsa_signal_exchange_screlease(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_exchange_screlease. + * + * @copydoc hsa_signal_exchange_screlease + */ +hsa_signal_value_t HSA_API HSA_DEPRECATED +hsa_signal_exchange_release(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @brief Atomically set the value of a signal if the observed value is equal to + * the expected value. The observed value is returned regardless of whether the + * replacement was done. + * + * @details If the value of the signal is changed, all the agents waiting + * on @p signal for which @p value satisfies their wait condition are awakened. + * + * @param[in] signal Signal. If @p signal is a queue + * doorbell signal, the behavior is undefined. + * + * @param[in] expected Value to compare with. + * + * @param[in] value New value. + * + * @return Observed value of the signal. + * + */ +hsa_signal_value_t HSA_API hsa_signal_cas_scacq_screl( + hsa_signal_t signal, hsa_signal_value_t expected, hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_cas_scacq_screl. + * + * @copydoc hsa_signal_cas_scacq_screl + */ +hsa_signal_value_t HSA_API HSA_DEPRECATED hsa_signal_cas_acq_rel( + hsa_signal_t signal, hsa_signal_value_t expected, hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_cas_scacq_screl + */ +hsa_signal_value_t HSA_API hsa_signal_cas_scacquire(hsa_signal_t signal, + hsa_signal_value_t expected, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_cas_scacquire. + * + * @copydoc hsa_signal_cas_scacquire + */ +hsa_signal_value_t HSA_API HSA_DEPRECATED hsa_signal_cas_acquire( + hsa_signal_t signal, hsa_signal_value_t expected, hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_cas_scacq_screl + */ +hsa_signal_value_t HSA_API hsa_signal_cas_relaxed(hsa_signal_t signal, + hsa_signal_value_t expected, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_cas_scacq_screl + */ +hsa_signal_value_t HSA_API hsa_signal_cas_screlease(hsa_signal_t signal, + hsa_signal_value_t expected, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_cas_screlease. + * + * @copydoc hsa_signal_cas_screlease + */ +hsa_signal_value_t HSA_API HSA_DEPRECATED hsa_signal_cas_release( + hsa_signal_t signal, hsa_signal_value_t expected, hsa_signal_value_t value); + +/** + * @brief Atomically increment the value of a signal by a given amount. + * + * @details If the value of the signal is changed, all the agents waiting on + * @p signal for which @p value satisfies their wait condition are awakened. + * + * @param[in] signal Signal. If @p signal is a queue doorbell signal, the + * behavior is undefined. + * + * @param[in] value Value to add to the value of the signal. + * + */ +void HSA_API hsa_signal_add_scacq_screl(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_add_scacq_screl. + * + * @copydoc hsa_signal_add_scacq_screl + */ +void HSA_API HSA_DEPRECATED hsa_signal_add_acq_rel(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_add_scacq_screl + */ +void HSA_API hsa_signal_add_scacquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_add_scacquire. + * + * @copydoc hsa_signal_add_scacquire + */ +void HSA_API HSA_DEPRECATED hsa_signal_add_acquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_add_scacq_screl + */ +void HSA_API hsa_signal_add_relaxed(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_add_scacq_screl + */ +void HSA_API hsa_signal_add_screlease(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_add_screlease. + * + * @copydoc hsa_signal_add_screlease + */ +void HSA_API HSA_DEPRECATED hsa_signal_add_release(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @brief Atomically decrement the value of a signal by a given amount. + * + * @details If the value of the signal is changed, all the agents waiting on + * @p signal for which @p value satisfies their wait condition are awakened. + * + * @param[in] signal Signal. If @p signal is a queue doorbell signal, the + * behavior is undefined. + * + * @param[in] value Value to subtract from the value of the signal. + * + */ +void HSA_API hsa_signal_subtract_scacq_screl(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_subtract_scacq_screl. + * + * @copydoc hsa_signal_subtract_scacq_screl + */ +void HSA_API HSA_DEPRECATED +hsa_signal_subtract_acq_rel(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_subtract_scacq_screl + */ +void HSA_API hsa_signal_subtract_scacquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_subtract_scacquire. + * + * @copydoc hsa_signal_subtract_scacquire + */ +void HSA_API HSA_DEPRECATED +hsa_signal_subtract_acquire(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_subtract_scacq_screl + */ +void HSA_API hsa_signal_subtract_relaxed(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_subtract_scacq_screl + */ +void HSA_API hsa_signal_subtract_screlease(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_subtract_screlease. + * + * @copydoc hsa_signal_subtract_screlease + */ +void HSA_API HSA_DEPRECATED +hsa_signal_subtract_release(hsa_signal_t signal, hsa_signal_value_t value); + +/** + * @brief Atomically perform a bitwise AND operation between the value of a + * signal and a given value. + * + * @details If the value of the signal is changed, all the agents waiting on + * @p signal for which @p value satisfies their wait condition are awakened. + * + * @param[in] signal Signal. If @p signal is a queue doorbell signal, the + * behavior is undefined. + * + * @param[in] value Value to AND with the value of the signal. + * + */ +void HSA_API hsa_signal_and_scacq_screl(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_and_scacq_screl. + * + * @copydoc hsa_signal_and_scacq_screl + */ +void HSA_API HSA_DEPRECATED hsa_signal_and_acq_rel(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_and_scacq_screl + */ +void HSA_API hsa_signal_and_scacquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_and_scacquire. + * + * @copydoc hsa_signal_and_scacquire + */ +void HSA_API HSA_DEPRECATED hsa_signal_and_acquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_and_scacq_screl + */ +void HSA_API hsa_signal_and_relaxed(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_and_scacq_screl + */ +void HSA_API hsa_signal_and_screlease(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_and_screlease. + * + * @copydoc hsa_signal_and_screlease + */ +void HSA_API HSA_DEPRECATED hsa_signal_and_release(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @brief Atomically perform a bitwise OR operation between the value of a + * signal and a given value. + * + * @details If the value of the signal is changed, all the agents waiting on + * @p signal for which @p value satisfies their wait condition are awakened. + * + * @param[in] signal Signal. If @p signal is a queue doorbell signal, the + * behavior is undefined. + * + * @param[in] value Value to OR with the value of the signal. + */ +void HSA_API hsa_signal_or_scacq_screl(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_or_scacq_screl. + * + * @copydoc hsa_signal_or_scacq_screl + */ +void HSA_API HSA_DEPRECATED hsa_signal_or_acq_rel(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_or_scacq_screl + */ +void HSA_API hsa_signal_or_scacquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_or_scacquire. + * + * @copydoc hsa_signal_or_scacquire + */ +void HSA_API HSA_DEPRECATED hsa_signal_or_acquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_or_scacq_screl + */ +void HSA_API hsa_signal_or_relaxed(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_or_scacq_screl + */ +void HSA_API hsa_signal_or_screlease(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_or_screlease. + * + * @copydoc hsa_signal_or_screlease + */ +void HSA_API HSA_DEPRECATED hsa_signal_or_release(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @brief Atomically perform a bitwise XOR operation between the value of a + * signal and a given value. + * + * @details If the value of the signal is changed, all the agents waiting on + * @p signal for which @p value satisfies their wait condition are awakened. + * + * @param[in] signal Signal. If @p signal is a queue doorbell signal, the + * behavior is undefined. + * + * @param[in] value Value to XOR with the value of the signal. + * + */ +void HSA_API hsa_signal_xor_scacq_screl(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_xor_scacq_screl. + * + * @copydoc hsa_signal_xor_scacq_screl + */ +void HSA_API HSA_DEPRECATED hsa_signal_xor_acq_rel(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_xor_scacq_screl + */ +void HSA_API hsa_signal_xor_scacquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_xor_scacquire. + * + * @copydoc hsa_signal_xor_scacquire + */ +void HSA_API HSA_DEPRECATED hsa_signal_xor_acquire(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_xor_scacq_screl + */ +void HSA_API hsa_signal_xor_relaxed(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @copydoc hsa_signal_xor_scacq_screl + */ +void HSA_API hsa_signal_xor_screlease(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @deprecated Renamed as ::hsa_signal_xor_screlease. + * + * @copydoc hsa_signal_xor_screlease + */ +void HSA_API HSA_DEPRECATED hsa_signal_xor_release(hsa_signal_t signal, + hsa_signal_value_t value); + +/** + * @brief Wait condition operator. + */ +typedef enum { + /** + * The two operands are equal. + */ + HSA_SIGNAL_CONDITION_EQ = 0, + /** + * The two operands are not equal. + */ + HSA_SIGNAL_CONDITION_NE = 1, + /** + * The first operand is less than the second operand. + */ + HSA_SIGNAL_CONDITION_LT = 2, + /** + * The first operand is greater than or equal to the second operand. + */ + HSA_SIGNAL_CONDITION_GTE = 3 +} hsa_signal_condition_t; + +/** + * @brief State of the application thread during a signal wait. + */ +typedef enum { + /** + * The application thread may be rescheduled while waiting on the signal. + */ + HSA_WAIT_STATE_BLOCKED = 0, + /** + * The application thread stays active while waiting on a signal. + */ + HSA_WAIT_STATE_ACTIVE = 1 +} hsa_wait_state_t; + +/** + * @brief Wait until a signal value satisfies a specified condition, or a + * certain amount of time has elapsed. + * + * @details A wait operation can spuriously resume at any time sooner than the + * timeout (for example, due to system or other external factors) even when the + * condition has not been met. + * + * The function is guaranteed to return if the signal value satisfies the + * condition at some point in time during the wait, but the value returned to + * the application might not satisfy the condition. The application must ensure + * that signals are used in such way that wait wakeup conditions are not + * invalidated before dependent threads have woken up. + * + * When the wait operation internally loads the value of the passed signal, it + * uses the memory order indicated in the function name. + * + * @param[in] signal Signal. + * + * @param[in] condition Condition used to compare the signal value with @p + * compare_value. + * + * @param[in] compare_value Value to compare with. + * + * @param[in] timeout_hint Maximum duration of the wait. Specified in the same + * unit as the system timestamp. The operation might block for a shorter or + * longer time even if the condition is not met. A value of UINT64_MAX indicates + * no maximum. + * + * @param[in] wait_state_hint Hint used by the application to indicate the + * preferred waiting state. The actual waiting state is ultimately decided by + * HSA runtime and may not match the provided hint. A value of + * ::HSA_WAIT_STATE_ACTIVE may improve the latency of response to a signal + * update by avoiding rescheduling overhead. + * + * @return Observed value of the signal, which might not satisfy the specified + * condition. + * + */ +hsa_signal_value_t HSA_API hsa_signal_wait_scacquire( + hsa_signal_t signal, hsa_signal_condition_t condition, + hsa_signal_value_t compare_value, uint64_t timeout_hint, + hsa_wait_state_t wait_state_hint); + +/** + * @copydoc hsa_signal_wait_scacquire + */ +hsa_signal_value_t HSA_API +hsa_signal_wait_relaxed(hsa_signal_t signal, hsa_signal_condition_t condition, + hsa_signal_value_t compare_value, uint64_t timeout_hint, + hsa_wait_state_t wait_state_hint); + +/** + * @deprecated Renamed as ::hsa_signal_wait_scacquire. + * + * @copydoc hsa_signal_wait_scacquire + */ +hsa_signal_value_t HSA_API HSA_DEPRECATED +hsa_signal_wait_acquire(hsa_signal_t signal, hsa_signal_condition_t condition, + hsa_signal_value_t compare_value, uint64_t timeout_hint, + hsa_wait_state_t wait_state_hint); + +/** + * @brief Group of signals. + */ +typedef struct hsa_signal_group_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_signal_group_t; + +/** + * @brief Create a signal group. + * + * @param[in] num_signals Number of elements in @p signals. Must not be 0. + * + * @param[in] signals List of signals in the group. The list must not contain + * any repeated elements. Must not be NULL. + * + * @param[in] num_consumers Number of elements in @p consumers. Must not be 0. + * + * @param[in] consumers List of agents that might consume (wait on) the signal + * group. The list must not contain repeated elements, and must be a subset of + * the set of agents that are allowed to wait on all the signals in the + * group. If an agent not listed in @p consumers waits on the returned group, + * the behavior is undefined. The memory associated with @p consumers can be + * reused or freed after the function returns. Must not be NULL. + * + * @param[out] signal_group Pointer to newly created signal group. Must not be + * NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p num_signals is 0, @p signals + * is NULL, @p num_consumers is 0, @p consumers is NULL, or @p signal_group is + * NULL. + */ +hsa_status_t HSA_API hsa_signal_group_create(uint32_t num_signals, + const hsa_signal_t *signals, + uint32_t num_consumers, + const hsa_agent_t *consumers, + hsa_signal_group_t *signal_group); + +/** + * @brief Destroy a signal group previous created by ::hsa_signal_group_create. + * + * @param[in] signal_group Signal group. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SIGNAL_GROUP @p signal_group is invalid. + */ +hsa_status_t HSA_API hsa_signal_group_destroy(hsa_signal_group_t signal_group); + +/** + * @brief Wait until the value of at least one of the signals in a signal group + * satisfies its associated condition. + * + * @details The function is guaranteed to return if the value of at least one of + * the signals in the group satisfies its associated condition at some point in + * time during the wait, but the signal value returned to the application may no + * longer satisfy the condition. The application must ensure that signals in the + * group are used in such way that wait wakeup conditions are not invalidated + * before dependent threads have woken up. + * + * When this operation internally loads the value of the passed signal, it uses + * the memory order indicated in the function name. + * + * @param[in] signal_group Signal group. + * + * @param[in] conditions List of conditions. Each condition, and the value at + * the same index in @p compare_values, is used to compare the value of the + * signal at that index in @p signal_group (the signal passed by the application + * to ::hsa_signal_group_create at that particular index). The size of @p + * conditions must not be smaller than the number of signals in @p signal_group; + * any extra elements are ignored. Must not be NULL. + * + * @param[in] compare_values List of comparison values. The size of @p + * compare_values must not be smaller than the number of signals in @p + * signal_group; any extra elements are ignored. Must not be NULL. + * + * @param[in] wait_state_hint Hint used by the application to indicate the + * preferred waiting state. The actual waiting state is decided by the HSA + * runtime and may not match the provided hint. A value of + * ::HSA_WAIT_STATE_ACTIVE may improve the latency of response to a signal + * update by avoiding rescheduling overhead. + * + * @param[out] signal Signal in the group that satisfied the associated + * condition. If several signals satisfied their condition, the function can + * return any of those signals. Must not be NULL. + * + * @param[out] value Observed value for @p signal, which might no longer satisfy + * the specified condition. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SIGNAL_GROUP @p signal_group is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p conditions is NULL, @p + * compare_values is NULL, @p signal is NULL, or @p value is NULL. + */ +hsa_status_t HSA_API hsa_signal_group_wait_any_scacquire( + hsa_signal_group_t signal_group, const hsa_signal_condition_t *conditions, + const hsa_signal_value_t *compare_values, hsa_wait_state_t wait_state_hint, + hsa_signal_t *signal, hsa_signal_value_t *value); + +/** + * @copydoc hsa_signal_group_wait_any_scacquire + */ +hsa_status_t HSA_API hsa_signal_group_wait_any_relaxed( + hsa_signal_group_t signal_group, const hsa_signal_condition_t *conditions, + const hsa_signal_value_t *compare_values, hsa_wait_state_t wait_state_hint, + hsa_signal_t *signal, hsa_signal_value_t *value); + +/** @} */ + +/** \defgroup memory Memory + * @{ + */ + +/** + * @brief A memory region represents a block of virtual memory with certain + * properties. For example, the HSA runtime represents fine-grained memory in + * the global segment using a region. A region might be associated with more + * than one agent. + */ +typedef struct hsa_region_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_region_t; + +/** @} */ + +/** \defgroup queue Queues + * @{ + */ + +/** + * @brief Queue type. Intended to be used for dynamic queue protocol + * determination. + */ +typedef enum { + /** + * Queue supports multiple producers. Use of multiproducer queue mechanics is + * required. + */ + HSA_QUEUE_TYPE_MULTI = 0, + /** + * Queue only supports a single producer. In some scenarios, the application + * may want to limit the submission of AQL packets to a single agent. Queues + * that support a single producer may be more efficient than queues supporting + * multiple producers. Use of multiproducer queue mechanics is not supported. + */ + HSA_QUEUE_TYPE_SINGLE = 1, + /** + * Queue supports multiple producers and cooperative dispatches. Cooperative + * dispatches are able to use GWS synchronization. Queues of this type may be + * limited in number. The runtime may return the same queue to serve multiple + * ::hsa_queue_create calls when this type is given. Callers must inspect the + * returned queue to discover queue size. Queues of this type are reference + * counted and require a matching number of ::hsa_queue_destroy calls to + * release. Use of multiproducer queue mechanics is required. See + * ::HSA_AMD_AGENT_INFO_COOPERATIVE_QUEUES to query agent support for this + * type. + */ + HSA_QUEUE_TYPE_COOPERATIVE = 2 +} hsa_queue_type_t; + +/** + * @brief A fixed-size type used to represent ::hsa_queue_type_t constants. + */ +typedef uint32_t hsa_queue_type32_t; + +/** + * @brief Queue features. + */ +typedef enum { + /** + * Queue supports kernel dispatch packets. + */ + HSA_QUEUE_FEATURE_KERNEL_DISPATCH = 1, + + /** + * Queue supports agent dispatch packets. + */ + HSA_QUEUE_FEATURE_AGENT_DISPATCH = 2 +} hsa_queue_feature_t; + +/** + * @brief User mode queue. + * + * @details The queue structure is read-only and allocated by the HSA runtime, + * but agents can directly modify the contents of the buffer pointed by @a + * base_address, or use HSA runtime APIs to access the doorbell signal. + * + */ +typedef struct hsa_queue_s { + /** + * Queue type. + */ + hsa_queue_type32_t type; + + /** + * Queue features mask. This is a bit-field of ::hsa_queue_feature_t + * values. Applications should ignore any unknown set bits. + */ + uint32_t features; + +#ifdef HSA_LARGE_MODEL + void *base_address; +#elif defined HSA_LITTLE_ENDIAN + /** + * Starting address of the HSA runtime-allocated buffer used to store the AQL + * packets. Must be aligned to the size of an AQL packet. + */ + void *base_address; + /** + * Reserved. Must be 0. + */ + uint32_t reserved0; +#else + uint32_t reserved0; + void *base_address; +#endif + + /** + * Signal object used by the application to indicate the ID of a packet that + * is ready to be processed. The HSA runtime manages the doorbell signal. If + * the application tries to replace or destroy this signal, the behavior is + * undefined. + * + * If @a type is ::HSA_QUEUE_TYPE_SINGLE, the doorbell signal value must be + * updated in a monotonically increasing fashion. If @a type is + * ::HSA_QUEUE_TYPE_MULTI, the doorbell signal value can be updated with any + * value. + */ + hsa_signal_t doorbell_signal; + + /** + * Maximum number of packets the queue can hold. Must be a power of 2. + */ + uint32_t size; + /** + * Reserved. Must be 0. + */ + uint32_t reserved1; + /** + * Queue identifier, which is unique over the lifetime of the application. + */ + uint64_t id; + +} hsa_queue_t; + +/** + * @brief Create a user mode queue. + * + * @details The HSA runtime creates the queue structure, the underlying packet + * buffer, the completion signal, and the write and read indexes. The initial + * value of the write and read indexes is 0. The type of every packet in the + * buffer is initialized to ::HSA_PACKET_TYPE_INVALID. + * + * The application should only rely on the error code returned to determine if + * the queue is valid. + * + * @param[in] agent Agent where to create the queue. + * + * @param[in] size Number of packets the queue is expected to + * hold. Must be a power of 2 between 1 and the value of + * ::HSA_AGENT_INFO_QUEUE_MAX_SIZE in @p agent. The size of the newly + * created queue is the maximum of @p size and the value of + * ::HSA_AGENT_INFO_QUEUE_MIN_SIZE in @p agent. + * + * @param[in] type Type of the queue, a bitwise OR of hsa_queue_type_t values. + * If the value of ::HSA_AGENT_INFO_QUEUE_TYPE in @p agent is + * ::HSA_QUEUE_TYPE_SINGLE, then @p type must also be ::HSA_QUEUE_TYPE_SINGLE. + * + * @param[in] callback Callback invoked by the HSA runtime for every + * asynchronous event related to the newly created queue. May be NULL. The HSA + * runtime passes three arguments to the callback: a code identifying the event + * that triggered the invocation, a pointer to the queue where the event + * originated, and the application data. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @param[in] private_segment_size Hint indicating the maximum + * expected private segment usage per work-item, in bytes. There may + * be performance degradation if the application places a kernel + * dispatch packet in the queue and the corresponding private segment + * usage exceeds @p private_segment_size. If the application does not + * want to specify any particular value for this argument, @p + * private_segment_size must be UINT32_MAX. If the queue does not + * support kernel dispatch packets, this argument is ignored. + * + * @param[in] group_segment_size Hint indicating the maximum expected + * group segment usage per work-group, in bytes. There may be + * performance degradation if the application places a kernel dispatch + * packet in the queue and the corresponding group segment usage + * exceeds @p group_segment_size. If the application does not want to + * specify any particular value for this argument, @p + * group_segment_size must be UINT32_MAX. If the queue does not + * support kernel dispatch packets, this argument is ignored. + * + * @param[out] queue Memory location where the HSA runtime stores a pointer to + * the newly created queue. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_QUEUE_CREATION @p agent does not + * support queues of the given type. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p size is not a power of two, + * @p size is 0, @p type is an invalid queue type, or @p queue is NULL. + * + */ +hsa_status_t HSA_API hsa_queue_create( + hsa_agent_t agent, uint32_t size, hsa_queue_type32_t type, + void (*callback)(hsa_status_t status, hsa_queue_t *source, void *data), + void *data, uint32_t private_segment_size, uint32_t group_segment_size, + hsa_queue_t **queue); + +/** + * @brief Create a queue for which the application or a kernel is responsible + * for processing the AQL packets. + * + * @details The application can use this function to create queues where AQL + * packets are not parsed by the packet processor associated with an agent, + * but rather by a unit of execution running on that agent (for example, a + * thread in the host application). + * + * The application is responsible for ensuring that all the producers and + * consumers of the resulting queue can access the provided doorbell signal + * and memory region. The application is also responsible for ensuring that the + * unit of execution processing the queue packets supports the indicated + * features (AQL packet types). + * + * When the queue is created, the HSA runtime allocates the packet buffer using + * @p region, and the write and read indexes. The initial value of the write and + * read indexes is 0, and the type of every packet in the buffer is initialized + * to ::HSA_PACKET_TYPE_INVALID. The value of the @e size, @e type, @e features, + * and @e doorbell_signal fields in the returned queue match the values passed + * by the application. + * + * @param[in] region Memory region that the HSA runtime should use to allocate + * the AQL packet buffer and any other queue metadata. + * + * @param[in] size Number of packets the queue is expected to hold. Must be a + * power of 2 greater than 0. + * + * @param[in] type Queue type. + * + * @param[in] features Supported queue features. This is a bit-field of + * ::hsa_queue_feature_t values. + * + * @param[in] doorbell_signal Doorbell signal that the HSA runtime must + * associate with the returned queue. The signal handle must not be 0. + * + * @param[out] queue Memory location where the HSA runtime stores a pointer to + * the newly created queue. The application should not rely on the value + * returned for this argument but only in the status code to determine if the + * queue is valid. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p size is not a power of two, @p + * size is 0, @p type is an invalid queue type, the doorbell signal handle is + * 0, or @p queue is NULL. + * + */ +hsa_status_t HSA_API hsa_soft_queue_create(hsa_region_t region, uint32_t size, + hsa_queue_type32_t type, + uint32_t features, + hsa_signal_t doorbell_signal, + hsa_queue_t **queue); + +/** + * @brief Destroy a user mode queue. + * + * @details When a queue is destroyed, the state of the AQL packets that have + * not been yet fully processed (their completion phase has not finished) + * becomes undefined. It is the responsibility of the application to ensure that + * all pending queue operations are finished if their results are required. + * + * The resources allocated by the HSA runtime during queue creation (queue + * structure, ring buffer, doorbell signal) are released. The queue should not + * be accessed after being destroyed. + * + * @param[in] queue Pointer to a queue created using ::hsa_queue_create. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_QUEUE The queue is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p queue is NULL. + */ +hsa_status_t HSA_API hsa_queue_destroy(hsa_queue_t *queue); + +/** + * @brief Inactivate a queue. + * + * @details Inactivating the queue aborts any pending executions and prevent any + * new packets from being processed. Any more packets written to the queue once + * it is inactivated will be ignored by the packet processor. + * + * @param[in] queue Pointer to a queue. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_QUEUE The queue is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p queue is NULL. + */ +hsa_status_t HSA_API hsa_queue_inactivate(hsa_queue_t *queue); + +/** + * @deprecated Renamed as ::hsa_queue_load_read_index_scacquire. + * + * @copydoc hsa_queue_load_read_index_scacquire + */ +uint64_t HSA_API HSA_DEPRECATED +hsa_queue_load_read_index_acquire(const hsa_queue_t *queue); + +/** + * @brief Atomically load the read index of a queue. + * + * @param[in] queue Pointer to a queue. + * + * @return Read index of the queue pointed by @p queue. + */ +uint64_t HSA_API hsa_queue_load_read_index_scacquire(const hsa_queue_t *queue); + +/** + * @copydoc hsa_queue_load_read_index_scacquire + */ +uint64_t HSA_API hsa_queue_load_read_index_relaxed(const hsa_queue_t *queue); + +/** + * @deprecated Renamed as ::hsa_queue_load_write_index_scacquire. + * + * @copydoc hsa_queue_load_write_index_scacquire + */ +uint64_t HSA_API HSA_DEPRECATED +hsa_queue_load_write_index_acquire(const hsa_queue_t *queue); + +/** + * @brief Atomically load the write index of a queue. + * + * @param[in] queue Pointer to a queue. + * + * @return Write index of the queue pointed by @p queue. + */ +uint64_t HSA_API hsa_queue_load_write_index_scacquire(const hsa_queue_t *queue); + +/** + * @copydoc hsa_queue_load_write_index_scacquire + */ +uint64_t HSA_API hsa_queue_load_write_index_relaxed(const hsa_queue_t *queue); + +/** + * @brief Atomically set the write index of a queue. + * + * @details It is recommended that the application uses this function to update + * the write index when there is a single agent submitting work to the queue + * (the queue type is ::HSA_QUEUE_TYPE_SINGLE). + * + * @param[in] queue Pointer to a queue. + * + * @param[in] value Value to assign to the write index. + * + */ +void HSA_API hsa_queue_store_write_index_relaxed(const hsa_queue_t *queue, + uint64_t value); + +/** + * @deprecated Renamed as ::hsa_queue_store_write_index_screlease. + * + * @copydoc hsa_queue_store_write_index_screlease + */ +void HSA_API HSA_DEPRECATED +hsa_queue_store_write_index_release(const hsa_queue_t *queue, uint64_t value); + +/** + * @copydoc hsa_queue_store_write_index_relaxed + */ +void HSA_API hsa_queue_store_write_index_screlease(const hsa_queue_t *queue, + uint64_t value); + +/** + * @deprecated Renamed as ::hsa_queue_cas_write_index_scacq_screl. + * + * @copydoc hsa_queue_cas_write_index_scacq_screl + */ +uint64_t HSA_API HSA_DEPRECATED hsa_queue_cas_write_index_acq_rel( + const hsa_queue_t *queue, uint64_t expected, uint64_t value); + +/** + * @brief Atomically set the write index of a queue if the observed value is + * equal to the expected value. The application can inspect the returned value + * to determine if the replacement was done. + * + * @param[in] queue Pointer to a queue. + * + * @param[in] expected Expected value. + * + * @param[in] value Value to assign to the write index if @p expected matches + * the observed write index. Must be greater than @p expected. + * + * @return Previous value of the write index. + */ +uint64_t HSA_API hsa_queue_cas_write_index_scacq_screl(const hsa_queue_t *queue, + uint64_t expected, + uint64_t value); + +/** + * @deprecated Renamed as ::hsa_queue_cas_write_index_scacquire. + * + * @copydoc hsa_queue_cas_write_index_scacquire + */ +uint64_t HSA_API HSA_DEPRECATED hsa_queue_cas_write_index_acquire( + const hsa_queue_t *queue, uint64_t expected, uint64_t value); + +/** + * @copydoc hsa_queue_cas_write_index_scacq_screl + */ +uint64_t HSA_API hsa_queue_cas_write_index_scacquire(const hsa_queue_t *queue, + uint64_t expected, + uint64_t value); + +/** + * @copydoc hsa_queue_cas_write_index_scacq_screl + */ +uint64_t HSA_API hsa_queue_cas_write_index_relaxed(const hsa_queue_t *queue, + uint64_t expected, + uint64_t value); + +/** + * @deprecated Renamed as ::hsa_queue_cas_write_index_screlease. + * + * @copydoc hsa_queue_cas_write_index_screlease + */ +uint64_t HSA_API HSA_DEPRECATED hsa_queue_cas_write_index_release( + const hsa_queue_t *queue, uint64_t expected, uint64_t value); + +/** + * @copydoc hsa_queue_cas_write_index_scacq_screl + */ +uint64_t HSA_API hsa_queue_cas_write_index_screlease(const hsa_queue_t *queue, + uint64_t expected, + uint64_t value); + +/** + * @deprecated Renamed as ::hsa_queue_add_write_index_scacq_screl. + * + * @copydoc hsa_queue_add_write_index_scacq_screl + */ +uint64_t HSA_API HSA_DEPRECATED +hsa_queue_add_write_index_acq_rel(const hsa_queue_t *queue, uint64_t value); + +/** + * @brief Atomically increment the write index of a queue by an offset. + * + * @param[in] queue Pointer to a queue. + * + * @param[in] value Value to add to the write index. + * + * @return Previous value of the write index. + */ +uint64_t HSA_API hsa_queue_add_write_index_scacq_screl(const hsa_queue_t *queue, + uint64_t value); + +/** + * @deprecated Renamed as ::hsa_queue_add_write_index_scacquire. + * + * @copydoc hsa_queue_add_write_index_scacquire + */ +uint64_t HSA_API HSA_DEPRECATED +hsa_queue_add_write_index_acquire(const hsa_queue_t *queue, uint64_t value); + +/** + * @copydoc hsa_queue_add_write_index_scacq_screl + */ +uint64_t HSA_API hsa_queue_add_write_index_scacquire(const hsa_queue_t *queue, + uint64_t value); + +/** + * @copydoc hsa_queue_add_write_index_scacq_screl + */ +uint64_t HSA_API hsa_queue_add_write_index_relaxed(const hsa_queue_t *queue, + uint64_t value); + +/** + * @deprecated Renamed as ::hsa_queue_add_write_index_screlease. + * + * @copydoc hsa_queue_add_write_index_screlease + */ +uint64_t HSA_API HSA_DEPRECATED +hsa_queue_add_write_index_release(const hsa_queue_t *queue, uint64_t value); + +/** + * @copydoc hsa_queue_add_write_index_scacq_screl + */ +uint64_t HSA_API hsa_queue_add_write_index_screlease(const hsa_queue_t *queue, + uint64_t value); + +/** + * @brief Atomically set the read index of a queue. + * + * @details Modifications of the read index are not allowed and result in + * undefined behavior if the queue is associated with an agent for which + * only the corresponding packet processor is permitted to update the read + * index. + * + * @param[in] queue Pointer to a queue. + * + * @param[in] value Value to assign to the read index. + * + */ +void HSA_API hsa_queue_store_read_index_relaxed(const hsa_queue_t *queue, + uint64_t value); + +/** + * @deprecated Renamed as ::hsa_queue_store_read_index_screlease. + * + * @copydoc hsa_queue_store_read_index_screlease + */ +void HSA_API HSA_DEPRECATED +hsa_queue_store_read_index_release(const hsa_queue_t *queue, uint64_t value); + +/** + * @copydoc hsa_queue_store_read_index_relaxed + */ +void HSA_API hsa_queue_store_read_index_screlease(const hsa_queue_t *queue, + uint64_t value); +/** @} */ + +/** \defgroup aql Architected Queuing Language + * @{ + */ + +/** + * @brief Packet type. + */ +typedef enum { + /** + * Vendor-specific packet. + */ + HSA_PACKET_TYPE_VENDOR_SPECIFIC = 0, + /** + * The packet has been processed in the past, but has not been reassigned to + * the packet processor. A packet processor must not process a packet of this + * type. All queues support this packet type. + */ + HSA_PACKET_TYPE_INVALID = 1, + /** + * Packet used by agents for dispatching jobs to kernel agents. Not all + * queues support packets of this type (see ::hsa_queue_feature_t). + */ + HSA_PACKET_TYPE_KERNEL_DISPATCH = 2, + /** + * Packet used by agents to delay processing of subsequent packets, and to + * express complex dependencies between multiple packets. All queues support + * this packet type. + */ + HSA_PACKET_TYPE_BARRIER_AND = 3, + /** + * Packet used by agents for dispatching jobs to agents. Not all + * queues support packets of this type (see ::hsa_queue_feature_t). + */ + HSA_PACKET_TYPE_AGENT_DISPATCH = 4, + /** + * Packet used by agents to delay processing of subsequent packets, and to + * express complex dependencies between multiple packets. All queues support + * this packet type. + */ + HSA_PACKET_TYPE_BARRIER_OR = 5 +} hsa_packet_type_t; + +/** + * @brief Scope of the memory fence operation associated with a packet. + */ +typedef enum { + /** + * No scope (no fence is applied). The packet relies on external fences to + * ensure visibility of memory updates. + */ + HSA_FENCE_SCOPE_NONE = 0, + /** + * The fence is applied with agent scope for the global segment. + */ + HSA_FENCE_SCOPE_AGENT = 1, + /** + * The fence is applied across both agent and system scope for the global + * segment. + */ + HSA_FENCE_SCOPE_SYSTEM = 2 +} hsa_fence_scope_t; + +/** + * @brief Sub-fields of the @a header field that is present in any AQL + * packet. The offset (with respect to the address of @a header) of a sub-field + * is identical to its enumeration constant. The width of each sub-field is + * determined by the corresponding value in ::hsa_packet_header_width_t. The + * offset and the width are expressed in bits. + */ +typedef enum { + /** + * Packet type. The value of this sub-field must be one of + * ::hsa_packet_type_t. If the type is ::HSA_PACKET_TYPE_VENDOR_SPECIFIC, the + * packet layout is vendor-specific. + */ + HSA_PACKET_HEADER_TYPE = 0, + /** + * Barrier bit. If the barrier bit is set, the processing of the current + * packet only launches when all preceding packets (within the same queue) are + * complete. + */ + HSA_PACKET_HEADER_BARRIER = 8, + /** + * Acquire fence scope. The value of this sub-field determines the scope and + * type of the memory fence operation applied before the packet enters the + * active phase. An acquire fence ensures that any subsequent global segment + * or image loads by any unit of execution that belongs to a dispatch that has + * not yet entered the active phase on any queue of the same kernel agent, + * sees any data previously released at the scopes specified by the acquire + * fence. The value of this sub-field must be one of ::hsa_fence_scope_t. + */ + HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE = 9, + /** + * @deprecated Renamed as ::HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE. + */ + HSA_PACKET_HEADER_ACQUIRE_FENCE_SCOPE = 9, + /** + * Release fence scope, The value of this sub-field determines the scope and + * type of the memory fence operation applied after kernel completion but + * before the packet is completed. A release fence makes any global segment or + * image data that was stored by any unit of execution that belonged to a + * dispatch that has completed the active phase on any queue of the same + * kernel agent visible in all the scopes specified by the release fence. The + * value of this sub-field must be one of ::hsa_fence_scope_t. + */ + HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE = 11, + /** + * @deprecated Renamed as ::HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE. + */ + HSA_PACKET_HEADER_RELEASE_FENCE_SCOPE = 11 +} hsa_packet_header_t; + +/** + * @brief Width (in bits) of the sub-fields in ::hsa_packet_header_t. + */ +typedef enum { + HSA_PACKET_HEADER_WIDTH_TYPE = 8, + HSA_PACKET_HEADER_WIDTH_BARRIER = 1, + HSA_PACKET_HEADER_WIDTH_SCACQUIRE_FENCE_SCOPE = 2, + /** + * @deprecated Use HSA_PACKET_HEADER_WIDTH_SCACQUIRE_FENCE_SCOPE. + */ + HSA_PACKET_HEADER_WIDTH_ACQUIRE_FENCE_SCOPE = 2, + HSA_PACKET_HEADER_WIDTH_SCRELEASE_FENCE_SCOPE = 2, + /** + * @deprecated Use HSA_PACKET_HEADER_WIDTH_SCRELEASE_FENCE_SCOPE. + */ + HSA_PACKET_HEADER_WIDTH_RELEASE_FENCE_SCOPE = 2 +} hsa_packet_header_width_t; + +/** + * @brief Sub-fields of the kernel dispatch packet @a setup field. The offset + * (with respect to the address of @a setup) of a sub-field is identical to its + * enumeration constant. The width of each sub-field is determined by the + * corresponding value in ::hsa_kernel_dispatch_packet_setup_width_t. The + * offset and the width are expressed in bits. + */ +typedef enum { + /** + * Number of dimensions of the grid. Valid values are 1, 2, or 3. + * + */ + HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS = 0 +} hsa_kernel_dispatch_packet_setup_t; + +/** + * @brief Width (in bits) of the sub-fields in + * ::hsa_kernel_dispatch_packet_setup_t. + */ +typedef enum { + HSA_KERNEL_DISPATCH_PACKET_SETUP_WIDTH_DIMENSIONS = 2 +} hsa_kernel_dispatch_packet_setup_width_t; + +/** + * @brief AQL kernel dispatch packet + */ +typedef struct hsa_kernel_dispatch_packet_s { + /** + * Packet header. Used to configure multiple packet parameters such as the + * packet type. The parameters are described by ::hsa_packet_header_t. + */ + uint16_t header; + + /** + * Dispatch setup parameters. Used to configure kernel dispatch parameters + * such as the number of dimensions in the grid. The parameters are described + * by ::hsa_kernel_dispatch_packet_setup_t. + */ + uint16_t setup; + + /** + * X dimension of work-group, in work-items. Must be greater than 0. + */ + uint16_t workgroup_size_x; + + /** + * Y dimension of work-group, in work-items. Must be greater than + * 0. If the grid has 1 dimension, the only valid value is 1. + */ + uint16_t workgroup_size_y; + + /** + * Z dimension of work-group, in work-items. Must be greater than + * 0. If the grid has 1 or 2 dimensions, the only valid value is 1. + */ + uint16_t workgroup_size_z; + + /** + * Reserved. Must be 0. + */ + uint16_t reserved0; + + /** + * X dimension of grid, in work-items. Must be greater than 0. Must + * not be smaller than @a workgroup_size_x. + */ + uint32_t grid_size_x; + + /** + * Y dimension of grid, in work-items. Must be greater than 0. If the grid has + * 1 dimension, the only valid value is 1. Must not be smaller than @a + * workgroup_size_y. + */ + uint32_t grid_size_y; + + /** + * Z dimension of grid, in work-items. Must be greater than 0. If the grid has + * 1 or 2 dimensions, the only valid value is 1. Must not be smaller than @a + * workgroup_size_z. + */ + uint32_t grid_size_z; + + /** + * Size in bytes of private memory allocation request (per work-item). + */ + uint32_t private_segment_size; + + /** + * Size in bytes of group memory allocation request (per work-group). Must not + * be less than the sum of the group memory used by the kernel (and the + * functions it calls directly or indirectly) and the dynamically allocated + * group segment variables. + */ + uint32_t group_segment_size; + + /** + * Opaque handle to a code object that includes an implementation-defined + * executable code for the kernel. + */ + uint64_t kernel_object; + +#ifdef HSA_LARGE_MODEL + void *kernarg_address; +#elif defined HSA_LITTLE_ENDIAN + /** + * Pointer to a buffer containing the kernel arguments. May be NULL. + * + * The buffer must be allocated using ::hsa_memory_allocate, and must not be + * modified once the kernel dispatch packet is enqueued until the dispatch has + * completed execution. + */ + void *kernarg_address; + /** + * Reserved. Must be 0. + */ + uint32_t reserved1; +#else + uint32_t reserved1; + void *kernarg_address; +#endif + + /** + * Reserved. Must be 0. + */ + uint64_t reserved2; + + /** + * Signal used to indicate completion of the job. The application can use the + * special signal handle 0 to indicate that no signal is used. + */ + hsa_signal_t completion_signal; + +} hsa_kernel_dispatch_packet_t; + +/** + * @brief Agent dispatch packet. + */ +typedef struct hsa_agent_dispatch_packet_s { + /** + * Packet header. Used to configure multiple packet parameters such as the + * packet type. The parameters are described by ::hsa_packet_header_t. + */ + uint16_t header; + + /** + * Application-defined function to be performed by the destination agent. + */ + uint16_t type; + + /** + * Reserved. Must be 0. + */ + uint32_t reserved0; + +#ifdef HSA_LARGE_MODEL + void *return_address; +#elif defined HSA_LITTLE_ENDIAN + /** + * Address where to store the function return values, if any. + */ + void *return_address; + /** + * Reserved. Must be 0. + */ + uint32_t reserved1; +#else + uint32_t reserved1; + void *return_address; +#endif + + /** + * Function arguments. + */ + uint64_t arg[4]; + + /** + * Reserved. Must be 0. + */ + uint64_t reserved2; + + /** + * Signal used to indicate completion of the job. The application can use the + * special signal handle 0 to indicate that no signal is used. + */ + hsa_signal_t completion_signal; + +} hsa_agent_dispatch_packet_t; + +/** + * @brief Barrier-AND packet. + */ +typedef struct hsa_barrier_and_packet_s { + /** + * Packet header. Used to configure multiple packet parameters such as the + * packet type. The parameters are described by ::hsa_packet_header_t. + */ + uint16_t header; + + /** + * Reserved. Must be 0. + */ + uint16_t reserved0; + + /** + * Reserved. Must be 0. + */ + uint32_t reserved1; + + /** + * Array of dependent signal objects. Signals with a handle value of 0 are + * allowed and are interpreted by the packet processor as satisfied + * dependencies. + */ + hsa_signal_t dep_signal[5]; + + /** + * Reserved. Must be 0. + */ + uint64_t reserved2; + + /** + * Signal used to indicate completion of the job. The application can use the + * special signal handle 0 to indicate that no signal is used. + */ + hsa_signal_t completion_signal; + +} hsa_barrier_and_packet_t; + +/** + * @brief Barrier-OR packet. + */ +typedef struct hsa_barrier_or_packet_s { + /** + * Packet header. Used to configure multiple packet parameters such as the + * packet type. The parameters are described by ::hsa_packet_header_t. + */ + uint16_t header; + + /** + * Reserved. Must be 0. + */ + uint16_t reserved0; + + /** + * Reserved. Must be 0. + */ + uint32_t reserved1; + + /** + * Array of dependent signal objects. Signals with a handle value of 0 are + * allowed and are interpreted by the packet processor as dependencies not + * satisfied. + */ + hsa_signal_t dep_signal[5]; + + /** + * Reserved. Must be 0. + */ + uint64_t reserved2; + + /** + * Signal used to indicate completion of the job. The application can use the + * special signal handle 0 to indicate that no signal is used. + */ + hsa_signal_t completion_signal; + +} hsa_barrier_or_packet_t; + +/** @} */ + +/** \addtogroup memory Memory + * @{ + */ + +/** + * @brief Memory segments associated with a region. + */ +typedef enum { + /** + * Global segment. Used to hold data that is shared by all agents. + */ + HSA_REGION_SEGMENT_GLOBAL = 0, + /** + * Read-only segment. Used to hold data that remains constant during the + * execution of a kernel. + */ + HSA_REGION_SEGMENT_READONLY = 1, + /** + * Private segment. Used to hold data that is local to a single work-item. + */ + HSA_REGION_SEGMENT_PRIVATE = 2, + /** + * Group segment. Used to hold data that is shared by the work-items of a + * work-group. + */ + HSA_REGION_SEGMENT_GROUP = 3, + /** + * Kernarg segment. Used to store kernel arguments. + */ + HSA_REGION_SEGMENT_KERNARG = 4 +} hsa_region_segment_t; + +/** + * @brief Global region flags. + */ +typedef enum { + /** + * The application can use memory in the region to store kernel arguments, and + * provide the values for the kernarg segment of a kernel dispatch. If this + * flag is set, then ::HSA_REGION_GLOBAL_FLAG_FINE_GRAINED must be set. + */ + HSA_REGION_GLOBAL_FLAG_KERNARG = 1, + /** + * Updates to memory in this region are immediately visible to all the + * agents under the terms of the HSA memory model. If this + * flag is set, then ::HSA_REGION_GLOBAL_FLAG_COARSE_GRAINED must not be set. + */ + HSA_REGION_GLOBAL_FLAG_FINE_GRAINED = 2, + /** + * Updates to memory in this region can be performed by a single agent at + * a time. If a different agent in the system is allowed to access the + * region, the application must explicitely invoke ::hsa_memory_assign_agent + * in order to transfer ownership to that agent for a particular buffer. + */ + HSA_REGION_GLOBAL_FLAG_COARSE_GRAINED = 4, + + /** + * Updates to memory in this region have extended scope, where the + * device-scope atomics to this memory type act as system-scope with respect + * to all variables located in memory regions of this type. Note: On + * non-compliant systems, the application may still be responsible for + * performing device-specific actions necessary to achieve system-scope + * coherence. + */ + HSA_REGION_GLOBAL_FLAG_EXTENDED_SCOPE_FINE_GRAINED = 8 +} hsa_region_global_flag_t; + +/** + * @brief Attributes of a memory region. + */ +typedef enum { + /** + * Segment where memory in the region can be used. The type of this + * attribute is ::hsa_region_segment_t. + */ + HSA_REGION_INFO_SEGMENT = 0, + /** + * Flag mask. The value of this attribute is undefined if the value of + * ::HSA_REGION_INFO_SEGMENT is not ::HSA_REGION_SEGMENT_GLOBAL. The type of + * this attribute is uint32_t, a bit-field of ::hsa_region_global_flag_t + * values. + */ + HSA_REGION_INFO_GLOBAL_FLAGS = 1, + /** + * Size of this region, in bytes. The type of this attribute is size_t. + */ + HSA_REGION_INFO_SIZE = 2, + /** + * Maximum allocation size in this region, in bytes. Must not exceed the value + * of ::HSA_REGION_INFO_SIZE. The type of this attribute is size_t. + * + * If the region is in the global or readonly segments, this is the maximum + * size that the application can pass to ::hsa_memory_allocate. + * + * If the region is in the group segment, this is the maximum size (per + * work-group) that can be requested for a given kernel dispatch. If the + * region is in the private segment, this is the maximum size (per work-item) + * that can be requested for a specific kernel dispatch, and must be at least + * 256 bytes. + */ + HSA_REGION_INFO_ALLOC_MAX_SIZE = 4, + /** + * Maximum size (per work-group) of private memory that can be requested for a + * specific kernel dispatch. Must be at least 65536 bytes. The type of this + * attribute is uint32_t. The value of this attribute is undefined if the + * region is not in the private segment. + */ + HSA_REGION_INFO_ALLOC_MAX_PRIVATE_WORKGROUP_SIZE = 8, + /** + * Indicates whether memory in this region can be allocated using + * ::hsa_memory_allocate. The type of this attribute is bool. + * + * The value of this flag is always false for regions in the group and private + * segments. + */ + HSA_REGION_INFO_RUNTIME_ALLOC_ALLOWED = 5, + /** + * Allocation granularity of buffers allocated by ::hsa_memory_allocate in + * this region. The size of a buffer allocated in this region is a multiple of + * the value of this attribute. The value of this attribute is only defined if + * ::HSA_REGION_INFO_RUNTIME_ALLOC_ALLOWED is true for this region. The type + * of this attribute is size_t. + */ + HSA_REGION_INFO_RUNTIME_ALLOC_GRANULE = 6, + /** + * Alignment of buffers allocated by ::hsa_memory_allocate in this region. The + * value of this attribute is only defined if + * ::HSA_REGION_INFO_RUNTIME_ALLOC_ALLOWED is true for this region, and must + * be a power of 2. The type of this attribute is size_t. + */ + HSA_REGION_INFO_RUNTIME_ALLOC_ALIGNMENT = 7 +} hsa_region_info_t; + +/** + * @brief Get the current value of an attribute of a region. + * + * @param[in] region A valid region. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to a application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_REGION The region is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * region attribute, or @p value is NULL. + */ +hsa_status_t HSA_API hsa_region_get_info(hsa_region_t region, + hsa_region_info_t attribute, + void *value); + +/** + * @brief Iterate over the memory regions associated with a given agent, and + * invoke an application-defined callback on every iteration. + * + * @param[in] agent A valid agent. + * + * @param[in] callback Callback to be invoked once per region that is + * accessible from the agent. The HSA runtime passes two arguments to the + * callback, the region and the application data. If @p callback returns a + * status other than ::HSA_STATUS_SUCCESS for a particular iteration, the + * traversal stops and ::hsa_agent_iterate_regions returns that status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_agent_iterate_regions( + hsa_agent_t agent, + hsa_status_t (*callback)(hsa_region_t region, void *data), void *data); + +/** + * @brief Allocate a block of memory in a given region. + * + * @param[in] region Region where to allocate memory from. The region must have + * the ::HSA_REGION_INFO_RUNTIME_ALLOC_ALLOWED flag set. + * + * @param[in] size Allocation size, in bytes. Must not be zero. This value is + * rounded up to the nearest multiple of ::HSA_REGION_INFO_RUNTIME_ALLOC_GRANULE + * in @p region. + * + * @param[out] ptr Pointer to the location where to store the base address of + * the allocated block. The returned base address is aligned to the value of + * ::HSA_REGION_INFO_RUNTIME_ALLOC_ALIGNMENT in @p region. If the allocation + * fails, the returned value is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_REGION The region is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION The host is not allowed to + * allocate memory in @p region, or @p size is greater than the value of + * HSA_REGION_INFO_ALLOC_MAX_SIZE in @p region. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p ptr is NULL, or @p size is 0. + */ +hsa_status_t HSA_API hsa_memory_allocate(hsa_region_t region, size_t size, + void **ptr); + +/** + * @brief Deallocate a block of memory previously allocated using + * ::hsa_memory_allocate. + * + * @param[in] ptr Pointer to a memory block. If @p ptr does not match a value + * previously returned by ::hsa_memory_allocate, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + */ +hsa_status_t HSA_API hsa_memory_free(void *ptr); + +/** + * @brief Copy a block of memory from the location pointed to by @p src to the + * memory block pointed to by @p dst. + * + * @param[out] dst Buffer where the content is to be copied. If @p dst is in + * coarse-grained memory, the copied data is only visible to the agent currently + * assigned (::hsa_memory_assign_agent) to @p dst. + * + * @param[in] src A valid pointer to the source of data to be copied. The source + * buffer must not overlap with the destination buffer. If the source buffer is + * in coarse-grained memory then it must be assigned to an agent, from which the + * data will be retrieved. + * + * @param[in] size Number of bytes to copy. If @p size is 0, no copy is + * performed and the function returns success. Copying a number of bytes larger + * than the size of the buffers pointed by @p dst or @p src results in undefined + * behavior. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT The source or destination + * pointers are NULL. + */ +hsa_status_t HSA_API hsa_memory_copy(void *dst, const void *src, size_t size); + +/** + * @brief Change the ownership of a global, coarse-grained buffer. + * + * @details The contents of a coarse-grained buffer are visible to an agent + * only after ownership has been explicitely transferred to that agent. Once the + * operation completes, the previous owner cannot longer access the data in the + * buffer. + * + * An implementation of the HSA runtime is allowed, but not required, to change + * the physical location of the buffer when ownership is transferred to a + * different agent. In general the application must not assume this + * behavior. The virtual location (address) of the passed buffer is never + * modified. + * + * @param[in] ptr Base address of a global buffer. The pointer must match an + * address previously returned by ::hsa_memory_allocate. The size of the buffer + * affected by the ownership change is identical to the size of that previous + * allocation. If @p ptr points to a fine-grained global buffer, no operation is + * performed and the function returns success. If @p ptr does not point to + * global memory, the behavior is undefined. + * + * @param[in] agent Agent that becomes the owner of the buffer. The + * application is responsible for ensuring that @p agent has access to the + * region that contains the buffer. It is allowed to change ownership to an + * agent that is already the owner of the buffer, with the same or different + * access permissions. + * + * @param[in] access Access permissions requested for the new owner. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p ptr is NULL, or @p access is + * not a valid access value. + */ +hsa_status_t HSA_API hsa_memory_assign_agent(void *ptr, hsa_agent_t agent, + hsa_access_permission_t access); + +/** + * + * @brief Register a global, fine-grained buffer. + * + * @details Registering a buffer serves as an indication to the HSA runtime that + * the memory might be accessed from a kernel agent other than the + * host. Registration is a performance hint that allows the HSA runtime + * implementation to know which buffers will be accessed by some of the kernel + * agents ahead of time. + * + * Registration is only recommended for buffers in the global segment that have + * not been allocated using the HSA allocator (::hsa_memory_allocate), but an OS + * allocator instead. Registering an OS-allocated buffer in the base profile is + * equivalent to a no-op. + * + * Registrations should not overlap. + * + * @param[in] ptr A buffer in global, fine-grained memory. If a NULL pointer is + * passed, no operation is performed. If the buffer has been allocated using + * ::hsa_memory_allocate, or has already been registered, no operation is + * performed. + * + * @param[in] size Requested registration size in bytes. A size of 0 is + * only allowed if @p ptr is NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p size is 0 but @p ptr + * is not NULL. + */ +hsa_status_t HSA_API hsa_memory_register(void *ptr, size_t size); + +/** + * + * @brief Deregister memory previously registered using ::hsa_memory_register. + * + * @details If the memory interval being deregistered does not match a previous + * registration (start and end addresses), the behavior is undefined. + * + * @param[in] ptr A pointer to the base of the buffer to be deregistered. If + * a NULL pointer is passed, no operation is performed. + * + * @param[in] size Size of the buffer to be deregistered. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + */ +hsa_status_t HSA_API hsa_memory_deregister(void *ptr, size_t size); + +/** @} */ + +/** \defgroup instruction-set-architecture Instruction Set Architecture. + * @{ + */ + +/** + * @brief Instruction set architecture. + */ +typedef struct hsa_isa_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_isa_t; + +/** + * @brief Retrieve a reference to an instruction set architecture handle out of + * a symbolic name. + * + * @param[in] name Vendor-specific name associated with a a particular + * instruction set architecture. @p name must start with the vendor name and a + * colon (for example, "AMD:"). The rest of the name is vendor-specific. Must be + * a NUL-terminated string. + * + * @param[out] isa Memory location where the HSA runtime stores the ISA handle + * corresponding to the given name. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ISA_NAME The given name does not + * correspond to any instruction set architecture. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p name is NULL, or @p isa is + * NULL. + */ +hsa_status_t HSA_API hsa_isa_from_name(const char *name, hsa_isa_t *isa); + +/** + * @brief Iterate over the instruction sets supported by the given agent, and + * invoke an application-defined callback on every iteration. The iterator is + * deterministic: if an agent supports several instruction set architectures, + * they are traversed in the same order in every invocation of this function. + * + * @param[in] agent A valid agent. + * + * @param[in] callback Callback to be invoked once per instruction set + * architecture. The HSA runtime passes two arguments to the callback: the + * ISA and the application data. If @p callback returns a status other than + * ::HSA_STATUS_SUCCESS for a particular iteration, the traversal stops and + * that status value is returned. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_agent_iterate_isas( + hsa_agent_t agent, hsa_status_t (*callback)(hsa_isa_t isa, void *data), + void *data); + +/** + * @brief Instruction set architecture attributes. + */ +typedef enum { + /** + * The length of the ISA name in bytes, not including the NUL terminator. The + * type of this attribute is uint32_t. + */ + HSA_ISA_INFO_NAME_LENGTH = 0, + /** + * Human-readable description. The type of this attribute is character array + * with the length equal to the value of ::HSA_ISA_INFO_NAME_LENGTH attribute. + */ + HSA_ISA_INFO_NAME = 1, + /** + * @deprecated + * + * Number of call conventions supported by the instruction set architecture. + * Must be greater than zero. The type of this attribute is uint32_t. + */ + HSA_ISA_INFO_CALL_CONVENTION_COUNT = 2, + /** + * @deprecated + * + * Number of work-items in a wavefront for a given call convention. Must be a + * power of 2 in the range [1,256]. The type of this attribute is uint32_t. + */ + HSA_ISA_INFO_CALL_CONVENTION_INFO_WAVEFRONT_SIZE = 3, + /** + * @deprecated + * + * Number of wavefronts per compute unit for a given call convention. In + * practice, other factors (for example, the amount of group memory used by a + * work-group) may further limit the number of wavefronts per compute + * unit. The type of this attribute is uint32_t. + */ + HSA_ISA_INFO_CALL_CONVENTION_INFO_WAVEFRONTS_PER_COMPUTE_UNIT = 4, + /** + * Machine models supported by the instruction set architecture. The type of + * this attribute is a bool[2]. If the ISA supports the small machine model, + * the element at index ::HSA_MACHINE_MODEL_SMALL is true. If the ISA supports + * the large model, the element at index ::HSA_MACHINE_MODEL_LARGE is true. + */ + HSA_ISA_INFO_MACHINE_MODELS = 5, + /** + * Profiles supported by the instruction set architecture. The type of this + * attribute is a bool[2]. If the ISA supports the base profile, the element + * at index ::HSA_PROFILE_BASE is true. If the ISA supports the full profile, + * the element at index ::HSA_PROFILE_FULL is true. + */ + HSA_ISA_INFO_PROFILES = 6, + /** + * Default floating-point rounding modes supported by the instruction set + * architecture. The type of this attribute is a bool[3]. The value at a given + * index is true if the corresponding rounding mode in + * ::hsa_default_float_rounding_mode_t is supported. At least one default mode + * has to be supported. + * + * If the default mode is supported, then + * ::HSA_ISA_INFO_BASE_PROFILE_DEFAULT_FLOAT_ROUNDING_MODES must report that + * both the zero and the near roundings modes are supported. + */ + HSA_ISA_INFO_DEFAULT_FLOAT_ROUNDING_MODES = 7, + /** + * Default floating-point rounding modes supported by the instruction set + * architecture in the Base profile. The type of this attribute is a + * bool[3]. The value at a given index is true if the corresponding rounding + * mode in ::hsa_default_float_rounding_mode_t is supported. The value at + * index HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT must be false. At least one + * of the values at indexes ::HSA_DEFAULT_FLOAT_ROUNDING_MODE_ZERO or + * HSA_DEFAULT_FLOAT_ROUNDING_MODE_NEAR must be true. + */ + HSA_ISA_INFO_BASE_PROFILE_DEFAULT_FLOAT_ROUNDING_MODES = 8, + /** + * Flag indicating that the f16 HSAIL operation is at least as fast as the + * f32 operation in the instruction set architecture. The type of this + * attribute is bool. + */ + HSA_ISA_INFO_FAST_F16_OPERATION = 9, + /** + * Maximum number of work-items of each dimension of a work-group. Each + * maximum must be greater than 0. No maximum can exceed the value of + * ::HSA_ISA_INFO_WORKGROUP_MAX_SIZE. The type of this attribute is + * uint16_t[3]. + */ + HSA_ISA_INFO_WORKGROUP_MAX_DIM = 12, + /** + * Maximum total number of work-items in a work-group. The type + * of this attribute is uint32_t. + */ + HSA_ISA_INFO_WORKGROUP_MAX_SIZE = 13, + /** + * Maximum number of work-items of each dimension of a grid. Each maximum must + * be greater than 0, and must not be smaller than the corresponding value in + * ::HSA_ISA_INFO_WORKGROUP_MAX_DIM. No maximum can exceed the value of + * ::HSA_ISA_INFO_GRID_MAX_SIZE. The type of this attribute is + * ::hsa_dim3_t. + */ + HSA_ISA_INFO_GRID_MAX_DIM = 14, + /** + * Maximum total number of work-items in a grid. The type of this + * attribute is uint64_t. + */ + HSA_ISA_INFO_GRID_MAX_SIZE = 16, + /** + * Maximum number of fbarriers per work-group. Must be at least 32. The + * type of this attribute is uint32_t. + */ + HSA_ISA_INFO_FBARRIER_MAX_SIZE = 17 +} hsa_isa_info_t; + +/** + * @deprecated The concept of call convention has been deprecated. If the + * application wants to query the value of an attribute for a given instruction + * set architecture, use ::hsa_isa_get_info_alt instead. If the application + * wants to query an attribute that is specific to a given combination of ISA + * and wavefront, use ::hsa_wavefront_get_info. + * + * @brief Get the current value of an attribute for a given instruction set + * architecture (ISA). + * + * @param[in] isa A valid instruction set architecture. + * + * @param[in] attribute Attribute to query. + * + * @param[in] index Call convention index. Used only for call convention + * attributes, otherwise ignored. Must have a value between 0 (inclusive) and + * the value of the attribute ::HSA_ISA_INFO_CALL_CONVENTION_COUNT (not + * inclusive) in @p isa. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ISA The instruction set architecture is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_INDEX The index is out of range. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * instruction set architecture attribute, or @p value is + * NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_isa_get_info(hsa_isa_t isa, + hsa_isa_info_t attribute, + uint32_t index, + void *value); + +/** + * @brief Get the current value of an attribute for a given instruction set + * architecture (ISA). + * + * @param[in] isa A valid instruction set architecture. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ISA The instruction set architecture is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * instruction set architecture attribute, or @p value is + * NULL. + */ +hsa_status_t HSA_API hsa_isa_get_info_alt(hsa_isa_t isa, + hsa_isa_info_t attribute, + void *value); + +/** + * @brief Retrieve the exception policy support for a given combination of + * instruction set architecture and profile. + * + * @param[in] isa A valid instruction set architecture. + * + * @param[in] profile Profile. + * + * @param[out] mask Pointer to a memory location where the HSA runtime stores a + * mask of ::hsa_exception_policy_t values. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ISA The instruction set architecture is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p profile is not a valid + * profile, or @p mask is NULL. + */ +hsa_status_t HSA_API hsa_isa_get_exception_policies(hsa_isa_t isa, + hsa_profile_t profile, + uint16_t *mask); + +/** + * @brief Floating-point types. + */ +typedef enum { + /** + * 16-bit floating-point type. + */ + HSA_FP_TYPE_16 = 1, + /** + * 32-bit floating-point type. + */ + HSA_FP_TYPE_32 = 2, + /** + * 64-bit floating-point type. + */ + HSA_FP_TYPE_64 = 4 +} hsa_fp_type_t; + +/** + * @brief Flush to zero modes. + */ +typedef enum { + /** + * Flush to zero. + */ + HSA_FLUSH_MODE_FTZ = 1, + /** + * Do not flush to zero. + */ + HSA_FLUSH_MODE_NON_FTZ = 2 +} hsa_flush_mode_t; + +/** + * @brief Round methods. + */ +typedef enum { + /** + * Single round method. + */ + HSA_ROUND_METHOD_SINGLE = 1, + /** + * Double round method. + */ + HSA_ROUND_METHOD_DOUBLE = 2 +} hsa_round_method_t; + +/** + * @brief Retrieve the round method (single or double) used to implement the + * floating-point multiply add instruction (mad) for a given combination of + * instruction set architecture, floating-point type, and flush to zero + * modifier. + * + * @param[in] isa Instruction set architecture. + * + * @param[in] fp_type Floating-point type. + * + * @param[in] flush_mode Flush to zero modifier. + * + * @param[out] round_method Pointer to a memory location where the HSA + * runtime stores the round method used by the implementation. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ISA The instruction set architecture is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p fp_type is not a valid + * floating-point type, or @p flush_mode is not a valid flush to zero modifier, + * or @p round_method is NULL. + */ +hsa_status_t HSA_API hsa_isa_get_round_method(hsa_isa_t isa, + hsa_fp_type_t fp_type, + hsa_flush_mode_t flush_mode, + hsa_round_method_t *round_method); + +/** + * @brief Wavefront handle + */ +typedef struct hsa_wavefront_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_wavefront_t; + +/** + * @brief Wavefront attributes. + */ +typedef enum { + /** + * Number of work-items in the wavefront. Must be a power of 2 in the range + * [1,256]. The type of this attribute is uint32_t. + */ + HSA_WAVEFRONT_INFO_SIZE = 0 +} hsa_wavefront_info_t; + +/** + * @brief Get the current value of a wavefront attribute. + * + * @param[in] wavefront A wavefront. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_WAVEFRONT The wavefront is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * wavefront attribute, or @p value is NULL. + */ +hsa_status_t HSA_API hsa_wavefront_get_info(hsa_wavefront_t wavefront, + hsa_wavefront_info_t attribute, + void *value); + +/** + * @brief Iterate over the different wavefronts supported by an instruction set + * architecture, and invoke an application-defined callback on every iteration. + * + * @param[in] isa Instruction set architecture. + * + * @param[in] callback Callback to be invoked once per wavefront that is + * supported by the agent. The HSA runtime passes two arguments to the callback: + * the wavefront handle and the application data. If @p callback returns a + * status other than ::HSA_STATUS_SUCCESS for a particular iteration, the + * traversal stops and that value is returned. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ISA The instruction set architecture is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_isa_iterate_wavefronts( + hsa_isa_t isa, + hsa_status_t (*callback)(hsa_wavefront_t wavefront, void *data), + void *data); + +/** + * @deprecated Use ::hsa_agent_iterate_isas to query which instructions set + * architectures are supported by a given agent. + * + * @brief Check if the instruction set architecture of a code object can be + * executed on an agent associated with another architecture. + * + * @param[in] code_object_isa Instruction set architecture associated with a + * code object. + * + * @param[in] agent_isa Instruction set architecture associated with an agent. + * + * @param[out] result Pointer to a memory location where the HSA runtime stores + * the result of the check. If the two architectures are compatible, the result + * is true; if they are incompatible, the result is false. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ISA @p code_object_isa or @p agent_isa are + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p result is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_isa_compatible( + hsa_isa_t code_object_isa, hsa_isa_t agent_isa, bool *result); + +/** @} */ + +/** \defgroup executable Executable + * @{ + */ + +/** + * @brief Code object reader handle. A code object reader is used to + * load a code object from file (when created using + * ::hsa_code_object_reader_create_from_file), or from memory (if created using + * ::hsa_code_object_reader_create_from_memory). + */ +typedef struct hsa_code_object_reader_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_code_object_reader_t; + +/** + * @brief Create a code object reader to operate on a file. + * + * @param[in] file File descriptor. The file must have been opened by + * application with at least read permissions prior calling this function. The + * file must contain a vendor-specific code object. + * + * The file is owned and managed by the application; the lifetime of the file + * descriptor must exceed that of any associated code object reader. + * + * @param[out] code_object_reader Memory location to store the newly created + * code object reader handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_FILE @p file is invalid. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p code_object_reader is NULL. + */ +hsa_status_t HSA_API hsa_code_object_reader_create_from_file( + hsa_file_t file, hsa_code_object_reader_t *code_object_reader); + +/** + * @brief Create a code object reader to operate on memory. + * + * @param[in] code_object Memory buffer that contains a vendor-specific code + * object. The buffer is owned and managed by the application; the lifetime of + * the buffer must exceed that of any associated code object reader. + * + * @param[in] size Size of the buffer pointed to by @p code_object. Must not be + * 0. + * + * @param[out] code_object_reader Memory location to store newly created code + * object reader handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p code_object is NULL, @p size + * is zero, or @p code_object_reader is NULL. + */ +hsa_status_t HSA_API hsa_code_object_reader_create_from_memory( + const void *code_object, size_t size, + hsa_code_object_reader_t *code_object_reader); + +/** + * @brief Destroy a code object reader. + * + * @details The code object reader handle becomes invalid after completion of + * this function. Any file or memory used to create the code object read is not + * closed, removed, or deallocated by this function. + * + * @param[in] code_object_reader Code object reader to destroy. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER @p code_object_reader + * is invalid. + */ +hsa_status_t HSA_API +hsa_code_object_reader_destroy(hsa_code_object_reader_t code_object_reader); + +/** + * @brief Struct containing an opaque handle to an executable, which contains + * ISA for finalized kernels and indirect functions together with the allocated + * global or readonly segment variables they reference. + */ +typedef struct hsa_executable_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_executable_t; + +/** + * @brief Executable state. + */ +typedef enum { + /** + * Executable state, which allows the user to load code objects and define + * external variables. Variable addresses, kernel code handles, and + * indirect function code handles are not available in query operations until + * the executable is frozen (zero always returned). + */ + HSA_EXECUTABLE_STATE_UNFROZEN = 0, + /** + * Executable state, which allows the user to query variable addresses, + * kernel code handles, and indirect function code handles using query + * operations. Loading new code objects, as well as defining external + * variables, is not allowed in this state. + */ + HSA_EXECUTABLE_STATE_FROZEN = 1 +} hsa_executable_state_t; + +/** + * @deprecated Use ::hsa_executable_create_alt instead, which allows the + * application to specify the default floating-point rounding mode of the + * executable and assumes an unfrozen initial state. + * + * @brief Create an empty executable. + * + * @param[in] profile Profile used in the executable. + * + * @param[in] executable_state Executable state. If the state is + * ::HSA_EXECUTABLE_STATE_FROZEN, the resulting executable is useless because no + * code objects can be loaded, and no variables can be defined. + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @param[out] executable Memory location where the HSA runtime stores the newly + * created executable handle. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p profile is invalid, or + * @p executable is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_executable_create( + hsa_profile_t profile, hsa_executable_state_t executable_state, + const char *options, hsa_executable_t *executable); + +/** + * @brief Create an empty executable. + * + * @param[in] profile Profile used in the executable. + * + * @param[in] default_float_rounding_mode Default floating-point rounding mode + * used in the executable. Allowed rounding modes are near and zero (default is + * not allowed). + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @param[out] executable Memory location where the HSA runtime stores newly + * created executable handle. The initial state of the executable is + * ::HSA_EXECUTABLE_STATE_UNFROZEN. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p profile is invalid, or + * @p executable is NULL. + */ +hsa_status_t HSA_API hsa_executable_create_alt( + hsa_profile_t profile, + hsa_default_float_rounding_mode_t default_float_rounding_mode, + const char *options, hsa_executable_t *executable); + +/** + * @brief Destroy an executable. + * + * @details An executable handle becomes invalid after the executable has been + * destroyed. Code object handles that were loaded into this executable are + * still valid after the executable has been destroyed, and can be used as + * intended. Resources allocated outside and associated with this executable + * (such as external global or readonly variables) can be released after the + * executable has been destroyed. + * + * Executable should not be destroyed while kernels are in flight. + * + * @param[in] executable Executable. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + */ +hsa_status_t HSA_API hsa_executable_destroy(hsa_executable_t executable); + +/** + * @brief Loaded code object handle. + */ +typedef struct hsa_loaded_code_object_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_loaded_code_object_t; + +/** + * @brief Load a program code object into an executable. + * + * @details A program code object contains information about resources that are + * accessible by all kernel agents that run the executable, and can be loaded + * at most once into an executable. + * + * If the program code object uses extensions, the implementation must support + * them for this operation to return successfully. + * + * @param[in] executable Executable. + * + * @param[in] code_object_reader A code object reader that holds the program + * code object to load. If a code object reader is destroyed before all the + * associated executables are destroyed, the behavior is undefined. + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @param[out] loaded_code_object Pointer to a memory location where the HSA + * runtime stores the loaded code object handle. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_FROZEN_EXECUTABLE The executable is frozen. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER @p code_object_reader + * is invalid. + * + * @retval ::HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS The program code object is + * not compatible with the executable or the implementation (for example, the + * code object uses an extension that is not supported by the implementation). + */ +hsa_status_t HSA_API hsa_executable_load_program_code_object( + hsa_executable_t executable, hsa_code_object_reader_t code_object_reader, + const char *options, hsa_loaded_code_object_t *loaded_code_object); + +/** + * @brief Load an agent code object into an executable. + * + * @details The agent code object contains all defined agent + * allocation variables, functions, indirect functions, and kernels in a given + * program for a given instruction set architecture. + * + * Any module linkage declaration must have been defined either by a define + * variable or by loading a code object that has a symbol with module linkage + * definition. + * + * The default floating-point rounding mode of the code object associated with + * @p code_object_reader must match that of the executable + * (::HSA_EXECUTABLE_INFO_DEFAULT_FLOAT_ROUNDING_MODE), or be default (in which + * case the value of ::HSA_EXECUTABLE_INFO_DEFAULT_FLOAT_ROUNDING_MODE is used). + * If the agent code object uses extensions, the implementation and the agent + * must support them for this operation to return successfully. + * + * @param[in] executable Executable. + * + * @param[in] agent Agent to load code object for. A code object can be loaded + * into an executable at most once for a given agent. The instruction set + * architecture of the code object must be supported by the agent. + * + * @param[in] code_object_reader A code object reader that holds the code object + * to load. If a code object reader is destroyed before all the associated + * executables are destroyed, the behavior is undefined. + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @param[out] loaded_code_object Pointer to a memory location where the HSA + * runtime stores the loaded code object handle. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_FROZEN_EXECUTABLE The executable is frozen. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER @p code_object_reader + * is invalid. + * + * @retval ::HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS The code object read by @p + * code_object_reader is not compatible with the agent (for example, the agent + * does not support the instruction set architecture of the code object), the + * executable (for example, there is a default floating-point mode mismatch + * between the two), or the implementation. + */ +hsa_status_t HSA_API hsa_executable_load_agent_code_object( + hsa_executable_t executable, hsa_agent_t agent, + hsa_code_object_reader_t code_object_reader, const char *options, + hsa_loaded_code_object_t *loaded_code_object); + +/** + * @brief Freeze the executable. + * + * @details No modifications to executable can be made after freezing: no code + * objects can be loaded to the executable, and no external variables can be + * defined. Freezing the executable does not prevent querying the executable's + * attributes. The application must define all the external variables in an + * executable before freezing it. + * + * @param[in] executable Executable. + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_VARIABLE_UNDEFINED One or more variables are + * undefined in the executable. + * + * @retval ::HSA_STATUS_ERROR_FROZEN_EXECUTABLE @p executable is already frozen. + */ +hsa_status_t HSA_API hsa_executable_freeze(hsa_executable_t executable, + const char *options); + +/** + * @brief Executable attributes. + */ +typedef enum { + /** + * Profile this executable is created for. The type of this attribute is + * ::hsa_profile_t. + */ + HSA_EXECUTABLE_INFO_PROFILE = 1, + /** + * Executable state. The type of this attribute is ::hsa_executable_state_t. + */ + HSA_EXECUTABLE_INFO_STATE = 2, + /** + * Default floating-point rounding mode specified when executable was created. + * The type of this attribute is ::hsa_default_float_rounding_mode_t. + */ + HSA_EXECUTABLE_INFO_DEFAULT_FLOAT_ROUNDING_MODE = 3 +} hsa_executable_info_t; + +/** + * @brief Get the current value of an attribute for a given executable. + * + * @param[in] executable Executable. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * executable attribute, or @p value is NULL. + */ +hsa_status_t HSA_API hsa_executable_get_info(hsa_executable_t executable, + hsa_executable_info_t attribute, + void *value); + +/** + * @brief Define an external global variable with program allocation. + * + * @details This function allows the application to provide the definition + * of a variable in the global segment memory with program allocation. The + * variable must be defined before loading a code object into an executable. + * In addition, code objects loaded must not define the variable. + * + * @param[in] executable Executable. Must not be in frozen state. + * + * @param[in] variable_name Name of the variable. The Programmer's Reference + * Manual describes the standard name mangling scheme. + * + * @param[in] address Address where the variable is defined. This address must + * be in global memory and can be read and written by any agent in the + * system. The application cannot deallocate the buffer pointed by @p address + * before @p executable is destroyed. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_VARIABLE_ALREADY_DEFINED The variable is + * already defined. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SYMBOL_NAME There is no variable with the + * @p variable_name. + * + * @retval ::HSA_STATUS_ERROR_FROZEN_EXECUTABLE @p executable is frozen. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p variable_name is NULL. + */ +hsa_status_t HSA_API hsa_executable_global_variable_define( + hsa_executable_t executable, const char *variable_name, void *address); + +/** + * @brief Define an external global variable with agent allocation. + * + * @details This function allows the application to provide the definition + * of a variable in the global segment memory with agent allocation. The + * variable must be defined before loading a code object into an executable. + * In addition, code objects loaded must not define the variable. + * + * @param[in] executable Executable. Must not be in frozen state. + * + * @param[in] agent Agent for which the variable is being defined. + * + * @param[in] variable_name Name of the variable. The Programmer's Reference + * Manual describes the standard name mangling scheme. + * + * @param[in] address Address where the variable is defined. This address must + * have been previously allocated using ::hsa_memory_allocate in a global region + * that is only visible to @p agent. The application cannot deallocate the + * buffer pointed by @p address before @p executable is destroyed. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT @p agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_VARIABLE_ALREADY_DEFINED The variable is + * already defined. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SYMBOL_NAME There is no variable with the + * @p variable_name. + * + * @retval ::HSA_STATUS_ERROR_FROZEN_EXECUTABLE @p executable is frozen. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p variable_name is NULL. + */ +hsa_status_t HSA_API hsa_executable_agent_global_variable_define( + hsa_executable_t executable, hsa_agent_t agent, const char *variable_name, + void *address); + +/** + * @brief Define an external readonly variable. + * + * @details This function allows the application to provide the definition + * of a variable in the readonly segment memory. The variable must be defined + * before loading a code object into an executable. In addition, code objects + * loaded must not define the variable. + * + * @param[in] executable Executable. Must not be in frozen state. + * + * @param[in] agent Agent for which the variable is being defined. + * + * @param[in] variable_name Name of the variable. The Programmer's Reference + * Manual describes the standard name mangling scheme. + * + * @param[in] address Address where the variable is defined. This address must + * have been previously allocated using ::hsa_memory_allocate in a readonly + * region associated with @p agent. The application cannot deallocate the buffer + * pointed by @p address before @p executable is destroyed. + * + * @param[in] address Address where the variable is defined. The buffer pointed + * by @p address is owned by the application, and cannot be deallocated before + * @p executable is destroyed. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE Executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT @p agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_VARIABLE_ALREADY_DEFINED The variable is + * already defined. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SYMBOL_NAME There is no variable with the + * @p variable_name. + * + * @retval ::HSA_STATUS_ERROR_FROZEN_EXECUTABLE @p executable is frozen. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p variable_name is NULL. + */ +hsa_status_t HSA_API hsa_executable_readonly_variable_define( + hsa_executable_t executable, hsa_agent_t agent, const char *variable_name, + void *address); + +/** + * @brief Validate an executable. Checks that all code objects have matching + * machine model, profile, and default floating-point rounding mode. Checks that + * all declarations have definitions. Checks declaration-definition + * compatibility (see the HSA Programming Reference Manual for compatibility + * rules). Invoking this function is equivalent to invoking + * ::hsa_executable_validate_alt with no options. + * + * @param[in] executable Executable. Must be in frozen state. + * + * @param[out] result Memory location where the HSA runtime stores the + * validation result. If the executable passes validation, the result is 0. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE @p executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p result is NULL. + */ +hsa_status_t HSA_API hsa_executable_validate(hsa_executable_t executable, + uint32_t *result); + +/** + * @brief Validate an executable. Checks that all code objects have matching + * machine model, profile, and default floating-point rounding mode. Checks that + * all declarations have definitions. Checks declaration-definition + * compatibility (see the HSA Programming Reference Manual for compatibility + * rules). + * + * @param[in] executable Executable. Must be in frozen state. + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @param[out] result Memory location where the HSA runtime stores the + * validation result. If the executable passes validation, the result is 0. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE @p executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p result is NULL. + */ +hsa_status_t HSA_API hsa_executable_validate_alt(hsa_executable_t executable, + const char *options, + uint32_t *result); + +/** + * @brief Executable symbol handle. + * + * The lifetime of an executable object symbol matches that of the executable + * associated with it. An operation on a symbol whose associated executable has + * been destroyed results in undefined behavior. + */ +typedef struct hsa_executable_symbol_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_executable_symbol_t; + +/** + * @deprecated Use ::hsa_executable_get_symbol_by_name instead. + * + * @brief Get the symbol handle for a given a symbol name. + * + * @param[in] executable Executable. + * + * @param[in] module_name Module name. Must be NULL if the symbol has + * program linkage. + * + * @param[in] symbol_name Symbol name. + * + * @param[in] agent Agent associated with the symbol. If the symbol is + * independent of any agent (for example, a variable with program + * allocation), this argument is ignored. + * + * @param[in] call_convention Call convention associated with the symbol. If the + * symbol does not correspond to an indirect function, this argument is ignored. + * + * @param[out] symbol Memory location where the HSA runtime stores the symbol + * handle. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SYMBOL_NAME There is no symbol with a name + * that matches @p symbol_name. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p symbol_name is NULL, or + * @p symbol is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_executable_get_symbol( + hsa_executable_t executable, const char *module_name, + const char *symbol_name, hsa_agent_t agent, int32_t call_convention, + hsa_executable_symbol_t *symbol); + +/** + * @brief Retrieve the symbol handle corresponding to a given a symbol name. + * + * @param[in] executable Executable. + * + * @param[in] symbol_name Symbol name. Must be a NUL-terminated character + * array. The Programmer's Reference Manual describes the standard name mangling + * scheme. + * + * @param[in] agent Pointer to the agent for which the symbol with the given + * name is defined. If the symbol corresponding to the given name has program + * allocation, @p agent must be NULL. + * + * @param[out] symbol Memory location where the HSA runtime stores the symbol + * handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SYMBOL_NAME There is no symbol with a name + * that matches @p symbol_name. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p symbol_name is NULL, or @p + * symbol is NULL. + */ +hsa_status_t HSA_API hsa_executable_get_symbol_by_name( + hsa_executable_t executable, const char *symbol_name, + const hsa_agent_t *agent, hsa_executable_symbol_t *symbol); + +/** + * @brief Symbol type. + */ +typedef enum { + /** + * Variable. + */ + HSA_SYMBOL_KIND_VARIABLE = 0, + /** + * Kernel. + */ + HSA_SYMBOL_KIND_KERNEL = 1, + /** + * Indirect function. + */ + HSA_SYMBOL_KIND_INDIRECT_FUNCTION = 2 +} hsa_symbol_kind_t; + +/** + * @brief Linkage type of a symbol. + */ +typedef enum { + /** + * Module linkage. + */ + HSA_SYMBOL_LINKAGE_MODULE = 0, + /** + * Program linkage. + */ + HSA_SYMBOL_LINKAGE_PROGRAM = 1 +} hsa_symbol_linkage_t; + +/** + * @brief Allocation type of a variable. + */ +typedef enum { + /** + * Agent allocation. + */ + HSA_VARIABLE_ALLOCATION_AGENT = 0, + /** + * Program allocation. + */ + HSA_VARIABLE_ALLOCATION_PROGRAM = 1 +} hsa_variable_allocation_t; + +/** + * @brief Memory segment associated with a variable. + */ +typedef enum { + /** + * Global memory segment. + */ + HSA_VARIABLE_SEGMENT_GLOBAL = 0, + /** + * Readonly memory segment. + */ + HSA_VARIABLE_SEGMENT_READONLY = 1 +} hsa_variable_segment_t; + +/** + * @brief Executable symbol attributes. + */ +typedef enum { + /** + * The kind of the symbol. The type of this attribute is ::hsa_symbol_kind_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_TYPE = 0, + /** + * The length of the symbol name in bytes, not including the NUL terminator. + * The type of this attribute is uint32_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_NAME_LENGTH = 1, + /** + * The name of the symbol. The type of this attribute is character array with + * the length equal to the value of ::HSA_EXECUTABLE_SYMBOL_INFO_NAME_LENGTH + * attribute. + */ + HSA_EXECUTABLE_SYMBOL_INFO_NAME = 2, + /** + * @deprecated + * + * The length of the module name in bytes (not including the NUL terminator) + * to which this symbol belongs if this symbol has module linkage, otherwise 0 + * is returned. The type of this attribute is uint32_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_MODULE_NAME_LENGTH = 3, + /** + * @deprecated + * + * The module name to which this symbol belongs if this symbol has module + * linkage, otherwise an empty string is returned. The type of this attribute + * is character array with the length equal to the value of + * ::HSA_EXECUTABLE_SYMBOL_INFO_MODULE_NAME_LENGTH attribute. + */ + HSA_EXECUTABLE_SYMBOL_INFO_MODULE_NAME = 4, + /** + * @deprecated + * + * Agent associated with this symbol. If the symbol is a variable, the + * value of this attribute is only defined if + * ::HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ALLOCATION is + * ::HSA_VARIABLE_ALLOCATION_AGENT. The type of this attribute is hsa_agent_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_AGENT = 20, + /** + * The address of the variable. The value of this attribute is undefined if + * the symbol is not a variable. The type of this attribute is uint64_t. + * + * If executable's state is ::HSA_EXECUTABLE_STATE_UNFROZEN, then 0 is + * returned. + */ + HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ADDRESS = 21, + /** + * The linkage kind of the symbol. The type of this attribute is + * ::hsa_symbol_linkage_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_LINKAGE = 5, + /** + * Indicates whether the symbol corresponds to a definition. The type of this + * attribute is bool. + */ + HSA_EXECUTABLE_SYMBOL_INFO_IS_DEFINITION = 17, + /** + * @deprecated + * + * The allocation kind of the variable. The value of this attribute is + * undefined if the symbol is not a variable. The type of this attribute is + * ::hsa_variable_allocation_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ALLOCATION = 6, + /** + * @deprecated + * + * The segment kind of the variable. The value of this attribute is undefined + * if the symbol is not a variable. The type of this attribute is + * ::hsa_variable_segment_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_SEGMENT = 7, + /** + * @deprecated + * + * Alignment of the symbol in memory. The value of this attribute is undefined + * if the symbol is not a variable. The type of this attribute is uint32_t. + * + * The current alignment of the variable in memory may be greater than the + * value specified in the source program variable declaration. + */ + HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ALIGNMENT = 8, + /** + * @deprecated + * + * Size of the variable. The value of this attribute is undefined if + * the symbol is not a variable. The type of this attribute is uint32_t. + * + * A value of 0 is returned if the variable is an external variable and has an + * unknown dimension. + */ + HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_SIZE = 9, + /** + * @deprecated + * + * Indicates whether the variable is constant. The value of this attribute is + * undefined if the symbol is not a variable. The type of this attribute is + * bool. + */ + HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_IS_CONST = 10, + /** + * Kernel object handle, used in the kernel dispatch packet. The value of this + * attribute is undefined if the symbol is not a kernel. The type of this + * attribute is uint64_t. + * + * If the state of the executable is ::HSA_EXECUTABLE_STATE_UNFROZEN, then 0 + * is returned. + */ + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT = 22, + /** + * Size of kernarg segment memory that is required to hold the values of the + * kernel arguments, in bytes. Must be a multiple of 16. The value of this + * attribute is undefined if the symbol is not a kernel. The type of this + * attribute is uint32_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE = 11, + /** + * Alignment (in bytes) of the buffer used to pass arguments to the kernel, + * which is the maximum of 16 and the maximum alignment of any of the kernel + * arguments. The value of this attribute is undefined if the symbol is not a + * kernel. The type of this attribute is uint32_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_ALIGNMENT = 12, + /** + * Size of static group segment memory required by the kernel (per + * work-group), in bytes. The value of this attribute is undefined + * if the symbol is not a kernel. The type of this attribute is uint32_t. + * + * The reported amount does not include any dynamically allocated group + * segment memory that may be requested by the application when a kernel is + * dispatched. + */ + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE = 13, + /** + * Size of static private, spill, and arg segment memory required by + * this kernel (per work-item), in bytes. The value of this attribute is + * undefined if the symbol is not a kernel. The type of this attribute is + * uint32_t. + * + * If the value of ::HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_DYNAMIC_CALLSTACK is + * true, the kernel may use more private memory than the reported value, and + * the application must add the dynamic call stack usage to @a + * private_segment_size when populating a kernel dispatch packet. + */ + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE = 14, + /** + * Dynamic callstack flag. The value of this attribute is undefined if the + * symbol is not a kernel. The type of this attribute is bool. + * + * If this flag is set (the value is true), the kernel uses a dynamically + * sized call stack. This can happen if recursive calls, calls to indirect + * functions, or the HSAIL alloca instruction are present in the kernel. + */ + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_DYNAMIC_CALLSTACK = 15, + /** + * @deprecated + * + * Call convention of the kernel. The value of this attribute is undefined if + * the symbol is not a kernel. The type of this attribute is uint32_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_CALL_CONVENTION = 18, + /** + * Indirect function object handle. The value of this attribute is undefined + * if the symbol is not an indirect function, or the associated agent does + * not support the Full Profile. The type of this attribute depends on the + * machine model: the type is uint32_t for small machine model, and uint64_t + * for large model. + * + * If the state of the executable is ::HSA_EXECUTABLE_STATE_UNFROZEN, then 0 + * is returned. + */ + HSA_EXECUTABLE_SYMBOL_INFO_INDIRECT_FUNCTION_OBJECT = 23, + /** + * @deprecated + * + * Call convention of the indirect function. The value of this attribute is + * undefined if the symbol is not an indirect function, or the associated + * agent does not support the Full Profile. The type of this attribute is + * uint32_t. + */ + HSA_EXECUTABLE_SYMBOL_INFO_INDIRECT_FUNCTION_CALL_CONVENTION = 16 +} hsa_executable_symbol_info_t; + +/** + * @brief Get the current value of an attribute for a given executable symbol. + * + * @param[in] executable_symbol Executable symbol. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE_SYMBOL The executable symbol is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * executable symbol attribute, or @p value is NULL. + */ +hsa_status_t HSA_API hsa_executable_symbol_get_info( + hsa_executable_symbol_t executable_symbol, + hsa_executable_symbol_info_t attribute, void *value); + +/** + * @deprecated + * + * @brief Iterate over the symbols in a executable, and invoke an + * application-defined callback on every iteration. + * + * @param[in] executable Executable. + * + * @param[in] callback Callback to be invoked once per executable symbol. The + * HSA runtime passes three arguments to the callback: the executable, a symbol, + * and the application data. If @p callback returns a status other than + * ::HSA_STATUS_SUCCESS for a particular iteration, the traversal stops and + * ::hsa_executable_iterate_symbols returns that status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_executable_iterate_symbols( + hsa_executable_t executable, + hsa_status_t (*callback)(hsa_executable_t exec, + hsa_executable_symbol_t symbol, void *data), + void *data); + +/** + * @brief Iterate over the kernels, indirect functions, and agent allocation + * variables in an executable for a given agent, and invoke an application- + * defined callback on every iteration. + * + * @param[in] executable Executable. + * + * @param[in] agent Agent. + * + * @param[in] callback Callback to be invoked once per executable symbol. The + * HSA runtime passes three arguments to the callback: the executable, a symbol, + * and the application data. If @p callback returns a status other than + * ::HSA_STATUS_SUCCESS for a particular iteration, the traversal stops and + * ::hsa_executable_iterate_symbols returns that status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_executable_iterate_agent_symbols( + hsa_executable_t executable, hsa_agent_t agent, + hsa_status_t (*callback)(hsa_executable_t exec, hsa_agent_t agent, + hsa_executable_symbol_t symbol, void *data), + void *data); + +/** + * @brief Iterate over the program allocation variables in an executable, and + * invoke an application-defined callback on every iteration. + * + * @param[in] executable Executable. + * + * @param[in] callback Callback to be invoked once per executable symbol. The + * HSA runtime passes three arguments to the callback: the executable, a symbol, + * and the application data. If @p callback returns a status other than + * ::HSA_STATUS_SUCCESS for a particular iteration, the traversal stops and + * ::hsa_executable_iterate_symbols returns that status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_executable_iterate_program_symbols( + hsa_executable_t executable, + hsa_status_t (*callback)(hsa_executable_t exec, + hsa_executable_symbol_t symbol, void *data), + void *data); + +/** @} */ + +/** \defgroup code-object Code Objects (deprecated). + * @{ + */ + +/** + * @deprecated + * + * @brief Struct containing an opaque handle to a code object, which contains + * ISA for finalized kernels and indirect functions together with information + * about the global or readonly segment variables they reference. + */ +typedef struct hsa_code_object_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_code_object_t; + +/** + * @deprecated + * + * @brief Application data handle that is passed to the serialization + * and deserialization functions. + */ +typedef struct hsa_callback_data_s { + /** + * Opaque handle. + */ + uint64_t handle; +} hsa_callback_data_t; + +/** + * @deprecated + * + * @brief Serialize a code object. Can be used for offline finalization, + * install-time finalization, disk code caching, etc. + * + * @param[in] code_object Code object. + * + * @param[in] alloc_callback Callback function for memory allocation. Must not + * be NULL. The HSA runtime passes three arguments to the callback: the + * allocation size, the application data, and a pointer to a memory location + * where the application stores the allocation result. The HSA runtime invokes + * @p alloc_callback once to allocate a buffer that contains the serialized + * version of @p code_object. If the callback returns a status code other than + * ::HSA_STATUS_SUCCESS, this function returns the same code. + * + * @param[in] callback_data Application data that is passed to @p + * alloc_callback. May be NULL. + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @param[out] serialized_code_object Memory location where the HSA runtime + * stores a pointer to the serialized code object. Must not be NULL. + * + * @param[out] serialized_code_object_size Memory location where the HSA runtime + * stores the size (in bytes) of @p serialized_code_object. The returned value + * matches the allocation size passed by the HSA runtime to @p + * alloc_callback. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT @p code_object is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p alloc_callback, @p + * serialized_code_object, or @p serialized_code_object_size are NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_code_object_serialize( + hsa_code_object_t code_object, + hsa_status_t (*alloc_callback)(size_t size, hsa_callback_data_t data, + void **address), + hsa_callback_data_t callback_data, const char *options, + void **serialized_code_object, size_t *serialized_code_object_size); + +/** + * @deprecated + * + * @brief Deserialize a code object. + * + * @param[in] serialized_code_object A serialized code object. Must not be NULL. + * + * @param[in] serialized_code_object_size The size (in bytes) of @p + * serialized_code_object. Must not be 0. + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @param[out] code_object Memory location where the HSA runtime stores the + * deserialized code object. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p serialized_code_object, or @p + * code_object are NULL, or @p serialized_code_object_size is 0. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_code_object_deserialize( + void *serialized_code_object, size_t serialized_code_object_size, + const char *options, hsa_code_object_t *code_object); + +/** + * @deprecated + * + * @brief Destroy a code object. + * + * @details The lifetime of a code object must exceed that of any executable + * where it has been loaded. If an executable that loaded @p code_object has not + * been destroyed, the behavior is undefined. + * + * @param[in] code_object Code object. The handle becomes invalid after it has + * been destroyed. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT @p code_object is invalid. + */ +hsa_status_t HSA_API HSA_DEPRECATED +hsa_code_object_destroy(hsa_code_object_t code_object); + +/** + * @deprecated + * + * @brief Code object type. + */ +typedef enum { + /** + * Produces code object that contains ISA for all kernels and indirect + * functions in HSA source. + */ + HSA_CODE_OBJECT_TYPE_PROGRAM = 0 +} hsa_code_object_type_t; + +/** + * @deprecated + * + * @brief Code object attributes. + */ +typedef enum { + /** + * The version of the code object. The type of this attribute is a + * NUL-terminated char[64]. The name must be at most 63 characters long (not + * including the NUL terminator) and all array elements not used for the name + * must be NUL. + */ + HSA_CODE_OBJECT_INFO_VERSION = 0, + /** + * Type of code object. The type of this attribute is + * ::hsa_code_object_type_t. + */ + HSA_CODE_OBJECT_INFO_TYPE = 1, + /** + * Instruction set architecture this code object is produced for. The type of + * this attribute is ::hsa_isa_t. + */ + HSA_CODE_OBJECT_INFO_ISA = 2, + /** + * Machine model this code object is produced for. The type of this attribute + * is ::hsa_machine_model_t. + */ + HSA_CODE_OBJECT_INFO_MACHINE_MODEL = 3, + /** + * Profile this code object is produced for. The type of this attribute is + * ::hsa_profile_t. + */ + HSA_CODE_OBJECT_INFO_PROFILE = 4, + /** + * Default floating-point rounding mode used when the code object is + * produced. The type of this attribute is + * ::hsa_default_float_rounding_mode_t. + */ + HSA_CODE_OBJECT_INFO_DEFAULT_FLOAT_ROUNDING_MODE = 5 +} hsa_code_object_info_t; + +/** + * @deprecated + * + * @brief Get the current value of an attribute for a given code object. + * + * @param[in] code_object Code object. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT @p code_object is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * code object attribute, or @p value is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED +hsa_code_object_get_info(hsa_code_object_t code_object, + hsa_code_object_info_t attribute, void *value); + +/** + * @deprecated + * + * @brief Load code object into the executable. + * + * @details Every global or readonly variable that is external must be defined + * before loading the code object. An internal global or readonly variable is + * allocated once the code object, that is being loaded, references this + * variable and this variable is not allocated. + * + * Any module linkage declaration must have been defined either by a define + * variable or by loading a code object that has a symbol with module linkage + * definition. + * + * @param[in] executable Executable. + * + * @param[in] agent Agent to load code object for. The agent must support the + * default floating-point rounding mode used by @p code_object. + * + * @param[in] code_object Code object to load. The lifetime of the code object + * must exceed that of the executable: if @p code_object is destroyed before @p + * executable, the behavior is undefined. + * + * @param[in] options Standard and vendor-specific options. Unknown options are + * ignored. A standard option begins with the "-hsa_" prefix. Options beginning + * with the "-hsa_ext__" prefix are reserved for extensions. A + * vendor-specific option begins with the "-_" prefix. Must be a + * NUL-terminated string. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT @p code_object is invalid. + * + * @retval ::HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS @p agent is not compatible + * with @p code_object (for example, @p agent does not support the default + * floating-point rounding mode specified by @p code_object), or @p code_object + * is not compatible with @p executable (for example, @p code_object and @p + * executable have different machine models or profiles). + * + * @retval ::HSA_STATUS_ERROR_FROZEN_EXECUTABLE @p executable is frozen. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_executable_load_code_object( + hsa_executable_t executable, hsa_agent_t agent, + hsa_code_object_t code_object, const char *options); + +/** + * @deprecated + * + * @brief Code object symbol handle. + * + * The lifetime of a code object symbol matches that of the code object + * associated with it. An operation on a symbol whose associated code object has + * been destroyed results in undefined behavior. + */ +typedef struct hsa_code_symbol_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_code_symbol_t; + +/** + * @deprecated + * + * @brief Get the symbol handle within a code object for a given a symbol name. + * + * @param[in] code_object Code object. + * + * @param[in] symbol_name Symbol name. + * + * @param[out] symbol Memory location where the HSA runtime stores the symbol + * handle. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT @p code_object is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SYMBOL_NAME There is no symbol with a name + * that matches @p symbol_name. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p symbol_name is NULL, or + * @p symbol is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED +hsa_code_object_get_symbol(hsa_code_object_t code_object, + const char *symbol_name, hsa_code_symbol_t *symbol); + +/** + * @deprecated + * + * @brief Get the symbol handle within a code object for a given a symbol name. + * + * @param[in] code_object Code object. + * + * @param[in] module_name Module name. Must be NULL if the symbol has + * program linkage. + * + * @param[in] symbol_name Symbol name. + * + * @param[out] symbol Memory location where the HSA runtime stores the symbol + * handle. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT @p code_object is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SYMBOL_NAME There is no symbol with a name + * that matches @p symbol_name. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p symbol_name is NULL, or + * @p symbol is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_code_object_get_symbol_from_name( + hsa_code_object_t code_object, const char *module_name, + const char *symbol_name, hsa_code_symbol_t *symbol); + +/** + * @deprecated + * + * @brief Code object symbol attributes. + */ +typedef enum { + /** + * The type of the symbol. The type of this attribute is ::hsa_symbol_kind_t. + */ + HSA_CODE_SYMBOL_INFO_TYPE = 0, + /** + * The length of the symbol name in bytes, not including the NUL terminator. + * The type of this attribute is uint32_t. + */ + HSA_CODE_SYMBOL_INFO_NAME_LENGTH = 1, + /** + * The name of the symbol. The type of this attribute is character array with + * the length equal to the value of ::HSA_CODE_SYMBOL_INFO_NAME_LENGTH + * attribute. + */ + HSA_CODE_SYMBOL_INFO_NAME = 2, + /** + * The length of the module name in bytes (not including the NUL terminator) + * to which this symbol belongs if this symbol has module linkage, otherwise 0 + * is returned. The type of this attribute is uint32_t. + */ + HSA_CODE_SYMBOL_INFO_MODULE_NAME_LENGTH = 3, + /** + * The module name to which this symbol belongs if this symbol has module + * linkage, otherwise an empty string is returned. The type of this attribute + * is character array with the length equal to the value of + * ::HSA_CODE_SYMBOL_INFO_MODULE_NAME_LENGTH attribute. + */ + HSA_CODE_SYMBOL_INFO_MODULE_NAME = 4, + /** + * The linkage kind of the symbol. The type of this attribute is + * ::hsa_symbol_linkage_t. + */ + HSA_CODE_SYMBOL_INFO_LINKAGE = 5, + /** + * Indicates whether the symbol corresponds to a definition. The type of this + * attribute is bool. + */ + HSA_CODE_SYMBOL_INFO_IS_DEFINITION = 17, + /** + * The allocation kind of the variable. The value of this attribute is + * undefined if the symbol is not a variable. The type of this attribute is + * ::hsa_variable_allocation_t. + */ + HSA_CODE_SYMBOL_INFO_VARIABLE_ALLOCATION = 6, + /** + * The segment kind of the variable. The value of this attribute is + * undefined if the symbol is not a variable. The type of this attribute is + * ::hsa_variable_segment_t. + */ + HSA_CODE_SYMBOL_INFO_VARIABLE_SEGMENT = 7, + /** + * Alignment of the symbol in memory. The value of this attribute is undefined + * if the symbol is not a variable. The type of this attribute is uint32_t. + * + * The current alignment of the variable in memory may be greater than the + * value specified in the source program variable declaration. + */ + HSA_CODE_SYMBOL_INFO_VARIABLE_ALIGNMENT = 8, + /** + * Size of the variable. The value of this attribute is undefined if the + * symbol is not a variable. The type of this attribute is uint32_t. + * + * A size of 0 is returned if the variable is an external variable and has an + * unknown dimension. + */ + HSA_CODE_SYMBOL_INFO_VARIABLE_SIZE = 9, + /** + * Indicates whether the variable is constant. The value of this attribute is + * undefined if the symbol is not a variable. The type of this attribute is + * bool. + */ + HSA_CODE_SYMBOL_INFO_VARIABLE_IS_CONST = 10, + /** + * Size of kernarg segment memory that is required to hold the values of the + * kernel arguments, in bytes. Must be a multiple of 16. The value of this + * attribute is undefined if the symbol is not a kernel. The type of this + * attribute is uint32_t. + */ + HSA_CODE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE = 11, + /** + * Alignment (in bytes) of the buffer used to pass arguments to the kernel, + * which is the maximum of 16 and the maximum alignment of any of the kernel + * arguments. The value of this attribute is undefined if the symbol is not a + * kernel. The type of this attribute is uint32_t. + */ + HSA_CODE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_ALIGNMENT = 12, + /** + * Size of static group segment memory required by the kernel (per + * work-group), in bytes. The value of this attribute is undefined + * if the symbol is not a kernel. The type of this attribute is uint32_t. + * + * The reported amount does not include any dynamically allocated group + * segment memory that may be requested by the application when a kernel is + * dispatched. + */ + HSA_CODE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE = 13, + /** + * Size of static private, spill, and arg segment memory required by + * this kernel (per work-item), in bytes. The value of this attribute is + * undefined if the symbol is not a kernel. The type of this attribute is + * uint32_t. + * + * If the value of ::HSA_CODE_SYMBOL_INFO_KERNEL_DYNAMIC_CALLSTACK is true, + * the kernel may use more private memory than the reported value, and the + * application must add the dynamic call stack usage to @a + * private_segment_size when populating a kernel dispatch packet. + */ + HSA_CODE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE = 14, + /** + * Dynamic callstack flag. The value of this attribute is undefined if the + * symbol is not a kernel. The type of this attribute is bool. + * + * If this flag is set (the value is true), the kernel uses a dynamically + * sized call stack. This can happen if recursive calls, calls to indirect + * functions, or the HSAIL alloca instruction are present in the kernel. + */ + HSA_CODE_SYMBOL_INFO_KERNEL_DYNAMIC_CALLSTACK = 15, + /** + * Call convention of the kernel. The value of this attribute is undefined if + * the symbol is not a kernel. The type of this attribute is uint32_t. + */ + HSA_CODE_SYMBOL_INFO_KERNEL_CALL_CONVENTION = 18, + /** + * Call convention of the indirect function. The value of this attribute is + * undefined if the symbol is not an indirect function. The type of this + * attribute is uint32_t. + */ + HSA_CODE_SYMBOL_INFO_INDIRECT_FUNCTION_CALL_CONVENTION = 16, + /** + * Wavefront size used by the kernel. The value of this attribute is either + * 32 or 64. The type of this attribute is uint32_t. + */ + HSA_CODE_SYMBOL_INFO_KERNEL_WAVEFRONT_SIZE = 19 +} hsa_code_symbol_info_t; + +/** + * @deprecated + * + * @brief Get the current value of an attribute for a given code symbol. + * + * @param[in] code_symbol Code symbol. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_SYMBOL The code symbol is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * code symbol attribute, or @p value is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED +hsa_code_symbol_get_info(hsa_code_symbol_t code_symbol, + hsa_code_symbol_info_t attribute, void *value); + +/** + * @deprecated + * + * @brief Iterate over the symbols in a code object, and invoke an + * application-defined callback on every iteration. + * + * @param[in] code_object Code object. + * + * @param[in] callback Callback to be invoked once per code object symbol. The + * HSA runtime passes three arguments to the callback: the code object, a + * symbol, and the application data. If @p callback returns a status other than + * ::HSA_STATUS_SUCCESS for a particular iteration, the traversal stops and + * ::hsa_code_object_iterate_symbols returns that status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT @p code_object is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API HSA_DEPRECATED hsa_code_object_iterate_symbols( + hsa_code_object_t code_object, + hsa_status_t (*callback)(hsa_code_object_t code_object, + hsa_code_symbol_t symbol, void *data), + void *data); + +/** @} */ + +#ifdef __cplusplus +} // end extern "C" block +#endif + +#endif // header guard diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_amd_tool.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_amd_tool.h new file mode 100644 index 000000000..8a58965da --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_amd_tool.h @@ -0,0 +1,92 @@ +#ifndef HSA_RUNTIME_AMD_TOOL_EVENTS_H_ +#define HSA_RUNTIME_AMD_TOOL_EVENTS_H_ + +// Insert license header + +#include "hsa.h" +#include +#include + +typedef enum { + HSA_AMD_EVENT_SCRATCH_ALLOC_FLAG_NONE = 0, + HSA_AMD_EVENT_SCRATCH_ALLOC_FLAG_USE_ONCE = + (1 << 0), // This scratch allocation is only valid for 1 dispatch. + HSA_AMD_EVENT_SCRATCH_ALLOC_FLAG_ALT = + (1 << 1), // Used alternate scratch instead of main scratch +} hsa_amd_event_scratch_alloc_flag_t; + +typedef enum { + HSA_AMD_TOOL_EVENT_MIN = 0, + + // Scratch memory tracking + HSA_AMD_TOOL_EVENT_SCRATCH_ALLOC_START, + HSA_AMD_TOOL_EVENT_SCRATCH_ALLOC_END, + HSA_AMD_TOOL_EVENT_SCRATCH_FREE_START, + HSA_AMD_TOOL_EVENT_SCRATCH_FREE_END, + HSA_AMD_TOOL_EVENT_SCRATCH_ASYNC_RECLAIM_START, + HSA_AMD_TOOL_EVENT_SCRATCH_ASYNC_RECLAIM_END, + + // Add new events above ^ + HSA_AMD_TOOL_EVENT_MAX +} hsa_amd_tool_event_kind_t; + +typedef struct { + hsa_amd_tool_event_kind_t kind; +} hsa_amd_tool_event_none_t; + +typedef struct { + hsa_amd_tool_event_kind_t kind; + const hsa_queue_t *queue; + hsa_amd_event_scratch_alloc_flag_t flags; + uint64_t dispatch_id; // Dispatch ID of the AQL packet that needs more scratch + // memory +} hsa_amd_event_scratch_alloc_start_t; + +typedef struct { + hsa_amd_tool_event_kind_t kind; + const hsa_queue_t *queue; + hsa_amd_event_scratch_alloc_flag_t flags; + uint64_t dispatch_id; // Dispatch ID of the AQL packet that needs more scratch + // memory + size_t size; // Amount of scratch allocated - in bytes + size_t num_slots; // limit of number of waves +} hsa_amd_event_scratch_alloc_end_t; + +typedef struct { + hsa_amd_tool_event_kind_t kind; + const hsa_queue_t *queue; + hsa_amd_event_scratch_alloc_flag_t flags; +} hsa_amd_event_scratch_free_start_t; + +typedef struct { + hsa_amd_tool_event_kind_t kind; + const hsa_queue_t *queue; + hsa_amd_event_scratch_alloc_flag_t flags; +} hsa_amd_event_scratch_free_end_t; + +typedef struct { + hsa_amd_tool_event_kind_t kind; + const hsa_queue_t *queue; + hsa_amd_event_scratch_alloc_flag_t flags; +} hsa_amd_event_scratch_async_reclaim_start_t; + +typedef struct { + hsa_amd_tool_event_kind_t kind; + const hsa_queue_t *queue; + hsa_amd_event_scratch_alloc_flag_t flags; +} hsa_amd_event_scratch_async_reclaim_end_t; + +typedef union { + const hsa_amd_tool_event_none_t *none; + const hsa_amd_event_scratch_alloc_start_t *scratch_alloc_start; + const hsa_amd_event_scratch_alloc_end_t *scratch_alloc_end; + const hsa_amd_event_scratch_free_start_t *scratch_free_start; + const hsa_amd_event_scratch_free_end_t *scratch_free_end; + const hsa_amd_event_scratch_async_reclaim_start_t + *scratch_async_reclaim_start; + const hsa_amd_event_scratch_async_reclaim_end_t *scratch_async_reclaim_end; +} hsa_amd_tool_event_t; + +typedef hsa_status_t (*hsa_amd_tool_event)(hsa_amd_tool_event_t); + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_api_trace.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_api_trace.h new file mode 100644 index 000000000..913d3f43d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_api_trace.h @@ -0,0 +1,632 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef HSA_RUNTIME_INC_HSA_API_TRACE_H +#define HSA_RUNTIME_INC_HSA_API_TRACE_H + +#include "hsa.h" +#include "hsa_api_trace_version.h" +#ifdef AMD_INTERNAL_BUILD +#include "hsa_amd_tool.h" +#include "hsa_ext_amd.h" +#include "hsa_ext_finalize.h" +#include "hsa_ext_image.h" +#include "hsa_ven_amd_pc_sampling.h" +#else +#include "inc/hsa_amd_tool.h" +#include "inc/hsa_ext_amd.h" +#include "inc/hsa_ext_finalize.h" +#include "inc/hsa_ext_image.h" +#include "inc/hsa_ven_amd_pc_sampling.h" +#endif + +#include +#include +#include + +// Table MAJOR_VERSION and STEP_VERSION defines have moved to +// hsa_api_trace_version.h + +// Min function used to copy Api Tables +static inline uint32_t Min(const uint32_t a, const uint32_t b) { + return (a > b) ? b : a; +} + +// Declarations of APIs intended for use only by tools. + +// An AQL packet that can be put in an intercept queue to cause a callback to +// be invoked when the packet is about to be submitted to the underlying +// hardware queue. These packets are not copied to the underlying hardware +// queue. These packets should come immediately before the regular AQL packet +// they relate to. This implies that packet rewriters should always keep these +// packets adjacent to the regular AQL packet that follows them. +const uint32_t AMD_AQL_FORMAT_INTERCEPT_MARKER = 0xFE; + +struct amd_aql_intercept_marker_s; + +// When an intercept queue is processing rewritten packets to put them on the +// underlying hardware queue, if it encounters a +// AMD_AQL_FORMAT_INTERCEPT_MARKER vendor AQL packet it will call the following +// handler. packet points to the packet, queue is the underlying hardware +// queue, and packet_id is the packet id of the next packet to be put on the +// underlying hardware queue. The intercept queue does not put these packets +// onto the underlying hardware queue. +typedef void (*amd_intercept_marker_handler)( + const struct amd_aql_intercept_marker_s *packet, hsa_queue_t *queue, + uint64_t packet_id); +// An AQL vendor packet used by the intercept queue to mark the following +// packet. The callback will be invoked to allow a tool to know where in the +// underlying hardware queue the following packet will be placed. user_data can +// be used to hold any data useful to the tool. +typedef struct amd_aql_intercept_marker_s { + uint16_t + header; // Must have a packet type of HSA_PACKET_TYPE_VENDOR_SPECIFIC. + uint8_t format; // Must be AMD_AQL_FORMAT_INTERCEPT_MARKER. + uint8_t reserved[5]; // Must be 0. +#ifdef HSA_LARGE_MODEL + amd_intercept_marker_handler callback; +#elif defined HSA_LITTLE_ENDIAN + amd_intercept_marker_handler callback; + uint32_t reserved1; // Must be 0. +#else + uint32_t reserved1; // Must be 0. + amd_intercept_marker_handler callback; +#endif + uint64_t user_data[6]; +} amd_aql_intercept_marker_t; + +typedef void (*hsa_amd_queue_intercept_packet_writer)(const void *pkts, + uint64_t pkt_count); +typedef void (*hsa_amd_queue_intercept_handler)( + const void *pkts, uint64_t pkt_count, uint64_t user_pkt_index, void *data, + hsa_amd_queue_intercept_packet_writer writer); +hsa_status_t +hsa_amd_queue_intercept_register(hsa_queue_t *queue, + hsa_amd_queue_intercept_handler callback, + void *user_data); +hsa_status_t hsa_amd_queue_intercept_create( + hsa_agent_t agent_handle, uint32_t size, hsa_queue_type32_t type, + void (*callback)(hsa_status_t status, hsa_queue_t *source, void *data), + void *data, uint32_t private_segment_size, uint32_t group_segment_size, + hsa_queue_t **queue); + +typedef void (*hsa_amd_runtime_queue_notifier)(const hsa_queue_t *queue, + hsa_agent_t agent, void *data); +hsa_status_t +hsa_amd_runtime_queue_create_register(hsa_amd_runtime_queue_notifier callback, + void *user_data); + +// Structure of Version used to identify an instance of Api table +// Must be the first member (offsetof == 0) of all API tables. +// This is the root of the table passing ABI. +struct ApiTableVersion { + uint32_t major_id; + uint32_t minor_id; + uint32_t step_id; + uint32_t reserved; +}; + +struct ToolsApiTable { + ApiTableVersion version; + + hsa_amd_tool_event hsa_amd_tool_scratch_event_alloc_start_fn; + hsa_amd_tool_event hsa_amd_tool_scratch_event_alloc_end_fn; + hsa_amd_tool_event hsa_amd_tool_scratch_event_free_start_fn; + hsa_amd_tool_event hsa_amd_tool_scratch_event_free_end_fn; + hsa_amd_tool_event hsa_amd_tool_scratch_event_async_reclaim_start_fn; + hsa_amd_tool_event hsa_amd_tool_scratch_event_async_reclaim_end_fn; +}; + +// Table to export HSA Finalizer Extension Apis +struct FinalizerExtTable { + ApiTableVersion version; + decltype(hsa_ext_program_create) *hsa_ext_program_create_fn; + decltype(hsa_ext_program_destroy) *hsa_ext_program_destroy_fn; + decltype(hsa_ext_program_add_module) *hsa_ext_program_add_module_fn; + decltype(hsa_ext_program_iterate_modules) *hsa_ext_program_iterate_modules_fn; + decltype(hsa_ext_program_get_info) *hsa_ext_program_get_info_fn; + decltype(hsa_ext_program_finalize) *hsa_ext_program_finalize_fn; +}; + +// Table to export HSA Image Extension Apis +struct ImageExtTable { + ApiTableVersion version; + decltype(hsa_ext_image_get_capability) *hsa_ext_image_get_capability_fn; + decltype(hsa_ext_image_data_get_info) *hsa_ext_image_data_get_info_fn; + decltype(hsa_ext_image_create) *hsa_ext_image_create_fn; + decltype(hsa_ext_image_import) *hsa_ext_image_import_fn; + decltype(hsa_ext_image_export) *hsa_ext_image_export_fn; + decltype(hsa_ext_image_copy) *hsa_ext_image_copy_fn; + decltype(hsa_ext_image_clear) *hsa_ext_image_clear_fn; + decltype(hsa_ext_image_destroy) *hsa_ext_image_destroy_fn; + decltype(hsa_ext_sampler_create) *hsa_ext_sampler_create_fn; + decltype(hsa_ext_sampler_destroy) *hsa_ext_sampler_destroy_fn; + decltype(hsa_ext_image_get_capability_with_layout) + *hsa_ext_image_get_capability_with_layout_fn; + decltype(hsa_ext_image_data_get_info_with_layout) + *hsa_ext_image_data_get_info_with_layout_fn; + decltype(hsa_ext_image_create_with_layout) + *hsa_ext_image_create_with_layout_fn; +}; + +// Table to export HSA PC Sampling Extension Apis +struct PcSamplingExtTable { + ApiTableVersion version; + decltype(hsa_ven_amd_pcs_iterate_configuration) + *hsa_ven_amd_pcs_iterate_configuration_fn; + decltype(hsa_ven_amd_pcs_create) *hsa_ven_amd_pcs_create_fn; + decltype(hsa_ven_amd_pcs_create_from_id) *hsa_ven_amd_pcs_create_from_id_fn; + decltype(hsa_ven_amd_pcs_destroy) *hsa_ven_amd_pcs_destroy_fn; + decltype(hsa_ven_amd_pcs_start) *hsa_ven_amd_pcs_start_fn; + decltype(hsa_ven_amd_pcs_stop) *hsa_ven_amd_pcs_stop_fn; + decltype(hsa_ven_amd_pcs_flush) *hsa_ven_amd_pcs_flush_fn; +}; + +// Table to export AMD Extension Apis +struct AmdExtTable { + ApiTableVersion version; + decltype(hsa_amd_coherency_get_type) *hsa_amd_coherency_get_type_fn; + decltype(hsa_amd_coherency_set_type) *hsa_amd_coherency_set_type_fn; + decltype(hsa_amd_profiling_set_profiler_enabled) + *hsa_amd_profiling_set_profiler_enabled_fn; + decltype(hsa_amd_profiling_async_copy_enable) + *hsa_amd_profiling_async_copy_enable_fn; + decltype(hsa_amd_profiling_get_dispatch_time) + *hsa_amd_profiling_get_dispatch_time_fn; + decltype(hsa_amd_profiling_get_async_copy_time) + *hsa_amd_profiling_get_async_copy_time_fn; + decltype(hsa_amd_profiling_convert_tick_to_system_domain) + *hsa_amd_profiling_convert_tick_to_system_domain_fn; + decltype(hsa_amd_signal_async_handler) *hsa_amd_signal_async_handler_fn; + decltype(hsa_amd_async_function) *hsa_amd_async_function_fn; + decltype(hsa_amd_signal_wait_any) *hsa_amd_signal_wait_any_fn; + decltype(hsa_amd_queue_cu_set_mask) *hsa_amd_queue_cu_set_mask_fn; + decltype(hsa_amd_memory_pool_get_info) *hsa_amd_memory_pool_get_info_fn; + decltype(hsa_amd_agent_iterate_memory_pools) + *hsa_amd_agent_iterate_memory_pools_fn; + decltype(hsa_amd_memory_pool_allocate) *hsa_amd_memory_pool_allocate_fn; + decltype(hsa_amd_memory_pool_free) *hsa_amd_memory_pool_free_fn; + decltype(hsa_amd_memory_async_copy) *hsa_amd_memory_async_copy_fn; + decltype(hsa_amd_memory_async_copy_on_engine) + *hsa_amd_memory_async_copy_on_engine_fn; + decltype(hsa_amd_memory_copy_engine_status) + *hsa_amd_memory_copy_engine_status_fn; + decltype(hsa_amd_agent_memory_pool_get_info) + *hsa_amd_agent_memory_pool_get_info_fn; + decltype(hsa_amd_agents_allow_access) *hsa_amd_agents_allow_access_fn; + decltype(hsa_amd_memory_pool_can_migrate) *hsa_amd_memory_pool_can_migrate_fn; + decltype(hsa_amd_memory_migrate) *hsa_amd_memory_migrate_fn; + decltype(hsa_amd_memory_lock) *hsa_amd_memory_lock_fn; + decltype(hsa_amd_memory_unlock) *hsa_amd_memory_unlock_fn; + decltype(hsa_amd_memory_fill) *hsa_amd_memory_fill_fn; + decltype(hsa_amd_interop_map_buffer) *hsa_amd_interop_map_buffer_fn; + decltype(hsa_amd_interop_unmap_buffer) *hsa_amd_interop_unmap_buffer_fn; + decltype(hsa_amd_image_create) *hsa_amd_image_create_fn; + decltype(hsa_amd_pointer_info) *hsa_amd_pointer_info_fn; + decltype(hsa_amd_pointer_info_set_userdata) + *hsa_amd_pointer_info_set_userdata_fn; + decltype(hsa_amd_ipc_memory_create) *hsa_amd_ipc_memory_create_fn; + decltype(hsa_amd_ipc_memory_attach) *hsa_amd_ipc_memory_attach_fn; + decltype(hsa_amd_ipc_memory_detach) *hsa_amd_ipc_memory_detach_fn; + decltype(hsa_amd_signal_create) *hsa_amd_signal_create_fn; + decltype(hsa_amd_ipc_signal_create) *hsa_amd_ipc_signal_create_fn; + decltype(hsa_amd_ipc_signal_attach) *hsa_amd_ipc_signal_attach_fn; + decltype(hsa_amd_register_system_event_handler) + *hsa_amd_register_system_event_handler_fn; + decltype(hsa_amd_queue_intercept_create) *hsa_amd_queue_intercept_create_fn; + decltype(hsa_amd_queue_intercept_register) + *hsa_amd_queue_intercept_register_fn; + decltype(hsa_amd_queue_set_priority) *hsa_amd_queue_set_priority_fn; + decltype(hsa_amd_memory_async_copy_rect) *hsa_amd_memory_async_copy_rect_fn; + decltype(hsa_amd_runtime_queue_create_register) + *hsa_amd_runtime_queue_create_register_fn; + decltype(hsa_amd_memory_lock_to_pool) *hsa_amd_memory_lock_to_pool_fn; + decltype(hsa_amd_register_deallocation_callback) + *hsa_amd_register_deallocation_callback_fn; + decltype(hsa_amd_deregister_deallocation_callback) + *hsa_amd_deregister_deallocation_callback_fn; + decltype(hsa_amd_signal_value_pointer) *hsa_amd_signal_value_pointer_fn; + decltype(hsa_amd_svm_attributes_set) *hsa_amd_svm_attributes_set_fn; + decltype(hsa_amd_svm_attributes_get) *hsa_amd_svm_attributes_get_fn; + decltype(hsa_amd_svm_prefetch_async) *hsa_amd_svm_prefetch_async_fn; + decltype(hsa_amd_spm_acquire) *hsa_amd_spm_acquire_fn; + decltype(hsa_amd_spm_release) *hsa_amd_spm_release_fn; + decltype(hsa_amd_spm_set_dest_buffer) *hsa_amd_spm_set_dest_buffer_fn; + decltype(hsa_amd_queue_cu_get_mask) *hsa_amd_queue_cu_get_mask_fn; + decltype(hsa_amd_portable_export_dmabuf) *hsa_amd_portable_export_dmabuf_fn; + decltype(hsa_amd_portable_close_dmabuf) *hsa_amd_portable_close_dmabuf_fn; + decltype(hsa_amd_vmem_address_reserve) *hsa_amd_vmem_address_reserve_fn; + decltype(hsa_amd_vmem_address_free) *hsa_amd_vmem_address_free_fn; + decltype(hsa_amd_vmem_handle_create) *hsa_amd_vmem_handle_create_fn; + decltype(hsa_amd_vmem_handle_release) *hsa_amd_vmem_handle_release_fn; + decltype(hsa_amd_vmem_map) *hsa_amd_vmem_map_fn; + decltype(hsa_amd_vmem_unmap) *hsa_amd_vmem_unmap_fn; + decltype(hsa_amd_vmem_set_access) *hsa_amd_vmem_set_access_fn; + decltype(hsa_amd_vmem_get_access) *hsa_amd_vmem_get_access_fn; + decltype(hsa_amd_vmem_export_shareable_handle) + *hsa_amd_vmem_export_shareable_handle_fn; + decltype(hsa_amd_vmem_import_shareable_handle) + *hsa_amd_vmem_import_shareable_handle_fn; + decltype(hsa_amd_vmem_retain_alloc_handle) + *hsa_amd_vmem_retain_alloc_handle_fn; + decltype(hsa_amd_vmem_get_alloc_properties_from_handle) + *hsa_amd_vmem_get_alloc_properties_from_handle_fn; + decltype(hsa_amd_agent_set_async_scratch_limit) + *hsa_amd_agent_set_async_scratch_limit_fn; + decltype(hsa_amd_queue_get_info) *hsa_amd_queue_get_info_fn; + decltype(hsa_amd_vmem_address_reserve_align) + *hsa_amd_vmem_address_reserve_align_fn; +}; + +// Table to export HSA Core Runtime Apis +struct CoreApiTable { + ApiTableVersion version; + decltype(hsa_init) *hsa_init_fn; + decltype(hsa_shut_down) *hsa_shut_down_fn; + decltype(hsa_system_get_info) *hsa_system_get_info_fn; + decltype(hsa_system_extension_supported) *hsa_system_extension_supported_fn; + decltype(hsa_system_get_extension_table) *hsa_system_get_extension_table_fn; + decltype(hsa_iterate_agents) *hsa_iterate_agents_fn; + decltype(hsa_agent_get_info) *hsa_agent_get_info_fn; + decltype(hsa_queue_create) *hsa_queue_create_fn; + decltype(hsa_soft_queue_create) *hsa_soft_queue_create_fn; + decltype(hsa_queue_destroy) *hsa_queue_destroy_fn; + decltype(hsa_queue_inactivate) *hsa_queue_inactivate_fn; + decltype(hsa_queue_load_read_index_scacquire) + *hsa_queue_load_read_index_scacquire_fn; + decltype(hsa_queue_load_read_index_relaxed) + *hsa_queue_load_read_index_relaxed_fn; + decltype(hsa_queue_load_write_index_scacquire) + *hsa_queue_load_write_index_scacquire_fn; + decltype(hsa_queue_load_write_index_relaxed) + *hsa_queue_load_write_index_relaxed_fn; + decltype(hsa_queue_store_write_index_relaxed) + *hsa_queue_store_write_index_relaxed_fn; + decltype(hsa_queue_store_write_index_screlease) + *hsa_queue_store_write_index_screlease_fn; + decltype(hsa_queue_cas_write_index_scacq_screl) + *hsa_queue_cas_write_index_scacq_screl_fn; + decltype(hsa_queue_cas_write_index_scacquire) + *hsa_queue_cas_write_index_scacquire_fn; + decltype(hsa_queue_cas_write_index_relaxed) + *hsa_queue_cas_write_index_relaxed_fn; + decltype(hsa_queue_cas_write_index_screlease) + *hsa_queue_cas_write_index_screlease_fn; + decltype(hsa_queue_add_write_index_scacq_screl) + *hsa_queue_add_write_index_scacq_screl_fn; + decltype(hsa_queue_add_write_index_scacquire) + *hsa_queue_add_write_index_scacquire_fn; + decltype(hsa_queue_add_write_index_relaxed) + *hsa_queue_add_write_index_relaxed_fn; + decltype(hsa_queue_add_write_index_screlease) + *hsa_queue_add_write_index_screlease_fn; + decltype(hsa_queue_store_read_index_relaxed) + *hsa_queue_store_read_index_relaxed_fn; + decltype(hsa_queue_store_read_index_screlease) + *hsa_queue_store_read_index_screlease_fn; + decltype(hsa_agent_iterate_regions) *hsa_agent_iterate_regions_fn; + decltype(hsa_region_get_info) *hsa_region_get_info_fn; + decltype(hsa_agent_get_exception_policies) + *hsa_agent_get_exception_policies_fn; + decltype(hsa_agent_extension_supported) *hsa_agent_extension_supported_fn; + decltype(hsa_memory_register) *hsa_memory_register_fn; + decltype(hsa_memory_deregister) *hsa_memory_deregister_fn; + decltype(hsa_memory_allocate) *hsa_memory_allocate_fn; + decltype(hsa_memory_free) *hsa_memory_free_fn; + decltype(hsa_memory_copy) *hsa_memory_copy_fn; + decltype(hsa_memory_assign_agent) *hsa_memory_assign_agent_fn; + decltype(hsa_signal_create) *hsa_signal_create_fn; + decltype(hsa_signal_destroy) *hsa_signal_destroy_fn; + decltype(hsa_signal_load_relaxed) *hsa_signal_load_relaxed_fn; + decltype(hsa_signal_load_scacquire) *hsa_signal_load_scacquire_fn; + decltype(hsa_signal_store_relaxed) *hsa_signal_store_relaxed_fn; + decltype(hsa_signal_store_screlease) *hsa_signal_store_screlease_fn; + decltype(hsa_signal_wait_relaxed) *hsa_signal_wait_relaxed_fn; + decltype(hsa_signal_wait_scacquire) *hsa_signal_wait_scacquire_fn; + decltype(hsa_signal_and_relaxed) *hsa_signal_and_relaxed_fn; + decltype(hsa_signal_and_scacquire) *hsa_signal_and_scacquire_fn; + decltype(hsa_signal_and_screlease) *hsa_signal_and_screlease_fn; + decltype(hsa_signal_and_scacq_screl) *hsa_signal_and_scacq_screl_fn; + decltype(hsa_signal_or_relaxed) *hsa_signal_or_relaxed_fn; + decltype(hsa_signal_or_scacquire) *hsa_signal_or_scacquire_fn; + decltype(hsa_signal_or_screlease) *hsa_signal_or_screlease_fn; + decltype(hsa_signal_or_scacq_screl) *hsa_signal_or_scacq_screl_fn; + decltype(hsa_signal_xor_relaxed) *hsa_signal_xor_relaxed_fn; + decltype(hsa_signal_xor_scacquire) *hsa_signal_xor_scacquire_fn; + decltype(hsa_signal_xor_screlease) *hsa_signal_xor_screlease_fn; + decltype(hsa_signal_xor_scacq_screl) *hsa_signal_xor_scacq_screl_fn; + decltype(hsa_signal_exchange_relaxed) *hsa_signal_exchange_relaxed_fn; + decltype(hsa_signal_exchange_scacquire) *hsa_signal_exchange_scacquire_fn; + decltype(hsa_signal_exchange_screlease) *hsa_signal_exchange_screlease_fn; + decltype(hsa_signal_exchange_scacq_screl) *hsa_signal_exchange_scacq_screl_fn; + decltype(hsa_signal_add_relaxed) *hsa_signal_add_relaxed_fn; + decltype(hsa_signal_add_scacquire) *hsa_signal_add_scacquire_fn; + decltype(hsa_signal_add_screlease) *hsa_signal_add_screlease_fn; + decltype(hsa_signal_add_scacq_screl) *hsa_signal_add_scacq_screl_fn; + decltype(hsa_signal_subtract_relaxed) *hsa_signal_subtract_relaxed_fn; + decltype(hsa_signal_subtract_scacquire) *hsa_signal_subtract_scacquire_fn; + decltype(hsa_signal_subtract_screlease) *hsa_signal_subtract_screlease_fn; + decltype(hsa_signal_subtract_scacq_screl) *hsa_signal_subtract_scacq_screl_fn; + decltype(hsa_signal_cas_relaxed) *hsa_signal_cas_relaxed_fn; + decltype(hsa_signal_cas_scacquire) *hsa_signal_cas_scacquire_fn; + decltype(hsa_signal_cas_screlease) *hsa_signal_cas_screlease_fn; + decltype(hsa_signal_cas_scacq_screl) *hsa_signal_cas_scacq_screl_fn; + + //===--- Instruction Set Architecture -----------------------------------===// + + decltype(hsa_isa_from_name) *hsa_isa_from_name_fn; + // Deprecated since v1.1. + decltype(hsa_isa_get_info) *hsa_isa_get_info_fn; + // Deprecated since v1.1. + decltype(hsa_isa_compatible) *hsa_isa_compatible_fn; + + //===--- Code Objects (deprecated) --------------------------------------===// + + // Deprecated since v1.1. + decltype(hsa_code_object_serialize) *hsa_code_object_serialize_fn; + // Deprecated since v1.1. + decltype(hsa_code_object_deserialize) *hsa_code_object_deserialize_fn; + // Deprecated since v1.1. + decltype(hsa_code_object_destroy) *hsa_code_object_destroy_fn; + // Deprecated since v1.1. + decltype(hsa_code_object_get_info) *hsa_code_object_get_info_fn; + // Deprecated since v1.1. + decltype(hsa_code_object_get_symbol) *hsa_code_object_get_symbol_fn; + // Deprecated since v1.1. + decltype(hsa_code_symbol_get_info) *hsa_code_symbol_get_info_fn; + // Deprecated since v1.1. + decltype(hsa_code_object_iterate_symbols) *hsa_code_object_iterate_symbols_fn; + + //===--- Executable -----------------------------------------------------===// + + // Deprecated since v1.1. + decltype(hsa_executable_create) *hsa_executable_create_fn; + decltype(hsa_executable_destroy) *hsa_executable_destroy_fn; + // Deprecated since v1.1. + decltype(hsa_executable_load_code_object) *hsa_executable_load_code_object_fn; + decltype(hsa_executable_freeze) *hsa_executable_freeze_fn; + decltype(hsa_executable_get_info) *hsa_executable_get_info_fn; + decltype(hsa_executable_global_variable_define) + *hsa_executable_global_variable_define_fn; + decltype(hsa_executable_agent_global_variable_define) + *hsa_executable_agent_global_variable_define_fn; + decltype(hsa_executable_readonly_variable_define) + *hsa_executable_readonly_variable_define_fn; + decltype(hsa_executable_validate) *hsa_executable_validate_fn; + // Deprecated since v1.1. + decltype(hsa_executable_get_symbol) *hsa_executable_get_symbol_fn; + decltype(hsa_executable_symbol_get_info) *hsa_executable_symbol_get_info_fn; + // Deprecated since v1.1. + decltype(hsa_executable_iterate_symbols) *hsa_executable_iterate_symbols_fn; + + //===--- Runtime Notifications ------------------------------------------===// + + decltype(hsa_status_string) *hsa_status_string_fn; + + // Start HSA v1.1 additions + decltype(hsa_extension_get_name) *hsa_extension_get_name_fn; + decltype(hsa_system_major_extension_supported) + *hsa_system_major_extension_supported_fn; + decltype(hsa_system_get_major_extension_table) + *hsa_system_get_major_extension_table_fn; + decltype(hsa_agent_major_extension_supported) + *hsa_agent_major_extension_supported_fn; + decltype(hsa_cache_get_info) *hsa_cache_get_info_fn; + decltype(hsa_agent_iterate_caches) *hsa_agent_iterate_caches_fn; + decltype(hsa_signal_silent_store_relaxed) *hsa_signal_silent_store_relaxed_fn; + decltype(hsa_signal_silent_store_screlease) + *hsa_signal_silent_store_screlease_fn; + decltype(hsa_signal_group_create) *hsa_signal_group_create_fn; + decltype(hsa_signal_group_destroy) *hsa_signal_group_destroy_fn; + decltype(hsa_signal_group_wait_any_scacquire) + *hsa_signal_group_wait_any_scacquire_fn; + decltype(hsa_signal_group_wait_any_relaxed) + *hsa_signal_group_wait_any_relaxed_fn; + + //===--- Instruction Set Architecture - HSA v1.1 additions --------------===// + + decltype(hsa_agent_iterate_isas) *hsa_agent_iterate_isas_fn; + decltype(hsa_isa_get_info_alt) *hsa_isa_get_info_alt_fn; + decltype(hsa_isa_get_exception_policies) *hsa_isa_get_exception_policies_fn; + decltype(hsa_isa_get_round_method) *hsa_isa_get_round_method_fn; + decltype(hsa_wavefront_get_info) *hsa_wavefront_get_info_fn; + decltype(hsa_isa_iterate_wavefronts) *hsa_isa_iterate_wavefronts_fn; + + //===--- Code Objects (deprecated) - HSA v1.1 additions -----------------===// + + // Deprecated since v1.1. + decltype(hsa_code_object_get_symbol_from_name) + *hsa_code_object_get_symbol_from_name_fn; + + //===--- Executable - HSA v1.1 additions --------------------------------===// + + decltype(hsa_code_object_reader_create_from_file) + *hsa_code_object_reader_create_from_file_fn; + decltype(hsa_code_object_reader_create_from_memory) + *hsa_code_object_reader_create_from_memory_fn; + decltype(hsa_code_object_reader_destroy) *hsa_code_object_reader_destroy_fn; + decltype(hsa_executable_create_alt) *hsa_executable_create_alt_fn; + decltype(hsa_executable_load_program_code_object) + *hsa_executable_load_program_code_object_fn; + decltype(hsa_executable_load_agent_code_object) + *hsa_executable_load_agent_code_object_fn; + decltype(hsa_executable_validate_alt) *hsa_executable_validate_alt_fn; + decltype(hsa_executable_get_symbol_by_name) + *hsa_executable_get_symbol_by_name_fn; + decltype(hsa_executable_iterate_agent_symbols) + *hsa_executable_iterate_agent_symbols_fn; + decltype(hsa_executable_iterate_program_symbols) + *hsa_executable_iterate_program_symbols_fn; +}; + +// Table to export HSA Apis from Core Runtime, Amd Extensions +// Finalizer and Images +struct HsaApiTable { + + // Version of Hsa Api Table + ApiTableVersion version; + + // Table of function pointers to HSA Core Runtime + CoreApiTable *core_; + + // Table of function pointers to AMD extensions + AmdExtTable *amd_ext_; + + // Table of function pointers to HSA Finalizer Extension + FinalizerExtTable *finalizer_ext_; + + // Table of function pointers to HSA Image Extension + ImageExtTable *image_ext_; + + // Table of function pointers for tools to use + ToolsApiTable *tools_; + + // Table of function pointers to AMD PC Sampling Extension + PcSamplingExtTable *pc_sampling_ext_; +}; + +// Structure containing instances of different api tables +struct HsaApiTableContainer { + HsaApiTable root; + CoreApiTable core; + AmdExtTable amd_ext; + FinalizerExtTable finalizer_ext; + ImageExtTable image_ext; + ToolsApiTable tools; + PcSamplingExtTable pc_sampling_ext; + + // Default initialization of a container instance + HsaApiTableContainer() { + root.version.major_id = HSA_API_TABLE_MAJOR_VERSION; + root.version.minor_id = sizeof(HsaApiTable); + root.version.step_id = HSA_API_TABLE_STEP_VERSION; + + core.version.major_id = HSA_CORE_API_TABLE_MAJOR_VERSION; + core.version.minor_id = sizeof(CoreApiTable); + core.version.step_id = HSA_CORE_API_TABLE_STEP_VERSION; + root.core_ = &core; + + amd_ext.version.major_id = HSA_AMD_EXT_API_TABLE_MAJOR_VERSION; + amd_ext.version.minor_id = sizeof(AmdExtTable); + amd_ext.version.step_id = HSA_AMD_EXT_API_TABLE_STEP_VERSION; + root.amd_ext_ = &amd_ext; + + finalizer_ext.version.major_id = HSA_FINALIZER_API_TABLE_MAJOR_VERSION; + finalizer_ext.version.minor_id = sizeof(FinalizerExtTable); + finalizer_ext.version.step_id = HSA_FINALIZER_API_TABLE_STEP_VERSION; + root.finalizer_ext_ = &finalizer_ext; + + image_ext.version.major_id = HSA_IMAGE_API_TABLE_MAJOR_VERSION; + image_ext.version.minor_id = sizeof(ImageExtTable); + image_ext.version.step_id = HSA_IMAGE_API_TABLE_STEP_VERSION; + root.image_ext_ = &image_ext; + + tools.version.major_id = HSA_TOOLS_API_TABLE_MAJOR_VERSION; + tools.version.minor_id = sizeof(ToolsApiTable); + tools.version.step_id = HSA_TOOLS_API_TABLE_STEP_VERSION; + root.tools_ = &tools; + + pc_sampling_ext.version.major_id = HSA_PC_SAMPLING_API_TABLE_MAJOR_VERSION; + pc_sampling_ext.version.minor_id = sizeof(PcSamplingExtTable); + pc_sampling_ext.version.step_id = HSA_PC_SAMPLING_API_TABLE_STEP_VERSION; + root.pc_sampling_ext_ = &pc_sampling_ext; + } +}; + +// Api to copy function pointers of a table +static void inline copyApi(void *src, void *dest, size_t size) { + assert(size >= sizeof(ApiTableVersion)); + memcpy((char *)src + sizeof(ApiTableVersion), + (char *)dest + sizeof(ApiTableVersion), + (size - sizeof(ApiTableVersion))); +} + +// Copy Api child tables if valid. +static void inline copyElement(ApiTableVersion *dest, ApiTableVersion *src) { + if (src->major_id && (dest->major_id == src->major_id)) { + dest->step_id = src->step_id; + dest->minor_id = Min(dest->minor_id, src->minor_id); + copyApi(dest, src, dest->minor_id); + } else { + dest->major_id = 0; + dest->minor_id = 0; + dest->step_id = 0; + } +} + +// Copy constructor for all Api tables. The function assumes the +// user has initialized an instance of tables container correctly +// for the Major, Minor and Stepping Ids of Root and Child Api tables. +// The function will overwrite the value of Minor Id by taking the +// minimum of source and destination parameters. It will also overwrite +// the stepping Id with value from source parameter. +static void inline copyTables(const HsaApiTable *src, HsaApiTable *dest) { + // Verify Major Id of source and destination tables match + if (dest->version.major_id != src->version.major_id) { + dest->version.major_id = 0; + dest->version.minor_id = 0; + dest->version.step_id = 0; + return; + } + + // Initialize the stepping id and minor id of root table. For the + // minor id which encodes struct size, take the minimum of source + // and destination parameters + dest->version.step_id = src->version.step_id; + dest->version.minor_id = Min(dest->version.minor_id, src->version.minor_id); + + // Copy child tables if present + if ((offsetof(HsaApiTable, core_) < dest->version.minor_id)) + copyElement(&dest->core_->version, &src->core_->version); + if ((offsetof(HsaApiTable, amd_ext_) < dest->version.minor_id)) + copyElement(&dest->amd_ext_->version, &src->amd_ext_->version); + if ((offsetof(HsaApiTable, finalizer_ext_) < dest->version.minor_id)) + copyElement(&dest->finalizer_ext_->version, &src->finalizer_ext_->version); + if ((offsetof(HsaApiTable, image_ext_) < dest->version.minor_id)) + copyElement(&dest->image_ext_->version, &src->image_ext_->version); + if ((offsetof(HsaApiTable, tools_) < dest->version.minor_id)) + copyElement(&dest->tools_->version, &src->tools_->version); + if ((offsetof(HsaApiTable, pc_sampling_ext_) < dest->version.minor_id)) + copyElement(&dest->pc_sampling_ext_->version, + &src->pc_sampling_ext_->version); +} +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_api_trace_version.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_api_trace_version.h new file mode 100644 index 000000000..db7d82719 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_api_trace_version.h @@ -0,0 +1,68 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2024, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef HSA_RUNTIME_INC_HSA_API_TRACE_VERSION_H +#define HSA_RUNTIME_INC_HSA_API_TRACE_VERSION_H + +// CODE IN THIS FILE **MUST** BE C-COMPATIBLE + +// Major Ids of the Api tables exported by Hsa Core Runtime +#define HSA_API_TABLE_MAJOR_VERSION 0x03 +#define HSA_CORE_API_TABLE_MAJOR_VERSION 0x02 +#define HSA_AMD_EXT_API_TABLE_MAJOR_VERSION 0x02 +#define HSA_FINALIZER_API_TABLE_MAJOR_VERSION 0x02 +#define HSA_IMAGE_API_TABLE_MAJOR_VERSION 0x02 +#define HSA_AQLPROFILE_API_TABLE_MAJOR_VERSION 0x01 +#define HSA_TOOLS_API_TABLE_MAJOR_VERSION 0x01 +#define HSA_PC_SAMPLING_API_TABLE_MAJOR_VERSION 0x01 + +// Step Ids of the Api tables exported by Hsa Core Runtime +#define HSA_API_TABLE_STEP_VERSION 0x01 +#define HSA_CORE_API_TABLE_STEP_VERSION 0x00 +#define HSA_AMD_EXT_API_TABLE_STEP_VERSION 0x03 +#define HSA_FINALIZER_API_TABLE_STEP_VERSION 0x00 +#define HSA_IMAGE_API_TABLE_STEP_VERSION 0x00 +#define HSA_AQLPROFILE_API_TABLE_STEP_VERSION 0x00 +#define HSA_TOOLS_API_TABLE_STEP_VERSION 0x00 +#define HSA_PC_SAMPLING_API_TABLE_STEP_VERSION 0x00 + +#endif // HSA_RUNTIME_INC_HSA_API_TRACE_VERSION_H diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_amd.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_amd.h new file mode 100644 index 000000000..9568d8d95 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_amd.h @@ -0,0 +1,3186 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +// HSA AMD extension. + +#ifndef HSA_RUNTIME_EXT_AMD_H_ +#define HSA_RUNTIME_EXT_AMD_H_ + +#include "hsa.h" +#include "hsa_ext_image.h" +#include "hsa_ven_amd_pc_sampling.h" + +/** + * - 1.0 - initial version + * - 1.1 - dmabuf export + * - 1.2 - hsa_amd_memory_async_copy_on_engine + * - 1.3 - HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_EXTENDED_SCOPE_FINE_GRAINED pool + * - 1.4 - Virtual Memory API + * - 1.5 - hsa_amd_agent_info: HSA_AMD_AGENT_INFO_MEMORY_PROPERTIES + * - 1.6 - Virtual Memory API: hsa_amd_vmem_address_reserve_align + */ +#define HSA_AMD_INTERFACE_VERSION_MAJOR 1 +#define HSA_AMD_INTERFACE_VERSION_MINOR 6 + +#ifdef __cplusplus +extern "C" { +#endif + +/** \addtogroup aql Architected Queuing Language + * @{ + */ + +/** + * @brief Macro to use to determine that a flag is set when querying flags + * within uint8_t[8] types + */ +static __inline__ __attribute__((always_inline)) bool +hsa_flag_isset64(uint8_t *value, uint32_t bit) { + unsigned int index = bit / 8; + unsigned int subBit = bit % 8; + return ((uint8_t *)value)[index] & (1 << subBit); +} + +/** + * @brief A fixed-size type used to represent ::hsa_signal_condition_t + * constants. + */ +typedef uint32_t hsa_signal_condition32_t; + +/** + * @brief AMD vendor specific packet type. + */ +typedef enum { + /** + * Packet used by agents to delay processing of subsequent packets until a + * configurable condition is satisfied by an HSA signal. Only kernel dispatch + * queues created from AMD GPU Agents support this packet. + */ + HSA_AMD_PACKET_TYPE_BARRIER_VALUE = 2, +} hsa_amd_packet_type_t; + +/** + * @brief A fixed-size type used to represent ::hsa_amd_packet_type_t constants. + */ +typedef uint8_t hsa_amd_packet_type8_t; + +/** + * @brief AMD vendor specific AQL packet header + */ +typedef struct hsa_amd_packet_header_s { + /** + * Packet header. Used to configure multiple packet parameters such as the + * packet type. The parameters are described by ::hsa_packet_header_t. + */ + uint16_t header; + + /** + *Format of the vendor specific packet. + */ + hsa_amd_packet_type8_t AmdFormat; + + /** + * Reserved. Must be 0. + */ + uint8_t reserved; +} hsa_amd_vendor_packet_header_t; + +/** + * @brief AMD barrier value packet. Halts packet processing and waits for + * (signal_value & ::mask) ::cond ::value to be satisfied, where signal_value + * is the value of the signal ::signal. + */ +typedef struct hsa_amd_barrier_value_packet_s { + /** + * AMD vendor specific packet header. + */ + hsa_amd_vendor_packet_header_t header; + + /** + * Reserved. Must be 0. + */ + uint32_t reserved0; + + /** + * Dependent signal object. A signal with a handle value of 0 is + * allowed and is interpreted by the packet processor a satisfied + * dependency. + */ + hsa_signal_t signal; + + /** + * Value to compare against. + */ + hsa_signal_value_t value; + + /** + * Bit mask to be combined by bitwise AND with ::signal's value. + */ + hsa_signal_value_t mask; + + /** + * Comparison operation. See ::hsa_signal_condition_t. + */ + hsa_signal_condition32_t cond; + + /** + * Reserved. Must be 0. + */ + uint32_t reserved1; + + /** + * Reserved. Must be 0. + */ + uint64_t reserved2; + + /** + * Reserved. Must be 0. + */ + uint64_t reserved3; + + /** + * Signal used to indicate completion of the job. The application can use the + * special signal handle 0 to indicate that no signal is used. + */ + hsa_signal_t completion_signal; +} hsa_amd_barrier_value_packet_t; + +/** @} */ + +/** + * @brief Enumeration constants added to ::hsa_status_t. + * + * @remark Additions to hsa_status_t + */ +enum { + /** + * The memory pool is invalid. + */ + HSA_STATUS_ERROR_INVALID_MEMORY_POOL = 40, + + /** + * Agent accessed memory beyond the maximum legal address. + */ + HSA_STATUS_ERROR_MEMORY_APERTURE_VIOLATION = 41, + + /** + * Agent executed an invalid shader instruction. + */ + HSA_STATUS_ERROR_ILLEGAL_INSTRUCTION = 42, + + /** + * Agent attempted to access an inaccessible address. + * See hsa_amd_register_system_event_handler and + * HSA_AMD_GPU_MEMORY_FAULT_EVENT for more information on illegal accesses. + */ + HSA_STATUS_ERROR_MEMORY_FAULT = 43, + + /** + * The CU mask was successfully set but the mask attempted to enable a CU + * which was disabled for the process. CUs disabled for the process remain + * disabled. + */ + HSA_STATUS_CU_MASK_REDUCED = 44, + + /** + * Exceeded number of VGPRs available on this agent + */ + HSA_STATUS_ERROR_OUT_OF_REGISTERS = 45, + + /** + * Resource is busy or temporarily unavailable + */ + HSA_STATUS_ERROR_RESOURCE_BUSY = 46, +}; + +/** + * @brief IOMMU version supported + */ +typedef enum { + /** + * IOMMU not supported + */ + HSA_IOMMU_SUPPORT_NONE = 0, + /* IOMMU V1 support is not relevant to user applications, so not reporting it + */ + /** + * IOMMU V2 supported + */ + HSA_IOMMU_SUPPORT_V2 = 1, +} hsa_amd_iommu_version_t; + +/** + * @brief Agent attributes. + */ +typedef enum hsa_amd_agent_info_s { + /** + * Chip identifier. The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_CHIP_ID = 0xA000, + /** + * Size of a cacheline in bytes. The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_CACHELINE_SIZE = 0xA001, + /** + * The number of compute unit available in the agent. The type of this + * attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_COMPUTE_UNIT_COUNT = 0xA002, + /** + * The maximum clock frequency of the agent in MHz. The type of this + * attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_MAX_CLOCK_FREQUENCY = 0xA003, + /** + * Internal driver node identifier. The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_DRIVER_NODE_ID = 0xA004, + /** + * Max number of watch points on memory address ranges to generate exception + * events when the watched addresses are accessed. The type of this + * attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_MAX_ADDRESS_WATCH_POINTS = 0xA005, + /** + * Agent BDF_ID, named LocationID in thunk. The type of this attribute is + * uint32_t. + */ + HSA_AMD_AGENT_INFO_BDFID = 0xA006, + /** + * Memory Interface width, the return value type is uint32_t. + * This attribute is deprecated. + */ + HSA_AMD_AGENT_INFO_MEMORY_WIDTH = 0xA007, + /** + * Max Memory Clock, the return value type is uint32_t. + */ + HSA_AMD_AGENT_INFO_MEMORY_MAX_FREQUENCY = 0xA008, + /** + * Board name of Agent - populated from MarketingName of Kfd Node + * The value is an Ascii string of 64 chars. + */ + HSA_AMD_AGENT_INFO_PRODUCT_NAME = 0xA009, + /** + * Maximum number of waves possible in a Compute Unit. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_MAX_WAVES_PER_CU = 0xA00A, + /** + * Number of SIMD's per compute unit CU + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_NUM_SIMDS_PER_CU = 0xA00B, + /** + * Number of Shader Engines (SE) in Gpu + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_NUM_SHADER_ENGINES = 0xA00C, + /** + * Number of Shader Arrays Per Shader Engines in Gpu + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_NUM_SHADER_ARRAYS_PER_SE = 0xA00D, + /** + * Address of the HDP flush registers. Use of these registers does not + * conform to the HSA memory model and should be treated with caution. The + * type of this attribute is hsa_amd_hdp_flush_t. + */ + HSA_AMD_AGENT_INFO_HDP_FLUSH = 0xA00E, + /** + * PCIe domain for the agent. Pairs with HSA_AMD_AGENT_INFO_BDFID + * to give the full physical location of the Agent. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_DOMAIN = 0xA00F, + /** + * Queries for support of cooperative queues. See + * ::HSA_QUEUE_TYPE_COOPERATIVE. The type of this attribute is bool. + */ + HSA_AMD_AGENT_INFO_COOPERATIVE_QUEUES = 0xA010, + /** + * Queries UUID of an agent. The value is an Ascii string with a maximum + * of 21 chars including NUL. The string value consists of two parts: header + * and body. The header identifies device type (GPU, CPU, DSP) while body + * encodes UUID as a 16 digit hex string + * + * Agents that do not support UUID will return the string "GPU-XX" or + * "CPU-XX" or "DSP-XX" depending upon their device type ::hsa_device_type_t + */ + HSA_AMD_AGENT_INFO_UUID = 0xA011, + /** + * Queries for the ASIC revision of an agent. The value is an integer that + * increments for each revision. This can be used by user-level software to + * change how it operates, depending on the hardware version. This allows + * selective workarounds for hardware errata. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_ASIC_REVISION = 0xA012, + /** + * Queries whether or not the host can directly access SVM memory that is + * physically resident in the agent's local memory. + * The type of this attribute is bool. + */ + HSA_AMD_AGENT_INFO_SVM_DIRECT_HOST_ACCESS = 0xA013, + /** + * Some processors support more CUs than can reliably be used in a cooperative + * dispatch. This queries the count of CUs which are fully enabled for + * cooperative dispatch. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_COOPERATIVE_COMPUTE_UNIT_COUNT = 0xA014, + /** + * Queries the amount of memory available in bytes accross all global pools + * owned by the agent. + * The type of this attribute is uint64_t. + */ + HSA_AMD_AGENT_INFO_MEMORY_AVAIL = 0xA015, + /** + * Timestamp value increase rate, in Hz. The timestamp (clock) frequency is + * in the range 1-400MHz. + * The type of this attribute is uint64_t. + */ + HSA_AMD_AGENT_INFO_TIMESTAMP_FREQUENCY = 0xA016, + /** + * Queries for the ASIC family ID of an agent. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_ASIC_FAMILY_ID = 0xA107, + /** + * Queries for the Packet Processor(CP Firmware) ucode version of an agent. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_UCODE_VERSION = 0xA108, + /** + * Queries for the SDMA engine ucode of an agent. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_SDMA_UCODE_VERSION = 0xA109, + /** + * Queries the number of SDMA engines. + * If HSA_AMD_AGENT_INFO_NUM_SDMA_XGMI_ENG query returns non-zero, + * this query returns the the number of SDMA engines optimized for + * host to device bidirectional traffic. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_NUM_SDMA_ENG = 0xA10A, + /** + * Queries the number of additional SDMA engines optimized for D2D xGMI + * copies. The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_NUM_SDMA_XGMI_ENG = 0xA10B, + /** + * Queries for version of IOMMU supported by agent. + * The type of this attribute is hsa_amd_iommu_version_t. + */ + HSA_AMD_AGENT_INFO_IOMMU_SUPPORT = 0xA110, + /** + * Queries for number of XCCs within the agent. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_NUM_XCC = 0xA111, + /** + * Queries for driver unique identifier. + * The type of this attribute is uint32_t. + */ + HSA_AMD_AGENT_INFO_DRIVER_UID = 0xA112, + /** + * Returns the hsa_agent_t of the nearest CPU agent + * The type of this attribute is hsa_agent_t. + */ + HSA_AMD_AGENT_INFO_NEAREST_CPU = 0xA113, + /** + * Bit-mask indicating memory properties of this agent. A memory property is + * set if the flag bit is set at that position. User may use the + * hsa_flag_isset64 macro to verify whether a flag is set. The type of this + * attribute is uint8_t[8]. + */ + HSA_AMD_AGENT_INFO_MEMORY_PROPERTIES = 0xA114, + /** + * Bit-mask indicating AQL Extensions supported by this agent. An AQL + * extension is set if the flag bit is set at that position. User may use the + * hsa_flag_isset64 macro to verify whether a flag is set. The type of this + * attribute is uint8_t[8]. + */ + HSA_AMD_AGENT_INFO_AQL_EXTENSIONS = 0xA115 /* Not implemented yet */ +} hsa_amd_agent_info_t; + +/** + * @brief Agent memory properties attributes + */ +typedef enum hsa_amd_agent_memory_properties_s { + HSA_AMD_MEMORY_PROPERTY_AGENT_IS_APU = (1 << 0), +} hsa_amd_agent_memory_properties_t; + +/** + * @brief SDMA engine IDs unique by single set bit position. + */ +typedef enum hsa_amd_sdma_engine_id { + HSA_AMD_SDMA_ENGINE_0 = 0x1, + HSA_AMD_SDMA_ENGINE_1 = 0x2, + HSA_AMD_SDMA_ENGINE_2 = 0x4, + HSA_AMD_SDMA_ENGINE_3 = 0x8, + HSA_AMD_SDMA_ENGINE_4 = 0x10, + HSA_AMD_SDMA_ENGINE_5 = 0x20, + HSA_AMD_SDMA_ENGINE_6 = 0x40, + HSA_AMD_SDMA_ENGINE_7 = 0x80, + HSA_AMD_SDMA_ENGINE_8 = 0x100, + HSA_AMD_SDMA_ENGINE_9 = 0x200, + HSA_AMD_SDMA_ENGINE_10 = 0x400, + HSA_AMD_SDMA_ENGINE_11 = 0x800, + HSA_AMD_SDMA_ENGINE_12 = 0x1000, + HSA_AMD_SDMA_ENGINE_13 = 0x2000, + HSA_AMD_SDMA_ENGINE_14 = 0x4000, + HSA_AMD_SDMA_ENGINE_15 = 0x8000 +} hsa_amd_sdma_engine_id_t; + +typedef struct hsa_amd_hdp_flush_s { + uint32_t *HDP_MEM_FLUSH_CNTL; + uint32_t *HDP_REG_FLUSH_CNTL; +} hsa_amd_hdp_flush_t; + +/** + * @brief Region attributes. + */ +typedef enum hsa_amd_region_info_s { + /** + * Determine if host can access the region. The type of this attribute + * is bool. + */ + HSA_AMD_REGION_INFO_HOST_ACCESSIBLE = 0xA000, + /** + * Base address of the region in flat address space. + */ + HSA_AMD_REGION_INFO_BASE = 0xA001, + /** + * Memory Interface width, the return value type is uint32_t. + * This attribute is deprecated. Use HSA_AMD_AGENT_INFO_MEMORY_WIDTH. + */ + HSA_AMD_REGION_INFO_BUS_WIDTH = 0xA002, + /** + * Max Memory Clock, the return value type is uint32_t. + * This attribute is deprecated. Use HSA_AMD_AGENT_INFO_MEMORY_MAX_FREQUENCY. + */ + HSA_AMD_REGION_INFO_MAX_CLOCK_FREQUENCY = 0xA003, +} hsa_amd_region_info_t; + +/** + * @brief Coherency attributes of fine grain region. + */ +typedef enum hsa_amd_coherency_type_s { + /** + * Coherent region. + */ + HSA_AMD_COHERENCY_TYPE_COHERENT = 0, + /** + * Non coherent region. + */ + HSA_AMD_COHERENCY_TYPE_NONCOHERENT = 1 +} hsa_amd_coherency_type_t; + +/** + * @brief Get the coherency type of the fine grain region of an agent. + * + * @param[in] agent A valid agent. + * + * @param[out] type Pointer to a memory location where the HSA runtime will + * store the coherency type of the fine grain region. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p type is NULL. + */ +hsa_status_t HSA_API hsa_amd_coherency_get_type(hsa_agent_t agent, + hsa_amd_coherency_type_t *type); + +/** + * @brief Set the coherency type of the fine grain region of an agent. + * Deprecated. This is supported on KV platforms. For backward compatibility + * other platforms will spuriously succeed. + * + * @param[in] agent A valid agent. + * + * @param[in] type The coherency type to be set. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p type is invalid. + */ +hsa_status_t HSA_API hsa_amd_coherency_set_type(hsa_agent_t agent, + hsa_amd_coherency_type_t type); + +/** + * @brief Structure containing profiling dispatch time information. + * + * Times are reported as ticks in the domain of the HSA system clock. + * The HSA system clock tick and frequency is obtained via hsa_system_get_info. + */ +typedef struct hsa_amd_profiling_dispatch_time_s { + /** + * Dispatch packet processing start time. + */ + uint64_t start; + /** + * Dispatch packet completion time. + */ + uint64_t end; +} hsa_amd_profiling_dispatch_time_t; + +/** + * @brief Structure containing profiling async copy time information. + * + * Times are reported as ticks in the domain of the HSA system clock. + * The HSA system clock tick and frequency is obtained via hsa_system_get_info. + */ +typedef struct hsa_amd_profiling_async_copy_time_s { + /** + * Async copy processing start time. + */ + uint64_t start; + /** + * Async copy completion time. + */ + uint64_t end; +} hsa_amd_profiling_async_copy_time_t; + +/** + * @brief Enable or disable profiling capability of a queue. + * + * @param[in] queue A valid queue. + * + * @param[in] enable 1 to enable profiling. 0 to disable profiling. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_QUEUE The queue is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p queue is NULL. + */ +hsa_status_t HSA_API hsa_amd_profiling_set_profiler_enabled(hsa_queue_t *queue, + int enable); + +/** + * @brief Enable or disable asynchronous memory copy profiling. + * + * @details The runtime will provide the copy processing start timestamp and + * completion timestamp of each call to hsa_amd_memory_async_copy if the + * async copy profiling is enabled prior to the call to + * hsa_amd_memory_async_copy. The completion signal object is used to + * hold the last async copy start and end timestamp. The client can retrieve + * these timestamps via call to hsa_amd_profiling_get_async_copy_time. + * + * @param[in] enable True to enable profiling. False to disable profiling. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Failed on allocating resources + * needed to profile the asynchronous copy. + */ +hsa_status_t HSA_API hsa_amd_profiling_async_copy_enable(bool enable); + +/** + * @brief Retrieve packet processing time stamps. + * + * @param[in] agent The agent with which the signal was last used. For + * instance, if the profiled dispatch packet is dispatched onto queue Q, + * which was created on agent A, then this parameter must be A. + * + * @param[in] signal A signal used as the completion signal of the dispatch + * packet to retrieve time stamps from. This dispatch packet must have been + * issued to a queue with profiling enabled and have already completed. Also + * the signal must not have yet been used in any other packet following the + * completion of the profiled dispatch packet. + * + * @param[out] time Packet processing timestamps in the HSA system clock + * domain. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SIGNAL The signal is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p time is NULL. + */ +hsa_status_t HSA_API +hsa_amd_profiling_get_dispatch_time(hsa_agent_t agent, hsa_signal_t signal, + hsa_amd_profiling_dispatch_time_t *time); + +/** + * @brief Retrieve asynchronous copy timestamps. + * + * @details Async copy profiling is enabled via call to + * hsa_amd_profiling_async_copy_enable. + * + * @param[in] signal A signal used as the completion signal of the call to + * hsa_amd_memory_async_copy. + * + * @param[out] time Async copy processing timestamps in the HSA system clock + * domain. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SIGNAL The signal is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p time is NULL. + */ +hsa_status_t HSA_API hsa_amd_profiling_get_async_copy_time( + hsa_signal_t signal, hsa_amd_profiling_async_copy_time_t *time); + +/** + * @brief Computes the frequency ratio and offset between the agent clock and + * HSA system clock and converts the agent's tick to HSA system domain tick. + * + * @param[in] agent The agent used to retrieve the agent_tick. It is user's + * responsibility to make sure the tick number is from this agent, otherwise, + * the behavior is undefined. + * + * @param[in] agent_tick The tick count retrieved from the specified @p agent. + * + * @param[out] system_tick The translated HSA system domain clock counter tick. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p system_tick is NULL; + */ +hsa_status_t HSA_API hsa_amd_profiling_convert_tick_to_system_domain( + hsa_agent_t agent, uint64_t agent_tick, uint64_t *system_tick); + +/** + * @brief Signal attribute flags. + */ +typedef enum { + /** + * Signal will only be consumed by AMD GPUs. Limits signal consumption to + * AMD GPU agents only. Ignored if @p num_consumers is not zero (all agents). + */ + HSA_AMD_SIGNAL_AMD_GPU_ONLY = 1, + /** + * Signal may be used for interprocess communication. + * IPC signals can be read, written, and waited on from any process. + * Profiling using an IPC enabled signal is only supported in a single process + * at a time. Producing profiling data in one process and consuming it in + * another process is undefined. + */ + HSA_AMD_SIGNAL_IPC = 2, +} hsa_amd_signal_attribute_t; + +/** + * @brief Create a signal with specific attributes. + * + * @param[in] initial_value Initial value of the signal. + * + * @param[in] num_consumers Size of @p consumers. A value of 0 indicates that + * any agent might wait on the signal. + * + * @param[in] consumers List of agents that might consume (wait on) the + * signal. If @p num_consumers is 0, this argument is ignored; otherwise, the + * HSA runtime might use the list to optimize the handling of the signal + * object. If an agent not listed in @p consumers waits on the returned + * signal, the behavior is undefined. The memory associated with @p consumers + * can be reused or freed after the function returns. + * + * @param[in] attributes Requested signal attributes. Multiple signal + * attributes may be requested by combining them with bitwise OR. Requesting no + * attributes + * (@p attributes == 0) results in the same signal as would have been obtained + * via hsa_signal_create. + * + * @param[out] signal Pointer to a memory location where the HSA runtime will + * store the newly created signal handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p signal is NULL, @p + * num_consumers is greater than 0 but @p consumers is NULL, or @p consumers + * contains duplicates. + */ +hsa_status_t HSA_API hsa_amd_signal_create(hsa_signal_value_t initial_value, + uint32_t num_consumers, + const hsa_agent_t *consumers, + uint64_t attributes, + hsa_signal_t *signal); + +/** + * @brief Returns a pointer to the value of a signal. + * + * Use of this API does not modify the lifetime of ::signal and any + * hsa_signal_value_t retrieved by this API has lifetime equal to that of + * ::signal. + * + * This API is intended for partial interoperability with non-HSA compatible + * devices and should not be used where HSA interfaces are available. + * + * Use of the signal value must comply with use restritions of ::signal. + * Use may result in data races if the operations performed are not platform + * atomic. Use with HSA_AMD_SIGNAL_AMD_GPU_ONLY or HSA_AMD_SIGNAL_IPC + * attributed signals is required. + * + * @param[in] Signal handle to extract the signal value pointer from. + * + * @param[out] Location where the extracted signal value pointer will be placed. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SIGNAL signal is not a valid hsa_signal_t + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT value_ptr is NULL. + */ +hsa_status_t +hsa_amd_signal_value_pointer(hsa_signal_t signal, + volatile hsa_signal_value_t **value_ptr); + +/** + * @brief Asyncronous signal handler function type. + * + * @details Type definition of callback function to be used with + * hsa_amd_signal_async_handler. This callback is invoked if the associated + * signal and condition are met. The callback receives the value of the signal + * which satisfied the associated wait condition and a user provided value. If + * the callback returns true then the callback will be called again if the + * associated signal and condition are satisfied again. If the callback returns + * false then it will not be called again. + * + * @param[in] value Contains the value of the signal observed by + * hsa_amd_signal_async_handler which caused the signal handler to be invoked. + * + * @param[in] arg Contains the user provided value given when the signal handler + * was registered with hsa_amd_signal_async_handler + * + * @retval true resumes monitoring the signal with this handler (as if calling + * hsa_amd_signal_async_handler again with identical parameters) + * + * @retval false stops monitoring the signal with this handler (handler will + * not be called again for this signal) + * + */ +typedef bool (*hsa_amd_signal_handler)(hsa_signal_value_t value, void *arg); + +/** + * @brief Register asynchronous signal handler function. + * + * @details Allows registering a callback function and user provided value with + * a signal and wait condition. The callback will be invoked if the associated + * signal and wait condition are satisfied. Callbacks will be invoked serially + * but in an arbitrary order so callbacks should be independent of each other. + * After being invoked a callback may continue to wait for its associated signal + * and condition and, possibly, be invoked again. Or the callback may stop + * waiting. If the callback returns true then it will continue waiting and may + * be called again. If false then the callback will not wait again and will not + * be called again for the associated signal and condition. It is possible to + * register the same callback multiple times with the same or different signals + * and/or conditions. Each registration of the callback will be treated entirely + * independently. + * + * @param[in] signal hsa signal to be asynchronously monitored + * + * @param[in] cond condition value to monitor for + * + * @param[in] value signal value used in condition expression + * + * @param[in] handler asynchronous signal handler invoked when signal's + * condition is met + * + * @param[in] arg user provided value which is provided to handler when handler + * is invoked + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SIGNAL signal is not a valid hsa_signal_t + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT handler is invalid (NULL) + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime is out of + * resources or blocking signals are not supported by the HSA driver component. + * + */ +hsa_status_t HSA_API hsa_amd_signal_async_handler( + hsa_signal_t signal, hsa_signal_condition_t cond, hsa_signal_value_t value, + hsa_amd_signal_handler handler, void *arg); + +/** + * @brief Call a function asynchronously + * + * @details Provides access to the runtime's asynchronous event handling thread + * for general asynchronous functions. Functions queued this way are executed + * in the same manner as if they were a signal handler who's signal is + * satisfied. + * + * @param[in] callback asynchronous function to be invoked + * + * @param[in] arg user provided value which is provided to handler when handler + * is invoked + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT handler is invalid (NULL) + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime is out of + * resources or blocking signals are not supported by the HSA driver component. + * + */ +hsa_status_t HSA_API hsa_amd_async_function(void (*callback)(void *arg), + void *arg); + +/** + * @brief Wait for any signal-condition pair to be satisfied. + * + * @details Allows waiting for any of several signal and conditions pairs to be + * satisfied. The function returns the index into the list of signals of the + * first satisfying signal-condition pair. The value of the satisfying signal's + * value is returned in satisfying_value unless satisfying_value is NULL. This + * function provides only relaxed memory semantics. + */ +uint32_t HSA_API hsa_amd_signal_wait_any( + uint32_t signal_count, hsa_signal_t *signals, hsa_signal_condition_t *conds, + hsa_signal_value_t *values, uint64_t timeout_hint, + hsa_wait_state_t wait_hint, hsa_signal_value_t *satisfying_value); + +/** + * @brief Query image limits. + * + * @param[in] agent A valid agent. + * + * @param[in] attribute HSA image info attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_QUEUE @p value is NULL or @p attribute < + * HSA_EXT_AGENT_INFO_IMAGE_1D_MAX_ELEMENTS or @p attribute > + * HSA_EXT_AGENT_INFO_IMAGE_ARRAY_MAX_LAYERS. + * + */ +hsa_status_t HSA_API hsa_amd_image_get_info_max_dim(hsa_agent_t agent, + hsa_agent_info_t attribute, + void *value); + +/** + * @brief Set a queue's CU affinity mask. + * + * @details Enables the queue to run on only selected CUs. The given mask is + * combined by bitwise AND with any device wide mask in HSA_CU_MASK before + * being applied. + * If num_cu_mask_count is 0 then the request is interpreted as a request to + * enable all CUs and no cu_mask array need be given. + * + * @param[in] queue A pointer to HSA queue. + * + * @param[in] num_cu_mask_count Size of CUMask bit array passed in, in bits. + * + * @param[in] cu_mask Bit-vector representing the CU mask. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_CU_MASK_REDUCED The function was successfully executed + * but the given mask attempted to enable a CU which was disabled by + * HSA_CU_MASK. CUs disabled by HSA_CU_MASK remain disabled. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_QUEUE @p queue is NULL or invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p num_cu_mask_count is not + * a multiple of 32 or @p num_cu_mask_count is not 0 and cu_mask is NULL. + * Devices with work group processors must even-index contiguous pairwise + * CU enable e.g. 0x33(b'110011) is valid while 0x5(0x101) and 0x6(b'0110) + * are invalid. + * + */ +hsa_status_t HSA_API hsa_amd_queue_cu_set_mask(const hsa_queue_t *queue, + uint32_t num_cu_mask_count, + const uint32_t *cu_mask); + +/** + * @brief Retrieve a queue's CU affinity mask. + * + * @details Returns the first num_cu_mask_count bits of a queue's CU mask. + * Ensure that num_cu_mask_count is at least as large as + * HSA_AMD_AGENT_INFO_COMPUTE_UNIT_COUNT to retrieve the entire mask. + * + * @param[in] queue A pointer to HSA queue. + * + * @param[in] num_cu_mask_count Size of CUMask bit array passed in, in bits. + * + * @param[out] cu_mask Bit-vector representing the CU mask. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_QUEUE @p queue is NULL or invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p num_cu_mask_count is 0, not + * a multiple of 32 or @p cu_mask is NULL. + * + */ +hsa_status_t HSA_API hsa_amd_queue_cu_get_mask(const hsa_queue_t *queue, + uint32_t num_cu_mask_count, + uint32_t *cu_mask); + +/** + * @brief Memory segments associated with a memory pool. + */ +typedef enum { + /** + * Global segment. Used to hold data that is shared by all agents. + */ + HSA_AMD_SEGMENT_GLOBAL = 0, + /** + * Read-only segment. Used to hold data that remains constant during the + * execution of a kernel. + */ + HSA_AMD_SEGMENT_READONLY = 1, + /** + * Private segment. Used to hold data that is local to a single work-item. + */ + HSA_AMD_SEGMENT_PRIVATE = 2, + /** + * Group segment. Used to hold data that is shared by the work-items of a + * work-group. + */ + HSA_AMD_SEGMENT_GROUP = 3, +} hsa_amd_segment_t; + +/** + * @brief A memory pool encapsulates physical storage on an agent + * along with a memory access model. + * + * @details A memory pool encapsulates a physical partition of an agent's + * memory system along with a memory access model. Division of a single + * memory system into separate pools allows querying each partition's access + * path properties (see ::hsa_amd_agent_memory_pool_get_info). Allocations + * from a pool are preferentially bound to that pool's physical partition. + * Binding to the pool's preferential physical partition may not be + * possible or persistent depending on the system's memory policy + * and/or state which is beyond the scope of HSA APIs. + * + * For example, a multi-node NUMA memory system may be represented by multiple + * pool's with each pool providing size and access path information for the + * partition it represents. Allocations from a pool are preferentially bound + * to the pool's partition (which in this example is a NUMA node) while + * following its memory access model. The actual placement may vary or migrate + * due to the system's NUMA policy and state, which is beyond the scope of + * HSA APIs. + */ +typedef struct hsa_amd_memory_pool_s { + /** + * Opaque handle. + */ + uint64_t handle; +} hsa_amd_memory_pool_t; + +typedef enum hsa_amd_memory_pool_global_flag_s { + /** + * The application can use allocations in the memory pool to store kernel + * arguments, and provide the values for the kernarg segment of + * a kernel dispatch. + */ + HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT = 1, + /** + * Updates to memory in this pool conform to HSA memory consistency model. + * If this flag is set, then ::HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_COARSE_GRAINED + * must not be set. + */ + HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED = 2, + /** + * Writes to memory in this pool can be performed by a single agent at a time. + */ + HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_COARSE_GRAINED = 4, + + /** Updates to memory in this memory pool have extended scope, acting as + * system-scope atomics for variables in memory regions of this type. + * Note: On non-compliant systems, device-specific actions may be required + * for system-scope coherence. */ + HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_EXTENDED_SCOPE_FINE_GRAINED = 8, + +} hsa_amd_memory_pool_global_flag_t; + +typedef enum hsa_amd_memory_pool_location_s { + /** + * This memory pool resides on the host (CPU) + */ + HSA_AMD_MEMORY_POOL_LOCATION_CPU = 0, + /** + * This memory pool resides on a GPU + */ + HSA_AMD_MEMORY_POOL_LOCATION_GPU = 1 +} hsa_amd_memory_pool_location_t; + +/** + * @brief Memory pool features. + */ +typedef enum { + /** + * Segment where the memory pool resides. The type of this attribute is + * ::hsa_amd_segment_t. + */ + HSA_AMD_MEMORY_POOL_INFO_SEGMENT = 0, + /** + * Flag mask. The value of this attribute is undefined if the value of + * ::HSA_AMD_MEMORY_POOL_INFO_SEGMENT is not ::HSA_AMD_SEGMENT_GLOBAL. The + * type of this attribute is uint32_t, a bit-field of + * ::hsa_amd_memory_pool_global_flag_t + * values. + */ + HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS = 1, + /** + * Size of this pool, in bytes. The type of this attribute is size_t. + */ + HSA_AMD_MEMORY_POOL_INFO_SIZE = 2, + /** + * Indicates whether memory in this pool can be allocated using + * ::hsa_amd_memory_pool_allocate. The type of this attribute is bool. + * + * The value of this flag is always false for memory pools in the group and + * private segments. + */ + HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED = 5, + /** + * Allocation granularity of buffers allocated by + * ::hsa_amd_memory_pool_allocate + * in this memory pool. The size of a buffer allocated in this pool is a + * multiple of the value of this attribute. While this is the minimum size of + * allocation allowed, it is recommened to use + * HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_REC_GRANULE to obtain the + * recommended allocation granularity size for this pool. The value of this + * attribute is only defined if + * ::HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED is true for + * this pool. The type of this attribute is size_t. + */ + HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE = 6, + /** + * Alignment of buffers allocated by ::hsa_amd_memory_pool_allocate in this + * pool. The value of this attribute is only defined if + * ::HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED is true for this pool, and + * must be a power of 2. The type of this attribute is size_t. + */ + HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALIGNMENT = 7, + /** + * This memory_pool can be made directly accessible by all the agents in the + * system (::hsa_amd_agent_memory_pool_get_info does not return + * ::HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED for any agent). The type of this + * attribute is bool. + */ + HSA_AMD_MEMORY_POOL_INFO_ACCESSIBLE_BY_ALL = 15, + /** + * Maximum aggregate allocation size in bytes. The type of this attribute + * is size_t. + */ + HSA_AMD_MEMORY_POOL_INFO_ALLOC_MAX_SIZE = 16, + /** + * Location of this memory pool. The type of this attribute + * is hsa_amd_memory_pool_location_t. + */ + HSA_AMD_MEMORY_POOL_INFO_LOCATION = 17, + /** + * Internal block size for allocations. This would also be the recommended + * granularity size for allocations as this prevents internal fragmentation. + * The value of this attribute is only defined if + * ::HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED is true for this pool. + * The size of this attribute is size_t. + */ + HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_REC_GRANULE = 18, +} hsa_amd_memory_pool_info_t; + +/** + * @brief Memory pool flag used to specify allocation directives + * + */ +typedef enum hsa_amd_memory_pool_flag_s { + /** + * Allocates memory that conforms to standard HSA memory consistency model + */ + HSA_AMD_MEMORY_POOL_STANDARD_FLAG = 0, + /** + * Allocates fine grain memory type where memory ordering is per point to + * point connection. Atomic memory operations on these memory buffers are not + * guaranteed to be visible at system scope. + */ + HSA_AMD_MEMORY_POOL_PCIE_FLAG = (1 << 0), + /** + * Allocates physically contiguous memory + */ + HSA_AMD_MEMORY_POOL_CONTIGUOUS_FLAG = (1 << 1), + +} hsa_amd_memory_pool_flag_t; + +/** + * @brief Get the current value of an attribute of a memory pool. + * + * @param[in] memory_pool A valid memory pool. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to a application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + */ +hsa_status_t HSA_API +hsa_amd_memory_pool_get_info(hsa_amd_memory_pool_t memory_pool, + hsa_amd_memory_pool_info_t attribute, void *value); + +/** + * @brief Iterate over the memory pools associated with a given agent, and + * invoke an application-defined callback on every iteration. + * + * @details An agent can directly access buffers located in some memory pool, or + * be enabled to access them by the application (see + * ::hsa_amd_agents_allow_access), yet that memory pool may not be returned by + * this function for that given agent. + * + * A memory pool of fine-grained type must be associated only with the host. + * + * @param[in] agent A valid agent. + * + * @param[in] callback Callback to be invoked on the same thread that called + * ::hsa_amd_agent_iterate_memory_pools, serially, once per memory pool that is + * associated with the agent. The HSA runtime passes two arguments to the + * callback: the memory pool, and the application data. If @p callback + * returns a status other than ::HSA_STATUS_SUCCESS for a particular iteration, + * the traversal stops and ::hsa_amd_agent_iterate_memory_pools returns that + * status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_amd_agent_iterate_memory_pools( + hsa_agent_t agent, + hsa_status_t (*callback)(hsa_amd_memory_pool_t memory_pool, void *data), + void *data); + +/** + * @brief Allocate a block of memory (or buffer) in the specified pool. + * + * @param[in] memory_pool Memory pool where to allocate memory from. The memory + * pool must have the ::HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALLOWED flag set. + * + * @param[in] size Allocation size, in bytes. Must not be zero. This value is + * rounded up to the nearest multiple of + * ::HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE in @p memory_pool. + * + * @param[in] flags A bit-field that is used to specify allocation + * directives. + * + * @param[out] ptr Pointer to the location where to store the base virtual + * address of + * the allocated block. The returned base address is aligned to the value of + * ::HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_ALIGNMENT in @p memory_pool. If the + * allocation fails, the returned value is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES No memory is available. + * + * @retval ::HSA_STATUS_ERROR_INVALID_MEMORY_POOL The memory pool is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION The host is not allowed to + * allocate memory in @p memory_pool, or @p size is greater than + * the value of HSA_AMD_MEMORY_POOL_INFO_ALLOC_MAX_SIZE in @p memory_pool. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p ptr is NULL, or @p size is 0, + * or flags is not 0. + * + */ +hsa_status_t HSA_API hsa_amd_memory_pool_allocate( + hsa_amd_memory_pool_t memory_pool, size_t size, uint32_t flags, void **ptr); + +/** + * @brief Deallocate a block of memory previously allocated using + * ::hsa_amd_memory_pool_allocate. + * + * @param[in] ptr Pointer to a memory block. If @p ptr does not match a value + * previously returned by ::hsa_amd_memory_pool_allocate, the behavior is + * undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + */ +hsa_status_t HSA_API hsa_amd_memory_pool_free(void *ptr); + +/** + * @brief Asynchronously copy a block of memory from the location pointed to by + * @p src on the @p src_agent to the memory block pointed to by @p dst on the @p + * dst_agent. + * Because the DMA engines used may not be in the same coherency domain, the + * caller must ensure that buffers are system-level coherent. In general this + * requires the sending device to have released the buffer to system scope prior + * to executing the copy API and the receiving device must execute a system + * scope acquire fence prior to use of the destination buffer. + * + * @param[out] dst Buffer where the content is to be copied. + * + * @param[in] dst_agent Agent associated with the @p dst. The agent must be able + * to directly access both the source and destination buffers in their current + * locations. May be zero in which case the runtime will attempt to discover the + * destination agent. Discovery may have variable and/or high latency. + * + * @param[in] src A valid pointer to the source of data to be copied. The source + * buffer must not overlap with the destination buffer, otherwise the copy will + * succeed but contents of @p dst is undefined. + * + * @param[in] src_agent Agent associated with the @p src. The agent must be able + * to directly access both the source and destination buffers in their current + * locations. May be zero in which case the runtime will attempt to discover the + * destination agent. Discovery may have variable and/or high latency. + * + * @param[in] size Number of bytes to copy. If @p size is 0, no copy is + * performed and the function returns success. Copying a number of bytes larger + * than the size of the buffers pointed by @p dst or @p src results in undefined + * behavior. + * + * @param[in] num_dep_signals Number of dependent signals. Can be 0. + * + * @param[in] dep_signals List of signals that must be waited on before the copy + * operation starts. The copy will start after every signal has been observed + * with the value 0. The dependent signal should not include completion signal + * from hsa_amd_memory_async_copy operation to be issued in future as that can + * result in a deadlock. If @p num_dep_signals is 0, this argument is ignored. + * + * @param[in] completion_signal Signal used to indicate completion of the copy + * operation. When the copy operation is finished, the value of the signal is + * decremented. The runtime indicates that an error has occurred during the copy + * operation by setting the value of the completion signal to a negative + * number. The signal handle must not be 0. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. The + * application is responsible for checking for asynchronous error conditions + * (see the description of @p completion_signal). + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT An agent is invalid or no discovered + * agent has access. + * + * @retval ::HSA_STATUS_ERROR_INVALID_SIGNAL @p completion_signal is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT The source or destination + * pointers are NULL, or the completion signal is 0. + */ +hsa_status_t HSA_API hsa_amd_memory_async_copy( + void *dst, hsa_agent_t dst_agent, const void *src, hsa_agent_t src_agent, + size_t size, uint32_t num_dep_signals, const hsa_signal_t *dep_signals, + hsa_signal_t completion_signal); + +/** + * @brief Asynchronously copy a block of memory from the location pointed to by + * @p src on the @p src_agent to the memory block pointed to by @p dst on the @p + * dst_agent on engine_id. + * + * WARNING: Concurrent use of this call with hsa_amd_memory_async_copy can + * result in resource conflicts as HSA runtime will auto assign engines with the + * latter call. Approach using both calls concurrently with caution. + * + * All param definitions are identical to hsa_amd_memory_async_copy with the + * exception of engine_id and force_copy_on_sdma. + * + * @param[in] - engine_id Target engine defined by hsa_amd_sdma_engine_id_t. + * Client should use hsa_amd_memory_copy_engine_status first to get the ID + * availability. + * + * @param[in] - force_copy_on_sdma By default, blit kernel copies are used when + * dst_agent == src_agent. Setting this to true will force the copy over SDMA1. + * + * All return definitions are identical to hsa_amd_memory_async_copy with the + * following ammendments: + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT The source or destination + * pointers are NULL, or the completion signal is 0 or engine_id is improperly + * bounded. + */ +hsa_status_t HSA_API hsa_amd_memory_async_copy_on_engine( + void *dst, hsa_agent_t dst_agent, const void *src, hsa_agent_t src_agent, + size_t size, uint32_t num_dep_signals, const hsa_signal_t *dep_signals, + hsa_signal_t completion_signal, hsa_amd_sdma_engine_id_t engine_id, + bool force_copy_on_sdma); +/** + * @brief Reports the availability of SDMA copy engines. + * + * @param[in] dst_agent Destination agent of copy status direction. + * + * @param[in] src_agent Source agent of copy status direction. + * + * @param[out] engine_ids_mask returns available SDMA engine IDs that can be + * masked with hsa_amd_sdma_engine_id_t. + * + * @retval ::HSA_STATUS_SUCCESS Agent has available SDMA engines. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Agent does not have available + * SDMA engines. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT dst_agent and src_agent are the same + * as dst_agent == src_agent is generally used for shader copies. + */ +hsa_status_t HSA_API hsa_amd_memory_copy_engine_status( + hsa_agent_t dst_agent, hsa_agent_t src_agent, uint32_t *engine_ids_mask); + +/* +[Provisional API] +Pitched memory descriptor. +All elements must be 4 byte aligned. Pitch and slice are in bytes. +*/ +typedef struct hsa_pitched_ptr_s { + void *base; + size_t pitch; + size_t slice; +} hsa_pitched_ptr_t; + +/* +[Provisional API] +Copy direction flag. +*/ +typedef enum { + hsaHostToHost = 0, + hsaHostToDevice = 1, + hsaDeviceToHost = 2, + hsaDeviceToDevice = 3 +} hsa_amd_copy_direction_t; + +/* +[Provisional API] +SDMA 3D memory copy API. The same requirements must be met by src and dst as in +hsa_amd_memory_async_copy. +Both src and dst must be directly accessible to the copy_agent during the copy, +src and dst rects must not overlap. CPU agents are not supported. API requires +SDMA and will return an error if SDMA is not available. Offsets and range carry +x in bytes, y and z in rows and layers. +*/ +hsa_status_t HSA_API hsa_amd_memory_async_copy_rect( + const hsa_pitched_ptr_t *dst, const hsa_dim3_t *dst_offset, + const hsa_pitched_ptr_t *src, const hsa_dim3_t *src_offset, + const hsa_dim3_t *range, hsa_agent_t copy_agent, + hsa_amd_copy_direction_t dir, uint32_t num_dep_signals, + const hsa_signal_t *dep_signals, hsa_signal_t completion_signal); + +/** + * @brief Type of accesses to a memory pool from a given agent. + */ +typedef enum { + /** + * The agent cannot directly access any buffer in the memory pool. + */ + HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED = 0, + /** + * The agent can directly access a buffer located in the pool; the application + * does not need to invoke ::hsa_amd_agents_allow_access. + */ + HSA_AMD_MEMORY_POOL_ACCESS_ALLOWED_BY_DEFAULT = 1, + /** + * The agent can directly access a buffer located in the pool, but only if the + * application has previously requested access to that buffer using + * ::hsa_amd_agents_allow_access. + */ + HSA_AMD_MEMORY_POOL_ACCESS_DISALLOWED_BY_DEFAULT = 2 +} hsa_amd_memory_pool_access_t; + +/** + * @brief Properties of the relationship between an agent a memory pool. + */ +typedef enum { + /** + * Hyper-transport bus type. + */ + HSA_AMD_LINK_INFO_TYPE_HYPERTRANSPORT = 0, + + /** + * QPI bus type. + */ + HSA_AMD_LINK_INFO_TYPE_QPI = 1, + + /** + * PCIe bus type. + */ + HSA_AMD_LINK_INFO_TYPE_PCIE = 2, + + /** + * Infiniband bus type. + */ + HSA_AMD_LINK_INFO_TYPE_INFINBAND = 3, + + /** + * xGMI link type. + */ + HSA_AMD_LINK_INFO_TYPE_XGMI = 4 + +} hsa_amd_link_info_type_t; + +/** + * @brief Link properties when accessing the memory pool from the specified + * agent. + */ +typedef struct hsa_amd_memory_pool_link_info_s { + /** + * Minimum transfer latency (rounded to ns). + */ + uint32_t min_latency; + + /** + * Maximum transfer latency (rounded to ns). + */ + uint32_t max_latency; + + /** + * Minimum link interface bandwidth in MB/s. + */ + uint32_t min_bandwidth; + + /** + * Maximum link interface bandwidth in MB/s. + */ + uint32_t max_bandwidth; + + /** + * Support for 32-bit atomic transactions. + */ + bool atomic_support_32bit; + + /** + * Support for 64-bit atomic transactions. + */ + bool atomic_support_64bit; + + /** + * Support for cache coherent transactions. + */ + bool coherent_support; + + /** + * The type of bus/link. + */ + hsa_amd_link_info_type_t link_type; + + /** + * NUMA distance of memory pool relative to querying agent + */ + uint32_t numa_distance; +} hsa_amd_memory_pool_link_info_t; + +/** + * @brief Properties of the relationship between an agent a memory pool. + */ +typedef enum { + /** + * Access to buffers located in the memory pool. The type of this attribute + * is ::hsa_amd_memory_pool_access_t. + * + * An agent can always directly access buffers currently located in a memory + * pool that is associated (the memory_pool is one of the values returned by + * ::hsa_amd_agent_iterate_memory_pools on the agent) with that agent. If the + * buffer is currently located in a memory pool that is not associated with + * the agent, and the value returned by this function for the given + * combination of agent and memory pool is not + * HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED, the application still needs to + * invoke + * ::hsa_amd_agents_allow_access in order to gain direct access to the buffer. + * + * If the given agent can directly access buffers the pool, the result is not + * HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED. If the memory pool is associated + * with the agent, or it is of fined-grained type, the result must not be + * HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED. If the memory pool is not + * associated with the agent, and does not reside in the global segment, the + * result must be HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED. + */ + HSA_AMD_AGENT_MEMORY_POOL_INFO_ACCESS = 0, + + /** + * Number of links to hop when accessing the memory pool from the specified + * agent. The value of this attribute is zero if the memory pool is associated + * with the agent, or if the access type is + * HSA_AMD_MEMORY_POOL_ACCESS_NEVER_ALLOWED. The type of this attribute is + * uint32_t. + */ + HSA_AMD_AGENT_MEMORY_POOL_INFO_NUM_LINK_HOPS = 1, + + /** + * Details of each link hop when accessing the memory pool starting from the + * specified agent. The type of this attribute is an array size of + * HSA_AMD_AGENT_MEMORY_POOL_INFO_NUM_LINK_HOPS with each element containing + * ::hsa_amd_memory_pool_link_info_t. + */ + HSA_AMD_AGENT_MEMORY_POOL_INFO_LINK_INFO = 2 + +} hsa_amd_agent_memory_pool_info_t; + +/** + * @brief Get the current value of an attribute of the relationship between an + * agent and a memory pool. + * + * @param[in] agent Agent. + * + * @param[in] memory_pool Memory pool. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to a application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + */ +hsa_status_t HSA_API hsa_amd_agent_memory_pool_get_info( + hsa_agent_t agent, hsa_amd_memory_pool_t memory_pool, + hsa_amd_agent_memory_pool_info_t attribute, void *value); + +/** + * @brief Enable direct access to a buffer from a given set of agents. + * + * @details + * + * Upon return, only the listed agents and the agent associated with the + * buffer's memory pool have direct access to the @p ptr. + * + * Any agent that has access to the buffer before and after the call to + * ::hsa_amd_agents_allow_access will also have access while + * ::hsa_amd_agents_allow_access is in progress. + * + * The caller is responsible for ensuring that each agent in the list + * must be able to access the memory pool containing @p ptr + * (using ::hsa_amd_agent_memory_pool_get_info with + * ::HSA_AMD_AGENT_MEMORY_POOL_INFO_ACCESS attribute), otherwise error code is + * returned. + * + * @param[in] num_agents Size of @p agents. + * + * @param[in] agents List of agents. If @p num_agents is 0, this argument is + * ignored. + * + * @param[in] flags A list of bit-field that is used to specify access + * information in a per-agent basis. This is currently reserved and must be + * NULL. + * + * @param[in] ptr A buffer previously allocated using + * ::hsa_amd_memory_pool_allocate. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p num_agents is 0, or @p agents + * is NULL, @p flags is not NULL, or attempting to enable access to agent(s) + * because @p ptr is allocated from an inaccessible pool. + * + */ +hsa_status_t HSA_API hsa_amd_agents_allow_access(uint32_t num_agents, + const hsa_agent_t *agents, + const uint32_t *flags, + const void *ptr); + +/** + * @brief Query if buffers currently located in some memory pool can be + * relocated to a destination memory pool. + * + * @details If the returned value is non-zero, a migration of a buffer to @p + * dst_memory_pool using ::hsa_amd_memory_migrate may nevertheless fail due to + * resource limitations. + * + * @param[in] src_memory_pool Source memory pool. + * + * @param[in] dst_memory_pool Destination memory pool. + * + * @param[out] result Pointer to a memory location where the result of the query + * is stored. Must not be NULL. If buffers currently located in @p + * src_memory_pool can be relocated to @p dst_memory_pool, the result is + * true. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_MEMORY_POOL One of the memory pools is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p result is NULL. + */ +hsa_status_t HSA_API hsa_amd_memory_pool_can_migrate( + hsa_amd_memory_pool_t src_memory_pool, + hsa_amd_memory_pool_t dst_memory_pool, bool *result); + +/** + * @brief Relocate a buffer to a new memory pool. + * + * @details When a buffer is migrated, its virtual address remains the same but + * its physical contents are moved to the indicated memory pool. + * + * After migration, only the agent associated with the destination pool will + * have access. + * + * The caller is also responsible for ensuring that the allocation in the + * source memory pool where the buffer is currently located can be migrated to + * the specified destination memory pool (using + * ::hsa_amd_memory_pool_can_migrate returns a value of true for the source and + * destination memory pools), otherwise behavior is undefined. + * + * The caller must ensure that the buffer is not accessed while it is migrated. + * + * @param[in] ptr Buffer to be relocated. The buffer must have been released to + * system prior to call this API. The buffer will be released to system upon + * completion. + * + * @param[in] memory_pool Memory pool where to place the buffer. + * + * @param[in] flags A bit-field that is used to specify migration + * information. Must be zero. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_MEMORY_POOL The destination memory pool is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES There is a failure in + * allocating the necessary resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p flags is not 0. + */ +hsa_status_t HSA_API hsa_amd_memory_migrate(const void *ptr, + hsa_amd_memory_pool_t memory_pool, + uint32_t flags); + +/** + * + * @brief Pin a host pointer allocated by C/C++ or OS allocator (i.e. ordinary + * system DRAM) and return a new pointer accessible by the @p agents. If the @p + * host_ptr overlaps with previously locked memory, then the overlap area is + * kept locked (i.e multiple mappings are permitted). In this case, the same + * input @p host_ptr may give different locked @p agent_ptr and when it does, + * they are not necessarily coherent (i.e. accessing either @p agent_ptr is not + * equivalent). Accesses to @p agent_ptr are coarse grained. + * + * @param[in] host_ptr A buffer allocated by C/C++ or OS allocator. + * + * @param[in] size The size to be locked. + * + * @param[in] agents Array of agent handle to gain access to the @p host_ptr. + * If this parameter is NULL and the @p num_agent is 0, all agents + * in the platform will gain access to the @p host_ptr. + * + * @param[out] agent_ptr Pointer to the location where to store the new address. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES There is a failure in + * allocating the necessary resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT One or more agent in @p agents is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p size is 0 or @p host_ptr or + * @p agent_ptr is NULL or @p agents not NULL but @p num_agent is 0 or @p agents + * is NULL but @p num_agent is not 0. + */ +hsa_status_t HSA_API hsa_amd_memory_lock(void *host_ptr, size_t size, + hsa_agent_t *agents, int num_agent, + void **agent_ptr); + +/** + * + * @brief Pin a host pointer allocated by C/C++ or OS allocator (i.e. ordinary + * system DRAM) and return a new pointer accessible by the @p agents. If the @p + * host_ptr overlaps with previously locked memory, then the overlap area is + * kept locked (i.e. multiple mappings are permitted). In this case, the same + * input @p host_ptr may give different locked @p agent_ptr and when it does, + * they are not necessarily coherent (i.e. accessing either @p agent_ptr is not + * equivalent). Acesses to the memory via @p agent_ptr have the same access + * properties as memory allocated from + * @p pool as determined by ::hsa_amd_memory_pool_get_info and + * ::hsa_amd_agent_memory_pool_get_info (ex. coarse/fine grain, platform atomic + * support, link info). Physical composition and placement of the memory (ex. + * page size, NUMA binding) is not changed. + * + * @param[in] host_ptr A buffer allocated by C/C++ or OS allocator. + * + * @param[in] size The size to be locked. + * + * @param[in] agents Array of agent handle to gain access to the @p host_ptr. + * If this parameter is NULL and the @p num_agent is 0, all agents + * in the platform will gain access to the @p host_ptr. + * + * @param[in] pool Global memory pool owned by a CPU agent. + * + * @param[in] flags A bit-field that is used to specify allocation + * directives. Reserved parameter, must be 0. + * + * @param[out] agent_ptr Pointer to the location where to store the new address. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES There is a failure in + * allocating the necessary resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT One or more agent in @p agents is + * invalid or can not access @p pool. + * + * @retval ::HSA_STATUS_ERROR_INVALID_MEMORY_POOL @p pool is invalid or not + * owned by a CPU agent. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p size is 0 or @p host_ptr or + * @p agent_ptr is NULL or @p agents not NULL but @p num_agent is 0 or @p agents + * is NULL but @p num_agent is not 0 or flags is not 0. + */ +hsa_status_t HSA_API hsa_amd_memory_lock_to_pool( + void *host_ptr, size_t size, hsa_agent_t *agents, int num_agent, + hsa_amd_memory_pool_t pool, uint32_t flags, void **agent_ptr); + +/** + * + * @brief Unpin the host pointer previously pinned via ::hsa_amd_memory_lock or + * ::hsa_amd_memory_lock_to_pool. + * + * @details The behavior is undefined if the host pointer being unpinned does + * not match previous pinned address or if the host pointer was already + * deallocated. + * + * @param[in] host_ptr A buffer allocated by C/C++ or OS allocator that was + * pinned previously via ::hsa_amd_memory_lock or ::hsa_amd_memory_lock_to_pool. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + */ +hsa_status_t HSA_API hsa_amd_memory_unlock(void *host_ptr); + +/** + * @brief Sets the first @p count of uint32_t of the block of memory pointed by + * @p ptr to the specified @p value. + * + * @param[in] ptr Pointer to the block of memory to fill. + * + * @param[in] value Value to be set. + * + * @param[in] count Number of uint32_t element to be set to the value. + * + * @retval HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p ptr is NULL or + * not 4 bytes aligned + * + * @retval HSA_STATUS_ERROR_INVALID_ALLOCATION if the given memory + * region was not allocated with HSA runtime APIs. + * + */ +hsa_status_t HSA_API hsa_amd_memory_fill(void *ptr, uint32_t value, + size_t count); + +/** + * @brief Maps an interop object into the HSA flat address space and establishes + * memory residency. The metadata pointer is valid during the lifetime of the + * map (until hsa_amd_interop_unmap_buffer is called). + * Multiple calls to hsa_amd_interop_map_buffer with the same interop_handle + * result in multiple mappings with potentially different addresses and + * different metadata pointers. Concurrent operations on these addresses are + * not coherent. Memory must be fenced to system scope to ensure consistency, + * between mappings and with any views of this buffer in the originating + * software stack. + * + * @param[in] num_agents Number of agents which require access to the memory + * + * @param[in] agents List of accessing agents. + * + * @param[in] interop_handle Handle of interop buffer (dmabuf handle in Linux) + * + * @param [in] flags Reserved, must be 0 + * + * @param[out] size Size in bytes of the mapped object + * + * @param[out] ptr Base address of the mapped object + * + * @param[out] metadata_size Size of metadata in bytes, may be NULL + * + * @param[out] metadata Pointer to metadata, may be NULL + * + * @retval HSA_STATUS_SUCCESS if successfully mapped + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in allocating + * necessary resources + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT all other errors + */ +hsa_status_t HSA_API hsa_amd_interop_map_buffer( + uint32_t num_agents, hsa_agent_t *agents, int interop_handle, + uint32_t flags, size_t *size, void **ptr, size_t *metadata_size, + const void **metadata); + +/** + * @brief Removes a previously mapped interop object from HSA's flat address + * space. Ends lifetime for the mapping's associated metadata pointer. + */ +hsa_status_t HSA_API hsa_amd_interop_unmap_buffer(void *ptr); + +/** + * @brief Encodes an opaque vendor specific image format. The length of data + * depends on the underlying format. This structure must not be copied as its + * true length can not be determined. + */ +typedef struct hsa_amd_image_descriptor_s { + /* + Version number of the descriptor + */ + uint32_t version; + + /* + Vendor and device PCI IDs for the format as VENDOR_ID<<16|DEVICE_ID. + */ + uint32_t deviceID; + + /* + Start of vendor specific data. + */ + uint32_t data[1]; +} hsa_amd_image_descriptor_t; + +/** + * @brief Creates an image from an opaque vendor specific image format. + * Does not modify data at image_data. Intended initially for + * accessing interop images. + * + * @param agent[in] Agent on which to create the image + * + * @param[in] image_descriptor[in] Vendor specific image format + * + * @param[in] image_data Pointer to image backing store + * + * @param[in] access_permission Access permissions for the image object + * + * @param[out] image Created image object. + * + * @retval HSA_STATUS_SUCCESS Image created successfully + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in allocating + * necessary resources + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT Bad or mismatched descriptor, + * null image_data, or mismatched access_permission. + */ +hsa_status_t HSA_API hsa_amd_image_create( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + const hsa_amd_image_descriptor_t *image_layout, const void *image_data, + hsa_access_permission_t access_permission, hsa_ext_image_t *image); + +/** + * @brief Denotes the type of memory in a pointer info query. + */ +typedef enum { + /* + Memory is not known to the HSA driver. Unallocated or unlocked system memory. + */ + HSA_EXT_POINTER_TYPE_UNKNOWN = 0, + /* + Memory was allocated with an HSA memory allocator. + */ + HSA_EXT_POINTER_TYPE_HSA = 1, + /* + System memory which has been locked for use with an HSA agent. + + Memory of this type is normal malloc'd memory and is always accessible to + the CPU. Pointer info queries may not include CPU agents in the accessible + agents list as the CPU has implicit access. + */ + HSA_EXT_POINTER_TYPE_LOCKED = 2, + /* + Memory originated in a graphics component and is shared with ROCr. + */ + HSA_EXT_POINTER_TYPE_GRAPHICS = 3, + /* + Memory has been shared with the local process via ROCr IPC APIs. + */ + HSA_EXT_POINTER_TYPE_IPC = 4 +} hsa_amd_pointer_type_t; + +/** + * @brief Describes a memory allocation known to ROCr. + * Within a ROCr major version this structure can only grow. + */ +typedef struct hsa_amd_pointer_info_s { + /* + Size in bytes of this structure. Used for version control within a major ROCr + revision. Set to sizeof(hsa_amd_pointer_t) prior to calling + hsa_amd_pointer_info. If the runtime supports an older version of pointer + info then size will be smaller on return. Members starting after the return + value of size will not be updated by hsa_amd_pointer_info. + */ + uint32_t size; + /* + The type of allocation referenced. + */ + hsa_amd_pointer_type_t type; + /* + Base address at which non-host agents may access the allocation. This field is + not meaningful if the type of the allocation is HSA_EXT_POINTER_TYPE_UNKNOWN. + */ + void *agentBaseAddress; + /* + Base address at which the host agent may access the allocation. This field is + not meaningful if the type of the allocation is HSA_EXT_POINTER_TYPE_UNKNOWN. + */ + void *hostBaseAddress; + /* + Size of the allocation. This field is not meaningful if the type of the + allocation is HSA_EXT_POINTER_TYPE_UNKNOWN. + */ + size_t sizeInBytes; + /* + Application provided value. This field is not meaningful if the type of the + allocation is HSA_EXT_POINTER_TYPE_UNKNOWN. + */ + void *userData; + /* + Reports an agent which "owns" (ie has preferred access to) the pool in which + the allocation was made. When multiple agents share equal access to a pool + (ex: multiple CPU agents, or multi-die GPU boards) any such agent may be + returned. This field is not meaningful if the type of the allocation is + HSA_EXT_POINTER_TYPE_UNKNOWN or if this agent is not available in this + process, for e.g if this agent is masked using ROCR_VISIBLE_DEVICES. + */ + hsa_agent_t agentOwner; + /* + Contains a bitfield of hsa_amd_memory_pool_global_flag_t values. + Reports the effective global flags bitmask for the allocation. This field is + not meaningful if the type of the allocation is HSA_EXT_POINTER_TYPE_UNKNOWN. + */ + uint32_t global_flags; +} hsa_amd_pointer_info_t; + +/** + * @brief Retrieves information about the allocation referenced by the given + * pointer. Optionally returns the number and list of agents which can + * directly access the allocation. In case this virtual address is unknown, the + * pointer type returned will be HSA_EXT_POINTER_TYPE_UNKNOWN and the only + * fields that are valid after hsa_amd_pointer_info returns are size and type. + * + * @param[in] ptr Pointer which references the allocation to retrieve info for. + * + * @param[in, out] info Pointer to structure to be filled with allocation info. + * Data member size must be set to the size of the structure prior to calling + * hsa_amd_pointer_info. On return size will be set to the size of the + * pointer info structure supported by the runtime, if smaller. Members + * beyond the returned value of size will not be updated by the API. + * Must not be NULL. + * + * @param[in] alloc Function pointer to an allocator used to allocate the + * @p accessible array. If NULL @p accessible will not be returned. + * + * @param[out] num_agents_accessible Recieves the count of agents in + * @p accessible. If NULL @p accessible will not be returned. + * + * @param[out] accessible Recieves a pointer to the array, allocated by @p + * alloc, holding the list of agents which may directly access the allocation. + * May be NULL. + * + * @retval HSA_STATUS_SUCCESS Info retrieved successfully + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in allocating + * necessary resources + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT NULL in @p ptr or @p info. + */ +hsa_status_t HSA_API hsa_amd_pointer_info(const void *ptr, + hsa_amd_pointer_info_t *info, + void *(*alloc)(size_t), + uint32_t *num_agents_accessible, + hsa_agent_t **accessible); + +/** + * @brief Associates an arbitrary pointer with an allocation known to ROCr. + * The pointer can be fetched by hsa_amd_pointer_info in the userData field. + * + * @param[in] ptr Pointer to the first byte of an allocation known to ROCr + * with which to associate @p userdata. + * + * @param[in] userdata Abitrary pointer to associate with the allocation. + * + * @retval HSA_STATUS_SUCCESS @p userdata successfully stored. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in allocating + * necessary resources + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p ptr is not known to ROCr. + */ +hsa_status_t HSA_API hsa_amd_pointer_info_set_userdata(const void *ptr, + void *userdata); + +/** + * @brief 256-bit process independent identifier for a ROCr shared memory + * allocation. + */ +typedef struct hsa_amd_ipc_memory_s { + uint32_t handle[8]; +} hsa_amd_ipc_memory_t; + +/** + * @brief Prepares an allocation for interprocess sharing and creates a + * handle of type hsa_amd_ipc_memory_t uniquely identifying the allocation. A + * handle is valid while the allocation it references remains accessible in + * any process. In general applications should confirm that a shared memory + * region has been attached (via hsa_amd_ipc_memory_attach) in the remote + * process prior to releasing that memory in the local process. + * Repeated calls for the same allocation may, but are not required to, return + * unique handles. The allocation needs to be on memory on an agent of type + * HSA_DEVICE_TYPE_GPU. + * + * @param[in] ptr Pointer to device memory allocated via ROCr APIs to prepare + * for sharing. + * + * @param[in] len Length in bytes of the allocation to share. + * + * @param[out] handle Process independent identifier referencing the shared + * allocation. + * + * @retval HSA_STATUS_SUCCESS allocation is prepared for interprocess sharing. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in allocating + * necessary resources + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p ptr does not point to the + * first byte of an allocation made through ROCr, or len is not the full length + * of the allocation or handle is NULL. + */ +hsa_status_t HSA_API hsa_amd_ipc_memory_create(void *ptr, size_t len, + hsa_amd_ipc_memory_t *handle); + +/** + * @brief Imports shared memory into the local process and makes it accessible + * by the given agents. If a shared memory handle is attached multiple times + * in a process each attach may return a different address. Each returned + * address is refcounted and requires a matching number of calls to + * hsa_amd_ipc_memory_detach to release the shared memory mapping. + * + * @param[in] handle Pointer to the identifier for the shared memory. + * + * @param[in] len Length of the shared memory to import. + * Reserved. Must be the full length of the shared allocation in this version. + * + * @param[in] num_agents Count of agents in @p mapping_agents. + * May be zero if all agents are to be allowed access. + * + * @param[in] mapping_agents List of agents to access the shared memory. + * Ignored if @p num_agents is zero. + * + * @param[out] mapped_ptr Recieves a process local pointer to the shared memory. + * + * @retval HSA_STATUS_SUCCESS if memory is successfully imported. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in allocating + * necessary resources + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p handle is not valid, @p len is + * incorrect, @p mapped_ptr is NULL, or some agent for which access was + * requested can not access the shared memory. + */ +hsa_status_t HSA_API hsa_amd_ipc_memory_attach( + const hsa_amd_ipc_memory_t *handle, size_t len, uint32_t num_agents, + const hsa_agent_t *mapping_agents, void **mapped_ptr); + +/** + * @brief Decrements the reference count for the shared memory mapping and + * releases access to shared memory imported with hsa_amd_ipc_memory_attach. + * + * @param[in] mapped_ptr Pointer to the first byte of a shared allocation + * imported with hsa_amd_ipc_memory_attach. + * + * @retval HSA_STATUS_SUCCESS if @p mapped_ptr was imported with + * hsa_amd_ipc_memory_attach. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p mapped_ptr was not imported + * with hsa_amd_ipc_memory_attach. + */ +hsa_status_t HSA_API hsa_amd_ipc_memory_detach(void *mapped_ptr); + +/** + * @brief 256-bit process independent identifier for a ROCr IPC signal. + */ +typedef hsa_amd_ipc_memory_t hsa_amd_ipc_signal_t; + +/** + * @brief Obtains an interprocess sharing handle for a signal. The handle is + * valid while the signal it references remains valid in any process. In + * general applications should confirm that the signal has been attached (via + * hsa_amd_ipc_signal_attach) in the remote process prior to destroying that + * signal in the local process. + * Repeated calls for the same signal may, but are not required to, return + * unique handles. + * + * @param[in] signal Signal created with attribute HSA_AMD_SIGNAL_IPC. + * + * @param[out] handle Process independent identifier referencing the shared + * signal. + * + * @retval HSA_STATUS_SUCCESS @p handle is ready to use for interprocess + * sharing. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in allocating + * necessary resources + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p signal is not a valid signal + * created with attribute HSA_AMD_SIGNAL_IPC or handle is NULL. + */ +hsa_status_t HSA_API hsa_amd_ipc_signal_create(hsa_signal_t signal, + hsa_amd_ipc_signal_t *handle); + +/** + * @brief Imports an IPC capable signal into the local process. If an IPC + * signal handle is attached multiple times in a process each attach may return + * a different signal handle. Each returned signal handle is refcounted and + * requires a matching number of calls to hsa_signal_destroy to release the + * shared signal. + * + * @param[in] handle Pointer to the identifier for the shared signal. + * + * @param[out] signal Recieves a process local signal handle to the shared + * signal. + * + * @retval HSA_STATUS_SUCCESS if the signal is successfully imported. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED if HSA is not initialized + * + * @retval HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in allocating + * necessary resources + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p handle is not valid. + */ +hsa_status_t HSA_API hsa_amd_ipc_signal_attach( + const hsa_amd_ipc_signal_t *handle, hsa_signal_t *signal); + +/** + * @brief GPU system event type. + */ +typedef enum hsa_amd_event_type_s { + /* + AMD GPU memory fault. + */ + HSA_AMD_GPU_MEMORY_FAULT_EVENT = 0, + /* + AMD GPU HW Exception. + */ + HSA_AMD_GPU_HW_EXCEPTION_EVENT, +} hsa_amd_event_type_t; + +/** + * @brief Flags denoting the cause of a memory fault. + */ +typedef enum { + // Page not present or supervisor privilege. + HSA_AMD_MEMORY_FAULT_PAGE_NOT_PRESENT = 1 << 0, + // Write access to a read-only page. + HSA_AMD_MEMORY_FAULT_READ_ONLY = 1 << 1, + // Execute access to a page marked NX. + HSA_AMD_MEMORY_FAULT_NX = 1 << 2, + // GPU attempted access to a host only page. + HSA_AMD_MEMORY_FAULT_HOST_ONLY = 1 << 3, + // DRAM ECC failure. + HSA_AMD_MEMORY_FAULT_DRAMECC = 1 << 4, + // Can't determine the exact fault address. + HSA_AMD_MEMORY_FAULT_IMPRECISE = 1 << 5, + // SRAM ECC failure (ie registers, no fault address). + HSA_AMD_MEMORY_FAULT_SRAMECC = 1 << 6, + // GPU reset following unspecified hang. + HSA_AMD_MEMORY_FAULT_HANG = 1U << 31 +} hsa_amd_memory_fault_reason_t; + +/** + * @brief AMD GPU memory fault event data. + */ +typedef struct hsa_amd_gpu_memory_fault_info_s { + /* + The agent where the memory fault occurred. + */ + hsa_agent_t agent; + /* + Virtual address accessed. + */ + uint64_t virtual_address; + /* + Bit field encoding the memory access failure reasons. There could be multiple + bits set for one fault. Bits are defined in hsa_amd_memory_fault_reason_t. + */ + uint32_t fault_reason_mask; +} hsa_amd_gpu_memory_fault_info_t; + +/** + * @brief Flags denoting the type of a HW exception + */ +typedef enum { + // Unused for now + HSA_AMD_HW_EXCEPTION_RESET_TYPE_OTHER = 1 << 0, +} hsa_amd_hw_exception_reset_type_t; + +/** + * @brief Flags denoting the cause of a HW exception + */ +typedef enum { + // GPU Hang + HSA_AMD_HW_EXCEPTION_CAUSE_GPU_HANG = 1 << 0, + // SRAM ECC + HSA_AMD_HW_EXCEPTION_CAUSE_ECC = 1 << 1, +} hsa_amd_hw_exception_reset_cause_t; + +/** + * @brief AMD GPU HW Exception event data. + */ +typedef struct hsa_amd_gpu_hw_exception_info_s { + /* + The agent where the HW exception occurred. + */ + hsa_agent_t agent; + hsa_amd_hw_exception_reset_type_t reset_type; + hsa_amd_hw_exception_reset_cause_t reset_cause; +} hsa_amd_gpu_hw_exception_info_t; + +/** + * @brief AMD GPU event data passed to event handler. + */ +typedef struct hsa_amd_event_s { + /* + The event type. + */ + hsa_amd_event_type_t event_type; + union { + /* + The memory fault info, only valid when @p event_type is + HSA_AMD_GPU_MEMORY_FAULT_EVENT. + */ + hsa_amd_gpu_memory_fault_info_t memory_fault; + /* + The memory fault info, only valid when @p event_type is + HSA_AMD_GPU_HW_EXCEPTION_EVENT. + */ + hsa_amd_gpu_hw_exception_info_t hw_exception; + }; +} hsa_amd_event_t; + +typedef hsa_status_t (*hsa_amd_system_event_callback_t)( + const hsa_amd_event_t *event, void *data); + +/** + * @brief Register AMD GPU event handler. + * + * @param[in] callback Callback to be invoked when an event is triggered. + * The HSA runtime passes two arguments to the callback: @p event + * is defined per event by the HSA runtime, and @p data is the user data. + * + * @param[in] data User data that is passed to @p callback. May be NULL. + * + * @retval HSA_STATUS_SUCCESS The handler has been registered successfully. + * + * @retval HSA_STATUS_ERROR An event handler has already been registered. + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p event is invalid. + */ +hsa_status_t HSA_API hsa_amd_register_system_event_handler( + hsa_amd_system_event_callback_t callback, void *data); + +/** + * @brief Per-queue dispatch and wavefront scheduling priority. + */ +typedef enum hsa_amd_queue_priority_s { + /* + Below normal/high priority compute and all graphics + */ + HSA_AMD_QUEUE_PRIORITY_LOW = 0, + /* + Above low priority compute, below high priority compute and all graphics + */ + HSA_AMD_QUEUE_PRIORITY_NORMAL = 1, + /* + Above low/normal priority compute and all graphics + */ + HSA_AMD_QUEUE_PRIORITY_HIGH = 2, +} hsa_amd_queue_priority_t; + +/** + * @brief Modifies the dispatch and wavefront scheduling prioirty for a + * given compute queue. The default is HSA_AMD_QUEUE_PRIORITY_NORMAL. + * + * @param[in] queue Compute queue to apply new priority to. + * + * @param[in] priority Priority to associate with queue. + * + * @retval HSA_STATUS_SUCCESS if priority was changed successfully. + * + * @retval HSA_STATUS_ERROR_INVALID_QUEUE if queue is not a valid + * compute queue handle. + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT if priority is not a valid + * value from hsa_amd_queue_priority_t. + */ +hsa_status_t HSA_API hsa_amd_queue_set_priority( + hsa_queue_t *queue, hsa_amd_queue_priority_t priority); + +/** + * @brief Deallocation notifier function type. + */ +typedef void (*hsa_amd_deallocation_callback_t)(void *ptr, void *user_data); + +/** + * @brief Registers a deallocation notifier monitoring for release of agent + * accessible address @p ptr. If successful, @p callback will be invoked when + * @p ptr is removed from accessibility from all agents. + * + * Notification callbacks are automatically deregistered when they are invoked. + * + * Note: The current version supports notifications of address release + * originating from ::hsa_amd_memory_pool_free. Support for other address + * release APIs will follow. + * + * @param[in] ptr Agent accessible address to monitor for deallocation. Passed + * to @p callback. + * + * @param[in] callback Notifier to be invoked when @p ptr is released from + * agent accessibility. + * + * @param[in] user_data User provided value passed to @p callback. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The notifier registered successfully + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION @p ptr does not refer to a + * valid agent accessible address. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL or @p ptr is + * NULL. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES if there is a failure in + * allocating necessary resources + */ +hsa_status_t HSA_API hsa_amd_register_deallocation_callback( + void *ptr, hsa_amd_deallocation_callback_t callback, void *user_data); + +/** + * @brief Removes a deallocation notifier previously registered with + * ::hsa_amd_register_deallocation_callback. Arguments must be identical to + * those given in ::hsa_amd_register_deallocation_callback. + * + * @param[in] ptr Agent accessible address which was monitored for deallocation. + * + * @param[in] callback Notifier to be removed. + * + * @retval ::HSA_STATUS_SUCCESS The notifier has been removed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT The given notifier was not + * registered. + */ +hsa_status_t HSA_API hsa_amd_deregister_deallocation_callback( + void *ptr, hsa_amd_deallocation_callback_t callback); + +typedef enum hsa_amd_svm_model_s { + /** + * Updates to memory with this attribute conform to HSA memory consistency + * model. + */ + HSA_AMD_SVM_GLOBAL_FLAG_FINE_GRAINED = 0, + /** + * Writes to memory with this attribute can be performed by a single agent + * at a time. + */ + HSA_AMD_SVM_GLOBAL_FLAG_COARSE_GRAINED = 1, + /** + * Memory region queried contains subregions with both + * HSA_AMD_SVM_GLOBAL_FLAG_COARSE_GRAINED and + * HSA_AMD_SVM_GLOBAL_FLAG_FINE_GRAINED attributes. + * + * This attribute can not be used in hsa_amd_svm_attributes_set. It is a + * possible return from hsa_amd_svm_attributes_get indicating that the query + * region contains both coarse and fine grained memory. + */ + HSA_AMD_SVM_GLOBAL_FLAG_INDETERMINATE = 2 +} hsa_amd_svm_model_t; + +typedef enum hsa_amd_svm_attribute_s { + // Memory model attribute. + // Type of this attribute is hsa_amd_svm_model_t. + HSA_AMD_SVM_ATTRIB_GLOBAL_FLAG = 0, + // Marks the range read only. This allows multiple physical copies to be + // placed local to each accessing device. + // Type of this attribute is bool. + HSA_AMD_SVM_ATTRIB_READ_ONLY = 1, + // Automatic migrations should attempt to keep the memory within the xgmi hive + // containing accessible agents. + // Type of this attribute is bool. + HSA_AMD_SVM_ATTRIB_HIVE_LOCAL = 2, + // Page granularity to migrate at once. Page granularity is specified as + // log2(page_count). + // Type of this attribute is uint64_t. + HSA_AMD_SVM_ATTRIB_MIGRATION_GRANULARITY = 3, + // Physical location to prefer when automatic migration occurs. + // Set to the null agent handle (handle == 0) to indicate there + // is no preferred location. + // Type of this attribute is hsa_agent_t. + HSA_AMD_SVM_ATTRIB_PREFERRED_LOCATION = 4, + // This attribute can not be used in ::hsa_amd_svm_attributes_set (see + // ::hsa_amd_svm_prefetch_async). + // Queries the physical location of most recent prefetch command. + // If the prefetch location has not been set or is not uniform across the + // address range then returned hsa_agent_t::handle will be 0. + // Querying this attribute will return the destination agent of the most + // recent ::hsa_amd_svm_prefetch_async targeting the address range. If + // multiple async prefetches have been issued targeting the region and the + // most recently issued prefetch has completed then the query will return + // the location of the most recently completed prefetch. + // Type of this attribute is hsa_agent_t. + HSA_AMD_SVM_ATTRIB_PREFETCH_LOCATION = 5, + // Optimizes with the anticipation that the majority of operations to the + // range will be read operations. + // Type of this attribute is bool. + HSA_AMD_SVM_ATTRIB_READ_MOSTLY = 6, + // Allows the execution on GPU. + // Type of this attribute is bool. + HSA_AMD_SVM_ATTRIB_GPU_EXEC = 7, + // This attribute can not be used in ::hsa_amd_svm_attributes_get. + // Enables an agent for access to the range. Access may incur a page fault + // and associated memory migration. Either this or + // HSA_AMD_SVM_ATTRIB_AGENT_ACCESSIBLE_IN_PLACE is required prior to SVM + // access if HSA_AMD_SYSTEM_INFO_SVM_ACCESSIBLE_BY_DEFAULT is false. + // Type of this attribute is hsa_agent_t. + HSA_AMD_SVM_ATTRIB_AGENT_ACCESSIBLE = 0x200, + // This attribute can not be used in ::hsa_amd_svm_attributes_get. + // Enables an agent for access to the range without page faults. Access + // will not incur a page fault and will not cause access based migration. + // and associated memory migration. Either this or + // HSA_AMD_SVM_ATTRIB_AGENT_ACCESSIBLE is required prior to SVM access if + // HSA_AMD_SYSTEM_INFO_SVM_ACCESSIBLE_BY_DEFAULT is false. + // Type of this attribute is hsa_agent_t. + HSA_AMD_SVM_ATTRIB_AGENT_ACCESSIBLE_IN_PLACE = 0x201, + // This attribute can not be used in ::hsa_amd_svm_attributes_get. + // Denies an agent access to the memory range. Access will cause a terminal + // segfault. + // Type of this attribute is hsa_agent_t. + HSA_AMD_SVM_ATTRIB_AGENT_NO_ACCESS = 0x202, + // This attribute can not be used in ::hsa_amd_svm_attributes_set. + // Returns the access attribute associated with the agent. + // The agent to query must be set in the attribute value field. + // The attribute enum will be replaced with the agent's current access + // attribute for the address range. + // TODO: Clarify KFD return value for non-uniform access attribute. + // Type of this attribute is hsa_agent_t. + HSA_AMD_SVM_ATTRIB_ACCESS_QUERY = 0x203, +} hsa_amd_svm_attribute_t; + +// List type for hsa_amd_svm_attributes_set/get. +typedef struct hsa_amd_svm_attribute_pair_s { + // hsa_amd_svm_attribute_t value. + uint64_t attribute; + // Attribute value. Bit values should be interpreted according to the type + // given in the associated attribute description. + uint64_t value; +} hsa_amd_svm_attribute_pair_t; + +/** + * @brief Sets SVM memory attributes. + * + * If HSA_AMD_SYSTEM_INFO_SVM_ACCESSIBLE_BY_DEFAULT returns false then enabling + * access to an Agent via this API (setting HSA_AMD_SVM_ATTRIB_AGENT_ACCESSIBLE + * or HSA_AMD_SVM_ATTRIB_AGENT_ACCESSIBLE_IN_PLACE) is required prior to SVM + * memory access by that Agent. + * + * Attributes HSA_AMD_SVM_ATTRIB_ACCESS_QUERY and + * HSA_AMD_SVM_ATTRIB_PREFETCH_LOCATION may not be used with this API. + * + * @param[in] ptr Will be aligned down to nearest page boundary. + * + * @param[in] size Will be aligned up to nearest page boundary. + * + * @param[in] attribute_list List of attributes to set for the address range. + * + * @param[in] attribute_count Length of @p attribute_list. + */ +hsa_status_t +hsa_amd_svm_attributes_set(void *ptr, size_t size, + hsa_amd_svm_attribute_pair_t *attribute_list, + size_t attribute_count); + +/** + * @brief Gets SVM memory attributes. + * + * Attributes HSA_AMD_SVM_ATTRIB_AGENT_ACCESSIBLE, + * HSA_AMD_SVM_ATTRIB_AGENT_ACCESSIBLE_IN_PLACE and + * HSA_AMD_SVM_ATTRIB_PREFETCH_LOCATION may not be used with this API. + * + * Note that attribute HSA_AMD_SVM_ATTRIB_ACCESS_QUERY takes as input an + * hsa_agent_t and returns the current access type through its attribute field. + * + * @param[in] ptr Will be aligned down to nearest page boundary. + * + * @param[in] size Will be aligned up to nearest page boundary. + * + * @param[in] attribute_list List of attributes to set for the address range. + * + * @param[in] attribute_count Length of @p attribute_list. + */ +hsa_status_t +hsa_amd_svm_attributes_get(void *ptr, size_t size, + hsa_amd_svm_attribute_pair_t *attribute_list, + size_t attribute_count); + +/** + * @brief Asynchronously migrates memory to an agent. + * + * Schedules memory migration to @p agent when @p dep_signals have been observed + * equal to zero. + * @p completion_signal will decrement when the migration is complete. + * + * @param[in] ptr Will be aligned down to nearest page boundary. + * + * @param[in] size Will be aligned up to nearest page boundary. + * + * @param[in] agent Agent to migrate to. + * + * @param[in] num_dep_signals Number of dependent signals. Can be 0. + * + * @param[in] dep_signals List of signals that must be waited on before the + * migration operation starts. The migration will start after every signal has + * been observed with the value 0. If @p num_dep_signals is 0, this argument is + * ignored. + * + * @param[in] completion_signal Signal used to indicate completion of the + * migration operation. When the migration operation is finished, the value of + * the signal is decremented. The runtime indicates that an error has occurred + * during the copy operation by setting the value of the completion signal to a + * negative number. If no completion signal is required this handle may be null. + */ +hsa_status_t hsa_amd_svm_prefetch_async(void *ptr, size_t size, + hsa_agent_t agent, + uint32_t num_dep_signals, + const hsa_signal_t *dep_signals, + hsa_signal_t completion_signal); + +/** + * @brief Acquire Stream Performance Monitor on an agent + * + * Acquire exclusive use of SPM on @p preferred_agent. + * See hsa_amd_spm_set_dest_buffer to provide a destination buffer to KFD to + * start recording and retrieve this data. + * @param[in] preferred_agent Agent on which to acquire SPM + */ +hsa_status_t hsa_amd_spm_acquire(hsa_agent_t preferred_agent); + +/** + * @brief Release Stream Performance Monitor on an agent + * + * Release exclusive use of SPM on @p preferred_agent. This will stop KFD + * writing SPM data. If a destination buffer is set, then data in the + * destination buffer is available to user when this function returns. + * + * @param[in] preferred_agent Agent on which to release SPM + */ +hsa_status_t hsa_amd_spm_release(hsa_agent_t preferred_agent); + +/** + * @brief Set up the current destination user mode buffer for stream + * performance counter data. KFD will start writing SPM data into the + * destination buffer. KFD will continue to copy data into the current + * destination buffer until any of the following functions are called + * - hsa_amd_spm_release + * - hsa_amd_spm_set_dest_buffer with dest set to NULL + * - hsa_amd_spm_set_dest_buffer with dest set to a new buffer + * + * if @p timeout is non-0, the call will wait for up to @p timeout ms for the + * previous buffer to be filled. If previous buffer to be filled before timeout, + * the @p timeout will be updated value with the time remaining. If the timeout + * is exceeded, the function copies any partial data available into the previous + * user buffer and returns success. User should not access destination data + * while KFD is copying data. If the previous destination buffer was full, then + * @p is_data_loss flag is set. + * @p dest is CPU accessible memory. It could be malloc'ed memory or host + * allocated memory + * + * @param[in] preferred_agent Agent on which to set the dest buffer + * + * @param[in] size_in_bytes size of the buffer + * + * @param[in/out] timeout timeout in milliseconds + * + * @param[out] size_copied number of bytes copied + * + * @param[in] dest destination address. Set to NULL to stop copy on previous + * buffer + * + * @param[out] is_data_loss true is data was lost + */ +hsa_status_t hsa_amd_spm_set_dest_buffer(hsa_agent_t preferred_agent, + size_t size_in_bytes, + uint32_t *timeout, + uint32_t *size_copied, void *dest, + bool *is_data_loss); +/** + * @brief Obtains an OS specific, vendor neutral, handle to a memory allocation. + * + * Obtains an OS specific handle to GPU agent memory. The memory must be part + * of a single allocation from an hsa_amd_memory_pool_t exposed by a GPU Agent. + * The handle may be used with other APIs (e.g. Vulkan) to obtain shared access + * to the allocation. + * + * Shared access to the memory is not guaranteed to be fine grain coherent even + * if the allocation exported is from a fine grain pool. The shared memory + * consistency model will be no stronger than the model exported from, consult + * the importing API to determine the final consistency model. + * + * The allocation's memory remains valid as long as the handle and any mapping + * of the handle remains valid. When the handle and all mappings are closed + * the backing memory will be released for reuse. + * + * @param[in] ptr Pointer to the allocation being exported. + * + * @param[in] size Size in bytes to export following @p ptr. The entire range + * being exported must be contained within a single allocation. + * + * @param[out] dmabuf Pointer to a dma-buf file descriptor holding a reference + * to the allocation. Contents will not be altered in the event of failure. + * + * @param[out] offset Offset in bytes into the memory referenced by the dma-buf + * object at which @p ptr resides. Contents will not be altered in the event + * of failure. + * + * @retval ::HSA_STATUS_SUCCESS Export completed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT One or more arguments is NULL. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION The address range described by + * @p ptr and @p size are not contained within a single allocation. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The allocation described by @p ptr + * and @p size was allocated on a device which can not export memory. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The return file descriptor, + * @p dmabuf, could not be created. + */ +hsa_status_t hsa_amd_portable_export_dmabuf(const void *ptr, size_t size, + int *dmabuf, uint64_t *offset); + +/** + * @brief Closes an OS specific, vendor neutral, handle to a memory allocation. + * + * Closes an OS specific handle to GPU agent memory. + * + * Applications should close a handle after imports are complete. The handle + * is not required to remain open for the lifetime of imported mappings. The + * referenced allocation will remain valid until all handles and mappings + * are closed. + * + * @param[in] dmabuf Handle to be closed. + * + * @retval ::HSA_STATUS_SUCCESS Handle closed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_RESOURCE_FREE A generic error was encountered + * when closing the handle. The handle may have been closed already or an + * async IO error may have occured. + */ +hsa_status_t hsa_amd_portable_close_dmabuf(int dmabuf); + +/** + * @brief Allocate a reserved address range + * + * Reserve a virtual address range. The size must be a multiple of the system + * page size. If it is not possible to allocate the address specified by @p + * address, then @p va will be a different address range. Address range should + * be released by calling hsa_amd_vmem_address_free. + * + * @param[out] va virtual address allocated + * @param[in] size of address range requested + * @param[in] address requested + * @param[in] flags currently unsupported + * + * @retval ::HSA_STATUS_SUCCESS Address range allocated successfully + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Insufficient resources to + * allocate an address range of this size. + * + * Note that this API will be deprecated in a future release and replaced by + * hsa_amd_vmem_address_reserve_align + */ +hsa_status_t hsa_amd_vmem_address_reserve(void **va, size_t size, + uint64_t address, uint64_t flags); + +/** + * @brief Allocate a reserved address range + * + * Reserve a virtual address range. The size must be a multiple of the system + * page size. If it is not possible to allocate the address specified by @p + * address, then @p va will be a different address range. Address range should + * be released by calling hsa_amd_vmem_address_free. + * + * @param[out] va virtual address allocated + * @param[in] size of address range requested + * @param[in] address requested + * @param[in] alignment requested. 0 for default. Must be >= page-size and a + * power of 2 + * @param[in] flags currently unsupported + * + * @retval ::HSA_STATUS_SUCCESS Address range allocated successfully + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Insufficient resources to + * allocate an address range of this size. + */ +hsa_status_t hsa_amd_vmem_address_reserve_align(void **va, size_t size, + uint64_t address, + uint64_t alignment, + uint64_t flags); + +/** + * @brief Free a reserved address range + * + * Free a previously allocated address range. The size must match the size of a + * previously allocated address range. + * + * @param[out] va virtual address to be freed + * @param[in] size of address range + * + * @retval ::HSA_STATUS_SUCCESS Address range released successfully + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION Invalid va specified + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid size specified + * @retval ::HSA_STATUS_ERROR_RESOURCE_FREE Address range is still in use + * @retval ::HSA_STATUS_ERROR Internal unexpected error + */ +hsa_status_t hsa_amd_vmem_address_free(void *va, size_t size); + +/** + * @brief Struct containing an opaque handle to a memory allocation handle + */ +typedef struct hsa_amd_vmem_alloc_handle_s { + /** + * Opaque handle. Two handles reference the same object of the enclosing type + * if and only if they are equal. + */ + uint64_t handle; +} hsa_amd_vmem_alloc_handle_t; + +typedef enum { + MEMORY_TYPE_NONE, + MEMORY_TYPE_PINNED, +} hsa_amd_memory_type_t; + +/** + * @brief Create a virtual memory handle + * + * Create a virtual memory handle within this pool + * @p size must be a aligned to allocation granule size for this memory pool, + * see HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_GRANULE To minimize internal + * memory fragmentation, align the size to the recommended allocation granule + * size, see HSA_AMD_MEMORY_POOL_INFO_RUNTIME_ALLOC_REC_GRANULE + * + * @param[in] pool memory to use + * @param[in] size of the memory allocation + * @param[in] type of memory + * @param[in] flags - currently unsupported + * @param[out] memory_handle - handle for the allocation + * + * @retval ::HSA_STATUS_SUCCESS memory allocated successfully + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid arguments + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION This memory pool does not + * support allocations + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Insufficient resources to + * allocate this memory + */ +hsa_status_t +hsa_amd_vmem_handle_create(hsa_amd_memory_pool_t pool, size_t size, + hsa_amd_memory_type_t type, uint64_t flags, + hsa_amd_vmem_alloc_handle_t *memory_handle); + +/** + * @brief Release a virtual memory handle + * + * @param[in] memory handle that was previously allocated + * + * @retval ::HSA_STATUS_SUCCESS Address range allocated successfully + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION Invalid memory handle + */ +hsa_status_t +hsa_amd_vmem_handle_release(hsa_amd_vmem_alloc_handle_t memory_handle); + +/** + * @brief Map a virtual memory handle + * + * Map a virtual memory handle to a reserved address range. The virtual address + * requested must be within a previously reserved address range. @p va and (@p + * va + size) must be must be within (va + size) of the previous allocated + * address range. + * @p size must be equal to size of the @p memory_handle + * hsa_amd_vmem_set_access needs to be called to make the memory accessible to + * specific agents + * + * @param[in] va virtual address range where memory will be mapped + * @param[in] size of memory mapping + * @param[in] in_offset offset into memory. Currently unsupported + * @param[in] memory_handle virtual memory handle to be mapped + * @param[in] flags. Currently unsupported + * + * @retval ::HSA_STATUS_SUCCESS Memory mapped successfully + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT va, size or memory_handle are + * invalid + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Insufficient resources + * + * @retval ::HSA_STATUS_ERROR Unexpected internal error + */ +hsa_status_t hsa_amd_vmem_map(void *va, size_t size, size_t in_offset, + hsa_amd_vmem_alloc_handle_t memory_handle, + uint64_t flags); + +/** + * @brief Unmap a virtual memory handle + * + * Unmap previously mapped virtual address range + * + * @param[in] va virtual address range where memory will be mapped + * @param[in] size of memory mapping + * + * @retval ::HSA_STATUS_SUCCESS Memory backing unmapped successfully + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION memory_handle is invalid + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT size is invalid + * + * @retval ::HSA_STATUS_ERROR Unexpected internal error + */ +hsa_status_t hsa_amd_vmem_unmap(void *va, size_t size); + +typedef struct hsa_amd_memory_access_desc_s { + hsa_access_permission_t permissions; + hsa_agent_t agent_handle; +} hsa_amd_memory_access_desc_t; + +/** + * @brief Make a memory mapping accessible + * + * Make previously mapped virtual address accessible to specific agents. @p size + * must be equal to size of previously mapped virtual memory handle. Calling + * hsa_amd_vmem_set_access multiple times on the same @p va will overwrite + * previous permissions for all agents + * + * @param[in] va previously mapped virtual address + * @param[in] size of memory mapping + * @param[in] desc list of access permissions for each agent + * @param[in] desc_cnt number of elements in desc + * + * @retval ::HSA_STATUS_SUCCESS + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT va, size or memory_handle are + * invalid + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION memory_handle is invalid + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Insufficient resources + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT Invalid agent in desc + * + * @retval ::HSA_STATUS_ERROR Unexpected internal error + */ +hsa_status_t hsa_amd_vmem_set_access(void *va, size_t size, + const hsa_amd_memory_access_desc_t *desc, + size_t desc_cnt); + +/** + * @brief Get current access permissions for memory mapping + * + * Get access permissions for memory mapping for specific agent. + * + * @param[in] va previously mapped virtual address + * @param[in] perms current permissions + * @param[in] agent_handle agent + * + * @retval ::HSA_STATUS_SUCCESS + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT Invalid agent + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION va is not mapped or permissions + * never set for this agent + * + * @retval ::HSA_STATUS_ERROR Unexpected internal error + */ +hsa_status_t hsa_amd_vmem_get_access(void *va, hsa_access_permission_t *perms, + hsa_agent_t agent_handle); + +/** + * @brief Get an exportable shareable handle + * + * Get an exportable shareable handle for a memory_handle. This shareabl handle + * can then be used to re-create a virtual memory handle using + * hsa_amd_vmem_import_shareable_handle. The shareable handle can be transferred + * using mechanisms that support posix file descriptors Once all shareable + * handles are closed, the memory_handle is released. + * + * @param[out] dmabuf_fd shareable handle + * @param[in] handle previously allocated virtual memory handle + * @param[in] flags Currently unsupported + * + * @retval ::HSA_STATUS_SUCCESS + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION Invalid memory handle + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Out of resources + * + * @retval ::HSA_STATUS_ERROR Unexpected internal error + */ +hsa_status_t hsa_amd_vmem_export_shareable_handle( + int *dmabuf_fd, hsa_amd_vmem_alloc_handle_t handle, uint64_t flags); +/** + * @brief Import a shareable handle + * + * Import a shareable handle for a memory handle. Importing a shareable handle + * that has been closed and released results in undefined behavior. + * + * @param[in] dmabuf_fd shareable handle exported with + * hsa_amd_vmem_export_shareable_handle + * @param[out] handle virtual memory handle + * + * @retval ::HSA_STATUS_SUCCESS + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION Invalid memory handle + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Out of resources + * + * @retval ::HSA_STATUS_ERROR Unexpected internal error + */ +hsa_status_t +hsa_amd_vmem_import_shareable_handle(int dmabuf_fd, + hsa_amd_vmem_alloc_handle_t *handle); + +/** + * @brief Returns memory handle for mapped memory + * + * Return a memory handle for previously mapped memory. The handle will be the + * same value of handle used to map the memory. The returned handle must be + * released with corresponding number of calls to hsa_amd_vmem_handle_release. + * + * @param[out] memory_handle memory handle for this mapped address + * @param[in] mapped address + * + * @retval ::HSA_STATUS_SUCCESS + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION Invalid address + */ +hsa_status_t +hsa_amd_vmem_retain_alloc_handle(hsa_amd_vmem_alloc_handle_t *memory_handle, + void *addr); + +/** + * @brief Returns the current allocation properties of a handle + * + * Returns the allocation properties of an existing handle + * + * @param[in] memory_handle memory handle to be queried + * @param[out] pool memory pool that owns this handle + * @param[out] memory type + + * @retval ::HSA_STATUS_SUCCESS + * + * @retval ::HSA_STATUS_ERROR_INVALID_ALLOCATION Invalid memory_handle + */ +hsa_status_t hsa_amd_vmem_get_alloc_properties_from_handle( + hsa_amd_vmem_alloc_handle_t memory_handle, hsa_amd_memory_pool_t *pool, + hsa_amd_memory_type_t *type); + +/** + * @brief Set the asynchronous scratch limit threshold on all the queues for + * this agent. Dispatches that are enqueued on HW queues on this agent that are + * smaller than threshold will not result in a scratch use-once method. + * + * Increasing this threshold will only increase the internal limit and not cause + * immediate allocation of additional scratch memory. Decreasing this threshold + * will result in a release in scratch memory on queues where the current amount + * of allocated scratch exceeds the new limit. + * + * This API is only supported on devices that support asynchronous scratch + * reclaim. + * + * @param[in] agent A valid agent. + * + * @param[in] threshold Threshold size in bytes + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT This agent does not support + * asynchronous scratch reclaim + */ +hsa_status_t HSA_API hsa_amd_agent_set_async_scratch_limit(hsa_agent_t agent, + size_t threshold); + +typedef enum { + /* + * Returns the agent that owns the underlying HW queue. + * The type of this attribute is hsa_agent_t. + */ + HSA_AMD_QUEUE_INFO_AGENT, + /* + * Returns the doorbell ID of the completion signal of the queue + * The type of this attribute is uint64_t. + */ + HSA_AMD_QUEUE_INFO_DOORBELL_ID, +} hsa_queue_info_attribute_t; + +hsa_status_t hsa_amd_queue_get_info(hsa_queue_t *queue, + hsa_queue_info_attribute_t attribute, + void *value); + +#ifdef __cplusplus +} // end extern "C" block +#endif + +#endif // header guard diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_finalize.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_finalize.h new file mode 100644 index 000000000..4bb44cd9e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_finalize.h @@ -0,0 +1,522 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef HSA_RUNTIME_INC_HSA_EXT_FINALIZE_H_ +#define HSA_RUNTIME_INC_HSA_EXT_FINALIZE_H_ + +#include "hsa.h" + +#undef HSA_API +#ifdef HSA_EXPORT_FINALIZER +#define HSA_API HSA_API_EXPORT +#else +#define HSA_API HSA_API_IMPORT +#endif + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct BrigModuleHeader; +typedef struct BrigModuleHeader *BrigModule_t; + +/** \defgroup ext-alt-finalizer-extensions Finalization Extensions + * @{ + */ + +/** + * @brief Enumeration constants added to ::hsa_status_t by this extension. + */ +enum { + /** + * The HSAIL program is invalid. + */ + HSA_EXT_STATUS_ERROR_INVALID_PROGRAM = 0x2000, + /** + * The HSAIL module is invalid. + */ + HSA_EXT_STATUS_ERROR_INVALID_MODULE = 0x2001, + /** + * Machine model or profile of the HSAIL module do not match the machine model + * or profile of the HSAIL program. + */ + HSA_EXT_STATUS_ERROR_INCOMPATIBLE_MODULE = 0x2002, + /** + * The HSAIL module is already a part of the HSAIL program. + */ + HSA_EXT_STATUS_ERROR_MODULE_ALREADY_INCLUDED = 0x2003, + /** + * Compatibility mismatch between symbol declaration and symbol definition. + */ + HSA_EXT_STATUS_ERROR_SYMBOL_MISMATCH = 0x2004, + /** + * The finalization encountered an error while finalizing a kernel or + * indirect function. + */ + HSA_EXT_STATUS_ERROR_FINALIZATION_FAILED = 0x2005, + /** + * Mismatch between a directive in the control directive structure and in + * the HSAIL kernel. + */ + HSA_EXT_STATUS_ERROR_DIRECTIVE_MISMATCH = 0x2006 +}; + +/** @} */ + +/** \defgroup ext-alt-finalizer-program Finalization Program + * @{ + */ + +/** + * @brief HSAIL (BRIG) module. The HSA Programmer's Reference Manual contains + * the definition of the BrigModule_t type. + */ +typedef BrigModule_t hsa_ext_module_t; + +/** + * @brief An opaque handle to a HSAIL program, which groups a set of HSAIL + * modules that collectively define functions and variables used by kernels and + * indirect functions. + */ +typedef struct hsa_ext_program_s { + /** + * Opaque handle. + */ + uint64_t handle; +} hsa_ext_program_t; + +/** + * @brief Create an empty HSAIL program. + * + * @param[in] machine_model Machine model used in the HSAIL program. + * + * @param[in] profile Profile used in the HSAIL program. + * + * @param[in] default_float_rounding_mode Default float rounding mode used in + * the HSAIL program. + * + * @param[in] options Vendor-specific options. May be NULL. + * + * @param[out] program Memory location where the HSA runtime stores the newly + * created HSAIL program handle. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES There is a failure to allocate + * resources required for the operation. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p machine_model is invalid, + * @p profile is invalid, @p default_float_rounding_mode is invalid, or + * @p program is NULL. + */ +hsa_status_t HSA_API hsa_ext_program_create( + hsa_machine_model_t machine_model, hsa_profile_t profile, + hsa_default_float_rounding_mode_t default_float_rounding_mode, + const char *options, hsa_ext_program_t *program); + +/** + * @brief Destroy a HSAIL program. + * + * @details The HSAIL program handle becomes invalid after it has been + * destroyed. Code object handles produced by ::hsa_ext_program_finalize are + * still valid after the HSAIL program has been destroyed, and can be used as + * intended. Resources allocated outside and associated with the HSAIL program + * (such as HSAIL modules that are added to the HSAIL program) can be released + * after the finalization program has been destroyed. + * + * @param[in] program HSAIL program. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_EXT_STATUS_ERROR_INVALID_PROGRAM The HSAIL program is + * invalid. + */ +hsa_status_t HSA_API hsa_ext_program_destroy(hsa_ext_program_t program); + +/** + * @brief Add a HSAIL module to an existing HSAIL program. + * + * @details The HSA runtime does not perform a deep copy of the HSAIL module + * upon addition. Instead, it stores a pointer to the HSAIL module. The + * ownership of the HSAIL module belongs to the application, which must ensure + * that @p module is not released before destroying the HSAIL program. + * + * The HSAIL module is successfully added to the HSAIL program if @p module is + * valid, if all the declarations and definitions for the same symbol are + * compatible, and if @p module specify machine model and profile that matches + * the HSAIL program. + * + * @param[in] program HSAIL program. + * + * @param[in] module HSAIL module. The application can add the same HSAIL module + * to @p program at most once. The HSAIL module must specify the same machine + * model and profile as @p program. If the floating-mode rounding mode of @p + * module is not default, then it should match that of @p program. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES There is a failure to allocate + * resources required for the operation. + * + * @retval ::HSA_EXT_STATUS_ERROR_INVALID_PROGRAM The HSAIL program is invalid. + * + * @retval ::HSA_EXT_STATUS_ERROR_INVALID_MODULE The HSAIL module is invalid. + * + * @retval ::HSA_EXT_STATUS_ERROR_INCOMPATIBLE_MODULE The machine model of @p + * module does not match machine model of @p program, or the profile of @p + * module does not match profile of @p program. + * + * @retval ::HSA_EXT_STATUS_ERROR_MODULE_ALREADY_INCLUDED The HSAIL module is + * already a part of the HSAIL program. + * + * @retval ::HSA_EXT_STATUS_ERROR_SYMBOL_MISMATCH Symbol declaration and symbol + * definition compatibility mismatch. See the symbol compatibility rules in the + * HSA Programming Reference Manual. + */ +hsa_status_t HSA_API hsa_ext_program_add_module(hsa_ext_program_t program, + hsa_ext_module_t module); + +/** + * @brief Iterate over the HSAIL modules in a program, and invoke an + * application-defined callback on every iteration. + * + * @param[in] program HSAIL program. + * + * @param[in] callback Callback to be invoked once per HSAIL module in the + * program. The HSA runtime passes three arguments to the callback: the program, + * a HSAIL module, and the application data. If @p callback returns a status + * other than ::HSA_STATUS_SUCCESS for a particular iteration, the traversal + * stops and ::hsa_ext_program_iterate_modules returns that status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_EXT_STATUS_ERROR_INVALID_PROGRAM The program is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t HSA_API hsa_ext_program_iterate_modules( + hsa_ext_program_t program, + hsa_status_t (*callback)(hsa_ext_program_t program, hsa_ext_module_t module, + void *data), + void *data); + +/** + * @brief HSAIL program attributes. + */ +typedef enum { + /** + * Machine model specified when the HSAIL program was created. The type + * of this attribute is ::hsa_machine_model_t. + */ + HSA_EXT_PROGRAM_INFO_MACHINE_MODEL = 0, + /** + * Profile specified when the HSAIL program was created. The type of + * this attribute is ::hsa_profile_t. + */ + HSA_EXT_PROGRAM_INFO_PROFILE = 1, + /** + * Default float rounding mode specified when the HSAIL program was + * created. The type of this attribute is ::hsa_default_float_rounding_mode_t. + */ + HSA_EXT_PROGRAM_INFO_DEFAULT_FLOAT_ROUNDING_MODE = 2 +} hsa_ext_program_info_t; + +/** + * @brief Get the current value of an attribute for a given HSAIL program. + * + * @param[in] program HSAIL program. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behaviour is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_EXT_STATUS_ERROR_INVALID_PROGRAM The HSAIL program is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * HSAIL program attribute, or @p value is NULL. + */ +hsa_status_t HSA_API hsa_ext_program_get_info(hsa_ext_program_t program, + hsa_ext_program_info_t attribute, + void *value); + +/** + * @brief Finalizer-determined call convention. + */ +typedef enum { + /** + * Finalizer-determined call convention. + */ + HSA_EXT_FINALIZER_CALL_CONVENTION_AUTO = -1 +} hsa_ext_finalizer_call_convention_t; + +/** + * @brief Control directives specify low-level information about the + * finalization process. + */ +typedef struct hsa_ext_control_directives_s { + /** + * Bitset indicating which control directives are enabled. The bit assigned to + * a control directive is determined by the corresponding value in + * BrigControlDirective. + * + * If a control directive is disabled, its corresponding field value (if any) + * must be 0. Control directives that are only present or absent (such as + * partial workgroups) have no corresponding field as the presence of the bit + * in this mask is sufficient. + */ + uint64_t control_directives_mask; + /** + * Bitset of HSAIL exceptions that must have the BREAK policy enabled. The bit + * assigned to an HSAIL exception is determined by the corresponding value + * in BrigExceptionsMask. If the kernel contains a enablebreakexceptions + * control directive, the finalizer uses the union of the two masks. + */ + uint16_t break_exceptions_mask; + /** + * Bitset of HSAIL exceptions that must have the DETECT policy enabled. The + * bit assigned to an HSAIL exception is determined by the corresponding value + * in BrigExceptionsMask. If the kernel contains a enabledetectexceptions + * control directive, the finalizer uses the union of the two masks. + */ + uint16_t detect_exceptions_mask; + /** + * Maximum size (in bytes) of dynamic group memory that will be allocated by + * the application for any dispatch of the kernel. If the kernel contains a + * maxdynamicsize control directive, the two values should match. + */ + uint32_t max_dynamic_group_size; + /** + * Maximum number of grid work-items that will be used by the application to + * launch the kernel. If the kernel contains a maxflatgridsize control + * directive, the value of @a max_flat_grid_size must not be greater than the + * value of the directive, and takes precedence. + * + * The value specified for maximum absolute grid size must be greater than or + * equal to the product of the values specified by @a required_grid_size. + * + * If the bit at position BRIG_CONTROL_MAXFLATGRIDSIZE is set in @a + * control_directives_mask, this field must be greater than 0. + */ + uint64_t max_flat_grid_size; + /** + * Maximum number of work-group work-items that will be used by the + * application to launch the kernel. If the kernel contains a + * maxflatworkgroupsize control directive, the value of @a + * max_flat_workgroup_size must not be greater than the value of the + * directive, and takes precedence. + * + * The value specified for maximum absolute grid size must be greater than or + * equal to the product of the values specified by @a required_workgroup_size. + * + * If the bit at position BRIG_CONTROL_MAXFLATWORKGROUPSIZE is set in @a + * control_directives_mask, this field must be greater than 0. + */ + uint32_t max_flat_workgroup_size; + /** + * Reserved. Must be 0. + */ + uint32_t reserved1; + /** + * Grid size that will be used by the application in any dispatch of the + * kernel. If the kernel contains a requiredgridsize control directive, the + * dimensions should match. + * + * The specified grid size must be consistent with @a required_workgroup_size + * and @a required_dim. Also, the product of the three dimensions must not + * exceed @a max_flat_grid_size. Note that the listed invariants must hold + * only if all the corresponding control directives are enabled. + * + * If the bit at position BRIG_CONTROL_REQUIREDGRIDSIZE is set in @a + * control_directives_mask, the three dimension values must be greater than 0. + */ + uint64_t required_grid_size[3]; + /** + * Work-group size that will be used by the application in any dispatch of the + * kernel. If the kernel contains a requiredworkgroupsize control directive, + * the dimensions should match. + * + * The specified work-group size must be consistent with @a required_grid_size + * and @a required_dim. Also, the product of the three dimensions must not + * exceed @a max_flat_workgroup_size. Note that the listed invariants must + * hold only if all the corresponding control directives are enabled. + * + * If the bit at position BRIG_CONTROL_REQUIREDWORKGROUPSIZE is set in @a + * control_directives_mask, the three dimension values must be greater than 0. + */ + hsa_dim3_t required_workgroup_size; + /** + * Number of dimensions that will be used by the application to launch the + * kernel. If the kernel contains a requireddim control directive, the two + * values should match. + * + * The specified dimensions must be consistent with @a required_grid_size and + * @a required_workgroup_size. This invariant must hold only if all the + * corresponding control directives are enabled. + * + * If the bit at position BRIG_CONTROL_REQUIREDDIM is set in @a + * control_directives_mask, this field must be 1, 2, or 3. + */ + uint8_t required_dim; + /** + * Reserved. Must be 0. + */ + uint8_t reserved2[75]; +} hsa_ext_control_directives_t; + +/** + * @brief Finalize an HSAIL program for a given instruction set architecture. + * + * @details Finalize all of the kernels and indirect functions that belong to + * the same HSAIL program for a specific instruction set architecture (ISA). The + * transitive closure of all functions specified by call or scall must be + * defined. Kernels and indirect functions that are being finalized must be + * defined. Kernels and indirect functions that are referenced in kernels and + * indirect functions being finalized may or may not be defined, but must be + * declared. All the global/readonly segment variables that are referenced in + * kernels and indirect functions being finalized may or may not be defined, but + * must be declared. + * + * @param[in] program HSAIL program. + * + * @param[in] isa Instruction set architecture to finalize for. + * + * @param[in] call_convention A call convention used in a finalization. Must + * have a value between ::HSA_EXT_FINALIZER_CALL_CONVENTION_AUTO (inclusive) + * and the value of the attribute ::HSA_ISA_INFO_CALL_CONVENTION_COUNT in @p + * isa (not inclusive). + * + * @param[in] control_directives Low-level control directives that influence + * the finalization process. + * + * @param[in] options Vendor-specific options. May be NULL. + * + * @param[in] code_object_type Type of code object to produce. + * + * @param[out] code_object Code object generated by the Finalizer, which + * contains the machine code for the kernels and indirect functions in the HSAIL + * program. The code object is independent of the HSAIL module that was used to + * generate it. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES There is a failure to allocate + * resources required for the operation. + * + * @retval ::HSA_EXT_STATUS_ERROR_INVALID_PROGRAM The HSAIL program is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ISA @p isa is invalid. + * + * @retval ::HSA_EXT_STATUS_ERROR_DIRECTIVE_MISMATCH The directive in + * the control directive structure and in the HSAIL kernel mismatch, or if the + * same directive is used with a different value in one of the functions used by + * this kernel. + * + * @retval ::HSA_EXT_STATUS_ERROR_FINALIZATION_FAILED The Finalizer + * encountered an error while compiling a kernel or an indirect function. + */ +hsa_status_t HSA_API hsa_ext_program_finalize( + hsa_ext_program_t program, hsa_isa_t isa, int32_t call_convention, + hsa_ext_control_directives_t control_directives, const char *options, + hsa_code_object_type_t code_object_type, hsa_code_object_t *code_object); + +/** @} */ + +#define hsa_ext_finalizer_1_00 + +typedef struct hsa_ext_finalizer_1_00_pfn_s { + hsa_status_t (*hsa_ext_program_create)( + hsa_machine_model_t machine_model, hsa_profile_t profile, + hsa_default_float_rounding_mode_t default_float_rounding_mode, + const char *options, hsa_ext_program_t *program); + + hsa_status_t (*hsa_ext_program_destroy)(hsa_ext_program_t program); + + hsa_status_t (*hsa_ext_program_add_module)(hsa_ext_program_t program, + hsa_ext_module_t module); + + hsa_status_t (*hsa_ext_program_iterate_modules)( + hsa_ext_program_t program, + hsa_status_t (*callback)(hsa_ext_program_t program, + hsa_ext_module_t module, void *data), + void *data); + + hsa_status_t (*hsa_ext_program_get_info)(hsa_ext_program_t program, + hsa_ext_program_info_t attribute, + void *value); + + hsa_status_t (*hsa_ext_program_finalize)( + hsa_ext_program_t program, hsa_isa_t isa, int32_t call_convention, + hsa_ext_control_directives_t control_directives, const char *options, + hsa_code_object_type_t code_object_type, hsa_code_object_t *code_object); +} hsa_ext_finalizer_1_00_pfn_t; + +#ifdef __cplusplus +} // extern "C" block +#endif // __cplusplus + +#endif // HSA_RUNTIME_INC_HSA_EXT_FINALIZE_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_image.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_image.h new file mode 100644 index 000000000..56f86147f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ext_image.h @@ -0,0 +1,1402 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef HSA_EXT_IMAGE_H +#define HSA_EXT_IMAGE_H + +#include "hsa.h" + +#undef HSA_API +#ifdef HSA_EXPORT_IMAGES +#define HSA_API HSA_API_EXPORT +#else +#define HSA_API HSA_API_IMPORT +#endif + +#ifdef __cplusplus +extern "C" { +#endif /*__cplusplus*/ + +/** \defgroup ext-images Images and Samplers + * @{ + */ + +/** + * @brief Enumeration constants added to ::hsa_status_t by this extension. + * + * @remark Additions to hsa_status_t + */ +enum { + /** + * Image format is not supported. + */ + HSA_EXT_STATUS_ERROR_IMAGE_FORMAT_UNSUPPORTED = 0x3000, + /** + * Image size is not supported. + */ + HSA_EXT_STATUS_ERROR_IMAGE_SIZE_UNSUPPORTED = 0x3001, + /** + * Image pitch is not supported or invalid. + */ + HSA_EXT_STATUS_ERROR_IMAGE_PITCH_UNSUPPORTED = 0x3002, + /** + * Sampler descriptor is not supported or invalid. + */ + HSA_EXT_STATUS_ERROR_SAMPLER_DESCRIPTOR_UNSUPPORTED = 0x3003 +}; + +/** + * @brief Enumeration constants added to ::hsa_agent_info_t by this + * extension. + * + * @remark Additions to hsa_agent_info_t + */ +enum { + /** + * Maximum number of elements in 1D images. Must be at least 16384. The type + * of this attribute is size_t. + */ + HSA_EXT_AGENT_INFO_IMAGE_1D_MAX_ELEMENTS = 0x3000, + /** + * Maximum number of elements in 1DA images. Must be at least 16384. The type + * of this attribute is size_t. + */ + HSA_EXT_AGENT_INFO_IMAGE_1DA_MAX_ELEMENTS = 0x3001, + /** + * Maximum number of elements in 1DB images. Must be at least 65536. The type + * of this attribute is size_t. + */ + HSA_EXT_AGENT_INFO_IMAGE_1DB_MAX_ELEMENTS = 0x3002, + /** + * Maximum dimensions (width, height) of 2D images, in image elements. The X + * and Y maximums must be at least 16384. The type of this attribute is + * size_t[2]. + */ + HSA_EXT_AGENT_INFO_IMAGE_2D_MAX_ELEMENTS = 0x3003, + /** + * Maximum dimensions (width, height) of 2DA images, in image elements. The X + * and Y maximums must be at least 16384. The type of this attribute is + * size_t[2]. + */ + HSA_EXT_AGENT_INFO_IMAGE_2DA_MAX_ELEMENTS = 0x3004, + /** + * Maximum dimensions (width, height) of 2DDEPTH images, in image + * elements. The X and Y maximums must be at least 16384. The type of this + * attribute is size_t[2]. + */ + HSA_EXT_AGENT_INFO_IMAGE_2DDEPTH_MAX_ELEMENTS = 0x3005, + /** + * Maximum dimensions (width, height) of 2DADEPTH images, in image + * elements. The X and Y maximums must be at least 16384. The type of this + * attribute is size_t[2]. + */ + HSA_EXT_AGENT_INFO_IMAGE_2DADEPTH_MAX_ELEMENTS = 0x3006, + /** + * Maximum dimensions (width, height, depth) of 3D images, in image + * elements. The maximum along any dimension must be at least 2048. The type + * of this attribute is size_t[3]. + */ + HSA_EXT_AGENT_INFO_IMAGE_3D_MAX_ELEMENTS = 0x3007, + /** + * Maximum number of image layers in a image array. Must be at least 2048. The + * type of this attribute is size_t. + */ + HSA_EXT_AGENT_INFO_IMAGE_ARRAY_MAX_LAYERS = 0x3008, + /** + * Maximum number of read-only image handles that can be created for an agent + * at any one time. Must be at least 128. The type of this attribute is + * size_t. + */ + HSA_EXT_AGENT_INFO_MAX_IMAGE_RD_HANDLES = 0x3009, + /** + * Maximum number of write-only and read-write image handles (combined) that + * can be created for an agent at any one time. Must be at least 64. The type + * of this attribute is size_t. + */ + HSA_EXT_AGENT_INFO_MAX_IMAGE_RORW_HANDLES = 0x300A, + /** + * Maximum number of sampler handlers that can be created for an agent at any + * one time. Must be at least 16. The type of this attribute is size_t. + */ + HSA_EXT_AGENT_INFO_MAX_SAMPLER_HANDLERS = 0x300B, + /** + * Image pitch alignment. The agent only supports linear image data + * layouts with a row pitch that is a multiple of this value. Must be + * a power of 2. The type of this attribute is size_t. + */ + HSA_EXT_AGENT_INFO_IMAGE_LINEAR_ROW_PITCH_ALIGNMENT = 0x300C +}; + +/** + * @brief Image handle, populated by ::hsa_ext_image_create or + * ::hsa_ext_image_create_with_layout. Image + * handles are only unique within an agent, not across agents. + * + */ +typedef struct hsa_ext_image_s { + /** + * Opaque handle. For a given agent, two handles reference the same object of + * the enclosing type if and only if they are equal. + */ + uint64_t handle; + +} hsa_ext_image_t; + +/** + * @brief Geometry associated with the image. This specifies the + * number of image dimensions and whether the image is an image + * array. See the Image Geometry section in the HSA + * Programming Reference Manual for definitions on each + * geometry. The enumeration values match the BRIG type @p + * hsa_ext_brig_image_geometry_t. + */ +typedef enum { + /** + * One-dimensional image addressed by width coordinate. + */ + HSA_EXT_IMAGE_GEOMETRY_1D = 0, + + /** + * Two-dimensional image addressed by width and height coordinates. + */ + HSA_EXT_IMAGE_GEOMETRY_2D = 1, + + /** + * Three-dimensional image addressed by width, height, and depth coordinates. + */ + HSA_EXT_IMAGE_GEOMETRY_3D = 2, + + /** + * Array of one-dimensional images with the same size and format. 1D arrays + * are addressed by width and index coordinate. + */ + HSA_EXT_IMAGE_GEOMETRY_1DA = 3, + + /** + * Array of two-dimensional images with the same size and format. 2D arrays + * are addressed by width, height, and index coordinates. + */ + HSA_EXT_IMAGE_GEOMETRY_2DA = 4, + + /** + * One-dimensional image addressed by width coordinate. It has + * specific restrictions compared to ::HSA_EXT_IMAGE_GEOMETRY_1D. An + * image with an opaque image data layout will always use a linear + * image data layout, and one with an explicit image data layout + * must specify ::HSA_EXT_IMAGE_DATA_LAYOUT_LINEAR. + */ + HSA_EXT_IMAGE_GEOMETRY_1DB = 5, + + /** + * Two-dimensional depth image addressed by width and height coordinates. + */ + HSA_EXT_IMAGE_GEOMETRY_2DDEPTH = 6, + + /** + * Array of two-dimensional depth images with the same size and format. 2D + * arrays are addressed by width, height, and index coordinates. + */ + HSA_EXT_IMAGE_GEOMETRY_2DADEPTH = 7 +} hsa_ext_image_geometry_t; + +/** + * @brief Channel type associated with the elements of an image. See + * the Channel Type section in the HSA Programming Reference + * Manual for definitions on each channel type. The + * enumeration values and definition match the BRIG type @p + * hsa_ext_brig_image_channel_type_t. + */ +typedef enum { + HSA_EXT_IMAGE_CHANNEL_TYPE_SNORM_INT8 = 0, + HSA_EXT_IMAGE_CHANNEL_TYPE_SNORM_INT16 = 1, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_INT8 = 2, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_INT16 = 3, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_INT24 = 4, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_SHORT_555 = 5, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_SHORT_565 = 6, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNORM_SHORT_101010 = 7, + HSA_EXT_IMAGE_CHANNEL_TYPE_SIGNED_INT8 = 8, + HSA_EXT_IMAGE_CHANNEL_TYPE_SIGNED_INT16 = 9, + HSA_EXT_IMAGE_CHANNEL_TYPE_SIGNED_INT32 = 10, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8 = 11, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16 = 12, + HSA_EXT_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32 = 13, + HSA_EXT_IMAGE_CHANNEL_TYPE_HALF_FLOAT = 14, + HSA_EXT_IMAGE_CHANNEL_TYPE_FLOAT = 15 +} hsa_ext_image_channel_type_t; + +/** + * @brief A fixed-size type used to represent ::hsa_ext_image_channel_type_t + * constants. + */ +typedef uint32_t hsa_ext_image_channel_type32_t; + +/** + * + * @brief Channel order associated with the elements of an image. See + * the Channel Order section in the HSA Programming Reference + * Manual for definitions on each channel order. The + * enumeration values match the BRIG type @p + * hsa_ext_brig_image_channel_order_t. + */ +typedef enum { + HSA_EXT_IMAGE_CHANNEL_ORDER_A = 0, + HSA_EXT_IMAGE_CHANNEL_ORDER_R = 1, + HSA_EXT_IMAGE_CHANNEL_ORDER_RX = 2, + HSA_EXT_IMAGE_CHANNEL_ORDER_RG = 3, + HSA_EXT_IMAGE_CHANNEL_ORDER_RGX = 4, + HSA_EXT_IMAGE_CHANNEL_ORDER_RA = 5, + HSA_EXT_IMAGE_CHANNEL_ORDER_RGB = 6, + HSA_EXT_IMAGE_CHANNEL_ORDER_RGBX = 7, + HSA_EXT_IMAGE_CHANNEL_ORDER_RGBA = 8, + HSA_EXT_IMAGE_CHANNEL_ORDER_BGRA = 9, + HSA_EXT_IMAGE_CHANNEL_ORDER_ARGB = 10, + HSA_EXT_IMAGE_CHANNEL_ORDER_ABGR = 11, + HSA_EXT_IMAGE_CHANNEL_ORDER_SRGB = 12, + HSA_EXT_IMAGE_CHANNEL_ORDER_SRGBX = 13, + HSA_EXT_IMAGE_CHANNEL_ORDER_SRGBA = 14, + HSA_EXT_IMAGE_CHANNEL_ORDER_SBGRA = 15, + HSA_EXT_IMAGE_CHANNEL_ORDER_INTENSITY = 16, + HSA_EXT_IMAGE_CHANNEL_ORDER_LUMINANCE = 17, + HSA_EXT_IMAGE_CHANNEL_ORDER_DEPTH = 18, + HSA_EXT_IMAGE_CHANNEL_ORDER_DEPTH_STENCIL = 19 +} hsa_ext_image_channel_order_t; + +/** + * @brief A fixed-size type used to represent ::hsa_ext_image_channel_order_t + * constants. + */ +typedef uint32_t hsa_ext_image_channel_order32_t; + +/** + * @brief Image format. + */ +typedef struct hsa_ext_image_format_s { + /** + * Channel type. + */ + hsa_ext_image_channel_type32_t channel_type; + + /** + * Channel order. + */ + hsa_ext_image_channel_order32_t channel_order; +} hsa_ext_image_format_t; + +/** + * @brief Implementation independent image descriptor. + */ +typedef struct hsa_ext_image_descriptor_s { + /** + * Image geometry. + */ + hsa_ext_image_geometry_t geometry; + /** + * Width of the image, in components. + */ + size_t width; + /** + * Height of the image, in components. Only used if the geometry is + * ::HSA_EXT_IMAGE_GEOMETRY_2D, ::HSA_EXT_IMAGE_GEOMETRY_3D, + * HSA_EXT_IMAGE_GEOMETRY_2DA, HSA_EXT_IMAGE_GEOMETRY_2DDEPTH, or + * HSA_EXT_IMAGE_GEOMETRY_2DADEPTH, otherwise must be 0. + */ + size_t height; + /** + * Depth of the image, in components. Only used if the geometry is + * ::HSA_EXT_IMAGE_GEOMETRY_3D, otherwise must be 0. + */ + size_t depth; + /** + * Number of image layers in the image array. Only used if the geometry is + * ::HSA_EXT_IMAGE_GEOMETRY_1DA, ::HSA_EXT_IMAGE_GEOMETRY_2DA, or + * HSA_EXT_IMAGE_GEOMETRY_2DADEPTH, otherwise must be 0. + */ + size_t array_size; + /** + * Image format. + */ + hsa_ext_image_format_t format; +} hsa_ext_image_descriptor_t; + +/** + * @brief Image capability. + */ +typedef enum { + /** + * Images of this geometry, format, and layout are not supported by + * the agent. + */ + HSA_EXT_IMAGE_CAPABILITY_NOT_SUPPORTED = 0x0, + /** + * Read-only images of this geometry, format, and layout are + * supported by the agent. + */ + HSA_EXT_IMAGE_CAPABILITY_READ_ONLY = 0x1, + /** + * Write-only images of this geometry, format, and layout are + * supported by the agent. + */ + HSA_EXT_IMAGE_CAPABILITY_WRITE_ONLY = 0x2, + /** + * Read-write images of this geometry, format, and layout are + * supported by the agent. + */ + HSA_EXT_IMAGE_CAPABILITY_READ_WRITE = 0x4, + /** + * @deprecated Images of this geometry, format, and layout can be accessed + * from read-modify-write atomic operations in the agent. + */ + HSA_EXT_IMAGE_CAPABILITY_READ_MODIFY_WRITE = 0x8, + /** + * Images of this geometry, format, and layout are guaranteed to + * have a consistent data layout regardless of how they are + * accessed by the associated agent. + */ + HSA_EXT_IMAGE_CAPABILITY_ACCESS_INVARIANT_DATA_LAYOUT = 0x10 +} hsa_ext_image_capability_t; + +/** + * @brief Image data layout. + * + * @details An image data layout denotes such aspects of image data + * layout as tiling and organization of channels in memory. Some image + * data layouts may only apply to specific image geometries, formats, + * and access permissions. Different agents may support different + * image layout identifiers, including vendor specific layouts. Note + * that an agent may not support the same image data layout for + * different access permissions to images with the same image + * geometry, size, and format. If multiple agents support the same + * image data layout then it is possible to use separate image handles + * for each agent that references the same image data. + */ + +typedef enum { + /** + * An implementation specific opaque image data layout which can + * vary depending on the agent, geometry, image format, image size, + * and access permissions. + */ + HSA_EXT_IMAGE_DATA_LAYOUT_OPAQUE = 0x0, + /** + * The image data layout is specified by the following rules in + * ascending byte address order. For a 3D image, 2DA image array, + * or 1DA image array, the image data is stored as a linear sequence + * of adjacent 2D image slices, 2D images, or 1D images + * respectively, spaced according to the slice pitch. Each 2D image + * is stored as a linear sequence of adjacent image rows, spaced + * according to the row pitch. Each 1D or 1DB image is stored as a + * single image row. Each image row is stored as a linear sequence + * of image elements. Each image element is stored as a linear + * sequence of image components specified by the left to right + * channel order definition. Each image component is stored using + * the memory type specified by the channel type. + * + * The 1DB image geometry always uses the linear image data layout. + */ + HSA_EXT_IMAGE_DATA_LAYOUT_LINEAR = 0x1 +} hsa_ext_image_data_layout_t; + +/** + * @brief Retrieve the supported image capabilities for a given combination of + * agent, geometry, and image format for an image created with an opaque image + * data layout. + * + * @param[in] agent Agent to be associated with the image handle. + * + * @param[in] geometry Geometry. + * + * @param[in] image_format Pointer to an image format. Must not be NULL. + * + * @param[out] capability_mask Pointer to a memory location where the HSA + * runtime stores a bit-mask of supported image capability + * (::hsa_ext_image_capability_t) values. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p image_format is + * NULL, or @p capability_mask is NULL. + */ +hsa_status_t HSA_API hsa_ext_image_get_capability( + hsa_agent_t agent, hsa_ext_image_geometry_t geometry, + const hsa_ext_image_format_t *image_format, uint32_t *capability_mask); + +/** + * @brief Retrieve the supported image capabilities for a given combination of + * agent, geometry, image format, and image layout for an image created with + * an explicit image data layout. + * + * @param[in] agent Agent to be associated with the image handle. + * + * @param[in] geometry Geometry. + * + * @param[in] image_format Pointer to an image format. Must not be NULL. + * + * @param[in] image_data_layout The image data layout. + * It is invalid to use ::HSA_EXT_IMAGE_DATA_LAYOUT_OPAQUE; use + * ::hsa_ext_image_get_capability instead. + * + * @param[out] capability_mask Pointer to a memory location where the HSA + * runtime stores a bit-mask of supported image capability + * (::hsa_ext_image_capability_t) values. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p image_format is + * NULL, @p image_data_layout is ::HSA_EXT_IMAGE_DATA_LAYOUT_OPAQUE, + * or @p capability_mask is NULL. + */ +hsa_status_t HSA_API hsa_ext_image_get_capability_with_layout( + hsa_agent_t agent, hsa_ext_image_geometry_t geometry, + const hsa_ext_image_format_t *image_format, + hsa_ext_image_data_layout_t image_data_layout, uint32_t *capability_mask); + +/** + * @brief Agent specific image size and alignment requirements, populated by + * ::hsa_ext_image_data_get_info and ::hsa_ext_image_data_get_info_with_layout. + */ +typedef struct hsa_ext_image_data_info_s { + /** + * Image data size, in bytes. + */ + size_t size; + + /** + * Image data alignment, in bytes. Must always be a power of 2. + */ + size_t alignment; + +} hsa_ext_image_data_info_t; + +/** + * @brief Retrieve the image data requirements for a given combination of agent, + * image descriptor, and access permission for an image created with an opaque + * image data layout. + * + * @details The optimal image data size and alignment requirements may + * vary depending on the image attributes specified in @p + * image_descriptor, the @p access_permission, and the @p agent. Also, + * different implementations of the HSA runtime may return different + * requirements for the same input values. + * + * The implementation must return the same image data requirements for + * different access permissions with matching image descriptors as long + * as ::hsa_ext_image_get_capability reports + * ::HSA_EXT_IMAGE_CAPABILITY_ACCESS_INVARIANT_DATA_LAYOUT. Image + * descriptors match if they have the same values, with the exception + * that s-form channel orders match the corresponding non-s-form + * channel order and vice versa. + * + * @param[in] agent Agent to be associated with the image handle. + * + * @param[in] image_descriptor Pointer to an image descriptor. Must not be NULL. + * + * @param[in] access_permission Access permission of the image when + * accessed by @p agent. The access permission defines how the agent + * is allowed to access the image and must match the corresponding + * HSAIL image handle type. The @p agent must support the image format + * specified in @p image_descriptor for the given @p + * access_permission. + * + * @param[out] image_data_info Memory location where the runtime stores the + * size and alignment requirements. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_FORMAT_UNSUPPORTED The @p + * agent does not support the image format specified by @p + * image_descriptor with the specified @p access_permission. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_SIZE_UNSUPPORTED The agent + * does not support the image dimensions specified by @p + * image_descriptor with the specified @p access_permission. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p image_descriptor is NULL, @p + * access_permission is not a valid access permission value, or @p + * image_data_info is NULL. + */ +hsa_status_t HSA_API hsa_ext_image_data_get_info( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + hsa_access_permission_t access_permission, + hsa_ext_image_data_info_t *image_data_info); + +/** + * @brief Retrieve the image data requirements for a given combination of + * image descriptor, access permission, image data layout, image data row pitch, + * and image data slice pitch for an image created with an explicit image + * data layout. + * + * @details The image data size and alignment requirements may vary + * depending on the image attributes specified in @p image_descriptor, + * the @p access_permission, and the image layout. However, different + * implementations of the HSA runtime will return the same + * requirements for the same input values. + * + * The implementation must return the same image data requirements for + * different access permissions with matching image descriptors and + * matching image layouts as long as ::hsa_ext_image_get_capability + * reports + * ::HSA_EXT_IMAGE_CAPABILITY_ACCESS_INVARIANT_DATA_LAYOUT. Image + * descriptors match if they have the same values, with the exception + * that s-form channel orders match the corresponding non-s-form + * channel order and vice versa. Image layouts match if they are the + * same image data layout and use the same image row and slice pitch + * values. + * + * @param[in] image_descriptor Pointer to an image descriptor. Must not be NULL. + * + * @param[in] access_permission Access permission of the image when + * accessed by an agent. The access permission defines how the agent + * is allowed to access the image and must match the corresponding + * HSAIL image handle type. + * + * @param[in] image_data_layout The image data layout to use. + * It is invalid to use ::HSA_EXT_IMAGE_DATA_LAYOUT_OPAQUE; use + * ::hsa_ext_image_data_get_info instead. + * + * @param[in] image_data_row_pitch The size in bytes for a single row + * of the image in the image data. If 0 is specified then the default + * row pitch value is used: image width * image element byte size. + * The value used must be greater than or equal to the default row + * pitch, and be a multiple of the image element byte size. For the + * linear image layout it must also be a multiple of the image linear + * row pitch alignment for the agents that will access the image data + * using image instructions. + * + * @param[in] image_data_slice_pitch The size in bytes of a single + * slice of a 3D image, or the size in bytes of each image layer in an + * image array in the image data. If 0 is specified then the default + * slice pitch value is used: row pitch * height if geometry is + * ::HSA_EXT_IMAGE_GEOMETRY_3D, ::HSA_EXT_IMAGE_GEOMETRY_2DA, or + * ::HSA_EXT_IMAGE_GEOMETRY_2DADEPTH; row pitch if geometry is + * ::HSA_EXT_IMAGE_GEOMETRY_1DA; and 0 otherwise. The value used must + * be 0 if the default slice pitch is 0, be greater than or equal to + * the default slice pitch, and be a multiple of the row pitch. + * + * @param[out] image_data_info Memory location where the runtime stores the + * size and alignment requirements. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_FORMAT_UNSUPPORTED The image + * format specified by @p image_descriptor is not supported for the + * @p access_permission and @p image_data_layout specified. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_SIZE_UNSUPPORTED The image + * dimensions specified by @p image_descriptor are not supported for + * the @p access_permission and @p image_data_layout specified. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_PITCH_UNSUPPORTED The row and + * slice pitch specified by @p image_data_row_pitch and @p + * image_data_slice_pitch are invalid or not supported. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p image_descriptor is + * NULL, @p image_data_layout is ::HSA_EXT_IMAGE_DATA_LAYOUT_OPAQUE, + * or @p image_data_info is NULL. + */ +hsa_status_t HSA_API hsa_ext_image_data_get_info_with_layout( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + hsa_access_permission_t access_permission, + hsa_ext_image_data_layout_t image_data_layout, size_t image_data_row_pitch, + size_t image_data_slice_pitch, hsa_ext_image_data_info_t *image_data_info); + +/** + * @brief Creates an agent specific image handle to an image with an + * opaque image data layout. + * + * @details Images with an opaque image data layout created with + * different access permissions but matching image descriptors and + * same agent can share the same image data if + * ::HSA_EXT_IMAGE_CAPABILITY_ACCESS_INVARIANT_DATA_LAYOUT is reported + * by ::hsa_ext_image_get_capability for the image format specified in + * the image descriptor. Image descriptors match if they have the same + * values, with the exception that s-form channel orders match the + * corresponding non-s-form channel order and vice versa. + * + * If necessary, an application can use image operations (import, + * export, copy, clear) to prepare the image for the intended use + * regardless of the access permissions. + * + * @param[in] agent agent to be associated with the image handle created. + * + * @param[in] image_descriptor Pointer to an image descriptor. Must not be NULL. + * + * @param[in] image_data Image data buffer that must have been allocated + * according to the size and alignment requirements dictated by + * ::hsa_ext_image_data_get_info. Must not be NULL. + * + * Any previous memory contents are preserved upon creation. The application is + * responsible for ensuring that the lifetime of the image data exceeds that of + * all the associated images. + * + * @param[in] access_permission Access permission of the image when + * accessed by agent. The access permission defines how the agent + * is allowed to access the image using the image handle created and + * must match the corresponding HSAIL image handle type. The agent + * must support the image format specified in @p image_descriptor for + * the given @p access_permission. + * + * @param[out] image Pointer to a memory location where the HSA runtime stores + * the newly created image handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_FORMAT_UNSUPPORTED The agent + * does not have the capability to support the image format contained + * in @p image_descriptor using the specified @p access_permission. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_SIZE_UNSUPPORTED The agent + * does not support the image dimensions specified by @p + * image_descriptor using the specified @p access_permission. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * support the creation of more image handles with the given @p + * access_permission). + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p image_descriptor is NULL, @p + * image_data is NULL, @p image_data does not have a valid alignment, + * @p access_permission is not a valid access permission + * value, or @p image is NULL. + */ +hsa_status_t HSA_API hsa_ext_image_create( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + const void *image_data, hsa_access_permission_t access_permission, + hsa_ext_image_t *image); + +/** + * @brief Creates an agent specific image handle to an image with an explicit + * image data layout. + * + * @details Images with an explicit image data layout created with + * different access permissions but matching image descriptors and + * matching image layout can share the same image data if + * ::HSA_EXT_IMAGE_CAPABILITY_ACCESS_INVARIANT_DATA_LAYOUT is reported + * by ::hsa_ext_image_get_capability_with_layout for the image format + * specified in the image descriptor and specified image data + * layout. Image descriptors match if they have the same values, with + * the exception that s-form channel orders match the corresponding + * non-s-form channel order and vice versa. Image layouts match if + * they are the same image data layout and use the same image row and + * slice values. + * + * If necessary, an application can use image operations (import, export, copy, + * clear) to prepare the image for the intended use regardless of the access + * permissions. + * + * @param[in] agent agent to be associated with the image handle created. + * + * @param[in] image_descriptor Pointer to an image descriptor. Must not be NULL. + * + * @param[in] image_data Image data buffer that must have been allocated + * according to the size and alignment requirements dictated by + * ::hsa_ext_image_data_get_info_with_layout. Must not be NULL. + * + * Any previous memory contents are preserved upon creation. The application is + * responsible for ensuring that the lifetime of the image data exceeds that of + * all the associated images. + * + * @param[in] access_permission Access permission of the image when + * accessed by the agent. The access permission defines how the agent + * is allowed to access the image and must match the corresponding + * HSAIL image handle type. The agent must support the image format + * specified in @p image_descriptor for the given @p access_permission + * and @p image_data_layout. + * + * @param[in] image_data_layout The image data layout to use for the + * @p image_data. It is invalid to use + * ::HSA_EXT_IMAGE_DATA_LAYOUT_OPAQUE; use ::hsa_ext_image_create + * instead. + * + * @param[in] image_data_row_pitch The size in bytes for a single row + * of the image in the image data. If 0 is specified then the default + * row pitch value is used: image width * image element byte size. + * The value used must be greater than or equal to the default row + * pitch, and be a multiple of the image element byte size. For the + * linear image layout it must also be a multiple of the image linear + * row pitch alignment for the agents that will access the image data + * using image instructions. + * + * @param[in] image_data_slice_pitch The size in bytes of a single + * slice of a 3D image, or the size in bytes of each image layer in an + * image array in the image data. If 0 is specified then the default + * slice pitch value is used: row pitch * height if geometry is + * ::HSA_EXT_IMAGE_GEOMETRY_3D, ::HSA_EXT_IMAGE_GEOMETRY_2DA, or + * ::HSA_EXT_IMAGE_GEOMETRY_2DADEPTH; row pitch if geometry is + * ::HSA_EXT_IMAGE_GEOMETRY_1DA; and 0 otherwise. The value used must + * be 0 if the default slice pitch is 0, be greater than or equal to + * the default slice pitch, and be a multiple of the row pitch. + * + * @param[out] image Pointer to a memory location where the HSA runtime stores + * the newly created image handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_FORMAT_UNSUPPORTED The agent does + * not have the capability to support the image format contained in the image + * descriptor using the specified @p access_permission and @p image_data_layout. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_SIZE_UNSUPPORTED The agent + * does not support the image dimensions specified by @p + * image_descriptor using the specified @p access_permission and @p + * image_data_layout. + * + * @retval ::HSA_EXT_STATUS_ERROR_IMAGE_PITCH_UNSUPPORTED The agent does + * not support the row and slice pitch specified by @p image_data_row_pitch + * and @p image_data_slice_pitch, or the values are invalid. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * support the creation of more image handles with the given @p + * access_permission). + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p image_descriptor is NULL, @p + * image_data is NULL, @p image_data does not have a valid alignment, + * @p image_data_layout is ::HSA_EXT_IMAGE_DATA_LAYOUT_OPAQUE, + * or @p image is NULL. + */ +hsa_status_t HSA_API hsa_ext_image_create_with_layout( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + const void *image_data, hsa_access_permission_t access_permission, + hsa_ext_image_data_layout_t image_data_layout, size_t image_data_row_pitch, + size_t image_data_slice_pitch, hsa_ext_image_t *image); + +/** + * @brief Destroy an image handle previously created using + * ::hsa_ext_image_create or + * ::hsa_ext_image_create_with_layout. + * + * @details Destroying the image handle does not free the associated image data, + * or modify its contents. The application should not destroy an image handle + * while there are references to it queued for execution or currently being used + * in a kernel dispatch. + * + * @param[in] agent Agent associated with the image handle. + * + * @param[in] image Image handle to destroy. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + */ +hsa_status_t HSA_API hsa_ext_image_destroy(hsa_agent_t agent, + hsa_ext_image_t image); + +/** + * @brief Copies a portion of one image (the source) to another image (the + * destination). + * + * @details The source and destination image formats should be the + * same, with the exception that s-form channel orders match the + * corresponding non-s-form channel order and vice versa. For example, + * it is allowed to copy a source image with a channel order of + * HSA_EXT_IMAGE_CHANNEL_ORDER_SRGB to a destination image with a + * channel order of HSA_EXT_IMAGE_CHANNEL_ORDER_RGB. + * + * The source and destination images do not have to be of the same geometry and + * appropriate scaling is performed by the HSA runtime. It is possible to copy + * subregions between any combinations of source and destination geometries, + * provided that the dimensions of the subregions are the same. For example, it + * is allowed to copy a rectangular region from a 2D image to a slice of a 3D + * image. + * + * If the source and destination image data overlap, or the combination of + * offset and range references an out-out-bounds element in any of the images, + * the behavior is undefined. + * + * @param[in] agent Agent associated with both the source and destination image + * handles. + * + * @param[in] src_image Image handle of source image. The agent associated with + * the source image handle must be identical to that of the destination image. + * + * @param[in] src_offset Pointer to the offset within the source image where to + * copy the data from. Must not be NULL. + * + * @param[in] dst_image Image handle of destination image. + * + * @param[in] dst_offset Pointer to the offset within the destination + * image where to copy the data. Must not be NULL. + * + * @param[in] range Dimensions of the image portion to be copied. The HSA + * runtime computes the size of the image data to be copied using this + * argument. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p src_offset is + * NULL, @p dst_offset is NULL, or @p range is NULL. + */ +hsa_status_t HSA_API hsa_ext_image_copy(hsa_agent_t agent, + hsa_ext_image_t src_image, + const hsa_dim3_t *src_offset, + hsa_ext_image_t dst_image, + const hsa_dim3_t *dst_offset, + const hsa_dim3_t *range); + +/** + * @brief Image region. + */ +typedef struct hsa_ext_image_region_s { + /** + * Offset within an image (in coordinates). + */ + hsa_dim3_t offset; + + /** + * Dimension size of the image range (in coordinates). The x, y, and z + * dimensions correspond to width, height, and depth or index respectively. + */ + hsa_dim3_t range; +} hsa_ext_image_region_t; + +/** + * @brief Import a linearly organized image data from memory directly to an + * image handle. + * + * @details This operation updates the image data referenced by the image handle + * from the source memory. The size of the data imported from memory is + * implicitly derived from the image region. + * + * It is the application's responsibility to avoid out of bounds memory access. + * + * None of the source memory or destination image data memory can + * overlap. Overlapping of any of the source and destination image + * data memory within the import operation produces undefined results. + * + * @param[in] agent Agent associated with the image handle. + * + * @param[in] src_memory Source memory. Must not be NULL. + * + * @param[in] src_row_pitch The size in bytes of a single row of the image in + * the source memory. If the value is smaller than the destination image region + * width * image element byte size, then region width * image element byte + * size is used. + * + * @param[in] src_slice_pitch The size in bytes of a single 2D slice of a 3D + * image, or the size in bytes of each image layer in an image array in the + * source memory. If the geometry is ::HSA_EXT_IMAGE_GEOMETRY_1DA and the value + * is smaller than the value used for @p src_row_pitch, then the value used for + * @p src_row_pitch is used. If the geometry is ::HSA_EXT_IMAGE_GEOMETRY_3D, + * ::HSA_EXT_IMAGE_GEOMETRY_2DA, or HSA_EXT_IMAGE_GEOMETRY_2DADEPTH and the + * value is smaller than the value used for + * @p src_row_pitch * destination image region height, then the value used for + * @p src_row_pitch * destination image region height is used. + * Otherwise, the value is not used. + * + * @param[in] dst_image Image handle of destination image. + * + * @param[in] image_region Pointer to the image region to be updated. Must not + * be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p src_memory is NULL, or @p + * image_region is NULL. + * + */ +hsa_status_t HSA_API hsa_ext_image_import( + hsa_agent_t agent, const void *src_memory, size_t src_row_pitch, + size_t src_slice_pitch, hsa_ext_image_t dst_image, + const hsa_ext_image_region_t *image_region); + +/** + * @brief Export the image data to linearly organized memory. + * + * @details The operation updates the destination memory with the image data of + * @p src_image. The size of the data exported to memory is implicitly derived + * from the image region. + * + * It is the application's responsibility to avoid out of bounds memory access. + * + * None of the destination memory or source image data memory can + * overlap. Overlapping of any of the source and destination image + * data memory within the export operation produces undefined results. + * + * @param[in] agent Agent associated with the image handle. + * + * @param[in] src_image Image handle of source image. + * + * @param[in] dst_memory Destination memory. Must not be NULL. + * + * @param[in] dst_row_pitch The size in bytes of a single row of the image in + * the destination memory. If the value is smaller than the source image region + * width * image element byte size, then region width * image element byte + * size is used. + * + * @param[in] dst_slice_pitch The size in bytes of a single 2D slice of a 3D + * image, or the size in bytes of each image in an image array in the + * destination memory. If the geometry is ::HSA_EXT_IMAGE_GEOMETRY_1DA and the + * value is smaller than the value used for @p dst_row_pitch, then the value + * used for @p dst_row_pitch is used. If the geometry is + * ::HSA_EXT_IMAGE_GEOMETRY_3D, ::HSA_EXT_IMAGE_GEOMETRY_2DA, or + * HSA_EXT_IMAGE_GEOMETRY_2DADEPTH and the value is smaller than the value used + * for + * @p dst_row_pitch * source image region height, then the value used for + * @p dst_row_pitch * source image region height is used. + * Otherwise, the value is not used. + * + * @param[in] image_region Pointer to the image region to be exported. Must not + * be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p dst_memory is NULL, or @p + * image_region is NULL. + */ +hsa_status_t HSA_API hsa_ext_image_export( + hsa_agent_t agent, hsa_ext_image_t src_image, void *dst_memory, + size_t dst_row_pitch, size_t dst_slice_pitch, + const hsa_ext_image_region_t *image_region); + +/** + * @brief Clear a region of an image so that every image element has + * the specified value. + * + * @param[in] agent Agent associated with the image handle. + * + * @param[in] image Image handle for image to be cleared. + * + * @param[in] data The value to which to set each image element being + * cleared. It is specified as an array of image component values. The + * number of array elements must match the number of access components + * for the image channel order. The type of each array element must + * match the image access type of the image channel type. When the + * value is used to set the value of an image element, the conversion + * method corresponding to the image channel type is used. See the + * Channel Order section and Channel Type section in + * the HSA Programming Reference Manual for more + * information. Must not be NULL. + * + * @param[in] image_region Pointer to the image region to clear. Must not be + * NULL. If the region references an out-out-bounds element, the behavior is + * undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p data is NULL, or @p + * image_region is NULL. + */ +hsa_status_t HSA_API +hsa_ext_image_clear(hsa_agent_t agent, hsa_ext_image_t image, const void *data, + const hsa_ext_image_region_t *image_region); + +/** + * @brief Sampler handle. Samplers are populated by + * ::hsa_ext_sampler_create. Sampler handles are only unique within an + * agent, not across agents. + */ +typedef struct hsa_ext_sampler_s { + /** + * Opaque handle. For a given agent, two handles reference the same object of + * the enclosing type if and only if they are equal. + */ + uint64_t handle; +} hsa_ext_sampler_t; + +/** + * @brief Sampler address modes. The sampler address mode describes + * the processing of out-of-range image coordinates. See the + * Addressing Mode section in the HSA Programming Reference + * Manual for definitions on each address mode. The values + * match the BRIG type @p hsa_ext_brig_sampler_addressing_t. + */ +typedef enum { + /** + * Out-of-range coordinates are not handled. + */ + HSA_EXT_SAMPLER_ADDRESSING_MODE_UNDEFINED = 0, + + /** + * Clamp out-of-range coordinates to the image edge. + */ + HSA_EXT_SAMPLER_ADDRESSING_MODE_CLAMP_TO_EDGE = 1, + + /** + * Clamp out-of-range coordinates to the image border color. + */ + HSA_EXT_SAMPLER_ADDRESSING_MODE_CLAMP_TO_BORDER = 2, + + /** + * Wrap out-of-range coordinates back into the valid coordinate + * range so the image appears as repeated tiles. + */ + HSA_EXT_SAMPLER_ADDRESSING_MODE_REPEAT = 3, + + /** + * Mirror out-of-range coordinates back into the valid coordinate + * range so the image appears as repeated tiles with every other + * tile a reflection. + */ + HSA_EXT_SAMPLER_ADDRESSING_MODE_MIRRORED_REPEAT = 4 + +} hsa_ext_sampler_addressing_mode_t; + +/** + * @brief A fixed-size type used to represent + * ::hsa_ext_sampler_addressing_mode_t constants. + */ +typedef uint32_t hsa_ext_sampler_addressing_mode32_t; + +/** + * @brief Sampler coordinate normalization modes. See the + * Coordinate Normalization Mode section in the HSA + * Programming Reference Manual for definitions on each + * coordinate normalization mode. The values match the BRIG type @p + * hsa_ext_brig_sampler_coord_normalization_t. + */ +typedef enum { + + /** + * Coordinates are used to directly address an image element. + */ + HSA_EXT_SAMPLER_COORDINATE_MODE_UNNORMALIZED = 0, + + /** + * Coordinates are scaled by the image dimension size before being + * used to address an image element. + */ + HSA_EXT_SAMPLER_COORDINATE_MODE_NORMALIZED = 1 + +} hsa_ext_sampler_coordinate_mode_t; + +/** + * @brief A fixed-size type used to represent + * ::hsa_ext_sampler_coordinate_mode_t constants. + */ +typedef uint32_t hsa_ext_sampler_coordinate_mode32_t; + +/** + * @brief Sampler filter modes. See the Filter Mode section + * in the HSA Programming Reference Manual for definitions + * on each address mode. The enumeration values match the BRIG type @p + * hsa_ext_brig_sampler_filter_t. + */ +typedef enum { + /** + * Filter to the image element nearest (in Manhattan distance) to the + * specified coordinate. + */ + HSA_EXT_SAMPLER_FILTER_MODE_NEAREST = 0, + + /** + * Filter to the image element calculated by combining the elements in a 2x2 + * square block or 2x2x2 cube block around the specified coordinate. The + * elements are combined using linear interpolation. + */ + HSA_EXT_SAMPLER_FILTER_MODE_LINEAR = 1 + +} hsa_ext_sampler_filter_mode_t; + +/** + * @brief A fixed-size type used to represent ::hsa_ext_sampler_filter_mode_t + * constants. + */ +typedef uint32_t hsa_ext_sampler_filter_mode32_t; + +/** + * @brief Implementation independent sampler descriptor. + */ +typedef struct hsa_ext_sampler_descriptor_s { + /** + * Sampler coordinate mode describes the normalization of image coordinates. + */ + hsa_ext_sampler_coordinate_mode32_t coordinate_mode; + + /** + * Sampler filter type describes the type of sampling performed. + */ + hsa_ext_sampler_filter_mode32_t filter_mode; + + /** + * Sampler address mode describes the processing of out-of-range image + * coordinates. + */ + hsa_ext_sampler_addressing_mode32_t address_mode; + +} hsa_ext_sampler_descriptor_t; + +/** + * @brief Create an agent specific sampler handle for a given agent + * independent sampler descriptor and agent. + * + * @param[in] agent Agent to be associated with the sampler handle created. + * + * @param[in] sampler_descriptor Pointer to a sampler descriptor. Must not be + * NULL. + * + * @param[out] sampler Memory location where the HSA runtime stores the newly + * created sampler handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + * + * @retval ::HSA_EXT_STATUS_ERROR_SAMPLER_DESCRIPTOR_UNSUPPORTED The + * @p agent does not have the capability to support the properties + * specified by @p sampler_descriptor or it is invalid. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p sampler_descriptor is NULL, or + * @p sampler is NULL. + */ +hsa_status_t HSA_API hsa_ext_sampler_create( + hsa_agent_t agent, const hsa_ext_sampler_descriptor_t *sampler_descriptor, + hsa_ext_sampler_t *sampler); + +/** + * @brief Destroy a sampler handle previously created using + * ::hsa_ext_sampler_create. + * + * @details The sampler handle should not be destroyed while there are + * references to it queued for execution or currently being used in a + * kernel dispatch. + * + * @param[in] agent Agent associated with the sampler handle. + * + * @param[in] sampler Sampler handle to destroy. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_AGENT The agent is invalid. + */ +hsa_status_t HSA_API hsa_ext_sampler_destroy(hsa_agent_t agent, + hsa_ext_sampler_t sampler); + +#define hsa_ext_images_1_00 + +/** + * @brief The function pointer table for the images v1.00 extension. Can be + * returned by ::hsa_system_get_extension_table or + * ::hsa_system_get_major_extension_table. + */ +typedef struct hsa_ext_images_1_00_pfn_s { + + hsa_status_t (*hsa_ext_image_get_capability)( + hsa_agent_t agent, hsa_ext_image_geometry_t geometry, + const hsa_ext_image_format_t *image_format, uint32_t *capability_mask); + + hsa_status_t (*hsa_ext_image_data_get_info)( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + hsa_access_permission_t access_permission, + hsa_ext_image_data_info_t *image_data_info); + + hsa_status_t (*hsa_ext_image_create)( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + const void *image_data, hsa_access_permission_t access_permission, + hsa_ext_image_t *image); + + hsa_status_t (*hsa_ext_image_destroy)(hsa_agent_t agent, + hsa_ext_image_t image); + + hsa_status_t (*hsa_ext_image_copy)(hsa_agent_t agent, + hsa_ext_image_t src_image, + const hsa_dim3_t *src_offset, + hsa_ext_image_t dst_image, + const hsa_dim3_t *dst_offset, + const hsa_dim3_t *range); + + hsa_status_t (*hsa_ext_image_import)( + hsa_agent_t agent, const void *src_memory, size_t src_row_pitch, + size_t src_slice_pitch, hsa_ext_image_t dst_image, + const hsa_ext_image_region_t *image_region); + + hsa_status_t (*hsa_ext_image_export)( + hsa_agent_t agent, hsa_ext_image_t src_image, void *dst_memory, + size_t dst_row_pitch, size_t dst_slice_pitch, + const hsa_ext_image_region_t *image_region); + + hsa_status_t (*hsa_ext_image_clear)( + hsa_agent_t agent, hsa_ext_image_t image, const void *data, + const hsa_ext_image_region_t *image_region); + + hsa_status_t (*hsa_ext_sampler_create)( + hsa_agent_t agent, const hsa_ext_sampler_descriptor_t *sampler_descriptor, + hsa_ext_sampler_t *sampler); + + hsa_status_t (*hsa_ext_sampler_destroy)(hsa_agent_t agent, + hsa_ext_sampler_t sampler); + +} hsa_ext_images_1_00_pfn_t; + +#define hsa_ext_images_1 + +/** + * @brief The function pointer table for the images v1 extension. Can be + * returned by ::hsa_system_get_extension_table or + * ::hsa_system_get_major_extension_table. + */ +typedef struct hsa_ext_images_1_pfn_s { + + hsa_status_t (*hsa_ext_image_get_capability)( + hsa_agent_t agent, hsa_ext_image_geometry_t geometry, + const hsa_ext_image_format_t *image_format, uint32_t *capability_mask); + + hsa_status_t (*hsa_ext_image_data_get_info)( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + hsa_access_permission_t access_permission, + hsa_ext_image_data_info_t *image_data_info); + + hsa_status_t (*hsa_ext_image_create)( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + const void *image_data, hsa_access_permission_t access_permission, + hsa_ext_image_t *image); + + hsa_status_t (*hsa_ext_image_destroy)(hsa_agent_t agent, + hsa_ext_image_t image); + + hsa_status_t (*hsa_ext_image_copy)(hsa_agent_t agent, + hsa_ext_image_t src_image, + const hsa_dim3_t *src_offset, + hsa_ext_image_t dst_image, + const hsa_dim3_t *dst_offset, + const hsa_dim3_t *range); + + hsa_status_t (*hsa_ext_image_import)( + hsa_agent_t agent, const void *src_memory, size_t src_row_pitch, + size_t src_slice_pitch, hsa_ext_image_t dst_image, + const hsa_ext_image_region_t *image_region); + + hsa_status_t (*hsa_ext_image_export)( + hsa_agent_t agent, hsa_ext_image_t src_image, void *dst_memory, + size_t dst_row_pitch, size_t dst_slice_pitch, + const hsa_ext_image_region_t *image_region); + + hsa_status_t (*hsa_ext_image_clear)( + hsa_agent_t agent, hsa_ext_image_t image, const void *data, + const hsa_ext_image_region_t *image_region); + + hsa_status_t (*hsa_ext_sampler_create)( + hsa_agent_t agent, const hsa_ext_sampler_descriptor_t *sampler_descriptor, + hsa_ext_sampler_t *sampler); + + hsa_status_t (*hsa_ext_sampler_destroy)(hsa_agent_t agent, + hsa_ext_sampler_t sampler); + + hsa_status_t (*hsa_ext_image_get_capability_with_layout)( + hsa_agent_t agent, hsa_ext_image_geometry_t geometry, + const hsa_ext_image_format_t *image_format, + hsa_ext_image_data_layout_t image_data_layout, uint32_t *capability_mask); + + hsa_status_t (*hsa_ext_image_data_get_info_with_layout)( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + hsa_access_permission_t access_permission, + hsa_ext_image_data_layout_t image_data_layout, + size_t image_data_row_pitch, size_t image_data_slice_pitch, + hsa_ext_image_data_info_t *image_data_info); + + hsa_status_t (*hsa_ext_image_create_with_layout)( + hsa_agent_t agent, const hsa_ext_image_descriptor_t *image_descriptor, + const void *image_data, hsa_access_permission_t access_permission, + hsa_ext_image_data_layout_t image_data_layout, + size_t image_data_row_pitch, size_t image_data_slice_pitch, + hsa_ext_image_t *image); + +} hsa_ext_images_1_pfn_t; +/** @} */ + +#ifdef __cplusplus +} // end extern "C" block +#endif /*__cplusplus*/ + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_aqlprofile.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_aqlprofile.h new file mode 100644 index 000000000..0e3c8e9a1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_aqlprofile.h @@ -0,0 +1,491 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef OPENSRC_HSA_RUNTIME_INC_HSA_VEN_AMD_AQLPROFILE_H_ +#define OPENSRC_HSA_RUNTIME_INC_HSA_VEN_AMD_AQLPROFILE_H_ + +#include "hsa.h" +#include + +#define HSA_AQLPROFILE_VERSION_MAJOR 2 +#define HSA_AQLPROFILE_VERSION_MINOR 0 + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//////////////////////////////////////////////////////////////////////////////// +// Library version +uint32_t hsa_ven_amd_aqlprofile_version_major(); +uint32_t hsa_ven_amd_aqlprofile_version_minor(); + +/////////////////////////////////////////////////////////////////////// +// Library API: +// The library provides helper methods for instantiation of +// the profile context object and for populating of the start +// and stop AQL packets. The profile object contains a profiling +// events list and needed for profiling buffers descriptors, +// a command buffer and an output data buffer. To check if there +// was an error the library methods return a status code. Also +// the library provides methods for querying required buffers +// attributes, to validate the event attributes and to get profiling +// output data. +// +// Returned status: +// hsa_status_t – HSA status codes are used from hsa.h header +// +// Supported profiling features: +// +// Supported profiling events +typedef enum { + HSA_VEN_AMD_AQLPROFILE_EVENT_TYPE_PMC = 0, + HSA_VEN_AMD_AQLPROFILE_EVENT_TYPE_TRACE = 1, +} hsa_ven_amd_aqlprofile_event_type_t; + +// Supported performance counters (PMC) blocks +// The block ID is the same for a block instances set, for example +// each block instance from the TCC block set, TCC0, TCC1, …, TCCN +// will have the same block ID HSA_VEN_AMD_AQLPROFILE_BLOCKS_TCC. +typedef enum { + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_CPC = 0, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_CPF = 1, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GDS = 2, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GRBM = 3, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GRBMSE = 4, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_SPI = 5, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_SQ = 6, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_SQCS = 7, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_SRBM = 8, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_SX = 9, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_TA = 10, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_TCA = 11, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_TCC = 12, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_TCP = 13, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_TD = 14, + // Memory related blocks + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_MCARB = 15, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_MCHUB = 16, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_MCMCBVM = 17, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_MCSEQ = 18, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_MCVML2 = 19, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_MCXBAR = 20, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_ATC = 21, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_ATCL2 = 22, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GCEA = 23, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_RPB = 24, + // System blocks + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_SDMA = 25, + // GFX10 added blocks + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GL1A = 26, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GL1C = 27, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GL2A = 28, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GL2C = 29, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GCR = 30, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_GUS = 31, + + // UMC & MMEA System Blocks + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_UMC = 32, + HSA_VEN_AMD_AQLPROFILE_BLOCK_NAME_MMEA = 33, + + HSA_VEN_AMD_AQLPROFILE_BLOCKS_NUMBER +} hsa_ven_amd_aqlprofile_block_name_t; + +// PMC event object structure +// ‘counter_id’ value is specified in GFXIPs perfcounter user guides +// which is the counters select value, “Performance Counters Selection” +// chapter. +typedef struct { + hsa_ven_amd_aqlprofile_block_name_t block_name; + uint32_t block_index; + uint32_t counter_id; +} hsa_ven_amd_aqlprofile_event_t; + +// Check if event is valid for the specific GPU +hsa_status_t hsa_ven_amd_aqlprofile_validate_event( + hsa_agent_t agent, // HSA handle for the profiling GPU + const hsa_ven_amd_aqlprofile_event_t + *event, // [in] Pointer on validated event + bool *result); // [out] True if the event valid, False otherwise + +// Profiling parameters +// All parameters are generic and if not applicable for a specific +// profile configuration then error status will be returned. +typedef enum { + /** + * Select the target compute unit (wgp) for profiling. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_COMPUTE_UNIT_TARGET = 0, + /** + * VMID Mask + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_VM_ID_MASK = 1, + /** + * Legacy. Deprecated. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_MASK = 2, + /** + * Legacy. Deprecated. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_TOKEN_MASK = 3, + /** + * Legacy. Deprecated. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_TOKEN_MASK2 = 4, + /** + * Shader engine mask for selection. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_SE_MASK = 5, + /** + * Legacy. Deprecated. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_SAMPLE_RATE = 6, + /** + * Legacy. Deprecated. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_K_CONCURRENT = 7, + /** + * Set SIMD Mask (GFX9) or SIMD ID for collection (Navi) + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_SIMD_SELECTION = 8, + /** + * Set true for occupancy collection only. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_OCCUPANCY_MODE = 9, + /** + * ATT collection max data size, in MB. Shared among shader engines. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_ATT_BUFFER_SIZE = 10, + /** + * Mask of which compute units to generate perfcounters. GFX9 only. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_PERFCOUNTER_MASK = 240, + /** + * Select collection period for perfcounters. GFX9 only. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_PERFCOUNTER_CTRL = 241, + /** + * Select perfcounter ID (SQ block) for collection. GFX9 only. + */ + HSA_VEN_AMD_AQLPROFILE_PARAMETER_NAME_PERFCOUNTER_NAME = 242, +} hsa_ven_amd_aqlprofile_parameter_name_t; + +// Profile parameter object +typedef struct { + hsa_ven_amd_aqlprofile_parameter_name_t parameter_name; + uint32_t value; +} hsa_ven_amd_aqlprofile_parameter_t; + +typedef enum { + HSA_VEN_AMD_AQLPROFILE_ATT_CHANNEL_0 = 0, + HSA_VEN_AMD_AQLPROFILE_ATT_CHANNEL_1, + HSA_VEN_AMD_AQLPROFILE_ATT_CHANNEL_2, + HSA_VEN_AMD_AQLPROFILE_ATT_CHANNEL_3 +} hsa_ven_amd_aqlprofile_att_marker_channel_t; + +// +// Profile context object: +// The library provides a profile object structure which contains +// the events array, a buffer for the profiling start/stop commands +// and a buffer for the output data. +// The buffers are specified by the buffer descriptors and allocated +// by the application. The buffers allocation attributes, the command +// buffer size, the PMC output buffer size as well as profiling output +// data can be get using the generic get profile info helper _get_info. +// +// Buffer descriptor +typedef struct { + void *ptr; + uint32_t size; +} hsa_ven_amd_aqlprofile_descriptor_t; + +// Profile context object structure, contains profiling events list and +// needed for profiling buffers descriptors, a command buffer and +// an output data buffer +typedef struct { + hsa_agent_t agent; // GFXIP handle + hsa_ven_amd_aqlprofile_event_type_t type; // Events type + const hsa_ven_amd_aqlprofile_event_t *events; // Events array + uint32_t event_count; // Events count + const hsa_ven_amd_aqlprofile_parameter_t *parameters; // Parameters array + uint32_t parameter_count; // Parameters count + hsa_ven_amd_aqlprofile_descriptor_t output_buffer; // Output buffer + hsa_ven_amd_aqlprofile_descriptor_t command_buffer; // PM4 commands +} hsa_ven_amd_aqlprofile_profile_t; + +// +// AQL packets populating methods: +// The helper methods to populate provided by the application START and +// STOP AQL packets which the application is required to submit before and +// after profiled GPU task packets respectively. +// +// AQL Vendor Specific packet which carries a PM4 command +typedef struct { + uint16_t header; + uint16_t pm4_command[27]; + hsa_signal_t completion_signal; +} hsa_ext_amd_aql_pm4_packet_t; + +// Method to populate the provided AQL packet with profiling start commands +// Only 'pm4_command' fields of the packet are set and the application +// is responsible to set Vendor Specific header type a completion signal +hsa_status_t hsa_ven_amd_aqlprofile_start( + hsa_ven_amd_aqlprofile_profile_t *profile, // [in/out] profile contex object + hsa_ext_amd_aql_pm4_packet_t + *aql_start_packet); // [out] profile start AQL packet + +// Method to populate the provided AQL packet with profiling stop commands +// Only 'pm4_command' fields of the packet are set and the application +// is responsible to set Vendor Specific header type and a completion signal +hsa_status_t hsa_ven_amd_aqlprofile_stop( + const hsa_ven_amd_aqlprofile_profile_t + *profile, // [in] profile contex object + hsa_ext_amd_aql_pm4_packet_t + *aql_stop_packet); // [out] profile stop AQL packet + +// Method to populate the provided AQL packet with profiling read commands +// Only 'pm4_command' fields of the packet are set and the application +// is responsible to set Vendor Specific header type and a completion signal +hsa_status_t hsa_ven_amd_aqlprofile_read( + const hsa_ven_amd_aqlprofile_profile_t + *profile, // [in] profile contex object + hsa_ext_amd_aql_pm4_packet_t + *aql_read_packet); // [out] profile stop AQL packet + +// Legacy devices, PM4 profiling packet size +const unsigned HSA_VEN_AMD_AQLPROFILE_LEGACY_PM4_PACKET_SIZE = 192; +// Legacy devices, converting the profiling AQL packet to PM4 packet blob +hsa_status_t hsa_ven_amd_aqlprofile_legacy_get_pm4( + const hsa_ext_amd_aql_pm4_packet_t *aql_packet, // [in] AQL packet + void *data); // [out] PM4 packet blob + +// Method to add a marker (correlation ID) into the ATT buffer. +hsa_status_t hsa_ven_amd_aqlprofile_att_marker( + hsa_ven_amd_aqlprofile_profile_t *profile, // [in/out] profile contex object + hsa_ext_amd_aql_pm4_packet_t + *aql_marker_packet, // [out] profile marker AQL packet + uint32_t data, // [in] Data to be inserted + hsa_ven_amd_aqlprofile_att_marker_channel_t channel); // [in] Comm channel + +// +// Get profile info: +// Generic method for getting various profile info including profile buffers +// attributes like the command buffer size and the profiling PMC results. +// It’s implied that all counters are 64bit values. +// +// Profile generic output data: +typedef struct { + uint32_t sample_id; // PMC sample or trace buffer index + union { + struct { + hsa_ven_amd_aqlprofile_event_t event; // PMC event + uint64_t result; // PMC result + } pmc_data; + hsa_ven_amd_aqlprofile_descriptor_t + trace_data; // Trace output data descriptor + }; +} hsa_ven_amd_aqlprofile_info_data_t; + +// ID query type +typedef struct { + const char *name; + uint32_t id; + uint32_t instance_count; +} hsa_ven_amd_aqlprofile_id_query_t; + +// Profile attributes +typedef enum { + HSA_VEN_AMD_AQLPROFILE_INFO_COMMAND_BUFFER_SIZE = + 0, // get_info returns uint32_t value + HSA_VEN_AMD_AQLPROFILE_INFO_PMC_DATA_SIZE = + 1, // get_info returns uint32_t value + HSA_VEN_AMD_AQLPROFILE_INFO_PMC_DATA = 2, // get_info returns PMC uint64_t + // value in info_data object + HSA_VEN_AMD_AQLPROFILE_INFO_TRACE_DATA = 3, // get_info returns trace buffer + // ptr/size in info_data object + HSA_VEN_AMD_AQLPROFILE_INFO_BLOCK_COUNTERS = + 4, // get_info returns number of block counter + HSA_VEN_AMD_AQLPROFILE_INFO_BLOCK_ID = + 5, // get_info returns block id, instances + // by name string using _id_query_t + HSA_VEN_AMD_AQLPROFILE_INFO_ENABLE_CMD = + 6, // get_info returns size/pointer for + // counters enable command buffer + HSA_VEN_AMD_AQLPROFILE_INFO_DISABLE_CMD = + 7, // get_info returns size/pointer for + // counters disable command buffer +} hsa_ven_amd_aqlprofile_info_type_t; + +// Definition of output data iterator callback +typedef hsa_status_t (*hsa_ven_amd_aqlprofile_data_callback_t)( + hsa_ven_amd_aqlprofile_info_type_t + info_type, // [in] data type, PMC or trace data + hsa_ven_amd_aqlprofile_info_data_t *info_data, // [in] info_data object + void *callback_data); // [in/out] data passed to the callback + +// Method for getting the profile info +hsa_status_t hsa_ven_amd_aqlprofile_get_info( + const hsa_ven_amd_aqlprofile_profile_t + *profile, // [in] profile context object + hsa_ven_amd_aqlprofile_info_type_t + attribute, // [in] requested profile attribute + void *value); // [in/out] returned value + +// Method for iterating the events output data +hsa_status_t hsa_ven_amd_aqlprofile_iterate_data( + const hsa_ven_amd_aqlprofile_profile_t + *profile, // [in] profile context object + hsa_ven_amd_aqlprofile_data_callback_t + callback, // [in] callback to iterate the output data + void *data); // [in/out] data passed to the callback + +// Return error string +hsa_status_t hsa_ven_amd_aqlprofile_error_string( + const char **str); // [out] pointer on the error string + +/** + * @brief Callback for iteration of all possible event coordinate IDs and + * coordinate names. + */ +typedef hsa_status_t (*hsa_ven_amd_aqlprofile_eventname_callback_t)( + int id, const char *name); +/** + * @brief Iterate over all possible event coordinate IDs and their names. + */ +hsa_status_t hsa_ven_amd_aqlprofile_iterate_event_ids( + hsa_ven_amd_aqlprofile_eventname_callback_t); + +/** + * @brief Iterate over all event coordinates for a given agent_t and event_t. + * @param position A counting sequence indicating callback number. + * @param id Coordinate ID as in _iterate_event_ids. + * @param extent Coordinate extent indicating maximum allowed instances. + * @param coordinate The coordinate, in the range [0,extent-1]. + * @param name Coordinate name as in _iterate_event_ids. + * @param userdata Userdata returned from _iterate_event_coord function. + */ +typedef hsa_status_t (*hsa_ven_amd_aqlprofile_coordinate_callback_t)( + int position, int id, int extent, int coordinate, const char *name, + void *userdata); + +/** + * @brief Iterate over all event coordinates for a given agent_t and event_t. + * @param[in] agent HSA agent. + * @param[in] event The event ID and block ID to iterate for. + * @param[in] sample_id aqlprofile_info_data_t.sample_id returned from + * _aqlprofile_iterate_data. + * @param[in] callback Callback function to return the coordinates. + * @param[in] userdata Arbitrary data pointer to be sent back to the user via + * callback. + */ +hsa_status_t hsa_ven_amd_aqlprofile_iterate_event_coord( + hsa_agent_t agent, hsa_ven_amd_aqlprofile_event_t event, uint32_t sample_id, + hsa_ven_amd_aqlprofile_coordinate_callback_t callback, void *userdata); + +/** + * @brief Extension version. + */ +#define hsa_ven_amd_aqlprofile_VERSION_MAJOR 1 +#define hsa_ven_amd_aqlprofile_LIB(suff) "libhsa-amd-aqlprofile" suff ".so" + +#ifdef HSA_LARGE_MODEL +static const char kAqlProfileLib[] = hsa_ven_amd_aqlprofile_LIB("64"); +#else +static const char kAqlProfileLib[] = hsa_ven_amd_aqlprofile_LIB(""); +#endif + +/** + * @brief Extension function table. + */ +typedef struct hsa_ven_amd_aqlprofile_1_00_pfn_s { + uint32_t (*hsa_ven_amd_aqlprofile_version_major)(); + uint32_t (*hsa_ven_amd_aqlprofile_version_minor)(); + + hsa_status_t (*hsa_ven_amd_aqlprofile_error_string)(const char **str); + + hsa_status_t (*hsa_ven_amd_aqlprofile_validate_event)( + hsa_agent_t agent, const hsa_ven_amd_aqlprofile_event_t *event, + bool *result); + + hsa_status_t (*hsa_ven_amd_aqlprofile_start)( + hsa_ven_amd_aqlprofile_profile_t *profile, + hsa_ext_amd_aql_pm4_packet_t *aql_start_packet); + + hsa_status_t (*hsa_ven_amd_aqlprofile_stop)( + const hsa_ven_amd_aqlprofile_profile_t *profile, + hsa_ext_amd_aql_pm4_packet_t *aql_stop_packet); + + hsa_status_t (*hsa_ven_amd_aqlprofile_read)( + const hsa_ven_amd_aqlprofile_profile_t *profile, + hsa_ext_amd_aql_pm4_packet_t *aql_read_packet); + + hsa_status_t (*hsa_ven_amd_aqlprofile_legacy_get_pm4)( + const hsa_ext_amd_aql_pm4_packet_t *aql_packet, void *data); + + hsa_status_t (*hsa_ven_amd_aqlprofile_get_info)( + const hsa_ven_amd_aqlprofile_profile_t *profile, + hsa_ven_amd_aqlprofile_info_type_t attribute, void *value); + + hsa_status_t (*hsa_ven_amd_aqlprofile_iterate_data)( + const hsa_ven_amd_aqlprofile_profile_t *profile, + hsa_ven_amd_aqlprofile_data_callback_t callback, void *data); + + hsa_status_t (*hsa_ven_amd_aqlprofile_iterate_event_ids)( + hsa_ven_amd_aqlprofile_eventname_callback_t); + + hsa_status_t (*hsa_ven_amd_aqlprofile_iterate_event_coord)( + hsa_agent_t agent, hsa_ven_amd_aqlprofile_event_t event, + uint32_t sample_id, hsa_ven_amd_aqlprofile_coordinate_callback_t callback, + void *userdata); + + hsa_status_t (*hsa_ven_amd_aqlprofile_att_marker)( + hsa_ven_amd_aqlprofile_profile_t *profile, + hsa_ext_amd_aql_pm4_packet_t *aql_packet, uint32_t data, + hsa_ven_amd_aqlprofile_att_marker_channel_t channel); +} hsa_ven_amd_aqlprofile_1_00_pfn_t; + +typedef hsa_ven_amd_aqlprofile_1_00_pfn_t hsa_ven_amd_aqlprofile_pfn_t; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // OPENSRC_HSA_RUNTIME_INC_HSA_VEN_AMD_AQLPROFILE_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_loader.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_loader.h new file mode 100644 index 000000000..54ddd79f2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_loader.h @@ -0,0 +1,639 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +// HSA AMD extension for additional loader functionality. + +#ifndef HSA_VEN_AMD_LOADER_H +#define HSA_VEN_AMD_LOADER_H + +#include "hsa.h" + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** + * @brief Queries equivalent host address for given @p device_address, and + * records it in @p host_address. + * + * + * @details Contents of memory pointed to by @p host_address would be identical + * to contents of memory pointed to by @p device_address. Only difference + * between the two is host accessibility: @p host_address is always accessible + * from host, @p device_address might not be accessible from host. + * + * If @p device_address already points to host accessible memory, then the value + * of @p device_address is simply copied into @p host_address. + * + * The lifetime of @p host_address is the same as the lifetime of @p + * device_address, and both lifetimes are limited by the lifetime of the + * executable that is managing these addresses. + * + * + * @param[in] device_address Device address to query equivalent host address + * for. + * + * @param[out] host_address Pointer to application-allocated buffer to record + * queried equivalent host address in. + * + * + * @retval HSA_STATUS_SUCCESS Function is executed successfully. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED Runtime is not initialized. + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p device_address is invalid or + * null, or @p host_address is null. + */ +hsa_status_t hsa_ven_amd_loader_query_host_address(const void *device_address, + const void **host_address); + +/** + * @brief The storage type of the code object that is backing loaded memory + * segment. + */ +typedef enum { + /** + * Loaded memory segment is not backed by any code object (anonymous), as the + * case would be with BSS (uninitialized data). + */ + HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_NONE = 0, + /** + * Loaded memory segment is backed by the code object that is stored in the + * file. + */ + HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_FILE = 1, + /** + * Loaded memory segment is backed by the code object that is stored in the + * memory. + */ + HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_MEMORY = 2 +} hsa_ven_amd_loader_code_object_storage_type_t; + +/** + * @brief Loaded memory segment descriptor. + * + * + * @details Loaded memory segment descriptor describes underlying loaded memory + * segment. Loaded memory segment is created/allocated by the executable during + * the loading of the code object that is backing underlying memory segment. + * + * The lifetime of underlying memory segment is limited by the lifetime of the + * executable that is managing underlying memory segment. + */ +typedef struct hsa_ven_amd_loader_segment_descriptor_s { + /** + * Agent underlying memory segment is allocated on. If the code object that is + * backing underlying memory segment is program code object, then 0. + */ + hsa_agent_t agent; + /** + * Executable that is managing this underlying memory segment. + */ + hsa_executable_t executable; + /** + * Storage type of the code object that is backing underlying memory segment. + */ + hsa_ven_amd_loader_code_object_storage_type_t code_object_storage_type; + /** + * If the storage type of the code object that is backing underlying memory + * segment is: + * - HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_NONE, then null; + * - HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_FILE, then null-terminated + * filepath to the code object; + * - HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_MEMORY, then host + * accessible pointer to the first byte of the code object. + */ + const void *code_object_storage_base; + /** + * If the storage type of the code object that is backing underlying memory + * segment is: + * - HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_NONE, then 0; + * - HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_FILE, then the length of + * the filepath to the code object (including null-terminating character); + * - HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_MEMORY, then the size, in + * bytes, of the memory occupied by the code object. + */ + size_t code_object_storage_size; + /** + * If the storage type of the code object that is backing underlying memory + * segment is: + * - HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_NONE, then 0; + * - other, then offset, in bytes, from the beginning of the code object to + * the first byte in the code object data is copied from. + */ + size_t code_object_storage_offset; + /** + * Starting address of the underlying memory segment. + */ + const void *segment_base; + /** + * Size, in bytes, of the underlying memory segment. + */ + size_t segment_size; +} hsa_ven_amd_loader_segment_descriptor_t; + +/** + * @brief Either queries loaded memory segment descriptors, or total number of + * loaded memory segment descriptors. + * + * + * @details If @p segment_descriptors is not null and @p num_segment_descriptors + * points to number that exactly matches total number of loaded memory segment + * descriptors, then queries loaded memory segment descriptors, and records them + * in @p segment_descriptors. If @p segment_descriptors is null and @p + * num_segment_descriptors points to zero, then queries total number of loaded + * memory segment descriptors, and records it in @p num_segment_descriptors. In + * all other cases returns appropriate error code (see below). + * + * The caller of this function is responsible for the allocation/deallocation + * and the lifetime of @p segment_descriptors and @p num_segment_descriptors. + * + * The lifetime of loaded memory segments that are described by queried loaded + * memory segment descriptors is limited by the lifetime of the executable that + * is managing loaded memory segments. + * + * Queried loaded memory segment descriptors are always self-consistent: they + * describe a complete set of loaded memory segments that are being backed by + * fully loaded code objects that are present at the time (i.e. this function + * is blocked until all executable manipulations are fully complete). + * + * + * @param[out] segment_descriptors Pointer to application-allocated buffer to + * record queried loaded memory segment descriptors in. Can be null if @p + * num_segment_descriptors points to zero. + * + * @param[in,out] num_segment_descriptors Pointer to application-allocated + * buffer that contains either total number of loaded memory segment descriptors + * or zero. + * + * + * @retval HSA_STATUS_SUCCESS Function is executed successfully. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED Runtime is not initialized. + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT @p segment_descriptors is null + * while @p num_segment_descriptors points to non-zero number, @p + * segment_descriptors is not null while @p num_segment_descriptors points to + * zero, or @p num_segment_descriptors is null. + * + * @retval HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS @p num_segment_descriptors + * does not point to number that exactly matches total number of loaded memory + * segment descriptors. + */ +hsa_status_t hsa_ven_amd_loader_query_segment_descriptors( + hsa_ven_amd_loader_segment_descriptor_t *segment_descriptors, + size_t *num_segment_descriptors); + +/** + * @brief Obtains the handle of executable to which the device address belongs. + * + * @details This method should not be used to obtain executable handle by using + * a host address. The executable returned is expected to be alive until its + * destroyed by the user. + * + * @retval HSA_STATUS_SUCCESS Function is executed successfully. + * + * @retval HSA_STATUS_ERROR_NOT_INITIALIZED Runtime is not initialized. + * + * @retval HSA_STATUS_ERROR_INVALID_ARGUMENT The input is invalid or there + * is no exectuable found for this kernel code object. + */ +hsa_status_t hsa_ven_amd_loader_query_executable(const void *device_address, + hsa_executable_t *executable); + +//===----------------------------------------------------------------------===// + +/** + * @brief Iterate over the loaded code objects in an executable, and invoke + * an application-defined callback on every iteration. + * + * @param[in] executable Executable. + * + * @param[in] callback Callback to be invoked once per loaded code object. The + * HSA runtime passes three arguments to the callback: the executable, a + * loaded code object, and the application data. If @p callback returns a + * status other than ::HSA_STATUS_SUCCESS for a particular iteration, the + * traversal stops and + * ::hsa_ven_amd_loader_executable_iterate_loaded_code_objects returns that + * status value. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_EXECUTABLE The executable is invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t hsa_ven_amd_loader_executable_iterate_loaded_code_objects( + hsa_executable_t executable, + hsa_status_t (*callback)(hsa_executable_t executable, + hsa_loaded_code_object_t loaded_code_object, + void *data), + void *data); + +/** + * @brief Loaded code object kind. + */ +typedef enum { + /** + * Program code object. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_KIND_PROGRAM = 1, + /** + * Agent code object. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_KIND_AGENT = 2 +} hsa_ven_amd_loader_loaded_code_object_kind_t; + +/** + * @brief Loaded code object attributes. + */ +typedef enum hsa_ven_amd_loader_loaded_code_object_info_e { + /** + * The executable in which this loaded code object is loaded. The + * type of this attribute is ::hsa_executable_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_EXECUTABLE = 1, + /** + * The kind of this loaded code object. The type of this attribute is + * ::uint32_t interpreted as ::hsa_ven_amd_loader_loaded_code_object_kind_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_KIND = 2, + /** + * The agent on which this loaded code object is loaded. The + * value of this attribute is only defined if + * ::HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_KIND is + * ::HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_KIND_AGENT. The type of this + * attribute is ::hsa_agent_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_AGENT = 3, + /** + * The storage type of the code object reader used to load the loaded code + * object. The type of this attribute is ::uint32_t interpreted as a + * ::hsa_ven_amd_loader_code_object_storage_type_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_TYPE = 4, + /** + * The memory address of the first byte of the code object that was loaaded. + * The value of this attribute is only defined if + * ::HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_TYPE is + * ::HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_MEMORY. The type of this + * attribute is ::uint64_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_MEMORY_BASE = + 5, + /** + * The memory size in bytes of the code object that was loaaded. + * The value of this attribute is only defined if + * ::HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_TYPE is + * ::HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_MEMORY. The type of this + * attribute is ::uint64_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_MEMORY_SIZE = + 6, + /** + * The file descriptor of the code object that was loaaded. + * The value of this attribute is only defined if + * ::HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_TYPE is + * ::HSA_VEN_AMD_LOADER_CODE_OBJECT_STORAGE_TYPE_FILE. The type of this + * attribute is ::int. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_CODE_OBJECT_STORAGE_FILE = 7, + /** + * The signed byte address difference of the memory address at which the code + * object is loaded minus the virtual address specified in the code object + * that is loaded. The value of this attribute is only defined if the + * executable in which the code object is loaded is froozen. The type of this + * attribute is ::int64_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_LOAD_DELTA = 8, + /** + * The base memory address at which the code object is loaded. This is the + * base address of the allocation for the lowest addressed segment of the code + * object that is loaded. Note that any non-loaded segments before the first + * loaded segment are ignored. The value of this attribute is only defined if + * the executable in which the code object is loaded is froozen. The type of + * this attribute is ::uint64_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_LOAD_BASE = 9, + /** + * The byte size of the loaded code objects contiguous memory allocation. The + * value of this attribute is only defined if the executable in which the code + * object is loaded is froozen. The type of this attribute is ::uint64_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_LOAD_SIZE = 10, + /** + * The length of the URI in bytes, not including the NUL terminator. The type + * of this attribute is uint32_t. + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_URI_LENGTH = 11, + /** + * The URI name from which the code object was loaded. The type of this + * attribute is a NUL terminated \p char* with the length equal to the value + * of ::HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_URI_LENGTH attribute. + * The URI name syntax is defined by the following BNF syntax: + * + * code_object_uri ::== file_uri | memory_uri + * file_uri ::== "file://" file_path [ range_specifier ] + * memory_uri ::== "memory://" process_id range_specifier + * range_specifier ::== [ "#" | "?" ] "offset=" number "&" "size=" number + * file_path ::== URI_ENCODED_OS_FILE_PATH + * process_id ::== DECIMAL_NUMBER + * number ::== HEX_NUMBER | DECIMAL_NUMBER | OCTAL_NUMBER + * + * ``number`` is a C integral literal where hexadecimal values are prefixed by + * "0x" or "0X", and octal values by "0". + * + * ``file_path`` is the file's path specified as a URI encoded UTF-8 string. + * In URI encoding, every character that is not in the regular expression + * ``[a-zA-Z0-9/_.~-]`` is encoded as two uppercase hexidecimal digits + * proceeded by "%". Directories in the path are separated by "/". + * + * ``offset`` is a 0-based byte offset to the start of the code object. For a + * file URI, it is from the start of the file specified by the ``file_path``, + * and if omitted defaults to 0. For a memory URI, it is the memory address + * and is required. + * + * ``size`` is the number of bytes in the code object. For a file URI, if + * omitted it defaults to the size of the file. It is required for a memory + * URI. + * + * ``process_id`` is the identity of the process owning the memory. For Linux + * it is the C unsigned integral decimal literal for the process ID (PID). + * + * For example: + * + * file:///dir1/dir2/file1 + * file:///dir3/dir4/file2#offset=0x2000&size=3000 + * memory://1234#offset=0x20000&size=3000 + */ + HSA_VEN_AMD_LOADER_LOADED_CODE_OBJECT_INFO_URI = 12, +} hsa_ven_amd_loader_loaded_code_object_info_t; + +/** + * @brief Get the current value of an attribute for a given loaded code + * object. + * + * @param[in] loaded_code_object Loaded code object. + * + * @param[in] attribute Attribute to query. + * + * @param[out] value Pointer to an application-allocated buffer where to store + * the value of the attribute. If the buffer passed by the application is not + * large enough to hold the value of @p attribute, the behavior is undefined. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT The loaded code object is + * invalid. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p attribute is an invalid + * loaded code object attribute, or @p value is NULL. + */ +hsa_status_t hsa_ven_amd_loader_loaded_code_object_get_info( + hsa_loaded_code_object_t loaded_code_object, + hsa_ven_amd_loader_loaded_code_object_info_t attribute, void *value); + +//===----------------------------------------------------------------------===// + +/** + * @brief Create a code object reader to operate on a file with size and offset. + * + * @param[in] file File descriptor. The file must have been opened by + * application with at least read permissions prior calling this function. The + * file must contain a vendor-specific code object. + * + * The file is owned and managed by the application; the lifetime of the file + * descriptor must exceed that of any associated code object reader. + * + * @param[in] size Size of the code object embedded in @p file. + * + * @param[in] offset 0-based offset relative to the beginning of the @p file + * that denotes the beginning of the code object embedded within the @p file. + * + * @param[out] code_object_reader Memory location to store the newly created + * code object reader handle. Must not be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_FILE @p file is not opened with at least + * read permissions. This condition may also be reported as + * ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER by the + * ::hsa_executable_load_agent_code_object function. + * + * @retval ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT The bytes starting at offset + * do not form a valid code object. If file size is 0. Or offset > file size. + * This condition may also be reported as + * ::HSA_STATUS_ERROR_INVALID_CODE_OBJECT by the + * ::hsa_executable_load_agent_code_object function. + * + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES The HSA runtime failed to + * allocate the required resources. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p code_object_reader is NULL. + */ +hsa_status_t +hsa_ven_amd_loader_code_object_reader_create_from_file_with_offset_size( + hsa_file_t file, size_t offset, size_t size, + hsa_code_object_reader_t *code_object_reader); + +//===----------------------------------------------------------------------===// + +/** + * @brief Iterate over the available executables, and invoke an + * application-defined callback on every iteration. While + * ::hsa_ven_amd_loader_iterate_executables is executing any calls to + * ::hsa_executable_create, ::hsa_executable_create_alt, or + * ::hsa_executable_destroy will be blocked. + * + * @param[in] callback Callback to be invoked once per executable. The HSA + * runtime passes two arguments to the callback: the executable and the + * application data. If @p callback returns a status other than + * ::HSA_STATUS_SUCCESS for a particular iteration, the traversal stops and + * ::hsa_ven_amd_loader_iterate_executables returns that status value. If + * @p callback invokes ::hsa_executable_create, ::hsa_executable_create_alt, or + * ::hsa_executable_destroy then the behavior is undefined. + * + * @param[in] data Application data that is passed to @p callback on every + * iteration. May be NULL. + * + * @retval ::HSA_STATUS_SUCCESS The function has been executed successfully. + * + * @retval ::HSA_STATUS_ERROR_NOT_INITIALIZED The HSA runtime has not been + * initialized. + * + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT @p callback is NULL. + */ +hsa_status_t hsa_ven_amd_loader_iterate_executables( + hsa_status_t (*callback)(hsa_executable_t executable, void *data), + void *data); + +//===----------------------------------------------------------------------===// + +/** + * @brief Extension version. + */ +#define hsa_ven_amd_loader 001003 + +/** + * @brief Extension function table version 1.00. + */ +typedef struct hsa_ven_amd_loader_1_00_pfn_s { + hsa_status_t (*hsa_ven_amd_loader_query_host_address)( + const void *device_address, const void **host_address); + + hsa_status_t (*hsa_ven_amd_loader_query_segment_descriptors)( + hsa_ven_amd_loader_segment_descriptor_t *segment_descriptors, + size_t *num_segment_descriptors); + + hsa_status_t (*hsa_ven_amd_loader_query_executable)( + const void *device_address, hsa_executable_t *executable); +} hsa_ven_amd_loader_1_00_pfn_t; + +/** + * @brief Extension function table version 1.01. + */ +typedef struct hsa_ven_amd_loader_1_01_pfn_s { + hsa_status_t (*hsa_ven_amd_loader_query_host_address)( + const void *device_address, const void **host_address); + + hsa_status_t (*hsa_ven_amd_loader_query_segment_descriptors)( + hsa_ven_amd_loader_segment_descriptor_t *segment_descriptors, + size_t *num_segment_descriptors); + + hsa_status_t (*hsa_ven_amd_loader_query_executable)( + const void *device_address, hsa_executable_t *executable); + + hsa_status_t (*hsa_ven_amd_loader_executable_iterate_loaded_code_objects)( + hsa_executable_t executable, + hsa_status_t (*callback)(hsa_executable_t executable, + hsa_loaded_code_object_t loaded_code_object, + void *data), + void *data); + + hsa_status_t (*hsa_ven_amd_loader_loaded_code_object_get_info)( + hsa_loaded_code_object_t loaded_code_object, + hsa_ven_amd_loader_loaded_code_object_info_t attribute, void *value); +} hsa_ven_amd_loader_1_01_pfn_t; + +/** + * @brief Extension function table version 1.02. + */ +typedef struct hsa_ven_amd_loader_1_02_pfn_s { + hsa_status_t (*hsa_ven_amd_loader_query_host_address)( + const void *device_address, const void **host_address); + + hsa_status_t (*hsa_ven_amd_loader_query_segment_descriptors)( + hsa_ven_amd_loader_segment_descriptor_t *segment_descriptors, + size_t *num_segment_descriptors); + + hsa_status_t (*hsa_ven_amd_loader_query_executable)( + const void *device_address, hsa_executable_t *executable); + + hsa_status_t (*hsa_ven_amd_loader_executable_iterate_loaded_code_objects)( + hsa_executable_t executable, + hsa_status_t (*callback)(hsa_executable_t executable, + hsa_loaded_code_object_t loaded_code_object, + void *data), + void *data); + + hsa_status_t (*hsa_ven_amd_loader_loaded_code_object_get_info)( + hsa_loaded_code_object_t loaded_code_object, + hsa_ven_amd_loader_loaded_code_object_info_t attribute, void *value); + + hsa_status_t ( + *hsa_ven_amd_loader_code_object_reader_create_from_file_with_offset_size)( + hsa_file_t file, size_t offset, size_t size, + hsa_code_object_reader_t *code_object_reader); +} hsa_ven_amd_loader_1_02_pfn_t; + +/** + * @brief Extension function table version 1.03. + */ +typedef struct hsa_ven_amd_loader_1_03_pfn_s { + hsa_status_t (*hsa_ven_amd_loader_query_host_address)( + const void *device_address, const void **host_address); + + hsa_status_t (*hsa_ven_amd_loader_query_segment_descriptors)( + hsa_ven_amd_loader_segment_descriptor_t *segment_descriptors, + size_t *num_segment_descriptors); + + hsa_status_t (*hsa_ven_amd_loader_query_executable)( + const void *device_address, hsa_executable_t *executable); + + hsa_status_t (*hsa_ven_amd_loader_executable_iterate_loaded_code_objects)( + hsa_executable_t executable, + hsa_status_t (*callback)(hsa_executable_t executable, + hsa_loaded_code_object_t loaded_code_object, + void *data), + void *data); + + hsa_status_t (*hsa_ven_amd_loader_loaded_code_object_get_info)( + hsa_loaded_code_object_t loaded_code_object, + hsa_ven_amd_loader_loaded_code_object_info_t attribute, void *value); + + hsa_status_t ( + *hsa_ven_amd_loader_code_object_reader_create_from_file_with_offset_size)( + hsa_file_t file, size_t offset, size_t size, + hsa_code_object_reader_t *code_object_reader); + + hsa_status_t (*hsa_ven_amd_loader_iterate_executables)( + hsa_status_t (*callback)(hsa_executable_t executable, void *data), + void *data); +} hsa_ven_amd_loader_1_03_pfn_t; + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif /* HSA_VEN_AMD_LOADER_H */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_pc_sampling.h b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_pc_sampling.h new file mode 100644 index 000000000..bcd5b884e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/hsa/hsa_ven_amd_pc_sampling.h @@ -0,0 +1,443 @@ +//////////////////////////////////////////////////////////////////////////////// +// +// The University of Illinois/NCSA +// Open Source License (NCSA) +// +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// +// Developed by: +// +// AMD Research and AMD HSA Software Development +// +// Advanced Micro Devices, Inc. +// +// www.amd.com +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimers in +// the documentation and/or other materials provided with the distribution. +// - Neither the names of Advanced Micro Devices, Inc, +// nor the names of its contributors may be used to endorse or promote +// products derived from this Software without specific prior written +// permission. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS WITH THE SOFTWARE. +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef HSA_VEN_AMD_PC_SAMPLING_H +#define HSA_VEN_AMD_PC_SAMPLING_H + +#include "hsa.h" + +#ifdef __cplusplus +extern "C" { +#endif /*__cplusplus*/ + +/** + * @brief HSA AMD Vendor PC Sampling APIs + * EXPERIMENTAL: All PC Sampling APIs are currently in an experimental phase and + * the APIs may be modified extensively in the future + */ + +/** + * @brief PC Sampling sample data for hosttrap sampling method + */ +typedef struct { + uint64_t pc; + uint64_t exec_mask; + uint32_t workgroup_id_x; + uint32_t workgroup_id_y; + uint32_t workgroup_id_z; + uint32_t wave_in_wg : 6; + uint32_t chiplet : 3; // Currently not used + uint32_t reserved : 23; + uint32_t hw_id; + uint32_t reserved0; + uint64_t reserved1; + uint64_t timestamp; + uint64_t correlation_id; +} perf_sample_hosttrap_v1_t; + +/** + * @brief PC Sampling sample data for stochastic sampling method + */ +typedef struct { + uint64_t pc; + uint64_t exec_mask; + uint32_t workgroup_id_x; + uint32_t workgroup_id_y; + uint32_t workgroup_id_z; + uint32_t wave_in_wg : 6; + uint32_t chiplet : 3; // Currently not used + uint32_t reserved : 23; + uint32_t hw_id; + uint32_t perf_snapshot_data; + uint32_t perf_snapshot_data1; + uint32_t perf_snapshot_data2; + uint64_t timestamp; + uint64_t correlation_id; +} perf_sample_snapshot_v1_t; + +/** + * @brief PC Sampling method kinds + */ +typedef enum { + HSA_VEN_AMD_PCS_METHOD_HOSTTRAP_V1, + HSA_VEN_AMD_PCS_METHOD_STOCHASTIC_V1 +} hsa_ven_amd_pcs_method_kind_t; + +/** + * @brief PC Sampling interval unit type + */ +typedef enum { + HSA_VEN_AMD_PCS_INTERVAL_UNITS_MICRO_SECONDS, + HSA_VEN_AMD_PCS_INTERVAL_UNITS_CLOCK_CYCLES, + HSA_VEN_AMD_PCS_INTERVAL_UNITS_INSTRUCTIONS +} hsa_ven_amd_pcs_units_t; + +/** + * @brief HSA callback function to perform the copy onto a destination buffer + * + * If data_size is 0, HSA will stop current copy operation and keep remaining + * data in internal buffers. Remaining contents of HSA internal buffers will be + * included in next hsa_ven_amd_pcs_data_ready_callback_t. HSA internal buffers + * can also be drained by calling hsa_ven_amd_pcs_flush. + * + * @param[in] hsa_callback_data private data to pass back to HSA. Provided in + * hsa_ven_amd_pcs_data_ready_callback_t + * + * @param[in] data_size size of destination buffer in bytes. + * @param[in] destination destination buffer + * @retval TBD: but could be used to indicate that there is no more data to + * be read. Or indicate an error and abort of current copy operations + */ +typedef hsa_status_t (*hsa_ven_amd_pcs_data_copy_callback_t)( + void *hsa_callback_data, size_t data_size, void *destination); + +/** + * @brief HSA callback function to to indicate that there is data ready to be + * copied + * + * When the client receives this callback, the client should call back @p + * data_copy_callback for HSA to perform the copy operation into an available + * buffer. @p data_copy_callback can be called back multiple times with smaller + * @p data_size to split the copy operation. + * + * This callback must not call ::hsa_ven_amd_pcs_flush. + * + * @param[in] client_callback_data client private data passed in via + * hsa_ven_amd_pcs_create/hsa_ven_amd_pcs_create_from_id + * @param[in] data_size size of data available to be copied + * @param[in] lost_sample_count number of lost samples since last call to + * hsa_ven_amd_pcs_data_ready_callback_t. + * @param[in] data_copy_callback callback function for HSA to perform the actual + * copy + * @param[in] hsa_callback_data private data to pass back to HSA + */ +typedef void (*hsa_ven_amd_pcs_data_ready_callback_t)( + void *client_callback_data, size_t data_size, size_t lost_sample_count, + hsa_ven_amd_pcs_data_copy_callback_t data_copy_callback, + void *hsa_callback_data); + +/** + * @brief Opaque handle representing a sampling session. + * Two sessions having same handle value represent the same session + */ +typedef struct { + uint64_t handle; +} hsa_ven_amd_pcs_t; + +/** + * @brief PC Sampling configuration flag options + */ +typedef enum { + /* The interval for this sampling method have to be a power of 2 */ + HSA_VEN_AMD_PCS_CONFIGURATION_FLAGS_INTERVAL_POWER_OF_2 = (1 << 0) +} hsa_ven_amd_pcs_configuration_flags_t; + +/** + * @brief PC Sampling method information + * Used to provide client with list of supported PC Sampling methods + */ +typedef struct { + hsa_ven_amd_pcs_method_kind_t method; + hsa_ven_amd_pcs_units_t units; + size_t min_interval; + size_t max_interval; + uint64_t flags; +} hsa_ven_amd_pcs_configuration_t; + +/** + * @brief Callback function to iterate through list of supported PC Sampling + * configurations + * + * @param[in] configuration one entry for supported PC Sampling method and + * configuration options + * @param[in] callback_data client private callback data that was passed in when + * calling hsa_ven_amd_pcs_iterate_configuration + */ +typedef hsa_status_t (*hsa_ven_amd_pcs_iterate_configuration_callback_t)( + const hsa_ven_amd_pcs_configuration_t *configuration, void *callback_data); + +/** + * @brief Iterate through list of current supported PC Sampling configurations + *for this @p agent + * + * HSA will callback @p configuration_callback for each currently available PC + *Sampling configuration. The list of currently available configurations may not + *be the complete list of configurations supported on the @p agent. The list of + *currently available configurations may be reduced if the @p agent is currently + *handling other PC sampling sessions. + * + * @param[in] agent target agent + * @param[in] configuration_callback callback function to iterate through list + *of configurations + * @param[in] callback_data client private callback data + **/ +hsa_status_t hsa_ven_amd_pcs_iterate_configuration( + hsa_agent_t agent, + hsa_ven_amd_pcs_iterate_configuration_callback_t configuration_callback, + void *callback_data); + +/** + * @brief Create a PC Sampling session on @p agent + * + * Allocate the resources required for a PC Sampling session. The @p method, @p + *units, @p interval parameters must be a legal configuration value, as + *described by the hsa_ven_amd_pcs_configuration_t configurations passed to the + *callbacks of hsa_ven_amd_pcs_iterate_configuration for this @p agent. A + *successfull call may restrict the list of possible PC sampling methods + *available to subsequent calls to hsa_ven_amd_pcs_iterate_configuration on the + *same agent as agents have limitations on what types of PC sampling they can + *perform concurrently. For all successful calls, hsa_ven_amd_pcs_destroy should + *be called to free this session. The session will be in a stopped/inactive + *state after this call + * + * @param[in] agent target agent + * @param[in] method method to use + * @param[in] units sampling units + * @param[in] interval sampling interval in @p units + * @param[in] latency expected latency in microseconds for client to provide a + *buffer for the data copy callback once HSA calls @p data_ready_callback. This + *is a performance hint to avoid the buffer filling up before the client is + *notified that data is ready. HSA-runtime will estimate how many samples are + *received within @p latency and call @p data_ready_callback ahead of time so + * that the client has @p latency time to allocate the buffer before the + *HSA-runtime internal buffers are full. The value of latency can be 0. + * @param[in] buffer_size size of client buffer in bytes. @p data_ready_callback + *will be called once HSA-runtime has enough samples to fill @p buffer_size. + *This needs to be a multiple of size of perf_sample_hosttrap_v1_t or size of + *perf_sample_snapshot_v1_t. + * @param[in] data_ready_callback client callback function that will be called + *when: + * 1. There is enough samples fill a buffer with @p buffer_size - estimated + *samples received within @p latency period. OR + * 2. When hsa_ven_amd_pcs_flush is called. + * @param[in] client_callback_data client private data to be provided back when + *data_ready_callback is called. + * @param[out] pc_sampling PC sampling session handle used to reference this + *session when calling hsa_ven_amd_pcs_start, hsa_ven_amd_pcs_stop, + *hsa_ven_amd_pcs_destroy + * + * @retval ::HSA_STATUS_SUCCESS session created successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT invalid parameters + * @retval ::HSA_STATUS_ERROR_RESOURCE_BUSY agent currently handling another PC + *Sampling session and cannot handle the type requested. + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Failed to allocate resources + * @retval ::HSA_STATUS_ERROR Unexpected error + **/ +hsa_status_t hsa_ven_amd_pcs_create( + hsa_agent_t agent, hsa_ven_amd_pcs_method_kind_t method, + hsa_ven_amd_pcs_units_t units, size_t interval, size_t latency, + size_t buffer_size, + hsa_ven_amd_pcs_data_ready_callback_t data_ready_callback, + void *client_callback_data, hsa_ven_amd_pcs_t *pc_sampling); + +/** + * @brief Creates a PC Sampling session on @p agent. Assumes that the caller + *provides the + * @p pcs_id generated by the previous call to the underlying driver that + *reserved PC sampling on the @p agent. + * + * Similar to the @ref hsa_ven_amd_pcs_create with the difference that it + *inherits an existing PC sampling session that was previously created in the + *underlying driver. + * + * Allocate the resources required for a PC Sampling session. The @p method, @p + *units, @p interval parameters must be a legal configuration value, and match + *the parameters that we used to create the underlying PC Sampling session in + *the underlying driver. A successfull call may restrict the list of possible PC + *sampling methods available to subsequent calls to + *hsa_ven_amd_pcs_iterate_configuration on the same agent as agents have + *limitations on what types of PC sampling they can perform concurrently. For + *all successful calls, hsa_ven_amd_pcs_destroy should be called to free this + *session. The session will be in a stopped/inactive state after this call + * + * @param[in] pcs_id ID that uniquely identifies the PC sampling session within + *underlying driver + * @param[in] agent target agent + * @param[in] method method to use + * @param[in] units sampling units + * @param[in] interval sampling interval in @p units + * @param[in] latency expected latency in microseconds for client to provide a + *buffer for the data copy callback once HSA calls @p data_ready_callback. This + *is a performance hint to avoid the buffer filling up before the client is + *notified that data is ready. HSA-runtime will estimate how many samples are + *received within @p latency and call @p data_ready_callback ahead of time so + * that the client has @p latency time to allocate the buffer before the + *HSA-runtime internal buffers are full. The value of latency can be 0. + * @param[in] buffer_size size of client buffer in bytes. @p data_ready_callback + *will be called once HSA-runtime has enough samples to fill @p buffer_size. + *This needs to be a multiple of size of perf_sample_hosttrap_v1_t or size of + *perf_sample_snapshot_v1_t. + * @param[in] data_ready_callback client callback function that will be called + *when: + * 1. There is enough samples fill a buffer with @p buffer_size - estimated + *samples received within @p latency period. OR + * 2. When hsa_ven_amd_pcs_flush is called. + * @param[in] client_callback_data client private data to be provided back when + *data_ready_callback is called. + * @param[out] pc_sampling PC sampling session handle used to reference this + *session when calling hsa_ven_amd_pcs_start, hsa_ven_amd_pcs_stop, + *hsa_ven_amd_pcs_destroy + * + * @retval ::HSA_STATUS_SUCCESS session created successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT invalid parameters + * @retval ::HSA_STATUS_ERROR_RESOURCE_BUSY agent currently handling another PC + *Sampling session and cannot handle the type requested. + * @retval ::HSA_STATUS_ERROR_OUT_OF_RESOURCES Failed to allocate resources + * @retval ::HSA_STATUS_ERROR Unexpected error + **/ +hsa_status_t hsa_ven_amd_pcs_create_from_id( + uint32_t pcs_id, hsa_agent_t agent, hsa_ven_amd_pcs_method_kind_t method, + hsa_ven_amd_pcs_units_t units, size_t interval, size_t latency, + size_t buffer_size, + hsa_ven_amd_pcs_data_ready_callback_t data_ready_callback, + void *client_callback_data, hsa_ven_amd_pcs_t *pc_sampling); + +/** + * @brief Free a PC Sampling session on @p agent + * + * Free all the resources allocated for a PC Sampling session on @p agent + * Internal buffers for this session will be lost. + * If the session was active, the session will be stopped before it is + * destroyed. + * + * @param[in] pc_sampling PC sampling session handle + * + * @retval ::HSA_STATUS_SUCCESS Session destroyed successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid PC sampling handle + * @retval ::HSA_STATUS_ERROR unexpected error + */ +hsa_status_t hsa_ven_amd_pcs_destroy(hsa_ven_amd_pcs_t pc_sampling); + +/** + * @brief Start a PC Sampling session + * + * Activate a PC Sampling session that was previous created. + * The session with be in a active state after this call + * If the session was already active, this will result in a no-op and will + * return HSA_STATUS_SUCCESS + * + * @param[in] pc_sampling PC sampling session handle + * + * @retval ::HSA_STATUS_SUCCESS Session started successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid PC sampling handle + * @retval ::HSA_STATUS_ERROR unexpected error + */ +hsa_status_t hsa_ven_amd_pcs_start(hsa_ven_amd_pcs_t pc_sampling); + +/** + * @brief Stop a PC Sampling session + * + * Stop a session that is currently active + * After a session is stopped HSA may still have some PC Sampling data in its + * internal buffers. The internal buffers can be drained using + * hsa_ven_amd_pcs_flush. If the internal buffers are not drained and the + * session is started again, the internal buffers will be available on the next + * data_ready_callback. If the session was already inactive, this will result in + * a no-op and will return HSA_STATUS_SUCCESS + * + * @param[in] pc_sampling PC sampling session handle + * + * @retval ::HSA_STATUS_SUCCESS Session stopped successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid PC sampling handle + */ +hsa_status_t hsa_ven_amd_pcs_stop(hsa_ven_amd_pcs_t pc_sampling); + +/** + * @brief Flush internal buffers for a PC Sampling session + * + * Drain internal buffers for a PC Sampling session. If internal buffers have + * available data, this trigger a data_ready_callback. + * + * The function blocks until all PC samples associated with the @p pc_sampling + * session generated prior to the function call have been communicated by + * invocations of + * @p data_ready_callback having completed execution. + * + * @param[in] pc_sampling PC sampling session handle + * + * @retval ::HSA_STATUS_SUCCESS Session flushed successfully + * @retval ::HSA_STATUS_ERROR_INVALID_ARGUMENT Invalid PC sampling handle + */ +hsa_status_t hsa_ven_amd_pcs_flush(hsa_ven_amd_pcs_t pc_sampling); + +#define hsa_ven_amd_pc_sampling_1_00 + +/** + * @brief The function pointer table for the PC Sampling v1.00 extension. Can be + * returned by + * ::hsa_system_get_extension_table or ::hsa_system_get_major_extension_table. + */ +typedef struct hsa_ven_amd_pc_sampling_1_00_pfn_t { + hsa_status_t (*hsa_ven_amd_pcs_iterate_configuration)( + hsa_agent_t agent, + hsa_ven_amd_pcs_iterate_configuration_callback_t configuration_callback, + void *callback_data); + + hsa_status_t (*hsa_ven_amd_pcs_create)( + hsa_agent_t agent, hsa_ven_amd_pcs_method_kind_t method, + hsa_ven_amd_pcs_units_t units, size_t interval, size_t latency, + size_t buffer_size, + hsa_ven_amd_pcs_data_ready_callback_t data_ready_callback, + void *client_callback_data, hsa_ven_amd_pcs_t *pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_create_from_id)( + uint32_t pcs_id, hsa_agent_t agent, hsa_ven_amd_pcs_method_kind_t method, + hsa_ven_amd_pcs_units_t units, size_t interval, size_t latency, + size_t buffer_size, + hsa_ven_amd_pcs_data_ready_callback_t data_ready_callback, + void *client_callback_data, hsa_ven_amd_pcs_t *pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_destroy)(hsa_ven_amd_pcs_t pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_start)(hsa_ven_amd_pcs_t pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_stop)(hsa_ven_amd_pcs_t pc_sampling); + + hsa_status_t (*hsa_ven_amd_pcs_flush)(hsa_ven_amd_pcs_t pc_sampling); + +} hsa_ven_amd_pc_sampling_1_00_pfn_t; + +#ifdef __cplusplus +} // end extern "C" block +#endif /*__cplusplus*/ + +#endif /* HSA_VEN_AMD_PC_SAMPLING_H */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/ext/prof_protocol.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/ext/prof_protocol.h new file mode 100644 index 000000000..7d0574c89 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/ext/prof_protocol.h @@ -0,0 +1,108 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +#ifndef EXT_PROF_PROTOCOL_H_ +#define EXT_PROF_PROTOCOL_H_ + +#include +#include + +/* Traced API domains */ +typedef enum { + ACTIVITY_DOMAIN_HSA_API = 0, /* HSA API domain */ + ACTIVITY_DOMAIN_HSA_OPS = 1, /* HSA async activity domain */ + ACTIVITY_DOMAIN_HIP_OPS = 2, /* HIP async activity domain */ + ACTIVITY_DOMAIN_HCC_OPS = + ACTIVITY_DOMAIN_HIP_OPS, /* HCC async activity domain */ + ACTIVITY_DOMAIN_HIP_VDI = + ACTIVITY_DOMAIN_HIP_OPS, /* HIP VDI async activity domain */ + ACTIVITY_DOMAIN_HIP_API = 3, /* HIP API domain */ + ACTIVITY_DOMAIN_KFD_API = 4, /* KFD API domain */ + ACTIVITY_DOMAIN_EXT_API = 5, /* External ID domain */ + ACTIVITY_DOMAIN_ROCTX = 6, /* ROCTX domain */ + ACTIVITY_DOMAIN_HSA_EVT = 7, /* HSA events */ + ACTIVITY_DOMAIN_NUMBER +} activity_domain_t; + +/* API callback type */ +typedef void (*activity_rtapi_callback_t)(uint32_t domain, uint32_t cid, + const void *data, void *arg); +typedef uint32_t activity_kind_t; +typedef uint32_t activity_op_t; + +/* API callback phase */ +typedef enum { + ACTIVITY_API_PHASE_ENTER = 0, + ACTIVITY_API_PHASE_EXIT = 1 +} activity_api_phase_t; + +/* Trace record types */ + +/* Correlation id */ +typedef uint64_t activity_correlation_id_t; + +/* Timestamp in nanoseconds */ +typedef uint64_t roctracer_timestamp_t; + +/* Activity record type */ +typedef struct activity_record_s { + uint32_t domain; /* activity domain id */ + activity_kind_t kind; /* activity kind */ + activity_op_t op; /* activity op */ + union { + struct { + activity_correlation_id_t correlation_id; /* activity ID */ + roctracer_timestamp_t begin_ns; /* host begin timestamp */ + roctracer_timestamp_t end_ns; /* host end timestamp */ + }; + struct { + uint32_t se; /* sampled SE */ + uint64_t cycle; /* sample cycle */ + uint64_t pc; /* sample PC */ + } pc_sample; + }; + union { + struct { + int device_id; /* device id */ + uint64_t queue_id; /* queue id */ + }; + struct { + uint32_t process_id; /* device id */ + uint32_t thread_id; /* thread id */ + }; + struct { + activity_correlation_id_t external_id; /* external correlation id */ + }; + }; + union { + size_t bytes; /* data size bytes */ + const char *kernel_name; /* kernel name */ + const char *mark_message; + }; +} activity_record_t; + +/* Activity sync callback type */ +typedef void (*activity_sync_callback_t)(uint32_t cid, + activity_record_t *record, + const void *data, void *arg); +/* Activity async callback type */ +typedef void (*activity_async_callback_t)(uint32_t op, void *record, void *arg); + +#endif /* EXT_PROF_PROTOCOL_H_ */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hip_ostream_ops.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hip_ostream_ops.h new file mode 100644 index 000000000..4cbfcf720 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hip_ostream_ops.h @@ -0,0 +1,4962 @@ +// automatically generated +/* +Copyright (c) 2018 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef INC_HIP_OSTREAM_OPS_H_ +#define INC_HIP_OSTREAM_OPS_H_ + +#include "roctracer.h" +#include +#include + +#ifdef __cplusplus +#include +#include +#include +#include + +namespace roctracer { +namespace hip_support { +static int HIP_depth_max = 1; +static int HIP_depth_max_cnt = 0; +static std::string HIP_structs_regex = ""; +// begin ostream ops for HIP +// basic ostream ops +namespace detail { +inline static void print_escaped_string(std::ostream &out, const char *v, + size_t len) { + out << '"'; + for (size_t i = 0; i < len && v[i]; ++i) { + switch (v[i]) { + case '\"': + out << "\\\""; + break; + case '\\': + out << "\\\\"; + break; + case '\b': + out << "\\\b"; + break; + case '\f': + out << "\\\f"; + break; + case '\n': + out << "\\\n"; + break; + case '\r': + out << "\\\r"; + break; + case '\t': + out << "\\\t"; + break; + default: + if (std::isprint((unsigned char)v[i])) + std::operator<<(out, v[i]); + else { + std::ios_base::fmtflags flags(out.flags()); + out << "\\x" << std::setfill('0') << std::setw(2) << std::hex + << (unsigned int)(unsigned char)v[i]; + out.flags(flags); + } + break; + } + } + out << '"'; +} + +template +inline static std::ostream &operator<<(std::ostream &out, const T &v) { + using std::operator<<; + static bool recursion = false; + if (recursion == false) { + recursion = true; + out << v; + recursion = false; + } + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const unsigned char &v) { + out << (unsigned int)v; + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const char &v) { + out << (unsigned char)v; + return out; +} + +template +inline static std::ostream &operator<<(std::ostream &out, const char (&v)[N]) { + print_escaped_string(out, v, N); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const char *v) { + print_escaped_string(out, v, strlen(v)); + return out; +} +// End of basic ostream ops + +inline static std::ostream &operator<<(std::ostream &out, + const __locale_struct &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("__locale_struct::__names").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "__names="); + roctracer::hip_support::detail::operator<<(out, v.__names); + std::operator<<(out, ", "); + } + if (std::string("__locale_struct::__ctype_toupper") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "__ctype_toupper="); + roctracer::hip_support::detail::operator<<(out, v.__ctype_toupper); + std::operator<<(out, ", "); + } + if (std::string("__locale_struct::__ctype_tolower") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "__ctype_tolower="); + roctracer::hip_support::detail::operator<<(out, v.__ctype_tolower); + std::operator<<(out, ", "); + } + if (std::string("__locale_struct::__ctype_b").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "__ctype_b="); + roctracer::hip_support::detail::operator<<(out, v.__ctype_b); + std::operator<<(out, ", "); + } + if (std::string("__locale_struct::__locales").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "__locales="); + roctracer::hip_support::detail::operator<<(out, v.__locales); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipDeviceArch_t &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipDeviceArch_t::hasDynamicParallelism") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasDynamicParallelism="); + roctracer::hip_support::detail::operator<<(out, v.hasDynamicParallelism); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::has3dGrid").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "has3dGrid="); + roctracer::hip_support::detail::operator<<(out, v.has3dGrid); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasSurfaceFuncs") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasSurfaceFuncs="); + roctracer::hip_support::detail::operator<<(out, v.hasSurfaceFuncs); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasSyncThreadsExt") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasSyncThreadsExt="); + roctracer::hip_support::detail::operator<<(out, v.hasSyncThreadsExt); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasThreadFenceSystem") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasThreadFenceSystem="); + roctracer::hip_support::detail::operator<<(out, v.hasThreadFenceSystem); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasFunnelShift") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasFunnelShift="); + roctracer::hip_support::detail::operator<<(out, v.hasFunnelShift); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasWarpShuffle") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasWarpShuffle="); + roctracer::hip_support::detail::operator<<(out, v.hasWarpShuffle); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasWarpBallot").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "hasWarpBallot="); + roctracer::hip_support::detail::operator<<(out, v.hasWarpBallot); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasWarpVote").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "hasWarpVote="); + roctracer::hip_support::detail::operator<<(out, v.hasWarpVote); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasDoubles").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "hasDoubles="); + roctracer::hip_support::detail::operator<<(out, v.hasDoubles); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasSharedInt64Atomics") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasSharedInt64Atomics="); + roctracer::hip_support::detail::operator<<(out, v.hasSharedInt64Atomics); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasGlobalInt64Atomics") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasGlobalInt64Atomics="); + roctracer::hip_support::detail::operator<<(out, v.hasGlobalInt64Atomics); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasFloatAtomicAdd") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasFloatAtomicAdd="); + roctracer::hip_support::detail::operator<<(out, v.hasFloatAtomicAdd); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasSharedFloatAtomicExch") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasSharedFloatAtomicExch="); + roctracer::hip_support::detail::operator<<(out, + v.hasSharedFloatAtomicExch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasSharedInt32Atomics") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasSharedInt32Atomics="); + roctracer::hip_support::detail::operator<<(out, v.hasSharedInt32Atomics); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasGlobalFloatAtomicExch") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasGlobalFloatAtomicExch="); + roctracer::hip_support::detail::operator<<(out, + v.hasGlobalFloatAtomicExch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceArch_t::hasGlobalInt32Atomics") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hasGlobalInt32Atomics="); + roctracer::hip_support::detail::operator<<(out, v.hasGlobalInt32Atomics); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const hipUUID &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipUUID::bytes").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "bytes="); + roctracer::hip_support::detail::operator<<(out, v.bytes); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipDeviceProp_tR0600 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipDeviceProp_tR0600::asicRevision") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "asicRevision="); + roctracer::hip_support::detail::operator<<(out, v.asicRevision); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::isLargeBar") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "isLargeBar="); + roctracer::hip_support::detail::operator<<(out, v.isLargeBar); + std::operator<<(out, ", "); + } + if (std::string( + "hipDeviceProp_tR0600::cooperativeMultiDeviceUnmatchedSharedMem") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceUnmatchedSharedMem="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceUnmatchedSharedMem); + std::operator<<(out, ", "); + } + if (std::string( + "hipDeviceProp_tR0600::cooperativeMultiDeviceUnmatchedBlockDim") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceUnmatchedBlockDim="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceUnmatchedBlockDim); + std::operator<<(out, ", "); + } + if (std::string( + "hipDeviceProp_tR0600::cooperativeMultiDeviceUnmatchedGridDim") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceUnmatchedGridDim="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceUnmatchedGridDim); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::cooperativeMultiDeviceUnmatchedFunc") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceUnmatchedFunc="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceUnmatchedFunc); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::hdpRegFlushCntl") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hdpRegFlushCntl="); + roctracer::hip_support::detail::operator<<(out, v.hdpRegFlushCntl); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::hdpMemFlushCntl") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hdpMemFlushCntl="); + roctracer::hip_support::detail::operator<<(out, v.hdpMemFlushCntl); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::arch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "arch="); + roctracer::hip_support::detail::operator<<(out, v.arch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::clockInstructionRate") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "clockInstructionRate="); + roctracer::hip_support::detail::operator<<(out, v.clockInstructionRate); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxSharedMemoryPerMultiProcessor") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSharedMemoryPerMultiProcessor="); + roctracer::hip_support::detail::operator<<( + out, v.maxSharedMemoryPerMultiProcessor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::gcnArchName") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "gcnArchName="); + roctracer::hip_support::detail::operator<<(out, v.gcnArchName); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::hipReserved") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hipReserved="); + roctracer::hip_support::detail::operator<<(out, v.hipReserved); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::reserved").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::unifiedFunctionPointers") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "unifiedFunctionPointers="); + roctracer::hip_support::detail::operator<<(out, + v.unifiedFunctionPointers); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::clusterLaunch") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "clusterLaunch="); + roctracer::hip_support::detail::operator<<(out, v.clusterLaunch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::ipcEventSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "ipcEventSupported="); + roctracer::hip_support::detail::operator<<(out, v.ipcEventSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::deferredMappingHipArraySupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "deferredMappingHipArraySupported="); + roctracer::hip_support::detail::operator<<( + out, v.deferredMappingHipArraySupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::memoryPoolSupportedHandleTypes") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "memoryPoolSupportedHandleTypes="); + roctracer::hip_support::detail::operator<<( + out, v.memoryPoolSupportedHandleTypes); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::gpuDirectRDMAWritesOrdering") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "gpuDirectRDMAWritesOrdering="); + roctracer::hip_support::detail::operator<<(out, + v.gpuDirectRDMAWritesOrdering); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::gpuDirectRDMAFlushWritesOptions") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "gpuDirectRDMAFlushWritesOptions="); + roctracer::hip_support::detail::operator<<( + out, v.gpuDirectRDMAFlushWritesOptions); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::gpuDirectRDMASupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "gpuDirectRDMASupported="); + roctracer::hip_support::detail::operator<<(out, v.gpuDirectRDMASupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::memoryPoolsSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "memoryPoolsSupported="); + roctracer::hip_support::detail::operator<<(out, v.memoryPoolsSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::timelineSemaphoreInteropSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "timelineSemaphoreInteropSupported="); + roctracer::hip_support::detail::operator<<( + out, v.timelineSemaphoreInteropSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::hostRegisterReadOnlySupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hostRegisterReadOnlySupported="); + roctracer::hip_support::detail::operator<<( + out, v.hostRegisterReadOnlySupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::sparseHipArraySupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "sparseHipArraySupported="); + roctracer::hip_support::detail::operator<<(out, + v.sparseHipArraySupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::hostRegisterSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hostRegisterSupported="); + roctracer::hip_support::detail::operator<<(out, v.hostRegisterSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::reservedSharedMemPerBlock") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reservedSharedMemPerBlock="); + roctracer::hip_support::detail::operator<<(out, + v.reservedSharedMemPerBlock); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::accessPolicyMaxWindowSize") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "accessPolicyMaxWindowSize="); + roctracer::hip_support::detail::operator<<(out, + v.accessPolicyMaxWindowSize); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxBlocksPerMultiProcessor") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxBlocksPerMultiProcessor="); + roctracer::hip_support::detail::operator<<(out, + v.maxBlocksPerMultiProcessor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::directManagedMemAccessFromHost") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "directManagedMemAccessFromHost="); + roctracer::hip_support::detail::operator<<( + out, v.directManagedMemAccessFromHost); + std::operator<<(out, ", "); + } + if (std::string( + "hipDeviceProp_tR0600::pageableMemoryAccessUsesHostPageTables") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "pageableMemoryAccessUsesHostPageTables="); + roctracer::hip_support::detail::operator<<( + out, v.pageableMemoryAccessUsesHostPageTables); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::sharedMemPerBlockOptin") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "sharedMemPerBlockOptin="); + roctracer::hip_support::detail::operator<<(out, v.sharedMemPerBlockOptin); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::cooperativeMultiDeviceLaunch") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceLaunch="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceLaunch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::cooperativeLaunch") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeLaunch="); + roctracer::hip_support::detail::operator<<(out, v.cooperativeLaunch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::canUseHostPointerForRegisteredMem") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "canUseHostPointerForRegisteredMem="); + roctracer::hip_support::detail::operator<<( + out, v.canUseHostPointerForRegisteredMem); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::computePreemptionSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "computePreemptionSupported="); + roctracer::hip_support::detail::operator<<(out, + v.computePreemptionSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::concurrentManagedAccess") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "concurrentManagedAccess="); + roctracer::hip_support::detail::operator<<(out, + v.concurrentManagedAccess); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::pageableMemoryAccess") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "pageableMemoryAccess="); + roctracer::hip_support::detail::operator<<(out, v.pageableMemoryAccess); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::singleToDoublePrecisionPerfRatio") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "singleToDoublePrecisionPerfRatio="); + roctracer::hip_support::detail::operator<<( + out, v.singleToDoublePrecisionPerfRatio); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::hostNativeAtomicSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hostNativeAtomicSupported="); + roctracer::hip_support::detail::operator<<(out, + v.hostNativeAtomicSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::multiGpuBoardGroupID") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "multiGpuBoardGroupID="); + roctracer::hip_support::detail::operator<<(out, v.multiGpuBoardGroupID); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::isMultiGpuBoard") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "isMultiGpuBoard="); + roctracer::hip_support::detail::operator<<(out, v.isMultiGpuBoard); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::managedMemory") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "managedMemory="); + roctracer::hip_support::detail::operator<<(out, v.managedMemory); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::regsPerMultiprocessor") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "regsPerMultiprocessor="); + roctracer::hip_support::detail::operator<<(out, v.regsPerMultiprocessor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::sharedMemPerMultiprocessor") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "sharedMemPerMultiprocessor="); + roctracer::hip_support::detail::operator<<(out, + v.sharedMemPerMultiprocessor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::localL1CacheSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "localL1CacheSupported="); + roctracer::hip_support::detail::operator<<(out, v.localL1CacheSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::globalL1CacheSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "globalL1CacheSupported="); + roctracer::hip_support::detail::operator<<(out, v.globalL1CacheSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::streamPrioritiesSupported") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "streamPrioritiesSupported="); + roctracer::hip_support::detail::operator<<(out, + v.streamPrioritiesSupported); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxThreadsPerMultiProcessor") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxThreadsPerMultiProcessor="); + roctracer::hip_support::detail::operator<<(out, + v.maxThreadsPerMultiProcessor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::persistingL2CacheMaxSize") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "persistingL2CacheMaxSize="); + roctracer::hip_support::detail::operator<<(out, + v.persistingL2CacheMaxSize); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::l2CacheSize") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "l2CacheSize="); + roctracer::hip_support::detail::operator<<(out, v.l2CacheSize); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::memoryBusWidth") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "memoryBusWidth="); + roctracer::hip_support::detail::operator<<(out, v.memoryBusWidth); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::memoryClockRate") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "memoryClockRate="); + roctracer::hip_support::detail::operator<<(out, v.memoryClockRate); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::unifiedAddressing") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "unifiedAddressing="); + roctracer::hip_support::detail::operator<<(out, v.unifiedAddressing); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::asyncEngineCount") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "asyncEngineCount="); + roctracer::hip_support::detail::operator<<(out, v.asyncEngineCount); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::tccDriver") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "tccDriver="); + roctracer::hip_support::detail::operator<<(out, v.tccDriver); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::pciDomainID") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "pciDomainID="); + roctracer::hip_support::detail::operator<<(out, v.pciDomainID); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::pciDeviceID") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "pciDeviceID="); + roctracer::hip_support::detail::operator<<(out, v.pciDeviceID); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::pciBusID").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "pciBusID="); + roctracer::hip_support::detail::operator<<(out, v.pciBusID); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::ECCEnabled") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "ECCEnabled="); + roctracer::hip_support::detail::operator<<(out, v.ECCEnabled); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::concurrentKernels") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "concurrentKernels="); + roctracer::hip_support::detail::operator<<(out, v.concurrentKernels); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::surfaceAlignment") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "surfaceAlignment="); + roctracer::hip_support::detail::operator<<(out, v.surfaceAlignment); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxSurfaceCubemapLayered") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSurfaceCubemapLayered="); + roctracer::hip_support::detail::operator<<(out, + v.maxSurfaceCubemapLayered); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxSurfaceCubemap") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSurfaceCubemap="); + roctracer::hip_support::detail::operator<<(out, v.maxSurfaceCubemap); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxSurface2DLayered") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSurface2DLayered="); + roctracer::hip_support::detail::operator<<(out, v.maxSurface2DLayered); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxSurface1DLayered") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSurface1DLayered="); + roctracer::hip_support::detail::operator<<(out, v.maxSurface1DLayered); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxSurface3D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSurface3D="); + roctracer::hip_support::detail::operator<<(out, v.maxSurface3D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxSurface2D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSurface2D="); + roctracer::hip_support::detail::operator<<(out, v.maxSurface2D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxSurface1D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSurface1D="); + roctracer::hip_support::detail::operator<<(out, v.maxSurface1D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTextureCubemapLayered") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTextureCubemapLayered="); + roctracer::hip_support::detail::operator<<(out, + v.maxTextureCubemapLayered); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture2DLayered") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture2DLayered="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture2DLayered); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture1DLayered") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture1DLayered="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture1DLayered); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTextureCubemap") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTextureCubemap="); + roctracer::hip_support::detail::operator<<(out, v.maxTextureCubemap); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture3DAlt") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture3DAlt="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture3DAlt); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture3D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture3D="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture3D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture2DGather") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture2DGather="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture2DGather); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture2DLinear") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture2DLinear="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture2DLinear); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture2DMipmap") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture2DMipmap="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture2DMipmap); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture2D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture2D="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture2D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture1DLinear") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture1DLinear="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture1DLinear); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture1DMipmap") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture1DMipmap="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture1DMipmap); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxTexture1D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture1D="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture1D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::computeMode") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "computeMode="); + roctracer::hip_support::detail::operator<<(out, v.computeMode); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::canMapHostMemory") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "canMapHostMemory="); + roctracer::hip_support::detail::operator<<(out, v.canMapHostMemory); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::integrated") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "integrated="); + roctracer::hip_support::detail::operator<<(out, v.integrated); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::kernelExecTimeoutEnabled") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "kernelExecTimeoutEnabled="); + roctracer::hip_support::detail::operator<<(out, + v.kernelExecTimeoutEnabled); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::multiProcessorCount") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "multiProcessorCount="); + roctracer::hip_support::detail::operator<<(out, v.multiProcessorCount); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::deviceOverlap") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "deviceOverlap="); + roctracer::hip_support::detail::operator<<(out, v.deviceOverlap); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::texturePitchAlignment") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "texturePitchAlignment="); + roctracer::hip_support::detail::operator<<(out, v.texturePitchAlignment); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::textureAlignment") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "textureAlignment="); + roctracer::hip_support::detail::operator<<(out, v.textureAlignment); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::minor").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "minor="); + roctracer::hip_support::detail::operator<<(out, v.minor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::major").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "major="); + roctracer::hip_support::detail::operator<<(out, v.major); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::totalConstMem") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "totalConstMem="); + roctracer::hip_support::detail::operator<<(out, v.totalConstMem); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::clockRate") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "clockRate="); + roctracer::hip_support::detail::operator<<(out, v.clockRate); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxGridSize") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxGridSize="); + roctracer::hip_support::detail::operator<<(out, v.maxGridSize); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxThreadsDim") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxThreadsDim="); + roctracer::hip_support::detail::operator<<(out, v.maxThreadsDim); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::maxThreadsPerBlock") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxThreadsPerBlock="); + roctracer::hip_support::detail::operator<<(out, v.maxThreadsPerBlock); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::memPitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "memPitch="); + roctracer::hip_support::detail::operator<<(out, v.memPitch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::warpSize").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "warpSize="); + roctracer::hip_support::detail::operator<<(out, v.warpSize); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::regsPerBlock") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "regsPerBlock="); + roctracer::hip_support::detail::operator<<(out, v.regsPerBlock); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::sharedMemPerBlock") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "sharedMemPerBlock="); + roctracer::hip_support::detail::operator<<(out, v.sharedMemPerBlock); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::totalGlobalMem") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "totalGlobalMem="); + roctracer::hip_support::detail::operator<<(out, v.totalGlobalMem); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::luidDeviceNodeMask") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "luidDeviceNodeMask="); + roctracer::hip_support::detail::operator<<(out, v.luidDeviceNodeMask); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::luid").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "luid="); + roctracer::hip_support::detail::operator<<(out, v.luid); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::uuid").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "uuid="); + roctracer::hip_support::detail::operator<<(out, v.uuid); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0600::name").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "name="); + roctracer::hip_support::detail::operator<<(out, v.name); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipPointerAttribute_t &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipPointerAttribute_t::allocationFlags") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "allocationFlags="); + roctracer::hip_support::detail::operator<<(out, v.allocationFlags); + std::operator<<(out, ", "); + } + if (std::string("hipPointerAttribute_t::isManaged") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "isManaged="); + roctracer::hip_support::detail::operator<<(out, v.isManaged); + std::operator<<(out, ", "); + } + if (std::string("hipPointerAttribute_t::device").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "device="); + roctracer::hip_support::detail::operator<<(out, v.device); + std::operator<<(out, ", "); + } + if (std::string("hipPointerAttribute_t::type").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipChannelFormatDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipChannelFormatDesc::f").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "f="); + roctracer::hip_support::detail::operator<<(out, v.f); + std::operator<<(out, ", "); + } + if (std::string("hipChannelFormatDesc::w").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("hipChannelFormatDesc::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("hipChannelFormatDesc::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("hipChannelFormatDesc::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const HIP_ARRAY_DESCRIPTOR &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("HIP_ARRAY_DESCRIPTOR::NumChannels") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "NumChannels="); + roctracer::hip_support::detail::operator<<(out, v.NumChannels); + std::operator<<(out, ", "); + } + if (std::string("HIP_ARRAY_DESCRIPTOR::Format").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Format="); + roctracer::hip_support::detail::operator<<(out, v.Format); + std::operator<<(out, ", "); + } + if (std::string("HIP_ARRAY_DESCRIPTOR::Height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Height="); + roctracer::hip_support::detail::operator<<(out, v.Height); + std::operator<<(out, ", "); + } + if (std::string("HIP_ARRAY_DESCRIPTOR::Width").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Width="); + roctracer::hip_support::detail::operator<<(out, v.Width); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const HIP_ARRAY3D_DESCRIPTOR &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("HIP_ARRAY3D_DESCRIPTOR::Flags").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Flags="); + roctracer::hip_support::detail::operator<<(out, v.Flags); + std::operator<<(out, ", "); + } + if (std::string("HIP_ARRAY3D_DESCRIPTOR::NumChannels") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "NumChannels="); + roctracer::hip_support::detail::operator<<(out, v.NumChannels); + std::operator<<(out, ", "); + } + if (std::string("HIP_ARRAY3D_DESCRIPTOR::Format").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Format="); + roctracer::hip_support::detail::operator<<(out, v.Format); + std::operator<<(out, ", "); + } + if (std::string("HIP_ARRAY3D_DESCRIPTOR::Depth").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Depth="); + roctracer::hip_support::detail::operator<<(out, v.Depth); + std::operator<<(out, ", "); + } + if (std::string("HIP_ARRAY3D_DESCRIPTOR::Height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Height="); + roctracer::hip_support::detail::operator<<(out, v.Height); + std::operator<<(out, ", "); + } + if (std::string("HIP_ARRAY3D_DESCRIPTOR::Width").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Width="); + roctracer::hip_support::detail::operator<<(out, v.Width); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hip_Memcpy2D &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hip_Memcpy2D::Height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Height="); + roctracer::hip_support::detail::operator<<(out, v.Height); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::WidthInBytes").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "WidthInBytes="); + roctracer::hip_support::detail::operator<<(out, v.WidthInBytes); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::dstPitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstPitch="); + roctracer::hip_support::detail::operator<<(out, v.dstPitch); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::dstArray").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstArray="); + roctracer::hip_support::detail::operator<<(out, v.dstArray); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::dstDevice").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstDevice="); + roctracer::hip_support::detail::operator<<(out, v.dstDevice); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::dstMemoryType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstMemoryType="); + roctracer::hip_support::detail::operator<<(out, v.dstMemoryType); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::dstY").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstY="); + roctracer::hip_support::detail::operator<<(out, v.dstY); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::dstXInBytes").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstXInBytes="); + roctracer::hip_support::detail::operator<<(out, v.dstXInBytes); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::srcPitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcPitch="); + roctracer::hip_support::detail::operator<<(out, v.srcPitch); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::srcArray").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcArray="); + roctracer::hip_support::detail::operator<<(out, v.srcArray); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::srcDevice").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcDevice="); + roctracer::hip_support::detail::operator<<(out, v.srcDevice); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::srcMemoryType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcMemoryType="); + roctracer::hip_support::detail::operator<<(out, v.srcMemoryType); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::srcY").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcY="); + roctracer::hip_support::detail::operator<<(out, v.srcY); + std::operator<<(out, ", "); + } + if (std::string("hip_Memcpy2D::srcXInBytes").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcXInBytes="); + roctracer::hip_support::detail::operator<<(out, v.srcXInBytes); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMipmappedArray &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMipmappedArray::num_channels") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "num_channels="); + roctracer::hip_support::detail::operator<<(out, v.num_channels); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::format").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "format="); + roctracer::hip_support::detail::operator<<(out, v.format); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::flags").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::max_mipmap_level") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "max_mipmap_level="); + roctracer::hip_support::detail::operator<<(out, v.max_mipmap_level); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::min_mipmap_level") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "min_mipmap_level="); + roctracer::hip_support::detail::operator<<(out, v.min_mipmap_level); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::depth").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "depth="); + roctracer::hip_support::detail::operator<<(out, v.depth); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "height="); + roctracer::hip_support::detail::operator<<(out, v.height); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::width").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "width="); + roctracer::hip_support::detail::operator<<(out, v.width); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::type").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + std::operator<<(out, ", "); + } + if (std::string("hipMipmappedArray::desc").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "desc="); + roctracer::hip_support::detail::operator<<(out, v.desc); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const HIP_TEXTURE_DESC &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("HIP_TEXTURE_DESC::reserved").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::borderColor").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "borderColor="); + roctracer::hip_support::detail::operator<<(out, v.borderColor); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::maxMipmapLevelClamp") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxMipmapLevelClamp="); + roctracer::hip_support::detail::operator<<(out, v.maxMipmapLevelClamp); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::minMipmapLevelClamp") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "minMipmapLevelClamp="); + roctracer::hip_support::detail::operator<<(out, v.minMipmapLevelClamp); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::mipmapLevelBias") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "mipmapLevelBias="); + roctracer::hip_support::detail::operator<<(out, v.mipmapLevelBias); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::mipmapFilterMode") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "mipmapFilterMode="); + roctracer::hip_support::detail::operator<<(out, v.mipmapFilterMode); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::maxAnisotropy") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxAnisotropy="); + roctracer::hip_support::detail::operator<<(out, v.maxAnisotropy); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::flags").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::filterMode").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "filterMode="); + roctracer::hip_support::detail::operator<<(out, v.filterMode); + std::operator<<(out, ", "); + } + if (std::string("HIP_TEXTURE_DESC::addressMode").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "addressMode="); + roctracer::hip_support::detail::operator<<(out, v.addressMode); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipResourceDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipResourceDesc::resType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "resType="); + roctracer::hip_support::detail::operator<<(out, v.resType); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const HIP_RESOURCE_DESC &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("HIP_RESOURCE_DESC::flags").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_DESC::resType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "resType="); + roctracer::hip_support::detail::operator<<(out, v.resType); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipResourceViewDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipResourceViewDesc::lastLayer").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "lastLayer="); + roctracer::hip_support::detail::operator<<(out, v.lastLayer); + std::operator<<(out, ", "); + } + if (std::string("hipResourceViewDesc::firstLayer") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "firstLayer="); + roctracer::hip_support::detail::operator<<(out, v.firstLayer); + std::operator<<(out, ", "); + } + if (std::string("hipResourceViewDesc::lastMipmapLevel") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "lastMipmapLevel="); + roctracer::hip_support::detail::operator<<(out, v.lastMipmapLevel); + std::operator<<(out, ", "); + } + if (std::string("hipResourceViewDesc::firstMipmapLevel") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "firstMipmapLevel="); + roctracer::hip_support::detail::operator<<(out, v.firstMipmapLevel); + std::operator<<(out, ", "); + } + if (std::string("hipResourceViewDesc::depth").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "depth="); + roctracer::hip_support::detail::operator<<(out, v.depth); + std::operator<<(out, ", "); + } + if (std::string("hipResourceViewDesc::height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "height="); + roctracer::hip_support::detail::operator<<(out, v.height); + std::operator<<(out, ", "); + } + if (std::string("hipResourceViewDesc::width").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "width="); + roctracer::hip_support::detail::operator<<(out, v.width); + std::operator<<(out, ", "); + } + if (std::string("hipResourceViewDesc::format").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "format="); + roctracer::hip_support::detail::operator<<(out, v.format); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const HIP_RESOURCE_VIEW_DESC &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("HIP_RESOURCE_VIEW_DESC::reserved") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_VIEW_DESC::lastLayer") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "lastLayer="); + roctracer::hip_support::detail::operator<<(out, v.lastLayer); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_VIEW_DESC::firstLayer") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "firstLayer="); + roctracer::hip_support::detail::operator<<(out, v.firstLayer); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_VIEW_DESC::lastMipmapLevel") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "lastMipmapLevel="); + roctracer::hip_support::detail::operator<<(out, v.lastMipmapLevel); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_VIEW_DESC::firstMipmapLevel") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "firstMipmapLevel="); + roctracer::hip_support::detail::operator<<(out, v.firstMipmapLevel); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_VIEW_DESC::depth").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "depth="); + roctracer::hip_support::detail::operator<<(out, v.depth); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_VIEW_DESC::height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "height="); + roctracer::hip_support::detail::operator<<(out, v.height); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_VIEW_DESC::width").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "width="); + roctracer::hip_support::detail::operator<<(out, v.width); + std::operator<<(out, ", "); + } + if (std::string("HIP_RESOURCE_VIEW_DESC::format").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "format="); + roctracer::hip_support::detail::operator<<(out, v.format); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipPitchedPtr &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipPitchedPtr::ysize").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "ysize="); + roctracer::hip_support::detail::operator<<(out, v.ysize); + std::operator<<(out, ", "); + } + if (std::string("hipPitchedPtr::xsize").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "xsize="); + roctracer::hip_support::detail::operator<<(out, v.xsize); + std::operator<<(out, ", "); + } + if (std::string("hipPitchedPtr::pitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "pitch="); + roctracer::hip_support::detail::operator<<(out, v.pitch); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const hipExtent &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExtent::depth").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "depth="); + roctracer::hip_support::detail::operator<<(out, v.depth); + std::operator<<(out, ", "); + } + if (std::string("hipExtent::height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "height="); + roctracer::hip_support::detail::operator<<(out, v.height); + std::operator<<(out, ", "); + } + if (std::string("hipExtent::width").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "width="); + roctracer::hip_support::detail::operator<<(out, v.width); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const hipPos &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipPos::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("hipPos::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("hipPos::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemcpy3DParms &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemcpy3DParms::kind").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "kind="); + roctracer::hip_support::detail::operator<<(out, v.kind); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpy3DParms::extent").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "extent="); + roctracer::hip_support::detail::operator<<(out, v.extent); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpy3DParms::dstPtr").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstPtr="); + roctracer::hip_support::detail::operator<<(out, v.dstPtr); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpy3DParms::dstPos").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstPos="); + roctracer::hip_support::detail::operator<<(out, v.dstPos); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpy3DParms::dstArray").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstArray="); + roctracer::hip_support::detail::operator<<(out, v.dstArray); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpy3DParms::srcPtr").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcPtr="); + roctracer::hip_support::detail::operator<<(out, v.srcPtr); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpy3DParms::srcPos").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcPos="); + roctracer::hip_support::detail::operator<<(out, v.srcPos); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpy3DParms::srcArray").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcArray="); + roctracer::hip_support::detail::operator<<(out, v.srcArray); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const HIP_MEMCPY3D &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("HIP_MEMCPY3D::Depth").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Depth="); + roctracer::hip_support::detail::operator<<(out, v.Depth); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::Height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "Height="); + roctracer::hip_support::detail::operator<<(out, v.Height); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::WidthInBytes").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "WidthInBytes="); + roctracer::hip_support::detail::operator<<(out, v.WidthInBytes); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstHeight").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstHeight="); + roctracer::hip_support::detail::operator<<(out, v.dstHeight); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstPitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstPitch="); + roctracer::hip_support::detail::operator<<(out, v.dstPitch); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstArray").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstArray="); + roctracer::hip_support::detail::operator<<(out, v.dstArray); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstDevice").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstDevice="); + roctracer::hip_support::detail::operator<<(out, v.dstDevice); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstMemoryType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstMemoryType="); + roctracer::hip_support::detail::operator<<(out, v.dstMemoryType); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstLOD").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstLOD="); + roctracer::hip_support::detail::operator<<(out, v.dstLOD); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstZ").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstZ="); + roctracer::hip_support::detail::operator<<(out, v.dstZ); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstY").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstY="); + roctracer::hip_support::detail::operator<<(out, v.dstY); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::dstXInBytes").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dstXInBytes="); + roctracer::hip_support::detail::operator<<(out, v.dstXInBytes); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcHeight").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcHeight="); + roctracer::hip_support::detail::operator<<(out, v.srcHeight); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcPitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcPitch="); + roctracer::hip_support::detail::operator<<(out, v.srcPitch); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcArray").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcArray="); + roctracer::hip_support::detail::operator<<(out, v.srcArray); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcDevice").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcDevice="); + roctracer::hip_support::detail::operator<<(out, v.srcDevice); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcMemoryType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcMemoryType="); + roctracer::hip_support::detail::operator<<(out, v.srcMemoryType); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcLOD").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcLOD="); + roctracer::hip_support::detail::operator<<(out, v.srcLOD); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcZ").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcZ="); + roctracer::hip_support::detail::operator<<(out, v.srcZ); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcY").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcY="); + roctracer::hip_support::detail::operator<<(out, v.srcY); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMCPY3D::srcXInBytes").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "srcXInBytes="); + roctracer::hip_support::detail::operator<<(out, v.srcXInBytes); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const uchar1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("uchar1::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const uchar2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("uchar2::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("uchar2::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const uchar3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("uchar3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("uchar3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("uchar3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const uchar4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("uchar4::w").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("uchar4::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("uchar4::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("uchar4::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const char1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("char1::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const char2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("char2::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("char2::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const char3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("char3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("char3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("char3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const char4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("char4::w").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("char4::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("char4::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("char4::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ushort1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ushort1::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ushort2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ushort2::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ushort2::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ushort3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ushort3::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("ushort3::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ushort3::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ushort4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ushort4::w").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("ushort4::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("ushort4::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ushort4::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const short1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("short1::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const short2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("short2::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("short2::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const short3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("short3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("short3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("short3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const short4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("short4::w").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("short4::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("short4::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("short4::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const uint1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("uint1::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const uint2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("uint2::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("uint2::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const uint3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("uint3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("uint3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("uint3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const uint4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("uint4::w").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("uint4::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("uint4::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("uint4::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const int1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("int1::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const int2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("int2::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("int2::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const int3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("int3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("int3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("int3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const int4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("int4::w").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("int4::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("int4::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("int4::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ulong1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ulong1::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ulong2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ulong2::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ulong2::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ulong3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ulong3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("ulong3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ulong3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ulong4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ulong4::w").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("ulong4::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("ulong4::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ulong4::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const long1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("long1::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const long2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("long2::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("long2::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const long3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("long3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("long3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("long3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const long4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("long4::w").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("long4::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("long4::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("long4::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ulonglong1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ulonglong1::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ulonglong2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ulonglong2::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ulonglong2::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ulonglong3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ulonglong3::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("ulonglong3::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ulonglong3::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const ulonglong4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("ulonglong4::w").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("ulonglong4::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("ulonglong4::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("ulonglong4::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const longlong1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("longlong1::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const longlong2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("longlong2::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("longlong2::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const longlong3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("longlong3::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("longlong3::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("longlong3::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const longlong4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("longlong4::w").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("longlong4::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("longlong4::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("longlong4::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const float1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("float1::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const float2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("float2::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("float2::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const float3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("float3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("float3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("float3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const float4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("float4::w").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("float4::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("float4::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("float4::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const double1 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("double1::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const double2 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("double2::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("double2::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const double3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("double3::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("double3::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("double3::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const double4 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("double4::w").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "w="); + roctracer::hip_support::detail::operator<<(out, v.w); + std::operator<<(out, ", "); + } + if (std::string("double4::z").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("double4::y").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("double4::x").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const textureReference &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("textureReference::format").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "format="); + roctracer::hip_support::detail::operator<<(out, v.format); + std::operator<<(out, ", "); + } + if (std::string("textureReference::numChannels").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "numChannels="); + roctracer::hip_support::detail::operator<<(out, v.numChannels); + std::operator<<(out, ", "); + } + if (std::string("textureReference::textureObject") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "textureObject="); + roctracer::hip_support::detail::operator<<(out, v.textureObject); + std::operator<<(out, ", "); + } + if (std::string("textureReference::maxMipmapLevelClamp") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxMipmapLevelClamp="); + roctracer::hip_support::detail::operator<<(out, v.maxMipmapLevelClamp); + std::operator<<(out, ", "); + } + if (std::string("textureReference::minMipmapLevelClamp") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "minMipmapLevelClamp="); + roctracer::hip_support::detail::operator<<(out, v.minMipmapLevelClamp); + std::operator<<(out, ", "); + } + if (std::string("textureReference::mipmapLevelBias") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "mipmapLevelBias="); + roctracer::hip_support::detail::operator<<(out, v.mipmapLevelBias); + std::operator<<(out, ", "); + } + if (std::string("textureReference::mipmapFilterMode") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "mipmapFilterMode="); + roctracer::hip_support::detail::operator<<(out, v.mipmapFilterMode); + std::operator<<(out, ", "); + } + if (std::string("textureReference::maxAnisotropy") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxAnisotropy="); + roctracer::hip_support::detail::operator<<(out, v.maxAnisotropy); + std::operator<<(out, ", "); + } + if (std::string("textureReference::sRGB").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "sRGB="); + roctracer::hip_support::detail::operator<<(out, v.sRGB); + std::operator<<(out, ", "); + } + if (std::string("textureReference::channelDesc").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "channelDesc="); + roctracer::hip_support::detail::operator<<(out, v.channelDesc); + std::operator<<(out, ", "); + } + if (std::string("textureReference::filterMode").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "filterMode="); + roctracer::hip_support::detail::operator<<(out, v.filterMode); + std::operator<<(out, ", "); + } + if (std::string("textureReference::readMode").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "readMode="); + roctracer::hip_support::detail::operator<<(out, v.readMode); + std::operator<<(out, ", "); + } + if (std::string("textureReference::normalized").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "normalized="); + roctracer::hip_support::detail::operator<<(out, v.normalized); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipTextureDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipTextureDesc::maxMipmapLevelClamp") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxMipmapLevelClamp="); + roctracer::hip_support::detail::operator<<(out, v.maxMipmapLevelClamp); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::minMipmapLevelClamp") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "minMipmapLevelClamp="); + roctracer::hip_support::detail::operator<<(out, v.minMipmapLevelClamp); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::mipmapLevelBias") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "mipmapLevelBias="); + roctracer::hip_support::detail::operator<<(out, v.mipmapLevelBias); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::mipmapFilterMode") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "mipmapFilterMode="); + roctracer::hip_support::detail::operator<<(out, v.mipmapFilterMode); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::maxAnisotropy").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "maxAnisotropy="); + roctracer::hip_support::detail::operator<<(out, v.maxAnisotropy); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::normalizedCoords") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "normalizedCoords="); + roctracer::hip_support::detail::operator<<(out, v.normalizedCoords); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::borderColor").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "borderColor="); + roctracer::hip_support::detail::operator<<(out, v.borderColor); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::sRGB").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "sRGB="); + roctracer::hip_support::detail::operator<<(out, v.sRGB); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::readMode").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "readMode="); + roctracer::hip_support::detail::operator<<(out, v.readMode); + std::operator<<(out, ", "); + } + if (std::string("hipTextureDesc::filterMode").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "filterMode="); + roctracer::hip_support::detail::operator<<(out, v.filterMode); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const surfaceReference &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("surfaceReference::surfaceObject") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "surfaceObject="); + roctracer::hip_support::detail::operator<<(out, v.surfaceObject); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipIpcMemHandle_t &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipIpcMemHandle_t::reserved").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipIpcEventHandle_t &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipIpcEventHandle_t::reserved").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipFuncAttributes &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipFuncAttributes::sharedSizeBytes") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "sharedSizeBytes="); + roctracer::hip_support::detail::operator<<(out, v.sharedSizeBytes); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::ptxVersion").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "ptxVersion="); + roctracer::hip_support::detail::operator<<(out, v.ptxVersion); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::preferredShmemCarveout") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "preferredShmemCarveout="); + roctracer::hip_support::detail::operator<<(out, v.preferredShmemCarveout); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::numRegs").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "numRegs="); + roctracer::hip_support::detail::operator<<(out, v.numRegs); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::maxThreadsPerBlock") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxThreadsPerBlock="); + roctracer::hip_support::detail::operator<<(out, v.maxThreadsPerBlock); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::maxDynamicSharedSizeBytes") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxDynamicSharedSizeBytes="); + roctracer::hip_support::detail::operator<<(out, + v.maxDynamicSharedSizeBytes); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::localSizeBytes") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "localSizeBytes="); + roctracer::hip_support::detail::operator<<(out, v.localSizeBytes); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::constSizeBytes") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "constSizeBytes="); + roctracer::hip_support::detail::operator<<(out, v.constSizeBytes); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::cacheModeCA").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "cacheModeCA="); + roctracer::hip_support::detail::operator<<(out, v.cacheModeCA); + std::operator<<(out, ", "); + } + if (std::string("hipFuncAttributes::binaryVersion") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "binaryVersion="); + roctracer::hip_support::detail::operator<<(out, v.binaryVersion); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemLocation &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemLocation::id").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "id="); + roctracer::hip_support::detail::operator<<(out, v.id); + std::operator<<(out, ", "); + } + if (std::string("hipMemLocation::type").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemAccessDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemAccessDesc::flags").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipMemAccessDesc::location").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "location="); + roctracer::hip_support::detail::operator<<(out, v.location); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemPoolProps &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemPoolProps::reserved").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipMemPoolProps::maxSize").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "maxSize="); + roctracer::hip_support::detail::operator<<(out, v.maxSize); + std::operator<<(out, ", "); + } + if (std::string("hipMemPoolProps::location").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "location="); + roctracer::hip_support::detail::operator<<(out, v.location); + std::operator<<(out, ", "); + } + if (std::string("hipMemPoolProps::handleTypes").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "handleTypes="); + roctracer::hip_support::detail::operator<<(out, v.handleTypes); + std::operator<<(out, ", "); + } + if (std::string("hipMemPoolProps::allocType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "allocType="); + roctracer::hip_support::detail::operator<<(out, v.allocType); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemPoolPtrExportData &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemPoolPtrExportData::reserved") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const dim3 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("dim3::z").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "z="); + roctracer::hip_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("dim3::y").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "y="); + roctracer::hip_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("dim3::x").find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "x="); + roctracer::hip_support::detail::operator<<(out, v.x); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipLaunchParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipLaunchParams::stream").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "stream="); + roctracer::hip_support::detail::operator<<(out, v.stream); + std::operator<<(out, ", "); + } + if (std::string("hipLaunchParams::sharedMem").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "sharedMem="); + roctracer::hip_support::detail::operator<<(out, v.sharedMem); + std::operator<<(out, ", "); + } + if (std::string("hipLaunchParams::blockDim").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "blockDim="); + roctracer::hip_support::detail::operator<<(out, v.blockDim); + std::operator<<(out, ", "); + } + if (std::string("hipLaunchParams::gridDim").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "gridDim="); + roctracer::hip_support::detail::operator<<(out, v.gridDim); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipFunctionLaunchParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipFunctionLaunchParams::hStream") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hStream="); + roctracer::hip_support::detail::operator<<(out, v.hStream); + std::operator<<(out, ", "); + } + if (std::string("hipFunctionLaunchParams::sharedMemBytes") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "sharedMemBytes="); + roctracer::hip_support::detail::operator<<(out, v.sharedMemBytes); + std::operator<<(out, ", "); + } + if (std::string("hipFunctionLaunchParams::blockDimZ") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "blockDimZ="); + roctracer::hip_support::detail::operator<<(out, v.blockDimZ); + std::operator<<(out, ", "); + } + if (std::string("hipFunctionLaunchParams::blockDimY") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "blockDimY="); + roctracer::hip_support::detail::operator<<(out, v.blockDimY); + std::operator<<(out, ", "); + } + if (std::string("hipFunctionLaunchParams::blockDimX") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "blockDimX="); + roctracer::hip_support::detail::operator<<(out, v.blockDimX); + std::operator<<(out, ", "); + } + if (std::string("hipFunctionLaunchParams::gridDimZ") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "gridDimZ="); + roctracer::hip_support::detail::operator<<(out, v.gridDimZ); + std::operator<<(out, ", "); + } + if (std::string("hipFunctionLaunchParams::gridDimY") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "gridDimY="); + roctracer::hip_support::detail::operator<<(out, v.gridDimY); + std::operator<<(out, ", "); + } + if (std::string("hipFunctionLaunchParams::gridDimX") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "gridDimX="); + roctracer::hip_support::detail::operator<<(out, v.gridDimX); + std::operator<<(out, ", "); + } + if (std::string("hipFunctionLaunchParams::function") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "function="); + roctracer::hip_support::detail::operator<<(out, v.function); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipExternalMemoryHandleDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExternalMemoryHandleDesc::reserved") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryHandleDesc::flags") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryHandleDesc::size") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "size="); + roctracer::hip_support::detail::operator<<(out, v.size); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryHandleDesc_st::union ::handle.fd") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "handle.fd="); + roctracer::hip_support::detail::operator<<(out, v.handle.fd); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryHandleDesc::type") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipExternalMemoryBufferDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExternalMemoryBufferDesc::reserved") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryBufferDesc::flags") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryBufferDesc::size") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "size="); + roctracer::hip_support::detail::operator<<(out, v.size); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryBufferDesc::offset") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "offset="); + roctracer::hip_support::detail::operator<<(out, v.offset); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalMemoryMipmappedArrayDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExternalMemoryMipmappedArrayDesc::numLevels") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "numLevels="); + roctracer::hip_support::detail::operator<<(out, v.numLevels); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryMipmappedArrayDesc::flags") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryMipmappedArrayDesc::extent") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "extent="); + roctracer::hip_support::detail::operator<<(out, v.extent); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryMipmappedArrayDesc::formatDesc") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "formatDesc="); + roctracer::hip_support::detail::operator<<(out, v.formatDesc); + std::operator<<(out, ", "); + } + if (std::string("hipExternalMemoryMipmappedArrayDesc::offset") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "offset="); + roctracer::hip_support::detail::operator<<(out, v.offset); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreHandleDesc &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExternalSemaphoreHandleDesc::reserved") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreHandleDesc::flags") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreHandleDesc_st::union ::handle.fd") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "handle.fd="); + roctracer::hip_support::detail::operator<<(out, v.handle.fd); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreHandleDesc::type") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreSignalParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExternalSemaphoreSignalParams::reserved") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreSignalParams::flags") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreWaitParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExternalSemaphoreWaitParams::reserved") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreWaitParams::flags") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipHostNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipHostNodeParams::fn").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "fn="); + roctracer::hip_support::detail::operator<<(out, v.fn); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipKernelNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipKernelNodeParams::sharedMemBytes") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "sharedMemBytes="); + roctracer::hip_support::detail::operator<<(out, v.sharedMemBytes); + std::operator<<(out, ", "); + } + if (std::string("hipKernelNodeParams::gridDim").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "gridDim="); + roctracer::hip_support::detail::operator<<(out, v.gridDim); + std::operator<<(out, ", "); + } + if (std::string("hipKernelNodeParams::blockDim").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "blockDim="); + roctracer::hip_support::detail::operator<<(out, v.blockDim); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemsetParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemsetParams::width").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "width="); + roctracer::hip_support::detail::operator<<(out, v.width); + std::operator<<(out, ", "); + } + if (std::string("hipMemsetParams::value").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "value="); + roctracer::hip_support::detail::operator<<(out, v.value); + std::operator<<(out, ", "); + } + if (std::string("hipMemsetParams::pitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "pitch="); + roctracer::hip_support::detail::operator<<(out, v.pitch); + std::operator<<(out, ", "); + } + if (std::string("hipMemsetParams::height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "height="); + roctracer::hip_support::detail::operator<<(out, v.height); + std::operator<<(out, ", "); + } + if (std::string("hipMemsetParams::elementSize").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "elementSize="); + roctracer::hip_support::detail::operator<<(out, v.elementSize); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemAllocNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemAllocNodeParams::bytesize") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "bytesize="); + roctracer::hip_support::detail::operator<<(out, v.bytesize); + std::operator<<(out, ", "); + } + if (std::string("hipMemAllocNodeParams::accessDescCount") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "accessDescCount="); + roctracer::hip_support::detail::operator<<(out, v.accessDescCount); + std::operator<<(out, ", "); + } + if (std::string("hipMemAllocNodeParams::accessDescs") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "accessDescs="); + roctracer::hip_support::detail::operator<<(out, v.accessDescs); + std::operator<<(out, ", "); + } + if (std::string("hipMemAllocNodeParams::poolProps") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "poolProps="); + roctracer::hip_support::detail::operator<<(out, v.poolProps); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipAccessPolicyWindow &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipAccessPolicyWindow::num_bytes") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "num_bytes="); + roctracer::hip_support::detail::operator<<(out, v.num_bytes); + std::operator<<(out, ", "); + } + if (std::string("hipAccessPolicyWindow::missProp") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "missProp="); + roctracer::hip_support::detail::operator<<(out, v.missProp); + std::operator<<(out, ", "); + } + if (std::string("hipAccessPolicyWindow::hitRatio") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hitRatio="); + roctracer::hip_support::detail::operator<<(out, v.hitRatio); + std::operator<<(out, ", "); + } + if (std::string("hipAccessPolicyWindow::hitProp").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "hitProp="); + roctracer::hip_support::detail::operator<<(out, v.hitProp); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipLaunchAttributeValue &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipLaunchAttributeValue::priority") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "priority="); + roctracer::hip_support::detail::operator<<(out, v.priority); + std::operator<<(out, ", "); + } + if (std::string("hipLaunchAttributeValue::cooperative") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperative="); + roctracer::hip_support::detail::operator<<(out, v.cooperative); + std::operator<<(out, ", "); + } + if (std::string("hipLaunchAttributeValue::accessPolicyWindow") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "accessPolicyWindow="); + roctracer::hip_support::detail::operator<<(out, v.accessPolicyWindow); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const HIP_MEMSET_NODE_PARAMS &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("HIP_MEMSET_NODE_PARAMS::height").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "height="); + roctracer::hip_support::detail::operator<<(out, v.height); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMSET_NODE_PARAMS::width").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "width="); + roctracer::hip_support::detail::operator<<(out, v.width); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMSET_NODE_PARAMS::elementSize") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "elementSize="); + roctracer::hip_support::detail::operator<<(out, v.elementSize); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMSET_NODE_PARAMS::value").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "value="); + roctracer::hip_support::detail::operator<<(out, v.value); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMSET_NODE_PARAMS::pitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "pitch="); + roctracer::hip_support::detail::operator<<(out, v.pitch); + std::operator<<(out, ", "); + } + if (std::string("HIP_MEMSET_NODE_PARAMS::dst").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "dst="); + roctracer::hip_support::detail::operator<<(out, v.dst); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipGraphInstantiateParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipGraphInstantiateParams::uploadStream") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "uploadStream="); + roctracer::hip_support::detail::operator<<(out, v.uploadStream); + std::operator<<(out, ", "); + } + if (std::string("hipGraphInstantiateParams::result_out") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "result_out="); + roctracer::hip_support::detail::operator<<(out, v.result_out); + std::operator<<(out, ", "); + } + if (std::string("hipGraphInstantiateParams::flags") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipGraphInstantiateParams::errNode_out") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "errNode_out="); + roctracer::hip_support::detail::operator<<(out, v.errNode_out); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemAllocationProp &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemAllocationProp::location").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "location="); + roctracer::hip_support::detail::operator<<(out, v.location); + std::operator<<(out, ", "); + } + if (std::string("hipMemAllocationProp::requestedHandleType") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "requestedHandleType="); + roctracer::hip_support::detail::operator<<(out, v.requestedHandleType); + std::operator<<(out, ", "); + } + if (std::string("hipMemAllocationProp::type").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreSignalNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExternalSemaphoreSignalNodeParams::numExtSems") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "numExtSems="); + roctracer::hip_support::detail::operator<<(out, v.numExtSems); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreSignalNodeParams::paramsArray") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "paramsArray="); + roctracer::hip_support::detail::operator<<(out, v.paramsArray); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreSignalNodeParams::extSemArray") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "extSemArray="); + roctracer::hip_support::detail::operator<<(out, v.extSemArray); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreWaitNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipExternalSemaphoreWaitNodeParams::numExtSems") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "numExtSems="); + roctracer::hip_support::detail::operator<<(out, v.numExtSems); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreWaitNodeParams::paramsArray") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "paramsArray="); + roctracer::hip_support::detail::operator<<(out, v.paramsArray); + std::operator<<(out, ", "); + } + if (std::string("hipExternalSemaphoreWaitNodeParams::extSemArray") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "extSemArray="); + roctracer::hip_support::detail::operator<<(out, v.extSemArray); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipArrayMapInfo &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipArrayMapInfo::reserved").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipArrayMapInfo::flags").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hipArrayMapInfo::deviceBitMask").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "deviceBitMask="); + roctracer::hip_support::detail::operator<<(out, v.deviceBitMask); + std::operator<<(out, ", "); + } + if (std::string("hipArrayMapInfo::offset").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "offset="); + roctracer::hip_support::detail::operator<<(out, v.offset); + std::operator<<(out, ", "); + } + if (std::string("hipArrayMapInfo::union ::memHandle.memHandle") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "memHandle.memHandle="); + roctracer::hip_support::detail::operator<<(out, v.memHandle.memHandle); + std::operator<<(out, ", "); + } + if (std::string("hipArrayMapInfo::memHandleType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "memHandleType="); + roctracer::hip_support::detail::operator<<(out, v.memHandleType); + std::operator<<(out, ", "); + } + if (std::string("hipArrayMapInfo::memOperationType") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "memOperationType="); + roctracer::hip_support::detail::operator<<(out, v.memOperationType); + std::operator<<(out, ", "); + } + if (std::string("hipArrayMapInfo::subresourceType") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "subresourceType="); + roctracer::hip_support::detail::operator<<(out, v.subresourceType); + std::operator<<(out, ", "); + } + if (std::string("hipArrayMapInfo::resourceType").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "resourceType="); + roctracer::hip_support::detail::operator<<(out, v.resourceType); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemcpyNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipMemcpyNodeParams::copyParams") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "copyParams="); + roctracer::hip_support::detail::operator<<(out, v.copyParams); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpyNodeParams::reserved").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipMemcpyNodeParams::flags").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hip_support::detail::operator<<(out, v.flags); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipChildGraphNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipChildGraphNodeParams::graph").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "graph="); + roctracer::hip_support::detail::operator<<(out, v.graph); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipEventWaitNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipEventWaitNodeParams::event").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "event="); + roctracer::hip_support::detail::operator<<(out, v.event); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipEventRecordNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipEventRecordNodeParams::event") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "event="); + roctracer::hip_support::detail::operator<<(out, v.event); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipMemFreeNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipGraphNodeParams &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipGraphNodeParams::reserved2").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved2="); + roctracer::hip_support::detail::operator<<(out, v.reserved2); + std::operator<<(out, ", "); + } + if (std::string("hipGraphNodeParams::reserved0").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved0="); + roctracer::hip_support::detail::operator<<(out, v.reserved0); + std::operator<<(out, ", "); + } + if (std::string("hipGraphNodeParams::type").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipGraphEdgeData &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string("hipGraphEdgeData::type").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "type="); + roctracer::hip_support::detail::operator<<(out, v.type); + std::operator<<(out, ", "); + } + if (std::string("hipGraphEdgeData::to_port").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "to_port="); + roctracer::hip_support::detail::operator<<(out, v.to_port); + std::operator<<(out, ", "); + } + if (std::string("hipGraphEdgeData::reserved").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hip_support::detail::operator<<(out, 0); + std::operator<<(out, ", "); + } + if (std::string("hipGraphEdgeData::from_port").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "from_port="); + roctracer::hip_support::detail::operator<<(out, v.from_port); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hipDeviceProp_tR0000 &v) { + std::operator<<(out, '{'); + HIP_depth_max_cnt++; + if (HIP_depth_max == -1 || HIP_depth_max_cnt <= HIP_depth_max) { + if (std::string( + "hipDeviceProp_tR0000::pageableMemoryAccessUsesHostPageTables") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "pageableMemoryAccessUsesHostPageTables="); + roctracer::hip_support::detail::operator<<( + out, v.pageableMemoryAccessUsesHostPageTables); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::pageableMemoryAccess") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "pageableMemoryAccess="); + roctracer::hip_support::detail::operator<<(out, v.pageableMemoryAccess); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::concurrentManagedAccess") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "concurrentManagedAccess="); + roctracer::hip_support::detail::operator<<(out, + v.concurrentManagedAccess); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::directManagedMemAccessFromHost") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "directManagedMemAccessFromHost="); + roctracer::hip_support::detail::operator<<( + out, v.directManagedMemAccessFromHost); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::managedMemory") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "managedMemory="); + roctracer::hip_support::detail::operator<<(out, v.managedMemory); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::asicRevision") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "asicRevision="); + roctracer::hip_support::detail::operator<<(out, v.asicRevision); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::isLargeBar") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "isLargeBar="); + roctracer::hip_support::detail::operator<<(out, v.isLargeBar); + std::operator<<(out, ", "); + } + if (std::string( + "hipDeviceProp_tR0000::cooperativeMultiDeviceUnmatchedSharedMem") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceUnmatchedSharedMem="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceUnmatchedSharedMem); + std::operator<<(out, ", "); + } + if (std::string( + "hipDeviceProp_tR0000::cooperativeMultiDeviceUnmatchedBlockDim") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceUnmatchedBlockDim="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceUnmatchedBlockDim); + std::operator<<(out, ", "); + } + if (std::string( + "hipDeviceProp_tR0000::cooperativeMultiDeviceUnmatchedGridDim") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceUnmatchedGridDim="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceUnmatchedGridDim); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::cooperativeMultiDeviceUnmatchedFunc") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceUnmatchedFunc="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceUnmatchedFunc); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::tccDriver") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "tccDriver="); + roctracer::hip_support::detail::operator<<(out, v.tccDriver); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::ECCEnabled") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "ECCEnabled="); + roctracer::hip_support::detail::operator<<(out, v.ECCEnabled); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::kernelExecTimeoutEnabled") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "kernelExecTimeoutEnabled="); + roctracer::hip_support::detail::operator<<(out, + v.kernelExecTimeoutEnabled); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::texturePitchAlignment") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "texturePitchAlignment="); + roctracer::hip_support::detail::operator<<(out, v.texturePitchAlignment); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::textureAlignment") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "textureAlignment="); + roctracer::hip_support::detail::operator<<(out, v.textureAlignment); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::memPitch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "memPitch="); + roctracer::hip_support::detail::operator<<(out, v.memPitch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::hdpRegFlushCntl") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hdpRegFlushCntl="); + roctracer::hip_support::detail::operator<<(out, v.hdpRegFlushCntl); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::hdpMemFlushCntl") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "hdpMemFlushCntl="); + roctracer::hip_support::detail::operator<<(out, v.hdpMemFlushCntl); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxTexture3D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture3D="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture3D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxTexture2D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture2D="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture2D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxTexture1D") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture1D="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture1D); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxTexture1DLinear") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxTexture1DLinear="); + roctracer::hip_support::detail::operator<<(out, v.maxTexture1DLinear); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::cooperativeMultiDeviceLaunch") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeMultiDeviceLaunch="); + roctracer::hip_support::detail::operator<<( + out, v.cooperativeMultiDeviceLaunch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::cooperativeLaunch") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "cooperativeLaunch="); + roctracer::hip_support::detail::operator<<(out, v.cooperativeLaunch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::integrated") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "integrated="); + roctracer::hip_support::detail::operator<<(out, v.integrated); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::gcnArchName") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "gcnArchName="); + roctracer::hip_support::detail::operator<<(out, v.gcnArchName); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::gcnArch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "gcnArch="); + roctracer::hip_support::detail::operator<<(out, v.gcnArch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::canMapHostMemory") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "canMapHostMemory="); + roctracer::hip_support::detail::operator<<(out, v.canMapHostMemory); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::isMultiGpuBoard") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "isMultiGpuBoard="); + roctracer::hip_support::detail::operator<<(out, v.isMultiGpuBoard); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxSharedMemoryPerMultiProcessor") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxSharedMemoryPerMultiProcessor="); + roctracer::hip_support::detail::operator<<( + out, v.maxSharedMemoryPerMultiProcessor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::pciDeviceID") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "pciDeviceID="); + roctracer::hip_support::detail::operator<<(out, v.pciDeviceID); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::pciBusID").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "pciBusID="); + roctracer::hip_support::detail::operator<<(out, v.pciBusID); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::pciDomainID") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "pciDomainID="); + roctracer::hip_support::detail::operator<<(out, v.pciDomainID); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::concurrentKernels") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "concurrentKernels="); + roctracer::hip_support::detail::operator<<(out, v.concurrentKernels); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::arch").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "arch="); + roctracer::hip_support::detail::operator<<(out, v.arch); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::clockInstructionRate") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "clockInstructionRate="); + roctracer::hip_support::detail::operator<<(out, v.clockInstructionRate); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::computeMode") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "computeMode="); + roctracer::hip_support::detail::operator<<(out, v.computeMode); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxThreadsPerMultiProcessor") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxThreadsPerMultiProcessor="); + roctracer::hip_support::detail::operator<<(out, + v.maxThreadsPerMultiProcessor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::l2CacheSize") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "l2CacheSize="); + roctracer::hip_support::detail::operator<<(out, v.l2CacheSize); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::multiProcessorCount") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "multiProcessorCount="); + roctracer::hip_support::detail::operator<<(out, v.multiProcessorCount); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::minor").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "minor="); + roctracer::hip_support::detail::operator<<(out, v.minor); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::major").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "major="); + roctracer::hip_support::detail::operator<<(out, v.major); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::totalConstMem") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "totalConstMem="); + roctracer::hip_support::detail::operator<<(out, v.totalConstMem); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::memoryBusWidth") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "memoryBusWidth="); + roctracer::hip_support::detail::operator<<(out, v.memoryBusWidth); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::memoryClockRate") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "memoryClockRate="); + roctracer::hip_support::detail::operator<<(out, v.memoryClockRate); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::clockRate") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "clockRate="); + roctracer::hip_support::detail::operator<<(out, v.clockRate); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxGridSize") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxGridSize="); + roctracer::hip_support::detail::operator<<(out, v.maxGridSize); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxThreadsDim") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxThreadsDim="); + roctracer::hip_support::detail::operator<<(out, v.maxThreadsDim); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::maxThreadsPerBlock") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "maxThreadsPerBlock="); + roctracer::hip_support::detail::operator<<(out, v.maxThreadsPerBlock); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::warpSize").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "warpSize="); + roctracer::hip_support::detail::operator<<(out, v.warpSize); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::regsPerBlock") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "regsPerBlock="); + roctracer::hip_support::detail::operator<<(out, v.regsPerBlock); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::sharedMemPerBlock") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "sharedMemPerBlock="); + roctracer::hip_support::detail::operator<<(out, v.sharedMemPerBlock); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::totalGlobalMem") + .find(HIP_structs_regex) != std::string::npos) { + std::operator<<(out, "totalGlobalMem="); + roctracer::hip_support::detail::operator<<(out, v.totalGlobalMem); + std::operator<<(out, ", "); + } + if (std::string("hipDeviceProp_tR0000::name").find(HIP_structs_regex) != + std::string::npos) { + std::operator<<(out, "name="); + roctracer::hip_support::detail::operator<<(out, v.name); + } + }; + HIP_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +// end ostream ops for HIP +}; // namespace detail +}; // namespace hip_support +}; // namespace roctracer + +inline static std::ostream &operator<<(std::ostream &out, + const __locale_struct &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipDeviceArch_t &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const hipUUID &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipDeviceProp_tR0600 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipPointerAttribute_t &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipChannelFormatDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const HIP_ARRAY_DESCRIPTOR &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const HIP_ARRAY3D_DESCRIPTOR &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hip_Memcpy2D &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMipmappedArray &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const HIP_TEXTURE_DESC &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipResourceDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const HIP_RESOURCE_DESC &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipResourceViewDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const HIP_RESOURCE_VIEW_DESC &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipPitchedPtr &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const hipExtent &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const hipPos &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemcpy3DParms &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const HIP_MEMCPY3D &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const uchar1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const uchar2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const uchar3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const uchar4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const char1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const char2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const char3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const char4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ushort1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ushort2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ushort3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ushort4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const short1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const short2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const short3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const short4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const uint1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const uint2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const uint3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const uint4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const int1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const int2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const int3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const int4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ulong1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ulong2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ulong3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ulong4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const long1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const long2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const long3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const long4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ulonglong1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ulonglong2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ulonglong3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const ulonglong4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const longlong1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const longlong2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const longlong3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const longlong4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const float1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const float2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const float3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const float4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const double1 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const double2 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const double3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const double4 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const textureReference &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipTextureDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const surfaceReference &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipIpcMemHandle_t &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipIpcEventHandle_t &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipFuncAttributes &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemLocation &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemAccessDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemPoolProps &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemPoolPtrExportData &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const dim3 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipLaunchParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipFunctionLaunchParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipExternalMemoryHandleDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipExternalMemoryBufferDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalMemoryMipmappedArrayDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreHandleDesc &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreSignalParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreWaitParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipHostNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipKernelNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemsetParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemAllocNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipAccessPolicyWindow &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipLaunchAttributeValue &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const HIP_MEMSET_NODE_PARAMS &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipGraphInstantiateParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemAllocationProp &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreSignalNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hipExternalSemaphoreWaitNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipArrayMapInfo &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemcpyNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipChildGraphNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipEventWaitNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipEventRecordNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipMemFreeNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipGraphNodeParams &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipGraphEdgeData &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hipDeviceProp_tR0000 &v) { + roctracer::hip_support::detail::operator<<(out, v); + return out; +} + +#endif //__cplusplus +#endif // INC_HIP_OSTREAM_OPS_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hsa_ostream_ops.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hsa_ostream_ops.h new file mode 100644 index 000000000..fbcdfff5d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hsa_ostream_ops.h @@ -0,0 +1,1930 @@ +// automatically generated +/* +Copyright (c) 2018 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#ifndef INC_HSA_OSTREAM_OPS_H_ +#define INC_HSA_OSTREAM_OPS_H_ + +#include "roctracer.h" + +#ifdef __cplusplus +#include +#include +#include +#include + +namespace roctracer { +namespace hsa_support { +static int HSA_depth_max = 1; +static int HSA_depth_max_cnt = 0; +static std::string HSA_structs_regex = ""; +// begin ostream ops for HSA +// basic ostream ops +namespace detail { +inline static void print_escaped_string(std::ostream &out, const char *v, + size_t len) { + out << '"'; + for (size_t i = 0; i < len && v[i]; ++i) { + switch (v[i]) { + case '\"': + out << "\\\""; + break; + case '\\': + out << "\\\\"; + break; + case '\b': + out << "\\\b"; + break; + case '\f': + out << "\\\f"; + break; + case '\n': + out << "\\\n"; + break; + case '\r': + out << "\\\r"; + break; + case '\t': + out << "\\\t"; + break; + default: + if (std::isprint((unsigned char)v[i])) + std::operator<<(out, v[i]); + else { + std::ios_base::fmtflags flags(out.flags()); + out << "\\x" << std::setfill('0') << std::setw(2) << std::hex + << (unsigned int)(unsigned char)v[i]; + out.flags(flags); + } + break; + } + } + out << '"'; +} + +template +inline static std::ostream &operator<<(std::ostream &out, const T &v) { + using std::operator<<; + static bool recursion = false; + if (recursion == false) { + recursion = true; + out << v; + recursion = false; + } + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const unsigned char &v) { + out << (unsigned int)v; + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const char &v) { + out << (unsigned char)v; + return out; +} + +template +inline static std::ostream &operator<<(std::ostream &out, const char (&v)[N]) { + print_escaped_string(out, v, N); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const char *v) { + print_escaped_string(out, v, strlen(v)); + return out; +} +// End of basic ostream ops + +inline static std::ostream &operator<<(std::ostream &out, const hsa_dim3_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_dim3_t::z").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "z="); + roctracer::hsa_support::detail::operator<<(out, v.z); + std::operator<<(out, ", "); + } + if (std::string("hsa_dim3_t::y").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "y="); + roctracer::hsa_support::detail::operator<<(out, v.y); + std::operator<<(out, ", "); + } + if (std::string("hsa_dim3_t::x").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "x="); + roctracer::hsa_support::detail::operator<<(out, v.x); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_agent_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_agent_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_cache_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_cache_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_signal_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_signal_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_signal_group_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_signal_group_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_region_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_region_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_queue_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_queue_t::id").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "id="); + roctracer::hsa_support::detail::operator<<(out, v.id); + std::operator<<(out, ", "); + } + if (std::string("hsa_queue_t::reserved1").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "reserved1="); + roctracer::hsa_support::detail::operator<<(out, v.reserved1); + std::operator<<(out, ", "); + } + if (std::string("hsa_queue_t::size").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "size="); + roctracer::hsa_support::detail::operator<<(out, v.size); + std::operator<<(out, ", "); + } + if (std::string("hsa_queue_t::doorbell_signal").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "doorbell_signal="); + roctracer::hsa_support::detail::operator<<(out, v.doorbell_signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_queue_t::features").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "features="); + roctracer::hsa_support::detail::operator<<(out, v.features); + std::operator<<(out, ", "); + } + if (std::string("hsa_queue_t::type").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "type="); + roctracer::hsa_support::detail::operator<<(out, v.type); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_kernel_dispatch_packet_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_kernel_dispatch_packet_t::completion_signal") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "completion_signal="); + roctracer::hsa_support::detail::operator<<(out, v.completion_signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::reserved2") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved2="); + roctracer::hsa_support::detail::operator<<(out, v.reserved2); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::kernel_object") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "kernel_object="); + roctracer::hsa_support::detail::operator<<(out, v.kernel_object); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::group_segment_size") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "group_segment_size="); + roctracer::hsa_support::detail::operator<<(out, v.group_segment_size); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::private_segment_size") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "private_segment_size="); + roctracer::hsa_support::detail::operator<<(out, v.private_segment_size); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::grid_size_z") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "grid_size_z="); + roctracer::hsa_support::detail::operator<<(out, v.grid_size_z); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::grid_size_y") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "grid_size_y="); + roctracer::hsa_support::detail::operator<<(out, v.grid_size_y); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::grid_size_x") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "grid_size_x="); + roctracer::hsa_support::detail::operator<<(out, v.grid_size_x); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::reserved0") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved0="); + roctracer::hsa_support::detail::operator<<(out, v.reserved0); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::workgroup_size_z") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_size_z="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_size_z); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::workgroup_size_y") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_size_y="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_size_y); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::workgroup_size_x") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_size_x="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_size_x); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::setup") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "setup="); + roctracer::hsa_support::detail::operator<<(out, v.setup); + std::operator<<(out, ", "); + } + if (std::string("hsa_kernel_dispatch_packet_t::header") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "header="); + roctracer::hsa_support::detail::operator<<(out, v.header); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_agent_dispatch_packet_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_agent_dispatch_packet_t::completion_signal") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "completion_signal="); + roctracer::hsa_support::detail::operator<<(out, v.completion_signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_agent_dispatch_packet_t::reserved2") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved2="); + roctracer::hsa_support::detail::operator<<(out, v.reserved2); + std::operator<<(out, ", "); + } + if (std::string("hsa_agent_dispatch_packet_t::arg") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "arg="); + roctracer::hsa_support::detail::operator<<(out, v.arg); + std::operator<<(out, ", "); + } + if (std::string("hsa_agent_dispatch_packet_t::reserved0") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved0="); + roctracer::hsa_support::detail::operator<<(out, v.reserved0); + std::operator<<(out, ", "); + } + if (std::string("hsa_agent_dispatch_packet_t::type") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "type="); + roctracer::hsa_support::detail::operator<<(out, v.type); + std::operator<<(out, ", "); + } + if (std::string("hsa_agent_dispatch_packet_t::header") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "header="); + roctracer::hsa_support::detail::operator<<(out, v.header); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_barrier_and_packet_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_barrier_and_packet_t::completion_signal") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "completion_signal="); + roctracer::hsa_support::detail::operator<<(out, v.completion_signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_and_packet_t::reserved2") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved2="); + roctracer::hsa_support::detail::operator<<(out, v.reserved2); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_and_packet_t::dep_signal") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "dep_signal="); + roctracer::hsa_support::detail::operator<<(out, v.dep_signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_and_packet_t::reserved1") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved1="); + roctracer::hsa_support::detail::operator<<(out, v.reserved1); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_and_packet_t::reserved0") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved0="); + roctracer::hsa_support::detail::operator<<(out, v.reserved0); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_and_packet_t::header") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "header="); + roctracer::hsa_support::detail::operator<<(out, v.header); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_barrier_or_packet_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_barrier_or_packet_t::completion_signal") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "completion_signal="); + roctracer::hsa_support::detail::operator<<(out, v.completion_signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_or_packet_t::reserved2") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved2="); + roctracer::hsa_support::detail::operator<<(out, v.reserved2); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_or_packet_t::dep_signal") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "dep_signal="); + roctracer::hsa_support::detail::operator<<(out, v.dep_signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_or_packet_t::reserved1") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved1="); + roctracer::hsa_support::detail::operator<<(out, v.reserved1); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_or_packet_t::reserved0") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved0="); + roctracer::hsa_support::detail::operator<<(out, v.reserved0); + std::operator<<(out, ", "); + } + if (std::string("hsa_barrier_or_packet_t::header") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "header="); + roctracer::hsa_support::detail::operator<<(out, v.header); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, const hsa_isa_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_isa_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_wavefront_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_wavefront_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_code_object_reader_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_code_object_reader_t::handle") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_executable_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_executable_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_loaded_code_object_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_loaded_code_object_t::handle") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_executable_symbol_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_executable_symbol_t::handle") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_code_object_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_code_object_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_callback_data_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_callback_data_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_code_symbol_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_code_symbol_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ext_image_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_format_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ext_image_format_t::channel_order") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "channel_order="); + roctracer::hsa_support::detail::operator<<(out, v.channel_order); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_image_format_t::channel_type") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "channel_type="); + roctracer::hsa_support::detail::operator<<(out, v.channel_type); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_descriptor_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ext_image_descriptor_t::format") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "format="); + roctracer::hsa_support::detail::operator<<(out, v.format); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_image_descriptor_t::array_size") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "array_size="); + roctracer::hsa_support::detail::operator<<(out, v.array_size); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_image_descriptor_t::depth") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "depth="); + roctracer::hsa_support::detail::operator<<(out, v.depth); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_image_descriptor_t::height") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "height="); + roctracer::hsa_support::detail::operator<<(out, v.height); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_image_descriptor_t::width") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "width="); + roctracer::hsa_support::detail::operator<<(out, v.width); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_image_descriptor_t::geometry") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "geometry="); + roctracer::hsa_support::detail::operator<<(out, v.geometry); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_data_info_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ext_image_data_info_t::alignment") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "alignment="); + roctracer::hsa_support::detail::operator<<(out, v.alignment); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_image_data_info_t::size") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "size="); + roctracer::hsa_support::detail::operator<<(out, v.size); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_region_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ext_image_region_t::range").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "range="); + roctracer::hsa_support::detail::operator<<(out, v.range); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_image_region_t::offset").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "offset="); + roctracer::hsa_support::detail::operator<<(out, v.offset); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_sampler_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ext_sampler_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_sampler_descriptor_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ext_sampler_descriptor_t::address_mode") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "address_mode="); + roctracer::hsa_support::detail::operator<<(out, v.address_mode); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_sampler_descriptor_t::filter_mode") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "filter_mode="); + roctracer::hsa_support::detail::operator<<(out, v.filter_mode); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_sampler_descriptor_t::coordinate_mode") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "coordinate_mode="); + roctracer::hsa_support::detail::operator<<(out, v.coordinate_mode); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_images_1_00_pfn_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ext_images_1_00_pfn_t::hsa_ext_sampler_destroy") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_sampler_destroy="); + roctracer::hsa_support::detail::operator<<(out, + v.hsa_ext_sampler_destroy); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_00_pfn_t::hsa_ext_sampler_create") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_sampler_create="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ext_sampler_create); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_00_pfn_t::hsa_ext_image_copy") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_copy="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ext_image_copy); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_00_pfn_t::hsa_ext_image_destroy") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_destroy="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ext_image_destroy); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_00_pfn_t::hsa_ext_image_data_get_info") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_data_get_info="); + roctracer::hsa_support::detail::operator<<(out, + v.hsa_ext_image_data_get_info); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_00_pfn_t::hsa_ext_image_get_capability") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_get_capability="); + roctracer::hsa_support::detail::operator<<( + out, v.hsa_ext_image_get_capability); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_images_1_pfn_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string( + "hsa_ext_images_1_pfn_t::hsa_ext_image_data_get_info_with_layout") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_data_get_info_with_layout="); + roctracer::hsa_support::detail::operator<<( + out, v.hsa_ext_image_data_get_info_with_layout); + std::operator<<(out, ", "); + } + if (std::string( + "hsa_ext_images_1_pfn_t::hsa_ext_image_get_capability_with_layout") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_get_capability_with_layout="); + roctracer::hsa_support::detail::operator<<( + out, v.hsa_ext_image_get_capability_with_layout); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_pfn_t::hsa_ext_sampler_destroy") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_sampler_destroy="); + roctracer::hsa_support::detail::operator<<(out, + v.hsa_ext_sampler_destroy); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_pfn_t::hsa_ext_sampler_create") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_sampler_create="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ext_sampler_create); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_pfn_t::hsa_ext_image_copy") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_copy="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ext_image_copy); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_pfn_t::hsa_ext_image_destroy") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_destroy="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ext_image_destroy); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_pfn_t::hsa_ext_image_data_get_info") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_data_get_info="); + roctracer::hsa_support::detail::operator<<(out, + v.hsa_ext_image_data_get_info); + std::operator<<(out, ", "); + } + if (std::string("hsa_ext_images_1_pfn_t::hsa_ext_image_get_capability") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ext_image_get_capability="); + roctracer::hsa_support::detail::operator<<( + out, v.hsa_ext_image_get_capability); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const perf_sample_hosttrap_v1_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("perf_sample_hosttrap_v1_t::correlation_id") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "correlation_id="); + roctracer::hsa_support::detail::operator<<(out, v.correlation_id); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::timestamp") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "timestamp="); + roctracer::hsa_support::detail::operator<<(out, v.timestamp); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::reserved1") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved1="); + roctracer::hsa_support::detail::operator<<(out, v.reserved1); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::reserved0") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved0="); + roctracer::hsa_support::detail::operator<<(out, v.reserved0); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::hw_id") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hw_id="); + roctracer::hsa_support::detail::operator<<(out, v.hw_id); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::reserved") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hsa_support::detail::operator<<(out, v.reserved); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::chiplet") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "chiplet="); + roctracer::hsa_support::detail::operator<<(out, v.chiplet); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::wave_in_wg") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "wave_in_wg="); + roctracer::hsa_support::detail::operator<<(out, v.wave_in_wg); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::workgroup_id_z") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_z="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_z); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::workgroup_id_y") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_y="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_y); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::workgroup_id_x") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_x="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_x); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::exec_mask") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "exec_mask="); + roctracer::hsa_support::detail::operator<<(out, v.exec_mask); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_hosttrap_v1_t::pc").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "pc="); + roctracer::hsa_support::detail::operator<<(out, v.pc); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const perf_sample_snapshot_v1_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("perf_sample_snapshot_v1_t::correlation_id") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "correlation_id="); + roctracer::hsa_support::detail::operator<<(out, v.correlation_id); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::timestamp") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "timestamp="); + roctracer::hsa_support::detail::operator<<(out, v.timestamp); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::perf_snapshot_data2") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "perf_snapshot_data2="); + roctracer::hsa_support::detail::operator<<(out, v.perf_snapshot_data2); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::perf_snapshot_data1") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "perf_snapshot_data1="); + roctracer::hsa_support::detail::operator<<(out, v.perf_snapshot_data1); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::perf_snapshot_data") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "perf_snapshot_data="); + roctracer::hsa_support::detail::operator<<(out, v.perf_snapshot_data); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::hw_id") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hw_id="); + roctracer::hsa_support::detail::operator<<(out, v.hw_id); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::reserved") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hsa_support::detail::operator<<(out, v.reserved); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::chiplet") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "chiplet="); + roctracer::hsa_support::detail::operator<<(out, v.chiplet); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::wave_in_wg") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "wave_in_wg="); + roctracer::hsa_support::detail::operator<<(out, v.wave_in_wg); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::workgroup_id_z") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_z="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_z); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::workgroup_id_y") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_y="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_y); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::workgroup_id_x") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "workgroup_id_x="); + roctracer::hsa_support::detail::operator<<(out, v.workgroup_id_x); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::exec_mask") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "exec_mask="); + roctracer::hsa_support::detail::operator<<(out, v.exec_mask); + std::operator<<(out, ", "); + } + if (std::string("perf_sample_snapshot_v1_t::pc").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "pc="); + roctracer::hsa_support::detail::operator<<(out, v.pc); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ven_amd_pcs_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ven_amd_pcs_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_ven_amd_pcs_configuration_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ven_amd_pcs_configuration_t::flags") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "flags="); + roctracer::hsa_support::detail::operator<<(out, v.flags); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pcs_configuration_t::max_interval") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "max_interval="); + roctracer::hsa_support::detail::operator<<(out, v.max_interval); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pcs_configuration_t::min_interval") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "min_interval="); + roctracer::hsa_support::detail::operator<<(out, v.min_interval); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pcs_configuration_t::units") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "units="); + roctracer::hsa_support::detail::operator<<(out, v.units); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pcs_configuration_t::method") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "method="); + roctracer::hsa_support::detail::operator<<(out, v.method); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_ven_amd_pc_sampling_1_00_pfn_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_ven_amd_pc_sampling_1_00_pfn_t::hsa_ven_amd_pcs_flush") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ven_amd_pcs_flush="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ven_amd_pcs_flush); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pc_sampling_1_00_pfn_t::hsa_ven_amd_pcs_stop") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ven_amd_pcs_stop="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ven_amd_pcs_stop); + std::operator<<(out, ", "); + } + if (std::string("hsa_ven_amd_pc_sampling_1_00_pfn_t::hsa_ven_amd_pcs_start") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ven_amd_pcs_start="); + roctracer::hsa_support::detail::operator<<(out, v.hsa_ven_amd_pcs_start); + std::operator<<(out, ", "); + } + if (std::string( + "hsa_ven_amd_pc_sampling_1_00_pfn_t::hsa_ven_amd_pcs_destroy") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "hsa_ven_amd_pcs_destroy="); + roctracer::hsa_support::detail::operator<<(out, + v.hsa_ven_amd_pcs_destroy); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_vendor_packet_header_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_vendor_packet_header_t::reserved") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved="); + roctracer::hsa_support::detail::operator<<(out, v.reserved); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_vendor_packet_header_t::AmdFormat") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "AmdFormat="); + roctracer::hsa_support::detail::operator<<(out, v.AmdFormat); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_vendor_packet_header_t::header") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "header="); + roctracer::hsa_support::detail::operator<<(out, v.header); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_barrier_value_packet_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_barrier_value_packet_t::completion_signal") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "completion_signal="); + roctracer::hsa_support::detail::operator<<(out, v.completion_signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::reserved3") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved3="); + roctracer::hsa_support::detail::operator<<(out, v.reserved3); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::reserved2") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved2="); + roctracer::hsa_support::detail::operator<<(out, v.reserved2); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::reserved1") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved1="); + roctracer::hsa_support::detail::operator<<(out, v.reserved1); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::cond") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "cond="); + roctracer::hsa_support::detail::operator<<(out, v.cond); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::mask") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "mask="); + roctracer::hsa_support::detail::operator<<(out, v.mask); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::value") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "value="); + roctracer::hsa_support::detail::operator<<(out, v.value); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::signal") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "signal="); + roctracer::hsa_support::detail::operator<<(out, v.signal); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::reserved0") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reserved0="); + roctracer::hsa_support::detail::operator<<(out, v.reserved0); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_barrier_value_packet_t::header") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "header="); + roctracer::hsa_support::detail::operator<<(out, v.header); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_hdp_flush_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_hdp_flush_t::HDP_REG_FLUSH_CNTL") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "HDP_REG_FLUSH_CNTL="); + roctracer::hsa_support::detail::operator<<(out, v.HDP_REG_FLUSH_CNTL); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_hdp_flush_t::HDP_MEM_FLUSH_CNTL") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "HDP_MEM_FLUSH_CNTL="); + roctracer::hsa_support::detail::operator<<(out, v.HDP_MEM_FLUSH_CNTL); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_profiling_dispatch_time_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_profiling_dispatch_time_t::end") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "end="); + roctracer::hsa_support::detail::operator<<(out, v.end); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_profiling_dispatch_time_t::start") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "start="); + roctracer::hsa_support::detail::operator<<(out, v.start); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_profiling_async_copy_time_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_profiling_async_copy_time_t::end") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "end="); + roctracer::hsa_support::detail::operator<<(out, v.end); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_profiling_async_copy_time_t::start") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "start="); + roctracer::hsa_support::detail::operator<<(out, v.start); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_memory_pool_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_memory_pool_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_pitched_ptr_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_pitched_ptr_t::slice").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "slice="); + roctracer::hsa_support::detail::operator<<(out, v.slice); + std::operator<<(out, ", "); + } + if (std::string("hsa_pitched_ptr_t::pitch").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "pitch="); + roctracer::hsa_support::detail::operator<<(out, v.pitch); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_memory_pool_link_info_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_memory_pool_link_info_t::numa_distance") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "numa_distance="); + roctracer::hsa_support::detail::operator<<(out, v.numa_distance); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_memory_pool_link_info_t::link_type") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "link_type="); + roctracer::hsa_support::detail::operator<<(out, v.link_type); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_memory_pool_link_info_t::max_bandwidth") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "max_bandwidth="); + roctracer::hsa_support::detail::operator<<(out, v.max_bandwidth); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_memory_pool_link_info_t::min_bandwidth") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "min_bandwidth="); + roctracer::hsa_support::detail::operator<<(out, v.min_bandwidth); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_memory_pool_link_info_t::max_latency") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "max_latency="); + roctracer::hsa_support::detail::operator<<(out, v.max_latency); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_memory_pool_link_info_t::min_latency") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "min_latency="); + roctracer::hsa_support::detail::operator<<(out, v.min_latency); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_image_descriptor_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_image_descriptor_t::data") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "data="); + roctracer::hsa_support::detail::operator<<(out, v.data); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_image_descriptor_t::deviceID") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "deviceID="); + roctracer::hsa_support::detail::operator<<(out, v.deviceID); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_image_descriptor_t::version") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "version="); + roctracer::hsa_support::detail::operator<<(out, v.version); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_pointer_info_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_pointer_info_t::global_flags") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "global_flags="); + roctracer::hsa_support::detail::operator<<(out, v.global_flags); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_pointer_info_t::agentOwner") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "agentOwner="); + roctracer::hsa_support::detail::operator<<(out, v.agentOwner); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_pointer_info_t::sizeInBytes") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "sizeInBytes="); + roctracer::hsa_support::detail::operator<<(out, v.sizeInBytes); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_pointer_info_t::type").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "type="); + roctracer::hsa_support::detail::operator<<(out, v.type); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_pointer_info_t::size").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "size="); + roctracer::hsa_support::detail::operator<<(out, v.size); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_ipc_memory_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_ipc_memory_t::handle").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_gpu_memory_fault_info_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_gpu_memory_fault_info_t::fault_reason_mask") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "fault_reason_mask="); + roctracer::hsa_support::detail::operator<<(out, v.fault_reason_mask); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_gpu_memory_fault_info_t::virtual_address") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "virtual_address="); + roctracer::hsa_support::detail::operator<<(out, v.virtual_address); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_gpu_memory_fault_info_t::agent") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "agent="); + roctracer::hsa_support::detail::operator<<(out, v.agent); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_gpu_hw_exception_info_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_gpu_hw_exception_info_t::reset_cause") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reset_cause="); + roctracer::hsa_support::detail::operator<<(out, v.reset_cause); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_gpu_hw_exception_info_t::reset_type") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "reset_type="); + roctracer::hsa_support::detail::operator<<(out, v.reset_type); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_gpu_hw_exception_info_t::agent") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "agent="); + roctracer::hsa_support::detail::operator<<(out, v.agent); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_event_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_event_t::event_type").find(HSA_structs_regex) != + std::string::npos) { + std::operator<<(out, "event_type="); + roctracer::hsa_support::detail::operator<<(out, v.event_type); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_svm_attribute_pair_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_svm_attribute_pair_t::value") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "value="); + roctracer::hsa_support::detail::operator<<(out, v.value); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_svm_attribute_pair_t::attribute") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "attribute="); + roctracer::hsa_support::detail::operator<<(out, v.attribute); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_vmem_alloc_handle_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_vmem_alloc_handle_t::handle") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "handle="); + roctracer::hsa_support::detail::operator<<(out, v.handle); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_memory_access_desc_t &v) { + std::operator<<(out, '{'); + HSA_depth_max_cnt++; + if (HSA_depth_max == -1 || HSA_depth_max_cnt <= HSA_depth_max) { + if (std::string("hsa_amd_memory_access_desc_t::agent_handle") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "agent_handle="); + roctracer::hsa_support::detail::operator<<(out, v.agent_handle); + std::operator<<(out, ", "); + } + if (std::string("hsa_amd_memory_access_desc_t::permissions") + .find(HSA_structs_regex) != std::string::npos) { + std::operator<<(out, "permissions="); + roctracer::hsa_support::detail::operator<<(out, v.permissions); + } + }; + HSA_depth_max_cnt--; + std::operator<<(out, '}'); + return out; +} +// end ostream ops for HSA +}; // namespace detail +}; // namespace hsa_support +}; // namespace roctracer + +inline static std::ostream &operator<<(std::ostream &out, const hsa_dim3_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_agent_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_cache_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_signal_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_signal_group_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_region_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_queue_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_kernel_dispatch_packet_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_agent_dispatch_packet_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_barrier_and_packet_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_barrier_or_packet_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, const hsa_isa_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_wavefront_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_code_object_reader_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_executable_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_loaded_code_object_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_executable_symbol_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_code_object_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_callback_data_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_code_symbol_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_format_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_descriptor_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_data_info_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_image_region_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_sampler_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_sampler_descriptor_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_images_1_00_pfn_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ext_images_1_pfn_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const perf_sample_hosttrap_v1_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const perf_sample_snapshot_v1_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_ven_amd_pcs_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_ven_amd_pcs_configuration_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_ven_amd_pc_sampling_1_00_pfn_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_vendor_packet_header_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_barrier_value_packet_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_hdp_flush_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_profiling_dispatch_time_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_profiling_async_copy_time_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_memory_pool_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_pitched_ptr_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_memory_pool_link_info_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_image_descriptor_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_pointer_info_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_ipc_memory_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_gpu_memory_fault_info_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream & +operator<<(std::ostream &out, const hsa_amd_gpu_hw_exception_info_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_event_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_svm_attribute_pair_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_vmem_alloc_handle_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +inline static std::ostream &operator<<(std::ostream &out, + const hsa_amd_memory_access_desc_t &v) { + roctracer::hsa_support::detail::operator<<(out, v); + return out; +} + +#endif //__cplusplus +#endif // INC_HSA_OSTREAM_OPS_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hsa_prof_str.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hsa_prof_str.h new file mode 100644 index 000000000..9874a47c6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/hsa_prof_str.h @@ -0,0 +1,3173 @@ +/* Generated by hsaap.py */ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +/* HSA API tracing primitives + 'CoreApi', header 'hsa.h', 125 funcs + 'AmdExt', header 'hsa_ext_amd.h', 70 funcs + 'ImageExt', header 'hsa_ext_image.h', 13 funcs + 'AmdExt', header 'hsa_api_trace.h', 70 funcs + */ + +#ifndef HSA_PROF_STR_H_ +#define HSA_PROF_STR_H_ + +/* section: API ID enumeration */ + +enum hsa_api_id_t { + /* block: CoreApi API */ + HSA_API_ID_hsa_init = 0, + HSA_API_ID_hsa_shut_down = 1, + HSA_API_ID_hsa_system_get_info = 2, + HSA_API_ID_hsa_system_extension_supported = 3, + HSA_API_ID_hsa_system_get_extension_table = 4, + HSA_API_ID_hsa_iterate_agents = 5, + HSA_API_ID_hsa_agent_get_info = 6, + HSA_API_ID_hsa_queue_create = 7, + HSA_API_ID_hsa_soft_queue_create = 8, + HSA_API_ID_hsa_queue_destroy = 9, + HSA_API_ID_hsa_queue_inactivate = 10, + HSA_API_ID_hsa_queue_load_read_index_scacquire = 11, + HSA_API_ID_hsa_queue_load_read_index_relaxed = 12, + HSA_API_ID_hsa_queue_load_write_index_scacquire = 13, + HSA_API_ID_hsa_queue_load_write_index_relaxed = 14, + HSA_API_ID_hsa_queue_store_write_index_relaxed = 15, + HSA_API_ID_hsa_queue_store_write_index_screlease = 16, + HSA_API_ID_hsa_queue_cas_write_index_scacq_screl = 17, + HSA_API_ID_hsa_queue_cas_write_index_scacquire = 18, + HSA_API_ID_hsa_queue_cas_write_index_relaxed = 19, + HSA_API_ID_hsa_queue_cas_write_index_screlease = 20, + HSA_API_ID_hsa_queue_add_write_index_scacq_screl = 21, + HSA_API_ID_hsa_queue_add_write_index_scacquire = 22, + HSA_API_ID_hsa_queue_add_write_index_relaxed = 23, + HSA_API_ID_hsa_queue_add_write_index_screlease = 24, + HSA_API_ID_hsa_queue_store_read_index_relaxed = 25, + HSA_API_ID_hsa_queue_store_read_index_screlease = 26, + HSA_API_ID_hsa_agent_iterate_regions = 27, + HSA_API_ID_hsa_region_get_info = 28, + HSA_API_ID_hsa_agent_get_exception_policies = 29, + HSA_API_ID_hsa_agent_extension_supported = 30, + HSA_API_ID_hsa_memory_register = 31, + HSA_API_ID_hsa_memory_deregister = 32, + HSA_API_ID_hsa_memory_allocate = 33, + HSA_API_ID_hsa_memory_free = 34, + HSA_API_ID_hsa_memory_copy = 35, + HSA_API_ID_hsa_memory_assign_agent = 36, + HSA_API_ID_hsa_signal_create = 37, + HSA_API_ID_hsa_signal_destroy = 38, + HSA_API_ID_hsa_signal_load_relaxed = 39, + HSA_API_ID_hsa_signal_load_scacquire = 40, + HSA_API_ID_hsa_signal_store_relaxed = 41, + HSA_API_ID_hsa_signal_store_screlease = 42, + HSA_API_ID_hsa_signal_wait_relaxed = 43, + HSA_API_ID_hsa_signal_wait_scacquire = 44, + HSA_API_ID_hsa_signal_and_relaxed = 45, + HSA_API_ID_hsa_signal_and_scacquire = 46, + HSA_API_ID_hsa_signal_and_screlease = 47, + HSA_API_ID_hsa_signal_and_scacq_screl = 48, + HSA_API_ID_hsa_signal_or_relaxed = 49, + HSA_API_ID_hsa_signal_or_scacquire = 50, + HSA_API_ID_hsa_signal_or_screlease = 51, + HSA_API_ID_hsa_signal_or_scacq_screl = 52, + HSA_API_ID_hsa_signal_xor_relaxed = 53, + HSA_API_ID_hsa_signal_xor_scacquire = 54, + HSA_API_ID_hsa_signal_xor_screlease = 55, + HSA_API_ID_hsa_signal_xor_scacq_screl = 56, + HSA_API_ID_hsa_signal_exchange_relaxed = 57, + HSA_API_ID_hsa_signal_exchange_scacquire = 58, + HSA_API_ID_hsa_signal_exchange_screlease = 59, + HSA_API_ID_hsa_signal_exchange_scacq_screl = 60, + HSA_API_ID_hsa_signal_add_relaxed = 61, + HSA_API_ID_hsa_signal_add_scacquire = 62, + HSA_API_ID_hsa_signal_add_screlease = 63, + HSA_API_ID_hsa_signal_add_scacq_screl = 64, + HSA_API_ID_hsa_signal_subtract_relaxed = 65, + HSA_API_ID_hsa_signal_subtract_scacquire = 66, + HSA_API_ID_hsa_signal_subtract_screlease = 67, + HSA_API_ID_hsa_signal_subtract_scacq_screl = 68, + HSA_API_ID_hsa_signal_cas_relaxed = 69, + HSA_API_ID_hsa_signal_cas_scacquire = 70, + HSA_API_ID_hsa_signal_cas_screlease = 71, + HSA_API_ID_hsa_signal_cas_scacq_screl = 72, + HSA_API_ID_hsa_isa_from_name = 73, + HSA_API_ID_hsa_isa_get_info = 74, + HSA_API_ID_hsa_isa_compatible = 75, + HSA_API_ID_hsa_code_object_serialize = 76, + HSA_API_ID_hsa_code_object_deserialize = 77, + HSA_API_ID_hsa_code_object_destroy = 78, + HSA_API_ID_hsa_code_object_get_info = 79, + HSA_API_ID_hsa_code_object_get_symbol = 80, + HSA_API_ID_hsa_code_symbol_get_info = 81, + HSA_API_ID_hsa_code_object_iterate_symbols = 82, + HSA_API_ID_hsa_executable_create = 83, + HSA_API_ID_hsa_executable_destroy = 84, + HSA_API_ID_hsa_executable_load_code_object = 85, + HSA_API_ID_hsa_executable_freeze = 86, + HSA_API_ID_hsa_executable_get_info = 87, + HSA_API_ID_hsa_executable_global_variable_define = 88, + HSA_API_ID_hsa_executable_agent_global_variable_define = 89, + HSA_API_ID_hsa_executable_readonly_variable_define = 90, + HSA_API_ID_hsa_executable_validate = 91, + HSA_API_ID_hsa_executable_get_symbol = 92, + HSA_API_ID_hsa_executable_symbol_get_info = 93, + HSA_API_ID_hsa_executable_iterate_symbols = 94, + HSA_API_ID_hsa_status_string = 95, + HSA_API_ID_hsa_extension_get_name = 96, + HSA_API_ID_hsa_system_major_extension_supported = 97, + HSA_API_ID_hsa_system_get_major_extension_table = 98, + HSA_API_ID_hsa_agent_major_extension_supported = 99, + HSA_API_ID_hsa_cache_get_info = 100, + HSA_API_ID_hsa_agent_iterate_caches = 101, + HSA_API_ID_hsa_signal_silent_store_relaxed = 102, + HSA_API_ID_hsa_signal_silent_store_screlease = 103, + HSA_API_ID_hsa_signal_group_create = 104, + HSA_API_ID_hsa_signal_group_destroy = 105, + HSA_API_ID_hsa_signal_group_wait_any_scacquire = 106, + HSA_API_ID_hsa_signal_group_wait_any_relaxed = 107, + HSA_API_ID_hsa_agent_iterate_isas = 108, + HSA_API_ID_hsa_isa_get_info_alt = 109, + HSA_API_ID_hsa_isa_get_exception_policies = 110, + HSA_API_ID_hsa_isa_get_round_method = 111, + HSA_API_ID_hsa_wavefront_get_info = 112, + HSA_API_ID_hsa_isa_iterate_wavefronts = 113, + HSA_API_ID_hsa_code_object_get_symbol_from_name = 114, + HSA_API_ID_hsa_code_object_reader_create_from_file = 115, + HSA_API_ID_hsa_code_object_reader_create_from_memory = 116, + HSA_API_ID_hsa_code_object_reader_destroy = 117, + HSA_API_ID_hsa_executable_create_alt = 118, + HSA_API_ID_hsa_executable_load_program_code_object = 119, + HSA_API_ID_hsa_executable_load_agent_code_object = 120, + HSA_API_ID_hsa_executable_validate_alt = 121, + HSA_API_ID_hsa_executable_get_symbol_by_name = 122, + HSA_API_ID_hsa_executable_iterate_agent_symbols = 123, + HSA_API_ID_hsa_executable_iterate_program_symbols = 124, + + /* block: AmdExt API */ + HSA_API_ID_hsa_amd_coherency_get_type = 125, + HSA_API_ID_hsa_amd_coherency_set_type = 126, + HSA_API_ID_hsa_amd_profiling_set_profiler_enabled = 127, + HSA_API_ID_hsa_amd_profiling_async_copy_enable = 128, + HSA_API_ID_hsa_amd_profiling_get_dispatch_time = 129, + HSA_API_ID_hsa_amd_profiling_get_async_copy_time = 130, + HSA_API_ID_hsa_amd_profiling_convert_tick_to_system_domain = 131, + HSA_API_ID_hsa_amd_signal_async_handler = 132, + HSA_API_ID_hsa_amd_async_function = 133, + HSA_API_ID_hsa_amd_signal_wait_any = 134, + HSA_API_ID_hsa_amd_queue_cu_set_mask = 135, + HSA_API_ID_hsa_amd_memory_pool_get_info = 136, + HSA_API_ID_hsa_amd_agent_iterate_memory_pools = 137, + HSA_API_ID_hsa_amd_memory_pool_allocate = 138, + HSA_API_ID_hsa_amd_memory_pool_free = 139, + HSA_API_ID_hsa_amd_memory_async_copy = 140, + HSA_API_ID_hsa_amd_memory_async_copy_on_engine = 141, + HSA_API_ID_hsa_amd_memory_copy_engine_status = 142, + HSA_API_ID_hsa_amd_agent_memory_pool_get_info = 143, + HSA_API_ID_hsa_amd_agents_allow_access = 144, + HSA_API_ID_hsa_amd_memory_pool_can_migrate = 145, + HSA_API_ID_hsa_amd_memory_migrate = 146, + HSA_API_ID_hsa_amd_memory_lock = 147, + HSA_API_ID_hsa_amd_memory_unlock = 148, + HSA_API_ID_hsa_amd_memory_fill = 149, + HSA_API_ID_hsa_amd_interop_map_buffer = 150, + HSA_API_ID_hsa_amd_interop_unmap_buffer = 151, + HSA_API_ID_hsa_amd_image_create = 152, + HSA_API_ID_hsa_amd_pointer_info = 153, + HSA_API_ID_hsa_amd_pointer_info_set_userdata = 154, + HSA_API_ID_hsa_amd_ipc_memory_create = 155, + HSA_API_ID_hsa_amd_ipc_memory_attach = 156, + HSA_API_ID_hsa_amd_ipc_memory_detach = 157, + HSA_API_ID_hsa_amd_signal_create = 158, + HSA_API_ID_hsa_amd_ipc_signal_create = 159, + HSA_API_ID_hsa_amd_ipc_signal_attach = 160, + HSA_API_ID_hsa_amd_register_system_event_handler = 161, + HSA_API_ID_hsa_amd_queue_intercept_create = 162, + HSA_API_ID_hsa_amd_queue_intercept_register = 163, + HSA_API_ID_hsa_amd_queue_set_priority = 164, + HSA_API_ID_hsa_amd_memory_async_copy_rect = 165, + HSA_API_ID_hsa_amd_runtime_queue_create_register = 166, + HSA_API_ID_hsa_amd_memory_lock_to_pool = 167, + HSA_API_ID_hsa_amd_register_deallocation_callback = 168, + HSA_API_ID_hsa_amd_deregister_deallocation_callback = 169, + HSA_API_ID_hsa_amd_signal_value_pointer = 170, + HSA_API_ID_hsa_amd_svm_attributes_set = 171, + HSA_API_ID_hsa_amd_svm_attributes_get = 172, + HSA_API_ID_hsa_amd_svm_prefetch_async = 173, + HSA_API_ID_hsa_amd_spm_acquire = 174, + HSA_API_ID_hsa_amd_spm_release = 175, + HSA_API_ID_hsa_amd_spm_set_dest_buffer = 176, + HSA_API_ID_hsa_amd_queue_cu_get_mask = 177, + HSA_API_ID_hsa_amd_portable_export_dmabuf = 178, + HSA_API_ID_hsa_amd_portable_close_dmabuf = 179, + HSA_API_ID_hsa_amd_vmem_address_reserve = 180, + HSA_API_ID_hsa_amd_vmem_address_free = 181, + HSA_API_ID_hsa_amd_vmem_handle_create = 182, + HSA_API_ID_hsa_amd_vmem_handle_release = 183, + HSA_API_ID_hsa_amd_vmem_map = 184, + HSA_API_ID_hsa_amd_vmem_unmap = 185, + HSA_API_ID_hsa_amd_vmem_set_access = 186, + HSA_API_ID_hsa_amd_vmem_get_access = 187, + HSA_API_ID_hsa_amd_vmem_export_shareable_handle = 188, + HSA_API_ID_hsa_amd_vmem_import_shareable_handle = 189, + HSA_API_ID_hsa_amd_vmem_retain_alloc_handle = 190, + HSA_API_ID_hsa_amd_vmem_get_alloc_properties_from_handle = 191, + HSA_API_ID_hsa_amd_agent_set_async_scratch_limit = 192, + HSA_API_ID_hsa_amd_queue_get_info = 193, + HSA_API_ID_hsa_amd_vmem_address_reserve_align = 194, + + /* block: ImageExt API */ + HSA_API_ID_hsa_ext_image_get_capability = 195, + HSA_API_ID_hsa_ext_image_data_get_info = 196, + HSA_API_ID_hsa_ext_image_create = 197, + HSA_API_ID_hsa_ext_image_import = 198, + HSA_API_ID_hsa_ext_image_export = 199, + HSA_API_ID_hsa_ext_image_copy = 200, + HSA_API_ID_hsa_ext_image_clear = 201, + HSA_API_ID_hsa_ext_image_destroy = 202, + HSA_API_ID_hsa_ext_sampler_create = 203, + HSA_API_ID_hsa_ext_sampler_destroy = 204, + HSA_API_ID_hsa_ext_image_get_capability_with_layout = 205, + HSA_API_ID_hsa_ext_image_data_get_info_with_layout = 206, + HSA_API_ID_hsa_ext_image_create_with_layout = 207, + + HSA_API_ID_DISPATCH = 208, + HSA_API_ID_NUMBER = 209, +}; +/* Declarations of APIs intended for use only by tools. */ +typedef void (*hsa_amd_queue_intercept_packet_writer)(const void *, uint64_t); +typedef void (*hsa_amd_queue_intercept_handler)( + const void *, uint64_t, uint64_t, void *, + hsa_amd_queue_intercept_packet_writer); +typedef void (*hsa_amd_runtime_queue_notifier)(const hsa_queue_t *, hsa_agent_t, + void *); + +/* section: API arg structure */ + +struct hsa_api_data_t { + uint64_t correlation_id; + uint32_t phase; + union { + uint64_t uint64_t_retval; + hsa_status_t hsa_status_t_retval; + hsa_signal_value_t hsa_signal_value_t_retval; + uint32_t uint32_t_retval; + }; + union { + /* block: CoreApi API */ + struct { + } hsa_init; + struct { + } hsa_shut_down; + struct { + hsa_system_info_t attribute; + void *value; + } hsa_system_get_info; + struct { + uint16_t extension; + uint16_t version_major; + uint16_t version_minor; + bool *result; + } hsa_system_extension_supported; + struct { + uint16_t extension; + uint16_t version_major; + uint16_t version_minor; + void *table; + } hsa_system_get_extension_table; + struct { + hsa_status_t (*callback)(hsa_agent_t agent, void *data); + void *data; + } hsa_iterate_agents; + struct { + hsa_agent_t agent; + hsa_agent_info_t attribute; + void *value; + } hsa_agent_get_info; + struct { + hsa_agent_t agent; + uint32_t size; + hsa_queue_type32_t type; + void (*callback)(hsa_status_t status, hsa_queue_t *source, void *data); + void *data; + uint32_t private_segment_size; + uint32_t group_segment_size; + hsa_queue_t **queue; + } hsa_queue_create; + struct { + hsa_region_t region; + uint32_t size; + hsa_queue_type32_t type; + uint32_t features; + hsa_signal_t doorbell_signal; + hsa_queue_t **queue; + } hsa_soft_queue_create; + struct { + hsa_queue_t *queue; + } hsa_queue_destroy; + struct { + hsa_queue_t *queue; + } hsa_queue_inactivate; + struct { + const hsa_queue_t *queue; + } hsa_queue_load_read_index_scacquire; + struct { + const hsa_queue_t *queue; + } hsa_queue_load_read_index_relaxed; + struct { + const hsa_queue_t *queue; + } hsa_queue_load_write_index_scacquire; + struct { + const hsa_queue_t *queue; + } hsa_queue_load_write_index_relaxed; + struct { + const hsa_queue_t *queue; + uint64_t value; + } hsa_queue_store_write_index_relaxed; + struct { + const hsa_queue_t *queue; + uint64_t value; + } hsa_queue_store_write_index_screlease; + struct { + const hsa_queue_t *queue; + uint64_t expected; + uint64_t value; + } hsa_queue_cas_write_index_scacq_screl; + struct { + const hsa_queue_t *queue; + uint64_t expected; + uint64_t value; + } hsa_queue_cas_write_index_scacquire; + struct { + const hsa_queue_t *queue; + uint64_t expected; + uint64_t value; + } hsa_queue_cas_write_index_relaxed; + struct { + const hsa_queue_t *queue; + uint64_t expected; + uint64_t value; + } hsa_queue_cas_write_index_screlease; + struct { + const hsa_queue_t *queue; + uint64_t value; + } hsa_queue_add_write_index_scacq_screl; + struct { + const hsa_queue_t *queue; + uint64_t value; + } hsa_queue_add_write_index_scacquire; + struct { + const hsa_queue_t *queue; + uint64_t value; + } hsa_queue_add_write_index_relaxed; + struct { + const hsa_queue_t *queue; + uint64_t value; + } hsa_queue_add_write_index_screlease; + struct { + const hsa_queue_t *queue; + uint64_t value; + } hsa_queue_store_read_index_relaxed; + struct { + const hsa_queue_t *queue; + uint64_t value; + } hsa_queue_store_read_index_screlease; + struct { + hsa_agent_t agent; + hsa_status_t (*callback)(hsa_region_t region, void *data); + void *data; + } hsa_agent_iterate_regions; + struct { + hsa_region_t region; + hsa_region_info_t attribute; + void *value; + } hsa_region_get_info; + struct { + hsa_agent_t agent; + hsa_profile_t profile; + uint16_t *mask; + } hsa_agent_get_exception_policies; + struct { + uint16_t extension; + hsa_agent_t agent; + uint16_t version_major; + uint16_t version_minor; + bool *result; + } hsa_agent_extension_supported; + struct { + void *ptr; + size_t size; + } hsa_memory_register; + struct { + void *ptr; + size_t size; + } hsa_memory_deregister; + struct { + hsa_region_t region; + size_t size; + void **ptr; + } hsa_memory_allocate; + struct { + void *ptr; + } hsa_memory_free; + struct { + void *dst; + const void *src; + size_t size; + } hsa_memory_copy; + struct { + void *ptr; + hsa_agent_t agent; + hsa_access_permission_t access; + } hsa_memory_assign_agent; + struct { + hsa_signal_value_t initial_value; + uint32_t num_consumers; + const hsa_agent_t *consumers; + hsa_signal_t *signal; + } hsa_signal_create; + struct { + hsa_signal_t signal; + } hsa_signal_destroy; + struct { + hsa_signal_t signal; + } hsa_signal_load_relaxed; + struct { + hsa_signal_t signal; + } hsa_signal_load_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_store_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_store_screlease; + struct { + hsa_signal_t signal; + hsa_signal_condition_t condition; + hsa_signal_value_t compare_value; + uint64_t timeout_hint; + hsa_wait_state_t wait_state_hint; + } hsa_signal_wait_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_condition_t condition; + hsa_signal_value_t compare_value; + uint64_t timeout_hint; + hsa_wait_state_t wait_state_hint; + } hsa_signal_wait_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_and_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_and_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_and_screlease; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_and_scacq_screl; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_or_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_or_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_or_screlease; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_or_scacq_screl; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_xor_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_xor_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_xor_screlease; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_xor_scacq_screl; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_exchange_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_exchange_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_exchange_screlease; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_exchange_scacq_screl; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_add_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_add_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_add_screlease; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_add_scacq_screl; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_subtract_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_subtract_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_subtract_screlease; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_subtract_scacq_screl; + struct { + hsa_signal_t signal; + hsa_signal_value_t expected; + hsa_signal_value_t value; + } hsa_signal_cas_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t expected; + hsa_signal_value_t value; + } hsa_signal_cas_scacquire; + struct { + hsa_signal_t signal; + hsa_signal_value_t expected; + hsa_signal_value_t value; + } hsa_signal_cas_screlease; + struct { + hsa_signal_t signal; + hsa_signal_value_t expected; + hsa_signal_value_t value; + } hsa_signal_cas_scacq_screl; + struct { + const char *name; + hsa_isa_t *isa; + } hsa_isa_from_name; + struct { + hsa_isa_t isa; + hsa_isa_info_t attribute; + uint32_t index; + void *value; + } hsa_isa_get_info; + struct { + hsa_isa_t code_object_isa; + hsa_isa_t agent_isa; + bool *result; + } hsa_isa_compatible; + struct { + hsa_code_object_t code_object; + hsa_status_t (*alloc_callback)(size_t size, hsa_callback_data_t data, + void **address); + hsa_callback_data_t callback_data; + const char *options; + void **serialized_code_object; + size_t *serialized_code_object_size; + } hsa_code_object_serialize; + struct { + void *serialized_code_object; + size_t serialized_code_object_size; + const char *options; + hsa_code_object_t *code_object; + } hsa_code_object_deserialize; + struct { + hsa_code_object_t code_object; + } hsa_code_object_destroy; + struct { + hsa_code_object_t code_object; + hsa_code_object_info_t attribute; + void *value; + } hsa_code_object_get_info; + struct { + hsa_code_object_t code_object; + const char *symbol_name; + hsa_code_symbol_t *symbol; + } hsa_code_object_get_symbol; + struct { + hsa_code_symbol_t code_symbol; + hsa_code_symbol_info_t attribute; + void *value; + } hsa_code_symbol_get_info; + struct { + hsa_code_object_t code_object; + hsa_status_t (*callback)(hsa_code_object_t code_object, + hsa_code_symbol_t symbol, void *data); + void *data; + } hsa_code_object_iterate_symbols; + struct { + hsa_profile_t profile; + hsa_executable_state_t executable_state; + const char *options; + hsa_executable_t *executable; + } hsa_executable_create; + struct { + hsa_executable_t executable; + } hsa_executable_destroy; + struct { + hsa_executable_t executable; + hsa_agent_t agent; + hsa_code_object_t code_object; + const char *options; + } hsa_executable_load_code_object; + struct { + hsa_executable_t executable; + const char *options; + } hsa_executable_freeze; + struct { + hsa_executable_t executable; + hsa_executable_info_t attribute; + void *value; + } hsa_executable_get_info; + struct { + hsa_executable_t executable; + const char *variable_name; + void *address; + } hsa_executable_global_variable_define; + struct { + hsa_executable_t executable; + hsa_agent_t agent; + const char *variable_name; + void *address; + } hsa_executable_agent_global_variable_define; + struct { + hsa_executable_t executable; + hsa_agent_t agent; + const char *variable_name; + void *address; + } hsa_executable_readonly_variable_define; + struct { + hsa_executable_t executable; + uint32_t *result; + } hsa_executable_validate; + struct { + hsa_executable_t executable; + const char *module_name; + const char *symbol_name; + hsa_agent_t agent; + int32_t call_convention; + hsa_executable_symbol_t *symbol; + } hsa_executable_get_symbol; + struct { + hsa_executable_symbol_t executable_symbol; + hsa_executable_symbol_info_t attribute; + void *value; + } hsa_executable_symbol_get_info; + struct { + hsa_executable_t executable; + hsa_status_t (*callback)(hsa_executable_t exec, + hsa_executable_symbol_t symbol, void *data); + void *data; + } hsa_executable_iterate_symbols; + struct { + hsa_status_t status; + const char **status_string; + } hsa_status_string; + struct { + uint16_t extension; + const char **name; + } hsa_extension_get_name; + struct { + uint16_t extension; + uint16_t version_major; + uint16_t *version_minor; + bool *result; + } hsa_system_major_extension_supported; + struct { + uint16_t extension; + uint16_t version_major; + size_t table_length; + void *table; + } hsa_system_get_major_extension_table; + struct { + uint16_t extension; + hsa_agent_t agent; + uint16_t version_major; + uint16_t *version_minor; + bool *result; + } hsa_agent_major_extension_supported; + struct { + hsa_cache_t cache; + hsa_cache_info_t attribute; + void *value; + } hsa_cache_get_info; + struct { + hsa_agent_t agent; + hsa_status_t (*callback)(hsa_cache_t cache, void *data); + void *data; + } hsa_agent_iterate_caches; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_silent_store_relaxed; + struct { + hsa_signal_t signal; + hsa_signal_value_t value; + } hsa_signal_silent_store_screlease; + struct { + uint32_t num_signals; + const hsa_signal_t *signals; + uint32_t num_consumers; + const hsa_agent_t *consumers; + hsa_signal_group_t *signal_group; + } hsa_signal_group_create; + struct { + hsa_signal_group_t signal_group; + } hsa_signal_group_destroy; + struct { + hsa_signal_group_t signal_group; + const hsa_signal_condition_t *conditions; + const hsa_signal_value_t *compare_values; + hsa_wait_state_t wait_state_hint; + hsa_signal_t *signal; + hsa_signal_value_t *value; + } hsa_signal_group_wait_any_scacquire; + struct { + hsa_signal_group_t signal_group; + const hsa_signal_condition_t *conditions; + const hsa_signal_value_t *compare_values; + hsa_wait_state_t wait_state_hint; + hsa_signal_t *signal; + hsa_signal_value_t *value; + } hsa_signal_group_wait_any_relaxed; + struct { + hsa_agent_t agent; + hsa_status_t (*callback)(hsa_isa_t isa, void *data); + void *data; + } hsa_agent_iterate_isas; + struct { + hsa_isa_t isa; + hsa_isa_info_t attribute; + void *value; + } hsa_isa_get_info_alt; + struct { + hsa_isa_t isa; + hsa_profile_t profile; + uint16_t *mask; + } hsa_isa_get_exception_policies; + struct { + hsa_isa_t isa; + hsa_fp_type_t fp_type; + hsa_flush_mode_t flush_mode; + hsa_round_method_t *round_method; + } hsa_isa_get_round_method; + struct { + hsa_wavefront_t wavefront; + hsa_wavefront_info_t attribute; + void *value; + } hsa_wavefront_get_info; + struct { + hsa_isa_t isa; + hsa_status_t (*callback)(hsa_wavefront_t wavefront, void *data); + void *data; + } hsa_isa_iterate_wavefronts; + struct { + hsa_code_object_t code_object; + const char *module_name; + const char *symbol_name; + hsa_code_symbol_t *symbol; + } hsa_code_object_get_symbol_from_name; + struct { + hsa_file_t file; + hsa_code_object_reader_t *code_object_reader; + } hsa_code_object_reader_create_from_file; + struct { + const void *code_object; + size_t size; + hsa_code_object_reader_t *code_object_reader; + } hsa_code_object_reader_create_from_memory; + struct { + hsa_code_object_reader_t code_object_reader; + } hsa_code_object_reader_destroy; + struct { + hsa_profile_t profile; + hsa_default_float_rounding_mode_t default_float_rounding_mode; + const char *options; + hsa_executable_t *executable; + } hsa_executable_create_alt; + struct { + hsa_executable_t executable; + hsa_code_object_reader_t code_object_reader; + const char *options; + hsa_loaded_code_object_t *loaded_code_object; + } hsa_executable_load_program_code_object; + struct { + hsa_executable_t executable; + hsa_agent_t agent; + hsa_code_object_reader_t code_object_reader; + const char *options; + hsa_loaded_code_object_t *loaded_code_object; + } hsa_executable_load_agent_code_object; + struct { + hsa_executable_t executable; + const char *options; + uint32_t *result; + } hsa_executable_validate_alt; + struct { + hsa_executable_t executable; + const char *symbol_name; + const hsa_agent_t *agent; + hsa_executable_symbol_t *symbol; + } hsa_executable_get_symbol_by_name; + struct { + hsa_executable_t executable; + hsa_agent_t agent; + hsa_status_t (*callback)(hsa_executable_t exec, hsa_agent_t agent, + hsa_executable_symbol_t symbol, void *data); + void *data; + } hsa_executable_iterate_agent_symbols; + struct { + hsa_executable_t executable; + hsa_status_t (*callback)(hsa_executable_t exec, + hsa_executable_symbol_t symbol, void *data); + void *data; + } hsa_executable_iterate_program_symbols; + + /* block: AmdExt API */ + struct { + hsa_agent_t agent; + hsa_amd_coherency_type_t *type; + } hsa_amd_coherency_get_type; + struct { + hsa_agent_t agent; + hsa_amd_coherency_type_t type; + } hsa_amd_coherency_set_type; + struct { + hsa_queue_t *queue; + int enable; + } hsa_amd_profiling_set_profiler_enabled; + struct { + bool enable; + } hsa_amd_profiling_async_copy_enable; + struct { + hsa_agent_t agent; + hsa_signal_t signal; + hsa_amd_profiling_dispatch_time_t *time; + } hsa_amd_profiling_get_dispatch_time; + struct { + hsa_signal_t signal; + hsa_amd_profiling_async_copy_time_t *time; + } hsa_amd_profiling_get_async_copy_time; + struct { + hsa_agent_t agent; + uint64_t agent_tick; + uint64_t *system_tick; + } hsa_amd_profiling_convert_tick_to_system_domain; + struct { + hsa_signal_t signal; + hsa_signal_condition_t cond; + hsa_signal_value_t value; + hsa_amd_signal_handler handler; + void *arg; + } hsa_amd_signal_async_handler; + struct { + void (*callback)(void *arg); + void *arg; + } hsa_amd_async_function; + struct { + uint32_t signal_count; + hsa_signal_t *signals; + hsa_signal_condition_t *conds; + hsa_signal_value_t *values; + uint64_t timeout_hint; + hsa_wait_state_t wait_hint; + hsa_signal_value_t *satisfying_value; + } hsa_amd_signal_wait_any; + struct { + const hsa_queue_t *queue; + uint32_t num_cu_mask_count; + const uint32_t *cu_mask; + } hsa_amd_queue_cu_set_mask; + struct { + hsa_amd_memory_pool_t memory_pool; + hsa_amd_memory_pool_info_t attribute; + void *value; + } hsa_amd_memory_pool_get_info; + struct { + hsa_agent_t agent; + hsa_status_t (*callback)(hsa_amd_memory_pool_t memory_pool, void *data); + void *data; + } hsa_amd_agent_iterate_memory_pools; + struct { + hsa_amd_memory_pool_t memory_pool; + size_t size; + uint32_t flags; + void **ptr; + } hsa_amd_memory_pool_allocate; + struct { + void *ptr; + } hsa_amd_memory_pool_free; + struct { + void *dst; + hsa_agent_t dst_agent; + const void *src; + hsa_agent_t src_agent; + size_t size; + uint32_t num_dep_signals; + const hsa_signal_t *dep_signals; + hsa_signal_t completion_signal; + } hsa_amd_memory_async_copy; + struct { + void *dst; + hsa_agent_t dst_agent; + const void *src; + hsa_agent_t src_agent; + size_t size; + uint32_t num_dep_signals; + const hsa_signal_t *dep_signals; + hsa_signal_t completion_signal; + hsa_amd_sdma_engine_id_t engine_id; + bool force_copy_on_sdma; + } hsa_amd_memory_async_copy_on_engine; + struct { + hsa_agent_t dst_agent; + hsa_agent_t src_agent; + uint32_t *engine_ids_mask; + } hsa_amd_memory_copy_engine_status; + struct { + hsa_agent_t agent; + hsa_amd_memory_pool_t memory_pool; + hsa_amd_agent_memory_pool_info_t attribute; + void *value; + } hsa_amd_agent_memory_pool_get_info; + struct { + uint32_t num_agents; + const hsa_agent_t *agents; + const uint32_t *flags; + const void *ptr; + } hsa_amd_agents_allow_access; + struct { + hsa_amd_memory_pool_t src_memory_pool; + hsa_amd_memory_pool_t dst_memory_pool; + bool *result; + } hsa_amd_memory_pool_can_migrate; + struct { + const void *ptr; + hsa_amd_memory_pool_t memory_pool; + uint32_t flags; + } hsa_amd_memory_migrate; + struct { + void *host_ptr; + size_t size; + hsa_agent_t *agents; + int num_agent; + void **agent_ptr; + } hsa_amd_memory_lock; + struct { + void *host_ptr; + } hsa_amd_memory_unlock; + struct { + void *ptr; + uint32_t value; + size_t count; + } hsa_amd_memory_fill; + struct { + uint32_t num_agents; + hsa_agent_t *agents; + int interop_handle; + uint32_t flags; + size_t *size; + void **ptr; + size_t *metadata_size; + const void **metadata; + } hsa_amd_interop_map_buffer; + struct { + void *ptr; + } hsa_amd_interop_unmap_buffer; + struct { + hsa_agent_t agent; + const hsa_ext_image_descriptor_t *image_descriptor; + const hsa_amd_image_descriptor_t *image_layout; + const void *image_data; + hsa_access_permission_t access_permission; + hsa_ext_image_t *image; + } hsa_amd_image_create; + struct { + const void *ptr; + hsa_amd_pointer_info_t *info; + void *(*alloc)(size_t); + uint32_t *num_agents_accessible; + hsa_agent_t **accessible; + } hsa_amd_pointer_info; + struct { + const void *ptr; + void *userdata; + } hsa_amd_pointer_info_set_userdata; + struct { + void *ptr; + size_t len; + hsa_amd_ipc_memory_t *handle; + } hsa_amd_ipc_memory_create; + struct { + const hsa_amd_ipc_memory_t *handle; + size_t len; + uint32_t num_agents; + const hsa_agent_t *mapping_agents; + void **mapped_ptr; + } hsa_amd_ipc_memory_attach; + struct { + void *mapped_ptr; + } hsa_amd_ipc_memory_detach; + struct { + hsa_signal_value_t initial_value; + uint32_t num_consumers; + const hsa_agent_t *consumers; + uint64_t attributes; + hsa_signal_t *signal; + } hsa_amd_signal_create; + struct { + hsa_signal_t signal; + hsa_amd_ipc_signal_t *handle; + } hsa_amd_ipc_signal_create; + struct { + const hsa_amd_ipc_signal_t *handle; + hsa_signal_t *signal; + } hsa_amd_ipc_signal_attach; + struct { + hsa_amd_system_event_callback_t callback; + void *data; + } hsa_amd_register_system_event_handler; + struct { + hsa_agent_t agent_handle; + uint32_t size; + hsa_queue_type32_t type; + void (*callback)(hsa_status_t status, hsa_queue_t *source, void *data); + void *data; + uint32_t private_segment_size; + uint32_t group_segment_size; + hsa_queue_t **queue; + } hsa_amd_queue_intercept_create; + struct { + hsa_queue_t *queue; + hsa_amd_queue_intercept_handler callback; + void *user_data; + } hsa_amd_queue_intercept_register; + struct { + hsa_queue_t *queue; + hsa_amd_queue_priority_t priority; + } hsa_amd_queue_set_priority; + struct { + const hsa_pitched_ptr_t *dst; + const hsa_dim3_t *dst_offset; + const hsa_pitched_ptr_t *src; + const hsa_dim3_t *src_offset; + const hsa_dim3_t *range; + hsa_dim3_t range__val; + hsa_agent_t copy_agent; + hsa_amd_copy_direction_t dir; + uint32_t num_dep_signals; + const hsa_signal_t *dep_signals; + hsa_signal_t completion_signal; + } hsa_amd_memory_async_copy_rect; + struct { + hsa_amd_runtime_queue_notifier callback; + void *user_data; + } hsa_amd_runtime_queue_create_register; + struct { + void *host_ptr; + size_t size; + hsa_agent_t *agents; + int num_agent; + hsa_amd_memory_pool_t pool; + uint32_t flags; + void **agent_ptr; + } hsa_amd_memory_lock_to_pool; + struct { + void *ptr; + hsa_amd_deallocation_callback_t callback; + void *user_data; + } hsa_amd_register_deallocation_callback; + struct { + void *ptr; + hsa_amd_deallocation_callback_t callback; + } hsa_amd_deregister_deallocation_callback; + struct { + hsa_signal_t signal; + volatile hsa_signal_value_t **value_ptr; + } hsa_amd_signal_value_pointer; + struct { + void *ptr; + size_t size; + hsa_amd_svm_attribute_pair_t *attribute_list; + size_t attribute_count; + } hsa_amd_svm_attributes_set; + struct { + void *ptr; + size_t size; + hsa_amd_svm_attribute_pair_t *attribute_list; + size_t attribute_count; + } hsa_amd_svm_attributes_get; + struct { + void *ptr; + size_t size; + hsa_agent_t agent; + uint32_t num_dep_signals; + const hsa_signal_t *dep_signals; + hsa_signal_t completion_signal; + } hsa_amd_svm_prefetch_async; + struct { + hsa_agent_t preferred_agent; + } hsa_amd_spm_acquire; + struct { + hsa_agent_t preferred_agent; + } hsa_amd_spm_release; + struct { + hsa_agent_t preferred_agent; + size_t size_in_bytes; + uint32_t *timeout; + uint32_t *size_copied; + void *dest; + bool *is_data_loss; + } hsa_amd_spm_set_dest_buffer; + struct { + const hsa_queue_t *queue; + uint32_t num_cu_mask_count; + uint32_t *cu_mask; + } hsa_amd_queue_cu_get_mask; + struct { + const void *ptr; + size_t size; + int *dmabuf; + uint64_t *offset; + } hsa_amd_portable_export_dmabuf; + struct { + int dmabuf; + } hsa_amd_portable_close_dmabuf; + struct { + void **va; + size_t size; + uint64_t address; + uint64_t flags; + } hsa_amd_vmem_address_reserve; + struct { + void *va; + size_t size; + } hsa_amd_vmem_address_free; + struct { + hsa_amd_memory_pool_t pool; + size_t size; + hsa_amd_memory_type_t type; + uint64_t flags; + hsa_amd_vmem_alloc_handle_t *memory_handle; + } hsa_amd_vmem_handle_create; + struct { + hsa_amd_vmem_alloc_handle_t memory_handle; + } hsa_amd_vmem_handle_release; + struct { + void *va; + size_t size; + size_t in_offset; + hsa_amd_vmem_alloc_handle_t memory_handle; + uint64_t flags; + } hsa_amd_vmem_map; + struct { + void *va; + size_t size; + } hsa_amd_vmem_unmap; + struct { + void *va; + size_t size; + const hsa_amd_memory_access_desc_t *desc; + size_t desc_cnt; + } hsa_amd_vmem_set_access; + struct { + void *va; + hsa_access_permission_t *perms; + hsa_agent_t agent_handle; + } hsa_amd_vmem_get_access; + struct { + int *dmabuf_fd; + hsa_amd_vmem_alloc_handle_t handle; + uint64_t flags; + } hsa_amd_vmem_export_shareable_handle; + struct { + int dmabuf_fd; + hsa_amd_vmem_alloc_handle_t *handle; + } hsa_amd_vmem_import_shareable_handle; + struct { + hsa_amd_vmem_alloc_handle_t *memory_handle; + void *addr; + } hsa_amd_vmem_retain_alloc_handle; + struct { + hsa_amd_vmem_alloc_handle_t memory_handle; + hsa_amd_memory_pool_t *pool; + hsa_amd_memory_type_t *type; + } hsa_amd_vmem_get_alloc_properties_from_handle; + struct { + hsa_agent_t agent; + size_t threshold; + } hsa_amd_agent_set_async_scratch_limit; + struct { + hsa_queue_t *queue; + hsa_queue_info_attribute_t attribute; + void *value; + } hsa_amd_queue_get_info; + struct { + void **va; + size_t size; + uint64_t address; + uint64_t alignment; + uint64_t flags; + } hsa_amd_vmem_address_reserve_align; + + /* block: ImageExt API */ + struct { + hsa_agent_t agent; + hsa_ext_image_geometry_t geometry; + const hsa_ext_image_format_t *image_format; + uint32_t *capability_mask; + } hsa_ext_image_get_capability; + struct { + hsa_agent_t agent; + const hsa_ext_image_descriptor_t *image_descriptor; + hsa_access_permission_t access_permission; + hsa_ext_image_data_info_t *image_data_info; + } hsa_ext_image_data_get_info; + struct { + hsa_agent_t agent; + const hsa_ext_image_descriptor_t *image_descriptor; + const void *image_data; + hsa_access_permission_t access_permission; + hsa_ext_image_t *image; + } hsa_ext_image_create; + struct { + hsa_agent_t agent; + const void *src_memory; + size_t src_row_pitch; + size_t src_slice_pitch; + hsa_ext_image_t dst_image; + const hsa_ext_image_region_t *image_region; + } hsa_ext_image_import; + struct { + hsa_agent_t agent; + hsa_ext_image_t src_image; + void *dst_memory; + size_t dst_row_pitch; + size_t dst_slice_pitch; + const hsa_ext_image_region_t *image_region; + } hsa_ext_image_export; + struct { + hsa_agent_t agent; + hsa_ext_image_t src_image; + const hsa_dim3_t *src_offset; + hsa_ext_image_t dst_image; + const hsa_dim3_t *dst_offset; + const hsa_dim3_t *range; + } hsa_ext_image_copy; + struct { + hsa_agent_t agent; + hsa_ext_image_t image; + const void *data; + const hsa_ext_image_region_t *image_region; + } hsa_ext_image_clear; + struct { + hsa_agent_t agent; + hsa_ext_image_t image; + } hsa_ext_image_destroy; + struct { + hsa_agent_t agent; + const hsa_ext_sampler_descriptor_t *sampler_descriptor; + hsa_ext_sampler_t *sampler; + } hsa_ext_sampler_create; + struct { + hsa_agent_t agent; + hsa_ext_sampler_t sampler; + } hsa_ext_sampler_destroy; + struct { + hsa_agent_t agent; + hsa_ext_image_geometry_t geometry; + const hsa_ext_image_format_t *image_format; + hsa_ext_image_data_layout_t image_data_layout; + uint32_t *capability_mask; + } hsa_ext_image_get_capability_with_layout; + struct { + hsa_agent_t agent; + const hsa_ext_image_descriptor_t *image_descriptor; + hsa_access_permission_t access_permission; + hsa_ext_image_data_layout_t image_data_layout; + size_t image_data_row_pitch; + size_t image_data_slice_pitch; + hsa_ext_image_data_info_t *image_data_info; + } hsa_ext_image_data_get_info_with_layout; + struct { + hsa_agent_t agent; + const hsa_ext_image_descriptor_t *image_descriptor; + const void *image_data; + hsa_access_permission_t access_permission; + hsa_ext_image_data_layout_t image_data_layout; + size_t image_data_row_pitch; + size_t image_data_slice_pitch; + hsa_ext_image_t *image; + } hsa_ext_image_create_with_layout; + } args; + uint64_t *phase_data; +}; + +/* section: API output stream */ + +#ifdef __cplusplus +#include "hsa_ostream_ops.h" +typedef std::pair hsa_api_data_pair_t; +inline std::ostream &operator<<(std::ostream &out, + const hsa_api_data_pair_t &data_pair) { + const uint32_t cid = data_pair.first; + const hsa_api_data_t &api_data = data_pair.second; + switch (cid) { + /* block: CoreApi API */ + case HSA_API_ID_hsa_init: { + out << "hsa_init("; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_shut_down: { + out << "hsa_shut_down("; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_system_get_info: { + out << "hsa_system_get_info("; + out << api_data.args.hsa_system_get_info.attribute << ", "; + out << api_data.args.hsa_system_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_system_extension_supported: { + out << "hsa_system_extension_supported("; + out << api_data.args.hsa_system_extension_supported.extension << ", "; + out << api_data.args.hsa_system_extension_supported.version_major << ", "; + out << api_data.args.hsa_system_extension_supported.version_minor << ", "; + out << api_data.args.hsa_system_extension_supported.result; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_system_get_extension_table: { + out << "hsa_system_get_extension_table("; + out << api_data.args.hsa_system_get_extension_table.extension << ", "; + out << api_data.args.hsa_system_get_extension_table.version_major << ", "; + out << api_data.args.hsa_system_get_extension_table.version_minor << ", "; + out << api_data.args.hsa_system_get_extension_table.table; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_iterate_agents: { + out << "hsa_iterate_agents("; + out << api_data.args.hsa_iterate_agents.callback << ", "; + out << api_data.args.hsa_iterate_agents.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_agent_get_info: { + out << "hsa_agent_get_info("; + out << api_data.args.hsa_agent_get_info.agent << ", "; + out << api_data.args.hsa_agent_get_info.attribute << ", "; + out << api_data.args.hsa_agent_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_queue_create: { + out << "hsa_queue_create("; + out << api_data.args.hsa_queue_create.agent << ", "; + out << api_data.args.hsa_queue_create.size << ", "; + out << api_data.args.hsa_queue_create.type << ", "; + out << api_data.args.hsa_queue_create.callback << ", "; + out << api_data.args.hsa_queue_create.data << ", "; + out << api_data.args.hsa_queue_create.private_segment_size << ", "; + out << api_data.args.hsa_queue_create.group_segment_size << ", "; + out << api_data.args.hsa_queue_create.queue; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_soft_queue_create: { + out << "hsa_soft_queue_create("; + out << api_data.args.hsa_soft_queue_create.region << ", "; + out << api_data.args.hsa_soft_queue_create.size << ", "; + out << api_data.args.hsa_soft_queue_create.type << ", "; + out << api_data.args.hsa_soft_queue_create.features << ", "; + out << api_data.args.hsa_soft_queue_create.doorbell_signal << ", "; + out << api_data.args.hsa_soft_queue_create.queue; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_queue_destroy: { + out << "hsa_queue_destroy("; + out << api_data.args.hsa_queue_destroy.queue; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_queue_inactivate: { + out << "hsa_queue_inactivate("; + out << api_data.args.hsa_queue_inactivate.queue; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_queue_load_read_index_scacquire: { + out << "hsa_queue_load_read_index_scacquire("; + out << api_data.args.hsa_queue_load_read_index_scacquire.queue; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_load_read_index_relaxed: { + out << "hsa_queue_load_read_index_relaxed("; + out << api_data.args.hsa_queue_load_read_index_relaxed.queue; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_load_write_index_scacquire: { + out << "hsa_queue_load_write_index_scacquire("; + out << api_data.args.hsa_queue_load_write_index_scacquire.queue; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_load_write_index_relaxed: { + out << "hsa_queue_load_write_index_relaxed("; + out << api_data.args.hsa_queue_load_write_index_relaxed.queue; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_store_write_index_relaxed: { + out << "hsa_queue_store_write_index_relaxed("; + out << api_data.args.hsa_queue_store_write_index_relaxed.queue << ", "; + out << api_data.args.hsa_queue_store_write_index_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_queue_store_write_index_screlease: { + out << "hsa_queue_store_write_index_screlease("; + out << api_data.args.hsa_queue_store_write_index_screlease.queue << ", "; + out << api_data.args.hsa_queue_store_write_index_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_queue_cas_write_index_scacq_screl: { + out << "hsa_queue_cas_write_index_scacq_screl("; + out << api_data.args.hsa_queue_cas_write_index_scacq_screl.queue << ", "; + out << api_data.args.hsa_queue_cas_write_index_scacq_screl.expected << ", "; + out << api_data.args.hsa_queue_cas_write_index_scacq_screl.value; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_cas_write_index_scacquire: { + out << "hsa_queue_cas_write_index_scacquire("; + out << api_data.args.hsa_queue_cas_write_index_scacquire.queue << ", "; + out << api_data.args.hsa_queue_cas_write_index_scacquire.expected << ", "; + out << api_data.args.hsa_queue_cas_write_index_scacquire.value; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_cas_write_index_relaxed: { + out << "hsa_queue_cas_write_index_relaxed("; + out << api_data.args.hsa_queue_cas_write_index_relaxed.queue << ", "; + out << api_data.args.hsa_queue_cas_write_index_relaxed.expected << ", "; + out << api_data.args.hsa_queue_cas_write_index_relaxed.value; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_cas_write_index_screlease: { + out << "hsa_queue_cas_write_index_screlease("; + out << api_data.args.hsa_queue_cas_write_index_screlease.queue << ", "; + out << api_data.args.hsa_queue_cas_write_index_screlease.expected << ", "; + out << api_data.args.hsa_queue_cas_write_index_screlease.value; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_add_write_index_scacq_screl: { + out << "hsa_queue_add_write_index_scacq_screl("; + out << api_data.args.hsa_queue_add_write_index_scacq_screl.queue << ", "; + out << api_data.args.hsa_queue_add_write_index_scacq_screl.value; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_add_write_index_scacquire: { + out << "hsa_queue_add_write_index_scacquire("; + out << api_data.args.hsa_queue_add_write_index_scacquire.queue << ", "; + out << api_data.args.hsa_queue_add_write_index_scacquire.value; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_add_write_index_relaxed: { + out << "hsa_queue_add_write_index_relaxed("; + out << api_data.args.hsa_queue_add_write_index_relaxed.queue << ", "; + out << api_data.args.hsa_queue_add_write_index_relaxed.value; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_add_write_index_screlease: { + out << "hsa_queue_add_write_index_screlease("; + out << api_data.args.hsa_queue_add_write_index_screlease.queue << ", "; + out << api_data.args.hsa_queue_add_write_index_screlease.value; + out << ") = " << api_data.uint64_t_retval; + break; + } + case HSA_API_ID_hsa_queue_store_read_index_relaxed: { + out << "hsa_queue_store_read_index_relaxed("; + out << api_data.args.hsa_queue_store_read_index_relaxed.queue << ", "; + out << api_data.args.hsa_queue_store_read_index_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_queue_store_read_index_screlease: { + out << "hsa_queue_store_read_index_screlease("; + out << api_data.args.hsa_queue_store_read_index_screlease.queue << ", "; + out << api_data.args.hsa_queue_store_read_index_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_agent_iterate_regions: { + out << "hsa_agent_iterate_regions("; + out << api_data.args.hsa_agent_iterate_regions.agent << ", "; + out << api_data.args.hsa_agent_iterate_regions.callback << ", "; + out << api_data.args.hsa_agent_iterate_regions.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_region_get_info: { + out << "hsa_region_get_info("; + out << api_data.args.hsa_region_get_info.region << ", "; + out << api_data.args.hsa_region_get_info.attribute << ", "; + out << api_data.args.hsa_region_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_agent_get_exception_policies: { + out << "hsa_agent_get_exception_policies("; + out << api_data.args.hsa_agent_get_exception_policies.agent << ", "; + out << api_data.args.hsa_agent_get_exception_policies.profile << ", "; + out << api_data.args.hsa_agent_get_exception_policies.mask; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_agent_extension_supported: { + out << "hsa_agent_extension_supported("; + out << api_data.args.hsa_agent_extension_supported.extension << ", "; + out << api_data.args.hsa_agent_extension_supported.agent << ", "; + out << api_data.args.hsa_agent_extension_supported.version_major << ", "; + out << api_data.args.hsa_agent_extension_supported.version_minor << ", "; + out << api_data.args.hsa_agent_extension_supported.result; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_memory_register: { + out << "hsa_memory_register("; + out << api_data.args.hsa_memory_register.ptr << ", "; + out << api_data.args.hsa_memory_register.size; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_memory_deregister: { + out << "hsa_memory_deregister("; + out << api_data.args.hsa_memory_deregister.ptr << ", "; + out << api_data.args.hsa_memory_deregister.size; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_memory_allocate: { + out << "hsa_memory_allocate("; + out << api_data.args.hsa_memory_allocate.region << ", "; + out << api_data.args.hsa_memory_allocate.size << ", "; + out << api_data.args.hsa_memory_allocate.ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_memory_free: { + out << "hsa_memory_free("; + out << api_data.args.hsa_memory_free.ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_memory_copy: { + out << "hsa_memory_copy("; + out << api_data.args.hsa_memory_copy.dst << ", "; + out << api_data.args.hsa_memory_copy.src << ", "; + out << api_data.args.hsa_memory_copy.size; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_memory_assign_agent: { + out << "hsa_memory_assign_agent("; + out << api_data.args.hsa_memory_assign_agent.ptr << ", "; + out << api_data.args.hsa_memory_assign_agent.agent << ", "; + out << api_data.args.hsa_memory_assign_agent.access; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_signal_create: { + out << "hsa_signal_create("; + out << api_data.args.hsa_signal_create.initial_value << ", "; + out << api_data.args.hsa_signal_create.num_consumers << ", "; + out << api_data.args.hsa_signal_create.consumers << ", "; + out << api_data.args.hsa_signal_create.signal; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_signal_destroy: { + out << "hsa_signal_destroy("; + out << api_data.args.hsa_signal_destroy.signal; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_signal_load_relaxed: { + out << "hsa_signal_load_relaxed("; + out << api_data.args.hsa_signal_load_relaxed.signal; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_load_scacquire: { + out << "hsa_signal_load_scacquire("; + out << api_data.args.hsa_signal_load_scacquire.signal; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_store_relaxed: { + out << "hsa_signal_store_relaxed("; + out << api_data.args.hsa_signal_store_relaxed.signal << ", "; + out << api_data.args.hsa_signal_store_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_store_screlease: { + out << "hsa_signal_store_screlease("; + out << api_data.args.hsa_signal_store_screlease.signal << ", "; + out << api_data.args.hsa_signal_store_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_wait_relaxed: { + out << "hsa_signal_wait_relaxed("; + out << api_data.args.hsa_signal_wait_relaxed.signal << ", "; + out << api_data.args.hsa_signal_wait_relaxed.condition << ", "; + out << api_data.args.hsa_signal_wait_relaxed.compare_value << ", "; + out << api_data.args.hsa_signal_wait_relaxed.timeout_hint << ", "; + out << api_data.args.hsa_signal_wait_relaxed.wait_state_hint; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_wait_scacquire: { + out << "hsa_signal_wait_scacquire("; + out << api_data.args.hsa_signal_wait_scacquire.signal << ", "; + out << api_data.args.hsa_signal_wait_scacquire.condition << ", "; + out << api_data.args.hsa_signal_wait_scacquire.compare_value << ", "; + out << api_data.args.hsa_signal_wait_scacquire.timeout_hint << ", "; + out << api_data.args.hsa_signal_wait_scacquire.wait_state_hint; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_and_relaxed: { + out << "hsa_signal_and_relaxed("; + out << api_data.args.hsa_signal_and_relaxed.signal << ", "; + out << api_data.args.hsa_signal_and_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_and_scacquire: { + out << "hsa_signal_and_scacquire("; + out << api_data.args.hsa_signal_and_scacquire.signal << ", "; + out << api_data.args.hsa_signal_and_scacquire.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_and_screlease: { + out << "hsa_signal_and_screlease("; + out << api_data.args.hsa_signal_and_screlease.signal << ", "; + out << api_data.args.hsa_signal_and_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_and_scacq_screl: { + out << "hsa_signal_and_scacq_screl("; + out << api_data.args.hsa_signal_and_scacq_screl.signal << ", "; + out << api_data.args.hsa_signal_and_scacq_screl.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_or_relaxed: { + out << "hsa_signal_or_relaxed("; + out << api_data.args.hsa_signal_or_relaxed.signal << ", "; + out << api_data.args.hsa_signal_or_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_or_scacquire: { + out << "hsa_signal_or_scacquire("; + out << api_data.args.hsa_signal_or_scacquire.signal << ", "; + out << api_data.args.hsa_signal_or_scacquire.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_or_screlease: { + out << "hsa_signal_or_screlease("; + out << api_data.args.hsa_signal_or_screlease.signal << ", "; + out << api_data.args.hsa_signal_or_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_or_scacq_screl: { + out << "hsa_signal_or_scacq_screl("; + out << api_data.args.hsa_signal_or_scacq_screl.signal << ", "; + out << api_data.args.hsa_signal_or_scacq_screl.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_xor_relaxed: { + out << "hsa_signal_xor_relaxed("; + out << api_data.args.hsa_signal_xor_relaxed.signal << ", "; + out << api_data.args.hsa_signal_xor_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_xor_scacquire: { + out << "hsa_signal_xor_scacquire("; + out << api_data.args.hsa_signal_xor_scacquire.signal << ", "; + out << api_data.args.hsa_signal_xor_scacquire.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_xor_screlease: { + out << "hsa_signal_xor_screlease("; + out << api_data.args.hsa_signal_xor_screlease.signal << ", "; + out << api_data.args.hsa_signal_xor_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_xor_scacq_screl: { + out << "hsa_signal_xor_scacq_screl("; + out << api_data.args.hsa_signal_xor_scacq_screl.signal << ", "; + out << api_data.args.hsa_signal_xor_scacq_screl.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_exchange_relaxed: { + out << "hsa_signal_exchange_relaxed("; + out << api_data.args.hsa_signal_exchange_relaxed.signal << ", "; + out << api_data.args.hsa_signal_exchange_relaxed.value; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_exchange_scacquire: { + out << "hsa_signal_exchange_scacquire("; + out << api_data.args.hsa_signal_exchange_scacquire.signal << ", "; + out << api_data.args.hsa_signal_exchange_scacquire.value; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_exchange_screlease: { + out << "hsa_signal_exchange_screlease("; + out << api_data.args.hsa_signal_exchange_screlease.signal << ", "; + out << api_data.args.hsa_signal_exchange_screlease.value; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_exchange_scacq_screl: { + out << "hsa_signal_exchange_scacq_screl("; + out << api_data.args.hsa_signal_exchange_scacq_screl.signal << ", "; + out << api_data.args.hsa_signal_exchange_scacq_screl.value; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_add_relaxed: { + out << "hsa_signal_add_relaxed("; + out << api_data.args.hsa_signal_add_relaxed.signal << ", "; + out << api_data.args.hsa_signal_add_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_add_scacquire: { + out << "hsa_signal_add_scacquire("; + out << api_data.args.hsa_signal_add_scacquire.signal << ", "; + out << api_data.args.hsa_signal_add_scacquire.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_add_screlease: { + out << "hsa_signal_add_screlease("; + out << api_data.args.hsa_signal_add_screlease.signal << ", "; + out << api_data.args.hsa_signal_add_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_add_scacq_screl: { + out << "hsa_signal_add_scacq_screl("; + out << api_data.args.hsa_signal_add_scacq_screl.signal << ", "; + out << api_data.args.hsa_signal_add_scacq_screl.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_subtract_relaxed: { + out << "hsa_signal_subtract_relaxed("; + out << api_data.args.hsa_signal_subtract_relaxed.signal << ", "; + out << api_data.args.hsa_signal_subtract_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_subtract_scacquire: { + out << "hsa_signal_subtract_scacquire("; + out << api_data.args.hsa_signal_subtract_scacquire.signal << ", "; + out << api_data.args.hsa_signal_subtract_scacquire.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_subtract_screlease: { + out << "hsa_signal_subtract_screlease("; + out << api_data.args.hsa_signal_subtract_screlease.signal << ", "; + out << api_data.args.hsa_signal_subtract_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_subtract_scacq_screl: { + out << "hsa_signal_subtract_scacq_screl("; + out << api_data.args.hsa_signal_subtract_scacq_screl.signal << ", "; + out << api_data.args.hsa_signal_subtract_scacq_screl.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_cas_relaxed: { + out << "hsa_signal_cas_relaxed("; + out << api_data.args.hsa_signal_cas_relaxed.signal << ", "; + out << api_data.args.hsa_signal_cas_relaxed.expected << ", "; + out << api_data.args.hsa_signal_cas_relaxed.value; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_cas_scacquire: { + out << "hsa_signal_cas_scacquire("; + out << api_data.args.hsa_signal_cas_scacquire.signal << ", "; + out << api_data.args.hsa_signal_cas_scacquire.expected << ", "; + out << api_data.args.hsa_signal_cas_scacquire.value; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_cas_screlease: { + out << "hsa_signal_cas_screlease("; + out << api_data.args.hsa_signal_cas_screlease.signal << ", "; + out << api_data.args.hsa_signal_cas_screlease.expected << ", "; + out << api_data.args.hsa_signal_cas_screlease.value; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_signal_cas_scacq_screl: { + out << "hsa_signal_cas_scacq_screl("; + out << api_data.args.hsa_signal_cas_scacq_screl.signal << ", "; + out << api_data.args.hsa_signal_cas_scacq_screl.expected << ", "; + out << api_data.args.hsa_signal_cas_scacq_screl.value; + out << ") = " << api_data.hsa_signal_value_t_retval; + break; + } + case HSA_API_ID_hsa_isa_from_name: { + out << "hsa_isa_from_name("; + out << "0x" << std::hex << (uint64_t)api_data.args.hsa_isa_from_name.name + << ", "; + out << api_data.args.hsa_isa_from_name.isa; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_isa_get_info: { + out << "hsa_isa_get_info("; + out << api_data.args.hsa_isa_get_info.isa << ", "; + out << api_data.args.hsa_isa_get_info.attribute << ", "; + out << api_data.args.hsa_isa_get_info.index << ", "; + out << api_data.args.hsa_isa_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_isa_compatible: { + out << "hsa_isa_compatible("; + out << api_data.args.hsa_isa_compatible.code_object_isa << ", "; + out << api_data.args.hsa_isa_compatible.agent_isa << ", "; + out << api_data.args.hsa_isa_compatible.result; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_serialize: { + out << "hsa_code_object_serialize("; + out << api_data.args.hsa_code_object_serialize.code_object << ", "; + out << api_data.args.hsa_code_object_serialize.alloc_callback << ", "; + out << api_data.args.hsa_code_object_serialize.callback_data << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_code_object_serialize.options << ", "; + out << api_data.args.hsa_code_object_serialize.serialized_code_object + << ", "; + out << api_data.args.hsa_code_object_serialize.serialized_code_object_size; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_deserialize: { + out << "hsa_code_object_deserialize("; + out << api_data.args.hsa_code_object_deserialize.serialized_code_object + << ", "; + out << api_data.args.hsa_code_object_deserialize.serialized_code_object_size + << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_code_object_deserialize.options << ", "; + out << api_data.args.hsa_code_object_deserialize.code_object; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_destroy: { + out << "hsa_code_object_destroy("; + out << api_data.args.hsa_code_object_destroy.code_object; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_get_info: { + out << "hsa_code_object_get_info("; + out << api_data.args.hsa_code_object_get_info.code_object << ", "; + out << api_data.args.hsa_code_object_get_info.attribute << ", "; + out << api_data.args.hsa_code_object_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_get_symbol: { + out << "hsa_code_object_get_symbol("; + out << api_data.args.hsa_code_object_get_symbol.code_object << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_code_object_get_symbol.symbol_name + << ", "; + out << api_data.args.hsa_code_object_get_symbol.symbol; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_symbol_get_info: { + out << "hsa_code_symbol_get_info("; + out << api_data.args.hsa_code_symbol_get_info.code_symbol << ", "; + out << api_data.args.hsa_code_symbol_get_info.attribute << ", "; + out << api_data.args.hsa_code_symbol_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_iterate_symbols: { + out << "hsa_code_object_iterate_symbols("; + out << api_data.args.hsa_code_object_iterate_symbols.code_object << ", "; + out << api_data.args.hsa_code_object_iterate_symbols.callback << ", "; + out << api_data.args.hsa_code_object_iterate_symbols.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_create: { + out << "hsa_executable_create("; + out << api_data.args.hsa_executable_create.profile << ", "; + out << api_data.args.hsa_executable_create.executable_state << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_create.options << ", "; + out << api_data.args.hsa_executable_create.executable; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_destroy: { + out << "hsa_executable_destroy("; + out << api_data.args.hsa_executable_destroy.executable; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_load_code_object: { + out << "hsa_executable_load_code_object("; + out << api_data.args.hsa_executable_load_code_object.executable << ", "; + out << api_data.args.hsa_executable_load_code_object.agent << ", "; + out << api_data.args.hsa_executable_load_code_object.code_object << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_load_code_object.options; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_freeze: { + out << "hsa_executable_freeze("; + out << api_data.args.hsa_executable_freeze.executable << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_freeze.options; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_get_info: { + out << "hsa_executable_get_info("; + out << api_data.args.hsa_executable_get_info.executable << ", "; + out << api_data.args.hsa_executable_get_info.attribute << ", "; + out << api_data.args.hsa_executable_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_global_variable_define: { + out << "hsa_executable_global_variable_define("; + out << api_data.args.hsa_executable_global_variable_define.executable + << ", "; + out << "0x" << std::hex + << (uint64_t) + api_data.args.hsa_executable_global_variable_define.variable_name + << ", "; + out << api_data.args.hsa_executable_global_variable_define.address; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_agent_global_variable_define: { + out << "hsa_executable_agent_global_variable_define("; + out << api_data.args.hsa_executable_agent_global_variable_define.executable + << ", "; + out << api_data.args.hsa_executable_agent_global_variable_define.agent + << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_agent_global_variable_define + .variable_name + << ", "; + out << api_data.args.hsa_executable_agent_global_variable_define.address; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_readonly_variable_define: { + out << "hsa_executable_readonly_variable_define("; + out << api_data.args.hsa_executable_readonly_variable_define.executable + << ", "; + out << api_data.args.hsa_executable_readonly_variable_define.agent << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_readonly_variable_define + .variable_name + << ", "; + out << api_data.args.hsa_executable_readonly_variable_define.address; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_validate: { + out << "hsa_executable_validate("; + out << api_data.args.hsa_executable_validate.executable << ", "; + out << api_data.args.hsa_executable_validate.result; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_get_symbol: { + out << "hsa_executable_get_symbol("; + out << api_data.args.hsa_executable_get_symbol.executable << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_get_symbol.module_name + << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_get_symbol.symbol_name + << ", "; + out << api_data.args.hsa_executable_get_symbol.agent << ", "; + out << api_data.args.hsa_executable_get_symbol.call_convention << ", "; + out << api_data.args.hsa_executable_get_symbol.symbol; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_symbol_get_info: { + out << "hsa_executable_symbol_get_info("; + out << api_data.args.hsa_executable_symbol_get_info.executable_symbol + << ", "; + out << api_data.args.hsa_executable_symbol_get_info.attribute << ", "; + out << api_data.args.hsa_executable_symbol_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_iterate_symbols: { + out << "hsa_executable_iterate_symbols("; + out << api_data.args.hsa_executable_iterate_symbols.executable << ", "; + out << api_data.args.hsa_executable_iterate_symbols.callback << ", "; + out << api_data.args.hsa_executable_iterate_symbols.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_status_string: { + out << "hsa_status_string("; + out << api_data.args.hsa_status_string.status << ", "; + out << api_data.args.hsa_status_string.status_string; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_extension_get_name: { + out << "hsa_extension_get_name("; + out << api_data.args.hsa_extension_get_name.extension << ", "; + out << api_data.args.hsa_extension_get_name.name; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_system_major_extension_supported: { + out << "hsa_system_major_extension_supported("; + out << api_data.args.hsa_system_major_extension_supported.extension << ", "; + out << api_data.args.hsa_system_major_extension_supported.version_major + << ", "; + out << api_data.args.hsa_system_major_extension_supported.version_minor + << ", "; + out << api_data.args.hsa_system_major_extension_supported.result; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_system_get_major_extension_table: { + out << "hsa_system_get_major_extension_table("; + out << api_data.args.hsa_system_get_major_extension_table.extension << ", "; + out << api_data.args.hsa_system_get_major_extension_table.version_major + << ", "; + out << api_data.args.hsa_system_get_major_extension_table.table_length + << ", "; + out << api_data.args.hsa_system_get_major_extension_table.table; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_agent_major_extension_supported: { + out << "hsa_agent_major_extension_supported("; + out << api_data.args.hsa_agent_major_extension_supported.extension << ", "; + out << api_data.args.hsa_agent_major_extension_supported.agent << ", "; + out << api_data.args.hsa_agent_major_extension_supported.version_major + << ", "; + out << api_data.args.hsa_agent_major_extension_supported.version_minor + << ", "; + out << api_data.args.hsa_agent_major_extension_supported.result; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_cache_get_info: { + out << "hsa_cache_get_info("; + out << api_data.args.hsa_cache_get_info.cache << ", "; + out << api_data.args.hsa_cache_get_info.attribute << ", "; + out << api_data.args.hsa_cache_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_agent_iterate_caches: { + out << "hsa_agent_iterate_caches("; + out << api_data.args.hsa_agent_iterate_caches.agent << ", "; + out << api_data.args.hsa_agent_iterate_caches.callback << ", "; + out << api_data.args.hsa_agent_iterate_caches.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_signal_silent_store_relaxed: { + out << "hsa_signal_silent_store_relaxed("; + out << api_data.args.hsa_signal_silent_store_relaxed.signal << ", "; + out << api_data.args.hsa_signal_silent_store_relaxed.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_silent_store_screlease: { + out << "hsa_signal_silent_store_screlease("; + out << api_data.args.hsa_signal_silent_store_screlease.signal << ", "; + out << api_data.args.hsa_signal_silent_store_screlease.value; + out << ") = void"; + break; + } + case HSA_API_ID_hsa_signal_group_create: { + out << "hsa_signal_group_create("; + out << api_data.args.hsa_signal_group_create.num_signals << ", "; + out << api_data.args.hsa_signal_group_create.signals << ", "; + out << api_data.args.hsa_signal_group_create.num_consumers << ", "; + out << api_data.args.hsa_signal_group_create.consumers << ", "; + out << api_data.args.hsa_signal_group_create.signal_group; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_signal_group_destroy: { + out << "hsa_signal_group_destroy("; + out << api_data.args.hsa_signal_group_destroy.signal_group; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_signal_group_wait_any_scacquire: { + out << "hsa_signal_group_wait_any_scacquire("; + out << api_data.args.hsa_signal_group_wait_any_scacquire.signal_group + << ", "; + out << api_data.args.hsa_signal_group_wait_any_scacquire.conditions << ", "; + out << api_data.args.hsa_signal_group_wait_any_scacquire.compare_values + << ", "; + out << api_data.args.hsa_signal_group_wait_any_scacquire.wait_state_hint + << ", "; + out << api_data.args.hsa_signal_group_wait_any_scacquire.signal << ", "; + out << api_data.args.hsa_signal_group_wait_any_scacquire.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_signal_group_wait_any_relaxed: { + out << "hsa_signal_group_wait_any_relaxed("; + out << api_data.args.hsa_signal_group_wait_any_relaxed.signal_group << ", "; + out << api_data.args.hsa_signal_group_wait_any_relaxed.conditions << ", "; + out << api_data.args.hsa_signal_group_wait_any_relaxed.compare_values + << ", "; + out << api_data.args.hsa_signal_group_wait_any_relaxed.wait_state_hint + << ", "; + out << api_data.args.hsa_signal_group_wait_any_relaxed.signal << ", "; + out << api_data.args.hsa_signal_group_wait_any_relaxed.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_agent_iterate_isas: { + out << "hsa_agent_iterate_isas("; + out << api_data.args.hsa_agent_iterate_isas.agent << ", "; + out << api_data.args.hsa_agent_iterate_isas.callback << ", "; + out << api_data.args.hsa_agent_iterate_isas.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_isa_get_info_alt: { + out << "hsa_isa_get_info_alt("; + out << api_data.args.hsa_isa_get_info_alt.isa << ", "; + out << api_data.args.hsa_isa_get_info_alt.attribute << ", "; + out << api_data.args.hsa_isa_get_info_alt.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_isa_get_exception_policies: { + out << "hsa_isa_get_exception_policies("; + out << api_data.args.hsa_isa_get_exception_policies.isa << ", "; + out << api_data.args.hsa_isa_get_exception_policies.profile << ", "; + out << api_data.args.hsa_isa_get_exception_policies.mask; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_isa_get_round_method: { + out << "hsa_isa_get_round_method("; + out << api_data.args.hsa_isa_get_round_method.isa << ", "; + out << api_data.args.hsa_isa_get_round_method.fp_type << ", "; + out << api_data.args.hsa_isa_get_round_method.flush_mode << ", "; + out << api_data.args.hsa_isa_get_round_method.round_method; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_wavefront_get_info: { + out << "hsa_wavefront_get_info("; + out << api_data.args.hsa_wavefront_get_info.wavefront << ", "; + out << api_data.args.hsa_wavefront_get_info.attribute << ", "; + out << api_data.args.hsa_wavefront_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_isa_iterate_wavefronts: { + out << "hsa_isa_iterate_wavefronts("; + out << api_data.args.hsa_isa_iterate_wavefronts.isa << ", "; + out << api_data.args.hsa_isa_iterate_wavefronts.callback << ", "; + out << api_data.args.hsa_isa_iterate_wavefronts.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_get_symbol_from_name: { + out << "hsa_code_object_get_symbol_from_name("; + out << api_data.args.hsa_code_object_get_symbol_from_name.code_object + << ", "; + out << "0x" << std::hex + << (uint64_t) + api_data.args.hsa_code_object_get_symbol_from_name.module_name + << ", "; + out << "0x" << std::hex + << (uint64_t) + api_data.args.hsa_code_object_get_symbol_from_name.symbol_name + << ", "; + out << api_data.args.hsa_code_object_get_symbol_from_name.symbol; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_reader_create_from_file: { + out << "hsa_code_object_reader_create_from_file("; + out << api_data.args.hsa_code_object_reader_create_from_file.file << ", "; + out << api_data.args.hsa_code_object_reader_create_from_file + .code_object_reader; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_reader_create_from_memory: { + out << "hsa_code_object_reader_create_from_memory("; + out << api_data.args.hsa_code_object_reader_create_from_memory.code_object + << ", "; + out << api_data.args.hsa_code_object_reader_create_from_memory.size << ", "; + out << api_data.args.hsa_code_object_reader_create_from_memory + .code_object_reader; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_code_object_reader_destroy: { + out << "hsa_code_object_reader_destroy("; + out << api_data.args.hsa_code_object_reader_destroy.code_object_reader; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_create_alt: { + out << "hsa_executable_create_alt("; + out << api_data.args.hsa_executable_create_alt.profile << ", "; + out << api_data.args.hsa_executable_create_alt.default_float_rounding_mode + << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_create_alt.options << ", "; + out << api_data.args.hsa_executable_create_alt.executable; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_load_program_code_object: { + out << "hsa_executable_load_program_code_object("; + out << api_data.args.hsa_executable_load_program_code_object.executable + << ", "; + out << api_data.args.hsa_executable_load_program_code_object + .code_object_reader + << ", "; + out << "0x" << std::hex + << (uint64_t) + api_data.args.hsa_executable_load_program_code_object.options + << ", "; + out << api_data.args.hsa_executable_load_program_code_object + .loaded_code_object; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_load_agent_code_object: { + out << "hsa_executable_load_agent_code_object("; + out << api_data.args.hsa_executable_load_agent_code_object.executable + << ", "; + out << api_data.args.hsa_executable_load_agent_code_object.agent << ", "; + out << api_data.args.hsa_executable_load_agent_code_object + .code_object_reader + << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_load_agent_code_object.options + << ", "; + out << api_data.args.hsa_executable_load_agent_code_object + .loaded_code_object; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_validate_alt: { + out << "hsa_executable_validate_alt("; + out << api_data.args.hsa_executable_validate_alt.executable << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_validate_alt.options << ", "; + out << api_data.args.hsa_executable_validate_alt.result; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_get_symbol_by_name: { + out << "hsa_executable_get_symbol_by_name("; + out << api_data.args.hsa_executable_get_symbol_by_name.executable << ", "; + out << "0x" << std::hex + << (uint64_t)api_data.args.hsa_executable_get_symbol_by_name.symbol_name + << ", "; + out << api_data.args.hsa_executable_get_symbol_by_name.agent << ", "; + out << api_data.args.hsa_executable_get_symbol_by_name.symbol; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_iterate_agent_symbols: { + out << "hsa_executable_iterate_agent_symbols("; + out << api_data.args.hsa_executable_iterate_agent_symbols.executable + << ", "; + out << api_data.args.hsa_executable_iterate_agent_symbols.agent << ", "; + out << api_data.args.hsa_executable_iterate_agent_symbols.callback << ", "; + out << api_data.args.hsa_executable_iterate_agent_symbols.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_executable_iterate_program_symbols: { + out << "hsa_executable_iterate_program_symbols("; + out << api_data.args.hsa_executable_iterate_program_symbols.executable + << ", "; + out << api_data.args.hsa_executable_iterate_program_symbols.callback + << ", "; + out << api_data.args.hsa_executable_iterate_program_symbols.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + + /* block: AmdExt API */ + case HSA_API_ID_hsa_amd_coherency_get_type: { + out << "hsa_amd_coherency_get_type("; + out << api_data.args.hsa_amd_coherency_get_type.agent << ", "; + out << api_data.args.hsa_amd_coherency_get_type.type; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_coherency_set_type: { + out << "hsa_amd_coherency_set_type("; + out << api_data.args.hsa_amd_coherency_set_type.agent << ", "; + out << api_data.args.hsa_amd_coherency_set_type.type; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_profiling_set_profiler_enabled: { + out << "hsa_amd_profiling_set_profiler_enabled("; + out << api_data.args.hsa_amd_profiling_set_profiler_enabled.queue << ", "; + out << api_data.args.hsa_amd_profiling_set_profiler_enabled.enable; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_profiling_async_copy_enable: { + out << "hsa_amd_profiling_async_copy_enable("; + out << api_data.args.hsa_amd_profiling_async_copy_enable.enable; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_profiling_get_dispatch_time: { + out << "hsa_amd_profiling_get_dispatch_time("; + out << api_data.args.hsa_amd_profiling_get_dispatch_time.agent << ", "; + out << api_data.args.hsa_amd_profiling_get_dispatch_time.signal << ", "; + out << api_data.args.hsa_amd_profiling_get_dispatch_time.time; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_profiling_get_async_copy_time: { + out << "hsa_amd_profiling_get_async_copy_time("; + out << api_data.args.hsa_amd_profiling_get_async_copy_time.signal << ", "; + out << api_data.args.hsa_amd_profiling_get_async_copy_time.time; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_profiling_convert_tick_to_system_domain: { + out << "hsa_amd_profiling_convert_tick_to_system_domain("; + out << api_data.args.hsa_amd_profiling_convert_tick_to_system_domain.agent + << ", "; + out << api_data.args.hsa_amd_profiling_convert_tick_to_system_domain + .agent_tick + << ", "; + out << api_data.args.hsa_amd_profiling_convert_tick_to_system_domain + .system_tick; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_signal_async_handler: { + out << "hsa_amd_signal_async_handler("; + out << api_data.args.hsa_amd_signal_async_handler.signal << ", "; + out << api_data.args.hsa_amd_signal_async_handler.cond << ", "; + out << api_data.args.hsa_amd_signal_async_handler.value << ", "; + out << api_data.args.hsa_amd_signal_async_handler.handler << ", "; + out << api_data.args.hsa_amd_signal_async_handler.arg; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_async_function: { + out << "hsa_amd_async_function("; + out << api_data.args.hsa_amd_async_function.callback << ", "; + out << api_data.args.hsa_amd_async_function.arg; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_signal_wait_any: { + out << "hsa_amd_signal_wait_any("; + out << api_data.args.hsa_amd_signal_wait_any.signal_count << ", "; + out << api_data.args.hsa_amd_signal_wait_any.signals << ", "; + out << api_data.args.hsa_amd_signal_wait_any.conds << ", "; + out << api_data.args.hsa_amd_signal_wait_any.values << ", "; + out << api_data.args.hsa_amd_signal_wait_any.timeout_hint << ", "; + out << api_data.args.hsa_amd_signal_wait_any.wait_hint << ", "; + out << api_data.args.hsa_amd_signal_wait_any.satisfying_value; + out << ") = " << api_data.uint32_t_retval; + break; + } + case HSA_API_ID_hsa_amd_queue_cu_set_mask: { + out << "hsa_amd_queue_cu_set_mask("; + out << api_data.args.hsa_amd_queue_cu_set_mask.queue << ", "; + out << api_data.args.hsa_amd_queue_cu_set_mask.num_cu_mask_count << ", "; + out << api_data.args.hsa_amd_queue_cu_set_mask.cu_mask; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_pool_get_info: { + out << "hsa_amd_memory_pool_get_info("; + out << api_data.args.hsa_amd_memory_pool_get_info.memory_pool << ", "; + out << api_data.args.hsa_amd_memory_pool_get_info.attribute << ", "; + out << api_data.args.hsa_amd_memory_pool_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_agent_iterate_memory_pools: { + out << "hsa_amd_agent_iterate_memory_pools("; + out << api_data.args.hsa_amd_agent_iterate_memory_pools.agent << ", "; + out << api_data.args.hsa_amd_agent_iterate_memory_pools.callback << ", "; + out << api_data.args.hsa_amd_agent_iterate_memory_pools.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_pool_allocate: { + out << "hsa_amd_memory_pool_allocate("; + out << api_data.args.hsa_amd_memory_pool_allocate.memory_pool << ", "; + out << api_data.args.hsa_amd_memory_pool_allocate.size << ", "; + out << api_data.args.hsa_amd_memory_pool_allocate.flags << ", "; + out << api_data.args.hsa_amd_memory_pool_allocate.ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_pool_free: { + out << "hsa_amd_memory_pool_free("; + out << api_data.args.hsa_amd_memory_pool_free.ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_async_copy: { + out << "hsa_amd_memory_async_copy("; + out << api_data.args.hsa_amd_memory_async_copy.dst << ", "; + out << api_data.args.hsa_amd_memory_async_copy.dst_agent << ", "; + out << api_data.args.hsa_amd_memory_async_copy.src << ", "; + out << api_data.args.hsa_amd_memory_async_copy.src_agent << ", "; + out << api_data.args.hsa_amd_memory_async_copy.size << ", "; + out << api_data.args.hsa_amd_memory_async_copy.num_dep_signals << ", "; + out << api_data.args.hsa_amd_memory_async_copy.dep_signals << ", "; + out << api_data.args.hsa_amd_memory_async_copy.completion_signal; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_async_copy_on_engine: { + out << "hsa_amd_memory_async_copy_on_engine("; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.dst << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.dst_agent << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.src << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.src_agent << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.size << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.num_dep_signals + << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.dep_signals + << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.completion_signal + << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.engine_id << ", "; + out << api_data.args.hsa_amd_memory_async_copy_on_engine.force_copy_on_sdma; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_copy_engine_status: { + out << "hsa_amd_memory_copy_engine_status("; + out << api_data.args.hsa_amd_memory_copy_engine_status.dst_agent << ", "; + out << api_data.args.hsa_amd_memory_copy_engine_status.src_agent << ", "; + out << api_data.args.hsa_amd_memory_copy_engine_status.engine_ids_mask; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_agent_memory_pool_get_info: { + out << "hsa_amd_agent_memory_pool_get_info("; + out << api_data.args.hsa_amd_agent_memory_pool_get_info.agent << ", "; + out << api_data.args.hsa_amd_agent_memory_pool_get_info.memory_pool << ", "; + out << api_data.args.hsa_amd_agent_memory_pool_get_info.attribute << ", "; + out << api_data.args.hsa_amd_agent_memory_pool_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_agents_allow_access: { + out << "hsa_amd_agents_allow_access("; + out << api_data.args.hsa_amd_agents_allow_access.num_agents << ", "; + out << api_data.args.hsa_amd_agents_allow_access.agents << ", "; + out << api_data.args.hsa_amd_agents_allow_access.flags << ", "; + out << api_data.args.hsa_amd_agents_allow_access.ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_pool_can_migrate: { + out << "hsa_amd_memory_pool_can_migrate("; + out << api_data.args.hsa_amd_memory_pool_can_migrate.src_memory_pool + << ", "; + out << api_data.args.hsa_amd_memory_pool_can_migrate.dst_memory_pool + << ", "; + out << api_data.args.hsa_amd_memory_pool_can_migrate.result; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_migrate: { + out << "hsa_amd_memory_migrate("; + out << api_data.args.hsa_amd_memory_migrate.ptr << ", "; + out << api_data.args.hsa_amd_memory_migrate.memory_pool << ", "; + out << api_data.args.hsa_amd_memory_migrate.flags; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_lock: { + out << "hsa_amd_memory_lock("; + out << api_data.args.hsa_amd_memory_lock.host_ptr << ", "; + out << api_data.args.hsa_amd_memory_lock.size << ", "; + out << api_data.args.hsa_amd_memory_lock.agents << ", "; + out << api_data.args.hsa_amd_memory_lock.num_agent << ", "; + out << api_data.args.hsa_amd_memory_lock.agent_ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_unlock: { + out << "hsa_amd_memory_unlock("; + out << api_data.args.hsa_amd_memory_unlock.host_ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_fill: { + out << "hsa_amd_memory_fill("; + out << api_data.args.hsa_amd_memory_fill.ptr << ", "; + out << api_data.args.hsa_amd_memory_fill.value << ", "; + out << api_data.args.hsa_amd_memory_fill.count; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_interop_map_buffer: { + out << "hsa_amd_interop_map_buffer("; + out << api_data.args.hsa_amd_interop_map_buffer.num_agents << ", "; + out << api_data.args.hsa_amd_interop_map_buffer.agents << ", "; + out << api_data.args.hsa_amd_interop_map_buffer.interop_handle << ", "; + out << api_data.args.hsa_amd_interop_map_buffer.flags << ", "; + out << api_data.args.hsa_amd_interop_map_buffer.size << ", "; + out << api_data.args.hsa_amd_interop_map_buffer.ptr << ", "; + out << api_data.args.hsa_amd_interop_map_buffer.metadata_size << ", "; + out << api_data.args.hsa_amd_interop_map_buffer.metadata; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_interop_unmap_buffer: { + out << "hsa_amd_interop_unmap_buffer("; + out << api_data.args.hsa_amd_interop_unmap_buffer.ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_image_create: { + out << "hsa_amd_image_create("; + out << api_data.args.hsa_amd_image_create.agent << ", "; + out << api_data.args.hsa_amd_image_create.image_descriptor << ", "; + out << api_data.args.hsa_amd_image_create.image_layout << ", "; + out << api_data.args.hsa_amd_image_create.image_data << ", "; + out << api_data.args.hsa_amd_image_create.access_permission << ", "; + out << api_data.args.hsa_amd_image_create.image; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_pointer_info: { + out << "hsa_amd_pointer_info("; + out << api_data.args.hsa_amd_pointer_info.ptr << ", "; + out << api_data.args.hsa_amd_pointer_info.info << ", "; + out << api_data.args.hsa_amd_pointer_info.alloc << ", "; + out << api_data.args.hsa_amd_pointer_info.num_agents_accessible << ", "; + out << api_data.args.hsa_amd_pointer_info.accessible; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_pointer_info_set_userdata: { + out << "hsa_amd_pointer_info_set_userdata("; + out << api_data.args.hsa_amd_pointer_info_set_userdata.ptr << ", "; + out << api_data.args.hsa_amd_pointer_info_set_userdata.userdata; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_ipc_memory_create: { + out << "hsa_amd_ipc_memory_create("; + out << api_data.args.hsa_amd_ipc_memory_create.ptr << ", "; + out << api_data.args.hsa_amd_ipc_memory_create.len << ", "; + out << api_data.args.hsa_amd_ipc_memory_create.handle; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_ipc_memory_attach: { + out << "hsa_amd_ipc_memory_attach("; + out << api_data.args.hsa_amd_ipc_memory_attach.handle << ", "; + out << api_data.args.hsa_amd_ipc_memory_attach.len << ", "; + out << api_data.args.hsa_amd_ipc_memory_attach.num_agents << ", "; + out << api_data.args.hsa_amd_ipc_memory_attach.mapping_agents << ", "; + out << api_data.args.hsa_amd_ipc_memory_attach.mapped_ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_ipc_memory_detach: { + out << "hsa_amd_ipc_memory_detach("; + out << api_data.args.hsa_amd_ipc_memory_detach.mapped_ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_signal_create: { + out << "hsa_amd_signal_create("; + out << api_data.args.hsa_amd_signal_create.initial_value << ", "; + out << api_data.args.hsa_amd_signal_create.num_consumers << ", "; + out << api_data.args.hsa_amd_signal_create.consumers << ", "; + out << api_data.args.hsa_amd_signal_create.attributes << ", "; + out << api_data.args.hsa_amd_signal_create.signal; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_ipc_signal_create: { + out << "hsa_amd_ipc_signal_create("; + out << api_data.args.hsa_amd_ipc_signal_create.signal << ", "; + out << api_data.args.hsa_amd_ipc_signal_create.handle; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_ipc_signal_attach: { + out << "hsa_amd_ipc_signal_attach("; + out << api_data.args.hsa_amd_ipc_signal_attach.handle << ", "; + out << api_data.args.hsa_amd_ipc_signal_attach.signal; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_register_system_event_handler: { + out << "hsa_amd_register_system_event_handler("; + out << api_data.args.hsa_amd_register_system_event_handler.callback << ", "; + out << api_data.args.hsa_amd_register_system_event_handler.data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_queue_intercept_create: { + out << "hsa_amd_queue_intercept_create("; + out << api_data.args.hsa_amd_queue_intercept_create.agent_handle << ", "; + out << api_data.args.hsa_amd_queue_intercept_create.size << ", "; + out << api_data.args.hsa_amd_queue_intercept_create.type << ", "; + out << api_data.args.hsa_amd_queue_intercept_create.callback << ", "; + out << api_data.args.hsa_amd_queue_intercept_create.data << ", "; + out << api_data.args.hsa_amd_queue_intercept_create.private_segment_size + << ", "; + out << api_data.args.hsa_amd_queue_intercept_create.group_segment_size + << ", "; + out << api_data.args.hsa_amd_queue_intercept_create.queue; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_queue_intercept_register: { + out << "hsa_amd_queue_intercept_register("; + out << api_data.args.hsa_amd_queue_intercept_register.queue << ", "; + out << api_data.args.hsa_amd_queue_intercept_register.callback << ", "; + out << api_data.args.hsa_amd_queue_intercept_register.user_data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_queue_set_priority: { + out << "hsa_amd_queue_set_priority("; + out << api_data.args.hsa_amd_queue_set_priority.queue << ", "; + out << api_data.args.hsa_amd_queue_set_priority.priority; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_async_copy_rect: { + out << "hsa_amd_memory_async_copy_rect("; + out << api_data.args.hsa_amd_memory_async_copy_rect.dst << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.dst_offset << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.src << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.src_offset << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.range << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.range__val << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.copy_agent << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.dir << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.num_dep_signals << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.dep_signals << ", "; + out << api_data.args.hsa_amd_memory_async_copy_rect.completion_signal; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_runtime_queue_create_register: { + out << "hsa_amd_runtime_queue_create_register("; + out << api_data.args.hsa_amd_runtime_queue_create_register.callback << ", "; + out << api_data.args.hsa_amd_runtime_queue_create_register.user_data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_memory_lock_to_pool: { + out << "hsa_amd_memory_lock_to_pool("; + out << api_data.args.hsa_amd_memory_lock_to_pool.host_ptr << ", "; + out << api_data.args.hsa_amd_memory_lock_to_pool.size << ", "; + out << api_data.args.hsa_amd_memory_lock_to_pool.agents << ", "; + out << api_data.args.hsa_amd_memory_lock_to_pool.num_agent << ", "; + out << api_data.args.hsa_amd_memory_lock_to_pool.pool << ", "; + out << api_data.args.hsa_amd_memory_lock_to_pool.flags << ", "; + out << api_data.args.hsa_amd_memory_lock_to_pool.agent_ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_register_deallocation_callback: { + out << "hsa_amd_register_deallocation_callback("; + out << api_data.args.hsa_amd_register_deallocation_callback.ptr << ", "; + out << api_data.args.hsa_amd_register_deallocation_callback.callback + << ", "; + out << api_data.args.hsa_amd_register_deallocation_callback.user_data; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_deregister_deallocation_callback: { + out << "hsa_amd_deregister_deallocation_callback("; + out << api_data.args.hsa_amd_deregister_deallocation_callback.ptr << ", "; + out << api_data.args.hsa_amd_deregister_deallocation_callback.callback; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_signal_value_pointer: { + out << "hsa_amd_signal_value_pointer("; + out << api_data.args.hsa_amd_signal_value_pointer.signal << ", "; + out << api_data.args.hsa_amd_signal_value_pointer.value_ptr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_svm_attributes_set: { + out << "hsa_amd_svm_attributes_set("; + out << api_data.args.hsa_amd_svm_attributes_set.ptr << ", "; + out << api_data.args.hsa_amd_svm_attributes_set.size << ", "; + out << api_data.args.hsa_amd_svm_attributes_set.attribute_list << ", "; + out << api_data.args.hsa_amd_svm_attributes_set.attribute_count; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_svm_attributes_get: { + out << "hsa_amd_svm_attributes_get("; + out << api_data.args.hsa_amd_svm_attributes_get.ptr << ", "; + out << api_data.args.hsa_amd_svm_attributes_get.size << ", "; + out << api_data.args.hsa_amd_svm_attributes_get.attribute_list << ", "; + out << api_data.args.hsa_amd_svm_attributes_get.attribute_count; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_svm_prefetch_async: { + out << "hsa_amd_svm_prefetch_async("; + out << api_data.args.hsa_amd_svm_prefetch_async.ptr << ", "; + out << api_data.args.hsa_amd_svm_prefetch_async.size << ", "; + out << api_data.args.hsa_amd_svm_prefetch_async.agent << ", "; + out << api_data.args.hsa_amd_svm_prefetch_async.num_dep_signals << ", "; + out << api_data.args.hsa_amd_svm_prefetch_async.dep_signals << ", "; + out << api_data.args.hsa_amd_svm_prefetch_async.completion_signal; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_spm_acquire: { + out << "hsa_amd_spm_acquire("; + out << api_data.args.hsa_amd_spm_acquire.preferred_agent; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_spm_release: { + out << "hsa_amd_spm_release("; + out << api_data.args.hsa_amd_spm_release.preferred_agent; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_spm_set_dest_buffer: { + out << "hsa_amd_spm_set_dest_buffer("; + out << api_data.args.hsa_amd_spm_set_dest_buffer.preferred_agent << ", "; + out << api_data.args.hsa_amd_spm_set_dest_buffer.size_in_bytes << ", "; + out << api_data.args.hsa_amd_spm_set_dest_buffer.timeout << ", "; + out << api_data.args.hsa_amd_spm_set_dest_buffer.size_copied << ", "; + out << api_data.args.hsa_amd_spm_set_dest_buffer.dest << ", "; + out << api_data.args.hsa_amd_spm_set_dest_buffer.is_data_loss; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_queue_cu_get_mask: { + out << "hsa_amd_queue_cu_get_mask("; + out << api_data.args.hsa_amd_queue_cu_get_mask.queue << ", "; + out << api_data.args.hsa_amd_queue_cu_get_mask.num_cu_mask_count << ", "; + out << api_data.args.hsa_amd_queue_cu_get_mask.cu_mask; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_portable_export_dmabuf: { + out << "hsa_amd_portable_export_dmabuf("; + out << api_data.args.hsa_amd_portable_export_dmabuf.ptr << ", "; + out << api_data.args.hsa_amd_portable_export_dmabuf.size << ", "; + out << api_data.args.hsa_amd_portable_export_dmabuf.dmabuf << ", "; + out << api_data.args.hsa_amd_portable_export_dmabuf.offset; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_portable_close_dmabuf: { + out << "hsa_amd_portable_close_dmabuf("; + out << api_data.args.hsa_amd_portable_close_dmabuf.dmabuf; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_address_reserve: { + out << "hsa_amd_vmem_address_reserve("; + out << api_data.args.hsa_amd_vmem_address_reserve.va << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve.size << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve.address << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve.flags; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_address_free: { + out << "hsa_amd_vmem_address_free("; + out << api_data.args.hsa_amd_vmem_address_free.va << ", "; + out << api_data.args.hsa_amd_vmem_address_free.size; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_handle_create: { + out << "hsa_amd_vmem_handle_create("; + out << api_data.args.hsa_amd_vmem_handle_create.pool << ", "; + out << api_data.args.hsa_amd_vmem_handle_create.size << ", "; + out << api_data.args.hsa_amd_vmem_handle_create.type << ", "; + out << api_data.args.hsa_amd_vmem_handle_create.flags << ", "; + out << api_data.args.hsa_amd_vmem_handle_create.memory_handle; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_handle_release: { + out << "hsa_amd_vmem_handle_release("; + out << api_data.args.hsa_amd_vmem_handle_release.memory_handle; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_map: { + out << "hsa_amd_vmem_map("; + out << api_data.args.hsa_amd_vmem_map.va << ", "; + out << api_data.args.hsa_amd_vmem_map.size << ", "; + out << api_data.args.hsa_amd_vmem_map.in_offset << ", "; + out << api_data.args.hsa_amd_vmem_map.memory_handle << ", "; + out << api_data.args.hsa_amd_vmem_map.flags; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_unmap: { + out << "hsa_amd_vmem_unmap("; + out << api_data.args.hsa_amd_vmem_unmap.va << ", "; + out << api_data.args.hsa_amd_vmem_unmap.size; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_set_access: { + out << "hsa_amd_vmem_set_access("; + out << api_data.args.hsa_amd_vmem_set_access.va << ", "; + out << api_data.args.hsa_amd_vmem_set_access.size << ", "; + out << api_data.args.hsa_amd_vmem_set_access.desc << ", "; + out << api_data.args.hsa_amd_vmem_set_access.desc_cnt; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_get_access: { + out << "hsa_amd_vmem_get_access("; + out << api_data.args.hsa_amd_vmem_get_access.va << ", "; + out << api_data.args.hsa_amd_vmem_get_access.perms << ", "; + out << api_data.args.hsa_amd_vmem_get_access.agent_handle; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_export_shareable_handle: { + out << "hsa_amd_vmem_export_shareable_handle("; + out << api_data.args.hsa_amd_vmem_export_shareable_handle.dmabuf_fd << ", "; + out << api_data.args.hsa_amd_vmem_export_shareable_handle.handle << ", "; + out << api_data.args.hsa_amd_vmem_export_shareable_handle.flags; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_import_shareable_handle: { + out << "hsa_amd_vmem_import_shareable_handle("; + out << api_data.args.hsa_amd_vmem_import_shareable_handle.dmabuf_fd << ", "; + out << api_data.args.hsa_amd_vmem_import_shareable_handle.handle; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_retain_alloc_handle: { + out << "hsa_amd_vmem_retain_alloc_handle("; + out << api_data.args.hsa_amd_vmem_retain_alloc_handle.memory_handle << ", "; + out << api_data.args.hsa_amd_vmem_retain_alloc_handle.addr; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_get_alloc_properties_from_handle: { + out << "hsa_amd_vmem_get_alloc_properties_from_handle("; + out << api_data.args.hsa_amd_vmem_get_alloc_properties_from_handle + .memory_handle + << ", "; + out << api_data.args.hsa_amd_vmem_get_alloc_properties_from_handle.pool + << ", "; + out << api_data.args.hsa_amd_vmem_get_alloc_properties_from_handle.type; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_agent_set_async_scratch_limit: { + out << "hsa_amd_agent_set_async_scratch_limit("; + out << api_data.args.hsa_amd_agent_set_async_scratch_limit.agent << ", "; + out << api_data.args.hsa_amd_agent_set_async_scratch_limit.threshold; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_queue_get_info: { + out << "hsa_amd_queue_get_info("; + out << api_data.args.hsa_amd_queue_get_info.queue << ", "; + out << api_data.args.hsa_amd_queue_get_info.attribute << ", "; + out << api_data.args.hsa_amd_queue_get_info.value; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_amd_vmem_address_reserve_align: { + out << "hsa_amd_vmem_address_reserve_align("; + out << api_data.args.hsa_amd_vmem_address_reserve_align.va << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve_align.size << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve_align.address << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve_align.alignment << ", "; + out << api_data.args.hsa_amd_vmem_address_reserve_align.flags; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + + /* block: ImageExt API */ + case HSA_API_ID_hsa_ext_image_get_capability: { + out << "hsa_ext_image_get_capability("; + out << api_data.args.hsa_ext_image_get_capability.agent << ", "; + out << api_data.args.hsa_ext_image_get_capability.geometry << ", "; + out << api_data.args.hsa_ext_image_get_capability.image_format << ", "; + out << api_data.args.hsa_ext_image_get_capability.capability_mask; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_data_get_info: { + out << "hsa_ext_image_data_get_info("; + out << api_data.args.hsa_ext_image_data_get_info.agent << ", "; + out << api_data.args.hsa_ext_image_data_get_info.image_descriptor << ", "; + out << api_data.args.hsa_ext_image_data_get_info.access_permission << ", "; + out << api_data.args.hsa_ext_image_data_get_info.image_data_info; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_create: { + out << "hsa_ext_image_create("; + out << api_data.args.hsa_ext_image_create.agent << ", "; + out << api_data.args.hsa_ext_image_create.image_descriptor << ", "; + out << api_data.args.hsa_ext_image_create.image_data << ", "; + out << api_data.args.hsa_ext_image_create.access_permission << ", "; + out << api_data.args.hsa_ext_image_create.image; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_import: { + out << "hsa_ext_image_import("; + out << api_data.args.hsa_ext_image_import.agent << ", "; + out << api_data.args.hsa_ext_image_import.src_memory << ", "; + out << api_data.args.hsa_ext_image_import.src_row_pitch << ", "; + out << api_data.args.hsa_ext_image_import.src_slice_pitch << ", "; + out << api_data.args.hsa_ext_image_import.dst_image << ", "; + out << api_data.args.hsa_ext_image_import.image_region; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_export: { + out << "hsa_ext_image_export("; + out << api_data.args.hsa_ext_image_export.agent << ", "; + out << api_data.args.hsa_ext_image_export.src_image << ", "; + out << api_data.args.hsa_ext_image_export.dst_memory << ", "; + out << api_data.args.hsa_ext_image_export.dst_row_pitch << ", "; + out << api_data.args.hsa_ext_image_export.dst_slice_pitch << ", "; + out << api_data.args.hsa_ext_image_export.image_region; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_copy: { + out << "hsa_ext_image_copy("; + out << api_data.args.hsa_ext_image_copy.agent << ", "; + out << api_data.args.hsa_ext_image_copy.src_image << ", "; + out << api_data.args.hsa_ext_image_copy.src_offset << ", "; + out << api_data.args.hsa_ext_image_copy.dst_image << ", "; + out << api_data.args.hsa_ext_image_copy.dst_offset << ", "; + out << api_data.args.hsa_ext_image_copy.range; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_clear: { + out << "hsa_ext_image_clear("; + out << api_data.args.hsa_ext_image_clear.agent << ", "; + out << api_data.args.hsa_ext_image_clear.image << ", "; + out << api_data.args.hsa_ext_image_clear.data << ", "; + out << api_data.args.hsa_ext_image_clear.image_region; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_destroy: { + out << "hsa_ext_image_destroy("; + out << api_data.args.hsa_ext_image_destroy.agent << ", "; + out << api_data.args.hsa_ext_image_destroy.image; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_sampler_create: { + out << "hsa_ext_sampler_create("; + out << api_data.args.hsa_ext_sampler_create.agent << ", "; + out << api_data.args.hsa_ext_sampler_create.sampler_descriptor << ", "; + out << api_data.args.hsa_ext_sampler_create.sampler; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_sampler_destroy: { + out << "hsa_ext_sampler_destroy("; + out << api_data.args.hsa_ext_sampler_destroy.agent << ", "; + out << api_data.args.hsa_ext_sampler_destroy.sampler; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_get_capability_with_layout: { + out << "hsa_ext_image_get_capability_with_layout("; + out << api_data.args.hsa_ext_image_get_capability_with_layout.agent << ", "; + out << api_data.args.hsa_ext_image_get_capability_with_layout.geometry + << ", "; + out << api_data.args.hsa_ext_image_get_capability_with_layout.image_format + << ", "; + out << api_data.args.hsa_ext_image_get_capability_with_layout + .image_data_layout + << ", "; + out << api_data.args.hsa_ext_image_get_capability_with_layout + .capability_mask; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_data_get_info_with_layout: { + out << "hsa_ext_image_data_get_info_with_layout("; + out << api_data.args.hsa_ext_image_data_get_info_with_layout.agent << ", "; + out << api_data.args.hsa_ext_image_data_get_info_with_layout + .image_descriptor + << ", "; + out << api_data.args.hsa_ext_image_data_get_info_with_layout + .access_permission + << ", "; + out << api_data.args.hsa_ext_image_data_get_info_with_layout + .image_data_layout + << ", "; + out << api_data.args.hsa_ext_image_data_get_info_with_layout + .image_data_row_pitch + << ", "; + out << api_data.args.hsa_ext_image_data_get_info_with_layout + .image_data_slice_pitch + << ", "; + out << api_data.args.hsa_ext_image_data_get_info_with_layout + .image_data_info; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + case HSA_API_ID_hsa_ext_image_create_with_layout: { + out << "hsa_ext_image_create_with_layout("; + out << api_data.args.hsa_ext_image_create_with_layout.agent << ", "; + out << api_data.args.hsa_ext_image_create_with_layout.image_descriptor + << ", "; + out << api_data.args.hsa_ext_image_create_with_layout.image_data << ", "; + out << api_data.args.hsa_ext_image_create_with_layout.access_permission + << ", "; + out << api_data.args.hsa_ext_image_create_with_layout.image_data_layout + << ", "; + out << api_data.args.hsa_ext_image_create_with_layout.image_data_row_pitch + << ", "; + out << api_data.args.hsa_ext_image_create_with_layout.image_data_slice_pitch + << ", "; + out << api_data.args.hsa_ext_image_create_with_layout.image; + out << ") = " << api_data.hsa_status_t_retval; + break; + } + default: + out << "ERROR: unknown API"; + abort(); + } + return out; +} +#endif +#endif /* HSA_PROF_STR_H_ */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer.h new file mode 100644 index 000000000..69446676a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer.h @@ -0,0 +1,780 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +/** \mainpage ROC Tracer API Specification + * + * \section introduction Introduction + * + * ROCtracer library, Runtimes Generic Callback/Activity APIs. + * + * The goal of the implementation is to provide a generic independent from + * specific runtime profiler to trace API and asynchronous activity. + * + * The API provides functionality for registering the runtimes API callbacks + * and asynchronous activity records pool support. + * + * \section known_limitations Known Limitations and Restrictions + * + * The ROCtracer API library implementation currently has the following + * restrictions. Future releases aim to address these restrictions. + * + * 1. The ACTIVITY_DOMAIN_HSA_OPS operations HSA_OP_ID_DISPATCH, + * HSA_OP_ID_BARRIER, and HSA_OP_ID_RESERVED1 are not currently implemented. + */ + +/** + * \file + * ROCtracer API interface. + */ + +#ifndef ROCTRACER_H_ +#define ROCTRACER_H_ + +/* Placeholder for calling convention and import/export macros */ +#if !defined(ROCTRACER_CALL) +#define ROCTRACER_CALL +#endif /* !defined (ROCTRACER_CALL) */ + +#if !defined(ROCTRACER_EXPORT_DECORATOR) +#if defined(__GNUC__) +#define ROCTRACER_EXPORT_DECORATOR __attribute__((visibility("default"))) +#elif defined(_MSC_VER) +#define ROCTRACER_EXPORT_DECORATOR __declspec(dllexport) +#endif /* defined (_MSC_VER) */ +#endif /* !defined (ROCTRACER_EXPORT_DECORATOR) */ + +#if !defined(ROCTRACER_IMPORT_DECORATOR) +#if defined(__GNUC__) +#define ROCTRACER_IMPORT_DECORATOR +#elif defined(_MSC_VER) +#define ROCTRACER_IMPORT_DECORATOR __declspec(dllimport) +#endif /* defined (_MSC_VER) */ +#endif /* !defined (ROCTRACER_IMPORT_DECORATOR) */ + +#define ROCTRACER_EXPORT ROCTRACER_EXPORT_DECORATOR ROCTRACER_CALL +#define ROCTRACER_IMPORT ROCTRACER_IMPORT_DECORATOR ROCTRACER_CALL + +#if !defined(ROCTRACER) +#if defined(ROCTRACER_EXPORTS) +#define ROCTRACER_API ROCTRACER_EXPORT +#else /* !defined (ROCTRACER_EXPORTS) */ +#define ROCTRACER_API ROCTRACER_IMPORT +#endif /* !defined (ROCTRACER_EXPORTS) */ +#endif /* !defined (ROCTRACER) */ + +#include +#include + +#include "ext/prof_protocol.h" + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** \defgroup symbol_versions_group Symbol Versions + * + * The names used for the shared library versioned symbols. + * + * Every function is annotated with one of the version macros defined in this + * section. Each macro specifies a corresponding symbol version string. After + * dynamically loading the shared library with \p dlopen, the address of each + * function can be obtained using \p dlvsym with the name of the function and + * its corresponding symbol version string. An error will be reported by \p + * dlvsym if the installed library does not support the version for the + * function specified in this version of the interface. + * + * @{ + */ + +/** + * The function was introduced in version 4.1 of the interface and has the + * symbol version string of ``"ROCTRACER_4.1"``. + */ +#define ROCTRACER_VERSION_4_1 + +/** @} */ + +/** \defgroup versioning_group Versioning + * + * Version information about the interface and the associated installed + * library. + * + * The semantic version of the interface following semver.org rules. A client + * that uses this interface is only compatible with the installed library if + * the major version numbers match and the interface minor version number is + * less than or equal to the installed library minor version number. + * + * @{ + */ + +/** + * The major version of the interface as a macro so it can be used by the + * preprocessor. + */ +#define ROCTRACER_VERSION_MAJOR 4 + +/** + * The minor version of the interface as a macro so it can be used by the + * preprocessor. + */ +#define ROCTRACER_VERSION_MINOR 1 + +/** + * Query the major version of the installed library. + * + * Return the major version of the installed library. This can be used to + * check if it is compatible with this interface version. This function can be + * used even when the library is not initialized. + */ +ROCTRACER_API uint32_t roctracer_version_major() ROCTRACER_VERSION_4_1; + +/** + * Query the minor version of the installed library. + * + * Return the minor version of the installed library. This can be used to + * check if it is compatible with this interface version. This function can be + * used even when the library is not initialized. + */ +ROCTRACER_API uint32_t roctracer_version_minor() ROCTRACER_VERSION_4_1; + +/** @} */ + +/** \defgroup status_codes_group Status Codes + * + * Most operations return a status code to indicate success or error. + * + * @{ + */ + +/** + * ROC Tracer API status codes. + */ +typedef enum { + /** + * The function has executed successfully. + */ + ROCTRACER_STATUS_SUCCESS = 0, + /** + * A generic error has occurred. + */ + ROCTRACER_STATUS_ERROR = -1, + /** + * The domain ID is invalid. + */ + ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID = -2, + /** + * An invalid argument was given to the function. + */ + ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT = -3, + /** + * No default pool is defined. + */ + ROCTRACER_STATUS_ERROR_DEFAULT_POOL_UNDEFINED = -4, + /** + * The default pool is already defined. + */ + ROCTRACER_STATUS_ERROR_DEFAULT_POOL_ALREADY_DEFINED = -5, + /** + * Memory allocation error. + */ + ROCTRACER_STATUS_ERROR_MEMORY_ALLOCATION = -6, + /** + * External correlation ID pop mismatch. + */ + ROCTRACER_STATUS_ERROR_MISMATCHED_EXTERNAL_CORRELATION_ID = -7, + /** + * The operation is not currently implemented. This error may be reported by + * any function. Check the \ref known_limitations section to determine the + * status of the library implementation of the interface. + */ + ROCTRACER_STATUS_ERROR_NOT_IMPLEMENTED = -8, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_UNINIT = 2, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_BREAK = 3, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_BAD_DOMAIN = ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_BAD_PARAMETER = ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_HIP_API_ERR = 6, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_HIP_OPS_ERR = 7, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_HCC_OPS_ERR = ROCTRACER_STATUS_HIP_OPS_ERR, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_HSA_ERR = 7, + /** + * Deprecated error code. + */ + ROCTRACER_STATUS_ROCTX_ERR = 8, +} roctracer_status_t; + +/** + * Query the textual description of the last error for the current thread. + * + * Returns a NUL terminated string describing the error of the last ROC Tracer + * API call by the calling thread that did not return success. The empty + * string is returned if there is no previous error. The last error is not + * cleared. + * + * \return Return the error string. The caller owns the returned string and + * should use \p free() to deallocate it. + */ +ROCTRACER_API const char *roctracer_error_string() ROCTRACER_VERSION_4_1; + +/** @} */ + +/** \defgroup domain_group Traced Runtime Domains + * + * The ROC Tracer API can trace multiple runtime libraries. Each library can + * have API operations and asynchronous operations that can be traced. + * + * @{ + */ + +/** + * Enumeration of domains that can be traced. + */ +typedef activity_domain_t roctracer_domain_t; + +/** + * Query textual name of an operation of a domain. + * + * @param[in] domain Domain being queried. + * + * @param[in] op Operation within \p domain. + * + * @param[in] kind \todo Define kind. + * + * @return Returns the NUL terminated string for the operation name, or NULL if + * the domain or operation are invalid. The string is owned by the ROC Tracer + * library. + */ +ROCTRACER_API const char * +roctracer_op_string(uint32_t domain, uint32_t op, + uint32_t kind) ROCTRACER_VERSION_4_1; + +/** + * Query the operation code given a domain and the name of an operation. + * + * @param[in] domain The domain being queried. + * + * @param[in] str The NUL terminated name of the operation name being queried. + * + * @param[out] op The operation code. + * + * @param[out] kind If not NULL then the operation kind code. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. \p op and \p kind have been updated. + * + * @retval ::ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT The \p op is invalid for + * \p domain. + * + * @retval ::ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID The domain is invalid or + * not supported. + */ +ROCTRACER_API roctracer_status_t +roctracer_op_code(uint32_t domain, const char *str, uint32_t *op, + uint32_t *kind) ROCTRACER_VERSION_4_1; + +/** + * Set the properties of a domain. + * + * @param[in] domain The domain. + * + * @param[in] properties The properties. Each domain defines its own type for + * the properties. Some domains require the properties to be set before they + * can be enabled. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + */ +ROCTRACER_API roctracer_status_t roctracer_set_properties( + roctracer_domain_t domain, void *properties) ROCTRACER_VERSION_4_1; + +/** @} */ + +/** \defgroup callback_api_group Callback API + * + * ROC tracer provides support for runtime API callbacks and activity + * records logging. The API callbacks provide the API calls arguments and are + * called on different phases, on enter, on exit, on kernel completion. + * + * @{ + */ + +/** + * Runtime API callback type. + * + * The callback that will be invoked when an enabled runtime API is called. The + * callback is invoked on entry and on exit. + */ +typedef activity_rtapi_callback_t roctracer_rtapi_callback_t; + +/** + * Enable runtime API callback for a specific operation of a domain. + * + * @param domain The domain. + * + * @param op The operation ID in \p domain. + * + * @param callback The callback to invoke each time the operation is performed + * on entry and exit. + * + * @param arg Value to pass as last argument of \p callback. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ::ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID \p domain is invalid. + * + * @retval ::ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT \p op is invalid for \p + * domain. + */ +ROCTRACER_API roctracer_status_t roctracer_enable_op_callback( + activity_domain_t domain, uint32_t op, activity_rtapi_callback_t callback, + void *arg) ROCTRACER_VERSION_4_1; + +/** + * Enable runtime API callback for all operations of a domain. + * + * @param domain The domain + * + * @param callback The callback to invoke each time the operation is performed + * on entry and exit. + * + * @param arg Value to pass as last argument of \p callback. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ::ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID \p domain is invalid. + */ +ROCTRACER_API roctracer_status_t roctracer_enable_domain_callback( + activity_domain_t domain, activity_rtapi_callback_t callback, + void *arg) ROCTRACER_VERSION_4_1; + +/** + * Disable runtime API callback for a specific operation of a domain. + * + * @param domain The domain + * + * @param op The operation in \p domain. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ::ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID \p domain is invalid. + * + * @retval ::ROCTRACER_STATUS_ERROR_INVALID_ARGUMENT \p op is invalid for \p + * domain. + */ +ROCTRACER_API roctracer_status_t roctracer_disable_op_callback( + activity_domain_t domain, uint32_t op) ROCTRACER_VERSION_4_1; + +/** + * Disable runtime API callback for all operations of a domain. + * + * @param domain The domain + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ::ROCTRACER_STATUS_ERROR_INVALID_DOMAIN_ID \p domain is invalid. + */ +ROCTRACER_API roctracer_status_t roctracer_disable_domain_callback( + activity_domain_t domain) ROCTRACER_VERSION_4_1; + +/** @} */ + +/** \defgroup activity_api_group Activity API + * + * The activity records are asynchronously logged to the pool and can be + * associated with the respective API callbacks using the correlation ID. + * Activity API can be used to enable collecting of the records with + * timestamping data for API calls and the kernel submits. + * + * @{ + */ + +/** + * Activity record. + * + * Asynchronous activity events generate activity records. + */ +typedef activity_record_t roctracer_record_t; + +/** + * Get a pointer to the next activity record. + * + * A memory pool generates buffers that contain multiple activity records. + * This function steps to the next activity record. + * + * @param[in] record Pointer to ac activity record in a memory pool buffer. + * + * @param[out] next Pointer to the following activity record in the memory pool + * buffer. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + */ +ROCTRACER_API roctracer_status_t +roctracer_next_record(const activity_record_t *record, + const activity_record_t **next) ROCTRACER_VERSION_4_1; + +/** + * Memory pool allocator callback. + * + * If \p *ptr is NULL, then allocate memory of \p size bytes and save address + * in \p *ptr. + * + * If \p *ptr is non-NULL and size is non-0, then reallocate the memory at \p + * *ptr with size \p size and save the address in \p *ptr. The memory will have + * been allocated by the same callback. + * + * If \p *ptr is non-NULL and size is 0, then deallocate the memory at \p *ptr. + * The memory will have been allocated by the same callback. + * + * \p size is the size of the memory allocation or reallocation, or 0 if + * deallocating. + * + * \p arg Argument provided in the ::roctracer_properties_t passed to the + * ::roctracer_open_pool function. + */ +typedef void (*roctracer_allocator_t)(char **ptr, size_t size, void *arg); + +/** + * Memory pool buffer callback. + * + * The callback that will be invoked when a memory pool buffer becomes full or + * is flushed. + * + * \p begin pointer to first entry entry in the buffer. + * + * \p end pointer to one past the end entry in the buffer. + * + * \p arg the argument specified when the callback was defined. + */ +typedef void (*roctracer_buffer_callback_t)(const char *begin, const char *end, + void *arg); + +/** + * Memory pool properties. + * + * Defines the properties when a tracer memory pool is created. + */ +typedef struct { + /** + * ROC Tracer mode. + */ + uint32_t mode; + + /** + * Size of buffer in bytes. + */ + size_t buffer_size; + + /** + * The allocator function to use to allocate and deallocate the buffer. If + * NULL then \p malloc, \p realloc, and \p free are used. + */ + roctracer_allocator_t alloc_fun; + + /** + * The argument to pass when invoking the \p alloc_fun allocator. + */ + void *alloc_arg; + + /** + * The function to call when a buffer becomes full or is flushed. + */ + roctracer_buffer_callback_t buffer_callback_fun; + + /** + * The argument to pass when invoking the \p buffer_callback_fun callback. + */ + void *buffer_callback_arg; +} roctracer_properties_t; + +/** + * Tracer memory pool type. + */ +typedef void roctracer_pool_t; + +/** + * Create tracer memory pool. + * + * If \p pool is not NULL, returns the created memory pool. Does not change the + * default memory pool. + * + * If \p pool is NULL, sets the default memory pool to the created pool if not + * already defined. Otherwise, return an error. + * + * @param[in] properties Tracer memory pool properties. + * + * @param[out] pool Tracer memory pool created if not NULL. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ROCTRACER_STATUS_ERROR_DEFAULT_POOL_ALREADY_DEFINED \p pool is NULL + * and the default pool is already defined. Unable to create the pool. + * + * @retval ROCTRACER_STATUS_ERROR_MEMORY_ALLOCATION Unable to allocate memory + * for the \p pool. Unable to create the pool. + */ +ROCTRACER_API roctracer_status_t +roctracer_open_pool_expl(const roctracer_properties_t *properties, + roctracer_pool_t **pool) ROCTRACER_VERSION_4_1; + +/** + * Create tracer memory pool. + * + * Sets the default memory pool to the created pool if not already defined. + * Otherwise, return an error. + * + * @param[in] properties Tracer memory pool properties. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ROCTRACER_STATUS_ERROR_DEFAULT_POOL_ALREADY_DEFINED The default pool + * is already defined. Unable to create the pool. + * + * @retval ROCTRACER_STATUS_ERROR_MEMORY_ALLOCATION Unable to allocate memory + * for the \p pool. Unable to create the pool. + */ +ROCTRACER_API roctracer_status_t roctracer_open_pool( + const roctracer_properties_t *properties) ROCTRACER_VERSION_4_1; + +/** + * Close tracer memory pool. + * + * All enabled activities that use the pool must have completed writing to the + * pool, before deleting the pool. Deleting a pool automatically disables any + * activities that specify the pool, and flushes it. + * + * @param[in] pool Memory pool to close. If NULL, the default memory pool is + * closed if defined. The default memory pool is set to undefined if closed. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully or pool was NULL and there is no default pool. + */ +ROCTRACER_API roctracer_status_t +roctracer_close_pool_expl(roctracer_pool_t *pool) ROCTRACER_VERSION_4_1; + +/** + * Close default tracer memory pool, if defined, and set to undefined. + * + * All enabled activities that use the pool must have completed writing to the + * pool, before deleting the pool. Deleting a pool automatically disables any + * activities that specify the pool, and flushes it. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully or there is no default pool. + */ +ROCTRACER_API roctracer_status_t roctracer_close_pool() ROCTRACER_VERSION_4_1; + +/** + * Query and set the default memory pool. + * + * @param[in] pool If not NULL, change the current default pool to \p pool. If + * NULL, the default pool is not changed. + * + * @return Return the current default memory pool before any change, or NULL if + * none is defined. + */ +ROCTRACER_API roctracer_pool_t * +roctracer_default_pool_expl(roctracer_pool_t *pool) ROCTRACER_VERSION_4_1; + +/** + * Query the current default memory pool. + * + * @return Return the current default memory pool, or NULL is none is defined. + */ +ROCTRACER_API roctracer_pool_t *roctracer_default_pool() ROCTRACER_VERSION_4_1; + +/** + * Enable activity record logging for a specified operation of a domain + * providing a memory pool. + * + * @param[in] domain The domain. + * + * @param[in] op The activity operation ID in \p domain. + * + * @param[in] pool The memory pool to write the activity record. If NULL, use + * the default memory pool. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ROCTRACER_STATUS_ERROR \p pool is NULL and no default pool is + * defined. + */ +ROCTRACER_API roctracer_status_t +roctracer_enable_op_activity_expl(activity_domain_t domain, uint32_t op, + roctracer_pool_t *pool) ROCTRACER_VERSION_4_1; + +/** + * Enable activity record logging for a specified operation of a domain using + * the default memory pool. + * + * @param[in] domain The domain. + * + * @param[in] op The activity operation ID in \p domain. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ROCTRACER_STATUS_ERROR No default pool is defined. + */ +ROCTRACER_API roctracer_status_t roctracer_enable_op_activity( + activity_domain_t domain, uint32_t op) ROCTRACER_VERSION_4_1; + +/** + * Enable activity record logging for all operations of a domain providing a + * memory pool. + * + * @param[in] domain The domain. + * + * @param[in] pool The memory pool to write the activity record. If NULL, use + * the default memory pool. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ROCTRACER_STATUS_ERROR \p pool is NULL and no default pool is + * defined. + */ +ROCTRACER_API roctracer_status_t roctracer_enable_domain_activity_expl( + activity_domain_t domain, roctracer_pool_t *pool) ROCTRACER_VERSION_4_1; + +/** + * Enable activity record logging for all operations of a domain using the + * default memory pool. + * + * @param[in] domain The domain. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + * + * @retval ROCTRACER_STATUS_ERROR No default pool is defined. + */ +ROCTRACER_API roctracer_status_t roctracer_enable_domain_activity( + activity_domain_t domain) ROCTRACER_VERSION_4_1; + +/** + * Disable activity record logging for a specified operation of a domain. + * + * @param[in] domain The domain. + * + * @param[in] op The activity operation ID in \p domain. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + */ +ROCTRACER_API roctracer_status_t roctracer_disable_op_activity( + activity_domain_t domain, uint32_t op) ROCTRACER_VERSION_4_1; + +/** + * Disable activity record logging for all operations of a domain. + * + * @param[in] domain The domain. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + */ +ROCTRACER_API roctracer_status_t roctracer_disable_domain_activity( + activity_domain_t domain) ROCTRACER_VERSION_4_1; + +/** + * Flush available activity records for a memory pool. + * + * If flushing encounters an activity record still being written, flushing + * stops. Use a subsequent flush when the record has completed being written to + * resume the flush. + * + * @param[in] pool The memory pool to flush. If NULL, flushes the default + * memory pool. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + */ +ROCTRACER_API roctracer_status_t +roctracer_flush_activity_expl(roctracer_pool_t *pool) ROCTRACER_VERSION_4_1; + +/** + * Flush available activity records for the default memory pool. + * + * If flushing encounters an activity record still being written, flushing + * stops. Use a subsequent flush when the record has completed being written to + * resume the flush. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + */ +ROCTRACER_API roctracer_status_t roctracer_flush_activity() + ROCTRACER_VERSION_4_1; + +/** @} */ + +/** \defgroup timestamp_group Timestamp Operations + * + * + * + * @{ + */ + +/** + * Get the system clock timestamp. + * + * @param[out] timestamp The system clock timestamp in nano seconds. + * + * @retval ::ROCTRACER_STATUS_SUCCESS The function has been executed + * successfully. + */ +ROCTRACER_API roctracer_status_t +roctracer_get_timestamp(roctracer_timestamp_t *timestamp) ROCTRACER_VERSION_4_1; + +/** @} */ + +#ifdef __cplusplus +} /* extern "C" block */ +#endif /* __cplusplus */ + +#endif /* ROCTRACER_H_ */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_ext.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_ext.h new file mode 100644 index 000000000..66b2972c5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_ext.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +//////////////////////////////////////////////////////////////////////////////// +// +// ROC Tracer Extension API +// +// The API provides functionality for application annotation with event and +// external ranges correlation +// +//////////////////////////////////////////////////////////////////////////////// + +#ifndef ROCTRACER_EXT_H_ +#define ROCTRACER_EXT_H_ + +#include "roctracer.h" + +/* Extension API opcodes */ +typedef enum { + ACTIVITY_EXT_OP_MARK = 0, + ACTIVITY_EXT_OP_EXTERN_ID = 1 +} activity_ext_op_t; + +typedef void (*roctracer_start_cb_t)(); +typedef void (*roctracer_stop_cb_t)(); +typedef struct { + roctracer_start_cb_t start_cb; + roctracer_stop_cb_t stop_cb; +} roctracer_ext_properties_t; + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//////////////////////////////////////////////////////////////////////////////// +// Application annotation API + +// Tracing start API +void ROCTRACER_API roctracer_start() ROCTRACER_VERSION_4_1; + +// Tracing stop API +void ROCTRACER_API roctracer_stop() ROCTRACER_VERSION_4_1; + +//////////////////////////////////////////////////////////////////////////////// +// External correlation id API + +// Notifies that the calling thread is entering an external API region. +// Push an external correlation id for the calling thread. +roctracer_status_t ROCTRACER_API +roctracer_activity_push_external_correlation_id(activity_correlation_id_t id) + ROCTRACER_VERSION_4_1; + +// Notifies that the calling thread is leaving an external API region. +// Pop an external correlation id for the calling thread. +// 'lastId' returns the last external correlation if not NULL +roctracer_status_t ROCTRACER_API roctracer_activity_pop_external_correlation_id( + activity_correlation_id_t *last_id) ROCTRACER_VERSION_4_1; + +#ifdef __cplusplus +} // extern "C" block +#endif // __cplusplus + +#endif // ROCTRACER_EXT_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hcc.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hcc.h new file mode 100644 index 000000000..702be14f0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hcc.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +#pragma message( \ + "This file has been deprecated and marked for removal. Please use roctracer_hip.h instead.") + +#include "roctracer_hip.h" diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hip.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hip.h new file mode 100644 index 000000000..74376c734 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hip.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +#ifndef ROCTRACER_HIP_H_ +#define ROCTRACER_HIP_H_ + +#include "roctracer.h" + +#include +#include +#include + +typedef enum { + HIP_OP_ID_DISPATCH = 0, + HIP_OP_ID_COPY = 1, + HIP_OP_ID_BARRIER = 2, + HIP_OP_ID_NUMBER = 3 +} hip_op_id_t; + +#endif // ROCTRACER_HIP_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hsa.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hsa.h new file mode 100644 index 000000000..14d4250e0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_hsa.h @@ -0,0 +1,112 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +#ifndef ROCTRACER_HSA_H_ +#define ROCTRACER_HSA_H_ + +#include "roctracer.h" + +#include "hsa_ostream_ops.h" +#include "hsa_prof_str.h" +#include +#include + +// HSA OP ID enumeration +enum hsa_op_id_t { + HSA_OP_ID_DISPATCH = 0, + HSA_OP_ID_COPY = 1, + HSA_OP_ID_BARRIER = 2, + HSA_OP_ID_RESERVED1 = 3, + HSA_OP_ID_NUMBER +}; + +// HSA EVT ID enumeration +enum hsa_evt_id_t { + HSA_EVT_ID_ALLOCATE = 0, // Memory allocate callback + HSA_EVT_ID_DEVICE = 1, // Device assign callback + HSA_EVT_ID_MEMCOPY = 2, // Memcopy callback + HSA_EVT_ID_SUBMIT = 3, // Packet submission callback + HSA_EVT_ID_KSYMBOL = 4, // Loading/unloading of kernel symbol + HSA_EVT_ID_CODEOBJ = 5, // Loading/unloading of device code object + HSA_EVT_ID_NUMBER +}; + +struct hsa_ops_properties_t { + void *reserved1[4]; +}; + +// HSA EVT data type +typedef struct { + union { + struct { + const void *ptr; // allocated area ptr + size_t size; // allocated area size, zero size means 'free' callback + hsa_amd_segment_t segment; // allocated area's memory segment type + hsa_amd_memory_pool_global_flag_t + global_flag; // allocated area's memory global flag + int is_code; // equal to 1 if code is allocated + } allocate; + + struct { + hsa_device_type_t type; // type of assigned device + uint32_t id; // id of assigned device + hsa_agent_t agent; // device HSA agent handle + const void *ptr; // ptr the device is assigned to + } device; + + struct { + const void *dst; // memcopy dst ptr + const void *src; // memcopy src ptr + size_t size; // memcopy size bytes + } memcopy; + + struct { + const void *packet; // submitted to GPU packet + const char + *kernel_name; // kernel name, NULL if not a kernel dispatch packet + hsa_queue_t *queue; // HSA queue the packet was submitted to + uint32_t device_type; // type of device the packet is submitted to + uint32_t device_id; // id of device the packet is submitted to + } submit; + + struct { + uint64_t object; // kernel symbol object + const char *name; // kernel symbol name + uint32_t name_length; // kernel symbol name length + int unload; // symbol executable destroy + } ksymbol; + + struct { + uint32_t storage_type; // code object storage type + int storage_file; // origin file descriptor + uint64_t memory_base; // origin memory base + uint64_t memory_size; // origin memory size + uint64_t load_base; // code object load base + uint64_t load_size; // code object load size + uint64_t load_delta; // code object load size + uint32_t uri_length; // URI string length (not including the terminating + // NUL character) + const char *uri; // URI string + int unload; // unload flag + } codeobj; + }; +} hsa_evt_data_t; + +#endif // ROCTRACER_HSA_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_plugin.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_plugin.h new file mode 100644 index 000000000..b4a47d56b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_plugin.h @@ -0,0 +1,140 @@ +/* Copyright (c) 2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +/** \section roctracer_plugin_api ROCtracer Plugin API + * + * The ROCtracer Plugin API is used by the ROCtracer Tool to output all tracing + * information. Different implementations of the ROCtracer Plugin API can be + * developed that output the tracing data in different formats. + * The ROCtracer Tool can be configured to load a specific library that + * supports the user desired format. + * + * The API is not thread safe. It is the responsibility of the ROCtracer Tool + * to ensure the operations are synchronized and not called concurrently. There + * is no requirement for the ROCtracer Tool to report trace data in any + * specific order. If the format supported by plugin requires specific + * ordering, it is the responsibility of the plugin implementation to perform + * any necessary sorting. + */ + +/** + * \file + * ROCtracer Tool Plugin API interface. + */ + +#ifndef ROCTRACER_PLUGIN_H_ +#define ROCTRACER_PLUGIN_H_ + +#include "roctracer.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** \defgroup initialization_group Initialization and Finalization + * + * The ROCtracer Plugin API must be initialized before using any of the + * operations to report trace data, and finalized after the last trace data has + * been reported. + * + * @{ + */ + +/** + * Initialize plugin. + * + * Must be called before any other operation. + * + * @param[in] roctracer_major_version The major version of the ROCtracer API + * being used by the ROCtracer Tool. An error is reported if this does not + * match the major version of the ROCtracer API used to build the plugin + * library. This ensures compatibility of the trace data format. + * + * @param[in] roctracer_minor_version The minor version of the ROCtracer API + * being used by the ROCtracer Tool. An error is reported if the + * \p roctracer_major_version matches and this is greater than the minor + * version of the ROCtracer API used to build the plugin library. This ensures + * compatibility of the trace data format. + * + * @return Returns 0 on success and -1 on error. + */ +ROCTRACER_EXPORT int +roctracer_plugin_initialize(uint32_t roctracer_major_version, + uint32_t roctracer_minor_version); + +/** + * Finalize plugin. + * + * This must be called after ::roctracer_plugin_initialize and after all trace + * data has been reported by ::roctracer_plugin_write_callback_record and + * ::roctracer_plugin_write_activity_records. + */ +ROCTRACER_EXPORT void roctracer_plugin_finalize(); + +/** @} */ + +/** \defgroup trace_record_write_functions Trace data reporting + * + * Operations to output trace data. + * + * @{ + */ + +/** + * Report a single callback trace data. + * + * @param[in] record Primarily domain independent trace data. + * + * @param[in] callback_data Domain specific trace data. The type of this + * argument depends on the values of \p record.domain. + * + * @return Returns 0 on success and -1 on error. + */ +ROCTRACER_EXPORT int +roctracer_plugin_write_callback_record(const roctracer_record_t *record, + const void *callback_data); + +/** + * Report a range of activity trace data. + * + * Reports a range of primarily domain independent trace data. The range is + * specified by a pointer to the first record and a pointer to one past the + * last record. ::roctracer_next_record is used to iterate the range in forward + * order. + * + * @param[in] begin Pointer to the first record. + * + * @param[in] end Pointer to one past the last record. + * + * @return Returns 0 on success and -1 on error. + */ +ROCTRACER_EXPORT int +roctracer_plugin_write_activity_records(const roctracer_record_t *begin, + const roctracer_record_t *end); + +/** @} */ + +#ifdef __cplusplus +} /* extern "C" */ +#endif /* __cplusplus */ + +#endif /* ROCTRACER_PLUGIN_H_ */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_roctx.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_roctx.h new file mode 100644 index 000000000..b1c2d4dba --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctracer_roctx.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +#ifndef ROCTRACER_ROCTX_H_ +#define ROCTRACER_ROCTX_H_ + +#include "roctx.h" + +/** + * ROCTX API ID enumeration + */ +enum roctx_api_id_t { + ROCTX_API_ID_roctxMarkA = 0, + ROCTX_API_ID_roctxRangePushA = 1, + ROCTX_API_ID_roctxRangePop = 2, + ROCTX_API_ID_roctxRangeStartA = 3, + ROCTX_API_ID_roctxRangeStop = 4, + ROCTX_API_ID_NUMBER, +}; + +/** + * ROCTX callbacks data type + */ +typedef struct roctx_api_data_s { + union { + struct { + const char *message; + roctx_range_id_t id; + }; + struct { + const char *message; + } roctxMarkA; + struct { + const char *message; + } roctxRangePushA; + struct { + const char *message; + } roctxRangePop; + struct { + const char *message; + roctx_range_id_t id; + } roctxRangeStartA; + struct { + const char *message; + roctx_range_id_t id; + } roctxRangeStop; + } args; +} roctx_api_data_t; + +#endif /* ROCTRACER_ROCTX_H_ */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctx.h b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctx.h new file mode 100644 index 000000000..142805794 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/backend/include/roctracer/roctx.h @@ -0,0 +1,229 @@ +/* Copyright (c) 2018-2022 Advanced Micro Devices, Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ + +/** \mainpage ROCTX API Specification + * + * \section introduction Introduction + * ROCTX is a library that implements the AMD code annotation API. It provides + * the support necessary to annotate events and code ranges in applications. + */ + +/** + * \file + * ROCTX API interface. + */ + +#ifndef ROCTX_H_ +#define ROCTX_H_ 1 + +/* Placeholder for calling convention and import/export macros */ +#if !defined(ROCTX_CALL) +#define ROCTX_CALL +#endif /* !defined (ROCTX_CALL) */ + +#if !defined(ROCTX_EXPORT_DECORATOR) +#if defined(__GNUC__) +#define ROCTX_EXPORT_DECORATOR __attribute__((visibility("default"))) +#elif defined(_MSC_VER) +#define ROCTX_EXPORT_DECORATOR __declspec(dllexport) +#endif /* defined (_MSC_VER) */ +#endif /* !defined (ROCTX_EXPORT_DECORATOR) */ + +#if !defined(ROCTX_IMPORT_DECORATOR) +#if defined(__GNUC__) +#define ROCTX_IMPORT_DECORATOR +#elif defined(_MSC_VER) +#define ROCTX_IMPORT_DECORATOR __declspec(dllimport) +#endif /* defined (_MSC_VER) */ +#endif /* !defined (ROCTX_IMPORT_DECORATOR) */ + +#define ROCTX_EXPORT ROCTX_EXPORT_DECORATOR ROCTX_CALL +#define ROCTX_IMPORT ROCTX_IMPORT_DECORATOR ROCTX_CALL + +#if !defined(ROCTX) +#if defined(ROCTX_EXPORTS) +#define ROCTX_API ROCTX_EXPORT +#else /* !defined (ROCTX_EXPORTS) */ +#define ROCTX_API ROCTX_IMPORT +#endif /* !defined (ROCTX_EXPORTS) */ +#endif /* !defined (ROCTX) */ + +#include + +#if defined(__cplusplus) +extern "C" { +#endif /* defined(__cplusplus) */ + +/** \defgroup symbol_versions_group Symbol Versions + * + * The names used for the shared library versioned symbols. + * + * Every function is annotated with one of the version macros defined in this + * section. Each macro specifies a corresponding symbol version string. After + * dynamically loading the shared library with \p dlopen, the address of each + * function can be obtained using \p dlvsym with the name of the function and + * its corresponding symbol version string. An error will be reported by \p + * dlvsym if the installed library does not support the version for the + * function specified in this version of the interface. + * + * @{ + */ + +/** + * The function was introduced in version 4.1 of the interface and has the + * symbol version string of ``"ROCTX_4.1"``. + */ +#define ROCTX_VERSION_4_1 + +/** @} */ + +/** \defgroup versioning_group Versioning + * + * Version information about the interface and the associated installed + * library. + * + * @{ + */ + +/** + * The semantic version of the interface following + * [semver.org][semver] rules. + * + * A client that uses this interface is only compatible with the installed + * library if the major version numbers match and the interface minor version + * number is less than or equal to the installed library minor version number. + */ + +/** + * The major version of the interface as a macro so it can be used by the + * preprocessor. + */ +#define ROCTX_VERSION_MAJOR 4 + +/** + * The minor version of the interface as a macro so it can be used by the + * preprocessor. + */ +#define ROCTX_VERSION_MINOR 1 + +/** + * Query the major version of the installed library. + * + * Return the major version of the installed library. This can be used to check + * if it is compatible with this interface version. + * + * \return Returns the major version number. + */ +ROCTX_API uint32_t roctx_version_major() ROCTX_VERSION_4_1; + +/** + * Query the minor version of the installed library. + * + * Return the minor version of the installed library. This can be used to check + * if it is compatible with this interface version. + * + * \return Returns the minor version number. + */ +ROCTX_API uint32_t roctx_version_minor() ROCTX_VERSION_4_1; + +/** @} */ + +/** \defgroup marker_group ROCTX Markers + * + * Marker annotations are used to describe events in a ROCm application. + * + * @{ + */ + +/** + * Mark an event. + * + * \param[in] message The message associated with the event. + */ +ROCTX_API void roctxMarkA(const char *message) ROCTX_VERSION_4_1; +#define roctxMark(message) roctxMarkA(message) + +/** @} */ + +/** \defgroup range_group ROCTX Ranges + * + * Range annotations are used to describe events in a ROCm application. + * + * @{ + */ + +/** + * Start a new nested range. + * + * Nested ranges are stacked and local to the current CPU thread. + * + * \param[in] message The message associated with this range. + * + * \return Returns the level this nested range is started at. Nested range + * levels are 0 based. + */ +ROCTX_API int roctxRangePushA(const char *message) ROCTX_VERSION_4_1; +#define roctxRangePush(message) roctxRangePushA(message) + +/** + * Stop the current nested range. + * + * Stop the current nested range, and pop it from the stack. If a nested range + * was active before the last one was started, it becomes again the current + * nested range. + * + * \return Returns the level the stopped nested range was started at, or a + * negative value if there was no nested range active. + */ +ROCTX_API int roctxRangePop() ROCTX_VERSION_4_1; + +/** + * ROCTX range ID. + * + * This is the range ID used to identify start/end ranges. + */ +typedef uint64_t roctx_range_id_t; + +/** + * Starts a process range. + * + * Start/stop ranges can be started and stopped in different threads. Each + * timespan is assigned a unique range ID. + * + * \param[in] message The message associated with this range. + * + * \return Returns the ID of the new range. + */ +ROCTX_API roctx_range_id_t roctxRangeStartA(const char *message) + ROCTX_VERSION_4_1; +#define roctxRangeStart(message) roctxRangeStartA(message) + +/** + * Stop a process range. + */ +ROCTX_API void roctxRangeStop(roctx_range_id_t id) ROCTX_VERSION_4_1; + +/** @} */ + +#if defined(__cplusplus) +} /* extern "C" */ +#endif /* defined (__cplusplus) */ + +#endif /* ROCTX_H_ */ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/lib/asanrtl.bc b/third_party/enflame/include/triton/third_party/amd/backend/lib/asanrtl.bc new file mode 100644 index 000000000..eb4432074 Binary files /dev/null and b/third_party/enflame/include/triton/third_party/amd/backend/lib/asanrtl.bc differ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/lib/ockl.bc b/third_party/enflame/include/triton/third_party/amd/backend/lib/ockl.bc new file mode 100644 index 000000000..be7400e97 Binary files /dev/null and b/third_party/enflame/include/triton/third_party/amd/backend/lib/ockl.bc differ diff --git a/third_party/enflame/include/triton/third_party/amd/backend/lib/ocml.bc b/third_party/enflame/include/triton/third_party/amd/backend/lib/ocml.bc new file mode 100644 index 000000000..7c5b5ac9b Binary files /dev/null and b/third_party/enflame/include/triton/third_party/amd/backend/lib/ocml.bc differ diff --git a/third_party/enflame/include/triton/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/enflame/include/triton/third_party/amd/include/Analysis/RangeAnalysis.h new file mode 100644 index 000000000..5d48b96f8 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Analysis/RangeAnalysis.h @@ -0,0 +1,125 @@ +#ifndef TRITONAMD_ANALYSIS_RANGE_ANALYSIS_H +#define TRITONAMD_ANALYSIS_RANGE_ANALYSIS_H + +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Interfaces/LoopLikeInterface.h" + +namespace mlir::triton::AMD { +/// This struct (analysis) adapt's upstream's IntegerRangeAnalysis (inferring +/// lower/upperbounds on integer constants) to our needs. +/// Specifically there are 2 points of extension: +/// +/// 1. Support for GetProgramIdOp, MakeRangeOp, SplatOp, ExpandDimsOp. *Note*, +/// upstream already supports range inference for shaped types such as tensors +/// (here we just implement effectively implement the interfaces for our ops). +/// * Upstream's semantics for "range of shape type" is union over ranges of +/// elements. +/// * We do not use tablegen to implement +/// DeclareOpInterfaceMethods +/// in order to keep the entire implementation contained/encapsulated. +/// +/// 2. Support for inference "through loops". Upstream's analysis conservatively +/// inferences [min_int, max_int] for loop carried values (and therefore loop +/// body values). Here we attempt to do better by analysis the loop bounds and +/// "abstractly interpreting" the loop when loop bounds are statically known. +/// See visitRegionSuccessors. +struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { + using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis; + TritonIntegerRangeAnalysis( + DataFlowSolver &solver, + const DenseMap> &assumptions) + : dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions) {} + + void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override; + + LogicalResult visitOperation( + Operation *op, + ArrayRef operands, + ArrayRef resultsLattices) override; + + /// This method (which overloads + /// AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors) + /// implements "abstract interpretation" of loops with statically known bounds + /// in order to infer tight ranges for loop carried values (and therefore loop + /// body values). By "abstract interpretation" we mean lattice states are + /// propagated to all region successors N times, where N is the total trip + /// count of the loop. Recall for scf.for, both the loop itself and the users + /// of the loop successors. Thus, after N propagations both loop body values + /// and users of loop results will have accurate ranges (assuming we have + /// implemented support for range analysis on the ops). + /// *Note*, this implementation is majority similar to + /// AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors + /// (so check there for more explanation/insight) and basically only does two + /// things differently: + /// + /// 1. If the branch op is a loop (LoopLikeOpInterface) then we attempt to + /// compute its total trip count (nested loop trip counts multiply) and + /// initialize a visit count to 0. Note, due to how Dataflow analysis works we + /// have to actually visit the loop N times for each iter_arg (each argument + /// lattice) so we actually track visit count for (loop, arg) not just (loop). + /// + /// 2. Before propagating, we check if we have propagated for (loop, arg) >= N + /// times. If so, we do not propagate (and thus the traversal converges/ends). + /// + /// Note, for loops where the trip count cannot be inferred *and* loops with a + /// total trip count larger than `kDefaultMaxTripCount`, fallback to + /// upstream's conservative inference (i.e., we infer [min_int, max_int]) for + /// the loop operands and all users and all users of the results of the loop. + void visitRegionSuccessors( + ProgramPoint *point, RegionBranchOpInterface branch, + RegionBranchPoint successor, + ArrayRef abstractLattices) override; + + /// Collect all operands that participate in assumptions (see description of + /// `assumptions` field below) under the rootOp. By default, operands that can + /// be folded to constants are excluded. + static DenseMap> + collectAssumptions(Operation *rootOp, bool filterConstants = true); + + /// Construct the tightest/narrowest range possible using all the assumptions + /// that `anchor` participates in. For example, the pattern + /// %assumesltlhs = arith.cmpi sge, %K, %c0 : i32 + /// llvm.intr.assume %assumesltlhs : i1 + /// %assumesltlhs = arith.cmpi slt, %K, %c128 : i32 + /// llvm.intr.assume %assumesltlhs : i1 + /// for %K, will produce a final range + /// [0, 2147483647] ∩ [-2147483648, 128] = [0, 128] + std::optional maybeGetAssumedRange(Value anchor) const; + + /// Trip counts of all loops with static loop bounds contained under the root + /// operation being analyzed. Note, nested loops have trip counts computed as + /// a product of enclosing loops; i.e. for + /// scf.for i = 1 to 10 + /// scf.for j = 1 to 10 + /// the trip count of the outer loop (on i) is 10 but the trip count of the + /// inner loop (on j) is 100. + llvm::SmallDenseMap loopTripCounts; + + /// Visit counts tabulating how many times each lattice has been propagated + /// through each loop. This is used in visitRegionSuccessors to end + /// propagation when loopVisits[loop, lattice] reaches loopTripCounts[loop]. + llvm::SmallDenseMap< + std::pair, + int64_t> + loopVisits; + + /// `assumptions` maps from values to (possibly) any operations that satisfy + /// the pattern + /// %assumesltlhs = arith.cmpi sge, %K, %c0 : i32 + /// llvm.intr.assume %assumesltlhs : i1 + /// %assumesltlhs = arith.cmpi slt, %K, %c128 : i32 + /// llvm.intr.assume %assumesltlhs : i1 + /// If one uses collectAssumptions below then `assumptions` will look like + /// %K -> {arith.cmpi slt..., arith.cmpi sge}. + llvm::DenseMap> assumptions; +}; + +std::optional> +collectRanges(const DataFlowSolver &solver, ValueRange values); +bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp); + +} // namespace mlir::triton::AMD + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/include/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/include/CMakeLists.txt new file mode 100644 index 000000000..08707d601 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Dialect) +add_subdirectory(TritonAMDGPUToLLVM) +add_subdirectory(TritonAMDGPUTransforms) diff --git a/third_party/enflame/include/triton/third_party/amd/include/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/include/Dialect/CMakeLists.txt new file mode 100644 index 000000000..4f9163bdf --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonAMDGPU) diff --git a/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..094ecfc7d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonAMDGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=amdgpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=amdgpu) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_mlir_doc(TritonAMDGPUDialect TritonAMDGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonAMDGPUOps TritonAMDGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonAMDGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonAMDGPUAttrDefs.td) +mlir_tablegen(TritonAMDGPUEnums.h.inc -gen-enum-decls) +mlir_tablegen(TritonAMDGPUEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen) diff --git a/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h new file mode 100644 index 000000000..9d91da924 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_IR_DIALECT_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// clang-format off +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" +#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.h.inc" +// clang-format on + +#define GET_ATTRDEF_CLASSES +#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "amd/include/Dialect/TritonAMDGPU/IR/Ops.h.inc" + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_IR_DIALECT_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td new file mode 100644 index 000000000..44ac3c8d3 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_AMDGPU_ATTRDEFS +#define TRITON_AMDGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "TritonAMDGPUDialect.td" +include "mlir/IR/EnumAttr.td" + +class TritonAMDGPU_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "OpIdx"; + let summary = "An operand index attribute."; + let description = [{ + The attribute is a way to describe which input argument of the target + operation (e.g., `tt.dot`) the result of a given operation belongs to. + }]; + + let parameters = (ins "uint32_t":$value); + let assemblyFormat = "`<` $value `>`"; +} + +def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "InstCounter"; + let summary = "An instruction counter attribute."; + let description = [{ + The attribute holds the number of issued LLVM instructions of a specific kind as well as + the data type. + }]; + + let parameters = (ins "uint32_t":$value, "Type":$type); + let assemblyFormat = "`<` params `>`"; +} + +class TritonAMDGPU_I32Enum cases> + : I32EnumAttr { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::triton::amdgpu"; +} + +class TritonAMDGPU_I32EnumAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; + let cppNamespace = "::mlir::triton::amdgpu"; +} + +def SchedHintCaseNone : I32EnumAttrCase<"none", 0>; +def SchedHintCaseLLVMIglp0 : I32EnumAttrCase<"llvm_iglp_0", 1>; +def SchedHintCaseLLVMIglp1 : I32EnumAttrCase<"llvm_iglp_1", 2>; +def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 3>; + +def TritonAMDGPU_SchedHintsEnum : TritonAMDGPU_I32Enum< + "SchedHint", "Instruction Scheduling Hints for AMD GPUs", [ + SchedHintCaseNone, + SchedHintCaseLLVMIglp0, + SchedHintCaseLLVMIglp1, + SchedHintCaseLocalPrefetch, + ]>; + +def TritonAMDGPU_SchedHintVariantAttr : + TritonAMDGPU_I32EnumAttr<"SchedHintVariant", TritonAMDGPU_SchedHintsEnum>; + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td new file mode 100644 index 000000000..91a3d3230 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_AMDGPU_DIALECT +#define TRITON_AMDGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonAMDGPU_Dialect : Dialect { + let name = "amdgpu"; + let cppNamespace = "::mlir::triton::amdgpu"; + + let description = [{ + TritonAMDGPU Dialect hosts AMD specific ops at TritonGPU abstraction level. + }]; + + let dependentDialects = ["triton::TritonDialect"]; + + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td new file mode 100644 index 000000000..2d2187981 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -0,0 +1,397 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + + +#ifndef TRITON_AMDGPU_OPS +#define TRITON_AMDGPU_OPS + +include "mlir/IR/OpBase.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" + +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "TritonAMDGPUDialect.td" +include "TritonAMDGPUAttrDefs.td" + + +class TT_AMDGPU_Op traits = []> : + Op; + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +//===----------------------------------------------------------------------===// +// ExtractSliceOp +//===----------------------------------------------------------------------===// + +def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> { + let summary = "extract slice operation"; + let description = [{ + The "extract_slice" operation enables extracting a slice of a tensor in + registers. + + The "extract_slice" operation supports the following arguments: + + * source: the base tensor on which to create a view tensor + * offsets: offsets into the base tensor at which to create the view + + Example 1: + + ```mlir + #blocked = #ttg.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [0, 1]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> + %1 = ttg.convert_layout %0 : tensor<128x128xf16, #blocked> + -> tensor<128x128xf16, #blocked1> + // create a slice of base tensor %1 with static offsets + %2 = amdgpu.extract_slice %0 [0, 0] : + tensor<128x128xf16, #blocked1> to tensor<128x32xf16, #blocked1> + ``` + + Example 1 shows how "extract_slice" operation may be used. In this example a + new slice of 128x32 is created. "extract_slice" works on tensors with layout + where the desired slice has the same layout as the source tensor. + "%0" cannot be sliced directly as the resulting slice cannot have the same + layout as "%0". Therefore it needs to be converted to a layout suitable + for slicing. "#blocked1" layout is appropriate for this as it keeps the + sizePerThread the same thus keeping coalescing properties the same. + In order to utilize all threads in a warp, "threadsPerWarp" is set to + [16,4] for this new layout. This layout conversion carried out before + using "extract_slice" ensures slicing still uses all threads efficiently. The + size of the slice is determined by the result type. + }]; + + let arguments = (ins + AnyRankedTensor:$source, + DenseI64ArrayAttr:$static_offsets + ); + let results = (outs AnyRankedTensor:$result); + + let builders = [ + // Build a ExtractSliceOp with static offsets and the same result type + OpBuilder<(ins "RankedTensorType":$resultType, + "Value":$source, + "ArrayRef": $static_offsets)>, + ]; + + let extraClassDeclaration = [{ + std::array getArrayAttrMaxRanks() { + unsigned rank = getSource().getType().getRank(); + return {rank, rank, rank}; + } + }]; + + let assemblyFormat = [{ + $source $static_offsets attr-dict `:` type($source) `to` type($result) + }]; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// InstructionSchedHint +//===----------------------------------------------------------------------===// + +def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { + let summary = "A placeholder op for instruction scheduling hints within a basic block"; + let description = [{ + A placeholder op for instruction scheduling hints applied to instructions within + a basic block where the placeholder op is located. This op is primarily intended + to be used to adjust instruction scheduling inside the resulting main loop + of a `tt.dot` operation. It's easier to identify dot ops at a high level and, thus, + to mark intended scheduling regions. The hint ops are eventually lowered + into LLVM AMDGPU instruction scheduling primitives, which are meant to control + how different kinds of instructions (valu/mfma, global/shared memory, etc.) should + interleave for better instruction level parallelism. + }]; + + let arguments = (ins + TritonAMDGPU_SchedHintVariantAttr:$variant, + TritonAMDGPU_InstCounter:$numDsReadsA, + TritonAMDGPU_InstCounter:$numDsReadsB, + TritonAMDGPU_InstCounter:$numDsWritesA, + TritonAMDGPU_InstCounter:$numDsWritesB, + TritonAMDGPU_InstCounter:$numGlobalLoadsA, + TritonAMDGPU_InstCounter:$numGlobalLoadsB, + BoolAttr:$isBufferLoadsAEnabled, + BoolAttr:$isBufferLoadsBEnabled, + TritonAMDGPU_InstCounter:$numMMAs + ); + + let builders = [ + OpBuilder<(ins "amdgpu::SchedHint":$variant), [{ + auto ctx = $_state.getContext(); + auto noneType = NoneType::get(ctx); + auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType); + build($_builder, $_state, variant, emptyAttr, emptyAttr, emptyAttr, emptyAttr, + emptyAttr, emptyAttr, false, false, emptyAttr); + }]> + ]; + + let assemblyFormat = [{ attr-dict }]; +} + +//===----------------------------------------------------------------------===// +// CondBarrierOp +//===----------------------------------------------------------------------===// + +def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier"> { + let summary = "Conditionally set barriers to synchronize partial threads in a block"; + + let description = [{ + condBarrierOp sets barrier instruction only when the given argument is true. + This provides a way to synchronize partial threads in a block, deliberately + diverges the execution sequences. However, user should guarantee all threads + converge at the end by calling condBarrierOp(true) with the remaining threads. + Conceptually, this is similar to having an execution barrier inside an if statement. + This op allows us to avoid blocking the whole block when suitable to help scheduling. + NB. This doesn't set any memory fence. + }]; + + let arguments = (ins I1:$pred); + + let assemblyFormat = "$pred attr-dict"; +} + +//===----------------------------------------------------------------------===// +// BufferLoadOp +//===----------------------------------------------------------------------===// + +def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [ + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + MemoryEffects<[MemRead]>, + TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)", + "(cast($_op).getMask() == nullptr) || std::equal_to<>()">, + TypesMatchWith<"result and other have the same type", "result", "other", "$_self", + "(cast($_op).getOther() == nullptr) || std::equal_to<>()">, +]>{ + let summary = "Load from a scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer load operation. Buffer store is similar to + a normal store but it accesses global memory via a scalar base pointer + and a tensor of offsets instead of a tensor of pointers. The other fields + are similar to a normal load, i.e., the `mask` is a boolean vector that + determines if a given element should be read from memory, and `other` is the + element that should be returned on lane `i` when `mask[i] == 0`. + Stride is the distance between the beginning of contiguous memory chunks. + When performing a load of a block, the `stride` is the address difference between + the first elements of each row in bytes. Compiler tries to obtain the `stride` + when it converts to the buffer ops because it is important for optimizing + the cache memory access. + }]; + let arguments = (ins + TT_Ptr:$ptr, + I32Tensor:$offsets, + Optional:$stride, + DefaultValuedAttr:$cache, + Optional:$mask, + Optional:$other + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $ptr `[` $offsets `]` (`,` $mask^)? (`,` $other^)? + oilist(`cacheModifier` `=` $cache) + (`stride` `=` $stride^)? + attr-dict `:` type($result) + }]; +} + +//===----------------------------------------------------------------------===// +// BufferLoadToLocalOp +//===----------------------------------------------------------------------===// + +def BufferLoadToLocalOp : TT_AMDGPU_Op<"buffer_load_to_local", [ + AttrSizedOperandSegments, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"dest element type matches pointee type of ptr", "dest", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"infer mask shape from offsets", + "offsets", "mask", "getI1SameShape($_self)", + "(cast($_op).getMask() == nullptr) || std::equal_to<>()">, + TypesMatchWith<"other matches shape and layout of offsets and the element type matches the pointee type of ptr", + "offsets", "other", "cast($_self).clone(getPointeeType($ptr.getType()))", + "(cast($_op).getOther() == nullptr) || std::equal_to<>()">, +]>{ + let summary = "Load from a scalar base pointer and a tensor offset to shared memory"; + let description = [{ + AMD Buffer load operation. Similar to amdgpu.buffer_load op but directly wirtes to shared memory instead of into registers. }]; + let arguments = (ins + TTG_MemDescType:$dest, + TT_Ptr:$ptr, + I32Tensor:$offsets, + Optional:$mask, + Optional:$other, + Optional:$stride, + DefaultValuedAttr:$cache + ); + let results = (outs TTG_AsyncToken:$token); + + let assemblyFormat = [{ + $ptr `[` $offsets `]` (`mask` `=` $mask^)? (`other` `=` $other^)? (`stride` `=` $stride^)? + oilist(`cacheModifier` `=` $cache) `into` $dest + attr-dict `:` type($ptr) `[` type($offsets) `]` type($other) `->` type($dest) + }]; +} + +//===----------------------------------------------------------------------===// +// BufferAtomicRMWOp +//===----------------------------------------------------------------------===// + +def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [ + AttrSizedOperandSegments, + SameLoadStoreOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"result element type matches the value type", "result", "value", "$_self">, + TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)", + "(cast($_op).getMask() == nullptr) || std::equal_to<>()">, + TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)", + "(cast($_op).getMask() == nullptr) || std::equal_to<>()">, +]>{ + let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer atomic RMW operation. Buffer atomics are similar to normal atomics, but access global memory via a + scalar base pointer and a tensor of offsets instead of a tensor of pointers. + Similar to other buffer ops, the `mask` is a boolean vector that determines if a given element should be processed with + the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed). + Similar to TT_AtomicRMWOp: Buffer atomic RMW ops load data at $ptr, do $rmw_op with $val, and store result to $ptr with + the specified memory semantics and scope. Atomic RMW ops return the pre-op value if used, otherwise the value is implicitly dropped. + Stride is the distance between the beginning of contiguous memory chunks. When performing a RMW, the `stride` is + the address difference between the first elements of each row in bytes. Compiler tries to obtain the `stride` + when it converts to the buffer ops because it is important for optimizing the cache memory access. + }]; + let arguments = (ins + TT_AtomicRMWAttr:$atomic_rmw_op, + TT_Ptr:$ptr, + I32Tensor:$offsets, + TT_Tensor:$value, + Optional:$stride, + TT_MemSemanticAttr:$sem, + TT_MemSyncScopeAttr:$scope, + Optional:$mask + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)? + (`stride` `=` $stride^)? + attr-dict `:` type($result) + }]; +} + +//===----------------------------------------------------------------------===// +// BufferStoreOp +//===----------------------------------------------------------------------===// + +def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [ + AttrSizedOperandSegments, + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)", + "(cast($_op).getMask() == nullptr) || std::equal_to<>()">, +]>{ + let summary = "Store into scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer store operation. Buffer store is similar to + normal store but it accesses global memory via a scalar base pointer + and a tensor of offsets instead of a tensor of pointers. The other fields + are similar to a normal store , i.e., the `mask` is a boolean vector that + determines if a given element should be written to memory, and `value` is the + tensor of elements that should be written on lane `i` when `mask[i] == 1`. + Stride is the distance between the beginning of contiguous memory chunks. + When performing a block store, the `stride` is the address difference between + the first elements of each row in bytes. Compiler tries to obtain the `stride` + when it converts to the buffer ops because it is important for optimizing + the cache memory access. + }]; + let arguments = (ins + TT_Tensor:$value, + TT_Ptr:$ptr, + I32Tensor:$offsets, + Optional:$stride, + DefaultValuedAttr:$cache, + Optional:$mask + ); + + let assemblyFormat = [{ + $value `,` $ptr `[` $offsets `]` (`,` $mask^)? + oilist(`cacheModifier` `=` $cache) + (`stride` `=` $stride^)? + attr-dict `:` type($value) + }]; +} + +//===----------------------------------------------------------------------===// +// UpcastMXFPOp +//===----------------------------------------------------------------------===// + +def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> { + let summary = "Convert an mxfp tensor to bf16/fp16"; + + let hasVerifier = 1; + + let description = [{ + Compute the bf16 encoded in the given mxfp number as per + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + }]; + let arguments = ( + ins + TT_Tensor:$src, + TT_Tensor:$scale, + TT_ScaleDotElemTypeAttr:$fp_type, + BoolAttr:$fastMath + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result) + }]; + + let extraClassDeclaration = [{ + static RankedTensorType deduceOutputType( + TypedValue inputTensor, ScaleDotElemType inputElemType, Type outputElemType); + }]; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h new file mode 100644 index 000000000..5b599f5e7 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h @@ -0,0 +1,11 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_UTILITY_COMMONUTILS_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_UTILITY_COMMONUTILS_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::AMD { +SmallVector getLeafForOps(triton::FuncOp funcOp); +} // namespace mlir::triton::AMD + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_UTILITY_COMMONUTILS_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..ed566b745 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonAMDGPUToLLVM) +add_public_tablegen_target(TritonAMDGPUConversionPassIncGen) diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h new file mode 100644 index 000000000..4e19b370c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h @@ -0,0 +1,404 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_GCNASMFORMAT_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_GCNASMFORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { + +class ConversionPatternRewriter; +class Location; + +} // namespace mlir + +namespace mlir::triton { +using llvm::StringRef; + +class GCNInstr; +class GCNInstrCommon; +class GCNInstrExecution; + +// GCNBuilder helps to manage a GCN asm program consists of one or multiple +// instructions. +// +// A helper for building an ASM program, the objective of GCNBuilder is to give +// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear. +// Currently, several factors are introduced to reduce the need for mixing +// string and C++ if-else code. +// +// Usage: +// To create a multiplcation operation +// +// +// GCNBuilder gcnBuilder; +// unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); +// +// const std::string readConstraint = "v"; +// const std::string writeConstraint = "=v"; +// auto res = gcnBuilder.newOperand(writeConstraint); +// auto lhs = gcnBuilder.newOperand(operands[0], readConstraint); +// auto rhs = gcnBuilder.newOperand(operands[1], readConstraint); +// +// create inst +// auto &mul_inst = +// gcnBuilder.create("v_mul")->float_op_type(bitwidth); +// +// launch insts +// mul_inst(res, lhs, rhs); +// +// return result +// Value ret = gcnBuilder.launch(rewriter, loc, elemTy, false); +// return ret; +// To get the asm code: +// builder.dump() +// +// To get all the mlir::Value used in the GCN code, +// +// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal} +// +// To get the string containing all the constraints with "," separated, +// builder.getConstraints() // get "=v,v,v" +// +// GCNBuilder can build a GCN asm with multiple instructions, sample code: +// +// GCNBuilder builder; +// auto &rcp = gcnBuilder.create("v_rcp")->float_op_type(bitwidth); +// auto &mul_inst = +// gcnBuilder.create("v_mul")->float_op_type(bitwidth); +// +// rcp(...); +// mul_inst(...); +// This will get a GCN code with two instructions. +// +// Similar to a C function, a declared GCNInstr instance can be launched +// multiple times with different operands, e.g. +// +// auto &mul_inst = +// gcnBuilder.create("v_mul")->float_op_type(bitwidth); mul_inst(... +// some operands ...); mul_inst(... some different operands ...); +// +// Finally, we will get a GCN code with two mov instructions. +// +// There are several derived instruction type for typical instructions, for +// example, the GCNIOInstr for ld and st instructions. +struct GCNBuilder { + struct Operand { + std::string constraint; + Value value; + int idx{-1}; + llvm::SmallVector list; + std::function repr; + + // for list + Operand() = default; + Operand(const Operation &) = delete; + Operand(Value value, StringRef constraint) + : constraint(constraint), value(value) {} + + bool isList() const { return !value && constraint.empty(); } + + Operand *listAppend(Operand *arg) { + list.push_back(arg); + return this; + } + + Operand *listGet(size_t nth) const { + assert(nth < list.size()); + return list[nth]; + } + + std::string dump() const; + }; + + struct Modifier { + Value value; + std::string modifier; + std::string arg; + llvm::SmallVector list; + + Modifier() = default; + Modifier(const Operation &) = delete; + Modifier(Value value, StringRef arg) : value(value), arg(arg) {} + + bool isList() const { return !value && modifier.empty(); } + + Modifier *listAppend(Modifier *arg) { + list.push_back(arg); + return this; + } + + Modifier *listGet(size_t index) const { + assert(index < list.size()); + return list[index]; + } + + std::string to_str() const { + std::string str = modifier; + if (!arg.empty()) { + str += ":" + arg; + } + return str; + } + + std::string dump() const; + }; + + template + INSTR *create(Args &&...args) { + instrs.emplace_back(std::make_unique(this, args...)); + return static_cast(instrs.back().get()); + } + + // Create a list of operands. + Operand *newListOperand() { return newOperand(); } + + Operand *newListOperand(ArrayRef> items) { + auto *list = newOperand(); + for (auto &item : items) { + list->listAppend(newOperand(item.first, item.second)); + } + return list; + } + + Operand *newListOperand(unsigned count, mlir::Value val, + const std::string &constraint) { + auto *list = newOperand(); + for (int i = 0; i < count; ++i) { + list->listAppend(newOperand(val, constraint)); + } + return list; + } + + Operand *newListOperand(unsigned count, const std::string &constraint) { + auto *list = newOperand(); + for (int i = 0; i < count; ++i) { + list->listAppend(newOperand(constraint)); + } + return list; + } + + // Create a new operand. It will not add to operand list. + // @value: the MLIR value bind to this operand. + // @constraint: ASM operand constraint, .e.g. "=r" + // @formatter: extra format to represent this operand in ASM code, default is + // "%{0}".format(operand.idx). + Operand *newOperand(mlir::Value value, StringRef constraint, + std::function formatter = nullptr); + + // Create a new operand which is written to, that is, the constraint starts + // with "=", e.g. "=r". + Operand *newOperand(StringRef constraint); + + // Create a constant integer operand. + Operand *newConstantOperand(int v); + // Create a constant operand with explicit code specified. + Operand *newConstantOperand(const std::string &v); + + Operand *newAddrOperand(mlir::Value addr, StringRef constraint); + + Modifier *newModifier(StringRef modifier, StringRef arg); + + llvm::SmallVector getAllArgs() const; + + llvm::SmallVector getAllMLIRArgs() const; + + std::string getConstraints() const; + + std::string dump() const; + + mlir::Value launch(RewriterBase &rewriter, Location loc, Type resTy, + bool hasSideEffect = true, bool isAlignStack = false, + ArrayRef attrs = {}) const; + +private: + Operand *newOperand() { + argArchive.emplace_back(std::make_unique()); + return argArchive.back().get(); + } + + Modifier *newModifier() { + modArchive.emplace_back(std::make_unique()); + return modArchive.back().get(); + } + + friend class GCNInstr; + friend class GCNInstrCommon; + +protected: + llvm::SmallVector, 6> argArchive; + llvm::SmallVector, 2> modArchive; + llvm::SmallVector, 2> instrs; + llvm::SmallVector, 4> executions; + int oprCounter{}; +}; + +// GCN instruction common interface. +// Put the generic logic for all the instructions here. +struct GCNInstrCommon { + explicit GCNInstrCommon(GCNBuilder *builder) : builder(builder) {} + + using Operand = GCNBuilder::Operand; + using Modifier = GCNBuilder::Modifier; + + // clang-format off + GCNInstrExecution& operator()() { return call({}, {}); } + GCNInstrExecution& operator()(Operand* a) { return call({a}, {}); } + GCNInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}, {}); } + GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}, {}); } + GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}, {}); } + GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}, {}); } + GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}, {}); } + GCNInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}, {}); } + // clang-format on + + // Set operands of this instruction. + GCNInstrExecution &operator()(llvm::ArrayRef oprs, + llvm::ArrayRef mods); + +protected: + GCNInstrExecution &call(llvm::ArrayRef oprs, + ArrayRef mods); + + GCNBuilder *builder{}; + llvm::SmallVector instrParts; + + friend class GCNInstrExecution; +}; + +template struct GCNInstrBase : public GCNInstrCommon { + using Operand = GCNBuilder::Operand; + using Modifier = GCNBuilder::Modifier; + + explicit GCNInstrBase(GCNBuilder *builder, const std::string &name) + : GCNInstrCommon(builder) { + o(name); + } + + ConcreteT &o(const std::string &suffix, bool predicate = true) { + if (predicate) + instrParts.push_back(suffix); + return *static_cast(this); + } +}; + +enum VectorWidth { Byte = 8, Short = 16, Dword = 32, Qword = 64 }; + +struct GCNInstr : public GCNInstrBase { + using GCNInstrBase::GCNInstrBase; + + GCNInstr &float_op_type(int width) { + switch (width) { + case Byte: + assert(Byte != width); + break; + case Short: + o("f16"); + break; + case Dword: + o("f32"); + break; + case Qword: + o("f64"); + break; + default: + break; + } + return *this; + } +}; + +struct GCNInstrExecution { + using Operand = GCNBuilder::Operand; + using Modifier = GCNBuilder::Modifier; + + llvm::SmallVector argsInOrder; + llvm::SmallVector mods; + + GCNInstrExecution() = default; + explicit GCNInstrExecution(GCNInstrCommon *instr, + llvm::ArrayRef oprs, + llvm::ArrayRef modifiers) + : argsInOrder(oprs.begin(), oprs.end()), instr(instr), + mods(modifiers.begin(), modifiers.end()) {} + + std::string dump() const; + + SmallVector getArgList() const; + + GCNInstrCommon *instr{}; +}; + +struct GCNMemInstr : public GCNInstrBase { + using GCNInstrBase::GCNInstrBase; + // Add specific type suffix to instruction + + GCNMemInstr &load_type(int width) { + switch (width) { + case Byte: + o("ubyte"); + break; + case Short: + o("ushort"); + break; + case Dword: + o("dword"); + break; + case Qword: + o("dwordx2"); + break; + default: + break; + } + return *this; + } + + GCNMemInstr &store_type(int width) { + switch (width) { + case Byte: + o("byte"); + break; + case Short: + o("short"); + break; + case Dword: + o("dword"); + break; + case Qword: + o("dwordx2"); + break; + default: + break; + } + return *this; + } +}; + +} // namespace mlir::triton + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_GCNASMFORMAT_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h new file mode 100644 index 000000000..f373ae8d9 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -0,0 +1,55 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PASSES_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PASSES_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +} // namespace mlir + +namespace mlir::triton { + +#define GEN_PASS_DECL +#include "TritonAMDGPUToLLVM/Passes.h.inc" + +} // namespace mlir::triton + +namespace mlir::triton::AMD { +std::unique_ptr> +createDecomposeUnsupportedConversionsPass(StringRef targetArch); + +/// @brief Creates pass that keep LDS consumption within specified limits. +/// @param arch target architecture name, for example "gfx940" +/// @param customLDSLimit defines LDS size available for one thread block +/// zero value tells pass that whole LDS is available on a device +/// @return created pass +std::unique_ptr> +createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0); +} // namespace mlir::triton::AMD + +namespace mlir::triton { + +std::unique_ptr> +createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); +std::unique_ptr> +createConvertBuiltinFuncToLLVMPass(bool ftz); +std::unique_ptr> +createTritonAMDGPUInsertInstructionSchedHintsPass(StringRef variant); +std::unique_ptr> +createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch, + int32_t numStages); + +#define GEN_PASS_REGISTRATION +#include "TritonAMDGPUToLLVM/Passes.h.inc" + +} // namespace mlir::triton + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PASSES_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td new file mode 100644 index 000000000..eb68fe3a5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -0,0 +1,90 @@ +#ifndef TRITONAMDGPU_CONVERSION_PASSES +#define TRITONAMDGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-conversions", "mlir::ModuleOp"> { + let summary = "Decompose conversions that are not supported by TritonGPU -> LLVM"; + let constructor = "mlir::triton::AMD::createDecomposeUnsupportedConversionsPass(\"\")"; + + let options = [ + Option<"arch", "arch", "std::string", /*default*/"\"\"", + "gfx target device architecture, e.g., gfx942">, + ]; +} + +def OptimizeAMDLDSUsage : Pass<"optimize-amd-lds-usage", "mlir::ModuleOp"> { + let summary = "Minimize LDS usage"; + let constructor = "mlir::triton::AMD::createOptimizeLDSUsagePass(\"\")"; + + let options = [ + Option<"targetArch", "target-arch", "std::string", /*default*/"", + "gfx target device architecture, e.g., gfx942">, + Option<"customLDSLimit", "lds-limit", "int", /*default*/"0", + "custom limit of LDS consumption, if not provided, maximum LDS size is used">, + ]; +} + +def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert TritonGPU to LLVM"; + let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::ROCDL::ROCDLDialect"]; + + let options = [ + Option<"arch", "arch", "std::string", /*default*/"\"\"", + "gfx target device architecture, e.g., gfx942">, + Option<"ftz", "ftz", "bool", /*default*/"true", + "flush denorms for math functions">, + ]; +} + +def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Builtin Func to LLVM"; + let constructor = "mlir::triton::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true)"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + + let options = [ + Option<"ftz", "ftz", "bool", /*default*/"true", + "flush denorms for math functions">, + ]; +} + +def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> { + let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; + let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass(/*variant=*/\"\")"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; + + let options = [ + Option<"variant", "variant", "std::string", /*default*/"\"none\"", + "instruction scheduling variant">, + ]; +} + +def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { + let summary = "Lower instruction scheduling hints to LLVM intrinsics"; + let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2)"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::ROCDL::ROCDLDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; + + let options = [ + Option<"arch", "arch", "std::string", /*default*/"\"\"", + "gfx target device architecture, e.g., gfx942">, + Option<"numStages", "num_stages", "int32_t", /*default*/"2", + "number of pipeline stages">, + ]; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h new file mode 100644 index 000000000..cd9407ed2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -0,0 +1,14 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +namespace mlir::triton::AMD { + +void populateExtractSliceOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit); + +} // namespace mlir::triton::AMD + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h new file mode 100644 index 000000000..223eadb2e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h @@ -0,0 +1,39 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_TARGETUTILS_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_TARGETUTILS_H_ + +#include "llvm/ADT/StringRef.h" + +namespace mlir::triton::AMD { + +// A list of ISA families we care about. +enum class ISAFamily { + Unknown, + CDNA1, + CDNA2, + CDNA3, + CDNA4, + RDNA1, + RDNA2, + RDNA3, +}; + +// Deduces the corresponding ISA family for the given target gfx |arch|. +ISAFamily deduceISAFamily(llvm::StringRef arch); + +// Retursn true if given architecture support V_DOT instruction. +bool supportsVDot(llvm::StringRef arch); + +// Here is a partial definition of DppCtrl enums. For the complete definition, +// please check: +// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939 +enum class DppCtrl : uint32_t { + QUAD_PERM_FIRST = 0, + ROW_SHL0 = 0x100, + ROW_SHR0 = 0x110, + BCAST15 = 0x142, + BCAST31 = 0x143 +}; + +} // namespace mlir::triton::AMD + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_TARGETUTILS_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/CMakeLists.txt new file mode 100644 index 000000000..b8c9325f4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonAMDGPU) +add_public_tablegen_target(TritonAMDGPUTransformsIncGen) diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h new file mode 100644 index 000000000..ef275c88c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h @@ -0,0 +1,49 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_MFMAGROUP_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_MFMAGROUP_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +// Returns true if the given type is an OCP FP8/FP6/FP6 type. +inline bool isF8F6F4(mlir::Type type) { + return llvm::isa(type); +} + +struct MfmaIntrinsic { + // Chooses a suitable mfma instrinsic for the given input case. + static FailureOr selectFor(int version, unsigned mDim, + unsigned nDim, unsigned inputKDim, + Type aElemType, Type bElemType, + bool withScale, bool useTF32); + + MfmaIntrinsic(StringRef symbol, unsigned m, unsigned n, unsigned k, + unsigned kB, Type aET, Type bET) + : name(symbol), mDim(m), nDim(n), kDim(k), kBase(kB), aElementType(aET), + bElementType(bET) {} + MfmaIntrinsic(const MfmaIntrinsic &other) = default; + MfmaIntrinsic(MfmaIntrinsic &&other) = default; + MfmaIntrinsic() = default; + MfmaIntrinsic &operator=(MfmaIntrinsic &&other) = default; + + llvm::StringRef name; + + // m, n, and k refer to the shapes of the two operands of an mfma intrinsic: + // Operand A has shape [m]x[k]; operand B has shape [k]x[n]. + // For mfma32 and mfma16 intrinsics, they are encoded in the instruction + // name, i.e. mfma_DType_[m]x[n]x[k]xABType. + unsigned mDim; + unsigned nDim; + unsigned kDim; + + // kBase is the number of elements each thread holds. + unsigned kBase; + + Type aElementType; + Type bElementType; +}; +} // namespace mlir + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_MFMAGROUP_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/Passes.h new file mode 100644 index 000000000..f5311723d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -0,0 +1,42 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_PASSES_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_PASSES_H_ + +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { + +std::unique_ptr +createTritonAMDGPUStreamPipelinePass(int numStages = 2, int globalPrefetch = 0, + int localPrefetch = 0); + +std::unique_ptr +createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), + int matrixInstructionSize = 0, + int kpack = 1); + +std::unique_ptr createTritonAMDGPUCanonicalizeLoopsPass(); + +std::unique_ptr createTritonAMDGPUReorderInstructionsPass(); + +std::unique_ptr createTritonAMDGPUVerifier(); + +std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); + +std::unique_ptr createTritonAMDGPUHoistLayoutConversionsPass(); + +std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); + +std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass( + std::string archGenName = std::string()); + +std::unique_ptr createTritonAMDGPUBlockPingpongPass(); + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "TritonAMDGPUTransforms/Passes.h.inc" + +} // namespace mlir +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_PASSES_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/Passes.td new file mode 100644 index 000000000..37b257c8b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -0,0 +1,169 @@ +#ifndef TRITONGPU_PASSES +#define TRITONGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::ModuleOp"> { + let summary = "pipeline"; + + let description = [{ + Pipeline global loads through registers to shared memory while computing on previous + tile + }]; + + let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()"; + + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; + + let options = [ + Option<"numStages", "num_stages", + "int32_t", /*default*/"2", + "Number of Pipeline stages">, + Option<"globalPrefetch", "global_prefetch", + "int32_t", /*default*/"0", + "Set global prefetch stage count">, + Option<"localPrefetch", "local_prefetch", + "int32_t", /*default*/"0", + "Set local prefetch stage count">, + ]; +} + +def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> { + let summary = "accelerate matmul"; + + let description = [{ + Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators + (e.g., AMD matrix cores) + }]; + + let constructor = "mlir::createTritonAMDGPUAccelerateMatmulPass()"; + + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; + + let options = [ + Option<"archGenerationName", "arch-generation-name", + "std::string", /*default=*/"std::string{}", + "GFX generation name of target device.">, + Option<"matrixInstructionSize", "matrix-instruction-size", + "int32_t", /*default*/"0", + "enforce matrix instruction MN size">, + Option<"kPack", "kPack", + "int32_t", /*default*/"1", + "KWidth / kBase"> + ]; +} + +def TritonAMDGPUOptimizeEpilogue : Pass<"tritonamdgpu-optimize-epilogue", "mlir::ModuleOp"> { + let summary = "Optimize epilogue: (1) Store accumulators directly without going thorough SMEM in epilogue."; + + let description = [{ + }]; + + let constructor = "mlir::createTritonAMDGPUOptimizeEpiloguePass()"; + + let dependentDialects = []; + +} + +def TritonAMDGPUHoistLayoutConversions : Pass<"tritonamdgpu-hoist-layout-conversions", "mlir::triton::FuncOp"> { + let summary = "Hoist layout conversions out of the loop"; + + let description = [{ + This pass tries to hoist a convert_layout op out of the loop if 1) its dst is a tensor + of dotOperand layout, and 2) its src is defined out of the loop. + The rational is as follows: + 1. When the defining op of the src is out of the loop, it means the src is loop-invariant. + Then we can potentially hoist this convert_layout op, since it's also loop-invariant. + 2. The drawback of this LICM is higher register pressure. However, on AMD GPUs, we have + a larger register file but smaller shared memory. It's beneficial to keep loop-invariant + variables in registers rather than loading them from shared memory in the loop. + }]; + + let constructor = "mlir::createTritonAMDGPUHoistLayoutConversionsPass()"; + +} + +def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers", "mlir::triton::FuncOp"> { + let summary = "Canonicalize pointers: rewrite pointers passed to load/store operation as a `` pair."; + + let description = [{ + This pass pushes all the constant pointer arithmetic on a scalar basePtr, while all the vector + pointer arithmetic to a vector offset. I.e., if we consider the following IR: + ``` + %v_ptr = tt.splat %s_ptr + %c_offset = tt.splat %s_offset + %v_offset0 = tt.make_range + %v_offset1 = tt.make_range + %v_ptr0 = tt.addptr %v_ptr, %c_offset + %v_ptr1 = tt.addptr %v_ptr0, %v_offset0 + %v_ptr2 = tt.addptr %v_ptr0, %v_offset1 + %data = tt.load(%v_ptr2) + ``` + We transform this into: + ``` + %s_ptr0 = tt.addptr %s_ptr, %s_offset + %v_offset = %zero + %v_offset = arith.addi %v_offset, %v_offset0 + %v_offset = arith.addi %v_offset, %v_offset1 + %c_ptr = tt.splat %s_ptr0 + %v_ptr = tt.addptr %c_ptr, %v_offset + %data = tt.load(%v_ptr) + ``` + In the above IR: + - `v_` means "variable vector across the program" + - `c_` means "constant vector across the program" + - `s_` means "scalar" + So we transform the IR such that the constant updates become scalar updates, and the variable updates happen on the offset. Note that + when we have to load the data, we splat the scalar pointer, add the "variable" offset and then issue the load. + }]; + + let constructor = "mlir::createTritonAMDGPUCanonicalizePointersPass()"; + + let dependentDialects = []; + +} + +def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "mlir::ModuleOp"> { + let summary = "Reorder instructions"; + + let description = "This pass reorder instructions so as to (1) decrease register pressure (e.g., by moving " + "conversions from shared memory before their first use) and (2) promote LLVM instruction " + "order more friendly to `ptxas`."; + + let constructor = "mlir::createTritonAMDGPUReorderInstructionsPass()"; + + let dependentDialects = []; +} + +def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "mlir::ModuleOp"> { + let summary = "Convert memory operations to buffer operations"; + + let description = "This pass converts memory and atomic operations (e.g., tt.load/tt.store/tt.atomic_rmw) to amdgpu buffer operations, if possible"; + + let constructor = "mlir::createTritonAMDGPUConvertToBufferOpsPass()"; + + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; + + let options = [ + Option<"archGenerationName", "arch-generation-name", + "std::string", /*default=*/"std::string{}", + "GFX generation name of target device.">, + ]; +} + +def TritonAMDGPUBlockPingpong: Pass<"tritonamdgpu-block-pingpong", "mlir::ModuleOp"> { + let summary = "Interleaving instructions from two warps on the same SIMD to better utilize matrix core"; + + let description = [{ + This pass reorder instructions to interleave instructions from two warps on the same SIMD unit. + We call this a ping-pong scheduling pattern, where two warps run concurrently in the synchronized fashion + This block ping-pong pattern could be beneficial under few conditions including + occupancy and number of warps. + }]; + + let constructor = "mlir::createTritonAMDGPUBlockPingpongPass()"; + + let dependentDialects = ["mlir::ROCDL::ROCDLDialect, mlir::triton::amdgpu::TritonAMDGPUDialect"]; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/TritonGPUConversion.h b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/TritonGPUConversion.h new file mode 100644 index 000000000..0e8b7a624 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/include/TritonAMDGPUTransforms/TritonGPUConversion.h @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonGPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_TRITONGPUCONVERSION_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class TritonGPUTypeConverter : public TypeConverter { +public: + TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp, + int numCTAs); + int getNumWarps() const { return numWarps; } + int getThreadsPerWarp() const { return threadsPerWarp; } + int getNumCTAs() const { return numCTAs; } + +private: + MLIRContext *context; + int numWarps; + int threadsPerWarp; + int numCTAs; +}; + +class TritonGPUConversionTarget : public ConversionTarget { + +public: + explicit TritonGPUConversionTarget(MLIRContext &ctx, + TritonGPUTypeConverter &typeConverter); +}; + +} // namespace mlir + +#endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/language/hip/__init__.py b/third_party/enflame/include/triton/third_party/amd/language/hip/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/language/hip/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/enflame/include/triton/third_party/amd/language/hip/libdevice.py b/third_party/enflame/include/triton/third_party/amd/language/hip/libdevice.py new file mode 100644 index 000000000..a69d4406c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/language/hip/libdevice.py @@ -0,0 +1,475 @@ +from triton.language import core + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")), + (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__triton_hip_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")), + (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/Analysis/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..99036acac --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/Analysis/CMakeLists.txt @@ -0,0 +1,12 @@ +add_triton_library(TritonAMDAnalysis + RangeAnalysis.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/enflame/include/triton/third_party/amd/lib/Analysis/RangeAnalysis.cpp new file mode 100644 index 000000000..3e9ca8252 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -0,0 +1,460 @@ +#include "third_party/amd/include/Analysis/RangeAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#include +#include + +#undef DEBUG_TYPE +#define DEBUG_TYPE "tritonamdgpu-range-analysis" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; + +namespace tt = mlir::triton; + +namespace { + +constexpr int64_t kDefaultMaxTripCount = 1024; +constexpr int64_t kDefaultMaxPrograms = 2 << 15; // 65536 + +std::optional maybeGetTripCount(LoopLikeOpInterface loop) { + std::optional lowerBound = loop.getSingleLowerBound(); + std::optional upperBound = loop.getSingleUpperBound(); + std::optional step = loop.getSingleStep(); + if (lowerBound && upperBound && step) + return constantTripCount(*lowerBound, *upperBound, *step); + return {}; +} + +void getEnclosingLoops(Operation &op, SmallVector &ops) { + Operation *currOp = op.getParentOp(); + while (currOp) { + if (isa(currOp)) + ops.push_back(llvm::cast(currOp)); + currOp = currOp->getParentOp(); + } +} + +void inferResultRangesPID(Operation *op, uint64_t max, + SetIntRangeFn setResultRange) { + assert(op->getNumResults() == 1 && "expected op to have one result"); + auto result = op->getResult(0); + assert(llvm::isa(result.getType()) && + "expected result type to be int"); + IntegerType resTy = llvm::cast(result.getType()); + auto bitWidth = mlir::ConstantIntRanges::getStorageBitwidth(resTy); + setResultRange(result, ConstantIntRanges::range( + /*min*/ {/*numBits*/ bitWidth, /*val*/ 0, + /*isSigned*/ resTy.isSigned()}, + /*max*/ + {/*numBits*/ bitWidth, /*val*/ max, + /*isSigned*/ resTy.isSigned()}, + /*isSigned*/ resTy.isSigned())); +} + +void inferResultRanges(tt::MakeRangeOp *op, SetIntRangeFn setResultRange) { + auto result = op->getResult(); + RankedTensorType resTy = result.getType(); + assert(llvm::isa(resTy.getElementType()) && "expected int type"); + IntegerType elTy = llvm::cast(resTy.getElementType()); + auto bitWidth = mlir::ConstantIntRanges::getStorageBitwidth(elTy); + setResultRange(result, + ConstantIntRanges::range( + /*min*/ {/*numBits*/ bitWidth, /*val*/ op->getStart(), + /*isSigned*/ elTy.isSigned()}, + /*max*/ + {/*numBits*/ bitWidth, /*val*/ op->getEnd(), + /*isSigned*/ elTy.isSigned()}, + /*isSigned*/ elTy.isSigned())); +} + +void inferResultRanges(tt::GatherOp *op, ArrayRef argRanges, + SetIntRangeFn setResultRange) { + assert(argRanges.size() == 2 && "expected two arg ranges"); + setResultRange(op->getResult(), argRanges[0]); +} + +void inferResultRangesUnaryOpForwardArgRange( + Operation *op, ArrayRef argRanges, + SetIntRangeFn setResultRange) { + assert(op->getNumResults() == 1 && "expected op to have one result"); + setResultRange(op->getResult(0), argRanges[0]); +} + +void inferResultRangesBinaryOpUnionArgRanges( + Operation *op, ArrayRef argRanges, + SetIntRangeFn setResultRange) { + assert(op->getNumResults() == 1 && "expected op to have one result"); + assert(op->getNumOperands() == 2 && "expected op to have two operands"); + assert(argRanges.size() == 2 && "expected two arg ranges"); + setResultRange(op->getResult(0), argRanges[0].rangeUnion(argRanges[1])); +} + +void inferResultRangesMaxNonNegSigned(Operation *op, + SetIntRangeFn setResultRange) { + for (auto result : op->getResults()) { + auto bitWidth = + mlir::ConstantIntRanges::getStorageBitwidth(result.getType()); + setResultRange(result, ConstantIntRanges::fromSigned( + APInt::getZero(bitWidth).sext(bitWidth), + APInt::getMaxValue(bitWidth).sext(bitWidth))); + } +} + +std::optional maybeGetAssumedRange(Operation *assumption, + Value anchor) { + arith::CmpIOp cmpOp = llvm::dyn_cast(assumption); + if (!cmpOp) { + emitRemark(assumption->getLoc(), "unsupported assumption operation"); + return {}; + } + + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + emitRemark(assumption->getLoc(), + "unsigned arithmetic not currently supported"); + return {}; + default: + break; + } + bool isSigned = true; + + bool anchorIsLhs = cmpOp.getLhs() == anchor; + auto maybeConstantIntValue = getConstantIntValue( + getAsOpFoldResult(anchorIsLhs ? cmpOp.getRhs() : cmpOp.getLhs())); + if (auto constValue = maybeConstantIntValue) { + unsigned bitWidth = ConstantIntRanges::getStorageBitwidth(anchor.getType()); + assert(bitWidth > 0 && "expected non-zero bitwdith"); + // This is always true in + APInt apVal = {bitWidth, static_cast(*constValue), isSigned}, + min = APInt::getSignedMinValue(bitWidth), + max = APInt::getSignedMaxValue(bitWidth); + + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::eq: + return mlir::ConstantIntRanges::constant(apVal); + case arith::CmpIPredicate::sge: { + // K >= apVal implies K ∈ [apVal, max] + if (anchorIsLhs) + return mlir::ConstantIntRanges::range(apVal, max, isSigned); + // apVal >= K implies K ∈ [min, apVal] + return mlir::ConstantIntRanges::range(min, apVal, isSigned); + } + case arith::CmpIPredicate::sgt: { + // K > apVal implies K >= apVal + 1 implies K ∈ [apVal + 1, max] + if (anchorIsLhs) + return mlir::ConstantIntRanges::range(apVal + 1, max, isSigned); + // apVal > K implies apVal - 1 >= K implies K ∈ [min, apVal - 1] + return mlir::ConstantIntRanges::range(min, apVal - 1, isSigned); + } + case arith::CmpIPredicate::sle: { + // K <= apVal implies K ∈ [min, apVal] + if (anchorIsLhs) + return mlir::ConstantIntRanges::range(min, apVal, isSigned); + // apVal <= K implies K ∈ [apVal, max] + return mlir::ConstantIntRanges::range(apVal, max, isSigned); + } + case arith::CmpIPredicate::slt: { + // K < apVal implies K <= apVal -1 implies K ∈ [min, apVal - 1] + if (anchorIsLhs) + return mlir::ConstantIntRanges::range(min, apVal - 1, isSigned); + // apVal < K implies apVal + 1 <= K implies K ∈ [apVal + 1, max] + return mlir::ConstantIntRanges::range(apVal + 1, max, isSigned); + } + default: + emitRemark(cmpOp.getLoc(), "unsupported cmp predicate for assumption"); + return {}; + } + } + return {}; +} + +} // namespace + +namespace mlir::triton::AMD { + +std::optional> +collectRanges(const DataFlowSolver &solver, ValueRange values) { + SmallVector ranges; + for (Value val : values) { + auto *maybeInferredRange = + solver.lookupState(val); + if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) + return {}; + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); + ranges.push_back(inferredRange); + } + return ranges; +} + +bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp) { + if (auto inputRanges = + collectRanges(solver, ValueRange{cmpOp.getOperands()})) { + intrange::CmpPredicate pred = + static_cast(cmpOp.getPredicate()); + return intrange::evaluatePred(pred, (*inputRanges)[0], (*inputRanges)[1]) + .value_or(false); + } + return false; +} + +std::optional +TritonIntegerRangeAnalysis::maybeGetAssumedRange(Value anchor) const { + auto matchingAssumptions = this->assumptions.lookup(anchor); + if (matchingAssumptions.empty()) + return {}; + + unsigned bitWidth = ConstantIntRanges::getStorageBitwidth(anchor.getType()); + assert(bitWidth > 0 && "expected non-zero bitwidth"); + ConstantIntRanges constIntRange = ConstantIntRanges::maxRange(bitWidth); + for (auto assumption : matchingAssumptions) { + if (auto constIntRange_ = ::maybeGetAssumedRange(assumption, anchor)) + constIntRange = constIntRange.intersection(*constIntRange_); + } + return constIntRange; +} + +void TritonIntegerRangeAnalysis::setToEntryState( + dataflow::IntegerValueRangeLattice *lattice) { + auto anchor = lattice->getAnchor(); + IntegerValueRange range = IntegerValueRange::getMaxRange(anchor); + if (auto maybeRange = maybeGetAssumedRange(anchor)) + range = *maybeRange; + propagateIfChanged(lattice, lattice->join(range)); +} + +LogicalResult TritonIntegerRangeAnalysis::visitOperation( + Operation *op, + ArrayRef operands, + ArrayRef resultsLattices) { + LDBG(" Inferring ranges for " << *op << "\n"); + // This callback is almost exactly like the callback in + // IntegerRangeAnalysis::visitOperation except we do not "short-cicruit" the + // analysis by inferring a maximum range for loop results (instead we + // perform a check based on visit counts in visitRegionSuccessors). + auto joinCallback = [&op, &resultsLattices, + this](Value v, const IntegerValueRange &incomingRange) { + auto result = dyn_cast(v); + if (!result) + return; + assert(llvm::is_contained(op->getResults(), result)); + + LDBG(" Inferred range " << incomingRange << "\n"); + dataflow::IntegerValueRangeLattice *lattice = + resultsLattices[result.getResultNumber()]; + IntegerValueRange oldRange = lattice->getValue(); + ChangeResult changed = lattice->join(incomingRange); + propagateIfChanged(lattice, changed); + }; + + // Ops with fixed/constant ranges. + if (llvm::isa( + op)) { + assert(resultsLattices.size() == 1 && "expected exactly one result"); + auto resultLattice = resultsLattices[0]; + + // No updates necessary. + if (!resultLattice->getValue().isUninitialized()) + return success(); + + // Check if user hinted/assumed (this really only applies to get_pid but + // simpler to keep it out of the typeswitch). + auto anchor = resultLattice->getAnchor(); + if (auto assumptions = this->assumptions.lookup(anchor); + !assumptions.empty()) { + setToEntryState(resultLattice); + return success(); + } + + // else use defaults + llvm::TypeSwitch(op) + .Case([&](auto getPIDOp) { + inferResultRangesPID(getPIDOp, kDefaultMaxPrograms - 1, joinCallback); + }) + .Case([&](auto getPIDOp) { + inferResultRangesPID(getPIDOp, kDefaultMaxPrograms, joinCallback); + }) + .Case([&](MakeRangeOp makeROp) { + inferResultRanges(&makeROp, joinCallback); + }) + .Case([&](HistogramOp histOp) { + return inferResultRangesMaxNonNegSigned(histOp, joinCallback); + }) + .Default([&](auto) { llvm::report_fatal_error("unsupported op"); }); + return success(); + } + + SmallVector argIntValueRanges = llvm::map_to_vector( + operands, [](const dataflow::IntegerValueRangeLattice *lattice) { + return lattice->getValue(); + }); + + // Ops with actually changing/variable input/output ranges. + if (llvm::isa(op)) { + SmallVector argConstIntRanges; + for (const auto &r : argIntValueRanges) { + if (r.isUninitialized()) { + setAllToEntryStates(resultsLattices); + return success(); + } + argConstIntRanges.push_back(r.getValue()); + } + llvm::TypeSwitch(op) + .Case([&](auto) { + return inferResultRangesUnaryOpForwardArgRange(op, argConstIntRanges, + joinCallback); + }) + .Case([&](auto joinOp) { + return inferResultRangesBinaryOpUnionArgRanges( + joinOp, argConstIntRanges, joinCallback); + }) + .Case([&](GatherOp gatherOp) { + return inferResultRanges(&gatherOp, argConstIntRanges, joinCallback); + }) + .Default([&](auto) { llvm::report_fatal_error("unsupported op"); }); + return success(); + } + + if (auto inferrable = dyn_cast(op)) { + inferrable.inferResultRangesFromOptional(argIntValueRanges, joinCallback); + return success(); + } + + setAllToEntryStates(resultsLattices); + return success(); +} + +void TritonIntegerRangeAnalysis::visitRegionSuccessors( + ProgramPoint *point, RegionBranchOpInterface branch, + RegionBranchPoint successor, + ArrayRef abstractLattices) { + SmallVector lattices; + for (auto abstractLat : abstractLattices) { + lattices.push_back( + static_cast(abstractLat)); + } + // Initialize loop trip counts + LoopLikeOpInterface loop = + llvm::dyn_cast(branch.getOperation()); + if (loop) { + if (!loopTripCounts.contains(loop)) { + SmallVector loops{loop}; + getEnclosingLoops(*loop, loops); + int loopTripCount = + std::accumulate(loops.begin(), loops.end(), 1, + [](int accum, LoopLikeOpInterface loop) { + return accum * maybeGetTripCount(loop).value_or( + kDefaultMaxTripCount + 1); + }); + loopTripCounts[loop] = loopTripCount; + } + for (auto argLat : lattices) { + if (!loopVisits.contains({loop, argLat})) { + loopVisits[{loop, argLat}] = 0; + } + } + } + + const auto *predecessors = + getOrCreateFor(point, point); + assert(predecessors->allPredecessorsKnown() && + "unexpected unresolved region successors"); + for (Operation *op : predecessors->getKnownPredecessors()) { + std::optional operands; + if (op == branch) { + operands = branch.getEntrySuccessorOperands(successor); + } else if (auto regionTerminator = + dyn_cast(op)) { + operands = regionTerminator.getSuccessorOperands(successor); + } + if (!operands) + return setAllToEntryStates(lattices); + + ValueRange inputs = predecessors->getSuccessorInputs(op); + assert(inputs.size() == operands->size() && + "expected the same number of successor inputs as operands"); + + unsigned firstIndex = 0; + if (inputs.size() != lattices.size()) { + if (!point->isBlockStart()) { + if (!inputs.empty()) { + firstIndex = cast(inputs.front()).getResultNumber(); + } + visitNonControlFlowArguments(branch, + RegionSuccessor(branch->getResults().slice( + firstIndex, inputs.size())), + lattices, firstIndex); + } else { + if (!inputs.empty()) { + firstIndex = cast(inputs.front()).getArgNumber(); + } + Region *region = point->getBlock()->getParent(); + visitNonControlFlowArguments( + branch, + RegionSuccessor(region, region->getArguments().slice( + firstIndex, inputs.size())), + lattices, firstIndex); + } + } + + for (auto [oper, argLat] : + llvm::zip(*operands, ArrayRef(lattices).drop_front(firstIndex))) { + std::pair loopArgLat = {loop, argLat}; + // If we've "run the loop" #tripcount times, stop propagating. + if (loop && loopVisits[loopArgLat] >= loopTripCounts[loop]) + continue; + ChangeResult changed; + if (loop && loopTripCounts[loop] > kDefaultMaxTripCount) { + // If the loop's tripcount is too large, infer the maximum range for + // the arg lattices. This will have the effect that all users will + // also be inferred to have maximum range and end the analysis will + // end (the maximum range is the "top" of the lattice and thus no + // further changes/updates are possible). + changed = argLat->join(IntegerValueRange::getMaxRange(oper)); + } else { + // Else, propagate pred operands. + changed = argLat->join(*getLatticeElementFor(point, oper)); + } + propagateIfChanged(argLat, changed); + // Only increase the loop visitation count if have actually update the + // lattice because otherwise we will over count the number of visits + // (since not all iter_arg lattices are updated/propagated on each + // visit). + if (loop && changed == ChangeResult::Change) + ++loopVisits[loopArgLat]; + } + } +} + +DenseMap> +TritonIntegerRangeAnalysis::collectAssumptions(Operation *rootOp, + bool filterConstants) { + DenseMap> assumptions; + rootOp->walk([&](LLVM::AssumeOp op) { + auto assump = op.getCond().getDefiningOp(); + for (auto operand : assump->getOperands()) { + if (filterConstants && getConstantIntValue(operand)) + continue; + assumptions[operand].insert(assump); + } + }); + return assumptions; +} + +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/CMakeLists.txt new file mode 100644 index 000000000..0b66ac50c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Analysis) +add_subdirectory(Dialect) +add_subdirectory(TritonAMDGPUToLLVM) +add_subdirectory(TritonAMDGPUDialectToLLVM) +add_subdirectory(TritonAMDGPUTransforms) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..4f9163bdf --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonAMDGPU) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt new file mode 100644 index 000000000..b79fc9480 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Utility) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..f550b6e20 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_triton_library(TritonAMDGPUIR + Dialect.cpp + + DEPENDS + TritonAMDGPUTableGen + TritonAMDGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp new file mode 100644 index 000000000..8764978bc --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +// clang-format off +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::amdgpu; + +void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" + >(); +} + +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" + +#define GET_OP_CLASSES +#include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" + +namespace mlir::triton::amdgpu { + +LogicalResult ExtractSliceOp::verify() { + auto srcTy = getSource().getType(); + auto srcLayout = srcTy.getEncoding(); + auto srcElementType = getElementTypeOrSelf(srcTy); + auto resultTy = getResult().getType(); + auto resultLayout = resultTy.getEncoding(); + auto resultElementType = getElementTypeOrSelf(resultTy); + + if (srcElementType != resultElementType) { + return emitError("result element type must match source element type"); + } + if (srcLayout != resultLayout) { + return emitError("result layout must match source layout"); + } + if (srcTy.getRank() != resultTy.getRank()) { + return emitError("result rank must be equal to source rank"); + } + if (srcTy.getRank() != 2) { + return emitError("currently only 2D tensors are supported"); + } + + auto srcShape = srcTy.getShape(); + + // ExtractSlice only supports slicing where offsets and sizes are multiples of + // shapePerCTATile. This condition ensures that slice has the same layout as + // the original tensor. + + auto offsets = getStaticOffsets(); + if (offsets.size() != 2) { + return emitError("invalid offset shape ") << offsets; + } + + SmallVector sizes; + for (auto i = 0; i < 2; ++i) { + auto resultDimSize = resultTy.getDimSize(i); + auto srcDimSize = srcTy.getDimSize(i); + if (resultDimSize == 0) { + return emitError("result tensor dimension size zero at dimension ") << i; + } + if (srcDimSize == 0) { + return emitError("source tensor dimension size zero at dimension ") << i; + } + if (resultDimSize > srcDimSize) { + return emitError( + "result shape cannot be larger than input shape at dimension ") + << i; + } + if (offsets[i] + resultDimSize > srcDimSize) { + return emitError("invalid offset ") + << offsets[i] << " at dimension " << i; + } + sizes.push_back(resultDimSize); + } + + auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcTy); + shapePerCTATile[0] = + std::min(static_cast(srcShape[0]), shapePerCTATile[0]); + shapePerCTATile[1] = + std::min(static_cast(srcShape[1]), shapePerCTATile[1]); + if (sizes[0] % shapePerCTATile[0] != 0 || + sizes[1] % shapePerCTATile[1] != 0) { + return emitError() << "sizes [" << sizes + << "] must be a multiple of shapePerCTATile [" + << shapePerCTATile << "]"; + } + + if (offsets[0] % shapePerCTATile[0] != 0 || + offsets[1] % shapePerCTATile[1] != 0) { + return emitError() << "offset [" << offsets + << "] must be a multiple of shapePerCTATile [" + << shapePerCTATile << "]"; + } + + return success(); +} + +LogicalResult UpcastMXFPOp::verify() { + auto fpType = getFpType(); + + auto xTy = getSrc().getType(); + auto scaleTy = getScale().getType(); + Builder b(getContext()); + if (xTy.getElementType() != b.getBF16Type() && + xTy.getElementType() != b.getF16Type() && + xTy.getElementType() != b.getI8Type()) { + return emitOpError( + "element type of the first operand must be bf16/fp16 or i8"); + } + + if (scaleTy.getElementType() != b.getI8Type()) { + return emitOpError("element type of the second operand must be uint8"); + } + + auto xShape = xTy.getShape(); + auto scaleShape = scaleTy.getShape(); + + if (xShape.size() != scaleShape.size() || xShape.size() < 2) { + return emitOpError( + "operands must have the same number of dimensions, at least 2"); + } + + if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 || + fpType == ScaleDotElemType::E5M2)) { + return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2"); + } + + auto layoutX = xTy.getEncoding(); + auto layoutScale = scaleTy.getEncoding(); + if (bool(layoutX) != bool(layoutScale)) { + return emitOpError( + "Expected either both or neither operands to have an encoding"); + } + // Nothing to check if no encoding. This is used to infer the return type in + // AccelerateMatmul.cpp + if (!layoutX) { + return success(); + } + + auto dotEncoding = dyn_cast(layoutX); + if (!dotEncoding) { + return emitOpError("Expected a DotOperandEncodingAttr for values"); + } + if (!isa(layoutScale)) { + return emitOpError( + "Expected a BlockOperandEncoding or LinearOperandEncoding " + "for scales"); + } + + // Change to support fp8 types + const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1; + // Figure out the K dimension for the input A/B. For A/B scale, the K + // dimension is always the last dimension. + const int opIdx = dotEncoding.getOpIdx(); + const bool hasBatch = xShape.size() == 3; + const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; + + if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) { + return emitOpError("K dimension of first operand must be 16 times " + "larger than last/K dimension of the second operand"); + } + + // Check other dimensions match too. For input A/B, we need to figure out the + // index for the M/N dimension. For scale, it's always {(batch), M/N, K}. + const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch; + if (hasBatch && xShape[0] != scaleShape[0]) + return emitOpError("batch dimension must match between operands"); + if (xShape[mnIdx] != scaleShape[hasBatch]) { + return emitOpError("M/N dimension must match between operands"); + } + + return success(); +} + +RankedTensorType +UpcastMXFPOp::deduceOutputType(TypedValue inputTensor, + ScaleDotElemType inputElemType, + Type outputElemType) { + MLIRContext *ctx = inputTensor.getContext(); + auto xTy = inputTensor.getType(); + if (inputElemType != ScaleDotElemType::E2M1) + return xTy; + + auto xShape = xTy.getShape(); + auto newShape = llvm::to_vector(xShape); + auto encoding = xTy.getEncoding(); + if (!encoding) { + newShape.back() *= 2; + return RankedTensorType::get(xShape, outputElemType); + } + + auto oldEncoding = cast(encoding); + auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(), + oldEncoding.getParent(), + oldEncoding.getKWidth() * 2); + // Figure out the K dimension for the input A/B, given that the return + // type is upcasted A/B type so we need to update the proper dim size. + const int opIdx = oldEncoding.getOpIdx(); + const bool hasBatch = xShape.size() == 3; + const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; + newShape[kIdx] *= 2; + return RankedTensorType::get(newShape, outputElemType, newVEncoding); +} + +} // namespace mlir::triton::amdgpu diff --git a/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/Utility/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/Utility/CMakeLists.txt new file mode 100644 index 000000000..f8d4deb6b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/Utility/CMakeLists.txt @@ -0,0 +1,8 @@ +add_triton_library(TritonAMDUtils + CommonUtils.cpp + + LINK_LIBS PUBLIC + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/Utility/CommonUtils.cpp b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/Utility/CommonUtils.cpp new file mode 100644 index 000000000..1c0d9743d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/Dialect/TritonAMDGPU/Utility/CommonUtils.cpp @@ -0,0 +1,17 @@ +#include "third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h" + +namespace mlir::triton::AMD { +SmallVector getLeafForOps(triton::FuncOp funcOp) { + SmallVector allOps; + funcOp->walk([&](scf::ForOp forOp) { allOps.push_back(forOp); }); + + SmallVector leafOps; + for (scf::ForOp forOp : allOps) { + auto searchResult = forOp.getBody()->walk( + [](scf::ForOp) { return WalkResult::interrupt(); }); + if (!searchResult.wasInterrupted()) + leafOps.push_back(forOp); + } + return leafOps; +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt new file mode 100644 index 000000000..4aebabc0a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -0,0 +1,7 @@ +add_triton_library(TritonAMDGPUDialectToLLVM + TritonAMDGPUToLLVMPatterns.cpp + ExtractSliceOpToLLVM.cpp + + DEPENDS + TritonAMDGPUIR +) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp new file mode 100644 index 000000000..07cf91870 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -0,0 +1,143 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +// clang-format off +//===--------------------------------------------------------------------------------===// +// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +// # WO # W1 # | # +// # # # | # +// # # # # # | # +// # W2 # W3 # .... | # +// # # # | SkipElems # +// # # # # # | # +// # | # +// # Slice | # +// # . / \ | # +// # . / \ | # +// # . / \| # +// # # # # # # # +// # # W0 # W1 # # +// # # # # # +// # # # # # # tensorStride # +// # # W2 # W3 # --------------------------------# +// # # # # # +// # # # # # # # +// # tensorStride # W0 # W1 # # +// # ---------------------------------- # # # # +// # # # # # # # +// # # W2 # W3 # # +// # # # # # +// # # # # # # ---> lastIdx # +// # . # +// # . # +// # . # +// # # +// # # +// # # +// # # +// # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +//===--------------------------------------------------------------------------------===// +// clang-format on + +namespace { +struct ExtractSliceOpConversion + : public ConvertOpToLLVMPattern { + explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit) { + } + + LogicalResult processLayout(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto srcTy = cast(op.getSource().getType()); + auto srcLayout = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultTy = cast(op.getType()); + auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); + auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); + auto contigPerThread = triton::gpu::getContigPerThread(srcTy); + auto totalContigPerThread = product(contigPerThread); + auto order = triton::gpu::getOrder(srcTy); + + // Calculate valid total number of workers in each dimension + auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcTy); + shapePerCTATile[0] = + std::min(static_cast(srcShape[0]), shapePerCTATile[0]); + shapePerCTATile[1] = + std::min(static_cast(srcShape[1]), shapePerCTATile[1]); + + // Rank == 2 checked in the verifier + SmallVector sizes; + for (auto i = 0; i < 2; ++i) { + sizes.push_back(resultTy.getDimSize(i)); + } + + auto offsets = op.getStaticOffsets(); + + // Calculate offsets and sizes in terms of CTA units. + std::array CTAOffsets{offsets[0] / shapePerCTATile[0], + offsets[1] / shapePerCTATile[1]}; + std::array CTASizes{sizes[0] / shapePerCTATile[0], + sizes[1] / shapePerCTATile[1]}; + std::array CTAPerShape{srcShape[0] / shapePerCTATile[0], + srcShape[1] / shapePerCTATile[1]}; + + // The diagram above illustrates the graphical representation of the + // skipElems, tensorStride, and lastIdx variables. + auto skipElems = CTAOffsets[order[1]] * (elemsPerThread[order[0]] * + contigPerThread[order[1]]) + + CTAOffsets[order[0]] * totalContigPerThread; + auto tensorStride = + (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalContigPerThread; + auto lastIdx = + (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * + elemsPerThread[order[0]] * contigPerThread[order[1]] + + (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalContigPerThread; + + assert(lastIdx <= vals.size()); + + SmallVector resultVals; + for (int i = skipElems; i < lastIdx; i += tensorStride) { + for (int j = 0; j < totalContigPerThread * CTASizes[order[0]]; ++j, ++i) { + assert(i < lastIdx); + resultVals.push_back(vals[i]); + } + } + Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + + rewriter.replaceOp(op, ret); + return success(); + } + + LogicalResult + matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = op.getSource().getType(); + if (isa( + op.getSource().getType().getEncoding())) { + return processLayout(op, adaptor, rewriter); + } + return failure(); + } +}; +} // namespace + +namespace mlir::triton::AMD { + +void populateExtractSliceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp new file mode 100644 index 000000000..c7c2f56d3 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -0,0 +1,10 @@ +#include "third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +namespace mlir::triton::AMD { +void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + populateExtractSliceOpToLLVMPatterns(typeConverter, patterns, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp new file mode 100644 index 000000000..9721eb735 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -0,0 +1,285 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "BufferOpsEmitter.h" + +using namespace triton::AMD; + +namespace { + +// Utility function to determine if a scalar/tensor value is zero +bool isZero(Value v) { + if (auto constantOp = v.getDefiningOp()) { + if (auto attr = dyn_cast(constantOp.getValue())) + return attr.getValue().isZero(); + if (auto attr = dyn_cast(constantOp.getValue())) + return attr.getValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + } + return false; +} +} // namespace + +namespace mlir::LLVM::AMD { +BufferEmitter::BufferEmitter(RewriterBase &rw, Location loc, TargetInfo ti) + : rewriter(rw), loc(loc), targetInfo(ti) {} + +Value BufferEmitter::createResourceDescriptor(Value basePtr, + Value blockStride) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // 1. Create the resource descriptor + // bits 0-11: dst sel, ignored by these intrinsics + // bits 12-14: data format (ignored, must be nonzero, 7=float) + // bits 15-18: data format (ignored, must be nonzero, 4=32bit) + // bit 19: In nested heap (0 here) + // bit 20: Behavior on unmap (0 means "return 0 / ignore") + // bits 21-22: Index stride for swizzles (N/A) + // bit 23: Add thread ID (0) + // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) + // bits 25-26: Reserved (0) + // bit 27: Buffer is non-volatile (CDNA only) + // bits 28-29: Out of bounds select (RDNA only) + // (0 = structured, + // 1 = check index, + // 2 = none, + // 3 = either swizzles or testing against offset field) + // bits 30-31: Type (must be 0) + uint32_t flags = (7 << 12) | (4 << 15); + if (targetInfo.getISAFamily() == ISAFamily::RDNA2 || + targetInfo.getISAFamily() == ISAFamily::RDNA3) { + flags |= (1 << 24); + uint32_t oob = 3; + flags |= (oob << 28); + } + + Value stride = b.int_val(16, 0); + if (llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4}, + targetInfo.getISAFamily())) { + if (blockStride) { + Value enableSwizzle = b.int_val(16, 16384); + Value mask14b = b.int_val(16, 16383); + // Cache swizzle supports only upto 8k stride. Also simply swizzling the + // largest available stride (8k) doesn't help those unsupported large + // stride. Especially better to avoid using the stride which is 2^N when + // N>13, e.g. by add padding to the buffer. + Value stride16b = + rewriter.create(loc, i16_ty, blockStride); + Value strideSat = rewriter.create(loc, stride16b, mask14b); + // stride[13:0] = swizzling stride + // stride[14] = swizzle enabling bit + stride = rewriter.create(loc, enableSwizzle, strideSat); + } + } + + Value flagsConst = b.int_val(32, flags); + Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); + Value numRecordsByte = b.int_val(32, std::numeric_limits::max() - 1); + + Value resource = rewriter.createOrFold( + loc, rsrcType, basePtr, stride, numRecordsByte, flagsConst); + return resource; +} + +Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset, + Value pred, Value falseVal, + triton::CacheModifier cm) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector args; + fillCommonArgs(type, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/true, args); + Type bufferType = getBufferOpType(type, false); + Value data = rewriter.create( + loc, bufferType, args, ArrayRef()); + data = b.bitcast(data, type); + if (!isZero(falseVal)) + data = b.select(pred, data, falseVal); + return data; +} + +void BufferEmitter::emitLoadToLds(Type type, Value byteWidth, Value rsrcDesc, + Value offset, Value dst, Value pred, + triton::CacheModifier cm) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector commonArgs; + fillCommonArgs(type, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/true, + commonArgs); + Type bufferType = getBufferOpType(type, false); + rewriter.create( + loc, TypeRange{}, + ValueRange{ + commonArgs[0], // Buffer descriptor + dst, // LDS base ptr + byteWidth, // Instr size + commonArgs[1], // Buffer offset + b.i32_val(0), // LDS offset + commonArgs[2], // Instruction offset + commonArgs[3], // AUX + }, + ArrayRef()); +} + +Value BufferEmitter::emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, + Value offset, Value data, Value pred, + bool hasUsers) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + VectorType vecTy = cast(data.getType()); + Type bufferType = getBufferOpType(type, true); + if (vecTy != bufferType) + data = b.bitcast(data, bufferType); + + SmallVector args{data}; + fillCommonArgsAtomics(type, rsrcDesc, offset, pred, hasUsers, args); + + // TODO: + // The ops in ROCDL (e.g., RawPtrBufferAtomicFaddOp) have no return value, + // but they lower to instrinsics that can return values. This causes the + // LLVM verifier to fail. When this is fixed, the ROCDL ops should be used + // here. + auto rmwOpStr = stringifyRMWOp(rmwType).str(); + auto instrinsic = "llvm.amdgcn.raw.ptr.buffer.atomic." + rmwOpStr; + auto bufferAtomicRMW = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, instrinsic, bufferType, args); + + return b.bitcast(bufferAtomicRMW.getResult(0), type); +} + +void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data, + Value pred, triton::CacheModifier cm) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + VectorType vecTy = cast(data.getType()); + Type bufferType = getBufferOpType(vecTy, false); + if (vecTy != bufferType) + data = b.bitcast(data, bufferType); + SmallVector args{data}; + fillCommonArgs(vecTy, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/false, + args); + rewriter.create(loc, TypeRange{}, args, + ArrayRef()); +} + +Type BufferEmitter::getBufferOpType(Type type, bool atomicsOp) { + int64_t vecSize = 1; + Type elementType = type; + if (auto vecType = dyn_cast(type)) { + vecSize = vecType.getNumElements(); + elementType = vecType.getElementType(); + } + + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const size_t totalWidthBits = valueElemNBits * vecSize; + + Type bufferElementType = elementType; + // We don't want to cast from bf16 if we are emitting buffer atomics + if (elementType.isBF16() && !atomicsOp) { + bufferElementType = rewriter.getI16Type(); + } + + // If we are dealing with a subword type (e.g., i8 or f16) but we + // still need multiple words, then pack the subwords into 32bit integers + // and update the vector length and the type + // We never need to pack for buffer atomics because we ensure + // 1) We can always emit a 32-bit / 64-bit atomics op + // 2) For tensors of 16-bit values that the values are contiguous + int64_t bufferVecSize = vecSize; + if (valueElemNBits < 32 && !atomicsOp) { + if (totalWidthBits > 32) { + bufferElementType = rewriter.getI32Type(); + bufferVecSize = totalWidthBits / 32; + } else { + bufferElementType = rewriter.getIntegerType(totalWidthBits); + bufferVecSize = 1; + } + } + + // This is the buffer type that the buffer operation will use. It + // will be bitcast-able to the original type. So if the types + // ended up different, we simply have to emit a `bitcastOp` to convert + Type bufferType = type; + if (bufferVecSize != vecSize || bufferElementType != elementType) + bufferType = VectorType::get(bufferVecSize, bufferElementType); + if (bufferVecSize == 1) + bufferType = getElementTypeOrSelf(bufferType); + + return bufferType; +} + +void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc, + Value vOffsetElems, Value pred, + triton::CacheModifier cm, bool isBufferLoad, + SmallVector &args) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // 1. Create the (masked) offset + Type elementType = getElementTypeOrSelf(type); + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const int elementByteWidth = valueElemNBits / 8; + // Please note: the index passed is not in bytes, but in number of elements + // In order to pass the index to the buffer operation, we need to convert in + // bytes (i.e., we need to multiply by `elementByteWidth`) + Value vOffsetOutOfBunds = b.int_val( + 32, static_cast(std::numeric_limits::max() + int64_t(1))); + Value vOffsetBytes = b.mul(b.int_val(32, elementByteWidth), vOffsetElems); + Value maskedOffsetBytes = b.select(pred, vOffsetBytes, vOffsetOutOfBunds); + + // 2. Set the sgprOffset to 0 + Value sgprOffset = b.int_val(32, 0); + + // 3. Create the cache modifiers word + int32_t aux = + getCtrlBitsForCacheModifierOnTarget(cm, isBufferLoad, targetInfo); + Value cacheModifiers = b.int_val(32, aux); + + // 4. Add the arguments + args.push_back(rsrcDesc); + args.push_back(maskedOffsetBytes); + args.push_back(sgprOffset); + args.push_back(cacheModifiers); +} + +void BufferEmitter::fillCommonArgsAtomics(Type type, Value rsrcDesc, + Value vOffsetElems, Value pred, + bool hasUsers, + SmallVector &args) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // 1. Create the (masked) offset + Type elementType = getElementTypeOrSelf(type); + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const int elementByteWidth = valueElemNBits / 8; + // Please note: the index passed is not in bytes, but in number of elements + // In order to pass the index to the buffer operation, we need to convert in + // bytes (i.e., we need to multiply by `elementByteWidth`) + Value vOffsetOutOfBunds = b.int_val( + 32, static_cast(std::numeric_limits::max() + int64_t(1))); + Value vOffsetBytes = b.mul(b.int_val(32, elementByteWidth), vOffsetElems); + Value maskedOffsetBytes = b.select(pred, vOffsetBytes, vOffsetOutOfBunds); + + // 2. Set the sgprOffset to 0 + Value sgprOffset = b.int_val(32, 0); + + // 3. Create the cache modifiers word + int32_t aux = 0; + if (hasUsers) + aux = getCtrlBitsForBufferAtomicsOnGFX_942_950(/*setSC0*/ true, + /*setSC1*/ false, + /*setNT*/ false); + else + aux = getCtrlBitsForBufferAtomicsOnGFX_942_950( + /*setSC0*/ false, /*setSC1*/ false, /*setNT*/ false); + + Value cacheModifiers = b.int_val(32, aux); + + // 4. Add the arguments + args.push_back(rsrcDesc); + args.push_back(maskedOffsetBytes); + args.push_back(sgprOffset); + args.push_back(cacheModifiers); +} + +} // namespace mlir::LLVM::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h new file mode 100644 index 000000000..6bd56742d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h @@ -0,0 +1,109 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_BUFFEROPSEMITTER_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_BUFFEROPSEMITTER_H_ + +#include "TargetInfo.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include + +namespace mlir::LLVM::AMD { +// Utility class to take care of buffer operation emission. We may add more +// emitters into this as needed. Buffer operations accept a memory descriptor +// and an offset. +// +// The memory descriptor is stored in s_gprs and hence needs to +// be uniform across the wave. It contains two fields (among many others): +// +// - `base_pointer`: represents the (scalar) pointer to the memory area +// - `num_records`: represents the size of the memory region. This is a +// 32 bit unsigned integer +// +// The offset can be non-uniform across the wave (and hence stored in vgprs). +// +// The high level behaviour of a buffer operation can be described as: +// ``` +// def buffer_op(mem_desc, offset): +// address = splat(mem_desc.base_pointer) +// address += offset +// return buffer_op(address) +// ``` +// This means we don't need to store the addresses in vgprs and we need less +// VALU operations to compute the final address. +// +// Also note that buffer operations support out-of-boundary memory access. +// I.e., if offset[i] > mem_desc.num_records the operation is a nop for the i-th +// thread. +// +// This can be exploited to support masked operations, like in the following +// snippet: +// ``` +// def masked_op(base_ptr, offset, pred) +// mem_desc.base_ptr = base_ptr +// mem_desc.num_records = max_int_32 +// oob_offset = max_int_32+1 +// masked_offset = (pred ? offset : oob_offset) +// buffer_op(mem_desc, masked_offset) +// ``` +// To use buffer operations three main requirements need to be met: +// +// 1. The buffer pointer needs to be a scalar, it cannot be non-uniform across +// threads of the given wave +// 2. The offset needs to be expressed in 32 bits +// 3. The offset needs to be non-negative +// +// Failure to meet 1) will result in a scalarized loop (very poor performance). +// Failure to meet 2) and 3) will result in incorrect memory access. +struct BufferEmitter { + BufferEmitter(RewriterBase &rw, Location loc, + mlir::triton::AMD::TargetInfo ti); + + // Create a resource descriptor that points to the area of memory we want to + // load from + Value createResourceDescriptor(Value basePtr, Value inferredStride = nullptr); + + // Emit a predicated rocdl.raw.ptr.buffer.load + Value emitLoad(Type type, Value rsrcDesc, Value offset, Value pred, + Value falseVal, CacheModifier cm); + + // Emit a predicated rocdl.raw.ptr.buffer.load.lds + void emitLoadToLds(Type type, Value byteWidth, Value rsrcDesc, Value offset, + Value dst, Value pred, CacheModifier cm); + + // Emit a predicated rocdl.raw.ptr.buffer.atomic.* RMWOp + Value emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, Value offset, + Value data, Value pred, bool hasUsers); + + // Emit a predicated rocdl.raw.ptr.buffer.store + void emitStore(Value rsrcDesc, Value offset, Value data, Value pred, + CacheModifier cm); + +private: + // Fill common buffer operation arguments. + void fillCommonArgs(Type type, Value rsrcDesc, Value vOffsetElems, Value pred, + CacheModifier cm, bool isBufferLoad, + SmallVector &args); + + // Fill buffer atomics arguments + void fillCommonArgsAtomics(Type type, Value rsrcDesc, Value vOffsetElems, + Value pred, bool hasUsers, + SmallVector &args); + + // Given a type, the buffer type can be either the same type + // or a packed version. E.g., a vector of 8xfp16 can be bitcasted to + // a vector of 4xi32. This usually makes the life of the backend easier + Type getBufferOpType(Type type, bool atomicsOp); + + // Rewriter utilities + RewriterBase &rewriter; + Location loc; + mlir::triton::AMD::TargetInfo targetInfo; +}; + +} // namespace mlir::LLVM::AMD + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_BUFFEROPSEMITTER_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp new file mode 100644 index 000000000..8d0702d37 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -0,0 +1,218 @@ +#include "TritonAMDGPUToLLVM/Passes.h" + +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_CONVERTBUILTINFUNCTOLLVM +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +using namespace mlir; + +namespace { + +class CallOpConversion : public OpRewritePattern { +public: + CallOpConversion(mlir::MLIRContext *context, bool ftz) + : OpRewritePattern(context, 1), ftz(ftz) {} + + LogicalResult + matchAndRewrite(LLVM::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { + if (isPredicatedLoad(callOp)) { + return convertPredicatedLoad(callOp, rewriter); + } else if (isPredicatedStore(callOp)) { + return convertPredicatedStore(callOp, rewriter); + } else if (isWrappedLLVMIntrinsic(callOp)) { + return convertToLLVMIntrinsic(callOp, rewriter); + } else { + return failure(); + } + } + +private: + bool isPredicatedLoad(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains(mlir::LLVM::AMD::predicatedLoad); + } + + bool isPredicatedStore(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedStore); + } + + bool isWrappedLLVMIntrinsic(LLVM::CallOp callOp) const { + if (std::optional callee = callOp.getCallee()) { + if (callee.value().starts_with("__triton_hip_")) { + return true; + } + } + return false; + } + + LogicalResult convertPredicatedStore(LLVM::CallOp callOp, + mlir::PatternRewriter &rewriter) const { + auto operands = callOp.getOperands(); + + auto loc = callOp.getLoc(); + auto ptr = operands[0]; + auto val = operands[1]; + auto pred = operands[2]; + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterStore = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *trueBlock = rewriter.createBlock(afterStore); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, pred, trueBlock, afterStore); + rewriter.setInsertionPointToStart(trueBlock); + // | vialatile | non-tmp | gcn instr gfx94 + // LLVM::StoreOp | 0 | 0 | (cg) global store + // | 0 | 1 | (cs) global store nt + // | 1 | 0/1 | (wt) global store sc0 sc1 + auto [volatileFlag, nonTmpFlag] = + mlir::LLVM::AMD::getCacheModifierFlagsForPredicatedCall(callOp); + auto storeOp = rewriter.create( + loc, val, ptr, /*alignment=*/0, volatileFlag, nonTmpFlag); + rewriter.create(loc, afterStore); + rewriter.setInsertionPointToStart(afterStore); + rewriter.eraseOp(callOp); + return mlir::success(); + } + + LogicalResult convertPredicatedLoad(LLVM::CallOp callOp, + mlir::PatternRewriter &rewriter) const { + auto operands = callOp.getOperands(); + auto result = callOp.getResult(); + + auto loc = callOp.getLoc(); + auto elemTy = result.getType(); + auto ptr = operands[0]; + auto pred = operands[1]; + auto falseVal = operands[2]; + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterLoad = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + afterLoad->addArgument({elemTy}, {loc}); + Block *trueBlock = rewriter.createBlock(afterLoad); + Block *falseBlock = + rewriter.splitBlock(trueBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, pred, trueBlock, falseBlock); + rewriter.setInsertionPointToStart(trueBlock); + // | vialatile | non-tmp | gcn instr gfx94 + // LLVM::LoadOp | 0 | 0 | (ca) global load + // | 0/1 | 1 | (cg) global load nt + // | 1 | 0 | (cv) flat load sc0 sc1 + auto [volatileFlag, nonTmpFlag] = + mlir::LLVM::AMD::getCacheModifierFlagsForPredicatedCall(callOp); + auto loadOp = rewriter.create( + loc, elemTy, ptr, /*alignment=*/0, volatileFlag, nonTmpFlag); + rewriter.create(loc, loadOp->getResult(0), afterLoad); + rewriter.setInsertionPointToStart(falseBlock); + rewriter.create(loc, falseVal, afterLoad); + rewriter.setInsertionPointToStart(afterLoad); + Value loadVal = afterLoad->getArgument(0); + rewriter.replaceOp(callOp, loadVal); + return mlir::success(); + } + + LogicalResult convertToLLVMIntrinsic(LLVM::CallOp callOp, + mlir::PatternRewriter &rewriter) const { + StringRef calleeName = callOp.getCallee().value(); + + auto operands = callOp.getOperands(); + auto result = callOp.getResult(); + + LLVM::LLVMFunctionType calleeType = callOp.getCalleeFunctionType(); + Type returnType = calleeType.getReturnType(); + + auto loc = callOp.getLoc(); + + Operation *replacementOp = nullptr; + if (calleeName == "__triton_hip_iabs") { + assert(operands.size() == 1); + replacementOp = rewriter.create(loc, returnType, operands[0], + /*is_int_min_poison=*/false); + } else if (calleeName == "__triton_hip_fabs") { + assert(operands.size() == 1); + replacementOp = + rewriter.create(loc, returnType, operands[0]); + } else if (calleeName == "__triton_hip_llrint") { + assert(operands.size() == 1); + // Note, LrintOp and LlrintOp result in a code-gen error + Operation *op = rewriter.create(loc, operands[0].getType(), + operands[0]); + replacementOp = + rewriter.create(loc, returnType, op->getResult(0)); + } else if (calleeName == "__triton_hip_fast_fdividef") { + assert(operands.size() == 2); + const char *intrinsic = "llvm.amdgcn.rcp.f32"; + auto rcpOp = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, + returnType, operands[1]); + + LLVM::FastmathFlagsAttr defaultFlags{}; + replacementOp = rewriter.create( + loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags); + } else if (calleeName == "__triton_hip_fast_expf") { + assert(operands.size() == 1); + assert(operands[0].getType().getIntOrFloatBitWidth() == 32); + const double log2e = 1.4426950408889634; + LLVM::FastmathFlagsAttr defaultFlags{}; + auto mulOp = rewriter.create( + loc, rewriter.getF32Type(), operands[0], + LLVM::createConstantF32(loc, rewriter, log2e), defaultFlags); + const char *intrinsic = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32"; + + replacementOp = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, intrinsic, returnType, mulOp->getResult(0)); + } + + if (replacementOp) { + rewriter.replaceOp(callOp, replacementOp); + return mlir::success(); + } + + return mlir::failure(); + } + +private: + bool ftz; +}; + +struct ConvertBuiltinFuncToLLVM + : public triton::impl::ConvertBuiltinFuncToLLVMBase< + ConvertBuiltinFuncToLLVM> { + explicit ConvertBuiltinFuncToLLVM(bool ftz) { this->ftz = ftz; } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + GreedyRewriteConfig config; + config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; + + RewritePatternSet patterns(context); + patterns.add(context, this->ftz); + + if (mlir::applyPatternsGreedily(mod, std::move(patterns), config) + .failed()) { + signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir::triton { + +std::unique_ptr> +createConvertBuiltinFuncToLLVMPass(bool ftz) { + return std::make_unique(ftz); +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..558d08037 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -0,0 +1,34 @@ +add_triton_library(TritonAMDGPUToLLVM + BufferOpsEmitter.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp + ConvertLayoutOpToLLVM.cpp + MemoryOpToLLVM.cpp + DotOpToLLVM/FMA.cpp + DotOpToLLVM/MFMA.cpp + DotOpToLLVM/WMMA.cpp + DotOpToLLVM.cpp + ElementwiseOpToLLVM.cpp + LoadStoreOpToLLVM.cpp + GCNAsmFormat.cpp + TritonGPUToLLVM.cpp + BuiltinFuncToLLVM.cpp + Utility.cpp + TargetInfo.cpp + TargetUtils.cpp + DecomposeUnsupportedConversions.cpp + OptimizeLDSUsage.cpp + OptimizeLDSUtility.cpp + SPMDOpToLLVM.cpp + SchedInstructions.cpp + UpcastMXFPToLLVM.cpp + + DEPENDS + TritonAMDGPUConversionPassIncGen + + LINK_LIBS PUBLIC + TritonGPUToLLVM + TritonAMDGPUIR + TritonProtonToLLVM +) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..a53a5d276 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,190 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::MemDescType; + +namespace SharedToDotOperandMFMA { +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, + DotOperandEncodingAttr bEncoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread); +} // namespace SharedToDotOperandMFMA + +namespace SharedToDotOperandWMMA { +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, + DotOperandEncodingAttr bEncoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread); +} // namespace SharedToDotOperandWMMA + +namespace { + +struct ConvertLayoutOpMFMAToDotOpConversion + : public ConvertOpToLLVMPattern { +public: + explicit ConvertLayoutOpMFMAToDotOpConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(op.getSrc().getType()); + auto dstType = cast(op.getType()); + + if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType)) + return failure(); + + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (inVals.empty() || inVals.size() % 8 != 0) + return failure(); + + auto mfmaLayout = dyn_cast(srcType.getEncoding()); + assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) && + "Expected MFMA size 16 or 32"); + assert(triton::gpu::getWarpSize(mfmaLayout) == 64 && + "Expected warp size 64 for MFMA"); + + auto elemTy = int_ty(8); + auto vecTy = vec_ty(elemTy, 4); + + Value c16 = b.i32_val(16); + Value c32 = b.i32_val(32); + Value c48 = b.i32_val(48); + Value c64 = b.i32_val(64); + + Value threadId = getThreadId(rewriter, loc); + Value laneId = b.urem(threadId, c64); + + Value mask0 = b.icmp_slt(laneId, c32); + Value mask1 = b.icmp_slt(b.urem(laneId, c32), c16); + + Value addrShift16 = b.urem(b.add(laneId, c16), c64); + Value addrShift32 = b.urem(b.add(laneId, c32), c64); + Value addrShift48 = b.urem(b.add(laneId, c48), c64); + + SmallVector outVals; + for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { + Value vec0 = b.undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec0 = b.insert_element(vecTy, vec0, inVals[startIdx + vIdx], + b.i32_val(vIdx)); + } + Value vec1 = b.undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec1 = b.insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], + b.i32_val(vIdx)); + } + + Value resVec0, resVec1; + if (mfmaLayout.getMDim() == 32) { + /* + Using wave shuffle to convert layouts (32x32x16 case): + 1) Input MMA layout (32x32, fp8, 16 values): + _____________________________________________________________ + |(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)| + | ... ... | + |(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)| + |_____________________________________________________________| + + 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): + ____________________________________________________________ ___ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) || + | ... ... ||... + |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) || + |____________________________________________________________||___ + */ + + Value shflVec0 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)), + addrShift32), + vecTy); + Value shflVec1 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)), + addrShift32), + vecTy); + + resVec0 = b.select(mask0, vec0, shflVec1); + resVec1 = b.select(mask0, shflVec0, vec1); + } else if (mfmaLayout.getMDim() == 16) { + /* + 16x16x32 case: + 1) Input MMA layout (two 16x16, fp8, 4 values each): + _________________________________________________________ ___________ + |(t0 v0 v1 v2 v3) (t16 v0 v1 v2 v3) ... (t48 v0 v1 v2 v3)||(t0 v4 ... + | ... ... || ... + |(t15 v0 v1 v2 v3) (t31 v0 v1 v2 v3) ... (t63 v0 v1 v2 v3)||(t15 v4 ... + |_________________________________________________________||___________ + + 2) Output Dot operand layout (16x32 tile, fp8, 8 values): + ________________________________________________________________ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) ... (t48 v0 v1 v2 v3 v4 v5 v6 v7) | + | ... ... | + |(t15 v0 v1 v2 v3 v4 v5 v6 v7) ... (t63 v0 v1 v2 v3 v4 v5 v6 v7) | + |________________________________________________________________| + */ + + Value shflVec0_16 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)), + addrShift16), + vecTy); + Value shflVec0_32 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)), + addrShift32), + vecTy); + Value shflVec1_32 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)), + addrShift32), + vecTy); + Value shflVec1_48 = b.bitcast( + targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)), + addrShift48), + vecTy); + + resVec0 = b.select(mask0, b.select(mask1, vec0, shflVec0_16), + b.select(mask1, shflVec1_32, shflVec1_48)); + resVec1 = b.select(mask0, b.select(mask1, shflVec0_16, shflVec0_32), + b.select(mask1, shflVec1_48, vec1)); + } + + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(b.extract_element(elemTy, resVec0, b.i32_val(vIdx))); + } + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(b.extract_element(elemTy, resVec1, b.i32_val(vIdx))); + } + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, + benefit); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp new file mode 100644 index 000000000..77028e37e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp @@ -0,0 +1,257 @@ +#include "SharedToDotOperandHelper.h" + +using ::mlir::triton::gpu::SwizzledSharedEncodingAttr; + +namespace mlir::triton::AMD { + +// Get warpId inside block of warps. +Value getWarpIdInBlock(ConversionPatternRewriter &rewriter, Location loc, + Value warpId, const ArrayRef &wpt, + int elemPerInstrNonK, int tensorSizeNonK, int nonKIdx, + const ArrayRef &order) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, wpt, order); + + return b.urem(multiDimWarpId[nonKIdx], + b.i32_val(tensorSizeNonK / elemPerInstrNonK)); +} + +bool isSwizzled(SwizzledSharedEncodingAttr layout) { + return layout.getMaxPhase() != 1; +} + +std::pair +swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, + Value col, SharedMemoryObject smemObj, + SwizzledSharedEncodingAttr attr) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + (void)smemObj; // unused in current pattern + const auto &order = attr.getOrder(); + auto rank = order.size(); + bool transposed = (order[rank - 2] != 1); + if (transposed) { + // tensor is column-wise, so swapping col and row in computations + std::swap(row, col); + } + auto vec = b.i32_val(attr.getVec()); + auto perPhase = b.i32_val(attr.getPerPhase()); + auto maxPhase = b.i32_val(attr.getMaxPhase()); + + // phase = (row // perPhase) % maxPhase + // colOffSwizzled = ((col // vec) ^ phase) * vec + // colOffOrdered = col % vec + // colOff = colOffSwizzled + colOffOrdered + auto phase = b.urem(b.udiv(row, perPhase), maxPhase); + auto colOffSwizzled = b.mul(b.xor_(b.udiv(col, vec), phase), vec); + auto colOffOrdered = b.urem(col, vec); + auto colOff = b.add(colOffSwizzled, colOffOrdered); + + if (transposed) + return {colOff, row}; + else + return {row, colOff}; +} + +Value computeOffset(ConversionPatternRewriter &rewriter, Location loc, + Value row, Value col, SharedMemoryObject smemObj, + ArrayRef smemStrides, + SwizzledSharedEncodingAttr srcLayout) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto [swizzledRow, swizzledCol] = + swizzleIndexes(rewriter, loc, row, col, smemObj, srcLayout); + auto rank = smemStrides.size(); + assert(rank == 2 || rank == 3); + Value rowOffset = b.mul(swizzledRow, smemStrides[rank - 2]); + Value colOffset = b.mul(swizzledCol, smemStrides[rank - 1]); + return b.add(rowOffset, colOffset); +} + +Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, + const SharedMemoryObject &smemObj, + ArrayRef smemStrides) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value base = smemObj.getBase(); + Type type = base.getType(); + Type elemType = smemObj.getBaseElemType(); + for (int i = 0; i < smemStrides.size(); ++i) { + Value offset = + b.sub(b.i32_val(0), b.mul(smemObj.getOffsets()[i], smemStrides[i])); + base = b.gep(type, elemType, base, offset); + } + return base; +} + +bool isKContig(llvm::ArrayRef order, int opIdx) { + auto rank = order.size(); + int kdim = opIdx == 0 ? rank - 1 : rank - 2; + return order[0] == kdim; +} + +/// Checks that swizzle pattern fits into one warp block +/// and block size is a multiple of swizzle size along non-K dimension +/// +/// \param sharedLayout +/// \param opIdx operand id 0 or 1 +/// \param reps number of repetitions: [non-k, k] or [batch, non-k, k] +/// \param elemsPerInstr one instruction size +/// \param warpsPerBlockNonK number of warps along non-k Dim +/// \returns bool +bool isSwizzlePatternFitsIntoBlock( + const SwizzledSharedEncodingAttr sharedLayout, int opIdx, + const ArrayRef reps, const ArrayRef elemsPerInstr, + unsigned warpsPerBlockNonK) { + assert(elemsPerInstr.size() == 2); + unsigned mfmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1]; + unsigned mfmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0]; + auto order = sharedLayout.getOrder(); + const auto swizzleFastDimSize = + sharedLayout.getMaxPhase() * sharedLayout.getVec(); + const auto swizzleSlowDimSize = + sharedLayout.getMaxPhase() * sharedLayout.getPerPhase(); + const auto swizzlePatternSizeK = + isKContig(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; + const auto swizzlePatternSizeNonK = + !isKContig(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; + + const auto blockSizeK = mfmaInstrK * reps[reps.size() - 1]; + const auto blockSizeNonK = mfmaInstrNonK * warpsPerBlockNonK; + return blockSizeK % swizzlePatternSizeK == 0 && + blockSizeNonK % swizzlePatternSizeNonK == 0; +} + +llvm::SmallVector computeOffsetsAType( + ConversionPatternRewriter &rewriter, Location loc, + computeTensorElemMappingInBlockT fn, const ArrayRef &elemsPerInstr, + Value warpId, Value laneId, int warpsPerBlock, int numOfElems, + ArrayRef reps, SharedMemoryObject smemObj, + ArrayRef smemStrides, SwizzledSharedEncodingAttr srcLayout, + unsigned nonKDim, unsigned kDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector offsets = smemObj.getOffsets(); + auto order = srcLayout.getOrder(); + auto rank = offsets.size(); + + int vectorSize = 1; + if (order[0] == rank - 1) { + if (isSwizzled(srcLayout)) + vectorSize = std::min(static_cast(srcLayout.getVec()), numOfElems); + else + vectorSize = numOfElems; + } + + auto mapping = fn(rewriter, loc, elemsPerInstr, warpId, laneId, numOfElems, + reps, offsets, vectorSize, nonKDim, kDim); + const auto numBlocks = reps[reps.size() - 2]; + const auto blockSize = mapping.size(); + llvm::SmallVector aOffsets(blockSize * numBlocks); + + if (!isSwizzlePatternFitsIntoBlock(srcLayout, 0, reps, elemsPerInstr, + warpsPerBlock)) { + for (int block = 0; block < numBlocks; ++block) { + int blockNonKOffset = block * nonKDim * warpsPerBlock; + for (int i = 0; i < blockSize; ++i) { + Value row = b.add(mapping[i][0], b.i32_val(blockNonKOffset)); + Value col = mapping[i][1]; + aOffsets[block * blockSize + i] = computeOffset( + rewriter, loc, row, col, smemObj, smemStrides, srcLayout); + } + } + } else { + // compute inblock offsets once and reuse them for all blocks + llvm::SmallVector inblockOffset(mapping.size()); + for (int i = 0; i < mapping.size(); ++i) { + Value row = mapping[i][0]; + Value col = mapping[i][1]; + inblockOffset[i] = computeOffset(rewriter, loc, row, col, smemObj, + smemStrides, srcLayout); + } + for (int block = 0; block < numBlocks; ++block) { + int blockNonKOffset = block * nonKDim * warpsPerBlock; + Value offAdjust = + b.mul(b.i32_val(blockNonKOffset), smemStrides[rank - 2]); + for (int i = 0; i < blockSize; ++i) + aOffsets[block * blockSize + i] = b.add(offAdjust, inblockOffset[i]); + } + } + return aOffsets; +} + +template +static SmallVector +transposeSpatialDims(const Container &vec) { + auto rank = vec.size(); + assert(rank == 2 || rank == 3); + SmallVector res(rank, vec[0]); + res[rank - 2] = vec[rank - 1]; + res[rank - 1] = vec[rank - 2]; + return res; +} + +llvm::SmallVector computeOffsetsBType( + ConversionPatternRewriter &rewriter, Location loc, + computeTensorElemMappingInBlockT fn, const ArrayRef &elemsPerInstr, + Value warpId, Value laneId, int warpsPerBlock, int numOfElems, + ArrayRef reps, SharedMemoryObject smemObj, + ArrayRef smemStrides, SwizzledSharedEncodingAttr srcLayout, + unsigned nonKDim, unsigned kDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // transpose reps and offsets, because operand B has layout equal to + // transposed operand A layout + // this unifies axis order, so non-K dim is 0, k dim is 1 + auto rank = smemObj.getOffsets().size(); + auto order = srcLayout.getOrder(); + SmallVector tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]}; + SmallVector tReps = transposeSpatialDims(reps); + SmallVector tOffsets = transposeSpatialDims(smemObj.getOffsets()); + SmallVector tStrides = transposeSpatialDims(smemStrides); + + int vectorSize = 1; + if (order[0] == rank - 2) { + if (isSwizzled(srcLayout)) + vectorSize = std::min(static_cast(srcLayout.getVec()), numOfElems); + else + vectorSize = numOfElems; + } + + auto mapping = fn(rewriter, loc, tElemsPerInstr, warpId, laneId, numOfElems, + tReps, tOffsets, vectorSize, nonKDim, kDim); + const auto numBlocks = tReps[tReps.size() - 2]; + const auto blockSize = mapping.size(); + llvm::SmallVector bOffsets(blockSize * numBlocks); + + if (!isSwizzlePatternFitsIntoBlock(srcLayout, 0, reps, elemsPerInstr, + warpsPerBlock)) { + for (int block = 0; block < numBlocks; ++block) { + int blockNonKOffset = block * nonKDim * warpsPerBlock; + for (int i = 0; i < mapping.size(); ++i) { + // swap row and col, because operand B layout is + // a transposed operand A layout + Value row = mapping[i][1]; + Value col = b.add(mapping[i][0], b.i32_val(blockNonKOffset)); + bOffsets[block * blockSize + i] = computeOffset( + rewriter, loc, row, col, smemObj, smemStrides, srcLayout); + } + } + } else { + // compute inblock offsets once and reuse them for all blocks + llvm::SmallVector inblockOffset(mapping.size()); + for (int i = 0; i < mapping.size(); ++i) { + // swap row and col, because operand B layout is a transposed operand A + // layout + Value row = mapping[i][1]; + Value col = mapping[i][0]; + inblockOffset[i] = computeOffset(rewriter, loc, row, col, smemObj, + smemStrides, srcLayout); + } + for (int block = 0; block < numBlocks; ++block) { + int blockNonKOffset = block * nonKDim * warpsPerBlock; + Value offAdjust = b.mul(b.i32_val(blockNonKOffset), tStrides[rank - 2]); + for (int i = 0; i < mapping.size(); ++i) + bOffsets[block * blockSize + i] = b.add(offAdjust, inblockOffset[i]); + } + } + return bOffsets; +} + +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h new file mode 100644 index 000000000..0055bbd77 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h @@ -0,0 +1,64 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_CONVERTLAYOUTOPTOLLVM_SHAREDTODOTOPERANDHELPER_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_CONVERTLAYOUTOPTOLLVM_SHAREDTODOTOPERANDHELPER_H_ + +#include "Utility.h" + +namespace mlir::triton::AMD { + +// Get warpId inside block of warps. +Value getWarpIdInBlock(ConversionPatternRewriter &rewriter, Location loc, + Value warpId, const ArrayRef &wpt, + int elemPerInstrNonK, int tensorSizeNonK, int nonKIdx, + const ArrayRef &order); + +bool isSwizzled(gpu::SwizzledSharedEncodingAttr layout); + +/// Swizzling tensor element indexes according pattern encoded in +/// SwizzledSharedEncodingAttr +/// +/// \param rewriter +/// \param loc +/// \param row row of target tensor element related to the start of smemObj +/// \param col col of target tensor element related to the start of smemObj +/// \param smemObj shared memory object, contains info about tensor in LDS +/// \param attr layout attribute, contains swizzling info +/// \returns swizzled row, col indexes in tensor notation +std::pair +swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, + Value col, SharedMemoryObject smemObj, + gpu::SwizzledSharedEncodingAttr attr); + +Value computeOffset(ConversionPatternRewriter &rewriter, Location loc, + Value row, Value col, SharedMemoryObject smemObj, + ArrayRef strides, + gpu::SwizzledSharedEncodingAttr srcLayout); + +Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, + const SharedMemoryObject &smemObj, + ArrayRef strides); + +bool isKContig(llvm::ArrayRef order, int opIdx); + +using computeTensorElemMappingInBlockT = + std::function>( + ConversionPatternRewriter &, Location, const ArrayRef &, Value, + Value, int, ArrayRef, ArrayRef, int, unsigned, + unsigned)>; + +llvm::SmallVector computeOffsetsAType( + ConversionPatternRewriter &rewriter, Location loc, + computeTensorElemMappingInBlockT fn, const ArrayRef &elemsPerInstr, + Value warpId, Value laneId, int warpsPerBlock, int numOfElems, + ArrayRef reps, SharedMemoryObject smemObj, ArrayRef strides, + gpu::SwizzledSharedEncodingAttr srcLayout, unsigned nonKDim, unsigned kDim); + +llvm::SmallVector computeOffsetsBType( + ConversionPatternRewriter &rewriter, Location loc, + computeTensorElemMappingInBlockT fn, const ArrayRef &elemsPerInstr, + Value warpId, Value laneId, int warpsPerBlock, int numOfElems, + ArrayRef reps, SharedMemoryObject smemObj, ArrayRef strides, + gpu::SwizzledSharedEncodingAttr srcLayout, unsigned nonKDim, unsigned kDim); + +} // namespace mlir::triton::AMD + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_CONVERTLAYOUTOPTOLLVM_SHAREDTODOTOPERANDHELPER_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp new file mode 100644 index 000000000..0656061bc --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -0,0 +1,388 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" +#include "SharedToDotOperandHelper.h" +#include "Utility.h" + +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::SwizzledSharedEncodingAttr; + +namespace SharedToDotOperandMFMA { + +/// This function maps particular load of mfma dot operand to element +/// indexes(row, col) +/// +/// Whole tensor is broken into "blocks" of warps along "non-K" axis. +/// One block could be processed by multiple warps. +/// One warp works on a piece of tensor size elemsPerInstr[0] x K. +/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x +/// elemsPerInstr[1]. +/// +/// Total offset of element is a sum of following values: +/// 1. Offset of warp-block in tensor +/// 2. Offset of warp inside one warp-block +/// 3. Offset of tile in one warp +/// 4. Offset of one lane data in a tile +/// 5. Offset of particular element of tensor processed by one lane +/// +/// This function computes these offsets for axies independently +/// Note that this function returns the offsets of elements in the first +/// warp-block. The offsets of elements in later warp-blocks can be computed +/// by adding a constant stride to the xor-ed offsets of elements in the +/// first warp-block. +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape consumed by one MFMA instruction +/// \param warpId id component of 2d warp grid along non-K axis +/// \param laneId lane id in warp [0..63] +/// \param numOfElems number of elements accessed by thread per repetition +/// \param reps number of instructions repetition to fully cover dot operand +/// \param smemStrides strides in LDS tensor +/// \param loadVecSize number of elements loaded by one operation +/// \param iNonKDim non-K dimension size of one MFMA instruction +/// \param iKDim K dimension size of one MFMA instruction +/// \returns vector (i-th element corresponds to i-th load instruction) of +/// 2-element vectors(tensor row and col). +llvm::SmallVector> computeTensorElemMappingInBlock( + ConversionPatternRewriter &rewriter, Location loc, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, + int numOfElems, ArrayRef reps, ArrayRef smemOffsets, + int loadVecSize, unsigned iNonKDim, unsigned iKDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto numM = reps[1]; + auto numK = reps[2]; + const int loadsPerThread = numOfElems / loadVecSize; + llvm::SmallVector> mapping(numK * loadsPerThread); + + Value _0 = b.i32_val(0); + Value _32 = b.i32_val(32); + Value nonKDim = b.i32_val(iNonKDim); + Value warpVOffset = b.mul(warpId, b.i32_val(elemsPerInstr[0])); + + auto rank = smemOffsets.size(); + + for (int tile = 0; tile < numK; ++tile) { + Value tileVOffset = _0; + Value tileHOffset = b.i32_val(tile * elemsPerInstr[1]); + + Value laneVOffset = b.urem(laneId, nonKDim); + Value laneHOffset; + if (iNonKDim == 32) { + laneHOffset = + b.select(b.icmp_uge(laneId, _32), b.i32_val(numOfElems), _0); + } else { + // In this configuration warp contains 16 copies of same data + if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) { + laneHOffset = b.i32_val(0); + } else { + assert(iKDim * iNonKDim / numOfElems == 64 && + "seems no all threads in warp contain unique elements"); + laneHOffset = b.mul(b.udiv(laneId, nonKDim), b.i32_val(numOfElems)); + } + } + + for (int loadId = 0; loadId < loadsPerThread; ++loadId) { + Value elemVOffset = _0; + Value elemHOffset = b.i32_val(loadId * loadVecSize); + + Value sliceVOffset = b.add( + b.add(b.add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); + Value sliceHOffset = b.add(b.add(tileHOffset, laneHOffset), elemHOffset); + + Value row = b.add(sliceVOffset, smemOffsets[rank - 2]); + Value col = b.add(sliceHOffset, smemOffsets[rank - 1]); + + mapping[loadsPerThread * tile + loadId] = {row, col}; + } + } + return mapping; +} + +bool hasSwizzleEnabled(const SwizzledSharedEncodingAttr &srcEncoding) { + return srcEncoding.getMaxPhase() > 1; +} + +/// Computes offsets for operand B or transposed operand A +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA +/// instruction +/// \param warpId warp id for the "non K" axis +/// \param laneId lane id in warp [0..63] +/// \param warpsPerBlock number of warps per horizontal axis +/// \param numOfElems number of elements accessed by threads per repetition +/// \param reps number of instructions repretition to fully cover dot operand +/// \param cSwizzleOffset +llvm::SmallVector +fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, + const ArrayRef &elemsPerInstr, Value warpId, + Value laneId, int warpsPerBlock, int numOfElems, + ArrayRef reps, Value cSwizzleOffset) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto numK = reps[1]; + auto numN = reps[2]; + SmallVector offsets(numK * numN * numOfElems); + + auto iKDim = elemsPerInstr[0]; + auto iNonKDim = elemsPerInstr[1]; + int lineSize = warpsPerBlock * iNonKDim * numN; + Value _nonKDim = b.i32_val(iNonKDim); + Value warpOffset = b.mul(warpId, b.i32_val(iNonKDim)); + Value colOffset = b.urem(laneId, _nonKDim); + + for (int block = 0; block < numN; ++block) { + Value blockOffset = b.i32_val(block * iNonKDim * warpsPerBlock); + for (int tile = 0; tile < numK; ++tile) { + Value tileOffset = b.i32_val(tile * iKDim * lineSize); + for (int elem = 0; elem < numOfElems; ++elem) { + // halfOffset is an offset related to wrapping of warp in the tile. + // for example, mfma 32 case (mapping of tensor elements to lane ids in + // warp): + // + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 32 33 34 35 ... 63 <- at this point warp is wrapping + // 32 33 34 35 ... 63 + // 32 33 34 35 ... 63 + // 32 33 34 35 ... 63 + Value halfOffset; + if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) + halfOffset = b.i32_val(0); + else + halfOffset = + b.mul(b.udiv(laneId, _nonKDim), b.i32_val(numOfElems * lineSize)); + Value rowOffset = b.add(b.i32_val(elem * lineSize), halfOffset); + Value elemOffset = b.add(rowOffset, colOffset); + Value offset = b.add(b.add(b.add(warpOffset, blockOffset), tileOffset), + elemOffset); + offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset; + } + } + } + return offsets; +} + +bool isColMajor(::llvm::ArrayRef order) { + auto rank = order.size(); + return order[0] == (rank - 2); +} + +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, DotOperandEncodingAttr encoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx"); + auto aTensorTy = cast(tensor.getType()); + ArrayRef shape = aTensorTy.getShape(); + auto rank = shape.size(); + int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; + int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; + + auto mfmaLayout = cast(encoding.getParent()); + auto mDim = mfmaLayout.getMDim(); + auto nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + + auto sharedLayout = cast(aTensorTy.getEncoding()); + auto order = sharedLayout.getOrder(); + assert((rank == 2 || order[2] == 0) && + "expect batch to be the slowest dimension"); + + auto elemTy = aTensorTy.getElementType(); + auto kWidth = encoding.getKWidth(); + auto elemsPerInstr = mfmaLayout.getInstrShapeForOperand(kWidth, opIdx); + + int64_t mfmaInstrNonK; + int64_t mfmaInstrK; + // TODO(Lixun): make it simpler + // getInstrShapeForOperand always returns a 2D vector + if (rank == 3) { + mfmaInstrNonK = elemsPerInstr[nonKDimIdx - 1]; + mfmaInstrK = elemsPerInstr[kDimIdx - 1]; + } else { + mfmaInstrNonK = elemsPerInstr[nonKDimIdx]; + mfmaInstrK = elemsPerInstr[kDimIdx]; + } + + if (mfmaInstrNonK > shape[nonKDimIdx] || mfmaInstrK > shape[kDimIdx]) { + // This pattern does not support cases tensor shape is smaller than + // one instruction size, it will be processed by LinearLayout converter + return Value(); + } + + auto numReps = mfmaLayout.getRepForOperand(shape, kWidth, opIdx); + auto numRepNonK = numReps[nonKDimIdx]; + auto numRepK = numReps[kDimIdx]; + auto repB = numReps[0]; + // TODO(Lixun): make it simpler + // getRepForOperand always returns a 3D vector + if (rank == 2) { + numRepNonK = numReps[nonKDimIdx + 1]; + numRepK = numReps[kDimIdx + 1]; + } + + unsigned iWarpSize = triton::gpu::getWarpSize(mfmaLayout); + assert(iWarpSize == 64); + Value warpSize = tb.i32_val(iWarpSize); + Value linearWarpId = tb.udiv(thread, warpSize); + Value lane = tb.urem(thread, warpSize); + + Value spatialWarpId = AMD::getWarpIdInBlock( + rewriter, loc, linearWarpId, warpsPerCTA, mfmaInstrNonK, + shape[nonKDimIdx], nonKDimIdx, mfmaLayout.getDefaultOrder()); + + // number of duplicates of elements in warp + // In case of 64x4 x 4x4 multiplication, 4x4 B operand is duplicated 16 times + int numSubBlocks = 1; + if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4) + numSubBlocks = 16; + // numOfElemsPerThreadPerMfmaInstr + int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWarpSize; + assert(numOfElems >= 1); + + unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK; + int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); + int warpsPerBatch = + rank == 3 ? std::min(shape[0], warpsPerCTA[0]) : 1; + Value warpIdInBatch = tb.urem(linearWarpId, tb.i32_val(warpsPerBatch)); + elemTy = typeConverter->convertType(elemTy); + + SmallVector loadedValues; + SmallVector offsets; + Value smemBase; + auto smemStrides = smemObj.getStrides(aTensorTy, loc, rewriter); + bool isFastPath = + !AMD::isKContig(order, opIdx) && !hasSwizzleEnabled(sharedLayout); + if (isFastPath) { + // fast path handles tensors that are not k-major and have swizzling + // disabled, in which case offsets computation can be simplified + // TODO (zhanglx): later when we enable vector access to LDS for non k-major + // tensors, we'll refactor the scope of fast and normal path + Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); + if (opIdx == 0) { + if (isColMajor(order)) { + SmallVector elemsPerInstr{mfmaInstrK, mfmaInstrNonK}; + SmallVector reps{numReps[0], numReps[2], numReps[1]}; + offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, + spatialWarpId, lane, warpsPerBlockNonK, + numOfElems, reps, cSwizzleOffset); + } else { + llvm_unreachable( + "row major operand A should be handled in the normal path"); + } + } else { + if (isColMajor(order)) { + llvm_unreachable( + "col major operand B should be handled in the normal path"); + } else { + offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, + spatialWarpId, lane, warpsPerBlockNonK, + numOfElems, numReps, cSwizzleOffset); + } + } + smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); + } else { // normal path + // Normal path handles tensors that fall into either of the following three + // cases: + // 1. k-major + swizzling is enabled <-- this should be the most + // performant case + // 2. k-major + swizzling is disabled <-- for testing purpose only + // 3. non k-major + swizzling is enabled <-- for testing purpose only + if (opIdx == 0) { + offsets = AMD::computeOffsetsAType( + rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, + smemStrides, sharedLayout, mDim, mfmaInstrK); + } else { + assert(opIdx == 1); + offsets = AMD::computeOffsetsBType( + rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, + smemStrides, sharedLayout, nDim, mfmaInstrK); + } + smemBase = AMD::computeBasePtr(rewriter, loc, smemObj, smemStrides); + } + + Type resElemTy = typeConverter->convertType(elemTy); + Type smemPtrTy = ptr_ty(rewriter.getContext(), 3); + + int loadsPerThread = offsets.size() / numRepK / numRepNonK; + int elemsPerLoad = numOfElems / loadsPerThread; + assert(numOfElems % loadsPerThread == 0); + + VectorType loadVecTy = vec_ty(elemTy, elemsPerLoad); + for (int b = 0; b < repB; ++b) { + int operandSize = shape[rank - 1] * shape[rank - 2]; + Value batchOffset = + tb.mul(tb.i32_val(operandSize), + tb.add(warpIdInBatch, tb.i32_val(b * warpsPerBatch))); + for (int nonK = 0; nonK < numRepNonK; ++nonK) { + int blockNonKOffset = nonK * mfmaInstrNonK * warpsPerBlockNonK; + Value warpBlockOffAdjust = tb.i32_val(blockNonKOffset * shape[order[0]]); + for (int k = 0; k < numRepK; ++k) { + auto vecTy = vec_ty(resElemTy, numOfElems); + for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { + Value loadOffset; + loadOffset = offsets[nonK * loadsPerThread * numRepK + + k * loadsPerThread + loadId]; + loadOffset = tb.add(loadOffset, batchOffset); + Value loadAddress = tb.gep(smemPtrTy, elemTy, smemBase, loadOffset); + Value loadedValue = tb.load(loadVecTy, loadAddress); + for (int elemId = 0; elemId < elemsPerLoad; ++elemId) { + Value elemVal = + tb.extract_element(elemTy, loadedValue, tb.i32_val(elemId)); + loadedValues.push_back(elemVal); + } + } + } + } + } + + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + + MLIRContext *ctx = mfmaLayout.getContext(); + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); + auto result = + packLLElements(loc, typeConverter, loadedValues, rewriter, structTy); + return result; +} + +} // namespace SharedToDotOperandMFMA diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp new file mode 100644 index 000000000..9e1e043f1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" +#include "SharedToDotOperandHelper.h" +#include "Utility.h" + +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::SwizzledSharedEncodingAttr; + +namespace SharedToDotOperandWMMA { + +/// Following functions maps particular load of wmma dot operand to +/// element indexes(row, col). For each WMMA generation separate function is +/// used. +/// +/// Whole tensor is broken into "blocks" of warps along "non-K" axis. +/// One block could be processed by multiple warps. +/// One warp works on a piece of tensor size elemsPerInstr[0] x K. +/// Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x +/// elemsPerInstr[1]. +/// +/// Total offset of element is a sum of following values: +/// 1. Offset of warp block in tensor +/// 2. Offset of warp inside one warp block +/// 3. Offset of tile in one warp +/// 4. Offset of one lane data in a tile +/// 5. Offset of particular element of tensor processed by one lane +/// +/// This function computes these offsets for axes independently +/// +/// \param rewriter +/// \param loc +/// \param elemsPerInstr operand tile shape consumed by one WMMA instruction +/// \param warpId id component of 2d warp grid along non-K axis +/// \param laneId lane id in warp [0..63] +/// \param numOfElems number of elements accessed by thread per repetition +/// \param reps number of instructions repetition to fully cover dot operand +/// \param smemStrides strides in LDS tensor +/// \param loadVecSize number of elements loaded by one operation +/// \param iNonKDim non-K dimension of dot operand +/// \returns vector (i-th element corresponds to i-th load instruction) of +/// 2-element vectors(tensor row and col). +llvm::SmallVector> +computeTensorElemMappingInBlockWmma1( + ConversionPatternRewriter &rewriter, Location loc, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, + int numOfElems, ArrayRef reps, ArrayRef smemOffsets, + int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(reps.size() == 3); + assert(elemsPerInstr.size() == 2); + auto numK = reps[2]; + const int loadsPerThread = numOfElems / loadVecSize; + llvm::SmallVector> mapping(numK * loadsPerThread); + + Value elemsPerInstrV = b.i32_val(elemsPerInstr[0]); + Value warpVOffset = b.mul(warpId, elemsPerInstrV); + Value sliceVOffset = b.add(b.urem(laneId, elemsPerInstrV), warpVOffset); + auto rank = smemOffsets.size(); + Value row = b.add(sliceVOffset, smemOffsets[rank - 2]); + + for (int tile = 0; tile < numK; ++tile) { + Value tileHOffset = b.i32_val(tile * elemsPerInstr[1]); + + for (int loadId = 0; loadId < loadsPerThread; ++loadId) { + Value elemHOffset = b.i32_val(loadId * loadVecSize); + Value sliceHOffset = b.add(tileHOffset, elemHOffset); + + Value col = b.add(sliceHOffset, smemOffsets[rank - 1]); + mapping[loadsPerThread * tile + loadId] = {row, col}; + } + } + + return mapping; +} + +llvm::SmallVector> +computeTensorElemMappingInBlockWmma2( + ConversionPatternRewriter &rewriter, Location loc, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, + int numOfElems, ArrayRef reps, ArrayRef smemOffsets, + int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(reps.size() == 3); + assert(elemsPerInstr.size() == 2); + auto numK = reps[2]; + const int loadsPerThread = numOfElems / loadVecSize; + llvm::SmallVector> mapping(numK * loadsPerThread); + + Value rowsPerInstr = b.i32_val(elemsPerInstr[0]); + Value colsPerInstr = b.i32_val(elemsPerInstr[1]); + Value elemsPerThread = b.i32_val(elemsPerInstr[1] / 2); + Value warpVOffset = b.mul(warpId, rowsPerInstr); + Value sliceVOffset = b.add(b.urem(laneId, rowsPerInstr), warpVOffset); + + auto rank = smemOffsets.size(); + Value row = b.add(sliceVOffset, smemOffsets[rank - 2]); + Value laneHOffset = b.mul(b.udiv(laneId, colsPerInstr), elemsPerThread); + + for (int tile = 0; tile < numK; ++tile) { + Value tileHOffset = b.add(laneHOffset, b.i32_val(tile * elemsPerInstr[1])); + for (int loadId = 0; loadId < loadsPerThread; ++loadId) { + Value elemHOffset = b.i32_val(loadId * loadVecSize); + Value sliceHOffset = b.add(tileHOffset, elemHOffset); + + Value col = b.add(sliceHOffset, smemOffsets[rank - 1]); + + mapping[loadsPerThread * tile + loadId] = {row, col}; + } + } + + return mapping; +} + +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, DotOperandEncodingAttr encoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx"); + auto rank = smemObj.getOffsets().size(); + int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; + int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; + + auto wmmaLayout = cast(encoding.getParent()); + auto computeTensorElemMappingInBlock = + wmmaLayout.getVersion() == 1 ? computeTensorElemMappingInBlockWmma1 + : computeTensorElemMappingInBlockWmma2; + assert(wmmaLayout.getMNKDimPerInstr()[nonKDimIdx] == 16); + auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + + auto aTensorTy = cast(tensor.getType()); + ArrayRef shape = aTensorTy.getShape(); + auto sharedLayout = cast(aTensorTy.getEncoding()); + auto order = sharedLayout.getOrder(); + assert((rank == 2 || order[2] == 0) && + "expect batch to be the slowest dimension"); + + auto elemTy = aTensorTy.getElementType(); + int kWidth = encoding.getKWidth(); + auto elemsPerInstr = wmmaLayout.getElemsPerInstrForOperands(); + auto wmmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0]; + auto wmmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1]; + assert(wmmaInstrNonK == 16); + + auto numReps = wmmaLayout.getRepForOperand(shape, elemTy, kWidth, opIdx); + auto numRepNonK = numReps[opIdx == 0 ? 1 : 2]; + auto numRepK = numReps[opIdx == 0 ? 2 : 1]; + auto repB = numReps[0]; + + unsigned iWaveSize = triton::gpu::getWarpSize(wmmaLayout); + assert(iWaveSize == 32); + Value waveSize = tb.i32_val(iWaveSize); + Value linearWaveId = tb.udiv(thread, waveSize); + + unsigned numElemsPerThreadPerRep = wmmaLayout.getKWidthForOperands(); + + Value lane = tb.urem(thread, waveSize); + unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK; + int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); + int warpsPerBatch = + rank == 3 ? std::min(shape[0], warpsPerCTA[0]) : 1; + Value waveIdInBatch = tb.urem(linearWaveId, tb.i32_val(warpsPerBatch)); + elemTy = typeConverter->convertType(elemTy); + + SmallVector loadedValues; + SmallVector offsets; + Value smemBase; + auto smemStrides = smemObj.getStrides(aTensorTy, loc, rewriter); + Value spatialWarpId = AMD::getWarpIdInBlock( + rewriter, loc, linearWaveId, warpsPerCTA, elemsPerInstr[0], + shape[nonKDimIdx], nonKDimIdx, wmmaLayout.getDefaultOrder()); + if (opIdx == 0) { + offsets = AMD::computeOffsetsAType( + rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, + spatialWarpId, lane, warpsPerBlockNonK, numElemsPerThreadPerRep, + numReps, smemObj, smemStrides, sharedLayout, wmmaInstrNonK, wmmaInstrK); + } else { + assert(opIdx == 1); + offsets = AMD::computeOffsetsBType( + rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, + spatialWarpId, lane, warpsPerBlockNonK, numElemsPerThreadPerRep, + numReps, smemObj, smemStrides, sharedLayout, wmmaInstrNonK, wmmaInstrK); + } + smemBase = AMD::computeBasePtr(rewriter, loc, smemObj, smemStrides); + + Type resElemTy = typeConverter->convertType(elemTy); + Type smemPtrTy = ptr_ty(rewriter.getContext(), 3); + + int loadsPerThread = offsets.size() / (numRepNonK * numRepK); + int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread; + assert(numElemsPerThreadPerRep % loadsPerThread == 0); + auto loadVecTy = vec_ty(elemTy, elemsPerLoad); + for (int b = 0; b < repB; ++b) { + int operandSize = shape[rank - 1] * shape[rank - 2]; + Value batchOffset = + tb.mul(tb.i32_val(operandSize), + tb.add(waveIdInBatch, tb.i32_val(b * warpsPerBatch))); + for (int nonK = 0; nonK < numRepNonK; ++nonK) { + for (int k = 0; k < numRepK; ++k) { + auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); + Value valVec = tb.undef(vecTy); + for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { + Value loadOffset = offsets[nonK * loadsPerThread * numRepK + + k * loadsPerThread + loadId]; + loadOffset = tb.add(loadOffset, batchOffset); + Value loadAddress = tb.gep(smemPtrTy, elemTy, smemBase, loadOffset); + Value loadedValue = tb.load(loadVecTy, loadAddress); + for (int elemId = 0; elemId < elemsPerLoad; ++elemId) { + Value elemVal = + tb.extract_element(elemTy, loadedValue, tb.i32_val(elemId)); + loadedValues.push_back(elemVal); + } + } + } + } + } + + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + + MLIRContext *ctx = wmmaLayout.getContext(); + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); + auto result = + packLLElements(loc, typeConverter, loadedValues, rewriter, structTy); + return result; +} + +} // namespace SharedToDotOperandWMMA diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp new file mode 100644 index 000000000..78cbdb781 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -0,0 +1,138 @@ +#include "OptimizeLDSUtility.h" +#include "TargetInfo.h" +#include "TritonAMDGPUToLLVM/Passes.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include + +using namespace mlir; +namespace mlir::triton { +#define GEN_PASS_DEF_DECOMPOSEUNSUPPORTEDAMDCONVERSIONS +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +namespace { + +struct DecomposeUnsupportedAMDConversions + : public mlir::triton::impl::DecomposeUnsupportedAMDConversionsBase< + DecomposeUnsupportedAMDConversions> { + explicit DecomposeUnsupportedAMDConversions(StringRef targetArch) { + this->arch = targetArch.str(); + } + + void runOnOperation() override { + triton::AMD::TargetInfo targetInfo(this->arch.getValue()); + int sharedMemoryLimit = targetInfo.getSharedMemorySize(); + + ModuleOp mod = getOperation(); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + + auto shortcutFn = [](RankedTensorType srcTy, RankedTensorType dstTy) { + auto srcWmma = + dyn_cast(srcTy.getEncoding()); + auto dstDotOp = + dyn_cast(dstTy.getEncoding()); + return !cvtNeedsSharedMemory(srcTy, dstTy) && !(srcWmma && dstDotOp); + }; + + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, shortcutFn); + + // Try to reduce LDS usage of cvt(mfma->blocked) op by changing the shape of + // WarpsPerCta attribute in mfma layout. The implicit LDS usage of + // cvt(mfma->blocked) op depends on the number of warps per CTA that mfma + // layout uses along x dimension and block layout uses across y dimension. + // + // clang-format off + // + // LDS usage of this op is roughly calculated as: + // LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type) + // LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C, + // where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type) + // + // clang-format on + // + // When LDS_USAGE exceeds the size of LDS, try to lower LDS usage by + // decomposing cvt(mfma->blocked) op into 2 conversions: cvt(mfma->mfma_tmp) + // and cvt(mfma_tmp->blocked), where mfma_tmp has WarpsPerCta attribute that + // minimizes uses of LDS for these conversions. + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + + auto srcType = cvtOp.getSrc().getType(); + auto dstType = cvtOp.getType(); + + auto srcEnc = + cast(srcType.getEncoding()); + auto dstBlocked = + dyn_cast(dstType.getEncoding()); + + // TODO: Reduce LDS usage for WMMA dots + if (!isa(srcEnc) || !dstBlocked) { + return; + } + + auto currLDSUsage = triton::AMD::getCvtOpLDSUsage(cvtOp); + if (currLDSUsage <= sharedMemoryLimit) { + return; + } + + unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); + + // Find all possible shapes of WarpsPerCTA by finding all possible + // factorizations of numWarps. Pick shape for which both conversions in + // decomposition use LDS less than sharedMemoryLimit and for which sum of + // LDS usage is minimal. If no such shape exists, do not decompose. + unsigned minLDSUsage = 2 * sharedMemoryLimit; + int minIdx = -1; + int rank = dstBlocked.getWarpsPerCTA().size(); + auto factorizedNumWarps = + mlir::triton::AMD::factorizePowerOf2(numWarps, rank); + + SmallVector tmpLayouts; + for (int i = 0; i < factorizedNumWarps.size(); i++) { + auto warpsPerCTA = factorizedNumWarps[i]; + tmpLayouts.push_back( + mlir::triton::AMD::createTmpLayout(srcEnc, warpsPerCTA)); + } + + for (int i = 0; i < tmpLayouts.size(); i++) { + auto resources = mlir::triton::AMD::estimateResourcesForReplacement( + builder, cvtOp, tmpLayouts[i]); + if (resources.LDS <= sharedMemoryLimit && resources.LDS < minLDSUsage) { + minLDSUsage = resources.LDS; + minIdx = i; + } + } + + if (minIdx == -1 || minLDSUsage > sharedMemoryLimit) { + return; + } + + assert(minIdx >= 0 && minIdx < tmpLayouts.size()); + auto replacementCvts = mlir::triton::AMD::createNewConvertOps( + builder, cvtOp, tmpLayouts[minIdx]); + + cvtOp.replaceAllUsesWith(replacementCvts.second.getResult()); + cvtOp.erase(); + }); + + triton::gpu::decomposeBlockedToDotLayoutConversion(mod); + } +}; + +} // namespace + +namespace mlir::triton::AMD { + +std::unique_ptr> +createDecomposeUnsupportedConversionsPass(StringRef targetArch) { + return std::make_unique(targetArch); +} + +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp new file mode 100644 index 000000000..641a71b36 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp @@ -0,0 +1,84 @@ +#include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +using namespace mlir; + +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; + +namespace mlir::triton::AMD { +LogicalResult convertAMDFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + +LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + +LogicalResult convertScaledMFMA(triton::DotScaledOp op, + triton::DotScaledOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + +LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +} // namespace mlir::triton::AMD + +namespace { +struct DotOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + // D = A * B + C + Value D = op.getResult(); + + auto dEncoding = cast(D.getType()).getEncoding(); + if (isa(dEncoding)) { + return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter); + } + if (isa(dEncoding)) { + return AMD::convertWMMA(op, adaptor, getTypeConverter(), rewriter); + } + + if (isa( + cast(D.getType()).getEncoding())) + return AMD::convertAMDFMADot(op, adaptor, getTypeConverter(), rewriter); + + llvm::report_fatal_error( + "Unsupported DotOp found when converting TritonGPU to LLVM."); + } +}; + +struct ScaledDotOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + int mfmaVersion; + int nonKDim; + int kPack; + + ScaledDotOpConversion(LLVMTypeConverter &typeConverter, int mfmaVersion, + int nonKDim, int kPack, PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + mfmaVersion(mfmaVersion), nonKDim(nonKDim), kPack(kPack) {} + + LogicalResult + matchAndRewrite(triton::DotScaledOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return AMD::convertScaledMFMA(op, adaptor, getTypeConverter(), rewriter); + } +}; +} // namespace + +namespace mlir::triton::AMD { +void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 000000000..95d06d65f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,131 @@ +#include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace ::mlir::triton::gpu; + +namespace { + +struct DotIntrinsic { + int vectorSize; + Type outElemTy; + StringRef intrinsicName; + SmallVector additionalArgs; +}; + +class AMDFMAVectorMultiplier : public FMAVectorMultiplier { + ConversionPatternRewriter &rewriter; + Location loc; + DotIntrinsic intrinsic; + + DotIntrinsic chooseIntrinsic(DotOp op) { + auto aOpTy = cast(op.getA().getType()); + auto aElemTy = aOpTy.getElementType(); + auto bOpTy = cast(op.getA().getType()); + auto bElemTy = aOpTy.getElementType(); + assert(aElemTy == bElemTy); + auto dOpTy = cast(op.getD().getType()); + auto dElemTy = dOpTy.getElementType(); + auto mod = op->getParentOfType(); + auto arch = getAMDArch(mod); + DotIntrinsic chosenOp; + + bool dotAvailable = AMD::supportsVDot(arch); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (dotAvailable) { + if (aElemTy.isF16() && dElemTy.isF32()) { + chosenOp.vectorSize = 2; + chosenOp.outElemTy = f32_ty; + chosenOp.intrinsicName = "llvm.amdgcn.fdot2"; + chosenOp.additionalArgs = {b.false_val()}; + return chosenOp; + } + if (aElemTy.isInteger(8) && dElemTy.isInteger(32)) { + chosenOp.vectorSize = 4; + chosenOp.outElemTy = i32_ty; + chosenOp.intrinsicName = "llvm.amdgcn.sdot4"; + chosenOp.additionalArgs = {b.false_val()}; + return chosenOp; + } + } + // choose one of FMA intrinsics + assert(aElemTy.isIntOrFloat() && !aElemTy.isIntOrIndex()); + assert(aElemTy == dElemTy); + assert(cast(op.getA().getType()).getElementType() == + dElemTy); + chosenOp.vectorSize = 1; + chosenOp.outElemTy = aElemTy; + if (aElemTy.isF32()) + chosenOp.intrinsicName = "llvm.fmuladd.f32"; + if (aElemTy.isF16()) + chosenOp.intrinsicName = "llvm.fmuladd.f16"; + chosenOp.additionalArgs = {}; + return chosenOp; + } + + Value packOperand(ArrayRef scalarValues, int firstElemPos, + unsigned vectorSize) { + if (vectorSize == 1) + return scalarValues[firstElemPos]; + auto elemTy = scalarValues[firstElemPos].getType(); + auto vecTy = vec_ty(elemTy, vectorSize); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value vec = b.undef(vecTy); + for (int elem = 0; elem < vectorSize; ++elem) { + int elemPos = firstElemPos + elem; + vec = + b.insert_element(vecTy, vec, scalarValues[elemPos], b.i32_val(elem)); + } + if (elemTy.isInteger(8)) { + assert(vectorSize == 4); + vec = b.bitcast(vec, i32_ty); + } + return vec; + } + + Value generateDotInstr(Value a, Value b, Value c) { + SmallVector args{a, b, c}; + args.append(intrinsic.additionalArgs.begin(), + intrinsic.additionalArgs.end()); + SmallVector argTypes; + for (auto arg : args) + argTypes.push_back(arg.getType()); + auto funcType = LLVM::LLVMFunctionType::get(intrinsic.outElemTy, argTypes); + auto d = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, intrinsic.intrinsicName, intrinsic.outElemTy, args); + return d.getResult(0); + } + +public: + AMDFMAVectorMultiplier(ConversionPatternRewriter &rewriter, DotOp op) + : rewriter(rewriter), loc(op.getLoc()), intrinsic(chooseIntrinsic(op)) {} + + Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) override { + auto kSize = a.size(); + assert(b.size() == kSize); + Value accum = c; + for (int k = 0; k < kSize; k += intrinsic.vectorSize) { + auto aOp = packOperand(a, k, intrinsic.vectorSize); + auto bOp = packOperand(b, k, intrinsic.vectorSize); + accum = generateDotInstr(aOp, bOp, accum); + } + return accum; + } +}; + +} // namespace + +namespace mlir::triton::AMD { + +LogicalResult convertAMDFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + AMDFMAVectorMultiplier multiplier(rewriter, op); + return parametricConvertFMADot(op, adaptor, typeConverter, rewriter, + multiplier); +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp new file mode 100644 index 000000000..2917d8463 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -0,0 +1,816 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" +#include "TritonAMDGPUTransforms/MfmaGroup.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType; +using ::mlir::LLVM::AMD::shuffleXor; +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::LinearEncodingAttr; + +using ValueTable = std::map, Value>; + +/// Get matrix format flag passed through BLGP/CBSZ args in V_MFMA_*_F8F6F4 +/// instructions. +/// +/// Values: +/// - 0: E4M3(FP8) +/// - 1: E5M2(BF8) +/// - 2: E2M3(FP6) +/// - 3: E3M2(BF6) +/// - 4: E2M1(FP4) +static inline int32_t getMfmaF8F6F4MatrixFormat(Type t) { + return llvm::TypeSwitch(t) + .Case([](Type) { return 0; }) + .Case([](Type) { return 1; }) + .Case([](Type) { return 2; }) + .Case([](Type) { return 3; }) + .Case([](Type) { return 4; }) + .Default([](Type) { return -1; }); +} + +struct DotOpMFMAConversionHelper { + AMDMfmaEncodingAttr mfmaLayout; + + ConversionPatternRewriter &rewriter; + const LLVMTypeConverter *typeConverter; + Location loc; + MLIRContext *ctx{}; + + explicit DotOpMFMAConversionHelper(AMDMfmaEncodingAttr mfmaLayout, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + Location loc) + : mfmaLayout(mfmaLayout), rewriter(rewriter), + typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {} + + Value generateMFMAOp(StringRef intrinsicName, Value valA, Value valB, + Value valC) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto resType = valC.getType(); + Value zeroFlag = b.i32_val(0); + OperationState loweredOp(loc, intrinsicName); + loweredOp.addTypes(resType); + loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + return rewriter.create(loweredOp)->getResult(0); + } + + int getNumSubmatrices(Type elementType, int mDim, int nDim) const { + if ((mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)) + return 1; + assert(mDim == nDim); + switch (mDim) { + case 32: + case 16: + return 1; + break; + case 4: + assert(elementType.getIntOrFloatBitWidth() <= 32 && + "fp64 is not supported yet"); + assert(elementType.getIntOrFloatBitWidth() != 8 || + elementType.isInteger(8) && "fp8 is not supported yet"); + return 16; + break; + default: + llvm::report_fatal_error("unsupported nonKDim in MFMA dot"); + } + return -1; + } + + Value processSubBlocks(int numSubBlocks, Value acc, bool reduceSubBlocks, + bool zeroSubBlocks) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert((numSubBlocks & (numSubBlocks - 1)) == 0 && + "numSubBlocks in not pow 2!"); + if (numSubBlocks == 1) + return acc; + constexpr int warpSize = 64; + int subBlockSize = warpSize / numSubBlocks; + Value laneId = getLaneId(rewriter, loc); + auto vecTy = dyn_cast(acc.getType()); + auto elemType = vecTy.getElementType(); + assert(elemType.getIntOrFloatBitWidth() == 32); + int numScalars = vecTy.getNumElements(); + std::vector accScalar(numScalars); + for (int i = 0; i < numScalars; ++i) + accScalar[i] = b.extract_element(elemType, acc, b.i32_val(i)); + + if (reduceSubBlocks) { + while (subBlockSize < warpSize) { + for (int i = 0; i < numScalars; ++i) { + Value other_acc = + shuffleXor(loc, rewriter, accScalar[i], subBlockSize); + if (elemType.isInteger(32)) + accScalar[i] = b.add(accScalar[i], other_acc); + else + accScalar[i] = b.fadd(accScalar[i], other_acc); + } + subBlockSize *= 2; + } + } + if (zeroSubBlocks) { + Value zero; + if (elemType.isInteger(32)) + zero = b.i32_val(0); + else + zero = b.f32_val(0.0); + auto cond = b.icmp_ult(laneId, b.i32_val(subBlockSize)); + for (int i = 0; i < numScalars; ++i) + accScalar[i] = b.select(cond, accScalar[i], zero); + } + + Value reducedAcc = b.undef(vecTy); + for (int i = 0; i < numScalars; ++i) + reducedAcc = + b.insert_element(vecTy, reducedAcc, accScalar[i], b.i32_val(i)); + return reducedAcc; + } + + /// @brief MFMA 4x4 is computes 16 matrix multiplications, this functions adds + /// these 16 matrices to get final 4x4 matrix + /// @param numSubBlocks + /// @param acc + /// @return + Value reduceSubBlocks(int numSubBlocks, Value acc) const { + return processSubBlocks(numSubBlocks, acc, true, false); + } + + /// @brief Zeroes out redundant values in all sub-blocks except first one + /// + /// Every warp in mfma 4x4 layout holds only 4 unique values(scalar or + /// vectors) in blocks of 4 consecutive threads, There are 16 copies of these + /// 4 values across all threads of the warp. Need to zero out 15 copies to use + /// accumulator between dot operations. + /// @param numSubBlocks + /// @param acc + /// @return + Value zeroAuxiliarBlocks(int numSubBlocks, Value acc) const { + return processSubBlocks(numSubBlocks, acc, false, true); + } + + /// Dot operand layout minimal tile is kDimInstrSize elements across + /// K dimension. If dot operand K dimension is smaller, layout + /// assigns tensor elements to multiple different hardware locations. + /// In this case mfma instruction adds elements in accumulator + /// multiple times. + /// + /// Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11 + /// Consider instruction K size is 4, + /// in this case operands will be duplicated: + /// A' = [1,2,1,2] B' = [3,4,3,4] + /// C' = (1*3+2*4) + (1*3+2*4) = 22 + /// + /// Following code adjusts accumulator values in such cases. + /// If accumulator is integer, shift accumulator right by + /// log2(duplicationRate). If accumulator is float, multiply accum + /// with 1/duplicationRate constant. + void adjustAccForSmallKDim(SmallVector &fc, Value &acc, Type dstElemTy, + int b, int m, int n, int64_t numRepM, + int64_t numRepN, int64_t kDimInstrSize, + int64_t kDimOperandSize, + unsigned elemsPerVec) const { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + for (unsigned v = 0; v < elemsPerVec; ++v) { + Value accElem = tb.extract_element(dstElemTy, acc, tb.i32_val(v)); + if (kDimInstrSize > kDimOperandSize) { + assert(kDimInstrSize % kDimOperandSize == 0); + int duplicationRate = kDimInstrSize / kDimOperandSize; + assert(llvm::isPowerOf2_32(duplicationRate)); + if (dstElemTy.isInteger()) { + auto shiftSize = llvm::Log2_32(duplicationRate); + assert(!accElem.getType().isUnsignedInteger() && + "MFMA uses signed accumulator"); + accElem = tb.ashr(accElem, tb.i32_val(shiftSize)); + } else { + auto multiplierAttr = + rewriter.getFloatAttr(dstElemTy, 1.0 / duplicationRate); + auto multiplierVal = + rewriter.create(loc, dstElemTy, multiplierAttr); + accElem = tb.fmul(accElem, multiplierVal); + } + } + auto linearIdx = b * numRepM * numRepN * elemsPerVec + + m * numRepN * elemsPerVec + n * elemsPerVec + v; + fc[linearIdx] = accElem; + } + } + + template + void packAndReplaceResult(T &op, SmallVector &fc, + const FailureOr &maybeMfmaIntrinsic, + Type dstElemTy, Type elemtTy, + size_t mmaCount) const { + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(fc.size(), dstElemTy)); + Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + setNumGeneratedMMAs(op, mmaCount, maybeMfmaIntrinsic->mDim, + maybeMfmaIntrinsic->nDim, maybeMfmaIntrinsic->kDim, + elemtTy); + + rewriter.replaceOp(op, res); + } + + // Conduct the Dot conversion. + LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + // Check if this dot has come with priority set by setprio. + auto setPrioOp = dyn_cast_or_null(op->getPrevNode()); + + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + auto mDim = mfmaLayout.getMDim(); + auto nDim = mfmaLayout.getNDim(); + auto mfmaVersion = mfmaLayout.getVersionMajor(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + + Value a = op.getA(); + Value b = op.getB(); + Value d = op.getD(); + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); + auto dTensorTy = cast(d.getType()); + auto elemTyA = aTensorTy.getElementType(); + auto elemTyB = bTensorTy.getElementType(); + + const auto kDimOperandSize = aTensorTy.getShape().back(); + + bool allowXF32 = + op.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3; + StringRef intrinsicName; + FailureOr maybeMfmaIntrinsic = MfmaIntrinsic::selectFor( + mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB, + /*withScale=*/false, allowXF32); + if (failed(maybeMfmaIntrinsic)) + llvm::report_fatal_error("No match found in MFMA database\n"); + + intrinsicName = maybeMfmaIntrinsic->name; + unsigned kBase = maybeMfmaIntrinsic->kBase; + + auto aEncoding = cast(aTensorTy.getEncoding()); + auto bEncoding = cast(bTensorTy.getEncoding()); + int kWidth = aEncoding.getKWidth(); + + // If we are using XF32, the kWidth (and kBase) is double that of F32. + if (aTensorTy.getElementType().isF32() && allowXF32) + kWidth *= 2; + + const auto kDimInstrSize = mfmaLayout.getInstrShapeForOperand(kWidth, 0)[1]; + + auto repA = mfmaLayout.getRepForOperand(aTensorTy.getShape(), kWidth, 0); + auto repB = mfmaLayout.getRepForOperand(bTensorTy.getShape(), kWidth, 1); + + assert(repA[2] == repB[1]); + + Value loadedA = adaptor.getA(); + Value loadedB = adaptor.getB(); + Value loadedC = adaptor.getC(); + + auto numRepM = repA[1]; + auto numRepN = repB[2]; + auto numRepK = repA[2]; + auto numRepB = repA[0]; + assert(repA[0] == repB[0]); + + bool preserveBF16 = intrinsicName.contains(".bf16") && mfmaVersion >= 4; + auto operandA = getValuesFromDotOperandLayoutStruct( + loadedA, numRepB, numRepM, numRepK, kWidth, kBase, + aTensorTy.getElementType(), allowXF32, preserveBF16); + auto operandB = getValuesFromDotOperandLayoutStruct( + loadedB, numRepB, numRepN, numRepK, kWidth, kBase, + aTensorTy.getElementType(), allowXF32, preserveBF16); + + auto dstElemTy = dTensorTy.getElementType(); + auto fc = unpackLLElements(loc, loadedC, rewriter); + + unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout); + // compute number of output elements that each thread holds for one MFMA + // instruction. + const int subBlocks = + getNumSubmatrices(aTensorTy.getElementType(), mDim, nDim); + auto elemsPerVec = mDim * nDim * subBlocks / warpSize; + + Value firstMfma; + auto vecTy = vec_ty(dstElemTy, elemsPerVec); + for (int b = 0; b < numRepB; ++b) { + for (int m = 0; m < numRepM; ++m) { + for (int n = 0; n < numRepN; ++n) { + Value acc = tb.undef(vecTy); + for (unsigned v = 0; v < elemsPerVec; ++v) { + acc = tb.insert_element( + vecTy, acc, + fc[b * numRepM * numRepN * elemsPerVec + + m * numRepN * elemsPerVec + n * elemsPerVec + v], + tb.i32_val(v)); + } + acc = zeroAuxiliarBlocks(subBlocks, acc); + for (int k = 0; k < numRepK; k++) { + for (int kPack = 0; kPack < kWidth / kBase; ++kPack) { + acc = mfmaLayout.getIsTransposed() + ? generateMFMAOp(intrinsicName, + operandB[kPack][{b, n, k}], + operandA[kPack][{b, m, k}], acc) + : generateMFMAOp(intrinsicName, + operandA[kPack][{b, m, k}], + operandB[kPack][{b, n, k}], acc); + if (!firstMfma) + firstMfma = acc; + } + } + acc = reduceSubBlocks(subBlocks, acc); + adjustAccForSmallKDim(fc, acc, dstElemTy, b, m, n, numRepM, numRepN, + kDimInstrSize, kDimOperandSize, elemsPerVec); + } + } + } + + // Originally, setprio (high) is set to the high-level dot op. After dot is + // being lowered to the series of mfma operations, it should be moved next + // to the first mfma leaving the first mfma staying at the low priority. In + // this way, incoming warp can be effectively waiting on the first mfma + // instruction (low priority) while the other warp is executing mfma with + // high priority. Otherwise, incoming warp can break the cluster. + if (setPrioOp && firstMfma) + setPrioOp->moveAfter(firstMfma.getDefiningOp()); + + const size_t mmaCount = + numRepB * numRepM * numRepN * numRepK * kWidth / kBase; + packAndReplaceResult(op, fc, maybeMfmaIntrinsic, dstElemTy, elemTyA, + mmaCount); + + return success(); + } + + /// Extract vector from rawElems based on kWidth and kBase + /// rawElems is a vector of kWidth elements. We need to prepare vector(s) of + /// kBase elements for each mfma instruction + SmallVector extractOperands(Value rawElems, int kWidth, int kBase, + Type type, bool preserveBF16, + bool isConstantScale = false) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + int kpack = kWidth / kBase; + SmallVector results; + auto vecTy = vec_ty(type, kBase); + if (type.isBF16() && !preserveBF16) + vecTy = vec_ty(i16_ty, kBase); + for (int k = 0; k < kpack; ++k) { + Value vec = b.undef(vecTy); + for (int elemId = 0; elemId < kBase; ++elemId) { + auto val = + b.extract_element(type, rawElems, b.i32_val(elemId + k * kBase)); + if (type.isBF16() && !preserveBF16) { + // rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type + auto cast = b.bitcast(val, i16_ty); + vec = b.insert_element(vecTy, vec, cast, b.i32_val(elemId)); + } else { + vec = b.insert_element(vecTy, vec, val, b.i32_val(elemId)); + } + } + if (type.getIntOrFloatBitWidth() == 8) { + if (1 == kBase) { + // This is only for the scale operands of scaled mfma on MI350 + if (isConstantScale) { + // If the scale is constant(created by arith::ConstantOp), it will + // be put in a sgpr instead of vgpr. In that case, instead of + // vgpr[7:0], the instruction reads sgpr[30:23] as the scale value. + // So we need to manually left shift the scale by 23 bits to meet + // the requirement. + results.push_back(b.shl( + i32_ty, b.zext(i32_ty, b.bitcast(vec, i8_ty)), b.i32_val(23))); + } else { + results.push_back(b.zext(i32_ty, b.bitcast(vec, i8_ty))); + } + } + if (4 == kBase) + // This is for int8 on pre- MI300 GPUs + results.push_back(b.bitcast(vec, i32_ty)); + if (8 == kBase) + results.push_back(b.bitcast(vec, i64_ty)); + if (16 == kBase) + // This is only for the operands of scaled mfma on MI350 + results.push_back(b.bitcast(vec, vec_ty(i32_ty, 4))); + if (32 == kBase) + results.push_back(b.bitcast(vec, vec_ty(i32_ty, 8))); + } else { + results.push_back(vec); + } + } + return results; + } + + /// Converts dot operand structure to value table and converts types + /// appropriate for mfma instructions + virtual SmallVector getValuesFromDotOperandLayoutStruct( + Value value, int batch, int n0, int n1, int kWidth, int kBase, Type type, + bool allowXF32, bool preserveBF16, bool isConstantScale = false) const { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + auto elems = unpackLLElements(loc, value, rewriter); + int kpack = kWidth / kBase; + SmallVector dotOpVals(kpack); + for (int b = 0; b < batch; ++b) { + for (int i = 0; i < n0; i++) { + for (int j = 0; j < n1; j++) { + Type elemTy = typeConverter->convertType(type); + Type ty = vec_ty(elemTy, kWidth); + Value rawElems = tb.undef(ty); + for (int k = 0; k < kWidth; ++k) { + rawElems = tb.insert_element( + ty, rawElems, + elems[kWidth * n1 * n0 * b + kWidth * n1 * i + kWidth * j + k], + tb.i32_val(k)); + } + + Value convertedElems; + if (type.isF32() && !allowXF32) { + for (int k = 0; k < kpack; ++k) + dotOpVals[k][{b, i, j}] = + tb.extract_element(type, rawElems, tb.i32_val(k)); + } else { + SmallVector vals; + if (type.isF32() && allowXF32) { + vals = extractOperands(rawElems, kWidth, kBase, f32_ty, + preserveBF16); + } else if (type.getIntOrFloatBitWidth() == 8) { + vals = extractOperands(rawElems, kWidth, kBase, i8_ty, + preserveBF16, isConstantScale); + } else if (type.isBF16()) { + vals = extractOperands(rawElems, kWidth, kBase, bf16_ty, + preserveBF16); + } else { + assert(type.isF16() && "Unsupported data type"); + vals = extractOperands(rawElems, kWidth, kBase, f16_ty, + preserveBF16); + } + for (int k = 0; k < kpack; ++k) { + dotOpVals[k][{b, i, j}] = vals[k]; + } + } + } + } + } + return dotOpVals; + } +}; + +struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper { + + ScaledDotOpMFMAConversionHelper(AMDMfmaEncodingAttr mfmaLayout, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + Location loc) + : DotOpMFMAConversionHelper(mfmaLayout, rewriter, typeConverter, loc) {} + + Value generateScaledMFMAOp(StringRef intrinsicName, Value valA, Value valB, + Value valC, Type elemTypeA, Type elemTypeB) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto resType = valC.getType(); + Value zeroFlag = b.i32_val(0); + OperationState loweredOp(loc, intrinsicName); + int32_t cbsz = getMfmaF8F6F4MatrixFormat(elemTypeA); + int32_t blgp = getMfmaF8F6F4MatrixFormat(elemTypeB); + assert((cbsz != -1) && (blgp != -1)); + loweredOp.addTypes(resType); + // If both scales are constant 0, the LLVM backend will use V_MFMA_*_F8F6F4 + // instructions instead of V_MFMA_SCALE_*_F8F6F4 to reduce memory access. + loweredOp.addOperands({valA, valB, valC, b.i32_val(cbsz), b.i32_val(blgp), + zeroFlag, zeroFlag, zeroFlag, zeroFlag}); + return rewriter.create(loweredOp)->getResult(0); + } + + Value generateScaledMFMAOp(StringRef intrinsicName, Value valA, Value valB, + Value valC, Value valScaleA, Value valScaleB, + Type elemTypeA, Type elemTypeB) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto resType = valC.getType(); + Value zeroFlag = b.i32_val(0); + OperationState loweredOp(loc, intrinsicName); + int32_t cbsz = getMfmaF8F6F4MatrixFormat(elemTypeA); + int32_t blgp = getMfmaF8F6F4MatrixFormat(elemTypeB); + assert((cbsz != -1) && (blgp != -1)); + loweredOp.addTypes(resType); + loweredOp.addOperands({valA, valB, valC, b.i32_val(cbsz), b.i32_val(blgp), + zeroFlag, valScaleA, zeroFlag, valScaleB}); + return rewriter.create(loweredOp)->getResult(0); + } + + LogicalResult convertScaledDot(DotScaledOp op, + DotScaledOpAdaptor adaptor) const { + // Check if this dot has come with priority set by setprio. + auto setPrioOp = dyn_cast_or_null(op->getPrevNode()); + + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + auto mDim = mfmaLayout.getMDim(); + auto nDim = mfmaLayout.getNDim(); + auto mfmaVersion = mfmaLayout.getVersionMajor(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + + Value a = op.getA(); + Value b = op.getB(); + Value aScale = op.getAScale(); + Value bScale = op.getBScale(); + if ((aScale && !bScale) || (!aScale && bScale)) { + llvm::report_fatal_error("Single scale is not supported\n"); + } + + bool existBothScales = aScale && bScale; + bool isAScaleConstant = aScale && aScale.getDefiningOp(); + bool isBScaleConstant = bScale && bScale.getDefiningOp(); + Value d = op.getD(); + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); + auto dTensorTy = cast(d.getType()); + auto elemTyA = aTensorTy.getElementType(); + auto elemTyB = bTensorTy.getElementType(); + ScaleDotElemType aElemType = op.getAElemType(); + ScaleDotElemType bElemType = op.getBElemType(); + + auto supportsTypes = [](ScaleDotElemType elemType) { + return elemType == ScaleDotElemType::E2M1 || + elemType == ScaleDotElemType::E4M3 || + elemType == ScaleDotElemType::E5M2; + }; + + if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) { + llvm::report_fatal_error("NYI: mxfp6\n"); + } + + int64_t kDimOperandSize = aTensorTy.getShape().back(); + + auto ctx = op.getContext(); + constexpr bool allowXF32 = false; + FailureOr maybeMfmaIntrinsic = MfmaIntrinsic::selectFor( + mfmaVersion, mDim, nDim, + aElemType == ScaleDotElemType::E2M1 ? kDimOperandSize * 2 + : kDimOperandSize, + scaleDotElemTypeToMLIRType(ctx, aElemType), + scaleDotElemTypeToMLIRType(ctx, bElemType), + /*withScale=*/true, allowXF32); + if (failed(maybeMfmaIntrinsic)) + llvm::report_fatal_error("No match found in MFMA database\n"); + + StringRef intrinsicName = maybeMfmaIntrinsic->name; + unsigned kBase = maybeMfmaIntrinsic->kBase; + // Two fp4 are packed into an uint8. + unsigned aKBase = aElemType == ScaleDotElemType::E2M1 ? kBase / 2 : kBase; + unsigned bKBase = bElemType == ScaleDotElemType::E2M1 ? kBase / 2 : kBase; + + int aKWidth = aKBase; + int bKWidth = bKBase; + + const auto kDimInstrSize = mfmaLayout.getInstrShapeForOperand(aKBase, 0)[1]; + + auto repA = mfmaLayout.getRepForOperand(aTensorTy.getShape(), aKWidth, 0); + auto repB = mfmaLayout.getRepForOperand(bTensorTy.getShape(), bKWidth, 1); + assert(repA[2] == repB[1]); + + // For fp4 scaled mfma, each thread takes 1 element from scale. Will have + // better way to get it when adapting other data types. Similar to + // scaleKBase + constexpr int scaleKWidth = 1; + constexpr int scaleKBase = 1; + + Value loadedA = adaptor.getA(); + Value loadedB = adaptor.getB(); + Value loadedAScale = adaptor.getAScale(); + Value loadedBScale = adaptor.getBScale(); + Value loadedC = adaptor.getC(); + + auto numRepM = repA[1]; + auto numRepN = repB[2]; + auto numRepK = repA[2]; + auto numRepB = repA[0]; + assert(repA[0] == repB[0]); + + auto operandA = getValuesFromDotOperandLayoutStruct( + loadedA, numRepB, numRepM, numRepK, aKWidth, aKBase, + aTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false); + auto operandB = getValuesFromDotOperandLayoutStruct( + loadedB, numRepB, numRepN, numRepK, bKWidth, bKBase, + bTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false); + + // Scales have the same replica distributions as their corresponding + // operands. + SmallVector operandAScale; + SmallVector operandBScale; + if (existBothScales) { + auto aScaleTensorTy = cast(aScale.getType()); + operandAScale = getValuesFromDotOperandLayoutStruct( + loadedAScale, numRepB, numRepM, numRepK, scaleKWidth, scaleKBase, + aScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false, + isAScaleConstant); + + auto bScaleTensorTy = cast(bScale.getType()); + operandBScale = getValuesFromDotOperandLayoutStruct( + loadedBScale, numRepB, numRepN, numRepK, scaleKWidth, scaleKBase, + bScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false, + isBScaleConstant); + } + + auto dstElemTy = dTensorTy.getElementType(); + auto fc = unpackLLElements(loc, loadedC, rewriter); + + unsigned warpSize = triton::gpu::getWarpSize(mfmaLayout); + // compute number of output elements that each thread holds for one MFMA + // instruction. subBlocks + const int subBlocks = + getNumSubmatrices(aTensorTy.getElementType(), mDim, nDim); + auto elemsPerVec = mDim * nDim * subBlocks / warpSize; + + Value firstMfma; + auto tb = TritonLLVMOpBuilder(loc, rewriter); + auto vecTy = vec_ty(dstElemTy, elemsPerVec); + for (int b = 0; b < numRepB; ++b) { + for (int m = 0; m < numRepM; ++m) { + for (int n = 0; n < numRepN; ++n) { + Value acc = tb.undef(vecTy); + for (unsigned v = 0; v < elemsPerVec; ++v) { + acc = tb.insert_element( + vecTy, acc, + fc[b * numRepM * numRepN * elemsPerVec + + m * numRepN * elemsPerVec + n * elemsPerVec + v], + tb.i32_val(v)); + } + acc = zeroAuxiliarBlocks(subBlocks, acc); + for (int k = 0; k < numRepK; k++) { + for (int kPack = 0; kPack < aKWidth / aKBase; ++kPack) { + if (existBothScales) { + if (mfmaLayout.getIsTransposed()) { + acc = generateScaledMFMAOp(intrinsicName, + operandB[kPack][{b, n, k}], + operandA[kPack][{b, m, k}], acc, + operandBScale[kPack][{b, n, k}], + operandAScale[kPack][{b, m, k}], + maybeMfmaIntrinsic->bElementType, + maybeMfmaIntrinsic->aElementType); + } else { + acc = generateScaledMFMAOp(intrinsicName, + operandA[kPack][{b, m, k}], + operandB[kPack][{b, n, k}], acc, + operandAScale[kPack][{b, m, k}], + operandBScale[kPack][{b, n, k}], + maybeMfmaIntrinsic->aElementType, + maybeMfmaIntrinsic->bElementType); + } + } else { + if (mfmaLayout.getIsTransposed()) { + acc = generateScaledMFMAOp(intrinsicName, + operandB[kPack][{b, n, k}], + operandA[kPack][{b, m, k}], acc, + maybeMfmaIntrinsic->bElementType, + maybeMfmaIntrinsic->aElementType); + } else { + acc = generateScaledMFMAOp(intrinsicName, + operandA[kPack][{b, m, k}], + operandB[kPack][{b, n, k}], acc, + maybeMfmaIntrinsic->aElementType, + maybeMfmaIntrinsic->bElementType); + } + } + if (!firstMfma) + firstMfma = acc; + } + } + acc = reduceSubBlocks(subBlocks, acc); + adjustAccForSmallKDim(fc, acc, dstElemTy, b, m, n, numRepM, numRepN, + kDimInstrSize, kDimOperandSize, elemsPerVec); + } + } + } + + // Originally, setprio (high) is set to the high-level dot op. After dot is + // being lowered to the series of mfma operations, it should be moved next + // to the first mfma leaving the first mfma staying at the low priority. In + // this way, incoming warp can be effectively waiting on the first mfma + // instruction (low priority) while the other warp is executing mfma with + // high priority. Otherwise, incoming warp can break the cluster. + if (setPrioOp && firstMfma) + setPrioOp->moveAfter(firstMfma.getDefiningOp()); + + const size_t mmaCount = + numRepB * numRepM * numRepN * numRepK * aKWidth / aKBase; + packAndReplaceResult(op, fc, maybeMfmaIntrinsic, dstElemTy, elemTyA, + mmaCount); + + return success(); + } +}; + +} // namespace + +namespace mlir::triton::AMD { +LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto rankedTType = [](Value tensor) { + return cast(tensor.getType()); + }; + + assert(isa(rankedTType(op.getA()).getEncoding()) && + isa(rankedTType(op.getB()).getEncoding()) && + "Both A and B should be DotOperand layout."); + + auto cTensorTy = rankedTType(op.getC()); + auto dTensorTy = rankedTType(op.getD()); + assert(isa(cTensorTy.getEncoding()) && + "Currently, we only support C with a mfma layout."); + + assert(cTensorTy.getShape()[0] == dTensorTy.getShape()[0] && + cTensorTy.getShape()[1] == dTensorTy.getShape()[1] && + "DotOp's C operand should pass the same number of values as D."); + + auto loc = op.getLoc(); + auto mfmaLayout = cast( + cast(op.getResult().getType()).getEncoding()); + + DotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter, loc); + + return helper.convertDot(op, adaptor); +} + +LogicalResult convertScaledMFMA(triton::DotScaledOp op, + triton::DotScaledOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + assert(isa(op.getA().getType().getEncoding()) && + isa(op.getB().getType().getEncoding()) && + "Both lhs and rhs should be linear layout."); + + auto aScale = op.getAScale(); + auto bScale = op.getBScale(); + + // If the tt.dot_scaled is transformed from a tt.dot, both scales are None. In + // this case, both scales remain None in this method and we will generate a + // mfma instruction with the scale operand to be 0. Then there's an + // optimization pass in the LLVM backend to convert such V_MFMA_SCALE_*_F8F6F4 + // instruction to V_MFMA_*_F8F6F4 to avoid LD_SCALE. + // + // If the tt.dot_scaled is not from a tt.dot but native, we support 0, 1, 2 + // scales and treat them in different ways: + // + // 1. #scales = 0: Just like those transformed from tt.dot, both scales remain + // None. + // 2. #scales = 1: The upstream transform guarantees to create constant + // scales for the absent. + // 2. #scales = 2: Both scales should exist. + + // Thus in this pass, there shouldn't be a single scale present. + assert(((aScale && bScale) || (!aScale && !bScale)) && + "Single scale is not supported"); + + if (aScale && bScale) { + assert( + isa(aScale.getType().getEncoding()) && + isa(bScale.getType().getEncoding()) && + "If scales exist, both LhsScale and RhsScale should be linear layout."); + } + + auto cTensorTy = op.getC().getType(); + auto dTensorTy = op.getD().getType(); + assert(isa(cTensorTy.getEncoding()) && + "Currently, we only support C with a mfma layout."); + + assert(cTensorTy.getShape()[0] == dTensorTy.getShape()[0] && + cTensorTy.getShape()[1] == dTensorTy.getShape()[1] && + "DotOp's C operand should pass the same number of values as D."); + + auto loc = op.getLoc(); + auto mfmaLayout = cast( + cast(op.getResult().getType()).getEncoding()); + + ScaledDotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter, + loc); + + return helper.convertScaledDot(op, adaptor); +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp new file mode 100644 index 000000000..fa4493101 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -0,0 +1,371 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir::triton::AMD { +namespace { + +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; + +enum class WMMAInstrType : uint8_t { + // D = AB + C; + // typeof(D) == typeof(C) + // typeof(A) == typeof(B) + // typeof(D), typeof(A): + FP32_FP16, + FP32_BF16, + FP16_FP16, + BF16_BF16, + I32_I8, + I32_I4, + NOT_APPLICABLE, +}; + +using ValueTable = std::map, Value>; + +ValueTable +getValuesFromDotOperandLayoutStruct(ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + Value value, int batch, int n0, int n1, + int kWidth, Type type, Location loc) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + auto elems = unpackLLElements(loc, value, rewriter); + ValueTable vals; + for (int b = 0; b < batch; b++) { + for (int i = 0; i < n0; i++) { + for (int j = 0; j < n1; j++) { + Type elemTy = typeConverter->convertType(type); + Type ty = vec_ty(elemTy, kWidth); + Value rawElems = tb.undef(ty); + for (int k = 0; k < kWidth; ++k) { + rawElems = tb.insert_element( + ty, rawElems, + elems[n0 * n1 * kWidth * b + kWidth * (n1 * i + j) + k], + tb.i32_val(k)); + } + + Value convertedElems; + if (type.isF16()) { + convertedElems = rawElems; + } else if (type.isBF16()) { + convertedElems = tb.bitcast(rawElems, vec_ty(i16_ty, kWidth)); + } else { + convertedElems = tb.bitcast( + rawElems, vec_ty(i32_ty, kWidth * type.getIntOrFloatBitWidth() / + i32_ty.getIntOrFloatBitWidth())); + } + vals[{b, i, j}] = convertedElems; + } + } + } + return vals; +} + +static WMMAInstrType getWMMAInstrTypeFromDot(DotOp op) { + auto aOperandTy = op.getA().getType(); + auto aTensorTy = cast(aOperandTy); + auto aElemTy = aTensorTy.getElementType(); + auto bOperandTy = op.getB().getType(); + auto bTensorTy = cast(bOperandTy); + auto bElemTy = bTensorTy.getElementType(); + assert(aElemTy == bElemTy); + auto cOperandTy = op.getC().getType(); + auto cTensorTy = cast(cOperandTy); + auto cElemTy = cTensorTy.getElementType(); + auto dOperandTy = op.getD().getType(); + auto dTensorTy = cast(dOperandTy); + auto dElemTy = dTensorTy.getElementType(); + assert(cElemTy == dElemTy); + + if (dElemTy.isF32() && aElemTy.isF16()) + return WMMAInstrType::FP32_FP16; + if (dElemTy.isF32() && aElemTy.isBF16()) + return WMMAInstrType::FP32_BF16; + if (dElemTy.isF16() && aElemTy.isF16()) + return WMMAInstrType::FP16_FP16; + if (dElemTy.isBF16() && aElemTy.isBF16()) + return WMMAInstrType::BF16_BF16; + if (dElemTy.isInteger(32) && aElemTy.isInteger(8)) + return WMMAInstrType::I32_I8; + if (dElemTy.isInteger(32) && aElemTy.isInteger(4)) + return WMMAInstrType::I32_I4; + + return WMMAInstrType::NOT_APPLICABLE; +} + +Value generateROCDLOp(ConversionPatternRewriter &rewriter, Location loc, + WMMAInstrType wmmaType, Value valA, Value valB, + Value valC, Type aElType, Type bElType) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto resType = valC.getType(); + Value falseFlag = b.int_val(1, false); + switch (wmmaType) { + case WMMAInstrType::FP32_FP16: + return rewriter.create( + loc, TypeRange{resType}, ValueRange{valA, valB, valC}); + case WMMAInstrType::FP32_BF16: + return rewriter.create( + loc, TypeRange{resType}, ValueRange{valA, valB, valC}); + case WMMAInstrType::FP16_FP16: + return rewriter.create( + loc, TypeRange{resType}, ValueRange{valA, valB, valC, falseFlag}); + case WMMAInstrType::BF16_BF16: + return rewriter.create( + loc, TypeRange{resType}, ValueRange{valA, valB, valC, falseFlag}); + case WMMAInstrType::I32_I8: + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{b.int_val(1, !aElType.isUnsignedInteger()), valA, + b.int_val(1, !bElType.isUnsignedInteger()), valB, valC, + falseFlag}); + case WMMAInstrType::I32_I4: + return rewriter.create( + loc, TypeRange{resType}, + ValueRange{b.int_val(1, !aElType.isUnsignedInteger()), valA, + b.int_val(1, !bElType.isUnsignedInteger()), valB, valC, + falseFlag}); + default: + llvm::report_fatal_error("WMMA data type not supported"); + } + return Value(); +} + +std::string getTypeStr(Type ty) { + std::string scalarName; + if (ty.isF32()) { + scalarName = "f32"; + } else if (ty.isF16()) { + scalarName = "f16"; + } else if (ty.isBF16()) { + scalarName = "bf16"; + } else if (ty.isInteger(32)) { + scalarName = "i32"; + } else if (ty.isInteger(16)) { + scalarName = "i16"; + } else if (ty.isInteger(8)) { + scalarName = "iu8"; + } else if (ty.isInteger(4)) { + scalarName = "iu4"; + } else if (auto vecTy = dyn_cast(ty)) { + auto elemType = vecTy.getElementType(); + auto numElems = vecTy.getNumElements(); + scalarName = "v" + std::to_string(numElems) + getTypeStr(elemType); + } else { + llvm::report_fatal_error("WMMA data type not supported"); + } + return scalarName; +} + +StringRef getWmmaIntrinsicName(Type aElTy, Type bElTy, Type dElTy, Type valATy, + Type valCTy) { + static llvm::SmallDenseMap intrinsics; + using MapInfo = llvm::DenseMapInfo; + llvm::hash_code h = llvm::hash_combine( + MapInfo::getHashValue(aElTy), MapInfo::getHashValue(bElTy), + MapInfo::getHashValue(dElTy), MapInfo::getHashValue(valATy), + MapInfo::getHashValue(valCTy)); + if (!intrinsics.contains(h)) { + std::string name = "llvm.amdgcn.wmma."; + name += getTypeStr(dElTy); + name += ".16x16x16."; // TODO support 16x16x32 for i4 operands + name += getTypeStr(aElTy); + if (isa(aElTy) && aElTy.getIntOrFloatBitWidth() == 8) + name += '.' + getTypeStr(bElTy); + name += '.' + getTypeStr(valCTy) + "." + getTypeStr(valATy); + intrinsics[h] = name; + } + return intrinsics[h]; +} + +Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc, + WMMAInstrType wmmaType, Value valA, Value valB, + Value valC, Type aElType, Type bElType, + Type dElType) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto name = getWmmaIntrinsicName(aElType, bElType, dElType, valA.getType(), + valC.getType()); + LLVM::FastmathFlagsAttr defaultFlags{}; + SmallVector operands; + if (aElType.isInteger()) + operands.push_back(b.int_val(1, !aElType.isUnsignedInteger())); + operands.push_back(valA); + if (bElType.isInteger()) + operands.push_back(b.int_val(1, !bElType.isUnsignedInteger())); + operands.push_back(valB); + operands.push_back(valC); + // Flag for using low bits in registers. Result could be already packed to + // int32. Set low bits by default for now. + if (32 / dElType.getIntOrFloatBitWidth() > 1 || dElType.isInteger(32)) { + operands.push_back(b.int_val(1, false)); + } + auto wmmaIntrinsic = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, name, valC.getType(), operands); + return wmmaIntrinsic.getResult(0); +} + +Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc, + WMMAInstrType wmmaType, Value valA, Value valB, Value valC, + Type aElType, Type bElType, Type dElType, int version) { + if (version == 1) { + return generateROCDLOp(rewriter, loc, wmmaType, valA, valB, valC, aElType, + bElType); + } else { + assert(version == 2); + return generateWMMAIntrinsic(rewriter, loc, wmmaType, valA, valB, valC, + aElType, bElType, dElType); + } +} + +// Conduct the Dot conversion. +LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter) { + auto wmmaLayout = cast( + cast(op.getResult().getType()).getEncoding()); + int wmmaVer = wmmaLayout.getVersion(); + auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerInstr(); + auto wmmaInstrType = getWMMAInstrTypeFromDot(op); + + auto loc = op.getLoc(); + auto tb = TritonLLVMOpBuilder(loc, rewriter); + Value a = op.getA(); + Value b = op.getB(); + Value d = op.getD(); + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); + auto dTensorTy = cast(d.getType()); + auto elemTy = aTensorTy.getElementType(); + + auto aEncoding = cast(aTensorTy.getEncoding()); + auto bEncoding = cast(bTensorTy.getEncoding()); + int kWidth = aEncoding.getKWidth(); + + auto repA = + wmmaLayout.getRepForOperand(aTensorTy.getShape(), elemTy, kWidth, 0); + auto repB = + wmmaLayout.getRepForOperand(bTensorTy.getShape(), elemTy, kWidth, 1); + + assert(repA[2] == repB[1]); + + Value loadedA = adaptor.getA(); + Value loadedB = adaptor.getB(); + Value loadedC = adaptor.getC(); + auto numRepM = repA[1]; + auto numRepN = repB[2]; + auto numRepK = repA[2]; + auto numRepB = repA[0]; + + ValueTable ha = getValuesFromDotOperandLayoutStruct( + rewriter, typeConverter, loadedA, numRepB, numRepM, numRepK, kWidth, + aTensorTy.getElementType(), loc); + ValueTable hb = getValuesFromDotOperandLayoutStruct( + rewriter, typeConverter, loadedB, numRepB, numRepN, numRepK, kWidth, + aTensorTy.getElementType(), loc); + auto dstElemTy = dTensorTy.getElementType(); + auto fc = unpackLLElements(loc, loadedC, rewriter); + + unsigned warpSize = triton::gpu::getWarpSize(wmmaLayout); + constexpr unsigned vgprElemBitWidth = 32; + unsigned paddedOutputElemSize = + wmmaVer == 1 ? vgprElemBitWidth / dstElemTy.getIntOrFloatBitWidth() : 1; + // compute number of output elements that each thread holds for one WMMA + // instruction. + auto elemsPerVec = mnkDim[0] * mnkDim[1] * paddedOutputElemSize / warpSize; + auto dElemsToStorePerThread = mnkDim[0] * mnkDim[1] / warpSize; + auto vecTy = vec_ty(dstElemTy, elemsPerVec); + for (int b = 0; b < numRepB; ++b) { + for (int m = 0; m < numRepM; ++m) { + for (int n = 0; n < numRepN; ++n) { + auto batchOffIdx = b * numRepM * numRepN * dElemsToStorePerThread; + auto mRepOffId = m * numRepN * dElemsToStorePerThread; + auto nRepOffId = n * dElemsToStorePerThread; + auto fcThreadOffIdx = batchOffIdx + mRepOffId + nRepOffId; + + Value acc = tb.undef(vecTy); + for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { + acc = tb.insert_element(vecTy, acc, fc[fcThreadOffIdx + v], + tb.i32_val(v * paddedOutputElemSize)); + } + for (size_t k = 0; k < numRepK; k++) { + acc = wmmaLayout.getIsTransposed() + ? generateWMMAOp( + rewriter, loc, wmmaInstrType, hb[{b, n, k}], + ha[{b, m, k}], acc, bTensorTy.getElementType(), + aTensorTy.getElementType(), dstElemTy, wmmaVer) + : generateWMMAOp( + rewriter, loc, wmmaInstrType, ha[{b, m, k}], + hb[{b, n, k}], acc, aTensorTy.getElementType(), + bTensorTy.getElementType(), dstElemTy, wmmaVer); + } + for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { + fc[fcThreadOffIdx + v] = tb.extract_element( + dstElemTy, acc, tb.i32_val(v * paddedOutputElemSize)); + } + } + } + } + + // replace with new packed result + Type structTy = LLVM::LLVMStructType::getLiteral( + wmmaLayout.getContext(), SmallVector(fc.size(), dstElemTy)); + Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + const size_t mmaCount = numRepB * numRepM * numRepN * numRepK; + setNumGeneratedMMAs(op, mmaCount, mnkDim[0], mnkDim[1], mnkDim[2], elemTy); + + rewriter.replaceOp(op, res); + return success(); +} + +} // namespace + +LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto rankedTType = [](Value tensor) { + return cast(tensor.getType()); + }; + + assert(isa(rankedTType(op.getA()).getEncoding()) && + isa(rankedTType(op.getB()).getEncoding()) && + "Both $a and %b should be DotOperand layout."); + + auto cTensorTy = rankedTType(op.getC()); + auto dTensorTy = rankedTType(op.getD()); + assert(isa(cTensorTy.getEncoding()) && + "Currently, we only support $c with a wmma layout."); + + assert(cTensorTy.getShape()[0] == dTensorTy.getShape()[0] && + cTensorTy.getShape()[1] == dTensorTy.getShape()[1] && + "DotOp's $c operand should pass the same number of values as $d"); + + return convertDot(op, adaptor, rewriter, typeConverter); +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..371f19d9c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,1568 @@ +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; + +using mlir::triton::gpu::appendOrGetExternFuncOp; +using mlir::triton::gpu::ElementwiseOpConversionBase; +using mlir::triton::gpu::getElementType; +using mlir::triton::gpu::getFunctionType; +using mlir::triton::gpu::MultipleOperandsRange; + +using ConverterT = std::function( + Location, ConversionPatternRewriter &, const SmallVector &)>; + +namespace { +//===----------------------------------------------------------------------===// +// Data type conversion utility functions +//===----------------------------------------------------------------------===// + +// FP8E5M2 is the open-compute standard FP8E5M2 format. NVIDIA GPU supports it +// natively but we don't have hardware native support on MI300. +// +// The SW based downcast with RTNE is not fully functional for the denorm +// values. We need rewrite it if we need to emulate this data type on AMDGPU. +static SmallVector +Fp16_to_Fp8E5M2_RTNE(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto fp16x2VecTy = vec_ty(f16_ty, 2); + Value fp16x2Vec0 = b.undef(fp16x2VecTy); + Value fp16x2Vec1 = b.undef(fp16x2VecTy); + fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[0], b.i32_val(0)); + fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[1], b.i32_val(1)); + fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[2], b.i32_val(0)); + fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[3], b.i32_val(1)); + + Value a0 = b.bitcast(fp16x2Vec0, i32_ty); + Value a1 = b.bitcast(fp16x2Vec1, i32_ty); + + a0 = b.and_(i32_ty, a0, b.i32_val(0xfffefffe)); + a1 = b.and_(i32_ty, a1, b.i32_val(0xfffefffe)); + + a0 = b.add(i32_ty, a0, b.i32_val(0x00800080)); + a1 = b.add(i32_ty, a1, b.i32_val(0x00800080)); + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + a0 = b.bitcast(a0, fp8x4VecTy); + a1 = b.bitcast(a1, fp8x4VecTy); + + return {b.extract_element(i8_ty, a0, b.i32_val(1)), + b.extract_element(i8_ty, a0, b.i32_val(3)), + b.extract_element(i8_ty, a1, b.i32_val(1)), + b.extract_element(i8_ty, a1, b.i32_val(3))}; +} + +static SmallVector +Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto fp16x2VecTy = vec_ty(f16_ty, 2); + Value fp16x2Vec0 = b.undef(fp16x2VecTy); + Value fp16x2Vec1 = b.undef(fp16x2VecTy); + fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[0], b.i32_val(0)); + fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[1], b.i32_val(1)); + fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[2], b.i32_val(0)); + fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[3], b.i32_val(1)); + + Value a0 = b.bitcast(fp16x2Vec0, i32_ty); + Value a1 = b.bitcast(fp16x2Vec1, i32_ty); + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + a0 = b.bitcast(a0, fp8x4VecTy); + a1 = b.bitcast(a1, fp8x4VecTy); + + return {b.extract_element(i8_ty, a0, b.i32_val(1)), + b.extract_element(i8_ty, a0, b.i32_val(3)), + b.extract_element(i8_ty, a1, b.i32_val(1)), + b.extract_element(i8_ty, a1, b.i32_val(3))}; +} + +static Value checkIsNan(TritonLLVMOpBuilder &builder, Value v) { + StringRef intrinsic = "llvm.is.fpclass"; + // bits 0 and 1 indicate signaling Nan and quiet Nan, respectively + Location loc = builder.loc; + OpBuilder &rewriter = *builder.builder; + Value nanBits = builder.i32_val(3); + + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, i1_ty, + ValueRange{v, nanBits}) + ->getResult(0); +} + +// Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode. +// According to +// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1, +// In saturation mode, inf and out-of-range numbers are converted to the largest +// normal number, i.e. ±448. NaNs are converted to NaNs. +static Value +Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc, + ConversionPatternRewriter &rewriter, Value v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value isNaN = checkIsNan(b, v); + // Get sign and absolute value + Value vi16 = b.bitcast(v, i16_ty); + Value sign = + b.trunc(i8_ty, b.lshr(b.and_(vi16, b.i16_val(0x8000)), b.i16_val(8))); + vi16 = b.and_(vi16, b.i16_val(0x7FFF)); + + // Rounding to nearest even + constexpr uint16_t baseRoundingBias = 0x003F; // 1 << (10 - 3 - 1) - 1 + + // S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M + Value remainingMantissaLSB = + b.lshr(b.and_(vi16, b.i16_val(0x0080)), b.i16_val(7)); + Value roundingBias = b.add(remainingMantissaLSB, b.i16_val(baseRoundingBias)); + Value vFp8 = b.add(vi16, roundingBias); + + // Reduce mantissa to 3 bits + vFp8 = b.and_(vFp8, b.i16_val(0xFF80)); // 0xFF80 == 1.11111.1110000000 + + // 0x2400 is the FP16 representation of 2^{-6}, which is the smallest normal + // number in FP8E4M3FN. We round numbers smaller than that to 0x2400 to make + // it easier to handle subnormals + vFp8 = b.umax(vFp8, b.i16_val(0x2400)); + + // Adjust exponent bias + vFp8 = b.sub(vFp8, b.i16_val(0x2000)); // (15 - 7) << 10 + + // Shift right and truncate + vFp8 = b.trunc(i8_ty, b.lshr(vFp8, b.i16_val(7))); // 10 - 3 + + // 0x5F7F == 0.10111.1101111111 is the largest possible normal + // number(including infinity) after rounding in FP8 + // + // In saturation mode, numbers larger than the max normal number(including + // infinity) in FP8 after rounding will be replaced with max_E4M3, i.e. 0x7E + // === 0.1111.110 + Value isOverflowOrInf = b.icmp_ugt(vi16, b.i16_val(0x5F7F)); + vFp8 = b.select(isOverflowOrInf, b.i8_val(0x7E), vFp8); + + // Round subnormals to nearest even. Ref: + // https://github.com/openxla/xla/blob/f20c6fe2/xla/service/elemental_ir_emitter.cc#L272 + constexpr size_t lutSize = 8; + constexpr float halfwayPointsLUT[lutSize] = {0x1400, 0x1A00, 0x1D00, 0x1F00, + 0x2080, 0x2180, 0x2280, 0x2380}; + + for (int i = lutSize - 1; i >= 0; i--) { + Value cmp; + if (i % 2 == 0) { + cmp = b.icmp_ule(vi16, b.i16_val(halfwayPointsLUT[i])); + } else { + cmp = b.icmp_ult(vi16, b.i16_val(halfwayPointsLUT[i])); + } + + vFp8 = b.select(cmp, b.i8_val(i), vFp8); + } + + // NaN remains NaN after conversion + vFp8 = b.select(isNaN, b.i8_val(0x7F), vFp8); + + // Set sign bit + vFp8 = b.or_(vFp8, sign); + + return vFp8; +} + +static SmallVector +Fp16_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + SmallVector result(2); + result[0] = Fp16_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[0]); + result[1] = Fp16_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[1]); + return result; +} + +static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter, + const Value &v) { + TritonLLVMOpBuilder b(loc, rewriter); + return b.fpext(f32_ty, v); +} + +// convert fp8 to fp32 +static SmallVector cvtFp8ToFp32(Location loc, + ConversionPatternRewriter &rewriter, + Value v0, Value v1, + const std::string &fp8_format) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(fp8_format == "fp8" || fp8_format == "bf8"); + std::string ins_str = "v_cvt_pk_f32_" + fp8_format; + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value fp8x4Vec = b.undef(fp8x4VecTy); + fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v0, b.i32_val(0)); + fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v1, b.i32_val(1)); + auto i32v = b.bitcast(fp8x4Vec, i32_ty); + + GCNBuilder builder1; + auto &cvt = *builder1.create(ins_str); + auto res = builder1.newOperand("=v"); + auto operand = builder1.newOperand(i32v, "v"); + cvt(res, operand); + auto i64v = builder1.launch(rewriter, loc, i64_ty, false); + auto fp32x2VecTy = vec_ty(f32_ty, 2); + auto fp32x2Vec = b.bitcast(i64v, fp32x2VecTy); + + SmallVector ret(2); + ret[0] = b.extract_element(f32_ty, fp32x2Vec, b.i32_val(0)); + ret[1] = b.extract_element(f32_ty, fp32x2Vec, b.i32_val(1)); + + return ret; +} + +// convert fp32 to fp8 +static SmallVector cvtFp32ToFp8(Location loc, + ConversionPatternRewriter &rewriter, + Value v0, Value v1, + const std::string &fp8_format) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(fp8_format == "fp8" || fp8_format == "bf8"); + std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32"; + + GCNBuilder builder; + auto &cvt = *builder.create(ins_str); + auto res = builder.newOperand("=v"); + auto operand0 = builder.newOperand(v0, "v"); + auto operand1 = builder.newOperand(v1, "v"); + cvt(res, operand0, operand1); + auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false); + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + auto a1 = b.bitcast(fp8x4Vec, fp8x4VecTy); + + SmallVector ret(2); + ret[0] = b.extract_element(i8_ty, a1, b.i32_val(0)); + ret[1] = b.extract_element(i8_ty, a1, b.i32_val(1)); + + return ret; +} +static SmallVector +convert_val_Fp16_to_Fp8(Location loc, ConversionPatternRewriter &rewriter, + Value v0, Value v1, const std::string &fp8_format) { + assert(fp8_format == "fp8" || fp8_format == "bf8"); + std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32"; + + auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0); + auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1); + + // Convert fp32 to fp8 + return cvtFp32ToFp8(loc, rewriter, f32_0, f32_1, fp8_format); +} + +static SmallVector +convert_val_Fp8_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, + Value v0, Value v1, const std::string &fp8_format) { + + // Convert fp8 to fp32 + SmallVector ret = cvtFp8ToFp32(loc, rewriter, v0, v1, fp8_format); + + // Convert fp32 to fp16 + ret[0] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[0], RoundingMode::RTNE); + ret[1] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[1], RoundingMode::RTNE); + + return ret; +} + +static SmallVector +Fp32_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 2); + return cvtFp32ToFp8(loc, rewriter, v[0], v[1], "bf8"); +} + +static SmallVector +Fp32_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 2); + return cvtFp32ToFp8(loc, rewriter, v[0], v[1], "fp8"); +} + +static SmallVector +Fp8E5M2FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 2); + return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "bf8"); +} + +static SmallVector +Fp8E4M3FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 2); + return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8"); +} + +// Depend on whether we focus more on performance, we may skip +// the processing of submornal values +static Value Fp16_to_Fp8E5M2FNUZ_oneValue(Location loc, + ConversionPatternRewriter &rewriter, + Value v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto vi16 = b.bitcast(v, i16_ty); + auto e = b.and_(i16_ty, vi16, b.int_val(16, 0x7C00)); + auto sign = b.and_(i16_ty, vi16, b.int_val(16, 0x8000)); + + // normal value + auto a = b.and_(i16_ty, vi16, b.int_val(16, 0x7FFFF)); + auto a1 = b.add(i16_ty, a, b.int_val(16, 0x0400)); + auto o1 = b.or_(i16_ty, a1, sign); + + // subnormal value, e is 0 + auto m = b.and_(i16_ty, vi16, b.int_val(16, 0x03FF)); + auto m2 = b.shl(m, b.int_val(16, 1)); + auto o2 = b.or_(i16_ty, sign, b.or_(i16_ty, b.int_val(16, 1), m2)); + + auto e_is_zero = b.icmp_eq(e, b.int_val(16, 0)); + auto e_is_all1 = b.icmp_eq(e, b.int_val(16, 0x7C00)); + + auto ot = b.select(e_is_zero, o2, o1); + auto o = b.select(e_is_all1, vi16, ot); + auto fp8x2VecTy = vec_ty(i8_ty, 2); + auto res = b.bitcast(o, fp8x2VecTy); + + return b.extract_element(i8_ty, res, b.i32_val(1)); +} + +static SmallVector +Fp16_to_Fp8E5M2FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + SmallVector result(2); + result[0] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[0]); + result[1] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[1]); + return result; +} + +static SmallVector +Fp16_to_Fp8E5M2FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "bf8"); +} + +ConverterT Fp16_to_Fp8E5M2FNUZ(AMD::ISAFamily isaFamily) { + return isaFamily == AMD::ISAFamily::CDNA3 ? Fp16_to_Fp8E5M2FNUZ_HW + : Fp16_to_Fp8E5M2FNUZ_SW; +} + +static Value Fp8E4M3FN_to_Fp16_oneValue(Location loc, + ConversionPatternRewriter &rewriter, + Value v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto fp8x2VecTy = vec_ty(i8_ty, 2); + Value a = b.undef(fp8x2VecTy); + a = b.insert_element(fp8x2VecTy, a, b.i8_val(0), b.i32_val(0)); + a = b.insert_element(fp8x2VecTy, a, v, b.i32_val(1)); + a = b.bitcast(a, i16_ty); + + // Get sign and absolute value + Value sign = b.and_(a, b.i16_val(0x8000)); + a = b.and_(a, b.i16_val(0x7FFF)); + + // Right shift 1 bit to adjust the positions of exponent and mantissa + a = b.lshr(a, b.i16_val(1)); + + // Adjust exponent, (15 - 7) << 10 === 0x2000 + a = b.add(a, b.i16_val(0x2000)); + + // Check NaN + Value vAbs = b.and_(b.bitcast(v, i8_ty), b.i8_val(0x7F)); + a = b.select(b.icmp_eq(vAbs, b.i8_val(0x7F)), b.i16_val(0x7E00), a); + + // Check denorms and zero + // Here we use a LUT to map S.0000.000 ~ S.0000.111 to its corresponding fp16 + // value + constexpr size_t lutSize = 8; + static constexpr int denormsAndZeroLut[lutSize] = { + 0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300}; + + for (int i = 0; i < lutSize; i++) { + a = b.select(b.icmp_eq(vAbs, b.i8_val(i)), b.i16_val(denormsAndZeroLut[i]), + a); + } + + // Set sign + a = b.or_(a, sign); + a = b.bitcast(a, f16_ty); + + return a; +} + +static SmallVector Fp8E4M3FN_to_Fp16(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &values) { + SmallVector results(2); + results[0] = Fp8E4M3FN_to_Fp16_oneValue(loc, rewriter, values[0]); + results[1] = Fp8E4M3FN_to_Fp16_oneValue(loc, rewriter, values[1]); + return results; +} + +static SmallVector Fp8E5M2_to_Fp16(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value a0 = b.undef(fp8x4VecTy); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(0)); + a0 = b.insert_element(fp8x4VecTy, a0, v[0], b.i32_val(1)); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(2)); + a0 = b.insert_element(fp8x4VecTy, a0, v[1], b.i32_val(3)); + a0 = b.bitcast(a0, i32_ty); + Value a1 = b.undef(fp8x4VecTy); + a1 = b.insert_element(fp8x4VecTy, a1, b.int_val(8, 0), b.i32_val(0)); + a1 = b.insert_element(fp8x4VecTy, a1, v[2], b.i32_val(1)); + a1 = b.insert_element(fp8x4VecTy, a1, b.int_val(8, 0), b.i32_val(2)); + a1 = b.insert_element(fp8x4VecTy, a1, v[3], b.i32_val(3)); + a1 = b.bitcast(a1, i32_ty); + + auto fp16x2VecTy = vec_ty(f16_ty, 2); + auto fp16x2Vec0 = b.bitcast(a0, fp16x2VecTy); + auto fp16x2Vec1 = b.bitcast(a1, fp16x2VecTy); + + return {b.extract_element(f16_ty, fp16x2Vec0, b.i32_val(0)), + b.extract_element(f16_ty, fp16x2Vec0, b.i32_val(1)), + b.extract_element(f16_ty, fp16x2Vec1, b.i32_val(0)), + b.extract_element(f16_ty, fp16x2Vec1, b.i32_val(1))}; +} + +static Value convertBf16ToFp32(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto as_int16 = b.bitcast(v, i16_ty); + auto as_int32 = b.zext(i32_ty, as_int16); + auto shifted = b.shl(i32_ty, as_int32, b.i32_val(16)); + return b.bitcast(shifted, f32_ty); +} + +static Value convertFp32ToBf16(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v, const RoundingMode rounding) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto as_int32 = b.bitcast(v, i32_ty); + if (rounding == RoundingMode::RTZ) { + auto shifted = b.lshr(i32_ty, as_int32, b.i32_val(16)); + auto truncated = b.trunc(i16_ty, shifted); + return b.bitcast(truncated, bf16_ty); + } + + // This implementation is a faster version for fp32 to bf16 type conversion + // It is from CK: + // https://github.com/cgmillette/composable_kernel/commit/24e75bef6aa5 + // It uses less VGPR and less number of instructions compared to the + // previous implementation + Value isNan = checkIsNan(b, v); + Value v16 = b.i32_val(16); + Value tmp = b.and_(i32_ty, b.lshr(i32_ty, as_int32, v16), b.i32_val(1)); + + Value v7FFF = b.i32_val(0x7FFF); + Value s1 = b.add(as_int32, tmp); + Value s2 = b.add(s1, v7FFF); + + Value vNan = b.i32_val(0x7FFF0000); + Value res = b.select(isNan, vNan, s2); + + Value shifted = b.lshr(i32_ty, res, v16); + Value truncated = b.trunc(i16_ty, shifted); + return b.bitcast(truncated, bf16_ty); +} + +static Value Fp8E5M2FNUZ_to_Fp16_oneValue(Location loc, + ConversionPatternRewriter &rewriter, + Value v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto fp8x2VecTy = vec_ty(i8_ty, 2); + Value a = b.undef(fp8x2VecTy); + a = b.insert_element(fp8x2VecTy, a, b.int_val(8, 0), b.i32_val(0)); + a = b.insert_element(fp8x2VecTy, a, v, b.i32_val(1)); + a = b.bitcast(a, i16_ty); + + auto e = b.and_(i16_ty, a, b.int_val(16, 0x7C00)); + auto m = b.and_(i16_ty, a, b.int_val(16, 0x0300)); + auto sign = b.and_(i16_ty, a, b.int_val(16, 0x8000)); + + // check whether all exponents are zeros + auto e_is_zero = b.icmp_eq(e, b.int_val(16, 0x0)); + + // case 1, e is zero, need to move m right by 1 bit + auto m1 = b.lshr(i16_ty, m, b.int_val(16, 1)); + auto o0 = b.or_(i16_ty, sign, m1); + + // case 2, e is nonzero, sub exponent by 1 + auto e1 = b.sub(i16_ty, e, b.int_val(16, 0x0400)); + + auto e_is_one = b.icmp_eq(e, b.int_val(16, 0x0400)); + auto m2 = b.add(i16_ty, m1, b.int_val(16, 0x0200)); + + auto o1 = b.or_(i16_ty, sign, b.or_(i16_ty, m, e1)); + auto o2 = b.or_(i16_ty, sign, m2); + + auto o12 = b.select(e_is_one, o2, o1); + auto o = b.select(e_is_zero, o0, o12); + + return b.bitcast(o, f16_ty); +} + +static SmallVector +Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + SmallVector result(2); + result[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]); + result[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]); + return result; +} + +static SmallVector +Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "bf8"); +} + +ConverterT Fp8E5M2FNUZ_to_Fp16(AMD::ISAFamily isaFamily) { + return isaFamily == AMD::ISAFamily::CDNA3 ? Fp8E5M2FNUZ_to_Fp16_HW + : Fp8E5M2FNUZ_to_Fp16_SW; +} + +static SmallVector Fp8E5M2_to_Bf16(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value a0 = b.undef(fp8x4VecTy); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(0)); + a0 = b.insert_element(fp8x4VecTy, a0, v[0], b.i32_val(1)); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(2)); + a0 = b.insert_element(fp8x4VecTy, a0, v[1], b.i32_val(3)); + a0 = b.bitcast(a0, i32_ty); + + Value a1 = b.undef(fp8x4VecTy); + a1 = b.insert_element(fp8x4VecTy, a1, b.int_val(8, 0), b.i32_val(0)); + a1 = b.insert_element(fp8x4VecTy, a1, v[2], b.i32_val(1)); + a1 = b.insert_element(fp8x4VecTy, a1, b.int_val(8, 0), b.i32_val(2)); + a1 = b.insert_element(fp8x4VecTy, a1, v[3], b.i32_val(3)); + a1 = b.bitcast(a1, i32_ty); + + Value b0 = b.and_(i32_ty, a0, b.i32_val(0x7fff7fff)); + Value b1 = b.and_(i32_ty, a1, b.i32_val(0x7fff7fff)); + b0 = b.lshr(i32_ty, b0, b.i32_val(3)); + b1 = b.lshr(i32_ty, b1, b.i32_val(3)); + + Value c0 = b.shl(i32_ty, b0, b.i32_val(16)); + Value c1 = b.and_(i32_ty, b0, b.i32_val(0xFFFF0000)); + Value c2 = b.shl(i32_ty, b1, b.i32_val(16)); + Value c3 = b.and_(i32_ty, b1, b.i32_val(0xFFFF0000)); + + c0 = b.bitcast(c0, f32_ty); + c1 = b.bitcast(c1, f32_ty); + c2 = b.bitcast(c2, f32_ty); + c3 = b.bitcast(c3, f32_ty); + + Value d0 = b.fmul(f32_ty, c0, b.f32_val(0x1p+112)); + Value d1 = b.fmul(f32_ty, c1, b.f32_val(0x1p+112)); + Value d2 = b.fmul(f32_ty, c2, b.f32_val(0x1p+112)); + Value d3 = b.fmul(f32_ty, c3, b.f32_val(0x1p+112)); + + d0 = b.bitcast(d0, i32_ty); + d1 = b.bitcast(d1, i32_ty); + d2 = b.bitcast(d2, i32_ty); + d3 = b.bitcast(d3, i32_ty); + + Value out0 = b.or_(i32_ty, b.lshr(i32_ty, d0, b.i32_val(16)), d1); + Value out1 = b.or_(i32_ty, b.lshr(i32_ty, d2, b.i32_val(16)), d3); + + Value sign0 = b.and_(i32_ty, a0, b.i32_val(0x80008000)); + Value sign1 = b.and_(i32_ty, a1, b.i32_val(0x80008000)); + + out0 = b.or_(i32_ty, out0, sign0); + out1 = b.or_(i32_ty, out1, sign1); + + auto bf16x2VecTy = vec_ty(bf16_ty, 2); + out0 = b.bitcast(out0, bf16x2VecTy); + out1 = b.bitcast(out1, bf16x2VecTy); + + return {b.extract_element(bf16_ty, out0, b.i32_val(0)), + b.extract_element(bf16_ty, out0, b.i32_val(1)), + b.extract_element(bf16_ty, out1, b.i32_val(0)), + b.extract_element(bf16_ty, out1, b.i32_val(1))}; +} + +static SmallVector Bf16_to_Fp8E5M2(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto bf16x2VecTy = vec_ty(bf16_ty, 2); + Value bf16x2Vec0 = b.undef(bf16x2VecTy); + Value bf16x2Vec1 = b.undef(bf16x2VecTy); + bf16x2Vec0 = b.insert_element(bf16x2VecTy, bf16x2Vec0, v[0], b.i32_val(0)); + bf16x2Vec0 = b.insert_element(bf16x2VecTy, bf16x2Vec0, v[1], b.i32_val(1)); + bf16x2Vec1 = b.insert_element(bf16x2VecTy, bf16x2Vec1, v[2], b.i32_val(0)); + bf16x2Vec1 = b.insert_element(bf16x2VecTy, bf16x2Vec1, v[3], b.i32_val(1)); + bf16x2Vec0 = b.bitcast(bf16x2Vec0, i32_ty); + bf16x2Vec1 = b.bitcast(bf16x2Vec1, i32_ty); + + Value sign0 = b.and_(i32_ty, bf16x2Vec0, b.i32_val(0x80008000)); + Value sign1 = b.and_(i32_ty, bf16x2Vec1, b.i32_val(0x80008000)); + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value sign = b.undef(fp8x4VecTy); + sign0 = b.bitcast(sign0, fp8x4VecTy); + sign1 = b.bitcast(sign1, fp8x4VecTy); + sign = b.insert_element(fp8x4VecTy, sign, + b.extract_element(i8_ty, sign0, b.i32_val(1)), + b.i32_val(0)); + sign = b.insert_element(fp8x4VecTy, sign, + b.extract_element(i8_ty, sign0, b.i32_val(3)), + b.i32_val(1)); + sign = b.insert_element(fp8x4VecTy, sign, + b.extract_element(i8_ty, sign1, b.i32_val(1)), + b.i32_val(2)); + sign = b.insert_element(fp8x4VecTy, sign, + b.extract_element(i8_ty, sign1, b.i32_val(3)), + b.i32_val(3)); + sign = b.bitcast(sign, i32_ty); + + Value nosign0 = b.and_(i32_ty, bf16x2Vec0, b.i32_val(0x7fff7fff)); + Value nosign1 = b.and_(i32_ty, bf16x2Vec1, b.i32_val(0x7fff7fff)); + + Value nosign_0_0 = b.and_(i32_ty, nosign0, b.i32_val(0xffff0000)); + nosign_0_0 = b.umax(i32_ty, nosign_0_0, b.i32_val(0x38000000)); + nosign_0_0 = b.umin(i32_ty, nosign_0_0, b.i32_val(0x57e00000)); + Value nosign_0_1 = b.and_(i32_ty, nosign0, b.i32_val(0x0000ffff)); + nosign_0_1 = b.umax(i32_ty, nosign_0_1, b.i32_val(0x3800)); + nosign_0_1 = b.umin(i32_ty, nosign_0_1, b.i32_val(0x57e0)); + nosign0 = b.or_(i32_ty, nosign_0_0, nosign_0_1); + + Value nosign_1_0 = b.and_(i32_ty, nosign1, b.i32_val(0xffff0000)); + nosign_1_0 = b.umax(i32_ty, nosign_1_0, b.i32_val(0x38000000)); + nosign_1_0 = b.umin(i32_ty, nosign_1_0, b.i32_val(0x57e00000)); + Value nosign_1_1 = b.and_(i32_ty, nosign1, b.i32_val(0x0000ffff)); + nosign_1_1 = b.umax(i32_ty, nosign_1_1, b.i32_val(0x3800)); + nosign_1_1 = b.umin(i32_ty, nosign_1_1, b.i32_val(0x57e0)); + nosign1 = b.or_(i32_ty, nosign_1_0, nosign_1_1); + + nosign0 = b.add(i32_ty, nosign0, b.i32_val(0x00100010)); + nosign1 = b.add(i32_ty, nosign1, b.i32_val(0x00100010)); + nosign0 = b.sub(i32_ty, nosign0, b.i32_val(0x38003800)); + nosign1 = b.sub(i32_ty, nosign1, b.i32_val(0x38003800)); + nosign0 = b.shl(i32_ty, nosign0, b.i32_val(3)); + nosign1 = b.shl(i32_ty, nosign1, b.i32_val(3)); + + nosign0 = b.bitcast(nosign0, fp8x4VecTy); + nosign1 = b.bitcast(nosign1, fp8x4VecTy); + Value nosign = b.undef(fp8x4VecTy); + nosign = b.insert_element(fp8x4VecTy, nosign, + b.extract_element(i8_ty, nosign0, b.i32_val(1)), + b.i32_val(0)); + nosign = b.insert_element(fp8x4VecTy, nosign, + b.extract_element(i8_ty, nosign0, b.i32_val(3)), + b.i32_val(1)); + nosign = b.insert_element(fp8x4VecTy, nosign, + b.extract_element(i8_ty, nosign1, b.i32_val(1)), + b.i32_val(2)); + nosign = b.insert_element(fp8x4VecTy, nosign, + b.extract_element(i8_ty, nosign1, b.i32_val(3)), + b.i32_val(3)); + nosign = b.bitcast(nosign, i32_ty); + + Value fp8x4Vec = b.or_(i32_ty, nosign, sign); + fp8x4Vec = b.bitcast(fp8x4Vec, fp8x4VecTy); + return {b.extract_element(i8_ty, fp8x4Vec, b.i32_val(0)), + b.extract_element(i8_ty, fp8x4Vec, b.i32_val(1)), + b.extract_element(i8_ty, fp8x4Vec, b.i32_val(2)), + b.extract_element(i8_ty, fp8x4Vec, b.i32_val(3))}; +} + +// fp8e4m3fn to bf16 +static SmallVector Fp8E4M3FN_to_Bf16(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value a0 = b.undef(fp8x4VecTy); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(0)); + a0 = b.insert_element(fp8x4VecTy, a0, v[0], b.i32_val(1)); + a0 = b.insert_element(fp8x4VecTy, a0, b.int_val(8, 0), b.i32_val(2)); + a0 = b.insert_element(fp8x4VecTy, a0, v[1], b.i32_val(3)); + a0 = b.bitcast(a0, i32_ty); + + Value b0 = b.and_(i32_ty, a0, b.i32_val(0x7fff7fff)); + b0 = b.lshr(i32_ty, b0, b.i32_val(4)); + + Value c0 = b.shl(i32_ty, b0, b.i32_val(16)); + Value c1 = b.and_(i32_ty, b0, b.i32_val(0xFFFF0000)); + c0 = b.bitcast(c0, f32_ty); + c1 = b.bitcast(c1, f32_ty); + + Value d0 = b.fmul(f32_ty, c0, b.f32_val(0x1p+120)); // bias 2**(127-7) + Value d1 = b.fmul(f32_ty, c1, b.f32_val(0x1p+120)); + d0 = b.bitcast(d0, i32_ty); + d1 = b.bitcast(d1, i32_ty); + + Value out0 = b.or_(i32_ty, b.lshr(i32_ty, d0, b.i32_val(16)), d1); + Value sign0 = b.and_(i32_ty, a0, b.i32_val(0x80008000)); + out0 = b.or_(i32_ty, out0, sign0); + + auto bf16x2VecTy = vec_ty(bf16_ty, 2); + out0 = b.bitcast(out0, bf16x2VecTy); + return {b.extract_element(bf16_ty, out0, b.i32_val(0)), + b.extract_element(bf16_ty, out0, b.i32_val(1))}; +} + +// fp8e4m3fnuz to bf16 +static SmallVector +Fp8E4M3FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 2); + auto ret = cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8"); + ret[0] = convertFp32ToBf16(loc, rewriter, ret[0], RoundingMode::RTZ); + ret[1] = convertFp32ToBf16(loc, rewriter, ret[1], RoundingMode::RTZ); + return ret; +} + +// bf16 to fp8e4m3fnuz +static SmallVector +Bf16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 2); + auto v0 = convertBf16ToFp32(loc, rewriter, v[0]); + auto v1 = convertBf16ToFp32(loc, rewriter, v[1]); + return cvtFp32ToFp8(loc, rewriter, v0, v1, "fp8"); +} + +// fp8e5m2fnuz to bf16 +static SmallVector +Fp8E5M2FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 2); + auto ret = cvtFp8ToFp32(loc, rewriter, v[0], v[1], "bf8"); + ret[0] = convertFp32ToBf16(loc, rewriter, ret[0], RoundingMode::RTZ); + ret[1] = convertFp32ToBf16(loc, rewriter, ret[1], RoundingMode::RTZ); + return ret; +} + +// bf16 to fp8e5m2fnuz +static SmallVector +Bf16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + assert(v.size() == 2); + auto v0 = convertBf16ToFp32(loc, rewriter, v[0]); + auto v1 = convertBf16ToFp32(loc, rewriter, v[1]); + return cvtFp32ToFp8(loc, rewriter, v0, v1, "bf8"); +} + +static Value Fp8E4M3FNUZ_to_Fp16_oneValue(Location loc, + ConversionPatternRewriter &rewriter, + Value v) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + auto fp8x2VecTy = vec_ty(i8_ty, 2); + Value a = tb.undef(fp8x2VecTy); + a = tb.insert_element(fp8x2VecTy, a, tb.int_val(8, 0), tb.i32_val(0)); + a = tb.insert_element(fp8x2VecTy, a, v, tb.i32_val(1)); + a = tb.bitcast(a, i16_ty); + + auto e_mask = tb.int_val(16, 0x7A00); + auto e = tb.and_(i16_ty, a, e_mask); + + auto m = tb.and_(i16_ty, a, tb.int_val(16, 0x0700)); + auto sign = tb.and_(i16_ty, a, tb.int_val(16, 0x8000)); + + // check whether all exponents are zeros + auto e_is_zero = tb.icmp_eq(e, tb.int_val(16, 0x0)); + auto b = tb.and_(i16_ty, a, tb.int_val(16, 0x7FFF)); + auto b1 = tb.lshr(i16_ty, b, tb.int_val(16, 1)); + + // case 1, e is nonzero, add exponent by 6 + auto o0v = tb.add(i16_ty, b1, tb.int_val(16, 0x0C00)); + auto o0 = tb.or_(i16_ty, o0v, sign); + + // case 2, e is nonzero, add exponent by 7 + auto o1v = tb.add(i16_ty, b1, tb.int_val(16, 0x1C00)); + auto o1 = tb.or_(i16_ty, o1v, sign); + + auto io = tb.select(e_is_zero, o0, o1); + return tb.bitcast(io, f16_ty); +} + +static SmallVector +Fp8E4M3FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + SmallVector result(2); + result[0] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]); + result[1] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]); + return result; +} + +static SmallVector +Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "fp8"); +} + +static ConverterT Fp8E4M3FNUZ_to_Fp16(AMD::ISAFamily isaFamily) { + return isaFamily == AMD::ISAFamily::CDNA3 ? Fp8E4M3FNUZ_to_Fp16_HW + : Fp8E4M3FNUZ_to_Fp16_SW; +} + +// Fp16 -> Fp8E4M3 (packed) +static Value Fp16_to_Fp8E4M3FNUZ_oneValue(Location loc, + ConversionPatternRewriter &rewriter, + Value v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto vi16 = b.bitcast(v, i16_ty); + auto e10 = b.and_(vi16, b.int_val(16, 0x7C00)); + auto e = b.lshr(i16_ty, e10, b.int_val(16, 10)); + + auto s = b.and_(i16_ty, vi16, b.int_val(16, 0x8000)); + + auto m7 = b.and_(i16_ty, vi16, b.int_val(16, 0x0380)); + auto m = b.shl(i16_ty, m7, b.int_val(16, 1)); + + // three cases: + // 1) e > 21 --> e = 1111, + // 2) e <= 7 ---> e = 0, + // 3) others, normal conversion + auto e1 = b.int_val(16, 0x7800); + auto e2 = b.int_val(16, 0x0); + auto e31 = b.sub(i16_ty, e10, b.int_val(16, 0x1C00)); + auto e3 = b.shl(i16_ty, e31, b.int_val(16, 1)); + + auto c13 = b.icmp_sgt(e, b.int_val(16, 21)); + auto e13 = b.select(c13, e1, e3); + auto c23 = b.icmp_sle(e, b.int_val(16, 7)); + auto re = b.select(c23, e2, e13); + + auto r = b.or_(i16_ty, s, b.or_(i16_ty, re, m)); + auto fp8x2VecTy = vec_ty(i8_ty, 2); + auto res = b.bitcast(r, fp8x2VecTy); + + return b.extract_element(i8_ty, res, b.i32_val(1)); +} + +static SmallVector +Fp16_to_Fp8E4M3FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + SmallVector result(2); + result[0] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[0]); + result[1] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[1]); + + return result; +} + +static SmallVector +Fp16_to_Fp8E4M3FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "fp8"); +} + +static ConverterT Fp16_to_Fp8E4M3FNUZ(AMD::ISAFamily isaFamily) { + return isaFamily == AMD::ISAFamily::CDNA3 ? Fp16_to_Fp8E4M3FNUZ_HW + : Fp16_to_Fp8E4M3FNUZ_SW; +} + +//===----------------------------------------------------------------------===// +// Data type conversion patterns +//===----------------------------------------------------------------------===// + +template +struct ElementwiseOpConversion + : public ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion> { + using Base = ElementwiseOpConversionBase; + using OpAdaptor = typename Base::OpAdaptor; + + using Base::Base; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; + } +}; + +// Attempts to use vectorized conversions via inline PTX when possible. +struct FpToFpOpConversion + : public ElementwiseOpConversionBase { + explicit FpToFpOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + AMD::ISAFamily isaFamily, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + isaFamily(isaFamily) {} + + static Value convertFp16ToFp32(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v) { + return cvtFp16ToFp32(loc, rewriter, v); + } + + FailureOr + getConversionFunc(Type srcTy, Type dstTy, + std::optional roundingMode) const { + auto F8E4M3B15TyID = TypeID::get(); + auto F8E4M3FNUZTyID = TypeID::get(); + auto F8E5M2FNUZTyID = TypeID::get(); + auto F8E5M2TyID = TypeID::get(); + auto F8E4M3FNTyID = TypeID::get(); + auto F16TyID = TypeID::get(); + auto BF16TyID = TypeID::get(); + auto F32TyID = TypeID::get(); + auto F64TyID = TypeID::get(); + + auto undefRounding = static_cast(-1); + + static DenseMap, ConverterT> + srcMap = { + // F8 -> F16 + {{F8E4M3FNUZTyID, F16TyID, undefRounding}, + Fp8E4M3FNUZ_to_Fp16(isaFamily)}, + {{F8E4M3FNTyID, F16TyID, undefRounding}, Fp8E4M3FN_to_Fp16}, + {{F8E5M2FNUZTyID, F16TyID, undefRounding}, + Fp8E5M2FNUZ_to_Fp16(isaFamily)}, + {{F8E5M2TyID, F16TyID, undefRounding}, Fp8E5M2_to_Fp16}, + // F16 -> F8 + {{F16TyID, F8E4M3FNTyID, RoundingMode::RTNE}, + Fp16_to_Fp8E4M3FN_RTNE}, + {{F16TyID, F8E5M2FNUZTyID, RoundingMode::RTNE}, + Fp16_to_Fp8E5M2FNUZ(isaFamily)}, + {{F16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE}, + Fp16_to_Fp8E4M3FNUZ(isaFamily)}, + {{F16TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp16_to_Fp8E5M2_RTNE}, + {{F16TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp16_to_Fp8E5M2_RTZ}, + // F8 -> BF16 + {{F8E5M2TyID, BF16TyID, undefRounding}, Fp8E5M2_to_Bf16}, + {{F8E5M2FNUZTyID, BF16TyID, undefRounding}, Fp8E5M2FNUZ_to_Bf16}, + {{F8E4M3FNTyID, BF16TyID, undefRounding}, Fp8E4M3FN_to_Bf16}, + {{F8E4M3FNUZTyID, BF16TyID, undefRounding}, Fp8E4M3FNUZ_to_Bf16}, + // BF16 -> F8 + {{BF16TyID, F8E5M2TyID, RoundingMode::RTNE}, Bf16_to_Fp8E5M2}, + {{BF16TyID, F8E5M2FNUZTyID, RoundingMode::RTNE}, + Bf16_to_Fp8E5M2FNUZ}, + {{BF16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE}, + Bf16_to_Fp8E4M3FNUZ}, + // F32 <-> F8 + {{F32TyID, F8E4M3FNUZTyID, RoundingMode::RTNE}, + Fp32_to_Fp8E4M3FNUZ}, + {{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE}, + Fp32_to_Fp8E5M2FNUZ}, + {{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32}, + {{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32}, + }; + std::tuple key = { + srcTy.getTypeID(), dstTy.getTypeID(), + roundingMode.value_or(undefRounding)}; + if (srcMap.count(key) == 0) { + return failure(); + } + return srcMap.lookup(key); + } + + SmallVector createDestOps(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcElementType = getElementType(op.getSrc()); + auto dstElementType = getElementType(op.getResult()); + auto roundingMode = op.getRounding(); + + if (srcElementType.isF32() && dstElementType.isF16()) { + assert(roundingMode.has_value() && + "rounding mode must be specified for fp32->fp16 conversion"); + SmallVector outVals; + outVals.reserve(operands[0].size()); + for (Value v : operands[0]) { + outVals.push_back( + LLVM::AMD::cvtFp32ToFp16(loc, rewriter, v, roundingMode.value())); + } + return outVals; + } + + if (srcElementType.isF32() && dstElementType.isBF16()) { + assert(roundingMode.has_value() && + "rounding mode must be specified for fp32->bf16 conversion"); + SmallVector outVals; + outVals.reserve(operands[0].size()); + for (Value v : operands[0]) { + outVals.push_back( + convertFp32ToBf16(loc, rewriter, v, roundingMode.value())); + } + return outVals; + } + size_t numElements = 4; + if (llvm::isa( + srcElementType) || + llvm::isa( + dstElementType)) { + numElements = 2; + } + bool useFP16IntermediateSrc = + srcElementType.isF32() && + !(isaFamily == AMD::ISAFamily::CDNA3 && + (llvm::isa(dstElementType))); + bool isDstFP32 = dstElementType.isF32(); + Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; + Type dstType = isDstFP32 ? f16_ty : dstElementType; + SmallVector inVals; + inVals.reserve(std::min(numElements, operands.size())); + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { + inVals.push_back(operands[i][0]); + } + bool isSrcFP16 = srcElementType.isF16(); + bool isSrcBF16 = srcElementType.isBF16(); + + if ((isSrcFP16 || isSrcBF16) && isDstFP32) { + SmallVector outVals; + for (Value &v : inVals) { + if (isSrcFP16) + outVals.push_back(convertFp16ToFp32(loc, rewriter, v)); + else + outVals.push_back(convertBf16ToFp32(loc, rewriter, v)); + } + return outVals; + } + if (useFP16IntermediateSrc) + for (Value &v : inVals) + v = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, v, + roundingMode.value_or(RoundingMode::RTNE)); + inVals.resize(numElements, b.undef(typeConverter->convertType(srcType))); + SmallVector outVals; + if (srcType != dstType) { + auto getCvtFunc = getConversionFunc(srcType, dstType, roundingMode); + if (failed(getCvtFunc)) { + std::string rmError; + if (roundingMode.has_value()) + rmError = std::string(" with rounding mode ") + + stringifyRoundingMode(roundingMode.value()).str(); + op->emitError("Unsupported conversion from ") + << srcType << " to " << dstType << rmError; + return outVals; + } else { + auto cvtFunc = getCvtFunc.value(); + outVals = cvtFunc(loc, rewriter, inVals); + } + } else { + outVals = inVals; + } + + assert(outVals.size() == inVals.size()); + outVals.resize(std::min(numElements, operands.size())); + if (isDstFP32) + for (Value &v : outVals) + v = convertFp16ToFp32(loc, rewriter, v); + // Pack values + return outVals; + } + +private: + AMD::ISAFamily isaFamily; +}; + +template +Value EmitDualBF16ElementwiseOp(Location loc, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands) { + auto v0 = convertBf16ToFp32(loc, rewriter, operands[0][0]); + auto v1 = convertBf16ToFp32(loc, rewriter, operands[0][1]); + auto result = rewriter.create(loc, f32_ty, v0, v1); + return convertFp32ToBf16(loc, rewriter, result, RoundingMode::RTNE); +} + +struct FDivOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + SmallVector createDestOps(arith::DivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } +}; + +struct FMulOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + SmallVector createDestOps(arith::MulFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {EmitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + } +}; + +struct FAddOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + SmallVector createDestOps(arith::AddFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {EmitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + } +}; + +struct FSubOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + SmallVector createDestOps(arith::SubFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {EmitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + } +}; + +static SmallVector S8_to_Bf16(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector inValues = {v[0], v[1], v[2], v[3]}; + SmallVector outValues = {}; + for (Value inVal : inValues) { + Value i32Val = b.sext(i32_ty, inVal); + + GCNBuilder builder; + auto &cvt = *builder.create("v_cvt_f32_i32"); + auto res = builder.newOperand("=v"); + auto operand = builder.newOperand(i32Val, "v"); + cvt(res, operand); + auto f32Val = builder.launch(rewriter, loc, f32_ty, false); + + f32Val = b.bitcast(f32Val, i32_ty); + auto shifted = b.lshr(i32_ty, f32Val, b.i32_val(16)); + auto truncated = b.trunc(i16_ty, shifted); + outValues.push_back(b.bitcast(truncated, bf16_ty)); + } + return outValues; +} + +// Uses inline ptx to convert s8/u8 to bf16, since the +struct SIToFPOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + Type inElemTy = getElementType(op.getIn()); + Type outElemTy = getElementType(op.getOut()); + if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) { + SmallVector inVals = {operands[0][0], operands[1][0], + operands[2][0], operands[3][0]}; + auto outVals = S8_to_Bf16(loc, rewriter, inVals); + assert(outVals.size() == 4); + return outVals; + } else if (outElemTy.isBF16()) { + auto value = rewriter.create(loc, f32_ty, operands[0][0]); + return {convertFp32ToBf16(loc, rewriter, value, RoundingMode::RTNE)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0])}; + } + } +}; + +struct FPToSIOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + SmallVector createDestOps(arith::FPToSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto inElemTy = getElementType(op.getIn()); + if (inElemTy.isBF16()) { + auto value = convertBf16ToFp32(loc, rewriter, operands[0][0]); + return {rewriter.create(loc, elemTy, value)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0])}; + } + } +}; + +struct ExtFOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + SmallVector createDestOps(arith::ExtFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto inElemTy = getElementType(op.getIn()); + if (inElemTy.isBF16()) { + auto outElemTy = getElementType(op.getOut()); + assert(outElemTy.isF32() && "unsupported conversion"); + return {convertBf16ToFp32(loc, rewriter, operands[0][0])}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0])}; + } + } +}; + +struct TruncFOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + explicit TruncFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + llvm::AMDGPU::GPUKind gpuKind, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + gpuKind(gpuKind) {} + + SmallVector createDestOps(arith::TruncFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto outElemTy = getElementType(op.getOut()); + if (outElemTy.isBF16() && gpuKind != llvm::AMDGPU::GK_GFX950) { + auto inElemTy = getElementType(op.getIn()); + assert(inElemTy.isF32() && "unsupported conversion"); + return { + convertFp32ToBf16(loc, rewriter, operands[0][0], RoundingMode::RTNE)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0])}; + } + } + +private: + llvm::AMDGPU::GPUKind gpuKind; +}; + +struct ExpOpConversionApprox + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase::ElementwiseOpConversionBase; + + SmallVector createDestOps(math::ExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // For non-FP32 input, call __ocml_exp_f64 for higher-precision calculation + if (elemTy.getIntOrFloatBitWidth() != 32) + return {}; + + const double log2e = 1.4426950408889634; + Value prod = b.fmul(f32_ty, operands[0][0], b.f32_val(log2e)); + + // Here we use llvm.exp2.f32 instead of math::Exp2Op. The latter + // flushes denorms by default, but we want to preserve denorms by default + // for expOp. + StringRef funcName = "llvm.exp2.f32"; + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + return {LLVM::createLLVMCallOp(rewriter, loc, funcOp, prod).getResult()}; + } +}; + +struct Exp2OpConversion + : ElementwiseOpConversionBase { + explicit Exp2OpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz, + PatternBenefit benefit) + : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), + ftz(ftz) {} + + SmallVector createDestOps(math::Exp2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // For non-FP32 input, call __ocml_exp2_f64 for higher-precision calculation + if (elemTy.getIntOrFloatBitWidth() != 32) + return {}; + + // On AMD backend, both intrinsics are lowered to v_exp_f32 instruction, + // which flushes input and output denorms. `llvm.amdgcn.exp2.f32` provides + // direct access to v_exp_f32. For `llvm.exp2.f32`, the LLVM backend inserts + // instructions to handle denorms iff `allow_flush_denorm` is False. + StringRef funcName = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32"; + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } + +private: + bool ftz; +}; + +struct RsqrtOpConversion + : ElementwiseOpConversionBase { + explicit RsqrtOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz, + PatternBenefit benefit) + : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), + ftz(ftz) {} + + SmallVector createDestOps(math::RsqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // This pass only deals with FP32 input with ftz configuration. Other cases + // are delegate to MLIR. + // + // For FP16/FP64 input, it's lowered to __ocml_rsqrt_f16/__ocml_rsqrt_f64. + // + // For FP32 input with non-ftz configuration, it's lowered to + // __ocml_rsqrt_f32, which will check the ftz/daz settings in the backend + // dynamically to decide to preserve/flush denorms. + if (elemTy.getIntOrFloatBitWidth() != 32 || !ftz) + return {}; + + // `llvm.amdgcn.rsq.f32` provides direct access to v_rsq_f32_e32. + StringRef funcName = "llvm.amdgcn.rsq.f32"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } + +private: + bool ftz; +}; + +static inline std::pair +scaleUpIfDenorm(ConversionPatternRewriter &rewriter, Location loc, + const Value &src, float scaleThreshold, float scaleFactor) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value needScale = b.fcmp_ogt(b.f32_val(scaleThreshold), src); + Value scaledSrc = b.fmul(f32_ty, src, b.f32_val(scaleFactor)); + Value selectedSrc = b.select(needScale, scaledSrc, src); + return {needScale, selectedSrc}; +} + +static inline Value scaleDownIfDenorm(ConversionPatternRewriter &rewriter, + Location loc, const Value &src, + Value needScale, float scaleFactor) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value scaledSrc = b.fmul(f32_ty, src, b.f32_val(scaleFactor)); + return b.select(needScale, scaledSrc, src); +} + +struct SqrtOpConversion + : ElementwiseOpConversionBase { + explicit SqrtOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz, + PatternBenefit benefit) + : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), + ftz(ftz) {} + + SmallVector createDestOps(math::SqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // This function only handles FP32 inputs. Other data types are lowered to + // LLVM::SqrtOp by MLIR. + // + // On the AMDGPU backend, instructions legalized from LLVM::SqrtOp are + // designed to produce IEEE-compliant results and always preserve denorms. + // But what we actually need is an approximated SQRT. So we need to manually + // lower the op. + // + // Differences in this approach are + // 1. Refinement iterations following llvm.amdgcn.sqrt.f32 are removed to + // improve performance. + // 2. With ftz enabled, the scaling-up-and-down process is bypassed to + // ensure denorms are flushed to zero. + if (elemTy.getIntOrFloatBitWidth() != 32) + return {}; + + Value needScale = b.false_val(); + Value scaledSrc = operands[0][0]; + if (!ftz) { + // For non-ftz cases, if the input value is below 2^{-96}, it needs to be + // scaled up by a factor of 2^{32}, to prevent it from being flushed by + // llvm.amdgcn.sqrt.f32. + // + // The result is then scaled down afterward to get the correct result. + // Reference: + // https://github.com/llvm/llvm-project/blob/0876c11c/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L5235-L5314. + std::tie(needScale, scaledSrc) = scaleUpIfDenorm( + rewriter, loc, operands[0][0], 0x1.0p-96f, 0x1.0p+32f); + } + + // llvm.amdgcn.sqrt.f32 provides direct access to v_sqrt_f32, which provides + // 1ULP accuracy and flushs denorms. + StringRef funcName = "llvm.amdgcn.sqrt.f32"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + Value intrinsicsOutput = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult(); + + if (!ftz) { + // In case of non-ftz, we need to calibrate the results by scaling down by + // a factor of 2^{-16}. + return {scaleDownIfDenorm(rewriter, loc, intrinsicsOutput, needScale, + 0x1.0p-16f)}; + } else { + return {intrinsicsOutput}; + } + } + +private: + bool ftz; +}; + +struct PreciseSqrtOpConversion + : ElementwiseOpConversionBase { + explicit PreciseSqrtOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool ftz, PatternBenefit benefit) + : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), + ftz(ftz) {} + + SmallVector createDestOps(triton::PreciseSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // If the op is neither FP32 nor denorm flushing(ftz), it's directly lowered + // to LLVM::SqrtOp. + if (elemTy.getIntOrFloatBitWidth() != 32 || !ftz) { + return {rewriter.create( + loc, elemTy, operands[0], adaptor.getAttributes().getValue())}; + } + + // On the AMDGPU backend, instructions legalized from LLVM::SqrtOp are + // designed to always preserve denorms, according to + // https://github.com/llvm/llvm-project/blob/3d6b2d49/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L5235-L5314. + // + // For f32 inputs with ftz enabled, we need to manually lower the op to + // bypass the scaling-up-and-down process while keeping other parts + // unchanged. To ensure IEEE-compliant results, we approximate `sqrt(x)` + // using `x * rsq(x)` and apply extra refinement iterations to correct the + // result. + StringRef funcName = "llvm.amdgcn.rsq.f32"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + Value sqrtR = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult(); + + Value sqrtX = operands[0][0]; + Value sqrtS = b.fmul(f32_ty, sqrtX, sqrtR); + + // Refine the approximation with Newton iteration + Value sqrtH = b.fmul(f32_ty, sqrtR, b.f32_val(0.5f)); + Value sqrtE = b.fma(b.neg(f32_ty, sqrtH), sqrtS, b.f32_val(0.5f)); + sqrtH = b.fma(sqrtH, sqrtE, sqrtH); + sqrtS = b.fma(sqrtS, sqrtE, sqrtS); + Value sqrtD = b.fma(b.neg(f32_ty, sqrtS), sqrtS, sqrtX); + sqrtS = b.fma(sqrtD, sqrtH, sqrtS); + + // Handle +0/-0/+inf + // These flags come from + // https://github.com/llvm/llvm-project/blob/217e0f39/llvm/include/llvm/ADT/FloatingPointMode.h#L239-L265. + const unsigned fcPosInf = 0x0200; + const unsigned fcNegZero = 0x0020; + const unsigned fcPosZero = 0x0040; + const unsigned fcZero = fcNegZero | fcPosZero; + + Value isZeroOrPosInf = + rewriter.create(loc, i1_ty, sqrtX, fcPosInf | fcZero); + return {b.select(isZeroOrPosInf, sqrtX, sqrtS)}; + } + +private: + bool ftz; +}; + +} // namespace + +namespace mlir::triton::AMD { +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, + ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, + const TargetInfo &targetInfo, PatternBenefit benefit) { + + // fmin (return NaN if either op is NaN) + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + // fmax (return NaN if either op is NaN) + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + targetInfo.getGPUKind(), benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + targetInfo.getISAFamily(), benefit); + + // ExpOpConversionApprox will try using __ocml_exp2_f32 if the input type is + // FP32. For other input types, ExpOpConversionApprox will return failure and + // later pass will call __ocml_exp_f64 for higher-precision calculation + patterns.add(typeConverter, axisInfoAnalysis, benefit); + // Exp2OpConversion will use llvm.exp2.f32 or llvm.amdgcn.exp2.f32 + // based on the ftz flag if the input type is FP32. For FP64 input, + // Exp2OpConversion will return failure and later pass will call + // __ocml_exp2_f64 for higher-precision calculation + patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); + patterns.add(typeConverter, axisInfoAnalysis, ftz, + benefit); + patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); + patterns.add(typeConverter, axisInfoAnalysis, ftz, + benefit); + triton::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); + triton::populateMinMaxFOpToLLVMPattern( + typeConverter, patterns, axisInfoAnalysis, + /*hwNanPropagationSupported=*/false, benefit); + triton::populateClampFOpToLLVMPattern(typeConverter, patterns, + axisInfoAnalysis, targetInfo, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp new file mode 100644 index 000000000..2de1c0f3d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp @@ -0,0 +1,188 @@ +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h" +#include "llvm/Support/raw_ostream.h" +#include // unify to llvm::raw_string_ostream ? + +namespace mlir::triton { + +GCNInstr::Operand * +GCNBuilder::newOperand(mlir::Value value, StringRef constraint, + std::function formatter) { + argArchive.emplace_back(std::make_unique(value, constraint)); + auto *opr = argArchive.back().get(); + opr->repr = formatter; + opr->idx = oprCounter++; + return opr; +} + +GCNBuilder::Operand *GCNBuilder::newOperand(StringRef constraint) { + // Constraint should be something like "=r" + assert(!constraint.empty() && constraint[0] == '='); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = constraint; + return opr; +} + +GCNBuilder::Modifier *GCNBuilder::newModifier(StringRef modifier, + StringRef arg) { + assert(!modifier.empty()); + auto *mod = newModifier(); + mod->modifier = modifier; + mod->arg = arg; + return mod; +} + +GCNBuilder::Operand *GCNBuilder::newConstantOperand(const std::string &v) { + argArchive.emplace_back(std::make_unique()); + argArchive.back()->repr = [v](int idx) { return v; }; + return argArchive.back().get(); +} + +GCNBuilder::Operand *GCNBuilder::newConstantOperand(int v) { + std::stringstream ss; + ss << "0x" << std::hex << v; + return newConstantOperand(ss.str()); +} + +std::string GCNBuilder::getConstraints() const { + auto args = getAllArgs(); + llvm::SmallVector argReprs; + for (auto arg : args) + argReprs.push_back(arg->constraint); + return strJoin(argReprs, ","); +} + +llvm::SmallVector GCNBuilder::getAllMLIRArgs() const { + llvm::SmallVector res; + for (auto &arg : argArchive) { + if (!arg->isList() && arg->value) + res.push_back(arg->value); + } + return res; +} + +SmallVector GCNBuilder::getAllArgs() const { + llvm::SmallVector res; + for (auto &x : argArchive) + if (!x->isList()) + res.push_back(x.get()); + return res; +} + +mlir::Value GCNBuilder::launch(RewriterBase &rewriter, Location loc, Type resTy, + bool hasSideEffect, bool isAlignStack, + ArrayRef attrs) const { + auto *ctx = rewriter.getContext(); + auto inlineAsm = rewriter.create( + loc, resTy, getAllMLIRArgs(), // operands + dump(), // asm_string + getConstraints(), // constraints + hasSideEffect, // has_side_effects + isAlignStack, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, attrs) // operand_attrs + ); + + return inlineAsm.getRes(); +} + +std::string GCNInstr::Operand::dump() const { + if (repr) + return repr(idx); + if (!isList()) + return "$" + std::to_string(idx); + + llvm::SmallVector oprs; + for (auto *opr : list) + oprs.push_back(opr->dump()); + return strJoin(oprs, ", "); +} + +std::string GCNInstr::Modifier::dump() const { + if (!isList()) + return to_str(); + + llvm::SmallVector mods; + for (auto *mod : list) + mods.push_back(mod->dump()); + return strJoin(mods, " "); +} + +GCNInstr::Operand *GCNBuilder::newAddrOperand(mlir::Value addr, + StringRef constraint) { + auto *opr = newOperand(addr, constraint); + opr->repr = [](int idx) -> std::string { + std::stringstream ss; + ss << "$" << idx; + return ss.str(); + }; + + return opr; +} + +std::string GCNBuilder::dump() const { + llvm::SmallVector lines; + for (auto &exec : executions) { + lines.push_back(exec->dump()); + } + + return strJoin(lines, "\n\t"); +} + +GCNInstrExecution &GCNInstrCommon::call(ArrayRef oprs, + ArrayRef mods) { + builder->executions.emplace_back( + std::make_unique(this, oprs, mods)); + return *builder->executions.back(); +} + +GCNInstrExecution &GCNInstrCommon::operator()(ArrayRef oprs, + ArrayRef mods) { + return call(oprs, mods); +} + +std::string GCNInstrExecution::dump() const { + std::string osStr; + llvm::raw_string_ostream os(osStr); + + std::string instrRepr = strJoin(instr->instrParts, "_"); + + llvm::SmallVector argReprs; + for (auto *arg : argsInOrder) { + argReprs.push_back(arg->dump()); + } + + std::string argsRepr = strJoin(argReprs, ", "); + + llvm::SmallVector modReprs; + for (auto *mod : mods) { + modReprs.push_back(mod->dump()); + } + + std::string modsRepr = strJoin(modReprs, " "); + if (!modsRepr.empty()) { + os << instrRepr << " " << argsRepr << " " << modsRepr; + } else { + os << instrRepr << " " << argsRepr; + } + os.flush(); + return osStr; +} + +SmallVector +GCNInstrExecution::getArgList() const { + SmallVector args; + for (auto *arg : argsInOrder) { + if (arg->isList()) + args.insert(args.end(), arg->list.begin(), arg->list.end()); + else + args.push_back(arg); + } + return args; +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp new file mode 100644 index 000000000..2b4c94b59 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -0,0 +1,1946 @@ +#include "BufferOpsEmitter.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::getSharedMemoryBase; +using ::mlir::LLVM::AMD::getVectorSize; +using ::mlir::LLVM::AMD::llLoad; +using ::mlir::LLVM::AMD::llStore; +using ::mlir::triton::AMD::ISAFamily; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +namespace { +// Return the mask for the unique data accessed by given tensor type. +// Used to mask out the redundant data accessed by threads. +Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, + Location loc, const AMD::TargetInfo &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto tensorTy = dyn_cast(valueTy); + Value mask = b.true_val(); + auto tid = getThreadId(rewriter, loc); + auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc); + if (tensorTy) { + // To remove this use, port https://github.com/triton-lang/triton/pull/5432 + // to the AMDGPU dialect + auto layout = cast(tensorTy.getEncoding()); + auto shape = tensorTy.getShape(); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + auto kLane = StringAttr::get(rewriter.getContext(), "lane"); + auto kWarp = StringAttr::get(rewriter.getContext(), "warp"); + auto maskLane = + std::get<1>(delinearize(rewriter, loc, layout, shape, kLane, laneId)); + auto maskWarp = + std::get<1>(delinearize(rewriter, loc, layout, shape, kWarp, warpId)); + mask = b.and_(maskLane, maskWarp); + + // Do not write duplicated data when multicast is enabled + if (triton::gpu::getNumCTAs(layout) > 1) { + auto _0 = b.i32_val(0); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); + auto CTASplitNum = triton::gpu::getCTASplitNum(layout); + auto CTAOrder = triton::gpu::getCTAOrder(layout); + + auto multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + auto rank = tensorTy.getRank(); + for (unsigned dim = 0; dim < rank; ++dim) { + // Skip when multicast is not enabled in this dimension + if (CTAsPerCGA[dim] == CTASplitNum[dim]) + continue; + unsigned splitNum = std::min(shape[dim], CTASplitNum[dim]); + Value repId = b.udiv(multiDimClusterCTAId[dim], b.i32_val(splitNum)); + // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]: + // CTA0 and CTA2 holds data of block0, + // CTA1 and CTA3 holds data of block1. + // Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should + // be masked. We add the following mask: + // multiDimClusterCTAId[dim] / splitNum == 0 + // Actually in all existing cases of multicast, splitNum is always 1. + // The mask is equivalent to: + // multiDimClusterCTAId[dim] == 0 + mask = b.and_(mask, b.icmp_eq(repId, _0)); + } + } + } else { + // If the tensor is not ranked, then it is a scalar and only thread 0 of + // CTA0 can write + mask = b.and_(mask, b.icmp_eq(clusterCTAId, b.i32_val(0))); + mask = b.and_(mask, b.icmp_eq(tid, b.i32_val(0))); + } + return mask; +} + +// Contains some helper functions for both Load and Store conversions. +struct LoadStoreConversionBase { + explicit LoadStoreConversionBase(const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass) + : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} + + // Createa a LLVM vector of type `vecTy` containing all zeros + Value createZeroVector(OpBuilder &builder, Location loc, + VectorType vecTy) const { + mlir::Attribute zeroAttr = builder.getZeroAttr(vecTy.getElementType()); + auto denseValue = + DenseElementsAttr::get(cast(vecTy), zeroAttr); + Value zeroVal = builder.create(loc, vecTy, denseValue); + return zeroVal; + } + + // Given a vector of values `elems` and a starting point `start`, create a + // LLVM vector of length `vec` whose elements are `elems[start, ..., + // elems+vec-1]` + Value packElementRangeIntoVector(ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + Location loc, VectorType vecTy, + ArrayRef elems, int64_t start) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + int64_t vec = vecTy.getNumElements(); + // If we need to mask the loaded value with other elements + Value v = b.undef(vecTy); + for (size_t s = 0; s < vec; ++s) { + Value otherElem = elems[start + s]; + Value indexVal = + LLVM::createIndexConstant(rewriter, loc, typeConverter, s); + v = b.insert_element(vecTy, v, otherElem, indexVal); + } + return v; + } + + // Return a tensor of pointers with the same type of `basePtr` and the same + // shape of `offset` + Type getPointerTypeWithShape(Value basePtr, Value offset) const { + Type basePtrType = basePtr.getType(); + auto offsetType = cast(offset.getType()); + return offsetType.cloneWith(std::nullopt, basePtrType); + } + + // Unpack the elements contained in a `llvmStruct` into a `SmallVector` of + // `Value`s. While you do that, check also the alignment of the mask and + // update the vector length `vec` accordingly + SmallVector + getMaskElemsAndUpdateVeclen(ConversionPatternRewriter &rewriter, Location loc, + Value llMask, Value mask, unsigned &vec) const { + SmallVector maskElems; + if (llMask) { + vec = std::min(vec, getMaskAlignment(mask)); + maskElems = unpackLLElements(loc, llMask, rewriter); + } + return maskElems; + } + + unsigned getMaskAlignment(Value mask) const { + return axisAnalysisPass.getMaskAlignment(mask); + } + + std::optional getAMDGPUMemScopeStr(MemSyncScope scope) const { + // See: https://llvm.org/docs/AMDGPUUsage.html#memory-scopes + auto scopeStr = ""; + switch (scope) { + case MemSyncScope::SYSTEM: + // The default AMDHSA LLVM Sync Scope is "system", so no string is + // provided here + scopeStr = ""; + break; + case MemSyncScope::GPU: + scopeStr = "agent"; + break; + case MemSyncScope::CTA: + scopeStr = "workgroup"; + break; + default: + return std::nullopt; + } + + return scopeStr; + } + +protected: + const AMD::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + +struct LoadOpConversion : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LoadOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // original values + Value ptr = op.getPtr(); + Value mask = op.getMask(); + Value other = op.getOther(); + + // adaptor values + assert(!isTensorPointerType(ptr.getType()) && + "Cannot convert load with a tensor pointer into LLVM; " + "this case should be transformed to normal load before lowering"); + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + + // Determine the vectorization size + Type valueTy = op.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + unsigned vec = getVectorSize(ptr, axisAnalysisPass); + unsigned numElems = getTotalElemsPerThread(ptr.getType()); + + // Get the LLVM values for pointers + auto ptrElems = unpackLLElements(loc, llPtr, rewriter); + assert(ptrElems.size() == numElems); + + // Get the LLVM values for mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + SmallVector otherElems; + if (other) + otherElems = unpackLLElements(loc, llOther, rewriter); + + // vectorized iteration through all the pointer/mask/other elements + const int valueElemNBits = + std::max(8u, valueElemTy.getIntOrFloatBitWidth()); + const size_t valueElemNBytes = valueElemNBits / 8; + const int numVecs = numElems / vec; + + auto cacheMod = op.getCache(); + SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t totalWidth = valueElemNBits * vec; + const size_t width = std::min(totalWidth, maxWordWidth); + const size_t nWords = std::max(1, totalWidth / width); + const size_t wordNElems = width / valueElemNBits; + const size_t movWidth = width < 16 ? 16 : width; + assert(wordNElems * nWords * numVecs == numElems); + + Value pred = mask ? maskElems[vecStart] : b.int_val(1, 1); + Value ptr = ptrElems[vecStart]; + + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + // If we need to mask the loaded value with other elements + if (otherElems.size() != 0) + falseVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + otherElems, vecStart); + + Value loadVal = + llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, cacheMod); + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + Value loaded = b.extract_element(valueElemTy, loadVal, vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, + rewriter, llvmResultStructTy); + + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct BufferLoadOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferLoadOp>::ConvertOpToLLVMPattern; + + BufferLoadOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value other = op.getOther(); + auto cacheMod = op.getCache(); + + // Converted values + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + Value llStride = adaptor.getStride(); + + // Determine the vectorization size + Type valueTy = op.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); + + // Get the offset + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + assert(offsetElems.size() == numElems); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + // Get the `other` value (if any) + SmallVector otherElems; + if (llOther) + otherElems = unpackLLElements(loc, llOther, rewriter); + + // Create the resource descriptor and then emit the buffer_load intrinsic(s) + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride); + SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Value pred = mask ? maskElems[vecStart] : b.int_val(1, 1); + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + if (otherElems.size() != 0) + falseVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + otherElems, vecStart); + Value loadVal = bufferEmitter.emitLoad( + vecTy, rsrcDesc, offsetElems[vecStart], pred, falseVal, cacheMod); + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + Value loaded = b.extract_element(valueElemTy, loadVal, vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, + rewriter, llvmResultStructTy); + + const int numVecs = numElems / vec; + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct BufferLoadToLocalOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + BufferLoadToLocalOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferLoadToLocalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // Original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + + // Converted values + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llDst = adaptor.getDest(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + Value llStride = adaptor.getStride(); + + RankedTensorType ptrType = + cast(getPointerTypeWithShape(ptr, offset)); + unsigned numElems = getTotalElemsPerThread(ptrType); + + // We can load N elements at a time if: + // 1. Every group of N source pointers are contiguous. For example, if + // N=2, then the pointers should be [x, x+1, y, y+1, ...]. + // 2. The mask (if present) has "alignment" N, meaning that each group of N + // mask bits are the same. For example if N=2, the mask must be + // [x, x, y, y, ...]. + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + assert(offsetElems.size() == numElems); + + SmallVector otherElems; + if (llOther) + otherElems = unpackLLElements(loc, llOther, rewriter); + + // buffer_load into LDS does not support per lane offsets. + // We need to ensure that we write coalesced into shared memory. + auto dstTy = op.getDest().getType(); + if (!LLVM::AMD::canCoalesceWriteIntoSharedMemory(rewriter, ptrType, dstTy, + vec)) { + return rewriter.notifyMatchFailure(op, + "does not write coalesced into LDS"); + } + + auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( + loc, llDst, resElemTy, rewriter); + + // First we determine the vector size per load and collect the + // shared addresses. This will only emit the address calculation and not the + // actual loads + VectorType vecTy; + SmallVector shmemAddrs; + bool ok = emitTransferBetweenRegistersAndShared( + ptrType, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { + vecTy = vecTy_; + shmemAddrs.push_back(shmemAddr); + }); + assert(ok); + + int vecBits = vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); + if (!targetInfo.supportsDirectToLdsLoadBitWidth(vecBits)) { + return rewriter.notifyMatchFailure( + op, "Buffer load to local does not support the required load vector " + "bitwidth" + + std::to_string(vecBits)); + } + + int vecBytes = vecBits / 8; + assert(llvm::isPowerOf2_32(vecBytes)); + Value vecBytesVal = b.i32_val(vecBytes); + + // Create the resource descriptor and then emit the buffer_loads to lds + // based on the collected shared addresses and vector size + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride); + + for (int i = 0; i < shmemAddrs.size(); i++) { + auto srcIdx = i * vec; + auto offsetIn = offsetElems[srcIdx]; + + Value pred = mask ? maskElems[srcIdx] : b.true_val(); + bufferEmitter.emitLoadToLds(vecTy, vecBytesVal, rsrcDesc, offsetIn, + shmemAddrs[i], pred, op.getCache()); + if (!otherElems.empty()) { + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx); + llStore(rewriter, loc, shmemAddrs[i], storeVal, + b.icmp_ne(maskElems[srcIdx], b.true_val()), op.getCache()); + } + } + + // Drop the result token. + Value zero = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); + rewriter.replaceOp(op, zero); + return success(); + } +}; + +struct AsyncCopyGlobalToLocalOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto srcTy = op.getSrc().getType(); + auto srcEncoding = srcTy.getEncoding(); + + if (!isa(srcEncoding)) + return rewriter.notifyMatchFailure( + op, "requires Blocked or Slice encoding for src"); + if (srcTy.getShape().size() != 2) + return rewriter.notifyMatchFailure(op, "only supports 2d tensors"); + + auto dstTy = op.getResult().getType(); + auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + + Value llSrc = adaptor.getSrc(); + + auto srcElems = unpackLLElements(loc, llSrc, rewriter); + + Value llDst = adaptor.getResult(); + auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( + loc, llDst, resElemTy, rewriter); + + // We can load N elements at a time if: + // 1. Every group of N source pointers are contiguous. For example, if + // N=2, then the pointers should be [x, x+1, y, y+1, ...]. + // 2. The mask (if present) has "alignment" N, meaning that each group of N + // mask bits are the same. For example if N=2, the mask must be + // [x, x, y, y, ...]. + unsigned maxVec = + mlir::LLVM::AMD::getContiguity(op.getSrc(), axisAnalysisPass); + auto maskElements = getMaskElemsAndUpdateVeclen( + rewriter, loc, adaptor.getMask(), op.getMask(), maxVec); + + // global.load.lds does not support per lane offsets. + // We need to ensure that we write coalesced into shared memory. This means + // that the kLane dim needs to be contigeous based on the vector size. + if (!LLVM::AMD::canCoalesceWriteIntoSharedMemory(rewriter, srcTy, dstTy, + maxVec)) { + return rewriter.notifyMatchFailure(op, + "does not write coalesced into LDS"); + } + + // First we determine the vector size per load and collect the + // shared addresses. This will only emit the address calculation and not the + // actual loads + VectorType vecTy; + SmallVector shmemAddrs; + bool ok = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { + vecTy = vecTy_; + shmemAddrs.push_back(shmemAddr); + }); + assert(ok); + + int vecBits = vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); + if (!targetInfo.supportsDirectToLdsLoadBitWidth(vecBits)) { + return rewriter.notifyMatchFailure( + op, "Async copy does not support the required load vector bitwidth" + + std::to_string(vecBits)); + } + + int vecBytes = vecBits / 8; + assert(llvm::isPowerOf2_32(vecBytes)); + Value vecBytesVal = b.i32_val(vecBytes); + + Value cacheModifiers = + b.i32_val(mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget( + op.getCache(), /*isLoad=*/true, targetInfo)); + + Value llMask = adaptor.getMask(); + SmallVector maskElems; + if (llMask) { + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(srcElems.size() == maskElems.size()); + } + + SmallVector otherElems; + if (op.getOther()) { + otherElems = unpackLLElements(loc, adaptor.getOther(), rewriter); + assert(srcElems.size() == otherElems.size()); + } + + // Emit the load to lds based on the collected shared addresses and vector + // size + for (int i = 0; i < shmemAddrs.size(); i++) { + auto srcIdx = i * maxVec; + auto srcPtr = srcElems[srcIdx]; + + if (maskElems.empty()) { + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), + cacheModifiers); + continue; + } + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterLoad = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *loadBlock = rewriter.createBlock(afterLoad); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, maskElems[srcIdx], loadBlock, + afterLoad); + rewriter.setInsertionPointToStart(loadBlock); + rewriter.create( + loc, srcPtr, shmemAddrs[i], vecBytesVal, /*offset=*/b.i32_val(0), + cacheModifiers); + + rewriter.create(loc, afterLoad); + rewriter.setInsertionPointToStart(afterLoad); + if (!otherElems.empty()) { + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx); + llStore(rewriter, loc, shmemAddrs[i], storeVal, + b.icmp_ne(maskElems[srcIdx], b.true_val()), op.getCache()); + } + } + + // Drop the result token. + Value zero = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); + rewriter.replaceOp(op, zero); + return success(); + } +}; + +struct StoreOpConversion : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + StoreOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = op.getPtr(); + Value value = op.getValue(); + Value mask = op.getMask(); + + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llValue = adaptor.getValue(); + + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + auto valueTy = value.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + + // Determine the vectorization size + unsigned vec = getVectorSize(ptr, axisAnalysisPass); + unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); + + auto ptrElems = unpackLLElements(loc, llPtr, rewriter); + auto valueElems = unpackLLElements(loc, llValue, rewriter); + assert(ptrElems.size() == valueElems.size()); + + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + const size_t valueElemNBits = + std::max(8, valueElemTy.getIntOrFloatBitWidth()); + const size_t valueElemNBytes = valueElemNBits / 8; + + auto cacheMod = op.getCache(); + const int numVecs = elemsPerThread / vec; + Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { + Value pred = mask ? b.and_(maskElems[vecStart], rDataMask) : rDataMask; + auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + + const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t totalWidth = valueElemNBits * vec; + const size_t width = std::min(totalWidth, maxWordWidth); + const size_t nWords = std::max(1, totalWidth / width); + const size_t wordNElems = width / valueElemNBits; + assert(wordNElems * nWords * numVecs == elemsPerThread); + + SmallVector> asmArgs; + Value elem = valueElems[vecStart]; + Value ptr = ptrElems[vecStart]; + + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); + llStore(rewriter, loc, ptr, storeVal, pred, cacheMod); + } // end vec + rewriter.eraseOp(op); + return success(); + } +}; + +static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { + switch (memOrdering) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + return LLVM::AtomicOrdering::acq_rel; + } +} + +struct BufferAtomicRMWOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferAtomicRMWOp>::ConvertOpToLLVMPattern; + + BufferAtomicRMWOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferAtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value data = op.getValue(); + auto atomicRmwAttr = op.getAtomicRmwOp(); + + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llData = adaptor.getValue(); + Value llStride = adaptor.getStride(); + + // Determine the vectorization size + Type valueTy = data.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); + + // v4f16 and v4bf16 variants of buffer atomics do not exist. + // only v2f16 and v2bf16. + if (valueElemTy.isBF16() || valueElemTy.isF16()) { + // We clamp to the only supported vectorization width here (2). + // In ConvertToBufferOps we check that we have a large enough vector size + assert(vec >= 2); + vec = 2u; + // The max width of a buffer atomic op is 64-bits + // Some types like F32 don't have a 2x vectorized version + } else if (valueElemTy.isF32() || valueElemTy.isF64() || + valueElemTy.isInteger(32) || valueElemTy.isInteger(64)) { + vec = 1u; + } + + // Get the offsets and value + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + SmallVector valueElems = unpackLLElements(loc, llData, rewriter); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + // We need to manually emit memory fences (LLVM doesn't do this for buffer + // ops) see: https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942 + auto memOrdering = op.getSem(); + auto atomicMemOrdering = getMemoryOrdering(memOrdering); + auto rel = LLVM::AtomicOrdering::release; + auto acq = LLVM::AtomicOrdering::acquire; + + bool emitReleaseFence = false; + bool emitAcquireFence = false; + switch (memOrdering) { + case MemSemantic::RELAXED: + // In this case, no memory fences are needed + break; + case MemSemantic::RELEASE: + emitReleaseFence = true; + break; + case MemSemantic::ACQUIRE: + emitAcquireFence = true; + break; + case MemSemantic::ACQUIRE_RELEASE: + emitAcquireFence = true; + emitReleaseFence = true; + default: + // default == acq_rel, so we emit the same barriers + emitAcquireFence = true; + emitReleaseFence = true; + } + + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride); + Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + SmallVector loadedVals; + + // set the scope + auto memScope = op.getScope(); + auto scopeStr = ""; + switch (memScope) { + // System scope is not supported yet + case MemSyncScope::SYSTEM: + return rewriter.notifyMatchFailure( + op, "System memory scope is not supported for Buffer Atomic RMW"); + case MemSyncScope::GPU: + scopeStr = "agent"; + break; + case MemSyncScope::CTA: + scopeStr = "workgroup"; + break; + default: + return rewriter.notifyMatchFailure( + op, "Unsupported memory scope for Buffer Atomic RMW"); + } + + StringAttr scope = mlir::StringAttr::get(loc.getContext(), scopeStr); + + if (emitReleaseFence) + rewriter.create(loc, TypeRange{}, rel, scope); + + mlir::Operation *lastRMWOp; + MLIRContext *ctx = rewriter.getContext(); + GCNBuilder waitcntBuilder; + + // Triton supports three scopes for atomic access + // 1. System + // 2. GPU (default) + // 3. CTA (i.e., threadblock or warp-group) + // + // Currently, the AMD backend emits atomics with agent-scope. + // + // The following properties are used to emit proper synchronization + // primitives between sequential buffer atomics See: Memory Model GFX942 + // (MI300 series) + // https://llvm.org/docs/AMDGPUUsage.html#memory-model-gfx942: + // + // buffer/global/flat_load/store/atomic instructions to global memory are + // termed vector memory operations. + // + // 1. Vector memory operations access a single vector L1 cache shared by + // all SIMDs a CU. + // No special action is required for coherence between wavefronts in the + // same work-group since they execute on the same CU. + // + // 2. Each CU has a separate request queue per channel for its associated + // L2. + // Therefore, the vector and scalar memory operations performed by + // wavefronts executing with different L1 caches and the same L2 cache + // can be reordered relative to each other. A `s_waitcnt vmcnt(0)` is + // required to ensure synchronization between vector memory operations of + // different CUs. It ensures a previous vector memory operation has + // completed before executing a subsequent vector memory or LDS operation + // and so can be used to meet the requirements of acquire and release. + // + // 3. Atomic read-modify-write instructions implicitly bypass the L1 cache + // (specific to gfx942) + // Therefore, they do not use the sc0 bit for coherence and instead use + // it to indicate if the instruction returns the original value being + // updated. They do use sc1 to indicate system or agent scope coherence. + // See the cache modifiers word in BufferEmitter::fillCommonArgs for + // more details. + // + // In summary: + // 1. We have to emit memory fences (i.e., acq/rel/acq_rel) before and after + // our buffer atomics. + // 2. Because buffer atomic rmw ops skip the l1 cache, s_waitcnt vmcnt(0) is + // sufficient for synchronization between instructions. + // We don't need to invalidate L1 between these ops on GFX942, just after + // (i.e., we can skip `buffer_wbinvl1_vol`) + // 3. We don't have to explicitly write to the l2 cache because + // `s_waitcnt vmcnt(0)` already does this as-per the MI300/CDNA3 ISA + // docs: "Decremented for reads when the data has been written back to + // the VGPRs, and for writes when the data has been written to the L2 + // cache. Ordering: Memory reads and writes return in the order they were + // issued, including mixing reads and writes" + // 4. We set GLC=1, to return the old value. Atomics in GFX942 execute with + // either device (default) or system scope (controlled by the sc1 flag). + // This is distinct from the memory scope of the atomic (i.e, the memory + // fences which appear before/after the ops). + + if (memScope == MemSyncScope::GPU) { + waitcntBuilder.create<>("s_waitcnt vmcnt(0)")->operator()(); + } else if (memScope == MemSyncScope::CTA) { + // TODO: Within a CTA we can possibly relax this? + waitcntBuilder.create<>("s_waitcnt vmcnt(0)")->operator()(); + } + + // Check if the op has users, if it does we set GLC=1, otherwise GLC=0 + auto opUsers = op.getResult().getUsers(); + auto hasUsers = std::distance(opUsers.begin(), opUsers.end()) > 0; + + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + Value pred = mask ? b.and_(maskElems[vecStart], rDataMask) : rDataMask; + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); + + Value loadVal = bufferEmitter.emitAtomicRMW( + atomicRmwAttr, vecTy, rsrcDesc, offsetElems[vecStart], storeVal, pred, + hasUsers); + // Track the last op, so we can emit a fenceop after the loop + lastRMWOp = loadVal.getDefiningOp(); + + // To sync vector memory ops between CUs within an agent, we need an + // s_waitcnt skip doing this on the last iteration of the loop + // In the relaxed memory ordering, we don't need this barrier + if (vecStart < numElems - vec && (emitReleaseFence || emitAcquireFence)) { + Value inst = + waitcntBuilder.launch(rewriter, lastRMWOp->getLoc(), void_ty(ctx)); + lastRMWOp = inst.getDefiningOp(); + } + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + Value loaded = b.extract_element(valueElemTy, loadVal, vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + // Acquire Fence post-atomic + if (emitAcquireFence) + rewriter.create(lastRMWOp->getLoc(), TypeRange{}, acq, + scope); + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, + rewriter, llvmResultStructTy); + + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct BufferStoreOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferStoreOp>::ConvertOpToLLVMPattern; + + BufferStoreOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value data = op.getValue(); + auto cacheMod = op.getCache(); + + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llData = adaptor.getValue(); + Value llStride = adaptor.getStride(); + + // Determine the vectorization size + Type valueTy = data.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset, axisAnalysisPass); + + // Get the offsets and value + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + SmallVector valueElems = unpackLLElements(loc, llData, rewriter); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride); + Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + Value pred = mask ? b.and_(maskElems[vecStart], rDataMask) : rDataMask; + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); + bufferEmitter.emitStore(rsrcDesc, offsetElems[vecStart], storeVal, pred, + cacheMod); + } // end vec + + rewriter.eraseOp(op); + return success(); + } +}; + +struct AtomicCASOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AtomicCASOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // extract relevant info from Module + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + Value ptr = op.getPtr(); + + Value llPtr = adaptor.getPtr(); + Value llCmp = adaptor.getCmp(); + Value llVal = adaptor.getVal(); + + // prep data by unpacking to get data ready + auto ptrElements = unpackLLElements(loc, llPtr, rewriter); + auto cmpElements = unpackLLElements(loc, llCmp, rewriter); + auto valElements = unpackLLElements(loc, llVal, rewriter); + + auto memOrdering = op.getSem(); + auto atomicMemOrdering = getMemoryOrdering(memOrdering); + auto scope = op.getScope(); + auto scopeStr = getAMDGPUMemScopeStr(scope); + if (!scopeStr) + return rewriter.notifyMatchFailure(op, "Unknown AMDGPU memory scope"); + + // deal with tensor or scalar + auto valueTy = op.getResult().getType(); + auto TensorTy = dyn_cast(valueTy); + Type valueElemTy = + TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType()) + : valueTy; + auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); + // vec = 1 for scalar + auto vec = getVectorSize(op.getPtr(), axisAnalysisPass); + // tensor + if (TensorTy) { + auto valTy = cast(op.getVal().getType()); + vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + } + + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + auto vecTy = vec_ty(valueElemTy, vec); + SmallVector resultVals(elemsPerThread); + + // atomic ops + for (size_t i = 0; i < elemsPerThread; i += vec) { + Value casVal = b.undef(vecTy); + for (int ii = 0; ii < vec; ++ii) { + Value iiVal = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal); + } + + Value casPtr = ptrElements[i]; + Value casCmp = cmpElements[i]; + casVal = valElements[i]; + + // use op + if (TensorTy) { // for tensor + auto retType = vec == 1 ? valueElemTy : vecTy; + // TODO: USE ATOMIC CAS OP on Tensor + auto successOrdering = atomicMemOrdering; + auto failureOrdering = LLVM::AtomicOrdering::monotonic; + auto cmpxchg = rewriter.create( + loc, casPtr, casCmp, casVal, successOrdering, failureOrdering, + StringRef(scopeStr.value())); + + // Extract the new_loaded value from the pair. + Value ret = b.extract_val(valueElemTy, cmpxchg, i); + + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? ret + : b.extract_element(valueElemTy, ret, b.i32_val(ii)); + } + } else { // for scalar + // Build blocks to bypass the atomic instruction for ~rmwMask. + auto *curBlock = rewriter.getInsertionBlock(); + auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); + auto *atomicBlock = rewriter.createBlock( + curBlock->getParent(), std::next(Region::iterator(curBlock))); + + // Fill entry block with global memory barrier and conditional branch. + rewriter.setInsertionPointToEnd(curBlock); + auto tid = getThreadId(rewriter, loc); + Value pred = b.icmp_eq(tid, b.i32_val(i)); + rewriter.create(loc, pred, atomicBlock, endBlock); + + // Build main block with atomic_cmpxchg. + rewriter.setInsertionPointToEnd(atomicBlock); + + auto successOrdering = LLVM::AtomicOrdering::acq_rel; + auto failureOrdering = LLVM::AtomicOrdering::monotonic; + auto cmpxchg = rewriter.create( + loc, casPtr, casCmp, casVal, successOrdering, failureOrdering, + StringRef("agent")); + + if (atomicNeedsSharedMemory(op.getResult())) { + // Extract the new_loaded value from the pair. + Value newLoaded = b.extract_val(valueElemTy, cmpxchg, 0); + Value atomPtr = + getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + b.store(newLoaded, atomPtr); + } + + rewriter.create(loc, ValueRange(), endBlock); + + // Build the last block: synced load from shared memory, exit. + rewriter.setInsertionPointToStart(endBlock); + + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + + GCNBuilder BuilderMemfenceLDS; + BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()(); + BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx)); + b.barrier(); + Value atomPtr = + getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + Value ret = b.load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + } + } + + // replace op + if (TensorTy) { + Type structTy = getTypeConverter()->convertType(TensorTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + } + return success(); + } +}; + +bool supportsGlobalAtomicF16PackedAndDpp(ISAFamily isaFamily) { + switch (isaFamily) { + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: + case ISAFamily::CDNA4: + return true; + default: + break; + } + return false; +} + +Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) { + assert(val.getType().isInteger(32)); + auto loc = val.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value old = b.i32_val(0); + int rowMask = 0b1111; // enable all rows + int bankMask = 0b1111; // enable all banks + bool boundCtrl = false; + auto dppMovOp = rewriter.create( + loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl); + return dppMovOp.getResult(); +} + +Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) { + return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane +} + +Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) { + return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane +} + +Value generatePopcount64(PatternRewriter &rewriter, Value val) { + auto loc = val.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value m1 = b.i64_val(0x5555555555555555); // binary: 0101 0101.. + Value m2 = b.i64_val(0x3333333333333333); // binary: 0011 0011.. + Value m4 = b.i64_val(0x0f0f0f0f0f0f0f0f); // binary: 0000 1111.. + // binary: 0000 0001 0000 0001.. + Value h01 = b.i64_val(0x0101010101010101); + // put count of each 2 bits into those 2 bits + val = b.sub(val, b.and_(m1, b.lshr(val, b.i64_val(1)))); + // put count of each 4 bits into those 4 bits + val = b.add(b.and_(val, m2), b.and_(b.lshr(val, b.i64_val(2)), m2)); + // put count of each 8 bits into those 8 bits + val = b.and_(b.add(val, b.lshr(val, b.i64_val(4))), m4); + // left 8 bits of x + (x<<8) + (x<<16) + (x<<24) + ... + return b.lshr(b.mul(val, h01), b.i64_val(56)); +} + +Value genReadFirstLane(PatternRewriter &rewriter, Value v) { + auto loc = v.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + std::string intrinsic = "llvm.amdgcn.readfirstlane"; + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, i32_ty, v) + ->getResult(0); +} + +Value genPermute(PatternRewriter &rewriter, Value v, Value dst) { + auto loc = v.getLoc(); + std::string intrinsic = "llvm.amdgcn.ds.permute"; + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, i32_ty, + ValueRange{dst, v}) + ->getResult(0); +} + +Value genBPermute(PatternRewriter &rewriter, Value v, Value dst) { + auto loc = v.getLoc(); + std::string intrinsic = "llvm.amdgcn.ds.bpermute"; + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, i32_ty, + ValueRange{dst, v}) + ->getResult(0); +} + +template +Value genI32TiledOp(PatternRewriter &rewriter, Generator genCall, + Value argToSplit, Values... args) { + auto loc = argToSplit.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Type ty = argToSplit.getType(); + size_t tySize = ty.getIntOrFloatBitWidth(); + size_t i32Size = i32_ty.getIntOrFloatBitWidth(); + size_t count = tySize / i32Size; + assert(tySize % i32Size == 0 && count > 0 && + "Unalligned types are not supported yet."); + Type i32VecValTy = vec_ty(i32_ty, count); + Value vec = b.undef(i32VecValTy); + Value valCasted = b.bitcast(argToSplit, i32VecValTy); + for (int i = 0; i < count; i++) { + Value subVal = b.extract_element(i32_ty, valCasted, b.i32_val(i)); + Value result = genCall(rewriter, subVal, args...); + vec = b.insert_element(i32VecValTy, vec, result, b.i32_val(i)); + } + return b.bitcast(vec, ty); +} + +Value genPrefixSum(PatternRewriter &rewriter, Value v0) { + auto loc = v0.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value old = b.i32_val(0); + + Value v1 = v0; + // v_add_f32 v1, v0, v0 row_shr:1 bound_ctrl:0 + Value tmp = rewriter.create(loc, i32_ty, old, v0, 0x111, + 0xF, 0xF, false); + v1 = b.add(v1, tmp); + // v_add_f32 v1, v0, v1 row_shr:2 bound_ctrl:0 + tmp = rewriter.create(loc, i32_ty, old, v0, 0x112, 0xF, + 0xF, false); + v1 = b.add(v1, tmp); + // v_add_f32 v1, v0, v1 row_shr:3 bound_ctrl:0 + tmp = rewriter.create(loc, i32_ty, old, v0, 0x113, 0xF, + 0xF, false); + v1 = b.add(v1, tmp); + + // v_add_f32 v1, v1, v1 row_shr:4 bank_mask:0xe + tmp = rewriter.create(loc, i32_ty, old, v1, 0x114, 0xF, + 0xE, true); + v1 = b.add(v1, tmp); + + // v_add_f32 v1, v1, v1 row_shr:8 bank_mask:0xc + tmp = rewriter.create(loc, i32_ty, old, v1, 0x118, 0xF, + 0xC, true); + v1 = b.add(v1, tmp); + + // v_add_f32 v1, v1, v1 row_bcast:15 row_mask:0xa + tmp = rewriter.create(loc, i32_ty, old, v1, 0x142, 0xA, + 0xF, true); + v1 = b.add(v1, tmp); + + // v_add_f32 v1, v1, v1 row_bcast:31 row_mask:0xc + tmp = rewriter.create(loc, i32_ty, old, v1, 0x143, 0xC, + 0xF, true); + v1 = b.add(v1, tmp); + + return v1; +} + +struct AtomicRMWOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AtomicRMWOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + /// Try to match the mlir::triton::RMWOp to LLVM::AtomicBinOp. + static std::optional matchAtomicOp(RMWOp atomicOp) { + switch (atomicOp) { + case RMWOp::AND: + return LLVM::AtomicBinOp::_and; + case RMWOp::OR: + return LLVM::AtomicBinOp::_or; + case RMWOp::XOR: + return LLVM::AtomicBinOp::_xor; + case RMWOp::ADD: + return LLVM::AtomicBinOp::add; + case RMWOp::FADD: + return LLVM::AtomicBinOp::fadd; + case RMWOp::MAX: + return LLVM::AtomicBinOp::max; + case RMWOp::MIN: + return LLVM::AtomicBinOp::min; + case RMWOp::UMAX: + return LLVM::AtomicBinOp::umax; + case RMWOp::UMIN: + return LLVM::AtomicBinOp::umin; + case RMWOp::XCHG: + return LLVM::AtomicBinOp::xchg; + default: + return std::nullopt; + } + llvm_unreachable("Invalid RMWOp"); + } + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + auto atomicRmwAttr = op.getAtomicRmwOp(); + Value ptr = op.getPtr(); + Value val = op.getVal(); + + Value llPtr = adaptor.getPtr(); + Value llVal = adaptor.getVal(); + Value llMask = adaptor.getMask(); + + auto valElements = unpackLLElements(loc, llVal, rewriter); + auto ptrElements = unpackLLElements(loc, llPtr, rewriter); + SmallVector maskElements; + if (llMask) + maskElements = unpackLLElements(loc, llMask, rewriter); + + Value opResult = op.getResult(); + auto tensorTy = dyn_cast(opResult.getType()); + Type valueElemTy = + tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) + : opResult.getType(); + const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); + auto elemsPerThread = getTotalElemsPerThread(val.getType()); + // vec = 1, numElements = 1 for scalar + auto vec = getVectorSize(ptr, axisAnalysisPass); + int numElems = 1; + Type packF16Ty = vec_ty(valueElemTy, 2); + + // CDNA3/CDNA4 arch allows to accelerate its atomics with LDS reduction + // algorithm, which is only applicable for atomics with no return. Otherwise + // we have to deal with an additional overhead. + bool enableIntraWaveReduce = + llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4}, + targetInfo.getISAFamily()) && + tensorTy && opResult.use_empty(); + + // TODO: support data types less than 32 bits + enableIntraWaveReduce &= valueElemNbits >= 32; + + // In the case of unpaired f16 elements utilize dpp instructions to + // accelerate atomics. Here is an algorithm of lowering + // tt::atomicRmwOp(%ptr, %val, %mask): + // 0. Group thread by pairs. Master thread is (tid % 2 == 0); + // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so + // all the masters receive value from secondary threads; + // 2. Take into account parity in the %mask value, build control flow + // structures according to it; + // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value; + // 4. All the threads send result of generated operation to (tid + 1) thread + // via dppUpdateOp shl, so all secondary thread also receive their + // result. + // + // This approach enables us to use half the active threads committing atomic + // requests to avoid generating of code providing unified access to f16 + // element and reduce contention. + bool useDppForPackedF16 = false; + // tensor + if (tensorTy) { + auto valTy = cast(val.getType()); + bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16(); + unsigned availableVecSize = isF16Ty ? 2 : 1; + vec = std::min(vec, availableVecSize); + // Force F16 packing in the case it's not coming in as packed, but the + // ISA can support packed atomic instructions. + useDppForPackedF16 = + supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) && + vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD && + !enableIntraWaveReduce; + // mask + numElems = tensorTy.getNumElements(); + } + Value mask = b.true_val(); + auto tid = getThreadId(rewriter, loc); + mask = b.and_(mask, b.icmp_slt(b.mul(tid, b.i32_val(elemsPerThread)), + b.i32_val(numElems))); + + auto memOrdering = op.getSem(); + auto scope = op.getScope(); + auto atomicMemOrdering = getMemoryOrdering(memOrdering); + + std::optional scopeStr = getAMDGPUMemScopeStr(scope); + if (!scopeStr) + return rewriter.notifyMatchFailure(op, "Unknown AMDGPU memory scope"); + + auto vecTy = vec_ty(valueElemTy, vec); + auto retType = vec == 1 ? valueElemTy : vecTy; + retType = useDppForPackedF16 ? packF16Ty : retType; + SmallVector resultVals(elemsPerThread); + for (size_t i = 0; i < elemsPerThread; i += vec) { + Value rmwPtr = ptrElements[i]; + // TODO: in case llMask is zero we can create only one branch for all + // elemsPerThread. + Value rmwMask = llMask ? b.and_(mask, maskElements[i]) : mask; + + Value i64Ones = b.i64_val(~uint64_t(0)); + Value i64Zeros = b.i64_val(0); + Value operand; + Value rightNeighbourPtr; + Value enablePackedOpt; + if (useDppForPackedF16) { + Value isOddI32 = b.urem(tid, b.i32_val(2)); + // First check if odd threads hold adjacent ptrs to even ones. + Value castedAddr = b.ptrtoint(i64_ty, rmwPtr); + // Set casted addr to all ones if the thread is disabled. + castedAddr = b.select(rmwMask, castedAddr, i64Ones); + + // Move %val to left neighbour to proceed packed atomic further. + Value packedVal = b.null(packF16Ty); + packedVal = + b.insert_element(packF16Ty, packedVal, valElements[i], isOddI32); + // Pack to i32 type to simplify transaction. + packedVal = b.bitcast(packedVal, i32_ty); + // Zero operands for disabled threads to make addition no op. + packedVal = b.select(rmwMask, packedVal, b.i32_val(0)); + Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal); + + Value rightNeighbourAddr = + genI32TiledOp(rewriter, shiftLeftI32ByDpp, castedAddr); + + // Packing optimization only supported if following conditions are true: + // 1. address is aligned by 4 bytes + // 2. right neighbour has adjacent address + // 3. both threads are active + Value isAligned = + b.icmp_eq(b.urem(castedAddr, b.i64_val(4)), b.i64_val(0)); + Value neighbourAddrAdjacent = b.icmp_eq( + rightNeighbourAddr, + b.add(castedAddr, + b.i64_val(valueElemTy.getIntOrFloatBitWidth() / 8))); + Value neighbourEnabled = b.icmp_ne(i64Ones, rightNeighbourAddr); + Value bothEnabled = b.and_(neighbourEnabled, rmwMask); + enablePackedOpt = + b.and_(b.and_(isAligned, bothEnabled), neighbourAddrAdjacent); + + // Enable only the even threads. + Value anyEnabled = b.or_(neighbourEnabled, rmwMask); + // If one of the threads is disabled, use the neighbour's addr. + rightNeighbourAddr = + b.select(neighbourEnabled, rightNeighbourAddr, castedAddr); + castedAddr = b.select(rmwMask, castedAddr, rightNeighbourAddr); + + rmwMask = b.and_(anyEnabled, b.icmp_eq(isOddI32, b.i32_val(0))); + + // Unpack results back + rightNeighbourPtr = b.inttoptr(rmwPtr.getType(), rightNeighbourAddr); + rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr); + operand = b.bitcast(b.or_(packedVal, dppMoveRes), packF16Ty); + } else if (vec == 1) { + operand = valElements[i]; + } else { + operand = b.undef(vecTy); + for (size_t ii = 0; ii < vec; ++ii) + operand = b.insert_element(vecTy, operand, valElements[i + ii], + b.i32_val(ii)); + } + + Value undefVal = b.undef(retType); + // Build blocks to bypass the atomic instruction for ~rmwMask. + auto *curBlock = rewriter.getInsertionBlock(); + auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); + auto *atomicBlock = rewriter.createBlock( + curBlock->getParent(), std::next(Region::iterator(curBlock))); + endBlock->addArgument({retType}, {loc}); + + rewriter.setInsertionPointToEnd(curBlock); + // intraWave reduce optimization for atomic ops needs all active threads + // at the beginning of a wave. This is achieved as: + // 1. Compute the prefix sum of the mask, then each active lane gets a + // different value (offset) from its previous lane. + // 2. Multiply the mask and the offset, so only active lanes have a + // non-zero offset, and the offset is different in each active lane + // 3. Sub 1 from offset to get the idx each active lane is moved to + // 4. Call ds_permute to move active lanes to the beginning of a wave + // 5. Update mask of each lane + if (enableIntraWaveReduce) { + Value maskI32 = b.zext(i32_ty, rmwMask); + Value offset = genPrefixSum(rewriter, maskI32); + offset = b.mul(offset, maskI32); + auto layout = tensorTy.getEncoding(); + Value waveSize = b.i32_val(triton::gpu::getWarpSize(layout)); + offset = b.select(b.icmp_eq(offset, b.i32_val(0)), waveSize, offset); + Value idx = b.sub(offset, b.i32_val(1)); + idx = b.mul(idx, b.i32_val(4)); + operand = genI32TiledOp(rewriter, genPermute, operand, idx); + Value castedAddr = b.ptrtoint(i64_ty, rmwPtr); + castedAddr = genI32TiledOp(rewriter, genPermute, castedAddr, idx); + rmwPtr = b.inttoptr(rmwPtr.getType(), castedAddr); + + // update mask + Value maskFlag = targetInfo.ballot(rewriter, loc, i64_ty, rmwMask); + Value numActiveLanes = + b.trunc(i32_ty, generatePopcount64(rewriter, maskFlag)); + + Value laneID = b.urem(tid, waveSize); + rmwMask = b.icmp_ult(laneID, numActiveLanes); + } + rewriter.create(loc, rmwMask, atomicBlock, endBlock, + undefVal); + + rewriter.setInsertionPointToEnd(atomicBlock); + auto maybeKind = matchAtomicOp(atomicRmwAttr); + Value atom; + Value isVecOp; + if (enableIntraWaveReduce) { + atom = atomicIntraWaveReduce(rewriter, rmwPtr, operand, *maybeKind, + atomicMemOrdering, *scopeStr); + } else { + if (useDppForPackedF16) { + // Determine on the runtime what atomic intrinsic to execute: + // packed or regular. + auto *packedBlock = + atomicBlock->splitBlock(rewriter.getInsertionPoint()); + auto *regularBlock = + rewriter.createBlock(atomicBlock->getParent(), + std::next(Region::iterator(atomicBlock))); + rewriter.setInsertionPointToEnd(atomicBlock); + rewriter.create(loc, enablePackedOpt, packedBlock, + regularBlock); + + // Fill out the regular block, where we issue two atomic ops. + rewriter.setInsertionPointToEnd(regularBlock); + Value pairedOperand0 = + b.extract_element(valueElemTy, operand, b.i32_val(0)); + Value pairedOperand1 = + b.extract_element(valueElemTy, operand, b.i32_val(1)); + Value atomNonVec0 = rewriter.create( + loc, *maybeKind, rmwPtr, pairedOperand0, atomicMemOrdering, + *scopeStr); + Value atomNonVec1 = rewriter.create( + loc, *maybeKind, rightNeighbourPtr, pairedOperand1, + atomicMemOrdering, *scopeStr); + Value packedRes = b.undef(packF16Ty); + packedRes = + b.insert_element(packF16Ty, packedRes, atomNonVec0, b.i32_val(0)); + packedRes = + b.insert_element(packF16Ty, packedRes, atomNonVec1, b.i32_val(1)); + rewriter.create(loc, packedRes, endBlock); + + // Start to fill out the packed block. + rewriter.setInsertionPointToEnd(packedBlock); + } + atom = rewriter.create( + loc, *maybeKind, rmwPtr, operand, atomicMemOrdering, *scopeStr); + } + + if (!tensorTy) { + if (atomicNeedsSharedMemory(op.getResult())) { + Value atomPtr = + getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + b.store(atom, atomPtr); + } + } + rewriter.create(loc, atom, endBlock); + + rewriter.setInsertionPointToStart(endBlock); + Value retVal = endBlock->getArgument(0); + if (tensorTy) { + if (useDppForPackedF16) { + // Return packed to i32 result after atomic operation back from master + // lane. + auto packedRet = b.bitcast(retVal, i32_ty); + Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet); + // Unpack results back + Value unpackedDppRes = b.bitcast(dppMovRes, packF16Ty); + retVal = b.insert_element( + packF16Ty, retVal, + b.extract_element(valueElemTy, unpackedDppRes, b.i32_val(1)), + b.i32_val(1)); + resultVals[i] = + b.extract_element(valueElemTy, retVal, b.urem(tid, b.i32_val(2))); + } else { + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 + ? retVal + : b.extract_element(valueElemTy, retVal, b.i32_val(ii)); + } + } + } else { + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + Value atomPtr = + getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + b.barrier(); + Value ret = b.load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + } + } + if (tensorTy) { + Type structTy = getTypeConverter()->convertType(tensorTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + } + return success(); + } + +private: + Value atomicIntraWaveReduce(PatternRewriter &rewriter, Value rmwPtr, + Value operand, LLVM::AtomicBinOp opKind, + LLVM::AtomicOrdering memOrdering, + StringRef scope) const { + // This approach minimizes intra-warp thread contention when accessing + // global memory pointers. It is particularly advantageous for certain ISA + // families, such as CDNA3. The algorithm follows these steps: + // 1. Analyze thread groups and their relative positions: + // 1.1. Consider groups of threads sharing identical pointers using + // `readfirstlane` and ballot `intrinsics`. + // 1.2. Compute parameters to form contiguous groups and further optimize + // them. + // 1.3. Disable threads that have already been processed. + // 1.4. If thread was not considered, jump to `1.1.`. + // 2. Form contiguous groups: + // Use `permute` instructions to organize threads within the wavefront + // into continuous groups. + // 4. Reduce Groups to Leader threads: + // Apply `bpermute` and operation-specific arithmetic based on the opKind + // to consolidate group data into leader threads. + // 5. Perform global atomic operations by leader threads. + auto loc = operand.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Type operandElemType = operand.getType(); + Type origPtrType = rmwPtr.getType(); + + rmwPtr = b.ptrtoint(i64_ty, rmwPtr); + + auto *curBlock = rewriter.getInsertionBlock(); + auto *atomicBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); + atomicBlock->addArgument(i64_ty, loc); + atomicBlock->addArgument(operandElemType, loc); + auto *initLoop = rewriter.createBlock( + curBlock->getParent(), std::next(Region::iterator(curBlock))); + + rewriter.setInsertionPointToEnd(curBlock); + + // check how many adjacent address are in the wave + Value rightNeighbourAddr = + genI32TiledOp(rewriter, shiftLeftI32ByDpp, rmwPtr); + Value elemSize = b.i64_val(operandElemType.getIntOrFloatBitWidth() / 8); + Value isNeighbour = b.icmp_eq(rightNeighbourAddr, b.add(rmwPtr, elemSize)); + Value neighbourFlag = targetInfo.ballot(rewriter, loc, i64_ty, isNeighbour); + Value numNeighbours = + b.trunc(i32_ty, generatePopcount64(rewriter, neighbourFlag)); + // Heuristic that atomic_add is optimizated only if the number of + // neighbouring addresses in a wave is less than 32. + // TODO: Calculate actual number of difference addresses in a wave. + Value optAtomic = b.icmp_ult(numNeighbours, b.i32_val(32)); + + rewriter.create(loc, optAtomic, initLoop, atomicBlock, + ValueRange({rmwPtr, operand})); + rewriter.setInsertionPointToEnd(initLoop); + + auto *afterLoopBlock = initLoop->splitBlock(rewriter.getInsertionPoint()); + afterLoopBlock->addArgument(i32_ty, loc); // idx + afterLoopBlock->addArgument(i32_ty, loc); // cnt + afterLoopBlock->addArgument(int_ty(1), loc); // isLeader + + auto *loopBody = rewriter.createBlock( + initLoop->getParent(), std::next(Region::iterator(initLoop))); + loopBody->addArgument(i32_ty, loc); + + rewriter.setInsertionPointToEnd(initLoop); + rewriter.create(loc, b.i32_val(0), loopBody); + + // Greed search of same addr within wavefront. Also collect auxiliary + // information about relative position: + // - idx in a group + base laneId. This param is required to form continuous + // groups further; + // - cnt of remaining threads in a group after current thread; + // - leadership status of the current thread. + rewriter.setInsertionPointToEnd(loopBody); + // `readfirstlane` considers only enabled threads + Value chosen = genI32TiledOp(rewriter, genReadFirstLane, rmwPtr); + // this flag is required to disable thread if we have already checked its + // pointer + Value done = b.icmp_eq(chosen, rmwPtr); + Value mask = targetInfo.ballot(rewriter, loc, i64_ty, done); + Value start = loopBody->getArgument(0); + Value cnt = b.trunc(i32_ty, generatePopcount64(rewriter, mask)); + Value mbcntLoRes = rewriter + .create( + loc, i32_ty, b.trunc(i32_ty, mask), b.i32_val(0)) + ->getResult(0); + Value idx = rewriter.create( + loc, i32_ty, b.trunc(i32_ty, b.lshr(mask, b.i64_val(32))), mbcntLoRes); + Value base = b.add(start, cnt); + Value leader = b.icmp_eq(idx, b.i32_val(0)); + cnt = b.sub(cnt, idx); + idx = b.add(idx, start); + rewriter.create(loc, done, afterLoopBlock, + ValueRange({idx, cnt, leader}), loopBody, + ValueRange({base})); + + rewriter.setInsertionPointToEnd(afterLoopBlock); + + Value idxRes = afterLoopBlock->getArgument(0); + Value cntRes = afterLoopBlock->getArgument(1); + Value leaderRes = afterLoopBlock->getArgument(2); + Value idxScaledForPermute = b.mul(idxRes, b.i32_val(4)); + + // Make groups continuous + rmwPtr = genI32TiledOp(rewriter, genPermute, rmwPtr, idxScaledForPermute); + operand = genI32TiledOp(rewriter, genPermute, operand, idxScaledForPermute); + // Actualize auxiliary info as well + Value packedRoleInfo = genI32TiledOp( + rewriter, genPermute, + b.or_(b.zext(i32_ty, leaderRes), + b.or_(idxScaledForPermute, b.shl(cntRes, b.i32_val(8)))), + idxScaledForPermute); + idxScaledForPermute = packedRoleInfo; + cntRes = b.and_(b.lshr(packedRoleInfo, b.i32_val(8)), b.i32_val(0xff)); + leaderRes = b.icmp_ne(b.and_(packedRoleInfo, b.i32_val(1)), b.i32_val(0)); + + auto *afterRedBlock = + afterLoopBlock->splitBlock(rewriter.getInsertionPoint()); + afterRedBlock->addArgument(operandElemType, loc); + auto *partialReductionBlock = + rewriter.createBlock(afterLoopBlock->getParent(), + std::next(Region::iterator(afterLoopBlock))); + rewriter.setInsertionPointToEnd(afterLoopBlock); + Value reductionCond = + b.icmp_ne(targetInfo.ballot(rewriter, loc, i64_ty, + b.icmp_ne(cntRes, b.i32_val(1))), + b.i64_val(0)); + rewriter.create(loc, reductionCond, partialReductionBlock, + afterRedBlock, operand); + rewriter.setInsertionPointToEnd(partialReductionBlock); + + auto performOpIfCond = [&](Value res, Value v, Value cond) -> Value { + Type ty = v.getType(); + assert(ty == res.getType()); + Value notCond = b.icmp_eq(cond, b.false_val()); + switch (opKind) { + case LLVM::AtomicBinOp::_and: + // res &= cond ? v : 1111.. + return b.and_(res, + b.or_(v, b.sub(b.int_val(ty.getIntOrFloatBitWidth(), 0), + b.zext(ty, notCond)))); + case LLVM::AtomicBinOp::_or: + // res |= cond ? v : 0 + return b.or_(res, b.mul(v, b.zext(ty, cond))); + case LLVM::AtomicBinOp::_xor: + // res ^= cond ? v : 0 + return b.xor_(res, b.mul(v, b.zext(ty, cond))); + case LLVM::AtomicBinOp::add: + // res += cond ? v : 0 + return b.add(res, b.mul(v, b.zext(ty, cond))); + case LLVM::AtomicBinOp::fadd: + // res += cond ? v : 0 + return b.fadd( + res, b.fmul(v, b.inttofloat( + ty, b.zext(int_ty(ty.getIntOrFloatBitWidth()), + cond)))); + case LLVM::AtomicBinOp::max: + case LLVM::AtomicBinOp::umax: + // res = cond ? umax(v, res) : res + return b.or_(b.mul(res, b.zext(ty, notCond)), + b.mul(b.umax(v, res), b.zext(ty, cond))); + case LLVM::AtomicBinOp::min: + case LLVM::AtomicBinOp::umin: + // res = cond ? umin(v, res) : res + return b.or_(b.mul(res, b.zext(ty, notCond)), + b.mul(b.umin(v, res), b.zext(ty, cond))); + case LLVM::AtomicBinOp::xchg: + // res = cond ? v : res + return b.or_(b.mul(res, b.zext(ty, notCond)), + b.mul(v, b.zext(ty, cond))); + default: + llvm_unreachable("Unsupported atomic binary operation."); + } + }; + Value acc = operand; + // Reduce to leader thread + for (int i = 32; i != 0; i /= 2) { + Value tmp = genI32TiledOp(rewriter, genBPermute, acc, + b.add(idxScaledForPermute, b.i32_val(i * 4))); + acc = performOpIfCond(acc, tmp, b.icmp_ult(b.i32_val(i), cntRes)); + } + + rewriter.create(loc, acc, afterRedBlock); + rewriter.setInsertionPointToEnd(afterRedBlock); + + auto *endBlock = afterRedBlock->splitBlock(rewriter.getInsertionPoint()); + endBlock->addArgument(operandElemType, loc); + rewriter.setInsertionPointToEnd(afterRedBlock); + Value leaderCond = leaderRes; + Value defaultRes = b.undef(operandElemType); + rewriter.create( + loc, leaderCond, atomicBlock, + ValueRange({rmwPtr, afterRedBlock->getArgument(0)}), endBlock, + ValueRange({defaultRes})); + rewriter.setInsertionPointToEnd(atomicBlock); + // Utilize global atomic only by leader threads + Value addr = atomicBlock->getArgument(0); + Value atomAddr = b.inttoptr(origPtrType, addr); + Value atom = rewriter.create( + loc, opKind, atomAddr, atomicBlock->getArgument(1), memOrdering, scope); + rewriter.create(loc, atom, endBlock); + rewriter.setInsertionPointToStart(endBlock); + + return endBlock->getArgument(0); + } +}; + +struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { + AsyncWaitOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + switch (targetInfo.getISAFamily()) { + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: + case ISAFamily::CDNA4: + break; + default: + return rewriter.notifyMatchFailure( + op, "Only supported on CDNA target architecture"); + } + + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // global.load.lds uses vmcnt to synchronize + // The rocdl op stores all available counters in a single int32 value (v). + // The vmcnt (6 bits) is split into a lower 3:0 and higher 5:4 parts. + // The lower part is stored in bits 3:0 of v and the higher part in bits + // 15:14. We have to set all other bits in v to 1 to signal we are not + // interested in those. + + int vmCnt = op.getNum(); + if (vmCnt >= 64) { + return emitError(loc, "AsyncWait does not support values >= 64"); + } + + // Extract low and high bits and combine while setting all other bits to 1 + unsigned lowBits = vmCnt & 0xF; + unsigned highBits = vmCnt >> 4 << 14; + unsigned otherCnts = ~0xC00F; // C00F has bits 15:14 and 3:0 set + unsigned waitValue = lowBits | highBits | otherCnts; + + rewriter.create(loc, waitValue); + + // Drop the result AsyncToken + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } + +private: + const AMD::TargetInfo &targetInfo; +}; + +struct AsyncCommitGroupOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Drop the result AsyncToken + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } +}; + +} // namespace + +namespace mlir::triton::AMD { +void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { + patterns.add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000..b946d72b5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,294 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::MemDescType; + +namespace SharedToDotOperandMFMA { +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, + DotOperandEncodingAttr bEncoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread); +} // namespace SharedToDotOperandMFMA + +namespace SharedToDotOperandWMMA { +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, + DotOperandEncodingAttr bEncoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread); +} // namespace SharedToDotOperandWMMA + +namespace { +struct LocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isa(dstLayout) && + isa( + cast(dstLayout).getParent())) { + return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter); + } + return failure(); + } + +private: + /// Lower ttg.local_load in dot operand layout if the operand parent layout is + /// MFMA or WMMA. + /// + /// \returns value with packed loaded values or empty value if this local_load + /// is not supproted. + Value lowerSharedToDotOperandMMA( + triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + const DotOperandEncodingAttr &dotOperandLayout) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value src = op.getSrc(); + Value dst = op.getResult(); + auto llvmElemTy = typeConverter->convertType( + cast(src.getType()).getElementType()); + + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + Value res; + auto dopOpParent = dotOperandLayout.getParent(); + if (isa(dopOpParent)) { + auto sharedToDotConvert = isa(dopOpParent) + ? SharedToDotOperandMFMA::convertLayout + : SharedToDotOperandWMMA::convertLayout; + res = sharedToDotConvert(dotOperandLayout.getOpIdx(), rewriter, loc, src, + dotOperandLayout, smemObj, typeConverter, + getThreadId(rewriter, loc)); + } else { + assert(false && "unsupported layout found"); + } + return res; + } + + // shared -> matrix_core_dot_operand + LogicalResult + lowerSharedToDotOperand(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + Value dst = op.getResult(); + auto dstTensorTy = cast(dst.getType()); + auto dotOperandLayout = + cast(dstTensorTy.getEncoding()); + + Value res = lowerSharedToDotOperandMMA(op, adaptor, typeConverter, rewriter, + dotOperandLayout); + if (!res) + return failure(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct TransLocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + TransLocalLoadOpConversion(const LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + PatternBenefit benefit = 2) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (canUseTransLoad(srcTy, dstTy)) { + assert(checkPerformanceProperties(srcTy, dstTy)); + return lowerSharedToDotOperandTransLL(op, adaptor, getTypeConverter(), + rewriter); + } + return failure(); + } + +private: + bool checkLayoutProperties(MemDescType srcTy, RankedTensorType dstTy) const { + // Verify the layout properties required for using the ds_read_tr + // instruction. This instruction is used to load non-k contiguous tensors + // from shared memory into a dot layout with an MFMA layout parent. + auto dotEnc = llvm::dyn_cast(dstTy.getEncoding()); + if (!dotEnc) { + return false; + } + + auto mfmaEnc = llvm::dyn_cast(dotEnc.getParent()); + if (!mfmaEnc) { + return false; + } + + auto sharedEnc = + dyn_cast(srcTy.getEncoding()); + if (!sharedEnc) + return false; + + int rank = dstTy.getRank(); + const int kDim = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; + return kDim != sharedEnc.getOrder()[0]; + } + + bool checkPerformanceProperties(MemDescType srcTy, + RankedTensorType dstTy) const { + // Single rate MFMA insts: + // fp16, bf16: mfma32x32x8, mfma16x16x16 + // int8: mfma32x32x16, mfma16x16x32 + // + // Double rate MFMA insts: + // fp16, bf16: mfma32x32x16, mfma16x16x32 + // i8: mfma32x32x32, mfma16x16x64 + // + // Check that double-rate MFMA instructions are used whenever possible. + // Single rate instructions should only be used if the K block size is not + // large enough. + auto dotEnc = llvm::cast(dstTy.getEncoding()); + auto mfmaEnc = llvm::cast(dotEnc.getParent()); + + int rank = dstTy.getRank(); + auto bitwidth = typeConverter->convertType(dstTy.getElementType()) + .getIntOrFloatBitWidth(); + int32_t kWidth = dotEnc.getKWidth(); + const int32_t mDim = mfmaEnc.getMDim(); + assert((mDim == 32 || mDim == 16) && "Invalid MFMA instruction dimension"); + + const int kFactor = 16 / bitwidth; + const int kSizeSingleRateMfma32 = 8 * kFactor; + const int kSizeSingleRateMfma16 = 16 * kFactor; + const int largeTileThreshold = + (mDim == 32) ? kSizeSingleRateMfma32 : kSizeSingleRateMfma16; + const auto shape = dstTy.getShape(); + const int kDim = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; + + const bool isLargeTile = shape[kDim] > largeTileThreshold; + const int expectedKWidth = (isLargeTile ? 8 : 4) * kFactor; + return kWidth == expectedKWidth; + } + + bool canUseTransLoad(MemDescType srcTy, RankedTensorType dstTy) const { + auto bitwidth = typeConverter->convertType(dstTy.getElementType()) + .getIntOrFloatBitWidth(); + + // 1. Check GPU arch properties. + if (!targetInfo.canUseLDSTransLoad(bitwidth)) { + return false; + } + + // 2. Check layout properties. + if (!checkLayoutProperties(srcTy, dstTy)) { + return false; + } + + // 3. Check current limitations. + if (bitwidth != 16 && + (bitwidth != 8 || !dstTy.getElementType().isInteger())) { + return false; + } + + return true; + } + + LogicalResult + lowerSharedToDotOperandTransLL(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto ctx = rewriter.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto dstTy = cast(op.getType()); + auto srcTy = cast(op.getSrc().getType()); + auto dotEnc = cast(dstTy.getEncoding()); + auto shape = dstTy.getShape(); + + auto llvmElemTy = typeConverter->convertType(dstTy.getElementType()); + auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto ldsTransLayout = chooseDsReadB64TrLayout(dotEnc, shape, bitwidth); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + SmallVector outVals; + SmallVector elemsI32; + mlir::Type retTy = dstTy; + bool valid = emitTransferBetweenRegistersAndShared( + ldsTransLayout, srcTy, llvmElemTy, + /*maxVecElems=*/std::nullopt, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy, Value vecAddr) { + if (bitwidth == 16) { + auto dsReadOp = + rewriter.create(loc, vecTy, vecAddr); + Value vecVal = dsReadOp.getResult(); + for (int v = 0; v < vecTy.getNumElements(); v++) { + outVals.push_back( + b.extract_element(llvmElemTy, vecVal, b.i32_val(v))); + } + } else { + // pack elements in i32 vectors + auto numElems = vecTy.getNumElements(); + auto numElemsI32 = (numElems * bitwidth / 32); + auto i32VecTy = VectorType::get(numElemsI32, i32_ty); + + auto dsReadOp = + rewriter.create(loc, i32VecTy, vecAddr); + Value vecVal = dsReadOp.getResult(); + for (auto i = 0; i < numElemsI32; ++i) { + elemsI32.push_back( + b.extract_element(i32_ty, vecVal, b.i32_val(i))); + } + } + }); + + // unpack i32 vectors and cast to native type + if (bitwidth != 16) { + auto numElemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(llvmElemTy, numElemsPerVec); + for (int v = 0; v < static_cast(elemsI32.size()); ++v) { + auto vec = b.bitcast(elemsI32[v], vecTy); + for (int i = 0; i < numElemsPerVec; ++i) + outVals.push_back(b.extract_element(llvmElemTy, vec, b.i32_val(i))); + } + + retTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(outVals.size(), llvmElemTy)); + } + assert(valid && "Failed to emit LDS transpose load operations"); + Value result = packLLElements(loc, typeConverter, outVals, rewriter, retTy); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const AMD::TargetInfo &targetInfo; +}; + +} // namespace + +void mlir::triton::AMD::populateMemoryOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + PatternBenefit transBenefit = PatternBenefit(benefit.getBenefit() + 1); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, + transBenefit); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp new file mode 100644 index 000000000..f4b23524f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "OptimizeLDSUtility.h" +#include "TargetInfo.h" +#include "TritonAMDGPUToLLVM/Passes.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; + +namespace mlir::triton { +#define GEN_PASS_DEF_OPTIMIZEAMDLDSUSAGE +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +namespace { + +class OptimizeAMDLDSUsage + : public mlir::triton::impl::OptimizeAMDLDSUsageBase { + + int LDSLimit; + + // Try to reduce LDS usage of convert op by adding tmp layout in conversion: + // + // %1 = convert %0 (src layout -> dst layout) + // -> + // %1 = convert %0 (src layout -> tmp) + // %2 = convert %1 (tmp -> dst layout) + // + // The implicit LDS usage of convert op depends on src and dst layouts + // + // Consider mfma->blocked conversion as an example. + // + // tensor shape: [128, 128] + // mfma layout: warpsPerCTA = [1, 4], instrShape = [32, 32] + // blocked layout: sizePerThread = [1, 4], threadsPerWarp = [32, 2], + // warpsPerCTA = [4, 1] + // + // minimal mfma tile is: [1*32, 4*32] = [32, 128] + // minimal blocked tile is: [1*32*4, 4*2*1] = [128, 8] + // + // Roughtly scratch buffer shape for conversion is: + // [max(32, 128), max(128, 16)] = [128, 128]. + // + // This shape could be reduces by introducing intermediate + // layout and replacing old convert operations with two new conversions: + // + // %1 = convert %0 (mfma -> blocked) + // -> + // %1 = convert %0 (mfma -> tmp) + // %2 = convert %1 (tmp -> blocked) + // + // Let's consider tmp as blocked layout: + // sizePerThread = [1, 4], threadsPerWarp = [32, 2], warpsPerCTA = [1, 4] + // Tmp layout scratch buffer has shape: [1*32*1, 4*2*4] = [32, 32] + // + // With intermediate layout we have two scratch buffers: + // + // %1 = convert %0 (mfma -> tmp): [max(32, 32), max(128, 32)] = [32, 128] + // %2 = convert %1 (tmp -> blocked): [max(32, 128), max(32, 32)] = [128, 32] + // + // Both of these buffers are 4x times smaller than original one and their live + // times do not intersect, therefore this transformation lowers LDS + // consumption. + void tryFitCvtIntoLDS(triton::gpu::ConvertLayoutOp cvtOp, int targetLDSSize) { + OpBuilder builder(cvtOp); + + auto srcType = cvtOp.getSrc().getType(); + auto dstType = cvtOp.getType(); + + auto srcEnc = + cast(srcType.getEncoding()); + auto dstEnc = + cast(dstType.getEncoding()); + + auto ctx = srcEnc.getContext(); + auto rank = srcType.getRank(); + + unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); + auto warpSize = triton::gpu::getWarpSize(srcEnc); + + // Find all possible shapes of WarpsPerCTA by finding all possible + // factorizations of numWarps. Pick shape for which both conversions in + // decomposition use LDS less than LDSLimit and for which sum of LDS usage + // is minimal. If no such shape exists, do not decompose. + auto factorizedNumWarps = + mlir::triton::AMD::factorizePowerOf2(numWarps, rank); + // Create a list of temporary layouts + SmallVector elemsPerThread(rank, 1); + SmallVector threadsPerWarp(rank, 1); + + // Special case for rank == 1 + if (rank == 1) { + threadsPerWarp[0] = warpSize; + } else { + assert(rank > 1); + threadsPerWarp[rank - 1] = warpSize / 8; + threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + } + + auto layoutCTA = triton::gpu::getCTALayout(srcEnc); + auto order = triton::gpu::getOrder(srcType); + SmallVector dummyWarpsPerCTA(rank, 1); + + auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get( + ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order, + layoutCTA); + SmallVector tmpLayouts; + for (int i = 0; i < factorizedNumWarps.size(); i++) { + auto warpsPerCTA = factorizedNumWarps[i]; + auto pushNotNull = [&](Attribute enc) { + if (enc) + tmpLayouts.push_back(enc); + }; + + pushNotNull(mlir::triton::AMD::createTmpLayout(srcEnc, warpsPerCTA)); + pushNotNull(mlir::triton::AMD::createTmpLayout(dstEnc, warpsPerCTA)); + pushNotNull( + mlir::triton::AMD::createTmpLayout(baseFallbackLayout, warpsPerCTA)); + } + + unsigned minLDSUsage = 2 * LDSLimit; + int minIdx = -1; + for (int i = 0; i < tmpLayouts.size(); i++) { + auto resources = mlir::triton::AMD::estimateResourcesForReplacement( + builder, cvtOp, tmpLayouts[i]); + // TODO analyze performance along with LDS consumption + if (resources.LDS < minLDSUsage) { + minLDSUsage = resources.LDS; + minIdx = i; + } + } + + if (minIdx == -1 || minLDSUsage > targetLDSSize) { + return; + } + + assert(minIdx >= 0 && minIdx < tmpLayouts.size()); + auto tmpLayout = tmpLayouts[minIdx]; + auto replacementCvts = + mlir::triton::AMD::createNewConvertOps(builder, cvtOp, tmpLayout); + + cvtOp.replaceAllUsesWith(replacementCvts.second.getResult()); + cvtOp.erase(); + } + + struct LDSBottleneckOperation { + triton::gpu::ConvertLayoutOp op; + int64_t LDSSizeTarget; + }; + + // Assuming that all buffer above scratch buffer in memory space can be + // shifted down in memory, gives an optimistic estimation of memory space + // available for scratch buffer. + int64_t + computeTargetScratchBufferSize(triton::gpu::ConvertLayoutOp op, + Allocation *allocation, + ArrayRef liveBuffers) { + int totalSize = 0; + auto scratchBufferId = allocation->getBufferId(op.getOperation()); + int64_t scratchBufferSize = allocation->getAllocatedSize(scratchBufferId); + size_t totalLDSConsumption = 0; + for (auto buf : liveBuffers) { + totalLDSConsumption = std::max( + totalLDSConsumption, allocation->getAllocatedInterval(buf).end()); + } + int64_t freeRequired = totalLDSConsumption - LDSLimit; + return std::max(static_cast(0), scratchBufferSize - freeRequired); + } + + SmallVector + findLDSBottleneckLayoutConvert(ModuleAllocation &allocAnalysis, + FunctionOpInterface func) { + SmallVector candidates; + auto funcAnalysis = allocAnalysis.getFuncData(func); + auto liveBuffers = funcAnalysis->getLiveBuffers(); + + func.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + auto srcTy = cvtOp.getSrc().getType(); + auto dstTy = cvtOp.getResult().getType(); + if (!cvtNeedsSharedMemory(srcTy, dstTy)) + return; + auto cvtBuffer = funcAnalysis->getBufferId(cvtOp.getOperation()); + assert(cvtBuffer != Allocation::InvalidBufferId); + + auto targetScratchBufferSize = computeTargetScratchBufferSize( + cvtOp, funcAnalysis, liveBuffers[cvtOp]); + auto currentLDSConsumption = funcAnalysis->getAllocatedSize(cvtBuffer); + if (currentLDSConsumption > targetScratchBufferSize) + candidates.push_back({cvtOp, targetScratchBufferSize}); + }); + return candidates; + } + +public: + OptimizeAMDLDSUsage(StringRef targetArch, int customLDSLimit) + : OptimizeAMDLDSUsageBase() { + this->targetArch = targetArch.str(); + this->customLDSLimit = customLDSLimit; + } + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + if ((this->LDSLimit = this->customLDSLimit) == 0) { + if (this->targetArch.empty()) { + mod->emitError("missing gfx* target for pass ") + << this->getName().str(); + return signalPassFailure(); + } + triton::AMD::TargetInfo targetInfo(this->targetArch.c_str()); + LDSLimit = targetInfo.getSharedMemorySize(); + } + + ModuleAllocation allocAnalysis(mod); + if (allocAnalysis.getSharedMemorySize() <= LDSLimit) + return; + + auto rootFunctions = allocAnalysis.getRoots(); + for (auto rootFunc : rootFunctions) { + // Find operations with peak LDS consumption + auto candidates = findLDSBottleneckLayoutConvert(allocAnalysis, rootFunc); + // Try to transform candidate operations to fit them into LDS + for (auto candidate : candidates) + tryFitCvtIntoLDS(candidate.op, candidate.LDSSizeTarget); + } + } +}; + +} // namespace + +namespace mlir::triton::AMD { + +std::unique_ptr> +createOptimizeLDSUsagePass(StringRef targetArch, int customLDSLimit) { + return std::make_unique(targetArch, customLDSLimit); +} + +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp new file mode 100644 index 000000000..fd3578278 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp @@ -0,0 +1,124 @@ +#include "OptimizeLDSUtility.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton::AMD { + +constexpr int kPtrBitWidth = 64; + +int getCvtOpLDSUsage(RankedTensorType srcTy, RankedTensorType dstTy) { + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + unsigned elems = getNumScratchElements(scratchConfig.paddedRepShape); + auto bytes = + isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + + return bytes; +} + +int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp op) { + return getCvtOpLDSUsage(op.getSrc().getType(), op.getType()); +} + +static void stepFactorizationPow2(std::vector> &factors, + SmallVector &curFactor, + int restTwos, int dim) { + if (dim == curFactor.size()) { + if (restTwos == 0) + factors.push_back(curFactor); + return; + } + curFactor[dim] = 1; + for (int i = 0; i <= restTwos; ++i) { + stepFactorizationPow2(factors, curFactor, restTwos - i, dim + 1); + curFactor[dim] *= 2; + } +} + +std::vector> factorizePowerOf2(int n, int rank) { + assert(llvm::isPowerOf2_32(n)); + int x = log2(n); + std::vector> factors; + SmallVector curFactor(rank, 1); + stepFactorizationPow2(factors, curFactor, x, 0); + return factors; +} + +triton::gpu::DistributedEncodingTrait +createTmpLayout(triton::gpu::DistributedEncodingTrait layout, + ArrayRef warpsPerCTA) { + auto ctx = layout.getContext(); + if (auto src = dyn_cast(layout)) + return triton::gpu::AMDMfmaEncodingAttr::get( + ctx, src.getVersionMajor(), src.getVersionMinor(), warpsPerCTA, + src.getMDim(), src.getNDim(), src.getIsTransposed(), + src.getCTALayout()); + if (auto src = dyn_cast(layout)) + return triton::gpu::AMDWmmaEncodingAttr::get( + ctx, src.getVersion(), src.getIsTransposed(), warpsPerCTA, + src.getCTALayout()); + if (auto src = dyn_cast(layout)) + return triton::gpu::BlockedEncodingAttr::get( + ctx, src.getSizePerThread(), src.getThreadsPerWarp(), warpsPerCTA, + src.getOrder(), src.getCTALayout()); + if (auto src = dyn_cast(layout)) { + auto parent = cast(src.getParent()); + return triton::gpu::DotOperandEncodingAttr::get( + ctx, src.getOpIdx(), createTmpLayout(parent, warpsPerCTA), + src.getKWidth()); + } + if (auto src = dyn_cast(layout)) { + // TODO: think of a way to construct slice layouts based on warpsPerCTA + // argument + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); + return triton::gpu::SliceEncodingAttr::get( + ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA)); + } + // TODO: support linear layout if needed. + if (isa(layout)) + return {}; + assert(false && "Encountered unsupported layout"); + return {}; +} + +std::pair +createNewConvertOps(OpBuilder &builder, triton::gpu::ConvertLayoutOp &cvtOp, + Attribute tmpLayout) { + auto srcType = cvtOp.getSrc().getType(); + auto dstType = cvtOp.getType(); + + auto newDstType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); + RankedTensorType newSrcType = RankedTensorType::get( + srcType.getShape(), srcType.getElementType(), tmpLayout); + + auto tmpCvt = builder.create( + cvtOp.getLoc(), newSrcType, cvtOp.getSrc()); + auto newEpilogueCvt = builder.create( + cvtOp.getLoc(), newDstType, tmpCvt); + + return std::make_pair(tmpCvt, newEpilogueCvt); +} + +Resources +estimateResourcesForReplacement(OpBuilder builder, + mlir::triton::gpu::ConvertLayoutOp cvtOp, + Attribute tmpLayout) { + Resources res; + RankedTensorType srcTy = cvtOp.getSrc().getType(); + RankedTensorType dstTy = cvtOp.getType(); + RankedTensorType intermediateTy = RankedTensorType::get( + srcTy.getShape(), srcTy.getElementType(), tmpLayout); + + int tmpCvtLDS = mlir::triton::AMD::getCvtOpLDSUsage(srcTy, intermediateTy); + int newCvtLDS = mlir::triton::AMD::getCvtOpLDSUsage(intermediateTy, dstTy); + res.LDS = std::max(tmpCvtLDS, newCvtLDS); + return res; +} + +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h new file mode 100644 index 000000000..f5d2f27db --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h @@ -0,0 +1,49 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::AMD { + +int getCvtOpLDSUsage(RankedTensorType srcTy, RankedTensorType dstTy); + +int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp op); + +std::vector> factorizePowerOf2(int n, int rank); + +/// Copy given layout with different warpsPerCTA parameter +/// +/// \param layout original layout +/// \param warpsPerCTA new warpsPerCTA +/// \returns create layout +triton::gpu::DistributedEncodingTrait +createTmpLayout(triton::gpu::DistributedEncodingTrait layout, + ArrayRef warpsPerCTA); + +/// Creates two chained convert layout operations +/// +/// %1 = cvtOp %0 (srcLayout -> dstLayout) // original operation +/// -> +/// %2 = cvtOp %0 (srcLayout -> tmpLayout) // .first +/// %3 = cvtOp %2 (tmpLayout -> dstLayout) // .second +/// +/// \param builder +/// \param cvtOp original operation +/// \param tmpLayout +/// \returns pair of created operations +std::pair +createNewConvertOps(OpBuilder &builder, triton::gpu::ConvertLayoutOp &cvtOp, + Attribute tmpLayout); + +struct Resources { + int LDS; +}; + +Resources +estimateResourcesForReplacement(OpBuilder builder, + mlir::triton::gpu::ConvertLayoutOp cvtOp, + Attribute tmpLayout); + +} // namespace mlir::triton::AMD + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 000000000..0c4915cf5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,48 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H_ + +#include "TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" + +namespace mlir::triton::AMD { +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMemoryOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + +void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit); +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, + ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, + const TargetInfo &targetInfo, PatternBenefit benefit); +void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); +void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + +} // namespace mlir::triton::AMD + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_PATTERNTRITONGPUOPTOLLVM_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..c85c2bf6c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,62 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" + +using namespace mlir; + +namespace { + +struct GetNumProgramsOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::GetNumProgramsOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; + Location loc = op->getLoc(); + assert(op.getAxisAsInt() < 3); + Value blockId = + rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxisAsInt()]); + rewriter.replaceOpWithNewOp(op, i32_ty, blockId); + return success(); + } +}; + +struct CondBarrierOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::amdgpu::CondBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterCondBarBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *trueBlock = rewriter.createBlock(afterCondBarBlock); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, adaptor.getPred(), trueBlock, + afterCondBarBlock); + + // conditional barrier + rewriter.setInsertionPointToStart(trueBlock); + rewriter.create(loc); + rewriter.create(loc, afterCondBarBlock); + rewriter.eraseOp(op); + return success(); + } +}; + +} // namespace + +void mlir::triton::AMD::populateSPMDOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp new file mode 100644 index 000000000..0164a37b6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -0,0 +1,552 @@ +#include "SchedInstructions.h" +#include "TritonAMDGPUToLLVM/Passes.h" +#include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPULOWERINSTRUCTIONSCHEDHINTS +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +#undef DEBUG_TYPE +#define DEBUG_TYPE "lower-insert-instruction-sched-hints" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; + +// TODO: The following passes/algorithms are applicable only for a single +// `tt.dot` op in a `scf.for` block -i.e., a single schedule hint op per block. +// Note, we need to relax this assumption in the future and extend the current +// implementation. + +namespace mlir::triton { +template +void setNumGeneratedMMAs(DotOpType op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType) { + auto *ctx = op->getContext(); + auto mmaType = RankedTensorType::get({m, n, k}, elementType); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + schedHint.setNumMMAsAttr(counterAttr); + }); +} + +template void setNumGeneratedMMAs(triton::DotOp op, size_t mmaCount, unsigned m, + unsigned n, unsigned k, Type elementType); +template void setNumGeneratedMMAs(triton::DotScaledOp op, size_t mmaCount, + unsigned m, unsigned n, unsigned k, + Type elementType); + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, + Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + if (auto opIdxAttr = op->template getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + const bool isBufferLoadOp = + std::is_same_v; + if (opIdxAttr.getValue() == 0) { + schedHint.setNumGlobalLoadsAAttr(counterAttr); + schedHint.setIsBufferLoadsAEnabled(isBufferLoadOp); + } else { + schedHint.setNumGlobalLoadsBAttr(counterAttr); + schedHint.setIsBufferLoadsBEnabled(isBufferLoadOp); + } + } + }); +} +template void setNumGeneratedGlobalLoads(triton::amdgpu::BufferLoadOp op, + size_t globalLoadsCount, Type type); +template void setNumGeneratedGlobalLoads(triton::LoadOp op, + size_t globalLoadsCount, Type type); + +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, + Type type) { + auto *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + Value dst = op.getResult(); + auto dstTensorTy = cast(dst.getType()); + auto dotOperandLayout = + cast(dstTensorTy.getEncoding()); + const size_t opIdx = dotOperandLayout.getOpIdx(); + assert(opIdx < 2); + if (opIdx == 0) + schedHint.setNumDsReadsAAttr(counterAttr); + else + schedHint.setNumDsReadsBAttr(counterAttr); + }); +} + +void storeOpSchedAnnotations(triton::gpu::LocalStoreOp op, + size_t localStoreOpCount, Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumDsWritesAAttr(counterAttr); + else + schedHint.setNumDsWritesBAttr(counterAttr); + } + }); +} + +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp) { + triton::DotOp dotOp = nullptr; + size_t dotCounter = 0; + forOp->walk( + [&dotOp, &dotCounter](triton::DotOp op) { dotOp = op, ++dotCounter; }); + + return (dotCounter == 1) ? dotOp : nullptr; +} + +// The AMDGPU compiler backend can fold consecutive `ds_read/ds_write` +// instructions into wider variants as a part of its load/store optimization +// during the instruction selection pass. If it happens, then it means that +// we are overestimated these types of instructions at the current level of +// the IR. In this scenario, the inserted `sched.group.barriers` will result +// in "fooling" the scheduling solver which can mess up the final assembly. +// To avoid this, we switch off the backend load/store folding optimization +// which is going to prevent instructions folding. In this case, the +// instruction widths of `ds_read/ds_write` instructions are going to match +// their LLVM representations. This is implemented as follows. +// TODO: The current implementation disables `ds_read/ds_write` folding for +// all basic blocks in the currently processed function. We should try to +// avoid it. The compiler backend team proposed to play we the load/store +// alignment values within the currently processed basic block as an +// alternative solution. +void disableInstructionFolding(triton::amdgpu::InstructionSchedHint schedHint) { + auto funcOp = schedHint->getParentOfType(); + MLIRContext *ctx = schedHint->getContext(); + llvm::SmallVector targetFeatures; + if (auto attr = funcOp.getTargetFeatures()) { + llvm::copy(attr->getFeatures(), std::back_inserter(targetFeatures)); + } + targetFeatures.push_back(str_attr("-load-store-opt")); + funcOp.setTargetFeaturesAttr( + ::mlir::LLVM::TargetFeaturesAttr::get(ctx, targetFeatures)); +} +} // namespace mlir::triton + +namespace { + +// Create an intrinsic to control how different instruction kinds should +// interleave for better ILP. +void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, + mlir::amdgpu::sched_barrier_opt_enum maskValue, + int sizeValue, int groupIdValue) { + if (sizeValue < 1) + return; + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + IntegerAttr size = + rewriter.getI32IntegerAttr(static_cast(sizeValue)); + IntegerAttr groupId = + rewriter.getI32IntegerAttr(static_cast(groupIdValue)); + rewriter.create(loc, mask, size, groupId); +} + +// Insert intrinsic that controls the types of instructions that may be +// allowed to cross the intrinsic during instruction scheduling. +Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, + mlir::amdgpu::sched_barrier_opt_enum maskValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + return rewriter.create(loc, mask); +} + +// Insert an experimental intrinsic for instruction group level parallelism. +// The intrinsic takes a value that specifies the strategy. +Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { + IntegerAttr iglpValue = + rewriter.getI32IntegerAttr(static_cast(value)); + return rewriter.create(loc, iglpValue); +} + +// The following structs represent in-source database regarding a target +// machine. It provides instructions execution and issue cycles needed for +// scheduling. +struct MachineDescr { + virtual ~MachineDescr() = default; + virtual uint32_t getDsReadIssueCycle(uint32_t instrWidth) = 0; + virtual FailureOr getMmaExecCycle(llvm::ArrayRef dims) = 0; + virtual uint32_t getMmaIssueCycle() = 0; + virtual uint32_t getNumLdsDataPaths() = 0; + static std::unique_ptr get(StringRef arch); +}; + +template struct MachineDescrImpl : MachineDescr { + uint32_t getDsReadIssueCycle(uint32_t instrWidth) final { + return instrWidth == 16 ? 8 : 4; + } + + FailureOr getMmaExecCycle(llvm::ArrayRef dims) final { + if (dims.size() != 3) + return failure(); + auto it = + Derived::mmaTable.find(std::make_tuple(dims[0], dims[1], dims[2])); + if (it != Derived::mmaTable.end()) + return it->second; + return failure(); + } + + uint32_t getMmaIssueCycle() final { return Derived::mmaIssueCycle; }; + uint32_t getNumLdsDataPaths() final { return Derived::numLdsDataPaths; } + + using MmaTable = + llvm::DenseMap, uint32_t>; +}; + +struct CDNA2Kind : public MachineDescrImpl { + static const inline MmaTable mmaTable{{{32, 32, 8}, 64}, {{16, 16, 16}, 32}}; + static const inline uint32_t mmaIssueCycle{4}; + static const inline uint32_t numLdsDataPaths{2}; +}; + +struct CDNA3Kind : public MachineDescrImpl { + static const inline MmaTable mmaTable{{{32, 32, 8}, 32}, {{16, 16, 16}, 16}}; + static const inline uint32_t mmaIssueCycle{4}; + static const inline uint32_t numLdsDataPaths{2}; +}; + +std::unique_ptr MachineDescr::get(StringRef arch) { + AMD::ISAFamily family = AMD::deduceISAFamily(arch); + switch (family) { + case AMD::ISAFamily::CDNA3: { + return std::make_unique>(); + } + case AMD::ISAFamily::CDNA2: { + return std::make_unique>(); + } + default: { + return nullptr; + } + } + return nullptr; +} + +struct InstructionSchedHintsRewriter + : public OpRewritePattern { + + InstructionSchedHintsRewriter(MLIRContext *ctx, StringRef arch, + int32_t numStages) + : OpRewritePattern(ctx), numStages(numStages) { + + this->machineDescr = MachineDescr::get(arch); + } + + // The following is inspired by ROCm Composable Kernel library's V3 pipelining + // (see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). + // This scheduling requires 1x register and 1x LDS buffers combined with the + // local (LDS to registers) and global (HBM to registers) data prefetching. + void createLocalPrefetchSchedule( + PatternRewriter &rewriter, Location loc, + triton::amdgpu::InstructionSchedHint schedHint) const { + + if (!machineDescr) { + schedHint.emitError("unknown target architecture detected"); + return; + } + + const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue(); + const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue(); + + const uint32_t numDsWriteInstA = schedHint.getNumDsWritesA().getValue(); + const uint32_t numDsWriteInstB = schedHint.getNumDsWritesB().getValue(); + + const uint32_t numBufferLoadInstA = + schedHint.getNumGlobalLoadsA().getValue(); + const uint32_t numBufferLoadInstB = + schedHint.getNumGlobalLoadsB().getValue(); + + if (numBufferLoadInstA == 0) { + schedHint.emitError( + "global/buffer load count for tile A must be initialized"); + return; + } + + if (numBufferLoadInstB == 0) { + schedHint.emitError( + "global/buffer load count for tile B must be initialized"); + return; + } + + const uint32_t numMmaInst = schedHint.getNumMMAs().getValue(); + + auto mmaType = cast(schedHint.getNumMMAs().getType()); + auto maybeMmaExecCycle = machineDescr->getMmaExecCycle(mmaType.getShape()); + if (llvm::failed(maybeMmaExecCycle)) { + schedHint.emitError("unknown mma instruction type"); + return; + } + const uint32_t mmaExecCycle = maybeMmaExecCycle.value(); + + auto dsReadsAType = cast(schedHint.getNumDsReadsA().getType()); + auto dsReadsBType = cast(schedHint.getNumDsReadsB().getType()); + + const uint32_t dsReadAIssueCycle = + machineDescr->getDsReadIssueCycle(dsReadsAType.getShape()[0]); + const uint32_t dsReadBIssueCycle = + machineDescr->getDsReadIssueCycle(dsReadsBType.getShape()[0]); + + const uint32_t mmaIssueCycle = this->machineDescr->getMmaIssueCycle(); + const uint32_t numLdsDataPaths = this->machineDescr->getNumLdsDataPaths(); + + // Compute how many ds_reads from tile A we can put between to adjacent + // MFMAs + const auto dsReadAMmaRate = (mmaExecCycle - mmaIssueCycle + + numLdsDataPaths * dsReadAIssueCycle - 1) / + (numLdsDataPaths * dsReadAIssueCycle); + + // Compute how many ds_reads from tile B we can put between to adjacent + // MFMAs + const auto dsReadBMmaRate = (mmaExecCycle - mmaIssueCycle + + numLdsDataPaths * dsReadBIssueCycle - 1) / + (numLdsDataPaths * dsReadBIssueCycle); + + // Compute how many (MFMA [ds_read]+) clusters we can get from tile A + const auto numDsreadAMma = + (numDsReadInstA + dsReadAMmaRate - 1) / dsReadAMmaRate; + + // Compute how many (MFMA [ds_read]+) clusters we can get from tile B + const auto numDsreadBMma = + (numDsReadInstB + dsReadBMmaRate - 1) / dsReadBMmaRate; + + // Stage 1 + // Compute how many MFMAs we have left for stage 1 - i.e., clusters with + // ds_writes, global/buffer_loads, MFMAs + const auto numMmaStage1 = numMmaInst - (numDsreadAMma + numDsreadBMma); + const auto numMmaPerIssue = + numMmaStage1 / (numBufferLoadInstA + numBufferLoadInstB); + + // Compute how many ds_writes we have per global/buffer load resulting from + // tile A + const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA; + + // Compute how many ds_writes we have per global/buffer load resulting from + // tile B + const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB; + + for (size_t i = 0; i < numBufferLoadInstA; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueA; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + numMmaPerIssue - numDswritePerIssueA, 0); + } + + for (size_t i = 0; i < numBufferLoadInstB; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueB; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + numMmaPerIssue - numDswritePerIssueB, 0); + } + + // stage 2 + for (size_t i = 0; i < numDsreadAMma; ++i) { + if ((numDsReadInstA - (i + 1) * dsReadAMmaRate) >= dsReadAMmaRate) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, + dsReadAMmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, + numDsReadInstA - (numDsreadAMma - 1) * dsReadAMmaRate, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); + } + + for (size_t i = 0; i < numDsreadBMma; ++i) { + if ((numDsReadInstB - (i + 1) * dsReadBMmaRate) >= dsReadBMmaRate) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, + dsReadBMmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, + numDsReadInstB - (numDsreadBMma - 1) * dsReadBMmaRate, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); + } + + disableInstructionFolding(schedHint); + } + + LogicalResult + matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, + PatternRewriter &rewriter) const override { + auto schedVariant = instructionSchedHint.getVariant(); + if (schedVariant == mlir::triton::amdgpu::SchedHint::none) { + rewriter.eraseOp(instructionSchedHint); + return success(); + } + + // The switch controls whether instructions are allowed to cross the basic + // block boundaries at the very top and at the very bottom. Note, this is + // not supposed to be used together with IGLP OPT according to the AMDGPU + // backend documentation. + const bool limitSchedulingRange = + schedVariant == mlir::triton::amdgpu::SchedHint::local_prefetch; + ; + Location loc = instructionSchedHint->getLoc(); + Block *block = instructionSchedHint->getBlock(); + if (limitSchedulingRange) { + rewriter.setInsertionPointToStart(block); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); + } + + rewriter.setInsertionPoint(block, std::prev(block->end())); + + switch (schedVariant) { + case mlir::triton::amdgpu::SchedHint::llvm_iglp_0: + case mlir::triton::amdgpu::SchedHint::llvm_iglp_1: + createIglpOpt(rewriter, loc, static_cast(schedVariant) - 1); + break; + case mlir::triton::amdgpu::SchedHint::local_prefetch: + createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint); + break; + case mlir::triton::amdgpu::SchedHint::none: + default: + break; + } + + if (limitSchedulingRange) + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); + + rewriter.eraseOp(instructionSchedHint); + return success(); + } + +private: + int32_t numStages; + std::unique_ptr machineDescr; +}; + +struct TritonAMDGPULowerInstructionSchedHints + : public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase< + TritonAMDGPULowerInstructionSchedHints> { + + explicit TritonAMDGPULowerInstructionSchedHints(StringRef arch, + int32_t numStages) { + this->arch = std::move(arch.str()); + this->numStages = numStages; + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp mod = getOperation(); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(ctx); + + patterns.add(ctx, this->arch, + this->numStages); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + + signalPassFailure(); + } + } +}; + +struct TritonAMDGPUInsertInstructionSchedHints + : public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase< + TritonAMDGPUInsertInstructionSchedHints> { + + explicit TritonAMDGPUInsertInstructionSchedHints(StringRef variant) { + this->variant = std::move(variant.str()); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp mod = getOperation(); + + auto schedHint = mlir::triton::amdgpu::SchedHint::none; + std::transform(variant.begin(), variant.end(), variant.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (auto maybeSchedHint = triton::amdgpu::symbolizeSchedHint(variant)) + schedHint = maybeSchedHint.value(); + else { + LDBG("ignoring instruction scheduling because " + "unknown instruction scheduling variant has been provided"); + return; + } + + if (schedHint != mlir::triton::amdgpu::SchedHint::none) { + mod.walk([&](scf::ForOp forOp) { + // Note, instruction schedule barriers are inserted only in the case of + // a single `tt.dot` op in a `scf::ForOp` scope in the current + // implementation. + if (auto dotOp = getSingleDotOpIfExists(forOp)) { + OpBuilder rewriter(ctx); + rewriter.setInsertionPointAfter(dotOp); + rewriter.create(dotOp->getLoc(), + schedHint); + } + }); + } + } +}; +} // namespace + +namespace mlir::triton { +std::unique_ptr> +createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch, + int32_t numStages) { + return std::make_unique(arch, + numStages); +} + +std::unique_ptr> +createTritonAMDGPUInsertInstructionSchedHintsPass(StringRef variant) { + return std::make_unique(variant); +} +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h new file mode 100644 index 000000000..988896b41 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h @@ -0,0 +1,27 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_SCHEDINSTRUCTIONS_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_SCHEDINSTRUCTIONS_H_ + +#include "mlir/IR/Types.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// The following functions are used to collect and set side-channel information +// during to LLVM conversion/lowering to facilitate instruction scheduling +// controls. +namespace mlir::triton { +template +void setNumGeneratedMMAs(DotOpType op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType); + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, + Type type); +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount, + Type type); +void storeOpSchedAnnotations(triton::gpu::LocalStoreOp op, size_t llvmOpCount, + Type type); +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp); +} // namespace mlir::triton + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_SCHEDINSTRUCTIONS_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp new file mode 100644 index 000000000..d9aefe628 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -0,0 +1,487 @@ +#include "TargetInfo.h" +#include "SchedInstructions.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using mlir::triton::AMD::DppCtrl; +namespace mlir::triton::AMD { + +namespace { +template +LLVM::LLVMFuncOp getOrInsertFunction(T &moduleOp, const Location loc, + RewriterBase &rewriter, StringRef name, + LLVM::LLVMFunctionType type) { + LLVM::LLVMFuncOp ret; + if (!(ret = moduleOp.template lookupSymbol(name))) { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + ret = rewriter.create(loc, name, type, + LLVM::Linkage::External); + } + return ret; +} + +// Extend all values to 64-bit per printf call requirements. +Value printfPromoteValue(RewriterBase &rewriter, Value value) { + auto *context = rewriter.getContext(); + auto loc = UnknownLoc::get(context); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto type = value.getType(); + + if (isa(type)) { + // The llvm.ptrtoint op requires signless integer types. + return b.ptrtoint(i64_ty, value); + } + + assert(type.getIntOrFloatBitWidth() <= 64); + + if (auto floatType = dyn_cast(type)) { + Value newValue = value; + if (!floatType.isF64()) + newValue = b.fpext(f64_ty, newValue); + return b.bitcast(newValue, i64_ty); + } + + assert(type.isIntOrIndex()); + if (type.getIntOrFloatBitWidth() < 64) { + if (type.isUnsignedInteger()) + return b.zext(ui64_ty, value); + if (type.isSignedInteger()) + return b.sext(i64_ty, value); + // Signless integers are printed using unsigned integer formats. + return b.zext(i64_ty, value); + } + + return value; +} +} // namespace + +llvm::AMDGPU::GPUKind TargetInfo::getGPUKind() const { + return llvm::AMDGPU::parseArchAMDGCN(arch); +} + +int TargetInfo::getSharedMemorySize() const { + int kbytes = getISAFamily() == ISAFamily::CDNA4 ? 160 : 64; + return kbytes * 1024; +} + +bool TargetInfo::supportMaximumMinimum() const { return false; } + +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + // On AMD hardware we don't have CTA clusters like NVIDIA. So this will always + // be zero. Whoever calling into this should make sure the whole program does + // not try to utilize CTA clusters. + return rewriter.create(loc, 0, 32); +} + +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.ballot", + type, cmp) + ->getResult(0); +} + +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { + if (ctaId.has_value()) { + llvm::report_fatal_error( + "AMDGPU does not support cross-CTA shared memory transfers"); + } + mlir::LLVM::AMD::llStore(rewriter, loc, ptr, val, pred); +} + +bool TargetInfo::canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const { + // AMD does not support stmatrix + return false; +} + +bool TargetInfo::canUseLDSTransLoad(int bitwidth) const { + return getISAFamily() == ISAFamily::CDNA4 && + llvm::is_contained({16, 8, 4, 6}, bitwidth); +} + +void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, + Value ptr, Value val) const { + llvm::report_fatal_error("AMDGPU does not support stmatrix"); +} + +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const { + if (ctaId.has_value()) { + llvm::report_fatal_error( + "AMDGPU does not support cross-CTA shared memory transfers"); + } + Value falseVal = rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)); + return mlir::LLVM::AMD::llLoad(rewriter, loc, ptr, elemTy, pred, falseVal); +} + +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily()); +} + +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily()); +} + +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); +} + +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); +} + +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, + ModuleOp moduleOp, int axis) const { + return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); +} + +// Cast and sext values into specific-length int to meet the requirements of +// instructions like UpdateDpp or readlane if necessary. +static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, + Value &val, Type fromType, + unsigned toBits) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned originalBits = fromType.getIntOrFloatBitWidth(); + Type toType = fromType; + + if (!fromType.isIntOrIndex()) { + val = b.bitcast(val, int_ty(originalBits)); + toType = int_ty(originalBits); + } + + if (originalBits < toBits) { + val = b.sext(int_ty(toBits), val); + toType = int_ty(toBits); + } + + return toType; +} + +// Trunc the value to specific length and then cast it to given type if +// necessary. This function is typically used in conjunction with +// castToAndSExtInt. +static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc, + Value val, Type valType, + unsigned fromBits) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned originalBits = valType.getIntOrFloatBitWidth(); + Value toVal = val; + + if (originalBits < fromBits) { + toVal = b.trunc(int_ty(originalBits), toVal); + } + + if (!valType.isIntOrIndex()) { + toVal = b.bitcast(toVal, valType); + } + + return toVal; +} + +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, + unsigned interleave) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (numLaneToReduce != 64) + return false; + + if (!llvm::is_contained( + {ISAFamily::CDNA2, ISAFamily::CDNA3, ISAFamily::CDNA4}, + getISAFamily())) { + return false; + } + + Operation *reduxOp = op.getSingleCombiner(); + if (!reduxOp) + return false; + + auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src, + uint32_t dppCtrl, int rowMask, + int bankMask) -> Value { + // DPP has limited support for data types, so here we need to + // cast non-integer types or integer types shorter than 32 bits + // to int32, except for fp32. + Type actualType = valType; + if (!valType.isF32()) { + actualType = castToAndSExtInt(rewriter, loc, src, valType, 32); + } + + Value dppResult = + rewriter + .create(loc, actualType, src, src, + rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), + rewriter.getBoolAttr(true)) + .getRes(); + + if (!valType.isF32()) { + src = truncAndCastFromInt(rewriter, loc, src, valType, 32); + dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32); + } + + IRMapping mapping; + mapping.map(reduxOp->getOperand(0), src); + mapping.map(reduxOp->getOperand(1), dppResult); + return rewriter.clone(*reduxOp, mapping)->getResult(0); + }; + + for (int i = 0; i < acc.size(); i++) { + Value buf; + auto valType = acc[i].getType(); + + // Here's the implementation of full-wavefront reduction using dpp. + // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ + // + // Each step has a v_mov_dpp instruction following the redux op. In + // some cases, the lower-level compiler could merge them into single + // instruction. For example, v_mov_dpp + max => v_max_dpp. + // + // For gfx9, we have 64 threads per warp. These 64 threads are arranged + // into 4 rows, with each row being 16 threads. Each 16 threads are arranged + // further into 4 banks, with each bank being 4 threads. Overall it's in a + // (row, bank, thread) structure. When shuffling, we use row/bank mask to + // indicate which row/bank to participate. Then modifier like row_shr and + // row_bcast means exact data movement schemes. In the following + // instructions, taking row 0 as an example: + // + // Step 1: Right shift for 8 lanes. + // lane 8-15 = redux(lane 0-7, lane 8-15) + // + // Step 2: Right shift for 4 lanes. + // lane 12-15 = redux(lane 8-11, lane 12-15) + // + // Step 3: Right shift for 2 lanes. + // lane 14-15 = redux(lane 12-13, lane 14-15) + // + // Step 4: Right shift for 1 lane. + // lane 15 = redux(lane 14, lane 15) + // + // Step 5: Broadcast lane 15 of each row to all the lanes of its next row. + // lane 16-31 = redux(lane 15, lane 16-31) + // + // Step 6: Broadcast lane 31 to lane 32-63. + // lane 32-63 = redux(lane 31, lane 32-63) + // + // Now the reduction result is stored in lane 63. + // + // Step 7: Read the reduction result from lane 63 and broadcast with + // readlane. + + const int allRows = 0xf; + const int allBanks = 0xf; + + const uint32_t dppCtrlRowShr = static_cast(DppCtrl::ROW_SHR0); + + // row_shr:8 + buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:4 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:2 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:1 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr, + allRows, allBanks); + + // row_bcast:15 row_mask:0xa + buf = createDppReduxOpWithBoundCtrl( + valType, buf, static_cast(DppCtrl::BCAST15), 0xa, allBanks); + + // row_bcast:31 + buf = createDppReduxOpWithBoundCtrl(valType, buf, + static_cast(DppCtrl::BCAST31), + allRows, allBanks); + + // Similarly, we need to cast data types for readlane instruction. + Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16); + + // Get reduction result from lane 63 + std::string intrinsic = "llvm.amdgcn.readlane"; + Value result = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType, + ValueRange{buf, b.i32_val(63)}) + ->getResult(0); + + result = truncAndCastFromInt(rewriter, loc, result, valType, 16); + + acc[i] = result; + } + + return true; +} + +void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, + ValueRange args, RewriterBase &rewriter, + bool useStdErr) const { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto *ctx = rewriter.getContext(); + mlir::Location loc = UnknownLoc::get(ctx); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // See + // https://github.com/ROCm/ROCm-Device-Libs/blob/rocm-6.0.x/ockl/src/services.cl#L263-L361 + // for details about the following HIP device print functions. + LLVM::LLVMFuncOp printBeginFn = getOrInsertFunction( + moduleOp, loc, rewriter, + useStdErr ? "__ockl_fprintf_stderr_begin" : "__ockl_printf_begin", + LLVM::LLVMFunctionType::get(i64_ty, + useStdErr ? ArrayRef() : i64_ty)); + LLVM::LLVMFuncOp printStrFn = getOrInsertFunction( + moduleOp, loc, rewriter, "__ockl_printf_append_string_n", + LLVM::LLVMFunctionType::get( + i64_ty, {i64_ty, ptr_ty(ctx), /*length=*/i64_ty, /*isLast=*/i32_ty})); + LLVM::LLVMFuncOp printArgsFn; + if (!args.empty()) { + printArgsFn = getOrInsertFunction( + moduleOp, loc, rewriter, "__ockl_printf_append_args", + LLVM::LLVMFunctionType::get( + i64_ty, {i64_ty, /*numArgs=*/i32_ty, i64_ty, i64_ty, i64_ty, i64_ty, + i64_ty, i64_ty, i64_ty, /*isLast=*/i32_ty})); + } + + // Emit the intrinsic function call to begin the printf. + Value zeroI64 = rewriter.create(loc, i64_ty, 0); + Value message = + b.call(printBeginFn, useStdErr ? ValueRange() : zeroI64).getResult(); + + // Emit the intrinsic function call to handle the printf format string. + Value oneI32 = b.i32_val(1); + Value zeroI32 = b.i32_val(0); + Value formatStrLen = + rewriter.create(loc, i64_ty, formatStrByteCount); + SmallVector arguments = {message, formatStrStart, formatStrLen, + args.empty() ? oneI32 : zeroI32}; + message = b.call(printStrFn, arguments).getResult(); + + // Emit the intrinsic function call to handle arguments iteratively. + // We can only handle at most 7 values each time. + constexpr size_t kArgsPerGroup = 7; + for (size_t group = 0; group < args.size(); group += kArgsPerGroup) { + size_t bound = std::min(group + kArgsPerGroup, args.size()); + size_t numArgs = bound - group; + + SmallVector arguments; + arguments.push_back(message); + arguments.push_back(b.i32_val(numArgs)); + for (size_t i = group; i < bound; ++i) { + arguments.push_back(printfPromoteValue(rewriter, args[i])); + } + // Pad out to 7 arguments since the function always needs 7 args. + for (size_t extra = numArgs; extra < kArgsPerGroup; ++extra) { + arguments.push_back(zeroI64); + } + + Value isLast = (bound == args.size()) ? oneI32 : zeroI32; + arguments.push_back(isLast); + message = b.call(printArgsFn, arguments).getResult(); + } +} + +std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { + std::string funcName = + resultElementTy.isInteger(32) ? "__ockl_mul_hi_u32" : "__ockl_mul_hi_u64"; + return funcName; +} + +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const { + return printfImpl(formatStrStart, formatStrByteCount, args, rewriter, + /*useStdError=*/false); +} + +void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); +} + +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // Compose and print an assert message. + llvm::SmallString<256> msgBuffer; + llvm::Twine("device assertion failed: '" + message + "', in " + func + + " at " + file + ":" + llvm::Twine(line) + "\n\0") + .toStringRef(msgBuffer); + Value msgValue = + LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgBuffer); + printfImpl(msgValue, msgBuffer.size_in_bytes(), /*args=*/ValueRange(), + rewriter, /*useStdError=*/true); + + // Set block barrrier before aborting kernel, give a chance for all + // the threads in a block to check/print the assert failure. + b.barrier(); + // Perform the trap to abort the kernel. + rewriter.create(loc); +} + +int TargetInfo::getSharedAddressSpace() const { return 3; } + +int TargetInfo::getAddressSpace(Attribute addressSpace) const { + int spaceId = 0; + if (isa(addressSpace)) { + spaceId = 3; + } else { + llvm::report_fatal_error("Only support SharedMemorySpace for now"); + } + return spaceId; +} + +bool TargetInfo::supportVectorizedAtomics() const { + // Note: not currently tested or used, but AMD generally supports vectorized + // atomics. + return true; +} + +void TargetInfo::storeOpAnnotation(triton::gpu::LocalStoreOp op, + size_t localStoreOpCount, Type type) const { + storeOpSchedAnnotations(op, localStoreOpCount, type); +} + +bool TargetInfo::supportsDirectToLdsLoadBitWidth(int bitWidth) const { + switch (getISAFamily()) { + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: + return llvm::is_contained({32, 16, 8}, bitWidth); + case ISAFamily::CDNA4: + // Disable 96 bits as it uses 128bit strides between threads in a warp + return llvm::is_contained({128, /*96, */ 32, 16, 8}, bitWidth); + default: + break; + } + + return false; +} + +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h new file mode 100644 index 000000000..60da18aa9 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -0,0 +1,88 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TARGETINFO_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TARGETINFO_H_ + +#include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "llvm/TargetParser/TargetParser.h" +#include + +namespace mlir::triton::AMD { +class TargetInfo : public mlir::triton::TargetInfoBase { +public: + explicit TargetInfo(std::string arch) : arch(std::move(arch)) {} + + ISAFamily getISAFamily() const { return deduceISAFamily(arch); } + + llvm::AMDGPU::GPUKind getGPUKind() const; + + int getSharedMemorySize() const; + + bool supportMaximumMinimum() const override; + + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + + Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const override; + + void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const override; + Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const override; + bool canUseLDSTransLoad(int bitwidth) const; + + bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const override; + void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val) const override; + + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const override; + + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + int axis) const override; + + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; + + std::string getMulhiFuncName(Type resultElementTy) const override; + + void printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const override; + + void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const override; + + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; + + int getSharedAddressSpace() const override; + + int getAddressSpace(Attribute addressSpace) const override; + + bool supportVectorizedAtomics() const override; + + void storeOpAnnotation(triton::gpu::LocalStoreOp op, size_t localStoreOpCount, + Type type) const override; + + bool supportsDirectToLdsLoadBitWidth(int bitWidth) const; + +private: + void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, + RewriterBase &rewriter, bool useStdErr) const; + + std::string arch; +}; +} // namespace mlir::triton::AMD + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TARGETINFO_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp new file mode 100644 index 000000000..fb655694e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -0,0 +1,52 @@ +#include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "llvm/TargetParser/TargetParser.h" + +namespace mlir::triton::AMD { + +ISAFamily deduceISAFamily(llvm::StringRef arch) { + llvm::AMDGPU::GPUKind kind = llvm::AMDGPU::parseArchAMDGCN(arch); + + // See https://llvm.org/docs/AMDGPUUsage.html#processors for how to categorize + // the following target gfx architectures. + + // CDNA ISA cases + switch (kind) { + case llvm::AMDGPU::GK_GFX950: + return ISAFamily::CDNA4; + case llvm::AMDGPU::GK_GFX942: + return ISAFamily::CDNA3; + case llvm::AMDGPU::GK_GFX90A: + return ISAFamily::CDNA2; + case llvm::AMDGPU::GK_GFX908: + return ISAFamily::CDNA1; + default: + break; + } + + // RNDA ISA cases + if (kind >= llvm::AMDGPU::GK_GFX1100 && kind <= llvm::AMDGPU::GK_GFX1201) + return ISAFamily::RDNA3; + if (kind >= llvm::AMDGPU::GK_GFX1030 && kind <= llvm::AMDGPU::GK_GFX1036) + return ISAFamily::RDNA2; + if (kind >= llvm::AMDGPU::GK_GFX1010 && kind <= llvm::AMDGPU::GK_GFX1013) + return ISAFamily::RDNA1; + + return ISAFamily::Unknown; +} + +bool supportsVDot(llvm::StringRef arch) { + switch (deduceISAFamily(arch)) { + case AMD::ISAFamily::CDNA1: + case AMD::ISAFamily::CDNA2: + case AMD::ISAFamily::CDNA3: + case AMD::ISAFamily::CDNA4: + case AMD::ISAFamily::RDNA2: + case AMD::ISAFamily::RDNA3: + return true; + default: + break; + } + return false; +} + +} // namespace mlir::triton::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp new file mode 100644 index 000000000..6f87ecb67 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -0,0 +1,258 @@ +#include "TritonAMDGPUToLLVM/Passes.h" + +#include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" +#include "TargetInfo.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +using namespace mlir; + +namespace { + +class TritonLLVMFunctionConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + } +}; + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); + addLegalOp(); + } +}; + +struct ConvertTritonAMDGPUToLLVM + : public triton::impl::ConvertTritonAMDGPUToLLVMBase< + ConvertTritonAMDGPUToLLVM> { + explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz) { + this->arch = targetArch.str(); + this->ftz = ftz; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + AMD::TargetInfo targetInfo(this->arch.getValue()); + if (targetInfo.getISAFamily() == AMD::ISAFamily::Unknown) { + mod.emitError("unsupported target: '") << this->arch.getValue() << "'"; + return signalPassFailure(); + } + + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); + TritonLLVMConversionTarget convTarget(*context); + + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + + // Allocate shared memory and set barrier + ModuleAllocation allocation(mod); + ModuleMembarAnalysis membarPass(&allocation); + membarPass.run(); + + // Lower functions + { + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + mlir::triton::populateFuncOpConversionPattern( + typeConverter, funcPatterns, targetInfo, patternBenefitDefault); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + } + + // initSharedMemory is run before the conversion of call and ret ops, + // because the call op has to know the shared memory base address of each + // function + initSharedMemory(typeConverter); + + // Convert call and ret ops + { + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + } + + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + + // Emit logics to get threadId/blockIds/linearized clusterCTAId etc. and + // cache the values. The reason to do it here is that cluster_ctaid is + // currently implemented via inline asm, and thus cannot be CSEed. + // clusterCTAId will be emitted only when numCTAs is larger than 1, and + // other values will be DCEed if not used hereafter. + OpBuilder::InsertPoint indexInsertPoint; + + RewritePatternSet patterns(context); + int commonBenefit = patternBenefitPrioritizeOverLLVMConversions; + // Make benefit for AMD specific patterns higher so they apply before common + // patterns + int AMDBenefit = commonBenefit + 1; + auto populatePatterns1 = [&](auto populateFunc, int benefit) { + populateFunc(typeConverter, patterns, axisInfoAnalysis, allocation, + benefit); + }; + + auto populatePatterns5 = [&](auto populateFunc, int benefit) { + populateFunc(typeConverter, patterns, benefit); + }; + + auto populatePatterns6 = [&](auto populateFunc, int benefit) { + populateFunc(typeConverter, patterns, axisInfoAnalysis, allocation, + targetInfo, benefit); + }; + + auto populatePatterns7 = [&](auto populateFunc, int benefit) { + populateFunc(typeConverter, patterns, targetInfo, benefit); + }; + + AMD::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, + patterns, AMDBenefit); + mlir::triton::populateConvertLayoutOpToLLVMPatterns( + typeConverter, targetInfo, patterns, commonBenefit); + AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, + AMDBenefit); + AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz, + axisInfoAnalysis, allocation, + targetInfo, AMDBenefit); + AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, + axisInfoAnalysis, AMDBenefit); + populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns, + commonBenefit); + populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns, + commonBenefit); + populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns, + commonBenefit); + populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, + commonBenefit); + populatePatterns7(mlir::triton::populateGatherOpToLLVMPatterns, + commonBenefit); + + AMD::populateMemoryOpToLLVMPatterns(typeConverter, patterns, targetInfo, + AMDBenefit); + mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo, + patterns, commonBenefit); + mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, + patterns, commonBenefit); + mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, + targetInfo, commonBenefit); + mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, + targetInfo, commonBenefit); + mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, commonBenefit); + AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, AMDBenefit); + + mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter, + patterns, AMDBenefit); + mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns, + targetInfo, AMDBenefit); + + // TODO(thomas): this should probably be done in a separate step to not + // interfere with our own lowering of arith ops. Add arith/math's patterns + // to help convert scalar expression to LLVM. + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); + + // Native lowering patterns + mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, + mlir::gpu::amd::HIP); + + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); + mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, + targetInfo, commonBenefit); + + mlir::triton::proton::populateRecordOpToLLVMPattern( + typeConverter, patterns, targetInfo, commonBenefit); + + mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { + return signalPassFailure(); + } + } + +private: + void initSharedMemory(LLVMTypeConverter &typeConverter) { + ModuleOp mod = getOperation(); + OpBuilder b(mod.getBodyRegion()); + auto ctx = mod.getContext(); + auto loc = mod.getLoc(); + auto elemTy = typeConverter.convertType(b.getIntegerType(8)); + // Set array size 0 and external linkage indicates that we use dynamic + // shared allocation to allow a larger shared memory size for each kernel. + // + // Ask for 16B alignment on global_smem because that's the largest we should + // ever need (4xi32). + auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); + auto global = b.create( + loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, + "global_smem", /*value=*/Attribute(), /*alignment=*/16, + // Add ROCm support. + static_cast(NVVM::NVVMMemorySpace::kSharedMemorySpace)); + } +}; + +} // namespace + +namespace mlir::triton { + +std::unique_ptr> +createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) { + return std::make_unique(targetArch, ftz); +} + +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp new file mode 100644 index 000000000..94aaa13e3 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -0,0 +1,364 @@ +#include "PatternTritonGPUOpToLLVM.h" + +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +SmallVector upcast8xMxfp4(RewriterBase &rewriter, + amdgpu::UpcastMXFPOp upcastOp, bool tofp16, + Value packedVec) { + Location loc = upcastOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // MXFP4 has 4 bits, S.EE.M, for Sign, Exponent, and Mantissa respectively. + // For a specific S, we have a total of 8 bit patterns. We can encode all + // these 8 resultant bf16/fp16 bit patterns in a lookup table (LUT). It + // happens that llvm.amdgcn.perm supports selecting 4 bytes from 8 input bytes + // using a 4-byte selector. So the overall idea is to use llvm.amdgcn.perm to + // implement such a LUT; though we need to select the two bytes for the + // resultant bf16/fp16 bit patterns separately. For the byte containing S, we + // also need to handle the S and E bits separately. + + // FP4 has 4 bits: S.EE.M. Bf16/fp16 bit patterns for positive values: + // + // FP4 | BF16 | FP16 | Value + // ------ | ------ | ------ | ----- + // 0.00.0 | 0x0000 | 0x0000 | + 0.0 + // 0.00.1 | 0x3f00 | 0x3800 | + 0.5 + // 0.01.0 | 0x3f80 | 0x3c00 | + 1.0 + // 0.01.1 | 0x3fc0 | 0x3e00 | + 1.5 + // 0.10.0 | 0x4000 | 0x4000 | + 2.0 + // 0.10.1 | 0x4040 | 0x4200 | + 3.0 + // 0.11.0 | 0x4080 | 0x4400 | + 4.0 + // 0.11.1 | 0x40c0 | 0x4600 | + 6.0 + // + // Encode Byte #0 (M) for BF16/FP16 in a LUT. + Value resB0LutLo = tofp16 ? b.i32_val(0) : b.i32_val(0xc0800000); + Value resB0LutHi = tofp16 ? b.i32_val(0) : b.i32_val(0xc0804000); + // Encode Byte #1 (EM, non-S part) for BF16/FP16 in a LUT. + Value resB1LutLoNoS = tofp16 ? b.i32_val(0x3e3c3800) : b.i32_val(0x3f3f3f00); + Value resB1LutHiNoS = tofp16 ? b.i32_val(0x46444240) : b.i32_val(0x40404040); + + Type i32Ty = rewriter.getI32Type(); + auto permU32FnTy = LLVM::LLVMFunctionType::get(i32Ty, {i32Ty, i32Ty, i32Ty}); + LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( + rewriter, upcastOp, "llvm.amdgcn.perm", permU32FnTy); + + // Start with 8 mxfp4 elements in a single i32 register + // | e7e6 | e5e4 | e3e2 | e1e0 | + Value input = b.bitcast(packedVec, i32Ty); + + // Step 1: extract EM bits for elements 0,2,4,6 and 1,3,5,7 respectively. + // e2m1_6420_idx = | 0[0e6EM] | 0[0e4EM] | 0[0e2EM] | 0[0e0EM] | + Value e2m1_6420_idx = b.and_(input, b.i32_val(0x07070707)); + // e2m1_7531_idx = | [0e7EM]0 | [0e5EM]0 | [0e3EM]0 | [0e1EM]0 | + Value e2m1_7531_idx = b.and_(input, b.i32_val(0x70707070)); + // e2m1_7531_idx = | 0[0e7EM] | 0[0e5EM] | 0[0e3EM] | 0[0e1EM] | + e2m1_7531_idx = b.lshr(e2m1_7531_idx, b.i32_val(4)); + + // Step 2: extract S bit for elements 0,2,4,6 and 1,3,5,7 + // s_6420 = | 0[e6S000] | 0[e4S000] | 0[e2S000] | 0[e0S000] | + Value s_6420 = b.and_(input, b.i32_val(0x08080808)); + // s_6420 = | [e6S000]0 | [e4S000]0 | [e2S000]0 | [e0S000]0 | + s_6420 = b.shl(s_6420, b.i32_val(4)); + // s_7531 = | [e7S000]0 | [e5S000]0 | [e3S000]0 | [e1S000]0 | + Value s_7531 = b.and_(input, b.i32_val(0x80808080)); + + // Step 3: Upcast elements 0,2,4,6 to 4 16-bit elements + // Select Byte #0. It's always 0 if upcasting to fp16. + // resB0_6420 = | e6B0 | e4B0 | e2B0 | e0B0 | + Value resB0_6420 = b.i32_val(0); + if (!tofp16) { + resB0_6420 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {resB0LutHi, resB0LutLo, e2m1_6420_idx}) + .getResult(); + } + // Select Byte #1 + Value resB1NoS_6420 = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {resB1LutHiNoS, resB1LutLoNoS, e2m1_6420_idx}) + .getResult(); + // resB1_6420 = | e6B1 | e4B1 | e2B1 | e0B1 | + Value resB1_6420 = b.or_(resB1NoS_6420, s_6420); + // Construct 16-bit values of e0 and e2 + // res_20 = | e2B1 | e2B0 | e0B1 | e0B0 | = | e2_f16 | e0_f16 | + Value res_20 = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {resB1_6420, resB0_6420, b.i32_val(0x05010400)}) + .getResult(); + // Construct 16-bit values of e4 and e6 + // res_64 = | e6B1 | e6B0 | e4B1 | e4B0 | = | e6_f16 | e4_f16 | + Value res_64 = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {resB1_6420, resB0_6420, b.i32_val(0x07030602)}) + .getResult(); + + // Step 4: Upcast elements 1,3,5,7 to 4 16-bit elements + // This is a copy of step 3 on different group of elements + // Select Byte #0. It's always 0 if upcasting to fp16. + // resB0_7531 = | e7B0 | e5B0 | e3B0 | e1B0 | + Value resB0_7531 = b.i32_val(0); + if (!tofp16) { + resB0_7531 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {resB0LutHi, resB0LutLo, e2m1_7531_idx}) + .getResult(); + } + // Select Byte #1 + Value resB1NoS_7531 = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {resB1LutHiNoS, resB1LutLoNoS, e2m1_7531_idx}) + .getResult(); + // resB1_7531 = | e7B1 | e5B1 | e3B1 | e1B1 | + Value resB1_7531 = b.or_(resB1NoS_7531, s_7531); + // Construct 16-bit values of e1 and e3 + // res_31 = | e3B1 | e3B0 | e1B1 | e1B0 | = | e3_f16 | e1_f16 | + Value res_31 = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {resB1_7531, resB0_7531, b.i32_val(0x05010400)}) + .getResult(); + // Construct 16-bit values of e5 and e7 + // res_75 = | e7B1 | e7B0 | e5B1 | e5B0 | = | e7_f16 | e5_f16 | + Value res_75 = + LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {resB1_7531, resB0_7531, b.i32_val(0x07030602)}) + .getResult(); + + // Step 5: Reorder 16-bit elements to be 0,1,2,3,4,5,6,7 + // res_10 = | e1_f16 | e0_f16 | + Value res_10 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {res_31, res_20, b.i32_val(0x05040100)}) + .getResult(); + // res_32 = | e3_f16 | e2_f16 | + Value res_32 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {res_31, res_20, b.i32_val(0x07060302)}) + .getResult(); + // res_54 = | e5_f16 | e4_f16 | + Value res_54 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {res_75, res_64, b.i32_val(0x05040100)}) + .getResult(); + // res_76 = | e7_f16 | e6_f16 | + Value res_76 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, + {res_75, res_64, b.i32_val(0x07060302)}) + .getResult(); + + return {res_10, res_32, res_54, res_76}; +} + +SmallVector upcastMxfp4(RewriterBase &rewriter, + amdgpu::UpcastMXFPOp upcastOp, bool toFp16, + ArrayRef values) { + assert(values.size() % 4 == 0); + Location loc = upcastOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + SmallVector results; + results.reserve(values.size() * 2); + Type elemType = toFp16 ? f16_ty : bf16_ty; + for (int i = 0; i < values.size(); i += 4) { + Value v0 = values[i]; + Value v1 = values[i + 1]; + Value v2 = values[i + 2]; + Value v3 = values[i + 3]; + Value packedVec = b.undef(vec_ty(i8_ty, 4)); + packedVec = b.insert_element(packedVec, v0, b.i32_val(0)); + packedVec = b.insert_element(packedVec, v1, b.i32_val(1)); + packedVec = b.insert_element(packedVec, v2, b.i32_val(2)); + packedVec = b.insert_element(packedVec, v3, b.i32_val(3)); + SmallVector v4i32 = + upcast8xMxfp4(rewriter, upcastOp, toFp16, packedVec); + for (int j = 0; j < 4; j++) { + Value elements = b.bitcast(v4i32[j], vec_ty(elemType, 2)); + results.push_back(b.extract_element(elements, b.i32_val(0))); + results.push_back(b.extract_element(elements, b.i32_val(1))); + } + } + return results; +} + +Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, Value scale, + bool fastMath) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value scaleF32 = + b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty); + Value scaleF16 = + LLVM::AMD::cvtFp32ToFp16(loc, rewriter, scaleF32, RoundingMode::RTNE); + Value mulF16 = b.fmul(v, scaleF16); + if (fastMath) + return mulF16; + // Account for NaN in the scale as per the mxfp specification. + Value scaleIsNan = b.icmp_eq(scale, b.i8_val(0xff)); + Value nanF16 = b.bitcast(b.i16_val(0x7c01), f16_ty); + return b.select(scaleIsNan, nanF16, b.bitcast(mulF16, f16_ty)); +}; + +// Scales the given bf16 v using the given scale factor without relying on bf16 +// multiplication. +// +// In gfx9 architectures, we don't have bf16 VALU ops. So instead this function +// handles v * scale multiplication using fp32 VALU ops. LLVM backend can do it +// for us, just with unnecessary overheads. +Value mxfpScaleBf16ViaF32(RewriterBase &rewriter, Location loc, Value v, + Value scale, bool fastMath) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value c16 = b.i32_val(16); + Value vF32 = + b.bitcast(b.shl(b.zext(i32_ty, b.bitcast(v, i16_ty)), c16), f32_ty); + Value scaleF32 = + b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty); + Value mulF32 = b.fmul(vF32, scaleF32); + Value mulI16 = b.trunc(i16_ty, b.lshr(b.bitcast(mulF32, i32_ty), c16)); + Value mulBf16 = b.bitcast(mulI16, bf16_ty); + if (fastMath) + return mulBf16; + // Account for NaN in the scale as per the mxfp specification. + Value scaleIsNan = b.icmp_eq(scale, b.i8_val(0xff)); + Value nanBf16 = b.bitcast(b.i16_val(0x7fff), bf16_ty); + return b.select(scaleIsNan, nanBf16, mulBf16); +}; + +class UpcastMXFPOpPattern + : public ConvertOpToLLVMPattern { +private: + const TargetInfoBase &targetInfo; + +public: + UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(amdgpu::UpcastMXFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto fpType = op.getFpType(); + bool isPacked = fpType == ScaleDotElemType::E2M1; + if (!(isPacked || fpType == ScaleDotElemType::E4M3 || + fpType == ScaleDotElemType::E5M2)) + return rewriter.notifyMatchFailure(op, "NYI: non-mxfp4/mxfp8 cases"); + + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter); + LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); + LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType()); + + // When we lower scaled dot op, we made sure to distribute K only on one + // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values + // along the K dimension. So in total each thread should read 32x main + // element values. + if (xVals.size() != scaleVals.size() * (isPacked ? 16 : 32)) + return rewriter.notifyMatchFailure(op, "unsupported problem size"); + + auto dotEncoding = + cast(op.getSrc().getType().getEncoding()); + auto mfmaEncoding = dyn_cast(dotEncoding.getParent()); + if (!mfmaEncoding) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand"); + LDBG("mfma: " << mfmaEncoding); + + int mDim = mfmaEncoding.getMDim(); + if (mDim != 32 && mDim != 16) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics"); + + int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + + bool useFp16 = op.getType().getElementType().isF16(); + if (isPacked) { + xVals = upcastMxfp4(rewriter, op, useFp16, xVals); + } + + // Given that MFMA layout for the A tensor arranges thread in a column-major + // manner, for the current tid, it's at row (tid % mDim). When we set up + // blocked layout for the A scale tensor, we made sure that it has a + // threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values + // for the current thread starts at ((tid % mDim) * (64 / mDim)). + Value offset = + b.mul(b.urem(laneId, b.i32_val(mDim)), b.i32_val(numThreads / mDim)); + + if (mDim == 32) { + // One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we + // tile, the same warp owns the whole K dim. Inside a warp, each thread + // only holds 4 consecutive elements along K--a 1x4 vector. We need to + // tile the warp 4 times to cover 32 values along K. So for a thread, the + // first 4 1x4 vectors it holds shares the first scale value at row (tid % + // mDim). the second 4 1x4 vectors shares the second scale value at row + // (tid % mDim); and so forth. + std::array scaleThreads = {offset, b.add(offset, b.i32_val(1))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + std::array si = { + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + }; + + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = + useFp16 ? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16], + op.getFastMath()) + : mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], + si[j / 16], op.getFastMath()); + } + } + } else { + assert(mDim == 16); + // One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we + // need to tile the warp 2 times to cover 32 values. So for a thread, the + // first 2 1x4 vectors shares the first scale value at row (tid % mDim). + std::array scaleThreads = {offset, b.add(offset, b.i32_val(1)), + b.add(offset, b.i32_val(2)), + b.add(offset, b.i32_val(3))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]), + }; + + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = useFp16 + ? mxfpScaleFp16(rewriter, loc, xVals[index], + si[j / 8], op.getFastMath()) + : mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], + si[j / 8], op.getFastMath()); + } + } + } + + Value result = + packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +void mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp new file mode 100644 index 000000000..9105a57d7 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -0,0 +1,631 @@ +#include "Utility.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" + +using mlir::triton::ModuleAxisInfoAnalysis; +using mlir::triton::AMD::DppCtrl; +using mlir::triton::AMD::ISAFamily; +using mlir::triton::gpu::appendOrGetExternFuncOp; +using mlir::triton::gpu::getFunctionType; + +namespace { +enum class ShflKind : uint32_t { + bfly = 0, + up = 1, + down = 2, + idx = 3, +}; + +std::string getTypeString(Type ty) { + std::string str; + llvm::raw_string_ostream rso(str); + ty.print(rso); + rso.flush(); + return str; +} + +std::string mangleFunc(std::string name, Type type) { + auto funcType = dyn_cast(type); + assert(funcType && "Expecting an LLVMFunctionType"); + std::string mangled = name + "_"; + auto retTy = funcType.getReturnType(); + mangled += getTypeString(retTy) + "_"; + auto params = funcType.getParams(); + for (auto paramType : params) { + mangled += getTypeString(paramType) + "_"; + } + return mangled; +} + +// Utility function to create a constant vector mask of length `vecSize` with +// the same `pred` value +Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc, + Value pred, int64_t vecSize) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto vecMaskTy = LLVM::getFixedVectorType(rewriter.getI1Type(), vecSize); + Value maskVal = b.undef(vecMaskTy); + for (size_t s = 0; s < vecSize; ++s) { + Value indexVal = + rewriter.create(loc, rewriter.getI64IntegerAttr(s)); + maskVal = b.insert_element(vecMaskTy, maskVal, pred, indexVal); + } + return maskVal; +} + +// Utility function to get the number of elements of a vector or a scalar +int64_t getNumElements(Type ty) { + if (auto vecType = dyn_cast(ty)) + return vecType.getNumElements(); + return 1; +} + +// Utility function to cast the given scalar or vector type to a vector type +Type castToVectorType(Type ty) { + if (isa(ty)) + return ty; + return LLVM::getFixedVectorType(ty, 1); +} + +} // namespace + +namespace mlir::LLVM::AMD { +static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, + ISAFamily isaFamily, Value val, Value i, + int strideInt, ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned bits = val.getType().getIntOrFloatBitWidth(); + + // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on + // 32bit/dwords so we need promote to 32 here. + auto valType = val.getType(); + if (!valType.isInteger(32) && bits <= 32) { + if (!valType.isIntOrIndex()) + val = b.bitcast(val, int_ty(bits)); + if (bits < 32) + val = b.sext(i32_ty, val); + + val = shuffleCommonImpl(loc, rewriter, isaFamily, val, i, strideInt, mode, + clamp); + + if (bits < 32) + val = b.trunc(int_ty(bits), val); + if (!valType.isIntOrIndex()) + val = b.bitcast(val, valType); + return val; + } + + if (bits == 64) { + Type vecTy = vec_ty(f32_ty, 2); + Value vec = b.bitcast(val, vecTy); + Value val0 = b.extract_element(f32_ty, vec, b.i32_val(0)); + Value val1 = b.extract_element(f32_ty, vec, b.i32_val(1)); + val0 = shuffleCommonImpl(loc, rewriter, isaFamily, val0, i, strideInt, mode, + clamp); + val1 = shuffleCommonImpl(loc, rewriter, isaFamily, val1, i, strideInt, mode, + clamp); + vec = b.undef(vecTy); + vec = b.insert_element(vecTy, vec, val0, b.i32_val(0)); + vec = b.insert_element(vecTy, vec, val1, b.i32_val(1)); + return b.bitcast(vec, val.getType()); + } + + auto mod = rewriter.getBlock()->getParent()->getParentOfType(); + Value threadId = getThreadId(rewriter, loc); + + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = b.i32_val(iWarpSize); + Value laneId = b.urem(threadId, warpSize); + auto bpermute = [&](Value lane) { + // Multiple lineId by 4. (More on permute instruction semantics: + // https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf#page=180 + Value byteOffset = b.i32_val(2); + Value permuteAddr = b.shl(lane, byteOffset); + return rewriter.create(loc, valType, permuteAddr, val); + }; + + switch (mode) { + case ShflKind::bfly: + if (strideInt > 16) { + Value stride = b.i32_val(32); + Value lineId = b.xor_(threadId, stride); + return bpermute(lineId); + } else if (strideInt == 16) { + Value offset = b.i32_val(0x401F); + return rewriter.create(loc, valType, val, offset); + } else { + if (!llvm::is_contained( + {ISAFamily::CDNA2, ISAFamily::CDNA3, ISAFamily::CDNA4}, + isaFamily)) { + // DPP is only supported for CDNA2/CDNA3/CDNA4 right now, so we fallback + // to ds_swizzle for other architectures. + // + // This map facilates the butterfly shuffle pattern for a stride less + // than 16. The pattern stride is the key of the map. + DenseMap masks{ + {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; + Value offset = b.i32_val(masks[strideInt]); + return rewriter.create(loc, valType, val, offset); + } + + auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src, + uint32_t dppCtrl, uint32_t rowMask, + uint32_t bankMask) { + return rewriter.create( + loc, valType, old, src, rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), rewriter.getBoolAttr(false)); + }; + + const int allRows = 0xf; + const int allBanks = 0xf; + + switch (strideInt) { + case 1: { + // quad_perm: 1, 0, 3, 2 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {1, 0, 3, 2}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 2: { + // quad_perm: 2, 3, 0, 1 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {2, 3, 0, 1}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 4: { + // row_shr:4 bank_mask: 0xa + auto ret = createDppOpWithoutBoundCtrl( + val, val, 4 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xa) + .getRes(); + + // row_shl:4 bank_mask: 0x5 + return createDppOpWithoutBoundCtrl( + ret, val, 4 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x5); + } + case 8: { + // row_shr:8 bank_mask: 0xc + auto ret = createDppOpWithoutBoundCtrl( + val, val, 8 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xc) + .getRes(); + + // row_shl:8 bank_mask: 0x3 + return createDppOpWithoutBoundCtrl( + ret, val, 8 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x3); + } + default: + assert(false && + "bfly shfl with stride >= 16 should not be handled by dpp."); + } + } + break; + case ShflKind::up: { + Value mask = b.icmp_slt(laneId, i); + Value delta = b.sub(laneId, i); + Value index = b.select(mask, laneId, delta); + return bpermute(index); + } + case ShflKind::idx: + return bpermute(i); + default: + assert(false && "Unsupported ShflKind"); + break; + } + return Value(); +} + +static Value shuffleCommon(Location loc, RewriterBase &rewriter, + ISAFamily isaFamily, Value val, Value i, + int strideInt, ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // To shuffle pointers, convert them to i64. + Type valTy = val.getType(); + if (isa(valTy)) + val = b.ptrtoint(i64_ty, val); + Value result = shuffleCommonImpl(loc, rewriter, isaFamily, val, i, strideInt, + mode, clamp); + if (isa(valTy)) + result = b.inttoptr(valTy, result); + return result; +} + +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, isaFamily, val, b.i32_val(i), i, + ShflKind::bfly, b.i32_val(0x1f)); +} + +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, isaFamily, val, b.i32_val(i), i, + ShflKind::up, b.i32_val(0x0)); +} + +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleIdx(loc, rewriter, val, b.i32_val(i), isaFamily); +} + +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + ISAFamily isaFamily) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, isaFamily, val, i, 0, ShflKind::idx, + b.i32_val(0x1f)); +} + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { + assert(axis >= 0); + assert(axis < 3); + assert(moduleOp); + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; + Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]); + return rewriter.create(loc, i32_ty, blockId); +} + +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal, triton::CacheModifier cm) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + + Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal})); + auto parent = ptr.getParentRegion()->getParentOfType(); + auto getLoadNameRaw = [](triton::CacheModifier cm) { + switch (cm) { + case triton::CacheModifier::CA: + return predicatedLoadCA; + case triton::CacheModifier::CG: + return predicatedLoadCG; + case triton::CacheModifier::CV: + return predicatedLoadCV; + default: + // Do not fail in compile time in the case of unsupported modifier. + // Just apply default config. + return predicatedLoad; + } + }; + + auto funcName = mangleFunc(getLoadNameRaw(cm), funcType); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); + return LLVM::createLLVMCallOp(rewriter, loc, funcOp, + ValueRange({ptr, pred, falseVal})) + .getResult(); +} + +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred, triton::CacheModifier cm) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto ctx = ptr.getContext(); + Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred})); + auto parent = ptr.getParentRegion()->getParentOfType(); + auto getStoreNameRaw = [](triton::CacheModifier cm) { + switch (cm) { + case triton::CacheModifier::WT: + return predicatedStoreWT; + case triton::CacheModifier::CG: + return predicatedStoreCG; + case triton::CacheModifier::CS: + return predicatedStoreCS; + default: + // Do not fail in compile time in the case of unsupported modifier. + // Just apply default config. + return predicatedStore; + } + }; + auto funcName = mangleFunc(getStoreNameRaw(cm), funcType); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); + LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange({ptr, val, pred})); +} + +static bool isPredicatedLoadCA(LLVM::CallOp callOp) { + return callOp.getCallee().value().contains(mlir::LLVM::AMD::predicatedLoadCA); +} + +static bool isPredicatedLoadCG(LLVM::CallOp callOp) { + return callOp.getCallee().value().contains(mlir::LLVM::AMD::predicatedLoadCG); +} + +static bool isPredicatedLoadCV(LLVM::CallOp callOp) { + return callOp.getCallee().value().contains(mlir::LLVM::AMD::predicatedLoadCV); +} + +static bool isPredicatedStoreCS(LLVM::CallOp callOp) { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedStoreCS); +} + +static bool isPredicatedStoreCG(LLVM::CallOp callOp) { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedStoreCG); +} + +static bool isPredicatedStoreWT(LLVM::CallOp callOp) { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedStoreWT); +} + +// Utility function that returns flags for a predicated +// Load or Store +// --------------------------------- +// Op | cm | volatile | NT +// -----+-----+--------------------- +// Load | .ca | F | F +// | .cg | F | T +// | .cs | F | T +// | .cv | T | X +// -----+-----+----------+--------- +// Store| .wb | F | F +// | .cg | F | F +// | .cs | F | T +// | .wt | T | X +// -----+-----+----------+--------- +std::pair +getCacheModifierFlagsForPredicatedCall(LLVM::CallOp callOp) { + if (isPredicatedLoadCA(callOp)) + return std::make_pair(false, false); + if (isPredicatedLoadCG(callOp)) + return std::make_pair(false, true); + if (isPredicatedLoadCV(callOp)) + return std::make_pair(true, true); + + if (isPredicatedStoreCG(callOp)) + return std::make_pair(false, false); + if (isPredicatedStoreCS(callOp)) + return std::make_pair(false, true); + if (isPredicatedStoreWT(callOp)) + return std::make_pair(true, true); + // unsupported modifier + return std::make_pair(false, false); +} + +// Create the auxiliary/cachepolicy value of ROCDL::RawPtrBufferLoad/StoreOp +// gfx942 and gfx950: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 +// Vector Memory instructions (Flat, Global, Scratch, and Buffer) have 3 +// bits to control scope and cacheability: +// - SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system +// - NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse +// +// -------+-----+-----+-----+----+-- +// Op | cm | SC1 | SC0 | NT | +// -------+-----+-----+-----+----+-- +// Load | .ca | 0 | 0 | 0 | +// | .cg | 0 | 1 | 1 | +// | .cs | 0 | 1 | 1 | +// | .cv | 1 | 1 | x | +// -------+-----+-----+-----+----+-- +// Store | .wb | 0 | 0 | 0 | +// | .cg | 0 | 0 | 0 | +// | .cs | 0 | 1 | 1 | +// | .wt | 1 | 1 | x | +// -------+-----+-----+-----+----+-- +// Atomic | N/A | 0 | 1 | x | Setting sc0 returns the pre-op value +// | N/A | 1 | 0 | x | Setting sc1 performs a system-scope atomic +// -------+-----+-----+-----+----+-- +static int32_t +getCtrlBitsForCacheModifierOnGFX_942_950(triton::CacheModifier cm, + bool isLoad) { + const int sc0Bit = 0b1, ntBit = 0b10, sc1Bit = 0b10000; + int32_t aux = 0; + switch (cm) { + case triton::CacheModifier::CA: + aux = 0; + break; + case triton::CacheModifier::CG: + if (isLoad) + aux |= sc0Bit | ntBit; + break; + case triton::CacheModifier::CS: + aux |= sc0Bit | ntBit; + break; + case triton::CacheModifier::CV: + assert(isLoad); + aux |= sc0Bit | sc1Bit; + break; + case triton::CacheModifier::WB: + assert(!isLoad); + aux = 0; + break; + case triton::CacheModifier::WT: + assert(!isLoad); + aux |= sc0Bit | sc1Bit; + break; + default: + aux = 0; + } + return aux; +} + +int32_t getCtrlBitsForBufferAtomicsOnGFX_942_950(bool setSC0, bool setSC1, + bool setNT) { + const int sc0Bit = 0b1, ntBit = 0b10, sc1Bit = 0b10000; + int32_t aux = 0; + if (setSC0) + aux |= sc0Bit; + if (setSC1) + aux |= sc1Bit; + if (setNT) + aux |= ntBit; + return aux; +} + +static int32_t getDefaultCtrlBitsForCacheModifier(triton::CacheModifier cm) { + return 0; +} + +// Cache modifiers changes how data is managed in the GPU's cache hierarchy: +// .ca: cache at all levels with LRU policy +// .cg: cache at L2, can use .ca or .cs +// .cs: cache streaming, use data once +// .cv: don't cache and fetch again +// .wb: write-back, writes back data at all cache levels +// .wt: write-through, write data directly to system memory +int32_t getCtrlBitsForCacheModifierOnTarget( + triton::CacheModifier cm, bool isLoad, + const mlir::triton::AMD::TargetInfo &targetInfo) { + switch (targetInfo.getGPUKind()) { + case llvm::AMDGPU::GK_GFX942: + case llvm::AMDGPU::GK_GFX950: + return getCtrlBitsForCacheModifierOnGFX_942_950(cm, isLoad); + default: + return getDefaultCtrlBitsForCacheModifier(cm); + } +} + +Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v, + triton::RoundingMode rounding) { + if (rounding == triton::RoundingMode::RTNE) { + LLVM::RoundingMode rm = LLVM::RoundingMode::NearestTiesToEven; + return rewriter.create( + loc, f16_ty, v, rm, LLVM::FPExceptionBehavior::Ignore); + } + + // TODO: Figure out the test failure with RTZ LLVM::ConstrainedFPTruncIntr and + // switch to not use inline assembly too. + assert(rounding == triton::RoundingMode::RTZ); + GCNBuilder builder; + + auto &cvt = *builder.create("v_cvt_f16_f32"); + auto res = builder.newOperand("=v"); + auto operand = builder.newOperand(v, "v"); + auto &setRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0xc"); + setRTZ(); + cvt(res, operand); + auto &resetRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0x0"); + resetRTZ(); + return builder.launch(rewriter, loc, f16_ty, false); +} + +Type getPointerTypeWithShape(Value basePtr, Value offset) { + Type basePtrType = basePtr.getType(); + auto offsetType = cast(offset.getType()); + return offsetType.cloneWith(std::nullopt, basePtrType); +} + +unsigned getContiguity(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + return axisAnalysisPass.getContiguity(ptr); +} + +unsigned getContiguity(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass) { + + Type type = getPointerTypeWithShape(ptr, offset); + RankedTensorType tensorTy = cast(type); + + // To compute the contiguity of the scalar/warp-uniform ptr and offset pair we + // need to look at the contiguity of the offsets and the alignment of the ptr + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto contiguity = axisAnalysisPass.getContiguity(offset, elemNumBits); + + // To get the alignment of the scalar ptr we need to look at the divisibility + auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); + auto maxMultipleBytes = axisInfo->getDivisibility(0); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto align = std::max(maxMultipleBytes / elemNumBytes, 1); + + // FIXME (Alex): this should not be needed anymore because it's done inside + // getContiguity, but we have an order issues with LL, so we keep this + // until the LL order issue is fixed + auto layout = tensorTy.getEncoding(); + auto linearLayout = triton::gpu::toLinearLayout(tensorTy.getShape(), layout); + auto llAttr = + triton::gpu::LinearEncodingAttr::get(tensorTy.getContext(), linearLayout); + auto order = triton::gpu::getOrder(tensorTy); + auto contigPerThread = llAttr.getContigPerThread(); + assert(order[0] < contigPerThread.size() && + "Unexpected contigPerThread size"); + contiguity = std::min(contiguity, contigPerThread[order[0]]); + + // Final contiguity is a min of the offset contiguity and pointer alignment + return std::min(align, contiguity); +} + +unsigned getVectorSize(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto contiguity = getContiguity(ptr, axisAnalysisPass); + auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + return std::min(128 / pointeeBitWidth, contiguity); +} + +unsigned getVectorSize(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass) { + auto contiguity = getContiguity(ptr, offset, axisAnalysisPass); + auto pointeeBitWidth = triton::getPointeeBitWidth(ptr.getType()); + return std::min(128 / pointeeBitWidth, contiguity); +} + +Type scaleDotElemTypeToMLIRType(MLIRContext *ctx, triton::ScaleDotElemType t) { + switch (t) { + case triton::ScaleDotElemType::FP16: + return Float16Type::get(ctx); + case triton::ScaleDotElemType::BF16: + return BFloat16Type::get(ctx); + case triton::ScaleDotElemType::E4M3: + return Float8E4M3FNType::get(ctx); + case triton::ScaleDotElemType::E5M2: + return Float8E5M2Type::get(ctx); + case triton::ScaleDotElemType::E3M2: + return Float6E3M2FNType::get(ctx); + case triton::ScaleDotElemType::E2M3: + return Float6E2M3FNType::get(ctx); + case triton::ScaleDotElemType::E2M1: + return Float4E2M1FNType::get(ctx); + default: + llvm_unreachable("unsupported ScaleDotElemType!"); + } +} + +bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, + RankedTensorType srcTy, + triton::gpu::MemDescType dstTy, + unsigned vectorSize) { + auto shape = srcTy.getShape(); + LinearLayout srcLayout = + triton::gpu::toLinearLayout(shape, srcTy.getEncoding()); + LinearLayout sharedLayout = + triton::gpu::toLinearLayout(shape, dstTy.getEncoding()); + LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); + + StringAttr kLane = rewriter.getStringAttr("lane"); + for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { + auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; + unsigned expected = vectorSize * (1 << inLane); + if (basis != expected) { + LDBG("detected uncoalesced layout from blocked to shared in async copy " + "for lane " + << 1 + inLane << "; given " << basis << " but expected " + << expected); + return false; + } + } + return true; +} + +} // namespace mlir::LLVM::AMD diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h new file mode 100644 index 000000000..e332a81e2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -0,0 +1,97 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ + +#include "TargetInfo.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "TritonAMDGPUToLLVM/TargetUtils.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir::LLVM::AMD { + +const char predicatedLoad[] = "__predicated_load"; +const char predicatedLoadCA[] = "__predicated_load_CA"; +const char predicatedLoadCG[] = "__predicated_load_CG"; +const char predicatedLoadCV[] = "__predicated_load_CV"; +const char predicatedStore[] = "__predicated_store"; +const char predicatedStoreCG[] = "__predicated_store_CG"; +const char predicatedStoreCS[] = "__predicated_store_CS"; +const char predicatedStoreWT[] = "__predicated_store_WT"; + +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); + +// Loads from shared or global memory with predication. +// `otherElems` is used to mask out the elements that are not loaded +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal, + triton::CacheModifier cm = triton::CacheModifier::NONE); + +// Stores to shared or global memory with predication. +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred, + triton::CacheModifier cm = triton::CacheModifier::NONE); + +// Get cache modifier information for creating load or store instruction +// Get flags for a predicated Load or Store +std::pair getCacheModifierFlagsForPredicatedCall(LLVM::CallOp); +// Get the cachepolicy value for a cache modifier +int32_t +getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool, + const mlir::triton::AMD::TargetInfo &); + +// Get cache modifier information for buffer atomics +int32_t getCtrlBitsForBufferAtomicsOnGFX_942_950(bool setSC0, bool setSC1, + bool setNT); + +Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v, + triton::RoundingMode rounding); + +// Return a tensor of pointers with the same type of `basePtr` and the same +// shape of `offset` +Type getPointerTypeWithShape(Value basePtr, Value offset); + +// Get contiguity for a tensor pointer `ptr` +unsigned getContiguity(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Get contiguity for a scalar pointer `ptr` and a tensor `offset` +unsigned getContiguity(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Determine the vector size of a tensor of pointers +unsigned getVectorSize(Value ptr, ModuleAxisInfoAnalysis &axisAnalysisPass); + +// Given a scalar pointer and a tensor of offsets, determine the vector size +unsigned getVectorSize(Value ptr, Value offset, + ModuleAxisInfoAnalysis &axisAnalysisPass); + +Type scaleDotElemTypeToMLIRType(MLIRContext *ctx, triton::ScaleDotElemType t); + +// Returns true if we can perform coalesced write from the source encoding to +// the destination encoding. +bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, + RankedTensorType srcTy, + triton::gpu::MemDescType dstTy, + unsigned vectorSize); + +} // namespace mlir::LLVM::AMD + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp new file mode 100644 index 000000000..f1c816b28 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -0,0 +1,1310 @@ +#include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "TritonAMDGPUTransforms/MfmaGroup.h" +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType; +using mlir::triton::gpu::chooseScaledMfmaOperandLayout; +using mlir::triton::gpu::chooseScaledMfmaScaleLayout; + +namespace { +using triton::AMD::ISAFamily; + +int getMfmaVersion(ISAFamily isaFamily) { + switch (isaFamily) { + case ISAFamily::CDNA1: + return 1; + case ISAFamily::CDNA2: + return 2; + case ISAFamily::CDNA3: + return 3; + case ISAFamily::CDNA4: + return 4; + default: + break; + } + return 0; +} + +int getWmmaVersion(StringRef archGen) { + if (archGen.contains("gfx11")) + return 1; + if (archGen.contains("gfx12")) + return 2; + return 0; +} + +FailureOr mlirTypeToScaledElemType(Type type) { + return llvm::TypeSwitch>(type) + .Case([](Type) { return ScaleDotElemType::E4M3; }) + .Case([](Type) { return ScaleDotElemType::E5M2; }) + .Case([](Type) { return ScaleDotElemType::E3M2; }) + .Case([](Type) { return ScaleDotElemType::E2M3; }) + .Case([](Type) { return ScaleDotElemType::E2M1; }) + .Default([](Type) { return failure(); }); +} + +// Check if the result of this tl.dot is used as opA of another tl.dot +// in the same region +bool isChainDotHead(tt::DotOpInterface dotOp) { + auto isInSameRegion = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion(); + }; + ForwardSliceOptions fwdOpt; + fwdOpt.filter = isInSameRegion; + SetVector fwdSlices; + getForwardSlice(dotOp, &fwdSlices, fwdOpt); + for (Operation *op : fwdSlices) { + if (auto dOp = dyn_cast(op)) { + assert(dOp != dotOp); + auto opA = dOp.getA().getDefiningOp(); + if (opA && fwdSlices.contains(opA)) { + return true; + } + } + } + return false; +} + +// Check if the opA of this tl.dot is the result of another tl.dot +// in the same region +bool isChainDotTail(tt::DotOpInterface dotOp) { + auto isInSameRegion = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion(); + }; + BackwardSliceOptions bwdOpt; + bwdOpt.omitBlockArguments = true; + bwdOpt.filter = isInSameRegion; + SetVector bwdSlices; + Operation *opA = dotOp.getA().getDefiningOp(); + if (!opA) + return false; + getBackwardSlice(opA, &bwdSlices, bwdOpt); + if (llvm::find_if(bwdSlices, [](Operation *op) { + return isa(op); + }) != bwdSlices.end()) + return true; + return false; +} + +SmallVector +warpsPerTile(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { + auto rank = shape.size(); + // Case 1: Early exit for batched matmul + if (rank == 3) + return {static_cast(numWarps), 1, 1}; + + // Case 2: For FA-like pattern, i.e. result of 1st tl.dot is used as the opA + // of the 2nd dot, we will set warpsPerCTA differently for 1st and 2nd dot + auto ttDotOp = cast(dotOp); + bool isHeadDot = isChainDotHead(ttDotOp); + bool isTailDot = isChainDotTail(ttDotOp); + // For the 1st dot in chain-dot, we always set warpsPerCTA={numWarps, 1} + // because this eliminates + // 1) inter-warp reduction in the softmax step. + // 2) layout conversion from #mma to #dot_op of the second dot. + if (isHeadDot) + return {static_cast(numWarps), 1}; + // For the 2nd dot in chain-dot, we always distribute warp along dim0 first, + // then dim1. Because + // 1) This is how we distribute the warps for the 1st dot. Now the + // warpsPerCTA for the 1st dot become the warp layout of the dotOperand + // layout of the 2nd dot, which must match the warpsPerCTA of the 2nd dot. + // 2) When shape[0] is small, as in decode kernels, we don't want to + // distribute more warps than shape[0] // mDim. If we do so, each warp + // needs to hold more elements in the final output, which increases + // register pressure, especially for large head dim (e.g. 512) attention + // kernels. + if (isTailDot) { + SmallVector ret = {1, 1}; + ret[0] = static_cast(std::min( + static_cast(numWarps), + static_cast(llvm::divideCeil(shape[0], shapePerWarp.first)))); + ret[1] = numWarps / ret[0]; + return ret; + } + + // Case 3: Regular cases + SmallVector tensorShape = {shape[0], shape[1]}; + SmallVector ret = {1, 1}; + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >= + tensorShape[1] / shapePerWarp.second / ret[1]) { + if (ret[0] < tensorShape[0] / shapePerWarp.first) { + ret[0] *= 2; + } else { + ret[1] *= 2; + } + } else { + ret[1] *= 2; + } + } while (true); + + if (ret[1] * shapePerWarp.second > tensorShape[1]) { + return {ret[1], ret[0]}; + } + + return ret; +} + +SmallVector +warpsPerTileMFMA(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { + return warpsPerTile(dotOp, shape, numWarps, shapePerWarp); +} + +SmallVector +warpsPerTileWMMA(Operation *dotOp, ArrayRef shape, int numWarps) { + auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr(); + return warpsPerTile(dotOp, shape, numWarps, {mnk[0], mnk[1]}); +} + +// Chooses a proper MFMA instruction that can used to compute the given dot op. +// If enforcedNonKDim is not zero, it will be used to overwrite the default +// logic to choose a MFMA with matching M/N dim. +FailureOr +chooseMfmaInstruction(int mfmaVersion, RankedTensorType cType, Type aElemType, + Type bElemType, int inputKSize, int enforcedNonKDim, + bool withScale, bool allowXF32) { + // number of matrix elements along k dim per one MFMA intruction + unsigned kDim = 0; + + auto resShape = cType.getShape(); + auto rank = resShape.size(); + auto M = resShape[rank - 2]; + auto N = resShape[rank - 1]; + + unsigned mDim = 0; + unsigned nDim = 0; + if (enforcedNonKDim != 0) { + mDim = nDim = enforcedNonKDim; + } else { + int minSize = std::min(M, N); + if (minSize >= 32) { + mDim = 32; + nDim = 32; + } + if (minSize >= 16 && minSize < 32) { + mDim = 16; + nDim = 16; + } + } + if (mDim == 0 || nDim == 0) + return failure(); + + FailureOr maybeMfmaIntrinsic = + MfmaIntrinsic::selectFor(mfmaVersion, mDim, nDim, inputKSize, aElemType, + bElemType, withScale, allowXF32); + if (failed(maybeMfmaIntrinsic)) + llvm::report_fatal_error("No match found in MFMA database\n"); + + kDim = maybeMfmaIntrinsic->kDim; + assert(kDim != 0); + assert(enforcedNonKDim != 0 || (M % mDim == 0 && N % nDim == 0)); + // if inputKSize % kDim != 0 this layout will introduce data duplication, + // consider FMA dot is prefered, except cases MFMA layout is enforced. + if (enforcedNonKDim == 0 && inputKSize % kDim != 0) + return failure(); + return maybeMfmaIntrinsic; +} + +FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, + int nonKDim, + bool withScale = false) { + RankedTensorType aType = dot.getA().getType(); + bool allowXF32 = + dot.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3; + return chooseMfmaInstruction( + mfmaVersion, dot.getC().getType(), aType.getElementType(), + dot.getB().getType().getElementType(), aType.getShape().back(), nonKDim, + withScale, allowXF32); +} + +FailureOr chooseMfmaInstruction(tt::DotScaledOp dot, + int mfmaVersion, int nonKDim) { + auto ctx = dot.getContext(); + int64_t inputKDim = dot.getA().getType().getShape().back(); + if (dot.getAElemType() == ScaleDotElemType::E2M1) { + // Since two fp4 are packed into int8, to get the correct K dim size, we + // need to multiply it by 2. + inputKDim *= 2; + } + Type aElemType = scaleDotElemTypeToMLIRType(ctx, dot.getAElemType()); + Type bElemType = scaleDotElemTypeToMLIRType(ctx, dot.getBElemType()); + return chooseMfmaInstruction(mfmaVersion, dot.getC().getType(), aElemType, + bElemType, inputKDim, nonKDim, + /*withScale=*/true, /*allowXF32=*/false); +} + +FailureOr chooseMfmaInstruction(tt::DotScaledOp dot, + int mfmaVersion, int nonKDim, + bool useFp16) { + // For scaled dot, we handle it with fp16 or bf16 emulation for now. + Builder b(dot.getContext()); + Type elemType = useFp16 ? b.getF16Type() : b.getBF16Type(); + return chooseMfmaInstruction(mfmaVersion, dot.getC().getType(), elemType, + elemType, dot.getA().getType().getShape().back(), + nonKDim, + /*withScale=*/false, /*allowXF32=*/false); +} + +using OperandTypesVector = SmallVector; +OperandTypesVector +selectMatrixCoreOperandTypes(tt::DotOp dot, + ArrayRef applicableTypes) { + SmallVector dotOperands = {dot.getA(), dot.getB(), dot.getC(), + dot.getD()}; + OperandTypesVector initElemTypes; + llvm::transform(dotOperands, std::back_inserter(initElemTypes), [](Value v) { + return cast(v.getType()).getElementType(); + }); + + // Use simple costmodel to define optimal set of the dot operands. + // Most expensive - accuracy loss conversions: + // - any larger type -> any smaller type; + // - float -> int; + // - int -> float (not supported for now); + // - signed int -> unsigned int; + // - unsigned int -> signed int with same or less size. + // They are never performed, better to use FMA. + // Supported conversion for now costs `1`, no conversion costs `0`. + // The model could be improved in the future. For example taken into account + // chain dot could be detected and result conversion score is decreased. + int maxConvertCost = + std::numeric_limits::max() / applicableTypes.front().size(); + auto calcConvertCost = [&](Type fromTy, Type toTy) -> int32_t { + if (fromTy == toTy) + return 0; + + // Skip conversion between int and float. Int16/int32 cases are lowered to + // FMA. + if (fromTy.isIntOrIndex() != toTy.isIntOrIndex()) + return maxConvertCost; + + if (fromTy.isIntOrIndex() && toTy.isIntOrIndex() && + fromTy.isUnsignedInteger() != toTy.isUnsignedInteger()) + return fromTy.isUnsignedInteger() && fromTy.getIntOrFloatBitWidth() < + toTy.getIntOrFloatBitWidth() + ? 1 + : maxConvertCost; + + return fromTy.getIntOrFloatBitWidth() <= toTy.getIntOrFloatBitWidth() + ? 1 + : maxConvertCost; + }; + auto minCost = maxConvertCost; + auto optTypes = OperandTypesVector(); + for (auto types : applicableTypes) { + assert(types.size() == initElemTypes.size()); + int accumulatedConvertCost = 0; + for (int i = 0; i < initElemTypes.size(); ++i) { + accumulatedConvertCost += calcConvertCost(initElemTypes[i], types[i]); + } + if (accumulatedConvertCost < minCost) { + minCost = accumulatedConvertCost; + optTypes = types; + } + } + return optTypes; +} + +OperandTypesVector getOperandTypesForWmmaOp(PatternRewriter &rewriter, + tt::DotOp dot, int version) { + Type f16 = rewriter.getF16Type(); + Type f32 = rewriter.getF32Type(); + Type bf16 = rewriter.getBF16Type(); + Type i8 = rewriter.getIntegerType(8); + Type i32 = rewriter.getIntegerType(32); + SmallVector applicableTypes = { + // clang-format off + {f16, f16, f32, f32}, + {f16, f16, f16, f16}, + {bf16, bf16, f32, f32}, + {bf16, bf16, bf16, bf16}, + {i8, i8, i32, i32}, + // i4, i4, i32, i32 - is supported configuration + // by WMMA instruction, but not supported by triton + // clang-format on + }; + // TODO: support fp8 configurations for WMMAv2. The code should be as + // following: + // if (version == 2) { + // Type fp8 = rewriter.getFp8Type(); + // Type bf8 = rewriter.getBF8Type(); + // applicableTypes.append({ + // // clang-format off + // {fp8, fp8, f32, f32}, + // {fp8, bf8, f32, f32}, + // {bf8, fp8, f32, f32}, + // {bf8, bf8, f32, f32}, + // // clang-format on + // }); + // } + return selectMatrixCoreOperandTypes(dot, applicableTypes); +} + +//===---------------------------------------------------------------------===// +// @brief Convert layout and cast element type of a given tensor +// +// If old element type is different from new element type, this function +// creates two new operations: +// 1. %converted_value = layout_convert %value, newEncoding +// 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType +// +// If old element type is same as new element type, this function creates only +// one operation: %converted_value = layout_convert %value, newEncoding +// +// @param rewriter +// @param value original tensor value, which we need to convert and cast +// @param newEncoding new encoding for the tensor +// @param newElemType new element type for the tensor +// @return converted and optionally casted tensor value +//===---------------------------------------------------------------------===// +Value convertAndCastTensor(PatternRewriter &rewriter, Value value, + Attribute newEncoding, Type newElemType) { + assert(newElemType.isIntOrFloat()); + + auto loc = value.getLoc(); + auto oldType = cast(value.getType()); + auto oldElemType = oldType.getElementType(); + + assert(oldElemType.isIntOrFloat()); + assert(oldElemType.isIntOrIndex() == newElemType.isIntOrIndex()); + + auto convertedType = + RankedTensorType::get(oldType.getShape(), oldElemType, newEncoding); + + Value convertedTensor = + rewriter.create(loc, convertedType, value); + + if (newElemType == oldElemType) + return convertedTensor; + + Type castedType = convertedType.cloneWith(std::nullopt, newElemType); + + Value castedTensor; + + if (newElemType.isIntOrIndex()) { + unsigned oldWidth = oldElemType.getIntOrFloatBitWidth(); + unsigned newWidth = newElemType.getIntOrFloatBitWidth(); + if (oldWidth == newWidth) + castedTensor = rewriter.create(loc, convertedType, + convertedTensor); + else if (oldWidth > newWidth) + castedTensor = + rewriter.create(loc, castedType, convertedTensor); + else if (oldElemType.isSignedInteger()) + castedTensor = + rewriter.create(loc, castedType, convertedTensor); + else + castedTensor = + rewriter.create(loc, castedType, convertedTensor); + } else { + if (oldElemType.isF16() && newElemType.isF32()) + castedTensor = + rewriter.create(loc, castedType, convertedTensor); + else if (oldElemType.isF32() && newElemType.isF16()) + castedTensor = + rewriter.create(loc, castedType, convertedTensor); + else + castedTensor = + rewriter.create(loc, castedType, convertedTensor); + } + return castedTensor; +} + +class BlockedToMFMA : public OpRewritePattern { + int mfmaVersion; + int nonKDim; + int kPack; + +public: + BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} + + LogicalResult matchAndRewrite(tt::DotOp dotOp, + PatternRewriter &rewriter) const override { + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + !isa(oldRetType.getEncoding())) + return failure(); + if (!isa_and_nonnull(dotOp.getType().getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); + + auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); + + // get MFMA encoding for the given number of warps + auto retShape = oldRetType.getShape(); + int numWarps = ttg::lookupNumWarps(dotOp); + + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = cast(a.getType()); + auto oldBType = cast(b.getType()); + auto ctx = oldAType.getContext(); + + Type aElemType = oldAType.getElementType(); + Type bElemType = oldBType.getElementType(); + bool withScale = + mfmaVersion == 4 && isF8F6F4(aElemType) && isF8F6F4(bElemType); + + // If mfmaVersion == 4 and both inputs are of F8F6F4 types, we will try to + // use the V_MFMA_*_F8F6F4 instructions since it has higher FLOPs per cycle. + // If we can't find a proper instruction, we will fall back to select from + // normal mfma instructions. + FailureOr mfmaInstr = + chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, withScale); + if (failed(mfmaInstr)) { + if (!withScale) { + return failure(); + } + mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, false); + if (failed(mfmaInstr)) + return failure(); + + withScale = false; + } + + auto mDim = mfmaInstr->mDim; + auto nDim = mfmaInstr->nDim; + auto kDim = mfmaInstr->kDim; + auto kBase = mfmaInstr->kBase; + + auto warpsPerTile = + warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); + + // Use transposed mfma layout to enable larger vectorization for global + // store instructions, except for fp8 matmul kernels due to regression + // TODO (lixun): investigate the regression and enable this feature again + auto aElemTy = mfmaInstr->aElementType; + bool isFP8 = llvm::isa(aElemTy); + bool isTransposed = + isChainDotHead(dotOp) || isChainDotTail(dotOp) || !isFP8; + ttg::AMDMfmaEncodingAttr mfmaEnc = ttg::AMDMfmaEncodingAttr::get( + oldRetType.getContext(), + /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, + /*instrShape*/ mDim, nDim, isTransposed, CTALayout); + + Type mfmaAccType; + if (oldRetType.getElementType().isIntOrIndex()) + mfmaAccType = rewriter.getIntegerType(32); + else + mfmaAccType = rewriter.getF32Type(); + + // convert accumulator + auto oldAcc = dotOp.getC(); + auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType); + + // Here is a brief explanation of kWidth, kBase, and kDim + // 1. kWidth: the number of elements each thread loads from shared memory in + // preparation for mfma instructions. In theory each thread can issue one + // or more load instructions to load a total of kWidth elements, since + // those elements are not required to be in contiguous addresses in + // shared memory. But in practice, we make sure the kWidth elements can + // be loaded from shared memory by a single ds_read instruction by + // setting vecSize of the sharedLayout to be kWidth. + // 2. kDim: the k dimension size of the mfma instruction. E.g. instruction + // mfma_32x32x16 has kDim = 16, meaning this mfma instruction can compute + // a matmul of operands with shape 32x16 and 16x32. + // 3. kBase: the number of elements each thread holds for a single mfma + // instruction. + // 4. relation between kBase and kDim: + // 4.1 For mfma_32, kBase = kDim / 2 + // 4.2 For mfma_16, kBase = kDim / 4 + // 4.3 For mfma_4, it depends on how mfma_4 is used. We'll extend to + // mfma_4 later. + // 5. relation between kWidth and kBase: For now it supports two cases + // 5.1 kWidth = kBase, i.e. kPack = 1. In this case, each load from + // shared memory results in one mfma instruction. + // 5.2 kWidth = 2 * kBase, i.e. kPack = 2. In this case, each load from + // shared memory results in two mfma instructions, since one mfma + // can only consume kBase elements from each thread. + // Note that we cannot have larger kPack since kPack = 2 means + // ds_read_b128, which is the largest vector size for shared memory load. + auto kWidth = kBase; + // in mfma 4x4 case argument matrix groups in 16 groups + if (mDim == 4 && nDim == 4) + kWidth = kDim / 16; + if ((mDim == 4 && nDim == 64) || (mDim == 64 && nDim == 4)) + kWidth = kDim; + + // We want to extend kWidth by kPack (kPack=1 means no extension) + // to increase ds_read vector size + // However, in FA, the second dot can only use kWidth = kBase since it's + // limited by the result of the first dot, which is of mfmaLayout. + if (!isChainDotTail(dotOp)) + kWidth *= kPack; + + Value newDot; + if (withScale) { + // If a scaled mfma instruction is chosen, we will rewrite the DotOp to a + // DotScaledOp. + auto aScaledElemTy = mlirTypeToScaledElemType(aElemType); + auto bScaledElemTy = mlirTypeToScaledElemType(bElemType); + if (failed(aScaledElemTy) || failed(bScaledElemTy)) + return failure(); + + auto aEncLL = chooseScaledMfmaOperandLayout( + mfmaEnc, kWidth, /*dotOperandIdx=*/0, aScaledElemTy.value(), + oldAType.getShape()); + auto bEncLL = chooseScaledMfmaOperandLayout( + mfmaEnc, kWidth, /*dotOperandIdx=*/1, bScaledElemTy.value(), + oldBType.getShape()); + auto newAEncoding = ttg::LinearEncodingAttr::get(ctx, aEncLL); + auto newBEncoding = ttg::LinearEncodingAttr::get(ctx, bEncLL); + + a = convertAndCastTensor(rewriter, a, newAEncoding, + mfmaInstr->aElementType); + b = convertAndCastTensor(rewriter, b, newBEncoding, + mfmaInstr->bElementType); + newDot = rewriter.create( + dotOp.getLoc(), newAcc.getType(), a, b, newAcc, Value(), Value(), + aScaledElemTy.value(), bScaledElemTy.value(), /*fastMath=*/false); + } else { + auto newAEncoding = + ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth); + auto newBEncoding = + ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth); + a = convertAndCastTensor(rewriter, a, newAEncoding, + mfmaInstr->aElementType); + b = convertAndCastTensor(rewriter, b, newBEncoding, + mfmaInstr->bElementType); + newDot = rewriter.create(dotOp.getLoc(), newAcc.getType(), a, + b, newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); + } + + Value dotOutput = + convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(), + oldRetType.getElementType()); + + rewriter.replaceOp(dotOp, dotOutput); + + return success(); + } +}; + +class ScaledBlockedToMFMA final : public OpRewritePattern { + int mfmaVersion; + int nonKDim; + int kPack; + +public: + ScaledBlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, + int kPack, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} + + LogicalResult matchAndRewrite(triton::DotScaledOp dotOp, + PatternRewriter &rewriter) const override { + using TensorValue = TypedValue; + + RankedTensorType oldRetType = dotOp.getType(); + if (!isa_and_nonnull(oldRetType.getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); + unsigned rank = oldRetType.getRank(); + if (rank == 3) + return rewriter.notifyMatchFailure(dotOp, "NYI: 3d case"); + + TensorValue a = dotOp.getA(); + TensorValue b = dotOp.getB(); + TensorValue aScale = dotOp.getAScale(); + TensorValue bScale = dotOp.getBScale(); + if (aScale && bScale) + return rewriter.notifyMatchFailure(dotOp, "NYI: both LHS and RHS scale"); + + ScaleDotElemType aElemType = dotOp.getAElemType(); + ScaleDotElemType bElemType = dotOp.getBElemType(); + auto supportsTypes = [](ScaleDotElemType elemType) { + return elemType == ScaleDotElemType::E2M1 || + elemType == ScaleDotElemType::E4M3 || + elemType == ScaleDotElemType::E5M2 || + elemType == ScaleDotElemType::BF16 || + elemType == ScaleDotElemType::FP16; + }; + if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) + return rewriter.notifyMatchFailure(dotOp, "NYI: mxfp6 operand"); + + MLIRContext *ctx = dotOp.getContext(); + auto moduleOp = dotOp->getParentOfType(); + int numWarps = ttg::lookupNumWarps(dotOp); + + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); + + // Choose a suitable MFMA instruction for this scaled dot op. + bool useFp16 = aElemType == ScaleDotElemType::FP16 || + bElemType == ScaleDotElemType::FP16; + FailureOr mfmaInstr = + chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, useFp16); + if (failed(mfmaInstr)) + return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic"); + + if (useFp16) { + dotOp.emitRemark( + "Warning: detected one dot_scaled operand is fp16 tensor so " + "upcasting to fp16 for computation, which impacts precision; " + "experimental behavior and may change in future"); + } + + unsigned mDim = mfmaInstr->mDim; + unsigned nDim = mfmaInstr->nDim; + unsigned kDim = mfmaInstr->kDim; + unsigned kBase = mfmaInstr->kBase; + + // For mxfp4 A/B tensor, we pack every two values into one int8 value there. + // For such cases, we have different initial kWidth for LHS and RHS, which + // will be "fixed" later by using upcast_mxfp to convert LHS to unpacked + // values. For such packed cases, we cannot support flexible kPack choices + // from the developer--it just does not apply here. So mandate the choice + // here. + bool isAPacked = aElemType == ScaleDotElemType::E2M1; + bool isBPacked = bElemType == ScaleDotElemType::E2M1; + bool isPacked = isAPacked || isBPacked; + unsigned kWidths[] = {isPacked ? (isAPacked ? 4 : 8) : kBase * kPack, + isPacked ? (isAPacked ? 8 : 4) : kBase * kPack}; + + // For A/B tensor, 32 consecutive elements along K dim share the same scale. + // We'd like to keep the scale values together with the base values in the + // same warp to avoid cross-warp data exchange. It means we want warpsPerCTA + // = 1 along the N/M dimension for the mxfp A/B case. We achieve that by + // setting the M/N dimension as numWarps. + SmallVector mfmaWarpsPerCTA(rank, 1); + mfmaWarpsPerCTA[aScale ? 0 : 1] = numWarps; + + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions. + auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get( + ctx, /*versionMajor=*/mfmaVersion, /*versionMinor=*/0, mfmaWarpsPerCTA, + /*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout); + + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mfmaEnc); + + auto newAcc = rewriter.create( + dotOp.getC().getLoc(), newRetType, dotOp.getC()); + + auto upcastForMMA = [&](TensorValue v, int idx, + ScaleDotElemType type) -> TensorValue { + auto vType = v.getType(); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), kWidths[idx]); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); + v = rewriter.create(v.getLoc(), newVType, v); + // Don't need to covert int8 holding mxfp4--the upcast_mxfp op can + // take int8 tensor as input. + if (type == ScaleDotElemType::BF16 || type == ScaleDotElemType::FP16 || + type == ScaleDotElemType::E2M1) + return v; + + auto upcastedType = RankedTensorType::get( + vType.getShape(), + useFp16 ? rewriter.getF16Type() : rewriter.getBF16Type(), + newVEncoding); + return cast( + rewriter.create(v.getLoc(), upcastedType, v).getResult()); + }; + a = upcastForMMA(a, 0, aElemType); + b = upcastForMMA(b, 1, bElemType); + + // We need to have "matching" encoding between the main tensor and scale + // tensor to make sure the scale values needed is in the same warp. So we + // adopt the same CTA layout and warps per CTA. The warp dimensions needs to + // match along M/N dimension too. With in a warp, we have 64 threads. We let + // each thread read in one scale value. So we need a threadsPerWarp = + // mDim/nDim along M/N dimension. Note that For MFMA intrinsics, mDim is + // always the same as nDim. And for scaled dot scale tensor, we always have + // K as the innermost dimension. So we have the same threadsPerWarp in the + // below no matter A or B scale. Similarly for warpsPerCTA, the non-K + // dimension is always at index 0. + assert(mDim == nDim); + SmallVector threadsPerWarp = {mDim, numThreads / mDim}; + SmallVector blockWarpsPerCTA(rank, 1); + blockWarpsPerCTA[0] = numWarps; + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, blockWarpsPerCTA, {1, 0}, ctaLayout); + + auto upcastMXFP = [&](TensorValue v, TensorValue scale, + ScaleDotElemType elemType, bool fastMath) -> Value { + if (!scale) + return v; + + auto newScaleType = RankedTensorType::get( + scale.getType().getShape(), scale.getType().getElementType(), + newScaleEncoding); + auto convOp = rewriter.create(scale.getLoc(), + newScaleType, scale); + + Builder b(v.getContext()); + // TODO: Emit device assert to check scale tensor range fitting into fp16? + Type outputElemType = useFp16 ? b.getF16Type() : b.getBF16Type(); + auto outputType = + amdgpu::UpcastMXFPOp::deduceOutputType(v, elemType, outputElemType); + return rewriter.create( + dotOp.getLoc(), outputType, v, convOp, elemType, fastMath); + }; + + Value scaledA = + upcastMXFP(a, aScale, dotOp.getAElemType(), dotOp.getFastMath()); + Value scaledB = + upcastMXFP(b, bScale, dotOp.getBElemType(), dotOp.getFastMath()); + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, scaledA, + scaledB, newAcc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot); + return success(); + } +}; + +class ScaledBlockedToScaledMFMAF8F6F4 final + : public OpRewritePattern { + int mfmaVersion; + int nonKDim; + +public: + ScaledBlockedToScaledMFMAF8F6F4(MLIRContext *context, int mfmaVersion, + int nonKDim, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim) {} + + LogicalResult matchAndRewrite(triton::DotScaledOp dotOp, + PatternRewriter &rewriter) const override { + using TensorValue = TypedValue; + + if (mfmaVersion != 4) { + return rewriter.notifyMatchFailure( + dotOp, "F8F6F4 scaled dot is only natively supported on gfx950"); + } + + RankedTensorType oldRetType = dotOp.getType(); + if (!isa_and_nonnull(oldRetType.getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); + + unsigned rank = oldRetType.getRank(); + if (rank == 3) + return rewriter.notifyMatchFailure(dotOp, "NYI: 3d case"); + + TensorValue a = dotOp.getA(); + TensorValue b = dotOp.getB(); + TensorValue aScale = dotOp.getAScale(); + TensorValue bScale = dotOp.getBScale(); + auto oldShape = oldRetType.getShape(); + + ScaleDotElemType aElemType = dotOp.getAElemType(); + ScaleDotElemType bElemType = dotOp.getBElemType(); + auto supportsTypes = [](ScaleDotElemType elemType) { + return elemType == ScaleDotElemType::E2M1 || + elemType == ScaleDotElemType::E4M3 || + elemType == ScaleDotElemType::E5M2; + }; + + if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) { + return rewriter.notifyMatchFailure(dotOp, "NYI: mxfp6"); + } + + bool bothScalesAbsent = !aScale && !bScale; + + MLIRContext *ctx = dotOp.getContext(); + + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + unsigned numWarps = ttg::lookupNumWarps(dotOp); + if (numWarps == 1) + return rewriter.notifyMatchFailure(dotOp, + "num_warps==1 is not supported"); + + // Choose a suitable Scaled MFMA instruction for this scaled dot op. + FailureOr mfmaInstr = + chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); + if (failed(mfmaInstr)) + return rewriter.notifyMatchFailure(dotOp, + "cannot choose scaled mfma intrinsic"); + + auto mDim = mfmaInstr->mDim; + auto nDim = mfmaInstr->nDim; + auto kDim = mfmaInstr->kDim; + auto kBase = mfmaInstr->kBase; + assert(mDim == nDim); + + auto warpsPerTile = + warpsPerTileMFMA(dotOp, oldShape, numWarps, {mDim, nDim}); + + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions. + auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get( + ctx, /*versionMajor=*/mfmaVersion, /*versionMinor=*/0, warpsPerTile, + /*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout); + auto warpOrder = mfmaEnc.getDefaultWarpOrder(); + + auto newRetType = + RankedTensorType::get(oldShape, oldRetType.getElementType(), mfmaEnc); + + auto newAcc = rewriter.create( + dotOp.getC().getLoc(), newRetType, dotOp.getC()); + + auto order = ttg::getMatrixOrder(rank, /*rowMajor=*/true); + auto standardOutDims = standardOutDimNames(ctx, rank); + + // For the mfma_scale_f32_*_f8f6f4 instructions, each thread consumes 32 + // elements. But since two fp4 elements are packed into one int8, the + // kWidth is 16 for fp4. + const unsigned kWidth = kBase; + using basisT = std::vector>; + + auto aShape = a.getType().getShape(); + auto bShape = b.getType().getShape(); + auto aEncLL = chooseScaledMfmaOperandLayout( + mfmaEnc, kWidth, /*dotOperandIdx=*/0, aElemType, aShape); + auto bEncLL = chooseScaledMfmaOperandLayout( + mfmaEnc, kWidth, /*dotOperandIdx=*/1, bElemType, bShape); + + auto convertInputLayout = [&](TensorValue v, + LinearLayout layout) -> TensorValue { + auto vType = v.getType(); + + auto newEnc = ttg::LinearEncodingAttr::get(ctx, layout); + auto newVType = RankedTensorType::get(vType.getShape(), + vType.getElementType(), newEnc); + return rewriter.create(v.getLoc(), newVType, v); + }; + a = convertInputLayout(a, aEncLL); + b = convertInputLayout(b, bEncLL); + + StringAttr kWarp = StringAttr::get(ctx, "warp"); + auto convertScaleLayout = [&](TensorValue scale, + llvm::ArrayRef valShape, + LinearLayout dotLL, int idx) -> Value { + if (bothScalesAbsent) + return Value(); + + LinearLayout::BasesT scaleBases = dotLL.getBases(); + auto &warpBases = scaleBases[kWarp]; + + SmallVector shape; + if (!scale) { + int64_t nonKDim = idx == 0 ? valShape[0] : valShape[1]; + int64_t k = idx == 0 ? valShape[1] : valShape[0]; + ScaleDotElemType &elemType = idx == 0 ? aElemType : bElemType; + int packSize = elemType == ScaleDotElemType::E2M1 ? 2 : 1; + shape = {nonKDim, k * packSize / 32}; + } else { + shape = llvm::to_vector(scale.getType().getShape()); + } + + LinearLayout newLL = + chooseScaledMfmaScaleLayout(ctx, idx, warpBases, shape, mDim); + + Attribute newScaleEncoding = ttg::LinearEncodingAttr::get(ctx, newLL); + // Scale's data type is always i8 + auto newScaleType = RankedTensorType::get(shape, i8_ty, newScaleEncoding); + + if (!scale) { + // 0x7F is 1.0 in E8M0 + return rewriter.create( + dotOp->getLoc(), newScaleType, + DenseElementsAttr::get(newScaleType, llvm::APInt(8, 0x7F))); + } else { + return rewriter.create(scale.getLoc(), + newScaleType, scale); + } + }; + auto newAScale = + convertScaleLayout(aScale, aShape, aEncLL, /*dotOperandIdx=*/0); + auto newBScale = + convertScaleLayout(bScale, bShape, bEncLL, /*dotOperandIdx=*/1); + + auto newDot = rewriter.create( + dotOp.getLoc(), newRetType, a, b, newAcc, newAScale, newBScale, + aElemType, bElemType, dotOp.getFastMath()); + + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot); + + return success(); + } +}; + +static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, + Type promotedType) { + Type tensorPromotedType = cast(operand.getType()) + .cloneWith(std::nullopt, promotedType); + return builder.create(loc, tensorPromotedType, operand); +} + +// promote operands of dot op if the existing combination is not natively +// supported. +static void decomposeMixedModeDotOp(ModuleOp mod) { + mod.walk([](triton::DotOp dotOp) -> void { + auto D = dotOp.getD(); + OpBuilder builder(dotOp); + Type AElType = dotOp.getA().getType().getElementType(); + Type promoteType; + if (isa(D.getType().getEncoding())) { + Type BElType = dotOp.getB().getType().getElementType(); + + auto maxBitWidth = std::max(AElType.getIntOrFloatBitWidth(), + BElType.getIntOrFloatBitWidth()); + + // TODO check mfma tensor core version compatibility + if (maxBitWidth == 8) + return; + + if (AElType == BElType) + return; + + if (maxBitWidth < 16) + promoteType = builder.getF16Type(); + else if (maxBitWidth <= 32) + promoteType = builder.getF32Type(); + } else if (isa(D.getType().getEncoding())) { + Type BElType = dotOp.getB().getType().getElementType(); + + if (AElType == BElType) + return; + + // Other cases must be filtered earlier + promoteType = + AElType.getIntOrFloatBitWidth() > BElType.getIntOrFloatBitWidth() + ? AElType + : BElType; + } else { + // FMA case is processed in AccelerateBlocked + return; + } + Location loc = dotOp.getLoc(); + Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType); + Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType); + dotOp.setOperand(0, promotedA); + dotOp.setOperand(1, promotedB); + }); +} + +class BlockedToWMMA : public OpRewritePattern { + int wmmaVersion; + +public: + BlockedToWMMA(MLIRContext *context, int wmmaVersion, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {} + + LogicalResult matchAndRewrite(tt::DotOp dotOp, + PatternRewriter &rewriter) const override { + auto ctx = dotOp->getContext(); + + Value a = dotOp.getA(); + Value b = dotOp.getB(); + + auto oldRetType = cast(dotOp.getResult().getType()); + auto oldRetEncoding = oldRetType.getEncoding(); + if (!oldRetEncoding || !isa(oldRetEncoding)) + return failure(); + + auto oldAType = cast(a.getType()); + auto oldBType = cast(b.getType()); + auto retShape = oldRetType.getShape(); + auto aShape = oldAType.getShape(); + auto bShape = oldBType.getShape(); + + // check shape + auto mnkDim = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr(); + auto rank = aShape.size(); + if (aShape[rank - 2] % mnkDim[0] != 0 || // m + bShape[rank - 1] % mnkDim[1] != 0 || // n + aShape[rank - 1] % mnkDim[2] != 0) // k + return failure(); + + if (wmmaVersion == 2 && llvm::isa(oldAType) && + oldAType.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure(dotOp, "not supported yet"); + } + + // get operand types + auto operandTypes = getOperandTypesForWmmaOp(rewriter, dotOp, wmmaVersion); + if (operandTypes.empty()) + return failure(); + + // get WMMA encoding for the given number of warps + int numWarps = ttg::lookupNumWarps(dotOp); + + ttg::AMDWmmaEncodingAttr wmmaEnc; + + auto warpsPerTile = warpsPerTileWMMA(dotOp, retShape, numWarps); + + auto CTALayout = ttg::getCTALayout(oldRetEncoding); + + // TODO implement heuristic/option for this parameter + bool isTransposed = false; + wmmaEnc = ttg::AMDWmmaEncodingAttr::get(ctx, wmmaVersion, isTransposed, + warpsPerTile, CTALayout); + + auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc); + + // convert accumulator + auto oldAcc = dotOp.getC(); + auto newAcc = + convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]); + auto kWidth = wmmaEnc.getKWidthForOperands(); + + auto newAType = RankedTensorType::get( + aShape, operandTypes[0], + ttg::DotOperandEncodingAttr::get(ctx, 0, wmmaEnc, kWidth)); + auto newBType = RankedTensorType::get( + bShape, operandTypes[1], + ttg::DotOperandEncodingAttr::get(ctx, 1, wmmaEnc, kWidth)); + + Value castedA = convertAndCastTensor(rewriter, a, newAType.getEncoding(), + operandTypes[0]); + Value castedB = convertAndCastTensor(rewriter, b, newBType.getEncoding(), + operandTypes[1]); + auto newDot = rewriter.create( + dotOp.getLoc(), newRetType, castedA, castedB, newAcc, + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); + + Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding, + oldRetType.getElementType()); + rewriter.replaceOp(dotOp, dotOutput); + return success(); + } +}; + +class AccelerateBlocked : public OpRewritePattern { + StringRef arch; + +public: + AccelerateBlocked(MLIRContext *context, StringRef arch, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), arch(arch) {} + + bool isFloat(Type t) const { return t.isIntOrFloat() && !t.isIntOrIndex(); } + + Value castToElTy(PatternRewriter &rewriter, Value v, Type elTy) const { + Location loc = v.getLoc(); + auto srcTy = cast(v.getType()); + auto dstTy = srcTy.cloneWith(std::nullopt, elTy); + if (srcTy == dstTy) + return v; + auto srcElTy = srcTy.getElementType(); + auto dstElTy = dstTy.getElementType(); + if (isFloat(srcElTy) && isFloat(dstElTy)) { + auto rmode = + RoundingModeAttr::get(rewriter.getContext(), RoundingMode::RTNE); + return rewriter.create(loc, dstTy, v, rmode); + } + if (!isFloat(srcElTy) && isFloat(dstElTy)) + return rewriter.create(loc, dstTy, v); + if (isFloat(srcElTy) && !isFloat(dstElTy)) + return rewriter.create(loc, dstTy, v); + assert(false && "int -> int cast is unexpected in FMA legalization"); + return Value(); + } + + struct DotElTypes { + Type a, b, c, d; + }; + + bool isLegalFMAForm(DotOp dotOp, const DotElTypes &dotTypes) const { + if (AMD::supportsVDot(arch)) { + auto aOpType = dotOp.getA().getType(); + int rank = aOpType.getRank(); + int k = aOpType.getShape()[rank - 1]; + // Try Fp16 x Fp16 -> Fp32 v_dot + // if k % 2 != 0: can not use fp V_DOT instruction + if (dotTypes.a.isF16() && dotTypes.b.isF16() && dotTypes.c.isF32() && + dotTypes.d.isF32() && k % 2 == 0) { + return true; + } + + // TODO: enable this condition, when fp32 -> fp16 cast works correctly + // Consider this case as non legal, despite this case is covered by fp16 + // FMA. Because v_dot expected to give both better performance and + // computational precision. + if (false && dotTypes.a.isF16() && dotTypes.b.isF16() && + dotTypes.c.isF16() && dotTypes.d.isF16() && k % 2 == 0) { + return false; + } + + // Try I8 x I8 -> I32 v_dot + // if k % 4 != 0: can not use integer V_DOT instruction + if (dotTypes.a.isInteger(8) && dotTypes.b.isInteger(8) && + dotTypes.c.isInteger(32) && dotTypes.d.isInteger(32) && k % 4 == 0) { + return true; + } + } + + auto expectedElTy = dotTypes.a; + for (auto operand : dotOp.getOperands()) { + auto opTy = cast(operand.getType()); + auto elTy = opTy.getElementType(); + if (elTy != expectedElTy) + return false; + if (!elTy.isF16() && !elTy.isF32()) + return false; + } + return true; + } + + LogicalResult tryAccelerateF16WithVDot(DotOp dotOp, PatternRewriter &rewriter, + const DotElTypes &dotTypes) const { + if (!AMD::supportsVDot(arch)) + return failure(); + + // If this is fp16 x fp16 ->fp16 case prioritize using v_dot. + auto aOpType = dotOp.getA().getType(); + int rank = aOpType.getRank(); + int k = aOpType.getShape()[rank - 1]; + if (dotTypes.a.isF16() && dotTypes.b.isF16() && dotTypes.c.isF16() && + dotTypes.d.isF16() && k % 2 == 0) { + auto newC = castToElTy(rewriter, dotOp.getC(), f32_ty); + auto newDot = rewriter.create( + dotOp.getLoc(), newC.getType(), dotOp.getA(), dotOp.getB(), newC, + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); + auto newD = castToElTy(rewriter, newDot.getResult(), f16_ty); + rewriter.replaceOp(dotOp, newD); + return success(); + } + return failure(); + } + + LogicalResult tryLegalizeFMA(DotOp dotOp, PatternRewriter &rewriter, + const DotElTypes &dotTypes) const { + // Legalize dot for plain FMA case, i.e. same operands and result type. + + // Find common type, larger or equal of all operand types + SmallVector opElTy{dotTypes.a, dotTypes.b, dotTypes.c, dotTypes.d}; + unsigned maxBitsize = 8; + for (auto elTy : opElTy) + maxBitsize = std::max(maxBitsize, elTy.getIntOrFloatBitWidth()); + assert(maxBitsize <= 32); + Type commonTy = + maxBitsize <= 16 ? rewriter.getF16Type() : rewriter.getF32Type(); + + // Check that type is compatible with all operands; fallback to fp32 if not. + if (commonTy.isF16()) { + for (auto elTy : opElTy) { + if (elTy.isInteger() && elTy.getIntOrFloatBitWidth() > 8) { + commonTy = rewriter.getF32Type(); + break; + } + if (elTy.isBF16()) { + commonTy = rewriter.getF32Type(); + break; + } + } + } + + auto newA = castToElTy(rewriter, dotOp.getA(), commonTy); + auto newB = castToElTy(rewriter, dotOp.getB(), commonTy); + auto newC = castToElTy(rewriter, dotOp.getC(), commonTy); + + auto newDot = rewriter.create(dotOp.getLoc(), newC.getType(), newA, + newB, newC, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); + auto newD = castToElTy(rewriter, newDot.getResult(), dotTypes.d); + + rewriter.replaceOp(dotOp, newD); + return success(); + } + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + if (!isa(dotOp.getD().getType().getEncoding())) + return failure(); + + DotElTypes dotTypes; + dotTypes.a = dotOp.getA().getType().getElementType(); + dotTypes.b = dotOp.getB().getType().getElementType(); + dotTypes.c = dotOp.getC().getType().getElementType(); + dotTypes.d = dotOp.getD().getType().getElementType(); + + // Check that dot is not legalized already + if (isLegalFMAForm(dotOp, dotTypes)) { + return failure(); + } + + // TODO: enable this condition, when fp32 -> fp16 cast works correctly + if (false && + tryAccelerateF16WithVDot(dotOp, rewriter, dotTypes).succeeded()) { + return success(); + } + + return tryLegalizeFMA(dotOp, rewriter, dotTypes); + } +}; + +} // namespace + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +class TritonAMDGPUAccelerateMatmulPass + : public TritonAMDGPUAccelerateMatmulBase< + TritonAMDGPUAccelerateMatmulPass> { +public: + TritonAMDGPUAccelerateMatmulPass() = default; + TritonAMDGPUAccelerateMatmulPass(StringRef archGen, int matrixInstructionSize, + int kPack) { + this->archGenerationName = archGen.data(); + this->matrixInstructionSize = matrixInstructionSize; + this->kPack = kPack; + } + void runOnOperation() override { + + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet patterns(context); + switch (auto isaFamily = triton::AMD::deduceISAFamily(archGenerationName)) { + case ISAFamily::CDNA4: + patterns.add<::ScaledBlockedToScaledMFMAF8F6F4>( + context, getMfmaVersion(isaFamily), matrixInstructionSize, + /*benefit=*/10); + [[fallthrough]]; + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: + patterns.add<::BlockedToMFMA, ::ScaledBlockedToMFMA>( + context, getMfmaVersion(isaFamily), matrixInstructionSize, kPack, + /*benefit=*/2); + break; + case ISAFamily::RDNA3: + patterns.add<::BlockedToWMMA>(context, getWmmaVersion(archGenerationName), + /*benefit=*/2); + break; + default: + break; + } + patterns.add(context, archGenerationName, /*benefit=*/1); + if (applyPatternsGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + decomposeMixedModeDotOp(m); + } +}; + +std::unique_ptr mlir::createTritonAMDGPUAccelerateMatmulPass( + std::string archGen, int matrixInstructionSize, int kPack) { + return std::make_unique( + archGen, matrixInstructionSize, kPack); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp new file mode 100644 index 000000000..845b8f5f0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -0,0 +1,728 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +#define DEBUG_TYPE "tritonamdgpu-block-pingpong" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace ttg = mlir::triton::gpu; +namespace tt = mlir::triton; + +namespace { + +// This pass transforms a for-loop calculating a GEMM. Main purpose of the +// transform is improve the efficiency of the GPU dot instruction (mfma) +// by interleaving the execution of two warps on each SIMD. Especially it groups +// instructions into Dot and Memory clusters so they can efficiently run in +// parallel. Also this pass inserts `rocdl.s.setprio` operation and +// `amdgpu.cond_barrier` to run two parallel warps in synchronization. +// This scheduling doesn't help improving the memory latency itself but it +// relies on software-pipelining to hide the global latency. Likely to improve +// the performance of compute-bound cases. +class Pingponger { + scf::ForOp forOp; + SmallVector gLoadOps; + SmallVector lLoadOps; + SmallVector lStoreOps; + SmallVector dotOps; + SmallVector> subViewOps; + SmallVector> loadSliceOps; + SmallVector dotSliceOps; + SmallVector constOffsets; + Operation *lastInsertedOp; + + // rocdl.s.setprio will be mapped to `s_setprio` instruction which set the + // priority of the warp within a SIMD, determines which warp to occupy the + // instruction unit when they compete on the same instruction. + // We use this instruction in the pingpong scheduling to prevent warps from + // entering into the dot cluster while the other warp is still busy in the dot + // cluster. Otherwise pingpong pattern can be broken and performance drops. + // Currently pingpong only handles two warps, we only need 0/1 priorities. + int lowPriority = 0; + int highPriority = 1; + int32_t kWidth; + int32_t numWarps; + +public: + Pingponger(scf::ForOp forOp, int32_t numWarps) + : forOp(forOp), numWarps(numWarps) {} + void getDotPingponged(); + +private: + void genOffsetConstants(Location loc, OpBuilder &builder, unsigned numSlices, + int64_t sliceWidth); + LogicalResult genLocalSlice(OpBuilder &builder, Value v, + Attribute dotEncoding, unsigned opIdx, + unsigned numSlices, int64_t sliceWidth); + LogicalResult sliceDot(OpBuilder &builder, Location loc, tt::DotOp op, + unsigned numSlices); + void transformOnePPClusters(OpBuilder &builder, Location loc); + LogicalResult transformFourPPClusters(OpBuilder &builder, Location loc); + LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc); + void addAsymmetricSyncToLoop(OpBuilder &builder, Location loc); + void updateOpInsertion(Operation *Op); + void appendOp(Operation *Op); + void moveOpAndPredecessorsUpSameBlock(Operation *Op); + void appendSlicedLoadAB(int slice); + void appendClusterBarrier(OpBuilder &builder, Location loc); + void appendOpWithPrio(OpBuilder &builder, Operation *Op, Location loc); + void determineDotMemoryOps(tt::DotOp dotOp, + DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotLocalStores); + template + void findClosestPredOps(Value v, DenseSet &matchingOps); +}; + +void Pingponger::updateOpInsertion(Operation *op) { lastInsertedOp = op; } +void Pingponger::appendOp(Operation *op) { + assert(lastInsertedOp != nullptr); + op->moveAfter(lastInsertedOp); + lastInsertedOp = op; +} + +// Move the given operations and any predecessors upon which it depends +// up in the block to the last inserted operation. This does not move +// operations that reaches the last inserted operation or +// are not in the same block. The exception is op, which is always moved +// to the new location (can move down or up). +void Pingponger::moveOpAndPredecessorsUpSameBlock(Operation *op) { + assert(lastInsertedOp != nullptr); + // TODO: Enable moving ops across blocks + assert(op->getBlock() == lastInsertedOp->getBlock()); + Operation *checkedOp = lastInsertedOp; + // Check if we are moving the op up, if so we may need to + // move additional ops up to maintain correctness. + if (lastInsertedOp->isBeforeInBlock(op)) { + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = [&checkedOp](Operation *op) { + return op->getBlock() == checkedOp->getBlock() && + checkedOp->isBeforeInBlock(op); + }; + getBackwardSlice(op, &backwardSlice, opt); + for (auto predOp : backwardSlice) + appendOp(predOp); + appendOp(op); + } else { + auto hasUnsafeUser = [&checkedOp](auto &&user) { + return user != checkedOp && user->getBlock() == checkedOp->getBlock() && + user->isBeforeInBlock(checkedOp); + }; + if (std::any_of(op->user_begin(), op->user_end(), hasUnsafeUser)) + LDBG("Unable to move operation " + << op << " due to use before intended move location"); + else + appendOp(op); + } +} +void Pingponger::appendSlicedLoadAB(int slice) { + appendOp(subViewOps[0][slice]); + appendOp(loadSliceOps[0][slice]); + appendOp(subViewOps[1][slice]); + appendOp(loadSliceOps[1][slice]); +} +// Asymmetrically synchronized loop in the pingpong scheduling synchronizes all +// the warps at the end of each instruction cluster. Since cond_barrier +// triggered a barrier for only half of the warps in a block, at the point +// this clusterBarrier is called, half warps are at dot cluster and the others +// are at the memory cluster. +// Also, SchedBarrier with `0` is set here to tell compiler backend not to +// reorder any instruction across this point. +void Pingponger::appendClusterBarrier(OpBuilder &builder, Location loc) { + // MembarAnalysis can recognize gpu::BarrierOp and skip inserting additional + // barrier + appendOp(builder.create(loc)); + appendOp(builder.create(loc, 0)); +} +void Pingponger::appendOpWithPrio(OpBuilder &builder, Operation *op, + Location loc) { + appendOp(builder.create(loc, highPriority)); + appendOp(op); + appendOp(builder.create(loc, lowPriority)); +} + +// Find all of the "closest" operations that are of a given type T +// in the same basic block. Here "closest" means along any path P, +// the first operation of type T that is encountered when traversing +// P from the given value v. This also includes "later" operations +// for block arguments. Note: That we find all T for every path P. +template +void Pingponger::findClosestPredOps(Value v, DenseSet &matchingOps) { + // Create a cache so we can traverse across block arguments. + DenseSet visitedOps; + std::function impl; + impl = [&matchingOps, &visitedOps, &impl](Value v) { + // If we encounter a block argument we only look at the terminators of the + // current block + if (auto blockArg = dyn_cast(v)) { + auto operandNumber = blockArg.getArgNumber(); + auto block = blockArg.getOwner(); + if (auto yield = dyn_cast(block->getTerminator())) { + auto parentOp = block->getParentOp(); + // Skip the induction variables to find the yield position + if (auto forOp = dyn_cast(parentOp)) { + if (operandNumber < forOp.getNumInductionVars()) + return; + operandNumber -= forOp.getNumInductionVars(); + } + impl(yield->getOperand(operandNumber)); + } + } else { + auto definingOp = v.getDefiningOp(); + if (!definingOp) + return; + else if (visitedOps.contains(definingOp)) + return; + visitedOps.insert(definingOp); + if (auto matchOp = dyn_cast(definingOp)) + matchingOps.insert(matchOp); + else + for (auto predValue : definingOp->getOperands()) + impl(predValue); + } + }; + impl(v); +} + +// Populate the dotGlobalLoads, dotLocalLoads, and dotLocalStores set with +// any loads that are generated by the current dot product. This occurs in +// steps to: +// 1. Determine which loads are generated by the dot product via getA() +// and getB(). +// 2. Determine which local stores are used to populate the inputs to +// the local loads. +// 3. Determine which global loads are used to populate the inputs to +// the local stores. +// Note: This function currently depends on num_stages=2, which is a +// precondition for the pingpong scheduling. +void Pingponger::determineDotMemoryOps( + tt::DotOp dotOp, DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotLocalStores) { + // Find the locals loads used to compute the dot inputs. These + // must come before the dot op. + findClosestPredOps(dotOp.getA(), dotLocalLoads); + findClosestPredOps(dotOp.getB(), dotLocalLoads); + + // Determine the local stores from the local loads. + // With pipelining we expect this to be a single local + // store within the loop based on a block argument after routing through + // a ttg.MemDescSubviewOp. + DenseSet subviews; + for (auto &&localLoad : dotLocalLoads) + findClosestPredOps(localLoad.getSrc(), subviews); + + for (auto &&subview : subviews) + for (auto &&user : subview->getUsers()) + if (auto localStore = dyn_cast(user)) + dotLocalStores.insert(localStore); + + // Determine the global loads from the local stores. + // We expect this to just be a global load + // within the loop. + for (auto &&localStore : dotLocalStores) + findClosestPredOps(localStore.getSrc(), dotGlobalLoads); +} + +// Transform a loop into one Dot - Memory (ping - pong) clusters +// Each cluster, especially the Dot cluster is guarded with setprio(1->0) so +// each warp can complete the execution of the cluster without being +// interrupted. This is also supposed to be used with the numWarps=4 case where +// each SIMD runs two warps from different blocks and those two warps don't need +// to be synchronized together. +// Splitting loading A/B and interleave global/local load in order to prevent +// the stalls. +// sched.barriers with 0 mask were used to enforce the boundary of the +// high-level operations, inserting `setPrio` also has a same effect of +// instruction scheduling boundary, too. +void Pingponger::transformOnePPClusters(OpBuilder &builder, Location loc) { + auto dotLoc = dotOps[0]->getPrevNode(); + // sched barrier to prevent memory ops from cross but leave other ops to be + // scheduled across the barrier. + auto preDotBar = builder.create(loc, 1); + updateOpInsertion(dotLoc); + + // Memory cluster #0 + moveOpAndPredecessorsUpSameBlock(lLoadOps[0]); + appendOp(builder.create(loc, highPriority)); + moveOpAndPredecessorsUpSameBlock(gLoadOps[0]); + appendOp(builder.create(loc, 0)); + moveOpAndPredecessorsUpSameBlock(lLoadOps[1]); + appendOp(builder.create(loc, lowPriority)); + moveOpAndPredecessorsUpSameBlock(gLoadOps[1]); + + // Dot cluster #0 + appendOp(preDotBar); + appendOpWithPrio(builder, dotOps[0], loc); + // Add a remark for user feedback + dotOps[0]->emitRemark() << "Performed one ping pong cluster transformation\n"; +} + +void Pingponger::genOffsetConstants(Location loc, OpBuilder &builder, + unsigned numSlices, int64_t sliceWidth) { + for (int i = 0; i < numSlices; i++) { + int64_t offset = sliceWidth * i; + constOffsets.push_back( + builder.create(loc, offset, 32)); + } +} + +// Splits given local_loads for dot into multiple subviews and local_loads. This +// function tries to slice the local_load into the given number of the slices, +// generates ops when succeed, return fail() otherwise. +LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, + Attribute dotEncoding, unsigned opIdx, + unsigned numSlices, + int64_t sliceWidth) { + SmallVector slices; + SmallVector subviews; + // TODO: support transformed input to dot + auto localLoad = v.getDefiningOp(); + if (!localLoad) + return failure(); + auto memDesc = localLoad.getSrc(); + auto type = cast(memDesc.getType()); + SmallVector shape = llvm::to_vector(type.getShape()); + Type elementType = type.getElementType(); + int64_t kIdx = opIdx == 0 ? 1 : 0; + shape[kIdx] = sliceWidth; + // Each slice cannot be smaller than the smallest supported mfma width. + if (sliceWidth < 16) + return failure(); + auto dotOperandEnc = ttg::DotOperandEncodingAttr::get( + builder.getContext(), opIdx, dotEncoding, kWidth); + auto subviewDescType = ttg::MemDescType::get( + shape, elementType, type.getEncoding(), type.getMemorySpace(), + type.getMutableMemory(), type.getAllocShape()); + for (int i = 0; i < numSlices; i++) { + SmallVector offsetsVal; + SmallVector offsets = {0, 0}; + offsets[kIdx] = i; + for (int64_t off : offsets) { + offsetsVal.push_back(constOffsets[off]); + } + Value newSmem = builder.create( + v.getLoc(), subviewDescType, memDesc, offsetsVal); + Value prefetchSlice = builder.create( + v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), + newSmem); + subviews.push_back(newSmem.getDefiningOp()); + slices.push_back(prefetchSlice.getDefiningOp()); + } + subViewOps.push_back(subviews); + loadSliceOps.push_back(slices); + return success(); +} + +// Split dot into 'numSlices' pieces. This is required by pingpong scheduling +// when it needs to schedule multiple dot clusters. Calls genLocalSlice to +// create corresponding local_load slices. +LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc, + tt::DotOp op, unsigned numSlices) { + builder.setInsertionPointToStart(forOp.getBody()); + auto typeB = op.getB().getType(); + auto shapeB = typeB.getShape(); + int64_t sliceWidth = shapeB[0] / numSlices; + if (shapeB[0] % numSlices != 0) + return failure(); + genOffsetConstants(loc, builder, numSlices, sliceWidth); + builder.setInsertionPointAfter(gLoadOps[0]); + auto dotEncoding = op.getType().getEncoding(); + if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth) + .failed() || + genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth) + .failed()) + return failure(); + + // Clone dots to consume all the slices + Operation *prevDot = op; + for (int i = 0; i < numSlices; i++) { + IRMapping mapping; + mapping.map(op.getA(), loadSliceOps[0][i]->getResult(0)); + mapping.map(op.getB(), loadSliceOps[1][i]->getResult(0)); + if (i > 0) + mapping.map(op.getC(), prevDot->getResult(0)); + auto newOp = builder.clone(*op, mapping); + prevDot = newOp; + dotSliceOps.push_back(newOp); + } + op->replaceAllUsesWith(prevDot); + op->erase(); + for (auto loads : lLoadOps) + loads->erase(); + return success(); +} + +// Transform a loop into four Dot - Memory (ping - pong) clusters +// This transform is useful when the original dot tile is too large that there's +// not enough registers to hold data for a Dot cluster. This path slices the dot +// into four pieces and pair with four clusters of reordered memory operations. +// There are multiple guards at the boundary of each cluster. +// (1) sched.barrier : with mask0 to prevent compiler backed from reordering +// instructions across the boundary +// (2) gpu.barrier : ensures asymmetric synchronization at each point +// (3) setprio (1->0) : in order to avoid incoming warp overtaking resource +// while the other warp is actively using it. +// +// Here's overview of the instruction clusters +// mem0: global load A, local load A(1/4), local load B(1/4) +// dot0: dot A(1/4) * B(1/4) +// mem1: global load B, local load A(2/4), local load B(2/4) +// dot1: dot A(2/4) * B(2/4) +// mem2: local load A(3/4, 4/4), local load B(3/4, 4/4) +// dot2: dot A(3/4) * B(3/4) +// mem3: local store A and B +// dot3: dot A(4/4) * B(4/4) + +LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, + Location loc) { + // First, slice local_loads and dot into 4 parts + if (sliceDot(builder, loc, dotOps[0], 4).failed()) + return failure(); + builder.setInsertionPointAfter(gLoadOps[1]); + // Reorder operations into four mem/dot clusters + + // mem0: global load A, local load A(1/4), local load B(1/4) + // set insertion point at the last global_load where all the addresses are + // ready to be used. + updateOpInsertion(gLoadOps[1]); + appendSlicedLoadAB(/*slice=*/0); + appendClusterBarrier(builder, loc); + + // dot0 (1/4) + appendOpWithPrio(builder, dotSliceOps[0], loc); + appendClusterBarrier(builder, loc); + + // mem1: global load B, local load A(2/4), local load B(2/4) + appendOp(gLoadOps[1]); + appendSlicedLoadAB(/*slice=*/1); + appendClusterBarrier(builder, loc); + + // dot1 (2/4) + appendOpWithPrio(builder, dotSliceOps[1], loc); + appendClusterBarrier(builder, loc); + + // mem2: local load A(3/4, 4/4), local load B(3/4, 4/4) + appendSlicedLoadAB(/*slice=*/2); + appendSlicedLoadAB(/*slice=*/3); + appendClusterBarrier(builder, loc); + + // dot2 (3/4) + appendOpWithPrio(builder, dotSliceOps[2], loc); + appendClusterBarrier(builder, loc); + + // mem3: local store A and B + // Matmul kernels may use the output of the dot product in another operation + // before the local store (e.g. persistent matmul epilogue). To accommodate + // such cases, we need to move the local store up in the loop. + moveOpAndPredecessorsUpSameBlock(lStoreOps[0]); + moveOpAndPredecessorsUpSameBlock(lStoreOps[1]); + appendClusterBarrier(builder, loc); + + // dot3 (4/4) + appendOpWithPrio(builder, dotSliceOps[3], loc); + appendClusterBarrier(builder, loc); + + // Add a remark for user feedback + dotSliceOps[0]->emitRemark() + << "Performed four ping pong cluster transformation\n"; + return success(); +} + +// Transform a loop into two Dot - Memory (ping - pong) clusters +// This is useful for the medium sized tile which doesn't fit to either one/four +// cluster scheduling. +LogicalResult Pingponger::transformTwoPPClusters(OpBuilder &builder, + Location loc) { + // First, slice local_loads and dot into 2 parts + if (sliceDot(builder, loc, dotOps[0], 2).failed()) + return failure(); + // Reorder operations into two mem/dot clusters + + // Memory cluster #0 + // interleave local_loads and global_loads to minimize the stalling + // cycles, sched.barrier prevents backend from canceling the interleaved order + updateOpInsertion(gLoadOps[1]); + appendSlicedLoadAB(/*slice=*/0); + appendOp(builder.create(loc, 0)); + appendOp(gLoadOps[0]); + appendOp(builder.create(loc, 0)); + appendSlicedLoadAB(/*slice=*/1); + appendOp(builder.create(loc, 0)); + appendOp(gLoadOps[1]); + // The first cluster just fits into the two cluster pingpong and cannot + // include wait of the local_load inserted by the gpu.barrier, using s.barrier + // instead. backend will schedule the local memory fences later in the dot0 + // cluster. + appendOp(builder.create(loc)); + appendOp(builder.create(loc, 0)); + + // dot0 (1/2) + appendOpWithPrio(builder, dotSliceOps[0], loc); + appendClusterBarrier(builder, loc); + + // mem1: local store A and B + // Matmul kernels may use the output of the dot product in another operation + // before the local store (e.g. persistent matmul epilogue). To accommodate + // such cases, we need to move the local store up in the loop. + moveOpAndPredecessorsUpSameBlock(lStoreOps[0]); + moveOpAndPredecessorsUpSameBlock(lStoreOps[1]); + appendClusterBarrier(builder, loc); + + // dot1 (2/2) + appendOpWithPrio(builder, dotSliceOps[1], loc); + appendClusterBarrier(builder, loc); + + // Add a remark for user feedback + dotSliceOps[0]->emitRemark() + << "Performed two ping pong cluster transformation\n"; + return success(); +} + +// This function wraps forOp with cond_barrier. First, hold half of the warps +// (warpHigh) in a block before the loop so the barriers in the loop synchronize +// warps at the different point per the warp groups. After the loop, hold +// proceeding warps (warpLow) by calling cond_barrier on them. +void Pingponger::addAsymmetricSyncToLoop(OpBuilder &builder, Location loc) { + builder.setInsertionPointAfter(forOp); + // Set barrier before starting the loop. This resolves any remaining required + // synchronization before beginning the specialized asymmetric + // synchronization. + auto preBarrier = builder.create(loc); + preBarrier->moveBefore(forOp); + builder.setInsertionPointAfter(preBarrier); + + // Insert condbarrier::second_half before starting the loop + auto i32ty = builder.getIntegerType(32); + auto workIDX = builder.create(loc, i32ty); + auto constZero = builder.create(loc, 0, 32); + auto constWarpSize = builder.create(loc, 256, 32); + auto warpIDX = builder.create(loc, workIDX, constWarpSize); + auto warpLow = builder.create(loc, arith::CmpIPredicate::eq, + warpIDX, constZero); + auto warpHigh = builder.create(loc, arith::CmpIPredicate::ne, + warpIDX, constZero); + auto condBarrierHigh = + builder.create(loc, warpHigh); + + // Insert condbarrier::first_half after the end of the loop + builder.setInsertionPointAfter(forOp); + auto condBarrierLow = builder.create(loc, warpLow); +} + +void Pingponger::getDotPingponged() { + OpBuilder builder(forOp); + MLIRContext *ctx = forOp.getContext(); + Location loc = forOp.getLoc(); + + forOp->walk([&](Operation *op) { + if (auto gLoad = dyn_cast(op)) + gLoadOps.push_back(gLoad); + else if (auto lLoad = dyn_cast(op)) { + // This scheduling doesn't help hiding intra-warp latency. So, we only + // collect local_load ops that are software pipelined, which means their + // source is from loop carried values + auto src = lLoad.getSrc(); + if (auto arg = mlir::dyn_cast(src)) + if (auto tiedLoopInit = forOp.getTiedLoopInit(arg)) + if (tiedLoopInit->get()) + lLoadOps.push_back(lLoad); + } else if (auto lStore = dyn_cast(op)) + lStoreOps.push_back(lStore); + else if (auto pingpongDot = dyn_cast(op)) + if (pingpongDot.getType().getRank() == 2) + dotOps.push_back(pingpongDot); + }); + + // Currently, pingpong scheduling is known as helpful under limited condition. + // Individual conditions are checked while collecting each operation such as + // software pipelining and dot rank=2. Also only accept the for-loop with + // supported combination of operations because this transformation is very + // tightly scheduling the latencies. + if (gLoadOps.size() < 2 || lLoadOps.size() < 2 || dotOps.size() != 1) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << gLoadOps.size() << " global loads, " << lLoadOps.size() + << " local loads, " << dotOps.size() << " dot products"; + LDBG(message.str()); + return; + } + + // The existing code depends on the loads being targeted being safe to move, + // which will not hold if we do not properly have a GEMM. As a result, we + // filter the associated load operations to only those that are associated + // // with the GEMM. + DenseSet dotGlobalLoads; + DenseSet dotLocalLoads; + DenseSet dotLocalStores; + determineDotMemoryOps(dotOps[0], dotGlobalLoads, dotLocalLoads, + dotLocalStores); + + auto origGlobalLoadCount = gLoadOps.size(); + auto origLocalLoadCount = lLoadOps.size(); + // Prune Memory operations that may be moved to only those involved in dot + // computation. + auto gLoadIt = llvm::remove_if(gLoadOps, [&dotGlobalLoads](tt::LoadOp op) { + return !dotGlobalLoads.contains(op); + }); + gLoadOps.erase(gLoadIt, gLoadOps.end()); + auto lLoadIt = + llvm::remove_if(lLoadOps, [&dotLocalLoads](ttg::LocalLoadOp op) { + return !dotLocalLoads.contains(op); + }); + lLoadOps.erase(lLoadIt, lLoadOps.end()); + auto lStoreIt = + llvm::remove_if(lStoreOps, [&dotLocalStores](ttg::LocalStoreOp op) { + return !dotLocalStores.contains(op); + }); + lStoreOps.erase(lStoreIt, lStoreOps.end()); + // All PingPong Scheduler assumes there are 2 movable global loads and 2 + // movable local loads. + if (gLoadOps.size() != 2 || lLoadOps.size() != 2) { + std::stringstream message; + message << "Unable to match ping pong slicing pattern. Details: " + << gLoadOps.size() << " global loads in dot computation, " + << lLoadOps.size() << " local loads in dot computation"; + LDBG(message.str()); + return; + } + + // Pingpong scheduling tries to form two different types of the instruction + // clusters, i.e., Dot clusters and Memory clusters. While each SIMD has + // two concurrent warps, both warps can execute a different type of + // instruction cluster in parallel. Here are currently available patterns, + // more patterns could be added later. + // + // (1) One Dot-Memory (ping-pong) cluster + // :Ideal to support small tile size e.g., 128x128x64_FP16. Where amount + // of the data used per each iteration is small enough and not causing + // local_load waiting or register spilling. Currently used for numWarps=4 + // case where SIMD can hold two warps from different blocks. + // + // (2) Four Dot-Memory (ping-pongx4) clusters + // :Useful for the larger tile size e.g., 256x256x64_FP16. Clustering + // the Dot instruction (mfma) all together without fetching data requires + // GPU to hold all the data for the calculation. Such large tile size + // exceeds the amount of register GPU has so, we need to split the dot + // into several pieces. + // + // (3) Two Dot-Memory (ping-pongx2) clusters + // :Covers medium sized tile e.g., 256x128x64_FP16. Different tile size may + // require different scheduling pattern because the loop consists of + // different amount of memory transfer and dot operation. This scheduling + // support the tile sizes not supported by above two methods. + // + // N.B., Tile size smaller than 128x128x64_FP16 is likely not compute-bound + // that pingpong scheduling doesn't help much. + + auto dotType = dotOps[0].getType(); + auto dotShape = dotType.getShape(); + auto aType = dotOps[0].getA().getType(); + auto aShape = aType.getShape(); + auto elemWidth = aType.getElementTypeBitWidth(); + int64_t tileSize = dotShape[0] * dotShape[1] * aShape[1] * elemWidth; + + const int64_t minTile = 262144; // e.g. 32x128x64x16bit + const int64_t smallTile = 16777216; // e.g. 128x128x64x16bit + const int64_t mediumTile = 33554432; // smallTile x 2 + const int64_t largeTile = 67108864; // e.g. 256x256x64x16bit + + auto encoding = cast(aType).getEncoding(); + auto srcEncoding = cast(encoding); + kWidth = srcEncoding.getKWidth(); + auto mfmaEncoding = cast(srcEncoding.getParent()); + SmallVector intShape; + intShape.push_back(mfmaEncoding.getMDim()); + intShape.push_back(mfmaEncoding.getNDim()); + + if (numWarps == 4) { // Pingpong between warps from different blocks + // Transform a loop with small tile size. + // We've observed that this small tile size spent almost equivalent cycle + // times for issuing the memory operations and issuing dot operations, + // smaller tile sizes are not likely to get any advantage from current dot + // centric pingpong scheduling. + if (tileSize <= smallTile && tileSize >= minTile) + transformOnePPClusters(builder, loc); + // numWarps=4 doesn't need asymmetric sync, return. + return; + } else if (numWarps == 8) { // Pingpong between warps from the same block + if (origGlobalLoadCount != 2 || origLocalLoadCount != 2) { + std::stringstream message; + message << "Unable to match ping pong slicing pattern. Details: " + << gLoadOps.size() << " global loads, " << lLoadOps.size() + << " local loads"; + LDBG(message.str()); + return; + } + if (lStoreOps.size() != 2) { + std::stringstream message; + message << "Unable to match ping pong slicing pattern. Details: " + << lStoreOps.size() << " local stores in dot computation "; + LDBG(message.str()); + return; + } + // Transform a loop where the tile size requires dots to be sliced + if (tileSize == mediumTile) { + if (transformTwoPPClusters(builder, dotOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the two ping pong " + "cluster transformation"); + return; + } + } else if (tileSize >= largeTile) { + // Avoid known register spilling. i.e., mfma16x16x16 & largetile & kpack>1 + if (intShape[0] == 16 && intShape[1] == 16 && kWidth == 8) { + LDBG("Reached known register spilling case, skip pingpong scheduling"); + return; + } + if (transformFourPPClusters(builder, dotOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the four ping pong " + "cluster transformation"); + return; + } + } else + return; + + // Let half of the warps start the loop first and the others follow later + // but in the synchronized way. This can be accomplished by calling + // cond_barrier for the second half before the beginning of the loop so they + // can wait until the first half hit the first barrier in the loop. Also + // need to call cond_barrier for the first_half after exiting the loop, so + // all warps can converge again. + addAsymmetricSyncToLoop(builder, loc); + } +} + +class TritonAMDGPUBlockPingpongPass + : public TritonAMDGPUBlockPingpongBase { +public: + TritonAMDGPUBlockPingpongPass() = default; + void runOnOperation() override { + ModuleOp m = getOperation(); + for (auto funcOp : m.getOps()) { + funcOp.walk([&](scf::ForOp forOp) { + Pingponger pingponger(forOp, ttg::lookupNumWarps(forOp)); + pingponger.getDotPingponged(); + }); + } + } +}; +} // namespace + +std::unique_ptr mlir::createTritonAMDGPUBlockPingpongPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt new file mode 100644 index 000000000..4ca78d3ea --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -0,0 +1,21 @@ +add_triton_library(TritonAMDGPUTransforms + AccelerateAMDMatmul.cpp + BlockPingpong.cpp + CanonicalizePointers.cpp + ConvertToBufferOps.cpp + OptimizeEpilogue.cpp + HoistLayoutConversions.cpp + ReorderInstructions.cpp + StreamPipeline.cpp + MfmaGroup.cpp + + DEPENDS + TritonAMDGPUIR + TritonAMDGPUTransformsIncGen + TritonGPUIR + TritonAMDUtils + TritonAMDAnalysis +) + +target_include_directories(TritonAMDGPUTransforms PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include) +target_include_directories(TritonAMDGPUTransforms PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/../../include) diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp new file mode 100644 index 000000000..d47dbb281 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -0,0 +1,1531 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/OneToNTypeConversion.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" +#include + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +#define DEBUG_TYPE "tritonamdgpu-canonicalize-pointers" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = triton; + +// ----------------------------------------------------------------------------- +// Pointer canonicalizer utility class +// ----------------------------------------------------------------------------- +// This class iterates through the argument of the `funcOp`, if the argument is +// a pointer, starts a walk through its transitive uses to build a in-memory +// data structure to record the current offset to that pointer. Only when the +// pointer is really loaded/stored we materialize the base pointer with the +// offset. +// +// Let's suppose that `arg0` is a pointer. The algorithm works like that: +// +// a) At the beginning the offset is a tensor initialized to zero, and we +// associate with `%arg0` a `FatPtr{basePtr=%arg0, offset=0}`. Through the +// algorithm `FatPtr.basePtr` represents the scalar base pointer (all the +// uniform updates will go into that) and `FatPtr.offset` represents the +// tensor offset (all the non-uniform updates will go into that) +// +// +// b) Follow the pointer through the IR. When we meet: +// `%ptr = tt.addptr(%arg0, %offset)` +// +// Isolate the uniform and the non-uniform contributions of %offset = +// (%u_offset, %nu_offset) and update the scalar pointer and the tensor +// offset +// ``` +// %s_ptr = addi(%fatPoniters[ptr].basePtr, %u_offset) +// %t_offset = addi(%fatPoniters[ptr].offset, %nu_offset) +// %fatPointers[%ptr0] = FatPtr{base=%s_ptr, offset=%t_offset} +// ``` +// c) When we meet the `tt.load(%ptr)` or `tt.store(%ptr)` instructions, +// replace that instruction with: +// `%t_ptr = tt.splat(%fatPointers[%ptr].basePtr) +// `%fat_ptr = tt.addptr(%t_ptr, %fatPointers[ptr].offset)` +// `%data = tt.load(%fat_ptr)` +// +// Please note that `%offset` might be a 32bit or 64bit integer. If +// we can, we would like to use 32 bit integers. This can happen under +// certain conditions: +// +// a) We can determine that the offset cannot overflow. In this case, we can +// downcast the pointer just before emitting the load +// b) We know that the underlying memory size can be expressed as a 32 bit +// value. In this case we can simply start with a 32bit offset and downcast +// if we ever meet 64 bit operations (because we know that the offset can be +// contained in 32 bits) +// +namespace { + +// Extend a 32bit `offset` into 64bit using a arith.extsi operation +static Value createExtend32bitOffsetTo64Bits(RewriterBase &rewriter, + Location loc, Value offset) { + if (auto tensorType = dyn_cast(offset.getType())) { + auto shape = tensorType.getShape(); + auto newTensorType = RankedTensorType::get(shape, rewriter.getI64Type(), + tensorType.getEncoding()); + return rewriter.create(loc, newTensorType, offset); + } + return rewriter.create(loc, rewriter.getI64Type(), offset); +} + +// Narrow a 64bit `offset` into 32bit using a arith.trunci operation +static Value createNarrow64bitOffsetTo32bits(RewriterBase &rewriter, + Location loc, Value offset) { + Type elementType = getElementTypeOrSelf(offset); + if (elementType.isInteger(32)) + return offset; + + if (auto tensorType = dyn_cast(offset.getType())) { + auto shape = tensorType.getShape(); + auto newTensorType = RankedTensorType::get(shape, rewriter.getI32Type(), + tensorType.getEncoding()); + return rewriter.create(loc, newTensorType, offset); + } + return rewriter.create(loc, rewriter.getI32Type(), offset); +} + +// Helper function to determine if the given `op` is a constant tensor and in +// that case return the scalar value. +std::optional maybeGetOrCreateScalarConstant(RewriterBase &rewriter, + Location loc, Value expr) { + Operation *op = expr.getDefiningOp(); + + // Check for splatness + if (auto splatOp = dyn_cast_or_null(op)) + return splatOp.getSrc(); + + // Check for constant + DenseIntElementsAttr constVal; + if (auto constOp = dyn_cast_or_null(op)) { + Value val = constOp.getResult(); + if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat()) + return rewriter.create( + loc, constVal.getSplatValue()); + } + + // Check for block arguments + if (auto blockArg = dyn_cast_or_null(expr)) { + Type type = blockArg.getType(); + if (!isa(type)) + return blockArg; + } + + return {}; +} + +// Narrowing logic +// For now we allow to narrow down to 32 bits only in the following case: +// - `baseOffset` is 32-bits and `addOffset`(64-bits) is zero +bool canNarrowOffset(Value baseOffset, Value addOffset) { + Type addOffsetType = getElementTypeOrSelf(addOffset); + auto baseSplatOp = baseOffset.getDefiningOp(); + return baseSplatOp && addOffsetType.isInteger(32); +} + +// Create a zero tensor with a given `type` +Value createTensorZero(RewriterBase &rw, Location loc, RankedTensorType type) { + mlir::Attribute zeroAttr = rw.getZeroAttr(type.getElementType()); + auto zeroDenseAttr = DenseElementsAttr::get(type, zeroAttr); + return rw.create(loc, zeroDenseAttr); +} + +} // namespace + +std::pair createDecomposeOffsetFromExpr(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness); +// Offset extraction logic for an addition op: +// decompose(A+B) = {U(A)+U(B), NU(A)+NU(B)} +std::pair createDecomposeOffsetFromAdd(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { + auto addOp = expr.getDefiningOp(); + auto [uniformOffsetL, nonUniformOffsetL] = + createDecomposeOffsetFromExpr(rewriter, loc, addOp.getLhs(), bitness); + auto [uniformOffsetR, nonUniformOffsetR] = + createDecomposeOffsetFromExpr(rewriter, loc, addOp.getRhs(), bitness); + Value uniformAdd = + rewriter.create(loc, uniformOffsetL, uniformOffsetR); + Value nonUniformAdd = + rewriter.create(loc, nonUniformOffsetL, nonUniformOffsetR); + return {uniformAdd, nonUniformAdd}; +} + +// Offset extraction logic for a multiplication op: +// decompose(A*B) = {U(A)*U(B), NU(A)*NU(B)+NU(B)*U(A)+U(A)*NU(B)} +std::pair createDecomposeOffsetFromMul(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { + auto mulOp = expr.getDefiningOp(); + auto [uniformOffsetL, nonUniformOffsetL] = + createDecomposeOffsetFromExpr(rewriter, loc, mulOp.getLhs(), bitness); + auto [uniformOffsetR, nonUniformOffsetR] = + createDecomposeOffsetFromExpr(rewriter, loc, mulOp.getRhs(), bitness); + Value uniformMul = + rewriter.create(loc, uniformOffsetL, uniformOffsetR); + + Value uniformOffsetLSplat = rewriter.create( + loc, nonUniformOffsetL.getType(), uniformOffsetL); + Value uniformOffsetRSplat = rewriter.create( + loc, nonUniformOffsetR.getType(), uniformOffsetR); + + Value nonUNonU = + rewriter.create(loc, nonUniformOffsetL, nonUniformOffsetR); + Value nonUU = rewriter.create(loc, uniformOffsetLSplat, + nonUniformOffsetR); + Value uNonU = rewriter.create(loc, nonUniformOffsetL, + uniformOffsetRSplat); + + Value tmp = rewriter.create(loc, nonUNonU, nonUU); + Value nonUniformMul = rewriter.create(loc, tmp, uNonU); + return {uniformMul, nonUniformMul}; +} + +std::pair createDecomposeOffsetFromExpr(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { + + // Base case 1: it is a splat. Return the scalar constant as the uniform part + if (auto scalarConst = maybeGetOrCreateScalarConstant(rewriter, loc, expr)) { + auto tensorZero = + createTensorZero(rewriter, loc, cast(expr.getType())); + return {*scalarConst, tensorZero}; + } + + // Base case 2: block argument. Since it is not a scalar constant, it must be + // a tensor. Note that this means we won't be able to decompose across loop + // boundaries (TODO: giuseros). + if (llvm::isa(expr)) { + Value scalarZero = rewriter.create(loc, 0, bitness); + return {scalarZero, expr}; + } + + auto offsets = + llvm::TypeSwitch>( + expr.getDefiningOp()) + .Case([&](auto broadcastOp) { + auto [uniform, nonUniform] = createDecomposeOffsetFromExpr( + rewriter, loc, broadcastOp.getSrc(), bitness); + auto broadcastNonUniform = rewriter.create( + loc, broadcastOp.getType(), nonUniform); + return std::make_pair(uniform, broadcastNonUniform); + }) + .Case([&](auto expandOp) { + auto [uniform, nonUniform] = createDecomposeOffsetFromExpr( + rewriter, loc, expandOp.getSrc(), bitness); + auto expandNonUniform = rewriter.create( + loc, nonUniform, expandOp.getAxis()); + return std::make_pair(uniform, expandNonUniform); + }) + .Case([&](Operation *op) { + return createDecomposeOffsetFromAdd(rewriter, loc, expr, bitness); + }) + .Case([&](Operation *op) { + return createDecomposeOffsetFromMul(rewriter, loc, expr, bitness); + }) + .Default([&](Operation *op) { + // Base case 3: it is not a supported operation. We assume no + // uniform part + Value scalarZero = + rewriter.create(loc, 0, bitness); + return std::make_pair(scalarZero, expr); + }); + + return offsets; +} + +static const std::string kPtrCanonPrefix = "__amdpointercanonicalize."; +static const std::string kSCFThenRewrittenAttr = + kPtrCanonPrefix + "scf-then-rewritten__"; +static const std::string kSCFElseRewrittenAttr = + kPtrCanonPrefix + "scf-else-rewritten__"; +static const std::string kSCFIfOpYieldFatPtrOffsets = + kPtrCanonPrefix + "scf-if-yield-fatptr-offsets__"; + +/// This struct is basically a thin wrapper over DenseMap +/// where fatPtr == (base, offset) and fatPtrAttrs is itself a map of (name, +/// attribute). +/// It is used to associate metadata/attributes with the canonicalized fat +/// pointers, such as `tt.pointer_range` and whether operations involving them +/// can be narrowed (`canNarrow`). +struct FatPointers { + struct FatPtrAttrs { + FatPtrAttrs(const FatPtrAttrs &other) = default; + FatPtrAttrs &operator=(const FatPtrAttrs &other) = default; + // for map default insert + FatPtrAttrs() = default; + + friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { + return lhs.canNarrow == rhs.canNarrow && lhs.attributes == rhs.attributes; + } + + friend bool operator!=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { + return !(lhs == rhs); + } + + llvm::DenseMap attributes; + bool canNarrow = false; + }; + + using KeyT = std::pair; + using ValueT = FatPtrAttrs; + using DenseMapT = DenseMap; + + void collectFatPointerAttributes(const KeyT &k); + ValueT &operator[](const KeyT &k) { + if (!pointerAttrs.contains(k)) + collectFatPointerAttributes(k); + return pointerAttrs[k]; + } + + ValueT &operator[](KeyT &&k) { + if (!pointerAttrs.contains(k)) + collectFatPointerAttributes(k); + return pointerAttrs[k]; + } + + template + using const_arg_type_t = typename llvm::const_pointer_or_const_ref::type; + const ValueT &at(const_arg_type_t k) const { + // this is redundant - DenseMap will assert the same thing - but better to + // have our own message + assert(pointerAttrs.contains(k) && + "expected fatPtrs to contain remapped fat pointer"); + return pointerAttrs.at(k); + } + + bool contains(const KeyT &k) { return pointerAttrs.contains(k); } + +private: + DenseMapT pointerAttrs; +}; + +// TODO(max): reconsider this approach, specifically how narrowing and +// attributes are propagated starting from a tt.ptr. +void FatPointers::collectFatPointerAttributes(const KeyT &k) { + auto [base, offset] = k; + // If it is the i-th block argument, then look if the operation defined some + // _argi attribute and add it to the fat pointer attributes + if (auto arg = dyn_cast(base)) { + // If the value is a block parameter, the operation can specify + // an attribute for the given parameter by using `tt.property_argi` + // where `argi` refers to the arg number of the given parameter. + // So we need to iterate through the property, find the right one + // and push the property onto the pointers attributes. + auto op = arg.getOwner()->getParentOp(); + for (NamedAttribute namedAttr : op->getAttrs()) { + StringAttr attrName = namedAttr.getName(); + std::string argSuffix = + llvm::formatv("_arg{0}", arg.getArgNumber()).str(); + if (!attrName.strref().ends_with(argSuffix)) + continue; + + auto newAttrName = attrName.strref().drop_back(argSuffix.size()); + pointerAttrs[k].attributes[newAttrName] = namedAttr.getValue(); + // Propagate the argument to the offset if it is also a block + // argument + if (auto offsetArg = dyn_cast(offset)) + op->setAttr( + llvm::formatv("{0}_arg{1}", newAttrName, offsetArg.getArgNumber()) + .str(), + namedAttr.getValue()); + } + return; + } + + // Otherwise add the attributes of the base to the fat pointer + for (auto baseAttr : base.getDefiningOp()->getAttrs()) + pointerAttrs[k].attributes[baseAttr.getName()] = baseAttr.getValue(); +} + +Value createTensorPointer(RewriterBase &rewriter, Value basePtr, Value offset, + Location loc, + const FatPointers::FatPtrAttrs &fatPtrAttrs) { + auto tensorType = dyn_cast(offset.getType()); + + // Scalar case: we only need to `tt.addptr %basePtr, %offset` + if (!tensorType) { + auto addPtrOp = + rewriter.create(loc, basePtr.getType(), basePtr, offset); + for (const auto &attribute : fatPtrAttrs.attributes) + addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); + return addPtrOp.getResult(); + } + + // Tensor case: splat the scalar pointer and add the (tensor) offset: + // ``` + // %tensorBasePtr = tt.splat %basePtr + // %tensorPtr = tt.addptr %tensorBasePtr, %offset + // ``` + ArrayRef offsetShape = tensorType.getShape(); + auto tensorPtrType = RankedTensorType::get(offsetShape, basePtr.getType(), + tensorType.getEncoding()); + if (fatPtrAttrs.canNarrow) + offset = createNarrow64bitOffsetTo32bits(rewriter, loc, offset); + + tt::SplatOp tensorPtr = + rewriter.create(loc, tensorPtrType, basePtr); + tt::AddPtrOp addPtrOp = + rewriter.create(loc, tensorPtrType, tensorPtr, offset); + + for (const auto &attribute : fatPtrAttrs.attributes) + addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); + return addPtrOp.getResult(); +} + +/// Flatten the given value ranges into a single vector of values. +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const ValueRange &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Assert that the given value range contains a single value and return it. +static Value getSingleValue(ValueRange values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); +} + +/// This is convenience class (that is a copy-paste of some of +/// OpConversionPattern) that keeps track of (and removes from) opToRewrite +/// after successful matchAndRewrite_ calls; subclasses must define +/// matchAndRewrite_ just as that would for conventional OpConversionPatterns. +template +struct PointerCanonicalizationPattern : ConversionPattern { + using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor>; + + PointerCanonicalizationPattern(MLIRContext *context, + llvm::SetVector &opsToRewrite, + FatPointers &fatPtrs, + PatternBenefit benefit = 1) + : ConversionPattern(SourceOp::getOperationName(), benefit, context), + fatPtrs(fatPtrs), opToRewrite(opsToRewrite) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto sourceOp = cast(op); + if (failed(matchAndRewrite_(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter))) + return failure(); + opToRewrite.remove(op); + return success(); + } + + virtual LogicalResult + matchAndRewrite_(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const = 0; + + FatPointers &fatPtrs; + llvm::SetVector &opToRewrite; +}; + +/// splat integer offset, keep base +class ConvertSplatOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(tt::SplatOp splatOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedOperands = adaptor.getSrc(); + if (remappedOperands.size() != 2) { + // some prior op materialized the fat ptr, e.g.: + // %3 = tt.bitcast %2 + // %4 = tt.splat %3 + return success(); + } + Value fatPtrBase = remappedOperands[0]; + Value fatPtrOffset = remappedOperands[1]; + if (!llvm::isa(fatPtrBase.getType())) + return rewriter.notifyMatchFailure(splatOp, + "non tt.ptr base unimplemented"); + if (!llvm::isa(fatPtrOffset.getType())) + return rewriter.notifyMatchFailure(splatOp, + "non-integer offset unimplemented"); + + RankedTensorType outType = splatOp.getResult().getType(); + auto newOffsetType = RankedTensorType::get( + outType.getShape(), fatPtrOffset.getType(), outType.getEncoding()); + tt::SplatOp offset = rewriter.create( + splatOp.getLoc(), newOffsetType, fatPtrOffset); + rewriter.replaceOpWithMultiple(splatOp, {{fatPtrBase, offset}}); + fatPtrs[{fatPtrBase, offset}] = fatPtrs.at({fatPtrBase, fatPtrOffset}); + + return success(); + } +}; + +/// Broadcast offset, keep base. +class ConvertBroadcastOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(tt::BroadcastOp broadcastOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedOperands = adaptor.getSrc(); + if (remappedOperands.size() != 2) { + // some prior op materialized the fat ptr, e.g.: + // %3 = tt.bitcast %2 + // %4 = tt.broadcast %3 + return success(); + } + + Value fatPtrBase = remappedOperands[0]; + Value fatPtrOffset = remappedOperands[1]; + if (!llvm::isa(fatPtrBase.getType())) + return rewriter.notifyMatchFailure(broadcastOp, + "non tt.ptr base unimplemented"); + auto offsetType = dyn_cast(fatPtrOffset.getType()); + if (!offsetType) + return rewriter.notifyMatchFailure(broadcastOp, + "non-tensor offset unimplemented"); + + auto outType = + dyn_cast(broadcastOp.getResult().getType()); + auto newOffsetType = RankedTensorType::get( + outType.getShape(), offsetType.getElementType(), outType.getEncoding()); + tt::BroadcastOp newOffset = rewriter.create( + broadcastOp.getLoc(), newOffsetType, fatPtrOffset); + rewriter.replaceOpWithMultiple(broadcastOp, {{fatPtrBase, newOffset}}); + fatPtrs[{fatPtrBase, newOffset}] = fatPtrs.at({fatPtrBase, fatPtrOffset}); + return success(); + } +}; + +/// Three cases: +/// 1. If it is a scalar pointer update -> bump only the base pointer; +/// 2. Constant tensor offset -> bump only the offset +/// 3. Non-constant tensor offset -> decompose parent(offset) into uniform and +/// non-uniform components. +class ConvertAddPtrOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(tt::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedPtr = adaptor.getPtr(); + if (remappedPtr.size() != 2) { + // some prior op materialized the fat ptr, e.g.: + // %3 = tt.bitcast %2 + // %4 = tt.addptr %3 + return success(); + } + ValueRange nonRemappedOffset = adaptor.getOffset(); + if (nonRemappedOffset.size() != 1) + return rewriter.notifyMatchFailure( + addPtrOp, "expected AddPtrOp Offset to have not have been remapped"); + Value fatPtrBase = remappedPtr[0]; + Value fatPtrOffset = remappedPtr[1]; + Value origOffset = nonRemappedOffset[0]; + Location curLoc = addPtrOp.getLoc(); + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(addPtrOp); + + // If it is a scalar pointer update, simply bump the base pointer + if (llvm::isa(addPtrOp.getPtr().getType())) { + assert(llvm::isa(origOffset.getType()) && + "expected offset to be integer type"); + auto newAddPtrOp = rewriter.create( + curLoc, fatPtrBase.getType(), fatPtrBase, origOffset); + rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); + fatPtrs[{newAddPtrOp, fatPtrOffset}] = + fatPtrs.at({fatPtrBase, fatPtrOffset}); + return success(); + } + + assert(llvm::isa(addPtrOp.getPtr().getType()) && + "expected Ptr to be RankedTensorType type"); + + // Early exit for the case of a constant tensor + if (auto scalarConst = + maybeGetOrCreateScalarConstant(rewriter, curLoc, origOffset)) { + tt::AddPtrOp newAddPtrOp = rewriter.create( + curLoc, fatPtrBase.getType(), fatPtrBase, *scalarConst); + rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); + // If we are updating the tensor pointer with a constant value, we can + // propagate the attributes of the tensor pointer to the fat pointer. + fatPtrs[{newAddPtrOp, fatPtrOffset}] = + fatPtrs.at({fatPtrBase, fatPtrOffset}); + return success(); + } + + int64_t bitness = llvm::cast(origOffset.getType()) + .getElementTypeBitWidth(); + auto [uniformOffset, nonUniformOffset] = + createDecomposeOffsetFromExpr(rewriter, curLoc, origOffset, bitness); + + auto newAddPtrOp = rewriter.create( + curLoc, fatPtrBase.getType(), fatPtrBase, uniformOffset); + + // Vector offset update (if any): bump the tensor offset + bool canNarrow = fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow; + bool propagateAtrs = true; + Value newOffset = fatPtrOffset; + if (!isZeroConst(nonUniformOffset)) { + Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset); + Type fatPtrOffsetType = getElementTypeOrSelf(fatPtrOffset); + canNarrow = canNarrow && canNarrowOffset(fatPtrOffset, nonUniformOffset); + // Upcast or downcast the offset accordingly + if (addPtrOffsetType.isInteger(32) && fatPtrOffsetType.isInteger(64)) + nonUniformOffset = + createExtend32bitOffsetTo64Bits(rewriter, curLoc, nonUniformOffset); + else if (addPtrOffsetType.isInteger(64) && fatPtrOffsetType.isInteger(32)) + nonUniformOffset = + createNarrow64bitOffsetTo32bits(rewriter, curLoc, nonUniformOffset); + + newOffset = rewriter.create(curLoc, nonUniformOffset, + fatPtrOffset); + propagateAtrs = false; + } + + rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, newOffset}}); + auto nextFatPtr = std::pair{newAddPtrOp.getResult(), newOffset}; + fatPtrs[nextFatPtr].canNarrow = canNarrow; + if (propagateAtrs) + fatPtrs[nextFatPtr].attributes = + fatPtrs.at({fatPtrBase, fatPtrOffset}).attributes; + + return success(); + } +}; + +using ConversionCallbackFn = + std::function(Type, SmallVectorImpl &)>; + +/// Rewrite init args and result type and bb args. +class ConvertSCFForOp : public PointerCanonicalizationPattern { + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + +public: + LogicalResult + matchAndRewrite_(scf::ForOp forOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector valRangeLens; + ArrayRef remappedInits = adaptor.getInitArgs(); + for (ValueRange remappedInit : remappedInits) + valRangeLens.push_back(remappedInit.size()); + + // rewrite the body bb args + Block *oldBodyBlock = forOp.getBody(); + auto oldTypes = oldBodyBlock->getArgumentTypes(); + TypeConverter::SignatureConversion sigConversion(oldTypes.size()); + // handle the 0th arg which is the induction var + sigConversion.addInputs(0, {oldTypes[0]}); + for (unsigned i = 1, e = oldTypes.size(); i != e; ++i) { + SmallVector remappedInitTypes = + llvm::to_vector(remappedInits[i - 1].getTypes()); + sigConversion.addInputs(i, remappedInitTypes); + } + auto newBodyBlock = + rewriter.applySignatureConversion(oldBodyBlock, sigConversion); + + // propagate fatPtrAttrs to bb arg fatPtrs in for body bb + // skip iv at index 0 + int offset = 1; + for (auto operands : remappedInits) { + if (operands.size() == 2) { + fatPtrs[{newBodyBlock->getArgument(offset), + newBodyBlock->getArgument(offset + 1)}] = + fatPtrs.at({operands[0], operands[1]}); + } + offset += operands.size(); + } + + SmallVector initArgs = flattenValues(adaptor.getInitArgs()); + auto newForOp = rewriter.create( + forOp.getLoc(), getSingleValue(adaptor.getLowerBound()), + getSingleValue(adaptor.getUpperBound()), + getSingleValue(adaptor.getStep()), initArgs); + + newForOp->setAttrs(forOp->getAttrs()); + rewriter.eraseBlock(newForOp.getBody()); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().end()); + + SmallVector packedRets; + for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { + size_t len = valRangeLens[i]; + assert(offset < newForOp->getNumResults() && + "expected offset to be within bounds of results"); + ValueRange mappedValue = newForOp->getResults().slice(offset, len); + // propagate fatPtrs + if (mappedValue.size() == 2) { + assert(remappedInits[i].size() == 2 && + "expected corresponding inits to be a remapped fat ptr"); + fatPtrs[{mappedValue[0], mappedValue[1]}] = + fatPtrs.at({remappedInits[i][0], remappedInits[i][1]}); + } + packedRets.push_back(mappedValue); + offset += len; + } + + rewriter.replaceOpWithMultiple(forOp, packedRets); + + return success(); + } +}; + +/// Rewrite with new remapped operands but also if the scf.yield is inside of +/// scf.if (possibly) annotate the scf.if. +class ConvertSCFYieldOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ArrayRef remappedYields = adaptor.getOperands(); + SmallVector newYieldedValues = flattenValues(remappedYields); + // have to mutate here because otherwise scf.if, scf.for, and scf.while will + // get confused about which yield is the "correct" yield (since there will + // be two of them before the rewriter DCEs) + rewriter.modifyOpInPlace(yieldOp, [&]() { + yieldOp.getResultsMutable().clear(); + yieldOp.getResultsMutable().append(newYieldedValues); + }); + + // rewriting a parent op from a child op isn't a great idea but there's no + // other to indicate to the parent IfOp that the result type can now be + // rewritten and not before. + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + rewriter.modifyOpInPlace(ifOp, [&] { + ifOp->setDiscardableAttr(ifOp.thenBlock() == yieldOp->getBlock() + ? kSCFThenRewrittenAttr + : kSCFElseRewrittenAttr, + rewriter.getUnitAttr()); + }); + // set indices of fatPtrs so that IfOp can propagate canNarrow to + // result users + int offset = 0; + SmallVector fatPtrOffsets; + for (auto operands : remappedYields) { + if (operands.size() == 2) + fatPtrOffsets.push_back(offset); + offset += operands.size(); + } + if (!fatPtrOffsets.empty()) + yieldOp->setDiscardableAttr( + kSCFIfOpYieldFatPtrOffsets, + rewriter.getDenseI64ArrayAttr(fatPtrOffsets)); + } + + return success(); + } +}; + +/// Simple here means each block arg is replaced 1-1 with the remapped operand +/// types (e.g., scf.for does not use this helper because scf.for needs to skip +/// the 0th bb arg, the induction var). +static void convertSimpleBlockSignature(Block *oldBlock, + ArrayRef remappedOperands, + ConversionPatternRewriter &rewriter, + FatPointers &fatPtrs) { + auto oldBlockTypes = oldBlock->getArgumentTypes(); + TypeConverter::SignatureConversion blockSigConversion(oldBlockTypes.size()); + for (unsigned i = 0, e = oldBlockTypes.size(); i != e; ++i) { + SmallVector remappedInitTypes = + llvm::to_vector(remappedOperands[i].getTypes()); + blockSigConversion.addInputs(i, remappedInitTypes); + } + auto newBlock = + rewriter.applySignatureConversion(oldBlock, blockSigConversion); + + int offset = 0; + for (auto operands : remappedOperands) { + if (operands.size() == 2) { + assert(fatPtrs.contains({operands[0], operands[1]}) && + "expected fatPtrs to contain existing (op0, op1) fat pointer"); + fatPtrs[{newBlock->getArgument(offset), + newBlock->getArgument(offset + 1)}] = + fatPtrs.at({operands[0], operands[1]}); + } + offset += operands.size(); + } +} + +/// Rewrite init_args, result type, before region bb args, after region bb args. +class ConvertSCFWhileOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite_(scf::WhileOp whileOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector valRangeLens; + ArrayRef remappedInits = adaptor.getInits(); + for (ValueRange remappedInit : remappedInits) + valRangeLens.push_back(remappedInit.size()); + + convertSimpleBlockSignature(whileOp.getBeforeBody(), remappedInits, + rewriter, fatPtrs); + convertSimpleBlockSignature(whileOp.getAfterBody(), remappedInits, rewriter, + fatPtrs); + + SmallVector initArgs = flattenValues(remappedInits); + SmallVector resultTypes = llvm::map_to_vector( + llvm::make_filter_range( + initArgs, [](Value v) { return !v.getType().isInteger(1); }), + [](Value v) { return v.getType(); }); + auto newWhileOp = + rewriter.create(whileOp.getLoc(), resultTypes, initArgs); + + newWhileOp->setAttrs(whileOp->getAttrs()); + rewriter.inlineRegionBefore(whileOp.getBefore(), newWhileOp.getBefore(), + newWhileOp.getBefore().end()); + rewriter.inlineRegionBefore(whileOp.getAfter(), newWhileOp.getAfter(), + newWhileOp.getAfter().end()); + + SmallVector packedRets; + for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { + // skip %cond + if (remappedInits[i].size() == 1 && + remappedInits[i].getType()[0].isInteger(1)) + continue; + size_t len = valRangeLens[i]; + assert(offset < newWhileOp->getNumResults() && + "expected offset to be within bounds of results"); + ValueRange mappedValue = newWhileOp->getResults().slice(offset, len); + // propagate fatPtrs + if (mappedValue.size() == 2) { + assert(remappedInits[i].size() == 2 && + "expected corresponding inits to be a remapped fat ptr"); + fatPtrs[{mappedValue[0], mappedValue[1]}] = + fatPtrs.at({remappedInits[i][0], remappedInits[i][1]}); + } + packedRets.push_back(mappedValue); + offset += len; + } + + rewriter.replaceOpWithMultiple(whileOp, packedRets); + + return success(); + } +}; + +/// Rewrite with new operands. +class ConvertSCFConditionOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite_(scf::ConditionOp condOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector newArgs = flattenValues(adaptor.getArgs()); + // have to mutate here because otherwise scf.while will + // get confused about which condition is the "correct" condition (since + // there will be two of them before the rewriter DCEs) + rewriter.modifyOpInPlace(condOp, [&]() { + condOp.getArgsMutable().clear(); + condOp.getArgsMutable().append(newArgs); + }); + return success(); + } +}; + +/// Rewrite operands for both true dest and false dest. +class ConvertCFCondBranch + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite_(cf::CondBranchOp branchOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ArrayRef remappedTrueOperands = adaptor.getTrueDestOperands(); + ArrayRef remappedFalseOperands = adaptor.getFalseDestOperands(); + SmallVector trueOperands = flattenValues(remappedTrueOperands); + SmallVector falseOperands = flattenValues(remappedFalseOperands); + + rewriter.replaceOpWithNewOp( + branchOp, branchOp.getCondition(), branchOp.getTrueDest(), trueOperands, + branchOp.getFalseDest(), falseOperands); + + convertSimpleBlockSignature(branchOp.getTrueDest(), remappedTrueOperands, + rewriter, fatPtrs); + convertSimpleBlockSignature(branchOp.getFalseDest(), remappedFalseOperands, + rewriter, fatPtrs); + + return success(); + } +}; + +/// Rewrite select(fatPtrTrue, fatPtrFalse) -> +/// ( +/// select(fatPtrTrueBase, fatPtrTrueOffset), +/// select(fatPtrFalseBase, fatPtrFalseOffset) +/// ) +/// +/// Note, this should only be reached after both +/// operands have already been rewritten because DialectConversion walks +/// PreOrder in order ForwardDominance order: see +/// https://github.com/llvm/llvm-project/blob/58389b220a9354ed6c34bdb9310a35165579c5e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2702 +class ConvertArithSelectOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite_(arith::SelectOp selectOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getTrueValue().size() != 2 || + adaptor.getFalseValue().size() != 2) { + assert(adaptor.getTrueValue().size() == adaptor.getFalseValue().size() && + "expected both true and false operands to be the same size"); + return success(); + } + // If both have been traversed, then we can rewrite select of pointers as a + // select of base and offset + ValueRange fatPtrFalse = adaptor.getFalseValue(); + ValueRange fatPtrTrue = adaptor.getTrueValue(); + // Simple case of a scalar select: update the base pointer + if (!isa(selectOp.getType())) { + auto newSelectOp = rewriter.create( + selectOp.getLoc(), selectOp.getType(), selectOp.getCondition(), + fatPtrTrue[0], selectOp.getFalseValue()); + rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}}); + fatPtrs[{newSelectOp, /*fatPtrOffset*/ fatPtrTrue[1]}] = + fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}); + return success(); + } + + // Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF) + auto newBase = rewriter.create( + selectOp.getLoc(), selectOp.getCondition(), fatPtrTrue[0], + fatPtrFalse[0]); + auto newOffset = rewriter.create( + selectOp.getLoc(), selectOp.getCondition(), fatPtrTrue[1], + fatPtrFalse[1]); + + assert((fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}) == + fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]})) && + "expected can narrow to be the same for both fatPtrT and fatPtrF"); + + rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}}); + fatPtrs[{newBase, newOffset}] = fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}); + + return success(); + } +}; + +/// Rewrite result type only after both arms have been visited. +/// We contrive this to happen, even though DialectConversion does a PreOrder +/// walk, by checking for two attributes in the ConversionTarget +/// ("then_rewritten", and "else_rewritten"). +class ConvertSCFIfOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite_(scf::IfOp ifOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(ifOp.thenYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) && + "expected then yield to report fat ptr indices"); + + bool withElseRegion = ifOp.getNumRegions() > 1; + +#ifndef NDEBUG + if (withElseRegion) { + assert(ifOp.thenYield().getOperandTypes() == + ifOp.elseYield().getOperandTypes() && + "ifOp types must match in both arms"); + if (auto thenFatPtrIndxs = ifOp.thenYield()->getDiscardableAttr( + kSCFIfOpYieldFatPtrOffsets)) { + assert(ifOp.elseYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) && + "expected then yield to report fat ptr indices"); + auto elseFatPtrIndxs = + ifOp.elseYield()->getDiscardableAttr(kSCFIfOpYieldFatPtrOffsets); + assert(elseFatPtrIndxs && + "expected else fat ptr indices as well as then fat ptr indices"); + + DenseI64ArrayAttr thenIdxs = + llvm::dyn_cast(thenFatPtrIndxs); + DenseI64ArrayAttr elseIdxs = + llvm::dyn_cast(elseFatPtrIndxs); + assert(bool(thenIdxs) && bool(elseIdxs) && + "expected else fat ptr index attrs to be DenseI64ArrayAttr"); + for (auto [i, j] : + llvm::zip(thenIdxs.asArrayRef(), elseIdxs.asArrayRef())) { + assert(i == j && + "expected thenFatPtrIndxs and elseFatPtrIndxs to agree"); + assert(i < ifOp.thenYield().getNumOperands() && + i + 1 < ifOp.thenYield().getNumOperands() && + "expected idx to be within bounds of IfOp's results"); + Value thenFatPtrBase = ifOp.thenYield().getOperand(i); + Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1); + Value elseFatPtrBase = ifOp.elseYield().getOperand(i); + Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1); + assert((fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}) == + fatPtrs.at({elseFatPtrBase, elseFatPtrOffset})) && + "expected then fat ptr canNarrow and else fat ptr canNarrow " + "to be equal"); + } + } + } +#endif + + auto newIfOp = rewriter.create( + ifOp.getLoc(), ifOp.thenYield().getOperandTypes(), ifOp.getCondition(), + withElseRegion); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newIfOp.thenBlock(), + newIfOp.thenBlock()->begin()); + if (withElseRegion) + rewriter.inlineBlockBefore(ifOp.elseBlock(), newIfOp.elseBlock(), + newIfOp.elseBlock()->begin()); + + rewriter.replaceOpWithMultiple(ifOp, {newIfOp.getResults()}); + + for (int64_t idx : + llvm::cast(newIfOp.thenYield()->getDiscardableAttr( + kSCFIfOpYieldFatPtrOffsets)) + .asArrayRef()) { + Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx); + Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1); + fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}] = + fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}); + } + + return success(); + } +}; + +/// Rewrite the non-cond operands and the signature of the dest bb. +class ConvertCFBranch : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite_(cf::BranchOp branchOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ArrayRef remappedDestOperands = adaptor.getDestOperands(); + SmallVector trueOperands = flattenValues(remappedDestOperands); + + rewriter.replaceOpWithNewOp(branchOp, branchOp.getDest(), + trueOperands); + convertSimpleBlockSignature(branchOp.getDest(), remappedDestOperands, + rewriter, fatPtrs); + return success(); + } +}; + +/// Rewrite to expand(base, offset) -> base, expand(offset) +class ConvertExpandDims + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite_(tt::ExpandDimsOp expandOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedOperands = adaptor.getSrc(); + if (remappedOperands.size() != 2) + return success(); + Value fatPtrBase = remappedOperands[0]; + if (!llvm::isa(fatPtrBase.getType())) + return rewriter.notifyMatchFailure( + expandOp, "only scalar base currently supported"); + Value fatPtrOffset = remappedOperands[1]; + + RankedTensorType result = + llvm::cast(expandOp->getResultTypes().front()); + if (!llvm::isa(result.getElementType())) + return rewriter.notifyMatchFailure( + expandOp, "expected expand_dim result to be tensor of tt.ptr"); + + RankedTensorType newResult = RankedTensorType::get( + result.getShape(), + llvm::cast(fatPtrOffset.getType()).getElementType(), + result.getEncoding()); + auto newOffset = rewriter.create( + expandOp.getLoc(), newResult, fatPtrOffset, adaptor.getAxis()); + rewriter.replaceOpWithMultiple(expandOp, {{fatPtrBase, newOffset}}); + fatPtrs[{fatPtrBase, newOffset}] = fatPtrs.at({fatPtrBase, fatPtrOffset}); + + return success(); + } +}; + +/// convert integer offset, keep base +class ConvertConvertLayoutOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(tt::gpu::ConvertLayoutOp cvtOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedOperands = adaptor.getSrc(); + if (remappedOperands.size() != 2) { + // some prior op materialized the fat ptr, e.g.: + // %3 = tt.bitcast %2 + // %4 = tt.splat %3 + return success(); + } + Value fatPtrBase = remappedOperands[0]; + Value fatPtrOffset = remappedOperands[1]; + if (!llvm::isa(fatPtrBase.getType())) { + return rewriter.notifyMatchFailure(cvtOp, + "non tt.ptr base unimplemented"); + } + auto offsetTensorTy = dyn_cast(fatPtrOffset.getType()); + if (!offsetTensorTy) { + return rewriter.notifyMatchFailure( + cvtOp, "non RankedTensorType offset unimplemented"); + } + + RankedTensorType outType = cvtOp.getResult().getType(); + auto newOffsetType = RankedTensorType::get(outType.getShape(), + offsetTensorTy.getElementType(), + outType.getEncoding()); + tt::gpu::ConvertLayoutOp cvtOffset = + rewriter.create(cvtOp.getLoc(), newOffsetType, + fatPtrOffset); + rewriter.replaceOpWithMultiple(cvtOp, {{fatPtrBase, cvtOffset}}); + fatPtrs[{fatPtrBase, cvtOffset}] = fatPtrs.at({fatPtrBase, fatPtrOffset}); + + return success(); + } +}; + +template +class MaterializeFatPointer : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern< + SourceOp>::PointerCanonicalizationPattern; + + LogicalResult matchAndRewrite_( + SourceOp op, + typename PointerCanonicalizationPattern::OneToNOpAdaptor + adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!llvm::isa( + getElementTypeOrSelf(op->getOperandTypes()[PtrLikeIdx]))) + return rewriter.notifyMatchFailure(op, + "expected operand to be pointer-like"); + ValueRange fatPtr = adaptor.getOperands()[PtrLikeIdx]; + if (fatPtr.size() != 2) { + // some prior op materialized the fat ptr, e.g.: + // %3 = tt.bitcast %2 + // %4 = tt.load %3 + return success(); + } + + Value fatPtrBase = fatPtr[0]; + Value fatPtrOffset = fatPtr[1]; + Location curLoc = op.getLoc(); + + const FatPointers::FatPtrAttrs &fatPtrAttrs = + this->fatPtrs.at({fatPtrBase, fatPtrOffset}); + SmallVector operands = op->getOperands(); + operands[PtrLikeIdx] = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, fatPtrAttrs); + + if (op->getNumResults()) + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), ValueRange{operands}, op->getAttrs()); + else + rewriter.replaceOpWithNewOp( + op, TypeRange{}, ValueRange{operands}, op->getAttrs()); + return success(); + } +}; + +template +class MaterializeFatPointerVariadic + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern< + SourceOp>::PointerCanonicalizationPattern; + + LogicalResult matchAndRewrite_( + SourceOp op, + typename PointerCanonicalizationPattern::OneToNOpAdaptor + adaptor, + ConversionPatternRewriter &rewriter) const override { + Location curLoc = op.getLoc(); + SmallVector operands = op->getOperands(); + for (auto [i, maybeFatPtr] : llvm::enumerate(adaptor.getOperands())) { + if (maybeFatPtr.size() != 2) + continue; + Value fatPtrBase = maybeFatPtr[0]; + Value fatPtrOffset = maybeFatPtr[1]; + + const FatPointers::FatPtrAttrs &fatPtrAttrs = + this->fatPtrs.at({fatPtrBase, fatPtrOffset}); + Value newPtr = createTensorPointer(rewriter, fatPtrBase, fatPtrOffset, + curLoc, fatPtrAttrs); + operands[i] = newPtr; + } + + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + ValueRange{operands}, op->getAttrs()); + return success(); + } +}; + +static const std::string kInitFuncArgsRewritten = + kPtrCanonPrefix + "init-func-ptr-args"; +/// tt.func gets rewritten differently from all the other ops - the op itself is +/// not rewritten. What is rewritten are all tt.ptr args are rewritten (all +/// uses) to be %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr. This +/// unrealized_cast is then (possibly) materialized in the second pass +/// (ConvertUnimplementedOpUnrealizedCasts) if it wasn't DCEd (via a user +/// extracting the tt.ptr and c0 operands). +struct InitFuncPtrArgs : OpRewritePattern { + InitFuncPtrArgs(MLIRContext *context, FatPointers &fatPtrs) + : OpRewritePattern(context, 0), fatPtrs(fatPtrs) {} + + LogicalResult matchAndRewrite(tt::FuncOp newOp, + PatternRewriter &rewriter) const override { + if (newOp->hasAttr(kInitFuncArgsRewritten)) + return failure(); + + int64_t bitness = 64; + rewriter.setInsertionPointToStart(&newOp.getBody().front()); + for (auto [idx, arg] : llvm::enumerate(newOp.getArguments())) { + // The pointer argument needs to be a scalar + if (!isa(arg.getType())) + continue; + if (auto pointerRangeAttr = + newOp.getArgAttrOfType(idx, "tt.pointer_range")) + bitness = pointerRangeAttr.getInt(); + Value zeroOffset = + rewriter.create(newOp.getLoc(), 0, bitness); + auto dummyCast = rewriter.create( + arg.getLoc(), TypeRange{arg.getType()}, ValueRange{arg, zeroOffset}); + rewriter.replaceAllUsesExcept(arg, dummyCast.getResult(0), dummyCast); + fatPtrs[{arg, zeroOffset}].canNarrow = true; + } + + newOp->setDiscardableAttr(kInitFuncArgsRewritten, rewriter.getUnitAttr()); + return success(); + } + + FatPointers &fatPtrs; +}; + +/// No-op to make conversion framework happy. +class ConvertReturnOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(tt::ReturnOp returnOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto returns = flattenValues(adaptor.getSrcs()); + rewriter.replaceOpWithNewOp(returnOp, TypeRange{}, returns); + return success(); + } +}; + +class ConvertFuncOpArgsUnrealizedCasts + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(UnrealizedConversionCastOp castOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (castOp.use_empty()) { + castOp->getParentOfType().emitRemark( + "expected at least 1 use of unrealized_cast"); + return success(); + } + // Exhaustive checking we're converting ONLY unrealized_casts inserted (by + // the 1:N conversion) in ConvertFuncOp. + ArrayRef remappedOperands = adaptor.getOperands(); + if (remappedOperands.size() != 2 || remappedOperands[0].size() != 1 || + remappedOperands[1].size() != 1) + return rewriter.notifyMatchFailure( + castOp, "expected CastOp to have already been remapped"); + Value fatPtrBase = remappedOperands[0][0]; + if (!llvm::isa(fatPtrBase) || + !llvm::isa(fatPtrBase.getParentBlock()->getParentOp()) || + !llvm::isa(fatPtrBase.getType())) + return rewriter.notifyMatchFailure( + castOp, + "expected CastOp first operand to be tt.ptr block arg of tt.func"); + Value fatPtrOffset = remappedOperands[1][0]; + if (llvm::isa(fatPtrOffset) || + !llvm::isa(fatPtrOffset.getDefiningOp())) + return rewriter.notifyMatchFailure( + castOp, "expected CastOp second operand to be arith.constant"); + OpFoldResult maybeScalar = getAsOpFoldResult(fatPtrOffset); + auto maybeAttr = llvm::dyn_cast(maybeScalar); + + if (auto integerAttr = + llvm::dyn_cast_or_null(maybeAttr)) { + if (integerAttr.getValue() == 0) { + rewriter.replaceOpWithMultiple(castOp, {{fatPtrBase, fatPtrOffset}}); + return success(); + } + } + return rewriter.notifyMatchFailure( + castOp, + "expected CastOp second operand to be arith.constant with value 0"); + } +}; + +class ConvertUnimplementedOpUnrealizedCasts + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(UnrealizedConversionCastOp castOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (castOp.use_empty()) { + castOp.erase(); + return success(); + } + ArrayRef remappedOperands = adaptor.getOperands(); + if (remappedOperands.size() != 2) + return rewriter.notifyMatchFailure( + castOp, "expected CastOp to have already been remapped"); + Value fatPtrBase = remappedOperands[0][0]; + Value fatPtrOffset = remappedOperands[1][0]; + if (!llvm::isa(fatPtrBase.getType())) + return rewriter.notifyMatchFailure(castOp, + "non tt.ptr base unimplemented"); + + rewriter.setInsertionPointAfter(castOp); + + // shortcut if offset == 0, no need for addptr + OpFoldResult maybeScalar = getAsOpFoldResult(fatPtrOffset); + auto maybeAttr = llvm::dyn_cast(maybeScalar); + if (auto integerAttr = + llvm::dyn_cast_or_null(maybeAttr)) { + if (integerAttr.getValue() == 0) { + rewriter.replaceAllUsesWith(castOp.getResult(0), fatPtrBase); + rewriter.eraseOp(castOp); + return success(); + } + } + + const FatPointers::FatPtrAttrs &fatPtrAttrs = + fatPtrs.at({fatPtrBase, fatPtrOffset}); + auto newPtr = createTensorPointer(rewriter, fatPtrBase, fatPtrOffset, + castOp.getLoc(), fatPtrAttrs); + rewriter.replaceAllUsesWith(newPtr, fatPtrBase); + rewriter.eraseOp(castOp); + return success(); + } +}; + +/// The pass structure/action is roughly: +/// +/// 1. Perform an approximate sparse dataflow analysis to find all transitive +/// uses for `tt.func` args that are `tt.ptr`s; legalize only these ops; +/// 2. Rewrite all operations' `use`s and `result`s to be `(%baseptr, +/// %offsetptr)` using `ConversionPattern`s that takes the new +/// `OneToNOpAdaptor`, which automatically forwards both `%baseptr` and +/// `%offsetptr` through `adaptor.getOperands()`[^3]; +/// 3. Clean up remaining `unrealized_casts` (currently only handling one +/// category of such remaining casts but can be extended to handle all; see +/// bullet 1 in TODOs). +class TritonAMDGPUCanonicalizePointersPass + : public TritonAMDGPUCanonicalizePointersBase< + TritonAMDGPUCanonicalizePointersPass> { +public: + TritonAMDGPUCanonicalizePointersPass() = default; + + void runOnOperation() override; +}; + +/// Forward slice == transitive use +/// This is a port/adaptation of upstream's getForwardSliceImpl +/// that operates on values instead of ops so that we can track tt.ptr through +/// the operands/args of region ops like scf.for/scf.while. +/// It also handles scf.if in a special way beacuse scf.if does not have +/// operands. +/// +/// TODO(max): this is still just a heuristic approximation to a "dataflow +/// analysis" that "understands" the relationship between each operands and +/// results for each op (i.e., whether fat ptrs are actually propagated). +static void getForwardSliceImpl(OpOperand *use, Operation *op, + SetVector *forwardSlice) { + assert(use && op && "expected both use and op to be valid pointers"); + assert(use->getOwner() == op && "expected use's owner to be op"); + + if (!llvm::isa(getElementTypeOrSelf(use->get().getType()))) + return; + + // verbose because you can't construct from + SmallVector nextUses; + auto addUses = [&nextUses](const Value::use_range &uses) { + for (auto &use : uses) + nextUses.emplace_back(&use); + }; + + // all of this is necessary because both the LoopLikeInterface and + // BrancOpInterface are bad... + auto addBlockArgUses = [&use, &addUses]( + const Block::BlockArgListType &blockArgs, + unsigned argOffset = 0, unsigned useOffset = 0) { + for (auto arg : blockArgs) { + if (arg.getArgNumber() - argOffset == use->getOperandNumber() - useOffset) + addUses(arg.getUses()); + } + }; + + if (auto whileLoop = llvm::dyn_cast(op)) { + addBlockArgUses(whileLoop.getBeforeArguments()); + addBlockArgUses(whileLoop.getAfterArguments()); + } else if (auto forLoop = llvm::dyn_cast(op)) { + addBlockArgUses(forLoop.getRegionIterArgs(), forLoop.getNumInductionVars(), + forLoop.getNumControlOperands()); + } else if (auto branchOp = llvm::dyn_cast(op)) { + addBlockArgUses(branchOp.getDest()->getArguments()); + } else if (auto condBranchOp = llvm::dyn_cast(op)) { + // the 0th operand of cf.cond_br is the condition + addBlockArgUses(condBranchOp.getTrueDest()->getArguments(), /*argOffset*/ 0, + /*useOffset*/ 1); + addBlockArgUses(condBranchOp.getFalseDest()->getArguments(), + /*argOffset*/ 0, /*useOffset*/ 1); + } else if (auto yield = llvm::dyn_cast(op)) { + forwardSlice->insert(yield); + if (auto ifOp = llvm::dyn_cast(yield->getParentOp())) + op = ifOp; + } + + for (auto result : op->getResults()) + addUses(result.getUses()); + + for (OpOperand *nextUse : nextUses) { + auto owner = nextUse->getOwner(); + getForwardSliceImpl(nextUse, owner, forwardSlice); + } + + forwardSlice->insert(op); +} + +void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { + LLVM_DEBUG({ + llvm::dbgs() << "before tritonamdgpu-canonicalize-pointers\n"; + getOperation()->getParentOfType()->dump(); + llvm::dbgs() << "\n"; + }); + + auto func = getOperation(); + + FatPointers fatPrs; + PatternRewriter rewriter(&getContext()); + // Convert tt.func; %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr + InitFuncPtrArgs pat(&getContext(), fatPrs); + if (failed(pat.matchAndRewrite(func, rewriter))) + return signalPassFailure(); + + llvm::SetVector opsToRewrite; + for (auto arg : func.getArguments()) { + if (llvm::isa(arg.getType())) { + // NB: reusing the same SetVector invalidates the topo order implied by + // getForwardSlice + for (auto &use : arg.getUses()) + getForwardSliceImpl(&use, use.getOwner(), &opsToRewrite); + } + } + + ConversionConfig config; + config.buildMaterializations = false; + ConversionTarget target(getContext()); + auto isLegal = [&opsToRewrite](Operation *op) { + if (auto ifOp = llvm::dyn_cast(op)) { + // This is the only hack in the entire pass; on first traversal, + // `scf.if` will be walked over, but we do not want to rewrite it yet + // because the `yields` in the then/else regions haven't been rewritten + // yet (and those `yields` tell us the final result types of the + // `scf.if`). Therefore, we check for these attributes and if they're + // absent then the `scf.if` is legal. Once both `yields` have been + // rewritten (the corresponding attributes have been added), we report the + // `scf.if` as illegal, and it will be rewritten (the pattern will fire). + return !(ifOp->hasAttr(kSCFThenRewrittenAttr) && + ifOp->hasAttr(kSCFElseRewrittenAttr)); + } + return !opsToRewrite.contains(op); + }; + + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + + // Rewrite the rest of the ops. + // Note we *do not* declare unrealized_cast an illegal op here in order that + // the whole conversion passes, even if there are tt ops that we do not + // currently support (their operands will be handled by + // ConvertUnimplementedOpUnrealizedCasts below). Note we *do* add + // ConvertFuncOpArgsUnrealizedCasts because that is necessary for + // "initializing" the chain of fat pointers starting from tt.func tt.ptr args. + RewritePatternSet patterns(&getContext()); + patterns.add< + ConvertFuncOpArgsUnrealizedCasts, ConvertBroadcastOp, ConvertSplatOp, + ConvertConvertLayoutOp, ConvertAddPtrOp, + MaterializeFatPointer, + MaterializeFatPointer, + MaterializeFatPointer, MaterializeFatPointer, + MaterializeFatPointer, + MaterializeFatPointer, MaterializeFatPointer, + MaterializeFatPointerVariadic, + MaterializeFatPointerVariadic, ConvertSCFForOp, + ConvertExpandDims, ConvertSCFYieldOp, ConvertSCFIfOp, + ConvertSCFConditionOp, ConvertSCFWhileOp, ConvertCFCondBranch, + ConvertCFBranch, ConvertArithSelectOp, ConvertReturnOp>( + patterns.getContext(), opsToRewrite, fatPrs); + if (failed(applyPartialConversion(func, target, std::move(patterns), config))) + return signalPassFailure(); + + // Rewrite any lingering unrealized_casts that *should* only be the result of + // unsupported ops. + target.addIllegalOp(); + patterns.clear(); + patterns.add(patterns.getContext(), + opsToRewrite, fatPrs); + if (failed(applyPartialConversion(func, target, std::move(patterns), config))) + return signalPassFailure(); + + func->walk([](Operation *op) { + for (auto attr : op->getDiscardableAttrs()) { + if (attr.getName().strref().starts_with(kPtrCanonPrefix)) + op->removeDiscardableAttr(attr.getName()); + } + }); +} + +std::unique_ptr mlir::createTritonAMDGPUCanonicalizePointersPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp new file mode 100644 index 000000000..70586e4c1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -0,0 +1,560 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "third_party/amd/include/Analysis/RangeAnalysis.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +#undef DEBUG_TYPE +#define DEBUG_TYPE "tritonamdgpu-convert-buffer-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using ::mlir::LLVM::AMD::getVectorSize; +using mlir::triton::AMD::ISAFamily; + +namespace ttg = mlir::triton::gpu; +namespace tt = mlir::triton; + +namespace { + +bool verifyNonSmallerByAssumption( + Value expr, const DenseSet &assumptions, + const std::function &matchesOther) { + for (Value assume : assumptions) { + if (auto cmpOp = assume.getDefiningOp()) { + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::eq: + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sgt: { + if (cmpOp.getLhs() == expr && matchesOther(cmpOp.getRhs())) { + LDBG(" " << expr << " non-neg by assumption " << cmpOp); + return true; + } + break; + } + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: { + if (cmpOp.getRhs() == expr && matchesOther(cmpOp.getLhs())) { + LDBG(" " << expr << " non-neg by assumption " << cmpOp); + return true; + } + break; + } + default: + break; + } + } + } + return false; +} + +bool verifyNonNegativeByAssumption(Value expr, + const DenseSet &assumptions) { + return verifyNonSmallerByAssumption(expr, assumptions, [](auto otherExpr) { + APInt cst; + return matchPattern(otherExpr, m_ConstantInt(&cst)) && cst.isNonNegative(); + }); +} + +bool verifyNonSmallerByAssumption(Value expr, + const DenseSet &assumptions, + Value other) { + return verifyNonSmallerByAssumption( + expr, assumptions, [&](auto otherAssum) { return otherAssum == other; }); +} + +bool verifyNonNegativeExpr(Value expr, const DenseSet &assumptions, + std::shared_ptr solver) { + LDBG("Determing if non-negative: " << expr); + + if (!llvm::isa(expr) && + succeeded(dataflow::staticallyNonNegative(*solver, expr))) { + return true; + } + + // Check if the expression is contained in any assumption + if (verifyNonNegativeByAssumption(expr, assumptions)) { + return true; + } + + // Recurse if the operation is defined + Operation *op = expr.getDefiningOp(); + if (!op) { + LDBG(" No defining op, assuming possibly negative"); + return false; + } + + bool nonNegative = + llvm::TypeSwitch(expr.getDefiningOp()) + // Various unary triton ops that don't change the sign of the operand + .Case([&](auto unaryOp) { + return verifyNonNegativeExpr(unaryOp.getOperand(), assumptions, + solver); + }) + .Case([&](auto gatherOp) { + return verifyNonNegativeExpr(gatherOp.getSrc(), assumptions, + solver); + }) + // Joining two non-negative tensors is still non-negative + .Case([&](auto joinOp) { + return verifyNonNegativeExpr(joinOp.getLhs(), assumptions, + solver) && + verifyNonNegativeExpr(joinOp.getRhs(), assumptions, solver); + }) + // Returns a tensor representing histogram: historgrams only contain + // buckets of non-negative values. + .Case([&](auto) { return true; }) + .Case([&](auto makeRangeOp) { + // See the warning in TritonOps.td: getStart/getEnd return unsigned, + // so we need to look through get*Attr. + return makeRangeOp.getStartAttr().getInt() >= 0 && + makeRangeOp.getEndAttr().getInt() >= 0; + }) + .Case( + [&](auto constIntOp) { return constIntOp.value() >= 0; }) + .Case([&](arith::ConstantOp constOp) { + Value val = constOp.getResult(); + DenseIntElementsAttr constVal; + if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat()) + return constVal.getSplatValue().isNonNegative(); + return false; + }) + .Case([&](auto) { + // These are defined as signless, but are actually unsigned + return true; + }) + .Case([&](auto maxOp) { + // max(a,b) >= 0 iff a>=0 || b>=0 + return verifyNonNegativeExpr(maxOp.getLhs(), assumptions, solver) || + verifyNonNegativeExpr(maxOp.getRhs(), assumptions, solver); + }) + .Case([&](auto remsiOp) { + // a % b >= 0 iff a>=0 + return verifyNonNegativeExpr(remsiOp.getLhs(), assumptions, solver); + }) + .Case([&](Operation *unaryOp) { + // a = OP b >= 0 iff b >= 0 + return verifyNonNegativeExpr(unaryOp->getOperand(0), assumptions, + solver); + }) + // Casting from arbitrary data does *not* guarantee the offset is in + // range (even if pointer, or the data is non-negative when + // interpreted as the src's type). + .Case( + [&](auto) { return false; }) + .Case( + // These OPs also return unsigned values. + // TODO: We can also sniff whether a Value is unsigned by looking + // for whether or not it's used as an argument to one of + // these OPs. + [&](auto uOp) { return true; }) + .Case( + // Generally speaking, a OP b >= 0 iff a >= 0 && b >= 0 when + // OP != sub + [&](Operation *binOp) { + return verifyNonNegativeExpr(binOp->getOperand(0), assumptions, + solver) && + verifyNonNegativeExpr(binOp->getOperand(1), assumptions, + solver); + }) + // TODO: more scf + .Case([&](auto ifOp) { + auto results = ifOp.getResults(); + auto it = std::find(results.begin(), results.end(), expr); + assert(it != results.end() && "expr should be the result of ifOp"); + auto resultIdx = it - results.begin(); + + // If we're here then we must have both then/else regions + // (each with 1 block) and each region must terminate with an + // `scf.yield` expression. + auto thenYield = cast(ifOp.thenYield()); + auto elseYield = cast(ifOp.elseYield()); + return verifyNonNegativeExpr(thenYield->getOperand(resultIdx), + assumptions, solver) && + verifyNonNegativeExpr(elseYield->getOperand(resultIdx), + assumptions, solver); + }) + .Case([&](auto op) { + // If a user annotates tl.assume(a >= b) then we know a - b >= 0 + return verifyNonSmallerByAssumption(op.getLhs(), assumptions, + op.getRhs()); + }) + .Default([&](Operation *) { + // Conservatively assume that the expression is negative + LDBG(" Unhandled op, cannot assume non-negative"); + return false; + }); + return nonNegative; +} + +// Quick analysis on the Triton IR to decide if we can safely use +// buffer operations +bool canUseBufferOps(Value ptr, const DenseSet &assumptions, + std::shared_ptr solver) { + // 1. Check if the pointer is uniform: i.e., if it comes from a uniform + // pointer(splatted) and non-uniform offset addition + + LDBG("Buffer op checks for: " << ptr); + auto addPtrOp = ptr.getDefiningOp(); + if (!addPtrOp) + return false; + + auto maybeSplatOp = addPtrOp.getPtr().getDefiningOp(); + if (!maybeSplatOp) + return false; + LDBG("Pattern matched"); + + // 2. Check if the offset is a 32-bit tensor + Value offset = addPtrOp.getOffset(); + if (cast(offset.getType()).getElementTypeBitWidth() != 32) + return false; + LDBG("32 bit offset"); + + return verifyNonNegativeExpr(offset, assumptions, std::move(solver)); +} + +// Extract stride of the blocked offset of LD/ST ops. +Value getBlockStride(Location loc, Value offset, PatternRewriter &rewriter) { + // canonicalize pointer pass sets block stride via + // `offset:add-broadcast-muli-splat`, backtrace that pattern to reach the + // stride. + if (auto maybeAdd = offset.getDefiningOp()) + for (auto addOpr : maybeAdd.getOperands()) + if (auto maybeBC = addOpr.getDefiningOp()) { + auto bcSrc = maybeBC.getSrc(); + if (auto maybeMul = bcSrc.getDefiningOp()) + for (auto mulOpr : maybeMul.getOperands()) + if (auto maybeSplat = mulOpr.getDefiningOp()) + return maybeSplat.getSrc(); + } + return nullptr; +} + +} // namespace + +struct ConvertTritonAtomicRMWOpToBufferAtomicRMW + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonAtomicRMWOpToBufferAtomicRMW( + mlir::MLIRContext *context, DenseSet &assumptions, + ModuleAxisInfoAnalysis &axisAnalysisPass, + std::shared_ptr solver) + : mlir::OpRewritePattern(context), + assumptions(assumptions), axisAnalysisPass(axisAnalysisPass), + solver(std::move(solver)) {} + + mlir::LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getPtr(); + auto atomicRmwOp = op.getAtomicRmwOp(); + auto sem = op.getSem(); + auto scope = op.getScope(); + + // In addition to the `canUserBufferOps` check, we should ensure that + // 1. Perform the canUserBufferOps check + if (!canUseBufferOps(ptr, assumptions, solver)) { + return rewriter.notifyMatchFailure(op, "canUseBufferOps check failed"); + } + + // 2. Check the scope. We support GPU and CTA for now (SYSTEM scope is not + // supported yet) + switch (scope) { + case MemSyncScope::GPU: + case MemSyncScope::CTA: + break; + default: + return rewriter.notifyMatchFailure(op, "RMW with unsupported scope"); + } + LDBG("RMW supported scope"); + + // 3. Check the memory ordering. + // TODO: support monotonic + switch (sem) { + case MemSemantic::RELAXED: + case MemSemantic::RELEASE: + case MemSemantic::ACQUIRE: + case MemSemantic::ACQUIRE_RELEASE: + break; + default: + return rewriter.notifyMatchFailure( + op, "RMW with unsupported memory ordering"); + } + + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + + // 4. Buffer atomic RMW does not support FP8 ops + // easier to just check what we support + auto checkType = getElementTypeOrSelf(op.getVal()); + bool isSupportedType = checkType.isF16() || checkType.isBF16() || + checkType.isF32() || checkType.isF64() || + checkType.isInteger(32) || checkType.isInteger(64); + if (!isSupportedType) { + return rewriter.notifyMatchFailure(op, "RMW with unsupported type"); + } + LDBG("RMW supported type"); + + auto vecSize = getVectorSize(ptr, axisAnalysisPass); + // f16/bf16 dtypes could only be efficiently calculated using instructions + // that pack 2 elements (e.g. @llvm.amdgcn.raw.buffer.atomic.fadd.v2f16) + if (vecSize % 2 != 0 && (checkType.isF16() || checkType.isBF16())) { + return rewriter.notifyMatchFailure( + op, "RMW float 16 dtypes must be aligned by 2"); + } + LDBG("RMW passed alignment check"); + + // 5. Check if the RMWOp is supported + switch (atomicRmwOp) { + case RMWOp::AND: + case RMWOp::OR: + case RMWOp::XOR: + case RMWOp::ADD: + case RMWOp::FADD: + case RMWOp::MAX: + case RMWOp::MIN: + case RMWOp::UMAX: + case RMWOp::UMIN: + case RMWOp::XCHG: + break; + default: + auto rmwOpStr = stringifyRMWOp(atomicRmwOp).str(); + return rewriter.notifyMatchFailure(op, "RMW with unsupported op: " + + rmwOpStr); + } + LDBG("RMW supported Op"); + + // 6. Buffer atomics support 32 and 64-bit operations, so inputs must be at + // least 32-bits. Otherwise, fall back to the existing path for atomics + auto opValueType = op.getVal().getType(); + auto opBitWidth = 0; + if (auto tensorType = dyn_cast(opValueType)) { + // We can't just compute the opBitWidth using the numElements * + // elemBitWidth here. In cases such as tensor<2xf16...>, if the elements + // are contiguous we can emit the buffer op. Otherwise, the buffer ops + // lowering will try to emit individual (unsupported) f16/bf16 ops. + auto elemBitWidth = tensorType.getElementTypeBitWidth(); + opBitWidth = vecSize * elemBitWidth; + } else { + opBitWidth = opValueType.getIntOrFloatBitWidth(); + } + + if (opBitWidth < 32) { + return rewriter.notifyMatchFailure(op, "RMW requires opBitWidth >= 32"); + } + + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter); + rewriter.replaceOpWithNewOp( + op, op.getVal().getType(), atomicRmwOp, basePtr, tensorOffset, + op.getVal(), blockStride, sem, scope, maybeMask); + + return success(); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; + ModuleAxisInfoAnalysis &axisAnalysisPass; + std::shared_ptr solver; +}; + +// Workaround to allow static_assert(false) on older compilers as it was +// ill-formed before defect report CWG2518 +// (https://cplusplus.github.io/CWG/issues/2518.html) +template struct always_false : std::false_type {}; + +template +struct ConvertTritonLoadToBufferLoad : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonLoadToBufferLoad(mlir::MLIRContext *context, + DenseSet &assumptions, + std::shared_ptr solver) + : mlir::OpRewritePattern(context), assumptions(assumptions), + solver(std::move(solver)) {} + + mlir::LogicalResult + matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getOperand(0); + + if (canUseBufferOps(ptr, assumptions, solver)) { + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + Value maybeOther{}; + if (op.getOther() && !isZeroConst(op.getOther())) + maybeOther = op.getOther(); + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter); + + auto bufferLoadOp = [&]() { + if constexpr (std::is_same_v) { + return rewriter.create( + op->getLoc(), op.getType(), basePtr, tensorOffset, blockStride, + op.getCache(), maybeMask, maybeOther); + } else if constexpr (std::is_same_v< + SourceOp, + triton::gpu::AsyncCopyGlobalToLocalOp>) { + return rewriter.create( + op->getLoc(), op.getType(), op.getResult(), basePtr, tensorOffset, + maybeMask, maybeOther, blockStride, op.getCache()); + } else { + static_assert(always_false::value, + "Unsupported type in ConvertTritonLoadToBufferLoad"); + } + }(); + + assert(bufferLoadOp); + + // Propagate `OpIdxAttr` if the currently processed `tt.LoadOp` was + // labeled it. The attribute needs to be preserved for custom instruction + // scheduling. + if (auto opIdxAttr = + op->template getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + bufferLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), + opIdxAttr); + } + rewriter.replaceOp(op, bufferLoadOp); + return success(); + } + + LDBG("Failed to convert: " << op); + return rewriter.notifyMatchFailure(op, "Failed to convert LoadOp"); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; + std::shared_ptr solver; +}; + +struct ConvertTritonStoreToBufferStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonStoreToBufferStore(mlir::MLIRContext *context, + DenseSet &assumptions, + std::shared_ptr solver) + : mlir::OpRewritePattern(context), + assumptions(assumptions), solver(std::move(solver)) {} + + mlir::LogicalResult + matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getPtr(); + + if (canUseBufferOps(ptr, assumptions, solver)) { + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter); + rewriter.replaceOpWithNewOp( + op, op.getValue(), basePtr, tensorOffset, blockStride, op.getCache(), + maybeMask); + return success(); + } + LDBG("Failed to convert: " << op); + return rewriter.notifyMatchFailure(op, "Failed to convert StoreOp"); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; + std::shared_ptr solver; +}; + +class TritonAMDGPUConvertToBufferOpsPass + : public TritonAMDGPUConvertToBufferOpsBase< + TritonAMDGPUConvertToBufferOpsPass> { + +public: + TritonAMDGPUConvertToBufferOpsPass() = default; + TritonAMDGPUConvertToBufferOpsPass(StringRef archGen) { + this->archGenerationName = archGen.data(); + }; + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp mod = getOperation(); + + // Collect assumptions in the function + DenseSet assumptions; + mod.walk([&](LLVM::AssumeOp op) { + auto oper = op->getOperand(0); + if (oper.getDefiningOp()) + assumptions.insert(oper); + }); + LLVM_DEBUG({ + DBGS() << "Number of assumptions found: " << assumptions.size() << "\n"; + for (Value assume : assumptions) { + DBGS() << "Assumption:" << assume << "\n"; + } + }); + + std::shared_ptr solver = createDataFlowSolver(); + solver->load(); + if (failed(solver->initializeAndRun(getOperation()))) + return signalPassFailure(); + + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + patterns.add, + ConvertTritonLoadToBufferLoad, + ConvertTritonStoreToBufferStore>(context, assumptions, solver); + + // Gate buffer atomics behind CDNA3 (i.e., MI300 series) for now + // GFX942-specific assumptions regarding cache coherence are made when + // lowering to LLVM + if (ISAFamily::CDNA3 == triton::AMD::deduceISAFamily(archGenerationName)) + patterns.add( + context, assumptions, axisInfoAnalysis, solver); + + if (applyPatternsGreedily(mod, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +std::unique_ptr +mlir::createTritonAMDGPUConvertToBufferOpsPass(std::string archGen) { + return std::make_unique(archGen); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp new file mode 100644 index 000000000..416ee581d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/HoistLayoutConversions.cpp @@ -0,0 +1,63 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +// Hoist convert_layout out of the loop if the src is defined out of the loop. +// This is a heuristic driven by optimizing fused attention kernels, in which +// we want to load Q tensor and keep it in register, instead of loading it +// (neither from global or shared memory) at every iteration of the loop. +static void hoistCvtDotOpOutOfLoop(ttg::ConvertLayoutOp cvtOp) { + // Check the dst of cvt has dotOperand layout + RankedTensorType rtType = dyn_cast(cvtOp.getType()); + if (!rtType) + return; + Attribute encoding = rtType.getEncoding(); + if (!encoding) + return; + if (!isa(encoding)) + return; + // Check the src of cvt is defined out of the loop + auto srcDefOp = cvtOp.getSrc().getDefiningOp(); + if (srcDefOp) { + scf::ForOp parentForOp = cvtOp->getParentOfType(); + if (parentForOp && !parentForOp->isAncestor(srcDefOp)) { + cvtOp->moveAfter(srcDefOp); + } + } +} + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +namespace { +struct TritonAMDGPUHoistLayoutConversionsPass + : public TritonAMDGPUHoistLayoutConversionsBase< + TritonAMDGPUHoistLayoutConversionsPass> { + + void runOnOperation() override { + tt::FuncOp funcOp = getOperation(); + + SmallVector cvtOps; + funcOp.walk([&](ttg::ConvertLayoutOp cvtOp) { cvtOps.push_back(cvtOp); }); + + for (auto cvtOp : cvtOps) + hoistCvtDotOpOutOfLoop(cvtOp); + } +}; +} // namespace + +std::unique_ptr mlir::createTritonAMDGPUHoistLayoutConversionsPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp new file mode 100644 index 000000000..cf4dbe63d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -0,0 +1,296 @@ +#include "TritonAMDGPUTransforms/MfmaGroup.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/DenseMap.h" +#include + +namespace mlir { +namespace { + +//===----------------------------------------------------------------------===// +// MFMA intrinsic query key +//===----------------------------------------------------------------------===// + +// The tuple used as key to query MFMA intrinsic map. +using MfmaKey = + std::tuple; + +// Returns a key for querying an MFMA intrinsic for the given parameters. +// Updates the passed-in A/B element type to the chosen MFMA intrinsic's A/B +// element type if the chosen intrinsic is not a direct hit and will require +// emulation. +// +// This function adapts certain parameters so we can be flexible when trying +// to query with "mismatches". +MfmaKey composeMfmaKeyFor(unsigned version, unsigned mDim, unsigned nDim, + Type &aElemType, Type &bElemType, bool withScale, + bool useTF32) { + Type aET = aElemType, bET = bElemType; + Builder b(aElemType.getContext()); + if (withScale) { + assert(version == 4 && isF8F6F4(aET) && isF8F6F4(bET)); + // For MXFP types, we have the same intrinsic, which uses FP4 as the key + // in the MFMA map. So adjust to that. + aET = bET = b.getType(); + } else if (useTF32 && aET.isF32() && bET.isF32()) { + // In Triton we use fp32 with TF32 input precision to mean TF32 types. + // In the MFMA map we use the proper TF32 type. So "fix" it here. + assert(version == 3); + aET = bET = b.getType(); + } else if (version <= 3 && isa(aET) && + isa(bET)) { + // For the OCP FP8 E5M2 type, we can emulate the support for it with FP16. + aElemType = bElemType = aET = bET = b.getF16Type(); + } + return {version, mDim, nDim, aET.getTypeID(), bET.getTypeID()}; +} + +//===----------------------------------------------------------------------===// +// MFMA intrinsic map +//===----------------------------------------------------------------------===// + +using MfmaMapValue = + std::tuple; +using MfmaMap = llvm::DenseMap>; + +class MfmaDatabase { +public: + static const MfmaMap &get(MLIRContext *context) { + static MfmaDatabase db(context); + return db.mfmaMap; + } + +private: + explicit MfmaDatabase(MLIRContext *context); + + MfmaMap mfmaMap; +}; + +MfmaDatabase::MfmaDatabase(MLIRContext *context) { +// Macro for defining MFMA intrinsics at a specific gfx version. +#define TRITON_MFMA_v(v, m, n, aET, bET, symbol, k, kBase) \ + { \ + /*key=*/{v, m, n, aET.getTypeID(), bET.getTypeID()}, /*value=*/{ \ + {ROCDL::symbol::getOperationName(), k, kBase}, \ + } \ + } + +// For certain architectures, we can have two intrinsics with the same M/N but +// different K. Order matters here: case1 will be preferred to case2. +#define TRITON_MFMA_v_2case(v, m, n, aET, bET, symbol1, k1, kBase1, symbol2, \ + k2, kBase2) \ + { \ + /*key=*/{v, m, n, aET.getTypeID(), bET.getTypeID()}, /*value=*/{ \ + {ROCDL::symbol1::getOperationName(), k1, kBase1}, \ + {ROCDL::symbol2::getOperationName(), k2, kBase2}, \ + } \ + } +#define TRITON_MFMA_v4_2case(m, n, aET, bET, symbol1, k1, kBase1, symbol2, k2, \ + kBase2) \ + TRITON_MFMA_v_2case(4, m, n, aET, bET, symbol1, k1, kBase1, symbol2, k2, \ + kBase2) +#define TRITON_MFMA_v2_2case(m, n, aET, bET, symbol1, k1, kBase1, symbol2, k2, \ + kBase2) \ + TRITON_MFMA_v_2case(2, m, n, aET, bET, symbol1, k1, kBase1, symbol2, k2, \ + kBase2) + +// Macro for defining MFMA intrinsics existing in multiple gfx versions. +#define TRITON_MFMA_v1to2(m, n, aET, bET, symbol, k, kBase) \ + TRITON_MFMA_v(1, m, n, aET, bET, symbol, k, kBase), \ + TRITON_MFMA_v(2, m, n, aET, bET, symbol, k, kBase) + +#define TRITON_MFMA_v2to3(m, n, aET, bET, symbol, k, kBase) \ + TRITON_MFMA_v(2, m, n, aET, bET, symbol, k, kBase), \ + TRITON_MFMA_v(3, m, n, aET, bET, symbol, k, kBase) + +#define TRITON_MFMA_v3to4(m, n, aET, bET, symbol, k, kBase) \ + TRITON_MFMA_v(3, m, n, aET, bET, symbol, k, kBase), \ + TRITON_MFMA_v(4, m, n, aET, bET, symbol, k, kBase) + +#define TRITON_MFMA_v2to4(m, n, aET, bET, symbol, k, kBase) \ + TRITON_MFMA_v(2, m, n, aET, bET, symbol, k, kBase), \ + TRITON_MFMA_v3to4(m, n, aET, bET, symbol, k, kBase) + +#define TRITON_MFMA_v1to3(m, n, aET, bET, symbol, k, kBase) \ + TRITON_MFMA_v(1, m, n, aET, bET, symbol, k, kBase), \ + TRITON_MFMA_v2to3(m, n, aET, bET, symbol, k, kBase) + +#define TRITON_MFMA_v1to4(m, n, aET, bET, symbol, k, kBase) \ + TRITON_MFMA_v(1, m, n, aET, bET, symbol, k, kBase), \ + TRITON_MFMA_v2to4(m, n, aET, bET, symbol, k, kBase) + + Builder b(context); + auto f32T = b.getF32Type(); + auto tf32T = b.getTF32Type(); + auto f16T = b.getF16Type(); + auto bf16T = b.getBF16Type(); + auto i8T = b.getI8Type(); + auto amdFp8T = b.getType(); + auto amdBf8T = b.getType(); + auto ocpFp8T = b.getType(); + auto ocpBf8T = b.getType(); + auto fp4T = b.getType(); + + mfmaMap = { + // f32 inputs + // mfma_f32_32x32x2f32 + TRITON_MFMA_v1to4(32, 32, f32T, f32T, mfma_f32_32x32x2f32, 2, 1), + // mfma_f32_16x16x4f32 + TRITON_MFMA_v1to4(16, 16, f32T, f32T, mfma_f32_16x16x4f32, 4, 1), + // mfma_f32_4x4x1f32 / mfma_f32_4x4x1_16B_f32 + TRITON_MFMA_v1to4(4, 4, f32T, f32T, mfma_f32_4x4x1f32, 16, 1), + TRITON_MFMA_v1to4(4, 64, f32T, f32T, mfma_f32_4x4x1f32, 1, 1), + TRITON_MFMA_v1to4(64, 4, f32T, f32T, mfma_f32_4x4x1f32, 1, 1), + + // xf32 + // mfma.xf32.16x16x8xf32 + TRITON_MFMA_v(3, 16, 16, tf32T, tf32T, mfma_f32_16x16x8_xf32, 8, 2), + // mfma.xf32.32x32x4.xf32 + TRITON_MFMA_v(3, 32, 32, tf32T, tf32T, mfma_f32_32x32x4_xf32, 4, 2), + + // f16 inputs + // mfma_f32_32x32x16_f16 & mfma_f32_32x32x8f16 + TRITON_MFMA_v4_2case(32, 32, f16T, f16T, mfma_f32_32x32x16_f16, 16, 8, + mfma_f32_32x32x8f16, 8, 4), + // mfma_f32_32x32x8f16 + TRITON_MFMA_v1to3(32, 32, f16T, f16T, mfma_f32_32x32x8f16, 8, 4), + // mfma_f32_16x16x32_f16 & mfma_f32_16x16x16f16 + TRITON_MFMA_v4_2case(16, 16, f16T, f16T, mfma_f32_16x16x32_f16, 32, 8, + mfma_f32_16x16x16f16, 16, 4), + // mfma_f32_16x16x16f16 + TRITON_MFMA_v1to3(16, 16, f16T, f16T, mfma_f32_16x16x16f16, 16, 4), + // mfma_f32_4x4x4f16 + TRITON_MFMA_v1to4(4, 4, f16T, f16T, mfma_f32_4x4x4f16, 64, 4), + TRITON_MFMA_v1to4(4, 64, f16T, f16T, mfma_f32_4x4x4f16, 4, 4), + TRITON_MFMA_v1to4(64, 4, f16T, f16T, mfma_f32_4x4x4f16, 4, 4), + + // bf16 inputs + // mfma_f32_32x32x16_bf16 & mfma_f32_32x32x8_bf16_1K + TRITON_MFMA_v4_2case(32, 32, bf16T, bf16T, mfma_f32_32x32x16_bf16, 16, 8, + mfma_f32_32x32x8bf16_1k, 8, 4), + TRITON_MFMA_v(3, 32, 32, bf16T, bf16T, mfma_f32_32x32x8bf16_1k, 8, 4), + // mfma_f32_32x32x8_bf16_1K & mfma_f32_32x32x4bf16_1k + TRITON_MFMA_v2_2case(32, 32, bf16T, bf16T, mfma_f32_32x32x8bf16_1k, 8, 4, + mfma_f32_32x32x4bf16_1k, 4, 2), + // mfma_f32_16x16x32_bf16 & mfma_f32_16x16x16_bf16_1K + TRITON_MFMA_v4_2case(16, 16, bf16T, bf16T, mfma_f32_16x16x32_bf16, 32, 8, + mfma_f32_16x16x16bf16_1k, 16, 4), + TRITON_MFMA_v(3, 16, 16, bf16T, bf16T, mfma_f32_16x16x16bf16_1k, 16, 4), + // mfma_f32_16x16x16_bf16_1K & mfma_f32_16x16x8_bf16 + TRITON_MFMA_v2_2case(16, 16, bf16T, bf16T, mfma_f32_16x16x16bf16_1k, 16, + 4, mfma_f32_16x16x8bf16, 8, 2), + // mfma_f32_32x32x4_bf16 + TRITON_MFMA_v(1, 32, 32, bf16T, bf16T, mfma_f32_32x32x4bf16, 4, 2), + // mfma_f32_16x16x8_bf16 + TRITON_MFMA_v(1, 16, 16, bf16T, bf16T, mfma_f32_16x16x8bf16, 8, 2), + // mfma_f32_4x4x4_bf16_1K + TRITON_MFMA_v2to4(4, 4, bf16T, bf16T, mfma_f32_4x4x4bf16_1k, 64, 4), + TRITON_MFMA_v2to4(4, 64, bf16T, bf16T, mfma_f32_4x4x4bf16_1k, 4, 4), + TRITON_MFMA_v2to4(64, 4, bf16T, bf16T, mfma_f32_4x4x4bf16_1k, 4, 4), + // mfma_f32_4x4x2_bf16 + TRITON_MFMA_v(1, 4, 4, bf16T, bf16T, mfma_f32_4x4x2bf16, 32, 2), + TRITON_MFMA_v(1, 4, 64, bf16T, bf16T, mfma_f32_4x4x2bf16, 2, 2), + TRITON_MFMA_v(1, 64, 4, bf16T, bf16T, mfma_f32_4x4x2bf16, 2, 2), + + // fp8/bf8 inputs + // mfma_f32_32x32x16_FP8_FP8 + TRITON_MFMA_v(4, 32, 32, ocpFp8T, ocpFp8T, mfma_f32_32x32x16_fp8_fp8, 16, + 8), + TRITON_MFMA_v(3, 32, 32, amdFp8T, amdFp8T, mfma_f32_32x32x16_fp8_fp8, 16, + 8), + // mfma_f32_32x32x16_FP8_BF8 + TRITON_MFMA_v(4, 32, 32, ocpFp8T, ocpBf8T, mfma_f32_32x32x16_fp8_bf8, 16, + 8), + TRITON_MFMA_v(3, 32, 32, amdFp8T, amdBf8T, mfma_f32_32x32x16_fp8_bf8, 16, + 8), + // mfma_f32_32x32x16_BF8_FP8 + TRITON_MFMA_v(4, 32, 32, ocpBf8T, ocpFp8T, mfma_f32_32x32x16_bf8_fp8, 16, + 8), + TRITON_MFMA_v(3, 32, 32, amdBf8T, amdFp8T, mfma_f32_32x32x16_bf8_fp8, 16, + 8), + // mfma_f32_32x32x16_BF8_BF8 + TRITON_MFMA_v(4, 32, 32, ocpBf8T, ocpBf8T, mfma_f32_32x32x16_bf8_bf8, 16, + 8), + TRITON_MFMA_v(3, 32, 32, amdBf8T, amdBf8T, mfma_f32_32x32x16_bf8_bf8, 16, + 8), + // mfma_f32_16x16x32_FP8_FP8 + TRITON_MFMA_v(4, 16, 16, ocpFp8T, ocpFp8T, mfma_f32_16x16x32_fp8_fp8, 32, + 8), + TRITON_MFMA_v(3, 16, 16, amdFp8T, amdFp8T, mfma_f32_16x16x32_fp8_fp8, 32, + 8), + // mfma_f32_16x16x32_FP8_BF8 + TRITON_MFMA_v(4, 16, 16, ocpFp8T, ocpBf8T, mfma_f32_16x16x32_fp8_bf8, 32, + 8), + TRITON_MFMA_v(3, 16, 16, amdFp8T, amdBf8T, mfma_f32_16x16x32_fp8_bf8, 32, + 8), + // mfma_f32_16x16x32_BF8_FP8 + TRITON_MFMA_v(4, 16, 16, ocpBf8T, ocpFp8T, mfma_f32_16x16x32_bf8_fp8, 32, + 8), + TRITON_MFMA_v(3, 16, 16, amdBf8T, amdFp8T, mfma_f32_16x16x32_bf8_fp8, 32, + 8), + // mfma_f32_16x16x32_BF8_BF8 + TRITON_MFMA_v(4, 16, 16, ocpBf8T, ocpBf8T, mfma_f32_16x16x32_bf8_bf8, 32, + 8), + TRITON_MFMA_v(3, 16, 16, amdBf8T, amdBf8T, mfma_f32_16x16x32_bf8_bf8, 32, + 8), + + // int8 inputs + // mfma_i32_32x32x32_i8 & mfma_i32_32x32x16i8 + TRITON_MFMA_v4_2case(32, 32, i8T, i8T, mfma_i32_32x32x32_i8, 32, 16, + mfma_i32_32x32x16_i8, 16, 8), + TRITON_MFMA_v(3, 32, 32, i8T, i8T, mfma_i32_32x32x16_i8, 16, 8), + // mfma_i32_32x32x8i8 + TRITON_MFMA_v1to2(32, 32, i8T, i8T, mfma_i32_32x32x8i8, 8, 4), + // mfma_i32_16x16x64_i8 & mfma_i32_16x16x32i8 + TRITON_MFMA_v4_2case(16, 16, i8T, i8T, mfma_i32_16x16x64_i8, 64, 16, + mfma_i32_16x16x32_i8, 32, 8), + TRITON_MFMA_v(3, 16, 16, i8T, i8T, mfma_i32_16x16x32_i8, 32, 8), + // mfma_i32_16x16x16i8 + TRITON_MFMA_v1to2(16, 16, i8T, i8T, mfma_i32_16x16x16i8, 16, 4), + // mfma_i32_4x4x4i8 + TRITON_MFMA_v1to4(4, 4, i8T, i8T, mfma_i32_4x4x4i8, 64, 4), + TRITON_MFMA_v1to4(4, 64, i8T, i8T, mfma_i32_4x4x4i8, 4, 4), + TRITON_MFMA_v1to4(64, 4, i8T, i8T, mfma_i32_4x4x4i8, 4, 4), + + // Scaled mfma f8f6f4 + // mfma_scale_F32_16x16x128_F8F6F4 + TRITON_MFMA_v(4, 16, 16, fp4T, fp4T, mfma_scale_f32_16x16x128_f8f6f4, 128, + 32), + // mfma_scale_F32_32x32x64_F8F6F4 + TRITON_MFMA_v(4, 32, 32, fp4T, fp4T, mfma_scale_f32_32x32x64_f8f6f4, 64, + 32), + }; +} + +} // namespace + +//===----------------------------------------------------------------------===// +// MFMA intrinsic selection +//===----------------------------------------------------------------------===// + +FailureOr +MfmaIntrinsic::selectFor(int version, unsigned mDim, unsigned nDim, + unsigned inputKDim, Type aElemType, Type bElemType, + bool withScale, bool useTF32) { + const MfmaMap &mfmaMap = MfmaDatabase::get(aElemType.getContext()); + MfmaKey key = composeMfmaKeyFor(version, mDim, nDim, aElemType, bElemType, + withScale, useTF32); + + auto it = mfmaMap.find(key); + if (it == mfmaMap.end()) + return failure(); + + const SmallVector &values = it->second; + + // If We have more than one instrinsics, prefer those with a larger K. + for (const auto [symbol, k, kBase] : llvm::drop_end(values)) { + if (inputKDim >= k) + return MfmaIntrinsic(symbol, mDim, nDim, k, kBase, aElemType, bElemType); + } + + // We always have one choice--the only / smallest-K intrinsic. + auto [symbol, k, kBase] = values.back(); + return MfmaIntrinsic(symbol, mDim, nDim, k, kBase, aElemType, bElemType); +} +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp new file mode 100644 index 000000000..2212fe8d5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; + +namespace { + +bool isOneOperandElementwiseOp(Operation *op) { + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (auto externElementwiseOp = dyn_cast(op)) + return op->getNumOperands() == 1 && op->getNumResults() == 1 && + externElementwiseOp.getPure(); + return false; +} + +// convert(val) : xmma -> blocked +// elementWiseOp(val) : blocked +// ... +// elementWiseOp(val) : blocked +// tt.store(ptr, val, mask, ...) : blocked +// ==> +// convert(ptr) : blocked -> xmma +// convert(mask) : blocked -> xmma +// elementWiseOp(val) : xmma +// ... +// elementWiseOp(val) : xmma +// tt.store(ptr, val, mask, ...) : xmma +// +// Store with xmma layout directly +// +// xmma layout is either MFMA or WMMA +class BypassEpilogueSMEM : public mlir::RewritePattern { + +public: + explicit BypassEpilogueSMEM(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::StoreOp::getOperationName(), 1, context) {} + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + + auto stOp = dyn_cast(op); + if (!stOp) + return mlir::failure(); + Value ptr = stOp.getPtr(); + Value val = stOp.getValue(); + Value mask = stOp.getMask(); + auto ptrType = dyn_cast(ptr.getType()); + auto valType = dyn_cast(val.getType()); + if (!ptrType || !valType || + !isa(ptrType.getEncoding()) || + !isa(valType.getEncoding())) + return mlir::failure(); + + llvm::SmallVector chainedOps; + while (true) { + auto chainedOp = val.getDefiningOp(); + if (!chainedOp) + return mlir::failure(); + if (llvm::isa(chainedOp)) + break; + if (!chainedOp->hasOneUse()) + return mlir::failure(); + if (!isOneOperandElementwiseOp(chainedOp)) + return mlir::failure(); + val = chainedOp->getOperand(0); + chainedOps.push_back(chainedOp); + } + + auto cvtOp = val.getDefiningOp(); + if (!cvtOp) + return mlir::failure(); + + auto encoding = cvtOp.getSrc().getType().getEncoding(); + if (!isa(encoding)) + return mlir::failure(); + + if (!cvtOp.getResult().hasOneUse()) + return mlir::failure(); + + auto newEncoding = + cast(cvtOp.getSrc().getType()).getEncoding(); + + auto newVal = cvtOp.getSrc(); + + auto newPtrType = RankedTensorType::get( + ptrType.getShape(), ptrType.getElementType(), newEncoding); + Value newPtr = rewriter.create( + ptr.getLoc(), newPtrType, ptr); + + for (auto chainedOp : llvm::reverse(chainedOps)) { + auto oldType = + cast(chainedOp->getResult(0).getType()); + chainedOp->setOperand(0, newVal); + newVal = llvm::cast>( + chainedOp->getResult(0)); + auto newType = mlir::RankedTensorType::get( + oldType.getShape(), oldType.getElementType(), newEncoding); + newVal.setType(newType); + } + + Value newMask = mask; + if (mask) { + auto maskType = dyn_cast(mask.getType()); + auto newMaskType = RankedTensorType::get( + maskType.getShape(), maskType.getElementType(), newEncoding); + newMask = rewriter.create( + mask.getLoc(), newMaskType, mask); + } + + rewriter.replaceOpWithNewOp( + stOp, newPtr, newVal, newMask, stOp.getCache(), stOp.getEvict()); + return mlir::success(); + } +}; + +} // namespace + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +class TritonAMDGPUOptimizeEpiloguePass + : public TritonAMDGPUOptimizeEpilogueBase< + TritonAMDGPUOptimizeEpiloguePass> { + +public: + TritonAMDGPUOptimizeEpiloguePass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + + patterns.add(context); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + } +}; + +std::unique_ptr mlir::createTritonAMDGPUOptimizeEpiloguePass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp new file mode 100644 index 000000000..23f1a5c21 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -0,0 +1,365 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/PassManager.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +// Return true if the given funcOp is a pure matmul problem; i.e., +// a single main loop with a single dot. +static bool isPureMatmulFunc(triton::FuncOp funcOp) { + bool isMatmul = true; + bool foundLoop = false; + funcOp.walk([&](scf::ForOp forOp) -> void { + int counter = 0; + forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); + isMatmul = (isMatmul && (counter == 1)); + foundLoop = true; + }); + return foundLoop && isMatmul; +} + +// Return true if the given ForOp contains a pure matmul problem; i.e., +// single dot and at least 2 glboal loads in the main loop. +static bool isPureMatmulLoop(scf::ForOp forOp) { + int dotCounter = 0; + int loadCounter = 0; + forOp.walk([&](Operation *op) { + if (isa(op)) + ++dotCounter; + else if (isa(op)) + ++loadCounter; + }); + return dotCounter == 1 && loadCounter >= 2; +} + +// Search through block to find earliest insertion point for move op. This can +// be either an atomic op or last usage of source pointer. Search ends when move +// op is encountered. +static llvm::ilist::iterator +findEarlyInsertionPoint(Block *block, Operation *move) { + Value src; + if (auto ld = dyn_cast(move)) + src = ld.getPtr(); + + auto ipnt = block->end(); + for (auto bi = block->begin(); bi != block->end(); ++bi) { + auto *op = &*bi; + if (op == move) // Don't move later than current location + break; + + op->walk([&](Operation *wop) { + if (src) { + // Check for ops accessing src value. + for (auto opr : wop->getOperands()) { + if (opr == src) + ipnt = bi; + } + } + // Atomics used for global synchronization. + if (isa(wop)) + ipnt = bi; + // Break at barrier + if (isa(wop)) + ipnt = bi; + // Break at loops. + if (isa(wop)) + ipnt = bi; + }); + } + return ipnt; +} + +// Return the first user in the same block of the given op. If the user is in a +// nested block then return the op owning the block. Return nullptr if not +// existing. +static Operation *getFirstUseInSameBlock(Operation *op) { + SmallVector usersInSameBlock; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + usersInSameBlock.push_back(ancestor); + } + auto minOpIt = + llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr; +} + +// Check if the operation opInsideLoop is inside any scf::ForOp and +// opOutsideLoop is not inside the same loop. +static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { + scf::ForOp parentForOp = opInsideLoop->getParentOfType(); + return parentForOp && !parentForOp->isAncestor(opOutsideLoop); +} + +//===----------------------------------------------------------------------===// +// Reorder mechanisms +//===----------------------------------------------------------------------===// + +// Sink dot layout conversions into loops to decrease register pressure when +// possible. +static void sinkDotConversion(triton::FuncOp funcOp) { + DenseMap opToMove; + funcOp.walk([&](ttg::ConvertLayoutOp op) { + Attribute encoding = op.getType().getEncoding(); + if (!isa_and_nonnull(encoding)) + return; + if (!op->hasOneUse()) + return; + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == + op->getParentOfType()) + return; + opToMove[op] = user; + }); + + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); +} + +// Sink conversion after the last dealloc but before the first use in its block. +// This helps to avoid unnecessary shared memory allocation. +static void moveDownCoversion(triton::FuncOp funcOp) { + SmallVector convertOps; + funcOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); }); + + for (auto op : convertOps) { + Operation *user = getFirstUseInSameBlock(op); + for (auto it = Block::iterator(op), ie = op->getBlock()->end(); + it != ie && &*it != user; ++it) + if (isa(&*it)) + op->moveAfter(&*it); + } +} + +// Move transpositions just after their definition. +static void moveUpTranspose(triton::FuncOp funcOp) { + SmallVector transOps; + funcOp.walk([&](triton::TransposeOpInterface op) { transOps.push_back(op); }); + + for (auto op : transOps) + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); +} + +// Schedule global load and local store ops for better GEMM performance. +static void scheduleGlobalLoadLocalStore(Operation *parentOp) { + SmallVector moveOps; + + // Search through the forOp initArgs to find global loads for a GEMM that + // the pipeliner may have peeled into a loop prologue. + if (auto forOp = dyn_cast(parentOp)) { + SmallVector vals = forOp.getInitArgs(); + while (!vals.empty()) { + SmallVector nextVals; // Next set of values to search via BFS. + for (size_t i = 0; i < vals.size(); ++i) { + Operation *defOp = vals[i].getDefiningOp(); + if (isa_and_nonnull(defOp)) { + moveOps.push_back(defOp); + continue; + } + + // Find uses of the op that are local_store + for (Operation *op : vals[i].getUsers()) { + if (auto storeOp = dyn_cast(op)) { + // Recurse on operands of the local_store (to find a global_load). + nextVals.push_back(storeOp.getSrc()); + } + } + } + vals.swap(nextVals); + } + } + + // Move local_store ops inside the loop early if dependence distance greater + // than one iteration (i.e., num_stages > 2). For such case, better perf on + // GEMM when local_store ops precede global loads. + parentOp->walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + // Move global_load ops inside the loop early to prefetch. This may increase + // register pressure but it enables issuing global loads early. + parentOp->walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + // Slice should inlcude values flowing into op regions + options.omitUsesFromAbove = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); + + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; + + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); + } + } +} + +//===-------------------------------------------------------------------===// +// Sched-load optimization for matmul kernels with large tile sizes +// The basic idea of sched-load optimization is to sink the 2nd tt.load +// after local_load so that global_load instructions can be interleaved with +// mfma's. This can help hide the issue latency of global_load instructions +// and improve performance on MI300X. +// +// It's assumed that the IR before this optimization has the following +// structure: +// ```mlir +// scf.for .. +// { +// tileA = tt.load a_ptr +// tileB = tt.load b_ptr +// opA = local_load bufferA +// opB = local_load bufferB +// res = tt.dot opA, opB +// local_store tileA, bufferA +// local_store tileB, bufferB +// } +// ``` +// After this optimization, the IR is transformed to +// ```mlir +// scf.for .. +// { +// tileA = tt.load a_ptr +// opA = local_load bufferA +// opB = local_load bufferB +// tileB = tt.load b_ptr <-- 2nd tt.load is sinked here +// res = tt.dot opA, opB +// local_store tileA, bufferA +// local_store tileB, bufferB +// } +// ``` +// For now, we don't have a perfect hueristic about when should this +// optimization be applied. Therefore, we implement a simple hueristic that +// this is applied when the tile size of A and B are large enough, i.e. +// nonKDim >= 128 and kDim >= 64. And also this is only applied for typical +// matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We +// are experimenting how to better control instruction scheduling and enable +// such optimizations. +//===-------------------------------------------------------------------===// +static void sinkSecondLoad(scf::ForOp forOp) { + SetVector loadOps; + triton::DotOp dotOp; + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + loadOps.insert(loadOp); + if (auto curOp = dyn_cast(&op)) + dotOp = curOp; + } + // Only apply the optimization when there are 2 load's in the loop + if (loadOps.size() != 2) + return; + + auto ldAOp = loadOps[0]; + auto loadAType = dyn_cast(ldAOp.getType()); + auto ldBOp = loadOps[1]; + auto loadBType = dyn_cast(ldBOp.getType()); + // Only apply the optimization when loading a 2D tensor + if (!loadAType || !loadBType) + return; + auto tileAShape = loadAType.getShape(); + auto tileBShape = loadBType.getShape(); + if (tileAShape.size() != 2 || tileBShape.size() != 2) + return; + // Only apply the optimization when tile size is large enough + // 1. nonKDim >= 128 + // 2. kDim >= 64 + if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) + return; + // Only apply the optimization when the moving is legal + // 1. Make sure the 2nd loadOp is before the dot + // 2. Make sure the first user of the 2nd loadOp is after the dot. + bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); + auto firstUser = *ldBOp.getResult().getUsers().begin(); + bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); + if (isBeforeDotOp && firstUserAfterDotOp) + // move ldBOp right before tt.dot + ldBOp->moveBefore(dotOp); +} + +//===----------------------------------------------------------------------===// +// Pass definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +namespace { +struct TritonAMDGPUReorderInstructionsPass + : public TritonAMDGPUReorderInstructionsBase< + TritonAMDGPUReorderInstructionsPass> { + void runOnOperation() override { + ModuleOp m = getOperation(); + for (auto funcOp : m.getOps()) { + sinkDotConversion(funcOp); + moveDownCoversion(funcOp); + + moveUpTranspose(funcOp); + + if (isPureMatmulFunc(funcOp)) { + scheduleGlobalLoadLocalStore(funcOp); + funcOp.walk([&](scf::ForOp forOp) -> void { sinkSecondLoad(forOp); }); + } else { + SmallVector leafForOps = triton::AMD::getLeafForOps(funcOp); + for (auto forOp : leafForOps) { + if (isPureMatmulLoop(forOp)) { + scheduleGlobalLoadLocalStore(forOp); + sinkSecondLoad(forOp); + } + } + } + } + } +}; +} // namespace + +std::unique_ptr mlir::createTritonAMDGPUReorderInstructionsPass() { + return std::make_unique(); +} diff --git a/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp new file mode 100644 index 000000000..47ba991d1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -0,0 +1,943 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create stream operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop and epilogue. +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +#define DEBUG_TYPE "tritonamdgpu-stream-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +static Operation *streamPredication(RewriterBase &rewriter, Operation *op, + Value pred) { + // The epilogue peeling generates a select for the stage output. This causes + // too much register pressure with the loop result and the epilogue-dot in + // regs for the select. Conditionally executing the dot will allow the backend + // to optimize the select away as redundant. + if (auto dotOp = dyn_cast(op)) { + auto loc = dotOp->getLoc(); + auto ifOp = rewriter.create(loc, dotOp.getResult().getType(), + pred, /*withElseRegion=*/true); + auto thenB = ifOp.getThenBodyBuilder(); + auto yield = thenB.create(loc, dotOp.getResult()); + dotOp->moveBefore(yield); + ifOp.getElseBodyBuilder().create(loc, dotOp.getC()); + return ifOp; + } + return tt::predicateOp(rewriter, op, pred); +} + +namespace { + +//===----------------------------------------------------------------------===// +// Software pipelining generally works by anchoring on global load ops in the +// main loop and rotating the loop to schedule global load ops for future loop +// iterations together with compute for the current iteration. In this way, we +// can 1) issue memory operations earlier to hide the latency and 2) break the +// strong dependency inside on loop iteration to give backends flexiblity to +// better interleave instructions for better instruction-level parallelism. +// +// This StreamPipeliner class creates the pipelining schedule and calls the +// PipelineExpander to rewrite the `scf.for` loop accordingly. A schedule +// consists of multiple stages, where ops from different stages can overlap +// executions because the dependencies are loop carried. +// +// The general flow of this process is: +// +// 1. The user provides a `num_stages` that specifies how many stages the +// pipeline will have. The number of stages must be larger than the distance +// from the first independent load to the compute in order to pipeline. +// 1.a. User may also specify `global_prefetch=` to set the number of +// stages between tt.load and ttg.local_store ops. +// 1.b. User may also specify `local_prefetch=` to set the number of +// stages between ttg.local_load and compute. +// 2. A schedule is created based on the distance between the global loads +// in the first stages and the compute that uses the loaded values in the +// last stage (num_stages - 1). Each operation will be clustered in the +// order to best overlap with other operations (see details below in the +// initSchedule method). +// 3. When the compute is a tt.dot, the scheduler will insert a shared +// memory allocation between the global load and tt.dot. The ttg.local_store +// will save the global load value to shared memory and the ttg.local_load +// will load the relevant tiles for the tt.dot. These operations will be +// scheduled according to various scheduling schemes outlined below in the +// initSchedule method (see details there). +// 4. Finally the schedule will be passed to the PipelineExpander to rewrite +// accordingly. The new implementation will consist of: +// a. Prologue: containing the ramp-up of num_stages-1 stages for +// iteratorions i=[0, num_stages-1). +// b. New loop: ordered by cluster and iterated on each operation by +// `i + (num_stages-op_stage)`. +// c. Epilogue: ramp-down of the last `num_stages-1` iterations for the +// ops in stages 1 to last_stage. This must consider that the loop +// bounds may be shorter than num_stages. In this case, the epilogue +// iterations must align with the prologue. +// +class StreamPipeliner { + // Define categories of scheduling details per Operation types. + // The StreamPipeliner schedules 4 types of operations: + // 1. GLOBAL_LOAD: tt.load + // 2. LOCAL_STORE: ttg.local_store (created by the StreamPipeliner) + // 3. LOCAL_LOAD: ttg.local_load (created by the StreamPipeliner) + // 4. COMPUTE: ops that use the loaded data + enum SchedType { + SCHED_GLOBAL_LOAD, + SCHED_LOCAL_STORE, + SCHED_LOCAL_LOAD, + SCHED_COMPUTE, + SCHED_SIZE + }; + +public: + StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch, + int _localPrefetch) + : forOp(_forOp), numBuffers(1), numStages(_numStages), + schedule(numStages), + axisInfoAnalysis(forOp->getParentOfType()) { + int lastStage = numStages - 1; + stages[SCHED_GLOBAL_LOAD] = 0; + stages[SCHED_LOCAL_STORE] = _globalPrefetch; + stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch; + stages[SCHED_COMPUTE] = lastStage; + + options.supportDynamicLoops = true; + options.peelEpilogue = true; + options.predicateFn = streamPredication; + } + + LogicalResult pipelineLoop(); + +private: + LogicalResult initSchedule(int maxIndirectionLevel); + + void computeLoadOpsToIndirectionLevelAndUse(); + void assignMemoryLayouts(); + LogicalResult scheduleLoads(DenseSet &rootUsers); + void scheduleDependencies(); + void scheduleDistanceOneDependencies(); + void scheduleRemainingToLastStage(); + + LogicalResult preprocessLoopAndBuildSchedule(); + + Value createAlloc(Operation *loadOp, + ttg::SwizzledSharedEncodingAttr sharedEnc); + void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx); + void createStreamOps(); + + void scheduleOp(Operation *op, SchedType type, int stage = -1) { + if (stage < 0) + stage = stages[type]; + schedule.insert(op, stage, clusters[type]); + } + +private: + // Data members + scf::ForOp forOp; + + // User settings + int numStages; + + // Computed number of buffers + int numBuffers; + + // Stage for each SchedType Op + int stages[SCHED_SIZE]; + // Cluster for each SchedType Op + std::array clusters; + + // Scheduling clusters + tt::CoarseSchedule schedule; + + // Mapping and indirection level for each `tt.load` to its use. + SmallVector> loadOpToIndLevelAndUse; + + struct LoadInfo { + // Shared layout is used for loads feeding into dot ops. + ttg::SwizzledSharedEncodingAttr sharedEncoding = nullptr; + // The distance of this load's stage to its use' stage. + int distToUse = 0; + bool usedByDot = false; + }; + + // Mapping for each pipelined load to scheduling details. + llvm::MapVector loadToInfo; + + // Lookup alignment/contiguity mappings for the current module. + tt::ModuleAxisInfoAnalysis axisInfoAnalysis; + + // Capture list of new shared memory buffers. + SmallVector sharedMemAllocs; + + // Pipelining options for the PipelineExpander + tt::PipeliningOption options; +}; + +} // namespace + +// Init Schedule Config based on settings and loop characteristics. +// Create clusters in order of ops in loop. This can interleave ops +// from different stages in the same cluster to achieve better backend +// scheduling. +// WARNING: Changing the order of schedule.clusters.newAtBack() calls +// can cause invalid schedules to be produced. +LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { + + bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; + stages[SCHED_LOCAL_STORE] += maxIndirectionLevel; + + LDBG( + "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] + << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] + << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] + << ", COMPUTE stage = " << stages[SCHED_COMPUTE] + << "; total = " << numStages); + + if (stages[SCHED_LOCAL_STORE] >= numStages || + stages[SCHED_LOCAL_STORE] > stages[SCHED_LOCAL_LOAD]) { + LDBG("Invalid stage schedule"); + return failure(); + } + + // Calculate the number of buffers needed for each load. + // TODO: Use the precise number of buffers needed by the particular load. + numBuffers = + std::max(1, stages[SCHED_LOCAL_LOAD] - stages[SCHED_LOCAL_STORE]); + LDBG("deduced max shared memory buffer number = " << numBuffers); + + // If tt.load and ttg.local_store are in the same stage + // spread them apart to allow overlap with compute + // else + // Initiate ttg.local_store before tt.load + int globalLoadCluster = 0; + int localStoreCluster = 2; + if (!pairedGlobalLoadLocalStore) { + globalLoadCluster = 2; + localStoreCluster = 1; + } + + // If ttg.local_load and ttg.local_store are in the same stage + // spread them apart to allow overlap with compute + // else if they share the buffer + // ttg.local_load must come first + // else + // schedule ttg.local_load in the middle + int localLoadCluster = globalLoadCluster; + if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_LOCAL_STORE]) { + localLoadCluster = std::max(2, localStoreCluster + 1); + } else if (numBuffers == 1 && localLoadCluster >= localStoreCluster) { + // For 1 buffer, ttg.local_load must occur before ttg.local_store + localLoadCluster = localStoreCluster - 1; + } + + // Schedule compute with ttg.local_load if paired + // otherwise, schedule in the middle + int computeCluster = 1; + if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_COMPUTE]) { + computeCluster = localLoadCluster; + } + + // Make assignments + std::array clusterVec = { + schedule.clusters.newAtBack(), schedule.clusters.newAtBack(), + schedule.clusters.newAtBack(), schedule.clusters.newAtBack()}; + + clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; + clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; + clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; + clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; + + LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster + << ", LOCAL_STORE cluster = " << localStoreCluster + << ", LOCAL_LOAD cluster = " << localLoadCluster + << ", COMPUTE cluster = " << computeCluster + << "; total = " << SCHED_SIZE); + + return success(); +} + +void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + + ttg::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + Operation *copy = builder.clone(*loadOp); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto subviewTy = ttg::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + // Clean up old local caches. + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + triton::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) + alloc.erase(); + + // Prefetch load ahead of the dot stage if is used by the dot. + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + scheduleOp(viewLoad, SCHED_LOCAL_STORE); + scheduleOp(storeOp, SCHED_LOCAL_STORE); + + // Create local load + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + Value result = sharedLoad.getResult(); + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) + scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); + + // If the currently processed `LoadOp` is labeled with an index regarding + // to which `DotOp` operand the corresponding data belongs to, then label the + // expanded `LocalStoreOp` with the same index. This is required for + // instruction scheduling hints to correctly count the emitted `ds_write` + // instructions for each GEMM tile. + if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { + storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); + } + + loadOp->replaceAllUsesWith(ValueRange{result}); + + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && result.hasOneUse()) { + if (auto cvt = dyn_cast(*result.getUsers().begin())) + scheduleOp(cvt, SCHED_LOCAL_LOAD); + } + + loadOp.erase(); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value val) { + ttg::SwizzledSharedEncodingAttr attr; + for (Operation *user : val.getUsers()) { + ttg::SwizzledSharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()) + .getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + SmallVector sharedOrder; + int rank = order.size(); + // TODO rework this when shared -> dotOperand conversions support + // arbitrary shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be + // the slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (order[i] != 0) + sharedOrder.emplace_back(order[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = order; + } + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, CTALayout, + bitWidth, /*needTrans=*/false); + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +// Create a map from load ops to their indirection levels and the final uses +// of the load op (another load op, or a dot op). +// +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +void StreamPipeliner::computeLoadOpsToIndirectionLevelAndUse() { + DenseSet seen; + + // Recursively visit the given op and its operands to discover all load ops + // and collect their indirection levels and uses. + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + // Skip previously visited load ops. + if (!seen.insert(op).second) + return; + + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.emplace_back(op, distance, use); + use = op; + ++distance; + } + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } +} + +// Goes through all load ops to identify those that can be pipelined and assign +// layout to them. +void StreamPipeliner::assignMemoryLayouts() { + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO: We'd need to verify that the distance is the same. + continue; + + LoadInfo loadInfo; + auto loadOp = cast(op); + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) { + LDBG("Skip non-tensor load " << *loadOp); + continue; + } + + auto pointeeTy = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * pointeeTy.getIntOrFloatBitWidth(); + + // Limit shared memory sharing to width >= 32 elements. + LDBG("Load " << *loadOp << " has width " << width); + if (width < 32) { + LDBG("Skip width<32 load " << *loadOp); + continue; + } + + if (isa(use)) { + // Only use shared memory when feeding into a dot op. + loadInfo.usedByDot = true; + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + } else if (auto useOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadToInfo already, it means that the use is not valid for pipelining + // for some reason. We should skip this loadOp, too. + // + // Note that we have an assumption that the use of this loadOp has already + // be processed in a previous loop iteration. This assumption is held by + // how loadOpsToIndirectionLevelAndUse recursively collects + // loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(useOp) == 0) { + continue; + } + } + + loadToInfo[op] = loadInfo; + } +} + +LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + computeLoadOpsToIndirectionLevelAndUse(); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return failure(); + + // Check which loads are good for pipelining, and assign them memory layouts. + assignMemoryLayouts(); + if (loadToInfo.empty()) + return failure(); + + // Filter out load ops that cannot be pipelined. + int resize = 0; + for (int i = 0, e = loadOpToIndLevelAndUse.size(); i < e; ++i) { + auto [loadOp, distance, use] = loadOpToIndLevelAndUse[i]; + if (loadToInfo.count(loadOp) != 0) + loadOpToIndLevelAndUse[resize++] = loadOpToIndLevelAndUse[i]; + } + loadOpToIndLevelAndUse.resize(resize); + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + + LDBG("maxIndirectionLevel = " << maxIndirectionLevel); + if (maxIndirectionLevel >= numStages) + return failure(); + + if (failed(initSchedule(maxIndirectionLevel))) + return failure(); + + // The stage gap between chained loads--this allows us to "spread" loads + // with a non-one step in case the number of stages given by the user is + // large. + assert(numStages >= 2 && "requires num_stages=2 at least"); + unsigned stagesBetweenLoads = + llvm::divideCeil(numStages - 2, maxIndirectionLevel + 1); + LDBG("stagesBetweenLoads = " << stagesBetweenLoads); + + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). + if (!isa(use)) { + scheduleOp(use, SCHED_COMPUTE); + rootUsers.insert(use); + } + } + + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); + } + + // Calculate distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + LLVM_DEBUG({ + LDBG("Chosen loads to pipeline:"); + for (const auto &[load, info] : loadToInfo) { + LDBG(" - load: " << *load); + LDBG(" distToUse: " << info.distToUse); + LDBG(" usedByDot: " << info.usedByDot); + } + }); + + return success(); +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void StreamPipeliner::scheduleDependencies() { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; ++stage) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +void StreamPipeliner::scheduleDistanceOneDependencies() { + auto getNestedOperands = [](Operation *op) { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap + dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + auto arg = dyn_cast(operand); + if (!arg || arg.getArgNumber() == 0 || arg.getOwner() != op.getBlock()) + continue; + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (!defOp || schedule.count(defOp) != 0) + continue; + if (isa(defOp)) { + // Exception: schedule loads with a distance of 1 together with the + // current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], true); + } + } + } +} + +void StreamPipeliner::scheduleRemainingToLastStage() { + int lastStage = numStages - 1; + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + auto cluster = clusters[SCHED_COMPUTE]; + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + opToCluster[&op] = cluster; + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == lastStage) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, lastStage, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +Value StreamPipeliner::createAlloc(Operation *loadOp, + ttg::SwizzledSharedEncodingAttr sharedEnc) { + OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), numBuffers); + Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + auto alloc = builder.create(loadOp->getLoc(), memdescType); + sharedMemAllocs.push_back(alloc); + return alloc; +} + +// Convert load ops into shared memory allocation loads and apply +// multi-buffering based on the required number of buffers. +void StreamPipeliner::createStreamOps() { + SmallVector> loadToAllocs; + for (auto &[loadOp, info] : loadToInfo) { + if (!info.sharedEncoding) + continue; + + Value alloc = createAlloc(loadOp, info.sharedEncoding); + assert(alloc && "Failed to create alloc for the async load."); + loadToAllocs.emplace_back(loadOp, alloc); + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, {extractIdx}); + forOp.erase(); + forOp = newForOp; + + // Create one counter for the extract indices to avoid creating long + // live range. + extractIdx = newForOp.getBody()->getArgument(newOperandIndex); + + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + // Create stream copies. + for (auto &[op, alloc] : loadToAllocs) { + if (auto loadOp = dyn_cast(op)) + createStreamCopy(loadOp, alloc, extractIdx); + } + // Patch the yield with the updated counters. + appendToForOpYield(forOp, {extractIdx}); +} + +LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + if (failed(scheduleLoads(rootUsers))) + return failure(); + if (loadToInfo.empty()) + return failure(); + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + schedule.dump(); + }); + + // Convert the loads into shared memory allocations and loads from them. + createStreamOps(); + + scheduleDependencies(); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + schedule.dump(); + }); + + scheduleDistanceOneDependencies(); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + schedule.dump(); + }); + + scheduleRemainingToLastStage(); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + schedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> coarseSchedule = + schedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [coarseSchedule](scf::ForOp, + std::vector> &s) { + s = std::move(coarseSchedule); + }; + + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + // Explicitly deallocate created allocations. + for (auto alloc : sharedMemAllocs) + builder.create(forOp.getLoc(), alloc); + + return success(); +} + +LogicalResult StreamPipeliner::pipelineLoop() { + if (failed(preprocessLoopAndBuildSchedule())) + return failure(); + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return tt::pipelineForLoop(rewriter, forOp, options); +} + +// Return true if the preconditions for pipelining the loop are met. +static bool checkPrecondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { return !operand.getDefiningOp(); })) + return false; + + auto hasInvalidOp = [forOp](Operation *op) { + // Don't pipeline outer loops. + if (op != forOp && isa(op)) + return WalkResult::interrupt(); + // Don't pipeline loops with barriers. + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }; + return !forOp->walk(hasInvalidOp).wasInterrupted(); +} + +namespace { +// Go through a single use chain to get the result of the target op after all +// unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. +template Operation *passPrevUnaryOps(Value value) { + auto getNextUnaryOps = [](Value value) -> Operation * { + if (auto defOp = value.getDefiningOp()) { + if ((defOp->getNumOperands() == 1) || llvm::dyn_cast(defOp)) + return defOp; + } + return nullptr; + }; + + auto unaryOp = getNextUnaryOps(value); + while (unaryOp) { + if (llvm::dyn_cast(unaryOp)) + return unaryOp; + unaryOp = getNextUnaryOps(unaryOp->getOperand(0)); + } + return nullptr; +} + +// Annotate each `tt.LoadOp` instruction with its corresponding gemm operand +// index. Note, this is a part of the instruction scheduling routine. Currently, +// we support `forOp`s which contain only a single `tt.DotOp` in the bodies. +void labelLoadOpsForTritonDot(scf::ForOp forOp) { + mlir::MLIRContext *ctx = forOp->getContext(); + if (auto dotOp = triton::getSingleDotOpIfExists(forOp)) { + for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { + if (auto loadOp = passPrevUnaryOps(dotOperand)) { + auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); + loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + } + } + } +} + +struct PipelinePass : public TritonAMDGPUStreamPipelineBase { + PipelinePass() = default; + PipelinePass(int32_t _numStages, int32_t _globalPrefetch, + int32_t _localPrefetch) { + this->numStages = _numStages; + + this->globalPrefetch = _globalPrefetch; + this->localPrefetch = _localPrefetch; + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + // check numStages + if (globalPrefetch < 0 || globalPrefetch >= numStages) { + moduleOp.emitError("global prefetch control must be in [0, ") + << numStages << "); " << globalPrefetch << " is out of range"; + return signalPassFailure(); + } + + if (localPrefetch < 0 || localPrefetch >= numStages) { + moduleOp.emitError("local prefetch control must be in [0, ") + << numStages << "); " << localPrefetch << " is out of range"; + return signalPassFailure(); + } + + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + labelLoadOpsForTritonDot(forOp); + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + for (scf::ForOp forOp : loops) { + if (!checkPrecondition(forOp)) + continue; + StreamPipeliner sp(forOp, getNumStagesOrDefault(forOp), globalPrefetch, + localPrefetch); + if (failed(sp.pipelineLoop())) + continue; + } + } + +private: + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists, otherwise use the + // global control. + if (auto attr = forOp->getAttrOfType(tt::kNumStagesAttrName)) + return attr.getInt(); + return numStages; + } +}; +} // namespace + +std::unique_ptr +mlir::createTritonAMDGPUStreamPipelinePass(int numStages, int globalPrefetch, + int localPrefetch) { + return std::make_unique(numStages, globalPrefetch, + localPrefetch); +} diff --git a/third_party/enflame/include/triton/third_party/amd/python/test/address_sanitizer_helper.py b/third_party/enflame/include/triton/third_party/amd/python/test/address_sanitizer_helper.py new file mode 100644 index 000000000..a40937677 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/python/test/address_sanitizer_helper.py @@ -0,0 +1,33 @@ +import torch +import triton +import triton.language as tl + +size = 4096 +x = torch.rand(size, device='cuda') +y = torch.rand(size, device='cuda') +output = torch.empty_like(x) +n_elements = output.numel() +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + #Set access to go out of bounds for ASAN test + offsets = block_start + tl.arange(0, BLOCK_SIZE) + 1 + x = tl.load(x_ptr + offsets) + y = tl.load(y_ptr + offsets) + output = x + y + tl.store(output_ptr + offsets, output) + + +pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) +amdgcn = pgm.asm['amdgcn'] +print(amdgcn) diff --git a/third_party/enflame/include/triton/third_party/amd/python/test/test_address_sanitizer.py b/third_party/enflame/include/triton/third_party/amd/python/test/test_address_sanitizer.py new file mode 100644 index 000000000..c2b626f87 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/python/test/test_address_sanitizer.py @@ -0,0 +1,34 @@ +import os +import subprocess + +import triton + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def test_address_sanitizer(): + if not is_hip(): + return #not supported on NV backend + + # It is recommended to disable various memory caching strategies both within the ROCm stack and PyTorch + # This will give the address sanitizer the best chance at finding the memory fault where it originates, + # otherwise it could be masked by writing past the end of a cached block within a larger allocation. + os.environ["HSA_DISABLE_FRAGMENT_ALLOCATOR"] = "1" + os.environ["AMD_PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1" + os.environ["PYTORCH_NO_HIP_MEMORY_CACHING"] = "1" + os.environ["TRITON_ENABLE_ASAN"] = "1" + + # HSA_XNACK here is required to set the xnack+ setting for the GPU at runtime. + # If it is not set and the default xnack setting of the system is xnack- + # a runtime error something like "No kernel image found" will occur. The system + # xnack setting can be found through rocminfo. xnack+ is required for ASAN. + # More information about xnack in general can be found here: + # https://llvm.org/docs/AMDGPUUsage.html#target-features + # https://rocm.docs.amd.com/en/docs-6.1.0/conceptual/gpu-memory.html + os.environ["HSA_XNACK"] = "1" + + out = subprocess.Popen(["python", "address_sanitizer_helper.py"], stderr=subprocess.PIPE, stdout=subprocess.PIPE) + assert "Begin function __asan_report" in out.stdout.read().decode() + assert "heap-buffer-overflow" in out.stderr.read().decode() diff --git a/third_party/enflame/include/triton/third_party/amd/python/test/test_extract_slice.py b/third_party/enflame/include/triton/third_party/amd/python/test/test_extract_slice.py new file mode 100644 index 000000000..5d2408086 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/python/test/test_extract_slice.py @@ -0,0 +1,111 @@ +import pytest +import torch + +import triton + +from triton._internal_testing import is_hip + +num_ctas_list = [1] + +GPU_DIALECT = "ttg" + +if is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +# ----------------------- +# test extract slice +# ----------------------- + +extract_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] +blocked_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", + [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("extract_layout", extract_layout) +@pytest.mark.parametrize("blocked_layout", blocked_layout) +def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, + extract_layout, device='cuda'): + if not is_hip(): + pytest.skip("extract_slice is AMD specific instruction.") + + ir = f""" + #blocked = {blocked_layout} + #extract_layout = {extract_layout} + module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> + %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> + %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x = torch.randn((M, N), device=device, dtype=torch.float16) + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + extract_slice = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + + kernel[(1, 1, 1)](x.data_ptr(), extract_slice) + test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], + extract_slice) + assert test_result diff --git a/third_party/enflame/include/triton/third_party/amd/python/triton_amd.cc b/third_party/enflame/include/triton/third_party/amd/python/triton_amd.cc new file mode 100644 index 000000000..551da3f33 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/python/triton_amd.cc @@ -0,0 +1,299 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "TritonAMDGPUToLLVM/Passes.h" +#include "TritonAMDGPUToLLVM/TargetUtils.h" +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" +#include "passes.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Module.h" +#include "llvm/MC/MCAsmBackend.h" +#include "llvm/MC/MCAsmInfo.h" +#include "llvm/MC/MCCodeEmitter.h" +#include "llvm/MC/MCContext.h" +#include "llvm/MC/MCInstrInfo.h" +#include "llvm/MC/MCObjectFileInfo.h" +#include "llvm/MC/MCObjectWriter.h" +#include "llvm/MC/MCParser/MCAsmParser.h" +#include "llvm/MC/MCParser/MCTargetAsmParser.h" +#include "llvm/MC/MCRegisterInfo.h" +#include "llvm/MC/MCSection.h" +#include "llvm/MC/MCStreamer.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/MC/MCTargetOptions.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/TargetParser/TargetParser.h" +#include +#include + +namespace py = pybind11; + +namespace { +const char *const amdTargetTriple = "amdgcn-amd-amdhsa"; + +void init_triton_amd_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton; + m.def("add_to_llvmir", + [](mlir::PassManager &pm, const std::string &arch, bool ftz) { + pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz)); + }); + m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm, bool ftz) { + pm.addPass(createConvertBuiltinFuncToLLVMPass(ftz)); + }); + m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm, + const std::string &variant) { + pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass(variant)); + }); + m.def("lower_instruction_sched_hints", + [](mlir::PassManager &pm, const std::string &arch, int32_t numStages) { + pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass( + arch, numStages)); + }); + m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, + const std::string &arch) { + pm.addPass( + mlir::triton::AMD::createDecomposeUnsupportedConversionsPass(arch)); + }); + ADD_PASS_WRAPPER_2("add_optimize_lds_usage", + mlir::triton::AMD::createOptimizeLDSUsagePass, + const std::string &, int32_t); + ADD_PASS_WRAPPER_3("add_accelerate_matmul", + mlir::createTritonAMDGPUAccelerateMatmulPass, + const std::string, int, int); + ADD_PASS_WRAPPER_0("add_optimize_epilogue", + mlir::createTritonAMDGPUOptimizeEpiloguePass); + m.def("add_hoist_layout_conversions", [](mlir::PassManager &pm) { + pm.addNestedPass( + mlir::createTritonAMDGPUHoistLayoutConversionsPass()); + }); + m.def("add_canonicalize_pointers", [](mlir::PassManager &pm) { + pm.addNestedPass( + mlir::createTritonAMDGPUCanonicalizePointersPass()); + }); + ADD_PASS_WRAPPER_1("add_convert_to_buffer_ops", + mlir::createTritonAMDGPUConvertToBufferOpsPass, + const std::string &); + ADD_PASS_WRAPPER_0("add_reorder_instructions", + mlir::createTritonAMDGPUReorderInstructionsPass); + ADD_PASS_WRAPPER_0("add_block_pingpong", + mlir::createTritonAMDGPUBlockPingpongPass); + ADD_PASS_WRAPPER_3("add_stream_pipeline", + mlir::createTritonAMDGPUStreamPipelinePass, int, int, int); +} + +void addControlConstant(llvm::Module *module, const char *name, + uint32_t bitwidth, uint32_t value) { + using llvm::GlobalVariable; + + llvm::IntegerType *type = + llvm::IntegerType::getIntNTy(module->getContext(), bitwidth); + auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false); + auto *constant = new llvm::GlobalVariable( + *module, type, /*isConstant=*/true, + GlobalVariable::LinkageTypes::LinkOnceODRLinkage, initializer, name, + /*before=*/nullptr, GlobalVariable::ThreadLocalMode::NotThreadLocal, + /*addressSpace=*/4); + constant->setAlignment(llvm::MaybeAlign(bitwidth / 8)); + constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local); + constant->setVisibility(GlobalVariable::VisibilityTypes::ProtectedVisibility); +} +} // namespace + +void init_triton_amd(py::module &&m) { + m.doc() = "Python bindings to the AMD Triton backend"; + + auto passes = m.def_submodule("passes"); + init_triton_amd_passes_ttgpuir(passes.def_submodule("ttgpuir")); + + m.attr("TARGET_TRIPLE") = amdTargetTriple; + m.attr("CALLING_CONV_AMDGPU_KERNEL") = + (unsigned)llvm::CallingConv::AMDGPU_KERNEL; + + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + // registry.insert(); + mlir::registerROCDLDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + m.def("attach_target_triple", + [](llvm::Module *module) { module->setTargetTriple(amdTargetTriple); }); + + // Set target architecture ISA version + m.def("set_isa_version", [](llvm::Module *module, const std::string &arch) { + llvm::AMDGPU::IsaVersion version = llvm::AMDGPU::getIsaVersion(arch); + addControlConstant(module, "__oclc_ISA_version", /*bitwidth=*/32, + version.Major * 1000 + version.Minor * 100 + + version.Stepping); + }); + + // Set boolean control constant + m.def("set_bool_control_constant", + [](llvm::Module *module, const std::string &name, bool enable) { + addControlConstant(module, name.c_str(), /*bitwidth=*/8, enable); + }); + + // Set code object ABI version + m.def("set_abi_version", [](llvm::Module *module, int version) { + // Inject the control constant into the LLVM module so that device libraries + // linked against module can resolve their references to it. + llvm::Type *i32Ty = llvm::Type::getInt32Ty(module->getContext()); + llvm::GlobalVariable *abi = new llvm::GlobalVariable( + *module, i32Ty, /*isConstant=*/true, + llvm::GlobalValue::LinkageTypes::LinkOnceODRLinkage, + llvm::ConstantInt::get(i32Ty, version), "__oclc_ABI_version", nullptr, + llvm::GlobalValue::ThreadLocalMode::NotThreadLocal, 4); + abi->setVisibility(llvm::GlobalValue::VisibilityTypes::ProtectedVisibility); + abi->setAlignment(llvm::MaybeAlign(4)); + abi->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Local); + + // Also attach the control attribute on the LLVM module. This is also needed + // in addition to the above for various transformations to know what code + // object version we are targeting at. + module->addModuleFlag(llvm::Module::Error, "amdhsa_code_object_version", + version); + }); + + m.def("cleanup_bitcode_metadata", [](llvm::Module *module) { + // We can have Clang version metadata from device libraries linked in. We + // don't care about them so drop them. + if (auto *ident = module->getNamedMetadata("llvm.ident")) + module->eraseNamedMetadata(ident); + // Also various OpenCL version details. + if (auto *openclVersion = module->getNamedMetadata("opencl.ocl.version")) + module->eraseNamedMetadata(openclVersion); + }); + + m.def("disable_print_inline", [](llvm::Module *module) { + // List of functions name prefixes we want to forbid inline. + std::array prefixes = {"__ockl_fprintf", "__ockl_printf"}; + + for (llvm::Function &f : module->functions()) { + if (!f.hasName()) + continue; + llvm::StringRef name = f.getName(); + + auto isNamePrefixed = [&name](const char *prefix) { + return name.starts_with(prefix); + }; + + if (llvm::any_of(prefixes, isNamePrefixed)) + f.addFnAttr(llvm::Attribute::NoInline); + } + }); + + m.def( + "assemble_amdgcn", + [](const std::string &assembly, const std::string &arch, + const std::string &features) { + std::string error; + + llvm::Triple triple(amdTargetTriple); + const llvm::Target *target = + llvm::TargetRegistry::lookupTarget(triple.normalize(), error); + if (!target) + throw std::runtime_error("target lookup error: " + error); + + llvm::SourceMgr srcMgr; + srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(assembly), + llvm::SMLoc()); + + const llvm::MCTargetOptions mcOptions; + std::unique_ptr mri( + target->createMCRegInfo(amdTargetTriple)); + std::unique_ptr mai( + target->createMCAsmInfo(*mri, amdTargetTriple, mcOptions)); + std::unique_ptr sti( + target->createMCSubtargetInfo(amdTargetTriple, arch, features)); + + llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr, + &mcOptions); + std::unique_ptr mofi( + target->createMCObjectFileInfo(ctx, /*PIC=*/false, + /*LargeCodeModel=*/false)); + ctx.setObjectFileInfo(mofi.get()); + + llvm::SmallString<128> cwd; + if (!llvm::sys::fs::current_path(cwd)) + ctx.setCompilationDir(cwd); + + llvm::SmallVector result; + llvm::raw_svector_ostream svos(result); + + std::unique_ptr mcStreamer; + std::unique_ptr mcii(target->createMCInstrInfo()); + + std::unique_ptr ce( + target->createMCCodeEmitter(*mcii, ctx)); + std::unique_ptr mab( + target->createMCAsmBackend(*sti, *mri, mcOptions)); + std::unique_ptr ow(mab->createObjectWriter(svos)); + mcStreamer.reset(target->createMCObjectStreamer( + triple, ctx, std::move(mab), std::move(ow), std::move(ce), *sti)); + + std::unique_ptr parser( + createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); + std::unique_ptr tap( + target->createMCAsmParser(*sti, *parser, *mcii, mcOptions)); + if (!tap) + throw std::runtime_error("assembler initializtion error"); + + parser->setTargetParser(*tap); + parser->Run(/*NoInitialTextSection=*/false); + + return py::bytes(std::string(result.begin(), result.end())); + }, + py::return_value_policy::take_ownership); + + m.def("need_extern_lib", [](llvm::Module *module, const std::string &lib) { + for (llvm::Function &f : module->functions()) { + if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) { + llvm::StringRef funcName = f.getName(); + // The rule for linking the extern lib: + // if the function name includes ocml or ockl, link + // ocml or ockl accordingly. + if (funcName.contains(lib)) + return true; + if (funcName.contains("__nv_")) { + std::stringstream message; + message << "Implicit conversion of CUDA " << funcName.str() + << " device function has been dropped; " + << "please, update your source program to use " + "triton.language.extra. " + << "to replace triton.language.extra.cuda."; + throw std::runtime_error(message.str()); + } + } + } + return false; + }); + + m.def("has_matrix_core_feature", [](const std::string &arch) { + using mlir::triton::AMD::ISAFamily; + switch (mlir::triton::AMD::deduceISAFamily(arch)) { + case ISAFamily::CDNA4: + case ISAFamily::CDNA3: + case ISAFamily::CDNA2: + case ISAFamily::CDNA1: + case ISAFamily::RDNA3: + return true; + default: + return false; + } + }); + + m.def("set_all_fn_arg_inreg", [](llvm::Function *fn) { + for (llvm::Argument &arg : fn->args()) { + // Check for incompatible attributes. + if (arg.hasByRefAttr() || arg.hasNestAttr()) + continue; + arg.addAttr(llvm::Attribute::InReg); + } + }); +} diff --git a/third_party/enflame/include/triton/third_party/amd/test/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/test/CMakeLists.txt new file mode 100644 index 000000000..3ea7a4199 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/test/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(lib) diff --git a/third_party/enflame/include/triton/third_party/amd/test/lib/Analysis/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/test/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..a05b67fba --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/test/lib/Analysis/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_library(TritonAMDGPUTestAnalysis + TestAMDRangeAnalysis.cpp + + LINK_LIBS PUBLIC + MLIRPass + ${triton_libs} +) diff --git a/third_party/enflame/include/triton/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp b/third_party/enflame/include/triton/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp new file mode 100644 index 000000000..7ff3f0a02 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp @@ -0,0 +1,67 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Analysis/RangeAnalysis.h" +#include "triton/Analysis/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct TestAMDRangeAnalysisPass + : PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAMDRangeAnalysisPass) + + StringRef getArgument() const final { + return "test-tritonamdgpu-range-analysis"; + } + StringRef getDescription() const final { + return "print the result of the tritonamdgpu-range-analysis pass"; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp mod = getOperation(); + + // Collect assumptions in the function + DenseMap> assumptions = + AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); + std::shared_ptr solver = createDataFlowSolver(); + solver->load(assumptions); + if (failed(solver->initializeAndRun(getOperation()))) + return signalPassFailure(); + + auto nonNegativePred = [&solver](Value v) -> bool { + return succeeded(dataflow::staticallyNonNegative(*solver, v)); + }; + mod->walk([&solver, nonNegativePred](Operation *op) { + auto results = op->getResults(); + if (auto outputRanges = AMD::collectRanges(*solver, results)) { + for (const auto &[res, outR] : llvm::zip(results, *outputRanges)) { + std::string rangeS; + llvm::raw_string_ostream rangeSt(rangeS); + rangeSt << outR; + emitRemark(res.getLoc(), rangeS); + } + + if (auto cmpOp = llvm::dyn_cast(op)) { + if (AMD::cmpIIsStaticallyTrue(*solver, cmpOp)) + emitRemark(op->getLoc(), "result is true"); + } + } + + if (!results.empty() && llvm::all_of(results, nonNegativePred)) + emitRemark(op->getLoc(), "non-neg"); + }); + } +}; + +} // namespace + +namespace mlir::test { +void registerTestTritonAMDGPURangeAnalysis() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/third_party/enflame/include/triton/third_party/amd/test/lib/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/test/lib/CMakeLists.txt new file mode 100644 index 000000000..fc6ef10fa --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/test/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Analysis) diff --git a/third_party/enflame/include/triton/third_party/amd/unittest/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/unittest/CMakeLists.txt new file mode 100644 index 000000000..bd3c0c6c0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/unittest/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Conversion) diff --git a/third_party/enflame/include/triton/third_party/amd/unittest/Conversion/CMakeLists.txt b/third_party/enflame/include/triton/third_party/amd/unittest/Conversion/CMakeLists.txt new file mode 100644 index 000000000..6d7a6b293 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/unittest/Conversion/CMakeLists.txt @@ -0,0 +1,6 @@ +add_triton_ut(NAME TestOptimizeLDS +SRCS OptimizeLDSTest.cpp +LIBS + TritonAnalysis + TritonIR + TritonGPUIR) diff --git a/third_party/enflame/include/triton/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp b/third_party/enflame/include/triton/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp new file mode 100644 index 000000000..a9f112239 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp @@ -0,0 +1,42 @@ +//===- OptimizeLDSTest.cpp - Tests for OptimizeLDSUtility -----------------===// + +#include "third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h" +#include +#include + +namespace mlir { + +template bool checkProdEq(ArrayRef a) { + unsigned prod = + std::reduce(a.begin(), a.end(), 1u, std::multiplies()); + return prod == P; +} + +TEST(OptimizeLDSUtility, factorizePowerOf2) { + int numwarps; + int rank; + // check rank=1 generation + numwarps = 4; + rank = 1; + auto output1 = triton::AMD::factorizePowerOf2(numwarps, rank); + ASSERT_EQ(output1.size(), 1); + ASSERT_EQ(output1[0][0], numwarps); + // check rank=2 generation + numwarps = 8; + rank = 2; + auto output2 = triton::AMD::factorizePowerOf2(numwarps, rank); + ASSERT_EQ(output2.size(), 4); + ASSERT_TRUE(std::all_of(output2.begin(), output2.end(), checkProdEq<8>)); + ASSERT_TRUE(std::all_of(output2.begin(), output2.end(), + [](auto a) { return a.size() == 2; })); + // check rank=3 generation + numwarps = 8; + rank = 3; + auto output3 = triton::AMD::factorizePowerOf2(numwarps, rank); + ASSERT_EQ(output3.size(), 10); + ASSERT_TRUE(std::all_of(output3.begin(), output3.end(), checkProdEq<8>)); + ASSERT_TRUE(std::all_of(output3.begin(), output3.end(), + [](auto a) { return a.size() == 3; })); +} + +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/f2reduce/CMakeLists.txt b/third_party/enflame/include/triton/third_party/f2reduce/CMakeLists.txt new file mode 100644 index 000000000..71db82e3c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/f2reduce/CMakeLists.txt @@ -0,0 +1,3 @@ +add_triton_library(f2reduce + f2reduce.cpp +) diff --git a/third_party/enflame/include/triton/third_party/f2reduce/LICENCE.txt b/third_party/enflame/include/triton/third_party/f2reduce/LICENCE.txt new file mode 100644 index 000000000..bce4fded0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/f2reduce/LICENCE.txt @@ -0,0 +1,7 @@ +Copyright 2023 Adam P. Goucher, Hatsya Limited + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/third_party/enflame/include/triton/third_party/f2reduce/README.md b/third_party/enflame/include/triton/third_party/f2reduce/README.md new file mode 100644 index 000000000..0be556a55 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/f2reduce/README.md @@ -0,0 +1,66 @@ +f2reduce: a MIT-licenced library for Gaussian elimination over GF(2) +==================================================================== + +This is a very lightweight implementation for converting a binary matrix +to row reduced echelon form. It incorporates the following optimisations: + + - Kronrod's algorithm ('method of four Russians'); + - Designed to properly autovectorise in both GCC and LLVM; + - Attempts to ensure that memory loads/stores are cache-aligned; + - Designed to achieve high instruction-level parallelism; + - Able to use AVX512's `vpternlogq` instruction if present; + - Minimal memory overhead (a few megabytes). + +There are no architecture-specific intrinsics or assembly, so this should +work well on any architecture where the compiler can autovectorise. + +For simplicity, we do not use Strassen, so our performance is overtaken by +[M4RI][1] whenever the matrices are large and have full column rank. + +For all other cases, we have several advantages over M4RI: + + - Substantially better performance on small, wide, or low-rank matrices; + - MIT-licenced rather than GPL-licenced; + - No assumptions about the processor architecture; + - No configuration required (`-O3 -march=native` is enough). + +We expose a single function with the following signature: + + void inplace_rref_strided(uint64_t *matrix, uint64_t rows, uint64_t cols, uint64_t stride); + +The matrix should be in row-major format and is overwritten in-place. The +`stride` parameter specifies the offset between adjacent rows **in 64-bit +words, not bytes**. The mapping between matrix entries and memory is as +follows: + + the (j+64*k)th entry of the ith row is (matrix[i * stride + k] >> j) & 1 + +Since the performance can depend on the stride and how it interacts with +processor caches, we expose another function to return a recommended stride: + + uint64_t get_recommended_stride(uint64_t cols); + +Although `f2reduce` is compiled in C++11, the resulting static library +has C-linkage so can be called from any C/C++ code. + +Dependencies +------------ + +`f2reduce` has no dependencies; just compile `f2reduce.cpp` with the +`-O3 -march=native` flags to produce a static library and include the header +file `f2reduce.h` in your project. + +The automated test suite has dependencies on [M4RI][1] (for benchmarking +timings against M4RI and checking that implementations agree), [GoogleTest][2] +(for unit testing), and [cpads][3] (for high-quality pseudo-random number +generation). Downloading of the dependencies and building of the test suite +is automated by [CMake][4]. + +To build the test suite, you need to manually append `add_subdirectory(test)` +to the end of the `CMakeLists.txt` file. This is so that `f2reduce` does not +have any build dependencies by default. + +[1]: https://github.com/malb/m4ri +[2]: https://github.com/google/googletest +[3]: https://gitlab.com/hatsya/open-source/cpads +[4]: https://cmake.org/ diff --git a/third_party/enflame/include/triton/third_party/f2reduce/VERSION b/third_party/enflame/include/triton/third_party/f2reduce/VERSION new file mode 100644 index 000000000..53a316cb1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/f2reduce/VERSION @@ -0,0 +1,2 @@ +Cloned from https://gitlab.com/hatsya/open-source/f2reduce at revision +949b91d022c001bbce19157f806013d37f05fbf5. diff --git a/third_party/enflame/include/triton/third_party/f2reduce/f2reduce.cpp b/third_party/enflame/include/triton/third_party/f2reduce/f2reduce.cpp new file mode 100644 index 000000000..c29498834 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/f2reduce/f2reduce.cpp @@ -0,0 +1,619 @@ +#include "f2reduce.h" + +#include +#include + +#if defined(_MSC_VER) +#define RESTRICT __restrict +#define NO_INLINE __declspec(noinline) +#elif defined(__GNUC__) +#define RESTRICT __restrict__ +#define NO_INLINE __attribute__((noinline)) +#endif + +namespace f2reduce { + +void swap_rows(uint64_t *RESTRICT x, uint64_t *RESTRICT y, uint64_t n) { + for (uint64_t i = 0; i < n; i++) { + uint64_t z = x[i]; + x[i] = y[i]; + y[i] = z; + } +} + +// the noinline attribute is necessary for gcc to properly vectorise this: +template +NO_INLINE void +memxor_lop7(uint64_t *RESTRICT dst, const uint64_t *RESTRICT src1, + const uint64_t *RESTRICT src2, const uint64_t *RESTRICT src3, + const uint64_t *RESTRICT src4, const uint64_t *RESTRICT src5, + const uint64_t *RESTRICT src6) { + for (uint64_t i = 0; i < N; i++) { + dst[i] ^= src1[i] ^ src2[i] ^ src3[i] ^ src4[i] ^ src5[i] ^ src6[i]; + } +} + +template +NO_INLINE void +memxor_lop5(uint64_t *RESTRICT dst, const uint64_t *RESTRICT src1, + const uint64_t *RESTRICT src2, const uint64_t *RESTRICT src3, + const uint64_t *RESTRICT src4) { + for (uint64_t i = 0; i < N; i++) { + dst[i] ^= src1[i] ^ src2[i] ^ src3[i] ^ src4[i]; + } +} + +template +NO_INLINE void memxor_lop3(uint64_t *RESTRICT dst, + const uint64_t *RESTRICT src1, + const uint64_t *RESTRICT src2) { + for (uint64_t i = 0; i < N; i++) { + dst[i] ^= src1[i] ^ src2[i]; + } +} + +template +void memxor_inplace(uint64_t *RESTRICT dst, const uint64_t *RESTRICT src1, + const uint64_t *RESTRICT src2) { + for (uint64_t i = 0; i < N; i++) { + dst[i] = src1[i] ^ src2[i]; + } +} + +// split k into 6 approximately-equal pieces +void split_k(int k, int *subkays) { + + int k5_k6 = (k <= 32) ? 0 : (k / 3); + int k3_k4 = (k - k5_k6) >> 1; + int k1_k2 = k - k5_k6 - k3_k4; + + subkays[0] = k1_k2 >> 1; + subkays[1] = k1_k2 - subkays[0]; + subkays[2] = k3_k4 >> 1; + subkays[3] = k3_k4 - subkays[2]; + subkays[4] = k5_k6 >> 1; + subkays[5] = k5_k6 - subkays[4]; +} + +/** + * Sextuple Kronrod implementation. + * + * This populates six lookup tables of approximately-equal sizes where each + * entry (8*N bytes) contains a linear combination of rows. The transformation + * encoded in 'workspace' is then applied using ternary XORs which are very + * AVX512-friendly. + */ +template +void kronrod(uint64_t *RESTRICT matrix, uint64_t rows, uint64_t stride, + const uint64_t *RESTRICT workspace, uint64_t *RESTRICT cache, + const uint64_t *RESTRICT pivots, int k) { + constexpr int logwidth = 5; + + static_assert(N <= (1ull << logwidth), "kronrod assumes that N <= 32"); + + int subkays[6]; + int cumkays[6]; + uint64_t *caches[6]; + split_k(k, subkays); + + caches[0] = cache; + cumkays[0] = 0; + + for (int i = 0; i < 5; i++) { + caches[i + 1] = caches[i] + (1ull << (subkays[i] + logwidth)); + cumkays[i + 1] = cumkays[i] + subkays[i]; + } + + // build: + for (int o = 0; o < 6; o++) { + uint64_t *subcache = caches[o]; + memset(subcache, 0, 8 << logwidth); + for (int j = 0; j < subkays[o]; j++) { + uint64_t p = (1ull << j); + memcpy(subcache + (p << logwidth), + matrix + pivots[j + cumkays[o]] * stride, N * 8); + for (uint64_t i = 1; i < p; i++) { + memxor_inplace(subcache + ((i + p) << logwidth), + subcache + (i << logwidth), + subcache + (p << logwidth)); + } + } + } + + uint64_t mask0 = (1ull << subkays[0]) - 1; + uint64_t mask1 = (1ull << subkays[1]) - 1; + uint64_t mask2 = (1ull << subkays[2]) - 1; + uint64_t mask3 = (1ull << subkays[3]) - 1; + uint64_t mask4 = (1ull << subkays[4]) - 1; + uint64_t mask5 = (1ull << subkays[5]) - 1; + + // apply: + for (uint64_t r = 0; r < rows; r++) { + + if (N >= 32) { + // prefetch 256 bytes, 15 rows later: + uint64_t *ppp = matrix + (r + 15) * stride; +#if defined(__GNUC__) + __builtin_prefetch(ppp); + __builtin_prefetch(ppp + 8); + __builtin_prefetch(ppp + 16); + __builtin_prefetch(ppp + 24); +#endif + } + + uint64_t w = workspace[r]; + + uint64_t w0 = w & mask0; + uint64_t w1 = (w >> cumkays[1]) & mask1; + uint64_t w2 = (w >> cumkays[2]) & mask2; + uint64_t w3 = (w >> cumkays[3]) & mask3; + if (k <= 32) { + memxor_lop5(matrix + r * stride, caches[0] + (w0 << logwidth), + caches[1] + (w1 << logwidth), caches[2] + (w2 << logwidth), + caches[3] + (w3 << logwidth)); + } else { + uint64_t w4 = (w >> cumkays[4]) & mask4; + uint64_t w5 = (w >> cumkays[5]) & mask5; + memxor_lop7(matrix + r * stride, caches[0] + (w0 << logwidth), + caches[1] + (w1 << logwidth), caches[2] + (w2 << logwidth), + caches[3] + (w3 << logwidth), caches[4] + (w4 << logwidth), + caches[5] + (w5 << logwidth)); + } + } +} + +bool find_pivots(uint64_t *RESTRICT pivots, uint64_t *RESTRICT this_strip, + uint64_t rows, uint64_t &starting_row, uint64_t *workspace, + uint64_t &next_b, uint64_t final_b, int K, int &k) { + + // sorted copy, so that we can skip existing pivots: + uint64_t spivots[64] = {(uint64_t)-1}; + + // find pivots + uint64_t b = 0; + + while (k < K) { + + int l = 0; + b = ((uint64_t)-1); + + for (uint64_t s = starting_row; s < rows; s++) { + + if (s == spivots[l]) { + // don't use an existing pivot: + l += 1; + continue; + } + + uint64_t this_row = this_strip[s]; + uint64_t a = (this_row & (-this_row)) - 1; + if (a < b) { + b = a; + pivots[k] = s; + if (b == next_b) { + // we've found the best pivot possible: + break; + } + } + } + + if (b == ((uint64_t)-1)) { + // we have exhausted this strip with no pivot found: + return true; + } + + uint64_t j = pivots[k]; + uint64_t wsj = workspace[j] ^ (1ull << k); + uint64_t m = this_strip[j]; + uint64_t ml = m & (-m); + + for (uint64_t s = 0; s < rows; s++) { + if (s == j) { + continue; + } + if (this_strip[s] & ml) { + this_strip[s] ^= m; + workspace[s] ^= wsj; + } + } + + spivots[k] = pivots[k]; + l = k; + while (l-- > 0) { + // insertion sort: + if (spivots[l] > spivots[l + 1]) { + uint64_t x = spivots[l]; + spivots[l] = spivots[l + 1]; + spivots[l + 1] = x; + } + } + + k += 1; + next_b = (b << 1) + 1; + if (b == final_b) { + // we have found a pivot for the last column in this strip: + return true; + } + } + + // we have found K pivots and have not proved that this 64-column strip + // has been fully exhausted: + return false; +} + +/** + * Use Kronrod's algorithm to reduce all strips to the right of the current + * strip. We do this in chunks of between 1 and 32 strips (64 to 2048 columns) + * and attempt to align chunks with cache lines if the stride is a multiple + * of the cache line size. + * + * The long switch statements are because we generate bespoke code for each + * value of the chunk width N, which outperforms having a variable-length loop. + */ +void chunked_kronrod(const uint64_t *RESTRICT pivots, uint64_t *RESTRICT matrix, + uint64_t rows, uint64_t strips, uint64_t stride, + const uint64_t *workspace, uint64_t *RESTRICT cache, + int k) { + + uint64_t re = strips - 1; + +#define KRONROD(N) \ + kronrod(matrix + (strips - re), rows, stride, workspace, cache, pivots, k) + + if ((re > 32) && ((stride & 7) == 0)) { + // try to optimise for cache lines: + uint64_t ptr = ((uint64_t)(matrix + (strips - re))); + + // optimise for both 64-byte and 128-byte cache lines: + uint64_t mask = (stride - 1) & 15; // either 0b0111 or 0b1111 + uint64_t ideal_re = 16 - ((ptr >> 3) & mask); + + switch (ideal_re) { + case 15: + KRONROD(15); + re -= 15; + break; + case 14: + KRONROD(14); + re -= 14; + break; + case 13: + KRONROD(13); + re -= 13; + break; + case 12: + KRONROD(12); + re -= 12; + break; + case 11: + KRONROD(11); + re -= 11; + break; + case 10: + KRONROD(10); + re -= 10; + break; + case 9: + KRONROD(9); + re -= 9; + break; + case 8: + KRONROD(8); + re -= 8; + break; + case 7: + KRONROD(7); + re -= 7; + break; + case 6: + KRONROD(6); + re -= 6; + break; + case 5: + KRONROD(5); + re -= 5; + break; + case 4: + KRONROD(4); + re -= 4; + break; + case 3: + KRONROD(3); + re -= 3; + break; + case 2: + KRONROD(2); + re -= 2; + break; + case 1: + KRONROD(1); + re -= 1; + break; + } + } + + while (re >= 32) { + KRONROD(32); + re -= 32; + } + + if (re >= 16) { + KRONROD(16); + re -= 16; + } + + switch (re) { + // process the last (incomplete) chunk: + case 15: + KRONROD(15); + break; + case 14: + KRONROD(14); + break; + case 13: + KRONROD(13); + break; + case 12: + KRONROD(12); + break; + case 11: + KRONROD(11); + break; + case 10: + KRONROD(10); + break; + case 9: + KRONROD(9); + break; + case 8: + KRONROD(8); + break; + case 7: + KRONROD(7); + break; + case 6: + KRONROD(6); + break; + case 5: + KRONROD(5); + break; + case 4: + KRONROD(4); + break; + case 3: + KRONROD(3); + break; + case 2: + KRONROD(2); + break; + case 1: + KRONROD(1); + break; + } + +#undef KRONROD +} + +/** + * Find up to K pivot rows in this strip of 64 columns, remove them from all + * other rows, and permute them into the correct places. + */ +bool perform_K_steps(uint64_t *RESTRICT matrix, uint64_t *RESTRICT stripspace, + uint64_t rows, uint64_t strips, uint64_t stride, + uint64_t &starting_row, uint64_t *workspace, + uint64_t *RESTRICT cache, uint64_t &next_b, int K, + uint64_t final_b) { + + memset(workspace, 0, 8 * rows); + + // array to contain the indices of the k pivot rows: + uint64_t pivots[64] = {(uint64_t)-1}; + + int k = 0; + bool completed_strip = find_pivots(pivots, stripspace, rows, starting_row, + workspace, next_b, final_b, K, k); + + if (k == 0) { + // no pivots detected: + return true; + } + + for (uint64_t r = 0; r < rows; r++) { + matrix[r * stride] = stripspace[r]; + } + + // for all strips to the right of the current strip, use Kronrod's + // method to XOR the correct linear combination of the k pivot rows + // from each row in the matrix: + chunked_kronrod(pivots, matrix, rows, strips, stride, workspace, cache, k); + + // apply a row permutation so that the k pivot rows are moved to the + // uppermost k slots, incrementing starting_row in the process: + for (int i = 0; i < k; i++) { + if (pivots[i] != starting_row) { + // swap rows in matrix: + swap_rows(matrix + starting_row * stride, matrix + pivots[i] * stride, + strips); + // swap rows in stripspace: + uint64_t x = stripspace[pivots[i]]; + stripspace[pivots[i]] = stripspace[starting_row]; + stripspace[starting_row] = x; + for (int j = 0; j < k; j++) { + if (pivots[j] == starting_row) { + pivots[j] = pivots[i]; + } + } + pivots[i] = starting_row; + } + starting_row += 1; + } + + // determine whether we have exhausted all of the columns in the strip: + return completed_strip; +} + +void inplace_rref_strided_K(uint64_t *RESTRICT matrix, + uint64_t *RESTRICT stripspace, uint64_t rows, + uint64_t cols, uint64_t stride, uint64_t *workspace, + uint64_t *cache, int K) { + + uint64_t strips = (cols + 63) >> 6; + + uint64_t current_row = 0; + + for (uint64_t current_strip = 0; current_strip < strips; current_strip++) { + uint64_t remcols = cols - (current_strip << 6); + if (remcols > 64) { + remcols = 64; + } + uint64_t final_b = (1ull << (remcols - 1)) - 1; + uint64_t next_b = 0; + + uint64_t *offset_matrix = matrix + current_strip; + + // We make a cached copy of the current strip. This has contiguous + // memory layout (unlike the source strip in the matrix), and the + // performance gain from having contiguity massively exceeds the + // cost of copying between the matrix and this cached copy. + for (uint64_t r = 0; r < rows; r++) { + stripspace[r] = offset_matrix[r * stride]; + } + + while (current_row < rows) { + if (perform_K_steps(offset_matrix, stripspace, rows, + strips - current_strip, stride, current_row, + workspace, cache, next_b, K, final_b)) { + break; + } + } + + if (current_row >= rows) { + break; + } + } +} + +void inplace_rref_strided_heap(uint64_t *matrix, uint64_t rows, uint64_t cols, + uint64_t stride, int K) { + + // Array for storing, for each row, the appropriate linear combination of + // the k <= K <= 32 pivot rows that needs to be subtracted: + uint64_t *workspace = ((uint64_t *)malloc(rows * 8)); + + // Array for caching the current strip (64 columns) of the matrix: + uint64_t *stripspace = ((uint64_t *)malloc(rows * 8)); + + int subkays[6]; + split_k(K, subkays); + + // Array for storing 256-byte chunks of linear combinations of pivot rows: + void *cache_raw = malloc(256 * (1 + (1 << subkays[0]) + (1 << subkays[1]) + + (1 << subkays[2]) + (1 << subkays[3]) + + (1 << subkays[4]) + (1 << subkays[5]))); + + // Align to cache lines: + uint64_t cache_ptr = ((uint64_t)cache_raw); + cache_ptr += (128 - (cache_ptr & 127)); + uint64_t *cache = ((uint64_t *)cache_ptr); + + // Convert to row reduced echelon form: + inplace_rref_strided_K(matrix, stripspace, rows, cols, stride, workspace, + cache, K); + + // Free the allocated memory buffers: + free(workspace); + free(stripspace); + free(cache_raw); +} + +void inplace_rref_small(uint64_t *matrix, uint64_t rows, uint64_t cols) { + + uint64_t final_b = (1ull << (cols - 1)) - 1; + uint64_t next_b = 0; + + for (uint64_t r = 0; r < rows; r++) { + + uint64_t b = (matrix[r] & (-matrix[r])) - 1; + + for (uint64_t s = r + 1; s < rows; s++) { + uint64_t this_row = matrix[s]; + uint64_t a = (this_row & (-this_row)) - 1; + + if (a < b) { + b = a; + matrix[s] = matrix[r]; + matrix[r] = this_row; + } + + if (b == next_b) { + break; + } + } + + if (b == ((uint64_t)-1)) { + break; + } + + uint64_t m = matrix[r]; + uint64_t ml = m & (-m); + + for (uint64_t s = 0; s < rows; s++) { + if (s == r) { + continue; + } + if (matrix[s] & ml) { + matrix[s] ^= m; + } + } + + next_b = (b << 1) + 1; + if (b == final_b) { + break; + } + } +} + +} // namespace f2reduce + +namespace f2reduce { + +void inplace_rref_strided(uint64_t *matrix, uint64_t rows, uint64_t cols, + uint64_t stride) { + + if (rows <= 1 || cols == 0) { + // If the matrix has 0 or 1 rows or 0 columns, it must already be in RREF: + return; + } + + if ((rows <= 64) && (cols <= 64)) { + uint64_t matrix2[64]; + for (uint64_t i = 0; i < rows; i++) { + matrix2[i] = matrix[i * stride]; + } + inplace_rref_small(matrix2, rows, cols); + for (uint64_t i = 0; i < rows; i++) { + matrix[i * stride] = matrix2[i]; + } + } else { + // Select value of k to minimise the objective function: + // ceil(64/k) * (rows + 2^(k/2)) + int k = (rows <= 5120) ? 32 : 64; + inplace_rref_strided_heap(matrix, rows, cols, stride, k); + } +} + +uint64_t get_recommended_stride(uint64_t cols) { + + uint64_t stride = (cols + 63) >> 6; + if (stride > 32) { + // pad to a multiple of a 64/128-byte cache line: + stride += (-stride) & 15; + } + if ((stride & 63) == 0) { + // ensure not divisible by 64 to avoid critical stride issues: + stride += 16; + } + return stride; +} + +} // namespace f2reduce diff --git a/third_party/enflame/include/triton/third_party/f2reduce/f2reduce.h b/third_party/enflame/include/triton/third_party/f2reduce/f2reduce.h new file mode 100644 index 000000000..9865541ec --- /dev/null +++ b/third_party/enflame/include/triton/third_party/f2reduce/f2reduce.h @@ -0,0 +1,29 @@ +#pragma once +#include + +// OpenAI change: Switched from `extern "C"` to `namespace f2reduce`. +namespace f2reduce { + +/** + * Converts a matrix over F_2 into row-reduced echelon form. + * + * The matrix should be in row-major format. The stride parameter specifies + * the offset (in 64-bit words, *not* bytes!) between successive rows of the + * matrix, and should obey the inequality: + * + * 64 |stride| >= cols + * + * i.e. that the rows occupy disjoint regions of memory. For best performance + * the stride should be divisible by 16 words (128 bytes). + * + * We adopt 'little-endian' semantics: the element in row i and column j+64*k + * of the matrix (zero-indexed) is given by (matrix[i * stride + k] >> j) & 1. + * + * The matrix is overwritten in place with its row-reduced echelon form. + */ +void inplace_rref_strided(uint64_t *matrix, uint64_t rows, uint64_t cols, + uint64_t stride); + +uint64_t get_recommended_stride(uint64_t cols); + +} // namespace f2reduce diff --git a/third_party/enflame/include/triton/third_party/nvidia/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/CMakeLists.txt new file mode 100644 index 000000000..bab189bcb --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/CMakeLists.txt @@ -0,0 +1,11 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonNVIDIA ${CMAKE_CURRENT_SOURCE_DIR}/triton_nvidia.cc LINK_LIBS TritonNVIDIAGPUToLLVM NVGPUToLLVM) + target_link_libraries(TritonNVIDIA PRIVATE Python3::Module pybind11::headers) +endif() +if(TRITON_BUILD_UT) + add_subdirectory(unittest) +endif() diff --git a/third_party/enflame/include/triton/third_party/nvidia/backend/__init__.py b/third_party/enflame/include/triton/third_party/nvidia/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/include/triton/third_party/nvidia/backend/compiler.py b/third_party/enflame/include/triton/third_party/nvidia/backend/compiler.py new file mode 100644 index 000000000..6db76a352 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/backend/compiler.py @@ -0,0 +1,457 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes, llvm, nvidia +from triton.runtime.errors import PTXASError + +from dataclasses import dataclass +import functools +from typing import Any, Dict, Tuple, Optional +from types import ModuleType +import hashlib +import re +import tempfile +import signal +import os +import subprocess +from pathlib import Path +import sysconfig + + +def min_dot_size(target: GPUTarget): + + def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k] + lhs_bitwidth = lhs_type.scalar.primitive_bitwidth + rhs_bitwidth = rhs_type.scalar.primitive_bitwidth + assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same" + if lhs_bitwidth == 8: + return (16, 16, 32) + else: + return (16, 16, 16) + + return check_dot_compatibility + + +@functools.lru_cache() +def _path_to_binary(binary: str): + binary += sysconfig.get_config_var("EXE") + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(os.path.dirname(__file__), "bin", binary), + ] + + for path in paths: + if os.path.exists(path) and os.path.isfile(path): + result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return path, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + +@functools.lru_cache() +def get_ptxas(arch: int): + name = "ptxas-blackwell" if arch >= 100 else "ptxas" + return _path_to_binary(name) + + +@functools.lru_cache() +def get_ptxas_version(arch: int): + mock_ver = os.environ.get('TRITON_MOCK_PTX_VERSION') + if mock_ver is not None: + return mock_ver # This is not really a version of ptxas, but it is good enough for testing + version = subprocess.check_output([get_ptxas(arch)[0], "--version"]).decode("utf-8") + return version + + +@functools.lru_cache() +def ptx_get_version(cuda_version) -> int: + ''' + Get the highest PTX version supported by the current CUDA driver. + ''' + assert isinstance(cuda_version, str) + major, minor = map(int, cuda_version.split('.')) + if major == 12: + if minor < 6: + return 80 + minor + else: + return 80 + minor - 1 + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor + raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version) + + +def get_ptx_version_from_options(options, arch: int): + ptx_version = options.ptx_version + if ptx_version is None: + _, cuda_version = get_ptxas(arch) + ptx_version = ptx_get_version(cuda_version) + return ptx_version + + +@functools.lru_cache() +def get_features(options, arch: int): + ptx_version = get_ptx_version_from_options(options, arch) + + # PTX 8.6 is the max version supported by llvm c1188642. + # + # To check if a newer PTX version is supported, increase this value + # and run a test. If it's not supported, LLVM will print a warning + # like "+ptx8.4 is not a recognized feature for this target". + llvm_ptx_version = min(86, ptx_version) + features = f'+ptx{llvm_ptx_version}' + return features + + +@functools.lru_cache(None) +def file_hash(path): + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +def sm_arch_from_capability(capability: int): + # TODO: Handle non-"a" sms + suffix = "a" if capability >= 90 else "" + return f"sm_{capability}{suffix}" + + +@dataclass(frozen=True) +class CUDAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 + # maxnreg corresponds to the ptx parameter .maxnreg, which controls the + # maximum number of 32-bit registers used by one thread. + maxnreg: Optional[int] = None + cluster_dims: tuple = (1, 1, 1) + ptx_version: int = None + enable_fp_fusion: bool = True + launch_cooperative_grid: bool = False + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15") + deprecated_fp8_dtypes: Tuple[str] = () + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + debug: bool = False + backend_name: str = 'cuda' + sanitize_overflow: bool = True + arch: str = None + + def __post_init__(self): + default_libdir = Path(__file__).parent / 'lib' + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get('libdevice', None): + extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", str(default_libdir / 'libdevice.10.bc')) + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + hash_dict = dict(self.__dict__) + hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class CUDABackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'cuda' + + def _parse_arch(self, arch): + pattern = r"^sm(\d+)$" + match = re.fullmatch(pattern, arch) + if not match: + raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}") + return int(match.group(1)) + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.binary_ext = "cubin" + + def parse_options(self, opts) -> Any: + args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", f"sm{self.target.arch}")} + args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None}) + capability = int(self._parse_arch(args["arch"])) + + if "supported_fp8_dtypes" not in args: + supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes) + if capability >= 89: + supported_fp8_dtypes.add("fp8e4nv") + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + + if "deprecated_fp8_dtypes" not in args: + if capability >= 90: + args["deprecated_fp8_dtypes"] = ("fp8e4b15", ) + + if "enable_fp_fusion" not in args: + args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" + + args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0 + + return CUDAOptions(**args) + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def get_codegen_implementation(self, options): + import triton.language.extra.cuda as cuda + capability = int(self._parse_arch(options.arch)) + codegen_fns = { + "convert_custom_types": + cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size": + min_dot_size(self.target) + } + return codegen_fns + + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.cuda import libdevice + return {"triton.language.extra.libdevice": libdevice} + + def load_dialects(self, ctx): + nvidia.load_dialects(ctx) + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.ttir.add_loop_unroll(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability): + cluster_info = nvidia.ClusterInfo() + if opt.cluster_dims is not None: + cluster_info.clusterDimX = opt.cluster_dims[0] + cluster_info.clusterDimY = opt.cluster_dims[1] + cluster_info.clusterDimZ = opt.cluster_dims[2] + pm = ir.pass_manager(mod.context) + dump_enabled = pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_f32_dot_tc(pm) + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + passes.ttgpuir.add_accelerate_matmul(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.common.add_cse(pm) + if capability // 10 in [8, 9]: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_licm(pm) + passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.common.add_canonicalizer(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups) + passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups, + opt.reg_dec_producer, opt.reg_inc_consumer) + passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups) + elif capability // 10 >= 10: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_licm(pm) + passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups) + passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups, + opt.reg_dec_producer, opt.reg_inc_consumer) + passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) + nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm) + passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups) + passes.common.add_canonicalizer(pm) + else: + passes.common.add_licm(pm) + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.ttgpuir.add_coalesce_async_copy(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if capability // 10 >= 9: + nvidia.passes.ttnvgpuir.add_fence_insertion(pm) + nvidia.passes.ttnvgpuir.add_tma_lowering(pm) + passes.common.add_canonicalizer(pm) + if capability // 10 >= 9: + passes.ttgpuir.add_ws_canonicalization(pm, opt.num_consumer_groups) + pm.run(mod) + metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) + return mod + + def make_llir(self, src, metadata, options, capability): + ptx_version = get_ptx_version_from_options(options, self.target.arch) + + mod = src + # TritonGPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + + nvidia.passes.ttnvgpuir.add_lower_mma(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_allocate_warp_groups(pm) + passes.convert.add_scf_to_cf(pm) + passes.ttgpuir.add_allocate_shared_memory(pm) + nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) + passes.ttgpuir.add_allocate_global_scratch_memory(pm) + nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) + nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + pm.run(mod) + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1": + raise RuntimeError( + "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend") + llvm_mod = llvm.to_module(mod, context) + proc = sm_arch_from_capability(capability) + features = get_features(options, self.target.arch) + triple = 'nvptx64-nvidia-cuda' + llvm.attach_datalayout(llvm_mod, triple, proc, features) + nvidia.set_nvvm_reflect_ftz(llvm_mod) + + # Set maxnreg on all kernels, if it was provided. + if options.maxnreg is not None: + for k in llvm_mod.get_functions(): + if not k.is_declaration() and k.is_external_linkage(): + k.set_nvvm_maxnreg(options.maxnreg) + + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llvm.link_extern_libs(llvm_mod, paths) + + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) + + # Get some metadata + # warp-specialization mutates num_warps + total_num_warps = src.get_int_attr("ttg.total-num-warps") + if total_num_warps is not None: + metadata["num_warps"] = total_num_warps + metadata["shared"] = src.get_int_attr("ttg.shared") + metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size") + metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size") + metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment") + ret = str(llvm_mod) + del llvm_mod + del context + return ret + + def make_ptx(self, src, metadata, opt, capability): + ptx_version = get_ptx_version_from_options(opt, self.target.arch) + + triple = 'nvptx64-nvidia-cuda' + proc = sm_arch_from_capability(capability) + features = get_features(opt, self.target.arch) + ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False) + # Find kernel names (there should only be one) + names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret) + assert len(names) == 1 + metadata["name"] = names[0] + # post-process + ptx_version = f'{ptx_version//10}.{ptx_version%10}' + ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE) + ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE) + # Remove the debug flag that prevents ptxas from optimizing the code + ret = re.sub(r",\s*debug|debug,\s*", "", ret) + if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1": + print("// -----// NVPTX Dump //----- //") + print(ret) + return ret + + def make_cubin(self, src, metadata, opt, capability): + ptxas, _ = get_ptxas(self.target.arch) + with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \ + tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog: + fsrc.write(src) + fsrc.flush() + fbin = fsrc.name + '.o' + + line_info = ["-lineinfo", "-suppress-debug-info"] if os.environ.get("TRITON_DISABLE_LINE_INFO", + "0") == "1" else ["-lineinfo"] + fmad = [] if opt.enable_fp_fusion else ['--fmad=false'] + arch = sm_arch_from_capability(capability) + opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else [] + ptxas_cmd = [ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name={arch}', fsrc.name, '-o', fbin] + try: + subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog) + if os.path.exists(fsrc.name): + os.remove(fsrc.name) + if os.path.exists(flog.name): + os.remove(flog.name) + except subprocess.CalledProcessError as e: + with open(flog.name) as log_file: + log = log_file.read() + if os.path.exists(flog.name): + os.remove(flog.name) + + if e.returncode == 255: + error = 'Internal Triton PTX codegen error' + elif e.returncode == 128 + signal.SIGSEGV: + error = '`ptxas` raised SIGSEGV' + else: + error = f'`ptxas` failed with error code {e.returncode}' + + raise PTXASError(f"{error}\n" + f"`ptxas` stderr:\n{log}\n" + f'Repro command: {" ".join(ptxas_cmd)}\n') + + with open(fbin, 'rb') as f: + cubin = f.read() + if os.path.exists(fbin): + os.remove(fbin) + return cubin + + def add_stages(self, stages, options): + capability = self._parse_arch(options.arch) + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) + stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch) + stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch) + + @functools.lru_cache() + def hash(self): + version = get_ptxas_version(self.target.arch) + return f'{version}-{self.target.arch}' diff --git a/third_party/enflame/include/triton/third_party/nvidia/backend/driver.c b/third_party/enflame/include/triton/third_party/nvidia/backend/driver.c new file mode 100644 index 000000000..12deb0d1e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/backend/driver.c @@ -0,0 +1,421 @@ +#include "cuda.h" +#include +#include +#define PY_SSIZE_T_CLEAN +#include + +// Raises a Python exception and returns false if code is not CUDA_SUCCESS. +static bool gpuAssert(CUresult code, const char *file, int line) { + if (code == CUDA_SUCCESS) + return true; + + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Used to check if functions exist in old CUDA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + CUdevice device; + cuDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int max_num_regs; + int multiprocessor_count; + int warp_size; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + CUDA_CHECK_AND_RETURN_NULL( + cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + CUfunction fun; + CUmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + CUcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); + if (!pctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); + } + + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (shared > 49152 && shared_optin > 49152) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( + &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + } + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +typedef CUresult (*cuOccupancyMaxActiveClusters_t)( + int *numClusters, CUfunction func, const CUlaunchConfig *config); + +typedef CUresult (*cuTensorMapEncodeTiled_t)( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, + const cuuint64_t *globalStrides, const cuuint32_t *boxDim, + const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill); + +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \ + if (!libHandle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \ + return NULL; \ + } \ + /* Clear any existing error */ \ + dlerror(); \ + symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from libcuda.so.1"); \ + dlclose(libHandle); \ + return NULL; \ + } \ + return funcHandle; \ + } + +defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, + cuOccupancyMaxActiveClusters); + +defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, + cuTensorMapEncodeTiled); + +static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { + int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, + maxActiveClusters = -1; + int shared = 0; + CUfunction func; + + if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, + &clusterDimY, &clusterDimZ)) { + return NULL; + } + + // Let each SM have one block + int maxActiveBlocks = 1; + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); + Py_END_ALLOW_THREADS; + + CUlaunchAttribute launchAttr[1]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + CUlaunchConfig config; + config.gridDimX = clusterDimX; + config.gridDimY = maxActiveBlocks * clusterDimY; + config.gridDimZ = clusterDimZ; + config.blockDimX = 128; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared; + config.hStream = 0; + config.numAttrs = 1; + config.attrs = launchAttr; + + static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, + getCuOccupancyMaxActiveClustersHandle); + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); + Py_END_ALLOW_THREADS; + return PyLong_FromLong(maxActiveClusters); +} + +static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { + long size; + if (!PyArg_ParseTuple(args, "l", &size)) { + return NULL; + } + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS; + + // Ensure we have an active context. + CUcontext ctx = NULL; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx)); + if (!ctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&ctx, /*device=*/0)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx)); + } + + // We can't set the fifo size after running a kernel that calls printf. This + // is true even if the set() call is a nop and the new size is the same as the + // old size. + // + // This is unfriendly, so check if the old size matches the new size, and skip + // the set() call if so. + size_t oldSize = 0; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); + if (oldSize != size) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); + } + + Py_END_ALLOW_THREADS; + Py_INCREF(Py_None); + return Py_None; +} + +// Simple helper to experiment creating TMA descriptors on the host. +// This is a useful to test TMA operations independently. +static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dim; + uint32_t tensorDim; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, + &elementSize, &desc_address)) { + return NULL; + } + uint64_t dims[1] = {dim}; + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t boxDim[1] = {tensorDim}; + uint32_t elementStrides[1] = {1}; + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + return NULL; + } + assert((elementSize * tensorDim) >= 32 && "block size too small."); + int rank = 1; + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, + globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); + return Py_None; +} + +// Simple helper to experiment creating TMA descriptors on the host. +// This is a useful to test TMA operations independently. +static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { + unsigned long long global_address; + uint64_t dims[2]; + uint32_t tensorDims[2]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], + &tensorDims[1], &tensorDims[0], &elementSize, + &desc_address)) { + return NULL; + } + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[2] = {1, 1}; + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + } + int rank = 2; + // Swizzling should be picked in codegen but since we need to set it on the + // descriptor we rely on a convention between this function and codegen. + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + assert(false && "block size too small."); + } + // The bounding box inner dimension must be less than or equal to the swizzle + // size. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // We clamp the block size and the codegen will emit multiple copy operations. + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); + return Py_None; +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, + "Python interface for cuOccupancyMaxActiveClusters function"}, + {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, + "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " + "controls how many bytes can be streamed from kernels before data starts " + "being dropped. This inherits all the limitations of this call; in " + "particular it's an error to change this value after launching any kernel " + "that calls printf()."}, + {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, + + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cuda_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/backend/driver.py b/third_party/enflame/include/triton/third_party/nvidia/backend/driver.py new file mode 100644 index 000000000..5f2621ae5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/backend/driver.py @@ -0,0 +1,576 @@ +import functools +import os +import sysconfig +import hashlib +import subprocess +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.runtime import _allocation +from triton.backends.compiler import GPUTarget +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dir = [os.path.join(dirname, "include")] +libdevice_dir = os.path.join(dirname, "lib") +libraries = ['cuda'] + + +@functools.lru_cache() +def libcuda_dirs(): + env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH") + if env_libcuda_path: + return [env_libcuda_path] + + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 + locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line] + dirs = [os.path.dirname(loc) for loc in locs] + env_ld_library_path = os.getenv("LD_LIBRARY_PATH") + if env_ld_library_path and not dirs: + dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))] + msg = 'libcuda.so cannot found!\n' + if locs: + msg += 'Possible files are located at %s.' % str(locs) + msg += 'Please create a symlink of libcuda.so to any of the files.' + else: + msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"' + msg += ' (requires sudo) to refresh the linker cache.' + assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg + return dirs + + +@functools.lru_cache() +def library_dirs(): + return [libdevice_dir, *libcuda_dirs()] + + +@functools.lru_cache() +def platform_key(): + from platform import machine, system, architecture + return ",".join([machine(), system(), *architecture()]) + + +def compile_module_from_src(src, name): + key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] + cache_path = cache.get_file(f"{name}.{ext}") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Utils +# ------------------------ + + +class CudaUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CudaUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters + self.set_printf_fifo_size = mod.set_printf_fifo_size + self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor + self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "CUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + "nvTmaDesc": "CUtensorMap", + }[ty] + + +def make_launcher(constants, signature): + + def _serialize_signature(sig): + if isinstance(sig, tuple): + return ','.join(map(_serialize_signature, sig)) + return sig + + def _extracted_type(ty): + if isinstance(ty, tuple): + val = ','.join(map(_extracted_type, ty)) + return f"[{val}]" + if ty[0] == '*': + return "PyObject*" + if ty in ("constexpr", "nvTmaDesc"): + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + if isinstance(ty, tuple): + val = ''.join(map(format_of, ty)) + return f"({val})" + if ty[0] == '*': + return "O" + if ty in ("constexpr", "nvTmaDesc"): + return "O" + return { + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "L", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty_to_cpp(ty)] + + args_format = ''.join([format_of(ty) for ty in signature.values()]) + format = "iiiKKpOOOOO" + args_format + signature = ','.join(map(_serialize_signature, signature.values())) + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") + internal_args_list = [] + for i, ty in signature.items(): + if ty[0] == "*": + internal_args_list.append(f"ptr_info{i}.dev_ptr") + elif ty == "nvTmaDesc": + # Note: we have to dereference the pointer + internal_args_list.append(f"*tma_ptr{i}") + elif ty != "constexpr": + internal_args_list.append(f"_arg{i}") + params = range(len(signature)) + + # generate glue code + newline = '\n ' + ptr_decls = [ + f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" + for i, ty in signature.items() + if ty[0] == "*" + ] + tma_decls = [ + f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items() + if ty == "nvTmaDesc" + ] + params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] + params.append("&global_scratch") + src = f""" +#include \"cuda.h\" +#include +#include +#include + +static inline void gpuAssert(CUresult code, const char *file, int line) +{{ + if (code != CUDA_SUCCESS) + {{ + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + }} +}} + +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); + +static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ + // Open the shared library + void* handle = dlopen("libcuda.so.1", RTLD_LAZY); + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); + return NULL; + }} + // Clear any existing error + dlerror(); + cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); + // Check for errors + const char *dlsym_error = dlerror(); + if (dlsym_error) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1"); + return NULL; + }} + return cuLaunchKernelExHandle; +}} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(params)} }}; + if (gridX*gridY*gridZ > 0) {{ + if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{ + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{ + CUlaunchAttribute launchAttr[1]; + CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}}; + launchAttr[0] = coopAttr; + + CUlaunchConfig config; + config.gridDimX = gridX; + config.gridDimY = gridY; + config.gridDimZ = gridZ; + config.blockDimX = 32 * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + config.attrs = launchAttr; + config.numAttrs = 1; + + static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; + if (cuLaunchKernelExHandle == NULL) {{ + cuLaunchKernelExHandle = getLaunchKernelExHandle(); + }} + CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); + + }} else {{ + CUlaunchAttribute launchAttr[3]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + + unsigned numAttrs = 2; + if (0 != launch_cooperative_grid) {{ + CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}}; + launchAttr[2] = coopAttr; + numAttrs = 3; + }} + + CUlaunchConfig config; + config.gridDimX = gridX * clusterDimX; + config.gridDimY = gridY * clusterDimY; + config.gridDimZ = gridZ * clusterDimZ; + config.blockDimX = 32 * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + config.attrs = launchAttr; + config.numAttrs = numAttrs; + static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; + if (cuLaunchKernelExHandle == NULL) {{ + cuLaunchKernelExHandle = getLaunchKernelExHandle(); + }} + CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + CUdeviceptr dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + uint64_t dev_ptr; + int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == CUDA_ERROR_INVALID_VALUE) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} else if (status != CUDA_SUCCESS) {{ + CUDA_CHECK(status); // Catch any other cuda API errors + ptr_info.valid = false; + }} + ptr_info.dev_ptr = dev_ptr; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static inline CUtensorMap* getTmaDesc(PyObject *obj) {{ + if (sizeof(CUtensorMap*) != 8) {{ + PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); + return NULL; + }} + + PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); + if (!method_handle) {{ + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); + return NULL; + }} + + PyObject *empty_tuple = PyTuple_New(0); + if (!empty_tuple) {{ + Py_DECREF(method_handle); + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + }} + PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(method_handle); + if (!method_ret) {{ + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + }} + + if (!PyLong_Check(method_ret)) {{ + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); + Py_DECREF(method_ret); + return NULL; + }} + + uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); + Py_DECREF(method_ret); + if (!ptr_as_uint) {{ + PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); + return NULL; + }} + if (ptr_as_uint % 64 != 0) {{ + PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); + return NULL; + }} + + return (CUtensorMap*)(ptr_as_uint); +}} + +static void ensureCudaContext() {{ + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) {{ + // Ensure device context. + CUdevice device; + CUDA_CHECK(cuDeviceGet(&device, 0)); + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes + ensureCudaContext(); + + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + int launch_cooperative_grid; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + PyObject *global_scratch_obj = NULL; + {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, + &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook{args_list})) {{ + return NULL; + }} + + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + CUdeviceptr global_scratch = 0; + if (global_scratch_obj != Py_None) {{ + DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1); + if (!global_scratch_info.valid) {{ + return NULL; + }} + global_scratch = global_scratch_info.dev_ptr; + }} + + // raise exception asap + {newline.join(ptr_decls)} + {newline.join(tma_decls)} + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + + }} + + Py_RETURN_NONE; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class CudaLauncher(object): + + def __init__(self, src, metadata): + constants = src.constants if hasattr(src, "constants") else dict() + arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x + constants = {arg_idx(idx): value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} + src = make_launcher(constants, signature) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + self.global_scratch_size = metadata.global_scratch_size + self.global_scratch_align = metadata.global_scratch_align + self.launch_cooperative_grid = metadata.launch_cooperative_grid + + def __call__(self, gridX, gridY, gridZ, stream, function, *args): + if self.global_scratch_size > 0: + grid_size = gridX * gridY * gridZ + alloc_size = grid_size * self.global_scratch_size + global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream) + else: + global_scratch = None + self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args) + + +class CudaDriver(GPUDriver): + + def __init__(self): + self.utils = CudaUtils() # TODO: make static + self.launcher_cls = CudaLauncher + super().__init__() + + def get_current_target(self): + device = self.get_current_device() + capability = self.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + warp_size = 32 + return GPUTarget("cuda", capability, warp_size) + + def get_active_torch_device(self): + import torch + return torch.device("cuda", self.get_current_device()) + + def get_device_interface(self): + import torch + return torch.cuda + + @staticmethod + def is_active(): + try: + import torch + return torch.cuda.is_available() and (torch.version.hip is None) + except ImportError: + return False + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') + + def clear_cache(self, cache): + cache.zero_() diff --git a/third_party/enflame/include/triton/third_party/nvidia/backend/lib/libdevice.10.bc b/third_party/enflame/include/triton/third_party/nvidia/backend/lib/libdevice.10.bc new file mode 100644 index 000000000..b2c75a502 Binary files /dev/null and b/third_party/enflame/include/triton/third_party/nvidia/backend/lib/libdevice.10.bc differ diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/include/CMakeLists.txt new file mode 100644 index 000000000..2ef7aab10 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Dialect) +add_subdirectory(TritonNVIDIAGPUToLLVM) +add_subdirectory(NVGPUToLLVM) diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/CMakeLists.txt new file mode 100644 index 000000000..edeac0660 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(NVGPU) diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..f8932cdc4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS NVGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvgpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvgpu) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(NVGPUDialect NVGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(NVGPUOps NVGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(NVGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td) +mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(NVGPUAttrDefsIncGen) diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h new file mode 100644 index 000000000..6e238af4f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_NVGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_NVGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc" +#include "nvidia/include/Dialect/NVGPU/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "nvidia/include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "nvidia/include/Dialect/NVGPU/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace nvgpu {} // namespace nvgpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUAttrDefs.td b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUAttrDefs.td new file mode 100644 index 000000000..c904824ef --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUAttrDefs.td @@ -0,0 +1,33 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_ATTRDEFS +#define NVGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "NVGPUDialect.td" + +class NVGPU_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUDialect.td b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUDialect.td new file mode 100644 index 000000000..6978173d4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUDialect.td @@ -0,0 +1,40 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_DIALECT +#define NVGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def NVGPU_Dialect : Dialect { + let name = "nvgpu"; + let cppNamespace = "::mlir::triton::nvgpu"; + + let description = [{ + NVGPU Dialect. + }]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect" + ]; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td new file mode 100644 index 000000000..630564e87 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -0,0 +1,221 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef NVGPU_OPS +#define NVGPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "NVGPUDialect.td" +include "NVGPUAttrDefs.td" + +def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; +def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; +def LLVM_PointerTensorMemory : LLVM_PointerInAddressSpace<6>; + + +def NVGPU_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def NVGPU_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def NVGPU_ScalarLike : AnyTypeOf<[NVGPU_Float, NVGPU_Int]>; + + +def NVGPU_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton::nvgpu"; +} + +class NVGPU_Op traits = []> : + LLVM_OpBase; + +def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> { + let assemblyFormat = "attr-dict"; +} + + +def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> { + let assemblyFormat = "attr-dict"; +} + +def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", [DeclareOpInterfaceMethods, + AllTypesMatch<["input", "output"]>]> { + let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings); + let results = (outs LLVM_AnyStruct:$output); + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType", + "mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'", + [ + I32EnumAttrCase<"normal", 0>, + I32EnumAttrCase<"cp_async", 1>, + I32EnumAttrCase<"expect_tx", 2>, + I32EnumAttrCase<"remote", 3>, + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> { + let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); + let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)"; +} + +def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout", + "wgmma layout, either 'row' or 'col'", + [ + I32EnumAttrCase<"row", 0>, + I32EnumAttrCase<"col", 1> + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", + "wgmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'", + [ + I32EnumAttrCase<"s8", 0>, + I32EnumAttrCase<"s32", 1>, + I32EnumAttrCase<"e4m3", 2>, + I32EnumAttrCase<"e5m2", 3>, + I32EnumAttrCase<"f16", 4>, + I32EnumAttrCase<"bf16", 5>, + I32EnumAttrCase<"tf32", 6>, + I32EnumAttrCase<"f32", 7> + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; + +def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, I1:$useC, Optional:$opC, + I32Attr:$m, I32Attr:$n, I32Attr:$k, + WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, + WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); + let results = (outs LLVM_AnyStruct:$res); + let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; +} + +def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { + let arguments = (ins BoolAttr:$bCluster); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_ClusterArriveOp : NVGPU_Op<"cluster_arrive", []> { + let arguments = (ins I1Attr:$relaxed); + + let assemblyFormat = "attr-dict"; +} + +def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> { + let assemblyFormat = "attr-dict"; +} + +def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> { + let arguments = ( + ins LLVM_PointerShared:$addr, + Variadic:$vals, + UnitAttr:$trans + ); + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + +def NVGPU_LoadMatrixOp : NVGPU_Op<"ldmatrix", [MemoryEffects<[MemRead]>]> { + let arguments = ( + ins LLVM_PointerShared:$addr, + UnitAttr:$trans + ); + let results = (outs LLVM_AnyStruct:$result); + let assemblyFormat = "$addr attr-dict `:` functional-type($addr, $result)"; +} + +def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_LoadAcquireOp : NVGPU_Op<"ld_acquire", [MemoryEffects<[MemRead]>]> { + let arguments = ( + ins LLVM_PointerGlobal:$addr, + Optional:$mask, + NVGPU_MemSemanticAttr:$sem, + NVGPU_MemSyncScopeAttr:$scope + ); + let results = (outs NVGPU_ScalarLike:$result); + let assemblyFormat = "$sem `,` $scope `,` $addr (`,` $mask^)? attr-dict `:` functional-type($addr, $result)"; +} + +def NVGPU_WarpIdOp : NVGPU_Op<"warp_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_CanonicalWarpIdOp : NVGPU_Op<"canonical_warp_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_TensorMemoryBaseAddress : NVGPU_Op<"tensor_memory_base", [Pure]> { + let description = [{ + Op to represent base address of tensor memory in a kernel. + This is used to simplify lowering from TritonGPU to LLVM. + }]; + let results = (outs LLVM_PointerTensorMemory:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "attr-dict"; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..f89521768 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name NVGPUToLLVM) +add_public_tablegen_target(NVGPUConversionPassIncGen) diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h b/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h new file mode 100644 index 000000000..12ac194a8 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h @@ -0,0 +1,40 @@ +#ifndef TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H +#define TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H + +#include +#include +#include +#include + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +namespace nvgpu { + +using Constraints = std::vector; +using OperandsAndConstraints = std::vector>; + +LogicalResult +rewriteAsPtxAsm(mlir::Operation *op, mlir::PatternRewriter &rewriter, + std::string ptxAsm, + const OperandsAndConstraints &operandsAndConstraints = {}, + const Constraints &outputConstraints = {}); + +} // namespace nvgpu + +std::unique_ptr> createConvertNVGPUToLLVMPass(); + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/Passes.h b/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/Passes.h new file mode 100644 index 000000000..a2d265356 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/Passes.h @@ -0,0 +1,17 @@ +#ifndef NVGPU_CONVERSION_PASSES_H +#define NVGPU_CONVERSION_PASSES_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "nvidia/include/NVGPUToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/Passes.td b/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/Passes.td new file mode 100644 index 000000000..345e6408c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/NVGPUToLLVM/Passes.td @@ -0,0 +1,19 @@ +#ifndef NVGPU_CONVERSION_PASSES +#define NVGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertNVGPUToLLVM : Pass<"convert-nv-gpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert NVGPU to LLVM"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertNVGPUToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::LLVM::LLVMDialect", + "mlir::NVVM::NVVMDialect", + "mlir::triton::nvgpu::NVGPUDialect"]; +} + +#endif // NVGPU_CONVERSION_PASSES diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..193f7aff0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonNVIDIAGPUToLLVM) +add_public_tablegen_target(TritonNVIDIAGPUConversionPassIncGen) diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h new file mode 100644 index 000000000..ca86594d8 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h @@ -0,0 +1,347 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_PTX_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_PTX_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +struct PTXInstr; +struct PTXInstrCommon; +struct PTXInstrExecution; + +// PTXBuilder helps to manage a PTX asm program consists of one or multiple +// instructions. +// +// A helper for building an ASM program, the objective of PTXBuilder is to give +// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear. +// Currently, several factors are introduced to reduce the need for mixing +// string and C++ if-else code. +// +// Usage: +// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), +// "b"(p)); +// +// PTXBuilder builder; +// auto& add = builder.create<>(); +// add.predicate(pVal).o("lo").o("u32"); // add any suffix +// // predicate here binds %0 to pVal, pVal is a mlir::Value +// +// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal +// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal +// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal +// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate +// +// To get the asm code: +// builder.dump() +// +// To get all the mlir::Value used in the PTX code, +// +// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal} +// +// To get the string containing all the constraints with "," separated, +// builder.getConstraints() // get "=r,r,k" +// +// PTXBuilder can build a PTX asm with multiple instructions, sample code: +// +// PTXBuilder builder; +// auto& mov = builder.create("mov"); +// auto& cp = builder.create("cp"); +// mov(...); +// cp(...); +// This will get a PTX code with two instructions. +// +// Similar to a C function, a declared PTXInstr instance can be launched +// multiple times with different operands, e.g. +// +// auto& mov = builder.create("mov"); +// mov(... some operands ...); +// mov(... some different operands ...); +// +// Finally, we will get a PTX code with two mov instructions. +// +// There are several derived instruction type for typical instructions, for +// example, the PtxIOInstr for ld and st instructions. +struct PTXBuilder { + struct Operand { + std::string constraint; + Value value; + int idx{-1}; + llvm::SmallVector list; + std::function repr; + + // for list + Operand() = default; + Operand(const Operation &) = delete; + Operand(Value value, StringRef constraint) + : constraint(constraint), value(value) {} + + bool isList() const { return !value && constraint.empty(); } + + Operand *listAppend(Operand *arg) { + list.push_back(arg); + return this; + } + + Operand *listGet(size_t nth) const { + assert(nth < list.size()); + return list[nth]; + } + + std::string dump() const; + }; + + template + INSTR *create(Args &&...args) { + instrs.emplace_back(std::make_unique(this, args...)); + return static_cast(instrs.back().get()); + } + + // Create a list of operands. + Operand *newListOperand() { return newOperand(); } + + Operand *newListOperand(ArrayRef> items) { + auto *list = newOperand(); + for (auto &item : items) { + list->listAppend(newOperand(item.first, item.second)); + } + return list; + } + + Operand *newListOperand(unsigned count, mlir::Value val, + const std::string &constraint) { + auto *list = newOperand(); + for (unsigned i = 0; i < count; ++i) { + list->listAppend(newOperand(val, constraint)); + } + return list; + } + + Operand *newListOperand(unsigned count, const std::string &constraint) { + auto *list = newOperand(); + for (unsigned i = 0; i < count; ++i) { + list->listAppend(newOperand(constraint)); + } + return list; + } + + // Create a new operand. It will not add to operand list. + // @value: the MLIR value bind to this operand. + // @constraint: ASM operand constraint, .e.g. "=r" + // @formatter: extra format to represent this operand in ASM code, default is + // "%{0}".format(operand.idx). + Operand *newOperand(mlir::Value value, StringRef constraint, + std::function formatter = nullptr); + + // Create a new operand which is written to, that is, the constraint starts + // with "=", e.g. "=r". + // If the operand will be used in predicated execution, + // users may want to initialize it before use. + // Otherwise if the register is only used in the true branch or the false + // branch but not both, the register is undefined and ptxas can perform + // aggressive optimizations that may lead to incorrect results. + Operand *newOperand(StringRef constraint, bool init = false); + + // Create a new operand that is tied to a previous operand. In this case the + // asm would be permitted to write to an input register. Instead of providing + // constraint code for this operand, the constraint code of the tied operand + // is used. + Operand *newOperand(unsigned operandIndex); + + // Create a constant integer operand. + Operand *newConstantOperand(int64_t v); + // Create a constant operand with explicit code specified. + Operand *newConstantOperand(const std::string &v); + + Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); + + llvm::SmallVector getAllArgs() const; + + llvm::SmallVector getAllMLIRArgs() const; + + std::string getConstraints() const; + + std::string dump() const; + + mlir::Value launch(OpBuilder &rewriter, Location loc, Type resTy, + bool hasSideEffect = true, bool isAlignStack = false, + ArrayRef attrs = {}) const; + +private: + Operand *newOperand() { + argArchive.emplace_back(std::make_unique()); + return argArchive.back().get(); + } + + void initOperand(Operand *opr); + + // Make the operands in argArchive follow the provided \param order. + void reorderArgArchive(ArrayRef order) { + assert(order.size() == argArchive.size()); + // The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but + // it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are + // determined by PTX code snippet passed from external. + sort(argArchive.begin(), argArchive.end(), + [&](std::unique_ptr &a, std::unique_ptr &b) { + auto ida = std::find(order.begin(), order.end(), a.get()); + auto idb = std::find(order.begin(), order.end(), b.get()); + assert(ida != order.end()); + assert(idb != order.end()); + return ida < idb; + }); + } + + friend struct PTXInstr; + friend struct PTXInstrCommon; + +protected: + llvm::SmallVector, 6> argArchive; + llvm::SmallVector, 2> instrs; + llvm::SmallVector, 4> executions; + int oprCounter{}; +}; + +// PTX instruction common interface. +// Put the generic logic for all the instructions here. +struct PTXInstrCommon { + explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {} + + using Operand = PTXBuilder::Operand; + + // clang-format off + PTXInstrExecution& operator()() { return call({}); } + PTXInstrExecution& operator()(Operand* a) { return call({a}); } + PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); } + PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); } + // clang-format on + + // Set operands of this instruction. + PTXInstrExecution &operator()(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); + +protected: + // "Call" the instruction with operands. + // \param oprs The operands of this instruction. + // \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments + // to the inline Asm without generating the operand ids(such as $0, $1) in PTX + // code. + PTXInstrExecution &call(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); + + PTXBuilder *builder{}; + llvm::SmallVector instrParts; + + friend struct PTXInstrExecution; +}; + +template struct PTXInstrBase : public PTXInstrCommon { + using Operand = PTXBuilder::Operand; + + explicit PTXInstrBase(PTXBuilder *builder, const std::string &name) + : PTXInstrCommon(builder) { + o(name); + } + + // Append a suffix to the instruction. + // e.g. PTXInstr("add").o("s32") get a add.s32. + // A predicate is used to tell whether to apply the suffix, so that no if-else + // code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will + // get a `add.s32` if isS32 is true. + ConcreteT &o(const std::string &suffix, bool predicate = true) { + if (predicate) + instrParts.push_back(suffix); + return *static_cast(this); + } +}; + +struct PTXInstr : public PTXInstrBase { + using PTXInstrBase::PTXInstrBase; + + // Append a ".global" to the instruction. + PTXInstr &global(); + + // Append a ".shared" to the instruction. + PTXInstr &shared(); + + // Append a ".v[0-9]+" to the instruction + PTXInstr &v(int vecWidth, bool predicate = true); + + // Append a".b[0-9]+" to the instruction + PTXInstr &b(int width); +}; + +// Record the operands and context for "launching" a PtxInstr. +struct PTXInstrExecution { + using Operand = PTXBuilder::Operand; + + llvm::SmallVector argsInOrder; + + PTXInstrExecution() = default; + explicit PTXInstrExecution(PTXInstrCommon *instr, + llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs) + : argsInOrder(oprs.begin(), oprs.end()), instr(instr), + onlyAttachMLIRArgs(onlyAttachMLIRArgs) {} + + // Prefix a predicate to the instruction. + PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { + assert(value); + pred = instr->builder->newOperand(value, constraint); + return *this; + } + + // Prefix a predicate to the instruction, if non-null + PTXInstrExecution &maybePredicate(mlir::Value value, + StringRef constraint = "b") { + if (value) + predicate(value, constraint); + return *this; + } + + // Prefix a !predicate to the instruction. + PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) { + pred = instr->builder->newOperand(value, constraint); + pred->repr = [](int idx) { return "@!$" + std::to_string(idx); }; + return *this; + } + + std::string dump() const; + + SmallVector getArgList() const; + + PTXInstrCommon *instr{}; + Operand *pred{}; + bool onlyAttachMLIRArgs{}; +}; + +/// ====== Some instruction wrappers ====== +// We add the wrappers to make the usage more intuitive by avoiding mixing the +// PTX code with some trivial C++ code. + +struct PTXCpAsyncLoadInstr : PTXInstrBase { + explicit PTXCpAsyncLoadInstr(PTXBuilder *builder, + triton::CacheModifier modifier) + : PTXInstrBase(builder, "cp.async") { + o(triton::stringifyCacheModifier(modifier).str()); + o("shared"); + o("global"); + } +}; + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h new file mode 100644 index 000000000..d8542065d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h @@ -0,0 +1,33 @@ +#ifndef TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_PASSES_H +#define TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h.inc" + +std::unique_ptr> createConvertTritonGPUToLLVMPass(); +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int32_t computeCapability); +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int32_t computeCapability, int32_t ptxVersion); + +#define GEN_PASS_REGISTRATION +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td new file mode 100644 index 000000000..ea753f578 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td @@ -0,0 +1,45 @@ +#ifndef TRITONGPU_CONVERSION_PASSES +#define TRITONGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert TritonGPU to LLVM"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::nvgpu::NVGPUDialect", + "mlir::NVVM::NVVMDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability">, + Option<"ptxVersion", "ptx-version", + "int32_t", /*default*/"80", + "PTX version">, + ]; +} + +def ConvertWarpSpecializeToLLVM : Pass<"convert-warp-specialize-to-llvm", "mlir::ModuleOp"> { + let summary = "lower `ttg.warp_specialize` to LLVM"; + let description = [{ + The `convert-warp-specialize-to-llvm` pass performs codegen for warp + specialization. It is a function-level transformation that rewrites + warp-specialized kernels by using shared memory and barriers to communicate + states between the default warpgroup and the worker warps. + }]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::NVVM::NVVMDialect"]; +} + +#endif // TRITONGPU_CONVERSION_PASSES diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h new file mode 100644 index 000000000..6d1c3c06a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h @@ -0,0 +1,17 @@ +#ifndef TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H +#define TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H + +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace triton { +namespace NVIDIA { + +/// Return true if we can skip a barrier synchronization between two operations +/// even if they access the same shared memory. +bool canSkipBarSync(Operation *before, Operation *after); +} // namespace NVIDIA +} // namespace triton +} // namespace mlir + +#endif // TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/cublas_instance.h b/third_party/enflame/include/triton/third_party/nvidia/include/cublas_instance.h new file mode 100644 index 000000000..d79d4d76b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/cublas_instance.h @@ -0,0 +1,213 @@ +#ifndef TRITON_CUBLAS_INSTANCE_H +#define TRITON_CUBLAS_INSTANCE_H + +#include "cublas_types.h" +#include +#include +#include + +class CublasLtInstance { + // Typedefs for cublas functions + typedef cublasStatus_t (*cublasLtCreate_t)(cublasLtHandle_t *); + typedef cublasStatus_t (*cublasLtDestroy_t)(cublasLtHandle_t); + typedef cublasStatus_t (*cublasLtMatmulDescCreate_t)(cublasLtMatmulDesc_t *, + cublasComputeType_t, + cudaDataType_t); + typedef cublasStatus_t (*cublasLtMatmulDescDestroy_t)(cublasLtMatmulDesc_t); + typedef cublasStatus_t (*cublasLtMatmulDescSetAttribute_t)( + cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void *, + size_t); + typedef cublasStatus_t (*cublasLtMatrixLayoutCreate_t)( + cublasLtMatrixLayout_t *, cudaDataType_t, uint64_t, uint64_t, int64_t); + typedef cublasStatus_t (*cublasLtMatrixLayoutDestroy_t)( + cublasLtMatrixLayout_t); + typedef cublasStatus_t (*cublasLtMatmulPreferenceCreate_t)( + cublasLtMatmulPreference_t *); + typedef cublasStatus_t (*cublasLtMatmulPreferenceDestroy_t)( + cublasLtMatmulPreference_t); + typedef cublasStatus_t (*cublasLtMatmulPreferenceSetAttribute_t)( + cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, + const void *, size_t); + typedef cublasStatus_t (*cublasLtMatmulAlgoGetHeuristic_t)( + cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, + cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, + cublasLtMatmulPreference_t, int, cublasLtMatmulHeuristicResult_t *, + int *); + typedef cublasStatus_t (*cublasLtMatmul_t)( + cublasLtHandle_t, cublasLtMatmulDesc_t, const void *, const void *, + const cublasLtMatrixLayout_t, const void *, const cublasLtMatrixLayout_t, + const void *, const void *, const cublasLtMatrixLayout_t, void *, + const cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, void *, + size_t, cudaStream_t); + + static constexpr const char *name = "libcublas.so"; + + cublasLtCreate_t cublasLtCreate; + cublasLtDestroy_t cublasLtDestroy; + cublasLtMatmulDescCreate_t cublasLtMatmulDescCreate; + cublasLtMatmulDescDestroy_t cublasLtMatmulDescDestroy; + cublasLtMatmulDescSetAttribute_t cublasLtMatmulDescSetAttribute; + cublasLtMatrixLayoutCreate_t cublasLtMatrixLayoutCreate; + cublasLtMatrixLayoutDestroy_t cublasLtMatrixLayoutDestroy; + cublasLtMatmulPreferenceCreate_t cublasLtMatmulPreferenceCreate; + cublasLtMatmulPreferenceDestroy_t cublasLtMatmulPreferenceDestroy; + cublasLtMatmulPreferenceSetAttribute_t cublasLtMatmulPreferenceSetAttribute; + cublasLtMatmulAlgoGetHeuristic_t cublasLtMatmulAlgoGetHeuristic; + cublasLtMatmul_t cublasLtMatmul; + + void *dylibHandle = nullptr; + cublasLtHandle_t ltHandle; + + void *workspace = nullptr; + size_t workspaceSize = 0; + + cublasLtMatmulPreference_t preference = NULL; + + void loadCublasDylib() { + if (dylibHandle == nullptr) { + // First reuse the existing handle + dylibHandle = dlopen(name, RTLD_NOLOAD); + } + if (dylibHandle == nullptr) { + // If not found, try to load it + dylibHandle = dlopen(name, RTLD_LOCAL | RTLD_LAZY); + } + if (dylibHandle == nullptr) { + throw std::runtime_error("Could not find `" + std::string(name) + + "`. Make sure it is in your " + "LD_LIBRARY_PATH."); + } + dlerror(); // Clear any existing error + + cublasLtCreate = (cublasLtCreate_t)dlsym(dylibHandle, "cublasLtCreate"); + cublasLtDestroy = (cublasLtDestroy_t)dlsym(dylibHandle, "cublasLtDestroy"); + cublasLtMatmulDescCreate = (cublasLtMatmulDescCreate_t)dlsym( + dylibHandle, "cublasLtMatmulDescCreate"); + cublasLtMatmulDescDestroy = (cublasLtMatmulDescDestroy_t)dlsym( + dylibHandle, "cublasLtMatmulDescDestroy"); + cublasLtMatmulDescSetAttribute = (cublasLtMatmulDescSetAttribute_t)dlsym( + dylibHandle, "cublasLtMatmulDescSetAttribute"); + cublasLtMatrixLayoutCreate = (cublasLtMatrixLayoutCreate_t)dlsym( + dylibHandle, "cublasLtMatrixLayoutCreate"); + cublasLtMatrixLayoutDestroy = (cublasLtMatrixLayoutDestroy_t)dlsym( + dylibHandle, "cublasLtMatrixLayoutDestroy"); + cublasLtMatmulPreferenceCreate = (cublasLtMatmulPreferenceCreate_t)dlsym( + dylibHandle, "cublasLtMatmulPreferenceCreate"); + cublasLtMatmulPreferenceDestroy = (cublasLtMatmulPreferenceDestroy_t)dlsym( + dylibHandle, "cublasLtMatmulPreferenceDestroy"); + cublasLtMatmulPreferenceSetAttribute = + (cublasLtMatmulPreferenceSetAttribute_t)dlsym( + dylibHandle, "cublasLtMatmulPreferenceSetAttribute"); + cublasLtMatmulAlgoGetHeuristic = (cublasLtMatmulAlgoGetHeuristic_t)dlsym( + dylibHandle, "cublasLtMatmulAlgoGetHeuristic"); + cublasLtMatmul = (cublasLtMatmul_t)dlsym(dylibHandle, "cublasLtMatmul"); + + const char *dlsym_error = dlerror(); + if (dlsym_error) { + throw std::runtime_error("Could not load symbol from `" + + std::string(name) + + "`: " + std::string(dlsym_error)); + } + } + + void unloadCublasDylib() { dlclose(dylibHandle); } + + void successOrExit(cublasStatus_t status) { + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("cuBLAS Error: " + std::to_string(status) + + "\n"); + } + } + + // Simple wrapper around the cublasLtMatmul function + void matmul_impl(int m, int n, int k, uint64_t A, uint64_t B, uint64_t D, + cudaDataType_t dtype) { + cublasLtMatmulDesc_t matmulDesc = NULL; + + cublasOperation_t transa = CUBLAS_OP_T; + cublasOperation_t transb = CUBLAS_OP_N; + + int8_t fastAccum = 1; + + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, + Ddesc = NULL; + + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + + successOrExit( + cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + successOrExit(cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); + successOrExit(cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + if (dtype == CUDA_R_8F_E4M3) { + successOrExit(cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccum, + sizeof(fastAccum))); + } + + successOrExit(cublasLtMatrixLayoutCreate(&Adesc, dtype, k, m, k)); + successOrExit(cublasLtMatrixLayoutCreate(&Bdesc, dtype, k, n, k)); + successOrExit(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16F, m, n, m)); + successOrExit(cublasLtMatrixLayoutCreate(&Ddesc, dtype, m, n, m)); + + successOrExit(cublasLtMatmulAlgoGetHeuristic( + ltHandle, matmulDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedResults)); + if (returnedResults == 0) { + throw std::runtime_error( + "No valid algorithm found by cublasLtMatmulAlgoGetHeuristic"); + } + + float alpha = 1.0f; + float beta = 0.0f; + successOrExit(cublasLtMatmul(ltHandle, matmulDesc, &alpha, (void *)A, Adesc, + (void *)B, Bdesc, &beta, nullptr, Cdesc, + (void *)D, Ddesc, &heuristicResult.algo, + (void *)workspace, workspaceSize, 0)); + if (Ddesc) + successOrExit(cublasLtMatrixLayoutDestroy(Ddesc)); + if (Cdesc) + successOrExit(cublasLtMatrixLayoutDestroy(Cdesc)); + if (Bdesc) + successOrExit(cublasLtMatrixLayoutDestroy(Bdesc)); + if (Adesc) + successOrExit(cublasLtMatrixLayoutDestroy(Adesc)); + if (matmulDesc) + successOrExit(cublasLtMatmulDescDestroy(matmulDesc)); + } + +public: + CublasLtInstance(uint64_t workspace, size_t workspaceSize) + : workspace((void *)workspace), workspaceSize(workspaceSize) { + loadCublasDylib(); + cublasLtCreate(<Handle); + + successOrExit(cublasLtMatmulPreferenceCreate(&preference)); + successOrExit(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, + sizeof(workspaceSize))); + } + ~CublasLtInstance() { + if (preference) + successOrExit(cublasLtMatmulPreferenceDestroy(preference)); + + cublasLtDestroy(ltHandle); + unloadCublasDylib(); + } + + // C = A * B + // Matrix B needs to be transposed, while matrix A does not. The function + // *will-not* transpose the matrices, so the caller is responsible for + // ensuring that the matrices are in the correct format and have the correct + // dimensions. + void matmul(int m, int n, int k, uint64_t A, uint64_t B, uint64_t C, + cudaDataType_t dtype) { + // CUDA is column-major, while triton is row-major, therefore we need to + // reverse the order of the matrices ( A * B = (B^T * A^T)^T ). + matmul_impl(n, m, k, B, A, C, dtype); + } +}; + +#endif // TRITON_CUBLAS_INSTANCE_H diff --git a/third_party/enflame/include/triton/third_party/nvidia/include/cublas_types.h b/third_party/enflame/include/triton/third_party/nvidia/include/cublas_types.h new file mode 100644 index 000000000..74c18c68f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/include/cublas_types.h @@ -0,0 +1,151 @@ +#ifndef TRITON_CUBLAS_TYPES_H +#define TRITON_CUBLAS_TYPES_H + +// Forward declarations of cuBLAS types and functions. + +/* CUBLAS status type returns */ +typedef enum { + CUBLAS_STATUS_SUCCESS = 0, + CUBLAS_STATUS_NOT_INITIALIZED = 1, + CUBLAS_STATUS_ALLOC_FAILED = 3, + CUBLAS_STATUS_INVALID_VALUE = 7, + CUBLAS_STATUS_ARCH_MISMATCH = 8, + CUBLAS_STATUS_MAPPING_ERROR = 11, + CUBLAS_STATUS_EXECUTION_FAILED = 13, + CUBLAS_STATUS_INTERNAL_ERROR = 14, + CUBLAS_STATUS_NOT_SUPPORTED = 15, + CUBLAS_STATUS_LICENSE_ERROR = 16 +} cublasStatus_t; + +typedef enum { + CUBLAS_COMPUTE_16F = 64, /* half - default */ + CUBLAS_COMPUTE_16F_PEDANTIC = 65, /* half - pedantic */ + CUBLAS_COMPUTE_32F = 68, /* float - default */ + CUBLAS_COMPUTE_32F_PEDANTIC = 69, /* float - pedantic */ + CUBLAS_COMPUTE_32F_FAST_16F = + 74, /* float - fast, allows down-converting inputs to half or TF32 */ + CUBLAS_COMPUTE_32F_FAST_16BF = + 75, /* float - fast, allows down-converting inputs to bfloat16 or TF32 */ + CUBLAS_COMPUTE_32F_FAST_TF32 = + 77, /* float - fast, allows down-converting inputs to TF32 */ + CUBLAS_COMPUTE_64F = 70, /* double - default */ + CUBLAS_COMPUTE_64F_PEDANTIC = 71, /* double - pedantic */ + CUBLAS_COMPUTE_32I = 72, /* signed 32-bit int - default */ + CUBLAS_COMPUTE_32I_PEDANTIC = 73, /* signed 32-bit int - pedantic */ +} cublasComputeType_t; + +typedef enum { + CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0, + CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1, + CUBLASLT_MATMUL_DESC_POINTER_MODE = 2, + CUBLASLT_MATMUL_DESC_TRANSA = 3, + CUBLASLT_MATMUL_DESC_TRANSB = 4, + CUBLASLT_MATMUL_DESC_TRANSC = 5, + CUBLASLT_MATMUL_DESC_FILL_MODE = 6, + CUBLASLT_MATMUL_DESC_EPILOGUE = 7, + CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8, + CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE = 10, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER = 11, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD = 12, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE = 13, + CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE = 14, + CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 15, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 17, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 18, + CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 19, + CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 20, + CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 21, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE = 22, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER = 23, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER = 24, + CUBLASLT_MATMUL_DESC_FAST_ACCUM = 25, + CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 26, + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS = 27, + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS = 28, + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER = 29, + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER = 30, +} cublasLtMatmulDescAttributes_t; + +typedef enum { + CUBLAS_OP_N = 0, + CUBLAS_OP_T = 1, + CUBLAS_OP_C = 2, + CUBLAS_OP_HERMITAN = 2, /* synonym if CUBLAS_OP_C */ + CUBLAS_OP_CONJG = + 3 /* conjugate, placeholder - not supported in the current release */ +} cublasOperation_t; + +typedef enum { + CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1, + CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8, + CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9, + CUBLASLT_MATMUL_PREF_IMPL_MASK = 12, +} cublasLtMatmulPreferenceAttributes_t; +typedef struct { + uint64_t data[8]; +} cublasLtMatrixLayoutOpaque_t; +typedef cublasLtMatrixLayoutOpaque_t *cublasLtMatrixLayout_t; + +typedef struct { + uint64_t data[8]; +} cublasLtMatmulPreferenceOpaque_t; +typedef cublasLtMatmulPreferenceOpaque_t *cublasLtMatmulPreference_t; + +typedef struct { + uint64_t data[8]; +} cublasLtMatmulAlgo_t; + +typedef struct { + cublasLtMatmulAlgo_t algo; + size_t workspaceSize; + cublasStatus_t state; + float wavesCount; + int reserved[4]; +} cublasLtMatmulHeuristicResult_t; + +typedef enum cudaDataType_t { + CUDA_R_16F = 2, /* real as a half */ + CUDA_C_16F = 6, /* complex as a pair of half numbers */ + CUDA_R_16BF = 14, /* real as a nv_bfloat16 */ + CUDA_C_16BF = 15, /* complex as a pair of nv_bfloat16 numbers */ + CUDA_R_32F = 0, /* real as a float */ + CUDA_C_32F = 4, /* complex as a pair of float numbers */ + CUDA_R_64F = 1, /* real as a double */ + CUDA_C_64F = 5, /* complex as a pair of double numbers */ + CUDA_R_4I = 16, /* real as a signed 4-bit int */ + CUDA_C_4I = 17, /* complex as a pair of signed 4-bit int numbers */ + CUDA_R_4U = 18, /* real as a unsigned 4-bit int */ + CUDA_C_4U = 19, /* complex as a pair of unsigned 4-bit int numbers */ + CUDA_R_8I = 3, /* real as a signed 8-bit int */ + CUDA_C_8I = 7, /* complex as a pair of signed 8-bit int numbers */ + CUDA_R_8U = 8, /* real as a unsigned 8-bit int */ + CUDA_C_8U = 9, /* complex as a pair of unsigned 8-bit int numbers */ + CUDA_R_16I = 20, /* real as a signed 16-bit int */ + CUDA_C_16I = 21, /* complex as a pair of signed 16-bit int numbers */ + CUDA_R_16U = 22, /* real as a unsigned 16-bit int */ + CUDA_C_16U = 23, /* complex as a pair of unsigned 16-bit int numbers */ + CUDA_R_32I = 10, /* real as a signed 32-bit int */ + CUDA_C_32I = 11, /* complex as a pair of signed 32-bit int numbers */ + CUDA_R_32U = 12, /* real as a unsigned 32-bit int */ + CUDA_C_32U = 13, /* complex as a pair of unsigned 32-bit int numbers */ + CUDA_R_64I = 24, /* real as a signed 64-bit int */ + CUDA_C_64I = 25, /* complex as a pair of signed 64-bit int numbers */ + CUDA_R_64U = 26, /* real as a unsigned 64-bit int */ + CUDA_C_64U = 27, /* complex as a pair of unsigned 64-bit int numbers */ + CUDA_R_8F_E4M3 = 28, /* real as a nv_fp8_e4m3 */ + CUDA_R_8F_E5M2 = 29, /* real as a nv_fp8_e5m2 */ +} cudaDataType; + +struct cublasContext; +typedef struct cublasLtContext *cublasLtHandle_t; +struct cublasLtMatmulDescOpaque_t; +typedef cublasLtMatmulDescOpaque_t *cublasLtMatmulDesc_t; +struct CUstream_st; +typedef struct CUstream_st *cudaStream_t; + +#endif // TRITON_CUBLAS_TYPES_H diff --git a/third_party/enflame/include/triton/third_party/nvidia/language/cuda/__init__.py b/third_party/enflame/include/triton/third_party/nvidia/language/cuda/__init__.py new file mode 100644 index 000000000..9fffa216b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/language/cuda/__init__.py @@ -0,0 +1,13 @@ +from . import libdevice + +from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80) + +from ._experimental_tma import * # noqa: F403 +from ._experimental_tma import __all__ as _tma_all + +__all__ = [ + "libdevice", "globaltimer", "num_threads", "num_warps", "smid", "convert_custom_float8_sm70", + "convert_custom_float8_sm80", *_tma_all +] + +del _tma_all diff --git a/third_party/enflame/include/triton/third_party/nvidia/language/cuda/_experimental_tma.py b/third_party/enflame/include/triton/third_party/nvidia/language/cuda/_experimental_tma.py new file mode 100644 index 000000000..3f60197ed --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/language/cuda/_experimental_tma.py @@ -0,0 +1,106 @@ +from typing import Sequence + +from triton.language import core +from triton.language import semantic +from triton._C.libtriton import ir + +__all__ = [ + "experimental_device_tensormap_create1d", + "experimental_device_tensormap_create2d", + "experimental_tensormap_fenceproxy_acquire", +] + + +def _determine_elem_type(element_ty: core.dtype): + if element_ty.primitive_bitwidth == 8: + return 0 + elif element_ty.primitive_bitwidth == 16: + return 1 + elif element_ty.primitive_bitwidth == 32: + return 2 + else: + raise ValueError("element_ty must be a primitive of size 1, 2, or 4 bytes but got") + + +@core.builtin +def experimental_device_tensormap_create1d( + desc_ptr: core.tensor, + global_address: core.tensor, + load_size: core.tensor, + global_size: core.tensor, + element_ty: core.dtype, + _builder: ir.builder = None, +): + load_size = core._constexpr_to_value(load_size) + global_size = semantic.to_tensor(global_size, _builder) + element_ty = core._constexpr_to_value(element_ty) + element_stride = [core.full([], 1, core.int32, _builder=_builder)] + + semantic.tensormap_create( + desc_ptr=desc_ptr, + global_address=global_address, + box_dim=[semantic.to_tensor(load_size, _builder)], + global_dim=[global_size], + global_stride=[], + element_stride=element_stride, + elem_type=_determine_elem_type(element_ty), + interleave_layout=0, + swizzle_mode=0, + fill_mode=0, + builder=_builder, + ) + + +@core.builtin +def experimental_device_tensormap_create2d( + desc_ptr: core.tensor, + global_address: core.tensor, + load_size: Sequence[core.constexpr], + global_size: Sequence[core.tensor], + element_ty: core.dtype, + _builder: ir.builder = None, +): + assert len(load_size) == 2 + assert len(global_size) == 2 + load_size = [core._constexpr_to_value(x) for x in load_size] + global_size = [semantic.to_tensor(x, _builder) for x in global_size] + + element_size = element_ty.primitive_bitwidth // 8 + element_size_t = core.full([], element_size, core.int64, _builder=_builder) + global_stride = semantic.mul(element_size_t, global_size[-1], True, _builder) + + contig_dim_size_in_bytes = element_size * load_size[-1] + if contig_dim_size_in_bytes > 128: + load_size[-1] = 128 // element_size + + elem_stride = core.full([], 1, core.int32, _builder=_builder) + + semantic.tensormap_create( + desc_ptr=desc_ptr, + global_address=global_address, + box_dim=[semantic.to_tensor(x, _builder) for x in load_size[::-1]], + global_dim=global_size[::-1], + global_stride=[global_stride], + element_stride=[elem_stride, elem_stride], + elem_type=_determine_elem_type(element_ty), + interleave_layout=0, + swizzle_mode=_determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size), + fill_mode=0, + builder=_builder, + ) + + +def _determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size): + if contig_dim_size_in_bytes >= 128: + return 3 + elif contig_dim_size_in_bytes >= 64: + return 2 + elif contig_dim_size_in_bytes >= 32: + return 1 + else: + raise ValueError("block size too small") + + +@core.builtin +def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder = None): + semantic.tensormap_fenceproxy_acquire(desc_ptr, _builder) diff --git a/third_party/enflame/include/triton/third_party/nvidia/language/cuda/libdevice.py b/third_party/enflame/include/triton/third_party/nvidia/language/cuda/libdevice.py new file mode 100644 index 000000000..37e810bb1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/language/cuda/libdevice.py @@ -0,0 +1,1629 @@ +from triton.language import core + + +@core.extern +def clz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def popc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__nv_dmul_ru", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2loint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_sinf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_cosf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log2f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_logf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_tanf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_exp10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_norm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_rnorm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def yn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def jn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) diff --git a/third_party/enflame/include/triton/third_party/nvidia/language/cuda/utils.py b/third_party/enflame/include/triton/third_party/nvidia/language/cuda/utils.py new file mode 100644 index 000000000..01bc040b2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/language/cuda/utils.py @@ -0,0 +1,109 @@ +from triton.language import core + + +@core.extern +def globaltimer(_builder=None): + return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, + _builder=_builder) + + +@core.extern +def smid(_builder=None): + return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, + _builder=_builder) + + +@core.builtin +def num_threads(_builder=None): + return core.constexpr(_builder.options.num_warps * 32) + + +@core.builtin +def num_warps(_builder=None): + return core.constexpr(_builder.options.num_warps) + + +# ----- FP8E4M3B15 ------ +# This data-type is a variant of the standard FP8E4M3 format. +# It was designed for fast software conversion to FP16 on +# nvidia GPUs that do not support it natively. +# This is the same format as FP8E4M3Nv, but: +# - the exponent bias is 15 instead of 7 +# - 0xff and 0x7f are mapped to +-1.750 instead of +-nan +@core.builtin +def convert_fp8e4b15_to_float16(arg, _builder=None): + return core.inline_asm_elementwise( + "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5746; \n" + "and.b32 b0, a0, 0x7f007f00; \n" + "and.b32 b1, a0, 0x00ff00ff; \n" + "and.b32 a1, a0, 0x00800080; \n" + "shr.b32 b0, b0, 1; \n" + "add.u32 b1, b1, a1; \n" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" + "shl.b32 $1, b1, 7; \n" + "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None): + asm = """{ + .reg .pred p<4>; + .reg .b32 a<2>, b<2>; + .reg .b16 c<4>; + .reg .b16 max_val_f16; + .reg .b32 max_val_f16x2; + mov.b16 max_val_f16, 0x3F00; + mov.b32 max_val_f16x2, 0x3F003F00; + and.b32 a0, $1, 0x7fff7fff; + and.b32 a1, $2, 0x7fff7fff;""" + if has_minx2: + asm += """min.f16x2 a0, a0, max_val_f16x2; + min.f16x2 a1, a1, max_val_f16x2;""" + else: + asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2; + setp.lt.f16x2 p2|p3, a1, max_val_f16x2; + mov.b32 {c0, c1}, a0; + mov.b32 {c2, c3}, a1; + selp.b16 c0, c0, max_val_f16, p0; + selp.b16 c1, c1, max_val_f16, p1; + selp.b16 c2, c2, max_val_f16, p2; + selp.b16 c3, c3, max_val_f16, p3; + mov.b32 a0, {c0, c1}; + mov.b32 a1, {c2, c3};""" + asm += """mad.lo.u32 a0, a0, 2, 0x00800080; + mad.lo.u32 a1, a1, 2, 0x00800080; + lop3.b32 b0, $1, 0x80008000, a0, 0xea; + lop3.b32 b1, $2, 0x80008000, a1, 0xea; + prmt.b32 $0, b0, b1, 0x7531; + }""" + return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None): + if arg.type.scalar.is_fp8e4b15(): + upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder) + if dst_ty.scalar.is_fp32(): + upcast_val = upcast_val.to(core.float32, _builder=_builder) + return upcast_val + + assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32() + downcast_val = arg + if arg.type.scalar.is_fp32(): + downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder) + downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder) + return downcast_val + + +@core.builtin +def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder) + + +@core.builtin +def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder) diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/lib/CMakeLists.txt new file mode 100644 index 000000000..2ef7aab10 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Dialect) +add_subdirectory(TritonNVIDIAGPUToLLVM) +add_subdirectory(NVGPUToLLVM) diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..edeac0660 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(NVGPU) diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/IR/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..1fd118d2b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(NVGPUIR + Dialect.cpp + + DEPENDS + NVGPUTableGen + NVGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect +) diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp new file mode 100644 index 000000000..f623f50c6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "Dialect/NVGPU/IR/Dialect.h" +#include "Dialect/NVGPU/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::nvgpu; + +void mlir::triton::nvgpu::NVGPUDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/NVGPU/IR/Ops.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "Dialect/NVGPU/IR/Ops.cpp.inc" +#include "Dialect/NVGPU/IR/OpsEnums.cpp.inc" diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/NVGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/lib/NVGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..937702e5c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/NVGPUToLLVM/CMakeLists.txt @@ -0,0 +1,7 @@ +add_triton_library(NVGPUToLLVM + NVGPUToLLVMPass.cpp + + DEPENDS + NVGPUConversionPassIncGen + NVGPUIR +) diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp new file mode 100644 index 000000000..81116188b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -0,0 +1,943 @@ +#include "NVGPUToLLVM/NVGPUToLLVMPass.h" + +#include "Dialect/NVGPU/IR/Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::triton; + +#define GEN_PASS_CLASSES +#include "NVGPUToLLVM/Passes.h.inc" + +namespace ttn = mlir::triton::nvgpu; +using ttn::Constraints; +using ttn::OperandsAndConstraints; + +namespace { + +const std::string kWgmmaFenceOp = "wgmma.fence.sync.aligned;"; +const std::string kWgmmaCommitGroupOp = "wgmma.commit_group.sync.aligned;"; +const std::string kClusterWaitOp = "barrier.cluster.wait.aligned;"; +const std::string kFenceMbarrierInitOp = "fence.mbarrier_init.release.cluster;"; +const std::string kClusterCtaIdOp = "{\n" + ".reg .u32 a<5>; \n" + "mov.u32 a0, %cluster_ctaid.x;\n" // x + "mov.u32 a1, %cluster_ctaid.y;\n" // y + "mov.u32 a2, %cluster_ctaid.z;\n" // z + "mov.u32 a3, %cluster_nctaid.x;\n" // nx + "mov.u32 a4, %cluster_nctaid.y;\n" // ny + "mad.lo.u32 a1, a2, a4, a1; \n" + "mad.lo.u32 $0, a1, a3, a0; \n" + "}"; +const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;"; +const std::string Reg_Dealloc_Op = "setmaxnreg.dec.sync.aligned.u32 #regCount;"; + +const std::string Named_Barrier_Arrive_Op = "bar.arrive $0, $1;"; +const std::string Named_Barrier_Wait_Op = "bar.sync $0, $1;"; +const std::string Canonical_Warp_Id_Op = + "{\n" + ".reg .u32 a<5>; \n" + "mov.u32 a0, %tid.x; \n" // x + "mov.u32 a1, %tid.y; \n" // y + "mov.u32 a2, %tid.z; \n" // z + "mov.u32 a3, %ntid.x; \n" // nx + "mov.u32 a4, %ntid.y; \n" // ny + "mad.lo.u32 a1, a2, a4, a1; \n" + "mad.lo.u32 a0, a1, a3, a0; \n" + "shr.u32 a0, a0, 5; \n" + ".reg .b32 %tmp<3>; \n" + "mov.u32 %tmp0, -1; \n" + "mov.u32 %tmp1, 31; \n" + "mov.u32 %tmp2, 0; \n" + "shfl.sync.idx.b32 $0, a0, %tmp2, %tmp1, %tmp0; \n" + "}"; + +bool isNumber(const std::string &s) { + return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { + return !std::isdigit(c); + }) == s.end(); +} + +Type getTypeFromConstraint(char constraint, PatternRewriter &rewriter) { + Type ty; + if (constraint == 'b') + ty = IntegerType::get(rewriter.getContext(), 1); + else if (constraint == 'h') + ty = IntegerType::get(rewriter.getContext(), 16); + else if (constraint == 'r') + ty = IntegerType::get(rewriter.getContext(), 32); + else if (constraint == 'l') + ty = IntegerType::get(rewriter.getContext(), 64); + else if (constraint == 'f') + ty = Float32Type::get(rewriter.getContext()); + else if (constraint == 'd') + ty = Float64Type::get(rewriter.getContext()); + else { + assert(false && "Unsupported constraint"); + } + return ty; +} + +// Converts the given value to the type represented by the constraint +// E.g. if val is of type llvmptr and constraint is 'r', then we convert +// val to i32 using ptrtoint(i32_ty, val) +Value convertToType(Value val, std::string constraint, Location loc, + PatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto isConstraintNumber = isNumber(constraint); + if (!isConstraintNumber) { + auto ty = getTypeFromConstraint(constraint[0], rewriter); + if (isa(val.getType())) { + return b.ptrtoint(ty, val); + } else { + assert(val.getType().getIntOrFloatBitWidth() <= + ty.getIntOrFloatBitWidth() && + "Cannot convert to a smaller type"); + if (val.getType().getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth()) + return b.zext(ty, val); + } + } + return val; +} + +SmallVector +getPtxOutputs(const nvgpu::Constraints &outputConstraints, + PTXBuilder &ptxBuilder) { + SmallVector ptxOutputs; + for (unsigned i = 0; i < outputConstraints.size(); i++) { + auto *ptxOutput = ptxBuilder.newOperand(outputConstraints[i]); + ptxOutputs.push_back(ptxOutput); + } + return ptxOutputs; +} + +OperandsAndConstraints +unpackOperands(const OperandsAndConstraints &operandsAndConstraints, + PTXBuilder &ptxBuilder, Location loc, + PatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + OperandsAndConstraints unpackedOperands; + for (const auto &[operand, constraint] : operandsAndConstraints) { + auto llvmStruct = llvm::dyn_cast(operand.getType()); + // if a constraint is a number, then we are doing input/output tying + // if the operand is a struct, then we need to unpack it, and + // add the constraint to each of the unpacked operands uses the constraint + // as an offset + auto isConstraintNumber = isNumber(constraint); + if (llvmStruct) { + for (unsigned i = 0; i < llvmStruct.getBody().size(); i++) { + if (isConstraintNumber) { + auto constraintInt = std::stoi(constraint) + i; + unpackedOperands.push_back( + {b.extract_val(llvmStruct.getBody()[i], operand, i), + std::to_string(constraintInt)}); + } else { + unpackedOperands.push_back( + {b.extract_val(llvmStruct.getBody()[i], operand, i), constraint}); + } + } + } else { + unpackedOperands.push_back({operand, constraint}); + } + } + return unpackedOperands; +} + +SmallVector +getPtxOperands(const OperandsAndConstraints &operandsAndConstraints, + PTXBuilder &ptxBuilder, Location loc, + PatternRewriter &rewriter) { + SmallVector ptxOperands; + auto unpackedOperandsAndConstraints = + unpackOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); + for (auto &[operand, constraint] : unpackedOperandsAndConstraints) { + auto convertedOperand = convertToType(operand, constraint, loc, rewriter); + auto *ptxOperand = ptxBuilder.newOperand(convertedOperand, constraint); + ptxOperands.push_back(ptxOperand); + } + return ptxOperands; +} + +std::string patchPtxAsm(Operation *op, std::string ptxAsm) { + std::vector> patchLocations; + std::vector patchValues; + auto start = ptxAsm.find("#", 0); + while (start != std::string::npos) { + auto endIterator = + std::find_if(ptxAsm.begin() + start + 1, ptxAsm.end(), + [](unsigned char c) { return !std::isalnum(c); }); + + assert(endIterator != ptxAsm.end() && "unexpected asm format"); + + auto end = std::distance(ptxAsm.begin(), endIterator); + auto patchLocation = std::make_pair(start, end); + patchLocations.push_back(patchLocation); + auto patchValue = ptxAsm.substr(start + 1, end - start - 1); + patchValues.push_back(patchValue); + start = ptxAsm.find("#", end); + } + assert(patchLocations.size() == patchValues.size() && + "patchLocations and patchValues should have the same size"); + if (patchLocations.size() == 0) { + return ptxAsm; + } + std::string res = ""; + size_t prevStart = 0; + unsigned i = 0; + for (auto &[start, end] : patchLocations) { + res += ptxAsm.substr(prevStart, start - prevStart); + auto integerAttr = op->getAttrOfType(patchValues[i]); + auto attr = integerAttr.getInt(); + res += std::to_string(attr); + prevStart = end; + i++; + } + if (prevStart < ptxAsm.size()) + res += ptxAsm.substr(prevStart, ptxAsm.size() - prevStart); + return res; +} + +template +class NVGPUOpGenericPattern : public OpRewritePattern { +public: + explicit NVGPUOpGenericPattern(MLIRContext *context, std::string ptxAsm, + Constraints outputConstraints, + Constraints inputConstraints) + : OpRewritePattern(context), ptxAsm(std::move(ptxAsm)), + outputConstraints(outputConstraints), + inputConstraints(inputConstraints) {} + + LogicalResult matchAndRewrite(SourceOp op, + PatternRewriter &rewriter) const override { + OperandsAndConstraints operandsAndConstraints; + for (unsigned i = 0; i < inputConstraints.size(); i++) { + operandsAndConstraints.push_back( + {op->getOperand(i), inputConstraints[i]}); + } + return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandsAndConstraints, + outputConstraints); + } + +private: + std::string ptxAsm; + Constraints outputConstraints; + Constraints inputConstraints; +}; + +class FenceAsyncSharedOpPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::FenceAsyncSharedOp op, + PatternRewriter &rewriter) const override { + std::string ptxAsm = op.getBCluster() ? "fence.proxy.async.shared::cluster;" + : "fence.proxy.async.shared::cta;"; + return rewriteAsPtxAsm(op, rewriter, std::move(ptxAsm)); + } +}; + +class WarpIdOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::WarpIdOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // If this is inside a warp specialize op, compute the relative thread ID + // within the warp group. + Value tid = rewriter.create(loc, i32_ty); + if (std::optional startId = + getWarpGroupStartThreadId(rewriter.getInsertionBlock())) + tid = rewriter.create(loc, tid, b.i32_val(*startId)); + + Value warpId = b.udiv(tid, b.i32_val(32)); + // This indicates to PTXAS that the result and its derived values are + // uniform across the warp. For example, if a branch condition derives from + // this value, it can be proven to be non-divergent. + warpId = LLVM::NVIDIA::shuffleIdx(loc, rewriter, warpId, 0); + rewriter.replaceOp(op, warpId); + return success(); + } +}; + +class ClusterArriveOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::ClusterArriveOp op, + PatternRewriter &rewriter) const override { + std::string ptxAsm = op.getRelaxed() + ? "barrier.cluster.arrive.relaxed.aligned;" + : "barrier.cluster.arrive.aligned;"; + return rewriteAsPtxAsm(op, rewriter, std::move(ptxAsm)); + } +}; + +// Base class for Matrix Operation Patterns +template +class MatrixOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MatrixOpType op, + PatternRewriter &rewriter) const override { + unsigned vecSize = getVectorSize(op); + bool trans = op.getTrans(); + // Template method for PTX assembly generation + std::string ptxAsm = + (llvm::Twine(ConcreteMatrixOpPattern::kOpCode) + + getPtxModifiers(vecSize, trans) + " " + getOperands(op, vecSize) + ";") + .str(); + + OperandsAndConstraints operandAndConstraints = + getOperandsAndConstraints(op, vecSize); + Constraints outputConstraints = getOutputConstraints(op, vecSize); + + return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandAndConstraints, + outputConstraints); + } + +protected: + // Shared helper methods + std::string getPtxModifiers(unsigned vecSize, bool trans) const { + auto ptxAsmBase = llvm::Twine(".sync.aligned.m8n8"); + const std::string suffix = trans ? ".trans.shared.b16" : ".shared.b16"; + switch (vecSize) { + case 1: + return (ptxAsmBase + ".x1" + suffix).str(); + case 2: + return (ptxAsmBase + ".x2" + suffix).str(); + case 4: + return (ptxAsmBase + ".x4" + suffix).str(); + default: + llvm_unreachable("Invalid vector size"); + } + } + + std::string getPtxRegOperands(unsigned startIdx, unsigned count) const { + llvm::SmallString<20> regOperands; + llvm::raw_svector_ostream stream(regOperands); + stream << "{"; + for (unsigned i = 0; i < count; i++) { + stream << "$" + llvm::utostr(startIdx + i); + if (i != count - 1) + stream << ", "; + } + stream << "}"; + return std::string(regOperands.str()); + } + + std::string getPtxAddrOperand(unsigned idx) const { + return (llvm::Twine("[$") + llvm::utostr(idx) + "]").str(); + } + + virtual std::string getOperands(MatrixOpType op, unsigned vecSize) const = 0; + virtual OperandsAndConstraints + getOperandsAndConstraints(MatrixOpType op, unsigned vecSize) const = 0; + virtual Constraints getOutputConstraints(MatrixOpType op, + unsigned vecSize) const = 0; + virtual unsigned getVectorSize(MatrixOpType op) const = 0; +}; + +// StoreMatrixOp Pattern +class StoreMatrixOpPattern + : public MatrixOpPattern { +public: + using MatrixOpPattern::MatrixOpPattern; + static constexpr const char *kOpCode = "stmatrix"; + +protected: + unsigned getVectorSize(ttn::StoreMatrixOp op) const override { + return op.getVals().size(); + } + + std::string getOperands(ttn::StoreMatrixOp op, + unsigned vecSize) const override { + return (llvm::Twine(getPtxAddrOperand(0)) + ", " + + getPtxRegOperands(1, vecSize)) + .str(); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::StoreMatrixOp op, + unsigned vecSize) const override { + OperandsAndConstraints constraints = {{op.getAddr(), "r"}}; + for (unsigned i = 0; i < vecSize; i++) { + constraints.push_back({op.getVals()[i], "r"}); + } + return constraints; + } + + Constraints getOutputConstraints(ttn::StoreMatrixOp op, + unsigned vecSize) const override { + return {}; // No output constraints for StoreMatrixOp + } +}; + +// LoadMatrixOp Pattern +class LoadMatrixOpPattern + : public MatrixOpPattern { +public: + using MatrixOpPattern::MatrixOpPattern; + static constexpr const char *kOpCode = "ldmatrix"; + +protected: + unsigned getVectorSize(ttn::LoadMatrixOp op) const override { + auto resultType = cast(op.getType()); + return resultType.getBody().size(); + } + + std::string getOperands(ttn::LoadMatrixOp op, + unsigned vecSize) const override { + return (llvm::Twine(getPtxRegOperands(0, vecSize)) + ", " + + getPtxAddrOperand(vecSize)) + .str(); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::LoadMatrixOp op, + unsigned vecSize) const override { + return {{op.getAddr(), "r"}}; + } + + Constraints getOutputConstraints(ttn::LoadMatrixOp op, + unsigned vecSize) const override { + return Constraints(vecSize, "=r"); + } +}; + +class LoadAcquireOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::LoadAcquireOp op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Type valueTy = op.getType(); + const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth()); + const size_t maxWordWidth = std::max(32, valueNBits); + const size_t width = std::min((size_t)valueNBits, maxWordWidth); + + const std::string writeConstraint = + (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); + PTXBuilder ptxBuilder; + bool init = true; + auto *dstOpr = ptxBuilder.newOperand(writeConstraint, init); // =r operation + auto *addrOpr = + ptxBuilder.newAddrOperand(op.getAddr(), "l", 0 /* in_off */); + auto &ld = + ptxBuilder.create<>("ld") + ->global() + .o("cta", op.getScope() == triton::nvgpu::MemSyncScope::CTA) + .o("gpu", op.getScope() == triton::nvgpu::MemSyncScope::GPU) + .o("sys", op.getScope() == triton::nvgpu::MemSyncScope::SYSTEM) + .o("acquire", op.getSem() == triton::nvgpu::MemSemantic::ACQUIRE) + .o("relaxed", op.getSem() == triton::nvgpu::MemSemantic::RELAXED) + .b(width); + ld(dstOpr, addrOpr).maybePredicate(op.getMask(), "b"); + + // Create inline ASM signature + Type retTy = IntegerType::get(getContext(), width); + Value ret = ptxBuilder.launch(rewriter, loc, retTy); + ret = b.bitcast(ret, op.getType()); + + rewriter.replaceOp(op, {ret}); + return success(); + } +}; + +class MBarrierArriveOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::MBarrierArriveOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op)); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::MBarrierArriveOp op) const { + OperandsAndConstraints operandsAndTypes; + Value mbarrier = op.getMbarrier(); + Value pred = op.getPred(); + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + + switch (arriveType) { + case ttn::MBarriveType::normal: + case ttn::MBarriveType::cp_async: + case ttn::MBarriveType::expect_tx: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + case ttn::MBarriveType::remote: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::MBarrierArriveOp op) const { + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + uint32_t txCount = op.getTxCount(); + std::string ptxAsm; + switch (arriveType) { + case ttn::MBarriveType::normal: + ptxAsm = "@$1 mbarrier.arrive.shared.b64 _, [$0];"; + break; + case ttn::MBarriveType::cp_async: + ptxAsm = "@$1 cp.async.mbarrier.arrive.noinc.shared.b64 [$0];"; + break; + case ttn::MBarriveType::expect_tx: + assert(txCount > 0 && "txCount should be valid"); + ptxAsm = "@$1 mbarrier.arrive.expect_tx.shared.b64 _, [$0], " + + std::to_string(txCount) + ";"; + break; + case ttn::MBarriveType::remote: + assert(ctaId && "ctaId should have a valid value"); + ptxAsm = + " { .reg .b32 remAddr32; \n" + " @$2 mapa.shared::cluster.u32 remAddr32, $0, $1; \n" + " @$2 mbarrier.arrive.shared::cluster.b64 _, [remAddr32]; } \n"; + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return ptxAsm; + } +}; + +class WGMMAWaitGroupOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::WGMMAWaitGroupOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op), + getOutputConstraints(op)); + } + + Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { + auto outputStructType = cast(op.getType()); + uint32_t numOutputRegs = outputStructType.getBody().size(); + std::string output = + outputStructType.getBody().front().isF32() ? "=f" : "=r"; + return Constraints(numOutputRegs, output); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::WGMMAWaitGroupOp op) const { + OperandsAndConstraints operandsAndConstraints; + auto input = op.getInput(); + operandsAndConstraints.push_back({input, "0"}); + return operandsAndConstraints; + } + + std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const { + auto outputStructType = dyn_cast(op.getType()); + uint32_t numCRegs = outputStructType.getBody().size(); + std::string args = ""; + uint32_t asmOpIdx = 0; + for (uint32_t i = 0; i < numCRegs; ++i) { + args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ","); + } + auto ptxAsm = "// wait for regs: " + args + "\n\t" + + "wgmma.wait_group.sync.aligned #pendings;"; + return ptxAsm; + } +}; + +class WGMMAOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::WGMMAOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op), + getOutputConstraints(op)); + } + + std::vector getOutputConstraints(ttn::WGMMAOp op) const { + // TODO (zahi): Return type must always be a struct for wgmma, currently + // we rely on the size of output constraints vector to determine whether + // the output is a struct or not. We should find a way to pass this info + auto resultType = op.getType(); + + auto outputStructType = dyn_cast(resultType); + uint32_t numOutputRegs = outputStructType.getBody().size(); + std::string output = + outputStructType.getBody().front().isF32() ? "=f" : "=r"; + return std::vector(numOutputRegs, output); + } + + OperandsAndConstraints getOperandsAndConstraints(ttn::WGMMAOp op) const { + OperandsAndConstraints operandsAndConstraints; + auto opA = op.getOpA(); + auto opB = op.getOpB(); + auto opC = op.getOpC(); + auto opScaleD = op.getUseC(); + auto typeA = opA.getType(); + + auto structTypeA = dyn_cast(typeA); + + // TODO (zahi): is this the best way to tie inputs/outputs ? + if (opC) + operandsAndConstraints.push_back({opC, "0"}); + + if (structTypeA) { + operandsAndConstraints.push_back({opA, "r"}); + } else { + operandsAndConstraints.push_back({opA, "l"}); + } + + // Operand B (must be `desc`) + operandsAndConstraints.push_back({opB, "l"}); + + // `scale-d` + if (op.getOpC()) + operandsAndConstraints.push_back({opScaleD, "b"}); + + return operandsAndConstraints; + } + + std::string getPtxAsm(ttn::WGMMAOp op) const { + using namespace ttn; + auto opA = op.getOpA(); + auto opB = op.getOpB(); + auto m = op.getM(); + auto n = op.getN(); + auto k = op.getK(); + auto eltTypeC = op.getEltTypeC(); + auto eltTypeA = op.getEltTypeA(); + auto eltTypeB = op.getEltTypeB(); + auto layoutA = op.getLayoutA(); + auto layoutB = op.getLayoutB(); + + // Register checks + auto typeA = opA.getType(); + auto typeB = opB.getType(); + auto typeOutput = op.getType(); + auto structTypeA = dyn_cast(typeA); + auto structTypeB = dyn_cast(typeB); + auto structTypeOutput = dyn_cast(typeOutput); + assert(!structTypeB && "Operand B can not be registers"); + assert(structTypeOutput && "Output and C operand must be registers"); + + // Element type, MNK shape and transposing support check + // Reference: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma + bool transA = layoutA == WGMMALayout::col; + bool transB = layoutB == WGMMALayout::row; + bool supported = false, needTransArgs = false, floatTypeWGMMA = false; + assert(m % 8 == 0 && n % 8 == 0 && k % 8 == 0); + // Below instructions do support transposing, must pass `trans` arguments + supported |= + (eltTypeA == WGMMAEltType::f16) && (eltTypeB == WGMMAEltType::f16) && + (eltTypeC == WGMMAEltType::f16 || eltTypeC == WGMMAEltType::f32) && + (m == 64 && 8 <= n && n <= 256 && k == 16); + supported |= (eltTypeA == WGMMAEltType::bf16) && + (eltTypeB == WGMMAEltType::bf16) && + (eltTypeC == WGMMAEltType::f32) && + (m == 64 && 8 <= n && n <= 256 && k == 16); + needTransArgs = supported; + floatTypeWGMMA = supported; + // Below instructions do not support transposing + if (!supported && !transA && !transB) { + supported |= (eltTypeA == WGMMAEltType::tf32) && + (eltTypeB == WGMMAEltType::tf32) && + (eltTypeC == WGMMAEltType::f32) && + (m == 64 && 8 <= n && n <= 256 && k == 8); + supported |= + (eltTypeA == WGMMAEltType::e4m3 || eltTypeA == WGMMAEltType::e5m2) && + (eltTypeB == WGMMAEltType::e4m3 || eltTypeB == WGMMAEltType::e5m2) && + (eltTypeC == WGMMAEltType::f16 || eltTypeC == WGMMAEltType::f32) && + (m == 64 && 8 <= n && n <= 256 && k == 32); + floatTypeWGMMA = supported; + // Below instructions are integer-based + supported |= (eltTypeA == WGMMAEltType::s8) && + (eltTypeB == WGMMAEltType::s8) && + (eltTypeC == WGMMAEltType::s32) && + (m == 64 && 8 <= n && n <= 224 && k == 32); + } + assert(supported && "WGMMA type or shape is not supported"); + + // Operands + uint32_t asmOpIdx = 0; + std::string args = ""; + + // Output and operand C + uint32_t numCRegs = structTypeOutput.getBody().size(); + + args += "{"; + for (uint32_t i = 0; i < numCRegs; ++i) { + args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ","); + } + args += "}, "; + + if (op.getOpC()) + asmOpIdx += numCRegs; + + // Operand A + if (structTypeA) { + uint32_t numARegs = structTypeA.getBody().size(); + args += "{"; + for (uint32_t i = 0; i < numARegs; ++i) { + args += + "$" + std::to_string(asmOpIdx++) + (i == numARegs - 1 ? "" : ","); + } + args += "}, "; + } else { + args += "$" + std::to_string(asmOpIdx++) + ", "; + } + + // Operand B (must be `desc`) + args += "$" + std::to_string(asmOpIdx++) + ", "; + + // `scale-d` + if (op.getOpC()) + args += "$" + std::to_string(asmOpIdx++); + else + args += "0"; + + // `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based + // WGMMA + if (floatTypeWGMMA) + args += ", 1, 1"; + + // Push `trans-a` and `trans-b` args if needed (determined as constant) + if (needTransArgs) { + if (!structTypeA) + args += ", " + std::to_string(transA); + args += ", " + std::to_string(transB); + } + + auto ptxAsm = "wgmma.mma_async.sync.aligned" + ".m" + + std::to_string(m) + "n" + std::to_string(n) + "k" + + std::to_string(k) + "." + stringifyEnum(eltTypeC).str() + + "." + stringifyEnum(eltTypeA).str() + "." + + stringifyEnum(eltTypeB).str() + " " + args + ";"; + return ptxAsm; + } +}; + +static Value createTMAlloc(IRRewriter &rewriter, LLVM::LLVMFuncOp func, + size_t size, Value pred, bool twoCTAs) { + PTXBuilder ptxBuilder; + Location loc = func.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value sharedMem = mlir::LLVM::getStackPointer(rewriter, func); + std::string ptxString = + "@$0 tcgen05.alloc.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + + ".sync.aligned.shared::cta.b32 [$1], " + std::to_string(size) + ";"; + + auto &allocOp = *ptxBuilder.create<>(ptxString); + allocOp( + {ptxBuilder.newOperand(pred, "b"), ptxBuilder.newOperand(sharedMem, "r")}, + /*onlyAttachMLIRArgs=*/true); + auto voidTy = void_ty(func->getContext()); + ptxBuilder.launch(rewriter, loc, void_ty(func->getContext())); + rewriter.create(loc); + Value address = b.load(i32_ty, sharedMem); + rewriter.create(loc); + address = b.inttoptr(ptr_ty(func.getContext(), 6), address); + return address; +} + +static void createRelinquishAlloc(IRRewriter &rewriter, Location loc, + Value pred, bool twoCTAs) { + PTXBuilder ptxBuilder; + std::string ptxString = "@$0 tcgen05.relinquish_alloc_permit.cta_group::" + + std::to_string(twoCTAs ? 2 : 1) + ".sync.aligned;"; + auto &f = *ptxBuilder.create<>(ptxString); + f({ptxBuilder.newOperand(pred, "b")}, /*onlyAttachMLIRArgs=*/true); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +void freeTMAlloc(LLVM::LLVMFuncOp func, Value alloc, size_t size, Value pred, + bool twoCTAs) { + func.walk([&](LLVM::ReturnOp ret) { + OpBuilder b(ret); + auto ctx = ret->getContext(); + auto loc = ret.getLoc(); + auto voidTy = void_ty(ctx); + PTXBuilder ptxBuilder; + // Calculate the predicate in the inline asm to avoid creating long + // liveranges. + std::string ptxString = + "@$0 tcgen05.dealloc.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + + ".sync.aligned.b32 $1, " + std::to_string(size) + ";"; + auto &dealloc = *ptxBuilder.create<>(ptxString); + dealloc( + {ptxBuilder.newOperand(pred, "b"), ptxBuilder.newOperand(alloc, "r")}, + /*onlyAttachMLIRArgs=*/true); + ptxBuilder.launch(b, loc, void_ty(ctx)); + }); +} + +static Value initTensorMemory(LLVM::LLVMFuncOp func) { + auto mod = func->getParentOfType(); + assert(mod->hasAttr("ttg.tensor_memory_size")); + size_t size = cast(mod->getAttr("ttg.tensor_memory_size")) + .getValue() + .getZExtValue(); + if (size == 0) + return Value(); + IRRewriter rewriter(func.getContext()); + rewriter.setInsertionPointToStart(&func.front()); + auto ctx = mod.getContext(); + auto loc = func.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + // A proper error will be raised by the frontend, but to allow compilation to + // continue we emit a trap. + if (size > 512) { + rewriter.create(loc); + return rewriter.create(loc, ptr_ty(ctx, 6)); + } + + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + // Assume that 2CTAs is used if we have two CTAs this is pessimistic but + // should be fine for now. + bool useTwoCTAs = numCTAs == 2; + // This code is only executed by the default warp group. + Value threadId = rewriter.create(loc, i32_ty); + Value pred = b.icmp_ult(threadId, b.i32_val(32)); + Value alloc = createTMAlloc(rewriter, func, size, pred, useTwoCTAs); + createRelinquishAlloc(rewriter, loc, pred, useTwoCTAs); + // TODO: pred will have a long liverange, we need to check if this is a + // problem and how it can be fixed. + freeTMAlloc(func, alloc, size, pred, useTwoCTAs); + return alloc; +} + +static void lowerTensorMemoryAlloc(ModuleOp mod) { + SmallVector baseOps; + LLVM::LLVMFuncOp kernel = nullptr; + mod.walk([&](ttn::TensorMemoryBaseAddress baseOp) { + baseOps.push_back(baseOp); + if (!kernel) + kernel = baseOp->getParentOfType(); + assert(kernel == baseOp->getParentOfType() && + "TODO: add support for function calls using tmem."); + }); + if (baseOps.empty()) + return; + // TODO: Handle cases of matmul used in noinline functions. + assert(LLVM::isKernel(kernel)); + Value newBase = initTensorMemory(kernel); + if (!newBase) + return; + for (auto baseOp : baseOps) { + baseOp->getResult(0).replaceAllUsesWith(newBase); + baseOp->erase(); + } +} + +class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { + +public: + explicit ConvertNVGPUToLLVM() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + RewritePatternSet patterns(context); + +#define POPULATE_NVGPU_OP(SRC_OP, ASM) \ + patterns.add>(context, ASM, Constraints(), \ + Constraints()); + POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, kWgmmaFenceOp) + POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, kWgmmaCommitGroupOp) + POPULATE_NVGPU_OP(ttn::ClusterWaitOp, kClusterWaitOp) + POPULATE_NVGPU_OP(ttn::RegAllocOp, Reg_Alloc_Op) + POPULATE_NVGPU_OP(ttn::RegDeallocOp, Reg_Dealloc_Op) +#undef POPULATE_NVGPU_OP + patterns.add>( + context, Named_Barrier_Arrive_Op, Constraints(), + Constraints({"r", "r"})); + patterns.add>( + context, Named_Barrier_Wait_Op, Constraints(), Constraints({"r", "r"})); + patterns.add>( + context, kClusterCtaIdOp, Constraints({"=r"}), Constraints()); + patterns.add>( + context, Canonical_Warp_Id_Op, Constraints({"=r"}), Constraints()); + patterns.add(context); + + if (applyPatternsGreedily(mod, std::move(patterns)).failed()) + signalPassFailure(); + + lowerTensorMemoryAlloc(mod); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { + +LogicalResult +nvgpu::rewriteAsPtxAsm(Operation *op, PatternRewriter &rewriter, + std::string ptxAsm, + const OperandsAndConstraints &operandsAndConstraints, + const Constraints &outputConstraints) { + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + ptxAsm = patchPtxAsm(op, std::move(ptxAsm)); + auto hasSideEffects = !isMemoryEffectFree(op); + + PTXBuilder ptxBuilder; + auto ptxOutputs = getPtxOutputs(outputConstraints, ptxBuilder); + auto ptxOperands = + getPtxOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); + SmallVector outputsAndOperands = ptxOutputs; + outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end()); + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true); + auto retTy = + op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType(); + auto res = ptxBuilder.launch(rewriter, loc, retTy, + /*hasSideEffects*/ hasSideEffects); + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, res); + } + + return success(); +} + +std::unique_ptr> createConvertNVGPUToLLVMPass() { + return std::make_unique<::ConvertNVGPUToLLVM>(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp new file mode 100644 index 000000000..9cb4b9e39 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "PatternTritonGPUOpToLLVM.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct BarrierOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mlir::gpu::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + if (op->hasAttr("bar_id")) { + // llvm.nvvm.barrier0 doesn't support bar_id and num_threads attributes, + // so we have to lower it to ptx manually. + auto barId = op->getAttrOfType("bar_id").getInt(); + auto numThreads = op->getAttrOfType("num_threads").getInt(); + ::mlir::triton::PTXBuilder ptxBuilder; + auto &barSyncOp = *ptxBuilder.create<>("bar.sync"); + barSyncOp(ptxBuilder.newConstantOperand(barId), + ptxBuilder.newConstantOperand(numThreads)); + auto voidTy = void_ty(op->getContext()); + ptxBuilder.launch(rewriter, op->getLoc(), voidTy); + rewriter.eraseOp(op); + return success(); + } + // Otherwise we let the default lowering handle it + return failure(); + } +}; + +// -------------------------------------------------------------------------- +// -- MBarrier related Ops lowering, to be moved to a separate file --------- +// -------------------------------------------------------------------------- +struct MBarrierArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::MBarrierArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto mbarrier = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getMbarrier(), + typeConverter->convertType(op.getMbarrier().getType().getElementType()), + rewriter); + + bool trackAsyncOp = op.getTrackAsyncOp(); + triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal; + uint32_t txCount = op.getTxCount(); + auto remoteCtaId = adaptor.getRemoteCtaId(); + if (trackAsyncOp) { + type = triton::nvgpu::MBarriveType::cp_async; + } else if (remoteCtaId) { + assert(txCount == 0 && + "remote arrive of transaction mbarrier is not implemented yet"); + type = triton::nvgpu::MBarriveType::remote; + } else if (txCount > 0) { + type = triton::nvgpu::MBarriveType::expect_tx; + } + Value pred = adaptor.getPred(); + if (pred == nullptr) { + pred = b.int_val(/*width*/ 1, 1); + } + rewriter.replaceOpWithNewOp( + op, mbarrier.getBase(), pred, remoteCtaId, type, txCount); + return success(); + } +}; + +struct NamedBarrierArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + +struct NamedBarrierWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + +struct FenceAsyncSharedOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::FenceAsyncSharedOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::FenceAsyncSharedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBCluster()); + return success(); + } +}; + +struct InitBarrierOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::InitBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getAlloc(), + typeConverter->convertType(op.getAlloc().getType().getElementType()), + rewriter); + + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::lookupNumWarps(op); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + + auto id = getThreadId(rewriter, loc); + auto pred = b.icmp_eq(id, b.i32_val(executingThreadId)); + ::mlir::triton::PTXBuilder ptxBuilder; + const std::string ptx = "@$0 mbarrier.init.shared::cta.b64 [$1], " + + std::to_string(op.getCount()) + ";"; + auto &barSyncOp = *ptxBuilder.create<>(ptx); + barSyncOp({ptxBuilder.newOperand(pred, "b"), + ptxBuilder.newOperand(smemObj.getBase(), "r")}, + /*onlyAttachMLIRArgs=*/true); + auto voidTy = void_ty(op->getContext()); + ptxBuilder.launch(rewriter, loc, voidTy); + rewriter.eraseOp(op); + return success(); + } +}; + +struct InvalBarrierOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::InvalBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getAlloc(), + typeConverter->convertType(op.getAlloc().getType().getElementType()), + rewriter); + + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::lookupNumWarps(op); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + auto id = getThreadId(rewriter, loc); + Value pred = b.icmp_eq(id, b.i32_val(executingThreadId)); + ::mlir::triton::PTXBuilder ptxBuilder; + const std::string ptx = "@$0 mbarrier.inval.shared::cta.b64 [$1];"; + auto &barSyncOp = *ptxBuilder.create<>(ptx); + barSyncOp({ptxBuilder.newOperand(pred, "b"), + ptxBuilder.newOperand(smemObj.getBase(), "r")}, + /*onlyAttachMLIRArgs=*/true); + auto voidTy = void_ty(op->getContext()); + ptxBuilder.launch(rewriter, loc, voidTy); + rewriter.eraseOp(op); + return success(); + } +}; + +struct BarrierExpectConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::BarrierExpectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getAlloc(), + typeConverter->convertType(op.getAlloc().getType().getElementType()), + rewriter); + + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::lookupNumWarps(op); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + auto id = getThreadId(rewriter, loc); + Value pred = b.icmp_eq(id, b.i32_val(executingThreadId)); + pred = b.and_(pred, adaptor.getPred()); + ::mlir::triton::PTXBuilder ptxBuilder; + const std::string ptx = + "@$0 mbarrier.arrive.expect_tx.shared.b64 _, [$1], " + + std::to_string(op.getSize()) + ";"; + auto &barSyncOp = *ptxBuilder.create<>(ptx); + barSyncOp({ptxBuilder.newOperand(pred, "b"), + ptxBuilder.newOperand(smemObj.getBase(), "r")}, + /*onlyAttachMLIRArgs=*/true); + auto voidTy = void_ty(op->getContext()); + ptxBuilder.launch(rewriter, loc, voidTy); + rewriter.eraseOp(op); + return success(); + } +}; + +struct WaitBarrierOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::WaitBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getAlloc(), + typeConverter->convertType(op.getAlloc().getType().getElementType()), + rewriter); + auto loc = op.getLoc(); + const std::string ptxNoPred = + "{ \n\t" + ".reg .pred P1; \n\t" + "waitLoop: \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [$0], $1; \n\t" + "@!P1 bra.uni waitLoop; \n\t" + "} \n\t"; + const std::string ptxPred = + "{ \n\t" + "@!$2 bra.uni skipWait; \n\t" + ".reg .pred P1; \n\t" + "waitLoop: \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [$0], $1; \n\t" + "@!P1 bra.uni waitLoop; \n\t" + "skipWait: \n\t" + "} \n\t"; + ::mlir::triton::PTXBuilder ptxBuilder; + bool predicated = adaptor.getPred() != nullptr; + std::string ptx = predicated ? ptxPred : ptxNoPred; + auto &waitLoop = *ptxBuilder.create<>(ptx); + SmallVector<::mlir::triton::PTXBuilder::Operand *, 3> operands = { + ptxBuilder.newOperand(smemObj.getBase(), "r"), + ptxBuilder.newOperand(adaptor.getPhase(), "r")}; + if (predicated) + operands.push_back(ptxBuilder.newOperand(adaptor.getPred(), "b")); + + waitLoop(operands, /*onlyAttachMLIRArgs=*/true); + auto voidTy = void_ty(op->getContext()); + ptxBuilder.launch(rewriter, op->getLoc(), voidTy); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, + benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..5d48dc6eb --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -0,0 +1,31 @@ +add_triton_library(TritonNVIDIAGPUToLLVM + ConvertLayoutOpToLLVM.cpp + ConvertWarpSpecializeToLLVM.cpp + MemoryOpToLLVM.cpp + DotOpToLLVM/MMAv2.cpp + DotOpToLLVM/MMAv5.cpp + DotOpToLLVM/WGMMA.cpp + DotOpToLLVM.cpp + ElementwiseOpToLLVM.cpp + LoadStoreOpToLLVM.cpp + BarrierOpToLLVM.cpp + TritonGPUToLLVM.cpp + TMAToLLVM.cpp + SPMDOpToLLVM.cpp + TensorMemoryToLLVM.cpp + TensorPtrOpsToLLVM.cpp + ClusterOpsToLLVM.cpp + PTXAsmFormat.cpp + Utility.cpp + Fp4ToFpOpToLLVM.cpp + TargetInfo.cpp + RegReallocOpToLLVM.cpp + + DEPENDS + TritonNVIDIAGPUConversionPassIncGen + NVGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonGPUToLLVM + TritonProtonToLLVM +) diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp new file mode 100644 index 000000000..a5ab6fddb --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "Dialect/NVGPU/IR/Dialect.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +struct ClusterArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::ClusterArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::ClusterArriveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getRelaxed()); + return success(); + } +}; + +struct ClusterWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::ClusterWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::ClusterWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; +} // namespace + +void mlir::triton::NVIDIA::populateClusterOpsToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..8abb7131e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,248 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { +public: + ConvertLayoutOpConversion(const LLVMTypeConverter &typeConverter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isa( + srcLayout) && + isa( + dstLayout)) { + if (shouldUseDistSmem(srcLayout, dstLayout)) + return lowerDistToDistWithDistSmem(op, adaptor, rewriter, targetInfo); + } + if (isa(srcLayout) && + isa(dstLayout)) { + return lowerMmaToDotOperand(op, adaptor, rewriter); + } + + return failure(); + } + +private: + LogicalResult + lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + MLIRContext *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto typeConverter = getTypeConverter(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto srcShapePerCTA = getShapePerCTA(srcTy); + auto srcCTAsPerCGA = triton::gpu::getCTAsPerCGA(srcLayout); + auto srcCTAOrder = triton::gpu::getCTAOrder(srcLayout); + unsigned rank = srcShapePerCTA.size(); + + auto llvmElemTy = typeConverter->convertType(dstTy.getElementType()); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + smemBase = b.bitcast(smemBase, elemPtrTy); + auto smemShape = convertType(srcShapePerCTA); + + // Store to local shared memory + { + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto inIndices = emitIndices(loc, rewriter, targetInfo, srcLayout, srcTy, + /*withCTAOffset*/ false); + + assert(inIndices.size() == inVals.size() && + "Unexpected number of indices emitted"); + + for (unsigned i = 0; i < inIndices.size(); ++i) { + Value offset = LLVM::linearize(rewriter, loc, inIndices[i], smemShape); + Value ptr = b.gep(elemPtrTy, llvmElemTy, smemBase, offset); + b.store(inVals[i], ptr); + } + } + + // Cluster barrier + rewriter.create(loc, false); + rewriter.create(loc); + + // Load from remote shared memory + { + SmallVector srcShapePerCTACache; + for (unsigned i = 0; i < rank; ++i) + srcShapePerCTACache.push_back(b.i32_val(srcShapePerCTA[i])); + + SmallVector outVals; + auto outIndices = emitIndices(loc, rewriter, targetInfo, dstLayout, dstTy, + /*withCTAOffset*/ true); + + for (unsigned i = 0; i < outIndices.size(); ++i) { + auto coord = outIndices[i]; + assert(coord.size() == rank && "Unexpected rank of index emitted"); + + SmallVector multiDimCTAId, localCoord; + for (unsigned d = 0; d < rank; ++d) { + multiDimCTAId.push_back(b.udiv(coord[d], srcShapePerCTACache[d])); + localCoord.push_back(b.urem(coord[d], srcShapePerCTACache[d])); + } + + Value remoteCTAId = LLVM::linearize(rewriter, loc, multiDimCTAId, + srcCTAsPerCGA, srcCTAOrder); + Value localOffset = + LLVM::linearize(rewriter, loc, localCoord, smemShape); + + Value ptr = b.gep(elemPtrTy, llvmElemTy, smemBase, localOffset); + outVals.push_back(targetInfo.loadDShared(rewriter, loc, ptr, + remoteCTAId, llvmElemTy, + /*pred=*/b.true_val())); + } + + Value result = + packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + } + + // Cluster barrier + rewriter.create(loc, false); + rewriter.create(loc); + + return success(); + } + + // Convert from accumulator MMA layout to 8bit dot operand layout. + // The conversion logic is taken from: + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a9de6446c1c0415c926025cea284210c799b11f8/src/fmha-pipeline/reg2reg.h#L45 + void + convertMMAV3To8BitsDotOperand(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto dstTy = op.getType(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector retVals; + for (int i = 0; i < vals.size(); i += 8) { + Value upper = b.undef(vec_ty(i8_ty, 4)); + for (int j = 0; j < 4; j++) { + upper = b.insert_element(vec_ty(i8_ty, 4), upper, vals[i + j], + b.i32_val(j)); + } + upper = b.bitcast(upper, i32_ty); + Value lower = b.undef(vec_ty(i8_ty, 4)); + for (int j = 0; j < 4; j++) { + lower = b.insert_element(vec_ty(i8_ty, 4), lower, vals[i + 4 + j], + b.i32_val(j)); + } + lower = b.bitcast(lower, i32_ty); + + Value threadIdMod4 = b.urem(getThreadId(rewriter, loc), b.i32_val(4)); + Value cnd = b.or_(b.icmp_eq(threadIdMod4, b.i32_val(0)), + b.icmp_eq(threadIdMod4, b.i32_val(3))); + Value selectorEx0 = b.select(cnd, b.i32_val(0x3210), b.i32_val(0x7654)); + Value selectorEx1 = b.select(cnd, b.i32_val(0x7654), b.i32_val(0x3210)); + Value selectorEx4 = b.select(cnd, b.i32_val(0x5410), b.i32_val(0x1054)); + Value selectorEx5 = b.select(cnd, b.i32_val(0x7632), b.i32_val(0x3276)); + + Value isOne = b.icmp_eq(threadIdMod4, b.i32_val(1)); + Value isTwo = b.icmp_eq(threadIdMod4, b.i32_val(2)); + Value isThree = b.icmp_eq(threadIdMod4, b.i32_val(3)); + Value upperIdx = b.i32_val(0); + upperIdx = b.select(isOne, b.i32_val(3), upperIdx); + upperIdx = b.select(isTwo, b.i32_val(1), upperIdx); + upperIdx = b.select(isThree, b.i32_val(2), upperIdx); + + Value lowerIdx = b.i32_val(1); + lowerIdx = b.select(isOne, b.i32_val(2), lowerIdx); + lowerIdx = b.select(isTwo, b.i32_val(0), lowerIdx); + lowerIdx = b.select(isThree, b.i32_val(3), lowerIdx); + + Value upper0 = + LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx0); + Value lower0 = + LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx1); + Value mask = b.i32_val(0xFFFFFFFF); + // Set clamp tp shuffle only within 4 lanes. + Value clamp = b.i32_val(0x1C1F); + upper0 = + rewriter.create(loc, i32_ty, mask, upper0, upperIdx, + clamp, NVVM::ShflKind::idx, UnitAttr()); + lower0 = + rewriter.create(loc, i32_ty, mask, lower0, lowerIdx, + clamp, NVVM::ShflKind::idx, UnitAttr()); + Value upper1 = + LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx4); + Value vecVal = b.bitcast(upper1, vec_ty(i8_ty, 4)); + for (int i = 0; i < 4; i++) { + retVals.push_back(b.extract_element(i8_ty, vecVal, b.i32_val(i))); + } + Value lower1 = + LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx5); + vecVal = b.bitcast(lower1, vec_ty(i8_ty, 4)); + for (int i = 0; i < 4; i++) { + retVals.push_back(b.extract_element(i8_ty, vecVal, b.i32_val(i))); + } + } + Value result = + packLLElements(loc, getTypeConverter(), retVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + } + + // mma -> dot_operand + LogicalResult + lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { + assert(srcTy.getElementType().getIntOrFloatBitWidth() == 8 && + "Unsupported type size."); + convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); + return success(); + } + return failure(); + } + +private: + const NVIDIA::TargetInfo &targetInfo; +}; + +} // namespace + +void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + // Give this convertLayoutOpConversion a higher benefit as it only matches + // optimized or cross CTA cases + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); + mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, + patterns, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp new file mode 100644 index 000000000..050fe2f4d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp @@ -0,0 +1,429 @@ +#include "TargetInfo.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_CONVERTWARPSPECIALIZETOLLVM +#include "TritonNVIDIAGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +//===----------------------------------------------------------------------===// +// convertOpTypes +//===----------------------------------------------------------------------===// + +static void convertOpTypes(Operation *op, const TypeConverter &typeConverter) { + ImplicitLocOpBuilder b(op->getLoc(), op); + SmallVector operands = llvm::to_vector(op->getOperands()); + for (Value &operand : operands) { + Type type = typeConverter.convertType(operand.getType()); + if (type != operand.getType()) { + operand = + b.create(type, operand).getResult(0); + } + } + op->setOperands(operands); + + for (Region ®ion : op->getRegions()) { + b.setInsertionPointToStart(®ion.front()); + for (BlockArgument arg : llvm::to_vector(region.getArguments())) { + Type type = typeConverter.convertType(arg.getType()); + BlockArgument newArg = region.addArgument(type, arg.getLoc()); + auto cast = b.create(arg.getType(), newArg); + arg.replaceAllUsesWith(cast.getResult(0)); + region.eraseArgument(0); + } + } + + SmallVector resultTypes; + (void)typeConverter.convertTypes(op->getResultTypes(), resultTypes); + if (TypeRange(resultTypes) == op->getResultTypes()) + return; + OperationState state(op->getLoc(), op->getName(), op->getOperands(), + resultTypes, op->getAttrs()); + for (Region ®ion : op->getRegions()) + state.addRegion()->takeBody(region); + b.setInsertionPoint(op); + Operation *newOp = b.create(state); + + SmallVector results; + for (auto [i, result, type] : + llvm::enumerate(newOp->getResults(), op->getResultTypes())) { + auto cast = b.create(type, result); + op->getResult(i).replaceAllUsesWith(cast.getResult(0)); + } + op->erase(); +} + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Reserve one barrier for the default warp group, one for the start barrier, +// and one for the end barrier. +enum BarrierIndex { + kDefaultWarpGroupBarrierIdx, + kSwitchLoopBarrierIdx, + + kNumReservedBarriers, + kNumBarriers = 16 +}; + +static void createBarrier(TritonLLVMIRRewriter &b, unsigned barIdx, + std::optional numThreads, bool aligned) { + assert(barIdx < 16 && "not enough barriers"); + + PTXBuilder ptxBuilder; + std::string ptxString; + llvm::raw_string_ostream os(ptxString); + os << "barrier.sync"; + if (aligned) + os << ".aligned"; + os << ' ' << barIdx; + if (numThreads) + os << ", " << *numThreads; + + (*ptxBuilder.create<>(ptxString))(); + ptxBuilder.launch(b, b.getLoc(), void_ty(b.getContext())); +} + +//===----------------------------------------------------------------------===// +// lowerWarpSpecialize +//===----------------------------------------------------------------------===// + +// Assign hardware barriers to each warp group and rewrite warp group barriers +// into `barrier.sync` instructions. There is a maximum number of barriers. +static LogicalResult rewriteWarpGroupBarriers(LLVM::LLVMFuncOp func, + ArrayRef wsOps, + unsigned threadsPerWarp, + unsigned defaultWarpGroupSize) { + // HACK: Turn all `nvvm.barrier0` ops into warp group barriers. + func.walk([&](Operation *op) { + // Walk into default regions but not partition regions. + if (isa(op)) + return WalkResult::skip(); + + if (auto bar = dyn_cast(op)) { + TritonLLVMIRRewriter b(bar.getLoc(), bar); + createBarrier(b, /*barIdx=*/0, defaultWarpGroupSize, /*aligned=*/true); + bar.erase(); + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + + // Each partition executes simultaneously, so each will get a different + // barrier ID, but note this means there is a maximum of 16 barriers. + for (WarpSpecializeOp op : wsOps) { + for (auto [idx, partition] : llvm::enumerate(op.getPartitionRegions())) { + unsigned barIdx = idx + kNumReservedBarriers; + if (barIdx >= kNumBarriers) { + return func.emitError("cannot support more than ") + << (kNumBarriers - kNumReservedBarriers) + << " warp group partitions"; + } + unsigned warpGroupSize = threadsPerWarp * op.getPartitionNumWarps()[idx]; + partition->walk([&](NVVM::Barrier0Op bar) { + TritonLLVMIRRewriter b(bar.getLoc(), bar); + createBarrier(b, barIdx, warpGroupSize, /*aligned=*/true); + bar.erase(); + }); + } + } + + return success(); +} + +static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop, + const NVIDIA::TargetInfo &targetInfo) { + TritonLLVMIRRewriter b(ws.getLoc(), ws.getContext()); + + for (Region *partition : ws.getPartitionRegions()) { + // Load the explicit captures from shared memory and replace the block args + // if there are any. + b.setInsertionPointToStart(&partition->front()); + if (partition->getNumArguments()) { + auto captureType = LLVM::LLVMStructType::getLiteral( + b.getContext(), llvm::to_vector(partition->getArgumentTypes()), + /*isPacked=*/true); + Value capturePtr = + LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, ws); + LLVM::LLVMPointerType ptrTy = ptr_ty(b.getContext(), 3); + for (auto [i, arg] : + llvm::zip(llvm::seq(partition->getNumArguments()), + partition->getArguments())) { + Value ptr = + b.gep(ptrTy, captureType, capturePtr, ArrayRef{0, i}); + // Each thread in the warp group needs a copy of the value. + Value value = b.load(arg.getType(), ptr, /*align=*/1); + arg.replaceAllUsesWith(value); + } + partition->front().eraseArguments([](auto) { return true; }); + } + + // The shared memory is only live for the entry into the region, so put + // another barrier here. + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + + // Rewrite all warp returns. + partition->walk([&](WarpReturnOp op) { + b.setInsertionPoint(op); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + b.replaceOpWithNewOp(op, switchLoop); + }); + } +} + +static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func, + const NVIDIA::TargetInfo &targetInfo) { + SmallVector wsOps; + func.walk([&](WarpSpecializeOp op) { wsOps.push_back(op); }); + // Nothing to do. This kernel is not warp specialized. + if (wsOps.empty()) + return success(); + + // Before lowering away `ttg.warp_specialize`, lower warp group barriers. + auto module = cast(func->getParentOp()); + unsigned threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module); + unsigned defaultNumWarps = lookupNumWarps(func); + unsigned defaultWarpGroupSize = threadsPerWarp * defaultNumWarps; + if (failed(rewriteWarpGroupBarriers(func, wsOps, threadsPerWarp, + defaultWarpGroupSize))) + return failure(); + + MLIRContext *ctx = func.getContext(); + TritonLLVMIRRewriter b(func.getLoc(), ctx); + Builder rewriter(ctx); + + // Generate the function header. + Block *entry = &func.getBody().front(); + SmallVector argLocs = llvm::to_vector(llvm::map_range( + func.getArguments(), [](BlockArgument arg) { return arg.getLoc(); })); + Block *header = b.createBlock(entry, func.getArgumentTypes(), argLocs); + Block *switchLoop = b.createBlock(entry); + b.setInsertionPointToStart(header); + + // This is the absolute thread ID. + Value tid = b.create(i32_ty); + Value wid = b.udiv(tid, b.i32_val(threadsPerWarp)); + // Tell PTXAS this value is warp-uniform. + wid = targetInfo.shuffleIdx(b, b.getLoc(), wid, 0); + Value isDefault = b.icmp_ult(wid, b.i32_val(defaultNumWarps)); + b.create(isDefault, entry, switchLoop); + + // Forward arguments from the header into the old entry block. + for (auto [arg, oldArg] : + llvm::zip(header->getArguments(), entry->getArguments())) + oldArg.replaceAllUsesWith(arg); + entry->eraseArguments([](auto) { return true; }); + + // Generate the switch loop. + auto totalNumWarpsAttr = + module->getAttrOfType("ttg.total-num-warps"); + if (!totalNumWarpsAttr) { + return mlir::emitError(module.getLoc(), + "module missing 'ttg.total-num-warps' attribute"); + } + unsigned totalNumThreads = totalNumWarpsAttr.getInt() * threadsPerWarp; + + // ^switchLoop: + // barrier.sync 1 + // %state_ptr = getelementptr (ptr @shared), + // %rel_tid = sub %tid, + // %rel_wid = udiv %rel_tid, 32 + b.setInsertionPointToStart(switchLoop); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + Value relWid = b.sub(wid, b.i32_val(defaultNumWarps)); + + // The default warp group will populate the state pointer with the state ID + // for all warps. + // %warp_state_ptr = getelementptr ptr %state_tr[%rel_wid] + // %warp_state = load i8 %warp_state_ptr + LLVM::LLVMPointerType ptrTy = ptr_ty(ctx, 3); + Value warpStatePtr = b.gep(ptrTy, i8_ty, statePtr, relWid); + // All threads in a warp reading from the same smem address will not create + // bank conflicts and is better than predicated load. + Value warpState = b.load(i8_ty, warpStatePtr); + + // Pull the partition regions out. Switch based on the state ID to the right + // partition. + SmallVector partitionBlocks; + SmallVector partitionStates; + int32_t partitionStateCounter = 0; + // This represents the data that the default warp group will fill into the + // state pointer before entering each `warp_specialize` region, which maps + // a warp ID to a state ID in the switch. + int32_t maxNumWarps = totalNumWarpsAttr.getInt() - defaultNumWarps; + SmallVector> warpToState( + wsOps.size(), SmallVector(maxNumWarps, -1)); + for (auto [op, stateMap] : llvm::zip(wsOps, warpToState)) { + rewritePartitionRegions(op, switchLoop, targetInfo); + for (auto [partition, partitionNumWarps, startId] : + llvm::zip(op.getPartitionRegions(), op.getPartitionNumWarps(), + *op.getWarpGroupStartIds())) { + partitionStates.push_back(partitionStateCounter++); + partitionBlocks.push_back(&partition->front()); + for (int32_t &stateId : MutableArrayRef(stateMap).slice( + startId - defaultNumWarps, partitionNumWarps)) + stateId = partitionStates.back(); + } + } + if (partitionStateCounter > std::numeric_limits::max()) { + return mlir::emitError(func.getLoc(), + "FIXME: too many warp group partitions"); + } + + // Splice them in reverse order so the IR is easier to read. + Region::BlockListType &funcBlocks = func.getBody().getBlocks(); + for (Block *block : llvm::reverse(partitionBlocks)) { + Region *region = block->getParent(); + funcBlocks.splice(std::next(switchLoop->getIterator()), + region->getBlocks()); + } + + // Default destination. + Block *defaultBlock = new Block; + funcBlocks.insert(std::next(switchLoop->getIterator()), defaultBlock); + b.setInsertionPointToStart(defaultBlock); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + b.create(switchLoop); + + // Exit state. + Block *switchExit = new Block; + funcBlocks.insert(std::next(defaultBlock->getIterator()), switchExit); + partitionBlocks.push_back(switchExit); + partitionStates.push_back(partitionStateCounter); + + // Create the switch. + b.setInsertionPointToEnd(switchLoop); + SmallVector caseValues; + for (int32_t state : partitionStates) + caseValues.push_back(APInt(8, state)); + b.create(warpState, defaultBlock, ValueRange(), caseValues, + partitionBlocks, + SmallVector(partitionBlocks.size())); + + // Now add synchronization around the default regions. + for (auto [ws, stateMap] : llvm::zip(wsOps, warpToState)) { + Block *before = ws->getBlock(); + Block *after = b.splitBlock(before, ws->getIterator()); + b.setInsertionPointToEnd(before); + Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + for (auto [i, state] : llvm::enumerate(stateMap)) { + b.store(b.i8_val(state), b.gep(ptrTy, i8_ty, statePtr, LLVM::GEPArg(i))); + } + + // Store the captures if there are any. + if (ws.getNumOperands()) { + auto captureType = LLVM::LLVMStructType::getLiteral( + b.getContext(), llvm::to_vector(ws.getOperandTypes()), + /*isPacked=*/true); + Value capturePtr = + LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, ws); + for (auto [i, arg] : llvm::zip(llvm::seq(ws.getNumOperands()), + ws.getOperands())) { + Value ptr = + b.gep(ptrTy, captureType, capturePtr, ArrayRef{0, i}); + b.store(arg, ptr, /*align=*/1); + } + } + + // First barrier releases the waiting warpgroups. The second barrier ensures + // they have read the captures before the memory is released upon entry. + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + b.create(&ws.getDefaultRegion().front()); + + ws.getDefaultRegion().walk([&](WarpYieldOp op) { + b.setInsertionPoint(op); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + b.replaceOpWithNewOp(op, op.getOperands(), after); + }); + after->getParent()->getBlocks().splice(after->getIterator(), + ws.getDefaultRegion().getBlocks()); + + // Replace the results. + auto outputs = after->addArguments( + ws.getResultTypes(), + SmallVector(ws.getNumResults(), ws.getLoc())); + ws.replaceAllUsesWith(outputs); + ws.erase(); + } + + // Signal all warp groups to exit. + func.walk([&](LLVM::ReturnOp op) { + TritonLLVMIRRewriter b(op.getLoc(), op); + Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + Value cst = b.i8_val(partitionStateCounter); + for (int32_t i : llvm::seq(maxNumWarps)) + b.store(cst, b.gep(ptrTy, i8_ty, statePtr, LLVM::GEPArg(i))); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + }); + b.setInsertionPointToStart(switchExit); + b.create(ValueRange()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertWarpSpecializeToLLVM + : public mlir::triton::impl::ConvertWarpSpecializeToLLVMBase< + ConvertWarpSpecializeToLLVM> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + // FIXME: Assume warp specialization only happens on Blackwell. + NVIDIA::TargetInfo targetInfo(/*computeCapability=*/100, + /*ptxVersion=*/100); + + // Convert types and cleanup unrealized conversions. + mlir::LowerToLLVMOptions option(&getContext()); + option.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(&getContext(), option, + targetInfo); + mod.walk([&](Operation *op) { + if (isa(op)) + convertOpTypes(op, typeConverter); + }); + RewritePatternSet patterns(&getContext()); + UnrealizedConversionCastOp::getCanonicalizationPatterns(patterns, + &getContext()); + if (failed(applyPatternsGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + + SmallVector kernels; + for (auto func : mod.getOps()) { + if (func.isPublic()) + kernels.push_back(func); + } + for (LLVM::LLVMFuncOp kernel : kernels) + if (failed(lowerWarpSpecialize(kernel, targetInfo))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp new file mode 100644 index 000000000..4c5840574 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -0,0 +1,169 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" + +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; + +LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + +LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + +LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, + triton::nvidia_gpu::WarpGroupDotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Value thread); +namespace { +struct DotOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + // D = A * B + C + Value A = op.getA(); + Value D = op.getResult(); + + // Here we assume the DotOp's operands always comes from shared memory. + auto AShapePerCTA = getShapePerCTA(A.getType()); + size_t reduceAxis = 1; + unsigned K = AShapePerCTA[reduceAxis]; + bool isOuter = K == 1; + + NvidiaMmaEncodingAttr mmaLayout = dyn_cast( + cast(D.getType()).getEncoding()); + if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) { + if (mmaLayout.isTuring()) + return convertMMA1688(op, adaptor, getTypeConverter(), rewriter); + if (mmaLayout.isAmpere()) + return convertMMA16816(op, adaptor, getTypeConverter(), rewriter); + + llvm::report_fatal_error( + "Unsupported MMA kind found when converting DotOp to LLVM."); + } + + if (isa( + cast(D.getType()).getEncoding())) + return convertFMADot(op, adaptor, getTypeConverter(), rewriter); + + llvm::report_fatal_error( + "Unsupported DotOp found when converting TritonGPU to LLVM."); + } +}; + +struct WarpGroupDotOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::WarpGroupDotOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + // D = A * B + C + Value A = op.getA(); + Value D = op.getResult(); + + // Here we assume the DotOp's operands always comes from shared memory. + auto AShapePerCTA = getShapePerCTA(A.getType()); + size_t reduceAxis = 1; + unsigned K = AShapePerCTA[reduceAxis]; + bool isOuter = K == 1; + + NvidiaMmaEncodingAttr mmaLayout = dyn_cast( + cast(D.getType()).getEncoding()); + if (!isOuter && mmaLayout && + supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) { + if (mmaLayout.isHopper()) { + return convertWGMMA(op, adaptor, getTypeConverter(), rewriter, + getThreadId(rewriter, loc)); + } + + llvm::report_fatal_error( + "Unsupported MMA kind found when converting WarpGroupDotOp to LLVM."); + } + + llvm::report_fatal_error( + "Unsupported WarpGroupDotOp found when converting TritonGPU to LLVM."); + } +}; + +struct WarpGroupDotWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::WarpGroupDotWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::WarpGroupDotWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pendings = op.getPendings(); + Location loc = op.getLoc(); + if (adaptor.getInputs().size() <= 1) { + Value intput = + adaptor.getInputs().size() == 1 ? adaptor.getInputs()[0] : Value(); + rewriter.replaceOpWithNewOp(op, intput, + pendings); + return success(); + } + std::vector types; + // Pack the inputs into a single struct. + for (Value input : adaptor.getInputs()) { + auto structType = dyn_cast(input.getType()); + if (!structType) + return failure(); + for (Type type : structType.getBody()) + types.push_back(type); + } + auto packedType = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + Value packed = rewriter.create(loc, packedType); + unsigned outputStructIndex = 0; + for (Value input : adaptor.getInputs()) { + auto structType = dyn_cast(input.getType()); + for (unsigned i = 0; i < structType.getBody().size(); ++i) { + Value value = rewriter.create( + loc, structType.getBody()[i], input, i); + packed = rewriter.create( + loc, packedType, packed, value, outputStructIndex++); + } + } + Value packedOutput = + rewriter.create(loc, packed, pendings); + // Unpack the output into the original struct types. + SmallVector outputs; + outputStructIndex = 0; + for (Value input : adaptor.getInputs()) { + auto structType = cast(input.getType()); + Value unpacked = rewriter.create(loc, structType); + for (unsigned i = 0; i < structType.getBody().size(); ++i) { + Value value = rewriter.create( + loc, packedType.getBody()[outputStructIndex], packedOutput, + outputStructIndex); + outputStructIndex++; + unpacked = rewriter.create(loc, structType, + unpacked, value, i); + } + outputs.push_back(unpacked); + } + rewriter.replaceOp(op, outputs); + return success(); + } +}; +} // namespace + +void mlir::triton::NVIDIA::populateDotOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h new file mode 100644 index 000000000..87346f99f --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h @@ -0,0 +1,94 @@ +#include "Utility.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace triton { +namespace NVIDIA { + +// The descriptor format is described in the spec: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor +// Unnamed fieids are not used +union SMEMDescriptor { + uint64_t descriptor; + struct { + uint64_t baseAddress : 14; + uint64_t : 2; + uint64_t leadDimensionBaseOffset : 14; + uint64_t : 2; + uint64_t strideDimensionBaseOffset : 14; + uint64_t : 3; + uint64_t matrixBaseOffset : 3; + uint64_t : 10; + uint64_t swizzlingMode : 2; + }; +}; + +// Abstract class to calculate the address of a shared or tensor memory slice. +class DotOpMmaMemLoader { +public: + virtual ~DotOpMmaMemLoader() = default; + virtual Value memLoad(int a, int b, ConversionPatternRewriter &rewriter, + Location loc) = 0; +}; + +// Helper class to load shared memory slices following MMAv3 layout. +class DotOpMmaV3SmemLoader : public DotOpMmaMemLoader { +public: + DotOpMmaV3SmemLoader() {} + DotOpMmaV3SmemLoader(Value tensor, Value base, SmallVector shape, + Value warpId, unsigned int dimWpt, bool trans, + SmallVector instrShape, + int64_t elementBitwidth, + ConversionPatternRewriter &rewriter, Location loc); + // Return a descriptor pointing to the shared memory slice at coordinates (a, + // b) + Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter, + Location loc); + + Value memLoad(int a, int b, ConversionPatternRewriter &rewriter, + Location loc) override { + return smemLoad(a, b, rewriter, loc); + } + +private: + Value base; + SmallVector shape; + Value warpId; + int dimWpt; + bool trans; + int fastMovingDim; + Value elemsPerSwizzlingRowVal; + SmallVector instrShape; + int elemsPerSwizzlingRow; + int64_t elemBits; + Value descriptor; +}; + +// Helper class to load tensor memory following MMAv5 layout. +class DotOpMmaV5TmemLoader : public DotOpMmaMemLoader { +public: + DotOpMmaV5TmemLoader() {} + DotOpMmaV5TmemLoader(Value tensor, Value base, + SmallVector instrShape, bool interleaved, + bool trans); + Value tmemLoad(int a, int b, ConversionPatternRewriter &rewriter, + Location loc); + + Value memLoad(int a, int b, ConversionPatternRewriter &rewriter, + Location loc) override { + return tmemLoad(a, b, rewriter, loc); + } + +private: + Value base; + bool trans; + bool interleaved; + bool unpacked; + SmallVector instrShape; + int numElementsPer32b; + int numRepM; +}; + +} // namespace NVIDIA +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp new file mode 100644 index 000000000..36fa804e6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -0,0 +1,616 @@ +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "Utility.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrderForDotOperand; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; + +using ValueTableV2 = std::map, Value>; + +Value loadC(Value tensor, Value llTensor, + const LLVMTypeConverter *typeConverter, Location loc, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = tensor.getContext(); + auto tensorTy = cast(tensor.getType()); + size_t fcSize = triton::gpu::getTotalElemsPerThread(tensor.getType()); + + assert(isa(tensorTy.getEncoding()) && + "Currently, we only support $c with a mma layout."); + // Load a normal C tensor with mma layout, that should be a + // LLVM::struct with fcSize elements. + auto structTy = cast(llTensor.getType()); + assert(structTy.getBody().size() == fcSize && + "DotOp's $c operand should pass the same number of values as $d in " + "mma layout."); + + auto numMmaRets = tensorTy.getElementType().getIntOrFloatBitWidth() / 8; + assert(numMmaRets == 4 || numMmaRets == 2); + if (numMmaRets == 4) { + return llTensor; + } else if (numMmaRets == 2) { + auto cPack = SmallVector(); + auto cElemTy = tensorTy.getElementType(); + int numCPackedElem = 4 / numMmaRets; + Type cPackTy = vec_ty(cElemTy, numCPackedElem); + for (int i = 0; i < fcSize; i += numCPackedElem) { + Value pack = rewriter.create(loc, cPackTy); + for (int j = 0; j < numCPackedElem; ++j) { + pack = b.insert_element(cPackTy, pack, + b.extract_val(cElemTy, llTensor, i + j), + b.i32_val(j)); + } + cPack.push_back(pack); + } + + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(cPack.size(), cPackTy)); + Value result = + packLLElements(loc, typeConverter, cPack, rewriter, structTy); + return result; + } + + return llTensor; +} + +ValueTableV2 getValuesFromDotOperandLayoutStruct( + const LLVMTypeConverter *typeConverter, Location loc, + ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter, + int repK, RankedTensorType type) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto elems = unpackLLElements(loc, value, rewriter); + auto eltTy = typeConverter->convertType(type.getElementType()); + int offset{}; + ValueTableV2 vals; + auto bitwidth = eltTy.getIntOrFloatBitWidth(); + auto numElemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(eltTy, numElemsPerVec); + + auto packVec = [&](std::array dstIdx) { + Value vec = b.undef(vecTy); + for (auto i = 0; i < numElemsPerVec; ++i) { + vec = b.insert_element(vec, b.bitcast(elems[offset + i], eltTy), + b.i32_val(i)); + } + vals[dstIdx] = b.bitcast(vec, i32_ty); + offset += numElemsPerVec; + }; + + auto dot = cast(type.getEncoding()); + auto kWidth = dot.getKWidth(); + auto largeK = bitwidth * kWidth > 32; + if (largeK) { + // For layouts with a large K dimension, the original register layout needs + // to be divided into multiple MMAs, where each MMA has contiguous 32 bits + // along the K dimension per thread. + // Using kWidth = 8 and bitwidth = 2 as an example, + // we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the + // K dimension. + llvm::SmallVector si; + auto kIters = kWidth / (32 / bitwidth); + + if (dot.getOpIdx() == 0) { + // Original register layout: + // + // [0, 1, 2, 3, 4, 5, 6, 7], [16, 17, 18, 19, 20, 21, 22, 23, 23] + // [8, 9, 10, 11, 12, 13, 14, 15], [24, 25, 26, 27, 28, 29, 30, 31] + // + // Each element in the layout is a single bf16. + // + // To derive four independent MMA operations, a stride of 4 is applied to + // the original register layout: + // + // 1st MMA: [[0, 1], [8, 9], [16, 17], [24, 25]] + // 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]] + // 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]] + // 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]] + if (kIters <= repK) { + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 4; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } + } else { + // Suppose kWidth=4 and type=fp32, so numElemsPerVec=1. + // Each tile of the dot operand layout has a size of 16x32. + // However, if the triton tensor size is 16x16, elements along the k + // dimension are duplicated. Within each tile, each register + // contains 2x8 elements arranged as follows: + // + // tile0/0 tile0/1 + // |<--kWidth=4-->| |<--kWidth-->| + // |<-mmaWidth=2->| + // [0, 1, 2, 3] [0, 1, 2, 3] + // [4, 5, 6, 7] [4, 5, 6, 7] + // + // tile0/1 replicates the elements in tile0/0 along the k dimension. + // For a tensor size of 32x32, the next tile on the m dimension is as + // follows: + // + // tile1/0 tile1/1 + // |<--kWidth-->| |<--kWidth-->| + // [8, 9, 10, 11], [8, 9, 10, 11] + // [12, 13, 14, 15], [12, 13, 14, 15] + // + // Within a single tile, we can perform two MMAs, and the + // resulting register layout for each MMA is as follows: + // + // 1st MMA: [0, 4, 1, 5] + // 2nd MMA: [2, 6, 3, 7] + // 3rd MMA: [8, 12, 9, 13] + // 4th MMA: [10, 14, 11, 15] + // + // Additionally, we should reorder the elements by moving the duplicated + // elements to the end. In the example above, we convert the order from + // tile0/0, tile0/1, tile1/0, tile1/1 to tile0/0, tile1/0, tile0/1, + // tile1/1, so that only the first two tiles will be used in the + // computation. + size_t elemsPerTile = 2 * 2 * kWidth; + size_t elemsPerMma = 2 * 2 * numElemsPerVec; + size_t mmaWidth = kWidth / numElemsPerVec / 2; + size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma); + for (size_t rep = 0; rep < repMma; ++rep) + for (size_t tile = 0; tile < elems.size() / elemsPerTile; ++tile) + for (size_t mmaKWidth = 0; mmaKWidth < mmaWidth; ++mmaKWidth) + for (size_t kTile = 0; kTile < 2; ++kTile) + for (size_t mTile = 0; mTile < 2; ++mTile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(rep * mmaWidth * elemsPerMma + + mmaKWidth * 2 * numElemsPerVec + + tile * elemsPerTile + mTile * kWidth + + kTile * numElemsPerVec + e); + } + } + } else { + // Original register layout: + // + // [0, 1, 2, 3, 4, 5, 6, 7]^T, [8, 9, 10, 11, 12, 13, 14, 15]^T + // + // A stride of 4 is applied to derive four independent MMA operations: + // + // 1st MMA: [[0, 1], [8, 9]] + // 2nd MMA: [[2, 3], [10, 11]] + // 3rd MMA: [[4, 5], [12, 13]] + // 4th MMA: [[6, 7], [14, 15]] + if (kIters <= repK) { + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 2; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } + } else { + // Suppose kWidth=4 and type=fp32. + // Original register layout: + // + // tile0/0 tile0/1 + // [0, 1, 2, 3]^T, [0, 1, 2, 3]^T + // + // Similar to the opIdx=0 situation, we should reorder the elements by + // moving the duplicated elements to the end. + size_t elemsPerTile = 2 * kWidth; + size_t elemsPerMma = 2 * numElemsPerVec; + size_t mmaWidth = kWidth / numElemsPerVec / 2; + size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma); + for (size_t rep = 0; rep < repMma; ++rep) + for (size_t tile = 0; tile < elems.size() / elemsPerTile; ++tile) + for (size_t mmaKWidth = 0; mmaKWidth < mmaWidth; ++mmaKWidth) + for (size_t kTile = 0; kTile < 2; ++kTile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(rep * mmaWidth * elemsPerMma + + mmaKWidth * 2 * numElemsPerVec + + tile * elemsPerTile + kTile * numElemsPerVec + + e); + } + } + } + + auto step = si.size(); + SmallVector perm(step); + for (auto i = 0; i < elems.size() / step; ++i) { + for (auto j = 0; j < step; ++j) { + perm[j] = elems[i * step + si[j]]; + } + std::copy(perm.begin(), perm.end(), elems.begin() + i * step); + } + } + + if (dot.getOpIdx() == 0) { + for (auto b = 0; b < batch; ++b) + for (auto m = 0; m < repOuter; ++m) + for (auto k = 0; k < repK; ++k) { + packVec({b, 2 * m, 2 * k}); + packVec({b, 2 * m + 1, 2 * k}); + packVec({b, 2 * m, 2 * k + 1}); + packVec({b, 2 * m + 1, 2 * k + 1}); + } + } else { + for (auto b = 0; b < batch; ++b) + for (auto n = 0; n < repOuter; ++n) + for (auto k = 0; k < repK; ++k) { + packVec({b, n, 2 * k}); + packVec({b, n, 2 * k + 1}); + } + } + return vals; +} + +enum class TensorCoreType : uint8_t { + // floating-point tensor core instr + FP32_FP16_FP16_FP32 = 0, // default + FP32_BF16_BF16_FP32, + FP32_TF32_TF32_FP32, + FP16_FP16_FP16_FP16, + FP32_FP8E5M2_FP8E5M2_FP32, + FP32_FP8E5M2_FP8E4M3FN_FP32, + FP32_FP8E4M3FN_FP8E5M2_FP32, + FP32_FP8E4M3FN_FP8E4M3FN_FP32, + // integer tensor core instr + INT32_INT1_INT1_INT32, // Not implemented + INT32_INT4_INT4_INT32, // Not implemented + INT32_INT8_INT8_INT32, // Not implemented + // + NOT_APPLICABLE, +}; + +Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { + Type fp32Ty = type::f32Ty(ctx); + Type fp16Ty = type::f16Ty(ctx); + Type i32Ty = type::i32Ty(ctx); + Type fp32x4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); + Type i32x4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i32Ty)); + Type fp16x2Pack2Ty = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(2, vec_ty(fp16Ty, 2))); + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return fp32x4Ty; + case TensorCoreType::FP32_BF16_BF16_FP32: + return fp32x4Ty; + case TensorCoreType::FP32_TF32_TF32_FP32: + return fp32x4Ty; + case TensorCoreType::FP16_FP16_FP16_FP16: + return fp16x2Pack2Ty; + case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32: + case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32: + case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32: + case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32: + return fp32x4Ty; + case TensorCoreType::INT32_INT8_INT8_INT32: + return i32x4Ty; + default: + llvm::report_fatal_error("Unsupported mma type found"); + } + + return Type{}; +} + +TensorCoreType getMmaType(triton::DotOp op) { + auto aTy = op.getA().getType(); + auto bTy = op.getB().getType(); + // d = a*b + c + auto dTy = op.getD().getType(); + + if (dTy.getElementType().isF32()) { + if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) + return TensorCoreType::FP32_FP16_FP16_FP32; + if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) + return TensorCoreType::FP32_BF16_BF16_FP32; + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) + return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) + return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) + return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) + return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; + if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && + op.getInputPrecision() == InputPrecision::TF32) + return TensorCoreType::FP32_TF32_TF32_FP32; + } else if (dTy.getElementType().isInteger(32)) { + if (aTy.getElementType().isInteger(8) && bTy.getElementType().isInteger(8)) + return TensorCoreType::INT32_INT8_INT8_INT32; + } else if (dTy.getElementType().isF16()) { + if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) + return TensorCoreType::FP16_FP16_FP16_FP16; + } + + return TensorCoreType::NOT_APPLICABLE; +} + +inline static const std::map mmaInstrPtxTuring = { + {TensorCoreType::FP32_FP16_FP16_FP32, + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"}, + + {TensorCoreType::INT32_INT8_INT8_INT32, + "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32"}, + + {TensorCoreType::FP16_FP16_FP16_FP16, + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16"}, +}; + +inline static const std::map mmaInstrPtxAmpere = { + {TensorCoreType::FP32_FP16_FP16_FP32, + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, + {TensorCoreType::FP32_BF16_BF16_FP32, + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"}, + {TensorCoreType::FP32_TF32_TF32_FP32, + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"}, + + {TensorCoreType::INT32_INT1_INT1_INT32, + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"}, + {TensorCoreType::INT32_INT4_INT4_INT32, + "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"}, + {TensorCoreType::INT32_INT8_INT8_INT32, + "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, + + {TensorCoreType::FP16_FP16_FP16_FP16, + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16"}, + + {TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32, + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"}, + {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32, + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"}, + {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32, + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"}, + {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32, + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"}, +}; + +static void callMmaTuringInt8(PTXBuilder &builder, int b, int m, int n, int k, + mlir::triton::PTXInstr &mma, unsigned numMmaRets, + unsigned colsPerThread, int numCPackedElem, + ValueTableV2 &ha, ValueTableV2 &hb, + const SmallVector &fc) { + auto retArgs1 = builder.newListOperand(numMmaRets / 2, "=r"); + auto retArgs2 = builder.newListOperand(numMmaRets / 2, "=r"); + auto cArgs1 = builder.newListOperand(); + for (int i = 0; i < numMmaRets / 2; ++i) { + cArgs1->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto cArgs2 = builder.newListOperand(); + for (int i = numMmaRets / 2; i < numMmaRets; ++i) { + cArgs2->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto aArgs1 = builder.newListOperand({ + {ha[{b, m, k}], "r"}, + }); + auto bArgs1 = builder.newListOperand({ + {hb[{b, n, k}], "r"}, + }); + auto aArgs2 = builder.newListOperand({ + {ha[{b, m, k + 1}], "r"}, + }); + auto bArgs2 = builder.newListOperand({{hb[{b, n, k + 1}], "r"}}); + auto aArgs3 = builder.newListOperand({ + {ha[{b, m + 1, k}], "r"}, + }); + auto bArgs3 = builder.newListOperand({ + {hb[{b, n, k}], "r"}, + }); + auto aArgs4 = builder.newListOperand({ + {ha[{b, m + 1, k + 1}], "r"}, + }); + auto bArgs4 = builder.newListOperand({{hb[{b, n, k + 1}], "r"}}); + mma(retArgs1, aArgs1, bArgs1, cArgs1); + mma(retArgs1, aArgs2, bArgs2, cArgs1); + mma(retArgs2, aArgs3, bArgs3, cArgs2); + mma(retArgs2, aArgs4, bArgs4, cArgs2); +} + +static void callMmaTuringFp16(PTXBuilder &builder, int b, int m, int n, int k, + mlir::triton::PTXInstr &mma, unsigned numMmaRets, + unsigned colsPerThread, int numCPackedElem, + ValueTableV2 &ha, ValueTableV2 &hb, + const SmallVector &fc, bool isAccF16) { + auto retArgs = builder.newListOperand(numMmaRets, isAccF16 ? "=r" : "=f"); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < numMmaRets; ++i) { + cArgs->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto aArgs1 = builder.newListOperand({ + {ha[{b, m, k}], "r"}, + {ha[{b, m + 1, k}], "r"}, + }); + auto bArgs1 = builder.newListOperand({{hb[{b, n, k}], "r"}}); + auto aArgs2 = builder.newListOperand({ + {ha[{b, m, k + 1}], "r"}, + {ha[{b, m + 1, k + 1}], "r"}, + }); + auto bArgs2 = builder.newListOperand({{hb[{b, n, k + 1}], "r"}}); + mma(retArgs, aArgs1, bArgs1, cArgs); + mma(retArgs, aArgs2, bArgs2, cArgs); +} + +static void callMmaAmpere(PTXBuilder &builder, int b, int m, int n, int k, + mlir::triton::PTXInstr &mma, unsigned numMmaRets, + unsigned colsPerThread, int numCPackedElem, + unsigned batchOffset, ValueTableV2 &ha, + ValueTableV2 &hb, const SmallVector &fc, + bool isAccF16, bool isIntMMA) { + auto retArgs = + builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f"); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < numMmaRets; ++i) { + cArgs->listAppend(builder.newOperand( + fc[(m * colsPerThread + 4 * n) / numCPackedElem + i + batchOffset * b], + std::to_string(i))); + // reuse the output registers + } + auto aArgs = builder.newListOperand({ + {ha[{b, m, k}], "r"}, + {ha[{b, m + 1, k}], "r"}, + {ha[{b, m, k + 1}], "r"}, + {ha[{b, m + 1, k + 1}], "r"}, + }); + auto bArgs = + builder.newListOperand({{hb[{b, n, k}], "r"}, {hb[{b, n, k + 1}], "r"}}); + mma(retArgs, aArgs, bArgs, cArgs); +} + +LogicalResult convertDot(const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc, + Value a, Value b, Value c, Value d, Value loadedA, + Value loadedB, Value loadedC, DotOp op, + DotOpAdaptor adaptor, bool isTuring) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = c.getContext(); + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); + auto dTensorTy = cast(d.getType()); + + auto aShapePerCTA = triton::gpu::getShapePerCTA(aTensorTy); + auto bShapePerCTA = triton::gpu::getShapePerCTA(bTensorTy); + auto dShapePerCTA = triton::gpu::getShapePerCTA(dTensorTy); + + int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); + auto dotOpA = cast(aTensorTy.getEncoding()); + int kWidth = dotOpA.getKWidth(); + auto repA = + cast(dotOpA.getParent()) + .getRepForOperand(aShapePerCTA, bitwidth, kWidth, dotOpA.getOpIdx()); + auto dotOpB = cast(bTensorTy.getEncoding()); + auto repB = + cast(dotOpB.getParent()) + .getRepForOperand(bShapePerCTA, bitwidth, kWidth, dotOpB.getOpIdx()); + + assert(repA[2] == repB[1]); + assert(repA[0] == repB[0]); + int repM = repA[1], repN = repB[2], repK = repA[2]; + int repBatch = repA[0]; + + // We can reuse the same iteration order in + // getValuesFromDotOperandLayoutStruct as both a and b are K-major + assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(), + aShapePerCTA.size(), + /*kContig=*/true)); + auto ha = getValuesFromDotOperandLayoutStruct( + typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); + + assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(), + bShapePerCTA.size(), + /*kContig=*/true)); + auto hb = getValuesFromDotOperandLayoutStruct( + typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy); + + auto fc = unpackLLElements(loc, loadedC, rewriter); + auto numMmaRets = dTensorTy.getElementType().getIntOrFloatBitWidth() / 8; + int numCPackedElem = 4 / numMmaRets; + + auto mmaType = getMmaType(op); + + const auto &mmaInstructions = + isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere; + if (mmaInstructions.find(mmaType) == mmaInstructions.end()) { + return emitError(loc, "Unsupported MMA instruction for the given mma type"); + } + auto rank = dTensorTy.getRank(); + auto elemsPerThread = triton::gpu::getElemsPerThread(dTensorTy); + auto batchOffset = + elemsPerThread[rank - 2] * elemsPerThread[rank - 1] / numCPackedElem; + auto callMma = [&](unsigned b, unsigned m, unsigned n, unsigned k) { + unsigned colsPerThread = repN * 2; + PTXBuilder builder; + auto &mma = *builder.create(mmaInstructions.at(mmaType)); + // using =r for float32 works but leads to less readable ptx. + bool isIntMMA = dTensorTy.getElementType().isInteger(32); + bool isAccF16 = dTensorTy.getElementType().isF16(); + + if (isTuring) { + assert(b == 0 && "Turing only supports batch size 1"); + if (isIntMMA) // Turing int8 + callMmaTuringInt8(builder, b, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, ha, hb, fc); + else // Turing fp16 + callMmaTuringFp16(builder, b, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, ha, hb, fc, isAccF16); + } else { // Ampere + callMmaAmpere(builder, b, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, batchOffset, ha, hb, fc, isAccF16, + isIntMMA); + } + + Value mmaOut = + builder.launch(rewriter, loc, getMmaRetType(mmaType, op.getContext())); + + Type elemTy = cast(mmaOut.getType()).getBody()[0]; + for (int i = 0; i < numMmaRets; ++i) { + fc[(m * colsPerThread + 4 * n) / numCPackedElem + i + batchOffset * b] = + tb.extract_val(elemTy, mmaOut, i); + } + }; + + for (int b = 0; b < repBatch; ++b) + for (int k = 0; k < repK; ++k) + for (int m = 0; m < repM; ++m) + for (int n = 0; n < repN; ++n) + callMma(b, 2 * m, n, 2 * k); + + Type resElemTy = dTensorTy.getElementType(); + + // replace with new packed result + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(fc.size() * numCPackedElem, resElemTy)); + SmallVector results(fc.size() * numCPackedElem); + for (int i = 0; i < fc.size(); ++i) { + for (int j = 0; j < numCPackedElem; ++j) { + results[i * numCPackedElem + j] = + numCPackedElem > 1 + ? tb.bitcast(tb.extract_element(fc[i], tb.i32_val(j)), resElemTy) + : tb.bitcast(fc[i], resElemTy); + } + } + Value res = packLLElements(loc, typeConverter, results, rewriter, structTy); + + rewriter.replaceOp(op, res); + + return success(); +} + +LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, bool isTuring) { + assert(mlir::isa(op.getA().getType().getEncoding()) && + mlir::isa(op.getB().getType().getEncoding()) && + "Both $a and %b should be DotOperand layout."); + + Value loadedC = + loadC(op.getC(), adaptor.getC(), typeConverter, op.getLoc(), rewriter); + return convertDot(typeConverter, rewriter, op.getLoc(), op.getA(), op.getB(), + op.getC(), op.getD(), adaptor.getA(), adaptor.getB(), + loadedC, op, adaptor, isTuring); +} + +// Convert to mma.m16n8k8 +LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + return convertMMA(op, adaptor, typeConverter, rewriter, true /*isTuring*/); +} + +// Convert to mma.m16n8k16 +LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + return convertMMA(op, adaptor, typeConverter, rewriter, false /*isTuring*/); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp new file mode 100644 index 000000000..e5ec0df36 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -0,0 +1,637 @@ +#include "Dialect/NVGPU/IR/Dialect.h" +#include "MMAHelpers.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using namespace mlir::triton::NVIDIA; + +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::triton::gpu::NVMMASharedEncodingAttr; + +mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::DotOpMmaV5TmemLoader( + Value tensor, Value base, SmallVector instrShape, + bool interleaved, bool trans) + : base(base), instrShape(instrShape), interleaved(interleaved), + trans(trans) { + auto ty = cast(tensor.getType()); + auto tmemEncoding = + cast(ty.getEncoding()); + unpacked = tmemEncoding.getUnpacked(); + int elTyWidth = ty.getElementTypeBitWidth(); + numElementsPer32b = unpacked ? 1 : 32 / elTyWidth; + auto shapePerCTA = triton::gpu::getShapePerCTA(ty); + numRepM = ceil(shapePerCTA[0], instrShape[0]); +} + +Value mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::tmemLoad( + int a, int b, ConversionPatternRewriter &rewriter, Location loc) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + int numRows = 64; + if (interleaved || instrShape[0] >= 128) + numRows = 128; + int numColPerBlock = + ((instrShape[0] * instrShape[1]) / numRows) / numElementsPer32b; + Value address = base; + int blockId = a + b * numRepM; + address = tb.ptrtoint(i32_ty, address); + if (!interleaved) { + address = tb.add(address, tb.i32_val(numColPerBlock * blockId)); + } else { + int blockIdIsOdd = blockId & 1; + int blockIdPrevEven = blockId - blockIdIsOdd; + Value offset = tb.i32_val(numColPerBlock * blockIdPrevEven + + ((16 * blockIdIsOdd) << 16)); + address = tb.add(address, offset); + } + return address; +} + +namespace { + +enum class mxfpKind { mxf8f6f4 = 0, mxf4 = 1, mxf4nvf4 = 2 }; + +inline mxfpKind getMXFPKind(ScaleDotElemType typeA, ScaleDotElemType typeB, + Type scaleAType, Type scaleBType) { + if (typeA == ScaleDotElemType::E2M1 && typeB == ScaleDotElemType::E2M1) { + if (llvm::isa(scaleAType) && + llvm::isa(scaleBType)) { + return mxfpKind::mxf4nvf4; + } + return mxfpKind::mxf4; + } + return mxfpKind::mxf8f6f4; +}; + +static Value createInstDescriptor(ConversionPatternRewriter &rewriter, + triton::nvidia_gpu::TCGen5MMAOp op, int M, + int N, bool transposeA, bool transposeB) { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + union TCGen5InstructionDescriptor { + uint32_t descriptor; + struct { + uint32_t sparsitySelector : 2; + uint32_t sparsity : 1; + uint32_t : 1; + uint32_t dType : 2; + uint32_t : 1; + uint32_t aType : 3; + uint32_t bType : 3; + uint32_t negateA : 1; + uint32_t negateB : 1; + uint32_t transposeA : 1; + uint32_t transposeB : 1; + uint32_t N : 6; + uint32_t : 1; + uint32_t M : 5; + uint32_t : 1; + uint32_t shift : 2; + }; + }; + auto getTypeEncoding = [](Type type) { + if (type.isF16()) + return 0; + if (type.isBF16()) + return 1; + if (type.isF32()) + return 2; + if (llvm::isa(type)) + return 0; + if (llvm::isa(type)) + return 1; + llvm_unreachable("Unsupported type."); + }; + static_assert(sizeof(TCGen5InstructionDescriptor) == 4, + "instruction descriptor size should be 32 bits."); + TCGen5InstructionDescriptor desc; + desc.descriptor = 0; + desc.transposeA = transposeA; + desc.transposeB = transposeB; + desc.M = M >> 4; + desc.N = N >> 3; + desc.aType = getTypeEncoding(op.getA().getType().getElementType()); + desc.bType = getTypeEncoding(op.getB().getType().getElementType()); + Type dstElType = op.getD().getType().getElementType(); + assert(dstElType.isF16() || dstElType.isF32()); + desc.dType = dstElType.isF16() ? 0 : 1; + return b.int_val(32, desc.descriptor); +} + +static Value createScaleInstDescriptor(ConversionPatternRewriter &rewriter, + triton::nvidia_gpu::TCGen5MMAScaledOp op, + int M, int N, bool transposeA, + bool transposeB, int scaleFactorsubIdxA, + int scaleFactorsubIdxB, + mxfpKind mxfpInstKind) { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + union TCGen5InstructionDescriptor { + uint32_t descriptor; + struct { + uint32_t sparsitySelector : 2; + uint32_t sparsity : 1; + uint32_t : 1; + uint32_t BScaleFactor : 2; + uint32_t : 1; + uint32_t aType : 3; + uint32_t bType : 3; + uint32_t negateA : 1; + uint32_t negateB : 1; + uint32_t transposeA : 1; + uint32_t transposeB : 1; + uint32_t N : 6; + uint32_t scaleType : 1; + uint32_t M : 5; + uint32_t AScaleFactor : 2; + uint32_t : 1; + }; + }; + auto getTypeEncoding = [](ScaleDotElemType type, bool isMXF4) { + switch (type) { + case ScaleDotElemType::E4M3: + return 0; + case ScaleDotElemType::E5M2: + return 1; + case ScaleDotElemType::E2M3: + return 3; + case ScaleDotElemType::E3M2: + return 4; + case ScaleDotElemType::E2M1: + return !isMXF4 ? 5 : 1; + default: + break; + } + llvm_unreachable("Unsupported type."); + }; + static_assert(sizeof(TCGen5InstructionDescriptor) == 4, + "instruction descriptor size should be 32 bits."); + TCGen5InstructionDescriptor desc; + desc.descriptor = 0; + desc.transposeA = transposeA; + desc.transposeB = transposeB; + desc.M = M >> 4; + desc.N = N >> 3; + desc.aType = + getTypeEncoding(op.getAType(), mxfpInstKind != mxfpKind::mxf8f6f4); + desc.bType = + getTypeEncoding(op.getBType(), mxfpInstKind != mxfpKind::mxf8f6f4); + desc.AScaleFactor = scaleFactorsubIdxA; + desc.BScaleFactor = scaleFactorsubIdxB; + // Hardcoded UE8M0 scale type. + desc.scaleType = 1; + + if (mxfpInstKind != mxfpKind::mxf8f6f4) { + assert(desc.aType == 1 && desc.bType == 1); + assert(desc.AScaleFactor <= 1 && desc.BScaleFactor <= 1); + assert(desc.transposeA == 0 && + "MMAv5 with kind=mxf4 does not support transpose"); + assert(desc.transposeB == 0 && + "MMAv5 with kind=mxf4 does not support transpose"); + if (mxfpInstKind == mxfpKind::mxf4) { + desc.AScaleFactor *= 2; + desc.BScaleFactor *= 2; + assert(desc.AScaleFactor == 0 || + desc.AScaleFactor == 2 && + "MMAv5 with kind=mxf4 only supports SFA_ID 0 or 2"); + assert(desc.BScaleFactor == 0 || + desc.BScaleFactor == 2 && + "MMAv5 with kind=mxf4 only supports SFB_ID 0 or 2"); + } else if (mxfpInstKind == mxfpKind::mxf4nvf4) { + desc.scaleType = 0; // UE4M3 + assert(desc.AScaleFactor == 0 && + "MMAv5 with kind=mxf4nvf4 currently only supports SFA_ID 0"); + assert(desc.BScaleFactor == 0 && + "MMAv5 with kind=mxf4nvf4 currently only supports SFB_ID 0"); + } + } + + return b.int_val(32, desc.descriptor); +} + +static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc, + triton::nvidia_gpu::TCGen5MMAOp op, Value a, Value b, + Value d, Value pred, Value instDescriptor, + Value useInitAcc, bool aInTMem, bool twoCTAs) { + PTXBuilder ptxBuilder; + std::string opcode = + "tcgen05.mma.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + ".kind::"; + Type srcElementTy = op.getA().getType().getElementType(); + if (srcElementTy.isF16() || srcElementTy.isBF16()) + opcode += "f16"; + else if (srcElementTy.isF32()) + opcode += "tf32"; + else if (llvm::isa(srcElementTy)) + opcode += "f8f6f4"; + else + assert(0 && "Unsupported type."); + auto *accOp = ptxBuilder.newAddrOperand(d, "r"); + auto *aOp = aInTMem ? ptxBuilder.newAddrOperand(a, "r") + : ptxBuilder.newOperand(a, "l"); + auto *bOp = ptxBuilder.newOperand(b, "l"); + auto *instDescOp = ptxBuilder.newOperand(instDescriptor, "r"); + auto *useInitAccOp = ptxBuilder.newOperand(useInitAcc, "b"); + auto &mmaOp = *ptxBuilder.create(opcode); + mmaOp({accOp, aOp, bOp, instDescOp, useInitAccOp}).predicate(pred); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +static void createScaledGen5MMA(ConversionPatternRewriter &rewriter, + Location loc, + triton::nvidia_gpu::TCGen5MMAScaledOp op, + Value a, Value b, Value d, Value scaleA, + Value scaleB, Value pred, Value instDescriptor, + Value useInitAcc, bool aInTmem, + mxfpKind mxfpInstKind) { + PTXBuilder ptxBuilder; + std::string opcode; + if (mxfpInstKind == mxfpKind::mxf8f6f4) { + opcode = + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X"; + } else if (mxfpInstKind == mxfpKind::mxf4) { + opcode = "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X"; + } else if (mxfpInstKind == mxfpKind::mxf4nvf4) { + opcode = + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X"; + } else { + assert(0 && "Unsupported mxfp kind."); + } + auto *accOp = ptxBuilder.newAddrOperand(d, "r"); + auto *aOp = aInTmem ? ptxBuilder.newAddrOperand(a, "r") + : ptxBuilder.newOperand(a, "l"); + auto *bOp = ptxBuilder.newOperand(b, "l"); + auto *instDescOp = ptxBuilder.newOperand(instDescriptor, "r"); + auto *scaleAOp = ptxBuilder.newAddrOperand(scaleA, "r"); + auto *scaleBOp = ptxBuilder.newAddrOperand(scaleB, "r"); + auto *useInitAccOp = ptxBuilder.newOperand(useInitAcc, "b"); + auto &mmaOp = *ptxBuilder.create(opcode); + mmaOp({accOp, aOp, bOp, instDescOp, scaleAOp, scaleBOp, useInitAccOp}) + .predicate(pred); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +static void createMMACommit(ConversionPatternRewriter &rewriter, Location loc, + Value barrier, Value pred, bool twoCTAs = false) { + PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector ptxOperands; + auto *predOperand = ptxBuilder.newOperand(pred, "b"); + ptxOperands.push_back(predOperand); + auto *barrierOperand = ptxBuilder.newOperand(barrier, "l"); + ptxOperands.push_back(barrierOperand); + std::string opcode; + if (twoCTAs) { + // .multicast::cluster and mask 0x3 means the completion of UTCMMA.2CTA will + // be boardcasted into CTAid 0 and 1 + auto *ctaMask = ptxBuilder.newOperand(b.int_val(16, 0x3), "h"); + ptxOperands.push_back(ctaMask); + opcode = "@$0 " + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::" + "cluster.multicast::cluster.b64 [$1], $2;"; + } else { + opcode = "@$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];"; + } + auto &barrierOp = *ptxBuilder.create(opcode); + barrierOp(ptxOperands, /*onlyAttachMLIRArgs=*/true); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +void convertDot(const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc, + triton::nvidia_gpu::TCGen5MMAOp op, Value a, Value b, Value d, + Value loadedA, Value loadedB, Value loadedD, Value useDFlag, + Value pred, Value barrier) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + bool twoCTAs = op.getTwoCtas().has_value(); + // Only run mma on one thread. We currently use elect as ptxas is not able to + // detect that tid.x == 0 is true only for 1 thread. + Value warpId = rewriter.create(loc); + Value wapr0 = tb.icmp_eq(warpId, tb.i32_val(0)); + if (twoCTAs) { + // TODO: we have to sync the two CTAs because we currently don't use remove + // barriers for the copies. + rewriter.create(loc, false); + rewriter.create(loc); + + Value clusterId = rewriter.create(loc); + Value cluster0 = tb.icmp_eq(clusterId, tb.i32_val(0)); + pred = tb.and_(pred, cluster0); + } + pred = tb.and_(pred, wapr0); + + // Wrap the whole mma code sequence within a IF block. + auto *curBlock = rewriter.getInsertionBlock(); + auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); + auto *mmaBlock = rewriter.createBlock(curBlock->getParent(), + std::next(Region::iterator(curBlock))); + rewriter.setInsertionPointToEnd(curBlock); + rewriter.create(loc, pred, mmaBlock, endBlock); + // Emit the rest in mmaBlock + rewriter.setInsertionPointToEnd(mmaBlock); + + pred = LLVM::NVIDIA::createElectPredicate(loc, rewriter); + + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); + auto dTensorTy = cast(d.getType()); + bool aInTmem = true; + bool transA = false; + if (auto aSharedLayout = + dyn_cast(aTensorTy.getEncoding())) { + transA = aSharedLayout.getTransposed(); + aInTmem = false; + } + auto bSharedLayout = cast(bTensorTy.getEncoding()); + bool transB = !bSharedLayout.getTransposed(); + Value baseA; + if (aInTmem) { + baseA = loadedA; + } else { + baseA = + getSharedMemoryObjectFromStruct( + loc, loadedA, + typeConverter->convertType(aTensorTy.getElementType()), rewriter) + .getBase(); + } + Value baseB = + getSharedMemoryObjectFromStruct( + loc, loadedB, typeConverter->convertType(bTensorTy.getElementType()), + rewriter) + .getBase(); + + SmallVector dstPerCTA = triton::gpu::getShapePerCTA(dTensorTy); + unsigned int M = dstPerCTA[0]; + unsigned int N = dstPerCTA[1]; + unsigned int K = aTensorTy.getDimSize(1); + // Get MMA size based on acc layout. + auto tensorMemAttr = cast( + dTensorTy.getEncoding()); + int mmaSizeM = tensorMemAttr.getBlockM(); + int mmaSizeN = tensorMemAttr.getBlockN(); + int mmaSizeK = 256 / aTensorTy.getElementTypeBitWidth(); + int numRepM = ceil(M, mmaSizeM); + int numRepN = ceil(N, mmaSizeN); + int numRepK = ceil(K, mmaSizeK); + assert((!aTensorTy.getElementType().isF32() || !(transA || transB)) && + "Currently don't support transpose for F32."); + bool interleaved = (mmaSizeM == 64 && (numRepM > 1 || numRepN > 1)); + Value instDescriptor = + createInstDescriptor(rewriter, op, twoCTAs ? mmaSizeM * 2 : mmaSizeM, + mmaSizeN, transA, transB); + Value zero = tb.i32_val(0); + SmallVector shapeA(triton::gpu::getShapePerCTA(aTensorTy)); + SmallVector shapeB(triton::gpu::getShapePerCTA(bTensorTy)); + SmallVector aOperandShape = {(unsigned)mmaSizeM, + (unsigned)mmaSizeK}; + std::unique_ptr aLoader; + if (aInTmem) { + aLoader = std::make_unique(a, baseA, aOperandShape, + interleaved, transA); + } else { + aLoader = std::make_unique( + a, baseA, shapeA, zero, 1, transA, aOperandShape, + aTensorTy.getElementTypeBitWidth(), rewriter, loc); + } + DotOpMmaV3SmemLoader bLoader = + DotOpMmaV3SmemLoader(b, baseB, shapeB, zero, 1, transB, + {(unsigned)mmaSizeN, (unsigned)mmaSizeK}, + bTensorTy.getElementTypeBitWidth(), rewriter, loc); + DotOpMmaV5TmemLoader dLoader = DotOpMmaV5TmemLoader( + d, loadedD, {(unsigned)mmaSizeM, (unsigned)mmaSizeN}, interleaved, false); + for (int m = 0; m < numRepM; m++) { + for (int n = 0; n < numRepN; n++) { + Value useInitAcc = useDFlag; + Value accAddress = dLoader.tmemLoad(m, n, rewriter, loc); + for (int k = 0; k < numRepK; k++) { + a = aLoader->memLoad(m, k, rewriter, loc); + b = bLoader.smemLoad(n, k, rewriter, loc); + createGen5MMA(rewriter, loc, op, a, b, accAddress, pred, instDescriptor, + useInitAcc, aInTmem, twoCTAs); + useInitAcc = tb.i1_val(1); + } + } + } + auto smemObj = + LLVM::getSharedMemoryObjectFromStruct(loc, barrier, i64_ty, rewriter); + createMMACommit(rewriter, loc, smemObj.getBase(), pred, twoCTAs); + rewriter.create(loc, endBlock); +} + +struct TCGen5MMAOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TCGen5MMAOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto AEnc = op.getA().getType().getEncoding(); + auto BEnc = op.getB().getType().getEncoding(); + assert(mlir::isa(AEnc) || + mlir::isa(AEnc) && + "Operand A should use Shared or Tensor memory layout."); + assert(mlir::isa(BEnc) && + "Operand B should use Shared layout."); + assert(op.getBarrier() && + "tensorcore op should have a barrier at this point."); + auto typeConverter = getTypeConverter(); + convertDot(typeConverter, rewriter, op.getLoc(), op, // + op.getA(), op.getB(), op.getD(), // + adaptor.getA(), adaptor.getB(), adaptor.getD(), + adaptor.getUseD(), adaptor.getPred(), adaptor.getBarrier()); + rewriter.eraseOp(op); + return success(); + } +}; + +static int64_t getFormatBitSize(ScaleDotElemType type) { + switch (type) { + case ScaleDotElemType::E4M3: + return 8; + case ScaleDotElemType::E5M2: + return 8; + case ScaleDotElemType::E2M3: + return 6; + case ScaleDotElemType::E3M2: + return 6; + case ScaleDotElemType::E2M1: + return 4; + default: + llvm_unreachable("Unsupported type."); + } +} + +struct TCGen5MMAScaledOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TCGen5MMAScaledOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(op.getBarrier() && + "tensorcore op should have a barrier at this point."); + auto typeConverter = getTypeConverter(); + Location loc = op.getLoc(); + auto tb = TritonLLVMOpBuilder(loc, rewriter); + auto aTensorTy = cast(op.getA().getType()); + auto bTensorTy = cast(op.getB().getType()); + auto dTensorTy = cast(op.getD().getType()); + mxfpKind mxfpInstKind = getMXFPKind( + op.getAType(), op.getBType(), op.getAScale().getType().getElementType(), + op.getBScale().getType().getElementType()); + bool opKindIsMXFP4 = mxfpInstKind != mxfpKind::mxf8f6f4; + bool aInTmem = true; + bool transA = false; + if (auto aSharedLayout = + dyn_cast(aTensorTy.getEncoding())) { + transA = aSharedLayout.getTransposed(); + aInTmem = false; + } + auto bSharedLayout = cast(bTensorTy.getEncoding()); + bool transB = !bSharedLayout.getTransposed(); + Value baseA = + getSharedMemoryObjectFromStruct( + loc, adaptor.getA(), + typeConverter->convertType(aTensorTy.getElementType()), rewriter) + .getBase(); + Value baseB = + getSharedMemoryObjectFromStruct( + loc, adaptor.getB(), + typeConverter->convertType(bTensorTy.getElementType()), rewriter) + .getBase(); + Value baseD = adaptor.getD(); + baseD = tb.ptrtoint(i32_ty, baseD); + Value baseScaleA = adaptor.getAScale(); + Value baseScaleB = adaptor.getBScale(); + baseScaleA = tb.ptrtoint(i32_ty, baseScaleA); + baseScaleB = tb.ptrtoint(i32_ty, baseScaleB); + + unsigned int M = dTensorTy.getDimSize(0); + unsigned int N = dTensorTy.getDimSize(1); + int numBitsPerElementA = opKindIsMXFP4 ? getFormatBitSize(op.getAType()) + : aTensorTy.getElementTypeBitWidth(); + int numBitsPerElementB = opKindIsMXFP4 ? getFormatBitSize(op.getBType()) + : bTensorTy.getElementTypeBitWidth(); + unsigned int K = (aTensorTy.getDimSize(1) * 8) / numBitsPerElementA; + + // Get MMA size based on acc layout. + auto tensorMemAttr = cast( + dTensorTy.getEncoding()); + int mmaSizeM = tensorMemAttr.getBlockM(); + int mmaSizeN = tensorMemAttr.getBlockN(); + int mmaSizeK = !opKindIsMXFP4 ? 32 : 64; + int numRepM = ceil(M, mmaSizeM); + int numRepN = ceil(N, mmaSizeN); + int numRepK = ceil(K, mmaSizeK); + bool interleaved = (mmaSizeM == 64 && (numRepM > 1 || numRepN > 1)); + + Value zero = tb.i32_val(0); + SmallVector shapeA( + triton::gpu::getAllocationShapePerCTA(aTensorTy)); + SmallVector shapeB( + triton::gpu::getAllocationShapePerCTA(bTensorTy)); + if (opKindIsMXFP4) { + shapeA[1] *= 2; + shapeB[0] *= 2; + } + SmallVector aOperandShape = {(unsigned)mmaSizeM, + (unsigned)mmaSizeK}; + std::unique_ptr aLoader; + if (aInTmem) { + aLoader = std::make_unique( + op.getA(), baseA, aOperandShape, interleaved, transA); + } else { + aLoader = std::make_unique( + op.getA(), baseA, shapeA, zero, 1, transA, aOperandShape, + numBitsPerElementA, rewriter, loc); + } + DotOpMmaV3SmemLoader bLoader = + DotOpMmaV3SmemLoader(op.getB(), baseB, shapeB, zero, 1, transB, + {(unsigned)mmaSizeN, (unsigned)mmaSizeK}, + numBitsPerElementB, rewriter, loc); + + // Only run mma on one thread. We currently use elect as ptxas is not able + // to detect that tid.x == 0 is true only for 1 thread. + Value pred = + tb.and_(adaptor.getPred(), + LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter)); + int numRows = 128; + int colSizeInBits = 32; + int numColPerBlock = + ceil((mmaSizeM * mmaSizeN * dTensorTy.getElementTypeBitWidth()), + (numRows * colSizeInBits)); + + int scaleFactorColsPerSet = [](mxfpKind kind) { + switch (kind) { + case mxfpKind::mxf8f6f4: + return 1; + case mxfpKind::mxf4: + return 2; + case mxfpKind::mxf4nvf4: + return 4; + default: + llvm_unreachable("Unsupported mxfp kind."); + } + }(mxfpInstKind); + int numColPerScaleBlockA = + ceil(triton::nvidia_gpu::getTmemAllocSizes( + cast(op.getAScale().getType())) + .numCols, + numRepM * (ceil(numRepK, 4 / scaleFactorColsPerSet))); + int numColPerScaleBlockB = + ceil(triton::nvidia_gpu::getTmemAllocSizes( + cast(op.getBScale().getType())) + .numCols, + numRepN * (ceil(numRepK, 4 / scaleFactorColsPerSet))); + for (int m = 0; m < numRepM; m++) { + for (int n = 0; n < numRepN; n++) { + // Blocks are laid out along M first then N as described in + // `TensorMemorySpace` definition. + int blockId = m + n * numRepM; + Value accAddress = tb.add(baseD, tb.i32_val(numColPerBlock * blockId)); + Value useInitAcc = op.getUseD(); + for (int k = 0; k < numRepK; k++) { + Value a = aLoader->memLoad(m, k, rewriter, loc); + Value b = bLoader.smemLoad(n, k, rewriter, loc); + int subWordIdx = k % (4 / scaleFactorColsPerSet); + int wordIdx = k / (4 / scaleFactorColsPerSet); + Value scaleA = tb.add(baseScaleA, tb.i32_val((m + wordIdx * numRepM) * + numColPerScaleBlockA)); + Value scaleB = tb.add(baseScaleB, tb.i32_val((n + wordIdx * numRepN) * + numColPerScaleBlockB)); + Value instDescriptor = createScaleInstDescriptor( + rewriter, op, mmaSizeM, mmaSizeN, transA, transB, subWordIdx, + subWordIdx, mxfpInstKind); + createScaledGen5MMA(rewriter, loc, op, a, b, accAddress, scaleA, + scaleB, pred, instDescriptor, useInitAcc, aInTmem, + mxfpInstKind); + useInitAcc = tb.i1_val(1); + } + } + } + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getBarrier(), i64_ty, rewriter); + createMMACommit(rewriter, loc, smemObj.getBase(), pred); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +namespace mlir { +namespace triton { +namespace NVIDIA { + +void populateTCGen5MMAOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add( + typeConverter, benefit); +} + +} // namespace NVIDIA +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp new file mode 100644 index 000000000..5b99ef061 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -0,0 +1,515 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "MMAHelpers.h" +#include "Utility.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::NVIDIA; + +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::MemDescType; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::NVMMASharedEncodingAttr; + +triton::nvgpu::WGMMAEltType getMmaRetType(Value d) { + auto dTy = cast(d.getType()).getElementType(); + if (dTy.isF32()) { + return triton::nvgpu::WGMMAEltType::f32; + } else if (dTy.isF16()) { + return triton::nvgpu::WGMMAEltType::f16; + } else if (dTy.isInteger(32)) { + return triton::nvgpu::WGMMAEltType::s32; + } else { + llvm::report_fatal_error("Unsupported mma result type found"); + } +} + +triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { + auto aTy = cast(a.getType()).getElementType(); + if (aTy.isF16()) { + return triton::nvgpu::WGMMAEltType::f16; + } else if (aTy.isBF16()) { + return triton::nvgpu::WGMMAEltType::bf16; + } else if (aTy.isF32() && allowTF32) { + return triton::nvgpu::WGMMAEltType::tf32; + } else if (aTy.isInteger(8)) { + return triton::nvgpu::WGMMAEltType::s8; + } else if (llvm::isa(aTy)) { + return triton::nvgpu::WGMMAEltType::e5m2; + } else if (llvm::isa(aTy)) { + return triton::nvgpu::WGMMAEltType::e4m3; + } else { + llvm::report_fatal_error("Unsupported mma operand type found"); + } +} + +int64_t getSwizzlingFromLayout(const NVMMASharedEncodingAttr &layout, + uint32_t widthInByte) { + uint32_t swizzlingByteWidth = layout.getSwizzlingByteWidth(); + // TODO[biaow]: remove it once we support swizzling size larger than matrix + // width, which requires padding the matrix width to the swizzling size when + // allocating shared memory. + assert(swizzlingByteWidth <= widthInByte && + "swizzling size larger than matrix width is not supported."); + return swizzlingByteWidth; +} + +static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, + int64_t swizzling, uint32_t stride) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + static_assert(sizeof(SMEMDescriptor) == 8, + "Descriptor size should be 64 bits."); + SMEMDescriptor desc; + desc.descriptor = 0; + switch (swizzling) { + case 0: + desc.swizzlingMode = 0; + break; + case 32: + desc.swizzlingMode = 3; + break; + case 64: + desc.swizzlingMode = 2; + break; + case 128: + desc.swizzlingMode = 1; + break; + default: + llvm::report_fatal_error("Unsupported swizzling size."); + } + desc.strideDimensionBaseOffset = swizzling >> 1; + desc.leadDimensionBaseOffset = (swizzling * stride) >> 4; + return b.int_val(64, desc.descriptor); +} + +mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::DotOpMmaV3SmemLoader( + Value tensor, Value base, SmallVector shape, Value warpId, + unsigned int dimWpt, bool trans, SmallVector instrShape, + int64_t elementBitwidth, ConversionPatternRewriter &rewriter, Location loc) + : base(base), shape(shape), warpId(warpId), dimWpt(dimWpt), trans(trans), + instrShape(instrShape), elemBits(elementBitwidth) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ty = cast(tensor.getType()); + auto sharedLayout = cast(ty.getEncoding()); + fastMovingDim = sharedLayout.getTransposed() ? 0 : 1; + const int swizzlingByteWidth = sharedLayout.getSwizzlingByteWidth(); + elemsPerSwizzlingRow = (swizzlingByteWidth * 8) / elemBits; + elemsPerSwizzlingRowVal = b.i32_val(elemsPerSwizzlingRow); + + uint32_t widthInByte = shape[fastMovingDim] * elemBits / 8; + int64_t swizzling = getSwizzlingFromLayout(sharedLayout, widthInByte); + + descriptor = + createDescriptor(rewriter, loc, swizzling, shape[1 - fastMovingDim]); +} + +Value mlir::triton::NVIDIA::DotOpMmaV3SmemLoader::smemLoad( + int a, int b, ConversionPatternRewriter &rewriter, Location loc) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + Value k = tb.i32_val(b * instrShape[1]); + Value m = tb.add(tb.i32_val(a * dimWpt * instrShape[0]), + tb.mul(warpId, tb.i32_val(instrShape[0]))); + if (trans) { + std::swap(k, m); + } + Value leading_offset = + tb.mul(tb.udiv(k, elemsPerSwizzlingRowVal), + tb.i32_val(shape[1 - fastMovingDim] * elemsPerSwizzlingRow)); + Value stride_offset = tb.mul(m, elemsPerSwizzlingRowVal); + Value offset = tb.add(tb.add(leading_offset, stride_offset), + tb.urem(k, elemsPerSwizzlingRowVal)); + Value off1; + // Avoid the runtime udiv if we know the elements are byte multiples + if (elemBits % 8) { + off1 = tb.udiv(tb.mul(tb.i32_val(elemBits), offset), tb.i32_val(8)); + } else { + off1 = tb.mul(tb.i32_val(elemBits / 8), offset); + } + Value off_ = tb.zext(i64_ty, tb.udiv(off1, tb.i32_val(16))); + + Value loadDesc = tb.add(descriptor, off_); + // Add the base at the end to make it easier to do loop invariant code + // motion. + loadDesc = tb.add( + loadDesc, tb.lshr(tb.shl(tb.ptrtoint(i64_ty, base), tb.int_val(64, 46)), + tb.int_val(64, 50))); + return loadDesc; +} + +DotOpMmaV3SmemLoader loadA(const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc, + const NvidiaMmaEncodingAttr &mmaEncoding, + Value tensor, Value smemObjBase, Value thread) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto aTy = cast(tensor.getType()); + auto aSharedLayout = dyn_cast(aTy.getEncoding()); + assert(aSharedLayout && "only support load dot operand from shared."); + auto instrShape = mmaEncoding.getInstrShape(); + auto wpt = mmaEncoding.getWarpsPerCTA(); + bool transA = aSharedLayout.getTransposed(); + auto shapePerCTA = getShapePerCTA(aTy); + + // The descriptor should be calculated based on the first warp of the + // warpgroup. + Value warp = b.and_(b.udiv(thread, b.i32_val(32)), b.i32_val(0xFFFFFFFC)); + // Workaround for a bug in ptxas 12.3 that cause a failure in + // test_core.py::test_dot. The shuffle will force the compiler to treat the + // value as uniform and prevent wrong optimizations. + warp = mlir::LLVM::NVIDIA::shuffleIdx(loc, rewriter, warp, 0); + Value warpM = b.urem(warp, b.i32_val(wpt[0])); + Value warpId = b.urem(warpM, b.i32_val(shapePerCTA[0] / instrShape[0])); + + return {tensor, + smemObjBase, + shapePerCTA, + warpId, + wpt[0], + transA, + {instrShape[0], instrShape[2]}, + aTy.getElementTypeBitWidth(), + rewriter, + loc}; +} + +DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc, + NvidiaMmaEncodingAttr &mmaEncoding, Value tensor, + Value base, Value thread) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto bTy = cast(tensor.getType()); + auto bSharedLayout = cast(bTy.getEncoding()); + assert(bSharedLayout && "only support load B from shared."); + auto instrShape = mmaEncoding.getInstrShape(); + auto wpt = mmaEncoding.getWarpsPerCTA(); + bool transB = !bSharedLayout.getTransposed(); + auto shapePerCTA = triton::gpu::getShapePerCTA(bTy); + + Value warp = b.and_(b.udiv(thread, b.i32_val(32)), b.i32_val(0xFFFFFFFC)); + Value warpMN = b.udiv(warp, b.i32_val(wpt[0])); + Value warpN = b.urem(warpMN, b.i32_val(wpt[1])); + Value warpId = b.urem(warpN, b.i32_val(shapePerCTA[1] / instrShape[1])); + + return {tensor, + base, + shapePerCTA, + warpId, + wpt[1], + transB, + {instrShape[1], instrShape[2]}, + bTy.getElementTypeBitWidth(), + rewriter, + loc}; +} + +// Return a vector of Value of the accumulator start at startIndex and pack the +// values into 32bits in case the accumulator is fp16. +// +// `elements` contains all loaded register values for operand A. +// This consists of operand A for possibly multiple wgmma instructions. +// For each wgmma, each warp in a warp group feeds a single "warp matrix" +// Each warp matrix consists of 2x2 "quads". +// Each thread holds several elements in each quad. Right before a wgmma, +// the sum of bitwidth of +// the elements in each quad should add up to 32. +// +// These values are stored unrolled in `elements`. +// The ordering of dimensions is as follows: +// batch (only 1 batch for Hopper currently) +// matM (m-index of the "warp matrix") +// matK (k-index of the "warp matrix") +// quadK (k-index of the "quad" in the core matrix) +// quadM (m-index of the "quad" in the core matrix) +// vecIdx (index of the element in the quad; this is always along the k-dim) +// +// This ordering is decided when a tensor in DotOpEnc is lowered into llvm. +// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. +// Thus, both lowerings must obey this above ordering for the below code to be +// correct. +llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, + Location loc, + const SmallVector &elements, + int startIndex, int numElements, + Operation *insertBefore) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(insertBefore); + + if (!elements[0].getType().isIntOrFloat() || + elements[0].getType().getIntOrFloatBitWidth() >= 32) { + llvm::SmallVector mmaOut(numElements); + for (int i = 0; i < numElements; ++i) + mmaOut[i] = elements[startIndex + i]; + return mmaOut; + } + Type elementType = elements[0].getType(); + int numElemsPer32Bits = 32 / elementType.getIntOrFloatBitWidth(); + + // For FP16 and BF16 we need to pack accumulator into 32-bit integers. + int num32BitValues = numElements / numElemsPer32Bits; + llvm::SmallVector mmaOut(num32BitValues); + Type packTy = vec_ty(elementType, numElemsPer32Bits); + for (int i = 0; i < num32BitValues; ++i) { + Value pack = rewriter.create(loc, packTy); + for (int j = 0; j < numElemsPer32Bits; ++j) { + Value element = elements[startIndex + i * numElemsPer32Bits + j]; + pack = b.insert_element(packTy, pack, element, b.i32_val(j)); + } + pack = b.bitcast(pack, rewriter.getIntegerType(32)); + mmaOut[i] = pack; + } + return mmaOut; +} + +// If the accumulator is fp16 unpack it from 32-bit integers. +SmallVector unpackAccumulator(ConversionPatternRewriter &rewriter, + Location loc, + const SmallVector &packed, + RankedTensorType tensorTy) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (!tensorTy.getElementType().isF16()) + return packed; + // For fp16 the accumulator is pack into 32-bit integers so we need to unpack + // it. + SmallVector results; + for (Value elem : packed) { + elem = b.bitcast(elem, vec_ty(rewriter.getF16Type(), 2)); + results.push_back( + b.extract_element(rewriter.getF16Type(), elem, b.i32_val(0))); + results.push_back( + b.extract_element(rewriter.getF16Type(), elem, b.i32_val(1))); + } + return results; +} + +static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc, + Value a, Value b) { + int numEl = cast(a.getType()).getBody().size(); + Value newStruct = rewriter.create(loc, a.getType()); + for (int i = 0; i < numEl; ++i) { + Value lhs = rewriter.create(loc, a, i); + Value rhs = rewriter.create(loc, b, i); + Value add = rewriter.create(loc, lhs, rhs); + newStruct = rewriter.create(loc, newStruct, add, i); + } + return newStruct; +} + +static SmallVector emitWait(ConversionPatternRewriter &rewriter, + Location loc, SmallVector acc, + int pendings) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector types(acc.size(), acc[0].getType()); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + Value llvmStruct = rewriter.create(loc, structTy); + int i = 0; + for (Value v : acc) { + llvmStruct = b.insert_val(structTy, llvmStruct, v, i++); + } + Value res = rewriter.create(loc, llvmStruct, + pendings); + SmallVector results; + for (int i = 0; i < acc.size(); ++i) { + results.push_back(b.extract_val(types[0], res, i)); + } + return results; +} + +LogicalResult convertDot(const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc, + Operation *op, Value a, Value b, Value c, Value d, + Value useCOperand, Value loadedA, Value loadedB, + Value loadedC, bool allowTF32, + bool needsPartialAccumulator, + uint32_t maxNumImpreciseAcc, bool sync, Value thread) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); + auto dTensorTy = cast(d.getType()); + auto aSharedLayout = + dyn_cast(aTensorTy.getEncoding()); + auto bSharedLayout = cast(bTensorTy.getEncoding()); + auto mmaEncoding = cast(dTensorTy.getEncoding()); + bool transA = false; + Value baseA; + Value baseB; + if (aSharedLayout) + baseA = + getSharedMemoryObjectFromStruct( + loc, loadedA, + typeConverter->convertType(aTensorTy.getElementType()), rewriter) + .getBase(); + baseB = getSharedMemoryObjectFromStruct( + loc, loadedB, + typeConverter->convertType(bTensorTy.getElementType()), rewriter) + .getBase(); + if (aSharedLayout) { + transA = aSharedLayout.getTransposed(); + } + bool transB = !bSharedLayout.getTransposed(); + auto dShapePerCTA = getShapePerCTA(dTensorTy); + auto instrShape = mmaEncoding.getInstrShape(); + auto accSize = 2 * (instrShape[1] / 4); + int M = 4 * instrShape[0]; + int N = instrShape[1]; + int K = instrShape[2]; + bool zeroAcc = isZeroConst(c); + auto instrMNK = mmaEncoding.getInstrShape(); + auto warpSize = mmaEncoding.getWarpsPerCTA(); + auto shapePerCTATile = SmallVector{instrMNK[0] * warpSize[0], + instrMNK[1] * warpSize[1]}; + int numRepM = ceil(dShapePerCTA[0], shapePerCTATile[0]); + int numRepN = ceil(dShapePerCTA[1], shapePerCTATile[1]); + int numRepK = ceil(aTensorTy.getShape()[1], instrShape[2]); + DotOpMmaV3SmemLoader aLoader; + SmallVector structA; + if (aSharedLayout) { + aLoader = + loadA(typeConverter, rewriter, loc, mmaEncoding, a, baseA, thread); + } else { + structA = unpackLLElements(loc, loadedA, rewriter); + } + DotOpMmaV3SmemLoader bLoader = + loadB(typeConverter, rewriter, loc, mmaEncoding, b, baseB, thread); + + auto fc = unpackLLElements(loc, loadedC, rewriter); + + triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(d); + triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(a, allowTF32); + triton::nvgpu::WGMMAEltType eltTypeB = getMmaOperandType(b, allowTF32); + + triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col + : triton::nvgpu::WGMMALayout::row; + triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row + : triton::nvgpu::WGMMALayout::col; + + auto func = op->getParentOfType(); + Operation *startSequence = rewriter.create(loc); + SmallVector mmaResults; + for (int m = 0; m < numRepM; ++m) { + for (int n = 0; n < numRepN; ++n) { + llvm::SmallVector mmaOut = + loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize, + startSequence); + llvm::SmallVector elemTypes; + for (Value accEl : mmaOut) + elemTypes.push_back(accEl.getType()); + auto accTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); + Value d; + Value useC = tb.i1_val(0); + if (!zeroAcc) { + d = packLLElements(loc, typeConverter, mmaOut, rewriter, accTy); + useC = tb.i1_val(1); + } + if (useCOperand) + useC = tb.and_(useC, useCOperand); + uint32_t numLowPrecisionAcc = 0; + Value partialAcc; + for (int k = 0; k < numRepK; ++k) { + Value a; + if (aSharedLayout) { + a = aLoader.smemLoad(m, k, rewriter, loc); + } else { + auto aDotOpEnc = + cast(aTensorTy.getEncoding()); + assert(aDotOpEnc.getKWidth() == + 32 / aTensorTy.getElementTypeBitWidth()); + + unsigned regASize = (instrShape[0] * instrShape[2]) / 32; + llvm::SmallVector regA = + loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize, + regASize, startSequence); + auto regATy = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), + SmallVector(regA.size(), regA[0].getType())); + a = packLLElements(loc, typeConverter, regA, rewriter, regATy); + } + auto b = bLoader.smemLoad(n, k, rewriter, loc); + numLowPrecisionAcc += K; + // If using native accumulation would cause use to do more low precion + // accumulation than allowed do a separate allocation. + bool requireAddAccumulator = + needsPartialAccumulator && + (numLowPrecisionAcc >= maxNumImpreciseAcc || k == numRepK - 1); + Value mmaAcc = needsPartialAccumulator ? partialAcc : d; + mmaAcc = rewriter.create( + loc, accTy, a, b, useC, mmaAcc, M, N, K, eltTypeC, eltTypeA, + eltTypeB, layoutA, layoutB); + useC = tb.i1_val(1); + if (needsPartialAccumulator) + partialAcc = mmaAcc; + else + d = mmaAcc; + // If we need accumulate separately to have higher precision, insert + // adds. + if (requireAddAccumulator) { + d = d ? faddAccumulate(rewriter, loc, d, partialAcc) : partialAcc; + numLowPrecisionAcc = 0; + partialAcc = Value(); + } + } + auto acc = unpackLLElements(loc, d, rewriter); + for (int i = 0; i < acc.size(); ++i) { + mmaResults.push_back(acc[i]); + } + } + } + rewriter.create(loc); + + if (sync) + mmaResults = emitWait(rewriter, loc, mmaResults, 0); + + SmallVector results = + unpackAccumulator(rewriter, loc, mmaResults, dTensorTy); + + // replace with new packed result + Type structTy = LLVM::LLVMStructType::getLiteral( + mmaEncoding.getContext(), + SmallVector(results.size(), dTensorTy.getElementType())); + auto res = packLLElements(loc, typeConverter, results, rewriter, structTy); + rewriter.replaceOp(op, res); + return success(); +} + +LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, + triton::nvidia_gpu::WarpGroupDotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Value thread) { + auto AEnc = op.getA().getType().getEncoding(); + auto BEnc = op.getB().getType().getEncoding(); + assert(mlir::isa(AEnc) || + mlir::isa(AEnc)); + assert(mlir::isa(BEnc) && + "Operand B should use Shared layout."); + return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // + op.getA(), op.getB(), op.getC(), op.getD(), op.getUseC(), // + adaptor.getA(), adaptor.getB(), adaptor.getC(), // + op.getInputPrecision() == InputPrecision::TF32, + op.needsPartialAccumulator(), op.getMaxNumImpreciseAcc(), + !op.getIsAsync(), thread); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..4f1c36236 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,841 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir::triton::gpu; + +namespace mlir::triton { + +namespace gpu { +namespace { + +/* ----- FP8E5M2 ------ */ +// This data-type is the standard FP8E5M2 format + +struct Fp8ConversionDesc { + std::string ptx; + int inVecWidthBits; + int outVecWidthBits; + size_t numElements; +}; + +static const Fp8ConversionDesc Fp16_to_Fp8E5M2_RTNE(bool hasNativeFP) { + Fp8ConversionDesc ret; + if (!hasNativeFP) { + ret = {"{ \n" + ".reg .b32 a<2>; \n" + "and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe + "and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit) + "add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080 + "add.u32 a1, a1, 0x00800080; \n" // (round to nearest) + "prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0 + "}", + 32, 32, 4}; + } else { + ret = {"cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t", 32, 16, 2}; + } + return ret; +} + +const Fp8ConversionDesc Fp16_to_Fp8E5M2_RTZ = { + "{ \n" + ".reg .b32 a<2>; \n" + "and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe + "and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit) + "prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0 + "}", + 32, 32, 4}; + +static const Fp8ConversionDesc Fp8E5M2_to_Fp16(bool hasNativeFP) { + Fp8ConversionDesc ret; + if (!hasNativeFP) { + ret = {"{ \n" + "prmt.b32 $0, 0, $2, 0x5140; \n\t" + "prmt.b32 $1, 0, $2, 0x7362; \n\t" + "}", + 32, 32, 4}; + } else { + ret = {"cvt.rn.f16x2.e5m2x2 $0, $1; \n\t", 16, 32, 2}; + } + return ret; +} + +static const Fp8ConversionDesc Fp8E5M2_to_Bf16(bool hasNativeFP) { + Fp8ConversionDesc ret; + if (!hasNativeFP) { + ret = { + "{ \n" + ".reg .b32 a<2>, b<2>, c<4>, d<4>, e112; \n" // if input = 0xf1f2f3f4 + "mov.u32 e112, 0x77800000; \n" + "prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400 + "prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200 + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign) + "shr.b32 b0, b0, 3; \n" // b0 >>= 3 + "shr.b32 b1, b1, 3; \n" // shift into bf16 + // position + "and.b32 c0, b0, 0xFFFF0000; \n" // c0 = f3 + "shl.b32 c1, b0, 16; \n" // c1 = f4 + "and.b32 c2, b1, 0xFFFF0000; \n" // c2 = f1 + "shl.b32 c3, b1, 16; \n" // c3 = f2 + "mul.f32 d0, c0, e112; \n" // d0 = c0 * 0x77800000 + "mul.f32 d1, c1, e112; \n" // d1 = c1 * 0x77800000 + "mul.f32 d2, c2, e112; \n" // d2 = c2 * 0x77800000 + "mul.f32 d3, c3, e112; \n" // d3 = c3 * 0x77800000 + "prmt.b32 b0, d0, d1, 0x3276; \n" // b0 = 0xd3d4 + "prmt.b32 b1, d2, d3, 0x3276; \n" // b1 = 0xd1d2 + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = + // b0|(0x80008000&a0) + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign) + "}", + 32, 32, 4}; + } else { + ret = { + "{ \n" + ".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4 + ".reg .b32 e112; \n" + "mov.u32 e112, 0x77807780; \n" // 2**112 represented as + // bf16x2 + "prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400 + "prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200 + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign) + "shr.b32 b0, b0, 3; \n" // b0 >>= 3 + "shr.b32 b1, b1, 3; \n" // shift into bf16 position + "lop3.b32 b0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0) + "lop3.b32 b1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign) + "mul.rn.bf16x2 $0, b0, e112; \n" // b0.exp += 2**7-2**4 + "mul.rn.bf16x2 $1, b1, e112; \n" // exponent compensate = 112 + "}", + 32, 32, 4}; + } + return ret; +} + +static const Fp8ConversionDesc Bf16_to_Fp8E5M2(bool hasNativeFP) { + Fp8ConversionDesc ret; + if (!hasNativeFP) { + ret = { + "{ \n" // bf16=fp8>>3 + 112<<7 + ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000 + ".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111 + "mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800 + "mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0 + "mov.u32 rn_, 0x00100010; \n" // round to nearest + "and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000 + "and.b32 sign1, $2, 0x80008000; \n" // (store sign) + "prmt.b32 sign, sign0, sign1, 0x7531; \n" + "and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff + "and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign) + + // nosign = clamp(nosign, min, max) + ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n" + "and.b32 nosign_0_0, nosign0, 0xffff0000; \n" + "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n" + "min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n" + "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n" + "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n" + "min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n" + "or.b32 nosign0, nosign_0_0, nosign_0_1; \n" + "and.b32 nosign_1_0, nosign1, 0xffff0000; \n" + "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n" + "min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n" + "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n" + "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n" + "min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n" + "or.b32 nosign1, nosign_1_0, nosign_1_1; \n" + + "add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_ + "add.u32 nosign1, nosign1, rn_; \n" // (round to nearest) + "sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800 + "sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset) + "shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3 + "shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4 + "prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200 + // nosign1 = 0xf300f400 + // nosign = 0xf3f4f1f2 + "or.b32 $0, nosign, sign; \n" // restore sign + "}", + 32, 32, 4}; + } else { + ret = {"{ \n" + ".reg .b16 a<2>; \n" + ".reg .f32 b<2>; \n" + "mov.b32 {a0, a1}, $1; \n" + "cvt.f32.bf16 b0, a0; \n" + "cvt.f32.bf16 b1, a1; \n" + "cvt.rn.satfinite.e5m2x2.f32 $0, b1, b0; \n" + "}", + 32, 16, 2}; + } + return ret; +} + +// Fp8E4M3 (x2) -> Fp16 (x2) (packed) +static const Fp8ConversionDesc Fp8E4M3Nv_to_Fp16 = { + "{ \n" + "cvt.rn.f16x2.e4m3x2 $0, $1; \n" + "}", + 16, 32, 2}; + +// Fp16 (x2) -> Fp8E4M3 (x2) (packed) +static const Fp8ConversionDesc Fp16_to_Fp8E4M3Nv = { + "{ \n" + "cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n" + "}", + 32, 16, 2}; + +static const Fp8ConversionDesc Fp8E4M3Nv_to_Bf16(bool hasNativeFP) { + Fp8ConversionDesc ret; + // Fp8E4M3 (x2) -> Fp16 (x2) (packed) + if (!hasNativeFP) { + ret = {"{ \n" + ".reg .b32 a; \n" + ".reg .f16 a<2>; \n" + ".reg .f32 b<2>; \n" + ".reg .b16 c<2>; \n" + "cvt.rn.f16x2.e4m3x2 a, $1; \n" + "mov.b32 {a0, a1}, a; \n" + "cvt.f32.f16 b0, a0; \n" + "cvt.f32.f16 b1, a1; \n" + "cvt.rn.bf16.f32 c0, b0; \n" + "cvt.rn.bf16.f32 c1, b1; \n" + "mov.b32 $0, {c0, c1}; \n" + "}", + 16, 32, 2}; + } else { + ret = {"{ \n" + ".reg .b32 a; \n" + ".reg .f16 a<2>; \n" + ".reg .b16 b<2>; \n" + "cvt.rn.f16x2.e4m3x2 a, $1; \n" + "mov.b32 {a0, a1}, a; \n" + "cvt.bf16.f16 b0, a0; \n" + "cvt.bf16.f16 b1, a1; \n" + "mov.b32 $0, {b0, b1}; \n" + "}", + 16, 32, 2}; + } + return ret; +} + +// Bf16 (x2) -> Fp8E4M3 (x2) (packed) +static const Fp8ConversionDesc Bf16_to_Fp8E4M3Nv = { + "{ \n" + ".reg .b16 a<2>; \n" + ".reg .f32 b<2>; \n" + "mov.b32 {a0, a1}, $1; \n" + "cvt.f32.bf16 b0, a0; \n" + "cvt.f32.bf16 b1, a1; \n" + "cvt.rn.satfinite.e4m3x2.f32 $0, b1, b0; \n" + "}", + 32, 16, 2}; + +// Fp32 (x2) -> Fp8 (x2) (packed) +static const Fp8ConversionDesc Fp32_to_Fp8E4M3Nv = { + "cvt.rn.satfinite.e4m3x2.f32 $0, $2, $1; \n", 32, 16, 2}; +static const Fp8ConversionDesc Fp32_to_Fp8E5M2 = { + "cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n", 32, 16, 2}; + +/* ----- Packed integer to BF16 ------ */ +static const std::string S8_to_Bf16 = + "{ \n" + ".reg .s8 s<4>; \n" + ".reg .f32 f<4>; \n" + "mov.b32 {s0, s1, s2, s3}, $2; \n" // unpack + "cvt.rn.f32.s8 f0, s0; \n" // no s8->bf16 pre-Hopper + "cvt.rn.f32.s8 f1, s1; \n" // fi[0:15] is always 0 + "cvt.rn.f32.s8 f2, s2; \n" // + "cvt.rn.f32.s8 f3, s3; \n" // + "prmt.b32 $0, f0, f1, 0x7632; \n" // f32->bf16 + pack + "prmt.b32 $1, f2, f3, 0x7632; \n" // + "}"; +// Conversions have low throughput, rely on bit tricks instead of cvt +// instruction on Hopper and later GPUs. +static const std::string S8_to_Bf16_sm90 = + "{ \n" + ".reg .b32 l<3>; \n" + ".reg .b32 h<3>; \n" + "prmt.b32 l0, $2, 0x43, 0x4140; \n" // Unpack to shifted bf16. + "prmt.b32 h0, $2, 0x43, 0x4342; \n" + "and.b32 l1, l0, 0xff7fff7f; \n" // Zero the least exp bit. + "and.b32 h1, h0, 0xff7fff7f; \n" + "and.b32 l2, l0, 0xff80ff80; \n" // Zero the mantissa. + "and.b32 h2, h0, 0xff80ff80; \n" + "sub.bf16x2 $0, l1, l2; \n" // Subtract the offset. + "sub.bf16x2 $1, h1, h2; \n" + "}"; + +typedef std::function(Location, ConversionPatternRewriter &, + const SmallVector &)> + ConverterT; + +static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, + Type outType, + const int inVecWidthBits = 32, + const int outVecWidthBits = 32) { + ConverterT converter = + [ptxAsm, inType, outType, inVecWidthBits, + outVecWidthBits](Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) -> SmallVector { + auto b = TritonLLVMOpBuilder(loc, rewriter); + int numElements = v.size(); + assert(numElements == 4 || numElements == 2 && "invalid vector size"); + + auto ctx = rewriter.getContext(); + int inBitwidth = inType.getIntOrFloatBitWidth(); + int outBitwidth = outType.getIntOrFloatBitWidth(); + // first, we pack `v` into 32-bit ints + int inVecWidth = inVecWidthBits / inBitwidth; + auto inVecTy = vec_ty(inType, inVecWidth); + SmallVector inPacked(numElements / inVecWidth, b.undef(inVecTy)); + for (size_t i = 0; i < numElements; i++) + inPacked[i / inVecWidth] = b.insert_element( + inVecTy, inPacked[i / inVecWidth], v[i], b.i32_val(i % inVecWidth)); + for (size_t i = 0; i < inPacked.size(); i++) + inPacked[i] = b.bitcast(inPacked[i], int_ty(inVecWidthBits)); + + // then, we run the provided inline PTX + int outVecWidth = outVecWidthBits / outBitwidth; + int outNums = numElements / outVecWidth; + PTXBuilder builder; + SmallVector operands; + auto outConstriant = outVecWidthBits == 16 ? "=h" : "=r"; + auto inConstraint = inVecWidthBits == 16 ? "h" : "r"; + for (int i = 0; i < outNums; i++) { + operands.push_back(builder.newOperand(outConstriant)); + } + + for (Value inVal : inPacked) { + operands.push_back(builder.newOperand(inVal, inConstraint)); + } + + auto &ptxOp = *builder.create(ptxAsm); + ptxOp(operands, /*onlyAttachMLIRArgs=*/true); + auto outVecTy = vec_ty(outType, outVecWidth); + SmallVector outPacked; + if (outNums == 1) + outPacked.push_back(builder.launch(rewriter, loc, outVecTy, false)); + else { + auto outStructTy = struct_ty(SmallVector(outNums, outVecTy)); + auto outStruct = builder.launch(rewriter, loc, outStructTy, false); + for (int i = 0; i < outNums; i++) + outPacked.push_back(b.extract_val(outVecTy, outStruct, i)); + } + // unpack the output + SmallVector ret; + for (size_t i = 0; i < numElements; i++) + ret.push_back(b.extract_element(outType, outPacked[i / outVecWidth], + b.i32_val(i % outVecWidth))); + return ret; + }; + return converter; +} + +// Attempts to use vectorized conversions via inline PTX when possible. +struct FpToFpOpConversion + : public ElementwiseOpConversionBase { + using ElementwiseOpConversionBase< + FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase; + + explicit FpToFpOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} + + static Value convertFp16ToFp32(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v) { + return rewriter.create(loc, f32_ty, v); + } + + static Value convertFp32ToBf16(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v, const RoundingMode rounding) { + StringRef name; + switch (rounding) { + case RoundingMode::RTNE: + name = "llvm.nvvm.f2bf16.rn"; + break; + case RoundingMode::RTZ: + name = "llvm.nvvm.f2bf16.rz"; + break; + default: + emitError(loc) << "unsupported rounding mode for f32->bf16 conversion: " + << stringifyRoundingMode(rounding) << "\n"; + llvm::report_fatal_error( + "unsupported rounding mode for f32->bf16 conversion: " + + stringifyRoundingMode(rounding) + "\n"); + } + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, name, bf16_ty, {v}) + .getResult(0); + } + + static Value convertFp32ToFp16(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v, const RoundingMode rounding) { + PTXBuilder builder; + StringRef ptx; + switch (rounding) { + case RoundingMode::RTNE: + ptx = "cvt.rn.f16.f32"; + break; + case RoundingMode::RTZ: + ptx = "cvt.rz.f16.f32"; + break; + default: + emitError(loc) << "unsupported rounding mode for f32->f16 conversion: " + << stringifyRoundingMode(rounding) << "\n"; + llvm::report_fatal_error( + "unsupported rounding mode for f32->f16 conversion: " + + stringifyRoundingMode(rounding) + "\n"); + } + auto &cvt = *builder.create(ptx.str()); + auto res = builder.newOperand("=h"); + auto operand = builder.newOperand(v, "r"); + cvt(res, operand); + return builder.launch(rewriter, loc, f16_ty, false); + } + + std::pair + getConversionFunc(Type srcTy, Type dstTy, + std::optional roundingMode) const { + auto F8E4M3TyID = TypeID::get(); + auto F8E5M2TyID = TypeID::get(); + auto F16TyID = TypeID::get(); + auto BF16TyID = TypeID::get(); + auto F32TyID = TypeID::get(); + auto F64TyID = TypeID::get(); + + auto undefRounding = static_cast(-1); + + static DenseMap, Fp8ConversionDesc> + srcMap = { + // F8 -> F16 + {{F8E4M3TyID, F16TyID, undefRounding}, Fp8E4M3Nv_to_Fp16}, + {{F8E5M2TyID, F16TyID, undefRounding}, + Fp8E5M2_to_Fp16(computeCapability >= 90)}, + {{F16TyID, F8E4M3TyID, RoundingMode::RTNE}, Fp16_to_Fp8E4M3Nv}, + {{F16TyID, F8E5M2TyID, RoundingMode::RTNE}, + Fp16_to_Fp8E5M2_RTNE(computeCapability >= 90)}, + {{F16TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp16_to_Fp8E5M2_RTZ}, + // F8 -> BF16 + {{F8E5M2TyID, BF16TyID, undefRounding}, + Fp8E5M2_to_Bf16(computeCapability >= 90)}, + {{F8E4M3TyID, BF16TyID, undefRounding}, + Fp8E4M3Nv_to_Bf16(computeCapability >= 90)}, + // BF16 -> F8 + {{BF16TyID, F8E5M2TyID, RoundingMode::RTNE}, + Bf16_to_Fp8E5M2(computeCapability >= 90)}, + {{BF16TyID, F8E4M3TyID, RoundingMode::RTNE}, Bf16_to_Fp8E4M3Nv}, + // F32 -> F8 + {{F32TyID, F8E4M3TyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3Nv}, + {{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2}, + }; + std::tuple key = { + srcTy.getTypeID(), dstTy.getTypeID(), + roundingMode.value_or(undefRounding)}; + if (srcMap.count(key) == 0) { + llvm::errs() << "Unsupported conversion from " << srcTy << " to " + << dstTy; + if (roundingMode.has_value()) + llvm::errs() << " with rounding mode " + << stringifyRoundingMode(roundingMode.value()); + llvm::errs() << "\n"; + llvm::report_fatal_error("Unsupported rounding mode for conversion."); + } + if (computeCapability < 89 && (llvm::isa(srcTy) || + llvm::isa(dstTy))) { + llvm::report_fatal_error("Conversion from/to f8e4m3nv is only supported " + "on compute capability >= 89\n"); + } + auto convDesc = srcMap.lookup(key); + return {makeConverterFromPtx( + convDesc.ptx, getTypeConverter()->convertType(srcTy), + getTypeConverter()->convertType(dstTy), convDesc.inVecWidthBits, + convDesc.outVecWidthBits), + convDesc.numElements}; + } + + SmallVector createDestOps(FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcElementType = getElementType(op.getSrc()); + auto dstElementType = getElementType(op.getResult()); + auto roundingMode = op.getRounding(); + + if (llvm::isa(dstElementType)) { + assert(roundingMode.has_value() && + "Rounding mode must be specified for convertsions to fp8"); + + // For now only RTNE is supported for conversions from fp16 to fp8 + if (!srcElementType.isF32() && + roundingMode.value() != RoundingMode::RTNE) { + llvm::report_fatal_error( + "Unsupported rounding mode for conversion to fp8: " + + stringifyRoundingMode(roundingMode.value()) + "\n"); + } + } + + if (srcElementType.isF32() && dstElementType.isF16()) { + assert(roundingMode.has_value() && + "rounding mode must be specified for fp32->fp16 conversion"); + SmallVector outVals; + for (Value v : operands[0]) { + outVals.push_back( + convertFp32ToFp16(loc, rewriter, v, roundingMode.value())); + } + return outVals; + } + + if (srcElementType.isF32() && dstElementType.isBF16()) { + assert(roundingMode.has_value() && + "rounding mode must be specified for fp32->bf16 conversion"); + SmallVector outVals; + for (Value v : operands[0]) { + outVals.push_back( + convertFp32ToBf16(loc, rewriter, v, roundingMode.value())); + } + return outVals; + } + + bool useFP16IntermediateSrc = + srcElementType.isF32() && + (!(computeCapability >= 90 && + (llvm::isa(dstElementType))) || + roundingMode.value() == RoundingMode::RTZ); + bool isDstFP32 = dstElementType.isF32(); + Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; + Type dstType = isDstFP32 ? f16_ty : dstElementType; + auto [cvtFunc, numElements] = + getConversionFunc(srcType, dstType, roundingMode); + SmallVector inVals; + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { + inVals.push_back(operands[i][0]); + } + if (useFP16IntermediateSrc) + for (Value &v : inVals) + v = convertFp32ToFp16(loc, rewriter, v, RoundingMode::RTZ); + inVals.resize(numElements, b.undef(typeConverter->convertType(srcType))); + SmallVector outVals = cvtFunc(loc, rewriter, inVals); + assert(outVals.size() == inVals.size()); + outVals.resize(std::min(numElements, operands.size())); + if (isDstFP32) + for (Value &v : outVals) + v = convertFp16ToFp32(loc, rewriter, v); + // Pack values + return outVals; + } + +private: + int computeCapability; +}; + +struct FDivOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::DivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + StringRef name; + Type resultTy; + if (32 == bitwidth) { + name = "llvm.nvvm.div.full"; + resultTy = f32_ty; + } else if (64 == bitwidth) { + name = "llvm.nvvm.div.rn.d"; + resultTy = f64_ty; + } else { + llvm::report_fatal_error("Unsupported bitwidth"); + } + Value args[] = {operands[0][0], operands[0][1]}; + auto callOp = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, name, resultTy, args); + return {callOp.getResult(0)}; + } +}; + +// Uses inline ptx to convert s8/u8 to bf16, since the +struct SIToFPOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Adaptor = typename Base::OpAdaptor; + + explicit SIToFPOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} + + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + Type inElemTy = getElementType(op.getIn()); + Type outElemTy = getElementType(op.getOut()); + if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) { + auto cvtFunc = makeConverterFromPtx( + computeCapability >= 90 ? S8_to_Bf16_sm90 : S8_to_Bf16, + getTypeConverter()->convertType(inElemTy), + getTypeConverter()->convertType(outElemTy)); + SmallVector inVals = {operands[0][0], operands[1][0], + operands[2][0], operands[3][0]}; + auto outVals = cvtFunc(loc, rewriter, inVals); + assert(outVals.size() == 4); + return outVals; + } else { + return {rewriter.create(loc, elemTy, operands[0][0])}; + } + } + +private: + int computeCapability; +}; + +struct FPToSIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::FPToSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto inElemTy = getElementType(op.getIn()); + return {rewriter.create(loc, elemTy, operands[0][0])}; + } +}; + +struct ExpOpConversionApprox + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::ExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // For non-FP32 input, call __nv_expf for higher-precision calculation + if (elemTy.getIntOrFloatBitWidth() != 32) + return {}; + + const double log2e = 1.4426950408889634; + Value prod = b.fmul(f32_ty, operands[0][0], b.f32_val(log2e)); + + Type resultTy = operands[0][0].getType(); + StringRef name = "llvm.nvvm.ex2.approx.f"; + auto callOp = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, name, resultTy, {prod}); + return {callOp.getResult(0)}; + } +}; + +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} + + bool isClipPattern(ClampFOp op) const { + bool xorsignAbsAvailable = (computeCapability >= 90); + // Pattern matching the sequence of clamp(x, -limit, limit) to generate + // more efficient PTX code. NOTE: This pattern matching is not general + // enough, but it is sufficient. We detect only two cases here: + // 1. where the "-limit" is computed as 0 - limit: + // %cst = arith.constant dense<0.000000e+00> + // %8 = tt.load %7, %2 + // %11 = arith.subf %cst, %8 + // %12 = tt.clamp %5, %11, %8 + // 2. where "-limit" and "limit" are constants. + // %cst_6 = arith.constant dense<-6.0000e+00> + // %cst_7 = arith.constant dense<6.0000e+00> + // %160 = tt.clamp %158, %cst_6, %cst_7 + bool patternFound = false; + + auto getSplatInitializer = [](Value v) -> std::optional { + if (auto constOp = v.getDefiningOp()) { + if (auto attr = mlir::dyn_cast( + constOp.getValueAttr())) { + if (attr.isSplat()) { + return attr.getSplatValue().convertToDouble(); + } + } + } + return std::nullopt; + }; + + if (xorsignAbsAvailable) { + if (auto subOp = op.getOperand(1).getDefiningOp()) { + if (subOp.getOperand(1) == op.getOperand(2)) { + auto initializer = getSplatInitializer(subOp.getOperand(0)); + if (initializer.has_value() && initializer.value() == 0.0) { + patternFound = true; + } + } + } else { + auto initializer1 = getSplatInitializer(op.getOperand(1)); + auto initializer2 = getSplatInitializer(op.getOperand(2)); + if (initializer1.has_value() && initializer2.has_value() && + initializer1.value() == -initializer2.value()) { + patternFound = true; + } + } + } + return patternFound; + } + + SmallVector emitOptimization(ClampFOp op, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { + std::string name = "llvm.nvvm.fmin"; + if (op.getPropagateNan() == PropagateNan::ALL) { + name += ".nan"; + } + name += ".xorsign.abs"; + if (elemTy.isF32()) { + name += ".f"; + } else if (elemTy.isF16()) { + name += ".f16"; + } + + Type resultTy = operands[0][0].getType(); + Value args[] = {operands[0][0], operands[0][2]}; + auto callOp = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, name, resultTy, args); + return {callOp.getResult(0)}; + } + + SmallVector createDestOps(ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (isClipPattern(op)) { + return emitOptimization(op, rewriter, elemTy, operands, loc); + } + return {}; + } + +private: + int computeCapability; +}; + +template +struct OpToExternCallConversion + : public ElementwiseOpConversionBase> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + StringRef externFuncName, + PatternBenefit benefit) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + funcName(externFuncName) {} + + SmallVector createDestOps(TritonOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } + +private: + StringRef funcName; +}; +} // namespace +} // namespace gpu + +} // namespace mlir::triton + +void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, + const TargetInfo &targetInfo, PatternBenefit benefit) { + using namespace mlir::triton::gpu; + + patterns.add>( + typeConverter, axisInfoAnalysis, "__nv_fsqrt_rn", benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, "__nv_fdiv_rn", benefit); + + mlir::triton::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); + +#define POPULATE_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit) + + POPULATE_OP(arith::SubFOp, LLVM::FSubOp); + POPULATE_OP(arith::AddFOp, LLVM::FAddOp); + POPULATE_OP(arith::MulFOp, LLVM::FMulOp); + + POPULATE_OP(arith::ExtFOp, LLVM::FPExtOp); + POPULATE_OP(arith::TruncFOp, LLVM::FPTruncOp); + +#undef POPULATE_OP + + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); + + // ExpOpConversionApprox will try using ex2.approx if the input type is + // FP32. For other input types, ExpOpConversionApprox will return failure and + // ElementwiseOpConversion defined below will call + // __nv_expf for higher-precision calculation + patterns.add(typeConverter, axisInfoAnalysis, benefit); + bool hwNanPropagationSupported = computeCapability >= 80; + mlir::triton::populateMinMaxFOpToLLVMPattern( + typeConverter, patterns, axisInfoAnalysis, hwNanPropagationSupported, + benefit); + mlir::triton::populateClampFOpToLLVMPattern( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); +} + +void mlir::triton::NVIDIA::populateClampFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, + PatternBenefit benefit) { + using namespace mlir::triton::gpu; + + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Fp4ToFpOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Fp4ToFpOpToLLVM.cpp new file mode 100644 index 000000000..dbb27a26e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Fp4ToFpOpToLLVM.cpp @@ -0,0 +1,140 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" + +#include "PatternTritonGPUOpToLLVM.h" + +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Convert 8 fp4 elements packed into a 32bit reg into 8 bf16 elements packed +// into 4 32bits regs. +static constexpr const char *FP4ToBP16Ptx = + "{\n" + ".reg .b32 a<14>;\n" + "and.b32 a0, $4, -2004318072;\n\t" + "shr.u32 a1, a0, 3;\n\t" + "and.b32 a2, $4, 2004318071;\n\t" + "shr.u32 a3, a2, 16;\n\t" + "shr.u32 a4, a0, 19;\n\t" + "prmt.b32 a5, -1065353216, -1065336832, a2;\n\t" + "prmt.b32 a6, -1065353216, -1065336832, a3;\n\t" + "prmt.b32 a7, 1061109504, 1077952576, a2;\n\t" + "prmt.b32 a8, 1061109504, 1077952576, a3;\n\t" + "prmt.b32 a9, 32768, 0, a1;\n\t" + "prmt.b32 a10, 32768, 0, a4;\n\t" + "or.b32 a11, a7, a9;\n\t" + "or.b32 a12, a8, a10;\n\t" + "prmt.b32 $0, a5, a11, 20800;\n\t" + "prmt.b32 $1, a5, a11, 29538;\n\t" + "prmt.b32 $2, a6, a12, 20800;\n\t" + "prmt.b32 $3, a6, a12, 29538;\n\t" + "}"; + +static constexpr const char *FP4ToFP16Ptx = + "{\n" + ".reg .b32 a<11>;\n" + ".reg .b16 t<4>;\n" + "and.b32 a0, $4, 0x77777777;\n\t" + "and.b32 a1, $4, 0x88888888;\n\t" + "shr.u32 a2, a1, 3;\n\t" + "shr.u32 a3, a0, 16;\n\t" + "shr.u32 a4, a2, 16;\n\t" + "prmt.b32 a5, 0x3C383000, 0x4C484440, a0;\n" + "prmt.b32 a6, 0x3C383000, 0x4C484440, a3;\n" + "prmt.b32 a7, 0x00008000, 0x0, a2;\n" + "prmt.b32 a8, 0x00008000, 0x0, a4;\n" + "or.b32 a9, a5, a7;\n\t" + "or.b32 a10, a6, a8;\n\t" + "mov.b32 {t0, t1}, a9;\n" + "mov.b32 {t2, t3}, a10;\n" + "cvt.rn.f16x2.e4m3x2 $0, t0;\n" + "cvt.rn.f16x2.e4m3x2 $1, t1;\n" + "cvt.rn.f16x2.e4m3x2 $2, t2;\n" + "cvt.rn.f16x2.e4m3x2 $3, t3;\n" + "}"; + +static Value createInlineAsmUpcast(Location loc, RewriterBase &rewriter, + bool toFp16, Type retType, Value packedVec) { + PTXBuilder builder; + SmallVector operands; + for (int i = 0; i < 4; i++) { + operands.push_back(builder.newOperand("=r")); + } + operands.push_back(builder.newOperand(packedVec, "r")); + auto &ptxOp = *builder.create(toFp16 ? FP4ToFP16Ptx : FP4ToBP16Ptx); + ptxOp(operands, /*onlyAttachMLIRArgs=*/true); + Value result = builder.launch(rewriter, loc, retType, false); + return result; +} + +namespace { +class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern { +public: + Fp4ToFpOpPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + + LogicalResult + matchAndRewrite(Fp4ToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto elemType = op.getType().getElementType(); + assert(elemType == f16_ty || elemType == bf16_ty); + bool toFp16 = elemType == f16_ty; + + auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + + SmallVector results; + results.reserve(xVals.size() * 2); + assert(xVals.size() % 4 == 0); + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (int i = 0; i < xVals.size(); i += 4) { + Value v0 = xVals[i]; + Value v1 = xVals[i + 1]; + Value v2 = xVals[i + 2]; + Value v3 = xVals[i + 3]; + Value packedVec = b.undef(vec_ty(i8_ty, 4)); + packedVec = b.insert_element(packedVec, v0, b.i32_val(0)); + packedVec = b.insert_element(packedVec, v1, b.i32_val(1)); + packedVec = b.insert_element(packedVec, v2, b.i32_val(2)); + packedVec = b.insert_element(packedVec, v3, b.i32_val(3)); + SmallVector rets(4, i32_ty); + Type retType = struct_ty(rets); + Value ret = + createInlineAsmUpcast(loc, rewriter, toFp16, retType, packedVec); + for (int i = 0; i < 4; i++) { + Value extractI32 = b.extract_val(ret, i); + Value elements = b.bitcast(extractI32, vec_ty(elemType, 2)); + results.push_back(b.extract_element(elements, b.i32_val(0))); + results.push_back(b.extract_element(elements, b.i32_val(1))); + } + } + + Value result = packLLElements(loc, getTypeConverter(), results, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // anonymous namespace + +void mlir::triton::NVIDIA::populateFp4ToFpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp new file mode 100644 index 000000000..dce3f0bc7 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -0,0 +1,1699 @@ +#include "Dialect/NVGPU/IR/Dialect.h" +#include "TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" + +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" + +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "triton/Tools/Sys/GetEnv.hpp" +#include + +using namespace mlir; +using namespace mlir::triton; +namespace ttg = mlir::triton::gpu; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getCTALayout; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::NVMMASharedEncodingAttr; + +namespace ttg = mlir::triton::gpu; + +// Toggle this to work around Cooperative Grid Launch ld.acquire optimized path +static constexpr bool disableLDAcquireLowering = false; + +namespace { + +llvm::MapVector getAllFreeVarMasks(MLIRContext *ctx) { + // Mask where all elements are redundant + auto kReg = str_attr("reg"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + int32_t fullMask = -1; + llvm::MapVector ret; + for (auto dimName : {kReg, kLane, kWarp, kBlock}) { + ret[dimName] = fullMask; + } + return ret; +} + +llvm::MapVector getFreeVariableMasks(Type type) { + auto ctx = type.getContext(); + auto tensorTy = dyn_cast(type); + if (!tensorTy) { + return getAllFreeVarMasks(ctx); + } + + auto ll = ttg::toLinearLayout(tensorTy.getShape(), tensorTy.getEncoding()); + return ll.getFreeVariableMasks(); +} + +Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + if (a && b) { + return tb.and_(a, b); + } + return a ? a : b; +} + +// Return a predicate that is true only if the current thread holds unique data, +// according to freeVarsMask. The predicate may be null to indicate no +// predication is required. +Value emitRedundantThreadPredicate( + ModuleOp moduleOp, const llvm::MapVector &freeVarMasks, + ConversionPatternRewriter &rewriter, Location loc, + const NVIDIA::TargetInfo &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ctx = rewriter.getContext(); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + Value zero = b.i32_val(0); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = freeVarMasks.lookup(kBlock) == 0 + ? zero + : targetInfo.getClusterCTAId(rewriter, loc); + + Value pred; + auto dimNames = {kLane, kWarp, kBlock}; + auto dimIds = {laneId, warpId, blockId}; + for (auto [dimName, dimId] : llvm::zip(dimNames, dimIds)) { + int32_t mask = freeVarMasks.lookup(dimName); + if (mask != 0) { + auto dimPred = b.icmp_eq(b.and_(dimId, b.i32_val(mask)), zero); + pred = maybeAnd(rewriter, loc, pred, dimPred); + } + } + return pred; +} + +bool isCanonicalIndex(unsigned index, unsigned freeVarMask) { + return (index & freeVarMask) == 0; +} + +unsigned getCanonicalIndex(unsigned index, unsigned freeVarMask) { + return index & ~freeVarMask; +} + +std::string getRegisterSizeCode(int size, bool is_float) { + switch (size) { + case 1: + return "b"; + case 16: + return "h"; + case 32: + return is_float ? "f" : "r"; + case 64: + return is_float ? "d" : "l"; + case 128: + return "q"; + default: + llvm_unreachable("Unsupported register size"); + } +} + +// Contains some helper functions for both Load and Store conversions. +struct LoadStoreConversionBase { + explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass) + : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} + + unsigned getContiguity(Value ptr) const { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + return axisAnalysisPass.getContiguity(ptr); + } + + unsigned getVectorSize(Value ptr) const { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto contiguity = getContiguity(ptr); + auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + LDBG("getVectorSize contiguity = " << contiguity << " pointeeBitWidth = " + << pointeeBitWidth); + // The maximum vector size is 128 bits on NVIDIA GPUs. + return std::min(128 / pointeeBitWidth, contiguity); + } + + unsigned getMaskAlignment(Value mask) const { + return axisAnalysisPass.getMaskAlignment(mask); + } + +protected: + const NVIDIA::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + +struct LoadOpConversion : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + LoadOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = getContext(); + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto typeConverter = getTypeConverter(); + + // original values + Value ptr = op.getPtr(); + Value mask = op.getMask(); + Value other = op.getOther(); + LDBG("Lower LoadOp for " << ptr); + + // adaptor values + assert(!isTensorPointerType(ptr.getType()) && + "Cannot convert load with a tensor pointer into LLVM; " + "this case should be transformed to normal load before lowering"); + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + + // Determine the vectorization size + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(op.getType())); + unsigned vec = getVectorSize(ptr); + unsigned numElems = getTotalElemsPerThread(ptr.getType()); + unsigned vecOrig = vec; + if (llMask) { + LLVM_DEBUG(DBGS() << "vec = " << vec + << " mask_alignment = " << getMaskAlignment(mask)); + vec = std::min(vec, getMaskAlignment(mask)); + LLVM_DEBUG(llvm::dbgs() << " vec = " << vec << '\n'); + } + + if (vec == 1 && numElems > 1) { + int maskValue = !llMask ? -1 : getMaskAlignment(mask); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " numElems = " << numElems << " mask is " << maskValue + << "\n"; + } + // Get the LLVM values for pointers + auto ptrElems = unpackLLElements(loc, llPtr, rewriter); + assert(ptrElems.size() == numElems); + + // Get the LLVM values for mask + SmallVector maskElems; + if (llMask) { + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(maskElems.size() == numElems); + } + + // Get the LLVM values for `other` + // TODO: (goostavz) handle when other is const but not splat, which + // should be rarely seen + bool otherIsSplatConstInt = false; + DenseElementsAttr constAttr; + int64_t splatVal = 0; + if (other && isa(valueElemTy) && + matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && + isa(constAttr.getElementType())) { + otherIsSplatConstInt = true; + splatVal = constAttr.getSplatValue().getSExtValue(); + } + SmallVector otherElems; + if (other) { + otherElems = unpackLLElements(loc, llOther, rewriter); + } + + // vectorized iteration through all the pointer/mask/other elements + const int valueElemNBits = + std::max(8u, valueElemTy.getIntOrFloatBitWidth()); + const int numVecs = numElems / vec; + + // Load redundantly in all dims except reg + auto freeVarMasks = getFreeVariableMasks(ptr.getType()); + uint32_t regMask = freeVarMasks[str_attr("reg")]; + + LDBG("LoadOp numElems = " << numElems << " vec = " << vec + << " valueElemNBits = " << valueElemNBits << " " + << op.getType()); + SmallVector loadedVals; + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + if (auto canonicalVecStart = getCanonicalIndex(vecStart, regMask); + vecStart != canonicalVecStart) { + // For redundant registers, refer back to the canonical load + for (auto iVec = 0; iVec < vec; ++iVec) { + loadedVals.push_back(loadedVals[canonicalVecStart + iVec]); + } + continue; + } + + // TODO: optimization when ptr is GEP with constant offset + size_t in_off = 0; + + const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t totalWidth = valueElemNBits * vec; + const size_t width = std::min(totalWidth, maxWordWidth); + const size_t nWords = std::max(1, totalWidth / width); + const size_t wordNElems = width / valueElemNBits; + const size_t movWidth = width < 16 ? 16 : width; + assert(wordNElems * nWords * numVecs == numElems); + + // TODO(Superjomn) Add cache policy fields to StoreOp. + // TODO(Superjomn) Deal with cache policy here. + const bool hasL2EvictPolicy = false; + + PTXBuilder ptxBuilder; + + Value pred = mask ? maskElems[vecStart] : Value{}; + + const std::string readConstraint = + (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + const std::string writeConstraint = + (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); + + // prepare asm operands + auto *dstsOpr = ptxBuilder.newListOperand(); + // If there is a `other` value, use it to init. + bool init = other == nullptr; + for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { + auto *opr = ptxBuilder.newOperand(writeConstraint, + init); // =r operations + dstsOpr->listAppend(opr); + } + + if (other) { + for (size_t ii = 0; ii < nWords; ++ii) { + // PTX doesn't support mov.u8, so we need to use mov.u16 + PTXInstr &mov = + ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth)); + + size_t size = width / valueElemNBits; + + auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); + Value v = b.undef(vecTy); + for (size_t s = 0; s < size; ++s) { + Value falseVal = otherElems[vecStart + ii * size + s]; + Value sVal = createIndexAttrConstant( + rewriter, loc, typeConverter->getIndexType(), s); + v = b.insert_element(vecTy, v, falseVal, sVal); + } + v = b.bitcast(v, IntegerType::get(getContext(), width)); + + PTXInstr::Operand *opr{}; + + if (otherIsSplatConstInt) { + int64_t replicatedSplatVal = 0; + for (size_t s = 0; s < movWidth; s += valueElemNBits) { + replicatedSplatVal |= splatVal << s; + } + opr = ptxBuilder.newConstantOperand(replicatedSplatVal); + } else + opr = ptxBuilder.newOperand(v, readConstraint); + + mov(dstsOpr->listGet(ii), opr); + } + } + + auto *addrOpr = + ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + + // Define the instruction opcode + auto &ld = ptxBuilder.create<>("ld") + ->o("volatile", op.getIsVolatile()) + .global() + .o("ca", op.getCache() == triton::CacheModifier::CA) + .o("cg", op.getCache() == triton::CacheModifier::CG) + .o("L1::evict_first", + op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", + op.getEvict() == triton::EvictionPolicy::EVICT_LAST) + .o("L1::cache_hint", hasL2EvictPolicy) + .v(nWords) + .b(width); + + PTXBuilder::Operand *evictOpr{}; + + // Here lack a mlir::Value to bind to this operation, so disabled. + // if (has_l2_evict_policy) + // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); + + if (!evictOpr) + ld(dstsOpr, addrOpr).maybePredicate(pred, "b"); + else + ld(dstsOpr, addrOpr, evictOpr).maybePredicate(pred, "b"); + + // Create inline ASM signature + SmallVector retTys(nWords, IntegerType::get(getContext(), width)); + Type retTy = retTys.size() > 1 + ? LLVM::LLVMStructType::getLiteral(getContext(), retTys) + : retTys[0]; + + // TODO: if (has_l2_evict_policy) + // auto asmDialectAttr = + // LLVM::AsmDialectAttr::get(rewriter.getContext(), + // LLVM::AsmDialect::AD_ATT); + Value ret = ptxBuilder.launch(rewriter, loc, retTy); + + // Extract and store return values + SmallVector rets; + for (unsigned int ii = 0; ii < nWords; ++ii) { + Value curr; + if (isa(retTy)) { + curr = b.extract_val(IntegerType::get(getContext(), width), ret, ii); + } else { + curr = ret; + } + curr = b.bitcast(curr, LLVM::getFixedVectorType( + valueElemTy, width / valueElemNBits)); + rets.push_back(curr); + } + int tmp = width / valueElemNBits; + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, typeConverter->getIndexType(), ii % tmp); + Value loaded = b.extract_element(valueElemTy, rets[ii / tmp], vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Value resultStruct = packLLElements(loc, typeConverter, loadedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct StoreOpConversion : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + StoreOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = op.getPtr(); + Value value = op.getValue(); + + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llValue = adaptor.getValue(); + + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + auto valueTy = value.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + + unsigned vec = getVectorSize(ptr); + unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); + + auto ptrElems = unpackLLElements(loc, llPtr, rewriter); + auto valueElems = unpackLLElements(loc, llValue, rewriter); + assert(ptrElems.size() == valueElems.size()); + + // Determine the vectorization size + unsigned vecOrig = vec; + SmallVector maskElems; + if (llMask) { + Value mask = op.getMask(); + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(valueElems.size() == maskElems.size()); + + unsigned maskAlign = getMaskAlignment(mask); + vec = std::min(vec, maskAlign); + } + + if (vec == 1 && elemsPerThread > 1) { + int mask = !llMask ? -1 : getMaskAlignment(op.getMask()); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << " mask is " + << mask << "\n"; + } + + auto moduleOp = op->getParentOfType(); + const size_t dtsize = + std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); + const size_t valueElemNBits = dtsize * 8; + + auto freeVarMasks = getFreeVariableMasks(ptr.getType()); + Value threadPred = emitRedundantThreadPredicate(moduleOp, freeVarMasks, + rewriter, loc, targetInfo); + uint32_t regMask = freeVarMasks[str_attr("reg")]; + + const int numVecs = elemsPerThread / vec; + for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { + if (!isCanonicalIndex(vecStart, regMask)) { + // Don't emit store ops for redundant elements within a thread + continue; + } + // TODO: optimization when ptr is AddPtr with constant offset + size_t in_off = 0; + + const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t totalWidth = valueElemNBits * vec; + const size_t width = std::min(totalWidth, maxWordWidth); + const size_t nWords = std::max(1, totalWidth / width); + const size_t wordNElems = width / valueElemNBits; + assert(wordNElems * nWords * numVecs == elemsPerThread); + + // TODO(Superjomn) Add cache policy fields to StoreOp. + // TODO(Superjomn) Deal with cache policy here. + + Type valArgTy = IntegerType::get(ctx, width); + auto wordTy = vec_ty(valueElemTy, wordNElems); + + SmallVector> asmArgs; + for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { + // llWord is a width-len composition + Value llWord = b.undef(wordTy); + // Insert each value element to the composition + for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { + const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; + assert(elemOffset < valueElems.size()); + Value elem = valueElems[elemOffset]; + if (elem.getType().isInteger(1)) + elem = b.sext(i8_ty, elem); + elem = b.bitcast(elem, valueElemTy); + + llWord = b.insert_element(wordTy, llWord, elem, b.i32_val(elemIdx)); + } + llWord = b.bitcast(llWord, valArgTy); + std::string constraint = + (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + asmArgs.emplace_back(llWord, constraint); + } + + // Prepare the PTX inline asm. + PTXBuilder ptxBuilder; + auto *asmArgList = ptxBuilder.newListOperand(asmArgs); + + Value pred = threadPred; + if (llMask) { + auto mask = maskElems[vecStart]; + pred = maybeAnd(rewriter, loc, pred, mask); + } + + auto *asmAddr = + ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + + auto &ptxStoreInstr = + ptxBuilder.create<>("st") + ->global() + .o("wb", op.getCache() == triton::CacheModifier::WB) + .o("cg", op.getCache() == triton::CacheModifier::CG) + .o("cs", op.getCache() == triton::CacheModifier::CS) + .o("wt", op.getCache() == triton::CacheModifier::WT) + .o("L1::evict_first", + op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", + op.getEvict() == triton::EvictionPolicy::EVICT_LAST) + .v(nWords) + .b(width); + ptxStoreInstr(asmAddr, asmArgList).maybePredicate(pred, "b"); + + auto asmReturnTy = void_ty(ctx); + ptxBuilder.launch(rewriter, loc, asmReturnTy); + } + rewriter.eraseOp(op); + return success(); + } +}; + +void createBarrier(ConversionPatternRewriter &rewriter, Operation *op, + int numCTAs) { + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (numCTAs == 1) { + insertBarrier(rewriter, op); + } else { + rewriter.create(loc, false); + rewriter.create(loc); + } +} + +struct AtomicCASOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + AtomicCASOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for AtomicCASOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + + Value llPtr = adaptor.getPtr(); + Value llCmp = adaptor.getCmp(); + Value llVal = adaptor.getVal(); + + auto ptrElements = unpackLLElements(loc, llPtr, rewriter); + auto cmpElements = unpackLLElements(loc, llCmp, rewriter); + auto valElements = unpackLLElements(loc, llVal, rewriter); + + auto valueTy = op.getType(); + auto tensorTy = dyn_cast(valueTy); + Type valueElemTy = + tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) + : valueTy; + auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); + // vec = 1 for scalar + auto vec = getVectorSize(op.getPtr()); + auto vecOrig = vec; + // tensor + if (tensorTy) { + auto valTy = cast(op.getVal().getType()); + vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + } + + if (vec == 1 && elemsPerThread > 1) + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << "\n"; + + auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType()); + Value threadPred = emitRedundantThreadPredicate(moduleOp, freeVarMasks, + rewriter, loc, targetInfo); + uint32_t regMask = freeVarMasks[str_attr("reg")]; + + auto vecTy = vec_ty(valueElemTy, vec); + SmallVector resultVals(elemsPerThread); + + for (size_t i = 0; i < elemsPerThread; i += vec) { + if (auto canonicalVecStart = getCanonicalIndex(i, regMask); + canonicalVecStart != i) { + // For redundant registers, refer back to the canonical result + for (auto iVec = 0; iVec < vec; ++iVec) { + resultVals[i + iVec] = resultVals[canonicalVecStart + iVec]; + } + continue; + } + + Value casVal = b.undef(vecTy); + for (int ii = 0; ii < vec; ++ii) { + Value iiVal = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal); + } + + Value casPtr = ptrElements[i]; + Value casCmp = cmpElements[i]; + casVal = valElements[i]; + PTXBuilder ptxBuilderAtomicCAS; + std::string tyId = valueElemNBits * vec == 64 + ? "l" + : (valueElemNBits * vec == 32 ? "r" : "h"); + auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true); + auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l"); + auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId); + auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, tyId); + auto &atom = *ptxBuilderAtomicCAS.create("atom"); + auto sTy = "b" + std::to_string(valueElemNBits); + std::string semStr; + llvm::raw_string_ostream os(semStr); + os << op.getSem(); + auto scope = stringifyMemSyncScope(op.getScope()).str(); + atom.global().o(semStr).o(scope).o("cas").o(sTy); + atom(dstOpr, ptrOpr, cmpOpr, valOpr).maybePredicate(threadPred); + + if (tensorTy) { + auto retType = vec == 1 ? valueElemTy : vecTy; + auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType); + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? ret + : b.extract_element(valueElemTy, ret, b.i32_val(ii)); + } + } else { + auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); + atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); + // Only threads with mask = True store the result + PTXBuilder ptxBuilderStore; + auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); + auto *valOprStore = ptxBuilderStore.newOperand(old, "r"); + auto &st = *ptxBuilderStore.create("st"); + st.shared().o(sTy); + st(dstOprStore, valOprStore).maybePredicate(threadPred); + auto ASMReturnTy = void_ty(ctx); + ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); + createBarrier(rewriter, op, numCTAs); + Value ret = b.load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + } + } + + if (tensorTy) { + Type structTy = getTypeConverter()->convertType(tensorTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + } + return success(); + } +}; + +struct AtomicRMWOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + AtomicRMWOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + bool supportsVectorized(RMWOp opType, Type elementType) const { + // vectorized atomics are only supported on hopper, + // and only for specific atomic ops (add, min, max). + // Note that "packed types" like f16x2 are supported sm60+. + if (!targetInfo.supportVectorizedAtomics()) { + return false; + } + + return opType == RMWOp::FADD && + (elementType.isF16() || elementType.isBF16() || elementType.isF32()); + } + + bool isPromotableToNVPTXLD(triton::AtomicRMWOp op) const { + if (disableLDAcquireLowering) + return false; + + Type valueTy = + getTypeConverter()->convertType(getElementTypeOrSelf(op.getType())); + + if (!valueTy.isIntOrFloat()) + return false; + if (op.getSem() != triton::MemSemantic::ACQUIRE && + op.getSem() != triton::MemSemantic::RELAXED) + return false; + if (op.getScope() != triton::MemSyncScope::CTA && + op.getScope() != triton::MemSyncScope::GPU && + op.getScope() != triton::MemSyncScope::SYSTEM) + return false; + + if (op.getAtomicRmwOp() != RMWOp::ADD && op.getAtomicRmwOp() != RMWOp::FADD) + return false; + if (isa(op.getType())) + return false; + if (!op.getVal().getDefiningOp()) + return false; + if (!isa(op.getVal().getDefiningOp())) + return false; + + auto constOp = cast(op.getVal().getDefiningOp()); + if (!isa(constOp.getValueAttr()) && + !isa(constOp.getValueAttr())) + return false; + + if (auto attr = dyn_cast_or_null(constOp.getValueAttr())) + if (!attr.getValue().isZero()) + return false; + + if (auto attr = dyn_cast_or_null(constOp.getValueAttr())) + if (!attr.getValue().isZero()) + return false; + + return true; + } + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for AtomicRMWOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + + auto atomicRmwAttr = op.getAtomicRmwOp(); + + Value val = op.getVal(); + Value ptr = op.getPtr(); + + Value llPtr = adaptor.getPtr(); + Value llVal = adaptor.getVal(); + Value llMask = adaptor.getMask(); + + auto valElements = unpackLLElements(loc, llVal, rewriter); + auto ptrElements = unpackLLElements(loc, llPtr, rewriter); + SmallVector maskElements; + if (llMask) + maskElements = unpackLLElements(loc, llMask, rewriter); + + auto valueTy = op.getType(); + auto tensorTy = dyn_cast(valueTy); + Type valueElemTy = + tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) + : valueTy; + const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + auto elemsPerThread = getTotalElemsPerThread(val.getType()); + // packed: e.g. packed=2 for f16x2 + // vec: e.g. .v2, .v4, .v8 version of atom instruction. + unsigned vec, vecOrig; + int numElems, packed; + if (tensorTy) { + vec = getVectorSize(ptr); + if (llMask) { + vec = std::min(vec, getMaskAlignment(op.getMask())); + } + vecOrig = vec; + packed = 1; + auto valTy = cast(val.getType()); + if (!supportsVectorized(atomicRmwAttr, valTy.getElementType())) { + packed = + std::min(vecOrig, valTy.getElementType().isF16() ? 2 : 1); + vec = 1; + } + numElems = tensorTy.getNumElements(); + } else { + // scalar + vec = 1; + vecOrig = 1; + numElems = 1; + packed = 1; + } + assert((packed == 1 || vec == 1) && "packed or vec must be 1"); + + if (vec * packed == 1 && numElems > 1) + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " packed = " << packed << " origin vec = " << vecOrig + << " numElems = " << numElems; + + auto freeVarMasks = getFreeVariableMasks(ptr.getType()); + Value threadPred = emitRedundantThreadPredicate(moduleOp, freeVarMasks, + rewriter, loc, targetInfo); + uint32_t regMask = freeVarMasks[str_attr("reg")]; + + auto packedTy = vec_ty(valueElemTy, packed); + SmallVector resultVals(elemsPerThread); + + // Lower AtomicRMWOp to a ld.acquire if possible + std::unordered_map + ScopeMap = { + {triton::MemSyncScope::CTA, triton::nvgpu::MemSyncScope::CTA}, + {triton::MemSyncScope::GPU, triton::nvgpu::MemSyncScope::GPU}, + {triton::MemSyncScope::SYSTEM, + triton::nvgpu::MemSyncScope::SYSTEM}}; + const bool doPTXLDPromotion = isPromotableToNVPTXLD(op) && vec == 1 && + packed == 1 && ScopeMap.count(op.getScope()); + + for (size_t i = 0; i < elemsPerThread; i += vec * packed) { + if (auto canonicalStart = getCanonicalIndex(i, regMask); + canonicalStart != i) { + // For redundant registers, refer back to the canonical result + for (auto iVecPack = 0; iVecPack < vec * packed; ++iVecPack) { + resultVals[i + iVecPack] = resultVals[canonicalStart + iVecPack]; + } + continue; + } + + Value rmwPtr = ptrElements[i]; + Value pred = llMask ? maybeAnd(rewriter, loc, threadPred, maskElements[i]) + : threadPred; + + if (doPTXLDPromotion) { + Type covertedValueTy = + getTypeConverter()->convertType(getElementTypeOrSelf(op.getType())); + auto loadAcquireOp = rewriter.create( + op.getLoc(), covertedValueTy, rmwPtr, pred, + op.getSem() == triton::MemSemantic::ACQUIRE + ? triton::nvgpu::MemSemantic::ACQUIRE + : triton::nvgpu::MemSemantic::RELAXED, + ScopeMap[op.getScope()]); + + auto ASMReturnTy = void_ty(ctx); + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); + atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); + // Only threads with rmwMask = True store the result + targetInfo.storeShared(rewriter, loc, atomPtr, loadAcquireOp, pred); + createBarrier(rewriter, op, numCTAs); + Value ret = b.load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + continue; + } + + std::string sTy; + PTXBuilder ptxBuilderAtomicRMW; + // 16-bit -> "h", 32-bit -> "r", 64-bit -> "l" + std::string tyId = + getRegisterSizeCode(valueElemNBits * packed, /*is_float=*/false); + + PTXBuilder::Operand *dstOpr; + if (vec > 1) { + dstOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + dstOpr->listAppend( + ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true)); + } + } else { + dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + } + + auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); + + PTXBuilder::Operand *valOpr; + if (vec > 1) { + valOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + valOpr->listAppend( + ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId)); + } + } else if (packed > 1) { + Value rmwVal = b.undef(packedTy); + for (int ii = 0; ii < packed; ++ii) { + rmwVal = b.insert_element(packedTy, rmwVal, valElements[i + ii], + b.i32_val(ii)); + } + valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + } else { + valOpr = ptxBuilderAtomicRMW.newOperand(valElements[i], tyId); + } + + auto scope = stringifyMemSyncScope(op.getScope()).str(); + auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope); + auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); + auto sBits = std::to_string(valueElemNBits); + switch (atomicRmwAttr) { + case RMWOp::AND: + sTy = "b" + sBits; + break; + case RMWOp::OR: + sTy = "b" + sBits; + break; + case RMWOp::XOR: + sTy = "b" + sBits; + break; + case RMWOp::ADD: + sTy = "u" + sBits; + break; + case RMWOp::FADD: + rmwOp = "add"; + rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); + sTy = "f" + sBits; + sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : ""; + break; + case RMWOp::MAX: + sTy = "s" + sBits; + break; + case RMWOp::MIN: + sTy = "s" + sBits; + break; + case RMWOp::UMAX: + rmwOp = "max"; + sTy = "u" + sBits; + break; + case RMWOp::UMIN: + rmwOp = "min"; + sTy = "u" + sBits; + break; + case RMWOp::XCHG: + sTy = "b" + sBits; + break; + default: + return failure(); + } + std::string semStr; + llvm::raw_string_ostream os(semStr); + os << op.getSem(); + atom.o(semStr).o(rmwOp).v(vec).o(sTy); + if (tensorTy) { + atom(dstOpr, ptrOpr, valOpr).maybePredicate(pred); + Type retType; + if (vec > 1) { + SmallVector retTys(vec, valueElemTy); + retType = struct_ty(retTys); + } else if (packed > 1) { + retType = packedTy; + } else { + retType = valueElemTy; + } + + auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); + + if (vec > 1) { + for (unsigned ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = b.extract_val(valueElemTy, ret, ii); + } + } else if (packed > 1) { + for (unsigned ii = 0; ii < packed; ++ii) { + resultVals[i + ii] = + b.extract_element(valueElemTy, ret, b.i32_val(ii)); + } + } else { + resultVals[i] = ret; + } + + } else { + auto ASMReturnTy = void_ty(ctx); + atom(dstOpr, ptrOpr, valOpr).maybePredicate(pred); + auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); + atomPtr = b.bitcast(atomPtr, ptr_ty(ctx, 3)); + // Only threads with rmwMask = True store the result + targetInfo.storeShared(rewriter, loc, atomPtr, old, pred); + createBarrier(rewriter, op, numCTAs); + Value ret = b.load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + } + } + if (tensorTy) { + Type structTy = getTypeConverter()->convertType(tensorTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + } + return success(); + } +}; + +struct AsyncCopyGlobalToLocalOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value res = op.getResult(); + Value mask = op.getMask(); + Value other = op.getOther(); + auto funcOp = op->getParentOfType(); + + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + auto srcLayout = srcTy.getEncoding(); + + Value llDst = adaptor.getResult(); + Value llSrc = adaptor.getSrc(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + + // %src + auto srcElems = unpackLLElements(loc, llSrc, rewriter); + + // %dst + auto smemObj = + getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter); + // %mask + SmallVector maskElems; + if (llMask) { + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(srcElems.size() == maskElems.size()); + } + + // %other + SmallVector otherElems; + if (llOther) { + // FIXME(Keren): assume other is 0 for now. + // + // It's not necessary for now because the pipeline pass will skip + // generating insert_slice_async if the load op has any "other" tensor. + otherElems = unpackLLElements(loc, llOther, rewriter); + assert(srcElems.size() == otherElems.size()); + } + + // We can load N elements at a time if: + // 1. Every group of N source pointers are contiguous. For example, if + // N=2, then the pointers should be [x, x+1, y, y+1, ...]. + // 2. The mask (if present) has "alignment" N, meaning that each group of N + // mask bits are the same. For example if N=2, the mask must be + // [x, x, y, y, ...]. + unsigned maxVec = getContiguity(op.getSrc()); + if (mask) { + maxVec = std::min(maxVec, getMaskAlignment(mask)); + } + + // Addresses to store into, one per `vecTy`. + VectorType vecTy; + SmallVector shmemAddrs; + bool ok = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { + vecTy = vecTy_; + shmemAddrs.push_back(shmemAddr); + }); + assert(ok); + + int vecBytes = vecTy.getNumElements() * vecTy.getElementTypeBitWidth() / 8; + assert(llvm::isPowerOf2_32(vecBytes)); + if (vecBytes < 4) { + return emitError(loc, "cp.async does not support transfers smaller than " + "4 bytes; calculated this as ") + << vecBytes << " bytes"; + } + + auto moduleOp = op->getParentOfType(); + auto freeVarMasks = getFreeVariableMasks(srcTy); + // NOTE(@peterbell10): We load redundant data on different CTAs, so the data + // is available in each CTAs respective shared memory. Otherwise, we would + // need an additional broadcast step to copy the data between CTAs. + freeVarMasks[str_attr("block")] = 0; + Value threadPred = emitRedundantThreadPredicate(moduleOp, freeVarMasks, + rewriter, loc, targetInfo); + uint32_t regMask = freeVarMasks[str_attr("reg")]; + + for (int i = 0; i < shmemAddrs.size(); i++) { + // It's possible that vecTy is larger than 128 bits, in which case we have + // to use multiple cp.async instructions. + int wordBytes = std::min(vecBytes, 16); + int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth(); + int numWordsInVec = std::max(1, vecBytes / wordBytes); + for (int j = 0; j < numWordsInVec; j++) { + int elemIdx = i * vecTy.getNumElements() + j * wordElems; + + if (!isCanonicalIndex(elemIdx, regMask)) { + continue; // Skip redundant registers + } + + // Tune CG and CA. + CacheModifier srcCacheModifier = + wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA; + assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4); + + PTXBuilder ptxBuilder; + auto ©AsyncOp = + *ptxBuilder.create(srcCacheModifier); + auto *dstOperand = ptxBuilder.newAddrOperand(shmemAddrs[i], "r", + /*offset=*/j * wordBytes); + auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[elemIdx], "l"); + auto *copySize = ptxBuilder.newConstantOperand(wordBytes); + auto *srcSize = copySize; + if (op.getMask()) { + // We don't use predicate in this case, setting src-size to 0 + // if there's any mask. cp.async will automatically fill the + // remaining slots with 0 if cp-size > src-size. + // XXX(Keren): Always assume other = 0 for now. + // When 'other != 0' is supported, we will need to fold the + // op.getMask() and redundantDataMask() into the same predicate, the + // way it is done for LoadOp. + auto selectOp = + b.select(maskElems[elemIdx], b.i32_val(wordBytes), b.i32_val(0)); + srcSize = ptxBuilder.newOperand(selectOp, "r"); + } + + copyAsyncOp(dstOperand, srcOperand, copySize, srcSize) + .maybePredicate(threadPred); + ptxBuilder.launch(rewriter, loc, void_ty(getContext())); + } + } + + // Drop the result token. + Value zero = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); + rewriter.replaceOp(op, zero); + return success(); + } +}; + +struct AsyncTMACopyGlobalToLocalOpConversion + : public ConvertOpToLLVMPattern< + triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getCache() != triton::CacheModifier::NONE) + return op.emitError("cache modifiers not supported yet"); + if (op.getEvict() != triton::EvictionPolicy::NORMAL) + return op.emitError("eviction policy not supported yet"); + if (op.getIsVolatile()) + return op.emitError("volatile not supported yet"); + + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Type llvmElemTy = + typeConverter->convertType(op.getResult().getType().getElementType()); + auto barrierMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getBarrier(), + typeConverter->convertType(op.getBarrier().getType().getElementType()), + rewriter); + auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getResult(), llvmElemTy, rewriter); + auto voidTy = void_ty(op->getContext()); + auto id = getThreadId(rewriter, loc); + + auto mod = op->getParentOfType(); + int numWarps = ttg::lookupNumWarps(op); + int warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpID = rewriter.create(loc); + Value pred = adaptor.getPred(); + // Select just one thread for the TMA copy. This also helps the compiler to + // figure out that the op is uniform. + pred = b.and_(pred, LLVM::NVIDIA::createElectPredicate(loc, rewriter)); + + int elementSizeInBytes = + op.getResult().getType().getElementType().getIntOrFloatBitWidth() / 8; + int totalNumElements = product(op.getResult().getType().getShape()); + int64_t size = totalNumElements * elementSizeInBytes; + + int innerBlockSize = op.getResult().getType().getShape().back(); + int contigDimSizeInByte = innerBlockSize * elementSizeInBytes; + int numCopies = 1; + int rank = op.getCoord().size(); + if (rank > 1) + numCopies = ceil(contigDimSizeInByte, 128); + + auto asyncTaskIds = getAsyncTaskIds(op); + int firstThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + firstThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + + // The bounding box inner dimension must be less than or equal to the + // swizzle size. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // We clamp the block size and the codegen will emit multiple copy + // operations. + for (int copyIdx = 0; copyIdx < numCopies; copyIdx += numWarps) { + int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); + if (numWarpsToCopy == 1) + warpID = b.i32_val(0); + Value boxPred = b.and_( + pred, + b.icmp_ult(id, b.i32_val(numWarpsToCopy * warpSize + firstThreadId))); + ::mlir::triton::PTXBuilder ptxBuilderTMA; + Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); + Value copyIdxVal = b.add(warpID, b.i32_val(copyIdx)); + Value shMemOffset = + b.mul(copyIdxVal, b.i32_val(totalNumElements / numCopies)); + Value shMemPtr = + b.gep(elemPtrTy, llvmElemTy, dstMemObj.getBase(), shMemOffset); + SmallVector operands = { + ptxBuilderTMA.newOperand(boxPred, "b"), + ptxBuilderTMA.newOperand(shMemPtr, "r"), + ptxBuilderTMA.newOperand(adaptor.getDescPtr(), "l")}; + std::string tmaInst = + "@$0 cp.async.bulk.tensor." + std::to_string(rank) + + "d.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {"; + int operandIdx = 3; + for (int i = 0; i < rank; i++) { + Value coord = adaptor.getCoord()[rank - i - 1]; + if (i == 0) { + Value offset = b.mul(copyIdxVal, b.i32_val(128 / elementSizeInBytes)); + coord = b.add(coord, offset); + } + operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); + tmaInst += "$" + std::to_string(operandIdx++); + if (i != rank - 1) + tmaInst += ", "; + } + operands.push_back( + ptxBuilderTMA.newOperand(barrierMemObj.getBase(), "r")); + tmaInst += "}], [$" + std::to_string(operandIdx++) + "];"; + + auto &tma = *ptxBuilderTMA.create<>(tmaInst); + tma(operands, /*onlyAttachMLIRArgs=*/true); + ptxBuilderTMA.launch(rewriter, loc, voidTy); + } + rewriter.eraseOp(op); + return success(); + } +}; + +int getWarpOffset(Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() > 0) { + return 4 * *std::min_element(asyncTaskIds.begin(), asyncTaskIds.end()); + } + return 0; +} + +struct AsyncTMACopyLocalToGlobalOpConversion + : public ConvertOpToLLVMPattern< + triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Type llvmElemTy = + typeConverter->convertType(op.getSrc().getType().getElementType()); + auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), llvmElemTy, rewriter); + auto voidTy = void_ty(op->getContext()); + auto id = getThreadId(rewriter, loc); + // Select just one thread for the TMA copy. This also helps the compiler to + // figure out that the op is uniform. + Value pred = LLVM::NVIDIA::createElectPredicate(loc, rewriter); + int elementSizeInBytes = + op.getSrc().getType().getElementType().getIntOrFloatBitWidth() / 8; + int totalNumElements = product(op.getSrc().getType().getShape()); + int64_t size = totalNumElements * elementSizeInBytes; + + auto mod = op->getParentOfType(); + int numWarps = ttg::lookupNumWarps(op); + int warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpID = rewriter.create(loc); + int innerBlockSize = op.getSrc().getType().getShape().back(); + int contigDimSizeInByte = innerBlockSize * elementSizeInBytes; + int numCopies = 1; + int rank = op.getCoord().size(); + if (rank > 1) + numCopies = ceil(contigDimSizeInByte, 128); + + // The bounding box inner dimension must be less than or equal to the + // swizzle size. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // We clamp the block size and the codegen will emit multiple copy + // operations. + for (int copyIdx = 0; copyIdx < numCopies; copyIdx += numWarps) { + int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); + if (numWarpsToCopy == 1) + warpID = b.i32_val(0); + auto warpOffset = getWarpOffset(op); + warpID = b.sub(warpID, b.i32_val(warpOffset)); + id = b.sub(id, b.i32_val(warpOffset * warpSize)); + Value boxPred = + b.and_(pred, b.icmp_ult(id, b.i32_val(numWarpsToCopy * warpSize))); + ::mlir::triton::PTXBuilder ptxBuilderTMA; + Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); + Value copyIdxVal = b.add(warpID, b.i32_val(copyIdx)); + Value shMemOffset = + b.mul(copyIdxVal, b.i32_val(totalNumElements / numCopies)); + Value shMemPtr = + b.gep(elemPtrTy, llvmElemTy, dstMemObj.getBase(), shMemOffset); + SmallVector operands = { + ptxBuilderTMA.newOperand(boxPred, "b"), + ptxBuilderTMA.newOperand(adaptor.getDescPtr(), "l")}; + std::string tmaInst = "@$0 cp.async.bulk.tensor." + std::to_string(rank) + + "d.global.shared::cta.bulk_group [$1, {"; + int operandIdx = 2; + for (int i = 0; i < rank; i++) { + Value coord = adaptor.getCoord()[rank - i - 1]; + if (i == 0) { + Value offset = b.mul(copyIdxVal, b.i32_val(128 / elementSizeInBytes)); + coord = b.add(coord, offset); + } + operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); + tmaInst += "$" + std::to_string(operandIdx++); + if (i != rank - 1) + tmaInst += ", "; + } + operands.push_back(ptxBuilderTMA.newOperand(shMemPtr, "r")); + tmaInst += "}], [$" + std::to_string(operandIdx++) + "];"; + auto &tma = *ptxBuilderTMA.create<>(tmaInst); + tma(operands, /*onlyAttachMLIRArgs=*/true); + ptxBuilderTMA.launch(rewriter, loc, voidTy); + } + + // TODO: Separate the syncronizations operations into separate TTGIR ops to + // be able to schedule them at the high level. + rewriter.create(loc); + + rewriter.eraseOp(op); + return success(); + } +}; + +static LinearLayout getUnswizzledLayout(triton::gpu::MemDescType type) { + return triton::gpu::sharedToLinearLayoutLeadingOffset( + type.getShape(), cast(type.getEncoding()), + /*disableSwizzle=*/true); +} + +// This function is shared between the TMA gather and scatter lowerings. It +// handles the logic for iterating over the x offset values in groups of 4 +// consecutive indices and mapping them to the appropriate shared memory offset. +// +// This invokes a callback with the predicate, shared memory offset, y offset, +// and x offsets. +static LogicalResult iterateGatherScatterIndices( + Operation *op, ConversionPatternRewriter &rewriter, + const TypeConverter &typeConverter, + mlir::TypedValue xCoords, + mlir::TypedValue smem, Value smemObjValue, + Value xOffsetsValue, Value yOffsetValue, Value pred, + function_ref)> callback) { + MLIRContext *ctx = op->getContext(); + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + StringAttr kDim0 = str_attr("dim0"); + StringAttr kDim1 = str_attr("dim1"); + StringAttr kMsg = str_attr("msg"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + // Each warp can issue a distinct `gather4` instruction that loads 4 rows into + // consecutive shared memory. Thus, the layout of the x offsets must be such + // that 4 consecutive elements are broadcasted to a warp. + RankedTensorType xCoordsTy = xCoords.getType(); + LinearLayout xCoordsLayout = triton::gpu::toLinearLayout( + xCoordsTy.getShape(), xCoordsTy.getEncoding()); + if (xCoordsLayout.getInDimSize(kRegister) < 4) + return op->emitError("must have at least 4 x offsets per warp"); + // Check that the first two bases are [1] and [2]. + for (unsigned i : {0, 1}) { + if (xCoordsLayout.getBasis(kRegister, i).front() != (1 << i)) + return op->emitError( + "x offsets are not grouped by 4 contiguous elements"); + } + + // TMA expects the memdesc shape to match the alloc shape. + triton::gpu::MemDescType smemType = smem.getType(); + ArrayRef allocShape = smemType.getAllocShape(); + if (allocShape.size() < 2 || smemType.getShape() != allocShape.take_back(2)) + return op->emitError("memdesc shape must match alloc shape"); + // `NVMMASharedEncodingAttr` means the core matrix tiles are placed next to + // each other in shared memory, which lines up with how `gather4` loads data. + if (!isa(smemType.getEncoding())) + return op->emitError("requires dst encoding NVMMASharedEncodingAttr"); + Type llvmElemTy = typeConverter.convertType(smemType.getElementType()); + Type elemPtrTy = ptr_ty(ctx, /*addrspace=*/3); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, smemObjValue, + llvmElemTy, rewriter); + + unsigned threadsPerWarp = xCoordsLayout.getInDimSize(kLane); + unsigned numWarps = xCoordsLayout.getInDimSize(kWarp); + + // Each gather4 instructions reads 128 bytes for 4 rows at a time. + unsigned innerBlockSize = smemType.getShape().back(); + unsigned contigDimSizeInBytes = + innerBlockSize * ceil(smemType.getElementTypeBitWidth(), 8); + unsigned numMessagesPerRow = ceil(contigDimSizeInBytes, 128); + + // `xCoordsLayout` maps the register ID into dim0. Tile dim1 by adding a new + // dimension representing the TMA message ID. + assert(innerBlockSize % numMessagesPerRow == 0); + assert(llvm::isPowerOf2_32(numMessagesPerRow)); + unsigned msgSize = innerBlockSize / numMessagesPerRow; + std::vector> msgBases; + for (unsigned msgId = 1; msgId < numMessagesPerRow; msgId *= 2) + msgBases.push_back({int32_t(msgId * msgSize)}); + LinearLayout msgToCol({{{kMsg, std::move(msgBases)}}}, + {{kDim1, innerBlockSize}}, + /*requiresSurjective=*/false); + LinearLayout msgLayout = xCoordsLayout * msgToCol; + + // `gather4` will put the 128-byte segments of the 4 rows consecutively in + // shared memory. However, if the 4 rows are smaller than the shared memory + // swizzle tile size, e.g. [4, 32] vs. [8, 32], then, for example, the address + // of the 0th element of row 4 will not be at the start of the segment. + LinearLayout sharedLayout = getUnswizzledLayout(smemType); + LinearLayout msgToShared = msgLayout.invertAndCompose(sharedLayout); + + // If there are too few rows, warps will have redundant data. An individual + // thread might also have redundant indices if there is register broadcasting. + auto freeVars = xCoordsLayout.getFreeVariableMasks(); + unsigned regMask = freeVars[kRegister]; + unsigned warpMask = freeVars[kWarp]; + if (freeVars[kLane] != (threadsPerWarp - 1)) + return op->emitError("x offsets must be broadcasted across each warp"); + + Value warpId = rewriter.create(loc); + // Each block has separate shared memory. Multiple CTAs don't work anyways. + Value blockId = b.i32_val(0); + + // Mask out warps with redundant x offsets. + pred = b.and_(pred, + b.icmp_eq(b.i32_val(0), b.and_(warpId, b.i32_val(warpMask)))); + // Select one thread in each warp to issue the gather4 messages. + pred = b.and_(pred, LLVM::NVIDIA::createElectPredicate(loc, rewriter)); + + SmallVector xOffsets = unpackLLElements(loc, xOffsetsValue, rewriter); + // Lane ID doesn't matter. + Value laneId = b.i32_val(0); + for (auto regId : seq(0, xOffsets.size(), 4)) { + // Skip redundant x offsets within a thread. + if ((regMask & regId) != 0) + continue; + Value regIdVal = b.i32_val(regId); + + for (auto msgId : llvm::seq(numMessagesPerRow)) { + Value msgIdVal = b.i32_val(msgId); + + auto result = applyLinearLayout(loc, rewriter, msgToShared, + {{kMsg, msgIdVal}, + {kRegister, regIdVal}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}); + assert(result.size() == 2 && result.front().first == "offset" && + result.back().first == "block"); + Value shMemOffset = result.front().second; + // Because we checked that the memdesc's allocshape and shape match, we + // can ignore the strides and directly index into the shmem object. + Value shMemPtr = + b.gep(elemPtrTy, llvmElemTy, smemObj.getBase(), shMemOffset); + Value yOffset = b.add(yOffsetValue, b.i32_val(msgId * msgSize)); + + callback(pred, shMemPtr, yOffset, ArrayRef(xOffsets).slice(regId, 4)); + }; + } + + return success(); +} + +struct AsyncTMAGatherOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::AsyncTMAGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +LogicalResult AsyncTMAGatherOpConversion::matchAndRewrite( + triton::nvidia_gpu::AsyncTMAGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + + LLVM::LLVMVoidType voidTy = void_ty(op->getContext()); + auto barrierMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getBarrier(), + typeConverter->convertType(op.getBarrier().getType().getElementType()), + rewriter); + + // Callback to generate the gather4 instruction. + auto callback = [&](Value pred, Value shMemPtr, Value yOffset, + ArrayRef xOffsets) { + std::string tmaInst = "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared" + "::cluster.global.mbarrier::complete_tx::bytes " + "[$1], [$2, {$3, $4, $5, $6, $7}], [$8];"; + + PTXBuilder ptxBuilder; + SmallVector operands{ + // clang-format off + ptxBuilder.newOperand(pred, "b"), + ptxBuilder.newOperand(shMemPtr, "r"), + ptxBuilder.newOperand(adaptor.getDescPtr(), "l"), + ptxBuilder.newOperand(yOffset, "r") + // clang-format on + }; + for (Value xOffset : xOffsets) + operands.push_back(ptxBuilder.newOperand(xOffset, "r")); + operands.push_back(ptxBuilder.newOperand(barrierMemObj.getBase(), "r")); + + auto &tma = *ptxBuilder.create<>(tmaInst); + tma(operands, /*attachOnlyMLIRArgs=*/true); + ptxBuilder.launch(rewriter, loc, voidTy); + }; + + if (failed(iterateGatherScatterIndices( + op, rewriter, *getTypeConverter(), op.getXOffsets(), op.getResult(), + adaptor.getResult(), adaptor.getXOffsets(), adaptor.getYOffset(), + adaptor.getPred(), callback))) + return failure(); + + rewriter.eraseOp(op); + return success(); +} + +struct AsyncTMAScatterOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::AsyncTMAScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +LogicalResult AsyncTMAScatterOpConversion::matchAndRewrite( + triton::nvidia_gpu::AsyncTMAScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = getContext(); + LLVM::LLVMVoidType voidTy = void_ty(op->getContext()); + + // Callback to generate the scatter4 instruction. + auto callback = [&](Value pred, Value shMemPtr, Value yOffset, + ArrayRef xOffsets) { + std::string tmaInst = "@$0 cp.async.bulk.tensor.2d.tile::scatter4.global" + ".shared::cta.bulk_group " + "[$1, {$2, $3, $4, $5, $6}], [$7];"; + + PTXBuilder ptxBuilder; + SmallVector operands{ + // clang-format off + ptxBuilder.newOperand(pred, "b"), + ptxBuilder.newOperand(adaptor.getDescPtr(), "l"), + ptxBuilder.newOperand(yOffset, "r") + // clang-format on + }; + for (Value xOffset : xOffsets) + operands.push_back(ptxBuilder.newOperand(xOffset, "r")); + operands.push_back(ptxBuilder.newOperand(shMemPtr, "r")); + + auto &tma = *ptxBuilder.create<>(tmaInst); + tma(operands, /*attachOnlyMLIRArgs=*/true); + ptxBuilder.launch(rewriter, loc, voidTy); + }; + + if (failed(iterateGatherScatterIndices( + op, rewriter, *getTypeConverter(), op.getXOffsets(), op.getSrc(), + adaptor.getSrc(), adaptor.getXOffsets(), adaptor.getYOffset(), + /*pred=*/b.true_val(), callback))) + return failure(); + + // TODO: Separate the syncronizations operations into separate TTGIR ops to + // be able to schedule them at the high level. + rewriter.create(loc); + + rewriter.eraseOp(op); + return success(); +} + +struct AsyncWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::AsyncWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto num = op->getAttrOfType("num"); + rewriter.create(loc, num); + + // Drop the result token. + TritonLLVMOpBuilder b(loc, rewriter); + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } +}; + +struct AsyncCommitGroupOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::AsyncCommitGroupOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::AsyncCommitGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + rewriter.create(loc); + + // Drop the result token. + TritonLLVMOpBuilder b(loc, rewriter); + rewriter.replaceOp(op, b.i32_val(0)); + return success(); + } +}; + +struct TMAStoreWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TMAStoreWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = op.getContext(); + auto isRead = UnitAttr::get(ctx); + rewriter.replaceOpWithNewOp( + op, op.getPendingsAttr(), isRead); + return success(); + } +}; + +} // namespace + +void mlir::triton::NVIDIA::populateLoadStoreOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { + patterns.add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns + .add(typeConverter, + benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000..30ddb94d1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,285 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +struct LocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + LocalLoadOpConversion(const LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isa(dstLayout) && + isa( + cast(dstLayout).getParent())) { + auto dotEnc = cast(dstLayout); + auto mmaEnc = cast(dotEnc.getParent()); + auto sharedEnc = dyn_cast(srcLayout); + if (!sharedEnc) + return failure(); + auto bitwidth = dstTy.getElementTypeBitWidth(); + auto vecWidth = 32 / bitwidth; + auto kWidth = dotEnc.getKWidth(); + auto rank = dstTy.getRank(); + auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; + auto nonKOrder = dotEnc.getOpIdx() == 0 ? rank - 2 : rank - 1; + auto needTrans = kOrder != sharedEnc.getOrder()[0]; + // Limitation 1 [TODO: remove]: Check LL bases to verify register and + // address alignment + auto canUseLdmatrix = (kWidth == vecWidth); + canUseLdmatrix &= (sharedEnc.getMaxPhase() == 1) || + (sharedEnc.getVec() * bitwidth >= 8 * 16); + auto shape = srcTy.getShape(); + // Limitation 2 [TODO: remove]: Only support 2d matrices now but we should + // be able to support 3D minor changes + canUseLdmatrix &= (bitwidth == 16 || !needTrans) && shape.size() <= 2; + // Limitation 3: Minimum tile size (8)x(8x16bits) + canUseLdmatrix &= + shape[kOrder] >= (8 * 16 / bitwidth) && shape[nonKOrder] >= 8; + if (canUseLdmatrix) { + return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), + rewriter); + } + } + return failure(); + } + +private: + LogicalResult + lowerSharedToDotOperand(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto ctx = rewriter.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto dstTy = cast(op.getType()); + auto srcTy = cast(op.getSrc().getType()); + auto dotEnc = cast(dstTy.getEncoding()); + auto sharedEnc = cast(srcTy.getEncoding()); + auto shape = dstTy.getShape(); + auto rank = dstTy.getRank(); + auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; + auto nonKOrder = dotEnc.getOpIdx() == 0 ? rank - 2 : rank - 1; + auto needTrans = kOrder != sharedEnc.getOrder()[0]; + + auto llvmElemTy = typeConverter->convertType(dstTy.getElementType()); + auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto ldmatrixLayout = + chooseLdMatrixLayout(dotEnc, shape, needTrans, bitwidth); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + // Emit ldmatrix load operations for values packed in i32s + SmallVector elemsI32; + // Typically we load 32x8 to use ldmatrix.x4, but the minimum tile size for + // opIdx=1 is 16x8. Therefore, we use ldmatrix.x2 instead of + // ldmatrix.x4 in this case. + auto shift = dotEnc.getOpIdx() == 1 && shape[kOrder] < (32 * 16 / bitwidth); + auto maxVecElems = 8 * 16 / bitwidth; + bool valid = emitTransferBetweenRegistersAndShared( + ldmatrixLayout, srcTy, llvmElemTy, + /*maxVecElems=*/maxVecElems, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy, Value vecAddr) { + auto numElems = vecTy.getNumElements(); + auto numElemsI32 = (numElems * bitwidth / 32) >> shift; + auto matTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(numElemsI32, i32_ty)); + auto ldMatrixOp = rewriter.create( + loc, matTy, vecAddr, /*needTrans=*/needTrans); + auto res = ldMatrixOp.getResult(); + for (auto i = 0; i < numElemsI32; ++i) { + elemsI32.push_back(b.extract_val(i32_ty, res, i)); + } + }); + assert(valid && "Failed to emit ldmatrix load operations"); + + // Unpack i32 values to the original type + SmallVector elems; + auto numElemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(llvmElemTy, numElemsPerVec); + for (int v = 0; v < static_cast(elemsI32.size()); ++v) { + auto vec = b.bitcast(elemsI32[v], vecTy); + for (int i = 0; i < numElemsPerVec; ++i) + elems.push_back(b.extract_element(llvmElemTy, vec, b.i32_val(i))); + } + + auto structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(elems.size(), llvmElemTy)); + auto ret = packLLElements(loc, typeConverter, elems, rewriter, structTy); + rewriter.replaceOp(op, ret); + return success(); + } + +private: + const NVIDIA::TargetInfo &targetInfo; +}; + +LogicalResult lowerDistributedToSharedStmatrix( + Location loc, TypedValue src, MemDescType memDescType, + Value adaptorSrc, Value smemBase, const TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto mmaEncoding = + dyn_cast(src.getType().getEncoding()); + if (!mmaEncoding) + return failure(); + auto sharedLayout = + dyn_cast(memDescType.getEncoding()); + if (!sharedLayout) + return failure(); + int swizzleByteSize = sharedLayout.getSwizzlingByteWidth(); + + RankedTensorType srcTy = src.getType(); + SmallVector shape = + convertType(srcTy.getShape()); + SmallVector order = sharedLayout.getTransposed() + ? SmallVector({0, 1}) + : SmallVector({1, 0}); + if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, swizzleByteSize)) { + return failure(); + } + + auto *ctx = rewriter.getContext(); + + auto layout = + chooseStMatrixLayout(rewriter.getContext(), srcTy, swizzleByteSize); + auto llvmElemTy = typeConverter->convertType(memDescType.getElementType()); + auto smemPtrTy = ptr_ty(ctx, 3); + + auto kRegister = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + + auto regBase = applyLinearLayout(loc, rewriter, layout, + {{kRegister, b.i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}})[0] + .second; + auto srcVals = unpackLLElements(loc, adaptorSrc, rewriter); + auto srcVec = layout.getNumConsecutiveInOut(); + for (int i = 0; i < srcVals.size(); i += srcVec) { + auto regIdx = + layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] + .second; + Value offset = b.xor_(regBase, b.i32_val(regIdx)); + auto vecAddr = b.gep(smemPtrTy, llvmElemTy, smemBase, offset); + vecAddr.setInbounds(true); + SmallVector inValsVec; + for (int j = 0; j < srcVec; j++) + inValsVec.push_back(srcVals[i + j]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } + return success(); +} + +struct LocalAllocOpConversion + : public ConvertOpToLLVMPattern { + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getSrc()) + return failure(); + MemDescType memDescType = op.getType(); + RankedTensorType srcTy = op.getSrc().getType(); + Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); + Value smemBase = + LLVM::getSharedMemoryBase(op.getLoc(), rewriter, targetInfo, op); + + if (lowerDistributedToSharedStmatrix(op.getLoc(), op.getSrc(), memDescType, + adaptor.getSrc(), smemBase, + typeConverter, rewriter, targetInfo) + .failed()) { + return failure(); + } + + auto resultTy = cast(op.getType()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(), + op.getLoc(), rewriter); + auto retVal = + getStructFromSharedMemoryObject(op.getLoc(), smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const NVIDIA::TargetInfo &targetInfo; +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + SharedMemoryObject smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + MemDescType memDescType = op.getDst().getType(); + if (lowerDistributedToSharedStmatrix( + op.getLoc(), op.getSrc(), memDescType, adaptor.getSrc(), + smemObj.getBase(), getTypeConverter(), rewriter, targetInfo) + .failed()) { + return failure(); + } + rewriter.eraseOp(op); + return success(); + } + +private: + const NVIDIA::TargetInfo &targetInfo; +}; +} // namespace + +void mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + // Backend optimized memory ops get higher benefit + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); + mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo, + patterns, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp new file mode 100644 index 000000000..2f4f03007 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp @@ -0,0 +1,236 @@ +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h" +#include "llvm/Support/raw_ostream.h" +// TODO(Superjomn): unify to llvm::raw_string_ostream +#include + +namespace mlir { +namespace triton { + +PTXInstr::Operand * +PTXBuilder::newOperand(mlir::Value value, StringRef constraint, + std::function formatter) { + argArchive.emplace_back(std::make_unique(value, constraint)); + auto *opr = argArchive.back().get(); + opr->repr = formatter; + opr->idx = oprCounter++; + return opr; +} + +void PTXBuilder::initOperand(Operand *opr) { + auto numBits = 0; + // Derive numBits from the constraint. + if (opr->constraint[1] == 'c' || opr->constraint[1] == 'h') + numBits = 16; + else if (opr->constraint[1] == 'r') + numBits = 32; + else if (opr->constraint[1] == 'l') + numBits = 64; + else + llvm_unreachable(("Unknown constraint: " + opr->constraint).c_str()); + // If numBits is less than 16, we use 16 as default because PTX does not + // support 8-bit mov. + numBits = numBits < 16 ? 16 : numBits; + auto *zero = newConstantOperand(0); + auto &init = create<>("mov")->o("u" + std::to_string(numBits)); + init(opr, zero); +} + +PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint, bool init) { + // Constraint should be something like "=r" + assert(constraint.size() == 2 && constraint[0] == '='); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = constraint; + if (init) { + initOperand(opr); + } + return opr; +} + +PTXBuilder::Operand *PTXBuilder::newOperand(unsigned operandIndex) { + assert(operandIndex < oprCounter && "operand index out of range"); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = std::to_string(operandIndex); + return opr; +} + +PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) { + argArchive.emplace_back(std::make_unique()); + argArchive.back()->repr = [v](int idx) { return v; }; + return argArchive.back().get(); +} + +PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) { + std::stringstream ss; + ss << "0x" << std::hex << v; + return newConstantOperand(ss.str()); +} + +std::string PTXBuilder::getConstraints() const { + auto args = getAllArgs(); + llvm::SmallVector argReprs; + for (auto arg : args) + argReprs.push_back(arg->constraint); + return strJoin(argReprs, ","); +} + +llvm::SmallVector PTXBuilder::getAllMLIRArgs() const { + llvm::SmallVector res; + for (auto &arg : argArchive) { + if (!arg->isList() && arg->value) + res.push_back(arg->value); + } + return res; +} + +SmallVector PTXBuilder::getAllArgs() const { + llvm::SmallVector res; + for (auto &x : argArchive) + if (!x->isList()) + res.push_back(x.get()); + return res; +} + +mlir::Value PTXBuilder::launch(OpBuilder &rewriter, Location loc, Type resTy, + bool hasSideEffect, bool isAlignStack, + ArrayRef attrs) const { + auto *ctx = rewriter.getContext(); + auto inlineAsm = rewriter.create( + loc, resTy, getAllMLIRArgs(), // operands + dump(), // asm_string + getConstraints(), // constraints + hasSideEffect, // has_side_effects + isAlignStack, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, attrs) // operand_attrs + ); + + return inlineAsm.getRes(); +} + +std::string PTXInstr::Operand::dump() const { + if (repr) + return repr(idx); + if (!isList()) + return "$" + std::to_string(idx); + + llvm::SmallVector oprs; + for (auto *opr : list) + oprs.push_back(opr->dump()); + return "{ " + strJoin(oprs, ", ") + " }"; +} + +PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr, + StringRef constraint, int off) { + auto *opr = newOperand(addr, constraint); + opr->repr = [off](int idx) -> std::string { + std::stringstream ss; + ss << "[ $" << idx << " + " << off << " ]"; + return ss.str(); + }; + + return opr; +} + +std::string PTXBuilder::dump() const { + llvm::SmallVector lines; + for (auto &exec : executions) { + lines.push_back(exec->dump()); + } + + return strJoin(lines, "\n\t"); +} + +PTXInstrExecution &PTXInstrCommon::call(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + if (onlyAttachMLIRArgs) { + // Nearly impossible to make the $0,$1 in two PTX code snippets to point to + // the same MLIR values in onlyAttachMLIRArgs mode. + assert(builder->executions.empty() && + "builder can only hold a single execution when onlyAttachMIIRArgs " + "is true."); + builder->reorderArgArchive(oprs); + } + + builder->executions.emplace_back( + std::make_unique(this, oprs, onlyAttachMLIRArgs)); + + return *builder->executions.back(); +} + +PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + return call(oprs, onlyAttachMLIRArgs); +} + +std::string PTXInstrExecution::dump() const { + std::string osStr; + llvm::raw_string_ostream os(osStr); + + if (pred) { + if (!pred->repr) + os << "@" << pred->dump() << " "; + else + os << pred->repr(pred->idx) << " "; + } + + std::string instrRepr = strJoin(instr->instrParts, "."); + if (onlyAttachMLIRArgs) { + os << instrRepr; + os.flush(); + return osStr; + } + + llvm::SmallVector argReprs; + for (auto *arg : argsInOrder) { + argReprs.push_back(arg->dump()); + } + + std::string argsRepr = strJoin(argReprs, ", "); + + os << instrRepr << " " << argsRepr << ";"; + os.flush(); + return osStr; +} + +SmallVector +PTXInstrExecution::getArgList() const { + SmallVector args; + for (auto *arg : argsInOrder) { + if (arg->isList()) + args.insert(args.end(), arg->list.begin(), arg->list.end()); + else + args.push_back(arg); + } + return args; +} + +PTXInstr &PTXInstr::global() { + o("global"); + return *this; +} + +PTXInstr &PTXInstr::shared() { + o("shared"); + return *this; +} + +PTXInstr &PTXInstr::v(int vecWidth, bool predicate) { + if (vecWidth > 1) { + o("v" + std::to_string(vecWidth), predicate); + } + return *this; +} + +PTXInstr &PTXInstr::b(int width) { + o("b" + std::to_string(width)); + return *this; +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 000000000..7938ecd80 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,88 @@ +#ifndef TRITON_CONVERSION_TRITONNVIDIAGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONNVIDIAGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H + +#include "TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +namespace mlir { +namespace triton { + +namespace NVIDIA { + +void populateBarrierOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateClusterOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMemoryOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateConvertLayoutOpToLLVMOptimizedPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit); + +void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, + const TargetInfo &targetInfo, PatternBenefit benefit); + +void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit); + +void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateTMAToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateClampFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + int computeCapability, + PatternBenefit benefit); + +void populateTCGen5MMAOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateTensorMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateTensorMemorySubviewOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit); +} // namespace NVIDIA +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/RegReallocOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/RegReallocOpToLLVM.cpp new file mode 100644 index 000000000..51c91c4af --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/RegReallocOpToLLVM.cpp @@ -0,0 +1,47 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +struct RegAllocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::RegAllocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; + +struct RegDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::RegDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; +} // namespace + +void mlir::triton::populateRegReallocOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..119a8d72b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,58 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetNumProgramsOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::GetNumProgramsOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // It is not easy to get the compute capability here, so we use numCTAs to + // decide the semantic of GetNumProgramsOp. If numCTAs = 1, then + // GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to + // "%nclusterid". + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + + Location loc = op->getLoc(); + assert(op.getAxisAsInt() < 3); + std::string sreg = numCTAs == 1 ? "nctaid." : "nclusterid."; + sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z' + + Value numPrograms = LLVM::NVIDIA::getSRegValue(rewriter, loc, sreg); + rewriter.replaceOp(op, numPrograms); + return success(); + } +}; + +struct GetCanonicalWarpIdConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::GetCanonicalWarpIdOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::nvidia_gpu::GetCanonicalWarpIdOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto warpIdOp = rewriter.create( + op->getLoc(), rewriter.getI32Type()); + rewriter.replaceOp(op, warpIdOp); + return success(); + } +}; +} // namespace + +void mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp new file mode 100644 index 000000000..fe2ff915b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -0,0 +1,337 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/TypeUtilities.h" + +#include "PatternTritonGPUOpToLLVM.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" + +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" + +#include "Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::nvidia_gpu; + +namespace { + +void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, Value outPtr, + Value inPtr) { + PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // prepare asm operands + auto *outAddrOpr = ptxBuilder.newAddrOperand(outPtr, "l"); + auto *inAddrOpr = ptxBuilder.newAddrOperand(inPtr, "l"); + auto *sizeOpr = ptxBuilder.newConstantOperand(TMA_SIZE_BYTES); + + // Define the instruction opcode + auto &cp = + *ptxBuilder.create<>("tensormap.cp_fenceproxy.global.shared::cta." + "tensormap::generic.release.gpu.sync.aligned"); + + // Execute collectively on first warp in block + constexpr int kWarpSize = 32; + Value threadId = getThreadId(rewriter, loc); + Value pred = b.icmp_slt(threadId, b.i32_val(kWarpSize)); + cp(outAddrOpr, inAddrOpr, sizeOpr).predicate(pred); + + ptxBuilder.launch(rewriter, loc, void_ty(ctx)); +}; + +void tensormap_replace_generic(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + std::string fieldName, Value descPtr, + int32_t newVal) { + PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // prepare asm operands + auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, "l"); + auto newValOpr = ptxBuilder.newConstantOperand(newVal); + + // Define the instruction opcode + auto &replace = ptxBuilder.create<>("tensormap.replace.tile") + ->o(fieldName) + .o("shared::cta") + .o("b1024") + .o("b32"); + + Value threadId = getThreadId(rewriter, loc); + Value pred = b.icmp_eq(threadId, b.i32_val(0)); + replace(descAddrOpr, newValOpr).predicate(pred); + + ptxBuilder.launch(rewriter, loc, void_ty(ctx)); +} + +void tensormap_replace_generic(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + std::string fieldName, Value descPtr, + Value newVal, + std::optional ord = std::nullopt) { + PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto newValTy = newVal.getType(); + int width = 0; + + // prepare asm operands + auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, "l"); + PTXInstr::Operand *ordOpr = + ord ? ptxBuilder.newConstantOperand(*ord) : nullptr; + PTXInstr::Operand *newValOpr = nullptr; + if (mlir::isa(newValTy)) { + width = mlir::cast(newValTy).getWidth(); + } else { + assert(mlir::isa(newValTy)); + width = 64; + } + const char *constraint = width == 64 ? "l" : "r"; + newValOpr = ptxBuilder.newOperand(newVal, constraint); + + // Define the instruction opcode + auto &replace = ptxBuilder.create<>("tensormap.replace.tile") + ->o(fieldName) + .o("shared::cta") + .o("b1024") + .o("b32", width == 32) + .o("b64", width == 64); + + Value threadId = getThreadId(rewriter, loc); + Value pred = b.icmp_eq(threadId, b.i32_val(0)); + + if (ord) { + replace(descAddrOpr, ordOpr, newValOpr).predicate(pred); + } else { + replace(descAddrOpr, newValOpr).predicate(pred); + } + + ptxBuilder.launch(rewriter, loc, void_ty(ctx)); +} + +void tensormap_replace_global_address(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "global_address", descPtr, + newVal); +} + +void tensormap_replace_rank(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, Value descPtr, + int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "rank", descPtr, newVal); +} + +void tensormap_replace_box_dim(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t ord, Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "box_dim", descPtr, newVal, + ord); +} + +void tensormap_replace_global_dim(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t ord, Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "global_dim", descPtr, newVal, + ord); +} + +void tensormap_replace_global_stride(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t ord, Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "global_stride", descPtr, + newVal, ord); +} + +void tensormap_replace_element_stride(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t ord, + Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "element_stride", descPtr, + newVal, ord); +} + +void tensormap_replace_elemtype(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "elemtype", descPtr, newVal); +} + +void tensormap_replace_interleave_layout(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "interleave_layout", descPtr, + newVal); +} + +void tensormap_replace_swizzle_mode(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "swizzle_mode", descPtr, + newVal); +} + +void tensormap_replace_fill_mode(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "fill_mode", descPtr, newVal); +} + +struct ExperimentalTensormapFenceproxyAcquireOpConversion + : public ConvertOpToLLVMPattern< + triton::ExperimentalTensormapFenceproxyAcquireOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ExperimentalTensormapFenceproxyAcquireOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + PTXBuilder ptxBuilder; + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // prepare asm operands + auto *descAddrOpr = ptxBuilder.newAddrOperand(adaptor.getDescPtr(), "l"); + auto *sizeOpr = ptxBuilder.newConstantOperand(TMA_SIZE_BYTES); + + // Define the instruction opcode + constexpr int kWarpSize = 32; + Value threadId = getThreadId(rewriter, loc); + Value pred = b.icmp_slt(threadId, b.i32_val(kWarpSize)); + auto &fence = + *ptxBuilder.create<>("fence.proxy.tensormap::generic.acquire.gpu"); + fence(descAddrOpr, sizeOpr).predicate(pred); + + ptxBuilder.launch(rewriter, loc, getVoidType()); + + // We run the fence on a single warp, then use a barrier to synchronize the + // rest. This ends up being faster than running the fence on each warp. + // TODO: Ideally we only emit one barrier after all fences are issued + insertBarrier(rewriter, op); + + rewriter.eraseOp(op); + return success(); + } +}; + +void zero_fill_tma(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + const NVIDIA::TargetInfo &targetInfo, Value descPtr) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // Write out zeros + constexpr int kWarpSize = 32; + Value threadId = getThreadId(rewriter, loc); + Value pred = b.icmp_slt(threadId, b.i32_val(kWarpSize)); + + auto fillVal = b.i32_val(0); + auto writeAddr = + b.gep(descPtr.getType(), fillVal.getType(), descPtr, threadId); + targetInfo.storeShared(rewriter, loc, writeAddr, fillVal, pred); + LLVM::NVIDIA::createSyncWarp(loc, rewriter); +} + +struct ExperimentalTensormapCreateOpConversion + : public ConvertOpToLLVMPattern { + const NVIDIA::TargetInfo &targetInfo; + + ExperimentalTensormapCreateOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ExperimentalTensormapCreateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ctx = getContext(); + + bool needsStrideWorkaround = targetInfo.getPtxVersion() <= 85; + auto smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + zero_fill_tma(loc, ctx, rewriter, targetInfo, smemBase); + tensormap_replace_global_address(loc, ctx, rewriter, smemBase, + adaptor.getGlobalAddress()); + tensormap_replace_rank(loc, ctx, rewriter, smemBase, op.getRank() - 1); + for (int i = 0; i < op.getRank(); ++i) { + tensormap_replace_box_dim(loc, ctx, rewriter, smemBase, i, + op.getBoxDim()[i]); + } + for (int i = 0; i < op.getRank(); ++i) { + tensormap_replace_global_dim(loc, ctx, rewriter, smemBase, i, + op.getGlobalDim()[i]); + } + for (int i = 0; i + 1 < op.getRank(); ++i) { + auto strideVal = op.getGlobalStride()[i]; + if (needsStrideWorkaround) { + // Workaround for a ptxas bug + strideVal = b.ashr(strideVal, b.i64_val(4)); + } + tensormap_replace_global_stride(loc, ctx, rewriter, smemBase, i, + strideVal); + } + for (int i = 0; i < op.getRank(); ++i) { + tensormap_replace_element_stride(loc, ctx, rewriter, smemBase, i, + op.getElementStride()[i]); + } + tensormap_replace_elemtype(loc, ctx, rewriter, smemBase, op.getElemType()); + tensormap_replace_interleave_layout(loc, ctx, rewriter, smemBase, + op.getInterleaveLayout()); + tensormap_replace_swizzle_mode(loc, ctx, rewriter, smemBase, + op.getSwizzleMode()); + tensormap_replace_fill_mode(loc, ctx, rewriter, smemBase, op.getFillMode()); + tensormap_cp_fenceproxy(loc, ctx, rewriter, adaptor.getDescPtr(), smemBase); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ReinterpretTensorDescOpConversion + : public ConvertOpToLLVMPattern { + + ReinterpretTensorDescOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::ReinterpretTensorDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getRawDesc()); + return success(); + } +}; + +struct TensorDescToTMAPtrOpConversion + : public ConvertOpToLLVMPattern { + + TensorDescToTMAPtrOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TensorDescToTMAPtrOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getDesc()); + return success(); + } +}; + +} // namespace + +void mlir::triton::NVIDIA::populateTMAToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, + targetInfo, benefit); + patterns + .add( + typeConverter, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp new file mode 100644 index 000000000..552ef476d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -0,0 +1,636 @@ +#include "TargetInfo.h" +#include "Dialect/NVGPU/IR/Dialect.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; + +using ::mlir::LLVM::linearize; +namespace { +// declare vprintf(i8*, i8*) as external function +LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("vprintf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + + SmallVector argsType{ptr_ty(context), ptr_ty(context)}; + auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(context), funcName, + funcType); +} + +// extend integer to int32, extend float to float64 +// this comes from vprintf alignment requirements. +std::pair printfPromoteValue(RewriterBase &rewriter, Value value) { + auto *context = rewriter.getContext(); + auto type = value.getType(); + Value newOp = value; + Type newType = type; + auto loc = UnknownLoc::get(context); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + bool isUnsigned = type.isUnsignedInteger(); + if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { + if (isUnsigned) { + newType = ui32_ty; + newOp = b.zext(newType, value); + } else { + newType = i32_ty; + newOp = b.sext(newType, value); + } + } else if (type.isBF16() || type.isF16() || type.isF32()) { + newType = f64_ty; + newOp = b.fpext(newType, value); + } + + return {newType, newOp}; +} + +LLVM::LLVMFuncOp getAssertfailDeclaration(RewriterBase &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("__assertfail"); + { + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + } + // void __assert_fail(const char * assertion, const char * file, unsigned + // int line, const char * function); + auto *ctx = rewriter.getContext(); + SmallVector argsType{ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx), + rewriter.getIntegerType(sizeof(size_t) * 8)}; + auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto funcOp = rewriter.create(UnknownLoc::get(ctx), + funcName, funcType); + + funcOp.setPassthroughAttr( + ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn"))); + return funcOp; +} +} // namespace + +namespace mlir::triton::NVIDIA { + +// Check if the reduction can use a redux op and return the kind. +static std::optional matchReduxKind(triton::ReduceOp op, + int computeCapability) { + if (computeCapability < 80) + return std::nullopt; + Operation *reduceOp = op.getSingleCombiner(); + if (!reduceOp) + return std::nullopt; + auto intType = dyn_cast(reduceOp->getResultTypes()[0]); + if (!intType || intType.getWidth() > 32) + return std::nullopt; + if (isa(reduceOp)) + return NVVM::ReduxKind::ADD; + if (isa(reduceOp)) + return NVVM::ReduxKind::AND; + if (isa(reduceOp)) + return NVVM::ReduxKind::OR; + if (isa(reduceOp)) + return NVVM::ReduxKind::XOR; + if (isa(reduceOp)) + return NVVM::ReduxKind::MIN; + if (isa(reduceOp)) + return NVVM::ReduxKind::UMIN; + if (isa(reduceOp)) + return NVVM::ReduxKind::MAX; + if (isa(reduceOp)) + return NVVM::ReduxKind::UMAX; + return std::nullopt; +} + +bool TargetInfo::supportMaximumMinimum() const { + return computeCapability >= 80; +} + +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + return rewriter.create(loc, + rewriter.getI32Type()); +} + +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadMask = b.int_val(type.getIntOrFloatBitWidth(), -1); + return rewriter.create(loc, type, threadMask, cmp); +} + +static Value mapa(RewriterBase &rewriter, Location loc, Value ptr, Value ctaid, + Value pred) { + Value args[] = {ptr, ctaid}; + StringRef name = "llvm.nvvm.mapa.shared.cluster"; + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, name, ptr.getType(), + args) + .getResult(0); +} + +static std::string getConstraintForBitwidth(unsigned bitwidth) { + switch (bitwidth) { + case 8: + case 16: + return "h"; + case 32: + return "r"; + case 64: + return "l"; + default: + llvm_unreachable("unsupported bitwidth"); + } +} + +static bool isConstantTruePred(Value pred) { + if (auto constOp = pred.getDefiningOp()) { + return cast(constOp.getValue()).getInt() != 0; + } + return false; +} + +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + auto ptrTy = cast(ptr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + + if (!isa(val.getType())) { + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, {val}, rewriter), + pred); + return; + } + + auto vecTy = cast(val.getType()); + Type elemTy = vecTy.getElementType(); + unsigned vec = vecTy.getNumElements(); + unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_32(vec)); + + if (elemBitwidth < 8) { + assert(vec == 1 && + "don't know how to load/store vectors of sub-byte elems"); + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (Value &v : vals) { + v = b.zext(int_ty(8), b.bitcast(v, int_ty(elemBitwidth))); + } + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), + pred); + return; + } + + if (!elemTy.isInteger()) { + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (Value &v : vals) { + v = b.bitcast(v, int_ty(elemBitwidth)); + } + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), + pred); + return; + } + + // load/store ops only support v2 and v4. If the vector width is larger than + // 4, we have two strategies for dealing with it. + // 1. If the element type is smaller than b32, store b32's instead. + // 2. Otherwise, split the store into multiple stores. + if (vec > 4 && elemBitwidth < 32) { + assert(llvm::isPowerOf2_32(vec)); + int elemsPerPack = 32 / elemBitwidth; + SmallVector oldVals = unpackLLVector(loc, val, rewriter); + + SmallVector newVals; + for (int i = 0; i < vec / elemsPerPack; i++) { + Value v = packLLVector( + loc, ArrayRef(oldVals).slice(i * elemsPerPack, elemsPerPack), + rewriter); + newVals.push_back(b.bitcast(v, i32_ty)); + } + storeDShared(rewriter, loc, ptr, ctaId, + packLLVector(loc, newVals, rewriter), pred); + return; + } + + if (vec * elemBitwidth > 128) { + assert(llvm::isPowerOf2_32(vec)); + assert(elemBitwidth == 32 || elemBitwidth == 64); + int maxVec = 128 / elemBitwidth; + + auto newVecTy = vec_ty(elemTy, maxVec); + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (int i = 0; i < vec / maxVec; i++) { + auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), + /*inbounds=*/true); + storeDShared( + rewriter, loc, newPtr, ctaId, + packLLVector(loc, ArrayRef(vals).slice(i * maxVec, maxVec), rewriter), + pred); + } + return; + } + + // At this point we're committed to doing the store! + assert(elemBitwidth >= 8); + assert(elemTy.isInteger()); + assert(1 <= vec && vec <= 4); + assert(vec * elemBitwidth <= 128); + + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } + + PTXBuilder builder; + auto st = builder.create<>("st") + ->o("shared::cta", ctaId.has_value()) + .o("shared", !ctaId.has_value()) + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); + auto *ptrOpr = builder.newAddrOperand(ptr, "r"); + + PTXBuilder::Operand *valOpr; + std::string constraint = getConstraintForBitwidth(elemBitwidth); + if (vec > 1) { + SmallVector> vecVals; + for (int i = 0; i < vec; i++) { + vecVals.push_back({b.extract_element(val, b.i32_val(i)), constraint}); + } + valOpr = builder.newListOperand(vecVals); + } else { + valOpr = builder.newOperand(val, constraint); + } + st(ptrOpr, valOpr).predicate(pred, "b"); + builder.launch(rewriter, loc, void_ty(ctx)); +} + +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type loadTy, + Value pred) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + auto ptrTy = cast(ptr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + + if (!isa(loadTy)) { + SmallVector values = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, vec_ty(loadTy, 1), pred), + rewriter); + assert(values.size() == 1); + return values[0]; + } + + auto vecTy = cast(loadTy); + Type elemTy = vecTy.getElementType(); + unsigned vec = vecTy.getNumElements(); + unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_32(vec)); + + if (elemBitwidth < 8) { + assert(vec == 1 && + "don't know how to load/store vectors of sub-byte elems"); + SmallVector vals = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, int_ty(8), pred), rewriter); + assert(vals.size() == 1); + return b.bitcast(b.trunc(int_ty(elemBitwidth), vals[0]), elemTy); + } + + // We only know how to load integers. + if (!elemTy.isInteger()) { + Type newLoadTy = vec_ty(int_ty(elemBitwidth), vec); + SmallVector vals = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, newLoadTy, pred), rewriter); + for (Value &v : vals) { + v = b.bitcast(v, elemTy); + } + return packLLVector(loc, vals, rewriter); + } + + // load/store ops only support v2 and v4. If the vector width is larger than + // 4, we have two strategies for dealing with it. + // 1. If the element type is smaller than b32, load b32's instead. + // 2. Otherwise, split the load into multiple loads. + if (vec > 4 && elemBitwidth < 32) { + int newVec = vec / (32 / elemBitwidth); + auto newVecTy = vec_ty(i32_ty, newVec); + auto res = loadDShared(rewriter, loc, ptr, ctaId, newVecTy, pred); + + // Unpack the b32's into the original vector type. + SmallVector vals; + for (Value v : unpackLLVector(loc, res, rewriter)) { + Value vv = b.bitcast(v, vec_ty(elemTy, 32 / elemBitwidth)); + for (Value vvv : unpackLLVector(loc, vv, rewriter)) { + vals.push_back(vvv); + } + } + return packLLVector(loc, vals, rewriter); + } + + if (vec * elemBitwidth > 128) { + assert(elemBitwidth == 32 || elemBitwidth == 64); + assert(llvm::isPowerOf2_32(vec)); + int maxVec = 128 / elemBitwidth; + + SmallVector vals; + for (int i = 0; i < vec / maxVec; i++) { + auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), + /*inbounds=*/true); + auto newVal = loadDShared(rewriter, loc, newPtr, ctaId, + vec_ty(elemTy, maxVec), pred); + for (Value v : unpackLLVector(loc, newVal, rewriter)) { + vals.push_back(v); + } + } + return packLLVector(loc, vals, rewriter); + } + + // At this point we're committed to actually do the load! + assert(elemBitwidth >= 8); + assert(elemTy.isInteger()); + assert(1 <= vec && vec <= 4); + assert(vec * elemBitwidth <= 128); + + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } + + PTXBuilder builder; + auto ld = builder.create<>("ld") + ->o("shared::cta", ctaId.has_value()) + .o("shared", !ctaId.has_value()) + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); + + Value load; + if (isConstantTruePred(pred)) { + Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) + : Type(vec_ty(int_ty(elemBitwidth), vec)); + load = b.load(resultTy, ptr); + if (vec > 1) { + Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); + Value structValue = b.undef(structTy); + for (int i = 0; i < vec; i++) { + structValue = b.insert_val(structTy, structValue, + b.extract_element(load, b.i32_val(i)), i); + } + load = structValue; + } + } else { + std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); + + Type resultTy = + vec == 1 + ? Type(int_ty(elemBitwidth)) + : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); + load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); + } + SmallVector resultVals = unpackLLElements(loc, load, rewriter); + return packLLVector(loc, resultVals, rewriter); +} + +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::NVIDIA::shuffleXor(loc, rewriter, val, i); +} + +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::NVIDIA::shuffleUp(loc, rewriter, val, i); +} + +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { + return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); +} + +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { + return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); +} + +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, + ModuleOp moduleOp, int axis) const { + return LLVM::NVIDIA::llGetPid(loc, rewriter, moduleOp, axis); +} +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, + unsigned interleave) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (auto kind = matchReduxKind(op, computeCapability)) { + // Based on benchmarking on A100 redux op gives a speed up only when doing + // a single reduction (not partitioned) and when the mask is static. + // Therefore we currently only enable it to reduce across all the lanes. + if (numLaneToReduce == 32) { + assert(acc.size() == 1); + Value mask = b.i32_val(0xFFFFFFFF); + // Even though we currently don't use redux for partitioned reduction + // the code below supports it in case we want to tweak the heuristic. + if (numLaneToReduce < 32) { + // For partitioned reduction we need to calculate the mask so that + // each group of numLaneToReduce threads has the correct mask. + unsigned bitmask = (1 << numLaneToReduce) - 1; + Value laneId = getLaneId(rewriter, loc); + mask = b.shl(b.i32_val(bitmask), + b.and_(laneId, b.i32_val(~(numLaneToReduce - 1)))); + } + for (unsigned i = 0; i < acc.size(); ++i) { + unsigned bitwidth = cast(acc[i].getType()).getWidth(); + if (bitwidth < 32) { + if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX) + acc[i] = b.sext(i32_ty, acc[i]); + else + acc[i] = b.zext(i32_ty, acc[i]); + } + acc[i] = rewriter.create(loc, acc[i].getType(), acc[0], + *kind, mask); + if (bitwidth < 32) + acc[i] = b.trunc(int_ty(bitwidth), acc[i]); + } + return true; + } + } + return false; +} + +// TODO (Keren): Currently, we have more restrictions than necessary when using +// stmatrix. These restrictions are retained from legacy code, and we could +// relax some of them in the future. +// TODO (Lezcano): The proper way of doing this is to directly try to fit the +// relevant layout and return an std::optional. I'm keeping this +// split to keep the current PR smaller +bool TargetInfo::canUseStMatrix(RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const { + if (computeCapability < 90) { + return false; + } + auto mmaLayout = + mlir::dyn_cast(tensorTy.getEncoding()); + if (!mmaLayout || !mmaLayout.isHopper()) + return false; + if (isa(tensorTy.getElementType())) + return false; + if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) + return false; + if (order[0] != 1) + return false; + + auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape()); + if (tensorShapePerCTA.size() != 2) + return false; + auto numIterations = ceil(tensorShapePerCTA[1], repShape[1]) * + ceil(tensorShapePerCTA[0], repShape[0]); + if (numIterations > 1) + return false; + if (paddedRepShape[1] % 8 != 0) + return false; + if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 && + swizzleByteSize != 128) + return false; + return true; +} + +void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, + Value ptr, Value val) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto vals = unpackLLVector(loc, val, rewriter); + // Ensure input consists of 4 vectors, each holding 2 elements of 16 bits + assert(vals[0].getType().getIntOrFloatBitWidth() == 16 && + "stmatrix requires elements to be 16-bit integers or floats"); + assert(vals.size() == 8 && + "stmatrix requires exactly 8 elements in the input vector"); + Type packedTy = vec_ty(vals[0].getType(), 2); + SmallVector inputs; + for (int i = 0; i < 4; i++) { + Value input = b.undef(packedTy); + for (int j = 0; j < 2; j++) { + input = b.insert_element(packedTy, input, vals[i * 2 + j], b.i32_val(j)); + } + inputs.push_back(b.bitcast(input, i32_ty)); + } + rewriter.create(loc, ptr, inputs); +} + +std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { + std::string funcName = + resultElementTy.isInteger(32) ? "__nv_umulhi" : "__nv_umul64hi"; + return funcName; +} + +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int /*formatStrByteCount*/, ValueRange args) const { + auto *ctx = rewriter.getContext(); + Type ptr = ptr_ty(ctx); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto funcOp = getVprintfDeclaration(rewriter); + auto loc = UnknownLoc::get(ctx); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + Value one = b.i32_val(1); + Value zero = b.i32_val(0); + + Value bufferPtr = b.null(ptr); + + SmallVector newArgs; + if (args.size() >= 1) { + SmallVector argTypes; + for (auto arg : args) { + Type newType; + Value newArg; + std::tie(newType, newArg) = printfPromoteValue(rewriter, arg); + argTypes.push_back(newType); + newArgs.push_back(newArg); + } + + Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes); + auto allocated = + rewriter.create(loc, ptr_ty(ctx), structTy, one, + /*alignment=*/0); + + for (const auto &entry : llvm::enumerate(newArgs)) { + auto index = b.i32_val(entry.index()); + auto fieldPtr = + b.gep(ptr_ty(ctx), structTy, allocated, ArrayRef{zero, index}); + b.store(entry.value(), fieldPtr); + } + bufferPtr = b.bitcast(allocated, ptr); + } + + SmallVector operands{formatStrStart, bufferPtr}; + b.call(funcOp, operands); +} + +void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); +} + +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto funcOp = getAssertfailDeclaration(rewriter); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + llvm::SmallString<64> messageString(message), fileString(file), + funcString(func); + messageString.push_back('\0'); + fileString.push_back('\0'); + funcString.push_back('\0'); + Value messageStringVal = + LLVM::addStringToModule(loc, rewriter, "assertMessage_", messageString); + Value fileStringVal = + LLVM::addStringToModule(loc, rewriter, "assertFile_", fileString); + Value funcStringVal = + LLVM::addStringToModule(loc, rewriter, "assertFunc_", funcString); + Value lineNumber = b.i32_val(line); + Value charSize = b.int_val(sizeof(size_t) * 8, sizeof(char)); + SmallVector operands = {messageStringVal, fileStringVal, lineNumber, + funcStringVal, charSize}; + b.call(funcOp, operands); +} + +int TargetInfo::getSharedAddressSpace() const { return 3; } + +int TargetInfo::getAddressSpace(Attribute addressSpace) const { + int spaceId = 0; + if (isa(addressSpace)) { + spaceId = 3; + } else { + llvm::report_fatal_error( + "Only support SharedMemorySpace, TensorMemorySpace for now"); + } + return spaceId; +} + +bool TargetInfo::supportVectorizedAtomics() const { + return computeCapability >= 90 && ptxVersion >= 81; +} + +} // namespace mlir::triton::NVIDIA diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h new file mode 100644 index 000000000..40e415dd4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -0,0 +1,77 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H + +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" + +namespace mlir::triton::NVIDIA { + +class TargetInfo : public mlir::triton::TargetInfoBase { +public: + TargetInfo(int computeCapability, int ptxVersion) + : computeCapability(computeCapability), ptxVersion(ptxVersion) {} + + bool supportMaximumMinimum() const override; + + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + + Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const override; + + void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const override; + Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const override; + + bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, + int swizzleByteSize) const override; + + void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val) const override; + + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const override; + + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + int axis) const override; + + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; + + std::string getMulhiFuncName(Type resultElementTy) const override; + + void printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const override; + + void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const override; + + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; + + int getSharedAddressSpace() const override; + + int getAddressSpace(Attribute addressSpace) const override; + + bool supportVectorizedAtomics() const override; + + int getPtxVersion() const { return ptxVersion; } + +private: + int computeCapability; + int ptxVersion; +}; + +} // namespace mlir::triton::NVIDIA + +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp new file mode 100644 index 000000000..e9e8909d5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -0,0 +1,727 @@ +#include "Dialect/NVGPU/IR/Dialect.h" +#include "DotOpToLLVM/MMAHelpers.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// The maximum number of tensor memory registers that can be accessed +// by a single message regardless of shape or repetitions +static constexpr int largestTmemLoadStore = 128; +// The maximum number of thread registers that can be populated by +// multiple messages +static constexpr int maxRegisters = 256; + +namespace { + +struct TMemAccessAtom { + int opBitWidth; + int colsPerThread; + int rowsPerThread; + const char *opShape; + bool usesSecondHalfOffset; +}; + +constexpr TMemAccessAtom TMemAccess32x32b{.opBitWidth = 32, + .colsPerThread = 1, + .rowsPerThread = 1, + .opShape = "32x32b", + .usesSecondHalfOffset = false}; + +constexpr TMemAccessAtom TMemAccess16x32bx2{.opBitWidth = 32, + .colsPerThread = 1, + .rowsPerThread = 1, + .opShape = "16x32bx2", + .usesSecondHalfOffset = true}; + +constexpr TMemAccessAtom TMemAccess16x256b{.opBitWidth = 256, + .colsPerThread = 2, + .rowsPerThread = 2, + .opShape = "16x256b", + .usesSecondHalfOffset = false}; + +struct TMemMessageTraits { + TMemAccessAtom atom; + bool usesSecondHalfOffset; + int numThreadsPerWarp; + int maxNumRepeats; + int maxCols; + int numRows; + int numCols; + int numRepeats; + int numRegs; + + bool operator<(const TMemMessageTraits &other) const { + return numRegs < other.numRegs; + } +}; + +struct TMemRuntimeInfo { + static constexpr int numRowsPerWarp = 32; + int numWarps; + int numWarpGroups; + int numElementsPer32B; + int numElements; + int numCols; + int blockM; + int blockN; + bool unpackedb16; + bool useStridedMessage; + int numBlocks; + int numWarpGroupsPerBlock; + bool blocksInterleaved; + int numColsPerBlock; + int colsPerWarpGroup; +}; + +TMemMessageTraits getTMemMessageFromAtom(const TMemAccessAtom &atom, + int narrowingFactor) { + TMemMessageTraits m; + m.atom = atom; + m.usesSecondHalfOffset = atom.usesSecondHalfOffset; + m.numThreadsPerWarp = 32; + m.maxNumRepeats = + largestTmemLoadStore / (atom.colsPerThread * atom.rowsPerThread); + m.maxCols = (atom.opBitWidth / 32) * m.maxNumRepeats; + m.numRows = m.numThreadsPerWarp / atom.rowsPerThread; + m.numCols = m.maxCols / narrowingFactor; + m.numRepeats = m.numCols / (atom.opBitWidth / 32); + m.numRegs = atom.colsPerThread * atom.rowsPerThread * m.numRepeats; + return m; +} + +// Only allows half of the thread registers to be used for tensor memory access +// to avoid register pressure. This ensures the largest tmem message width is +// used for the workload without inducing spills. +int getTMemMessageNarrowingFactor(int workloadThreadRegs) { + const int allowedRegUsage = maxRegisters / 2; + int narrowingFactor = 1; + while (workloadThreadRegs > allowedRegUsage) { + workloadThreadRegs /= 2; + narrowingFactor *= 2; + } + return narrowingFactor; +} + +int getEffectiveRegs(bool unpackedb16, bool useStridedMessage, int numRegs) { + // The effective register count is less when using unpacked or strided + // messages + if (unpackedb16) { + numRegs /= 2; + } + if (useStridedMessage) { + numRegs /= 2; + } + return numRegs; +} + +// If the workload runtime requires fewer registers than the default message +// width, use the widest possible message that matches the workload +TMemMessageTraits constrainMessageFromWorkload(TMemMessageTraits m, + const TMemRuntimeInfo &info, + int numRegs) { + m.numRegs = + getEffectiveRegs(info.unpackedb16, info.useStridedMessage, numRegs); + m.numRegs = std::min(largestTmemLoadStore, m.numRegs); + // Invert the above formulas to calculate the effective runtime message width + m.numCols = (m.numRegs * (m.atom.opBitWidth / 32)) / + (m.atom.colsPerThread * m.atom.rowsPerThread); + // Half as many registers are needed for 16-bit packed elements, + // so twice as many columns are accessed per message. + m.numCols *= info.numElementsPer32B; + return m; +} + +SmallVector packToI32(const SmallVector &values, Location loc, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector packedValues; + Type elType = values[0].getType(); + int numElementsPer32B = 32 / elType.getIntOrFloatBitWidth(); + if (numElementsPer32B == 1) + return values; + Value packed = b.undef(vec_ty(elType, numElementsPer32B)); + for (int i = 0; i < values.size(); i++) { + Value val = values[i]; + packed = b.insert_element(packed.getType(), packed, val, + b.i32_val(i % numElementsPer32B)); + if (i % numElementsPer32B == numElementsPer32B - 1 || + i == values.size() - 1) { + packed = b.bitcast(packed, i32_ty); + packedValues.push_back(packed); + packed = b.undef(vec_ty(elType, numElementsPer32B)); + } + } + return packedValues; +} + +TMemRuntimeInfo getTMemRuntimeInfo(Operation *op, RankedTensorType tensorType, + MemDescType memType) { + TMemRuntimeInfo info; + static_assert(TMemRuntimeInfo::numRowsPerWarp == 32, + "A single warp must access exactly 32 rows of tmem"); + assert( + nvidia_gpu::isDistributedLayoutTMemCompatible(op, tensorType, memType) && + "unsupported distributed layout for tensor memory"); + + info.numWarps = triton::gpu::lookupNumWarps(op); + assert(info.numWarps % 4 == 0 && "Unexpected number of warps"); + info.numWarpGroups = info.numWarps / 4; + info.numElementsPer32B = 32 / tensorType.getElementTypeBitWidth(); + auto shapePerCTA = mlir::triton::gpu::getShapePerCTA(tensorType); + info.numElements = product(shapePerCTA); + + triton::nvidia_gpu::TMemAllocation tmemAlloc = + triton::nvidia_gpu::getTmemAllocSizes(memType); + info.numCols = tmemAlloc.numCols; + + info.blockM = 0; + info.blockN = 0; + info.unpackedb16 = false; + if (auto attr = dyn_cast( + memType.getEncoding())) { + info.blockM = attr.getBlockM(); + info.blockN = attr.getBlockN(); + assert((!attr.getUnpacked() || info.numElementsPer32B <= 2) && + "unsupported unpacked layout"); + info.unpackedb16 = attr.getUnpacked() && (info.numElementsPer32B == 2); + } else { + assert(isa( + memType.getEncoding()) && + "Expecting a tensor memory encoding attribute"); + info.blockM = 128; + info.blockN = 32; + } + + info.useStridedMessage = (info.blockM == 64); + + info.numBlocks = ceil(info.numElements, info.blockM * info.blockN); + info.numWarpGroupsPerBlock = ceil(info.numWarpGroups, info.numBlocks); + info.blocksInterleaved = (info.numBlocks > 1 && info.useStridedMessage); + info.numColsPerBlock = info.numCols / info.numBlocks; + if (info.blocksInterleaved) { + info.numColsPerBlock *= 2; + } + info.colsPerWarpGroup = info.numColsPerBlock / info.numWarpGroupsPerBlock; + // If more than one warp group processes the same block, + // then fewer columns must be processed per message per warp group + info.numColsPerBlock /= info.numWarpGroupsPerBlock; + return info; +} + +void calculateAddressAndEmitTmemMessage( + Location loc, Value baseAddress, const TMemRuntimeInfo &info, + const TMemMessageTraits &message, ConversionPatternRewriter &rewriter, + const std::function &createMemoryOp) { + + TritonLLVMOpBuilder b(loc, rewriter); + Value warpId = rewriter.create(loc); + Value warpIdInGroup = b.urem(warpId, b.i32_val(4)); + Value warpGroupId = b.udiv(warpId, b.i32_val(4)); + + for (int block = 0; block < info.numBlocks; block += info.numWarpGroups) { + Value address = b.ptrtoint(i32_ty, baseAddress); + Value blockId = + b.add(b.i32_val(block), + b.udiv(warpGroupId, b.i32_val(info.numWarpGroupsPerBlock))); + Value warpGroupIdInBlock = + b.urem(warpGroupId, b.i32_val(info.numWarpGroupsPerBlock)); + Value startColumnId = + b.mul(warpGroupIdInBlock, b.i32_val(info.colsPerWarpGroup)); + Value blockRowId = + b.mul(warpIdInGroup, b.i32_val(TMemRuntimeInfo::numRowsPerWarp)); + + if (info.blocksInterleaved) { + Value blockIdIsOdd = b.urem(blockId, b.i32_val(2)); + Value blockIdPrevEven = b.sub(blockId, blockIdIsOdd); + blockRowId = b.add(blockRowId, b.mul(blockIdIsOdd, b.i32_val(16))); + startColumnId = + b.add(startColumnId, + b.mul(blockIdPrevEven, b.i32_val(info.numColsPerBlock / 2))); + } else { + startColumnId = + b.add(startColumnId, b.mul(blockId, b.i32_val(info.numColsPerBlock))); + } + + // A strided message accesses twice as many columns per message, + // thus half as many messages are required + int numColumns = info.useStridedMessage ? info.numColsPerBlock / 2 + : info.numColsPerBlock; + for (int colStart = 0; colStart < numColumns; colStart += message.numCols) { + // For messages that span only 16 rows (e.g. 16x256b), multiple messages + // are required to cover the entire set of rows per warp. + for (int rowStart = 0; rowStart < TMemRuntimeInfo::numRowsPerWarp; + rowStart += message.numRows) { + Value rowOffset = b.add(blockRowId, b.i32_val(rowStart)); + Value warpGroupAddress = + b.add(address, b.shl(rowOffset, b.i32_val(16))); + warpGroupAddress = b.add(warpGroupAddress, startColumnId); + + Value msgAddress = b.add(warpGroupAddress, b.i32_val(colStart)); + int secondHalfColOffset = 0; + if (info.useStridedMessage) { + // Offset to half way through the set of columns for this warpgroup. + secondHalfColOffset = numColumns; + } + createMemoryOp(msgAddress, secondHalfColOffset, info.unpackedb16, + message.numRegs, info.useStridedMessage); + } + } + } +} + +void createTensorMemoryStore(Location loc, Value address, + SmallVector &srcs, int secondHalfOffset, + Value pred, bool unpacked, + const TMemAccessAtom &atom, + ConversionPatternRewriter &rewriter) { + PTXBuilder ptxBuilder; + std::string packedStr = unpacked ? ".unpack::16b" : ""; + unsigned numRepeats = srcs.size() / (atom.rowsPerThread * atom.colsPerThread); + std::string opcode = "@$0 tcgen05.st.sync.aligned." + + std::string(atom.opShape) + ".x" + + std::to_string(numRepeats) + packedStr; + if (secondHalfOffset) + opcode += ".b32 [$1], " + std::to_string(secondHalfOffset) + ", {"; + else + opcode += ".b32 [$1], {"; + + SmallVector operands; + operands.push_back(ptxBuilder.newOperand(pred, "b")); + operands.push_back(ptxBuilder.newOperand(address, "r")); + for (int i = 0; i < srcs.size(); i++) { + opcode += "$" + std::to_string(i + 2); + auto *resultOp = ptxBuilder.newOperand(srcs[i], "r"); + operands.push_back(resultOp); + if (i < srcs.size() - 1) + opcode += ", "; + } + opcode += "};"; + + auto &st = *ptxBuilder.create(opcode); + st(operands, /*onlyAttachMLIRArgs=*/true); + Type voidTy = void_ty(rewriter.getContext()); + ptxBuilder.launch(rewriter, loc, voidTy); +} + +static void createWaitOpSt(Location loc, ConversionPatternRewriter &rewriter) { + PTXBuilder ptxBuilder; + std::string opcode = "tcgen05.wait::st.sync.aligned;"; + auto &wait = *ptxBuilder.create(opcode); + wait({}, /*onlyAttachMLIRArgs=*/true); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info) { + auto atom = info.useStridedMessage ? TMemAccess16x32bx2 : TMemAccess32x32b; + + int totalRegsNeeded = + getEffectiveRegs(info.unpackedb16, info.useStridedMessage, + info.numCols / info.numWarpGroups); + int narrowingFactor = getTMemMessageNarrowingFactor(totalRegsNeeded); + auto narrowedMessage = getTMemMessageFromAtom(atom, narrowingFactor); + narrowedMessage = constrainMessageFromWorkload(narrowedMessage, info, + narrowedMessage.numRegs); + + auto maxWidthMessage = getTMemMessageFromAtom(atom, /*narrowingFactor=*/1); + maxWidthMessage = constrainMessageFromWorkload(maxWidthMessage, info, + info.colsPerWarpGroup); + return std::min(narrowedMessage, maxWidthMessage); +} + +static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src, + Value dest, Value llSrc, Value pred, + Value tmemBase, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector srcValues = unpackLLElements(loc, llSrc, rewriter); + srcValues = packToI32(srcValues, loc, rewriter); + auto dstType = cast(dest.getType()); + auto info = getTMemRuntimeInfo(op, cast(src.getType()), + cast(dest.getType())); + const TMemMessageTraits message = selectTMemMessage(info); + int regIdx = 0; + calculateAddressAndEmitTmemMessage( + loc, tmemBase, info, message, rewriter, + [&](Value startAddress, int secondHalfColOffset, bool unpackedb16, + int regsPerMsg, bool useStridedMessage) { + SmallVector srcValuesSlice(srcValues.begin() + regIdx, + srcValues.begin() + regIdx + + regsPerMsg); + regIdx += regsPerMsg; + createTensorMemoryStore(loc, startAddress, srcValuesSlice, + secondHalfColOffset, pred, unpackedb16, + message.atom, rewriter); + }); + createWaitOpSt(loc, rewriter); + + // Emit a barrier to ensure all threads have finished writing to tensor memory + // before any use of the tensor memory. + b.barrier(); +} + +struct TensorMemoryAllocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TMEMAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value base = rewriter.create(loc); + Value baseInt = b.ptrtoint(i32_ty, base); + int colOffset = cast(op->getAttr("tensor_memory_col_offset")) + .getValue() + .getZExtValue(); + int rowOffset = cast(op->getAttr("tensor_memory_row_offset")) + .getValue() + .getZExtValue(); + Value allocAddress = b.add(baseInt, b.i32_val(colOffset | rowOffset << 16)); + // Cast to address space 3 as the shared memory object uses 3. + // TODO: clean this up and use either a int or ptr address space 6 + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + Value ptr = b.inttoptr(ptrTy, allocAddress); + SmallVector order(op.getType().getRank()); + std::iota(order.begin(), order.end(), 0); + std::reverse(order.begin(), order.end()); + auto shape = op.getType().getShape(); + + if (op.getSrc()) { + lowerStoreToTensorMemory(loc, op, op.getSrc(), op.getResult(), + adaptor.getSrc(), b.i1_val(true), ptr, rewriter); + } + + rewriter.replaceOp(op, ptr); + return success(); + } +}; + +Value createTensorMemoryLoad(Location loc, triton::nvidia_gpu::TMEMLoadOp op, + Value address, int secondHalfOffset, bool unpacked, + int numRegPerMessage, const TMemAccessAtom &atom, + ConversionPatternRewriter &rewriter) { + PTXBuilder ptxBuilder; + // If the memory is unpacked we need to pack on the fly when loading. + std::string packedStr = unpacked ? ".pack::16b" : ""; + unsigned numRepeats = + numRegPerMessage / (atom.rowsPerThread * atom.colsPerThread); + std::string opcode = "tcgen05.ld.sync.aligned." + std::string(atom.opShape) + + ".x" + std::to_string(numRepeats) + packedStr + ".b32 {"; + + SmallVector operands; + for (int i = 0; i < numRegPerMessage; i++) { + opcode += "$" + std::to_string(i); + auto *resultOp = ptxBuilder.newOperand("=r"); + operands.push_back(resultOp); + if (i < numRegPerMessage - 1) + opcode += ", "; + } + opcode += "}, [$" + std::to_string(numRegPerMessage) + "]"; + if (secondHalfOffset) + opcode += ", " + std::to_string(secondHalfOffset); + opcode += ";"; + operands.push_back(ptxBuilder.newOperand(address, "r")); + auto &ld = *ptxBuilder.create(opcode); + ld(operands, /*onlyAttachMLIRArgs=*/true); + SmallVector elemTypes(numRegPerMessage, i32_ty); + MLIRContext *ctx = op.getContext(); + Type structTy = struct_ty(elemTypes); + Value ret = ptxBuilder.launch(rewriter, loc, structTy); + return ret; +} + +static SmallVector unpackResults(Value packedValues, Type elemTy, + int numCols, Location loc, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector resultVals; + int numElementsPer32B = 32 / elemTy.getIntOrFloatBitWidth(); + Type packedType = elemTy; + if (numElementsPer32B > 1) + packedType = vec_ty(elemTy, numElementsPer32B); + for (int i = 0; i < numCols; i++) { + Value result = b.extract_val(i32_ty, packedValues, i); + result = b.bitcast(result, packedType); + if (numElementsPer32B > 1) { + for (int j = 0; j < numElementsPer32B; j++) { + Value elem = b.extract_element(elemTy, result, b.i32_val(j)); + resultVals.push_back(elem); + } + } else { + resultVals.push_back(result); + } + } + return resultVals; +} + +static void createWaitOpLd(Location loc, ConversionPatternRewriter &rewriter) { + PTXBuilder ptxBuilder; + std::string opcode = "tcgen05.wait::ld.sync.aligned;"; + auto &wait = *ptxBuilder.create(opcode); + wait({}, /*onlyAttachMLIRArgs=*/true); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +struct TensorMemoryLoadOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TMEMLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto llvmElemTy = + getTypeConverter()->convertType(op.getSrc().getType().getElementType()); + auto tmemBase = adaptor.getSrc(); + + auto info = getTMemRuntimeInfo(op, cast(op.getType()), + cast(op.getSrc().getType())); + const TMemMessageTraits message = selectTMemMessage(info); + SmallVector resultVals; + calculateAddressAndEmitTmemMessage( + loc, tmemBase, info, message, rewriter, + [&](Value startAddress, int secondHalfColOffset, bool unpackedb16, + int regsPerMessage, bool useStridedMessage) { + Value packedValues = createTensorMemoryLoad( + loc, op, startAddress, secondHalfColOffset, unpackedb16, + regsPerMessage, message.atom, rewriter); + auto results = + unpackResults(packedValues, op.getType().getElementType(), + regsPerMessage, loc, rewriter); + resultVals.append(results.begin(), results.end()); + }); + Type structTy = getTypeConverter()->convertType(op.getType()); + Value resultStruct = + packLLElements(loc, getTypeConverter(), resultVals, rewriter, structTy); + // Wait insertion could be moved to the TTGIR level if needed. + createWaitOpLd(loc, rewriter); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct TensorMemoryStoreOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TMEMStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + auto tmemBase = adaptor.getDst(); + Value pred = adaptor.getPred(); + lowerStoreToTensorMemory(loc, op, op.getSrc(), op.getDst(), + adaptor.getSrc(), pred, tmemBase, rewriter); + + rewriter.eraseOp(op); + return success(); + } +}; + +static Value +createBlockedScalesSMEMDescriptor(ConversionPatternRewriter &rewriter, + Location loc, Value baseSrc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + static_assert(sizeof(NVIDIA::SMEMDescriptor) == 8, + "Descriptor size should be 64 bits."); + NVIDIA::SMEMDescriptor desc; + desc.descriptor = 0; + desc.swizzlingMode = 0; // No swizzling for now + desc.leadDimensionBaseOffset = 16 >> 4; // 16 bytes + desc.strideDimensionBaseOffset = 128 >> 4; // 8 x 16 bytes + // See matrix-descriptor-encode(x) function in the ptx doc. + // matrix-descriptor-encode(addr) = (addr & 0x3FFFF) >> 4 + auto smemAddr = b.ptrtoint(i64_ty, baseSrc); + return b.add(b.int_val(64, desc.descriptor), + b.lshr(b.shl(smemAddr, b.int_val(64, 46)), b.int_val(64, 50))); +} + +static void createCommit(ConversionPatternRewriter &rewriter, Location loc, + Value barrier, Value pred) { + PTXBuilder ptxBuilder; + auto *barrierOperand = ptxBuilder.newAddrOperand(barrier, "r"); + std::string opcode = "tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64"; + auto &barrierOp = *ptxBuilder.create(opcode); + barrierOp(barrierOperand).predicate(pred); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +static void createTcgen05Cp(ConversionPatternRewriter &rewriter, Location loc, + Value tmem_address, Value src_desc, Value pred) { + PTXBuilder ptxBuilder; + auto dst = ptxBuilder.newAddrOperand(tmem_address, "r"); + auto src = ptxBuilder.newOperand(src_desc, "l"); + std::string opcode = "tcgen05.cp.cta_group::1.warpx4.32x128b"; + auto &op = *ptxBuilder.create(opcode); + op({dst, src}).predicate(pred); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +struct TensorMemoryCopyOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::TMEMCopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = cast(op.getSrc().getType()); + assert(isa(srcTy.getMemorySpace())); + assert(isa(srcTy.getEncoding())); + + auto sharedEnc = + cast(srcTy.getEncoding()); + assert( + sharedEnc.getMaxPhase() == 1 && sharedEnc.getPerPhase() == 1 && + sharedEnc.getVec() == 1 && + "The src SMEM of tmem_copy should not have swizzling applied for now"); + + Value baseSrc = + LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), + typeConverter->convertType(srcTy.getElementType()), rewriter) + .getBase(); + + Value baseDst = adaptor.getDst(); + + // The following codegen assumes that we use tcgen05.cp only with + // the warpx4.32x128b mode, to load blocked scales from MXFP. + // We will expand the support as we find more use cases for the instruction. + + Value smemDesc = createBlockedScalesSMEMDescriptor(rewriter, loc, baseSrc); + Value pred = LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter); + + auto createCopy = [&](int repMorN, int repK) { + for (int i = 0; i < repMorN; ++i) { + for (int j = 0; j < repK; ++j) { + // Multiple copies of 32x128b blocks are laid out along M/N first then + // K + auto colOffset = b.int_val(32, (j * repMorN + i) * 4); + auto tmemAddr = b.add(b.ptrtoint(i32_ty, baseDst), colOffset); + createTcgen05Cp(rewriter, loc, tmemAddr, smemDesc, pred); + smemDesc = + b.add(smemDesc, b.int_val(64, 512 >> 4)); // one chunk = 32x16B + } + } + }; + + // Break up src axes into rep_m x rep_k x 32x128b, where rep_m = BLOCK_M / + // 128 and rep_k = BLOCK_K / 128 32x128b blockes are contiguously laid out + // in SMEM. rep_m * rep_k copies of such blocks are consumed by one + // dot_scaled op for given BLOCK_M / BLOCK_K. Some axes of the scale shape + // can be flattened into one, to reduce the rank of the load. Since rep_m + // blocks are not contiguous in SMEM, we need to identify the original rep_m + // axis from the given input shape. + + // The SMEM shapes are expected to be one of the followings. As long as + // rep_m and rep_k can be identified correctly, other patterns are allowed. + // * (rep_m x 32, 16B), meant only for TMEMCopy unit tests + // * (rep_m, rep_k * 32 x 4 x 4B), 2D scale load with cp.async + // * (rep_m, rep_k, 32, 16B), 4D scale load with TMA + // * (rep_m, rep_k, 32, 4, 4B), 5D scale load with cp.async + auto elemBits = srcTy.getElementType().getIntOrFloatBitWidth(); + int prodInner = 1; + int repMorN = 1; + int repK = 1; + + for (int i = srcTy.getRank() - 1; i >= 0; --i) { + prodInner *= srcTy.getDimSize(i); + if (prodInner * elemBits >= 32 * 128) { + if (i == 0) { + repMorN = prodInner * elemBits / (32 * 128); + repK = 1; + } else if (i == 1) { + repMorN = srcTy.getDimSize(0); + repK = prodInner * elemBits / (32 * 128); + } else { + repMorN = srcTy.getDimSize(0); + repK = srcTy.getDimSize(1); + } + break; + } + } + + createCopy(repMorN, repK); + + if (op.getBarrier()) { + auto barrier = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getBarrier(), i64_ty, rewriter); + createCommit(rewriter, loc, barrier.getBase(), pred); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct MemDescSubviewOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + + if (!isa( + srcTy.getEncoding())) { + return failure(); + } + + // newBase = base + offset + auto tmemBase = adaptor.getSrc(); + SmallVector opOffsetVals = op.getOffsets(); + size_t destRank = op.getResult().getType().getRank(); + SmallVector offsetVals; + int rankReduced = srcTy.getRank() - destRank; + for (int i = rankReduced; i < opOffsetVals.size(); i++) { + offsetVals.push_back(opOffsetVals[i]); + } + + triton::nvidia_gpu::TMemAllocation tmemAlloc = + triton::nvidia_gpu::getTmemAllocSizes(cast(dstTy)); + int numColOffset = tmemAlloc.numCols; + Value newBase = b.ptrtoint(rewriter.getI32Type(), tmemBase); + newBase = rewriter.create( + loc, newBase, + rewriter.create(loc, opOffsetVals[0], + b.i32_val(numColOffset))); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + rewriter.replaceOp(op, b.inttoptr(elemPtrTy, newBase)); + return success(); + } +}; + +} // namespace + +void mlir::triton::NVIDIA::populateTensorMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add( + typeConverter, benefit); + return; +} + +void mlir::triton::NVIDIA::populateTensorMemorySubviewOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + return; +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp new file mode 100644 index 000000000..97e830967 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +struct MakeTensorPtrOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto offsets = adaptor.getOffsets(); + auto shapes = adaptor.getShape(); + auto strides = adaptor.getStrides(); + auto base = adaptor.getBase(); + auto result = op.getResult(); + + SmallVector elems; + for (auto offset : offsets) + elems.push_back(offset); + for (auto shape : shapes) + elems.push_back(shape); + for (auto stride : strides) + elems.push_back(stride); + + elems.push_back(base); + + auto newValue = packLLElements(op.getLoc(), getTypeConverter(), elems, + rewriter, result.getType()); + rewriter.replaceOp(op, newValue); + return success(); + } +}; + +struct AdvanceOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ptrType = op.getPtr().getType(); + auto tensorPtr = adaptor.getPtr(); + + auto offsets = adaptor.getOffsets(); + auto elems = unpackLLElements(loc, tensorPtr, rewriter); + + SmallVector newOffsets; + + for (auto [offset, oldOffset] : llvm::zip_first(offsets, elems)) { + newOffsets.push_back((b.add(offset, oldOffset))); + } + + for (size_t i = 0; i < newOffsets.size(); ++i) { + elems[i] = newOffsets[i]; + } + + auto newValue = packLLElements(op.getLoc(), getTypeConverter(), elems, + rewriter, ptrType); + rewriter.replaceOp(op, newValue); + return success(); + } +}; +} // namespace + +void mlir::triton::NVIDIA::populateTensorPtrOpsToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp new file mode 100644 index 000000000..3bba6d473 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -0,0 +1,274 @@ +#include "Dialect/NVGPU/IR/Dialect.h" +#include "TritonNVIDIAGPUToLLVM/Passes.h" +#include "TritonNVIDIAGPUToLLVM/Utility.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM +#include "TritonNVIDIAGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton::NVIDIA; + +namespace { + +class TritonLLVMFunctionConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + } +}; + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); + + // Warp specialization is lowered later. + addLegalOp(); + addLegalOp(); + addLegalOp(); + addLegalOp(); + } +}; + +struct ConvertTritonGPUToLLVM + : public triton::impl::ConvertTritonGPUToLLVMBase { + using ConvertTritonGPUToLLVMBase::ConvertTritonGPUToLLVMBase; + + ConvertTritonGPUToLLVM(int32_t computeCapability) + : ConvertTritonGPUToLLVMBase({computeCapability}) {} + ConvertTritonGPUToLLVM(int32_t computeCapability, int32_t ptxVersion) + : ConvertTritonGPUToLLVMBase({computeCapability, ptxVersion}) {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + TargetInfo targetInfo(computeCapability, ptxVersion); + + // Allocate shared memory and set barrier + ModuleAllocation allocation(mod); + ModuleMembarAnalysis membarPass(&allocation); + membarPass.run(); + + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); + + // Lower functions + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + mlir::triton::populateFuncOpConversionPattern( + typeConverter, funcPatterns, targetInfo, patternBenefitDefault); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + + // initSharedMemory is run before the conversion of call and ret ops, + // because the call op has to know the shared memory base address of each + // function + initSharedMemory(typeConverter); + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + + RewritePatternSet patterns(context); + int benefit = patternBenefitPrioritizeOverLLVMConversions; + mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); + mlir::triton::NVIDIA::populateTensorMemorySubviewOpToLLVMPattern( + typeConverter, patterns, patternBenefitNvidiaTensorCoreSubviewPattern); + mlir::triton::NVIDIA::populateTMAToLLVMPatterns(typeConverter, targetInfo, + patterns, benefit); + populateDotOpToLLVMPatterns(typeConverter, patterns, benefit); + populateElementwiseOpToLLVMPatterns(typeConverter, patterns, + axisInfoAnalysis, computeCapability, + targetInfo, benefit); + populateClampFOpToLLVMPattern(typeConverter, patterns, axisInfoAnalysis, + computeCapability, + patternBenefitClampOptimizedPattern); + populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, + axisInfoAnalysis, benefit); + mlir::triton::populateReduceOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateGatherOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + populateBarrierOpToLLVMPatterns(typeConverter, patterns, benefit); + populateTensorPtrOpsToLLVMPatterns(typeConverter, patterns, benefit); + populateClusterOpsToLLVMPatterns(typeConverter, patterns, benefit); + mlir::triton::populateHistogramOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + // TODO(thomas): this should probably be done in a separate step to not + // interfere with our own lowering of arith ops. Add arith/math's patterns + // to help convert scalar expression to LLVM. + mlir::arith::populateCeilFloorDivExpandOpsPatterns(patterns); + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); + mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); + mlir::triton::populateViewOpToLLVMPatterns(typeConverter, patterns, + benefit); + mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); + mlir::triton::NVIDIA::populateTensorMemoryOpToLLVMPattern( + typeConverter, patterns, benefit); + mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, + patterns, benefit); + mlir::triton::NVIDIA::populateTCGen5MMAOpToLLVMPattern(typeConverter, + patterns, benefit); + mlir::triton::NVIDIA::populateFp4ToFpToLLVMPatterns(typeConverter, patterns, + benefit); + mlir::triton::populateRegReallocOpToLLVMPatterns(typeConverter, patterns, + benefit); + TritonLLVMConversionTarget convTarget(*context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + + // Fold CTAId when there is only 1 CTA. + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + if (numCTAs == 1) { + mod.walk([](triton::nvgpu::ClusterCTAIdOp id) { + OpBuilder b(id); + Value zero = LLVM::createConstantI32(id->getLoc(), b, 0); + id.replaceAllUsesWith(zero); + }); + } + } + +private: + void initSharedMemory(LLVMTypeConverter &typeConverter) { + ModuleOp mod = getOperation(); + OpBuilder b(mod.getBodyRegion()); + auto loc = mod.getLoc(); + auto elemTy = typeConverter.convertType(b.getIntegerType(8)); + // Set array size 0 and external linkage indicates that we use dynamic + // shared allocation to allow a larger shared memory size for each kernel. + // + // Ask for 16B alignment on global_smem because that's the largest we should + // ever need (4xi32). + auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); + b.create( + loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, + "global_smem", /*value=*/Attribute(), /*alignment=*/16, + // Add ROCm support. + static_cast(NVVM::NVVMMemorySpace::kSharedMemorySpace)); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { + +std::unique_ptr> createConvertTritonGPUToLLVMPass() { + return std::make_unique(); +} +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int32_t computeCapability) { + return std::make_unique(computeCapability); +} +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int32_t computeCapability, + int32_t ptxVersion) { + return std::make_unique(computeCapability, + ptxVersion); +} + +bool NVIDIA::canSkipBarSync(Operation *before, Operation *after) { + // Multiple init barriers on the same allocation would usually not happen but + // that allows us to avoid barriers between multiple subslice of an array of + // mbarriers. This is still correct even if the inits happen on the same + // allocation. + if (isa(before) && + isa(after)) + return true; + + if (isa(before) && + isa(after)) + return true; + + // We can't have a warp get ahead when we have a chain of mbarrier wait so we + // need a barrier in between two WaitBarrierOp. + if (isa(before) && + isa(after)) + return false; + + // Even though WaitBarrierOp, AsyncTMACopyGlobalToLocalOp and + // AsyncTMACopyGlobalToLocalOp read and write to the mbarrier allocation it is + // valid for them to happen in different order on different threads, therefore + // we don't need a barrier between those operations. + if (isa(before) && + isa(after)) + return true; + + // A mbarrier wait is released only when the whole operations is done, + // therefore any thread can access the memory after the barrier even if some + // threads haven't reached the mbarrier wait. + if (isa(before) && + !isa(after)) + return true; + + return false; +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp new file mode 100644 index 000000000..d22ef12e7 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -0,0 +1,137 @@ +#include "Utility.h" +#include "Dialect/NVGPU/IR/Dialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +namespace LLVM { +namespace NVIDIA { +using namespace mlir::triton; + +static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter, Value val, + Value i, NVVM::ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned bits = val.getType().getIntOrFloatBitWidth(); + + if (bits == 64) { + Type vecTy = vec_ty(f32_ty, 2); + Value vec = b.bitcast(val, vecTy); + Value val0 = b.extract_element(f32_ty, vec, b.i32_val(0)); + Value val1 = b.extract_element(f32_ty, vec, b.i32_val(1)); + val0 = shuffleCommonImpl(loc, rewriter, val0, i, mode, clamp); + val1 = shuffleCommonImpl(loc, rewriter, val1, i, mode, clamp); + vec = b.undef(vecTy); + vec = b.insert_element(vecTy, vec, val0, b.i32_val(0)); + vec = b.insert_element(vecTy, vec, val1, b.i32_val(1)); + return b.bitcast(vec, val.getType()); + } + Type type = val.getType(); + if (type != i32_ty) { + val = b.bitcast(val, int_ty(bits)); + if (bits < 32) + val = b.zext(i32_ty, val); + } + Value mask = b.i32_val(0xFFFFFFFF); + Value result = rewriter.create(loc, i32_ty, mask, val, i, clamp, + mode, UnitAttr()); + if (type != i32_ty) { + if (bits < 32) + result = b.trunc(int_ty(bits), result); + result = b.bitcast(result, type); + } + return result; +} + +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, NVVM::ShflKind mode, Value clamp) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + // To shuffle pointers, convert them to i64. + Type valTy = val.getType(); + if (isa(valTy)) + val = b.ptrtoint(i64_ty, val); + Value result = shuffleCommonImpl(loc, rewriter, val, i, mode, clamp); + if (isa(valTy)) + result = b.inttoptr(valTy, result); + return result; +} + +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), NVVM::ShflKind::bfly, + b.i32_val(0x1f)); +} + +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, b.i32_val(i), NVVM::ShflKind::up, + b.i32_val(0x0)); +} + +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleIdx(loc, rewriter, val, b.i32_val(i)); +} + +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + return shuffleCommon(loc, rewriter, val, i, NVVM::ShflKind::idx, + b.i32_val(0x1f)); +} + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { + assert(axis >= 0); + assert(axis < 3); + assert(moduleOp); + + // It is not easy to get the compute capability here, so we use numCTAs to + // decide the semantic of GetProgramIdOp. If numCTAs = 1, then + // GetProgramIdOp is converted to "%ctaid", otherwise it is converted to + // "%clusterid". + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + + std::string sreg = numCTAs == 1 ? "ctaid." : "clusterid."; + sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z' + return getSRegValue(rewriter, loc, sreg); +} + +Value getSRegValue(OpBuilder &rewriter, Location loc, StringRef sRegStr) { + ValueRange args; + auto intrName = Twine("llvm.nvvm.read.ptx.sreg.") + sRegStr; + auto callOp = + createLLVMIntrinsicCallOp(rewriter, loc, intrName.str(), i32_ty, args); + return callOp.getResult(0); +} + +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask) { + Value args[] = {a, b, mask}; + auto op = + createLLVMIntrinsicCallOp(rewriter, loc, "llvm.nvvm.prmt", i32_ty, args); + return op.getResult(0); +} + +/// Create a predicate with just single active thread. +Value createElectPredicate(Location loc, RewriterBase &rewriter) { + return rewriter.create(loc, i1_ty); +} + +void createSyncWarp(Location loc, OpBuilder &rewriter) { + TritonLLVMOpBuilder b(loc, rewriter); + Type resultTy = void_ty(rewriter.getContext()); + Value args[] = {b.i32_val(0xffffffff)}; + createLLVMIntrinsicCallOp(rewriter, loc, "llvm.nvvm.bar.warp.sync", resultTy, + args); +} + +Value createElectPredicateWarp0(Location loc, RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value threadId = getThreadId(rewriter, loc); + Value warp0 = b.icmp_ult(threadId, b.i32_val(32)); + return b.and_(warp0, createElectPredicate(loc, rewriter)); +} + +} // namespace NVIDIA +} // namespace LLVM +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h new file mode 100644 index 000000000..078af5efa --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -0,0 +1,51 @@ +#ifndef TRITON_CONVERSION_TRITONNVIDIAGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONNVIDIAGPU_TO_LLVM_UTILITY_H + +#include "nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +// Operators + +namespace mlir { +namespace LLVM { + +namespace NVIDIA { + +Value getSRegValue(OpBuilder &b, Location loc, StringRef sRegStr); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); + +/// Create a predicate with just single active thread. +Value createElectPredicate(Location loc, RewriterBase &rewriter); +Value createElectPredicateWarp0(Location loc, RewriterBase &rewriter); + +// Create bar.warp.sync +void createSyncWarp(Location loc, OpBuilder &builder); + +} // namespace NVIDIA +} // namespace LLVM + +} // namespace mlir + +#endif diff --git a/third_party/enflame/include/triton/third_party/nvidia/tools/cuda/compile.c b/third_party/enflame/include/triton/third_party/nvidia/tools/cuda/compile.c new file mode 100644 index 000000000..24b369354 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/tools/cuda/compile.c @@ -0,0 +1,68 @@ +/* clang-format off */ +#include +#include +#include +#include +#include + + +// helpers to check for cuda errors +#define CUDA_CHECK(ans) {{\ + gpuAssert((ans), __FILE__, __LINE__);\ + }}\ + +static inline void gpuAssert(CUresult code, const char *file, int line) {{ + if (code != CUDA_SUCCESS) {{ + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + printf("%s\\n", err); + exit(code); + }} +}} + +// globals +#define CUBIN_NAME {kernel_name}_cubin +CUmodule {kernel_name}_mod = NULL; +CUfunction {kernel_name}_func = NULL; +unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }}; + + +void unload_{kernel_name}(void) {{ + CUDA_CHECK(cuModuleUnload({kernel_name}_mod)); +}} + +// TODO: some code duplication with `runtime/backend/cuda.c` +void load_{kernel_name}() {{ + int dev = 0; + void *bin = (void *)&CUBIN_NAME; + int shared = {shared}; + CUDA_CHECK(cuModuleLoadData(&{kernel_name}_mod, bin)); + CUDA_CHECK(cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev)); + if (shared > 49152 && shared_optin > 49152) {{ + CUDA_CHECK(cuFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED)); + CUDA_CHECK(cuFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)) + }} +}} + +/* +{kernel_docstring} +*/ +CUresult {kernel_name}(CUstream stream, {signature}) {{ + if ({kernel_name}_func == NULL) + load_{kernel_name}(); + unsigned int gX = {gridX}; + unsigned int gY = {gridY}; + unsigned int gZ = {gridZ}; + CUdeviceptr global_scratch = 0; + void *args[{num_args}] = {{ {arg_pointers} }}; + // TODO: shared memory + if(gX * gY * gZ > 0) + return cuLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL); +}} diff --git a/third_party/enflame/include/triton/third_party/nvidia/tools/cuda/compile.h b/third_party/enflame/include/triton/third_party/nvidia/tools/cuda/compile.h new file mode 100644 index 000000000..d98b7063b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/tools/cuda/compile.h @@ -0,0 +1,14 @@ +#ifndef TT_KERNEL_INCLUDES +#define TT_KERNEL_INCLUDES + +#include +#include +#include +#include + +#endif + +void unload_{kernel_name}(void); +void load_{kernel_name}(void); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} +CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); diff --git a/third_party/enflame/include/triton/third_party/nvidia/triton_nvidia.cc b/third_party/enflame/include/triton/third_party/nvidia/triton_nvidia.cc new file mode 100644 index 000000000..0b280f789 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/triton_nvidia.cc @@ -0,0 +1,170 @@ +#include "Dialect/NVGPU/IR/Dialect.h" +#include "NVGPUToLLVM/NVGPUToLLVMPass.h" +#include "TritonNVIDIAGPUToLLVM/Passes.h" +#include "cublas_instance.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "passes.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "llvm/IR/Constants.h" +#include +#include +#include + +namespace py = pybind11; + +void init_triton_nvidia_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton; + // TODO: it is weird to pass mlir::triton::NVVM here since the conversion is + // nvidia-specificontext + m.def("add_to_llvmir", + [](mlir::PassManager &pm, int32_t capability, int32_t ptxVersion) { + pm.addPass(mlir::triton::createConvertTritonGPUToLLVMPass( + capability, ptxVersion)); + }); +} + +void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) { + ADD_PASS_WRAPPER_1("add_plan_cta", mlir::createTritonNvidiaGPUPlanCTAPass, + mlir::triton::nvidia_gpu::ClusterInfo *); + ADD_PASS_WRAPPER_0("add_fence_insertion", + mlir::createTritonNvidiaGPUFenceInsertionPass); + ADD_PASS_WRAPPER_0("add_tma_lowering", + mlir::createTritonNvidiaGPUTMALoweringPass); + ADD_PASS_WRAPPER_0("add_keep_acc_in_tmem", + mlir::createTritonNvidiaGPUKeepAccInTMemPass); + ADD_PASS_WRAPPER_0("add_promote_lhs_to_tmem", + mlir::createTritonNvidiaGPUPromoteLHSToTMemPass); + ADD_PASS_WRAPPER_0("add_nvgpu_to_llvm", + mlir::triton::createConvertNVGPUToLLVMPass); + ADD_PASS_WRAPPER_0("add_warp_specialize_to_llvm", + mlir::triton::createConvertWarpSpecializeToLLVM); + ADD_PASS_WRAPPER_0("add_allocate_tensor_memory", + mlir::createTensorMemoryAllocationPass); + ADD_PASS_WRAPPER_0("add_lower_mma", + mlir::createTritonNvidiaGPUMMALoweringPass); +} + +void init_triton_nvidia(py::module &&m) { + auto passes = m.def_submodule("passes"); + init_triton_nvidia_passes_ttgpuir(passes.def_submodule("ttgpuir")); + init_triton_nvidia_passes_ttnvgpuir(passes.def_submodule("ttnvgpuir")); + + // cluster info + py::class_(m, "ClusterInfo") + .def(py::init<>()) + .def_readwrite("clusterDimX", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimX) + .def_readwrite("clusterDimY", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimY) + .def_readwrite("clusterDimZ", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimZ) + .def("__repr__", [](mlir::triton::nvidia_gpu::ClusterInfo &self) { + std::ostringstream oss; + oss << "(" << self.clusterDimX << ", " << self.clusterDimY << ", " + << self.clusterDimZ << ")"; + return oss.str(); + }); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::registerNVVMDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + // TODO: could be done in python if we had a generic interface to set metadata + m.def("set_nvvm_reflect_ftz", [](llvm::Module *mod) { + // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters + // this will enable fast math path in libdevice + // for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to + // sqrt.approx.ftz.f32 + using namespace llvm; + auto &ctx = mod->getContext(); + Type *i32 = Type::getInt32Ty(ctx); + auto *mdFour = ConstantAsMetadata::get(ConstantInt::getSigned(i32, 4)); + auto *mdName = MDString::get(ctx, "nvvm-reflect-ftz"); + auto *mdOne = ConstantAsMetadata::get(ConstantInt::getSigned(i32, 1)); + auto *reflect = MDNode::get(ctx, {mdFour, mdName, mdOne}); + mod->addModuleFlag(reflect); + }); + + // cublas + auto cublas = m.def_submodule("cublas"); + + py::class_(cublas, "CublasLt") + .def(py::init<>([&](py::object &workspace) { + auto wrk_ptr = workspace.attr("data_ptr")().cast(); + auto wrk_size = workspace.attr("numel")().cast() * + workspace.attr("element_size")().cast(); + return new CublasLtInstance(wrk_ptr, wrk_size); + })) + .def("matmul", [](CublasLtInstance &self, py::object &A, py::object &B, + py::object &C) { + auto A_ptr = A.attr("data_ptr")().cast(); + auto B_ptr = B.attr("data_ptr")().cast(); + auto C_ptr = C.attr("data_ptr")().cast(); + + auto A_shape = A.attr("shape").cast>(); + auto B_shape = B.attr("shape").cast>(); + auto C_shape = C.attr("shape").cast>(); + + auto A_dtype = A.attr("dtype").attr("__str__")().cast(); + auto B_dtype = B.attr("dtype").attr("__str__")().cast(); + auto C_dtype = C.attr("dtype").attr("__str__")().cast(); + + assert(A_dtype == B_dtype && A_dtype == C_dtype); + assert(A_dtype == "torch.float8_e4m3fn" || A_dtype == "torch.float16"); + + std::string dtype_str = A_dtype.substr(A_dtype.find_last_of('.') + 1); + cudaDataType_t dtype; + if (dtype_str == "float8_e4m3fn") { + dtype = CUDA_R_8F_E4M3; + } else if (dtype_str == "float16") { + dtype = CUDA_R_16F; + } + + if (A_shape.size() != 2 || B_shape.size() != 2 || C_shape.size() != 2) { + throw std::runtime_error("Only 2D matrices are supported."); + } + + int k = A_shape[1]; + if (k != B_shape[1]) { + throw std::runtime_error("Matrix dimensions do not match. A is [" + + std::to_string(A_shape[0]) + ", " + + std::to_string(A_shape[1]) + "], B is [" + + std::to_string(B_shape[0]) + ", " + + std::to_string(B_shape[1]) + + "]. Expected A.shape[1] == B.shape[1]. Note " + "that B needs to be transposed."); + } + + int m = A_shape[0]; + if (m != C_shape[0]) { + throw std::runtime_error("Matrix dimensions do not match. A is [" + + std::to_string(A_shape[0]) + ", " + + std::to_string(A_shape[1]) + "], C is [" + + std::to_string(C_shape[0]) + ", " + + std::to_string(C_shape[1]) + + "]. Expected A.shape[0] == C.shape[0]."); + } + + int n = B_shape[0]; + if (n != C_shape[1]) { + throw std::runtime_error("Matrix dimensions do not match. B is [" + + std::to_string(B_shape[0]) + ", " + + std::to_string(B_shape[1]) + "], C is [" + + std::to_string(C_shape[0]) + ", " + + std::to_string(C_shape[1]) + + "]. Expected B.shape[0] == C.shape[1]. Note " + "that B needs to be transposed."); + } + + self.matmul(A_shape[0], B_shape[0], A_shape[1], A_ptr, B_ptr, C_ptr, + dtype); + }); +} diff --git a/third_party/enflame/include/triton/third_party/nvidia/unittest/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/unittest/CMakeLists.txt new file mode 100644 index 000000000..bd3c0c6c0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/unittest/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Conversion) diff --git a/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/CMakeLists.txt new file mode 100644 index 000000000..b543b6c62 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonGPUToLLVM) diff --git a/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..5d2dbbb0b --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,5 @@ +add_triton_ut( + NAME TestPtxAsmFormat + SRCS PTXAsmFormatTest.cpp + LIBS TritonGPUToLLVM TritonNVIDIAGPUToLLVM +) diff --git a/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp b/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp new file mode 100644 index 000000000..4fd6cefbd --- /dev/null +++ b/third_party/enflame/include/triton/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp @@ -0,0 +1,154 @@ +#include "nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/Signals.h" + +#include + +namespace mlir { +namespace triton { +class PTXAsmFormatTest : public ::testing::Test { +protected: + static constexpr int numValues = 4; + + PTXAsmFormatTest() { + ctx.loadDialect(); + + createValues(); + } + + // Creates the test values. + void createValues() { + OpBuilder builder(&ctx); + builder.setInsertionPointToStart(&block); + + // a b1 value for predicate. + v[0] = builder.create(builder.getUnknownLoc(), 1, 1); + for (int i = 0; i < numValues; i++) { + v[i + 1] = + builder.create(builder.getUnknownLoc(), i, 32); + } + } + + MLIRContext ctx; + Block block; + Value v[numValues + 1]; +}; + +TEST_F(PTXAsmFormatTest, basic) { + PTXBuilder builder; + + // Create the operands needed by the instructions in the PTX code. + auto *cst = builder.newConstantOperand(1); + auto *val = builder.newOperand(v[1], "=r"); + + // create an instruction + auto &mov = *builder.create("mov.b16"); + + mov(val, cst).predicate(v[0]); + ASSERT_EQ(builder.dump(), "@$1 mov.b16 $0, 0x1;"); + + auto values = builder.getAllMLIRArgs(); + ASSERT_EQ(values[0], v[1]); // $0 -> v[1] + ASSERT_EQ(values[1], v[0]); // $1 -> v[0] + + auto constraints = builder.getConstraints(); + ASSERT_EQ(constraints, "=r,b"); // $0 -> =r, $1 -> b +} + +TEST_F(PTXAsmFormatTest, complexInstruction) { + using triton::CacheModifier; + using triton::EvictionPolicy; + + PTXBuilder builder; + + int width = 16; + int nWords = 2; + + Value predicateVal = v[0]; + Value addrVal = v[1]; + + auto addr = builder.newAddrOperand(addrVal, "l", 128 /*offset*/); + + bool isVolatile = false; + auto cache = triton::CacheModifier::CA; + auto cachePriority = triton::EvictionPolicy::EVICT_FIRST; + bool hasL2EvictPolicy = true; + + auto &ld = + builder + .create<>("ld") // + ->o("volatile", isVolatile) + .global() + .o("ca", cache == CacheModifier::CA) + .o("cg", cache == CacheModifier::CG) + .o("L1::evict_first", cachePriority == EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", cachePriority == EvictionPolicy::EVICT_LAST) + .o("L1::cache_hint", hasL2EvictPolicy) + .v(nWords) + .b(width); + + // Link the instruction to operands + ld(addr).predicate(predicateVal); + + EXPECT_EQ( + builder.dump(), + "@$1 ld.global.ca.L1::evict_first.L1::cache_hint.v2.b16 [ $0 + 128 ];"); + auto values = builder.getAllMLIRArgs(); + EXPECT_EQ(values[0], addrVal); // $0 -> predicate + EXPECT_EQ(values[1], predicateVal); // $1 -> addr + EXPECT_EQ(builder.getConstraints(), "l,b"); +} + +TEST_F(PTXAsmFormatTest, MultiLinePTX) { + PTXBuilder builder; + + auto *constVal = builder.newConstantOperand(1); + auto *valVal0 = builder.newOperand(v[1], "=r"); + auto *valVal1 = builder.newOperand(v[2], "=r"); + + auto &mov = *builder.create("mov"); + + mov(valVal0, constVal); + mov(valVal1, constVal); + mov(valVal1, valVal0); + + EXPECT_EQ(builder.dump(), "mov $0, 0x1;\n\t" + "mov $1, 0x1;\n\t" + "mov $1, $0;"); + + auto values = builder.getAllMLIRArgs(); + EXPECT_EQ(values[0], v[1]); // $0 -> v[1] + EXPECT_EQ(values[1], v[2]); // $1 -> v[2] +} + +TEST_F(PTXAsmFormatTest, onlyAttachMLIRArgs) { + PTXBuilder builder; + const char *ptxCode = + ".param .b64 param0;\n" // prepare param0 (format string) + "st.param.b64 [param0], %0;\n" + "st.param.b64 [param0], %1;\n" + "st.param.b64 [param0], %2;\n"; + + auto &ptxSnippet = *builder.create(ptxCode); + auto *opr0 = builder.newOperand(v[0], "r"); + auto *opr1 = builder.newOperand(v[1], "r"); + auto *opr2 = builder.newOperand(v[2], "r"); + ptxSnippet({opr1, opr2, opr0}, true); + + EXPECT_EQ(builder.dump(), ptxCode); + ASSERT_EQ(builder.getAllMLIRArgs()[0], v[1]); + ASSERT_EQ(builder.getAllMLIRArgs()[1], v[2]); + ASSERT_EQ(builder.getAllMLIRArgs()[2], v[0]); + ASSERT_EQ(builder.getAllMLIRArgs().size(), 3); +} + +} // namespace triton +} // namespace mlir + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/enflame/include/triton/third_party/proton/.gitignore b/third_party/enflame/include/triton/third_party/proton/.gitignore new file mode 100644 index 000000000..a44d8282a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/.gitignore @@ -0,0 +1,5 @@ +build/ +proton.egg-info +proton/_C/libproton.so + +*.hatchet diff --git a/third_party/enflame/include/triton/third_party/proton/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/CMakeLists.txt new file mode 100644 index 000000000..7dfbe35f1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/CMakeLists.txt @@ -0,0 +1,78 @@ +project(Proton LANGUAGES CXX) + +set(PROTON_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}/csrc") + +# ============ Check for includes ============= +if(NOT CUPTI_INCLUDE_DIR) + message(FATAL_ERROR "CUPTI include directory not defined") +endif() +if(NOT ROCTRACER_INCLUDE_DIR) + message(FATAL_ERROR "ROCTRACER include directory not defined") +endif() +if(NOT JSON_INCLUDE_DIR) + message(FATAL_ERROR "JSON include directory not defined") +endif() + +# ============ Dependencies ============= +find_package(Python3 REQUIRED Interpreter Development.Module) +find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") + +# ============ Define a GLOBAL property to store object-libraries ============ +set_property(GLOBAL PROPERTY PROTON_LIBS "") + +# ============ Define a function to create object libraries ============ +function(add_proton_library name) + add_library(${name} OBJECT ${ARGN}) + + target_link_libraries(${name} PRIVATE Python3::Module pybind11::headers) + + # Use system to skip warnings caused by legacy clang compilers + target_include_directories(${name} + SYSTEM PRIVATE + "${ROCTRACER_INCLUDE_DIR}" + ) + + target_include_directories(${name} + PRIVATE + "${CUPTI_INCLUDE_DIR}" + "${JSON_INCLUDE_DIR}" + "${PROTON_SRC_DIR}/include" + ) + + # If HIP is AMD-based + target_compile_definitions(${name} PRIVATE __HIP_PLATFORM_AMD__) + + # Append this library name to the GLOBAL property "PROTON_LIBS" + set_property(GLOBAL APPEND PROPERTY PROTON_LIBS ${name}) +endfunction() + +# ============ Add subdirectory with actual code that calls add_proton_library ============ +add_subdirectory("${PROTON_SRC_DIR}") + +# ============ Possibly handle macOS specifics ============ +if(APPLE) + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") + # Other platforms build with -flto, but we found that this adds significant overhead to our macos CI without providing a major benefit. + set(PROTON_PYTHON_LDFLAGS "-undefined dynamic_lookup") +endif() + +# ============ Collect all object libraries from property and build final shared lib ============ +get_property(_proton_obj_libs GLOBAL PROPERTY PROTON_LIBS) + +if(NOT _proton_obj_libs) + message(WARNING "No object libraries were defined in 'PROTON_LIBS'!") +endif() + +set(_proton_obj_sources "") +foreach(_lib IN LISTS _proton_obj_libs) + list(APPEND _proton_obj_sources $) + message(STATUS "Collecting object files from ${_lib}") +endforeach() + +add_library(proton SHARED ${_proton_obj_sources}) + +target_link_libraries(proton PRIVATE Python3::Module) +# Apply any macOS linker flags or extra link options +if(PROTON_PYTHON_LDFLAGS) + target_link_options(proton PRIVATE ${PROTON_PYTHON_LDFLAGS}) +endif() diff --git a/third_party/enflame/include/triton/third_party/proton/README.md b/third_party/enflame/include/triton/third_party/proton/README.md new file mode 100644 index 000000000..bf9acdee6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/README.md @@ -0,0 +1,319 @@ +# Proton - A Profiler for Triton + +## Introduction + +Proton is a lightweight profiler for Triton, designed to be used for code written in Python and to invoke underlying GPU kernels. Proton provides insightful information about the program context, metadata, and hardware performance metrics of the GPU kernels invoked. + +## Installation + +The following command installs the latest version of Proton. + +```bash +git clone https://github.com/triton-lang/triton +cd triton/python +pip install . +``` + +To **not build** Proton, you can set the `TRITON_BUILD_PROTON` environment variable to `OFF`: + +```bash +TRITON_BUILD_PROTON=OFF pip install . +``` + +## Usage + +### Basic usage + +More examples can be found in the [tutorials](tutorials) directory. + +Proton can be used to profile *functions* and *regions* in Python code. + +- The following examples demonstrate how to use Proton to profile a simple Python function. + +```python +import triton.profiler as proton + +# name: The path to the profile data +# context: The method used to annotate the context of each GPU kernel. Currently, "shadow" and "python" are supported. +session_id = proton.profile(func, name="profile_name", context="python")(args) +``` + +- The following examples demonstrate how to use Proton to profile a region in Python code. + +```python +session_id = proton.start(name="profile_name", context="python") +... +# Skip a region +proton.deactivate(session_id) +... +# Restart profiling +proton.activate(session_id) +... +# Write out the profile data and finalize the profiler +proton.finalize() +``` + +### Scope + +Unlike the *python* context that provide users with files, functions, and lines where the GPU kernels are invoked, the *shadow* context provides users with the annotated regions in the code. The following example demonstrates how to use the *shadow* context. + +```python +import triton.profiler as proton + + +session_id = proton.start(name="profile_name", context="shadow") + +with proton.scope("test0"): + with proton.scope("test1"): + foo[1,](x, y) +with proton.scope("test2"): + foo[1,](x, y) + +... +proton.finalize() +``` + +The *scope* utility also accepts flexible metrics, provided with a dictionary that maps from a string (metric name) to a value (int or float). +Proton will aggregate the metrics for each scope and write them to the profile data. +It is useful for users to understand the performance of the model at a high level. + +```python +with proton.scope("test0", {"bytes": 1000}): + with proton.scope("test1", {"bytes": 2000}): + foo[1,](x, y) +with proton.scope("test2", {"bytes": 3000}): + foo[1,](x, y) +``` + +### Hook + +```python +import triton.profiler as proton +from typing import NamedTuple + +# hook: When hook="triton", it enables proton to invoke launch_metadata function before launching the GPU kernel +proton.start("profile_name", hook="triton") + +def metadata_fn( + grid: tuple, + metadata: NamedTuple, + args: dict +): + return {"name": "", "flops8": 1.0} + +@triton.jit(launch_metadata=metadata_fn) +def foo(x, y): + tl.store(y, tl.load(x)) +``` + +The `metadata_fn` function is called before launching the GPU kernel to provide metadata for the GPU kernel, which returns a dictionary that maps from a string (metadata name) to a value (int or float). + +Currently, **only the triton hook is supported**. In the dictionary returned by the `metadata_fn` function, we can supply the following keys: + +```python +name: str # The name of the kernel +flops8: float # The number of 8-bit floating-point operations +flops16: float # The number of 16-bit floating-point operations +flops32: float # The number of 32-bit floating-point operations +flops64: float # The number of 64-bit floating-point operations +bytes: int # The number of bytes expected to be transferred +``` + +### Command line + +Proton can be used as a command-line tool to profile Python scripts and Pytest tests. +The following examples demonstrate how to use Proton command-line. + +```bash +proton [options] script.py [script_args] [script_options] +proton [options] pytest [pytest_args] [script_options] +python -m triton.profiler.proton [options] script.py [script_args] [script_options] +proton --instrument=[instrumentation pass] script.py +``` + +When profiling in the command line mode, the `proton.start` and `proton.finalize` functions are automatically called before and after the script execution. Any `proton.start` and `proton.finalize` functions in the script are ignored. Also, in the command line mode, only a single *session* is supported. Therefore, `proton.deactivate(session_id=1)` is invalid, while `proton.deactivate(session_id=0)` is valid. + +### Visualizing the profile data + +By default, proton profiles are in the *json* format and can be read by *Hatchet*. The following command visualizes the profile data on terminal. + +```bash +pip install llnl-hatchet +proton-viewer -m time/s +``` + +NOTE: `pip install hatchet` does not work because the API is slightly different. + +### Visualizing sorted profile data + +In addition visualizing the profile data on terminal through Hatchet. A sorted list of the kernels by the first metric can be done using the --print-sorted flag with proton-viewer + +```bash +proton-viewer -m time/ns,time/% --print-sorted +``` + +prints the sorted kernels by the time/ns since it is the first listed. + +More options can be found by running the following command. + +```bash +proton-viewer -h +``` + +## Advanced features and knowledge + +### Thread management + +We guarantee that any call to `libproton.so`, such as `enter_scope`, is synchronized using explicit locks. +For operations that do not trigger calls to libproton.so—including callbacks to CUDA/HIP APIs—we use separated locks to protect data structures that may be accessed concurrently by multiple threads. +For example, the `enter_op` method in `OpInterface` can be invoked by the main thread that involves triton operators, as well as by helper threads that invoke torch operators. + +### `cpu_timed_scope` + +`cpu_timed_scope` is a utility that wraps `scope` to measure the CPU time of a scope along with other metrics. +The following example demonstrates how to use `cpu_timed_scope`: + +```python +import triton.profiler as proton + +with proton.cpu_timed_scope("test"): + foo[1,](x, y) +``` + +The `cpu_timed_scope` output metric is referred to as `cpu_time`, while `time` represents accelerator (e.g., GPU) time. +The key distinction between `cpu_time` and `time` lies in their inclusivity: `cpu_time` is exclusive, whereas `time` is inclusive. +This difference arises because the time spent on individual kernels represents the smallest measurable time granularity, and each kernel is mutually exclusive. +This exclusivity allows time to be accurately accumulated across parent scopes for `time`. +In contrast, `cpu_time` measures the time within a specific scope. +Since a parent scope encompasses the time spent in its child scopes, summing `cpu_time` from child scope into parent scope would result in double counting. +To visualize both the CPU and GPU time, we can use the following command: + +```bash +proton-viewer -m time/ns,cpu_time/ns +``` + +### Metrics naming + +Custom metrics should follow this format: `metric_name (unit) (type)`. +We prefer no space within the metric name. +`unit` and `type` are optional fields. + +There are three types of metrics in proton: inclusive, exclusive, and property metrics. +By default, a metric is inclusive. +The metric types are distinguished by the suffix of their names. +The following table shows the suffix for each type and its meaning: + +| Suffix | Name | Meaning | +| --- | --- | --- | +| (inc) or "" | Inclusive metric | The metric is accumulated at a scope and can be propagated to the parent scope. | +| (exc) | Exclusive metric | The metric is accumulated at a scope and cannot be propagated to the parent scope. | +| (pty) | Property metric | The metric is a property of the scope and cannot be accumulated or propagated. | + +### State annotation + +In addition to `proton.scope`, we can also customize the call path of each GPU operation using `proton.state`. + +`state` is different from `scope` in several ways: + +1. State is not recursive; each operation can have only a single state. Inner most state will overwrite the outer most state. +2. A states is a suffix, meaning that the original call path will append a state above the name of each kernel. +3. State is compatible with both Python and shadow contexts. + +The following example demonstrates a basic use of state: + +```python +with proton.scope("test"): + with proton.state("state0"): + with proton.scope("test0"): + foo0[1,](x, y) + with proton.scope("test1"): + foo1[1,](x, y) +``` + +The call path of `foo1` will be `test->test1->state0`. + +### Instrumentation (experimental) + +In addition to profiling, Proton also incorporates MLIR/LLVM based compiler instrumentation passes to get Triton level analysis +and optimization information. This feature is under active development and the list of available passes is expected to grow. + +#### Available passes + +print-mem-spaces: this pass prints the load and store address spaces (e.g. global, flat, shared) chosen by the compiler and attributes back to Triton source information. + +Example usage with the Proton matmul tutorial: + +```bash +$ proton --instrument=print-mem-spaces matmul.py +0 matmul_kernel matmul.py:180:20 SHARED STORE +1 matmul_kernel matmul.py:181:20 SHARED STORE +2 matmul_kernel matmul.py:180:20 SHARED LOAD +3 matmul_kernel matmul.py:181:20 SHARED LOAD +``` + +Notes: The instrument functionality is currently only available from the command line. Additionally the instrument and profile command line arguments can not be use simulantously. + +### Instruction sampling (experimental) + +Proton supports instruction sampling on NVIDIA GPUs. +Please note that this is an experimental feature and may not work on all GPUs. +You may experience ~20x end-to-end overhead when using instruction sampling, although the overhead for each individual GPU kernel is negligible. +The overhead is mostly caused by data transfer and processing on the CPU. +Additionally, the proton-viewer options `-i -d -t ` can be helpful for filtering out GPU kernels that are not of interest. +The following example demonstrates how to use instruction sampling: + +```python +import triton.profiler as proton + +proton.start(name="profile_name", context="shadow", backend="cupti_pcsampling") +``` + +## Proton *vs* nsys + +- Runtime overhead (up to 1.5x) + +Proton has a lower profiling overhead than nsys. Even for workload with a large number of small GPU kernels, proton triggers less than ~1.5x overhead. + +For GPU-bound workload, both proton and nsys has similar overhead, with little impact on the workload. + +The lower overhead of proton is due to its less profiling metrics and callbacks compared to nsys. + +- Profile size (significantly smaller than nsys) + +nsys traces and records every GPU kernel, while proton aggregates the metrics of GPU kernels under the same calling context. + +As a result, proton's profile size can be up to thousands of times smaller than nsys's profile size, depending on the running time. + +- Portability (support different GPUs) + +Proton is designed to be portable and can be used on AMD GPUs. nsys only supports NVIDIA GPUs. + +- Insights (more insightful than nsys on triton kernels) + +Proton can register hooks to analyze the metadata of triton kernels, while nsys cannot. **Note** that the hooks do add additional overhead to proton. + +## Proton *vs* ncu + +Similar to the comparison between Proton and Nsight Systems (Nsys), Proton has a lower profiling overhead than Nsight Compute (NCU). We also plan to support instruction sampling on AMD GPUs. +However, Nsight Compute supports the collection of more detailed metrics than Proton, such as memory access patterns, memory transactions, and other instruction-level metrics. +In contrast, Proton only supports instruction sampling and is designed to be lightweight and portable. + +## Known issues + +- CUDA graph + +`hooks` cannot be used to accurately accumulate the number of FLOPs in CUDA graph mode profiling because kernels are captured and launched separately; metrics are not accumulated when kernels are launched in graph mode. This issue can be circumvented by using `scope` to supply FLOPs. + +If profiling is initiated after CUDA graph capturing, there may be minor memory leak issues. +This is because the number of kernels in a graph instance (i.e., `cuGraphExec`) is unknown, preventing the deletion of mappings between the kernel ID and the graph ID. + +- Instruction sampling + +If you encounter permission related problems when using instruction sampling, you can lookup this [page](https://developer.nvidia.com/nvidia-development-tools-solutions-err_nvgpuctrperm-permission-issue-performance-counters) for help. + +The overhead of instruction sampling on NVIDIA GPUs is about 20x using Proton because we haven't enabled continuous sampling yet. +Continuous sampling can allow for more runtime optimizations, but it makes it more challenging to attribute performance data back to the GPU kernels because: (1) it enables profiling of concurrent kernels, (2) it doesn't allow profiling of time and instruction samples simultaneously, and (3) it works best if we have a separate thread dedicated to attributing instruction samples to the GPU kernels + +- Visible devices on AMD GPUs + +Environment variables such as `HIP_VISIBLE_DEVICES`, and `CUDA_VISIBLE_DEVICES` are not supported on AMD GPUs. Once it's set, we cannot find a valid mapping between the device ID returned by RocTracer and the physical device ID. Instead, `ROCR_VISIBLE_DEVICES` is recommended to be used. diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/csrc/CMakeLists.txt new file mode 100644 index 000000000..772b58a60 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/CMakeLists.txt @@ -0,0 +1,5 @@ +add_proton_library(Proton + Proton.cpp +) + +add_subdirectory(lib) diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/Proton.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/Proton.cpp new file mode 100644 index 000000000..b4840cca9 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/Proton.cpp @@ -0,0 +1,87 @@ +#include "Proton.h" + +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +using namespace proton; + +void initProton(pybind11::module &&m) { + using ret = pybind11::return_value_policy; + using namespace pybind11::literals; + + m.def("start", + [](const std::string &path, const std::string &contextSourceName, + const std::string &dataName, const std::string &profilerName, + const std::string &profilerPath) { + auto sessionId = SessionManager::instance().addSession( + path, profilerName, profilerPath, contextSourceName, dataName); + SessionManager::instance().activateSession(sessionId); + return sessionId; + }); + + m.def("activate", [](size_t sessionId) { + SessionManager::instance().activateSession(sessionId); + }); + + m.def("activate_all", + []() { SessionManager::instance().activateAllSessions(); }); + + m.def("deactivate", [](size_t sessionId) { + SessionManager::instance().deactivateSession(sessionId); + }); + + m.def("deactivate_all", + []() { SessionManager::instance().deactivateAllSessions(); }); + + m.def("finalize", [](size_t sessionId, const std::string &outputFormat) { + auto outputFormatEnum = parseOutputFormat(outputFormat); + SessionManager::instance().finalizeSession(sessionId, outputFormatEnum); + }); + + m.def("finalize_all", [](const std::string &outputFormat) { + auto outputFormatEnum = parseOutputFormat(outputFormat); + SessionManager::instance().finalizeAllSessions(outputFormatEnum); + }); + + m.def("record_scope", []() { return Scope::getNewScopeId(); }); + + m.def("enter_scope", [](size_t scopeId, const std::string &name) { + SessionManager::instance().enterScope(Scope(scopeId, name)); + }); + + m.def("exit_scope", [](size_t scopeId, const std::string &name) { + SessionManager::instance().exitScope(Scope(scopeId, name)); + }); + + m.def("enter_op", [](size_t scopeId, const std::string &name) { + SessionManager::instance().enterOp(Scope(scopeId, name)); + }); + + m.def("exit_op", [](size_t scopeId, const std::string &name) { + SessionManager::instance().exitOp(Scope(scopeId, name)); + }); + + m.def("enter_state", [](const std::string &state) { + SessionManager::instance().setState(state); + }); + + m.def("exit_state", + []() { SessionManager::instance().setState(std::nullopt); }); + + m.def("add_metrics", + [](size_t scopeId, + const std::map &metrics) { + SessionManager::instance().addMetrics(scopeId, metrics); + }); + + pybind11::bind_map>(m, "MetricMap"); +} + +PYBIND11_MODULE(libproton, m) { + m.doc() = "Python bindings to the Proton API"; + initProton(std::move(m.def_submodule("proton"))); +} diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Context.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Context.h new file mode 100644 index 000000000..4baa357d9 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Context.h @@ -0,0 +1,142 @@ +#ifndef PROTON_CONTEXT_CONTEXT_H_ +#define PROTON_CONTEXT_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +/// A context is a named object. +struct Context { + std::string name{}; + + Context() = default; + Context(const std::string &name) : name(name) {} + virtual ~Context() = default; + + bool operator==(const Context &other) const { return name == other.name; } + bool operator!=(const Context &other) const { return !(*this == other); } + bool operator<(const Context &other) const { return name < other.name; } + bool operator>(const Context &other) const { return name > other.name; } + bool operator<=(const Context &other) const { return !(*this > other); } + bool operator>=(const Context &other) const { return !(*this < other); } +}; + +/// A context source is an object that can provide a list of contexts. +class ContextSource { +public: + ContextSource() = default; + virtual ~ContextSource() = default; + + std::vector getContexts() { + auto contexts = getContextsImpl(); + if (state.has_value()) { + contexts.push_back(state.value()); + } + return contexts; + } + + void setState(std::optional state) { ContextSource::state = state; } + +protected: + virtual std::vector getContextsImpl() = 0; + static thread_local std::optional state; +}; + +/// A scope is a context with a unique identifier. +struct Scope : public Context { + const static size_t DummyScopeId = std::numeric_limits::max(); + static std::atomic scopeIdCounter; + + static size_t getNewScopeId() { return scopeIdCounter++; } + + size_t scopeId{}; + + explicit Scope(size_t scopeId) : Context(), scopeId(scopeId) {} + + explicit Scope(const std::string &name) : Context(name) { + scopeId = getNewScopeId(); + } + + Scope(size_t scopeId, const std::string &name) + : scopeId(scopeId), Context(name) {} + + Scope() : Scope(DummyScopeId, "") {} + + bool operator==(const Scope &other) const { + return scopeId == other.scopeId && name == other.name; + } + bool operator!=(const Scope &other) const { return !(*this == other); } + bool operator<(const Scope &other) const { + return scopeId < other.scopeId || name < other.name; + } + bool operator>(const Scope &other) const { + return scopeId > other.scopeId || name > other.name; + } + bool operator<=(const Scope &other) const { return !(*this > other); } + bool operator>=(const Scope &other) const { return !(*this < other); } +}; + +/// A scope interface allows to instrument handles before and after a scope. +/// Scopes can be nested. +class ScopeInterface { +public: + ScopeInterface() = default; + virtual ~ScopeInterface() = default; + virtual void enterScope(const Scope &scope) = 0; + virtual void exitScope(const Scope &scope) = 0; +}; + +/// An op interface allows to instrument handles before and after an operation, +/// which cannot be nested. +class OpInterface { +public: + OpInterface() = default; + virtual ~OpInterface() = default; + void enterOp(const Scope &scope) { + if (isOpInProgress()) { + return; + } + startOp(scope); + setOpInProgress(true); + } + void exitOp(const Scope &scope) { + if (!isOpInProgress()) { + return; + } + stopOp(scope); + setOpInProgress(false); + } + +protected: + virtual void startOp(const Scope &scope) = 0; + virtual void stopOp(const Scope &scope) = 0; + virtual bool isOpInProgress() = 0; + virtual void setOpInProgress(bool value) = 0; +}; + +class ThreadLocalOpInterface : public OpInterface { +public: + using OpInterface::OpInterface; + +protected: + bool isOpInProgress() override final { return opInProgress[this]; } + void setOpInProgress(bool value) override final { + opInProgress[this] = value; + if (opInProgress.size() > MAX_CACHE_OBJECTS && !value) + opInProgress.erase(this); + } + +private: + inline static const int MAX_CACHE_OBJECTS = 10; + static thread_local std::map opInProgress; +}; + +} // namespace proton + +#endif // PROTON_CONTEXT_CONTEXT_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Python.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Python.h new file mode 100644 index 000000000..9c34d0f6d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Python.h @@ -0,0 +1,19 @@ +#ifndef PROTON_CONTEXT_PYTHON_H_ +#define PROTON_CONTEXT_PYTHON_H_ + +#include "Context.h" + +namespace proton { + +/// Unwind the Python stack and early return a list of contexts. +class PythonContextSource : public ContextSource { +public: + PythonContextSource() = default; + +private: + std::vector getContextsImpl() override; +}; + +} // namespace proton + +#endif // PROTON_CONTEXT_PYTHON_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Shadow.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Shadow.h new file mode 100644 index 000000000..37a891e5a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Context/Shadow.h @@ -0,0 +1,40 @@ +#ifndef PROTON_CONTEXT_SHADOW_H_ +#define PROTON_CONTEXT_SHADOW_H_ + +#include "Context.h" +#include + +namespace proton { + +/// ShadowContextSource is designed to: +/// +/// - Maintain a main context stack for the main thread. +/// - Provide thread-local context stacks for individual threads. +/// - Allow threads to inherit and shadow the main context stack with their +/// own user-defined scopes. +/// +/// This implementation is suited for use cases like PyTorch, where: +/// +/// - The main thread initializes the main context stack during session setup. +/// - The backward phase spawns multiple CPU threads. +class ShadowContextSource : public ContextSource, public ScopeInterface { +public: + ShadowContextSource() = default; + + void enterScope(const Scope &scope) override; + + void exitScope(const Scope &scope) override; + +private: + std::vector getContextsImpl() override; + + void initializeThreadContext(); + + std::vector *mainContextStack{}; + static thread_local bool contextInitialized; + static thread_local std::vector threadContextStack; +}; + +} // namespace proton + +#endif // PROTON_CONTEXT_CONTEXT_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/Data.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/Data.h new file mode 100644 index 000000000..e8abd1884 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/Data.h @@ -0,0 +1,56 @@ +#ifndef PROTON_DATA_DATA_H_ +#define PROTON_DATA_DATA_H_ + +#include "Context/Context.h" +#include "Metric.h" +#include +#include +#include +#include + +namespace proton { + +enum class OutputFormat { Hatchet, Count }; + +class Data : public ScopeInterface { +public: + Data(const std::string &path, ContextSource *contextSource = nullptr) + : path(path), contextSource(contextSource) {} + virtual ~Data() = default; + + /// Add an op to the data. + /// If scopeId is already present, add an op under/inside it. + /// Otherwise obtain the current context and append opName to it if opName is + /// not empty. + virtual size_t addOp(size_t scopeId, const std::string &opName = {}) = 0; + + /// Add a single metric to the data. + virtual void addMetric(size_t scopeId, std::shared_ptr metric) = 0; + + /// Add multiple metrics to the data. + virtual void + addMetrics(size_t scopeId, + const std::map &metrics) = 0; + + /// Clear all caching data. + virtual void clear() = 0; + + /// Dump the data to the given output format. + void dump(OutputFormat outputFormat); + +protected: + /// The actual implementation of the dump operation. + virtual void doDump(std::ostream &os, OutputFormat outputFormat) const = 0; + + mutable std::shared_mutex mutex; + const std::string path{}; + ContextSource *contextSource{}; +}; + +OutputFormat parseOutputFormat(const std::string &outputFormat); + +const std::string outputFormatToString(OutputFormat outputFormat); + +} // namespace proton + +#endif // PROTON_DATA_DATA_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/Metric.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/Metric.h new file mode 100644 index 000000000..f03704ead --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/Metric.h @@ -0,0 +1,247 @@ +#ifndef PROTON_DATA_METRIC_H_ +#define PROTON_DATA_METRIC_H_ + +#include "Utility/String.h" +#include "Utility/Traits.h" +#include +#include + +namespace proton { + +enum class MetricKind { Flexible, Kernel, PCSampling, Count }; + +using MetricValueType = std::variant; + +/// A metric is a class that can be associated with a context. +/// `Metric` is the base class for all metrics. +/// Each `Metric` has a name and a set of values. +/// Each value could be of type `uint64_t`, `int64_t`, or `double`, +/// Each value can be inclusive (inc), exclusive (exc), or a property (pty). +/// Inclusive values are aggregated by addition and can be propagated to the +/// parent. +/// Exclusive values can be aggregated at a context but cannot be +/// propagated to the parent. +/// Property values are not aggregated and cannot be propagated to the parent. +class Metric { +public: + Metric(MetricKind kind, size_t size) : kind(kind), values(size) {} + + virtual ~Metric() = default; + + virtual const std::string getName() const = 0; + + virtual const std::string getValueName(int valueId) const = 0; + + virtual bool isProperty(int valueId) const = 0; + + virtual bool isExclusive(int valueId) const = 0; + + std::vector getValues() const { return values; } + + MetricValueType getValue(int valueId) { return values[valueId]; } + + /// Update a specific value id with the new value. + void updateValue(int valueId, MetricValueType value) { + // Handle string and other values separately + if (std::holds_alternative(value)) { + values[valueId] = std::get(value); + } else { + std::visit( + [&](auto &¤tValue, auto &&otherValue) { + using CurrentType = std::decay_t; + using ValueType = std::decay_t; + if constexpr (std::is_same_v) { + if (isProperty(valueId)) { + currentValue = otherValue; + } else { + currentValue += otherValue; + } + } + }, + values[valueId], value); + } + } + + /// Update all values of the metric with the same value. + void updateValue(MetricValueType value) { + for (int i = 0; i < values.size(); ++i) { + updateValue(i, value); + } + } + + /// Update all values with another metric. + void updateMetric(Metric &other) { + for (int i = 0; i < values.size(); ++i) { + updateValue(i, other.values[i]); + } + } + + MetricKind getKind() const { return kind; } + +private: + const MetricKind kind; + const std::string name; + +protected: + std::vector values; +}; + +/// A flexible metric is provided by users but not the backend profiling API. +/// Each flexible metric has a single value. +class FlexibleMetric : public Metric { +public: + FlexibleMetric(const std::string &valueName, + std::variant value) + : Metric(MetricKind::Flexible, 1), valueName(valueName) { + this->exclusive = endWith(valueName, "(exc)"); + this->property = endWith(valueName, "(pty)"); + this->valueName = trim(replace(this->valueName, "(exc)", "")); + this->valueName = trim(replace(this->valueName, "(pty)", "")); + std::visit([&](auto &&v) { this->values[0] = v; }, value); + } + + FlexibleMetric(const std::string &valueName, + std::variant value, bool property, + bool exclusive) + : Metric(MetricKind::Flexible, 1), valueName(valueName), + property(property), exclusive(exclusive) { + std::visit([&](auto &&v) { this->values[0] = v; }, value); + } + + const std::string getName() const override { return "FlexibleMetric"; } + + const std::string getValueName(int valueId) const override { + return valueName; + } + + bool isProperty(int valueId) const override { return property; } + + bool isExclusive(int valueId) const override { return exclusive; } + +private: + bool property{}; + bool exclusive{}; + std::string valueName; +}; + +class KernelMetric : public Metric { +public: + enum kernelMetricKind : int { + StartTime, + EndTime, + Invocations, + Duration, + DeviceId, + DeviceType, + Count, + }; + + KernelMetric() : Metric(MetricKind::Kernel, kernelMetricKind::Count) {} + + KernelMetric(uint64_t startTime, uint64_t endTime, uint64_t invocations, + uint64_t deviceId, uint64_t deviceType) + : KernelMetric() { + this->values[StartTime] = startTime; + this->values[EndTime] = endTime; + this->values[Invocations] = invocations; + this->values[Duration] = endTime - startTime; + this->values[DeviceId] = deviceId; + this->values[DeviceType] = deviceType; + } + + virtual const std::string getName() const { return "KernelMetric"; } + + virtual const std::string getValueName(int valueId) const { + return VALUE_NAMES[valueId]; + } + + virtual bool isProperty(int valueId) const { return PROPERTY[valueId]; } + + virtual bool isExclusive(int valueId) const { return EXCLUSIVE[valueId]; } + +private: + const static inline bool PROPERTY[kernelMetricKind::Count] = { + true, true, false, false, true, true}; + const static inline bool EXCLUSIVE[kernelMetricKind::Count] = { + false, false, false, false, true, true}; + const static inline std::string VALUE_NAMES[kernelMetricKind::Count] = { + "start_time (ns)", "end_time (ns)", "count", + "time (ns)", "device_id", "device_type", + }; +}; + +class PCSamplingMetric : public Metric { +public: + enum PCSamplingMetricKind : int { + NumSamples, + NumStalledSamples, + StalledBranchResolving, + StalledNoInstruction, + StalledShortScoreboard, + StalledWait, + StalledLongScoreboard, + StalledTexThrottle, + StalledBarrier, + StalledMembar, + StalledIMCMiss, + StalledMIOThrottle, + StalledMathPipeThrottle, + StalledDrain, + StalledLGThrottle, + StalledNotSelected, + StalledMisc, + StalledDispatchStall, + StalledSleeping, + StalledSelected, + Count, + }; + + PCSamplingMetric() + : Metric(MetricKind::PCSampling, PCSamplingMetricKind::Count) {} + + PCSamplingMetric(PCSamplingMetricKind kind, uint64_t samples, + uint64_t stalledSamples) + : PCSamplingMetric() { + this->values[kind] = stalledSamples; + this->values[PCSamplingMetricKind::NumSamples] = samples; + this->values[PCSamplingMetricKind::NumStalledSamples] = stalledSamples; + } + + virtual const std::string getName() const { return "PCSamplingMetric"; } + + virtual const std::string getValueName(int valueId) const { + return VALUE_NAMES[valueId]; + } + + virtual bool isProperty(int valueId) const { return false; } + + virtual bool isExclusive(int valueId) const { return false; } + +private: + const static inline std::string VALUE_NAMES[PCSamplingMetricKind::Count] = { + "num_samples", + "num_stalled_samples", + "stalled_branch_resolving", + "stalled_no_instruction", + "stalled_short_scoreboard", + "stalled_wait", + "stalled_long_scoreboard", + "stalled_tex_throttle", + "stalled_barrier", + "stalled_membar", + "stalled_imc_miss", + "stalled_mio_throttle", + "stalled_math_pipe_throttle", + "stalled_drain", + "stalled_lg_throttle", + "stalled_not_Selected", + "stalled_misc", + "stalled_dispatch_stall", + "stalled_sleeping", + "stalled_selected", + }; +}; + +} // namespace proton + +#endif // PROTON_DATA_METRIC_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/TraceData.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/TraceData.h new file mode 100644 index 000000000..dc73343df --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/TraceData.h @@ -0,0 +1,35 @@ +#ifndef PROTON_DATA_TRACE_DATA_H_ +#define PROTON_DATA_TRACE_DATA_H_ + +#include "Data.h" + +namespace proton { + +class TraceData : public Data { +public: + using Data::Data; + virtual ~TraceData() = default; + + size_t addOp(size_t scopeId, const std::string &name) override; + + void addMetric(size_t scopeId, std::shared_ptr metric) override; + + void + addMetrics(size_t scopeId, + const std::map &metrics) override; + + void clear() override; + +protected: + // ScopeInterface + void enterScope(const Scope &scope) override final; + + void exitScope(const Scope &scope) override final; + +private: + void doDump(std::ostream &os, OutputFormat outputFormat) const override; +}; + +} // namespace proton + +#endif // PROTON_DATA_TRACE_DATA_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/TreeData.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/TreeData.h new file mode 100644 index 000000000..64ef530bd --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Data/TreeData.h @@ -0,0 +1,50 @@ +#ifndef PROTON_DATA_TREE_DATA_H_ +#define PROTON_DATA_TREE_DATA_H_ + +#include "Context/Context.h" +#include "Data.h" +#include +#include + +namespace proton { + +class TreeData : public Data { +public: + TreeData(const std::string &path, ContextSource *contextSource); + virtual ~TreeData(); + + TreeData(const std::string &path) : TreeData(path, nullptr) {} + + size_t addOp(size_t scopeId, const std::string &name) override; + + void addMetric(size_t scopeId, std::shared_ptr metric) override; + + void + addMetrics(size_t scopeId, + const std::map &metrics) override; + + void clear() override; + +protected: + // ScopeInterface + void enterScope(const Scope &scope) override; + + void exitScope(const Scope &scope) override; + +private: + void init(); + void dumpHatchet(std::ostream &os) const; + void doDump(std::ostream &os, OutputFormat outputFormat) const override; + + // `tree` and `scopeIdToContextId` can be accessed by both the user thread and + // the background threads concurrently, so methods that access them should be + // protected by a (shared) mutex. + class Tree; + std::unique_ptr tree; + // ScopeId -> ContextId + std::unordered_map scopeIdToContextId; +}; + +} // namespace proton + +#endif // PROTON_DATA_TREE_DATA_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/Device.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/Device.h new file mode 100644 index 000000000..3e414c824 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/Device.h @@ -0,0 +1,48 @@ +#ifndef PROTON_DRIVER_DEVICE_H_ +#define PROTON_DRIVER_DEVICE_H_ + +#include +#include + +namespace proton { + +enum class DeviceType { HIP, CUDA, COUNT }; + +template struct DeviceTraits; + +template <> struct DeviceTraits { + constexpr static DeviceType type = DeviceType::CUDA; + constexpr static const char *name = "CUDA"; +}; + +template <> struct DeviceTraits { + constexpr static DeviceType type = DeviceType::HIP; + constexpr static const char *name = "HIP"; +}; + +struct Device { + DeviceType type; + uint64_t id; + uint64_t clockRate; // khz + uint64_t memoryClockRate; // khz + uint64_t busWidth; + uint64_t numSms; + std::string arch; + + Device() = default; + + Device(DeviceType type, uint64_t id, uint64_t clockRate, + uint64_t memoryClockRate, uint64_t busWidth, uint64_t numSms, + std::string arch) + : type(type), id(id), clockRate(clockRate), + memoryClockRate(memoryClockRate), busWidth(busWidth), numSms(numSms), + arch(arch) {} +}; + +Device getDevice(DeviceType type, uint64_t index); + +const std::string getDeviceTypeString(DeviceType type); + +}; // namespace proton + +#endif // PROTON_DRIVER_DEVICE_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/Dispatch.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/Dispatch.h new file mode 100644 index 000000000..81f6b5b32 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/Dispatch.h @@ -0,0 +1,112 @@ +#ifndef PROTON_DRIVER_DISPATCH_H_ +#define PROTON_DRIVER_DISPATCH_H_ + +#include + +#include +#include + +#define DISPATCH_ARGS_0() +#define DISPATCH_ARGS_1(t1) t1 v1 +#define DISPATCH_ARGS_2(t1, t2) t1 v1, t2 v2 +#define DISPATCH_ARGS_3(t1, t2, t3) t1 v1, t2 v2, t3 v3 +#define DISPATCH_ARGS_4(t1, t2, t3, t4) t1 v1, t2 v2, t3 v3, t4 v4 +#define DISPATCH_ARGS_N(_4, _3, _2, _1, _0, N, ...) DISPATCH_ARGS##N +#define DISPATCH_ARGS(...) \ + DISPATCH_ARGS_N(_0, ##__VA_ARGS__, _4, _3, _2, _1, _0) \ + (__VA_ARGS__) + +#define DISPATCH_VALS_0() +#define DISPATCH_VALS_1(t1) , v1 +#define DISPATCH_VALS_2(t1, t2) , v1, v2 +#define DISPATCH_VALS_3(t1, t2, t3) , v1, v2, v3 +#define DISPATCH_VALS_4(t1, t2, t3, t4) , v1, v2, v3, v4 +#define DISPATCH_VALS_N(_4, _3, _2, _1, _0, N, ...) DISPATCH_VALS##N +#define DISPATCH_VALS(...) \ + DISPATCH_VALS_N(_0, ##__VA_ARGS__, _4, _3, _2, _1, _0) \ + (__VA_ARGS__) + +#define DEFINE_DISPATCH_TEMPLATE(CheckSuccess, FuncName, ExternLib, FuncType, \ + ...) \ + template <> \ + ExternLib::RetType FuncName(DISPATCH_ARGS(__VA_ARGS__)) { \ + typedef typename ExternLib::RetType (*FuncType##_t)(__VA_ARGS__); \ + static FuncType##_t func = nullptr; \ + return Dispatch::exec( \ + func, #FuncType DISPATCH_VALS(__VA_ARGS__)); \ + } + +#define DEFINE_DISPATCH(ExternLib, FuncName, FuncType, ...) \ + DEFINE_DISPATCH_TEMPLATE(true, FuncName, ExternLib, FuncType, __VA_ARGS__) \ + DEFINE_DISPATCH_TEMPLATE(false, FuncName, ExternLib, FuncType, __VA_ARGS__) + +namespace proton { + +struct ExternLibBase { + using RetType = int; // Generic type, can be overridden in derived structs + static constexpr const char *name = ""; // Placeholder + static constexpr RetType success = 0; // Placeholder + ExternLibBase() = delete; + ExternLibBase(const ExternLibBase &) = delete; + ExternLibBase &operator=(const ExternLibBase &) = delete; + static inline void *lib{nullptr}; + static inline std::string defaultDir{""}; +}; + +template class Dispatch { +public: + Dispatch() = delete; + + static void init(const char *name, void **lib) { + if (*lib == nullptr) { + // If not found, try to load it from the default path + auto dir = std::string(ExternLib::defaultDir); + if (dir.length() > 0) { + auto fullPath = dir + "/" + name; + *lib = dlopen(fullPath.c_str(), RTLD_LOCAL | RTLD_LAZY); + } else { + // Only if the default path is not set, we try to load it from the + // system. + // First reuse the existing handle + *lib = dlopen(name, RTLD_NOLOAD); + if (*lib == nullptr) { + // If not found, try to load it from LD_LIBRARY_PATH + *lib = dlopen(name, RTLD_LOCAL | RTLD_LAZY); + } + } + } + if (*lib == nullptr) { + throw std::runtime_error("Could not load `" + std::string(name) + "`"); + } + } + + static void check(typename ExternLib::RetType ret, const char *functionName) { + if (ret != ExternLib::success) { + throw std::runtime_error("Failed to execute " + + std::string(functionName) + " with error " + + std::to_string(ret)); + } + } + + template + static typename ExternLib::RetType + exec(FnT &handler, const char *functionName, Args... args) { + init(ExternLib::name, &ExternLib::lib); + if (handler == nullptr) { + handler = reinterpret_cast(dlsym(ExternLib::lib, functionName)); + if (handler == nullptr) { + throw std::runtime_error("Failed to load " + + std::string(ExternLib::name)); + } + } + auto ret = handler(args...); + if constexpr (CheckSuccess) { + check(ret, functionName); + } + return ret; + } +}; + +} // namespace proton + +#endif // PROTON_DRIVER_DISPATCH_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/CudaApi.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/CudaApi.h new file mode 100644 index 000000000..1178cf5db --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/CudaApi.h @@ -0,0 +1,28 @@ +#ifndef PROTON_DRIVER_GPU_CUDA_H_ +#define PROTON_DRIVER_GPU_CUDA_H_ + +#include "Driver/Device.h" +#include "cuda.h" + +namespace proton { + +namespace cuda { + +template CUresult init(int flags); + +template CUresult ctxSynchronize(); + +template CUresult ctxGetCurrent(CUcontext *pctx); + +template +CUresult deviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); + +template CUresult deviceGet(CUdevice *device, int ordinal); + +Device getDevice(uint64_t index); + +} // namespace cuda + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_CUDA_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h new file mode 100644 index 000000000..28c12f7c4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h @@ -0,0 +1,116 @@ +#ifndef PROTON_DRIVER_GPU_CUPTI_H_ +#define PROTON_DRIVER_GPU_CUPTI_H_ + +#include "cupti.h" +#include "cupti_pcsampling.h" +#include + +namespace proton { + +namespace cupti { + +template CUptiResult getVersion(uint32_t *version); + +template +CUptiResult getContextId(CUcontext context, uint32_t *pCtxId); + +template +CUptiResult activityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc funcBufferRequested, + CUpti_BuffersCallbackCompleteFunc funcBufferCompleted); + +template +CUptiResult subscribe(CUpti_SubscriberHandle *subscriber, + CUpti_CallbackFunc callback, void *userdata); + +template +CUptiResult enableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain); + +template +CUptiResult enableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, CUpti_CallbackId cbid); + +template +CUptiResult activityEnableContext(CUcontext context, CUpti_ActivityKind kind); + +template +CUptiResult activityDisableContext(CUcontext context, CUpti_ActivityKind kind); + +template +CUptiResult activityEnable(CUpti_ActivityKind kind); + +template +CUptiResult activityDisable(CUpti_ActivityKind kind); + +template CUptiResult activityFlushAll(uint32_t flag); + +template +CUptiResult activityGetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes, + CUpti_Activity **record); + +template +CUptiResult +activityPushExternalCorrelationId(CUpti_ExternalCorrelationKind kind, + uint64_t id); + +template +CUptiResult activityPopExternalCorrelationId(CUpti_ExternalCorrelationKind kind, + uint64_t *lastId); + +template +CUptiResult activitySetAttribute(CUpti_ActivityAttribute attr, + size_t *valueSize, void *value); + +template +CUptiResult unsubscribe(CUpti_SubscriberHandle subscriber); + +template CUptiResult finalize(); + +template +CUptiResult getGraphExecId(CUgraphExec graph, uint32_t *pId); + +template +CUptiResult getGraphId(CUgraph graph, uint32_t *pId); + +template +CUptiResult getCubinCrc(CUpti_GetCubinCrcParams *pParams); + +template +CUptiResult +getSassToSourceCorrelation(CUpti_GetSassToSourceCorrelationParams *pParams); + +template +CUptiResult +pcSamplingGetNumStallReasons(CUpti_PCSamplingGetNumStallReasonsParams *pParams); + +template +CUptiResult +pcSamplingGetStallReasons(CUpti_PCSamplingGetStallReasonsParams *pParams); + +template +CUptiResult pcSamplingSetConfigurationAttribute( + CUpti_PCSamplingConfigurationInfoParams *pParams); + +template +CUptiResult pcSamplingEnable(CUpti_PCSamplingEnableParams *pParams); + +template +CUptiResult pcSamplingDisable(CUpti_PCSamplingDisableParams *pParams); + +template +CUptiResult pcSamplingGetData(CUpti_PCSamplingGetDataParams *pParams); + +template +CUptiResult pcSamplingStart(CUpti_PCSamplingStartParams *pParams); + +template +CUptiResult pcSamplingStop(CUpti_PCSamplingStopParams *pParams); + +void setLibPath(const std::string &path); + +} // namespace cupti + +} // namespace proton + +#endif // PROTON_EXTERN_DISPATCH_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/HipApi.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/HipApi.h new file mode 100644 index 000000000..fadb9c425 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/HipApi.h @@ -0,0 +1,33 @@ +#ifndef PROTON_DRIVER_GPU_HIP_H_ +#define PROTON_DRIVER_GPU_HIP_H_ + +#include "Driver/Device.h" +#include "hip/hip_runtime_api.h" + +namespace proton { + +namespace hip { + +template hipError_t deviceSynchronize(); + +template +hipError_t deviceGetAttribute(int *value, hipDeviceAttribute_t attribute, + int deviceId); + +template hipError_t getDeviceCount(int *count); + +template +hipError_t getDeviceProperties(hipDeviceProp_t *prop, int deviceId); + +Device getDevice(uint64_t index); + +const std::string getHipArchName(uint64_t index); + +const char *getKernelNameRef(const hipFunction_t f); +const char *getKernelNameRefByPtr(const void *hostFunction, hipStream_t stream); + +} // namespace hip + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_HIP_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/HsaApi.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/HsaApi.h new file mode 100644 index 000000000..c694a11af --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/HsaApi.h @@ -0,0 +1,23 @@ +#ifndef PROTON_DRIVER_GPU_HSA_H_ +#define PROTON_DRIVER_GPU_HSA_H_ + +#include "Driver/Device.h" +#include "hsa/hsa_ext_amd.h" + +namespace proton { + +namespace hsa { + +template +hsa_status_t agentGetInfo(hsa_agent_t agent, hsa_agent_info_t attribute, + void *value); + +hsa_status_t iterateAgents(hsa_status_t (*callback)(hsa_agent_t agent, + void *data), + void *data); + +} // namespace hsa + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_HSA_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/RoctracerApi.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/RoctracerApi.h new file mode 100644 index 000000000..c1ab3260e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Driver/GPU/RoctracerApi.h @@ -0,0 +1,85 @@ +#ifndef PROTON_DRIVER_GPU_ROCTRACER_H_ +#define PROTON_DRIVER_GPU_ROCTRACER_H_ + +#include "roctracer/roctracer.h" + +namespace proton { + +namespace roctracer { + +template +roctracer_status_t setProperties(roctracer_domain_t domain, void *properties); + +template +roctracer_status_t getTimestamp(roctracer_timestamp_t *timestamp); + +void start(); + +void stop(); + +// +// Callbacks +// + +template +roctracer_status_t enableDomainCallback(activity_domain_t domain, + activity_rtapi_callback_t callback, + void *arg); + +template +roctracer_status_t disableDomainCallback(activity_domain_t domain); + +template +roctracer_status_t enableOpCallback(activity_domain_t domain, uint32_t op, + activity_rtapi_callback_t callback, + void *arg); + +template +roctracer_status_t disableOpCallback(activity_domain_t domain, uint32_t op); + +// +// Activity +// + +template +roctracer_status_t openPool(const roctracer_properties_t *properties); + +template roctracer_status_t closePool(); + +template +roctracer_status_t enableOpActivity(activity_domain_t domain, uint32_t op); + +template +roctracer_status_t enableDomainActivity(activity_domain_t domain); + +template +roctracer_status_t disableOpActivity(activity_domain_t domain, uint32_t op); + +template +roctracer_status_t disableDomainActivity(activity_domain_t domain); + +template roctracer_status_t flushActivity(); + +template +roctracer_status_t getNextRecord(const activity_record_t *record, + const activity_record_t **next); + +char *getOpString(uint32_t domain, uint32_t op, uint32_t kind); + +// +// External correlation +// + +template +roctracer_status_t +activityPushExternalCorrelationId(activity_correlation_id_t id); + +template +roctracer_status_t +activityPopExternalCorrelationId(activity_correlation_id_t *last_id); + +} // namespace roctracer + +} // namespace proton + +#endif // PROTON_EXTERN_DISPATCH_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h new file mode 100644 index 000000000..58b6e2be8 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h @@ -0,0 +1,141 @@ +#ifndef PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ +#define PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ + +#include "CuptiProfiler.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Utility/Map.h" +#include "Utility/Singleton.h" +#include +#include + +namespace proton { + +struct CubinData { + size_t cubinCrc; + const char *cubin; + size_t cubinSize; + + struct LineInfoKey { + uint32_t functionIndex; + uint64_t pcOffset; + + bool operator<(const LineInfoKey &other) const { + return functionIndex < other.functionIndex || + (functionIndex == other.functionIndex && + pcOffset < other.pcOffset); + } + }; + + struct LineInfoValue { + uint32_t lineNumber{}; + const std::string functionName{}; + const std::string dirName{}; + const std::string fileName{}; + + LineInfoValue() = default; + + LineInfoValue(uint32_t lineNumber, const std::string &functionName, + const std::string &dirName, const std::string &fileName) + : lineNumber(lineNumber), functionName(functionName), dirName(dirName), + fileName(fileName) {} + }; + + std::map lineInfo; +}; + +struct ConfigureData { + ConfigureData() = default; + + ~ConfigureData() { + if (stallReasonNames) { + for (size_t i = 0; i < numStallReasons; i++) { + if (stallReasonNames[i]) + std::free(stallReasonNames[i]); + } + std::free(stallReasonNames); + } + if (stallReasonIndices) + std::free(stallReasonIndices); + if (pcSamplingData.pPcData) { + for (size_t i = 0; i < numValidStallReasons; ++i) { + std::free(pcSamplingData.pPcData[i].stallReason); + } + std::free(pcSamplingData.pPcData); + } + } + + void initialize(CUcontext context); + + CUpti_PCSamplingConfigurationInfo configureStallReasons(); + CUpti_PCSamplingConfigurationInfo configureSamplingPeriod(); + CUpti_PCSamplingConfigurationInfo configureSamplingBuffer(); + CUpti_PCSamplingConfigurationInfo configureScratchBuffer(); + CUpti_PCSamplingConfigurationInfo configureHardwareBufferSize(); + CUpti_PCSamplingConfigurationInfo configureStartStopControl(); + CUpti_PCSamplingConfigurationInfo configureCollectionMode(); + + // The amount of data reserved on the GPU + static constexpr size_t HardwareBufferSize = 128 * 1024 * 1024; + // The amount of data copied from the hardware buffer each time + static constexpr size_t ScratchBufferSize = 16 * 1024 * 1024; + // The number of PCs copied from the scratch buffer each time + static constexpr size_t DataBufferPCCount = 1024; + // The sampling period in cycles = 2^frequency + static constexpr uint32_t DefaultFrequency = 10; + + CUcontext context{}; + uint32_t contextId; + uint32_t numStallReasons{}; + uint32_t numValidStallReasons{}; + char **stallReasonNames{}; + uint32_t *stallReasonIndices{}; + std::map stallReasonIndexToMetricIndex{}; + std::set notIssuedStallReasonIndices{}; + CUpti_PCSamplingData pcSamplingData{}; + // The memory storing configuration information has to be kept alive during + // the profiling session + std::vector configurationInfos; +}; + +class CuptiPCSampling : public Singleton { + +public: + CuptiPCSampling() = default; + virtual ~CuptiPCSampling() = default; + + void initialize(CUcontext context); + + void start(CUcontext context); + + void stop(CUcontext context, uint64_t externId, bool isAPI); + + void finalize(CUcontext context); + + void loadModule(const char *cubin, size_t cubinSize); + + void unloadModule(const char *cubin, size_t cubinSize); + +private: + ConfigureData *getConfigureData(uint32_t contextId); + + CubinData *getCubinData(uint64_t cubinCrc); + + void processPCSamplingData(ConfigureData *configureData, uint64_t externId, + bool isAPI); + + ThreadSafeMap contextIdToConfigureData; + // In case the same cubin is loaded multiple times, we need to keep track of + // all of them + ThreadSafeMap> + cubinCrcToCubinData; + ThreadSafeSet contextInitialized; + + std::atomic pcSamplingStarted{false}; + std::mutex pcSamplingMutex{}; + std::mutex contextMutex{}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h new file mode 100644 index 000000000..c443ec2e3 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h @@ -0,0 +1,19 @@ +#ifndef PROTON_PROFILER_CUPTI_PROFILER_H_ +#define PROTON_PROFILER_CUPTI_PROFILER_H_ + +#include "Profiler/GPUProfiler.h" + +namespace proton { + +class CuptiProfiler : public GPUProfiler { +public: + CuptiProfiler(); + virtual ~CuptiProfiler(); + +private: + struct CuptiProfilerPimpl; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_CUPTI_PROFILER_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/GPUProfiler.h new file mode 100644 index 000000000..a12889278 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/GPUProfiler.h @@ -0,0 +1,159 @@ +#ifndef PROTON_PROFILER_GPU_PROFILER_H_ +#define PROTON_PROFILER_GPU_PROFILER_H_ + +#include "Context/Context.h" +#include "Profiler.h" +#include "Utility/Atomic.h" +#include "Utility/Map.h" +#include "Utility/Set.h" + +#include +#include +#include +#include +#include + +namespace proton { + +// Singleton: Each concrete GPU profiler, e.g., +// CuptiProfiler, should be a singleton. +template +class GPUProfiler : public Profiler, + public ThreadLocalOpInterface, + public Singleton { +public: + GPUProfiler() = default; + virtual ~GPUProfiler() = default; + + using CorrIdToExternIdMap = + ThreadSafeMap, /**/ + std::unordered_map>>; + using ApiExternIdSet = ThreadSafeSet>; + + ConcreteProfilerT &enablePCSampling() { + pcSamplingEnabled = true; + return dynamic_cast(*this); + } + ConcreteProfilerT &disablePCSampling() { + pcSamplingEnabled = false; + return dynamic_cast(*this); + } + bool isPCSamplingEnabled() const { return pcSamplingEnabled; } + + ConcreteProfilerT &setLibPath(const std::string &libPath) { + pImpl->setLibPath(libPath); + return dynamic_cast(*this); + } + +protected: + // OpInterface + void startOp(const Scope &scope) override { + this->correlation.pushExternId(scope.scopeId); + for (auto data : getDataSet()) + data->addOp(scope.scopeId, scope.name); + } + void stopOp(const Scope &scope) override { this->correlation.popExternId(); } + + // Profiler + virtual void doStart() override { pImpl->doStart(); } + virtual void doFlush() override { pImpl->doFlush(); } + virtual void doStop() override { pImpl->doStop(); } + + struct ThreadState { + ConcreteProfilerT &profiler; + size_t scopeId{Scope::DummyScopeId}; + + ThreadState(ConcreteProfilerT &profiler) : profiler(profiler) {} + + void enterOp() { + if (profiler.isOpInProgress()) + return; + scopeId = Scope::getNewScopeId(); + profiler.enterOp(Scope(scopeId)); + profiler.correlation.apiExternIds.insert(scopeId); + } + + void exitOp() { + if (!profiler.isOpInProgress()) + return; + profiler.exitOp(Scope(scopeId)); + } + }; + + struct Correlation { + std::atomic maxSubmittedCorrelationId{0}; + std::atomic maxCompletedCorrelationId{0}; + // Mapping from a native profiler correlation id to an external id. + CorrIdToExternIdMap corrIdToExternId; + // A set of kernels triggered by GPU runtime APIs (e.g., torch + // kernels) other than Triton. + // It stores a subset of external ids in corrIdToExternId. + ApiExternIdSet apiExternIds; + static thread_local std::deque externIdQueue; + + Correlation() = default; + + void submit(const uint64_t correlationId) { + atomicMax(maxSubmittedCorrelationId, correlationId); + } + + void complete(const uint64_t correlationId) { + atomicMax(maxCompletedCorrelationId, correlationId); + } + + void pushExternId(size_t externId) { externIdQueue.push_back(externId); } + + void popExternId() { externIdQueue.pop_front(); } + + // Correlate the correlationId with the last externId + void correlate(uint64_t correlationId, size_t numInstances = 1) { + if (externIdQueue.empty()) + return; + corrIdToExternId[correlationId] = {externIdQueue.back(), numInstances}; + } + + template + void flush(uint64_t maxRetries, uint64_t sleepMs, FlushFnT &&flushFn) { + flushFn(); + auto submittedId = maxSubmittedCorrelationId.load(); + auto completedId = maxCompletedCorrelationId.load(); + auto retries = maxRetries; + while ((completedId < submittedId) && retries > 0) { + std::this_thread::sleep_for(std::chrono::microseconds(sleepMs)); + flushFn(); + completedId = maxCompletedCorrelationId.load(); + --retries; + } + } + }; + + static thread_local ThreadState threadState; + Correlation correlation; + + // Use the pimpl idiom to hide the implementation details. This lets us avoid + // including the cupti header from this header. The cupti header and the + // equivalent header from AMD define conflicting macros, so we want to use + // those headers only within cpp files. + class GPUProfilerPimplInterface { + public: + GPUProfilerPimplInterface(ConcreteProfilerT &profiler) + : profiler(profiler) {} + virtual ~GPUProfilerPimplInterface() = default; + + virtual void setLibPath(const std::string &libPath) = 0; + virtual void doStart() = 0; + virtual void doFlush() = 0; + virtual void doStop() = 0; + + protected: + ConcreteProfilerT &profiler; + }; + std::unique_ptr pImpl; + + bool pcSamplingEnabled{false}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_GPU_PROFILER_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Profiler.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Profiler.h new file mode 100644 index 000000000..15f771fb2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Profiler.h @@ -0,0 +1,93 @@ +#ifndef PROTON_PROFILER_PROFILER_H_ +#define PROTON_PROFILER_PROFILER_H_ + +#include "Data/Data.h" +#include "Utility/Singleton.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +/// A profiler contains utilities provided by the profiler library to +/// collect and analyze performance data. +class Profiler { +public: + Profiler() = default; + + virtual ~Profiler() = default; + + /// Start the profiler. + /// If the profiler is already started, this function does nothing. + Profiler *start() { + if (!this->started) { + this->started = true; + this->doStart(); + } + return this; + } + + /// Flush the profiler's data from the device to the host. + /// It doesn't stop the profiler. + Profiler *flush() { + this->doFlush(); + return this; + } + + /// Stop the profiler. + /// Do real stop if there's no data to collect. + Profiler *stop() { + if (!this->started) { + return this; + } + if (this->getDataSet().empty()) { + this->started = false; + this->doStop(); + } + return this; + } + + /// Register a data object to the profiler. + /// A profiler can yield metrics to multiple data objects. + Profiler *registerData(Data *data) { + std::unique_lock lock(mutex); + dataSet.insert(data); + return this; + } + + /// Unregister a data object from the profiler. + Profiler *unregisterData(Data *data) { + std::unique_lock lock(mutex); + dataSet.erase(data); + return this; + } + + /// Get the set of data objects registered to the profiler. + std::set getDataSet() const { + std::shared_lock lock(mutex); + return dataSet; + } + +protected: + virtual void doStart() = 0; + virtual void doFlush() = 0; + virtual void doStop() = 0; + + // `dataSet` can be accessed by both the user thread and the background + // threads + mutable std::shared_mutex mutex; + std::set dataSet; + +private: + bool started{}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_PROFILER_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h new file mode 100644 index 000000000..b9bc08de8 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h @@ -0,0 +1,19 @@ +#ifndef PROTON_PROFILER_ROCTRACER_PROFILER_H_ +#define PROTON_PROFILER_ROCTRACER_PROFILER_H_ + +#include "Profiler/GPUProfiler.h" + +namespace proton { + +class RoctracerProfiler : public GPUProfiler { +public: + RoctracerProfiler(); + virtual ~RoctracerProfiler(); + +private: + struct RoctracerProfilerPimpl; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_ROCTRACER_PROFILER_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Proton.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Proton.h new file mode 100644 index 000000000..92e2fdf0a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Proton.h @@ -0,0 +1,9 @@ +#ifndef PROTON_H_ +#define PROTON_H_ + +#include "Context/Context.h" +#include "Data/Data.h" +#include "Data/Metric.h" +#include "Session/Session.h" + +#endif // PROTON_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Session/Session.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Session/Session.h new file mode 100644 index 000000000..88b9ab349 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Session/Session.h @@ -0,0 +1,179 @@ +#ifndef PROTON_SESSION_SESSION_H_ +#define PROTON_SESSION_SESSION_H_ + +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Utility/Singleton.h" +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +class Profiler; +class Data; +enum class OutputFormat; + +/// A session is a collection of profiler, context source, and data objects. +/// There could be multiple sessions in the system, each can correspond to a +/// different duration, or the same duration but with different configurations. +class Session { +public: + ~Session() = default; + + void activate(); + + void deactivate(); + + void finalize(OutputFormat outputFormat); + +private: + Session(size_t id, const std::string &path, Profiler *profiler, + std::unique_ptr contextSource, + std::unique_ptr data) + : id(id), path(path), profiler(profiler), + contextSource(std::move(contextSource)), data(std::move(data)) {} + + template std::vector getInterfaces() { + std::vector interfaces; + // There's an implicit order between contextSource and profiler/data. The + // latter two rely on the contextSource to obtain the context, so we need to + // add the contextSource first. + if (auto interface = dynamic_cast(contextSource.get())) { + interfaces.push_back(interface); + } + if (auto interface = dynamic_cast(profiler)) { + interfaces.push_back(interface); + } + if (auto interface = dynamic_cast(data.get())) { + interfaces.push_back(interface); + } + return interfaces; + } + + const std::string path{}; + size_t id{}; + Profiler *profiler{}; + std::unique_ptr contextSource{}; + std::unique_ptr data{}; + + friend class SessionManager; +}; + +/// A session manager is responsible for managing the lifecycle of sessions. +/// There's a single and unique session manager in the system. +class SessionManager : public Singleton { +public: + SessionManager() = default; + ~SessionManager() = default; + + size_t addSession(const std::string &path, const std::string &profilerName, + const std::string &profilerPath, + const std::string &contextSourceName, + const std::string &dataName); + + void finalizeSession(size_t sessionId, OutputFormat outputFormat); + + void finalizeAllSessions(OutputFormat outputFormat); + + void activateSession(size_t sessionId); + + void activateAllSessions(); + + void deactivateSession(size_t sessionId); + + void deactivateAllSessions(); + + void enterScope(const Scope &scope); + + void exitScope(const Scope &scope); + + void enterOp(const Scope &scope); + + void exitOp(const Scope &scope); + + void addMetrics(size_t scopeId, + const std::map &metrics); + + void setState(std::optional context); + +private: + std::unique_ptr makeSession(size_t id, const std::string &path, + const std::string &profilerName, + const std::string &profilerPath, + const std::string &contextSourceName, + const std::string &dataName); + + void activateSessionImpl(size_t sessionId); + + void deActivateSessionImpl(size_t sessionId); + + size_t getSessionId(const std::string &path) { return sessionPaths[path]; } + + bool hasSession(const std::string &path) { + return sessionPaths.find(path) != sessionPaths.end(); + } + + bool hasSession(size_t sessionId) { + return sessions.find(sessionId) != sessions.end(); + } + + void removeSession(size_t sessionId); + + template + void updateInterfaceCount(size_t sessionId, Counter &interfaceCounts) { + auto interfaces = sessions[sessionId]->getInterfaces(); + for (auto *interface : interfaces) { + auto it = std::find_if( + interfaceCounts.begin(), interfaceCounts.end(), + [interface](const auto &pair) { return pair.first == interface; }); + + if (it != interfaceCounts.end()) { + if constexpr (isRegistering) { + ++it->second; + } else { + --it->second; + if (it->second == 0) { + interfaceCounts.erase(it); + } + } + } else if constexpr (isRegistering) { + interfaceCounts.emplace_back(interface, 1); + } + } + } + + template + void registerInterface(size_t sessionId, Counter &interfaceCounts) { + updateInterfaceCount(sessionId, interfaceCounts); + } + + template + void unregisterInterface(size_t sessionId, Counter &interfaceCounts) { + updateInterfaceCount(sessionId, interfaceCounts); + } + + mutable std::mutex mutex; + + size_t nextSessionId{}; + // path -> session id + std::map sessionPaths; + // session id -> active + std::map sessionActive; + // session id -> session + std::map> sessions; + // {scope, active count} + std::vector> scopeInterfaceCounts; + // {op, active count} + std::vector> opInterfaceCounts; + // {context source, active count} + std::vector> contextSourceCounts; +}; + +} // namespace proton + +#endif // PROTON_SESSION_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Atomic.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Atomic.h new file mode 100644 index 000000000..0f759e0d6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Atomic.h @@ -0,0 +1,39 @@ +#ifndef PROTON_UTILITY_ATOMIC_H_ +#define PROTON_UTILITY_ATOMIC_H_ + +#include +#include + +namespace proton { + +template T atomicMax(std::atomic &target, T value) { + T current = target.load(); + while (current < value && !target.compare_exchange_weak(current, value)) + ; + return current; +} + +template T atomicMin(std::atomic &target, T value) { + T current = target.load(); + while (current > value && !target.compare_exchange_weak(current, value)) + ; + return current; +} + +template +void doubleCheckedLock(Condition enterCondition, std::mutex &lock, + Function function) { + if (!enterCondition()) + return; + + std::unique_lock guard(lock); + + if (!enterCondition()) + return; + + function(); +} + +} // namespace proton + +#endif // PROTON_UTILITY_ATOMIC_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Errors.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Errors.h new file mode 100644 index 000000000..09c44025d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Errors.h @@ -0,0 +1,15 @@ +#ifndef PROTON_UTILITY_ERRORS_H_ +#define PROTON_UTILITY_ERRORS_H_ + +#include + +namespace proton { + +class NotImplemented : public std::logic_error { +public: + NotImplemented() : std::logic_error("Not yet implemented") {}; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_ERRORS_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Map.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Map.h new file mode 100644 index 000000000..c173d163e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Map.h @@ -0,0 +1,61 @@ +#ifndef PROTON_UTILITY_MAP_H_ +#define PROTON_UTILITY_MAP_H_ + +#include +#include + +namespace proton { + +/// A simple thread safe map with read/write lock. +template > +class ThreadSafeMap { +public: + ThreadSafeMap() = default; + + Value &operator[](const Key &key) { + std::unique_lock lock(mutex); + return map[key]; + } + + Value &operator[](Key &&key) { + std::unique_lock lock(mutex); + return map[std::move(key)]; + } + + Value &at(const Key &key) { + std::shared_lock lock(mutex); + return map.at(key); + } + + void insert(const Key &key, const Value &value) { + std::unique_lock lock(mutex); + map[key] = value; + } + + bool contain(const Key &key) { + std::shared_lock lock(mutex); + auto it = map.find(key); + if (it == map.end()) + return false; + return true; + } + + bool erase(const Key &key) { + std::unique_lock lock(mutex); + return map.erase(key) > 0; + } + + void clear() { + std::unique_lock lock(mutex); + map.clear(); + } + +private: + Container map; + std::shared_mutex mutex; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_MAP_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Set.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Set.h new file mode 100644 index 000000000..50ce165db --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Set.h @@ -0,0 +1,45 @@ +#ifndef PROTON_UTILITY_SET_H_ +#define PROTON_UTILITY_SET_H_ + +#include +#include + +namespace proton { + +/// A simple thread safe set with read/write lock. +template > +class ThreadSafeSet { +public: + ThreadSafeSet() = default; + + void insert(const Key &key) { + std::unique_lock lock(mutex); + set.insert(key); + } + + bool contain(const Key &key) { + std::shared_lock lock(mutex); + auto it = set.find(key); + if (it == set.end()) + return false; + return true; + } + + bool erase(const Key &key) { + std::unique_lock lock(mutex); + return set.erase(key) > 0; + } + + void clear() { + std::unique_lock lock(mutex); + set.clear(); + } + +private: + Container set; + std::shared_mutex mutex; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_MAP_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Singleton.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Singleton.h new file mode 100644 index 000000000..f91fef143 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Singleton.h @@ -0,0 +1,22 @@ +#ifndef PROTON_UTILITY_SINGLETON_H_ +#define PROTON_UTILITY_SINGLETON_H_ + +namespace proton { + +template class Singleton { +public: + Singleton(const Singleton &) = delete; + Singleton &operator=(const Singleton &) = delete; + + static T &instance() { + static T _; + return _; + } + +protected: + Singleton() = default; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_SINGLETON_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/String.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/String.h new file mode 100644 index 000000000..74f34ed7c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/String.h @@ -0,0 +1,49 @@ +#ifndef PROTON_UTILITY_STRING_H_ +#define PROTON_UTILITY_STRING_H_ + +#include + +namespace proton { + +inline std::string toLower(const std::string &str) { + std::string lower; + for (auto c : str) { + lower += tolower(c); + } + return lower; +} + +inline std::string replace(const std::string &str, const std::string &src, + const std::string &dst) { + std::string replaced = str; + size_t pos = replaced.find(src); + while (pos != std::string::npos) { + replaced.replace(pos, src.length(), dst); + pos += dst.length(); + pos = replaced.find(src, pos); + } + return replaced; +} + +inline bool endWith(const std::string &str, const std::string &sub) { + if (str.length() < sub.length()) { + return false; + } + return str.compare(str.length() - sub.length(), sub.length(), sub) == 0; +} + +inline std::string trim(const std::string &str) { + size_t start = 0; + size_t end = str.length(); + while (start < end && isspace(str[start])) { + start++; + } + while (end > start && isspace(str[end - 1])) { + end--; + } + return str.substr(start, end - start); +} + +} // namespace proton + +#endif // PROTON_UTILITY_STRING_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Traits.h b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Traits.h new file mode 100644 index 000000000..bdc43906e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/include/Utility/Traits.h @@ -0,0 +1,12 @@ +#ifndef PROTON_UTILITY_TRAITS_H_ +#define PROTON_UTILITY_TRAITS_H_ + +#include +#include + +namespace proton { +template +struct is_one_of : std::disjunction...> {}; +} // namespace proton + +#endif diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/csrc/lib/CMakeLists.txt new file mode 100644 index 000000000..564901af3 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Context) +add_subdirectory(Data) +add_subdirectory(Driver) +add_subdirectory(Profiler) +add_subdirectory(Session) diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/CMakeLists.txt new file mode 100644 index 000000000..e2385c9fe --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/CMakeLists.txt @@ -0,0 +1,5 @@ +add_proton_library(ProtonContext + Context.cpp + Python.cpp + Shadow.cpp +) diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Context.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Context.cpp new file mode 100644 index 000000000..04e5170d0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Context.cpp @@ -0,0 +1,13 @@ +#include "Context/Context.h" + +namespace proton { + +/*static*/ thread_local std::optional ContextSource::state = + std::nullopt; + +std::atomic Scope::scopeIdCounter{1}; + +/*static*/ thread_local std::map + ThreadLocalOpInterface::opInProgress; + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Python.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Python.cpp new file mode 100644 index 000000000..3dc3aeb64 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Python.cpp @@ -0,0 +1,97 @@ +#include "Context/Python.h" +#include "pybind11/pybind11.h" +#include +#include + +namespace proton { + +namespace { + +// bpo-42262 added Py_NewRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) +PyObject *_Py_NewRef(PyObject *obj) { + Py_INCREF(obj); + return obj; +} +#define Py_NewRef(obj) _Py_NewRef((PyObject *)(obj)) +#endif + +// bpo-42262 added Py_XNewRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_XNewRef) +PyObject *_Py_XNewRef(PyObject *obj) { + Py_XINCREF(obj); + return obj; +} +#define Py_XNewRef(obj) _Py_XNewRef((PyObject *)(obj)) +#endif + +// bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 +PyCodeObject *getFrameCodeObject(PyFrameObject *frame) { + assert(frame != nullptr); + assert(frame->f_code != nullptr); + return (PyCodeObject *)(Py_NewRef(frame->f_code)); +} +#else +PyCodeObject *getFrameCodeObject(PyFrameObject *frame) { + assert(frame != nullptr); + return PyFrame_GetCode(frame); +} +#endif + +// bpo-40421 added PyFrame_GetBack() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 +PyFrameObject *getFrameBack(PyFrameObject *frame) { + assert(frame != nullptr); + return (PyFrameObject *)(Py_XNewRef(frame->f_back)); +} +#else +PyFrameObject *getFrameBack(PyFrameObject *frame) { + assert(frame != nullptr); + return PyFrame_GetBack(frame); +} +#endif + +std::string unpackPyobject(PyObject *pyObject) { + if (PyBytes_Check(pyObject)) { + size_t size = PyBytes_GET_SIZE(pyObject); + return std::string(PyBytes_AS_STRING(pyObject), size); + } + if (PyUnicode_Check(pyObject)) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + Py_ssize_t size; + const char *data = PyUnicode_AsUTF8AndSize(pyObject, &size); + if (!data) { + return ""; + } + return std::string(data, (size_t)size); + } + return ""; +} + +} // namespace + +std::vector PythonContextSource::getContextsImpl() { + pybind11::gil_scoped_acquire gil; + + PyFrameObject *frame = PyEval_GetFrame(); + Py_XINCREF(frame); + + std::vector contexts; + while (frame != nullptr) { + PyCodeObject *f_code = getFrameCodeObject(frame); + size_t lineno = PyFrame_GetLineNumber(frame); + size_t firstLineNo = f_code->co_firstlineno; + std::string file = unpackPyobject(f_code->co_filename); + std::string function = unpackPyobject(f_code->co_name); + auto pythonFrame = file + ":" + function + "@" + std::to_string(lineno); + contexts.push_back(Context(pythonFrame)); + auto newFrame = getFrameBack(frame); + Py_DECREF(frame); + frame = newFrame; + } + std::reverse(contexts.begin(), contexts.end()); + return contexts; +} + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Shadow.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Shadow.cpp new file mode 100644 index 000000000..b0a6d2b58 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Context/Shadow.cpp @@ -0,0 +1,44 @@ +#include "Context/Shadow.h" + +#include +#include + +namespace proton { + +void ShadowContextSource::initializeThreadContext() { + if (!mainContextStack) { + mainContextStack = &threadContextStack; + contextInitialized = true; + } + if (!contextInitialized) { + threadContextStack = *mainContextStack; + contextInitialized = true; + } +} + +void ShadowContextSource::enterScope(const Scope &scope) { + initializeThreadContext(); + threadContextStack.push_back(scope); +} + +std::vector ShadowContextSource::getContextsImpl() { + initializeThreadContext(); + return threadContextStack; +} + +void ShadowContextSource::exitScope(const Scope &scope) { + if (threadContextStack.empty()) { + throw std::runtime_error("Context stack is empty"); + } + if (threadContextStack.back() != scope) { + throw std::runtime_error("Context stack is not balanced"); + } + threadContextStack.pop_back(); +} + +/*static*/ thread_local std::vector + ShadowContextSource::threadContextStack; + +/*static*/ thread_local bool ShadowContextSource::contextInitialized = false; + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/CMakeLists.txt new file mode 100644 index 000000000..0f835198c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/CMakeLists.txt @@ -0,0 +1,5 @@ +add_proton_library(ProtonData + Data.cpp + TraceData.cpp + TreeData.cpp +) diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/Data.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/Data.cpp new file mode 100644 index 000000000..73df0a705 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/Data.cpp @@ -0,0 +1,41 @@ +#include "Data/Data.h" +#include "Utility/String.h" + +#include +#include +#include + +#include + +namespace proton { + +void Data::dump(OutputFormat outputFormat) { + std::shared_lock lock(mutex); + + std::unique_ptr out; + if (path.empty() || path == "-") { + out.reset(new std::ostream(std::cout.rdbuf())); // Redirecting to cout + } else { + out.reset(new std::ofstream( + path + "." + + outputFormatToString(outputFormat))); // Opening a file for output + } + doDump(*out, outputFormat); +} + +OutputFormat parseOutputFormat(const std::string &outputFormat) { + if (toLower(outputFormat) == "hatchet") { + return OutputFormat::Hatchet; + } + throw std::runtime_error("Unknown output format: " + outputFormat); +} + +const std::string outputFormatToString(OutputFormat outputFormat) { + if (outputFormat == OutputFormat::Hatchet) { + return "hatchet"; + } + throw std::runtime_error("Unknown output format: " + + std::to_string(static_cast(outputFormat))); +} + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/TraceData.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/TraceData.cpp new file mode 100644 index 000000000..acce21463 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/TraceData.cpp @@ -0,0 +1,31 @@ +#include "Data/TraceData.h" +#include "Utility/Errors.h" + +#include + +namespace proton { + +void TraceData::enterScope(const Scope &scope) { throw NotImplemented(); } + +void TraceData::exitScope(const Scope &scope) { throw NotImplemented(); } + +size_t TraceData::addOp(size_t scopeId, const std::string &name) { + throw NotImplemented(); +} + +void TraceData::addMetric(size_t scopeId, std::shared_ptr metric) { + throw NotImplemented(); +} + +void TraceData::addMetrics( + size_t scopeId, const std::map &metrics) { + throw NotImplemented(); +} + +void TraceData::clear() { throw NotImplemented(); } + +void TraceData::doDump(std::ostream &os, OutputFormat outputFormat) const { + throw NotImplemented(); +} + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/TreeData.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/TreeData.cpp new file mode 100644 index 000000000..614635c78 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Data/TreeData.cpp @@ -0,0 +1,299 @@ +#include "Data/TreeData.h" +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Driver/Device.h" +#include "nlohmann/json.hpp" + +#include +#include +#include +#include +#include + +using json = nlohmann::json; + +namespace proton { + +class TreeData::Tree { +public: + struct TreeNode : public Context { + inline static const size_t RootId = 0; + inline static const size_t DummyId = std::numeric_limits::max(); + + TreeNode() = default; + explicit TreeNode(size_t id, const std::string &name) + : id(id), Context(name) {} + TreeNode(size_t id, size_t parentId, const std::string &name) + : id(id), parentId(parentId), Context(name) {} + virtual ~TreeNode() = default; + + void addChild(const Context &context, size_t id) { children[context] = id; } + + bool hasChild(const Context &context) const { + return children.find(context) != children.end(); + } + + size_t getChild(const Context &context) const { + return children.at(context); + } + + size_t parentId = DummyId; + size_t id = DummyId; + std::map children = {}; + std::map> metrics = {}; + std::map flexibleMetrics = {}; + friend class Tree; + }; + + Tree() { + treeNodeMap.try_emplace(TreeNode::RootId, TreeNode::RootId, "ROOT"); + } + + size_t addNode(const Context &context, size_t parentId) { + if (treeNodeMap[parentId].hasChild(context)) { + return treeNodeMap[parentId].getChild(context); + } + auto id = nextContextId++; + treeNodeMap.try_emplace(id, id, parentId, context.name); + treeNodeMap[parentId].addChild(context, id); + return id; + } + + size_t addNode(const std::vector &indices) { + auto parentId = TreeNode::RootId; + for (auto index : indices) { + parentId = addNode(index, parentId); + } + return parentId; + } + + TreeNode &getNode(size_t id) { return treeNodeMap.at(id); } + + enum class WalkPolicy { PreOrder, PostOrder }; + + template void walk(FnT &&fn) { + if constexpr (walkPolicy == WalkPolicy::PreOrder) { + walkPreOrder(TreeNode::RootId, fn); + } else if constexpr (walkPolicy == WalkPolicy::PostOrder) { + walkPostOrder(TreeNode::RootId, fn); + } + } + + template void walkPreOrder(size_t contextId, FnT &&fn) { + fn(getNode(contextId)); + for (auto &child : getNode(contextId).children) { + walkPreOrder(child.second, fn); + } + } + + template void walkPostOrder(size_t contextId, FnT &&fn) { + for (auto &child : getNode(contextId).children) { + walkPostOrder(child.second, fn); + } + fn(getNode(contextId)); + } + +private: + size_t nextContextId = TreeNode::RootId + 1; + // tree node id -> tree node + std::map treeNodeMap; +}; + +void TreeData::init() { tree = std::make_unique(); } + +void TreeData::enterScope(const Scope &scope) { + // enterOp and addMetric maybe called from different threads + std::unique_lock lock(mutex); + std::vector contexts; + if (contextSource != nullptr) + contexts = contextSource->getContexts(); + auto contextId = tree->addNode(contexts); + scopeIdToContextId[scope.scopeId] = contextId; +} + +void TreeData::exitScope(const Scope &scope) {} + +size_t TreeData::addOp(size_t scopeId, const std::string &name) { + std::unique_lock lock(mutex); + auto scopeIdIt = scopeIdToContextId.find(scopeId); + if (scopeIdIt == scopeIdToContextId.end()) { + // Obtain the current context + std::vector contexts; + if (contextSource != nullptr) + contexts = contextSource->getContexts(); + // Add an op under the current context + if (!name.empty()) + contexts.emplace_back(name); + scopeIdToContextId[scopeId] = tree->addNode(contexts); + } else { + // Add a new context under it and update the context + scopeId = Scope::getNewScopeId(); + scopeIdToContextId[scopeId] = + tree->addNode(Context(name), scopeIdIt->second); + } + return scopeId; +} + +void TreeData::addMetric(size_t scopeId, std::shared_ptr metric) { + std::unique_lock lock(mutex); + auto scopeIdIt = scopeIdToContextId.find(scopeId); + // The profile data is deactivated, ignore the metric + if (scopeIdIt == scopeIdToContextId.end()) + return; + auto contextId = scopeIdIt->second; + auto &node = tree->getNode(contextId); + if (node.metrics.find(metric->getKind()) == node.metrics.end()) + node.metrics.emplace(metric->getKind(), metric); + else + node.metrics[metric->getKind()]->updateMetric(*metric); +} + +void TreeData::addMetrics( + size_t scopeId, const std::map &metrics) { + std::unique_lock lock(mutex); + auto scopeIdIt = scopeIdToContextId.find(scopeId); + // The profile data is deactivated, ignore the metric + if (scopeIdIt == scopeIdToContextId.end()) + return; + auto contextId = scopeIdIt->second; + auto &node = tree->getNode(contextId); + for (auto [metricName, metricValue] : metrics) { + if (node.flexibleMetrics.find(metricName) == node.flexibleMetrics.end()) { + node.flexibleMetrics.emplace(metricName, + FlexibleMetric(metricName, metricValue)); + } else { + node.flexibleMetrics.at(metricName).updateValue(metricValue); + } + } +} + +void TreeData::clear() { + std::unique_lock lock(mutex); + scopeIdToContextId.clear(); +} + +void TreeData::dumpHatchet(std::ostream &os) const { + std::map jsonNodes; + json output = json::array(); + output.push_back(json::object()); + jsonNodes[Tree::TreeNode::RootId] = &(output.back()); + std::set inclusiveValueNames; + std::map> deviceIds; + this->tree->template walk( + [&](Tree::TreeNode &treeNode) { + const auto contextName = treeNode.name; + auto contextId = treeNode.id; + json *jsonNode = jsonNodes[contextId]; + (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; + (*jsonNode)["metrics"] = json::object(); + for (auto [metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + std::shared_ptr kernelMetric = + std::dynamic_pointer_cast(metric); + uint64_t duration = std::get( + kernelMetric->getValue(KernelMetric::Duration)); + uint64_t invocations = std::get( + kernelMetric->getValue(KernelMetric::Invocations)); + uint64_t deviceId = std::get( + kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + std::string deviceTypeName = + getDeviceTypeString(static_cast(deviceType)); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Duration)] = + duration; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Invocations)] = + invocations; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceId)] = + std::to_string(deviceId); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceType)] = + deviceTypeName; + inclusiveValueNames.insert( + kernelMetric->getValueName(KernelMetric::Duration)); + inclusiveValueNames.insert( + kernelMetric->getValueName(KernelMetric::Invocations)); + deviceIds.insert({deviceType, {deviceId}}); + } else if (metricKind == MetricKind::PCSampling) { + auto pcSamplingMetric = + std::dynamic_pointer_cast(metric); + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + auto valueName = pcSamplingMetric->getValueName(i); + inclusiveValueNames.insert(valueName); + std::visit( + [&](auto &&value) { + (*jsonNode)["metrics"][valueName] = value; + }, + pcSamplingMetric->getValues()[i]); + } + } else { + throw std::runtime_error("MetricKind not supported"); + } + } + for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { + auto valueName = flexibleMetric.getValueName(0); + if (!flexibleMetric.isExclusive(0)) + inclusiveValueNames.insert(valueName); + std::visit( + [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, + flexibleMetric.getValues()[0]); + } + (*jsonNode)["children"] = json::array(); + auto children = treeNode.children; + for (auto _ : children) { + (*jsonNode)["children"].push_back(json::object()); + } + auto idx = 0; + for (auto child : children) { + auto [index, childId] = child; + jsonNodes[childId] = &(*jsonNode)["children"][idx]; + idx++; + } + }); + // Hints for all inclusive metrics + for (auto valueName : inclusiveValueNames) { + output[Tree::TreeNode::RootId]["metrics"][valueName] = 0; + } + // Prepare the device information + // Note that this is done from the application thread, + // query device information from the tool thread (e.g., CUPTI) will have + // problems + output.push_back(json::object()); + auto &deviceJson = output.back(); + for (auto [deviceType, deviceIds] : deviceIds) { + auto deviceTypeName = + getDeviceTypeString(static_cast(deviceType)); + if (!deviceJson.contains(deviceTypeName)) + deviceJson[deviceTypeName] = json::object(); + for (auto deviceId : deviceIds) { + Device device = getDevice(static_cast(deviceType), deviceId); + deviceJson[deviceTypeName][std::to_string(deviceId)] = { + {"clock_rate", device.clockRate}, + {"memory_clock_rate", device.memoryClockRate}, + {"bus_width", device.busWidth}, + {"arch", device.arch}, + {"num_sms", device.numSms}}; + } + } + os << std::endl << output.dump(4) << std::endl; +} + +void TreeData::doDump(std::ostream &os, OutputFormat outputFormat) const { + if (outputFormat == OutputFormat::Hatchet) { + dumpHatchet(os); + } else { + std::logic_error("OutputFormat not supported"); + } +} + +TreeData::TreeData(const std::string &path, ContextSource *contextSource) + : Data(path, contextSource) { + init(); +} + +TreeData::~TreeData() {} + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/CMakeLists.txt new file mode 100644 index 000000000..ba81a4d54 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/CMakeLists.txt @@ -0,0 +1,8 @@ +add_proton_library(ProtonDriver + Device.cpp + GPU/CudaApi.cpp + GPU/CuptiApi.cpp + GPU/HipApi.cpp + GPU/HsaApi.cpp + GPU/RoctracerApi.cpp +) diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/Device.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/Device.cpp new file mode 100644 index 000000000..1fb1f2361 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/Device.cpp @@ -0,0 +1,28 @@ +#include "Driver/Device.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/HipApi.h" + +#include "Utility/Errors.h" + +namespace proton { + +Device getDevice(DeviceType type, uint64_t index) { + if (type == DeviceType::CUDA) { + return cuda::getDevice(index); + } + if (type == DeviceType::HIP) { + return hip::getDevice(index); + } + throw std::runtime_error("DeviceType not supported"); +} + +const std::string getDeviceTypeString(DeviceType type) { + if (type == DeviceType::CUDA) { + return DeviceTraits::name; + } else if (type == DeviceType::HIP) { + return DeviceTraits::name; + } + throw std::runtime_error("DeviceType not supported"); +} + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp new file mode 100644 index 000000000..d1617b48a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp @@ -0,0 +1,61 @@ +#include "Driver/GPU/CudaApi.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace cuda { + +struct ExternLibCuda : public ExternLibBase { + using RetType = CUresult; + // https://forums.developer.nvidia.com/t/wsl2-libcuda-so-and-libcuda-so-1-should-be-symlink/236301 + // On WSL, "libcuda.so" and "libcuda.so.1" may not be linked, so we use + // "libcuda.so.1" instead. + static constexpr const char *name = "libcuda.so.1"; + static constexpr const char *defaultDir = ""; + static constexpr RetType success = CUDA_SUCCESS; + static void *lib; +}; + +void *ExternLibCuda::lib = nullptr; + +DEFINE_DISPATCH(ExternLibCuda, init, cuInit, int) + +DEFINE_DISPATCH(ExternLibCuda, ctxSynchronize, cuCtxSynchronize) + +DEFINE_DISPATCH(ExternLibCuda, ctxGetCurrent, cuCtxGetCurrent, CUcontext *) + +DEFINE_DISPATCH(ExternLibCuda, deviceGet, cuDeviceGet, CUdevice *, int) + +DEFINE_DISPATCH(ExternLibCuda, deviceGetAttribute, cuDeviceGetAttribute, int *, + CUdevice_attribute, CUdevice) + +Device getDevice(uint64_t index) { + CUdevice device; + cuda::deviceGet(&device, index); + int clockRate; + cuda::deviceGetAttribute(&clockRate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, + device); + int memoryClockRate; + cuda::deviceGetAttribute(&memoryClockRate, + CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device); + int busWidth; + cuda::deviceGetAttribute( + &busWidth, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device); + int numSms; + cuda::deviceGetAttribute( + &numSms, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device); + int major; + cuda::deviceGetAttribute( + &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); + int minor; + cuda::deviceGetAttribute( + &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); + std::string arch = std::to_string(major * 10 + minor); + + return Device(DeviceType::CUDA, index, clockRate, memoryClockRate, busWidth, + numSms, arch); +} + +} // namespace cuda + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp new file mode 100644 index 000000000..f86db8de2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp @@ -0,0 +1,117 @@ +#include "Driver/GPU/CuptiApi.h" +#include "Driver/Device.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace cupti { + +struct ExternLibCupti : public ExternLibBase { + using RetType = CUptiResult; + static constexpr const char *name = "libcupti.so"; + static inline std::string defaultDir = ""; + static constexpr RetType success = CUPTI_SUCCESS; + static void *lib; +}; + +void *ExternLibCupti::lib = nullptr; + +DEFINE_DISPATCH(ExternLibCupti, getVersion, cuptiGetVersion, uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getContextId, cuptiGetContextId, CUcontext, + uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, activityRegisterCallbacks, + cuptiActivityRegisterCallbacks, + CUpti_BuffersCallbackRequestFunc, + CUpti_BuffersCallbackCompleteFunc) + +DEFINE_DISPATCH(ExternLibCupti, subscribe, cuptiSubscribe, + CUpti_SubscriberHandle *, CUpti_CallbackFunc, void *) + +DEFINE_DISPATCH(ExternLibCupti, enableDomain, cuptiEnableDomain, uint32_t, + CUpti_SubscriberHandle, CUpti_CallbackDomain) + +DEFINE_DISPATCH(ExternLibCupti, enableCallback, cuptiEnableCallback, uint32_t, + CUpti_SubscriberHandle, CUpti_CallbackDomain, CUpti_CallbackId); + +DEFINE_DISPATCH(ExternLibCupti, activityEnable, cuptiActivityEnable, + CUpti_ActivityKind) + +DEFINE_DISPATCH(ExternLibCupti, activityDisable, cuptiActivityDisable, + CUpti_ActivityKind) + +DEFINE_DISPATCH(ExternLibCupti, activityEnableContext, + cuptiActivityEnableContext, CUcontext, CUpti_ActivityKind) + +DEFINE_DISPATCH(ExternLibCupti, activityDisableContext, + cuptiActivityDisableContext, CUcontext, CUpti_ActivityKind) + +DEFINE_DISPATCH(ExternLibCupti, activityFlushAll, cuptiActivityFlushAll, + uint32_t) + +DEFINE_DISPATCH(ExternLibCupti, activityGetNextRecord, + cuptiActivityGetNextRecord, uint8_t *, size_t, + CUpti_Activity **) + +DEFINE_DISPATCH(ExternLibCupti, activityPushExternalCorrelationId, + cuptiActivityPushExternalCorrelationId, + CUpti_ExternalCorrelationKind, uint64_t) + +DEFINE_DISPATCH(ExternLibCupti, activityPopExternalCorrelationId, + cuptiActivityPopExternalCorrelationId, + CUpti_ExternalCorrelationKind, uint64_t *) + +DEFINE_DISPATCH(ExternLibCupti, activitySetAttribute, cuptiActivitySetAttribute, + CUpti_ActivityAttribute, size_t *, void *) + +DEFINE_DISPATCH(ExternLibCupti, unsubscribe, cuptiUnsubscribe, + CUpti_SubscriberHandle) + +DEFINE_DISPATCH(ExternLibCupti, finalize, cuptiFinalize) + +DEFINE_DISPATCH(ExternLibCupti, getGraphExecId, cuptiGetGraphExecId, + CUgraphExec, uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getGraphId, cuptiGetGraphId, CUgraph, + uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getCubinCrc, cuptiGetCubinCrc, + CUpti_GetCubinCrcParams *); + +DEFINE_DISPATCH(ExternLibCupti, getSassToSourceCorrelation, + cuptiGetSassToSourceCorrelation, + CUpti_GetSassToSourceCorrelationParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetNumStallReasons, + cuptiPCSamplingGetNumStallReasons, + CUpti_PCSamplingGetNumStallReasonsParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetStallReasons, + cuptiPCSamplingGetStallReasons, + CUpti_PCSamplingGetStallReasonsParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingSetConfigurationAttribute, + cuptiPCSamplingSetConfigurationAttribute, + CUpti_PCSamplingConfigurationInfoParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingEnable, cuptiPCSamplingEnable, + CUpti_PCSamplingEnableParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingDisable, cuptiPCSamplingDisable, + CUpti_PCSamplingDisableParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetData, cuptiPCSamplingGetData, + CUpti_PCSamplingGetDataParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingStart, cuptiPCSamplingStart, + CUpti_PCSamplingStartParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingStop, cuptiPCSamplingStop, + CUpti_PCSamplingStopParams *); + +void setLibPath(const std::string &path) { ExternLibCupti::defaultDir = path; } + +} // namespace cupti + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp new file mode 100644 index 000000000..9e8ef8d22 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp @@ -0,0 +1,91 @@ +#include "Driver/GPU/HipApi.h" +#include "Driver/Dispatch.h" +#include "hip/hip_runtime_api.h" +#include + +namespace proton { + +namespace hip { + +struct ExternLibHip : public ExternLibBase { + using RetType = hipError_t; + static constexpr const char *name = "libamdhip64.so"; + static constexpr const char *defaultDir = ""; + static constexpr RetType success = hipSuccess; + static void *lib; +}; + +void *ExternLibHip::lib = nullptr; + +DEFINE_DISPATCH(ExternLibHip, deviceSynchronize, hipDeviceSynchronize) + +DEFINE_DISPATCH(ExternLibHip, deviceGetAttribute, hipDeviceGetAttribute, int *, + hipDeviceAttribute_t, int); + +DEFINE_DISPATCH(ExternLibHip, getDeviceCount, hipGetDeviceCount, int *); + +DEFINE_DISPATCH(ExternLibHip, getDeviceProperties, hipGetDeviceProperties, + hipDeviceProp_t *, int); + +Device getDevice(uint64_t index) { + int clockRate; + (void)hip::deviceGetAttribute(&clockRate, hipDeviceAttributeClockRate, + index); + int memoryClockRate; + (void)hip::deviceGetAttribute(&memoryClockRate, + hipDeviceAttributeMemoryClockRate, index); + int busWidth; + (void)hip::deviceGetAttribute(&busWidth, + hipDeviceAttributeMemoryBusWidth, index); + int smCount; + (void)hip::deviceGetAttribute( + &smCount, hipDeviceAttributeMultiprocessorCount, index); + + std::string arch = getHipArchName(index); + + return Device(DeviceType::HIP, index, clockRate, memoryClockRate, busWidth, + smCount, arch); +} + +// TODO: hipDeviceProp_t was updated to point from hipDeviceProp_tR0000 -> +// hipDeviceProp_tR0600 as part of a breaking API change in Rocm 6.0 +// https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/driver.c +// uses hipDeviceProp_tR0000 and imports the hip_deprecated.h header file to be +// be back compatible with ROCm 5.x. PyTorch stills needs to support 5.x and the +// hipDeviceProp_tR0600 symbol does not exist pre-Rocm 6.0. Calling +// hipDeviceProp_tR0000 here with Rocm 6.1 causes a stack corruption. Therefore +// were will use hipDeviceProp_t and investigate if we can unify the definitions +// in the two files. + +const std::string getHipArchName(uint64_t index) { + hipDeviceProp_t devProp; + (void)hip::getDeviceProperties(&devProp, index); + std::string gcnArchName(devProp.gcnArchName); + std::string hipArch = gcnArchName.substr(0, 6); + return hipArch; +} + +const char *getKernelNameRef(const hipFunction_t f) { + typedef const char *(*hipKernelNameRef_t)(const hipFunction_t); + static hipKernelNameRef_t func = nullptr; + Dispatch::init(ExternLibHip::name, &ExternLibHip::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibHip::lib, "hipKernelNameRef")); + return (func ? func(f) : NULL); +} + +const char *getKernelNameRefByPtr(const void *hostFunction, + hipStream_t stream) { + typedef const char *(*hipKernelNameRefByPtr_t)(const void *, hipStream_t); + static hipKernelNameRefByPtr_t func = nullptr; + Dispatch::init(ExternLibHip::name, &ExternLibHip::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibHip::lib, "hipKernelNameRefByPtr")); + return (func ? func(hostFunction, stream) : NULL); +} + +} // namespace hip + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp new file mode 100644 index 000000000..7c607b4b9 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp @@ -0,0 +1,36 @@ +#include "Driver/GPU/HsaApi.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace hsa { + +struct ExternLibHsa : public ExternLibBase { + using RetType = hsa_status_t; + static constexpr const char *name = "libhsa-runtime64.so"; + static constexpr const char *defaultDir = ""; + static constexpr RetType success = HSA_STATUS_SUCCESS; + static void *lib; +}; + +void *ExternLibHsa::lib = nullptr; + +DEFINE_DISPATCH(ExternLibHsa, agentGetInfo, hsa_agent_get_info, hsa_agent_t, + hsa_agent_info_t, void *); + +hsa_status_t iterateAgents(hsa_status_t (*callback)(hsa_agent_t agent, + void *data), + void *data) { + typedef hsa_status_t (*hsa_iterate_agents_t)( + hsa_status_t (*)(hsa_agent_t, void *), void *data); + static hsa_iterate_agents_t func = nullptr; + Dispatch::init(ExternLibHsa::name, &ExternLibHsa::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibHsa::lib, "hsa_iterate_agents")); + return (func ? func(callback, data) : HSA_STATUS_ERROR_FATAL); +} + +} // namespace hsa + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp new file mode 100644 index 000000000..a6dcdcf34 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp @@ -0,0 +1,105 @@ +#include "Driver/GPU/RoctracerApi.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace roctracer { + +struct ExternLibRoctracer : public ExternLibBase { + using RetType = roctracer_status_t; + static constexpr const char *name = "libroctracer64.so"; + static constexpr const char *defaultDir = ""; + static constexpr RetType success = ROCTRACER_STATUS_SUCCESS; + static void *lib; +}; + +void *ExternLibRoctracer::lib = nullptr; + +DEFINE_DISPATCH(ExternLibRoctracer, setProperties, roctracer_set_properties, + roctracer_domain_t, void *) + +DEFINE_DISPATCH(ExternLibRoctracer, getTimestamp, roctracer_get_timestamp, + roctracer_timestamp_t *) + +void start() { + typedef void (*roctracer_start_t)(); + static roctracer_start_t func = nullptr; + Dispatch::init(ExternLibRoctracer::name, + &ExternLibRoctracer::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibRoctracer::lib, "roctracer_start")); + if (func) + func(); +} + +void stop() { + typedef void (*roctracer_stop_t)(); + static roctracer_stop_t func = nullptr; + Dispatch::init(ExternLibRoctracer::name, + &ExternLibRoctracer::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibRoctracer::lib, "roctracer_stop")); + if (func) + func(); +} + +char *getOpString(uint32_t domain, uint32_t op, uint32_t kind) { + typedef char *(*roctracer_op_string_t)(uint32_t, uint32_t, uint32_t); + static roctracer_op_string_t func = nullptr; + Dispatch::init(ExternLibRoctracer::name, + &ExternLibRoctracer::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibRoctracer::lib, "roctracer_op_string")); + return (func ? func(domain, op, kind) : NULL); +} + +DEFINE_DISPATCH(ExternLibRoctracer, enableDomainCallback, + roctracer_enable_domain_callback, activity_domain_t, + activity_rtapi_callback_t, void *) + +DEFINE_DISPATCH(ExternLibRoctracer, disableDomainCallback, + roctracer_disable_domain_callback, activity_domain_t) + +DEFINE_DISPATCH(ExternLibRoctracer, enableOpCallback, + roctracer_enable_op_callback, activity_domain_t, uint32_t, + activity_rtapi_callback_t, void *) + +DEFINE_DISPATCH(ExternLibRoctracer, disableOpCallback, + roctracer_disable_op_callback, activity_domain_t, uint32_t) + +DEFINE_DISPATCH(ExternLibRoctracer, openPool, roctracer_open_pool, + const roctracer_properties_t *) + +DEFINE_DISPATCH(ExternLibRoctracer, closePool, roctracer_close_pool) + +DEFINE_DISPATCH(ExternLibRoctracer, enableOpActivity, + roctracer_enable_op_activity, activity_domain_t, uint32_t) + +DEFINE_DISPATCH(ExternLibRoctracer, enableDomainActivity, + roctracer_enable_domain_activity, activity_domain_t) + +DEFINE_DISPATCH(ExternLibRoctracer, disableOpActivity, + roctracer_disable_op_activity, activity_domain_t, uint32_t) + +DEFINE_DISPATCH(ExternLibRoctracer, disableDomainActivity, + roctracer_disable_domain_activity, activity_domain_t) + +DEFINE_DISPATCH(ExternLibRoctracer, flushActivity, roctracer_flush_activity) + +DEFINE_DISPATCH(ExternLibRoctracer, activityPushExternalCorrelationId, + roctracer_activity_push_external_correlation_id, + activity_correlation_id_t) + +DEFINE_DISPATCH(ExternLibRoctracer, activityPopExternalCorrelationId, + roctracer_activity_pop_external_correlation_id, + activity_correlation_id_t *) + +DEFINE_DISPATCH(ExternLibRoctracer, getNextRecord, roctracer_next_record, + const activity_record_t *, const activity_record_t **) + +} // namespace roctracer + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/CMakeLists.txt new file mode 100644 index 000000000..6d26c37a6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/CMakeLists.txt @@ -0,0 +1,5 @@ +add_proton_library(ProtonProfiler + Cupti/CuptiPCSampling.cpp + Cupti/CuptiProfiler.cpp + RocTracer/RoctracerProfiler.cpp +) diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp new file mode 100644 index 000000000..294460c42 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp @@ -0,0 +1,459 @@ +#include "Profiler/Cupti/CuptiPCSampling.h" +#include "Data/Metric.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Utility/Atomic.h" +#include "Utility/Map.h" +#include "Utility/String.h" +#include +#include +#include + +namespace proton { + +namespace { + +uint64_t getCubinCrc(const char *cubin, size_t size) { + CUpti_GetCubinCrcParams cubinCrcParams = { + /*size=*/CUpti_GetCubinCrcParamsSize, + /*cubinSize=*/size, + /*cubin=*/cubin, + /*cubinCrc=*/0, + }; + cupti::getCubinCrc(&cubinCrcParams); + return cubinCrcParams.cubinCrc; +} + +size_t getNumStallReasons(CUcontext context) { + size_t numStallReasons = 0; + CUpti_PCSamplingGetNumStallReasonsParams numStallReasonsParams = { + /*size=*/CUpti_PCSamplingGetNumStallReasonsParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numStallReasons=*/&numStallReasons}; + cupti::pcSamplingGetNumStallReasons(&numStallReasonsParams); + return numStallReasons; +} + +std::tuple +getSassToSourceCorrelation(const char *functionName, uint64_t pcOffset, + const char *cubin, size_t cubinSize) { + CUpti_GetSassToSourceCorrelationParams sassToSourceParams = { + /*size=*/CUpti_GetSassToSourceCorrelationParamsSize, + /*cubin=*/cubin, + /*functionName=*/functionName, + /*cubinSize=*/cubinSize, + /*lineNumber=*/0, + /*pcOffset=*/pcOffset, + /*fileName=*/NULL, + /*dirName=*/NULL, + }; + // Get source can fail if the line mapping is not available in the cubin so we + // don't check the return value + cupti::getSassToSourceCorrelation(&sassToSourceParams); + auto fileNameStr = sassToSourceParams.fileName + ? std::string(sassToSourceParams.fileName) + : ""; + auto dirNameStr = + sassToSourceParams.dirName ? std::string(sassToSourceParams.dirName) : ""; + // It's user's responsibility to free the memory + if (sassToSourceParams.fileName) + std::free(sassToSourceParams.fileName); + if (sassToSourceParams.dirName) + std::free(sassToSourceParams.dirName); + return std::make_tuple(sassToSourceParams.lineNumber, fileNameStr, + dirNameStr); +} + +std::pair +getStallReasonNamesAndIndices(CUcontext context, size_t numStallReasons) { + char **stallReasonNames = + static_cast(std::calloc(numStallReasons, sizeof(char *))); + for (size_t i = 0; i < numStallReasons; i++) { + stallReasonNames[i] = static_cast( + std::calloc(CUPTI_STALL_REASON_STRING_SIZE, sizeof(char))); + } + uint32_t *stallReasonIndices = + static_cast(std::calloc(numStallReasons, sizeof(uint32_t))); + // Initialize the names with 128 characters to avoid buffer overflow + CUpti_PCSamplingGetStallReasonsParams stallReasonsParams = { + /*size=*/CUpti_PCSamplingGetStallReasonsParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numStallReasons=*/numStallReasons, + /*stallReasonIndex=*/stallReasonIndices, + /*stallReasons=*/stallReasonNames, + }; + cupti::pcSamplingGetStallReasons(&stallReasonsParams); + return std::make_pair(stallReasonNames, stallReasonIndices); +} + +size_t matchStallReasonsToIndices( + size_t numStallReasons, char **stallReasonNames, + uint32_t *stallReasonIndices, + std::map &stallReasonIndexToMetricIndex, + std::set ¬IssuedStallReasonIndices) { + // In case there's any invalid stall reasons, we only collect valid ones. + // Invalid ones are swapped to the end of the list + std::vector validIndex(numStallReasons, false); + size_t numValidStalls = 0; + for (size_t i = 0; i < numStallReasons; i++) { + bool notIssued = std::string(stallReasonNames[i]).find("not_issued") != + std::string::npos; + std::string cuptiStallName = std::string(stallReasonNames[i]); + for (size_t j = 0; j < PCSamplingMetric::PCSamplingMetricKind::Count; j++) { + auto metricName = PCSamplingMetric().getValueName(j); + if (cuptiStallName.find(metricName) != std::string::npos) { + if (notIssued) + notIssuedStallReasonIndices.insert(stallReasonIndices[i]); + stallReasonIndexToMetricIndex[stallReasonIndices[i]] = j; + validIndex[i] = true; + numValidStalls++; + break; + } + } + } + int invalidIndex = -1; + for (size_t i = 0; i < numStallReasons; i++) { + if (invalidIndex == -1 && !validIndex[i]) { + invalidIndex = i; + } else if (invalidIndex != -1 && validIndex[i]) { + std::swap(stallReasonIndices[invalidIndex], stallReasonIndices[i]); + std::swap(stallReasonNames[invalidIndex], stallReasonNames[i]); + validIndex[invalidIndex] = true; + invalidIndex++; + } + } + return numValidStalls; +} + +#define CUPTI_CUDA12_4_VERSION 22 +#define CUPTI_CUDA12_4_PC_DATA_PADDING_SIZE sizeof(uint32_t) + +CUpti_PCSamplingData allocPCSamplingData(size_t collectNumPCs, + size_t numValidStallReasons) { + uint32_t libVersion = 0; + cupti::getVersion(&libVersion); + size_t pcDataSize = sizeof(CUpti_PCSamplingPCData); + // Since CUPTI 12.4, a new field (i.e., correlationId) is added to + // CUpti_PCSamplingPCData, which breaks the ABI compatibility. + // Instead of using workarounds, we emit an error message and exit the + // application. + if ((libVersion < CUPTI_CUDA12_4_VERSION && + CUPTI_API_VERSION >= CUPTI_CUDA12_4_VERSION) || + (libVersion >= CUPTI_CUDA12_4_VERSION && + CUPTI_API_VERSION < CUPTI_CUDA12_4_VERSION)) { + throw std::runtime_error( + "[PROTON] CUPTI API version: " + std::to_string(CUPTI_API_VERSION) + + " and CUPTI driver version: " + std::to_string(libVersion) + + " are not compatible. Please set the environment variable " + " TRITON_CUPTI_INCLUDE_PATH and TRITON_CUPTI_LIB_PATH to resolve the " + "problem."); + } + CUpti_PCSamplingData pcSamplingData{ + /*size=*/sizeof(CUpti_PCSamplingData), + /*collectNumPcs=*/collectNumPCs, + /*totalSamples=*/0, + /*droppedSamples=*/0, + /*totalNumPcs=*/0, + /*remainingNumPcs=*/0, + /*rangeId=*/0, + /*pPcData=*/ + static_cast( + std::calloc(collectNumPCs, sizeof(CUpti_PCSamplingPCData)))}; + for (size_t i = 0; i < collectNumPCs; ++i) { + pcSamplingData.pPcData[i].stallReason = + static_cast(std::calloc( + numValidStallReasons, sizeof(CUpti_PCSamplingStallReason))); + } + return pcSamplingData; +} + +void enablePCSampling(CUcontext context) { + CUpti_PCSamplingEnableParams params = { + /*size=*/CUpti_PCSamplingEnableParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + }; + cupti::pcSamplingEnable(¶ms); +} + +void disablePCSampling(CUcontext context) { + CUpti_PCSamplingDisableParams params = { + /*size=*/CUpti_PCSamplingDisableParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + }; + cupti::pcSamplingDisable(¶ms); +} + +void startPCSampling(CUcontext context) { + CUpti_PCSamplingStartParams params = { + /*size=*/CUpti_PCSamplingStartParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + }; + cupti::pcSamplingStart(¶ms); +} + +void stopPCSampling(CUcontext context) { + CUpti_PCSamplingStopParams params = { + /*size=*/CUpti_PCSamplingStopParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + }; + cupti::pcSamplingStop(¶ms); +} + +void getPCSamplingData(CUcontext context, + CUpti_PCSamplingData *pcSamplingData) { + CUpti_PCSamplingGetDataParams params = { + /*size=*/CUpti_PCSamplingGetDataParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*pcSamplingData=*/pcSamplingData, + }; + cupti::pcSamplingGetData(¶ms); +} + +void setConfigurationAttribute( + CUcontext context, + std::vector &configurationInfos) { + CUpti_PCSamplingConfigurationInfoParams infoParams = { + /*size=*/CUpti_PCSamplingConfigurationInfoParamsSize, + /*pPriv=*/NULL, + /*ctx=*/context, + /*numAttributes=*/configurationInfos.size(), + /*pPCSamplingConfigurationInfo=*/configurationInfos.data(), + }; + cupti::pcSamplingSetConfigurationAttribute(&infoParams); +} + +} // namespace + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureStallReasons() { + numStallReasons = getNumStallReasons(context); + std::tie(this->stallReasonNames, this->stallReasonIndices) = + getStallReasonNamesAndIndices(context, numStallReasons); + numValidStallReasons = matchStallReasonsToIndices( + numStallReasons, stallReasonNames, stallReasonIndices, + stallReasonIndexToMetricIndex, notIssuedStallReasonIndices); + CUpti_PCSamplingConfigurationInfo stallReasonInfo{}; + stallReasonInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON; + stallReasonInfo.attributeData.stallReasonData.stallReasonCount = + numValidStallReasons; + stallReasonInfo.attributeData.stallReasonData.pStallReasonIndex = + stallReasonIndices; + return stallReasonInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureSamplingPeriod() { + CUpti_PCSamplingConfigurationInfo samplingPeriodInfo{}; + samplingPeriodInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_PERIOD; + samplingPeriodInfo.attributeData.samplingPeriodData.samplingPeriod = + DefaultFrequency; + return samplingPeriodInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureSamplingBuffer() { + CUpti_PCSamplingConfigurationInfo samplingBufferInfo{}; + samplingBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_DATA_BUFFER; + this->pcSamplingData = + allocPCSamplingData(DataBufferPCCount, numValidStallReasons); + samplingBufferInfo.attributeData.samplingDataBufferData.samplingDataBuffer = + &this->pcSamplingData; + return samplingBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureScratchBuffer() { + CUpti_PCSamplingConfigurationInfo scratchBufferInfo{}; + scratchBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SCRATCH_BUFFER_SIZE; + scratchBufferInfo.attributeData.scratchBufferSizeData.scratchBufferSize = + ScratchBufferSize; + return scratchBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureHardwareBufferSize() { + CUpti_PCSamplingConfigurationInfo hardwareBufferInfo{}; + hardwareBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_HARDWARE_BUFFER_SIZE; + hardwareBufferInfo.attributeData.hardwareBufferSizeData.hardwareBufferSize = + HardwareBufferSize; + return hardwareBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureStartStopControl() { + CUpti_PCSamplingConfigurationInfo startStopControlInfo{}; + startStopControlInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL; + startStopControlInfo.attributeData.enableStartStopControlData + .enableStartStopControl = true; + return startStopControlInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureCollectionMode() { + CUpti_PCSamplingConfigurationInfo collectionModeInfo{}; + collectionModeInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_COLLECTION_MODE; + collectionModeInfo.attributeData.collectionModeData.collectionMode = + CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS; + return collectionModeInfo; +} + +void ConfigureData::initialize(CUcontext context) { + this->context = context; + cupti::getContextId(context, &contextId); + configurationInfos.emplace_back(configureStallReasons()); + configurationInfos.emplace_back(configureSamplingPeriod()); + configurationInfos.emplace_back(configureHardwareBufferSize()); + configurationInfos.emplace_back(configureScratchBuffer()); + configurationInfos.emplace_back(configureSamplingBuffer()); + configurationInfos.emplace_back(configureStartStopControl()); + configurationInfos.emplace_back(configureCollectionMode()); + setConfigurationAttribute(context, configurationInfos); +} + +ConfigureData *CuptiPCSampling::getConfigureData(uint32_t contextId) { + return &contextIdToConfigureData[contextId]; +} + +CubinData *CuptiPCSampling::getCubinData(uint64_t cubinCrc) { + return &(cubinCrcToCubinData[cubinCrc].first); +} + +void CuptiPCSampling::initialize(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() { return !contextInitialized.contain(contextId); }, + contextMutex, + [&]() { + enablePCSampling(context); + getConfigureData(contextId)->initialize(context); + contextInitialized.insert(contextId); + }); +} + +void CuptiPCSampling::start(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() -> bool { return !pcSamplingStarted; }, + pcSamplingMutex, + [&]() { + initialize(context); + // Ensure all previous operations are completed + cuda::ctxSynchronize(); + startPCSampling(context); + pcSamplingStarted = true; + }); +} + +void CuptiPCSampling::processPCSamplingData(ConfigureData *configureData, + uint64_t externId, bool isAPI) { + auto *pcSamplingData = &configureData->pcSamplingData; + auto &profiler = CuptiProfiler::instance(); + auto dataSet = profiler.getDataSet(); + // In the first round, we need to call getPCSamplingData to get the unsynced + // data from the hardware buffer + bool firstRound = true; + while (pcSamplingData->totalNumPcs > 0 || + pcSamplingData->remainingNumPcs > 0 || firstRound) { + // Handle data + for (size_t i = 0; i < pcSamplingData->totalNumPcs; ++i) { + auto *pcData = pcSamplingData->pPcData + i; + auto *cubinData = getCubinData(pcData->cubinCrc); + auto key = + CubinData::LineInfoKey{pcData->functionIndex, pcData->pcOffset}; + if (cubinData->lineInfo.find(key) == cubinData->lineInfo.end()) { + auto [lineNumber, fileName, dirName] = + getSassToSourceCorrelation(pcData->functionName, pcData->pcOffset, + cubinData->cubin, cubinData->cubinSize); + cubinData->lineInfo.try_emplace(key, lineNumber, + std::string(pcData->functionName), + dirName, fileName); + } + auto &lineInfo = cubinData->lineInfo[key]; + for (size_t j = 0; j < pcData->stallReasonCount; ++j) { + auto *stallReason = &pcData->stallReason[j]; + if (!configureData->stallReasonIndexToMetricIndex.count( + stallReason->pcSamplingStallReasonIndex)) + throw std::runtime_error("[PROTON] Invalid stall reason index"); + for (auto *data : dataSet) { + auto scopeId = externId; + if (isAPI) + scopeId = data->addOp(externId, lineInfo.functionName); + if (lineInfo.fileName.size()) + scopeId = data->addOp( + scopeId, lineInfo.dirName + "/" + lineInfo.fileName + ":" + + std::to_string(lineInfo.lineNumber) + "@" + + lineInfo.functionName); + auto metricKind = static_cast( + configureData->stallReasonIndexToMetricIndex + [stallReason->pcSamplingStallReasonIndex]); + auto samples = stallReason->samples; + auto stalledSamples = + configureData->notIssuedStallReasonIndices.count( + stallReason->pcSamplingStallReasonIndex) + ? 0 + : samples; + auto metric = std::make_shared(metricKind, samples, + stalledSamples); + data->addMetric(scopeId, metric); + } + } + } + if (pcSamplingData->remainingNumPcs > 0 || firstRound) { + getPCSamplingData(configureData->context, pcSamplingData); + firstRound = false; + } else + break; + } +} + +void CuptiPCSampling::stop(CUcontext context, uint64_t externId, bool isAPI) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() -> bool { return pcSamplingStarted; }, + pcSamplingMutex, + [&]() { + auto *configureData = getConfigureData(contextId); + stopPCSampling(context); + pcSamplingStarted = false; + processPCSamplingData(configureData, externId, isAPI); + }); +} + +void CuptiPCSampling::finalize(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + if (!contextInitialized.contain(contextId)) + return; + auto *configureData = getConfigureData(contextId); + contextIdToConfigureData.erase(contextId); + contextInitialized.erase(contextId); + disablePCSampling(context); +} + +void CuptiPCSampling::loadModule(const char *cubin, size_t cubinSize) { + auto cubinCrc = getCubinCrc(cubin, cubinSize); + auto *cubinData = getCubinData(cubinCrc); + cubinData->cubinCrc = cubinCrc; + cubinData->cubinSize = cubinSize; + cubinData->cubin = cubin; +} + +void CuptiPCSampling::unloadModule(const char *cubin, size_t cubinSize) { + // XXX: Unload module is supposed to be called in a thread safe manner + // i.e., no two threads will be calling unload module the same time + auto cubinCrc = getCubinCrc(cubin, cubinSize); + auto count = cubinCrcToCubinData[cubinCrc].second; + if (count > 1) + cubinCrcToCubinData[cubinCrc].second = count - 1; + else + cubinCrcToCubinData.erase(cubinCrc); +} + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp new file mode 100644 index 000000000..2c60b536d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -0,0 +1,438 @@ +#include "Profiler/Cupti/CuptiProfiler.h" +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Driver/Device.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Profiler/Cupti/CuptiPCSampling.h" +#include "Utility/Map.h" + +#include +#include +#include +#include + +namespace proton { + +template <> +thread_local GPUProfiler::ThreadState + GPUProfiler::threadState(CuptiProfiler::instance()); + +template <> +thread_local std::deque + GPUProfiler::Correlation::externIdQueue{}; + +namespace { + +std::shared_ptr convertActivityToMetric(CUpti_Activity *activity) { + std::shared_ptr metric; + switch (activity->kind) { + case CUPTI_ACTIVITY_KIND_KERNEL: + case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { + auto *kernel = reinterpret_cast(activity); + if (kernel->start < kernel->end) { + metric = std::make_shared( + static_cast(kernel->start), + static_cast(kernel->end), 1, + static_cast(kernel->deviceId), + static_cast(DeviceType::CUDA)); + } // else: not a valid kernel activity + break; + } + default: + break; + } + return metric; +} + +uint32_t +processActivityKernel(CuptiProfiler::CorrIdToExternIdMap &corrIdToExternId, + CuptiProfiler::ApiExternIdSet &apiExternIds, + std::set &dataSet, CUpti_Activity *activity) { + // Support CUDA >= 11.0 + auto *kernel = reinterpret_cast(activity); + auto correlationId = kernel->correlationId; + if (/*Not a valid context*/ !corrIdToExternId.contain(correlationId)) + return correlationId; + auto [parentId, numInstances] = corrIdToExternId.at(correlationId); + if (kernel->graphId == 0) { + // Non-graph kernels + for (auto *data : dataSet) { + auto scopeId = parentId; + if (apiExternIds.contain(scopeId)) { + // It's triggered by a CUDA op but not triton op + scopeId = data->addOp(parentId, kernel->name); + } + data->addMetric(scopeId, convertActivityToMetric(activity)); + } + } else { + // Graph kernels + // A single graph launch can trigger multiple kernels. + // Our solution is to construct the following maps: + // --- Application threads --- + // 1. graphId -> numKernels + // 2. graphExecId -> graphId + // --- CUPTI thread --- + // 3. corrId -> numKernels + for (auto *data : dataSet) { + auto externId = data->addOp(parentId, kernel->name); + data->addMetric(externId, convertActivityToMetric(activity)); + } + } + apiExternIds.erase(parentId); + --numInstances; + if (numInstances == 0) { + corrIdToExternId.erase(correlationId); + } else { + corrIdToExternId[correlationId].second = numInstances; + } + return correlationId; +} + +uint32_t processActivity(CuptiProfiler::CorrIdToExternIdMap &corrIdToExternId, + CuptiProfiler::ApiExternIdSet &apiExternIds, + std::set &dataSet, CUpti_Activity *activity) { + auto correlationId = 0; + switch (activity->kind) { + case CUPTI_ACTIVITY_KIND_KERNEL: + case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { + correlationId = processActivityKernel(corrIdToExternId, apiExternIds, + dataSet, activity); + break; + } + default: + break; + } + return correlationId; +} + +void setRuntimeCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_RUNTIME_API, id) + + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_ptsz_v7000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_ptsz_v7000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernelExC_v11060); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernelExC_ptsz_v11060); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000); + CALLBACK_ENABLE( + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_ptsz_v9000); + CALLBACK_ENABLE( + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaGraphLaunch_v10000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaGraphLaunch_ptsz_v10000); + +#undef CALLBACK_ENABLE +} + +void setDriverCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_DRIVER_API, id) + + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunch); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz); +#undef CALLBACK_ENABLE +} + +void setGraphCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { + +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_RESOURCE, id) + + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHEXEC_CREATED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPH_DESTROY_STARTING); +#undef CALLBACK_ENABLE +} + +void setResourceCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_RESOURCE, id) + + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_MODULE_LOADED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_MODULE_UNLOAD_STARTING); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_CONTEXT_CREATED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING); +#undef CALLBACK_ENABLE +} + +bool isDriverAPILaunch(CUpti_CallbackId cbId) { + return cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz; +} + +} // namespace + +struct CuptiProfiler::CuptiProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + CuptiProfilerPimpl(CuptiProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) {} + virtual ~CuptiProfilerPimpl() = default; + + void setLibPath(const std::string &libPath) override { + cupti::setLibPath(libPath); + } + void doStart() override; + void doFlush() override; + void doStop() override; + + static void allocBuffer(uint8_t **buffer, size_t *bufferSize, + size_t *maxNumRecords); + static void completeBuffer(CUcontext context, uint32_t streamId, + uint8_t *buffer, size_t size, size_t validSize); + static void callbackFn(void *userData, CUpti_CallbackDomain domain, + CUpti_CallbackId cbId, const void *cbData); + + static constexpr size_t AlignSize = 8; + static constexpr size_t BufferSize = 64 * 1024 * 1024; + static constexpr size_t AttributeSize = sizeof(size_t); + + CUpti_SubscriberHandle subscriber{}; + CuptiPCSampling pcSampling; + + ThreadSafeMap> + graphIdToNumInstances; + ThreadSafeMap> + graphExecIdToGraphId; +}; + +void CuptiProfiler::CuptiProfilerPimpl::allocBuffer(uint8_t **buffer, + size_t *bufferSize, + size_t *maxNumRecords) { + *buffer = static_cast(aligned_alloc(AlignSize, BufferSize)); + if (*buffer == nullptr) { + throw std::runtime_error("[PROTON] aligned_alloc failed"); + } + *bufferSize = BufferSize; + *maxNumRecords = 0; +} + +void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, + uint32_t streamId, + uint8_t *buffer, + size_t size, + size_t validSize) { + CuptiProfiler &profiler = threadState.profiler; + auto dataSet = profiler.getDataSet(); + uint32_t maxCorrelationId = 0; + CUptiResult status; + CUpti_Activity *activity = nullptr; + do { + status = cupti::activityGetNextRecord(buffer, validSize, &activity); + if (status == CUPTI_SUCCESS) { + auto correlationId = + processActivity(profiler.correlation.corrIdToExternId, + profiler.correlation.apiExternIds, dataSet, activity); + maxCorrelationId = std::max(maxCorrelationId, correlationId); + } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { + break; + } else { + throw std::runtime_error("[PROTON] cupti::activityGetNextRecord failed"); + } + } while (true); + + std::free(buffer); + + profiler.correlation.complete(maxCorrelationId); +} + +void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbId, + const void *cbData) { + CuptiProfiler &profiler = threadState.profiler; + if (domain == CUPTI_CB_DOMAIN_RESOURCE) { + auto *resourceData = + static_cast(const_cast(cbData)); + auto *pImpl = dynamic_cast(profiler.pImpl.get()); + if (cbId == CUPTI_CBID_RESOURCE_MODULE_LOADED) { + auto *moduleResource = static_cast( + resourceData->resourceDescriptor); + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.loadModule(moduleResource->pCubin, + moduleResource->cubinSize); + } + } else if (cbId == CUPTI_CBID_RESOURCE_MODULE_UNLOAD_STARTING) { + auto *moduleResource = static_cast( + resourceData->resourceDescriptor); + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.unloadModule(moduleResource->pCubin, + moduleResource->cubinSize); + } + } else if (cbId == CUPTI_CBID_RESOURCE_CONTEXT_CREATED) { + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.initialize(resourceData->context); + } + } else if (cbId == CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING) { + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.finalize(resourceData->context); + } + } else { + auto *graphData = + static_cast(resourceData->resourceDescriptor); + uint32_t graphId = 0; + uint32_t graphExecId = 0; + if (graphData->graph) + cupti::getGraphId(graphData->graph, &graphId); + if (graphData->graphExec) + cupti::getGraphExecId(graphData->graphExec, &graphExecId); + if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED || + cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED) { + if (!pImpl->graphIdToNumInstances.contain(graphId)) + pImpl->graphIdToNumInstances[graphId] = 1; + else + pImpl->graphIdToNumInstances[graphId]++; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING) { + pImpl->graphIdToNumInstances[graphId]--; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_CREATED) { + pImpl->graphExecIdToGraphId[graphExecId] = graphId; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING) { + pImpl->graphExecIdToGraphId.erase(graphExecId); + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPH_DESTROY_STARTING) { + pImpl->graphIdToNumInstances.erase(graphId); + } + } + } else { + const CUpti_CallbackData *callbackData = + static_cast(cbData); + auto *pImpl = dynamic_cast(profiler.pImpl.get()); + if (callbackData->callbackSite == CUPTI_API_ENTER) { + threadState.enterOp(); + size_t numInstances = 1; + if (cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz) { + auto graphExec = static_cast( + callbackData->functionParams) + ->hGraph; + uint32_t graphExecId = 0; + cupti::getGraphExecId(graphExec, &graphExecId); + numInstances = std::numeric_limits::max(); + auto findGraph = false; + if (pImpl->graphExecIdToGraphId.contain(graphExecId)) { + auto graphId = pImpl->graphExecIdToGraphId[graphExecId]; + if (pImpl->graphIdToNumInstances.contain(graphId)) { + numInstances = pImpl->graphIdToNumInstances[graphId]; + findGraph = true; + } + } + if (!findGraph) + std::cerr << "[PROTON] Cannot find graph for graphExecId: " + << graphExecId + << ", and t may cause memory leak. To avoid this problem, " + "please start profiling before the graph is created." + << std::endl; + } + profiler.correlation.correlate(callbackData->correlationId, numInstances); + if (profiler.isPCSamplingEnabled() && isDriverAPILaunch(cbId)) { + pImpl->pcSampling.start(callbackData->context); + } + } else if (callbackData->callbackSite == CUPTI_API_EXIT) { + if (profiler.isPCSamplingEnabled() && isDriverAPILaunch(cbId)) { + // XXX: Conservatively stop every GPU kernel for now + auto scopeId = profiler.correlation.externIdQueue.back(); + pImpl->pcSampling.stop( + callbackData->context, scopeId, + profiler.correlation.apiExternIds.contain(scopeId)); + } + threadState.exitOp(); + profiler.correlation.submit(callbackData->correlationId); + } + } +} + +void CuptiProfiler::CuptiProfilerPimpl::doStart() { + cupti::subscribe(&subscriber, callbackFn, nullptr); + if (profiler.isPCSamplingEnabled()) { + setResourceCallbacks(subscriber, /*enable=*/true); + // Continuous PC sampling is not compatible with concurrent kernel profiling + cupti::activityEnable(CUPTI_ACTIVITY_KIND_KERNEL); + } else { + cupti::activityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + } + cupti::activityRegisterCallbacks(allocBuffer, completeBuffer); + setGraphCallbacks(subscriber, /*enable=*/true); + setRuntimeCallbacks(subscriber, /*enable=*/true); + setDriverCallbacks(subscriber, /*enable=*/true); +} + +void CuptiProfiler::CuptiProfilerPimpl::doFlush() { + // cuptiActivityFlushAll returns the activity records associated with all + // contexts/streams. + // This is a blocking call but it doesn’t issue any CUDA synchronization calls + // implicitly thus it’s not guaranteed that all activities are completed on + // the underlying devices. + // We do an "opportunistic" synchronization here to try to ensure that all + // activities are completed on the current context. + // If the current context is not set, we don't do any synchronization. + CUcontext cuContext = nullptr; + cuda::ctxGetCurrent(&cuContext); + if (cuContext) { + cuda::ctxSynchronize(); + } + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepMs=*/10, + /*flush=*/[]() { + cupti::activityFlushAll( + /*flag=*/0); + }); + // CUPTI_ACTIVITY_FLAG_FLUSH_FORCED is used to ensure that even incomplete + // activities are flushed so that the next profiling session can start with + // new activities. + cupti::activityFlushAll(/*flag=*/CUPTI_ACTIVITY_FLAG_FLUSH_FORCED); +} + +void CuptiProfiler::CuptiProfilerPimpl::doStop() { + if (profiler.isPCSamplingEnabled()) { + profiler.disablePCSampling(); + CUcontext cuContext = nullptr; + cuda::ctxGetCurrent(&cuContext); + if (cuContext) + pcSampling.finalize(cuContext); + setResourceCallbacks(subscriber, /*enable=*/false); + cupti::activityDisable(CUPTI_ACTIVITY_KIND_KERNEL); + } else { + cupti::activityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + } + setGraphCallbacks(subscriber, /*enable=*/false); + setRuntimeCallbacks(subscriber, /*enable=*/false); + setDriverCallbacks(subscriber, /*enable=*/false); + cupti::unsubscribe(subscriber); + cupti::finalize(); +} + +CuptiProfiler::CuptiProfiler() { + pImpl = std::make_unique(*this); +} + +CuptiProfiler::~CuptiProfiler() = default; + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp new file mode 100644 index 000000000..f5d66907e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -0,0 +1,401 @@ +#include "Profiler/Roctracer/RoctracerProfiler.h" +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Driver/GPU/HipApi.h" +#include "Driver/GPU/HsaApi.h" +#include "Driver/GPU/RoctracerApi.h" + +#include "hip/amd_detail/hip_runtime_prof.h" +#include "roctracer/roctracer_ext.h" +#include "roctracer/roctracer_hip.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace proton { + +template <> +thread_local GPUProfiler::ThreadState + GPUProfiler::threadState(RoctracerProfiler::instance()); + +template <> +thread_local std::deque + GPUProfiler::Correlation::externIdQueue{}; + +namespace { + +class DeviceInfo : public Singleton { +public: + DeviceInfo() = default; + int mapDeviceId(int id) { + // Lazy initialization of device offset by calling hip API. + // Otherwise on nvidia platforms, the HSA call will fail because of no + // available libraries. + std::call_once(deviceOffsetFlag, [this]() { initDeviceOffset(); }); + return id - deviceOffset; + } + +private: + void initDeviceOffset() { + int dc = 0; + auto ret = hip::getDeviceCount(&dc); + hsa::iterateAgents( + [](hsa_agent_t agent, void *data) { + auto &offset = *static_cast(data); + int nodeId; + hsa::agentGetInfo( + agent, + static_cast(HSA_AMD_AGENT_INFO_DRIVER_NODE_ID), + &nodeId); + int deviceType; + hsa::agentGetInfo( + agent, static_cast(HSA_AGENT_INFO_DEVICE), + &deviceType); + if ((nodeId < offset) && (deviceType == HSA_DEVICE_TYPE_GPU)) + offset = nodeId; + + return HSA_STATUS_SUCCESS; + }, + &deviceOffset); + } + + std::once_flag deviceOffsetFlag; + int deviceOffset = 0x7fffffff; +}; + +std::shared_ptr +convertActivityToMetric(const roctracer_record_t *activity) { + std::shared_ptr metric; + switch (activity->kind) { + case kHipVdiCommandTask: + case kHipVdiCommandKernel: { + if (activity->begin_ns < activity->end_ns) { + metric = std::make_shared( + static_cast(activity->begin_ns), + static_cast(activity->end_ns), 1, + static_cast( + DeviceInfo::instance().mapDeviceId(activity->device_id)), + static_cast(DeviceType::HIP)); + } + break; + } + default: + break; + } + return metric; +} + +void processActivityKernel( + RoctracerProfiler::CorrIdToExternIdMap &corrIdToExternId, size_t externId, + std::set &dataSet, const roctracer_record_t *activity, bool isAPI, + bool isGraph) { + if (externId == Scope::DummyScopeId) + return; + auto correlationId = activity->correlation_id; + auto [parentId, numInstances] = corrIdToExternId.at(correlationId); + if (!isGraph) { + for (auto *data : dataSet) { + auto scopeId = parentId; + if (isAPI) + scopeId = data->addOp(parentId, activity->kernel_name); + data->addMetric(scopeId, convertActivityToMetric(activity)); + } + } else { + // Graph kernels + // A single grpah launch can trigger multiple kernels. + // Our solution is to construct the following maps: + // --- Application threads --- + // 1. Graph -> numKernels + // 2. GraphExec -> Graph + // --- Roctracer thread --- + // 3. corrId -> numKernels + for (auto *data : dataSet) { + auto externId = data->addOp(parentId, activity->kernel_name); + data->addMetric(externId, convertActivityToMetric(activity)); + } + } + --numInstances; + if (numInstances == 0) { + corrIdToExternId.erase(correlationId); + } else { + corrIdToExternId[correlationId].second = numInstances; + } + return; +} + +void processActivity(RoctracerProfiler::CorrIdToExternIdMap &corrIdToExternId, + RoctracerProfiler::ApiExternIdSet &apiExternIds, + size_t externId, std::set &dataSet, + const roctracer_record_t *record, bool isAPI, + bool isGraph) { + switch (record->kind) { + case kHipVdiCommandTask: + case kHipVdiCommandKernel: { + processActivityKernel(corrIdToExternId, externId, dataSet, record, isAPI, + isGraph); + break; + } + default: + break; + } +} + +} // namespace + +namespace { + +std::pair matchKernelCbId(uint32_t cbId) { + bool isRuntimeApi = false; + bool isDriverApi = false; + switch (cbId) { + // TODO: switch to directly subscribe the APIs + case HIP_API_ID_hipStreamBeginCapture: + case HIP_API_ID_hipStreamEndCapture: + case HIP_API_ID_hipExtLaunchKernel: + case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: + case HIP_API_ID_hipExtModuleLaunchKernel: + case HIP_API_ID_hipHccModuleLaunchKernel: + case HIP_API_ID_hipLaunchCooperativeKernel: + case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: + case HIP_API_ID_hipLaunchKernel: + case HIP_API_ID_hipModuleLaunchKernel: + case HIP_API_ID_hipGraphLaunch: + case HIP_API_ID_hipModuleLaunchCooperativeKernel: + case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: + case HIP_API_ID_hipGraphExecDestroy: + case HIP_API_ID_hipGraphInstantiateWithFlags: + case HIP_API_ID_hipGraphInstantiate: { + isRuntimeApi = true; + break; + } + default: + break; + } + return std::make_pair(isRuntimeApi, isDriverApi); +} + +} // namespace + +struct RoctracerProfiler::RoctracerProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + RoctracerProfilerPimpl(RoctracerProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) {} + virtual ~RoctracerProfilerPimpl() = default; + + void setLibPath(const std::string &libPath) override {} + void doStart() override; + void doFlush() override; + void doStop() override; + + static void apiCallback(uint32_t domain, uint32_t cid, + const void *callbackData, void *arg); + static void activityCallback(const char *begin, const char *end, void *arg); + + static constexpr size_t BufferSize = 64 * 1024 * 1024; + + ThreadSafeMap> + CorrIdToIsHipGraph; + + ThreadSafeMap> + GraphExecToGraph; + + ThreadSafeMap> + GraphToNumInstances; + + ThreadSafeMap> + StreamToCaptureCount; + + ThreadSafeMap> + StreamToCapture; +}; + +void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback( + uint32_t domain, uint32_t cid, const void *callbackData, void *arg) { + auto [isRuntimeAPI, isDriverAPI] = matchKernelCbId(cid); + + if (!(isRuntimeAPI || isDriverAPI)) { + return; + } + + auto &profiler = + dynamic_cast(RoctracerProfiler::instance()); + auto *pImpl = dynamic_cast( + profiler.pImpl.get()); + if (domain == ACTIVITY_DOMAIN_HIP_API) { + const hip_api_data_t *data = (const hip_api_data_t *)(callbackData); + if (data->phase == ACTIVITY_API_PHASE_ENTER) { + // Valid context and outermost level of the kernel launch + threadState.enterOp(); + size_t numInstances = 1; + if (cid == HIP_API_ID_hipGraphLaunch) { + pImpl->CorrIdToIsHipGraph[data->correlation_id] = true; + hipGraphExec_t GraphExec = data->args.hipGraphLaunch.graphExec; + numInstances = std::numeric_limits::max(); + bool findGraph = false; + if (pImpl->GraphExecToGraph.contain(GraphExec)) { + hipGraph_t Graph = pImpl->GraphExecToGraph[GraphExec]; + if (pImpl->GraphToNumInstances.contain(Graph)) { + numInstances = pImpl->GraphToNumInstances[Graph]; + findGraph = true; + } + } + if (!findGraph) + std::cerr + << "[PROTON] Cannot find graph and it may cause a memory leak." + "To avoid this problem, please start profiling before the " + "graph is created." + << std::endl; + } + profiler.correlation.correlate(data->correlation_id, numInstances); + } else if (data->phase == ACTIVITY_API_PHASE_EXIT) { + switch (cid) { + case HIP_API_ID_hipStreamBeginCapture: { + hipStream_t Stream = data->args.hipStreamBeginCapture.stream; + pImpl->StreamToCaptureCount[Stream] = 0; + pImpl->StreamToCapture[Stream] = true; + break; + } + case HIP_API_ID_hipStreamEndCapture: { + hipGraph_t Graph = *(data->args.hipStreamEndCapture.pGraph); + hipStream_t Stream = data->args.hipStreamEndCapture.stream; + // How many times did we capture a kernel launch for this stream + uint32_t StreamCaptureCount = pImpl->StreamToCaptureCount[Stream]; + pImpl->GraphToNumInstances[Graph] = StreamCaptureCount; + pImpl->StreamToCapture.erase(Stream); + } + case HIP_API_ID_hipLaunchKernel: { + hipStream_t Stream = data->args.hipLaunchKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipExtLaunchKernel: { + hipStream_t Stream = data->args.hipExtLaunchKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipLaunchCooperativeKernel: { + hipStream_t Stream = data->args.hipLaunchCooperativeKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipModuleLaunchKernel: { + hipStream_t Stream = data->args.hipModuleLaunchKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipModuleLaunchCooperativeKernel: { + hipStream_t Stream = data->args.hipModuleLaunchCooperativeKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipGraphInstantiateWithFlags: { + hipGraph_t Graph = data->args.hipGraphInstantiateWithFlags.graph; + hipGraphExec_t GraphExec = + *(data->args.hipGraphInstantiateWithFlags.pGraphExec); + pImpl->GraphExecToGraph[GraphExec] = Graph; + break; + } + case HIP_API_ID_hipGraphInstantiate: { + hipGraph_t Graph = data->args.hipGraphInstantiate.graph; + hipGraphExec_t GraphExec = *(data->args.hipGraphInstantiate.pGraphExec); + pImpl->GraphExecToGraph[GraphExec] = Graph; + break; + } + } + threadState.exitOp(); + // Track outstanding op for flush + profiler.correlation.submit(data->correlation_id); + } + } +} + +void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( + const char *begin, const char *end, void *arg) { + auto &profiler = + dynamic_cast(RoctracerProfiler::instance()); + auto *pImpl = dynamic_cast( + profiler.pImpl.get()); + auto dataSet = profiler.getDataSet(); + auto &correlation = profiler.correlation; + + const roctracer_record_t *record = + reinterpret_cast(begin); + const roctracer_record_t *endRecord = + reinterpret_cast(end); + uint64_t maxCorrelationId = 0; + + while (record != endRecord) { + // Log latest completed correlation id. Used to ensure we have flushed all + // data on stop + maxCorrelationId = + std::max(maxCorrelationId, record->correlation_id); + // TODO(Keren): Roctracer doesn't support cuda graph yet. + auto externId = + correlation.corrIdToExternId.contain(record->correlation_id) + ? correlation.corrIdToExternId.at(record->correlation_id).first + : Scope::DummyScopeId; + auto isAPI = correlation.apiExternIds.contain(externId); + bool isGraph = pImpl->CorrIdToIsHipGraph.contain(record->correlation_id); + processActivity(correlation.corrIdToExternId, correlation.apiExternIds, + externId, dataSet, record, isAPI, isGraph); + // Track correlation ids from the same stream and erase those < + // correlationId + correlation.corrIdToExternId.erase(record->correlation_id); + correlation.apiExternIds.erase(externId); + roctracer::getNextRecord(record, &record); + } + correlation.complete(maxCorrelationId); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doStart() { + roctracer::enableDomainCallback(ACTIVITY_DOMAIN_HIP_API, apiCallback, + nullptr); + // Activity Records + roctracer_properties_t properties{0}; + properties.buffer_size = BufferSize; + properties.buffer_callback_fun = activityCallback; + roctracer::openPool(&properties); + roctracer::enableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); + roctracer::start(); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doFlush() { + // Implement reliable flushing. + // Wait for all dispatched ops to be reported. + std::ignore = hip::deviceSynchronize(); + // If flushing encounters an activity record still being written, flushing + // stops. Use a subsequent flush when the record has completed being written + // to resume the flush. + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepMs=*/10, /*flush=*/ + []() { roctracer::flushActivity(); }); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doStop() { + roctracer::stop(); + roctracer::disableDomainCallback(ACTIVITY_DOMAIN_HIP_API); + roctracer::disableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); + roctracer::closePool(); +} + +RoctracerProfiler::RoctracerProfiler() { + pImpl = std::make_unique(*this); +} + +RoctracerProfiler::~RoctracerProfiler() = default; + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Session/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Session/CMakeLists.txt new file mode 100644 index 000000000..f84eb610a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Session/CMakeLists.txt @@ -0,0 +1,3 @@ +add_proton_library(ProtonSession + Session.cpp +) diff --git a/third_party/enflame/include/triton/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Session/Session.cpp new file mode 100644 index 000000000..26f0fbf89 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/csrc/lib/Session/Session.cpp @@ -0,0 +1,243 @@ +#include "Session/Session.h" +#include "Context/Python.h" +#include "Context/Shadow.h" +#include "Data/TreeData.h" +#include "Profiler/Cupti/CuptiProfiler.h" +#include "Profiler/Roctracer/RoctracerProfiler.h" +#include "Utility/String.h" + +namespace proton { + +namespace { +Profiler *getProfiler(const std::string &name, const std::string &path) { + if (proton::toLower(name) == "cupti") { + return &CuptiProfiler::instance().setLibPath(path); + } + if (proton::toLower(name) == "cupti_pcsampling") { + return &CuptiProfiler::instance().setLibPath(path).enablePCSampling(); + } + if (proton::toLower(name) == "roctracer") { + return &RoctracerProfiler::instance(); + } + throw std::runtime_error("Unknown profiler: " + name); +} + +std::unique_ptr makeData(const std::string &dataName, + const std::string &path, + ContextSource *contextSource) { + if (toLower(dataName) == "tree") { + return std::make_unique(path, contextSource); + } + throw std::runtime_error("Unknown data: " + dataName); +} + +std::unique_ptr +makeContextSource(const std::string &contextSourceName) { + if (toLower(contextSourceName) == "shadow") { + return std::make_unique(); + } else if (toLower(contextSourceName) == "python") { + return std::make_unique(); + } + throw std::runtime_error("Unknown context source: " + contextSourceName); +} + +void throwIfSessionNotInitialized( + const std::map> &sessions, + size_t sessionId) { + if (!sessions.count(sessionId)) { + throw std::runtime_error("Session has not been initialized: " + + std::to_string(sessionId)); + } +} + +} // namespace + +void Session::activate() { + profiler->start(); + profiler->flush(); + profiler->registerData(data.get()); +} + +void Session::deactivate() { + profiler->flush(); + profiler->unregisterData(data.get()); + data->clear(); +} + +void Session::finalize(OutputFormat outputFormat) { + profiler->stop(); + data->dump(outputFormat); +} + +std::unique_ptr SessionManager::makeSession( + size_t id, const std::string &path, const std::string &profilerName, + const std::string &profilerPath, const std::string &contextSourceName, + const std::string &dataName) { + auto profiler = getProfiler(profilerName, profilerPath); + auto contextSource = makeContextSource(contextSourceName); + auto data = makeData(dataName, path, contextSource.get()); + auto *session = new Session(id, path, profiler, std::move(contextSource), + std::move(data)); + return std::unique_ptr(session); +} + +void SessionManager::activateSession(size_t sessionId) { + std::lock_guard lock(mutex); + activateSessionImpl(sessionId); +} + +void SessionManager::activateAllSessions() { + std::lock_guard lock(mutex); + for (auto iter : sessionActive) { + activateSessionImpl(iter.first); + } +} + +void SessionManager::deactivateSession(size_t sessionId) { + std::lock_guard lock(mutex); + deActivateSessionImpl(sessionId); +} + +void SessionManager::deactivateAllSessions() { + std::lock_guard lock(mutex); + for (auto iter : sessionActive) { + deActivateSessionImpl(iter.first); + } +} + +void SessionManager::activateSessionImpl(size_t sessionId) { + throwIfSessionNotInitialized(sessions, sessionId); + if (sessionActive[sessionId]) + return; + sessionActive[sessionId] = true; + sessions[sessionId]->activate(); + registerInterface(sessionId, scopeInterfaceCounts); + registerInterface(sessionId, opInterfaceCounts); + registerInterface(sessionId, contextSourceCounts); +} + +void SessionManager::deActivateSessionImpl(size_t sessionId) { + throwIfSessionNotInitialized(sessions, sessionId); + if (!sessionActive[sessionId]) { + return; + } + sessionActive[sessionId] = false; + sessions[sessionId]->deactivate(); + unregisterInterface(sessionId, scopeInterfaceCounts); + unregisterInterface(sessionId, opInterfaceCounts); + unregisterInterface(sessionId, contextSourceCounts); +} + +void SessionManager::removeSession(size_t sessionId) { + if (!hasSession(sessionId)) { + return; + } + auto path = sessions[sessionId]->path; + sessionPaths.erase(path); + sessionActive.erase(sessionId); + sessions.erase(sessionId); +} + +size_t SessionManager::addSession(const std::string &path, + const std::string &profilerName, + const std::string &profilerPath, + const std::string &contextSourceName, + const std::string &dataName) { + std::lock_guard lock(mutex); + if (hasSession(path)) { + auto sessionId = getSessionId(path); + activateSessionImpl(sessionId); + return sessionId; + } + auto sessionId = nextSessionId++; + sessionPaths[path] = sessionId; + sessions[sessionId] = makeSession(sessionId, path, profilerName, profilerPath, + contextSourceName, dataName); + return sessionId; +} + +void SessionManager::finalizeSession(size_t sessionId, + OutputFormat outputFormat) { + std::lock_guard lock(mutex); + if (!hasSession(sessionId)) { + return; + } + deActivateSessionImpl(sessionId); + sessions[sessionId]->finalize(outputFormat); + removeSession(sessionId); +} + +void SessionManager::finalizeAllSessions(OutputFormat outputFormat) { + std::lock_guard lock(mutex); + auto sessionIds = std::vector{}; + for (auto &[sessionId, session] : sessions) { + deActivateSessionImpl(sessionId); + session->finalize(outputFormat); + sessionIds.push_back(sessionId); + } + for (auto sessionId : sessionIds) { + removeSession(sessionId); + } +} + +void SessionManager::enterScope(const Scope &scope) { + std::lock_guard lock(mutex); + for (auto iter : scopeInterfaceCounts) { + auto [scopeInterface, count] = iter; + if (count > 0) { + scopeInterface->enterScope(scope); + } + } +} + +void SessionManager::exitScope(const Scope &scope) { + std::lock_guard lock(mutex); + for (auto iter : scopeInterfaceCounts) { + auto [scopeInterface, count] = iter; + if (count > 0) { + scopeInterface->exitScope(scope); + } + } +} + +void SessionManager::enterOp(const Scope &scope) { + std::lock_guard lock(mutex); + for (auto iter : opInterfaceCounts) { + auto [opInterface, count] = iter; + if (count > 0) { + opInterface->enterOp(scope); + } + } +} + +void SessionManager::exitOp(const Scope &scope) { + std::lock_guard lock(mutex); + for (auto iter : opInterfaceCounts) { + auto [opInterface, count] = iter; + if (count > 0) { + opInterface->exitOp(scope); + } + } +} + +void SessionManager::addMetrics( + size_t scopeId, const std::map &metrics) { + std::lock_guard lock(mutex); + for (auto [sessionId, active] : sessionActive) { + if (active) { + sessions[sessionId]->data->addMetrics(scopeId, metrics); + } + } +} + +void SessionManager::setState(std::optional context) { + std::lock_guard lock(mutex); + for (auto iter : contextSourceCounts) { + auto [contextSource, count] = iter; + if (count > 0) { + contextSource->setState(context); + } + } +} + +} // namespace proton diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/CMakeLists.txt new file mode 100644 index 000000000..cfa593887 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/CMakeLists.txt @@ -0,0 +1,8 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc) + target_link_libraries(TritonProton PRIVATE ProtonIR Python3::Module pybind11::headers) +endif() diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/include/CMakeLists.txt new file mode 100644 index 000000000..0ca0f41c5 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f18c30ba1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 000000000..4645b0ebc --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS ProtonOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc) +add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc) +add_public_tablegen_target(ProtonTableGen) + +set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td) +mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(ProtonAttrDefsIncGen) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h new file mode 100644 index 000000000..680a205f0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h @@ -0,0 +1,23 @@ +#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_ +#define TRITON_DIALECT_PROTON_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc" +#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace proton {} // namespace proton +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_ diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td new file mode 100644 index 000000000..d469fbb35 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td @@ -0,0 +1,12 @@ +#ifndef PROTON_ATTRDEFS +#define PROTON_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "ProtonDialect.td" + +class Proton_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif // PROTON_ATTRDEFS diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td new file mode 100644 index 000000000..245f2e09a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td @@ -0,0 +1,18 @@ +#ifndef PROTON_DIALECT +#define PROTON_DIALECT + +include "mlir/IR/OpBase.td" + +def Proton_Dialect : Dialect { + let name = "proton"; + let cppNamespace = "::mlir::triton::proton"; + + let description = [{ + Proton Dialect provides core ops for building third-party compiler-based + performance profiling and analysis tools. + }]; + + let dependentDialects = []; +} + +#endif diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td new file mode 100644 index 000000000..d18a48d5d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td @@ -0,0 +1,65 @@ +#ifndef PROTON_OPS +#define PROTON_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "ProtonDialect.td" +include "ProtonAttrDefs.td" + +class TT_Proton_Op traits = []> : + Op { +} + +// Proton profiling metric. +def MetricAttr : I32EnumAttr< + "Metric", "", + [ + I32EnumAttrCase<"CYCLE", 0, "cycle">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +// Proton profiling granularity. +def GranularityAttr : I32EnumAttr< + "Granularity", "", + [ + I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">, + I32EnumAttrCase<"WARP", 1, "warp">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods]> { + let summary = "Record a GPU hardware event"; + + let description = [{ + The operator records GPU events from performance counters. + Currently only cycle counter is supported. + + Example: + + ```mlir + proton.record() {isStart = true, regionId = 4 : i32} + ... + proton.record() {isStart = false, regionId = 4 : i32} + ... + proton.record() {isStart = true, regionId = 1 : i32, granularity = 1 : i32} + ... + proton.record() {isStart = false, regionId = 1 : i32, granularity = 1 : i32} + ``` + }]; + let arguments = ( + ins BoolAttr: $isStart, + ConfinedAttr:$regionId, + DefaultValuedAttr:$metric, + DefaultValuedAttr:$granularity + ); + let assemblyFormat = " `(` operands `)` attr-dict"; +} + +#endif // PROTON_OPS diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h b/third_party/enflame/include/triton/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h new file mode 100644 index 000000000..a123e8fbf --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h @@ -0,0 +1,16 @@ +#ifndef TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +namespace mlir::triton { +class TargetInfoBase; +namespace proton { +void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +} // namespace proton +} // namespace mlir::triton + +#endif diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/lib/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/lib/CMakeLists.txt new file mode 100644 index 000000000..a224fd6f2 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/lib/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Dialect) +add_subdirectory(TritonProtonToLLVM) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..f18c30ba1 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 000000000..5eea5cb3c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(ProtonIR + Dialect.cpp + Ops.cpp + + DEPENDS + ProtonTableGen + ProtonAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp new file mode 100644 index 000000000..60c285265 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp @@ -0,0 +1,25 @@ +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "Dialect/Proton/IR/Dialect.h" +#include "Dialect/Proton/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::proton; + +void mlir::triton::proton::ProtonDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/Proton/IR/Ops.cpp.inc" + >(); +} + +#define GET_ATTRDEF_CLASSES +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp new file mode 100644 index 000000000..1a0799aea --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp @@ -0,0 +1,33 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#define GET_OP_CLASSES +#include "Dialect/Proton/IR/Ops.cpp.inc" +#include "Dialect/Proton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { +namespace proton { + +// -- RecordOp -- +void RecordOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace proton +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt b/third_party/enflame/include/triton/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt new file mode 100644 index 000000000..84b134fda --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt @@ -0,0 +1,6 @@ +add_triton_library(TritonProtonToLLVM + RecordOpToLLVM.cpp + + LINK_LIBS PUBLIC + ProtonIR +) diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp b/third_party/enflame/include/triton/third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp new file mode 100644 index 000000000..9b0b08ed7 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp @@ -0,0 +1,41 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" + +namespace { + +struct RecordOpConversion + : public ConvertOpToLLVMPattern { + explicit RecordOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::RecordOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.eraseOp(op); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::proton::populateRecordOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/enflame/include/triton/third_party/proton/dialect/triton_proton.cc b/third_party/enflame/include/triton/third_party/proton/dialect/triton_proton.cc new file mode 100644 index 000000000..804653979 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/dialect/triton_proton.cc @@ -0,0 +1,20 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include +#include +#include + +namespace py = pybind11; + +void init_triton_proton(py::module &&m) { + auto passes = m.def_submodule("passes"); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); +} diff --git a/third_party/enflame/include/triton/third_party/proton/proton/_C/include b/third_party/enflame/include/triton/third_party/proton/proton/_C/include new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/include/triton/third_party/proton/proton/__init__.py b/third_party/enflame/include/triton/third_party/proton/proton/__init__.py new file mode 100644 index 000000000..161235e80 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/__init__.py @@ -0,0 +1,11 @@ +# ruff: noqa +from .scope import scope, cpu_timed_scope, enter_scope, exit_scope +from .state import state, enter_state, exit_state +from .profile import ( + start, + activate, + deactivate, + finalize, + profile, + DEFAULT_PROFILE_NAME, +) diff --git a/third_party/enflame/include/triton/third_party/proton/proton/flags.py b/third_party/enflame/include/triton/third_party/proton/proton/flags.py new file mode 100644 index 000000000..37c75b243 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/flags.py @@ -0,0 +1,33 @@ +""" +This file contains the global flags used in the proton package. +""" + +# Whether to enable profiling. Default is False. +profiling_on = False +# Whether the script is run from the command line. Default is False. +command_line = False + + +def set_profiling_on(): + global profiling_on + profiling_on = True + + +def set_profiling_off(): + global profiling_on + profiling_on = False + + +def get_profiling_on(): + global profiling_on + return profiling_on + + +def set_command_line(): + global command_line + command_line = True + + +def is_command_line(): + global command_line + return command_line diff --git a/third_party/enflame/include/triton/third_party/proton/proton/hook.py b/third_party/enflame/include/triton/third_party/proton/proton/hook.py new file mode 100644 index 000000000..e40e1b38c --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/hook.py @@ -0,0 +1,34 @@ +from .state import enter_state, exit_state +from .scope import enter_scope, exit_scope +from triton.compiler import CompiledKernel, LazyDict + +COMPUTE_METADATA_SCOPE_NAME = "__proton_launch_metadata" + + +class TritonHook: + flops_width = [8, 16, 32, 64] + metrics = [f"flops{width}" for width in flops_width] + ["bytes"] + ["flops"] + + @staticmethod + def enter(lazy_dict: LazyDict) -> None: + enter_state(COMPUTE_METADATA_SCOPE_NAME) + metadata = lazy_dict.get() + exit_state() + fn_metrics = {k: metadata[k] for k in TritonHook.metrics if k in metadata} + enter_scope(metadata["name"], triton_op=True, metrics=fn_metrics) + + @staticmethod + def exit(lazy_dict: LazyDict) -> None: + exit_scope(triton_op=True) + + +def register_triton_hook() -> None: + if CompiledKernel.launch_enter_hook is None: + CompiledKernel.launch_enter_hook = TritonHook.enter + CompiledKernel.launch_exit_hook = TritonHook.exit + + +def unregister_triton_hook() -> None: + if CompiledKernel.launch_enter_hook == TritonHook.enter: + CompiledKernel.launch_enter_hook = None + CompiledKernel.launch_exit_hook = None diff --git a/third_party/enflame/include/triton/third_party/proton/proton/language.py b/third_party/enflame/include/triton/third_party/proton/proton/language.py new file mode 100644 index 000000000..b88934b21 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/language.py @@ -0,0 +1,11 @@ +from triton.language import core as tl +from triton.language.core import builtin +import warnings + + +@builtin +def record(isStart: bool, regionId: int, _builder=None): + warnings.warn( + "\nWarning the proton language module within Proton contains under development features that are not intended to be used outside of the core development team" + ) + return tl.tensor(_builder.create_proton_record(isStart, regionId), tl.void) diff --git a/third_party/enflame/include/triton/third_party/proton/proton/profile.py b/third_party/enflame/include/triton/third_party/proton/proton/profile.py new file mode 100644 index 000000000..5ee01f7b4 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/profile.py @@ -0,0 +1,226 @@ +import functools +import triton +import os +import pathlib + +from triton._C.libproton import proton as libproton +from .hook import register_triton_hook, unregister_triton_hook +from .flags import set_profiling_off, set_profiling_on, is_command_line +from typing import Optional + +DEFAULT_PROFILE_NAME = "proton" + + +def _select_backend() -> str: + backend = triton.runtime.driver.active.get_current_target().backend + if backend == "cuda": + return "cupti" + elif backend == "hip": + return "roctracer" + else: + raise ValueError("No backend is available for the current target.") + + +def _get_backend_default_path(backend: str) -> str: + lib_path = "" + if backend == "cupti": + # First try to get the path from the environment variable that overrides the default path + lib_path = os.getenv("TRITON_CUPTI_LIB_PATH", None) + if lib_path is None: + # Get the default path for the cupti backend, + # which is the most compatible with the current CUPTI header file triton is compiled with + lib_path = str(pathlib.Path(__file__).parent.parent.absolute() / "backends" / "nvidia" / "lib" / "cupti") + return lib_path + + +def _check_env(backend: str) -> None: + if backend == "roctracer": + hip_device_envs = ["HIP_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES"] + for env in hip_device_envs: + if os.getenv(env, None) is not None: + raise ValueError( + f"Proton does not work when the environment variable {env} is set on AMD GPUs. Please unset it and use `ROCR_VISIBLE_DEVICES` instead" + ) + + +def start( + name: Optional[str] = None, + *, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + hook: Optional[str] = None, +): + """ + Start profiling with the given name and backend. + + Usage: + + ```python + proton.start("my_profile") + # do something + proton.finalize() + ``` + + Args: + name (str, optional): The name (with path) of the profiling session. + If not provided, the default name is "~/proton.hatchet". + backend (str, optional): The backend to use for profiling. + Available options are [None, "cupti", "cupti_pcsampling", "roctracer"]. + Defaults to None, which automatically selects the backend matching the current active runtime. + context (str, optional): The context to use for profiling. + Available options are ["shadow", "python"]. + Defaults to "shadow". + data (str, optional): The data structure to use for profiling. + Available options are ["tree"]. + Defaults to "tree". + hook (str, optional): The hook to use for profiling. + Available options are [None, "triton"]. + Defaults to None. + Returns: + session (int): The session ID of the profiling session. + """ + if is_command_line(): + # Ignore the start() call if the script is run from the command line. + return + + if name is None: + name = DEFAULT_PROFILE_NAME + + if backend is None: + backend = _select_backend() + + _check_env(backend) + + backend_path = _get_backend_default_path(backend) + + set_profiling_on() + if hook and hook == "triton": + register_triton_hook() + return libproton.start(name, context, data, backend, backend_path) + + +def activate(session: Optional[int] = None) -> None: + """ + Activate the specified session. + The profiling session will be active and data will be recorded. + + Args: + session (int): The session ID of the profiling session. Defaults to None (all sessions) + + Returns: + None + """ + if is_command_line() and session != 0: + raise ValueError("Only one session can be activated when running from the command line.") + if session is None: + libproton.activate_all() + else: + libproton.activate(session) + + +def deactivate(session: Optional[int] = None) -> None: + """ + Stop the specified session. + The profiling session's data will still be in the memory, but no more data will be recorded. + + Args: + session (int): The session ID of the profiling session. Defaults to None (all sessions) + + Returns: + None + """ + if is_command_line() and session != 0: + raise ValueError("Only one session can be deactivated when running from the command line.") + if session is None: + libproton.deactivate_all() + else: + libproton.deactivate(session) + + +def finalize(session: Optional[int] = None, output_format: str = "hatchet") -> None: + """ + Finalizes a profiling session. + Flush and write the profiling data to the file specified by the session name. + + Args: + session (int, optional): The session ID to finalize. If None, all sessions are finalized. Defaults to None. + output_format (str, optional): The output format for the profiling results. + Aavailable options are ["hatchet"]. + + Returns: + None + """ + if session is None: + set_profiling_off() + libproton.finalize_all(output_format) + unregister_triton_hook() + else: + if is_command_line() and session != 0: + raise ValueError("Only one session can be finalized when running from the command line.") + libproton.finalize(session, output_format) + + +def _profiling( + func, + name: Optional[str] = None, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + hook: Optional[str] = None, +): + """ + Context manager for profiling. Internally use only. + + Args: + See start() for the arguments. + + Returns: + wrapper (function): The wrapped function. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + session = start(name, context=context, data=data, backend=backend, hook=hook) + ret = func(*args, **kwargs) + deactivate(session) + return ret + + return wrapper + + +def profile( + func=None, + *, + name: Optional[str] = None, + context: Optional[str] = "shadow", + data: Optional[str] = "tree", + backend: Optional[str] = None, + hook: Optional[str] = None, +): + """ + Decorator for profiling. + + Usage: + + ```python + @proton.profile + def foo(): + pass + ``` + + Args: + See start() for the arguments. + + Returns: + decorator (function): The decorator function. + """ + if func is None: + # It's being used with parentheses, so return a decorator + def decorator(f): + return _profiling(f, name=name, context=context, data=data, backend=backend, hook=hook) + + return decorator + else: + # It's being used without parentheses, so apply the decorator directly + return _profiling(func, name=name, context=context, data=data, backend=backend, hook=hook) diff --git a/third_party/enflame/include/triton/third_party/proton/proton/proton.py b/third_party/enflame/include/triton/third_party/proton/proton/proton.py new file mode 100644 index 000000000..0eacc850e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/proton.py @@ -0,0 +1,105 @@ +import argparse +import sys +import os +import pathlib +from .profile import start, finalize, _select_backend +from .flags import set_command_line +import triton + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="The proton command utility for profiling scripts and pytest tests.", usage=""" + proton [options] script.py [script_args] [script_options] + proton [options] pytest [pytest_args] [script_options] + python -m triton.profiler.proton [options] script.py [script_args] [script_options] +""", formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-n", "--name", type=str, help="Name of the profiling session") + parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None, + choices=["cupti", "cupti_pcsampling", "roctracer"]) + parser.add_argument("-c", "--context", type=str, help="Profiling context", default="shadow", + choices=["shadow", "python"]) + parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree"]) + parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"]) + parser.add_argument("-i", "--instrument", type=str, help="Instrumentation analysis type", default=None, + choices=[None, "print-mem-spaces"]) + parser.add_argument('target_args', nargs=argparse.REMAINDER, help='Subcommand and its arguments') + args = parser.parse_args() + return args, args.target_args + + +def is_pytest(script): + return os.path.basename(script) == 'pytest' + + +def execute_as_main(script, args, instrumentation_pass=None): + script_path = os.path.abspath(script) + # Prepare a clean global environment + clean_globals = { + "__name__": "__main__", + "__file__": script_path, + "__builtins__": __builtins__, + sys.__name__: sys, + } + + original_argv = sys.argv + sys.argv = [script] + args + # Append the script's directory in case the script uses relative imports + sys.path.append(os.path.dirname(script_path)) + top_level_triton_path = os.path.dirname(triton.__file__) + + if instrumentation_pass == "print-mem-spaces": + instrumentation_pass_path = str( + next(pathlib.Path(top_level_triton_path).rglob("libPrintLoadStoreMemSpaces.so"), None)) + os.environ['TRITON_ALWAYS_COMPILE'] = "1" + os.environ['TRITON_DISABLE_LINE_INFO'] = "0" + os.environ['LLVM_PASS_PLUGIN_PATH'] = instrumentation_pass_path + + # Execute in the isolated environment + try: + with open(script_path, 'rb') as file: + code = compile(file.read(), script_path, 'exec') + exec(code, clean_globals) + except Exception as e: + print(f"An error occurred while executing the script: {e}") + finally: + sys.argv = original_argv + + +def do_setup_and_execute(target_args, instrumentation_pass=None): + # Set the command line mode to avoid any `start` calls in the script. + set_command_line() + + script = target_args[0] + script_args = target_args[1:] if len(target_args) > 1 else [] + if is_pytest(script): + import pytest + pytest.main(script_args) + else: + execute_as_main(script, script_args, instrumentation_pass) + + +def run_profiling(args, target_args): + backend = args.backend if args.backend else _select_backend() + + start(args.name, context=args.context, data=args.data, backend=backend, hook=args.hook) + + do_setup_and_execute(target_args) + + finalize() + + +def run_instrumentation(args, target_args): + do_setup_and_execute(target_args, args.instrument) + + +def main(): + args, target_args = parse_arguments() + if args.instrument: + run_instrumentation(args, target_args) + return + run_profiling(args, target_args) + + +if __name__ == "__main__": + main() diff --git a/third_party/enflame/include/triton/third_party/proton/proton/scope.py b/third_party/enflame/include/triton/third_party/proton/proton/scope.py new file mode 100644 index 000000000..bcd51c97e --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/scope.py @@ -0,0 +1,130 @@ +import threading +import time +from functools import wraps +from typing import Optional, Union + +from .flags import get_profiling_on +from triton._C.libproton import proton as libproton + +thread_local_scopes = threading.local() + +MetricValueType = Union[float, int] + + +class scope: + """ + A context manager and decorator for entering and exiting a scope. + + Usage: + context manager: + ```python + with proton.scope("test0", {metric_name: metric_value}): + foo[1,](x, y) + ``` + + decorator: + ```python + @proton.scope("test0", {metric_name: metric_value}) + def foo(x, y): + ... + ``` + + Args: + name (str): The name of the scope. + metrics (dict[str, float], optional): The metrics of the scope. Default is None. + """ + + def __init__(self, name: str, metrics: Optional[dict[str, MetricValueType]] = None) -> None: + self.name = name + self.metrics = metrics + self.id = None + + def _enter_scope(self): + if not get_profiling_on(): + return + self.id = libproton.record_scope() + libproton.enter_scope(self.id, self.name) + if self.metrics: + libproton.add_metrics(self.id, self.metrics) + + def _exit_scope(self): + if not get_profiling_on() or self.id is None: + return + libproton.exit_scope(self.id, self.name) + + def __enter__(self): + self._enter_scope() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._exit_scope() + + def __call__(self, func): + + @wraps(func) + def wrapper(*args, **kwargs): + self._enter_scope() + try: + return func(*args, **kwargs) + finally: + self._exit_scope() + + return wrapper + + +class cpu_timed_scope(scope): + """ + A scope that measures elapsed time (cpu_time). + + Args: + name (str): The name of the scope. + metrics (dict[str, float], optional): Additional metrics to add. Default is None. + """ + + def __init__(self, name: str, metrics: Optional[dict[str, float]] = None) -> None: + super().__init__(name, metrics) + self.start_time = None + if metrics and "cpu_time" in metrics: + raise ValueError("The metric name 'cpu_time' is reserved.") + + def _enter_scope(self): + if not get_profiling_on(): + return + self.start_time = time.time_ns() + super()._enter_scope() + + def _exit_scope(self): + if not get_profiling_on(): + return + super()._exit_scope() + if self.start_time is not None: + cpu_time = time.time_ns() - self.start_time + libproton.add_metrics(self.id, {"cpu_time (ns)(exc)": cpu_time}) + + +def enter_scope(name: str, *, triton_op: bool = False, metrics: Optional[dict[str, MetricValueType]] = None) -> int: + if not get_profiling_on(): + return -1 + id = libproton.record_scope() + thread_local_scopes.scopes = getattr(thread_local_scopes, "scopes", []) + thread_local_scopes.scopes.append((id, name)) + if triton_op: + libproton.enter_op(id, name) + else: + libproton.enter_scope(id, name) + if metrics: + libproton.add_metrics(id, metrics) + return id + + +def exit_scope(triton_op: bool = False, metrics: Optional[dict[str, MetricValueType]] = None) -> int: + if not get_profiling_on(): + return -1 + id, name = thread_local_scopes.scopes.pop() + if triton_op: + libproton.exit_op(id, name) + else: + libproton.exit_scope(id, name) + if metrics: + libproton.add_metrics(id, metrics) + return id diff --git a/third_party/enflame/include/triton/third_party/proton/proton/state.py b/third_party/enflame/include/triton/third_party/proton/proton/state.py new file mode 100644 index 000000000..dd1e47801 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/state.py @@ -0,0 +1,61 @@ +from triton._C.libproton import proton as libproton +from .flags import get_profiling_on +from functools import wraps + + +class state: + """ + A context manager and decorator for entering and exiting a state. + + Usage: + context manager: + ```python + with proton.state("test0"): + foo[1,](x, y) + ``` + + decorator: + ```python + @proton.state("test0") + def foo(x, y): + ... + ``` + + Args: + name (str): The name of the state. + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __enter__(self): + if not get_profiling_on(): + return self + libproton.enter_state(self.name) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + if not get_profiling_on(): + return + libproton.exit_state() + + def __call__(self, func): + + @wraps(func) + def wrapper(*args, **kwargs): + if get_profiling_on(): + libproton.enter_state(self.name) + ret = func(*args, **kwargs) + if get_profiling_on(): + libproton.exit_state() + return ret + + return wrapper + + +def enter_state(name: str) -> None: + libproton.enter_state(name) + + +def exit_state() -> None: + libproton.exit_state() diff --git a/third_party/enflame/include/triton/third_party/proton/proton/viewer.py b/third_party/enflame/include/triton/third_party/proton/proton/viewer.py new file mode 100644 index 000000000..896408704 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/proton/viewer.py @@ -0,0 +1,414 @@ +import argparse +from collections import namedtuple +import json +import pandas as pd +try: + import hatchet as ht + from hatchet.query import NegationQuery +except ImportError: + raise ImportError("Failed to import hatchet. `pip install llnl-hatchet` to get the correct version.") +import numpy as np +from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME, TritonHook + + +def match_available_metrics(metrics, inclusive_metrics, exclusive_metrics): + ret = [] + if not isinstance(metrics, list): + metrics = [metrics] + if metrics: + for metric in metrics: + metric = metric.lower() + for raw_metric in inclusive_metrics + exclusive_metrics: + suffix = " (inc)" if raw_metric in inclusive_metrics else "" + raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() + if metric in (raw_metric, raw_metric_no_unit): + ret.append(raw_metric + suffix) + break + if len(ret) == 0: + raise RuntimeError(f"Metric {metric} is not found. Use the --list flag to list available metrics") + return ret + + +def remove_frames(database: json): + # We first fine frames that match either one of the two conditions: + # 1. The frame name is COMPUTE_METADATA_SCOPE_NAME + # 2. The frame has no metrics and no children + # Then we go up from the located nodes and remove the parents if all children were + # metadata nodes + def remove_frame_helper(node): + if "frame" not in node: + return node + if node["frame"]["name"] == COMPUTE_METADATA_SCOPE_NAME: + return None + if len(node["metrics"]) == 0 and len(node["children"]) == 0: + return None + children = node.get("children", []) + new_children = [] + for child in children: + new_child = remove_frame_helper(child) + if new_child is not None: + new_children.append(new_child) + if len(new_children) > 0 or len(children) == 0: + node["children"] = new_children + return node + return None + + new_database = [] + for node in database: + new_node = remove_frame_helper(node) + if new_node is not None: + new_database.append(new_node) + return new_database + + +def get_raw_metrics(file): + database = json.load(file) + database = remove_frames(database) + device_info = database.pop(1) + gf = ht.GraphFrame.from_literal(database) + inclusive_metrics = gf.show_metric_columns() + exclusive_metrics = [metric for metric in gf.dataframe.columns if metric not in inclusive_metrics] + return gf, inclusive_metrics, exclusive_metrics, device_info + + +def get_min_time_flops(df, device_info): + min_time_flops = pd.DataFrame(0.0, index=df.index, columns=["min_time"]) + for device_type in device_info: + for device_index in device_info[device_type]: + arch = device_info[device_type][device_index]["arch"] + num_sms = device_info[device_type][device_index]["num_sms"] + clock_rate = device_info[device_type][device_index]["clock_rate"] + for width in TritonHook.flops_width: + idx = df["device_id"] == device_index + device_frames = df[idx] + if f"flops{width}" not in device_frames.columns: + continue + max_flops = 0 + if device_type == "CUDA": + if arch == "80": + max_flops = 624e12 / (width / 8) + elif arch == "89": + # TODO(Keren): Implement fp16 acc-> 660.6 fp8 + max_flops = (330.3 * 1e12) / (width / 8) + elif arch == "90": + # 114 sms and 1755mhz is the base number of sms and clock rate of H100 pcie + max_flops = ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) / (width / 8) + elif arch == "100": + max_flops = (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8) + elif device_type == "HIP": + if arch == "gfx90a": + max_flops = 383e12 / (width / 8) + elif arch == "gfx941" or arch == "gfx942": + max_flops = 2614.9e12 / (width / 8) + else: + raise ValueError(f"Unsupported device type: {device_type}") + min_time_flops.loc[idx, "min_time"] += device_frames[f"flops{width}"].fillna(0) / max_flops + return min_time_flops + + +def get_min_time_bytes(df, device_info): + min_time_bytes = pd.DataFrame(0.0, index=df.index, columns=["min_time"]) + for device_type in device_info: + for device_index in device_info[device_type]: + idx = df["device_id"] == device_index + device_frames = df[idx] + memory_clock_rate = device_info[device_type][device_index]["memory_clock_rate"] # in khz + bus_width = device_info[device_type][device_index]["bus_width"] # in bits + peak_bandwidth = 2 * bus_width * memory_clock_rate * 1e3 / 8 + min_time_bytes.loc[idx, "min_time"] += device_frames["bytes"] / peak_bandwidth + return min_time_bytes + + +FactorDict = namedtuple("FactorDict", ["name", "factor"]) +time_factor_dict = FactorDict("time", {"time/s": 1, "time/ms": 1e-3, "time/us": 1e-6, "time/ns": 1e-9}) +avg_time_factor_dict = FactorDict("avg_time", {f"avg_{key}": value for key, value in time_factor_dict.factor.items()}) +cpu_time_factor_dict = FactorDict("cpu_time", + {"cpu_time/s": 1, "cpu_time/ms": 1e-3, "cpu_time/us": 1e-6, "cpu_time/ns": 1e-9}) +avg_cpu_time_factor_dict = FactorDict("avg_cpu_time", + {f"avg_{key}": value + for key, value in cpu_time_factor_dict.factor.items()}) +bytes_factor_dict = FactorDict("bytes", {"byte/s": 1, "gbyte/s": 1e9, "tbyte/s": 1e12}) + +derivable_metrics = { + **{key: bytes_factor_dict + for key in bytes_factor_dict.factor.keys()}, +} + +# FLOPS have a specific width to their metric +default_flop_factor_dict = {"flop/s": 1, "gflop/s": 1e9, "tflop/s": 1e12} +derivable_metrics.update( + {key: FactorDict("flops", default_flop_factor_dict) + for key in default_flop_factor_dict.keys()}) +for width in TritonHook.flops_width: + factor_name = f"flops{width}" + factor_dict = {f"flop{width}/s": 1, f"gflop{width}/s": 1e9, f"tflop{width}/s": 1e12} + derivable_metrics.update({key: FactorDict(factor_name, factor_dict) for key in factor_dict.keys()}) + + +def derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info): + derived_metrics = [] + + def get_time_seconds(df, metric, factor_dict): + time_metric_name = match_available_metrics(metric, inclusive_metrics, exclusive_metrics)[0] + time_unit = (factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0]) + return df[time_metric_name] * factor_dict.factor[time_unit] + + for metric in metrics: + if metric == "util": # exclusive + min_time_bytes = get_min_time_bytes(gf.dataframe, device_info) + min_time_flops = get_min_time_flops(gf.dataframe, device_info) + time_sec = get_time_seconds(gf.dataframe, "time", time_factor_dict) + internal_frame_indices = gf.dataframe["device_id"].isna() + gf.dataframe["util"] = min_time_flops["min_time"].combine(min_time_bytes["min_time"], max) / time_sec + gf.dataframe.loc[internal_frame_indices, "util"] = np.nan + derived_metrics.append("util") + elif metric in derivable_metrics: # flop/s, byte/s, inclusive + derivable_metric = derivable_metrics[metric] + metric_name = derivable_metric.name + metric_factor_dict = derivable_metric.factor + matched_metric_name = match_available_metrics(metric_name, inclusive_metrics, exclusive_metrics)[0] + gf.dataframe[f"{metric} (inc)"] = (gf.dataframe[matched_metric_name] / + (get_time_seconds(gf.dataframe, "time", time_factor_dict)) / + metric_factor_dict[metric]) + derived_metrics.append(f"{metric} (inc)") + elif metric in time_factor_dict.factor or metric in cpu_time_factor_dict.factor or \ + metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor: # inclusive + is_cpu = metric in cpu_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor + is_avg = metric in avg_time_factor_dict.factor or metric in avg_cpu_time_factor_dict.factor + + factor_dict = (avg_cpu_time_factor_dict if is_avg else cpu_time_factor_dict) if is_cpu \ + else (avg_time_factor_dict if is_avg else time_factor_dict) + metric_name = "cpu_time" if is_cpu else "time" + metric_time_unit = factor_dict.name + "/" + metric.split("/")[1] + + time_value = get_time_seconds(gf.dataframe, metric_name, factor_dict) + if is_avg: + time_value = time_value / gf.dataframe["count (inc)"] + + gf.dataframe[f"{metric} (inc)"] = time_value / factor_dict.factor[metric_time_unit] + derived_metrics.append(f"{metric} (inc)") + else: + metric_name_and_unit = metric.split("/") + metric_name = metric_name_and_unit[0] + if len(metric_name_and_unit) > 1: # percentage, exclusive or inclusive + metric_unit = metric_name_and_unit[1] + if metric_unit != "%": + raise ValueError(f"Unsupported unit {metric_unit}") + matched_metric_name = match_available_metrics(metric_name, inclusive_metrics, exclusive_metrics)[0] + single_frame = gf.dataframe[matched_metric_name] + suffix = "" + if "(inc)" in matched_metric_name: + suffix = " (inc)" + total = gf.dataframe[matched_metric_name].iloc[0] + else: + total = gf.dataframe[matched_metric_name].sum() + gf.dataframe[metric + suffix] = (single_frame / total) * 100.0 + derived_metrics.append(metric + suffix) + else: + matched_metric_name = match_available_metrics(metric_name, inclusive_metrics, exclusive_metrics)[0] + derived_metrics.append(matched_metric_name) + + # Update derived metrics to the graph frame + for derived_metric in derived_metrics: + if derived_metric.endswith("(inc)"): + gf.inc_metrics.append(derived_metric) + else: + gf.exc_metrics.append(derived_metric) + + return derived_metrics + + +def format_frames(gf, format): + if format == "file_function_line": + gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split("/")[-1]) + elif format == "function_line": + gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split(":")[-1]) + elif format == "file_function": + gf.dataframe["name"] = gf.dataframe["name"].apply( + lambda x: f"{x.split('/')[-1].split(':')[0]}@{x.split('@')[-1].split(':')[0]}") + return gf + + +def filter_frames(gf, include=None, exclude=None, threshold=None, metric=None): + if include: + query = f""" +MATCH ("*")->(".", p)->("*") +WHERE p."name" =~ "{include}" +""" + gf = gf.filter(query, squash=True) + if exclude: + inclusion_query = f""" +MATCH (".", p)->("*") +WHERE p."name" =~ "{exclude}" +""" + query = NegationQuery(inclusion_query) + gf = gf.filter(query, squash=True) + if threshold: + query = ["*", {metric: f">= {threshold}"}] + gf = gf.filter(query, squash=True) + return gf + + +def emit_warnings(gf, metrics): + if "bytes (inc)" in metrics: + byte_values = gf.dataframe["bytes (inc)"].values + min_byte_value = np.nanmin(byte_values) + if min_byte_value < 0: + print("Warning: Negative byte values detected, this is usually the result of a datatype overflow\n") + + +def print_tree(gf, metrics, depth=100, format=None, print_sorted=False): + gf = format_frames(gf, format) + print(gf.tree(metric_column=metrics, expand_name=True, depth=depth, render_header=False)) + + if print_sorted: + print("Sorted kernels by metric " + metrics[0]) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + for row in range(1, len(sorted_df)): + kernel_name = sorted_df.iloc[row]['name'][:100] + "..." if len( + sorted_df.iloc[row]['name']) > 100 else sorted_df.iloc[row]['name'] + print("{:105} {:.4}".format(kernel_name, sorted_df.iloc[row][metrics[0]])) + emit_warnings(gf, metrics) + + +def parse(metrics, filename, include=None, exclude=None, threshold=None): + with open(filename, "r") as f: + gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(f) + assert len(inclusive_metrics + exclusive_metrics) > 0, "No metrics found in the input file" + gf.update_inclusive_columns() + metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info) + # TODO: generalize to support multiple metrics, not just the first one + gf = filter_frames(gf, include, exclude, threshold, metrics[0]) + return gf, metrics + + +def show_metrics(file_name): + with open(file_name, "r") as f: + _, inclusive_metrics, exclusive_metrics, _ = get_raw_metrics(f) + print("Available inclusive metrics:") + if inclusive_metrics: + for raw_metric in inclusive_metrics: + raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() + print(f"- {raw_metric_no_unit}") + print("Available exclusive metrics:") + if exclusive_metrics: + for raw_metric in exclusive_metrics: + raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() + print(f"- {raw_metric_no_unit}") + + +def main(): + argparser = argparse.ArgumentParser( + description="Performance data viewer for proton profiles.", + formatter_class=argparse.RawTextHelpFormatter, + ) + argparser.add_argument( + "-l", + "--list", + action="store_true", + help="""List available metrics. Metric names are case insensitive and ignore units. +Derived metrics can be created when source metrics are available. +- time/s, time/ms, time/us, time/ns: time +- avg_time/s, avg_time/ms, avg_time/us, avg_time/ns: time / count +- flop[<8/16/32/64>]/s, gflop[<8/16/32/64>]/s, tflop[<8/16/32/64>]/s: flops / time +- byte/s, gbyte/s, tbyte/s: bytes / time +- util: max(sum(flops) / peak_flops_time, sum(bytes) / peak_bandwidth_time) +- /%%: frame(metric) / sum(metric). Only availble for inclusive metrics (e.g. time) +""", + ) + argparser.add_argument( + "-m", + "--metrics", + type=str, + default=None, + help="""At maximum two metrics can be specified, separated by comma. +There are two modes: +1) Choose the output metric to display. It's case insensitive and ignore units. +2) Derive a new metric from existing metrics. +""", + ) + argparser.add_argument( + "-i", + "--include", + type=str, + default=None, + help= + """Find frames that match the given regular expression and return all nodes in the paths that pass through the matching frames. +For example, the following command will display all paths that contain frames that contains "test": +``` +proton-viewer -i ".*test.*" path/to/file.json +``` +""", + ) + argparser.add_argument( + "-e", + "--exclude", + type=str, + default=None, + help="""Exclude frames that match the given regular expression and their children. +For example, the following command will exclude all paths starting from frames that contains "test": +``` +proton-viewer -e ".*test.*" path/to/file.json +``` +""", + ) + argparser.add_argument( + "-t", + "--threshold", + type=float, + default=None, + help= + "Exclude frames(kernels) whose metrics are below the given threshold. This filter only applies on the first metric.", + ) + argparser.add_argument( + "-d", + "--depth", + type=int, + default=100, + help="The depth of the tree to display", + ) + argparser.add_argument( + "-f", "--format", type=str, choices=["full", "file_function_line", "function_line", "file_function"], + default="full", help="""Formatting the frame name. +- full: include the path, file name, function name and line number. +- file_function_line: include the file name, function name and line number. +- function_line: include the function name and line number. +- file_function: include the file name and function name. +""") + argparser.add_argument( + "--print-sorted", + action='store_true', + default=False, + help="Sort output by metric value instead of chronologically", + ) + argparser.add_argument( + "--diff-profile", "-diff", type=str, default=None, + help="Compare two profiles. When used as 'proton-viewer -m time -diff file1.log file2.log', " + "computes the difference: file2['time'] - file1['time']") + + args, target_args = argparser.parse_known_args() + assert len(target_args) == 1, "Must specify a file to read" + + file_name = target_args[0] + metrics = args.metrics.split(",") if args.metrics else None + include = args.include + exclude = args.exclude + threshold = args.threshold + depth = args.depth + format = args.format + diff = args.diff_profile + print_sorted = args.print_sorted + if include and exclude: + raise ValueError("Cannot specify both include and exclude") + if args.list: + show_metrics(file_name) + elif metrics: + gf, derived_metrics = parse(metrics, file_name, include, exclude, threshold) + if diff: + gf2, _ = parse(metrics, diff, include, exclude, threshold) + gf = gf.sub(gf2) + print_tree(gf, derived_metrics, depth, format, print_sorted) + + +if __name__ == "__main__": + main() diff --git a/third_party/enflame/include/triton/third_party/proton/test/examples/cuda.json b/third_party/enflame/include/triton/third_party/proton/test/examples/cuda.json new file mode 100644 index 000000000..bcf433d60 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/examples/cuda.json @@ -0,0 +1,86 @@ +[ + { + "children": [ + { + "children": [], + "frame": { + "name": "foo0", + "type": "function" + }, + "metrics": { + "count": 10, + "device_id": "1", + "device_type": "CUDA", + "time (ns)": 204800, + "flops8": 1e11, + "bytes": 1e8 + } + }, + { + "children": [], + "frame": { + "name": "foo1", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 204800, + "flops8": 1e10, + "bytes": 1e7 + } + }, + { + "children": [], + "frame": { + "name": "foo2", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "2", + "device_type": "CUDA", + "time (ns)": 204800, + "flops8": 1e11, + "bytes": 1e7 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "count": 0, + "time (ns)": 0, + "flops8": 0, + "bytes": 0 + } + }, + { + "CUDA": { + "0": { + "arch": "89", + "bus_width": 384, + "clock_rate": 2625000, + "memory_clock_rate": 10501000, + "num_sms": 128 + }, + "1": { + "arch": "90", + "bus_width": 6144, + "clock_rate": 1980000, + "memory_clock_rate": 2619000, + "num_sms": 132 + }, + "2": { + "arch": "100", + "bus_width": 6144, + "clock_rate": 1700000, + "memory_clock_rate": 2619000, + "num_sms": 148 + } + } + } +] diff --git a/third_party/enflame/include/triton/third_party/proton/test/examples/frame.json b/third_party/enflame/include/triton/third_party/proton/test/examples/frame.json new file mode 100644 index 000000000..cd671c9df --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/examples/frame.json @@ -0,0 +1,58 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "/home/user/projects/example.py/test.py:1@foo", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800 + } + } + ], + "frame": { + "name": "test0" + }, + "metrics": {} + }, + { + "children": [], + "frame": { + "name": "test1" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "count": 0, + "time (ns)": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + } + } + } +] diff --git a/third_party/enflame/include/triton/third_party/proton/test/examples/hip.json b/third_party/enflame/include/triton/third_party/proton/test/examples/hip.json new file mode 100644 index 000000000..70eaf325d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/examples/hip.json @@ -0,0 +1,64 @@ +[ + { + "children": [ + { + "children": [], + "frame": { + "name": "foo0", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "1", + "device_type": "HIP", + "time (ns)": 204800, + "flops8": 1e11, + "bytes": 1e8 + } + }, + { + "children": [], + "frame": { + "name": "foo1", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800, + "flops8": 1e10, + "bytes": 1e7 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "count": 0, + "time (ns)": 0, + "flops8": 0, + "bytes": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + }, + "1": { + "arch": "gfx941", + "bus_width": 8192, + "clock_rate": 5200000, + "memory_clock_rate": 2525000, + "num_sms": 304 + } + } + } +] diff --git a/third_party/enflame/include/triton/third_party/proton/test/examples/leaf_nodes.json b/third_party/enflame/include/triton/third_party/proton/test/examples/leaf_nodes.json new file mode 100644 index 000000000..5930664dd --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/examples/leaf_nodes.json @@ -0,0 +1,168 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_1_2_2", + "type": "function" + }, + "metrics": { + "count": 402, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 78190414 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_1_3_1", + "type": "function" + }, + "metrics": { + "count": 502, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 24125138 + } + } + ], + "frame": { + "name": "kernel_1_2_1", + "type": "function" + }, + "metrics": { + "bytes": 3997237248, + "flops": 1534939103232 + } + } + ], + "frame": { + "name": "kernel_1_1_1", + "type": "function" + }, + "metrics": {} + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_2_2_2", + "type": "function" + }, + "metrics": { + "count": 120, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 23174888 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_2_3_1", + "type": "function" + }, + "metrics": { + "count": 149, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 1040322 + } + } + ], + "frame": { + "name": "kernel_2_2_1", + "type": "function" + }, + "metrics": { + "bytes": 58589184, + "flops": 4999610368 + } + } + ], + "frame": { + "name": "kernel_2_1_1", + "type": "function" + }, + "metrics": {} + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_3_2_2", + "type": "function" + }, + "metrics": { + "count": 480, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 93036508 + } + }, + { + "children": [ + { + "children": [], + "frame": { + "name": "kernel_3_2_1", + "type": "function" + }, + "metrics": { + "count": 599, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 6306402 + } + } + ], + "frame": { + "name": "kernel_3_2_1", + "type": "function" + }, + "metrics": { + "bytes": 529956864, + "flops": 67834478592 + } + } + ], + "frame": { + "name": "kernel_3_1_1", + "type": "function" + }, + "metrics": {} + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "bytes": 0, + "count": 0, + "flops": 0, + "time (ns)": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + } + } + } +] diff --git a/third_party/enflame/include/triton/third_party/proton/test/examples/triton.json b/third_party/enflame/include/triton/third_party/proton/test/examples/triton.json new file mode 100644 index 000000000..2a29ee358 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/examples/triton.json @@ -0,0 +1,73 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "cuda_kernel", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 4064 + } + } + ], + "frame": { + "name": "__proton_launch_metadata", + "type": "function" + }, + "metrics": {} + }, + { + "children": [], + "frame": { + "name": "triton_kernel", + "type": "function" + }, + "metrics": { + "bytes": 2.0, + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 1664 + } + } + ], + "frame": { + "name": "scope", + "type": "function" + }, + "metrics": { + "cpu_time (ns)": 12345 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "bytes": 0, + "count": 0, + "time (ns)": 0 + } + }, + { + "CUDA": { + "0": { + "arch": "86", + "bus_width": 128, + "clock_rate": 1140000, + "memory_clock_rate": 5501000, + "num_sms": 16 + } + } + } +] diff --git a/third_party/enflame/include/triton/third_party/proton/test/helper.py b/third_party/enflame/include/triton/third_party/proton/test/helper.py new file mode 100644 index 000000000..4591aeb54 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/helper.py @@ -0,0 +1,21 @@ +import triton.profiler as proton + +import torch +import sys + +from helper_kernels import custom_add + + +def main(): + a = torch.zeros(1, device="cuda") + with proton.scope("test"): + custom_add[(1, )](a) + + +def test_main(): + main() + + +if __name__ == "__main__": + if sys.argv[1] == "test": + main() diff --git a/third_party/enflame/include/triton/third_party/proton/test/helper_kernels.py b/third_party/enflame/include/triton/third_party/proton/test/helper_kernels.py new file mode 100644 index 000000000..7a128dbac --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/helper_kernels.py @@ -0,0 +1,7 @@ +import triton.language as tl +import triton + + +@triton.jit +def custom_add(a_ptr): + tl.store(a_ptr, 1.0) diff --git a/third_party/enflame/include/triton/third_party/proton/test/instrument.py b/third_party/enflame/include/triton/third_party/proton/test/instrument.py new file mode 100644 index 000000000..769fab19d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/instrument.py @@ -0,0 +1,65 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b, activation=""): + # Check constraints. + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + # 1D launch kernel where each block gets its own program. + matmul_kernel[(1, )]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + 128, 256, 64, 8) + return c + + +a = torch.randn((32, 32), device="cuda", dtype=torch.float16) +b = torch.randn((32, 32), device="cuda", dtype=torch.float16) +matmul(a, b) diff --git a/third_party/enflame/include/triton/third_party/proton/test/test_api.py b/third_party/enflame/include/triton/third_party/proton/test/test_api.py new file mode 100644 index 000000000..804b35573 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/test_api.py @@ -0,0 +1,255 @@ +import json +import triton.profiler as proton +import pathlib + + +def test_profile_single_session(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + session_id0 = proton.start(str(temp_file0.with_suffix(""))) + proton.activate() + proton.deactivate() + proton.finalize() + assert session_id0 == 0 + assert temp_file0.exists() + + temp_file1 = tmp_path / "test_profile1.hatchet" + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + proton.activate(session_id1) + proton.deactivate(session_id1) + proton.finalize(session_id1) + assert session_id1 == session_id0 + 1 + assert temp_file1.exists() + + session_id2 = proton.start("test") + proton.activate(session_id2) + proton.deactivate(session_id2) + proton.finalize() + assert session_id2 == session_id1 + 1 + assert pathlib.Path("test.hatchet").exists() + pathlib.Path("test.hatchet").unlink() + + +def test_profile_multiple_sessions(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_profile0.hatchet" + proton.start(str(temp_file0.with_suffix(""))) + temp_file1 = tmp_path / "test_profile1.hatchet" + proton.start(str(temp_file1.with_suffix(""))) + proton.activate() + proton.deactivate() + proton.finalize() + assert temp_file0.exists() + assert temp_file1.exists() + + temp_file2 = tmp_path / "test_profile2.hatchet" + session_id2 = proton.start(str(temp_file2.with_suffix(""))) + temp_file3 = tmp_path / "test_profile3.hatchet" + session_id3 = proton.start(str(temp_file3.with_suffix(""))) + proton.deactivate(session_id2) + proton.deactivate(session_id3) + proton.finalize() + assert temp_file2.exists() + assert temp_file3.exists() + + +def test_profile_decorator(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_profile_decorator.hatchet" + + @proton.profile(name=str(temp_file.with_suffix(""))) + def foo0(a, b): + return a + b + + foo0(1, 2) + proton.finalize() + assert temp_file.exists() + + @proton.profile + def foo1(a, b): + return a + b + + foo1(1, 2) + proton.finalize() + default_file = pathlib.Path(proton.DEFAULT_PROFILE_NAME + ".hatchet") + assert default_file.exists() + default_file.unlink() + + +def test_scope(tmp_path: pathlib.Path): + # Scope can be annotated even when profiling is off + with proton.scope("test"): + pass + + temp_file = tmp_path / "test_scope.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test"): + pass + + @proton.scope("test") + def foo(): + pass + + foo() + + proton.enter_scope("test") + proton.exit_scope() + proton.finalize() + assert temp_file.exists() + + +def test_hook(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_hook.hatchet" + session_id0 = proton.start(str(temp_file.with_suffix("")), hook="triton") + proton.activate(session_id0) + proton.deactivate(session_id0) + proton.finalize(None) + assert temp_file.exists() + + +def test_scope_metrics(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_metrics.hatchet" + session_id = proton.start(str(temp_file.with_suffix(""))) + # Test different scope creation methods + with proton.scope("test0", {"a": 1.0}): + pass + + @proton.scope("test1", {"a": 1.0}) + def foo(): + pass + + foo() + + # After deactivation, the metrics should be ignored + proton.deactivate(session_id) + proton.enter_scope("test2", metrics={"a": 1.0}) + proton.exit_scope() + + # Metrics should be recorded again after reactivation + proton.activate(session_id) + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() + + proton.enter_scope("test3", metrics={"a": 1.0}) + proton.exit_scope() + + # exit_scope can also take metrics + proton.enter_scope("test4") + proton.exit_scope(metrics={"b": 1.0}) + + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 4 + for child in data[0]["children"]: + if child["frame"]["name"] == "test3": + assert child["metrics"]["a"] == 2.0 + elif child["frame"]["name"] == "test4": + assert child["metrics"]["b"] == 1.0 + + +def test_scope_properties(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_properties.hatchet" + proton.start(str(temp_file.with_suffix(""))) + # Test different scope creation methods + # Different from metrics, properties could be str + with proton.scope("test0", {"a (pty)": "1"}): + pass + + @proton.scope("test1", {"a (pty)": "1"}) + def foo(): + pass + + foo() + + # Properties do not aggregate + proton.enter_scope("test2", metrics={"a (pty)": 1.0}) + proton.exit_scope() + + proton.enter_scope("test2", metrics={"a (pty)": 1.0}) + proton.exit_scope() + + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + for child in data[0]["children"]: + if child["frame"]["name"] == "test2": + assert child["metrics"]["a"] == 1.0 + elif child["frame"]["name"] == "test0": + assert child["metrics"]["a"] == "1" + + +def test_scope_exclusive(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_exclusive.hatchet" + proton.start(str(temp_file.with_suffix(""))) + # metric a only appears in the outermost scope + # metric b only appears in the innermost scope + # both metrics do not appear in the root scope + with proton.scope("test0", metrics={"a (exc)": "1"}): + with proton.scope("test1", metrics={"b (exc)": "1"}): + pass + + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + root_metrics = data[0]["metrics"] + assert len(root_metrics) == 0 + test0_frame = data[0]["children"][0] + test0_metrics = test0_frame["metrics"] + assert len(test0_metrics) == 1 + assert test0_metrics["a"] == "1" + test1_frame = test0_frame["children"][0] + test1_metrics = test1_frame["metrics"] + assert len(test1_metrics) == 1 + assert test1_metrics["b"] == "1" + + +def test_state(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_state.hatchet" + proton.start(str(temp_file.with_suffix(""))) + proton.enter_scope("test0") + proton.enter_state("state") + proton.enter_scope("test1", metrics={"a": 1.0}) + proton.exit_scope() + proton.exit_state() + proton.exit_scope() + proton.finalize() + assert temp_file.exists() + with temp_file.open() as f: + data = json.load(f) + # test0->test1->state + assert len(data[0]["children"]) == 1 + child = data[0]["children"][0] + assert child["frame"]["name"] == "test0" + assert len(child["children"]) == 1 + child = child["children"][0] + assert child["frame"]["name"] == "test1" + assert len(child["children"]) == 1 + child = child["children"][0] + assert child["frame"]["name"] == "state" + assert child["metrics"]["a"] == 1.0 + + +def test_throw(tmp_path: pathlib.Path): + # Catch an exception thrown by c++ + session_id = 100 + temp_file = tmp_path / "test_throw.hatchet" + activate_error = "" + try: + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.activate(session_id + 1) + except Exception as e: + activate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in activate_error + + deactivate_error = "" + try: + session_id = proton.start(str(temp_file.with_suffix(""))) + proton.deactivate(session_id + 1) + except Exception as e: + deactivate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error diff --git a/third_party/enflame/include/triton/third_party/proton/test/test_cmd.py b/third_party/enflame/include/triton/third_party/proton/test/test_cmd.py new file mode 100644 index 000000000..1efa3e0a7 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/test_cmd.py @@ -0,0 +1,65 @@ +import triton +import pytest +import subprocess +import json +import pathlib + + +def test_help(): + # Only check if the viewer can be invoked + subprocess.check_call(["proton", "-h"], stdout=subprocess.DEVNULL) + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@pytest.mark.parametrize("mode", ["script", "python", "pytest"]) +def test_exec(mode, tmp_path: pathlib.Path): + file_path = __file__ + helper_file = file_path.replace("test_cmd.py", "helper.py") + temp_file = tmp_path / "test_exec.hatchet" + name = str(temp_file.with_suffix("")) + if mode == "script": + subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) + elif mode == "python": + subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], + stdout=subprocess.DEVNULL) + elif mode == "pytest": + subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], + stdout=subprocess.DEVNULL) + with temp_file.open() as f: + data = json.load(f, ) + kernels = data[0]["children"] + assert len(kernels) == 2 + assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" + + +def test_instrument_exec(): + + try: + out = subprocess.Popen(["proton", "--instrument=print-mem-spaces", "instrument.py"], + cwd=pathlib.Path(__file__).parent, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + except Exception as e: + print(f"An error occurred while executing proton: {e}") + + result = [] + for line in str(out.stderr.read().decode()).split("\n"): + if line: + result.append(line.split()) + + if is_hip(): + assert len(result) == 7 + assert result[0] == ['0', 'matmul_kernel', 'instrument.py:32:20', 'GLOBAL', 'LOAD'] + assert result[1] == ['1', 'matmul_kernel', 'instrument.py:33:20', 'GLOBAL', 'LOAD'] + assert result[2] == ['2', 'matmul_kernel', 'instrument.py:32:20', 'SHARED', 'STORE'] + assert result[3] == ['3', 'matmul_kernel', 'instrument.py:33:20', 'SHARED', 'STORE'] + assert result[4] == ['4', 'matmul_kernel', 'instrument.py:32:20', 'SHARED', 'LOAD'] + assert result[5] == ['5', 'matmul_kernel', 'instrument.py:33:20', 'SHARED', 'LOAD'] + assert result[6] == ['6', 'matmul_kernel', 'instrument.py:42:21', 'GLOBAL', 'STORE'] + else: + assert [row[0] for row in result] == ['0'] + assert [row[1] for row in result] == ['matmul_kernel'] + assert [row[2] for row in result] == ['instrument.py:42:21'] + assert [row[3] for row in result] == ['SHARED'] + assert [row[4] for row in result] == ['LOAD'] diff --git a/third_party/enflame/include/triton/third_party/proton/test/test_lib.py b/third_party/enflame/include/triton/third_party/proton/test/test_lib.py new file mode 100644 index 000000000..c1936c73d --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/test_lib.py @@ -0,0 +1,51 @@ +import pathlib + +import triton._C.libproton.proton as libproton +from triton.profiler.profile import _select_backend + + +def test_record(): + id0 = libproton.record_scope() + id1 = libproton.record_scope() + assert id1 == id0 + 1 + + +def test_state(): + libproton.enter_state("zero") + libproton.exit_state() + + +def test_scope(): + id0 = libproton.record_scope() + libproton.enter_scope(id0, "zero") + id1 = libproton.record_scope() + libproton.enter_scope(id1, "one") + libproton.exit_scope(id1, "one") + libproton.exit_scope(id0, "zero") + + +def test_op(): + id0 = libproton.record_scope() + libproton.enter_op(id0, "zero") + libproton.exit_op(id0, "zero") + + +def test_session(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_session.hatchet" + session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "") + libproton.deactivate(session_id) + libproton.activate(session_id) + libproton.finalize(session_id, "hatchet") + libproton.finalize_all("hatchet") + assert temp_file.exists() + + +def test_add_metrics(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_add_metrics.hatchet" + libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "") + id1 = libproton.record_scope() + libproton.enter_scope(id1, "one") + libproton.add_metrics(id1, {"a": 1.0, "b": 2.0}) + libproton.exit_scope(id1, "one") + libproton.finalize_all("hatchet") + assert temp_file.exists() diff --git a/third_party/enflame/include/triton/third_party/proton/test/test_profile.py b/third_party/enflame/include/triton/third_party/proton/test/test_profile.py new file mode 100644 index 000000000..a673c1da6 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/test_profile.py @@ -0,0 +1,313 @@ +import torch +import triton +import triton.profiler as proton +import json +import pytest +from typing import NamedTuple +import pathlib + +import triton.language as tl +from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@pytest.mark.parametrize("context", ["shadow", "python"]) +def test_torch(context, tmp_path: pathlib.Path): + temp_file = tmp_path / "test_torch.hatchet" + proton.start(str(temp_file.with_suffix("")), context=context) + proton.enter_scope("test") + torch.ones((2, 2), device="cuda") + proton.exit_scope() + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + if context == "shadow": + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test" + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + elif context == "python": + assert len(data[0]["children"]) == 1 + # bfs search until find the "elementwise_kernel" and then check its children + queue = [data[0]] + while len(queue) > 0: + parent_frame = queue.pop(0) + for child in parent_frame["children"]: + if "elementwise_kernel" in child["frame"]["name"]: + assert len(child["children"]) == 0 + return + queue.append(child) + + +def test_triton(tmp_path: pathlib.Path): + + @triton.jit + def foo(x, y): + tl.store(y, tl.load(x)) + + x = torch.tensor([2], device="cuda") + y = torch.zeros_like(x) + temp_file = tmp_path / "test_triton.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0"): + with proton.scope("test1"): + foo[(1, )](x, y) + with proton.scope("test2"): + foo[(1, )](x, y) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 2 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert len(data[0]["children"][0]["children"]) == 1 + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "test1" + assert data[0]["children"][1]["frame"]["name"] == "test2" + + +def test_cudagraph(tmp_path: pathlib.Path): + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + + @triton.jit + def foo(x, y, z): + tl.store(z, tl.load(y) + tl.load(x)) + + def fn(): + a = torch.ones((2, 2), device="cuda") + b = torch.ones((2, 2), device="cuda") + c = a + b + foo[(1, )](a, b, c) + + temp_file = tmp_path / "test_cudagraph.hatchet" + proton.start(str(temp_file.with_suffix("")), context="shadow") + + # warmup + # four kernels + fn() + + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(10): + fn() + + proton.enter_scope("test") + g.replay() + g.reset() + torch.cuda.synchronize() + proton.exit_scope() + proton.finalize() + + with temp_file.open() as f: + data = json.load(f) + # CUDA/HIP graph may also invoke additional kernels to reset outputs + # {torch.ones, add, foo, test} + assert len(data[0]["children"]) >= 4 + # find the test frame + test_frame = None + for child in data[0]["children"]: + if child["frame"]["name"] == "test": + test_frame = child + break + assert test_frame is not None + # {torch.ones, add, foo} + if is_hip(): + assert len(test_frame["children"]) >= 2 + else: + assert len(test_frame["children"]) >= 3 + assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 + + +def test_metrics(tmp_path: pathlib.Path): + + @triton.jit + def foo(x, y): + tl.store(y, tl.load(x)) + + x = torch.tensor([2], device="cuda") + y = torch.zeros_like(x) + temp_file = tmp_path / "test_metrics.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("test0", {"foo": 1.0}): + foo[(1, )](x, y) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["metrics"]["foo"] == 1.0 + + +def test_scope_backward(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_scope_backward.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.scope("ones1"): + a = torch.ones((100, 100), device="cuda", requires_grad=True) + with proton.scope("plus"): + a2 = a * a * a + with proton.scope("ones2"): + loss = torch.ones_like(a2) + + # Backward triggers two kernels in a single scope + with proton.scope("backward"): + a2.backward(loss) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 4 + + +def test_cpu_timed_scope(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_cpu_timed_scope.hatchet" + proton.start(str(temp_file.with_suffix(""))) + with proton.cpu_timed_scope("test0"): + with proton.cpu_timed_scope("test1"): + torch.ones((100, 100), device="cuda") + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 1 + test0_frame = data[0]["children"][0] + assert test0_frame["metrics"]["cpu_time (ns)"] > 0 + test1_frame = test0_frame["children"][0] + assert test1_frame["metrics"]["cpu_time (ns)"] > 0 + kernel_frame = test1_frame["children"][0] + assert kernel_frame["metrics"]["time (ns)"] > 0 + + +def test_hook(tmp_path: pathlib.Path): + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + # get arg's element size + element_size = args["x"].element_size() # non-const + size = args["size"] # const + key = "flops" + str(element_size * 8) + num_ctas = metadata.num_ctas + return {"name": f"foo_test_{num_ctas}ctas_{size}elems", key: 1.0} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.tensor([2], device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + temp_file = tmp_path / "test_hook.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton") + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + assert len(data[0]["children"]) == 1 + assert data[0]["children"][0]["frame"]["name"] == "test0" + assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" + assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + + +@pytest.mark.parametrize("context", ["shadow", "python"]) +def test_hook_gpu_kernel(tmp_path: pathlib.Path, context: str): + + def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): + x = args["x"] + # A gpu kernel, but it should be under the metadata state + return {"name": "foo_test", "bytes": x.sum().item()} + + @triton.jit(launch_metadata=metadata_fn) + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + x = torch.tensor([2], device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + temp_file = tmp_path / "test_hook.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton", context=context) + with proton.scope("test0"): + foo[(1, )](x, 1, y, num_warps=4) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + # bfs search until find the reduce kernel and then check its parent + queue = [data[0]] + while len(queue) > 0: + parent_frame = queue.pop(0) + for child in parent_frame["children"]: + if "reduce" in child["frame"]["name"]: + assert parent_frame["frame"]["name"] == COMPUTE_METADATA_SCOPE_NAME + return + queue.append(child) + + +def test_pcsampling(tmp_path: pathlib.Path): + if is_hip(): + pytest.skip("HIP backend does not support pc sampling") + + import os + if os.environ.get("PROTON_SKIP_PC_SAMPLING_TEST", "0") == "1": + pytest.skip("PC sampling test is disabled") + + @triton.jit + def foo(x, y, size: tl.constexpr): + offs = tl.arange(0, size) + for _ in range(1000): + tl.store(y + offs, tl.load(x + offs)) + + temp_file = tmp_path / "test_pcsampling.hatchet" + proton.start(str(temp_file.with_suffix("")), hook="triton", backend="cupti_pcsampling") + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + init_frame = data[0]["children"][0] + test_frame = data[0]["children"][1] + # With line mapping + assert "foo" in test_frame["children"][0]["frame"]["name"] + assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 + assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] + # Without line mapping + assert "elementwise" in init_frame["children"][0]["frame"]["name"] + assert init_frame["children"][0]["metrics"]["num_samples"] > 0 + + +def test_deactivate(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_deactivate.hatchet" + session_id = proton.start(str(temp_file.with_suffix("")), hook="triton") + proton.deactivate(session_id) + torch.randn((10, 10), device="cuda") + proton.activate(session_id) + torch.zeros((10, 10), device="cuda") + proton.deactivate(session_id) + proton.finalize() + with temp_file.open() as f: + data = json.load(f) + # Root shouldn't have device id + assert "device_id" not in data[0]["metrics"] + assert len(data[0]["children"]) == 1 + assert "device_id" in data[0]["children"][0]["metrics"] + + +def test_multiple_sessions(tmp_path: pathlib.Path): + temp_file0 = tmp_path / "test_multiple_sessions0.hatchet" + temp_file1 = tmp_path / "test_multiple_sessions1.hatchet" + session_id0 = proton.start(str(temp_file0.with_suffix(""))) + session_id1 = proton.start(str(temp_file1.with_suffix(""))) + torch.randn((10, 10), device="cuda") + torch.randn((10, 10), device="cuda") + proton.deactivate(session_id0) + proton.finalize(session_id0) + torch.randn((10, 10), device="cuda") + proton.finalize(session_id1) + # kernel has been invokved twice in session 0 and three times in session 1 + with temp_file0.open() as f: + data = json.load(f) + assert int(data[0]["children"][0]["metrics"]["count"]) == 2 + with temp_file1.open() as f: + data = json.load(f) + assert int(data[0]["children"][0]["metrics"]["count"]) == 3 diff --git a/third_party/enflame/include/triton/third_party/proton/test/test_record.py b/third_party/enflame/include/triton/third_party/proton/test/test_record.py new file mode 100644 index 000000000..57a233790 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/test_record.py @@ -0,0 +1,40 @@ +import torch +import pathlib + +import triton +import triton.language as tl +import triton.profiler.language as pl + + +def test_proton_record(tmp_path: pathlib.Path): + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + pl.record(True, 0) + y = tl.load(y_ptr + offsets, mask=mask) + pl.record(False, 0) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + torch.manual_seed(0) + size = 2**12 + x = torch.rand(size, device='cuda') + y = torch.rand(size, device='cuda') + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + ttir = pgm.asm['ttir'] + assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir + assert "proton.record() {isStart = false, regionId = 0 : i32}" in ttir diff --git a/third_party/enflame/include/triton/third_party/proton/test/test_viewer.py b/third_party/enflame/include/triton/third_party/proton/test/test_viewer.py new file mode 100644 index 000000000..0e526116a --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/test/test_viewer.py @@ -0,0 +1,199 @@ +import pytest +import subprocess +from triton.profiler.viewer import get_min_time_flops, get_min_time_bytes, get_raw_metrics, format_frames, derive_metrics, filter_frames, parse +from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME +import numpy as np + +file_path = __file__ +triton_example_file = file_path.replace("test_viewer.py", "examples/triton.json") +cuda_example_file = file_path.replace("test_viewer.py", "examples/cuda.json") +hip_example_file = file_path.replace("test_viewer.py", "examples/hip.json") +frame_example_file = file_path.replace("test_viewer.py", "examples/frame.json") +leaf_example_file = file_path.replace("test_viewer.py", "examples/leaf_nodes.json") + + +def test_help(): + # Only check if the viewer can be invoked + subprocess.check_call(["proton-viewer", "-h"], stdout=subprocess.DEVNULL) + + +def test_exclusive_metrics(): + with open(triton_example_file, "r") as f: + gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(f) + gf.update_inclusive_columns() + metrics = ["cpu_time/ns"] + metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info) + gf = filter_frames(gf, None, None, None, metrics[0]) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + actual = sorted_df.iloc[0:1]["name"].values[0] + assert actual == "scope" + + +def test_sort(): + with open(leaf_example_file, "r") as f: + gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(f) + gf = format_frames(gf, None) + gf.update_inclusive_columns() + metrics = ["time/s", "time/ms", "time/us", "time/ns"] + metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info) + gf = filter_frames(gf, None, None, None, metrics[0]) + sorted_df = gf.dataframe.sort_values(by=[metrics[0]], ascending=False) + actual = sorted_df.iloc[0:5]["name"].values + expected = ["ROOT", "kernel_1_1_1", "kernel_3_1_1", "kernel_3_2_2", "kernel_1_2_2"] + assert len(actual) == len(expected) + assert all(a == b for a, b in zip(actual, expected)) + + +@pytest.mark.parametrize("option", ["full", "file_function_line", "function_line", "file_function"]) +def test_format_frames(option): + with open(frame_example_file, "r") as f: + gf, _, _, _ = get_raw_metrics(f) + gf = format_frames(gf, option) + if option == "full": + idx = gf.dataframe["name"] == "/home/user/projects/example.py/test.py:1@foo" + elif option == "file_function_line": + idx = gf.dataframe["name"] == "test.py:1@foo" + elif option == "function_line": + idx = gf.dataframe["name"] == "1@foo" + elif option == "file_function": + idx = gf.dataframe["name"] == "test.py@foo" + assert idx.sum() == 1 + + +@pytest.mark.parametrize("option", ["include", "exclude"]) +def test_filter_frames(option): + include = "" + exclude = "" + with open(frame_example_file, "r") as f: + gf, _, _, _ = get_raw_metrics(f) + if option == "include": + include = ".*test0.*" + elif option == "exclude": + exclude = ".*test1.*" + gf = filter_frames(gf, include=include, exclude=exclude) + idx = gf.dataframe["name"] == "test1" + assert idx.sum() == 0 + idx = gf.dataframe["name"] == "test0" + assert idx.sum() == 1 + + +def test_filter_metadata(): + with open(triton_example_file, "r") as f: + gf, _, _, _ = get_raw_metrics(f) + assert COMPUTE_METADATA_SCOPE_NAME not in gf.dataframe["name"].tolist() + assert "cuda_kernel" not in gf.dataframe["name"].tolist() + assert "scope" in gf.dataframe["name"].tolist() + assert "triton_kernel" in gf.dataframe["name"].tolist() + + +def test_parse(): + gf, derived_metrics = parse(["time/s"], triton_example_file) + for derived_metric in derived_metrics: + assert derived_metric in gf.inc_metrics or derived_metric in gf.exc_metrics + + +def test_min_time_flops(): + with open(cuda_example_file, "r") as f: + gf, _, _, device_info = get_raw_metrics(f) + ret = get_min_time_flops(gf.dataframe, device_info) + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" + device2_idx = gf.dataframe["device_id"] == "2" + # sm89 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000025]], atol=1e-5) + # sm90 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.00005]], atol=1e-5) + # sm100 + np.testing.assert_allclose(ret[device2_idx].to_numpy(), [[0.000025]], atol=1e-5) + with open(hip_example_file, "r") as f: + gf, _, _, device_info = get_raw_metrics(f) + ret = get_min_time_flops(gf.dataframe, device_info) + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" + # MI200 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000026]], atol=1e-5) + # MI300 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.000038]], atol=1e-5) + + +def test_min_time_bytes(): + with open(cuda_example_file, "r") as f: + gf, _, _, device_info = get_raw_metrics(f) + ret = get_min_time_bytes(gf.dataframe, device_info) + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" + # sm89 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[9.91969e-06]], atol=1e-6) + # sm90 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[2.48584e-05]], atol=1e-6) + with open(hip_example_file, "r") as f: + gf, _, _, device_info = get_raw_metrics(f) + ret = get_min_time_bytes(gf.dataframe, device_info) + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" + # MI200 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[6.10351e-06]], atol=1e-6) + # MI300 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[1.93378e-05]], atol=1e-6) + + +def test_percentage(): + pass + + +def derivation_metrics_test(metrics, expected_data, sample_file, rtol=1e-7, atol=1e-6): + with open(sample_file, "r") as f: + gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(f) + assert len(inclusive_metrics + exclusive_metrics) > 0, "No metrics found in the input file" + gf.update_inclusive_columns() + derived_metrics = derive_metrics(gf, metrics, inclusive_metrics, exclusive_metrics, device_info) + for derived_metric in derived_metrics: + np.testing.assert_allclose(gf.dataframe[derived_metric].to_numpy(), expected_data[derived_metric], + rtol=rtol, atol=atol) + + +def test_avg_time_derivation(): + derivation_metrics_test( + metrics=["avg_time/s", "avg_time/ms", "avg_time/us", "avg_time/ns"], expected_data={ + "avg_time/s (inc)": [0.0000512, 0.0000205, 0.000205, + 0.000205], "avg_time/ms (inc)": [0.0512, 0.02048, 0.2048, 0.2048], "avg_time/us (inc)": + [51.2, 20.48, 204.8, 204.8], "avg_time/ns (inc)": [51200.0, 20480.0, 204800.0, 204800.0] + }, sample_file=cuda_example_file) + + +def test_util(): + derivation_metrics_test(metrics=["util"], expected_data={ + "util": [np.nan, 0.247044, 0.147830, 0.118451], + }, sample_file=cuda_example_file) + + +def test_time_derivation(): + derivation_metrics_test( + metrics=["time/s", "time/ms", "time/us", "time/ns"], expected_data={ + "time/s (inc)": [0.000614, 0.0002048, 0.0002048, 0.0002048], + "time/ms (inc)": [0.6144, 0.2048, 0.2048, 0.2048], + "time/us (inc)": [614.4, 204.8, 204.8, 204.8], + "time/ns (inc)": [614400.0, 204800.0, 204800.0, 204800.0], + "time/% (inc)": [100.0, 50.0, 50.0, 50.0], + }, sample_file=cuda_example_file) + + +def test_bytes_derivation(): + derivation_metrics_test( + metrics=["byte/s", "gbyte/s", "tbyte/s"], expected_data={ + "byte/s (inc)": [1.953125e+11, 4.88281250e+11, 4.88281250e+10, + 4.88281250e+10], "gbyte/s (inc)": [195.3125, 488.28125, 48.828125, 48.828125], + "tbyte/s (inc)": [0.195312, 0.48828125, 0.04882812, 0.04882812] + }, sample_file=cuda_example_file) + + +def test_flops_derivation(): + derivation_metrics_test( + metrics=["flop8/s", "gflop8/s", "tflop8/s"], + expected_data={ + "flop8/s (inc)": [3.417969e+14, 4.88281250e+14, 4.88281250e+13, + 4.88281250e+14], "gflop8/s (inc)": [341796.875, 488281.25, 48828.125, 488281.25], + "tflop8/s (inc)": [341.796875, 488.28125, 48.828125, 488.28125] + }, + sample_file=cuda_example_file, + ) diff --git a/third_party/enflame/include/triton/third_party/proton/tutorials/dynamic_net.py b/third_party/enflame/include/triton/third_party/proton/tutorials/dynamic_net.py new file mode 100644 index 000000000..5793bebd0 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/tutorials/dynamic_net.py @@ -0,0 +1,102 @@ +import random +import torch +import math + +import triton.profiler as proton +import argparse + +mode = "torch" + + +class DynamicNet(torch.nn.Module): + # https://pytorch.org/tutorials/beginner/examples_nn/dynamic_net.html + def __init__(self): + """ + In the constructor we instantiate five parameters and assign them as members. + """ + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) + self.b = torch.nn.Parameter(torch.randn(())) + self.c = torch.nn.Parameter(torch.randn(())) + self.d = torch.nn.Parameter(torch.randn(())) + self.e = torch.nn.Parameter(torch.randn(())) + + def forward(self, x): + """ + For the forward pass of the model, we randomly choose either 4, 5 + and reuse the e parameter to compute the contribution of these orders. + + Since each forward pass builds a dynamic computation graph, we can use normal + Python control-flow operators like loops or conditional statements when + defining the forward pass of the model. + + Here we also see that it is perfectly safe to reuse the same parameter many + times when defining a computational graph. + """ + y = self.a + self.b * x + self.c * x**2 + self.d * x**3 + for exp in range(4, random.randint(4, 6)): + y = y + self.e * x**exp + return y + + def string(self): + """ + Just like any class in Python, you can also define custom method on PyTorch modules + """ + return f"y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?" + + +def run(): + # Create Tensors to hold input and outputs. + with proton.scope("init"): + x = torch.linspace(-math.pi, math.pi, 2000, device="cuda") + y = torch.sin(x) + + # Construct our model by instantiating the class defined above + model = DynamicNet().to("cuda") + if mode == "torchinductor": + model = torch.compile(model) + + # Construct our loss function and an Optimizer. Training this strange model with + # vanilla stochastic gradient descent is tough, so we use momentum + criterion = torch.nn.MSELoss(reduction="sum") + optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) + for t in range(1000): + # Forward pass: Compute predicted y by passing x to the model + with proton.scope("forward"): + y_pred = model(x) + + # Compute and print loss + with proton.scope("loss"): + loss = criterion(y_pred, y) + if t % 200 == 199: + print(t, loss.item()) + + # Zero gradients, perform a backward pass, and update the weights. + with proton.scope("backward"): + optimizer.zero_grad() + loss.backward() + with proton.scope("optimizer"): + optimizer.step() + + print(f"Result: {model.string()}") + + +argparser = argparse.ArgumentParser() +argparser.add_argument("--profile", action="store_true") +argparser.add_argument("--mode", default="torch", choices=["torch", "torchinductor"]) +argparser.add_argument("--context", default="shadow", choices=["shadow", "python"]) +argparser.add_argument("--backend", default=None, choices=["cupti", "roctracer", "cupti_pcsampling"]) + +args = argparser.parse_args() + +mode = args.mode + +if args.profile: + func = proton.profile(run, name="dynamic_net", context=args.context, backend=args.backend) +else: + func = run + +func() +# Write out the profile +# Visualize using `proton-viewer -m time/s ./dynamic_net.hatchet` +proton.finalize() diff --git a/third_party/enflame/include/triton/third_party/proton/tutorials/matmul.py b/third_party/enflame/include/triton/third_party/proton/tutorials/matmul.py new file mode 100644 index 000000000..0a7ffc163 --- /dev/null +++ b/third_party/enflame/include/triton/third_party/proton/tutorials/matmul.py @@ -0,0 +1,318 @@ +import torch + +import triton +import triton.language as tl +import triton.profiler as proton +from typing import NamedTuple +import argparse + + +def unpack_grid(grid): + if len(grid) == 1: + return grid[0], 1, 1 + if len(grid) == 2: + return grid[0], grid[1], 1 + if len(grid) == 3: + return grid[0], grid[1], grid[2] + + +def metadata_fn( + grid: tuple, + metadata: NamedTuple, + args: dict, +): + grid_x, grid_y, grid_z = unpack_grid(grid) + num_warps = metadata.num_warps + num_stages = metadata.num_stages + cluster_x, cluster_y, cluster_z = metadata.cluster_dims + shared_memory = metadata.shared + M, K = args["a_ptr"].shape + K, N = args["b_ptr"].shape + return { + "name": + f"matmul_____", + "flops": 2 * M * N * K, + "bytes": (M * N + N * K + K * M) * args["a_ptr"].element_size(), + } + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + ], + key=["M", "N", "K"], +) +@triton.jit(launch_metadata=metadata_fn) +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + + # 1D launch kernel where each block gets its own program. + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +argparser = argparse.ArgumentParser() +argparser.add_argument("--profile", action="store_true") +argparser.add_argument("--pcsampling", action="store_true", default=False) +argparser.add_argument("--cudagraph", action="store_true", default=False) +args = argparser.parse_args() + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 10)], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + line_vals=["cublas", "triton"], + # Label name for the lines + line_names=["cuBLAS", "Triton"], + # Line styles + styles=[("green", "-"), ("blue", "-")], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + with proton.scope(f"matmul_{M}_{N}_{K}"): + if provider == "cublas": + + @proton.scope( + "cublas", + metrics={ + "flops": 2 * M * N * K, + "bytes": (M * N + N * K + K * M) * a.element_size(), + }, + ) + def cublas_matmul(a, b): + torch.matmul(a, b) + + if args.cudagraph: + ms = triton.testing.do_bench_cudagraph(lambda: cublas_matmul(a, b)) + min_ms = max_ms = ms + else: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: cublas_matmul(a, b), quantiles=quantiles) + if provider == "triton": + + def enter_autotune(args, reset_only=False): + if reset_only: + return + proton.enter_scope("") + + def exit_autotune(args, exception): + proton.exit_scope() + + matmul_kernel.pre_hook = enter_autotune + matmul_kernel.post_hook = exit_autotune + with proton.scope("triton"): + if args.cudagraph: + ms = triton.testing.do_bench_cudagraph(lambda: matmul(a, b)) + min_ms = max_ms = ms + else: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + + def perf(ms): + return 2 * M * N * K * 1e-12 / (ms * 1e-3) + + return perf(ms), perf(max_ms), perf(min_ms) + + +if args.profile: + if args.pcsampling: + # proton-viewer -m num_samples/%,time/s ./matmul.hatchet + proton.start("matmul", hook="triton", backend="cupti_pcsampling") + else: + # proton-viewer -m tflop/s,time/s ./matmul.hatchet + proton.start("matmul", hook="triton") + benchmark.run(show_plots=True, print_data=True) + proton.finalize() +else: + benchmark.run(show_plots=True, print_data=True) diff --git a/third_party/enflame/include/triton/unittest/Analysis/CMakeLists.txt b/third_party/enflame/include/triton/unittest/Analysis/CMakeLists.txt new file mode 100644 index 000000000..829d0bff7 --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Analysis/CMakeLists.txt @@ -0,0 +1,8 @@ +add_triton_ut( + NAME TestTritonAnalysis + SRCS UtilityTest.cpp + LIBS + TritonAnalysis + TritonIR + TritonGPUIR +) diff --git a/third_party/enflame/include/triton/unittest/Analysis/UtilityTest.cpp b/third_party/enflame/include/triton/unittest/Analysis/UtilityTest.cpp new file mode 100644 index 000000000..70d95363e --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Analysis/UtilityTest.cpp @@ -0,0 +1,32 @@ +#include "triton/Dialect/Triton/IR/Utility.h" + +#include "llvm/Support/Signals.h" +#include + +namespace mlir { + +TEST(Analysis, reorder) { + SmallVector shape({10, 20, 30}); + { + SmallVector order({2, 1, 0}); + auto reordered = triton::applyPermutation(shape, order); + EXPECT_EQ(reordered[0], 30); + EXPECT_EQ(reordered[1], 20); + EXPECT_EQ(reordered[2], 10); + } + { + SmallVector order({1, 0, 2}); + auto reordered = triton::applyPermutation(shape, order); + EXPECT_EQ(reordered[0], 20); + EXPECT_EQ(reordered[1], 10); + EXPECT_EQ(reordered[2], 30); + } +} + +} // namespace mlir + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/enflame/include/triton/unittest/CMakeLists.txt b/third_party/enflame/include/triton/unittest/CMakeLists.txt new file mode 100644 index 000000000..b4061e90c --- /dev/null +++ b/third_party/enflame/include/triton/unittest/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Analysis) +add_subdirectory(Dialect) +add_subdirectory(Tools) diff --git a/third_party/enflame/include/triton/unittest/Dialect/CMakeLists.txt b/third_party/enflame/include/triton/unittest/Dialect/CMakeLists.txt new file mode 100644 index 000000000..eba47a67c --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonGPU) diff --git a/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/CMakeLists.txt b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..ad9629323 --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,21 @@ +add_triton_ut( + NAME TestSwizzling + SRCS SwizzleTest.cpp + LIBS TritonGPUIR TritonNvidiaGPUIR +) +add_triton_ut( + NAME Dialect + SRCS DialectTest.cpp + LIBS TritonGPUIR +) +add_triton_ut( + NAME LinearLayoutConversions + SRCS LinearLayoutConversionsTest.cpp + LIBS TritonGPUIR +) + +add_triton_ut( + NAME DumpLayoutTest + SRCS DumpLayoutTest.cpp + LIBS TritonGPUIR +) diff --git a/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/DialectTest.cpp b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/DialectTest.cpp new file mode 100644 index 000000000..ff3fcffcf --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -0,0 +1,773 @@ +#include +#include +#include +#include + +#include "mlir/AsmParser/AsmParser.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Signals.h" + +namespace { + +template std::string stringifyLLVMType(const T &t) { + std::string str; + llvm::raw_string_ostream ros(str); + ros << t; + return str; +} +} // namespace + +namespace mlir { +// gtest printer for mlir::Attribute. This must live in namespace mlir in order +// for it to be found via ADL. +void PrintTo(const Attribute &attr, std::ostream *os) { + *os << stringifyLLVMType(attr); +} +} // namespace mlir + +namespace mlir::triton::gpu { +namespace { + +std::vector +createDistributedEncodings(MLIRContext &ctx) { + // Assorted distributed encodings to run tests on + // Define a tensor shape + auto rank = 2; + SmallVector> orders = {{0, 1}, {1, 0}}; + SmallVector ctaLayouts = { + triton::gpu::CTALayoutAttr::getDefault(&ctx, rank), + triton::gpu::CTALayoutAttr::get(&ctx, {4, 2}, {2, 2}, {1, 0}), + }; + std::vector distributedEncodings; + + // Create blocked and slice(blocked) encodings + { + SmallVector sizePerThread = {4, 4}; + SmallVector threadsPerWarp = {4, 8}; + SmallVector warpsPerCTA = {2, 2}; + + for (auto ctaLayout : ctaLayouts) { + for (const auto &order : orders) { + auto blockedEncoding = triton::gpu::BlockedEncodingAttr::get( + &ctx, sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + distributedEncodings.push_back(blockedEncoding); + distributedEncodings.push_back( + triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding)); + } + } + } + + // Create an MMAv2 and DotOperandEncodingAttr (MMAv3 doesn't support linear + // layouts yet) + { + for (auto versionMajor : {2, 3}) { + unsigned versionMinor = 0; + auto kWidth = 2; + SmallVector warpsPerCTA{4, 2}; + auto instrShape = versionMajor == 2 ? SmallVector{16, 8} + : SmallVector{16, 32, 16}; + auto mma = triton::gpu::NvidiaMmaEncodingAttr::get( + &ctx, versionMajor, versionMinor, warpsPerCTA, ctaLayouts[0], + instrShape); + distributedEncodings.push_back(mma); + // Create an opIdx=0 and opIdx=1 encoding + for (unsigned opIdx = 0; opIdx < 2; ++opIdx) { + if (opIdx == 1 && versionMajor == 3) { + // MMAv3 doesn't support register operand on the rhs + continue; + } + distributedEncodings.push_back( + triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, mma, kWidth)); + } + } + } + return distributedEncodings; +} + +std::string strReplace(std::string s, const std::string &from, + const std::string &to) { + size_t start_pos = 0; + while ((start_pos = s.find(from, start_pos)) != std::string::npos) { + s.replace(start_pos, from.length(), to); + start_pos += to.length(); + } + return s; +} + +// We use some abbreviations when spelling out MLIR types. +std::string expandTyStr(std::string s) { + s = strReplace(s, "T<", "tensor<"); + s = strReplace(s, "#B", "#ttg.blocked"); + s = strReplace(s, "spt", "sizePerThread"); + s = strReplace(s, "tpw", "threadsPerWarp"); + s = strReplace(s, "wpc", "warpsPerCTA"); + s = strReplace(s, "ord", "order"); + return s; +} + +// Advances a multidimensional index. Returns true if we wrapped around to the +// beginning. +bool advance(MutableArrayRef idx, ArrayRef shape, + ArrayRef order) { + for (int dim : order) { + if (idx[dim] < shape[dim] - 1) { + idx[dim]++; + return false; + } + idx[dim] = 0; + } + return true; +} + +// Gets a flat index from a multidimensional index. +int64_t getFlatIdx(ArrayRef idx, ArrayRef shape, + ArrayRef order) { + int64_t flatIdx = 0; + int64_t stride = 1; + for (int i = 0; i < idx.size(); i++) { + flatIdx += idx[order[i]] * stride; + stride *= shape[order[i]]; + } + return flatIdx; +} + +class InferLayoutTest : public ::testing::Test { +public: + InferLayoutTest() + : inferLayout( + ctx.getOrLoadDialect() + ->getRegisteredInterface()) {} + +protected: + static MLIRContext ctx; + + DialectInferLayoutInterface *inferLayout; +}; + +/*static*/ MLIRContext InferLayoutTest::ctx; + +void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, + std::optional expectedDstEnc, + DialectInferLayoutInterface *inferLayout, + bool longErrors = true) { + + MLIRContext *ctx = srcTy.getContext(); + + // Capture any errors from calling inferReshapeNoOpReorderEncoding, so we can + // print them if we expected the reshape to succeed but it failed. + std::vector diags; + Attribute inferredEnc; + LogicalResult result = success(); + { + ScopedDiagnosticHandler scopedHandler( + ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); + result = inferLayout->inferReshapeOpEncoding( + srcTy.getShape(), srcTy.getEncoding(), dstTy.getShape(), inferredEnc, + UnknownLoc::get(ctx)); + } + + // We expect the reshape to succeed as long as the inputs have the same + // number of elements + EXPECT_TRUE(succeeded(result)) + << "Expected reshape to succeed, but it didn't! Error(s):\n" + << join(diags, "\n"); + + if (auto expectedEnc = dstTy.getEncoding()) { + EXPECT_EQ(inferredEnc, expectedEnc); + } + + // We know that infer(srcShape, srcEnc, dstShape) => dstEnc. Check that it + // works the other way around too: infer(dstShape, dstEnc, srcShape) => + // srcEnc. (This is an invariant of the inference function.) + // Even more, we check that the inferred encoding is structurally the same as + // the src encoding, showing that the inference is consistent. + { + std::vector diags; + ScopedDiagnosticHandler scopedHandler( + ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); + Attribute inferredSrcEnc; + auto result = inferLayout->inferReshapeOpEncoding( + dstTy.getShape(), inferredEnc, srcTy.getShape(), inferredSrcEnc, + UnknownLoc::get(ctx)); + EXPECT_TRUE(succeeded(result)) + << "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x") + << " " << stringifyLLVMType(inferredEnc) << " -> " + << triton::join(srcTy.getShape(), "x") << "failed:\n" + << join(diags, "\n"); + auto srcLinear = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + auto inferredSrcLinear = toLinearLayout(srcTy.getShape(), inferredSrcEnc); + EXPECT_EQ(inferredSrcLinear, srcLinear) + << "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x") + << " " << stringifyLLVMType(inferredEnc) << " -> " + << triton::join(srcTy.getShape(), "x") + << " gave the wrong result. Expected " << srcLinear.toString() + << " but " + << "got " << inferredSrcLinear.toString() << ".\n"; + } + + // The funtional characterisation of resize is that, if we have a srcLayout + // and a dstLayout, then the flattened layouts are views of the same data + // when considered as C-contiguous. + auto makeFlattenedCContig = [](ArrayRef shape, Attribute layout) { + auto ctx = layout.getContext(); + auto linear = toLinearLayout(shape, layout); + auto dims = standardOutDimNames(ctx, shape.size()); + std::reverse(dims.begin(), dims.end()); + return linear.transposeOuts(dims).reshapeOuts( + {{dims.back(), linear.getTotalOutDimSize()}}); + }; + EXPECT_EQ(makeFlattenedCContig(srcTy.getShape(), srcTy.getEncoding()), + makeFlattenedCContig(dstTy.getShape(), inferredEnc)); +} + +class InferReshapeOpEncodingTest + : public InferLayoutTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(InferReshapeOpEncodingTest, DoIt) { + std::string srcTyStr = expandTyStr(std::get<0>(GetParam())); + std::string dstTyStr = expandTyStr(std::get<1>(GetParam())); + + auto src = mlir::parseType(srcTyStr, &ctx); + if (!src) + FAIL() << "Could not parse source type: " << srcTyStr; + + auto dst = mlir::parseType(dstTyStr, &ctx); + if (!dst) + FAIL() << "Could not parse destination type: " << dstTyStr; + + std::optional expectedDstEnc; + if (auto dstEnc = cast(dst).getEncoding()) { + expectedDstEnc = cast(dstEnc); + } + + testReshape(cast(src), cast(dst), + expectedDstEnc, inferLayout, /*longErrors=*/true); +} + +// A testcase of {a, b, c} means: +// - if `c` is false, check that a reshape from shape+encoding `a` to shape `b` +// is deemed impossible. +// - else if `c` is true: +// - check that a reshape from shape+encoding `a` to shape `b` yields an +// encoding that makes the reshape a nop, and +// - if b has an encoding, check that the inferred encoding matches b's. +INSTANTIATE_TEST_SUITE_P( + Reshapes, InferReshapeOpEncodingTest, + ::testing::ValuesIn(std::vector>({ + // Use raw strings in here so clang-format doesn't try to wrap them. + {R"(T<128x64xf32, #B<{spt=[1,1], tpw=[1,32], wpc=[1,1], ord=[1,0]}>>)", + R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)"}, + + {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", + R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)"}, + + {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", + R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)"}, + + {R"(T<32x32xf32, #B<{spt=[2,2], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", + "T<1024xf32>"}, + + {R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", + R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + + {R"(T<4x32xf32, #B<{spt=[4,1], tpw=[1,32], wpc=[1,1], ord=[0,1]}>>)", + R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)"}, + + {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", + R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + + {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", + R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + + {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[0,1]}>>)", + R"(T<16x2x16x2xf32>)"}, + + // nop reshape, but the block size is 2x larger than the tensor. + {R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", + R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)"}, + + {R"(T<2x4x2x4xf32, #B<{spt=[1,2,2,1], tpw=[1,2,1,2], wpc=[1,2,2,1], ord=[2,1,0,3]}>>)", + R"(T<4x2x2x4xf32>)"}, + + {R"(T<1x2x2x4xf32, #B<{spt=[1,32,4,4], tpw=[4,4,16,16], wpc=[8,8,8,1], ord=[0,1,2,3]}>>)", + R"(T<2x2x4x1xf32>)"}, + + {R"(T<2x2x2x2xf32, #B<{spt=[2,2,2,2], tpw=[1,1,1,1], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", + R"(T<4x4xf32>)"}, + + {R"(T<16x8xf32, #B<{spt=[1,2], tpw=[2,4], wpc=[2,1], ord=[1,0]}>>)", + R"(T<128xf32>)"}, + + {R"(T<16x1x8xf32, #B<{spt=[8,1,1], tpw=[2,1,1], wpc=[1,1,8], ord=[2,1,0]}>>)", + R"(T<128x1xf32>)"}, + + {R"(T<16x1x8xf32, #B<{spt=[1,1,8], tpw=[2,1,1], wpc=[8,1,1], ord=[2,1,0]}>>)", + R"(T<128x1xf32>)"}, + + {R"(T<32x32xf32, #B<{spt=[1,2], tpw=[1,8], wpc=[1,1], ord=[1,0]}>>)", + R"(T<1024xf32>)"}, + + {R"(T<4x4xf32, #B<{spt=[1,1], tpw=[2,4], wpc=[2,1], ord=[0,1]}>>)", + R"(T<16xf32>)"}, + + {R"(T<32xf32, #B<{spt=[2], tpw=[32], wpc=[2], ord=[0]}>>)", + R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)"}, + + {R"(T<2x1x2xf32, #B<{spt=[2,1,1], tpw=[2,1,2], wpc=[4,1,8], ord=[2,1,0]}>>)", + R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)"}, + }))); + +class Fp4ToFpOpTest : public ::testing::Test { +public: + Fp4ToFpOpTest() { ctx.getOrLoadDialect(); } + +protected: + MLIRContext ctx; +}; + +TEST_F(Fp4ToFpOpTest, Fp4ToFpOpLayoutPropagation) { + SmallVector> shapes = {{64, 128}, {256, 1024}}; + auto distributedEncodings = createDistributedEncodings(ctx); + auto *inferLayout = + ctx.getOrLoadDialect() + ->getRegisteredInterface(); + + for (auto enc : distributedEncodings) { + for (auto shape : shapes) { + if (auto sliceEncoding = dyn_cast(enc)) { + shape.erase(shape.begin() + sliceEncoding.getDim()); + } + auto rank = shape.size(); + auto axis = rank - 1; + // Test that we can do a round trip from src to dst encoding and back. + Attribute dstEnc; + LogicalResult result = inferLayout->inferFp4ToFpOpEncoding( + shape, axis, enc, dstEnc, /*fwdInference=*/true, std::nullopt); + EXPECT_TRUE(succeeded(result)); + Attribute newSrcEnc; + auto newShape = shape; + newShape[axis] *= 2; + result = inferLayout->inferFp4ToFpOpEncoding( + newShape, axis, dstEnc, newSrcEnc, /*fwdInference=*/false, + std::nullopt); + EXPECT_TRUE(succeeded(result)); + // Structural equality. + EXPECT_EQ(toLinearLayout(shape, newSrcEnc), toLinearLayout(shape, enc)); + // We'll have equality iff dstEnc is a legacy encoding. + if (!isa(dstEnc)) { + EXPECT_EQ(newSrcEnc, enc); + } + } + } +} + +class JoinOpTest : public ::testing::Test { +public: + JoinOpTest() { ctx.getOrLoadDialect(); } + +protected: + MLIRContext ctx; +}; + +TEST_F(JoinOpTest, JoinOpLayoutPropagation) { + SmallVector> shapes = {{64, 128}, {256, 1024}}; + auto distributedEncodings = createDistributedEncodings(ctx); + auto *inferLayout = + ctx.getOrLoadDialect() + ->getRegisteredInterface(); + + for (auto enc : distributedEncodings) { + for (auto shape : shapes) { + if (auto sliceEncoding = dyn_cast(enc)) { + shape.erase(shape.begin() + sliceEncoding.getDim()); + } + auto rank = shape.size(); + // Join only supports Linear or Blocked + auto linear = LinearEncodingAttr::get(&ctx, toLinearLayout(shape, enc)); + // Test that we can do a round trip from src to dst encoding and back. + Attribute dstEnc; + LogicalResult result = + inferLayout->inferJoinOpEncoding(linear, dstEnc, shape, std::nullopt); + EXPECT_TRUE(succeeded(result)); + Attribute newSrcEnc; + auto newShape = shape; + newShape.push_back(2); + result = inferLayout->inferSplitOpEncoding(dstEnc, newSrcEnc, newShape, + std::nullopt); + EXPECT_TRUE(succeeded(result)); + // Structural equality. + EXPECT_EQ(toLinearLayout(shape, newSrcEnc), toLinearLayout(shape, enc)); + // We'll have equality iff dstEnc is a legacy encoding. + if (!isa(dstEnc)) { + EXPECT_EQ(newSrcEnc, enc); + } + + // We test against this decomposition: + // newShape = shape + // newShape[axis] *= 2 + // rank = len(shape) + // transShape = list(range(rank)) + // transShape.insert(axis + 1, rank) + // join(enc, enc).trans(transShape).reshape(newShape) + auto axis = rank - 1; + auto transPerm = llvm::to_vector(llvm::seq(0, rank)); + transPerm.insert(transPerm.begin() + axis + 1, rank); + Attribute joinedEnc; + result = + inferLayout->inferJoinOpEncoding(enc, joinedEnc, shape, std::nullopt); + auto joinShape = shape; + joinShape.push_back(2); + assert(succeeded(result)); + Attribute transEnc; + result = inferLayout->inferTransOpEncoding(joinedEnc, joinShape, + transPerm, transEnc); + assert(succeeded(result)); + SmallVector transShape; + for (auto i : transPerm) { + transShape.push_back(joinShape[i]); + } + Attribute reshapedEnc; + result = inferLayout->inferReshapeOpEncoding( + transShape, transEnc, newShape, reshapedEnc, std::nullopt); + assert(succeeded(result)); + // The layouts should be structurally the same + // but reshapeEnc will likely be a LinearEncodingAttr + EXPECT_EQ(toLinearLayout(newShape, reshapedEnc), + toLinearLayout(newShape, dstEnc)); + } + } +} + +class AMDLayoutTest : public ::testing::Test { +public: + AMDLayoutTest() { + ctx.getOrLoadDialect(); + ctaLayout = + triton::gpu::CTALayoutAttr::get(&ctx, ctaPerCGA, ctaSplit, ctaOrder); + f16Ty = Float16Type::get(&ctx); + } + + triton::gpu::DotOperandEncodingAttr + createDotOperand(int idx, Attribute parent, int kWidth) { + return triton::gpu::DotOperandEncodingAttr::get(&ctx, idx, parent, kWidth); + } + +protected: + MLIRContext ctx; + const SmallVector ctaPerCGA{1, 1, 1}; + const SmallVector ctaSplit{1, 1, 1}; + const SmallVector ctaOrder{2, 1, 0}; + triton::gpu::CTALayoutAttr ctaLayout; + Type f16Ty; +}; + +class AMDMfmaLayoutTest : public AMDLayoutTest { +public: + AMDMfmaLayoutTest() = default; + + triton::gpu::AMDMfmaEncodingAttr createMFMA(int mDim, int nDim, + ArrayRef warpsPerCTA) { + return triton::gpu::AMDMfmaEncodingAttr::get( + &ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, mDim, nDim, + /*isTransposed=*/false, ctaLayout); + } + + triton::gpu::AMDMfmaEncodingAttr + createTransposedMFMA(int mDim, int nDim, ArrayRef warpsPerCTA) { + return triton::gpu::AMDMfmaEncodingAttr::get( + &ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, mDim, nDim, + /*isTransposed=*/true, ctaLayout); + } +}; + +class AMDWmmaLayoutTest : public AMDLayoutTest { +public: + AMDWmmaLayoutTest() = default; + + triton::gpu::AMDWmmaEncodingAttr + createWMMAv1(ArrayRef warpsPerCTA) { + return triton::gpu::AMDWmmaEncodingAttr::get( + &ctx, /*version=*/1, /*isTransposed=*/false, warpsPerCTA, ctaLayout); + } + + triton::gpu::AMDWmmaEncodingAttr + createWMMAv2(bool isTransposed, ArrayRef warpsPerCTA) { + return triton::gpu::AMDWmmaEncodingAttr::get( + &ctx, /*version=*/2, isTransposed, warpsPerCTA, ctaLayout); + } +}; + +TEST_F(AMDMfmaLayoutTest, mfma32) { + auto mfma2d = createMFMA(32, 32, {2, 4}); + ASSERT_THAT(mfma2d.getDefaultThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(mfma2d.getDefaultWarpOrder(), testing::ElementsAre(1u, 0u)); + + auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); + ASSERT_THAT(tmfma2d.getDefaultThreadOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getDefaultWarpOrder(), testing::ElementsAre(1u, 0u)); + + auto mfma3d = createMFMA(32, 32, {2, 4, 1}); + ASSERT_THAT(mfma3d.getDefaultThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(mfma3d.getDefaultWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); + + auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); + ASSERT_THAT(tmfma3d.getDefaultThreadOrder(), + testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getDefaultWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); +} + +TEST_F(AMDMfmaLayoutTest, mfma16) { + auto mfma2d = createMFMA(16, 16, {2, 4}); + ASSERT_THAT(mfma2d.getDefaultThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(mfma2d.getDefaultWarpOrder(), testing::ElementsAre(1u, 0u)); + + auto tmfma2d = createTransposedMFMA(16, 16, {2, 4}); + ASSERT_THAT(tmfma2d.getDefaultThreadOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getDefaultWarpOrder(), testing::ElementsAre(1u, 0u)); + + auto mfma3d = createMFMA(16, 16, {2, 4, 1}); + ASSERT_THAT(mfma3d.getDefaultThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(mfma3d.getDefaultWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); + + auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1}); + ASSERT_THAT(tmfma3d.getDefaultThreadOrder(), + testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getDefaultWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); +} + +TEST_F(AMDMfmaLayoutTest, mfma_dot_op) { + auto mfma2d = createMFMA(32, 32, {2, 4}); + auto dot2dOp0 = createDotOperand(0, mfma2d, 4); + auto dot2dOp1 = createDotOperand(1, mfma2d, 4); + ASSERT_THAT(dot2dOp0.getDefaultWarpOrder(), mfma2d.getDefaultWarpOrder()); + ASSERT_THAT(dot2dOp1.getDefaultWarpOrder(), mfma2d.getDefaultWarpOrder()); + ASSERT_THAT(dot2dOp0.getThreadsPerWarp(), testing::ElementsAre(32u, 2u)); + ASSERT_THAT(dot2dOp1.getThreadsPerWarp(), testing::ElementsAre(2u, 32u)); + + auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); + auto tdot2dOp0 = createDotOperand(0, tmfma2d, 4); + auto tdot2dOp1 = createDotOperand(1, tmfma2d, 4); + ASSERT_THAT(tdot2dOp0.getDefaultWarpOrder(), tmfma2d.getDefaultWarpOrder()); + ASSERT_THAT(tdot2dOp1.getDefaultWarpOrder(), tmfma2d.getDefaultWarpOrder()); + + auto mfma3d = createMFMA(32, 32, {2, 4, 1}); + auto dot3dOp0 = createDotOperand(0, mfma3d, 4); + auto dot3dOp1 = createDotOperand(1, mfma3d, 4); + ASSERT_THAT(dot3dOp0.getDefaultWarpOrder(), mfma3d.getDefaultWarpOrder()); + ASSERT_THAT(dot3dOp1.getDefaultWarpOrder(), mfma3d.getDefaultWarpOrder()); + ASSERT_THAT(dot3dOp0.getThreadsPerWarp(), testing::ElementsAre(1u, 32u, 2u)); + ASSERT_THAT(dot3dOp1.getThreadsPerWarp(), testing::ElementsAre(1u, 2u, 32u)); + + auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); + auto tdot3dOp0 = createDotOperand(0, tmfma3d, 4); + auto tdot3dOp1 = createDotOperand(1, tmfma3d, 4); + ASSERT_THAT(tdot3dOp0.getDefaultWarpOrder(), tmfma3d.getDefaultWarpOrder()); + ASSERT_THAT(tdot3dOp1.getDefaultWarpOrder(), tmfma3d.getDefaultWarpOrder()); + + auto mfma16_2d = createMFMA(16, 16, {2, 4}); + auto dot16_2dOp0 = createDotOperand(0, mfma16_2d, 4); + auto dot16_2dOp1 = createDotOperand(1, mfma16_2d, 4); + ASSERT_THAT(dot16_2dOp0.getThreadsPerWarp(), testing::ElementsAre(16u, 4u)); + ASSERT_THAT(dot16_2dOp1.getThreadsPerWarp(), testing::ElementsAre(4u, 16u)); + + auto mfma16_3d = createMFMA(16, 16, {2, 4, 1}); + auto dot16_3dOp0 = createDotOperand(0, mfma16_3d, 4); + auto dot16_3dOp1 = createDotOperand(1, mfma16_3d, 4); + ASSERT_THAT(dot16_3dOp0.getThreadsPerWarp(), + testing::ElementsAre(1u, 16u, 4u)); + ASSERT_THAT(dot16_3dOp1.getThreadsPerWarp(), + testing::ElementsAre(1u, 4u, 16u)); +} + +TEST_F(AMDWmmaLayoutTest, wmmaV1) { + auto wmma2d = createWMMAv1({2, 4}); + ASSERT_THAT(wmma2d.getDefaultThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(wmma2d.getDefaultWarpOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(wmma2d.getThreadsPerWarp(), testing::ElementsAre(2u, 16u)); + + auto wmma3d = createWMMAv1({2, 4, 1}); + ASSERT_THAT(wmma3d.getDefaultThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(wmma3d.getDefaultWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(wmma3d.getThreadsPerWarp(), testing::ElementsAre(1, 2u, 16u)); +} + +TEST_F(AMDWmmaLayoutTest, wmmaV2) { + auto wmma2d = createWMMAv2(false, {2, 4}); + ASSERT_THAT(wmma2d.getDefaultThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(wmma2d.getDefaultWarpOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(wmma2d.getThreadsPerWarp(), testing::ElementsAre(2u, 16u)); + + auto wmma3d = createWMMAv2(false, {2, 4, 1}); + ASSERT_THAT(wmma3d.getDefaultThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(wmma3d.getDefaultWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(wmma3d.getThreadsPerWarp(), testing::ElementsAre(1u, 2u, 16u)); + + auto twmma2d = createWMMAv2(true, {2, 4}); + ASSERT_THAT(twmma2d.getDefaultThreadOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(twmma2d.getDefaultWarpOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(twmma2d.getThreadsPerWarp(), testing::ElementsAre(16u, 2u)); + + auto twmma3d = createWMMAv2(true, {2, 4, 1}); + ASSERT_THAT(twmma3d.getDefaultThreadOrder(), + testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(twmma3d.getDefaultWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(twmma3d.getThreadsPerWarp(), testing::ElementsAre(1u, 16u, 2u)); +} + +TEST_F(AMDWmmaLayoutTest, wmma_dot_op) { + auto wmma2dVer1 = createWMMAv1({2, 4}); + auto dot2dVer1Op0 = createDotOperand(0, wmma2dVer1, 16); + auto dot2dVer1Op1 = createDotOperand(1, wmma2dVer1, 16); + ASSERT_THAT(dot2dVer1Op0.getDefaultWarpOrder(), + wmma2dVer1.getDefaultWarpOrder()); + ASSERT_THAT(dot2dVer1Op1.getDefaultWarpOrder(), + wmma2dVer1.getDefaultWarpOrder()); + ASSERT_THAT(dot2dVer1Op0.getThreadsPerWarp(), testing::ElementsAre(16u, 1u)); + ASSERT_THAT(dot2dVer1Op1.getThreadsPerWarp(), testing::ElementsAre(1u, 16u)); + + auto wmma3dVer1 = createWMMAv1({2, 4, 1}); + auto dot3dVer1Op0 = createDotOperand(0, wmma3dVer1, 16); + auto dot3dVer1Op1 = createDotOperand(1, wmma3dVer1, 16); + ASSERT_THAT(dot3dVer1Op0.getDefaultWarpOrder(), + wmma3dVer1.getDefaultWarpOrder()); + ASSERT_THAT(dot3dVer1Op1.getDefaultWarpOrder(), + wmma3dVer1.getDefaultWarpOrder()); + ASSERT_THAT(dot3dVer1Op0.getThreadsPerWarp(), + testing::ElementsAre(1, 16u, 1u)); + ASSERT_THAT(dot3dVer1Op1.getThreadsPerWarp(), + testing::ElementsAre(1, 1u, 16u)); + + auto wmma2dVer2 = createWMMAv2(false, {2, 4}); + auto dot2dVer2Op0 = createDotOperand(0, wmma2dVer2, 16); + auto dot2dVer2Op1 = createDotOperand(1, wmma2dVer2, 16); + ASSERT_THAT(dot2dVer2Op0.getDefaultWarpOrder(), + wmma2dVer2.getDefaultWarpOrder()); + ASSERT_THAT(dot2dVer2Op1.getDefaultWarpOrder(), + wmma2dVer2.getDefaultWarpOrder()); + ASSERT_THAT(dot2dVer2Op0.getThreadsPerWarp(), testing::ElementsAre(16u, 2u)); + ASSERT_THAT(dot2dVer2Op1.getThreadsPerWarp(), testing::ElementsAre(2u, 16u)); + + auto wmma3dVer2 = createWMMAv2(false, {2, 4, 1}); + auto dot3dVer2Op0 = createDotOperand(0, wmma3dVer2, 16); + auto dot3dVer2Op1 = createDotOperand(1, wmma3dVer2, 16); + ASSERT_THAT(dot3dVer2Op0.getDefaultWarpOrder(), + wmma3dVer2.getDefaultWarpOrder()); + ASSERT_THAT(dot3dVer2Op1.getDefaultWarpOrder(), + wmma3dVer2.getDefaultWarpOrder()); + ASSERT_THAT(dot3dVer2Op0.getThreadsPerWarp(), + testing::ElementsAre(1, 16u, 2u)); + ASSERT_THAT(dot3dVer2Op1.getThreadsPerWarp(), + testing::ElementsAre(1, 2u, 16u)); +} + +class LinearEncodingTest : public ::testing::Test { +public: + LinearEncodingTest() { ctx.getOrLoadDialect(); } + +protected: + MLIRContext ctx; +}; + +TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { + // Define a tensor shape + auto rank = 2; + SmallVector> shapes = {{64, 128}, {256, 1024}}; + std::vector distributedEncodings = + createDistributedEncodings(ctx); + + auto n = distributedEncodings.size(); + for (auto i = 0; i < n; ++i) { + if (auto blocked = dyn_cast( + distributedEncodings[i])) { + for (unsigned opIdx = 0; opIdx < 2; ++opIdx) { + distributedEncodings.push_back( + triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, blocked, 0)); + } + } + } + + auto is_dot_op_with_block_parent = [](Attribute layout) { + auto dot_layout = dyn_cast(layout); + return dot_layout && + isa(dot_layout.getParent()); + }; + + for (const auto &distributedEncoding : distributedEncodings) { + for (auto shape : shapes) { + if (auto sliceEncoding = + dyn_cast(distributedEncoding)) { + shape.erase(shape.begin() + sliceEncoding.getDim()); + } + + // Create LinearEncodingAttr from the LinearLayout + auto linearLayout = distributedEncoding.toLinearLayout(shape); + auto linearEncoding = + triton::gpu::LinearEncodingAttr::get(&ctx, linearLayout); + + // Test that the canonical form of the LinearLayout is indeed canonical + // by expanding it to the original shape + auto expandedLL = linearEncoding.toLinearLayout(shape); + ASSERT_EQ(linearLayout, expandedLL); + + // Test that methods of DistributedEncoding return the same values + Type eltTy = Float32Type::get(&ctx); + + ASSERT_EQ(distributedEncoding.getTotalElemsPerThread(shape), + linearEncoding.getTotalElemsPerThread(shape)); + ASSERT_EQ(distributedEncoding.getElemsPerThread(shape), + linearEncoding.getElemsPerThread(shape)); + if (!is_dot_op_with_block_parent(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getRepOrder(), + linearEncoding.getRepOrder()); + } + // DotOperandEncodingAttr::getDefaultWarpOrder() is not defined + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getDefaultWarpOrder(), + linearEncoding.getWarpOrder()); + } + if (!is_dot_op_with_block_parent(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getDefaultThreadOrder(), + linearEncoding.getThreadOrder()); + } + // For slice these do not equal the total number of lines / warps + // See [Note. Divergence of methods wrt. legacy layouts] + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getWarpsPerCTA(), + linearEncoding.getWarpsPerCTA()); + if (!is_dot_op_with_block_parent(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), + linearEncoding.getThreadsPerWarp()); + } + } + + // block level + // SliceEncoding is not well-defined for CGAs + if (!isa(distributedEncoding)) { + auto baseEncoding = cast(distributedEncoding); + ASSERT_EQ(baseEncoding.getCTASplitNum(), + linearEncoding.getCTASplitNum()); + ASSERT_EQ(baseEncoding.getCTAsPerCGA(), baseEncoding.getCTAsPerCGA()); + // If we are not using CGAs, the order is meaningless + auto useCGA = + baseEncoding.getCTAsPerCGA() != SmallVector(rank, 1); + if (useCGA && !is_dot_op_with_block_parent(distributedEncoding)) { + ASSERT_EQ(baseEncoding.getCTAOrder(), linearEncoding.getCTAOrder()); + } + } + } + } +} +} // namespace +} // namespace mlir::triton::gpu + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp new file mode 100644 index 000000000..5c6260371 --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp @@ -0,0 +1,525 @@ +#include "mlir/IR/MLIRContext.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Signals.h" +#include +#include + +namespace mlir::triton::gpu { +namespace { + +class DumpLayoutTest : public ::testing::Test { +public: + void SetUp() { ctx.getOrLoadDialect(); } + + BlockedEncodingAttr blocked(ArrayRef spt, ArrayRef tpw, + ArrayRef wpb, ArrayRef cpg, + ArrayRef cSplit, ArrayRef ord, + ArrayRef cOrd) { + return BlockedEncodingAttr::get( + &ctx, spt, tpw, wpb, ord, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + SwizzledSharedEncodingAttr shared(unsigned vec, unsigned perPhase, + unsigned maxPhase, ArrayRef cpg, + ArrayRef cSplit, + ArrayRef ord, + ArrayRef cOrd) { + return SwizzledSharedEncodingAttr::get( + &ctx, vec, perPhase, maxPhase, ord, + CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + void assertSameStr(const std::string &refStr, const std::string &output) { + if (refStr != output) { + llvm::outs() << "RefStr =\n" + << refStr << "\n" + << "\n" + << "Output =\n" + << output << "\n"; + FAIL() << "Incorrect output string"; + } + } + +protected: + MLIRContext ctx; +}; + +TEST_F(DumpLayoutTest, SimpleBlocked) { + std::string ref = + R"([ T0:0| T4:0| T8:0|T12:0|T16:0|T20:0|T24:0|T28:0, T1:0| T5:0| T9:0|T13:0|T17:0|T21:0|T25:0|T29:0, T2:0| T6:0|T10:0|T14:0|T18:0|T22:0|T26:0|T30:0, T3:0| T7:0|T11:0|T15:0|T19:0|T23:0|T27:0|T31:0] +)"; + auto blockedLayout = blocked({1}, {8}, {4}, {1}, {1}, {0}, {0}); + auto tensorType = RankedTensorType::get( + {4}, IntegerType::get(blockedLayout.getContext(), 32), blockedLayout); + std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false); + assertSameStr(ref, layout); + + std::string refHWRep = + R"(Warp0: +(0), (1), (2), (3), (0), (1), (2), (3) +Warp1: +(0), (1), (2), (3), (0), (1), (2), (3) +Warp2: +(0), (1), (2), (3), (0), (1), (2), (3) +Warp3: +(0), (1), (2), (3), (0), (1), (2), (3) +)"; + std::string layoutHW = getLayoutStr(tensorType, /*useHWPointOfView=*/true); + assertSameStr(refHWRep, layoutHW); +} + +TEST_F(DumpLayoutTest, NDTensor) { + auto blockedLayout = blocked({2, 1, 4}, {2, 2, 2}, {1, 2, 1}, {1, 1, 1}, + {1, 1, 1}, {2, 1, 0}, {2, 1, 0}); + auto tensorType = RankedTensorType::get( + {8, 2, 16}, IntegerType::get(blockedLayout.getContext(), 32), + blockedLayout); + std::string ref = + R"([[[ T0:0| T8:0, T0:1| T8:1, T0:2| T8:2, T0:3| T8:3, T1:0| T9:0, T1:1| T9:1, T1:2| T9:2, T1:3| T9:3, T0:8| T8:8, T0:9| T8:9, T0:10| T8:10, T0:11| T8:11, T1:8| T9:8, T1:9| T9:9, T1:10| T9:10, T1:11| T9:11] +[ T2:0| T10:0, T2:1| T10:1, T2:2| T10:2, T2:3| T10:3, T3:0| T11:0, T3:1| T11:1, T3:2| T11:2, T3:3| T11:3, T2:8| T10:8, T2:9| T10:9, T2:10|T10:10, T2:11|T10:11, T3:8| T11:8, T3:9| T11:9, T3:10|T11:10, T3:11|T11:11]] +[[ T0:4| T8:4, T0:5| T8:5, T0:6| T8:6, T0:7| T8:7, T1:4| T9:4, T1:5| T9:5, T1:6| T9:6, T1:7| T9:7, T0:12| T8:12, T0:13| T8:13, T0:14| T8:14, T0:15| T8:15, T1:12| T9:12, T1:13| T9:13, T1:14| T9:14, T1:15| T9:15] +[ T2:4| T10:4, T2:5| T10:5, T2:6| T10:6, T2:7| T10:7, T3:4| T11:4, T3:5| T11:5, T3:6| T11:6, T3:7| T11:7, T2:12|T10:12, T2:13|T10:13, T2:14|T10:14, T2:15|T10:15, T3:12|T11:12, T3:13|T11:13, T3:14|T11:14, T3:15|T11:15]] +[[ T4:0| T12:0, T4:1| T12:1, T4:2| T12:2, T4:3| T12:3, T5:0| T13:0, T5:1| T13:1, T5:2| T13:2, T5:3| T13:3, T4:8| T12:8, T4:9| T12:9, T4:10|T12:10, T4:11|T12:11, T5:8| T13:8, T5:9| T13:9, T5:10|T13:10, T5:11|T13:11] +[ T6:0| T14:0, T6:1| T14:1, T6:2| T14:2, T6:3| T14:3, T7:0| T15:0, T7:1| T15:1, T7:2| T15:2, T7:3| T15:3, T6:8| T14:8, T6:9| T14:9, T6:10|T14:10, T6:11|T14:11, T7:8| T15:8, T7:9| T15:9, T7:10|T15:10, T7:11|T15:11]] +[[ T4:4| T12:4, T4:5| T12:5, T4:6| T12:6, T4:7| T12:7, T5:4| T13:4, T5:5| T13:5, T5:6| T13:6, T5:7| T13:7, T4:12|T12:12, T4:13|T12:13, T4:14|T12:14, T4:15|T12:15, T5:12|T13:12, T5:13|T13:13, T5:14|T13:14, T5:15|T13:15] +[ T6:4| T14:4, T6:5| T14:5, T6:6| T14:6, T6:7| T14:7, T7:4| T15:4, T7:5| T15:5, T7:6| T15:6, T7:7| T15:7, T6:12|T14:12, T6:13|T14:13, T6:14|T14:14, T6:15|T14:15, T7:12|T15:12, T7:13|T15:13, T7:14|T15:14, T7:15|T15:15]] +[[ T0:16| T8:16, T0:17| T8:17, T0:18| T8:18, T0:19| T8:19, T1:16| T9:16, T1:17| T9:17, T1:18| T9:18, T1:19| T9:19, T0:24| T8:24, T0:25| T8:25, T0:26| T8:26, T0:27| T8:27, T1:24| T9:24, T1:25| T9:25, T1:26| T9:26, T1:27| T9:27] +[ T2:16|T10:16, T2:17|T10:17, T2:18|T10:18, T2:19|T10:19, T3:16|T11:16, T3:17|T11:17, T3:18|T11:18, T3:19|T11:19, T2:24|T10:24, T2:25|T10:25, T2:26|T10:26, T2:27|T10:27, T3:24|T11:24, T3:25|T11:25, T3:26|T11:26, T3:27|T11:27]] +[[ T0:20| T8:20, T0:21| T8:21, T0:22| T8:22, T0:23| T8:23, T1:20| T9:20, T1:21| T9:21, T1:22| T9:22, T1:23| T9:23, T0:28| T8:28, T0:29| T8:29, T0:30| T8:30, T0:31| T8:31, T1:28| T9:28, T1:29| T9:29, T1:30| T9:30, T1:31| T9:31] +[ T2:20|T10:20, T2:21|T10:21, T2:22|T10:22, T2:23|T10:23, T3:20|T11:20, T3:21|T11:21, T3:22|T11:22, T3:23|T11:23, T2:28|T10:28, T2:29|T10:29, T2:30|T10:30, T2:31|T10:31, T3:28|T11:28, T3:29|T11:29, T3:30|T11:30, T3:31|T11:31]] +[[ T4:16|T12:16, T4:17|T12:17, T4:18|T12:18, T4:19|T12:19, T5:16|T13:16, T5:17|T13:17, T5:18|T13:18, T5:19|T13:19, T4:24|T12:24, T4:25|T12:25, T4:26|T12:26, T4:27|T12:27, T5:24|T13:24, T5:25|T13:25, T5:26|T13:26, T5:27|T13:27] +[ T6:16|T14:16, T6:17|T14:17, T6:18|T14:18, T6:19|T14:19, T7:16|T15:16, T7:17|T15:17, T7:18|T15:18, T7:19|T15:19, T6:24|T14:24, T6:25|T14:25, T6:26|T14:26, T6:27|T14:27, T7:24|T15:24, T7:25|T15:25, T7:26|T15:26, T7:27|T15:27]] +[[ T4:20|T12:20, T4:21|T12:21, T4:22|T12:22, T4:23|T12:23, T5:20|T13:20, T5:21|T13:21, T5:22|T13:22, T5:23|T13:23, T4:28|T12:28, T4:29|T12:29, T4:30|T12:30, T4:31|T12:31, T5:28|T13:28, T5:29|T13:29, T5:30|T13:30, T5:31|T13:31] +[ T6:20|T14:20, T6:21|T14:21, T6:22|T14:22, T6:23|T14:23, T7:20|T15:20, T7:21|T15:21, T7:22|T15:22, T7:23|T15:23, T6:28|T14:28, T6:29|T14:29, T6:30|T14:30, T6:31|T14:31, T7:28|T15:28, T7:29|T15:29, T7:30|T15:30, T7:31|T15:31]]] +)"; + std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false); + assertSameStr(ref, layout); + std::string refHWRep = + R"(Warp0: +(0,0, 0), (0,0, 4), (0,1, 0), (0,1, 4), (2,0, 0), (2,0, 4), (2,1, 0), (2,1, 4) +(0,0, 1), (0,0, 5), (0,1, 1), (0,1, 5), (2,0, 1), (2,0, 5), (2,1, 1), (2,1, 5) +(0,0, 2), (0,0, 6), (0,1, 2), (0,1, 6), (2,0, 2), (2,0, 6), (2,1, 2), (2,1, 6) +(0,0, 3), (0,0, 7), (0,1, 3), (0,1, 7), (2,0, 3), (2,0, 7), (2,1, 3), (2,1, 7) +(1,0, 0), (1,0, 4), (1,1, 0), (1,1, 4), (3,0, 0), (3,0, 4), (3,1, 0), (3,1, 4) +(1,0, 1), (1,0, 5), (1,1, 1), (1,1, 5), (3,0, 1), (3,0, 5), (3,1, 1), (3,1, 5) +(1,0, 2), (1,0, 6), (1,1, 2), (1,1, 6), (3,0, 2), (3,0, 6), (3,1, 2), (3,1, 6) +(1,0, 3), (1,0, 7), (1,1, 3), (1,1, 7), (3,0, 3), (3,0, 7), (3,1, 3), (3,1, 7) +(0,0, 8), (0,0,12), (0,1, 8), (0,1,12), (2,0, 8), (2,0,12), (2,1, 8), (2,1,12) +(0,0, 9), (0,0,13), (0,1, 9), (0,1,13), (2,0, 9), (2,0,13), (2,1, 9), (2,1,13) +(0,0,10), (0,0,14), (0,1,10), (0,1,14), (2,0,10), (2,0,14), (2,1,10), (2,1,14) +(0,0,11), (0,0,15), (0,1,11), (0,1,15), (2,0,11), (2,0,15), (2,1,11), (2,1,15) +(1,0, 8), (1,0,12), (1,1, 8), (1,1,12), (3,0, 8), (3,0,12), (3,1, 8), (3,1,12) +(1,0, 9), (1,0,13), (1,1, 9), (1,1,13), (3,0, 9), (3,0,13), (3,1, 9), (3,1,13) +(1,0,10), (1,0,14), (1,1,10), (1,1,14), (3,0,10), (3,0,14), (3,1,10), (3,1,14) +(1,0,11), (1,0,15), (1,1,11), (1,1,15), (3,0,11), (3,0,15), (3,1,11), (3,1,15) +(4,0, 0), (4,0, 4), (4,1, 0), (4,1, 4), (6,0, 0), (6,0, 4), (6,1, 0), (6,1, 4) +(4,0, 1), (4,0, 5), (4,1, 1), (4,1, 5), (6,0, 1), (6,0, 5), (6,1, 1), (6,1, 5) +(4,0, 2), (4,0, 6), (4,1, 2), (4,1, 6), (6,0, 2), (6,0, 6), (6,1, 2), (6,1, 6) +(4,0, 3), (4,0, 7), (4,1, 3), (4,1, 7), (6,0, 3), (6,0, 7), (6,1, 3), (6,1, 7) +(5,0, 0), (5,0, 4), (5,1, 0), (5,1, 4), (7,0, 0), (7,0, 4), (7,1, 0), (7,1, 4) +(5,0, 1), (5,0, 5), (5,1, 1), (5,1, 5), (7,0, 1), (7,0, 5), (7,1, 1), (7,1, 5) +(5,0, 2), (5,0, 6), (5,1, 2), (5,1, 6), (7,0, 2), (7,0, 6), (7,1, 2), (7,1, 6) +(5,0, 3), (5,0, 7), (5,1, 3), (5,1, 7), (7,0, 3), (7,0, 7), (7,1, 3), (7,1, 7) +(4,0, 8), (4,0,12), (4,1, 8), (4,1,12), (6,0, 8), (6,0,12), (6,1, 8), (6,1,12) +(4,0, 9), (4,0,13), (4,1, 9), (4,1,13), (6,0, 9), (6,0,13), (6,1, 9), (6,1,13) +(4,0,10), (4,0,14), (4,1,10), (4,1,14), (6,0,10), (6,0,14), (6,1,10), (6,1,14) +(4,0,11), (4,0,15), (4,1,11), (4,1,15), (6,0,11), (6,0,15), (6,1,11), (6,1,15) +(5,0, 8), (5,0,12), (5,1, 8), (5,1,12), (7,0, 8), (7,0,12), (7,1, 8), (7,1,12) +(5,0, 9), (5,0,13), (5,1, 9), (5,1,13), (7,0, 9), (7,0,13), (7,1, 9), (7,1,13) +(5,0,10), (5,0,14), (5,1,10), (5,1,14), (7,0,10), (7,0,14), (7,1,10), (7,1,14) +(5,0,11), (5,0,15), (5,1,11), (5,1,15), (7,0,11), (7,0,15), (7,1,11), (7,1,15) +Warp1: +(0,0, 0), (0,0, 4), (0,1, 0), (0,1, 4), (2,0, 0), (2,0, 4), (2,1, 0), (2,1, 4) +(0,0, 1), (0,0, 5), (0,1, 1), (0,1, 5), (2,0, 1), (2,0, 5), (2,1, 1), (2,1, 5) +(0,0, 2), (0,0, 6), (0,1, 2), (0,1, 6), (2,0, 2), (2,0, 6), (2,1, 2), (2,1, 6) +(0,0, 3), (0,0, 7), (0,1, 3), (0,1, 7), (2,0, 3), (2,0, 7), (2,1, 3), (2,1, 7) +(1,0, 0), (1,0, 4), (1,1, 0), (1,1, 4), (3,0, 0), (3,0, 4), (3,1, 0), (3,1, 4) +(1,0, 1), (1,0, 5), (1,1, 1), (1,1, 5), (3,0, 1), (3,0, 5), (3,1, 1), (3,1, 5) +(1,0, 2), (1,0, 6), (1,1, 2), (1,1, 6), (3,0, 2), (3,0, 6), (3,1, 2), (3,1, 6) +(1,0, 3), (1,0, 7), (1,1, 3), (1,1, 7), (3,0, 3), (3,0, 7), (3,1, 3), (3,1, 7) +(0,0, 8), (0,0,12), (0,1, 8), (0,1,12), (2,0, 8), (2,0,12), (2,1, 8), (2,1,12) +(0,0, 9), (0,0,13), (0,1, 9), (0,1,13), (2,0, 9), (2,0,13), (2,1, 9), (2,1,13) +(0,0,10), (0,0,14), (0,1,10), (0,1,14), (2,0,10), (2,0,14), (2,1,10), (2,1,14) +(0,0,11), (0,0,15), (0,1,11), (0,1,15), (2,0,11), (2,0,15), (2,1,11), (2,1,15) +(1,0, 8), (1,0,12), (1,1, 8), (1,1,12), (3,0, 8), (3,0,12), (3,1, 8), (3,1,12) +(1,0, 9), (1,0,13), (1,1, 9), (1,1,13), (3,0, 9), (3,0,13), (3,1, 9), (3,1,13) +(1,0,10), (1,0,14), (1,1,10), (1,1,14), (3,0,10), (3,0,14), (3,1,10), (3,1,14) +(1,0,11), (1,0,15), (1,1,11), (1,1,15), (3,0,11), (3,0,15), (3,1,11), (3,1,15) +(4,0, 0), (4,0, 4), (4,1, 0), (4,1, 4), (6,0, 0), (6,0, 4), (6,1, 0), (6,1, 4) +(4,0, 1), (4,0, 5), (4,1, 1), (4,1, 5), (6,0, 1), (6,0, 5), (6,1, 1), (6,1, 5) +(4,0, 2), (4,0, 6), (4,1, 2), (4,1, 6), (6,0, 2), (6,0, 6), (6,1, 2), (6,1, 6) +(4,0, 3), (4,0, 7), (4,1, 3), (4,1, 7), (6,0, 3), (6,0, 7), (6,1, 3), (6,1, 7) +(5,0, 0), (5,0, 4), (5,1, 0), (5,1, 4), (7,0, 0), (7,0, 4), (7,1, 0), (7,1, 4) +(5,0, 1), (5,0, 5), (5,1, 1), (5,1, 5), (7,0, 1), (7,0, 5), (7,1, 1), (7,1, 5) +(5,0, 2), (5,0, 6), (5,1, 2), (5,1, 6), (7,0, 2), (7,0, 6), (7,1, 2), (7,1, 6) +(5,0, 3), (5,0, 7), (5,1, 3), (5,1, 7), (7,0, 3), (7,0, 7), (7,1, 3), (7,1, 7) +(4,0, 8), (4,0,12), (4,1, 8), (4,1,12), (6,0, 8), (6,0,12), (6,1, 8), (6,1,12) +(4,0, 9), (4,0,13), (4,1, 9), (4,1,13), (6,0, 9), (6,0,13), (6,1, 9), (6,1,13) +(4,0,10), (4,0,14), (4,1,10), (4,1,14), (6,0,10), (6,0,14), (6,1,10), (6,1,14) +(4,0,11), (4,0,15), (4,1,11), (4,1,15), (6,0,11), (6,0,15), (6,1,11), (6,1,15) +(5,0, 8), (5,0,12), (5,1, 8), (5,1,12), (7,0, 8), (7,0,12), (7,1, 8), (7,1,12) +(5,0, 9), (5,0,13), (5,1, 9), (5,1,13), (7,0, 9), (7,0,13), (7,1, 9), (7,1,13) +(5,0,10), (5,0,14), (5,1,10), (5,1,14), (7,0,10), (7,0,14), (7,1,10), (7,1,14) +(5,0,11), (5,0,15), (5,1,11), (5,1,15), (7,0,11), (7,0,15), (7,1,11), (7,1,15) +)"; + std::string layoutHW = getLayoutStr(tensorType, /*useHWPointOfView=*/true); + assertSameStr(refHWRep, layoutHW); +} + +TEST_F(DumpLayoutTest, Simple1DShared) { + std::string refStr = + "[( 0),( 1),( 2),( 3),( 4),( 5),( 6),( 7),( 8),( " + "9),(10),(11),(12),(13),(14),(15),(16),(17),(18),(19),(20),(21),(22),(23)" + ",(24),(25),(26),(27),(28),(29),(30),(31)]\n"; + + auto sharedLayout = shared(1, /* vec */ + 1, /* perPhase */ + 4, /* maxPhase */ + {1}, /* cpg */ + {1}, /* csplit */ + {1}, /* ord, row-major */ + {1}); /* cOrd */ + + auto elemTy = Float16Type::get(sharedLayout.getContext()); + auto tensorType = RankedTensorType::get({32}, elemTy, sharedLayout); + std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false); + assertSameStr(refStr, layout); +} + +TEST_F(DumpLayoutTest, Larger2DShared) { + + std::string refStr = + "[[(0: 0),(0: 1),(0: 2),(0: 3),(0: 4),(0: 5),(0: 6),(0: 7),(0: 8),(0: " + "9),(0:10),(0:11),(0:12),(0:13),(0:14),(0:15),(0:16),(0:17),(0:18),(0:19)" + ",(0:20),(0:21),(0:22),(0:23),(0:24),(0:25),(0:26),(0:27),(0:28),(0:29),(" + "0:30),(0:31)]\n" + "[ (1: 0),(1: 1),(1: 2),(1: 3),(1: 4),(1: 5),(1: 6),(1: 7),(1: 8),(1: " + "9),(1:10),(1:11),(1:12),(1:13),(1:14),(1:15),(1:16),(1:17),(1:18),(1:19)" + ",(1:20),(1:21),(1:22),(1:23),(1:24),(1:25),(1:26),(1:27),(1:28),(1:29),(" + "1:30),(1:31)]\n" + "[ (2: 8),(2: 9),(2:10),(2:11),(2:12),(2:13),(2:14),(2:15),(2: 0),(2: " + "1),(2: 2),(2: 3),(2: 4),(2: 5),(2: 6),(2: " + "7),(2:24),(2:25),(2:26),(2:27),(2:28),(2:29),(2:30),(2:31),(2:16),(2:17)" + ",(2:18),(2:19),(2:20),(2:21),(2:22),(2:23)]\n" + "[ (3: 8),(3: 9),(3:10),(3:11),(3:12),(3:13),(3:14),(3:15),(3: 0),(3: " + "1),(3: 2),(3: 3),(3: 4),(3: 5),(3: 6),(3: " + "7),(3:24),(3:25),(3:26),(3:27),(3:28),(3:29),(3:30),(3:31),(3:16),(3:17)" + ",(3:18),(3:19),(3:20),(3:21),(3:22),(3:23)]\n" + "[ " + "(4:16),(4:17),(4:18),(4:19),(4:20),(4:21),(4:22),(4:23),(4:24),(4:25),(" + "4:26),(4:27),(4:28),(4:29),(4:30),(4:31),(4: 0),(4: 1),(4: 2),(4: " + "3),(4: 4),(4: 5),(4: 6),(4: 7),(4: 8),(4: " + "9),(4:10),(4:11),(4:12),(4:13),(4:14),(4:15)]\n" + "[ " + "(5:16),(5:17),(5:18),(5:19),(5:20),(5:21),(5:22),(5:23),(5:24),(5:25),(" + "5:26),(5:27),(5:28),(5:29),(5:30),(5:31),(5: 0),(5: 1),(5: 2),(5: " + "3),(5: 4),(5: 5),(5: 6),(5: 7),(5: 8),(5: " + "9),(5:10),(5:11),(5:12),(5:13),(5:14),(5:15)]\n" + "[ " + "(6:24),(6:25),(6:26),(6:27),(6:28),(6:29),(6:30),(6:31),(6:16),(6:17),(" + "6:18),(6:19),(6:20),(6:21),(6:22),(6:23),(6: 8),(6: " + "9),(6:10),(6:11),(6:12),(6:13),(6:14),(6:15),(6: 0),(6: 1),(6: 2),(6: " + "3),(6: 4),(6: 5),(6: 6),(6: 7)]\n" + "[ " + "(7:24),(7:25),(7:26),(7:27),(7:28),(7:29),(7:30),(7:31),(7:16),(7:17),(" + "7:18),(7:19),(7:20),(7:21),(7:22),(7:23),(7: 8),(7: " + "9),(7:10),(7:11),(7:12),(7:13),(7:14),(7:15),(7: 0),(7: 1),(7: 2),(7: " + "3),(7: 4),(7: 5),(7: 6),(7: 7)]]\n"; + + auto sharedLayout = shared(8, /* vec */ + 2, /* perPhase */ + 8, /* maxPhase */ + {1, 1}, /* cpg */ + {1, 1}, /* csplit */ + {1, 0}, /* ord, row-major */ + {1, 0}); /* cOrd */ + + auto elemTy = Float16Type::get(sharedLayout.getContext()); + auto tensorType = RankedTensorType::get({8, 32}, elemTy, sharedLayout); + std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false); + assertSameStr(refStr, layout); + + std::string refHWRep = + R"(Block: 0: +Offset: 0 -> (0, 0) +Offset: 1 -> (0, 1) +Offset: 2 -> (0, 2) +Offset: 3 -> (0, 3) +Offset: 4 -> (0, 4) +Offset: 5 -> (0, 5) +Offset: 6 -> (0, 6) +Offset: 7 -> (0, 7) +Offset: 8 -> (0, 8) +Offset: 9 -> (0, 9) +Offset: 10 -> (0,10) +Offset: 11 -> (0,11) +Offset: 12 -> (0,12) +Offset: 13 -> (0,13) +Offset: 14 -> (0,14) +Offset: 15 -> (0,15) +Offset: 16 -> (0,16) +Offset: 17 -> (0,17) +Offset: 18 -> (0,18) +Offset: 19 -> (0,19) +Offset: 20 -> (0,20) +Offset: 21 -> (0,21) +Offset: 22 -> (0,22) +Offset: 23 -> (0,23) +Offset: 24 -> (0,24) +Offset: 25 -> (0,25) +Offset: 26 -> (0,26) +Offset: 27 -> (0,27) +Offset: 28 -> (0,28) +Offset: 29 -> (0,29) +Offset: 30 -> (0,30) +Offset: 31 -> (0,31) +Offset: 32 -> (1, 2) +Offset: 33 -> (1, 3) +Offset: 34 -> (1, 0) +Offset: 35 -> (1, 1) +Offset: 36 -> (1, 6) +Offset: 37 -> (1, 7) +Offset: 38 -> (1, 4) +Offset: 39 -> (1, 5) +Offset: 40 -> (1,10) +Offset: 41 -> (1,11) +Offset: 42 -> (1, 8) +Offset: 43 -> (1, 9) +Offset: 44 -> (1,14) +Offset: 45 -> (1,15) +Offset: 46 -> (1,12) +Offset: 47 -> (1,13) +Offset: 48 -> (1,18) +Offset: 49 -> (1,19) +Offset: 50 -> (1,16) +Offset: 51 -> (1,17) +Offset: 52 -> (1,22) +Offset: 53 -> (1,23) +Offset: 54 -> (1,20) +Offset: 55 -> (1,21) +Offset: 56 -> (1,26) +Offset: 57 -> (1,27) +Offset: 58 -> (1,24) +Offset: 59 -> (1,25) +Offset: 60 -> (1,30) +Offset: 61 -> (1,31) +Offset: 62 -> (1,28) +Offset: 63 -> (1,29) +Offset: 64 -> (2, 4) +Offset: 65 -> (2, 5) +Offset: 66 -> (2, 6) +Offset: 67 -> (2, 7) +Offset: 68 -> (2, 0) +Offset: 69 -> (2, 1) +Offset: 70 -> (2, 2) +Offset: 71 -> (2, 3) +Offset: 72 -> (2,12) +Offset: 73 -> (2,13) +Offset: 74 -> (2,14) +Offset: 75 -> (2,15) +Offset: 76 -> (2, 8) +Offset: 77 -> (2, 9) +Offset: 78 -> (2,10) +Offset: 79 -> (2,11) +Offset: 80 -> (2,20) +Offset: 81 -> (2,21) +Offset: 82 -> (2,22) +Offset: 83 -> (2,23) +Offset: 84 -> (2,16) +Offset: 85 -> (2,17) +Offset: 86 -> (2,18) +Offset: 87 -> (2,19) +Offset: 88 -> (2,28) +Offset: 89 -> (2,29) +Offset: 90 -> (2,30) +Offset: 91 -> (2,31) +Offset: 92 -> (2,24) +Offset: 93 -> (2,25) +Offset: 94 -> (2,26) +Offset: 95 -> (2,27) +Offset: 96 -> (3, 6) +Offset: 97 -> (3, 7) +Offset: 98 -> (3, 4) +Offset: 99 -> (3, 5) +Offset: 100 -> (3, 2) +Offset: 101 -> (3, 3) +Offset: 102 -> (3, 0) +Offset: 103 -> (3, 1) +Offset: 104 -> (3,14) +Offset: 105 -> (3,15) +Offset: 106 -> (3,12) +Offset: 107 -> (3,13) +Offset: 108 -> (3,10) +Offset: 109 -> (3,11) +Offset: 110 -> (3, 8) +Offset: 111 -> (3, 9) +Offset: 112 -> (3,22) +Offset: 113 -> (3,23) +Offset: 114 -> (3,20) +Offset: 115 -> (3,21) +Offset: 116 -> (3,18) +Offset: 117 -> (3,19) +Offset: 118 -> (3,16) +Offset: 119 -> (3,17) +Offset: 120 -> (3,30) +Offset: 121 -> (3,31) +Offset: 122 -> (3,28) +Offset: 123 -> (3,29) +Offset: 124 -> (3,26) +Offset: 125 -> (3,27) +Offset: 126 -> (3,24) +Offset: 127 -> (3,25) +Offset: 128 -> (4, 8) +Offset: 129 -> (4, 9) +Offset: 130 -> (4,10) +Offset: 131 -> (4,11) +Offset: 132 -> (4,12) +Offset: 133 -> (4,13) +Offset: 134 -> (4,14) +Offset: 135 -> (4,15) +Offset: 136 -> (4, 0) +Offset: 137 -> (4, 1) +Offset: 138 -> (4, 2) +Offset: 139 -> (4, 3) +Offset: 140 -> (4, 4) +Offset: 141 -> (4, 5) +Offset: 142 -> (4, 6) +Offset: 143 -> (4, 7) +Offset: 144 -> (4,24) +Offset: 145 -> (4,25) +Offset: 146 -> (4,26) +Offset: 147 -> (4,27) +Offset: 148 -> (4,28) +Offset: 149 -> (4,29) +Offset: 150 -> (4,30) +Offset: 151 -> (4,31) +Offset: 152 -> (4,16) +Offset: 153 -> (4,17) +Offset: 154 -> (4,18) +Offset: 155 -> (4,19) +Offset: 156 -> (4,20) +Offset: 157 -> (4,21) +Offset: 158 -> (4,22) +Offset: 159 -> (4,23) +Offset: 160 -> (5,10) +Offset: 161 -> (5,11) +Offset: 162 -> (5, 8) +Offset: 163 -> (5, 9) +Offset: 164 -> (5,14) +Offset: 165 -> (5,15) +Offset: 166 -> (5,12) +Offset: 167 -> (5,13) +Offset: 168 -> (5, 2) +Offset: 169 -> (5, 3) +Offset: 170 -> (5, 0) +Offset: 171 -> (5, 1) +Offset: 172 -> (5, 6) +Offset: 173 -> (5, 7) +Offset: 174 -> (5, 4) +Offset: 175 -> (5, 5) +Offset: 176 -> (5,26) +Offset: 177 -> (5,27) +Offset: 178 -> (5,24) +Offset: 179 -> (5,25) +Offset: 180 -> (5,30) +Offset: 181 -> (5,31) +Offset: 182 -> (5,28) +Offset: 183 -> (5,29) +Offset: 184 -> (5,18) +Offset: 185 -> (5,19) +Offset: 186 -> (5,16) +Offset: 187 -> (5,17) +Offset: 188 -> (5,22) +Offset: 189 -> (5,23) +Offset: 190 -> (5,20) +Offset: 191 -> (5,21) +Offset: 192 -> (6,12) +Offset: 193 -> (6,13) +Offset: 194 -> (6,14) +Offset: 195 -> (6,15) +Offset: 196 -> (6, 8) +Offset: 197 -> (6, 9) +Offset: 198 -> (6,10) +Offset: 199 -> (6,11) +Offset: 200 -> (6, 4) +Offset: 201 -> (6, 5) +Offset: 202 -> (6, 6) +Offset: 203 -> (6, 7) +Offset: 204 -> (6, 0) +Offset: 205 -> (6, 1) +Offset: 206 -> (6, 2) +Offset: 207 -> (6, 3) +Offset: 208 -> (6,28) +Offset: 209 -> (6,29) +Offset: 210 -> (6,30) +Offset: 211 -> (6,31) +Offset: 212 -> (6,24) +Offset: 213 -> (6,25) +Offset: 214 -> (6,26) +Offset: 215 -> (6,27) +Offset: 216 -> (6,20) +Offset: 217 -> (6,21) +Offset: 218 -> (6,22) +Offset: 219 -> (6,23) +Offset: 220 -> (6,16) +Offset: 221 -> (6,17) +Offset: 222 -> (6,18) +Offset: 223 -> (6,19) +Offset: 224 -> (7,14) +Offset: 225 -> (7,15) +Offset: 226 -> (7,12) +Offset: 227 -> (7,13) +Offset: 228 -> (7,10) +Offset: 229 -> (7,11) +Offset: 230 -> (7, 8) +Offset: 231 -> (7, 9) +Offset: 232 -> (7, 6) +Offset: 233 -> (7, 7) +Offset: 234 -> (7, 4) +Offset: 235 -> (7, 5) +Offset: 236 -> (7, 2) +Offset: 237 -> (7, 3) +Offset: 238 -> (7, 0) +Offset: 239 -> (7, 1) +Offset: 240 -> (7,30) +Offset: 241 -> (7,31) +Offset: 242 -> (7,28) +Offset: 243 -> (7,29) +Offset: 244 -> (7,26) +Offset: 245 -> (7,27) +Offset: 246 -> (7,24) +Offset: 247 -> (7,25) +Offset: 248 -> (7,22) +Offset: 249 -> (7,23) +Offset: 250 -> (7,20) +Offset: 251 -> (7,21) +Offset: 252 -> (7,18) +Offset: 253 -> (7,19) +Offset: 254 -> (7,16) +Offset: 255 -> (7,17) +)"; + auto sharedLayoutHW = shared(2, /* vec */ + 1, /* perPhase */ + 32, /* maxPhase */ + {1, 1}, /* cpg */ + {1, 1}, /* csplit */ + {1, 0}, /* ord, row-major */ + {1, 0}); /* cOrd */ + + auto elemTyHW = Float16Type::get(sharedLayoutHW.getContext()); + auto tensorTypeHW = RankedTensorType::get({8, 32}, elemTyHW, sharedLayoutHW); + + std::string layoutHW = getLayoutStr(tensorTypeHW, /*useHWPointOfView=*/true); + assertSameStr(refHWRep, layoutHW); +} + +} // anonymous namespace +} // namespace mlir::triton::gpu + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp new file mode 100644 index 000000000..27268d7a7 --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -0,0 +1,2808 @@ +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" + +#include "mlir/IR/MLIRContext.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Signals.h" +#include +#include + +namespace mlir { +std::ostream &operator<<(std::ostream &os, StringAttr str) { + os << str.str(); + return os; +} +} // namespace mlir + +namespace mlir::triton::gpu { +namespace { + +class LinearLayoutConversionsTest : public ::testing::Test { +public: + void SetUp() { ctx.getOrLoadDialect(); } + + BlockedEncodingAttr blocked(ArrayRef spt, ArrayRef tpw, + ArrayRef wpb, ArrayRef cpg, + ArrayRef cSplit, ArrayRef ord, + ArrayRef cOrd) { + return BlockedEncodingAttr::get( + &ctx, spt, tpw, wpb, ord, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + NvidiaMmaEncodingAttr mma(unsigned versionMaj, unsigned versionMin, + ArrayRef instrShape, + ArrayRef wbp, ArrayRef cpg, + ArrayRef cSplit, + ArrayRef cOrd) { + return NvidiaMmaEncodingAttr::get( + &ctx, versionMaj, versionMin, wbp, + CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); + } + + NvidiaMmaEncodingAttr mma(unsigned versionMaj, unsigned versionMin, + ArrayRef instrShape, + ArrayRef numWarps) { + auto ctaLayout = CTALayoutAttr::getDefault(&ctx, numWarps.size()); + return NvidiaMmaEncodingAttr::get(&ctx, versionMaj, versionMin, numWarps, + std::move(ctaLayout), instrShape); + } + + DotOperandEncodingAttr dot(Attribute parent, int idx, int kWidth) { + return DotOperandEncodingAttr::get(&ctx, idx, parent, /*kWidth=*/kWidth); + } + + AMDMfmaEncodingAttr mfma(ArrayRef warps, unsigned mDim, + unsigned nDim, bool isTransposed) { + SmallVector cpg(warps.size(), 1u); + SmallVector cSplit(warps.size(), 1u); + SmallVector cOrd(warps.size()); + std::iota(cOrd.begin(), cOrd.end(), 0); + return AMDMfmaEncodingAttr::get( + &ctx, /*versionMajor=*/2, /*versionMinor=*/0, warps, mDim, nDim, + isTransposed, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + DotOperandEncodingAttr mfmaDotOp(AMDMfmaEncodingAttr mfma, unsigned opIdx, + unsigned kWidth) { + return DotOperandEncodingAttr::get(&ctx, opIdx, mfma, kWidth); + } + + AMDWmmaEncodingAttr wmma(ArrayRef warps, int version, + bool transposed) { + SmallVector cpg(warps.size(), 1u); + SmallVector cSplit(warps.size(), 1u); + SmallVector cOrd(warps.size()); + std::iota(cOrd.begin(), cOrd.end(), 0); + return AMDWmmaEncodingAttr::get( + &ctx, version, transposed, warps, + CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + DotOperandEncodingAttr wmmaDotOp(AMDWmmaEncodingAttr wmma, unsigned opIdx, + unsigned kWidth) { + return DotOperandEncodingAttr::get(&ctx, opIdx, wmma, kWidth); + } + + SliceEncodingAttr slice(DistributedEncodingTrait parent, int dim) { + return SliceEncodingAttr::get(&ctx, dim, parent); + } + + SwizzledSharedEncodingAttr shared(unsigned vec, unsigned perPhase, + unsigned maxPhase, ArrayRef cpg, + ArrayRef cSplit, + ArrayRef ord, + ArrayRef cOrd) { + return SwizzledSharedEncodingAttr::get( + &ctx, vec, perPhase, maxPhase, ord, + CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + NVMMASharedEncodingAttr + nvmmaShared(unsigned swizzleSizeInBytes, bool transposed, + unsigned elementBitWidth, ArrayRef cpg, + ArrayRef cSplit, ArrayRef ord, + ArrayRef cOrd, bool fp4Padded = false) { + return NVMMASharedEncodingAttr::get( + &ctx, swizzleSizeInBytes, transposed, elementBitWidth, fp4Padded, + CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + StringAttr S(StringRef str) { return StringAttr::get(&ctx, str); } + +protected: + MLIRContext ctx; +}; + +TEST_F(LinearLayoutConversionsTest, SimpleBlocked) { + auto layout = + toLinearLayout({16}, blocked({1}, {4}, {4}, {1}, {1}, {0}, {0})); + EXPECT_THAT(layout, LinearLayout( + { + {S("register"), {}}, + {S("lane"), {{1}, {2}}}, + {S("warp"), {{4}, {8}}}, + {S("block"), {}}, + }, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, CTADuplication) { + auto layout = toLinearLayout( + {32}, blocked({1}, {4}, {4}, /*cpg=*/{4}, /*cSplit=*/{2}, {0}, {0})); + EXPECT_EQ(layout, LinearLayout( + { + {S("register"), {}}, + {S("lane"), {{1}, {2}}}, + {S("warp"), {{4}, {8}}}, + {S("block"), {{16}, {0}}}, + }, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, CTABroadcast) { + auto layout = + toLinearLayout({64, 128}, blocked({8, 1}, {8, 4}, {1, 4}, {1, 2}, {1, 2}, + {0, 1}, {1, 0})); + EXPECT_EQ( + layout, + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {0, 16}, {0, 32}}}, + {S("lane"), {{8, 0}, {16, 0}, {32, 0}, {0, 1}, {0, 2}}}, + {S("warp"), {{0, 4}, {0, 8}}}, + {S("block"), {{0, 64}}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, ShapeLargerThanLayout) { + // The layout is 16 elements, but the shape is 128, so it's repeated 128/16 = + // 8 times. + auto layout = + toLinearLayout({128}, blocked({1}, {4}, {4}, {1}, {1}, {0}, {0})); + EXPECT_EQ(layout, LinearLayout( + { + {S("register"), {{16}, {32}, {64}}}, + {S("lane"), {{1}, {2}}}, + {S("warp"), {{4}, {8}}}, + {S("block"), {}}, + }, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, ShapeLargerThanLayout2DDegenerate) { + auto layout = toLinearLayout({128, 1}, blocked({1, 1}, {4, 1}, {4, 1}, {1, 1}, + {1, 1}, {0, 1}, {1, 0})); + EXPECT_EQ(layout, LinearLayout( + { + {S("register"), {{16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}}}, + {S("warp"), {{4, 0}, {8, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, ShapeSmallerThanLayout) { + // The shape is 8 elements, but the layout is 4*4*4 = 64 elems. Therefore the + // log2(64/8) = 3 most major bases are 0. + auto layout = toLinearLayout({8}, blocked({4}, {4}, {4}, {1}, {1}, {0}, {0})); + EXPECT_EQ(layout, LinearLayout( + { + {S("register"), {{1}, {2}}}, + {S("lane"), {{4}, {0}}}, + {S("warp"), {{0}, {0}}}, + {S("block"), {}}, + }, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, ReversedOrder) { + auto layout = toLinearLayout({1, 64}, blocked({1, 1}, {32, 1}, {1, 8}, {1, 1}, + {1, 1}, {0, 1}, {1, 0})); + EXPECT_EQ(layout, + LinearLayout( + { + {S("register"), {{0, 8}, {0, 16}, {0, 32}}}, + {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, ReplicateInRegisterDim) { + auto layout = + toLinearLayout({32}, blocked({2}, {4}, {1}, {1}, {1}, {0}, {0})); + EXPECT_EQ(layout, LinearLayout( + { + {S("register"), {{1}, {8}, {16}}}, + {S("lane"), {{2}, {4}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, OneDimTooLargeAnotherTooSmall) { + auto blockedLayout = + blocked({1, 4}, {8, 4}, {4, 1}, {2, 2}, {2, 1}, {1, 0}, {1, 0}); + auto ll = toLinearLayout({128, 16}, blockedLayout); + EXPECT_EQ(ll, LinearLayout( + { + {S("register"), {{0, 1}, {0, 2}, {32, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{8, 0}, {16, 0}}}, + {S("block"), {{0, 0}, {64, 0}}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, RepeatInCTGDimFirst) { + // We have a 4-element shape and an 8-element layout (4 elems per CTA). So + // the layout will map two inputs to each output. The question is, which two + // inputs? The answer is, we split between CTAs first, so the two CTAs have + // distinct elements. + auto blockedLayout = blocked({1}, {1}, {4}, {2}, {2}, {0}, {0}); + auto ll = toLinearLayout({4}, blockedLayout); + EXPECT_EQ(ll, LinearLayout( + { + {S("register"), {}}, + {S("lane"), {}}, + {S("warp"), {{1}, {0}}}, + {S("block"), {{2}}}, + }, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, SmallerThanCTALayout) { + auto blockedLayout = blocked({1}, {1}, {1}, {4}, {4}, {0}, {0}); + auto ll = toLinearLayout({2}, blockedLayout); + EXPECT_EQ(ll, LinearLayout( + { + {S("register"), {}}, + {S("lane"), {}}, + {S("warp"), {}}, + {S("block"), {{1}, {0}}}, + }, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, Skinny) { + auto blockedLayout = + blocked({8, 1}, {8, 4}, {1, 4}, {1, 2}, {1, 2}, {0, 1}, {0, 1}); + auto ll = toLinearLayout({64, 1}, blockedLayout); + EXPECT_EQ(ll, LinearLayout( + { + {S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {32, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {{0, 0}}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, BlockedOrder) { + auto ll = toLinearLayout({1024, 128}, blocked({2, 2}, {4, 8}, {2, 2}, {2, 2}, + {2, 2}, {1, 0}, {1, 0})); + EXPECT_EQ(ll, LinearLayout( + { + {S("register"), + { + {0, 1}, + {1, 0}, + {0, 32}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}, + {256, 0}, + }}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 16}, {8, 0}}}, + {S("block"), {{0, 64}, {512, 0}}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, Blocked4D) { + auto ll = toLinearLayout({2, 1, 1, 1}, + blocked({1, 1, 1, 4}, {2, 1, 1, 16}, {1, 2, 4, 1}, + {1, 1, 1, 1}, {1, 1, 1, 1}, {3, 0, 1, 2}, + {3, 2, 1, 0})); + EXPECT_EQ(ll, LinearLayout( + { + {S("register"), {{0, 0, 0, 0}, {0, 0, 0, 0}}}, + {S("lane"), + {{0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 0, 0, 0}}}, + {S("warp"), {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); +} + +TEST_F(LinearLayoutConversionsTest, BlockedDotOperandLhs) { + auto parent = blocked(/*size*/ {2, 4}, /*threads*/ {8, 4}, /*warps*/ {2, 4}, + /*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0}, + /*cta order*/ {1, 0}); + auto dotOperand = dot(parent, /*idx*/ 0, /*kWidth*/ 0); + EXPECT_EQ( + toLinearLayout({32, 16}, dotOperand), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("lane"), {{0, 0}, {0, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, BlockedDot3dOperandLhs) { + auto parent = + blocked(/*size*/ {2, 2, 4}, /*threads*/ {2, 4, 4}, /*warps*/ {2, 2, 2}, + /*ctas*/ {1, 1, 1}, /*splits*/ {1, 1, 1}, /*order*/ {2, 1, 0}, + /*cta order*/ {2, 1, 0}); + auto dotOperand = dot(parent, /*idx*/ 0, /*kWidth*/ 0); + EXPECT_EQ( + toLinearLayout({16, 32, 4}, dotOperand), + LinearLayout( + {{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 1, 0}, + {1, 0, 0}, + {0, 16, 0}, + {8, 0, 0}}}, + {S("lane"), {{0, 0, 0}, {0, 0, 0}, {0, 2, 0}, {0, 4, 0}, {2, 0, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 8, 0}, {4, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, BlockedDotOperandRhs) { + auto parent = blocked(/*size*/ {2, 4}, /*threads*/ {8, 4}, /*warps*/ {2, 4}, + /*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0}, + /*cta order*/ {1, 0}); + auto dotOperand = dot(parent, /*idx*/ 1, /*kWidth*/ 0); + EXPECT_EQ(toLinearLayout({16, 64}, dotOperand), + LinearLayout({{S("register"), + {{0, 1}, {0, 2}, {1, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {0, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, BlockedDot3dOperandRhs) { + auto parent = + blocked(/*size*/ {2, 2, 4}, /*threads*/ {2, 4, 4}, /*warps*/ {2, 2, 2}, + /*ctas*/ {1, 1, 1}, /*splits*/ {1, 1, 1}, /*order*/ {2, 1, 0}, + /*cta order*/ {2, 1, 0}); + auto dotOperand = dot(parent, /*idx*/ 1, /*kWidth*/ 0); + EXPECT_EQ( + toLinearLayout({16, 4, 64}, dotOperand), + LinearLayout( + {{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 1, 0}, + {0, 2, 0}, + {1, 0, 0}, + {0, 0, 32}, + {8, 0, 0}}}, + {S("lane"), {{0, 0, 4}, {0, 0, 8}, {0, 0, 0}, {0, 0, 0}, {2, 0, 0}}}, + {S("warp"), {{0, 0, 16}, {0, 0, 0}, {4, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { + EXPECT_EQ(toLinearLayout({16, 16}, + mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv2_32x32) { + EXPECT_EQ(toLinearLayout({32, 32}, + mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}, {0, 16}, {16, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv2_ExtendDim2) { + EXPECT_EQ(toLinearLayout({16, 128}, + mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), + LinearLayout( + { + {S("register"), + {{0, 1}, {8, 0}, {0, 8}, {0, 16}, {0, 32}, {0, 64}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv2_Cga) { + EXPECT_EQ( + toLinearLayout({64, 128, 128}, mma(2, 0, {1, 16, 8}, {16, 1, 1}, + {4, 2, 2}, {4, 2, 1}, {2, 1, 0})), + LinearLayout( + { + {S("register"), + { + {0, 0, 1}, + {0, 8, 0}, + {0, 0, 8}, + {0, 0, 16}, + {0, 0, 32}, + {0, 0, 64}, + {0, 16, 0}, + {0, 32, 0}, + }}, + {S("lane"), + {{0, 0, 2}, {0, 0, 4}, {0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("warp"), {{1, 0, 0}, {2, 0, 0}, {4, 0, 0}, {8, 0, 0}}}, + {S("block"), {{0, 0, 0}, {0, 64, 0}, {16, 0, 0}, {32, 0, 0}}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv2_Small3D) { + EXPECT_EQ(toLinearLayout({1, 128, 128}, mma(2, 0, {1, 16, 8}, {16, 1, 1}, + {4, 2, 2}, {4, 2, 1}, {2, 1, 0})), + LinearLayout( + { + {S("register"), + { + {0, 0, 1}, + {0, 8, 0}, + {0, 0, 8}, + {0, 0, 16}, + {0, 0, 32}, + {0, 0, 64}, + {0, 16, 0}, + {0, 32, 0}, + }}, + {S("lane"), + {{0, 0, 2}, {0, 0, 4}, {0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {{0, 0, 0}, {0, 64, 0}, {0, 0, 0}, {0, 0, 0}}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv3_64x16) { + SmallVector, 2> instrShapes = {{16, 16, 8}, {16, 8, 8}}; + for (auto instrShape : instrShapes) { + SCOPED_TRACE(triton::join(instrShape, ",")); + EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, instrShape, {4, 1}, {1, 1}, + {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + } +} + +TEST_F(LinearLayoutConversionsTest, MMAv3_128x16) { + EXPECT_EQ(toLinearLayout({128, 16}, mma(3, 0, {16, 16, 8}, {4, 1}, {1, 1}, + {1, 1}, {1, 0})), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv3_1024x1024) { + EXPECT_EQ(toLinearLayout({1024, 1024}, mma(3, 0, {16, 16, 8}, {4, 1}, {1, 1}, + {1, 1}, {1, 0})), + LinearLayout({{S("register"), + {{0, 1}, + {8, 0}, + {0, 8}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {0, 256}, + {0, 512}, + {64, 0}, + {128, 0}, + {256, 0}, + {512, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv3_4x2Warps) { + auto legacy = mma(3, 0, {16, 32, 16}, {4, 2}, {1, 1}, {1, 1}, {1, 0}); + + EXPECT_EQ(toLinearLayout({64, 32}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}, {0, 16}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 64}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}, {0, 16}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({128, 64}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}, {0, 16}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({256, 64}, legacy), + LinearLayout({{S("register"), + {{0, 1}, {8, 0}, {0, 8}, {0, 16}, {64, 0}, {128, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { + auto legacy = mma(3, 0, {16, 16, 8}, {4, 4}, {1, 1}, {1, 1}, {1, 0}); + + EXPECT_EQ(toLinearLayout({16, 16}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 16}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 16}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 16}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 32}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {0, 0}, {0, 16}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, legacy), + LinearLayout({{S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 16}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { + auto parent = mma(2, 0, {16, 8}, {1, 1}); + EXPECT_EQ(toLinearLayout({16, 64}, dot(parent, 0, 8)), + LinearLayout( + { + {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 8}, dot(parent, 1, 8)), + LinearLayout( + { + {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { + auto parent = mma(2, 0, {16, 8}, {4, 1}); + EXPECT_EQ( + toLinearLayout({128, 128}, dot(parent, 0, 8)), + LinearLayout( + { + {S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 64}, dot(parent, 1, 8)), + LinearLayout( + { + {S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {64, 0}, + {0, 8}, + {0, 16}, + {0, 32}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + { + S("warp"), + {{0, 0}, {0, 0}}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, dot(parent, 1, 8)), + LinearLayout( + { + {S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {0, 8}, + {0, 16}, + {0, 32}, + {0, 64}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + { + S("warp"), + {{0, 0}, {0, 0}}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv2_3D) { + // We implement one that exercises all the paths + auto parent = mma(2, 0, {1, 16, 8}, {2, 4, 2}); + EXPECT_EQ(toLinearLayout({16, 128, 128}, dot(parent, 0, 8)), + LinearLayout( + { + {S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 8, 0}, + {0, 0, 32}, + {0, 0, 64}, + {0, 64, 0}, + {2, 0, 0}, + {4, 0, 0}, + {8, 0, 0}}}, + {S("lane"), + {{0, 0, 8}, {0, 0, 16}, {0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 16, 0}, {0, 32, 0}, {1, 0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({8, 128, 64}, dot(parent, 1, 8)), + LinearLayout( + { + {S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 0, 16}, + {0, 0, 32}, + {2, 0, 0}, + {4, 0, 0}}}, + {S("lane"), + {{0, 8, 0}, {0, 16, 0}, {0, 0, 1}, {0, 0, 2}, {0, 0, 4}}}, + { + S("warp"), + {{0, 0, 8}, {0, 0, 0}, {0, 0, 0}, {1, 0, 0}}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv3_warp4_kwidth2) { + auto parent = mma(3, 0, {16, 16, 8}, {4, 1}); + auto dotOp = dot(parent, 0, 2); + + EXPECT_EQ(toLinearLayout({64, 16}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 16}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 32}, dotOp), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}, {0, 16}, {64, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv3_mixed_warp_kwidth4) { + // Testing dot with MMAv3 encoding for opIdx = 0 and kWidth = 4 + auto parent = mma(3, 0, {16, 16, 8}, {4, 2}); + auto dotOp = dot(parent, 0, 4); + + EXPECT_EQ(toLinearLayout({128, 64}, dotOp), + LinearLayout( + { + {S("register"), + {{0, 1}, {0, 2}, {8, 0}, {0, 16}, {0, 32}, {64, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}, {0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) { + auto parent = mma(2, 0, {16, 8}, {2, 2}); + EXPECT_EQ( + toLinearLayout({32, 64}, dot(parent, 0, 8)), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({64, 16}, dot(parent, 1, 8)), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, dot(parent, 0, 8)), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({128, 32}, dot(parent, 1, 8)), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SliceDot) { + // Slice layout with a DotOperand (MMAv2) as the parent. + auto parentV2 = dot(mma(2, 0, {16, 8}, {1, 1}), /*opIdx=*/0, /*kWidth=*/8); + auto sliceV2 = slice(parentV2, /*dim=*/1); + + EXPECT_EQ(toLinearLayout({16}, sliceV2), + LinearLayout( + { + {S("register"), {{8}}}, + {S("lane"), {{0}, {0}, {1}, {2}, {4}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0")})); + + // Slice layout with a DotOperand (MMAv3) as the parent. + auto parentV3 = + dot(mma(3, 0, {16, 16, 8}, {4, 1}), /*opIdx=*/0, /*kWidth=*/2); + auto sliceV3 = slice(parentV3, /*dim=*/0); + + EXPECT_EQ(toLinearLayout({16}, sliceV3), + LinearLayout( + { + {S("register"), {{1}, {8}}}, + {S("lane"), {{2}, {4}, {0}, {0}, {0}}}, + {S("warp"), {{0}, {0}}}, + {S("block"), {}}, + }, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { + auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + + EXPECT_EQ(toLinearLayout({32, 32}, mfmaNT), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaNT), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaNT), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto mfmaT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + + EXPECT_EQ(toLinearLayout({32, 32}, mfmaT), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaT), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaT), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 32}, {0, 64}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MFMA16_2x4Warps) { + auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaNT), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MFMA32_2x4x1Warps) { + auto mfmaNT = mfma(/*warps=*/{2, 4, 1}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + + EXPECT_EQ(toLinearLayout({1, 128, 128}, mfmaNT), + LinearLayout({{S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 0, 32}, + {0, 0, 64}}}, + {S("lane"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 0, 16}, + {0, 4, 0}}}, + {S("warp"), {{0, 32, 0}, {0, 64, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({2, 32, 32}, mfmaNT), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 8, 0}, {0, 16, 0}}}, + {S("lane"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 0, 16}, + {0, 4, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({2, 64, 32}, mfmaNT), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 8, 0}, {0, 16, 0}}}, + {S("lane"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 0, 16}, + {0, 4, 0}}}, + {S("warp"), {{0, 32, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + + auto mfmaT = mfma(/*warps=*/{2, 4, 1}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + + EXPECT_EQ(toLinearLayout({1, 128, 128}, mfmaT), + LinearLayout({{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 8}, + {0, 0, 16}, + {0, 0, 32}, + {0, 0, 64}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 0, 4}}}, + {S("warp"), {{0, 32, 0}, {0, 64, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({2, 32, 32}, mfmaT), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 8}, {0, 0, 16}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 0, 4}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({2, 64, 32}, mfmaT), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 8}, {0, 0, 16}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 0, 4}}}, + {S("warp"), {{0, 32, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_lhs_kwidth8) { + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 16}, {0, 32}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ( + toLinearLayout({128, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 128}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ( + toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 128}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}, + {0, 128}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_lhs_kwidth8) { + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {16, 0}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({1, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + { + {0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + }}, + {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ( + toLinearLayout({128, 1}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 0}, {0, 0}, {0, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {0, 128}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {0, 128}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8_1 = mfmaDotOp(parentMfma_1_8_1, /*opIdx=*/0, /*kWidth=*/8); + + EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), + LinearLayout({{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 32}, + {0, 0, 64}, + {0, 0, 128}, + {0, 16, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 128, 0}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 0, 8}, + {0, 0, 16}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({1, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 0}, {0, 0}, {0, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 16}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 1}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {64, 0}, + {128, 0}, + {0, 64}, + {0, 128}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ( + toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {128, 0}, {0, 128}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 64}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8_1 = mfmaDotOp(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8); + + EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), + LinearLayout({{S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 128, 0}, + {0, 0, 128}}}, + {S("lane"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 8, 0}, + {0, 16, 0}}}, + {S("warp"), {{0, 0, 16}, {0, 0, 32}, {0, 0, 64}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_lhs_kwidth4) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp0_32 = mfmaDotOp(parentMfma32, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp0_32 = mfmaDotOp(parentTMfma32, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp0_32), + toLinearLayout({128, 128}, mfmaDotOp0_32)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp0_32), + toLinearLayout({64, 32}, mfmaDotOp0_32)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp0_32), + toLinearLayout({16, 16}, mfmaDotOp0_32)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_lhs_kwidth4) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp0_16 = mfmaDotOp(parentMfma16, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 16}, {0, 32}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 16}, {32, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp0_16 = mfmaDotOp(parentTMfma16, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp0_16), + toLinearLayout({128, 128}, mfmaDotOp0_16)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp0_16), + toLinearLayout({64, 32}, mfmaDotOp0_16)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp0_16), + toLinearLayout({16, 16}, mfmaDotOp0_16)); +} + +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_rhs_kwidth4) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp1_32 = mfmaDotOp(parentMfma32, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp1_32 = mfmaDotOp(parentTMfma32, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp1_32), + toLinearLayout({128, 128}, mfmaDotOp1_32)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp1_32), + toLinearLayout({64, 32}, mfmaDotOp1_32)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp1_32), + toLinearLayout({16, 16}, mfmaDotOp1_32)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_rhs_kwidth4) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp1_16 = mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {16, 0}, {32, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp1_16 = mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp1_16), + toLinearLayout({128, 128}, mfmaDotOp1_16)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp1_16), + toLinearLayout({64, 32}, mfmaDotOp1_16)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp1_16), + toLinearLayout({16, 16}, mfmaDotOp1_16)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_lhs_trans) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp0_kwidth_8 = mfmaDotOp(parentMfma16, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {128, 128}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {0, 4}, {0, 32}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{4, 0}, {8, 0}, {0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {32, 64}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {0, 4}, {0, 32}}}, + {S("lane"), {{4, 0}, {8, 0}, {0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto mfmaDotOp0_kwidth_4 = mfmaDotOp(parentMfma16, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_4, {16, 16}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}}}, + {S("lane"), {{4, 0}, {8, 0}, {0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto mfmaDotOp0_kwidth_16 = + mfmaDotOp(parentMfma16, /*opIdx=*/0, /*kWidth=*/16); + EXPECT_EQ( + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_16, {128, 128}, + /*elemBitWidth=*/8), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {0, 8}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{8, 0}, {0, 1}, {0, 2}, {0, 4}, {0, 16}, {0, 32}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {16, 32}, + /*elemBitWidth=*/8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{8, 0}, {0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand for LDS transpose load on transposed mfma layout has same + // layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp0_kwidth_16 = + mfmaDotOp(parentTMfma16, /*opIdx=*/0, /*kWidth=*/16); + auto tmfmaDotOp0_kwidth_8 = + mfmaDotOp(parentTMfma16, /*opIdx=*/0, /*kWidth=*/8); + auto tmfmaDotOp0_kwidth_4 = + mfmaDotOp(parentTMfma16, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_8, {128, 128}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {128, 128}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_16, {128, 128}, + /*elemBitWidth=*/8), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_16, {128, 128}, + /*elemBitWidth=*/8)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_8, {64, 32}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {64, 32}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_4, {16, 16}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_4, {16, 16}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_8, {16, 32}, + /*elemBitWidth=*/8), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {16, 32}, + /*elemBitWidth=*/8)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_rhs_trans) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp1_kwidth_8 = mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDotOp1_kwidth_16 = + mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/16); + EXPECT_EQ( + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_16, {128, 128}, + /*elemBitWidth=*/8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 8}, {1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {128, 128}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {4, 0}, {32, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {32, 64}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {4, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto mfmaDotOp1_kwidth_4 = mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_4, {16, 16}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {32, 16}, + /*elemBitWidth=*/8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{0, 8}, {1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand for LDS transpose load based on transposed mfma layout has + // same layout as ordinary. + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + + auto tmfmaDotOp1_kwidth_16 = + mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/8); + auto tmfmaDotOp1_kwidth_8 = + mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/8); + auto tmfmaDotOp1_kwidth_4 = + mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_8, {128, 128}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {128, 128}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_16, {128, 128}, + /*elemBitWidth=*/8), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {128, 128}, + /*elemBitWidth=*/8)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_8, {64, 32}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {64, 32}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_4, {16, 16}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_4, {16, 16}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_8, {32, 16}, + /*elemBitWidth=*/8), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {32, 16}, + /*elemBitWidth=*/8)); +} + +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_lhs_trans) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp0_kwidth_8 = mfmaDotOp(parentMfma32, /*opIdx=*/0, /*kWidth=*/8); + auto mfmaDotOp0_kwidth_16 = + mfmaDotOp(parentMfma32, /*opIdx=*/0, /*kWidth=*/16); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {128, 128}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {0, 4}, {0, 16}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{4, 0}, {8, 0}, {0, 1}, {0, 2}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_16, {128, 128}, + /*elemBitWidth=*/8), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {0, 8}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{8, 0}, {0, 1}, {0, 2}, {0, 4}, {16, 0}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {32, 64}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {0, 4}, {0, 16}, {0, 32}}}, + {S("lane"), {{4, 0}, {8, 0}, {0, 1}, {0, 2}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto mfmaDotOp0_kwidth_4 = mfmaDotOp(parentMfma32, /*opIdx=*/0, + /*kWidth=*/4); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_4, {32, 8}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}}}, + {S("lane"), {{4, 0}, {8, 0}, {0, 1}, {0, 2}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {32, 16}, + /*elemBitWidth=*/8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{8, 0}, {0, 1}, {0, 2}, {0, 4}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand for LDS transpose load based on transposed mfma layout has + // same layout as ordinary. + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp0_kwidth_16 = + mfmaDotOp(parentTMfma32, /*opIdx=*/0, /*kWidth=*/16); + auto tmfmaDotOp0_kwidth_8 = + mfmaDotOp(parentTMfma32, /*opIdx=*/0, /*kWidth=*/8); + auto tmfmaDotOp0_kwidth_4 = + mfmaDotOp(parentTMfma32, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_8, {128, 128}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {128, 128}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_16, {128, 128}, + /*elemBitWidth=*/8), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_16, {128, 128}, + /*elemBitWidth=*/8)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_8, {64, 32}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {64, 32}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_4, {32, 8}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_4, {32, 8}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp0_kwidth_8, {32, 16}, + /*elemBitWidth=*/8), + chooseDsReadB64TrLayout(mfmaDotOp0_kwidth_8, {32, 16}, + /*elemBitWidth=*/8)); +} + +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_rhs_trans) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp1_kwidth_8 = mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/8); + auto mfmaDotOp1_kwidth_16 = + mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/16); + + EXPECT_EQ( + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {128, 128}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_16, {128, 128}, + /*elemBitWidth=*/8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 8}, {1, 0}, {2, 0}, {4, 0}, {0, 16}, {16, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {32, 64}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {4, 0}, {16, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto mfmaDotOp1_kwidth_4 = mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_4, {8, 32}, + /*elemBitWidth=*/16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}}}, + {S("lane"), {{0, 4}, {0, 8}, {1, 0}, {2, 0}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {16, 32}, + /*elemBitWidth=*/8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{0, 8}, {1, 0}, {2, 0}, {4, 0}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand for LDS transpose load based on transposed mfma layout has + // same layout as ordinary. + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp1_kwidth_16 = + mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/8); + auto tmfmaDotOp1_kwidth_8 = + mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/8); + auto tmfmaDotOp1_kwidth_4 = + mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_8, {128, 128}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {128, 128}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_16, {128, 128}, + /*elemBitWidth=*/8), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_16, {128, 128}, + /*elemBitWidth=*/8)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_8, {64, 32}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {64, 32}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_4, {8, 32}, + /*elemBitWidth=*/16), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_4, {8, 32}, + /*elemBitWidth=*/16)); + EXPECT_EQ(chooseDsReadB64TrLayout(tmfmaDotOp1_kwidth_8, {16, 32}, + /*elemBitWidth=*/8), + chooseDsReadB64TrLayout(mfmaDotOp1_kwidth_8, {16, 32}, + /*elemBitWidth=*/8)); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v1_2x4Warps) { + auto legacy = wmma(/*warps=*/{2, 4}, /*version=*/1, /*transposed=*/false); + + EXPECT_EQ(toLinearLayout({16, 16}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 32x16, we need 2x1 WMMA instances. We have 2x4 warps, so we are + // broadcasted along the warp N dimension, distributed along the warp M + // dimension. + EXPECT_EQ(toLinearLayout({32, 16}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 16x32, we need 1x2 WMMA instances. We have 2x4 warps, so along the warp + // N dimension, warp 0/2 gets the first distributed instance, warp 1/3 gets + // the second distributed instance. Along the warp M dimension, all are + // broadcasted. + EXPECT_EQ(toLinearLayout({16, 32}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 16}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 128x128, we need 8x8 WMMA instances. Given that we have 2x4 warps, each + // warp handles 4x2 instances. So for both the warp M and N dimension, we + // distribute. The register dimension will handle (8 x 4x2 =) 64 values--those + // additional base vectors after the intrinsic shape are next power of two + // values following the warp dimension, given that we are tiling cyclically + // among warps. + EXPECT_EQ(toLinearLayout({128, 128}, legacy), + LinearLayout({{S("register"), + {{2, 0}, {4, 0}, {8, 0}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v1_2x4x1Warps) { + auto legacy = wmma(/*warps=*/{2, 4, 1}, /*version=*/1, /*transposed=*/false); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, legacy), + LinearLayout( + {{S("register"), {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 16, 16}, legacy), + LinearLayout( + {{S("register"), {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({8, 16, 16}, legacy), + LinearLayout( + {{S("register"), + {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {2, 0, 0}, {4, 0, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v1_2x4Warps_lhs) { + auto dot = wmma(/*warps=*/{2, 4}, /*version=*/1, /*transposed=*/false); + auto wmmaOperand = wmmaDotOp(dot, 0, 16); + + EXPECT_EQ(toLinearLayout({16, 16}, wmmaOperand), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 16}, wmmaOperand), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, wmmaOperand), + LinearLayout({{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {0, 32}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, wmmaOperand), + LinearLayout({{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {0, 32}, + {0, 64}, + {32, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v1_2x4Warps_rhs) { + auto dot = wmma(/*warps=*/{2, 4}, /*version=*/1, /*transposed=*/false); + auto wmmaOperand = wmmaDotOp(dot, 1, 16); + + EXPECT_EQ(toLinearLayout({16, 16}, wmmaOperand), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({32, 16}, wmmaOperand), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({32, 64}, wmmaOperand), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, wmmaOperand), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {32, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v1_2x4x1Warps_lhs) { + auto dot = wmma(/*warps=*/{2, 4, 1}, /*version=*/1, /*transposed=*/false); + auto wmmaOperand = wmmaDotOp(dot, 0, 16); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, wmmaOperand), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 32, 16}, wmmaOperand), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 0}}}, + {S("warp"), {{0, 16, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 64, 16}, wmmaOperand), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 0}}}, + {S("warp"), {{0, 16, 0}, {0, 32, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({4, 128, 32}, wmmaOperand), + LinearLayout( + {{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 0, 16}, + {0, 64, 0}, + {2, 0, 0}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 0}}}, + {S("warp"), {{0, 16, 0}, {0, 32, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v1_2x4x1Warps_rhs) { + auto dot = wmma(/*warps=*/{2, 4, 1}, /*version=*/1, /*transposed=*/false); + auto wmmaOperand = wmmaDotOp(dot, 1, 16); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, wmmaOperand), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 0, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 32, 16}, wmmaOperand), + LinearLayout( + {{S("register"), + {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 16, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 0, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 64, 16}, wmmaOperand), + LinearLayout( + {{S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 32, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 0, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({4, 128, 32}, wmmaOperand), + LinearLayout( + {{S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 0, 16}, + {2, 0, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 0, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v2_2x4Warps) { + auto layout = wmma(/*warps=*/{2, 4}, /*version=*/2, /*transposed=*/false); + + EXPECT_EQ(toLinearLayout({16, 16}, layout), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 16}, layout), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 32}, layout), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({64, 128}, layout), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {0, 64}, {32, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v2_2x2x2Warps) { + auto layout = wmma(/*warps=*/{2, 2, 2}, /*version=*/2, /*transposed=*/false); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, layout), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 8, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 16, 16}, layout), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 8, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({4, 64, 64}, layout), + LinearLayout( + {{S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 0, 32}, + {0, 32, 0}, + {2, 0, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 8, 0}}}, + {S("warp"), {{0, 0, 16}, {0, 16, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, TWMMA_v2_2x4Warps) { + auto layout = wmma(/*warps=*/{2, 4}, /*version=*/2, /*transposed=*/true); + + EXPECT_EQ(toLinearLayout({16, 16}, layout), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 16}, layout), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 32}, layout), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}, + {S("warp"), {{0, 16}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({64, 128}, layout), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 64}, {32, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}, + {S("warp"), {{0, 16}, {0, 32}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, TWMMA_v2_2x2x2Warps) { + auto layout = wmma(/*warps=*/{2, 2, 2}, /*version=*/2, /*transposed=*/true); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, layout), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 8}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 16, 16}, layout), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 8}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({4, 64, 64}, layout), + LinearLayout( + {{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 32}, + {0, 32, 0}, + {2, 0, 0}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 8}}}, + {S("warp"), {{0, 0, 16}, {0, 16, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v2_2x4Warps_lhs) { + auto dot = wmma(/*warps=*/{2, 4}, /*version=*/2, /*transposed=*/false); + + auto wmmaOperandK8 = wmmaDotOp(dot, 0, 8); + EXPECT_EQ(toLinearLayout({16, 16}, wmmaOperandK8), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 16}, wmmaOperandK8), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({32, 64}, wmmaOperandK8), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 16}, {0, 32}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, wmmaOperandK8), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {0, 16}, {0, 32}, {0, 64}, {32, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto wmmaOperandK16 = wmmaDotOp(dot, 0, 16); + EXPECT_EQ( + toLinearLayout({16, 32}, wmmaOperandK16), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({32, 32}, wmmaOperandK16), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({32, 128}, wmmaOperandK16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 32}, {0, 64}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, wmmaOperandK16), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 32}, {0, 64}, {32, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v2_2x4Warps_rhs) { + auto dot = wmma(/*warps=*/{2, 4}, /*version=*/2, /*transposed=*/false); + + auto wmmaOperandK8 = wmmaDotOp(dot, 1, 8); + EXPECT_EQ(toLinearLayout({16, 16}, wmmaOperandK8), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 16}, wmmaOperandK8), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, wmmaOperandK8), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, wmmaOperandK8), + LinearLayout({{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto wmmaOperandK16 = wmmaDotOp(dot, 1, 16); + EXPECT_EQ( + toLinearLayout({32, 16}, wmmaOperandK16), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {16, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({32, 32}, wmmaOperandK16), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({64, 64}, wmmaOperandK16), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {32, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 128}, wmmaOperandK16), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {32, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v2_2x4x1Warps_lhs) { + auto dot = wmma(/*warps=*/{2, 4, 1}, /*version=*/2, /*transposed=*/false); + auto wmmaOperandK8 = wmmaDotOp(dot, 0, 8); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, wmmaOperandK8), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 8}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 32, 16}, wmmaOperandK8), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 8}}}, + {S("warp"), {{0, 16, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 64, 16}, wmmaOperandK8), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 8}}}, + {S("warp"), {{0, 16, 0}, {0, 32, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({4, 128, 32}, wmmaOperandK8), + LinearLayout( + {{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 16}, + {0, 64, 0}, + {2, 0, 0}}}, + {S("lane"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {0, 0, 8}}}, + {S("warp"), {{0, 16, 0}, {0, 32, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_v2_2x4x1Warps_rhs) { + auto dot = wmma(/*warps=*/{2, 4, 1}, /*version=*/2, /*transposed=*/false); + auto wmmaOperandK8 = wmmaDotOp(dot, 1, 8); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, wmmaOperandK8), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 8, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 32, 16}, wmmaOperandK8), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 16, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 8, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 64, 16}, wmmaOperandK8), + LinearLayout( + {{S("register"), + {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}, {0, 16, 0}, {0, 32, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 8, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({4, 128, 32}, wmmaOperandK8), + LinearLayout( + {{S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 16, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 0, 16}, + {2, 0, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 8, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, SliceOfBlocked) { + auto parent = blocked({2, 4}, {4, 2}, {2, 2}, {2, 2}, {2, 2}, {1, 0}, {1, 0}); + EXPECT_EQ(toLinearLayout({128}, slice(parent, 0)), + LinearLayout({{S("register"), {{1}, {2}, {16}, {32}}}, + {S("lane"), {{4}, {0}, {0}}}, + {S("warp"), {{8}, {0}}}, + {S("block"), {{64}, {0}}}}, + {S("dim0")})); + EXPECT_EQ(toLinearLayout({128}, slice(parent, 1)), + LinearLayout({{S("register"), {{1}, {16}, {32}}}, + {S("lane"), {{0}, {2}, {4}}}, + {S("warp"), {{0}, {8}}}, + {S("block"), {{0}, {64}}}}, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, SliceWithShape1) { + auto parent = blocked({1, 4}, {8, 4}, {2, 2}, {1, 1}, {1, 1}, {0, 1}, {1, 0}); + EXPECT_EQ(toLinearLayout({1}, slice(parent, 0)), + LinearLayout({{S("register"), {}}, + {S("lane"), {{0}, {0}, {0}, {0}, {0}}}, + {S("warp"), {{0}, {0}}}, + {S("block"), {}}}, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, Slice4D) { + auto parent = blocked({1, 1, 1, 4}, {2, 1, 1, 16}, {1, 2, 4, 1}, {1, 1, 1, 1}, + {1, 1, 1, 1}, {3, 0, 1, 2}, {3, 2, 1, 0}); + EXPECT_EQ(toLinearLayout({2, 1, 1}, slice(parent, 3)), + LinearLayout( + { + {S("register"), {}}, + {S("lane"), + {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, SliceOfMmaV2) { + auto parent = mma(2, 0, {16, 8}, {2, 2}, {1, 1}, {1, 1}, {0, 1}); + EXPECT_EQ(toLinearLayout({16}, slice(parent, 0)), + LinearLayout({{S("register"), {{1}}}, + {S("lane"), {{2}, {4}, {0}, {0}, {0}}}, + {S("warp"), {{8}, {0}}}, + {S("block"), {}}}, + {S("dim0")})); + EXPECT_EQ(toLinearLayout({128}, slice(parent, 0)), + LinearLayout({{S("register"), {{1}, {16}, {32}, {64}}}, + {S("lane"), {{2}, {4}, {0}, {0}, {0}}}, + {S("warp"), {{8}, {0}}}, + {S("block"), {}}}, + {S("dim0")})); + EXPECT_EQ(toLinearLayout({8}, slice(parent, 1)), + LinearLayout({{S("register"), {}}, + {S("lane"), {{0}, {0}, {1}, {2}, {4}}}, + {S("warp"), {{0}, {0}}}, + {S("block"), {}}}, + {S("dim0")})); + EXPECT_EQ(toLinearLayout({128}, slice(parent, 1)), + LinearLayout({{S("register"), {{8}, {32}, {64}}}, + {S("lane"), {{0}, {0}, {1}, {2}, {4}}}, + {S("warp"), {{0}, {16}}}, + {S("block"), {}}}, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSimple1D) { + EXPECT_EQ(toLinearLayout({1024}, shared(1, 1, 1, {1}, {1}, {0}, {0})), + LinearLayout::identity1D(1024, S("offset"), S("dim0")) * + LinearLayout::identity1D(1, S("block"), S("dim0"))); +} + +TEST_F(LinearLayoutConversionsTest, SharedSimple2D) { + EXPECT_EQ(toLinearLayout({128, 128}, + shared(1, 1, 1, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + (LinearLayout::identity1D(128, S("offset"), S("dim1")) * + LinearLayout::identity1D(128, S("offset"), S("dim0")) * + LinearLayout::identity1D(1, S("block"), S("dim0"))) + .transposeOuts({S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSimple2D_Order01) { + EXPECT_EQ(toLinearLayout({128, 128}, + shared(1, 1, 1, {1, 1}, {1, 1}, {0, 1}, {1, 0})), + LinearLayout::identity1D(128, S("offset"), S("dim0")) * + LinearLayout::identity1D(128, S("offset"), S("dim1")) * + LinearLayout::identity1D(1, S("block"), S("dim0"))); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_MaxPhaseOnly) { + EXPECT_EQ( + toLinearLayout({32, 32}, shared(1, 1, 4, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {1, 1}, + {2, 2}, + {4, 0}, + {8, 0}, + {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_PerPhaseMaxPhase) { + EXPECT_EQ( + toLinearLayout({32, 32}, shared(1, 2, 4, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {1, 0}, + {2, 1}, + {4, 2}, + {8, 0}, + {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_Vec) { + EXPECT_EQ( + toLinearLayout({4, 8}, shared(2, 1, 4, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + LinearLayout({{S("offset"), {{0, 1}, {0, 2}, {0, 4}, {1, 2}, {2, 4}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_PerPhaseMaxPhaseVec) { + EXPECT_EQ( + toLinearLayout({32, 32}, shared(2, 2, 4, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {1, 0}, + {2, 2}, + {4, 4}, + {8, 0}, + {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled4D) { + EXPECT_EQ( + toLinearLayout({2, 4, 32, 32}, shared(2, 2, 4, {1, 1, 1, 1}, {1, 1, 1, 1}, + {3, 2, 1, 0}, {3, 2, 1, 0})), + LinearLayout({{S("offset"), + {{0, 0, 0, 1}, + {0, 0, 0, 2}, + {0, 0, 0, 4}, + {0, 0, 0, 8}, + {0, 0, 0, 16}, + {0, 0, 1, 0}, + {0, 0, 2, 2}, + {0, 0, 4, 4}, + {0, 0, 8, 0}, + {0, 0, 16, 0}, + {0, 1, 0, 0}, + {0, 2, 0, 0}, + {1, 0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_Order01) { + EXPECT_EQ( + toLinearLayout({4, 8}, shared(1, 1, 4, {1, 1}, {1, 1}, {0, 1}, {0, 1})), + LinearLayout({{S("offset"), {{1, 0}, {2, 0}, {1, 1}, {2, 2}, {0, 4}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_8x16_4_2) { + EXPECT_EQ( + toLinearLayout( + {8, 16}, nvmmaShared(32, false, 16, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}, {4, 8}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_128x16_4_2) { + EXPECT_EQ(toLinearLayout({128, 16}, nvmmaShared(32, false, 16, {1, 1}, {1, 1}, + {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {1, 0}, + {2, 0}, + {4, 8}, + {8, 0}, + {16, 0}, + {32, 0}, + {64, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_8x32_2_4) { + EXPECT_EQ( + toLinearLayout( + {8, 32}, nvmmaShared(64, false, 16, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + LinearLayout( + {{S("offset"), + {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}, {2, 8}, {4, 16}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_8x64_1_8) { + EXPECT_EQ(toLinearLayout({8, 64}, nvmmaShared(128, false, 16, {1, 1}, {1, 1}, + {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {0, 32}, + {1, 8}, + {2, 16}, + {4, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_8x64_1_8_32b) { + EXPECT_EQ(toLinearLayout({8, 64}, nvmmaShared(128, false, 32, {1, 1}, {1, 1}, + {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {1, 4}, + {2, 8}, + {4, 16}, + {0, 32}}}, + {S("block"), {}}}, + {{S("dim0"), 8}, {S("dim1"), 64}}, + /*requireSurjective=*/false)); +} + +TEST_F(LinearLayoutConversionsTest, Shared1DSwizzle) { + EXPECT_EQ( + toLinearLayout({64, 1}, shared(2, 2, 4, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + LinearLayout::identity1D(64, S("offset"), S("dim0")) * + LinearLayout::identity1D(1, S("offset"), S("dim1")) * + LinearLayout::identity1D(1, S("block"), S("dim0"))); +} + +TEST_F(LinearLayoutConversionsTest, ChooseShmemLayout) { + LinearLayout ll = LinearLayout({{S("register"), {{1}, {2}, {2}, {8}}}, + {S("lane"), {{8}, {4}, {1}}}, + {S("warp"), {{16}, {32}, {0}}}, + {S("block"), {}}}, + {S("dim0")}); + EXPECT_EQ(chooseShemLayoutForRegToRegConversion(&ctx, /*tensorShape=*/{64}, + /*repShape=*/{64}, + /*order=*/{0}), + LinearLayout({{S("offset"), {{1}, {2}, {4}, {8}, {16}, {32}}}, + {S("iteration"), {}}, + {S("block"), {}}}, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, ChooseShmemLayout_Empty) { + LinearLayout ll = LinearLayout({{S("register"), {{0}}}, + {S("lane"), {{0}}}, + {S("warp"), {{0}}}, + {S("block"), {}}}, + {S("dim0")}); + EXPECT_EQ( + chooseShemLayoutForRegToRegConversion(&ctx, /*tensorShape=*/{}, + /*repShape=*/{}, /*order=*/{}), + LinearLayout({{S("offset"), {}}, {S("iteration"), {}}, {S("block"), {}}}, + {})); +} + +TEST_F(LinearLayoutConversionsTest, ChooseShmemLayout_Multidim) { + LinearLayout src( + {{S("register"), {}}, + {S("lane"), + {{0, 0, 1, 0}, {0, 0, 2, 0}, {1, 0, 0, 0}, {2, 0, 0, 0}, {0, 0, 0, 1}}}, + {S("warp"), {{0, 0, 0, 2}, {0, 1, 0, 0}, {0, 2, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2"), S("dim3")}); + EXPECT_EQ( + chooseShemLayoutForRegToRegConversion(&ctx, /*tensorShape=*/{4, 4, 4, 4}, + /*repShape=*/{2, 2, 2, 2}, + /*order=*/{3, 2, 1, 0}), + LinearLayout({{S("offset"), + {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}}}, + {S("iteration"), + {{2, 0, 0, 0}, {0, 2, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 2}}}, + {S("block"), {}}}, + {S("dim3"), S("dim2"), S("dim1"), S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, MMAv5Fp4Padded) { + auto ll = toLinearLayout({32, 64}, nvmmaShared(128, false, 8, {1, 1}, {1, 1}, + {1, 0}, {1, 0}, true)); + EXPECT_EQ(ll, LinearLayout( + {{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 0}, // offset 8 maps to the same indices as offset 0 + {0, 8}, + {0, 16}, + {0, 32}, + {1, 8}, + {2, 16}, + {4, 32}, + {8, 0}, + {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +} // anonymous namespace +} // namespace mlir::triton::gpu + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/SwizzleTest.cpp b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/SwizzleTest.cpp new file mode 100644 index 000000000..aa1f2290d --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Dialect/TritonGPU/SwizzleTest.cpp @@ -0,0 +1,65 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Signals.h" +#include + +using namespace mlir; +using mlir::triton::gpu::SwizzledSharedEncodingAttr; + +struct swizzleParams { + int vec; + int perPhase; + int maxPhase; +}; + +struct ParamT { + std::array shape; + int opIdx; + int typeWidth; + swizzleParams refSwizzle; +}; + +class SwizzleDotOperandTestFixture : public ::testing::TestWithParam { +protected: + ParamType param; +}; + +TEST_P(SwizzleDotOperandTestFixture, DotOperands) { + auto params = GetParam(); + // init context + MLIRContext ctx; + ctx.loadDialect(); + + auto CTALayout = + triton::gpu::CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {0, 1}); + + // create encoding + auto parent = triton::gpu::NvidiaMmaEncodingAttr::get( + &ctx, 2, 0, {1, 1}, CTALayout, {16, 64, 16}); + auto encoding = triton::gpu::DotOperandEncodingAttr::get( + &ctx, params.opIdx, parent, 32 / params.typeWidth); + + // create element type + Type eltType = IntegerType::get(&ctx, params.typeWidth); + auto layout = SwizzledSharedEncodingAttr::get(&ctx, encoding, params.shape, + {1, 0}, CTALayout, eltType); + + ASSERT_EQ(layout.getVec(), params.refSwizzle.vec); + ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase); + ASSERT_EQ(layout.getMaxPhase(), params.refSwizzle.maxPhase); +} + +INSTANTIATE_TEST_SUITE_P(TestDotOperands, SwizzleDotOperandTestFixture, + ::testing::Values(ParamT{{128, 64}, 0, 16, {8, 1, 8}}, + ParamT{{64, 256}, 1, 16, {8, 1, 8}}, + ParamT{{128, 32}, 0, 16, {8, 2, 4}}, + ParamT{{32, 128}, 1, 16, {8, 1, 8}}, + ParamT{{32, 32}, 0, 16, {8, 2, 4}}, + ParamT{{32, 32}, 1, 16, {8, 2, 4}}, + ParamT{{16, 16}, 0, 16, {8, 4, 2}}, + ParamT{{16, 16}, 1, 16, {8, 4, 2}})); + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/enflame/include/triton/unittest/Tools/CMakeLists.txt b/third_party/enflame/include/triton/unittest/Tools/CMakeLists.txt new file mode 100644 index 000000000..3615a88f1 --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Tools/CMakeLists.txt @@ -0,0 +1,5 @@ +add_triton_ut( + NAME LinearLayout + SRCS LayoutUtilsTest.cpp LinearLayoutTest.cpp + LIBS TritonTools +) diff --git a/third_party/enflame/include/triton/unittest/Tools/LayoutUtilsTest.cpp b/third_party/enflame/include/triton/unittest/Tools/LayoutUtilsTest.cpp new file mode 100644 index 000000000..b4f4e382c --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Tools/LayoutUtilsTest.cpp @@ -0,0 +1,49 @@ +#include "triton/Tools/LayoutUtils.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Signals.h" +#include +#include + +namespace mlir::triton { +namespace { + +class LayoutUtilsTest : public ::testing::Test { +public: + StringAttr S(StringRef str) { return StringAttr::get(&ctx, str); } + +protected: + MLIRContext ctx; +}; + +TEST_F(LayoutUtilsTest, SquareSublayoutIsIdentity) { + EXPECT_TRUE(squareSublayoutIsIdentity( + LinearLayout::identity1D(4, S("in"), S("in")), {S("in")})); + EXPECT_TRUE(squareSublayoutIsIdentity( + LinearLayout::identity1D(4, S("in"), S("in")), {})); + + LinearLayout l1( + {{S("in1"), {{1, 1}, {2, 2}, {4, 4}}}, {S("in2"), {{2, 1}, {1, 2}}}}, + {{S("in1"), 8}, {S("in2"), 8}}, /*requireSurjective=*/false); + EXPECT_TRUE(squareSublayoutIsIdentity(l1, {S("in1")})); + EXPECT_FALSE(squareSublayoutIsIdentity(l1, {S("in2")})); + + LinearLayout l2 = LinearLayout::identity1D(4, S("in1"), S("in1")) * + LinearLayout::identity1D(8, S("in2"), S("in2")) * + LinearLayout({{S("in3"), {{1, 1, 1}}}}, + {{S("in1"), 2}, {S("in2"), 2}, {S("in3"), 2}}, + /*requireSurjective=*/false); + EXPECT_FALSE(squareSublayoutIsIdentity(l2, {S("in1")})); + EXPECT_FALSE(squareSublayoutIsIdentity(l2, {S("in2")})); + EXPECT_TRUE(squareSublayoutIsIdentity(l2, {S("in3")})); + EXPECT_FALSE(squareSublayoutIsIdentity(l2, {S("in1"), S("in2")})); + + LinearLayout l3 = LinearLayout::identity1D(4, S("in1"), S("in1")) * + LinearLayout::identity1D(8, S("in2"), S("in2")); + EXPECT_TRUE(squareSublayoutIsIdentity(l3, {S("in1")})); + EXPECT_TRUE(squareSublayoutIsIdentity(l3, {S("in2")})); + EXPECT_TRUE(squareSublayoutIsIdentity(l3, {S("in1"), S("in2")})); +} + +} // namespace +} // namespace mlir::triton diff --git a/third_party/enflame/include/triton/unittest/Tools/LinearLayoutTest.cpp b/third_party/enflame/include/triton/unittest/Tools/LinearLayoutTest.cpp new file mode 100644 index 000000000..4f89bc9c0 --- /dev/null +++ b/third_party/enflame/include/triton/unittest/Tools/LinearLayoutTest.cpp @@ -0,0 +1,931 @@ +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/LayoutUtils.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Signals.h" +#include +#include + +namespace mlir { +std::ostream &operator<<(std::ostream &os, StringAttr str) { + os << str.str(); + return os; +} +} // namespace mlir + +namespace mlir::triton { +namespace { + +using ::llvm::to_vector; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Pair; + +using BasesT = LinearLayout::BasesT; + +class LinearLayoutTest : public ::testing::Test { +public: + StringAttr S(StringRef str) { return StringAttr::get(&ctx, str); } + +protected: + MLIRContext ctx; +}; + +TEST_F(LinearLayoutTest, Empty) { + LinearLayout layout = LinearLayout::empty(); + EXPECT_THAT(layout.getBases(), IsEmpty()); + EXPECT_THAT(to_vector(layout.getInDimNames()), IsEmpty()); + EXPECT_THAT(to_vector(layout.getOutDimNames()), IsEmpty()); +} + +TEST_F(LinearLayoutTest, Identity1D) { + LinearLayout layout = + LinearLayout::identity1D(32, S("testIns"), S("testOuts")); + EXPECT_THAT(layout, LinearLayout({{S("testIns"), {{1}, {2}, {4}, {8}, {16}}}}, + {S("testOuts")})); + EXPECT_THAT(to_vector(layout.getInDimNames()), ElementsAre(S("testIns"))); + EXPECT_THAT(to_vector(layout.getOutDimNames()), ElementsAre(S("testOuts"))); + EXPECT_THAT(layout.getInDimSizeLog2(S("testIns")), 5); + EXPECT_THAT(layout.getOutDimSizeLog2(S("testOuts")), 5); +} + +TEST_F(LinearLayoutTest, Identity1DSize1) { + LinearLayout layout = + LinearLayout::identity1D(1, S("testIns"), S("testOuts")); + EXPECT_EQ(layout, LinearLayout({{S("testIns"), {}}}, {S("testOuts")})); + EXPECT_THAT(to_vector(layout.getInDimNames()), ElementsAre(S("testIns"))); + EXPECT_THAT(to_vector(layout.getOutDimNames()), ElementsAre(S("testOuts"))); + EXPECT_THAT(layout.getInDimSizeLog2(S("testIns")), 0); + EXPECT_THAT(layout.getOutDimSizeLog2(S("testOuts")), 0); +} + +TEST_F(LinearLayoutTest, Zeros1D) { + LinearLayout layout = LinearLayout::zeros1D(32, S("ins"), S("outs")); + EXPECT_EQ(layout, + LinearLayout({{S("ins"), {{0}, {0}, {0}, {0}, {0}}}}, {S("outs")})); +} + +TEST_F(LinearLayoutTest, MultiplyIdentity) { + LinearLayout prod = LinearLayout::identity1D(16, S("in"), S("out")) * + LinearLayout::identity1D(32, S("in"), S("out")); + EXPECT_EQ(prod, LinearLayout( + {{S("in"), + {{1}, {2}, {4}, {8}, {16}, {32}, {64}, {128}, {256}}}}, + {S("out")})); + EXPECT_THAT(to_vector(prod.getInDimNames()), ElementsAre(S("in"))); + EXPECT_THAT(to_vector(prod.getOutDimNames()), ElementsAre(S("out"))); +} + +TEST_F(LinearLayoutTest, MultiplyDisjoint) { + LinearLayout prod = LinearLayout::identity1D(32, S("in1"), S("out1")) * + LinearLayout::identity1D(16, S("in2"), S("out2")); + EXPECT_EQ(prod, LinearLayout( + { + {S("in1"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}}, + {S("in2"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + }, + {S("out1"), S("out2")})); + EXPECT_THAT(to_vector(prod.getInDimNames()), ElementsAre(S("in1"), S("in2"))); + EXPECT_THAT(to_vector(prod.getOutDimNames()), + ElementsAre(S("out1"), S("out2"))); +} + +TEST_F(LinearLayoutTest, MultiplyByEmpty) { + LinearLayout prod = + LinearLayout::empty() * LinearLayout::identity1D(32, S("in"), S("out")); + EXPECT_EQ(prod, LinearLayout::identity1D(32, S("in"), S("out"))); +} + +TEST_F(LinearLayoutTest, MultiplyByZeros) { + LinearLayout prod = LinearLayout::identity1D(8, S("in"), S("out")) * + LinearLayout::zeros1D(16, S("in"), S("out")); + EXPECT_EQ(prod, LinearLayout({{S("in"), {{1}, {2}, {4}, {0}, {0}, {0}, {0}}}}, + {S("out")})); +} + +TEST_F(LinearLayoutTest, MultiplyZerosByDegenerate) { + LinearLayout prod = LinearLayout::zeros1D(16, S("in"), S("out1")) * + LinearLayout({{S("in"), {}}}, {S("out2")}); + EXPECT_EQ(prod, LinearLayout({{S("in"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}}}}, + {S("out1"), S("out2")})); +} + +TEST_F(LinearLayoutTest, MultiplyEmptyIdentityAndZeros) { + LinearLayout prod = LinearLayout::identity1D(0, S("in"), S("out")) * + LinearLayout::zeros1D(4, S("in"), S("out")); + EXPECT_EQ(prod, LinearLayout({{S("in"), {{0}, {0}}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, MultiplyOverlapping) { + LinearLayout prod = LinearLayout::identity1D(4, S("in"), S("out1")) * + LinearLayout::identity1D(8, S("in"), S("out2")); + EXPECT_EQ(prod, + LinearLayout({{S("in"), {{1, 0}, {2, 0}, {0, 1}, {0, 2}, {0, 4}}}}, + {S("out1"), S("out2")})); +} + +TEST_F(LinearLayoutTest, TimesEquals) { + LinearLayout prod = LinearLayout::empty(); + prod *= LinearLayout::identity1D(32, S("in"), S("out")); + EXPECT_EQ(prod, LinearLayout::identity1D(32, S("in"), S("out"))); +} + +TEST_F(LinearLayoutTest, GetOutDimSizeLog2) { + LinearLayout layout( + { + {S("in0"), {{0}, {0}, {0}}}, + {S("in1"), {{1}, {2}}}, + }, + {S("dim0")}); + EXPECT_EQ(layout.getOutDimSizeLog2(S("dim0")), 2); +} + +TEST_F(LinearLayoutTest, TransposeOuts) { + LinearLayout layout = (LinearLayout::identity1D(32, S("in1"), S("out1")) * + LinearLayout::identity1D(16, S("in2"), S("out2"))) + .transposeOuts({S("out2"), S("out1")}); + EXPECT_THAT(to_vector(layout.getOutDimNames()), + ElementsAre(S("out2"), S("out1"))); + EXPECT_EQ(layout, + LinearLayout( + { + {S("in1"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("in2"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}, + }, + {S("out2"), S("out1")})); +} + +TEST_F(LinearLayoutTest, TransposeOutsDegenerate) { + LinearLayout layout = (LinearLayout::identity1D(32, S("in1"), S("out1")) * + LinearLayout::identity1D(1, S("in2"), S("out2"))) + .transposeOuts({S("out2"), S("out1")}); + EXPECT_THAT(to_vector(layout.getOutDimNames()), + ElementsAre(S("out2"), S("out1"))); + EXPECT_EQ(layout, + LinearLayout( + { + {S("in1"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("in2"), {}}, + }, + {S("out2"), S("out1")})); +} + +TEST_F(LinearLayoutTest, TransposeIns) { + LinearLayout layout = (LinearLayout::identity1D(32, S("in1"), S("out1")) * + LinearLayout::identity1D(16, S("in2"), S("out2"))) + .transposeIns({S("in2"), S("in1")}); + EXPECT_THAT(to_vector(layout.getInDimNames()), + ElementsAre(S("in2"), S("in1"))); + EXPECT_EQ(layout, + LinearLayout( + { + {S("in2"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("in1"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}}, + }, + {S("out1"), S("out2")})); +} + +TEST_F(LinearLayoutTest, EmptyToString) { + // Mostly I just want to make sure it doesn't crash. + EXPECT_EQ(LinearLayout::empty().toString(), "\n(empty layout)"); +} + +TEST_F(LinearLayoutTest, Apply) { + LinearLayout layout( + { + {S("in1"), {{4, 2}, {2, 1}, {1, 0}}}, + {S("in2"), {{1, 2}, {2, 1}}}, + }, + {{S("out1"), 8}, {S("out2"), 4}}, /*requireSurjective=*/false); + EXPECT_THAT(layout.apply({{S("in1"), 0}, {S("in2"), 0}}), + ElementsAre(Pair(S("out1"), 0), Pair(S("out2"), 0))); + EXPECT_THAT(layout.apply({{S("in2"), 0}, {S("in1"), 1}}), + ElementsAre(Pair(S("out1"), 4), Pair(S("out2"), 2))); + EXPECT_THAT(layout.apply({{S("in2"), 1}, {S("in1"), 0}}), + ElementsAre(Pair(S("out1"), 1), Pair(S("out2"), 2))); +} + +// This is really more of a benchmark than a test. We're checking that it +// doesn't take so long to run that a human notices and says "hmm". :) +TEST_F(LinearLayoutTest, ConstructLargeLayout) { + std::vector> pows2; + for (int i = 0; i < 25; i++) { + pows2.emplace_back().push_back(1 << i); + } + LinearLayout layout({{S("in"), pows2}}, {S("out")}); + (void)layout; +} + +TEST_F(LinearLayoutTest, Compose) { + LinearLayout l1( + { + {S("in1"), {{1, 1}, {0, 1}}}, + {S("in2"), {{1, 0}, {1, 2}}}, + }, + {S("out1"), S("out2")}); + LinearLayout l2( + { + {S("out1"), {{2, 2}, {1, 0}}}, + {S("out2"), {{1, 1}, {2, 1}}}, + }, + {S("out3"), S("out4")}); + LinearLayout composition = l1.compose(l2); + EXPECT_EQ(composition, + LinearLayout( + { + {S("in1"), {{3, 3}, {1, 1}}}, + {S("in2"), {{2, 2}, {0, 3}}}, + }, + {{S("out3"), 4}, {S("out4"), 4}}, /*requireSurjective=*/false)); + EXPECT_FALSE(composition.isSurjective()); +} + +TEST_F(LinearLayoutTest, Compose4D) { + LinearLayout l1( + {{S("in0"), {{1, 0, 0, 0}, {2, 0, 0, 0}}}, + {S("in1"), {{4, 0, 0, 0}, {8, 0, 0, 0}, {16, 0, 0, 0}, {32, 0, 0, 0}}}, + {S("in2"), {{0, 0, 1, 0}, {0, 0, 0, 1}, {0, 0, 0, 2}}}, + {S("in3"), {}}}, + {S("out3"), S("out0"), S("out1"), S("out2")}); + LinearLayout l2( + { + {S("out3"), + {{1, 0, 0, 0}, + {2, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}}}, + {S("out0"), {{0, 1, 0, 0}}}, + {S("out1"), {{0, 0, 1, 0}}}, + {S("out2"), {{0, 0, 0, 1}, {0, 0, 0, 2}}}, + }, + {S("out3"), S("out2"), S("out1"), S("out0")}); + EXPECT_EQ( + l1.compose(l2), + LinearLayout( + { + {S("in0"), {{1, 0, 0, 0}, {2, 0, 0, 0}}}, + {S("in1"), + {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}}, + {S("in2"), {{0, 0, 1, 0}, {0, 0, 0, 1}, {0, 0, 0, 2}}}, + {S("in3"), {}}, + }, + {{S("out3"), 4}, {S("out2"), 2}, {S("out1"), 2}, {S("out0"), 4}}, + /*requireSurjective=*/false)); +} + +TEST_F(LinearLayoutTest, ReshapeIns) { + LinearLayout ll({{S("in1"), {{1}, {4}, {8}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ(ll.reshapeIns({{S("in3"), {2}}, {S("in4"), {8}}}), + LinearLayout({{S("in3"), {{1}}}, {S("in4"), {{4}, {8}, {2}}}}, + {S("out")})); +} + +TEST_F(LinearLayoutTest, ReshapeInsDegenerateIn) { + LinearLayout ll({{S("in1"), {{1}, {4}, {2}}}, {S("in2"), {}}}, {S("out")}); + EXPECT_EQ( + ll.reshapeIns({{S("in3"), {4}}, {S("in4"), {2}}}), + LinearLayout({{S("in3"), {{1}, {4}}}, {S("in4"), {{2}}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, ReshapeInsDegenerateOut) { + LinearLayout ll({{S("in1"), {{1}, {4}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ( + ll.reshapeIns({{S("in3"), {8}}, {S("in4"), {1}}}), + LinearLayout({{S("in3"), {{1}, {4}, {2}}}, {S("in4"), {}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, ReshapeInsDegenerateFirstOut) { + LinearLayout ll({{S("in1"), {{1}, {4}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ( + ll.reshapeIns({{S("in3"), {1}}, {S("in4"), {8}}}), + LinearLayout({{S("in3"), {}}, {S("in4"), {{1}, {4}, {2}}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, FlattenIns) { + LinearLayout ll({{S("in1"), {{1}, {4}, {8}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ(ll.flattenIns(), + LinearLayout({{S("in1"), {{1}, {4}, {8}, {2}}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, FlattenInsEdgeCases) { + EXPECT_EQ(LinearLayout({{S("in1"), {}}}, {S("out")}).flattenIns(), + LinearLayout({{S("in1"), {}}}, {S("out")})); + EXPECT_EQ(LinearLayout({{S("in1"), {}}}, {}).flattenIns(), + LinearLayout({{S("in1"), {}}}, {})); + using BasesArray = + ArrayRef>>>; + EXPECT_EQ(LinearLayout(BasesArray{}, {S("out")}).flattenIns(), + LinearLayout(BasesArray{}, {S("out")})); + EXPECT_EQ(LinearLayout(BasesArray{}, {}).flattenIns(), + LinearLayout(BasesArray{}, {})); +} + +TEST_F(LinearLayoutTest, ReshapeOuts) { + LinearLayout ll({{S("in1"), {{1}, {4}, {8}}}, {S("in2"), {{3}}}}, {S("out")}); + EXPECT_EQ(ll.getTotalOutDimSize(), 16); + EXPECT_EQ( + ll.reshapeOuts({{S("out2"), {2}}, {S("out3"), {8}}}), + LinearLayout({{S("in1"), {{1, 0}, {0, 2}, {0, 4}}}, {S("in2"), {{1, 1}}}}, + {S("out2"), S("out3")})); +} + +TEST_F(LinearLayoutTest, ReshapeOutsDegenerateIn) { + LinearLayout ll({{S("in1"), {{1}, {4}, {2}}}, {S("in2"), {}}}, {S("out")}); + EXPECT_EQ(ll.reshapeOuts({{S("out1"), {4}}, {S("out2"), {2}}}), + LinearLayout({{S("in1"), {{1, 0}, {0, 1}, {2, 0}}}, {S("in2"), {}}}, + {S("out1"), S("out2")})); +} + +TEST_F(LinearLayoutTest, ReshapeOutsDegenerateOut) { + LinearLayout ll({{S("in1"), {{1}, {4}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ(ll.reshapeOuts({{S("out1"), {8}}, {S("out2"), {1}}}), + LinearLayout({{S("in1"), {{1, 0}, {4, 0}}}, {S("in2"), {{2, 0}}}}, + {S("out1"), S("out2")})); +} + +TEST_F(LinearLayoutTest, FlattenOuts) { + LinearLayout ll({{S("in1"), {{1, 0}, {4, 1}, {8, 4}}}, {S("in2"), {{3, 2}}}}, + {{S("out1"), 16}, {S("out2"), 8}}, + /*requireSurjective=*/false); + EXPECT_EQ(ll.flattenOuts(), + LinearLayout({{S("in1"), {{1}, {4 + 16}, {8 + 4 * 16}}}, + {S("in2"), {{3 + 2 * 16}}}}, + {{S("out1"), 16 * 8}}, /*requireSurjective=*/false)); +} + +TEST_F(LinearLayoutTest, FlattenOutsEdgeCases) { + EXPECT_EQ(LinearLayout({{S("in1"), {}}}, {S("out")}).flattenOuts(), + LinearLayout({{S("in1"), {}}}, {S("out")})); + EXPECT_EQ(LinearLayout({{S("in1"), {}}}, {}).flattenOuts(), + LinearLayout({{S("in1"), {}}}, {})); + using BasesArray = + ArrayRef>>>; + EXPECT_EQ(LinearLayout(BasesArray{}, {S("out")}).flattenOuts(), + LinearLayout(BasesArray{}, {S("out")})); + EXPECT_EQ(LinearLayout(BasesArray{}, {}).flattenOuts(), + LinearLayout(BasesArray{}, {})); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_Simple) { + LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in2"), {{4}, {1}, {2}}}}, {S("out")}); + + // Inverse of l2 is + // out(1) => in2=2 + // out(2) => in2=4 + // out(4) => in2=1. + // + // Composing with l1 gives + // l2^-1(l1(1)) = l2^-1(2) = 4 + // l2^-1(l1(2)) = l2^-1(1) = 2 + // l2^-1(l1(4)) = l2^-1(4) = 1 + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{4}, {2}, {1}}}}, {S("in2")})); + // L2 ∘ L2^-1 ∘ L1 == L1. + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_NonInjective) { + LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in2"), {{0}, {2}, {1}, {4}}}}, {S("out")}); + + // The pseudo-inverse of l2 is + // out(1) => in2=4 + // out(2) => in2=2 + // out(4) => in2=8. + // + // Composing with l1 gives + // l2^-1(l1(1)) = l2^-1(2) = 2 + // l2^-1(l1(2)) = l2^-1(0) = 4 + // l2^-1(l1(4)) = l2^-1(4) = 8 + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{2}, {4}, {8}}}}, {{S("in2"), 16}}, + /*requireSurjective=*/false)); + EXPECT_FALSE(composition.isSurjective()); + + // L2 ∘ L2^-1 ∘ L1 == L1. + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedInDim) { + LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout l2({{S("in"), {{4}, {1}, {2}}}}, {S("out")}); + // Inverse of l2 is + // out(1) = 2 + // out(2) = 4 + // out(4) = 1 + // + // Composing with l1 gives + // + // l2^-1(l1(1, 0)) = l2^-1(2) = 4 + // l2^-1(l1(2, 0)) = l2^-1(1) = 2 + // l2^-1(l1(4, 0)) = l2^-1(4) = 1 + // l2^-1(l1(0, 1)) = l2^-1(0) = 0 + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{4}, {2}, {1}}}, {S("in2"), {{0}}}}, + {S("in")})); + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastAtBeginningOfSecond) { + LinearLayout l1({{S("in"), {{1}, {2}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in"), {{0}, {4}, {1}, {2}}}}, {S("out")}); + // Pseudo-inverse of l2 is + // out(1) = 4 + // out(2) = 8 + // out(4) = 2 + // + // l1 is the identity, so composing with l1 gives back l2^-1. + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in"), {{4}, {8}, {2}}}}, {{S("in"), 16}}, + /*requireSurjective=*/false)); + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastAtEndOfSecond) { + LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in2"), {{4}, {1}, {2}, {0}}}}, {S("out")}); + // Pseudo-inverse of l2 is + // + // out(1) = 2 + // out(2) = 4 + // out(4) = 1 + // + // l1 is the identity, so composing with l1 gives back l2^-1. + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{2}, {4}, {1}}}}, {{S("in2"), 16}}, + /*requireSurjective=*/false)); + EXPECT_TRUE(composition.compose(l2).equalIgnoringOutDimSizes(l1)); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastBeginningAndEndOfSecond) { + LinearLayout l1({{S("in"), {{1}, {2}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in"), {{0}, {4}, {1}, {2}, {0}}}}, {S("out")}); + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in"), {{4}, {8}, {2}}}}, {{S("in"), 32}}, + /*requireSurjective=*/false)); + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_Multidim) { + LinearLayout l1( + {{S("in1"), {{1, 0}, {0, 1}, {2, 0}, {3, 2}}}, {S("in2"), {{2, 2}}}}, + {S("out1"), S("out2")}); + LinearLayout l2({{S("in3"), {{0, 1}, {1, 0}, {0, 0}, {0, 2}, {2, 1}}}}, + {S("out2"), S("out1")}); + + LinearLayout c1 = l1.invertAndCompose(l2); + EXPECT_EQ(c1.compose(l2), + l1.transposeOuts(llvm::to_vector(l2.getOutDimNames()))); + + LinearLayout c2 = l2.invertAndCompose(l1); + EXPECT_EQ(c2.compose(l1), + l2.transposeOuts(llvm::to_vector(l1.getOutDimNames()))); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims) { + LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout l2({{S("in3"), {{1}, {2}, {4}}}, {S("in4"), {{0}}}}, {S("out")}); + LinearLayout c = l1.invertAndCompose(l2); + EXPECT_EQ(c, LinearLayout( + {{S("in1"), {{1, 0}, {2, 0}, {4, 0}}}, {S("in2"), {{0, 0}}}}, + {{S("in3"), 8}, {S("in4"), 2}}, + /*requireSurjective=*/false)); + EXPECT_EQ(c.compose(l2), + l1.transposeOuts(llvm::to_vector(l2.getOutDimNames()))); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims2) { + LinearLayout a({{S("in1"), {{1}, {2}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout b({{S("in3"), {{2}, {1}}}, {S("in4"), {{0}}}}, {S("out")}); + LinearLayout c = a.invertAndCompose(b); + EXPECT_EQ(c, + LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 0}}}}, + {{S("in3"), 4}, {S("in4"), 2}}, + /*requireSurjective=*/false)); + EXPECT_EQ(c.compose(b), a.transposeOuts(llvm::to_vector(b.getOutDimNames()))); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_IdentityInDim) { + SmallVector outDims = {S("dim0"), S("dim1"), S("dim2"), + S("dim3"), S("dim4"), S("dim5"), + S("dim6"), S("dim7"), S("dim8")}; + + LinearLayout src({{S("register"), + { + {0, 0, 0, 0, 0, 0, 0, 0, 1}, + {0, 0, 0, 0, 0, 0, 0, 1, 0}, + }}, + {S("lane"), + { + {0, 0, 0, 0, 0, 0, 1, 0, 0}, + {0, 0, 0, 0, 0, 1, 0, 0, 0}, + {0, 0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 0, 0, 0, 0}, + }}, + {S("warp"), + { + {0, 1, 0, 0, 0, 0, 0, 0, 0}, + {1, 0, 0, 0, 0, 0, 0, 0, 0}, + }}, + {S("block"), {}}}, + outDims); + LinearLayout dst({{S("register"), + { + {0, 0, 0, 0, 0, 0, 0, 0, 1}, + {0, 0, 0, 0, 0, 0, 0, 1, 0}, + }}, + {S("lane"), + { + {1, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 1, 0, 0, 0, 0}, + }}, + {S("warp"), + { + {0, 0, 0, 0, 0, 1, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 1, 0, 0}, + }}, + {S("block"), {}}}, + outDims); + + LinearLayout cvt = dst.invertAndCompose(src); + SmallVector> k = { + {S("register"), 3}, {S("lane"), 0}, {S("warp"), 2}, {S("block"), 0}}; + + EXPECT_EQ(dst.apply(k), src.apply(cvt.apply(k))); +} + +TEST_F(LinearLayoutTest, NumConsecutiveInOut) { + EXPECT_EQ( + 1, + LinearLayout::identity1D(1, S("in"), S("out")).getNumConsecutiveInOut()); + EXPECT_EQ( + 4, + LinearLayout::identity1D(4, S("in"), S("out")).getNumConsecutiveInOut()); + EXPECT_EQ(4, (LinearLayout::identity1D(4, S("in1"), S("out")) * + LinearLayout::identity1D(8, S("in2"), S("out"))) + .getNumConsecutiveInOut()); + EXPECT_EQ(4, (LinearLayout::identity1D(4, S("in"), S("out1")) * + LinearLayout::identity1D(8, S("in"), S("out2"))) + .getNumConsecutiveInOut()); + EXPECT_EQ(1, (LinearLayout::zeros1D(4, S("in"), S("out1")) * + LinearLayout::identity1D(4, S("in"), S("out2"))) + .getNumConsecutiveInOut()); + EXPECT_EQ(1, LinearLayout({{S("in"), {{1}, {2}, {4}, {9}}}}, {S("out")}) + .getNumConsecutiveInOut()); + EXPECT_EQ(2, LinearLayout({{S("in"), {{1}, {2}, {4}, {10}}}}, {S("out")}) + .getNumConsecutiveInOut()); + EXPECT_EQ(2, LinearLayout({{S("in"), {{1}, {4}, {2}}}}, {S("out")}) + .getNumConsecutiveInOut()); + EXPECT_EQ(2, LinearLayout( + { + {S("in"), {{1}, {2}, {4}}}, + {S("in2"), {{8}, {18}}}, + }, + {S("out")}) + .getNumConsecutiveInOut()); +} + +TEST_F(LinearLayoutTest, EqualsChecksOutDimSizes) { + EXPECT_FALSE(LinearLayout::identity1D(4, S("in"), S("out")) == + LinearLayout({{S("in"), {{1}, {2}}}}, {{S("out"), 8}}, + /*requireSurjective=*/false)); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) != + LinearLayout({{S("in"), {{1}, {2}}}}, {{S("out"), 8}}, + /*requireSurjective=*/false)); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) + .equalIgnoringOutDimSizes( + LinearLayout({{S("in"), {{1}, {2}}}}, {{S("out"), 8}}, + /*requireSurjective=*/false))); +} + +TEST_F(LinearLayoutTest, Sublayout) { + LinearLayout l1({{S("in1"), {{1, 0}, {0, 1}, {2, 0}}}, {S("in2"), {{0, 1}}}}, + {S("out1"), S("out2")}); + EXPECT_EQ(l1.sublayout({S("in1"), S("in2")}, {S("out1")}), + LinearLayout({{S("in1"), {{1}, {0}, {2}}}, {S("in2"), {{0}}}}, + {S("out1")})); + EXPECT_EQ(l1.sublayout({S("in2"), S("in1")}, {S("out1")}), + LinearLayout({{S("in1"), {{1}, {0}, {2}}}, {S("in2"), {{0}}}}, + {S("out1")})); + EXPECT_EQ(l1.sublayout({S("in2"), S("in1")}, {S("out2"), S("out1")}), l1); + EXPECT_EQ(l1.sublayout({S("in1")}, {S("out1")}), + LinearLayout({{S("in1"), {{1}, {0}, {2}}}}, {S("out1")})); + EXPECT_EQ(l1.sublayout({}, {}), LinearLayout::empty()); + EXPECT_EQ(l1.sublayout({S("in1")}, {}), + LinearLayout({{S("in1"), {{}, {}, {}}}}, {})); + EXPECT_EQ(l1.sublayout({}, {S("out1")}), + LinearLayout(LinearLayout::BasesT{}, {{S("out1"), 4}}, + /*requireSurjective=*/false)); +} + +TEST_F(LinearLayoutTest, SublayoutIsZero) { + EXPECT_FALSE(LinearLayout::identity1D(4, S("in"), S("out")) + .sublayoutIsZero({S("in")}, {S("out")})); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) + .sublayoutIsZero({}, {S("out")})); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) + .sublayoutIsZero({S("in")}, {})); + EXPECT_TRUE( + LinearLayout::identity1D(4, S("in"), S("out")).sublayoutIsZero({}, {})); + + LinearLayout l1({{S("in1"), {{0, 1}, {0, 2}}}, {S("in2"), {{1, 1}}}}, + {S("out1"), S("out2")}); + EXPECT_TRUE(l1.sublayoutIsZero({S("in1")}, {S("out1")})); + EXPECT_FALSE(l1.sublayoutIsZero({S("in1")}, {S("out2")})); + EXPECT_FALSE(l1.sublayoutIsZero({S("in2")}, {S("out1")})); + EXPECT_FALSE(l1.sublayoutIsZero({S("in2")}, {S("out2")})); +} + +TEST_F(LinearLayoutTest, FreeVariableMasks) { + using llvm::to_vector; + using AR = llvm::ArrayRef>; + + EXPECT_EQ(AR(to_vector(LinearLayout::identity1D(4, S("in"), S("out")) + .getFreeVariableMasks())), + AR({{S("in"), 0}})); + EXPECT_EQ( + AR(to_vector( + LinearLayout::zeros1D(16, S("in"), S("out")).getFreeVariableMasks())), + AR({{S("in"), 0b1111}})); + EXPECT_EQ(AR(to_vector((LinearLayout::identity1D(2, S("in"), S("out")) * + LinearLayout::zeros1D(4, S("in"), S("out")) * + LinearLayout::identity1D(4, S("in"), S("out")) * + LinearLayout::zeros1D(2, S("in"), S("out"))) + .getFreeVariableMasks())), + AR({{S("in"), 0b100110}})); + EXPECT_EQ(AR(to_vector((LinearLayout::identity1D(2, S("in"), S("out")) * + LinearLayout::zeros1D(4, S("in"), S("out")) * + LinearLayout::identity1D(4, S("in"), S("out")) * + LinearLayout::zeros1D(2, S("in"), S("out"))) + .getFreeVariableMasks())), + AR({{S("in"), 0b100110}})); + EXPECT_EQ(AR(to_vector(LinearLayout({{S("in1"), {{1, 1}, {2, 2}, {0, 0}}}, + {S("in2"), {{1, 0}, {0, 1}, {2, 0}}}}, + {S("out1"), S("out2")}) + .getFreeVariableMasks())), + AR({{S("in1"), 0b100}, {S("in2"), 0b10}})); +} + +TEST_F(LinearLayoutTest, QuotientOneDimension) { + LinearLayout layout( + { + {S("dim1"), {{1, 0}}}, + {S("dim2"), {{0, 0}}}, + }, + {{S("dim1"), 2}, {S("dim2"), 1}}, /*requireSurjective=*/false); + + // Quotient over dim1, which is trivial + auto quotientLayout = layout.quotient({S("dim1")}); + ASSERT_TRUE(quotientLayout.has_value()); + EXPECT_EQ(*quotientLayout, LinearLayout::zeros1D(2, S("dim2"), S("dim2"))); + // dim2 is zero, not the identity + ASSERT_FALSE(quotientLayout->quotient({S("dim2")}).has_value()); +} + +TEST_F(LinearLayoutTest, QuotientSeveralDimensions) { + LinearLayout layout( + { + {S("dim1"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("dim2"), {{0, 1}, {0, 2}}}, + }, + {S("dim1"), S("dim2")}); + + auto quotientLayout = layout.quotient({S("dim1"), S("dim2")}); + EXPECT_TRUE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientMultipleTrivialDimensions) { + LinearLayout layout( + { + {S("dim1"), {{1, 0, 2}, {2, 0, 1}}}, + {S("dim2"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("dim3"), {{0, 0, 1}, {0, 0, 2}}}, + }, + {S("dim1"), S("dim2"), S("dim3")}); + + // Quotient over dim2 is trivial, even if there's some funny business + // going on in the other dimensions + auto quotientLayout = layout.quotient({S("dim2")}); + ASSERT_TRUE(quotientLayout.has_value()); + + layout = LinearLayout( + { + {S("dim1"), {{1, 0, 2}, {2, 0, 1}}}, + {S("dim2"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("dim3"), {{0, 1, 1}, {0, 0, 2}}}, + }, + {S("dim1"), S("dim2"), S("dim3")}); + + // As soon as one maps into the dimension being quotiented or out of it + // (in this case dim3 depends on dim2), we cannot quotient + quotientLayout = layout.quotient({S("dim2")}); + ASSERT_FALSE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientEmptyLayout) { + LinearLayout layout = LinearLayout::empty(); + + // Quotienting over a dimension that doesn't exist is invalid + auto quotientLayout = layout.quotient({S("dim1")}); + ASSERT_FALSE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { + // Test quotient on identity layout with multiple dimensions + LinearLayout layout = LinearLayout::identity1D(8, S("dim1"), S("dim1")) * + LinearLayout::identity1D(2, S("dim2"), S("dim2")) * + LinearLayout::identity1D(4, S("dim3"), S("dim3")); + + // We can quotient over all dimensions in any order + auto quotientLayout = layout.quotient({S("dim1"), S("dim3")}); + ASSERT_TRUE(quotientLayout.has_value()); + ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); +} + +LinearLayout getPackedCoordtoPaddedOffset(int M, int KPacked8b, StringAttr row, + StringAttr col, StringAttr offset) { + std::vector> basesRows, basesCols; + for (int i = 0; i < llvm::Log2_32(M); ++i) { + int row = 1 << i; + int col = 0; + int linearCoord = row * KPacked8b + col; + int offset = (linearCoord / 8) * 16 + (linearCoord % 8); + basesRows.push_back({offset}); + } + + for (int j = 0; j < llvm::Log2_32(KPacked8b); ++j) { + int row = 0; + int col = 1 << j; + int linearCoord = row * KPacked8b + col; + int offset = (linearCoord / 8) * 16 + (linearCoord % 8); + basesCols.push_back({offset}); + } + + return LinearLayout({{row, basesRows}, {col, basesCols}}, + {{offset, M * KPacked8b * 2}}, /*surjective*/ false); +} + +TEST_F(LinearLayoutTest, BlackwellMixedPrecisionDotScaledSMEM) { + std::vector> basesRows, basesCols, basesOffset; + int numFp4Elems = 128; + int M = 16; + int KPacked8b = numFp4Elems / M / 2; + int KPadded8b = numFp4Elems / M; + + for (int i = 0; i < llvm::Log2_32(M * KPadded8b); ++i) { + int offset = 1 << i; + int linearCoordPacked = offset / 16 * 8 + offset % 8; + int row = linearCoordPacked / KPacked8b; + int col = linearCoordPacked % KPacked8b; + basesOffset.push_back({row, col}); + } + + LinearLayout layout({{S("offset"), basesOffset}}, {S("row"), S("col")}); + LinearLayout layoutInverseComputed = layout.pseudoinvert(); + LinearLayout layoutInverseManual = getPackedCoordtoPaddedOffset( + M, KPacked8b, S("row"), S("col"), S("offset")); + + for (int i = 0; i < M; ++i) { + for (int j = 0; j < KPacked8b; ++j) { + auto off1 = layoutInverseManual.apply({{S("row"), i}, {S("col"), j}}); + auto off2 = layoutInverseComputed.apply({{S("row"), i}, {S("col"), j}}); + EXPECT_EQ(off1[0].second, off2[0].second); + } + } +} + +TEST_F(LinearLayoutTest, BlackwellMixedPrecisionDotScaledSMEMSwizzled) { + int M = 16; + int KPadded8b = 128; + int numFp4Elems = M * KPadded8b; + int KPacked8b = KPadded8b / 2; + int elemBitWidth = 8; + int tileWidthBytes = 128; + int tileRows = 8; + int tileCols = 8 * tileWidthBytes / elemBitWidth; + int vec = 16; + + std::vector> bases2D; + for (int logCol = 0; logCol < llvm::Log2_32(tileCols); logCol++) { + int colPadded = 1 << logCol; + int colPacked = colPadded / 16 * 8 + colPadded % 8; + bases2D.push_back({0, colPacked}); + } + for (int logRow = 0; logRow < llvm::Log2_32(tileRows); logRow++) { + int row = 1 << logRow; + int perPhase = 1; + int maxPhase = 8; + int colPadded = vec * ((row / perPhase) % maxPhase); + int colPacked = colPadded / 16 * 8 + colPadded % 8; + bases2D.push_back({row, colPacked}); + } + + LinearLayout layoutSwizzled({{S("offset"), bases2D}}, {S("row"), S("col")}); + layoutSwizzled = ensureLayoutNotSmallerThan( + layoutSwizzled, {{S("row"), M}, {S("col"), KPacked8b}}); + + auto layoutInverseSwizzled = layoutSwizzled.pseudoinvert(); + + LinearLayout layoutInverseNoSwizzle = getPackedCoordtoPaddedOffset( + M, KPacked8b, S("row"), S("col"), S("offset")); + + for (int i = 0; i < M; ++i) { + for (int j = 0; j < KPacked8b; ++j) { + auto nonSwizzleOffset = + layoutInverseNoSwizzle.apply({{S("row"), i}, {S("col"), j}})[0] + .second; + auto swizzledOffset = + layoutInverseSwizzled.apply({{S("row"), i}, {S("col"), j}})[0].second; + int row = nonSwizzleOffset / KPadded8b; + int col = nonSwizzleOffset % KPadded8b; + int colSwizzled = ((col / 16) ^ (row % 8)) * 16 + col % 16; + EXPECT_EQ(row * KPadded8b + colSwizzled, swizzledOffset); + } + } +} + +static SmallVector makeList(MLIRContext *ctx, + llvm::ArrayRef list) { + SmallVector ret; + for (auto s : list) + ret.push_back(StringAttr::get(ctx, s)); + return ret; +} + +TEST(SupremumTest, IdenticalLists) { + MLIRContext ctx; + SmallVector x = makeList(&ctx, {"a", "b", "c"}); + SmallVector y = makeList(&ctx, {"a", "b", "c"}); + EXPECT_EQ(supremum(x, y), x); +} + +TEST(SupremumTest, NonUniqueSupremumFirstListPriority) { + MLIRContext ctx; + // sup([a, b], [a, c]) should yield [a, b, c] + SmallVector x = makeList(&ctx, {"a", "b"}); + SmallVector y = makeList(&ctx, {"a", "c"}); + EXPECT_EQ(supremum(x, y), makeList(&ctx, {"a", "b", "c"})); +} + +TEST(SupremumTest, NonUniqueSupremumAlternate) { + MLIRContext ctx; + // sup([a, b], [b, c]) should yield [a, b, c] + SmallVector x = makeList(&ctx, {"a", "b"}); + SmallVector y = makeList(&ctx, {"b", "c"}); + EXPECT_EQ(supremum(x, y), makeList(&ctx, {"a", "b", "c"})); +} + +TEST(SupremumTest, DifferentLengths) { + MLIRContext ctx; + // sup([a, b, c], [a, d]) should yield [a, b, c, d] + SmallVector x = makeList(&ctx, {"a", "b", "c"}); + SmallVector y = makeList(&ctx, {"a", "d"}); + EXPECT_EQ(supremum(x, y), makeList(&ctx, {"a", "b", "c", "d"})); +} + +TEST(SupremumTest, SupremumEmptyLists) { + MLIRContext ctx; + SmallVector x; + SmallVector y; + EXPECT_TRUE(supremum(x, y).empty()); +} + +TEST(SupremumTest, OneEmptyList) { + MLIRContext ctx; + // sup([a, b], []) should yield [a, b] + SmallVector x = makeList(&ctx, {"a", "b"}); + SmallVector y; + EXPECT_EQ(supremum(x, y), makeList(&ctx, {"a", "b"})); +} + +#ifdef LLVM_ENABLE_ASSERTIONS +TEST(SupremumTest, ErrorOnInconsistentOrder) { + MLIRContext ctx; + // sup([a, b], [b, a]) has no consistent ordering so it should trigger + // llvm_unreachable. + SmallVector x = makeList(&ctx, {"a", "b"}); + SmallVector y = makeList(&ctx, {"b", "a"}); + ASSERT_DEATH({ supremum(x, y); }, "Supremum does not exist"); +} +#endif +} // anonymous namespace +} // namespace mlir::triton + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/enflame/include/triton/unittest/googletest.cmake b/third_party/enflame/include/triton/unittest/googletest.cmake new file mode 100644 index 000000000..41d3d4fa4 --- /dev/null +++ b/third_party/enflame/include/triton/unittest/googletest.cmake @@ -0,0 +1,23 @@ +include(FetchContent) + +set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") + +if(GOOGLETEST_DIR) + set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") +endif() + +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.12.1 + ) + +FetchContent_GetProperties(googletest) + +if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + if (MSVC) + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + endif() + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) +endif() diff --git a/third_party/enflame/include/triton/utils/generate-test-checks.py b/third_party/enflame/include/triton/utils/generate-test-checks.py new file mode 100755 index 000000000..a271c3f8b --- /dev/null +++ b/third_party/enflame/include/triton/utils/generate-test-checks.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +""" +=============================================================== +A script to generate FileCheck statements for mlir unit tests. +=============================================================== + +This script is a utility to add FileCheck patterns to an mlir file. + +NOTE: The input ``.mlir`` is expected to be the output from the parser, not a +stripped down variant. + +Example usage: + +.. code-block:: shell + + $ generate-test-checks.py foo.mlir + $ mlir-opt foo.mlir -transformation | generate-test-checks.py + $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir + $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i + $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @' + +The script will heuristically generate CHECK/CHECK-LABEL commands for each line +within the file. By default this script will also try to insert string +substitution blocks for all SSA value names. If ``--source file`` is specified, the +script will attempt to insert the generated CHECKs to the source file by looking +for line positions matched by ``--source_delim_regex``. + +The script is designed to make adding checks to a test case fast, it is *not* +designed to be authoritative about what constitutes a good test! +""" + +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import os # Used to advertise this file's name ("autogenerated_note"). +import re +import sys +from typing import Optional + +ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by " +ADVERT_END = """ +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. +""" + +# Regex command to match an SSA identifier. +SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" +SSA_RE = re.compile(SSA_RE_STR) + +# Regex matching the left-hand side of an assignment +SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*=' +SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR) + +# Regex matching attributes +ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)' +ATTR_RE = re.compile(ATTR_RE_STR) + +# Regex matching the left-hand side of an attribute definition +ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*=' +ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR) + + +# Class used to generate and manage string substitution blocks for SSA value +# names. +class VariableNamer: + + def __init__(self, variable_names): + self.scopes = [] + self.name_counter = 0 + + # Number of variable names to still generate in parent scope + self.generate_in_parent_scope_left = 0 + + # Parse variable names + self.variable_names = [name.upper() for name in variable_names.split(',')] + self.used_variable_names = set() + + # Generate the following 'n' variable names in the parent scope. + def generate_in_parent_scope(self, n): + self.generate_in_parent_scope_left = n + + # Generate a substitution name for the given ssa value name. + def generate_name(self, source_variable_name): + + # Compute variable name + variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else '' + if variable_name == '': + variable_name = "VAL_" + str(self.name_counter) + self.name_counter += 1 + + # Scope where variable name is saved + scope = len(self.scopes) - 1 + if self.generate_in_parent_scope_left > 0: + self.generate_in_parent_scope_left -= 1 + scope = len(self.scopes) - 2 + assert (scope >= 0) + + # Save variable + if variable_name in self.used_variable_names: + raise RuntimeError(variable_name + ': duplicate variable name') + self.scopes[scope][source_variable_name] = variable_name + self.used_variable_names.add(variable_name) + + return variable_name + + # Push a new variable name scope. + def push_name_scope(self): + self.scopes.append({}) + + # Pop the last variable name scope. + def pop_name_scope(self): + self.scopes.pop() + + # Return the level of nesting (number of pushed scopes). + def num_scopes(self): + return len(self.scopes) + + # Reset the counter and used variable names. + def clear_names(self): + self.name_counter = 0 + self.used_variable_names = set() + + +class AttributeNamer: + + def __init__(self, attribute_names): + self.name_counter = 0 + self.attribute_names = [name.upper() for name in attribute_names.split(',')] + self.map = {} + self.used_attribute_names = set() + + # Generate a substitution name for the given attribute name. + def generate_name(self, source_attribute_name): + + # Compute FileCheck name + attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else '' + if attribute_name == '': + attribute_name = "ATTR_" + str(self.name_counter) + self.name_counter += 1 + + # Prepend global symbol + attribute_name = '$' + attribute_name + + # Save attribute + if attribute_name in self.used_attribute_names: + raise RuntimeError(attribute_name + ': duplicate attribute name') + self.map[source_attribute_name] = attribute_name + self.used_attribute_names.add(attribute_name) + return attribute_name + + # Get the saved substitution name for the given attribute name, if it exists. + def get_name(self, source_attribute_name) -> Optional[str]: + return self.map.get(source_attribute_name) + + +# Return the number of SSA results in a line of type +# %0, %1, ... = ... +# The function returns 0 if there are no results. +def get_num_ssa_results(input_line): + m = SSA_RESULTS_RE.match(input_line) + return m.group().count('%') if m else 0 + + +# Process a line of input that has been split at each SSA identifier '%'. +def process_line(line_chunks, variable_namer): + output_line = "" + + # Process the rest that contained an SSA value name. + for chunk in line_chunks: + m = SSA_RE.match(chunk) + ssa_name = m.group(0) if m is not None else '' + + # Check if an existing variable exists for this name. + variable = None + for scope in variable_namer.scopes: + variable = scope.get(ssa_name) + if variable is not None: + break + + # If one exists, then output the existing name. + if variable is not None: + output_line += "%[[" + variable + "]]" + else: + # Otherwise, generate a new variable. + variable = variable_namer.generate_name(ssa_name) + output_line += "%[[" + variable + ":.*]]" + + # Append the non named group. + output_line += chunk[len(ssa_name):] + + return output_line.rstrip() + "\n" + + +# Process the source file lines. The source file doesn't have to be .mlir. +def process_source_lines(source_lines, note, args): + source_split_re = re.compile(args.source_delim_regex) + + source_segments = [[]] + for line in source_lines: + # Remove previous note. + if line == note: + continue + # Remove previous CHECK lines. + if line.find(args.check_prefix) != -1: + continue + # Segment the file based on --source_delim_regex. + if source_split_re.search(line): + source_segments.append([]) + + source_segments[-1].append(line + "\n") + return source_segments + + +def process_attribute_definition(line, attribute_namer, output): + m = ATTR_DEF_RE.match(line) + if m: + attribute_name = attribute_namer.generate_name(m.group(1)) + line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n' + output.append(line) + + +def process_attribute_references(line, attribute_namer): + + output_line = '' + components = ATTR_RE.split(line) + for component in components: + m = ATTR_RE.match(component) + name = attribute_namer.get_name(m.group(1)) if m else None + if name is None: + output_line += component + else: + output_line += '#[[' + name + ']]' + output_line += component[len(m.group()):] + return output_line + + +# Pre-process a line of input to remove any character sequences that will be +# problematic with FileCheck. +def preprocess_line(line): + # Replace any double brackets, '[[' with escaped replacements. '[[' + # corresponds to variable names in FileCheck. + output_line = line.replace("[[", "{{\\[\\[}}") + + # Replace any single brackets that are followed by an SSA identifier, the + # identifier will be replace by a variable; Creating the same situation as + # above. + output_line = output_line.replace("[%", "{{\\[}}%") + + return output_line + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("--check-prefix", default="CHECK", help="Prefix to use from check file.") + parser.add_argument("-o", "--output", nargs="?", type=argparse.FileType("w"), default=None) + parser.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin) + parser.add_argument( + "--source", + type=str, + help="Print each CHECK chunk before each delimiter line in the source" + "file, respectively. The delimiter lines are identified by " + "--source_delim_regex.", + ) + parser.add_argument("--source_delim_regex", type=str, default="func @") + parser.add_argument( + "--starts_from_scope", + type=int, + default=1, + help="Omit the top specified level of content. For example, by default " + 'it omits "module {"', + ) + parser.add_argument("-i", "--inplace", action="store_true", default=False) + parser.add_argument( + "--variable_names", type=str, default='', + help="Names to be used in FileCheck regular expression to represent SSA " + "variables in the order they are encountered. Separate names with commas, " + "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')") + parser.add_argument( + "--attribute_names", type=str, default='', help="Names to be used in FileCheck regular expression to represent " + "attributes in the order they are defined. Separate names with commas," + "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')") + + args = parser.parse_args() + + # Open the given input file. + input_lines = [l.rstrip() for l in args.input] + args.input.close() + + # Generate a note used for the generated check file. + script_name = os.path.basename(__file__) + autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END + + source_segments = None + if args.source: + source_segments = process_source_lines([l.rstrip() for l in open(args.source, "r")], autogenerated_note, args) + + if args.inplace: + assert args.output is None + output = open(args.source, "w") + elif args.output is None: + output = sys.stdout + else: + output = args.output + + output_segments = [[]] + + # Namers + variable_namer = VariableNamer(args.variable_names) + attribute_namer = AttributeNamer(args.attribute_names) + + # Process lines + for input_line in input_lines: + if not input_line: + continue + + # Check if this is an attribute definition and process it + process_attribute_definition(input_line, attribute_namer, output_segments[-1]) + + # Lines with blocks begin with a ^. These lines have a trailing comment + # that needs to be stripped. + lstripped_input_line = input_line.lstrip() + is_block = lstripped_input_line[0] == "^" + if is_block: + input_line = input_line.rsplit("//", 1)[0].rstrip() + + cur_level = variable_namer.num_scopes() + + # If the line starts with a '}', pop the last name scope. + if lstripped_input_line[0] == "}": + variable_namer.pop_name_scope() + cur_level = variable_namer.num_scopes() + + # If the line ends with a '{', push a new name scope. + if input_line[-1] == "{": + variable_namer.push_name_scope() + if cur_level == args.starts_from_scope: + output_segments.append([]) + + # Result SSA values must still be pushed to parent scope + num_ssa_results = get_num_ssa_results(input_line) + variable_namer.generate_in_parent_scope(num_ssa_results) + + # Omit lines at the near top level e.g. "module {". + if cur_level < args.starts_from_scope: + continue + + if len(output_segments[-1]) == 0: + variable_namer.clear_names() + + # Preprocess the input to remove any sequences that may be problematic with + # FileCheck. + input_line = preprocess_line(input_line) + + # Process uses of attributes in this line + input_line = process_attribute_references(input_line, attribute_namer) + + # Split the line at the each SSA value name. + ssa_split = input_line.split("%") + + # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. + if len(output_segments[-1]) != 0 or not ssa_split[0]: + output_line = "// " + args.check_prefix + ": " + # Pad to align with the 'LABEL' statements. + output_line += " " * len("-LABEL") + + # Output the first line chunk that does not contain an SSA name. + output_line += ssa_split[0] + + # Process the rest of the input line. + output_line += process_line(ssa_split[1:], variable_namer) + + else: + # Output the first line chunk that does not contain an SSA name for the + # label. + output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n" + + # Process the rest of the input line on separate check lines. + output_line += "// " + args.check_prefix + "-SAME: " + output_line += process_line(ssa_split[1:], variable_namer) + + # Append the output line. + output_segments[-1].append(output_line) + + output.write(autogenerated_note + "\n") + + # Write the output. + if source_segments: + assert len(output_segments) == len(source_segments), (len(output_segments), len(source_segments)) + for check_segment, source_segment in zip(output_segments, source_segments): + for line in check_segment: + output.write(line) + for line in source_segment: + output.write(line) + else: + for segment in output_segments: + output.write("\n") + for output_line in segment: + output.write(output_line) + output.write("\n") + output.close() + + +if __name__ == "__main__": + main() diff --git a/third_party/enflame/include/triton/utils/nightly.pypirc b/third_party/enflame/include/triton/utils/nightly.pypirc new file mode 100644 index 000000000..d62ba7b93 --- /dev/null +++ b/third_party/enflame/include/triton/utils/nightly.pypirc @@ -0,0 +1,6 @@ +[distutils] +Index-servers = + Triton-Nightly + +[Triton-Nightly] +Repository = https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/upload/ diff --git a/third_party/enflame/include/version.h.in b/third_party/enflame/include/version.h.in new file mode 100644 index 000000000..035fb3229 --- /dev/null +++ b/third_party/enflame/include/version.h.in @@ -0,0 +1,14 @@ +#define TRITON_GCU_VERSION @CMAKE_PROJECT_VERSION@ +#define TRITON_GCU_VERSION_MAJOR @CMAKE_PROJECT_VERSION_MAJOR@ +#define TRITON_GCU_VERSION_MINOR @CMAKE_PROJECT_VERSION_MINOR@ +#define TRITON_GCU_VERSION_PATCH @CMAKE_PROJECT_VERSION_PATCH@ +#define TRITON_GCU_VERSION_TWEAK @CMAKE_PROJECT_VERSION_TWEAK@ +#define TRITON_GCU_VERSION_VALUE VERSION_VALUE(TRITON_GCU) + +#define TRITON_GCU_VERSION_STR "@CMAKE_PROJECT_VERSION@" +#define TRITON_GCU_VERSION_STR_MAJOR "@CMAKE_PROJECT_VERSION_MAJOR@" +#define TRITON_GCU_VERSION_STR_MINOR "@CMAKE_PROJECT_VERSION_MINOR@" +#define TRITON_GCU_VERSION_STR_PATCH "@CMAKE_PROJECT_VERSION_PATCH@" +#define TRITON_GCU_VERSION_STR_TWEAK "@CMAKE_PROJECT_VERSION_TWEAK@" +#define TRITON_GCU_GIT_VERSION_STR "@TOTT_WC_REVISION@" +#define CLANG_BINARY_SHORT_HASH "@CLANG_BINARY_SHORT_HASH@" diff --git a/third_party/enflame/language/gcu/__init__.py b/third_party/enflame/language/gcu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/language/gcu/libdevice.py b/third_party/enflame/language/gcu/libdevice.py new file mode 100644 index 000000000..b467db849 --- /dev/null +++ b/third_party/enflame/language/gcu/libdevice.py @@ -0,0 +1,1058 @@ +import os +from triton.language import core +import triton.language.extra +from triton.language.extra import libdevice + + +def get_bool_env(env): + s = os.getenv(env, "").lower() + if (s == "1" or s == "true" or s == "on"): + return True + return False + + +# unary op +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +# unary int float +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +# unary int float + + +# unary half float +@core.extern +def float2half_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2half_rn", core.dtype("uint16")), + }, is_pure=True, _builder=_builder) + + +# unary half float + + +# unary long long float +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + # gcu don't support the return type int32 + # (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")), + ( + core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +# unary long long float +# unary op + + +# binary op +@core.extern +def max(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int64"), core.dtype("int64")): ("__nv_llmax", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_ullmax", core.dtype("uint64")), + (core.dtype("int32"), core.dtype("int32")): ("__nv_max", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umax", core.dtype("uint32")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaxf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def min(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int64"), core.dtype("int64")): ("__nv_llmin", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_ullmin", core.dtype("uint64")), + (core.dtype("int32"), core.dtype("int32")): ("__nv_min", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umin", core.dtype("uint32")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fminf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +# binary i32 +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +# binary i32 + + +# binary float +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +# binary float +# binary + + +# ternary +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +# ternary +# other +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +# TODO: the implementation of logger is required to be added +@core.extern +def logger(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_dummy_logger", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def begin_clock(_builder=None): + return core.extern_elementwise("", "", [], { + (): ("__gcu_begin_clock", core.dtype("int64")), + }, is_pure=False, _builder=_builder) + + +@core.extern +def end_clock(_builder=None): + return core.extern_elementwise("", "", [], { + (): ("__gcu_end_clock", core.dtype("int64")), + }, is_pure=False, _builder=_builder) + + +# other + +if (get_bool_env("TRITON_GCU_DEBUG")): + libdevice.begin_clock = begin_clock + libdevice.end_clock = end_clock + +func_mapping = { + # unary + "abs": abs, + "acos": acos, + "acosh": acosh, + "asin": asin, + "asinh": asinh, + "atan": atan, + "atanh": atanh, + "cbrt": cbrt, + "ceil": ceil, + "cos": cos, + "cosh": cosh, + "cospi": cospi, + "erf": erf, + "exp10": exp10, + "exp2": exp2, + "exp": exp, + "expm1": expm1, + "floor": floor, + "rcp_rd": rcp_rd, + "rcp_rn": rcp_rn, + "rcp_ru": rcp_ru, + "rcp_rz": rcp_rz, + "rsqrt_rn": rsqrt_rn, + "sqrt_rd": sqrt_rd, + "sqrt_rn": sqrt_rn, + "sqrt_ru": sqrt_ru, + "sqrt_rz": sqrt_rz, + "j0": j0, + "j1": j1, + "lgamma": lgamma, + "log10": log10, + "log1p": log1p, + "log2": log2, + "log": log, + "rint": rint, + "round": round, + "rsqrt": rsqrt, + "sin": sin, + "sinh": sinh, + "sqrt": sqrt, + "tan": tan, + "tanh": tanh, + "trunc": trunc, + "y0": y0, + "y1": y1, + # unary half float + "float2half_rn": float2half_rn, + # unary int float + "ilogb": ilogb, + "signbit": signbit, + "float_as_int": float_as_int, + "int_as_float": int_as_float, + "float2int_rd": float2int_rd, + "float2int_rn": float2int_rn, + "float2int_ru": float2int_ru, + "float2int_rz": float2int_rz, + "float2uint_rd": float2uint_rd, + "float2uint_rn": float2uint_rn, + "float2uint_ru": float2uint_ru, + "float2uint_rz": float2uint_rz, + "int2float_rd": int2float_rd, + "int2float_rn": int2float_rn, + "int2float_ru": int2float_ru, + "int2float_rz": int2float_rz, + "uint2float_rd": uint2float_rd, + "uint2float_rn": uint2float_rn, + "uint2float_ru": uint2float_ru, + "uint2float_rz": uint2float_rz, + # unary long long float + "ffs": ffs, + "llrint": llrint, + "llround": llround, + "float2ll_rd": float2ll_rd, + "float2ll_rn": float2ll_rn, + "float2ll_ru": float2ll_ru, + "float2ll_rz": float2ll_rz, + "float2ull_rd": float2ull_rd, + "float2ull_rn": float2ull_rn, + "float2ull_ru": float2ull_ru, + "float2ull_rz": float2ull_rz, + "ll2float_rd": ll2float_rd, + "ll2float_rn": ll2float_rn, + "ll2float_ru": ll2float_ru, + "ll2float_rz": ll2float_rz, + "ull2float_rd": ull2float_rd, + "ull2float_rn": ull2float_rn, + "ull2float_ru": ull2float_ru, + "ull2float_rz": ull2float_rz, + # binary + "max": max, + "min": min, + # binary i32 + "mulhi": mulhi, + # binary float + "atan2": atan2, + "copysign": copysign, + "add_rd": add_rd, + "add_rn": add_rn, + "add_ru": add_ru, + "add_rz": add_rz, + "fdim": fdim, + "div_rd": div_rd, + "div_rn": div_rn, + "div_ru": div_ru, + "div_rz": div_rz, + "fmod": fmod, + "mul_rd": mul_rd, + "mul_rn": mul_rn, + "mul_ru": mul_ru, + "mul_rz": mul_rz, + "sub_rd": sub_rd, + "sub_rn": sub_rn, + "sub_ru": sub_ru, + "sub_rz": sub_rz, + "hypot": hypot, + "pow": pow, + # ternary + "fma": fma, + "fma_rd": fma_rd, + "fma_rn": fma_rn, + "fma_ru": fma_ru, + "fma_rz": fma_rz, + # other + "isinf": isinf, + "isnan": isnan, + "finitef": finitef, + "logger": logger, +} + +for name, func in func_mapping.items(): + if hasattr(libdevice, name): + existing_func = getattr(libdevice, name) + if existing_func is not None: + setattr(libdevice, name, func) diff --git a/third_party/enflame/python/test/unit/conftest.py b/third_party/enflame/python/test/unit/conftest.py new file mode 100644 index 000000000..b4b541435 --- /dev/null +++ b/third_party/enflame/python/test/unit/conftest.py @@ -0,0 +1,39 @@ +# content of conftest.py +import os +import pytest +import tempfile + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='gcu') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") + + +@pytest.fixture +def fresh_triton_cache(): + with tempfile.TemporaryDirectory() as tmpdir: + try: + os.environ["TRITON_CACHE_DIR"] = tmpdir + yield tmpdir + finally: + os.environ.pop("TRITON_CACHE_DIR", None) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + skipped = terminalreporter.stats.get("skipped", []) + if skipped: + terminalreporter.write_sep("=", "detailed skipped tests") + for report in skipped: + node_id = report.nodeid + skip_info = "" + + if isinstance(report.longrepr, tuple): + skip_info = report.longrepr[2] if len(report.longrepr) > 2 else str(report.longrepr) + else: + skip_info = str(report.longrepr).split("\n")[-1] + + terminalreporter.write_line(f"{node_id} - {skip_info}") diff --git a/third_party/enflame/python/test/unit/language/.coveragerc b/third_party/enflame/python/test/unit/language/.coveragerc new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/python/test/unit/language/assert_helper.py b/third_party/enflame/python/test/unit/language/assert_helper.py new file mode 100644 index 000000000..04ae0876b --- /dev/null +++ b/third_party/enflame/python/test/unit/language/assert_helper.py @@ -0,0 +1,136 @@ +import sys + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import torch_gcu + + +@triton.jit +def kernel_device_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_assert(x < 32, "x >= 32") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Trivial assert + tl.device_assert(0 == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=False) +def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_assert(x < 32, "x >= 32") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + assert x < 32, "x >= 32" + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_static_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_assert(BLOCK == 128, "BLOCK != 128") + tl.store(Y + tl.arange(0, BLOCK), x) + + +def test_assert(func: str): + shape = (128, ) + x = torch.arange(0, shape[0], dtype=torch.int32, device='gcu') + y = torch.zeros(shape, dtype=x.dtype, device="gcu") + if func == "device_assert": + kernel_device_assert[(1, )](x, y, BLOCK=shape[0]) + kernel_device_assert_scalar[(1, )](x, y, BLOCK=shape[0]) + elif func == "no_debug": + # TRITON_DEBUG=1 can override the debug flag + kernel_device_assert_no_debug[(1, )](x, y, BLOCK=shape[0]) + elif func == "assert": + kernel_assert[(1, )](x, y, BLOCK=shape[0]) + elif func == "static_assert": + kernel_static_assert[(1, )](x, y, BLOCK=shape[0]) + assert_close(y, x) + + +@triton.jit +def jit_device_assert_none(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit(debug=True) +def jit_device_assert_true(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit(debug=False) +def jit_device_assert_false(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit +def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=True) +def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=False) +def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +def test_assert_nested(caller: str, callee: str): + shape = (128, ) + x = torch.arange(0, shape[0], dtype=torch.int32, device='gcu') + y = torch.zeros(shape, dtype=x.dtype, device="gcu") + if caller == "none": + kernel_device_assert_nested[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) + elif caller == "true": + kernel_device_assert_nested_true[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) + elif caller == "false": + kernel_device_assert_nested_false[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) + assert_close(y, x) + + +if __name__ == "__main__": + if len(sys.argv) == 3: + test_assert_nested(sys.argv[1], sys.argv[2]) + else: + test_assert(sys.argv[1]) diff --git a/third_party/enflame/python/test/unit/language/common_utils.py b/third_party/enflame/python/test/unit/language/common_utils.py new file mode 100644 index 000000000..ea0806078 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/common_utils.py @@ -0,0 +1,74 @@ +import pytest +import csv +import re +import os +import triton +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + + +def check_skip(file_path, test_name, params=None): + """ + Checks if a test case should be skipped based on data from 'skip_tests_write.csv'. + Args: + file_path (str): relative path of csv file + test_name (str): the name of the test function + params (dict, optional): input params to the test function + Usage: + import inspect + + def test_func(): + check_skip("skip.csv", inspect.currentframe().f_code.co_name, locals()) + """ + if params is None: + params = {} + + # Remove 'device' key if present + if (test_name != "test_pointer_arguments" or params["device"] == "cuda"): + params.pop("device", None) + + if "tmp_path" in params: + params.pop("tmp_path", None) + + if "test_gather_warp_shuffle" in test_name: + params.pop("src_layout", None) + params.pop("indices_layout", None) + + # for test_convertmma2mma convert mma_pair object to parseable string + if "mma_pair" in params: + params["mma_pair"] = [ + re.sub(r"versionMajor=\d+,?\s*|versionMinor=\d+,?\s*", "", str(x)) for x in params["mma_pair"] + ] + + target_arch = triton.runtime.driver.active.get_current_target().arch.split("--")[1] + + # Generate a unique case identifier + case_identifier = (test_name + "_" + "".join(f"{key}={value}-" for key, value in params.items())) + + try: + print(f"\nDEBUG(check_skip): Case Identifier: <{case_identifier}>\n") + if (target_arch == 'gcu400' or target_arch == 'gcu410'): + if ("num_warps" in params and params["num_warps"] == 8): + pytest.skip("num_warps 8 is not supported in gcu4xx") + # Read the CSV file and check if the test case should be skipped + script_dir = os.path.dirname(os.path.abspath(__file__)) + file_path = os.path.join(script_dir, file_path) + + with open(file_path, mode="r", newline="", encoding="utf-8") as file: + reader = csv.DictReader(file) + for row in reader: + if not (target_arch in row["arch"].split("/")): + continue + if row["case_identifier"] == case_identifier: + # Construct skip message + reason = row["resb"] if row["resb"] else "Failed on gcu" + skip_message = f"{reason}: {row['resa']}" + print("Skipping with message: ", skip_message) + + pytest.skip(skip_message) + + except FileNotFoundError: + print("Error: 'test_core_skip.csv' not found.") + except KeyError as e: + print(f"Missing expected column in CSV: {e}") diff --git a/third_party/enflame/python/test/unit/language/conftest.py b/third_party/enflame/python/test/unit/language/conftest.py new file mode 100644 index 000000000..091f9ea41 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/conftest.py @@ -0,0 +1,5 @@ +# content of conftest.py + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") diff --git a/third_party/enflame/python/test/unit/language/print_helper.py b/third_party/enflame/python/test/unit/language/print_helper.py new file mode 100644 index 000000000..6246d0e1a --- /dev/null +++ b/third_party/enflame/python/test/unit/language/print_helper.py @@ -0,0 +1,153 @@ +import sys +import uuid + +import torch +import torch_gcu +from torch.testing import assert_close + +import triton +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import triton.language as tl + + +def get_current_target_warp_size(): + return triton.runtime.driver.active.get_current_target().warp_size + + +@triton.jit +def kernel_device_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_hex(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x, hex=True) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Triton should add a space after this prefix. + print("x:", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_scalar(SCALAR): + x = tl.load(SCALAR) + # Triton should add a space after this prefix. + print("x:", x) + + +@triton.jit +def kernel_device_print_large( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32) + # Triton should change this prefix to "x: ". + tl.device_print("x ", x) + + +@triton.jit +def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + print("", x, y) + + +@triton.jit +def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + tl.device_print("", x, y) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr): + # This function takes an extra value as a tl.constexpr so this kernel is not + # cached. This way the static print is run every time. + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_print("", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_no_arg_print(): + print("", tl.program_id(0)) + + +@triton.jit +def kernel_print_no_arg(): + print("no arg") + + +@triton.jit +def kernel_print_pointer(X, Y, BLOCK: tl.constexpr): + tl.device_print("ptr ", X + tl.arange(0, BLOCK)) + + +def test_print(func: str, data_type: str, device: str): + N = 128 # This value should match with test_print in test_subprocess.py. + # TODO(antiagainst): Currently the warp count is chosen to make sure wedon't have multiple + # threads printing duplicated messages due to broadcasting. Improve print op lowering logic + # to filter out duplicated data range. + + # For triton_gcu backend, num_warps must be a power of 2. If we give a value of 128 for N, and get_current_target_warp_size() returns 12 for gcu. The num_warps will be 10 which is not a power of 2. + # num_warps = N // get_current_target_warp_size() + num_warps = 4 + + x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type)) + y = torch.zeros((N, ), dtype=x.dtype, device=device) + if func == "device_print": + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_scalar": + scalar = torch.tensor(42, dtype=x.dtype, device=device) + kernel_device_print_scalar[(1, )](scalar, num_warps=num_warps) + elif func == "device_print_negative": + x = -x + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_uint": + x = torch.arange((1 << 31), (1 << 31) + N, device=device).to(getattr(torch, data_type)) + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "print": + kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_large": + kernel_device_print_large[(1, 2)](BLOCK_M=64, num_warps=num_warps, BLOCK_N=N) + elif func == "print_multiple_args": + kernel_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_multiple_args": + kernel_device_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "static_print": + kernel_static_print[(1, )](x, y, num_warps=num_warps, BLOCK=N, PLACEHOLDER=uuid.uuid4()) + elif func == "no_arg_print": + kernel_no_arg_print[(1, )](num_warps=num_warps) + elif func == "print_no_arg": + kernel_print_no_arg[(1, )](num_warps=num_warps) + elif func == "device_print_hex": + kernel_device_print_hex[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_pointer": + kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N) + else: + assert f"Unknown kernel: {func}" + + if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ + func != "print_multiple_args" and func != "device_print_multiple_args" and \ + func != "device_print_pointer" and func != "device_print_scalar": + assert_close(y, x) + + # Wait until driver complete all the jobs for the device_print, especially test_subprocess + # require this which captures stdout when child exits. + getattr(torch, device).synchronize() + + +if __name__ == "__main__": + fn = globals()[sys.argv[1]] + fn(*sys.argv[2:]) diff --git a/third_party/enflame/python/test/unit/language/test_annotations.py b/third_party/enflame/python/test/unit/language/test_annotations.py new file mode 100644 index 000000000..391d1a981 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_annotations.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl +import pytest + +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import torch_gcu + + +def test_annotations(device): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32]#, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X + v, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + if not signed and width < 64: + assert "arith.extui %arg1" in h.asm["ttir"] + assert f'%arg1: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] + + +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass diff --git a/third_party/enflame/python/test/unit/language/test_block_pointer.py b/third_party/enflame/python/test/unit/language/test_block_pointer.py new file mode 100644 index 000000000..cad8cc797 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_block_pointer.py @@ -0,0 +1,95 @@ +import pytest +import torch + +import triton +import triton.language as tl + +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import torch_gcu + + +@triton.jit +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): + pid = tl.program_id(0) + # We only copy half of the data to see if the padding works + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + tl.store(b_block_ptr, a, boundary_check=(0, )) + + +@pytest.mark.parametrize("dtype_str, n, padding_option", [(dtype_str, n, padding) + for dtype_str in ("bool", "int16", "float16") + for n in (64, 128, 256, 512, 1024) + for padding in ("zero", "nan")]) +def test_block_copy(dtype_str, n, padding_option, device): + capability = torch.gcu.get_device_capability() + if capability[0] >= 9: + pytest.skip("Hopper support is working in progress") + + dtype = getattr(torch, dtype_str) + if dtype_str in ("bool", "int16"): + if padding_option == "nan": + pytest.skip("Padding with NaN is not supported for integer types") + a = torch.randint(0, 2, (n, ), device=device, dtype=dtype) + else: + a = torch.randn((n, ), device=device, dtype=dtype) + b = torch.zeros((n, ), device=device, dtype=dtype) + + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) + + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + else: + assert torch.all(torch.isnan(b[n // 2:n])) + + +@triton.jit +def matmul_no_scf_with_advance_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + # Below two lines are just for testing negative offsets for the `advance` API, which could be removed + a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) + a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) + a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero") + b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero") + + c = tl.dot(a, b) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.parametrize("shape, num_warps", [(shape, num_warps) for shape in [ + [64, 64, 16], + [64, 64, 32], + [64, 64, 64], +] for num_warps in [4, 8]]) +def test_block_ptr_matmul_no_scf(shape, num_warps, device): + capability = torch.gcu.get_device_capability() + if capability[0] >= 9: + pytest.skip("Hopper support is working in progress") + + m, n, k = shape + a = torch.randn((m, k), device=device, dtype=torch.float16) + b = torch.randn((k, n), device=device, dtype=torch.float16) + c = torch.empty((m, n), device=device, dtype=torch.float32) + + grid = lambda META: (1, ) + matmul_no_scf_with_advance_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=m, N=n, K=k, stride_am=a.stride(0), + stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=m, BLOCK_N=n, + BLOCK_K=k, num_warps=num_warps) + golden = torch.matmul(a, b) + torch.testing.assert_allclose(c, golden) diff --git a/third_party/enflame/python/test/unit/language/test_compile_errors.py b/third_party/enflame/python/test/unit/language/test_compile_errors.py new file mode 100644 index 000000000..8845d5fc7 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_compile_errors.py @@ -0,0 +1,460 @@ +import contextlib +import pytest +import os + +import torch +import triton +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import torch_gcu +import triton.language as tl + +from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure +import traceback +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_hip_mi350 + + +def format_exception(type, value, tb): + list_msg = traceback.format_exception(type, value, tb, chain=False) + return "\n".join(list_msg) + + +def test_err_undefined_variable(): + + @triton.jit + def kernel(): + a += 1 # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "is not defined" in err_msg, "error should mention the undefined variable" + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_operator(): + + @triton.jit + def kernel(): + 0 + "a" # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the 0" + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_static_assert(): + + @triton.jit + def kernel(): + tl.static_assert(isinstance(0, tl.tensor)) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + assert isinstance(e.value, CompileTimeAssertionFailure) + assert e.value.__cause__ is None + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + print(err_msg) + assert "at 2:4:" in err_msg, "error should point to the static_assert call" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_unary_op(): + # Currently Triton can't evaluate `not` of a tuple at compile time. That's + # ok, but the error message needs to point to the correct spot. + @triton.jit + def kernel(): + not (0, 0) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + assert e.value.__cause__ is None + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the `not`" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_op(): + + @triton.jit + def kernel(): + 1.0 << 1 # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + err_msg = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in err_msg, "error should point to the 1.0" + assert "" not in err_msg + assert "code_generator.py" not in err_msg + except AssertionError as assertion_err: + raise assertion_err from e.value + + +# This has to be defined as a top-level function; jit'ed functions can't call +# nested functions. +@triton.jit +def nested_call(): + xyz # noqa + + +def test_err_in_nested_call(): + + @triton.jit + def kernel(): + # this is a comment to push nested_call() onto the next line + nested_call() + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + inner_exc = e.value.__cause__ + inner = format_exception(inner_exc.__class__, inner_exc, inner_exc.__traceback__) + assert "at 2:4:" in inner, "error should point to xyz" + assert "" not in inner + assert "code_generator.py" not in inner + + outer = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 3:4" in outer, "error should point to the nested_call" + assert "" not in outer + assert "code_generator.py" not in outer + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_builtin(): + + # The root error here comes from core.py. Make sure the stacktrace reflects + # this. + @triton.jit + def kernel(): + tl.expand_dims(None, -1) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + try: + inner_exc = e.value.__cause__ + inner = format_exception(inner_exc.__class__, inner_exc, inner_exc.__traceback__) + assert f"{os.sep}core.py" in inner, "error should point inside core.py" + assert "code_generator.py" not in inner + + outer = format_exception(e.type, value=e.value, tb=e.tb) + assert "at 2:4:" in outer, "error should point to expand_dims call" + assert "" not in outer + assert "code_generator.py" not in outer + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@triton.jit +def two_returns(): + return tl.arange(0, 4) + return tl.arange(0, 8) + + +def test_two_returns_no_err(): + # This program is valid; `a` has shape (10,). + @triton.jit + def kernel(): + a = two_returns() + a + tl.arange(0, 4) # only works if we took the first return + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +def test_not_const_annotate_no_err(): + + @triton.jit + def kernel(N: int = 1): + pass + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) + + +@triton.jit +def returns_branched_on_constexpr(N: tl.constexpr): + if N == 0: + return tl.arange(0, 4) + # Ideally this would work even without the `else`, but we're not that smart + # yet. + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_constexpr(): + + @triton.jit + def kernel1(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 4) + + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0})) + + @triton.jit + def kernel2(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 8) + + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1})) + + +@triton.jit +def returns_branched_on_non_constexpr(N: int): + if N == 0: + return tl.arange(0, 4) + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_non_constexpr(): + + @triton.jit + def kernel(N: int): + returns_branched_on_non_constexpr(N) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the function call" + assert "at 5:8:" in str(e.value.__cause__), "error should point to the second `return`" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_power_of_two_shapes(): + + @triton.jit + def kernel(): + tl.arange(2, 7) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert str(e.value.__cause__) == "arange's range must be a power of 2" + + +def test_power_of_two_shapes_2(): + + @triton.jit + def kernel(): + tl.full((33, ), 0, dtype=tl.int64) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" + + +def test_captured_var_access(): + + CAPTURED = 42 + + @triton.jit + def kernel(): + a = CAPTURED # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert "CAPTURED is not defined" in str(e.value) + + +GLOBAL = 42 + + +def test_global_var_access(): + + @triton.jit + def kernel(): + a = GLOBAL # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert "global variable" in str(e.value) + + +CONSTEXPR_ANNOTATED_GLOBAL: tl.constexpr = 42 + + +def test_constexpr_annotated_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_ANNOTATED_GLOBAL # noqa + + # No error. + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + assert False, "Using a constexpr annotated global variable should not be allowed" + except CompilationError as e: + assert "Cannot access global variable" in str(e) + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_constexpr_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +TYPE_ALIAS = tl.pointer_type(tl.int32) + + +def test_global_type_alias_access(): + + @triton.jit + def kernel(): + a = TYPE_ALIAS # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +def test_global_access_in_fn_default_arg(): + + @triton.jit + def kernel(a=GLOBAL): + pass + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={})) + + +def test_defaults_assign_no_err(): + + @triton.jit + def kernel(a=1, B: tl.constexpr = ""): + pass + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""})) + + +def test_where_warning(fresh_triton_cache): + + @triton.jit + def kernel(): + a = tl.full((64, ), 0, tl.uint32) + b = tl.full((64, ), 1, tl.float32) + c = tl.full((64, ), 2, tl.float32) + tl.where(a, b, c) + + with pytest.warns(UserWarning): + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + + +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) +def test_fp8_support(fresh_triton_cache, dtype): + warning_dtypes = [] + supported_dtypes = [] + if is_cuda(): + cc = torch.cuda.get_device_capability(0) + supported_dtypes.append(tl.float8e4b15) + if cc >= (9, 0): + warning_dtypes.append(tl.float8e4b15) + if cc >= (8, 9): + supported_dtypes.append(tl.float8e4nv) + elif is_hip(): + if is_hip_mi300(): + supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16] + if is_hip_mi350(): + supported_dtypes += [tl.float8e4nv] + + @triton.jit + def dtype_kernel(dtype: tl.constexpr): + _ = tl.full((256, ), 0.0, dtype) + + if dtype in warning_dtypes: + ctx = pytest.warns(UserWarning, match=r"fp8e4b15 is deprecated in this architecture") + elif dtype in supported_dtypes: + ctx = contextlib.nullcontext() + else: + ctx = pytest.raises(CompilationError, match="") + + with ctx as e: + triton.compile( + triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + + if dtype not in supported_dtypes: + try: + assert ("not supported in this architecture" in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.int8, tl.float16]) +def test_min_dot_size(dtype): + error_msg = "Input shapes should have " + if is_cuda(): + if dtype.primitive_bitwidth == 8: + error_msg += "M >= 16, N >= 16 and K >= 32" + else: + error_msg = "M >= 16, N >= 16 and K >= 16" + elif is_hip(): + # hip supports arbitrary sizes + error_msg = None + else: + pytest.skip("Test only supported on CUDA and HIP") + + @triton.jit + def dot_kernel(dtype: tl.constexpr): + SIZE: tl.constexpr = 8 + a = tl.full((SIZE, SIZE), 0.0, dtype) + b = tl.full((SIZE, SIZE), 0.0, dtype) + tl.dot(a, b) + + if error_msg is None: + triton.compile( + triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + else: + with pytest.raises(CompilationError) as e: + triton.compile( + triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + try: + assert (error_msg in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +# def test_max_num_imprecise_acc_limit(): + +# @triton.jit +# def dot_kernel(): +# SIZE: tl.constexpr = 64 +# a = tl.full((SIZE, SIZE), 0.0, tl.float8e5) +# b = tl.full((SIZE, SIZE), 0.0, tl.float8e5) +# tl.dot(a, b, max_num_imprecise_acc=128) + +# with pytest.raises(CompilationError) as e: +# triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={})) +# try: +# assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)") +# except AssertionError as assertion_err: +# raise assertion_err from e.value diff --git a/third_party/enflame/python/test/unit/language/test_conversions.py b/third_party/enflame/python/test/unit/language/test_conversions.py new file mode 100644 index 000000000..14af6cdf2 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_conversions.py @@ -0,0 +1,359 @@ +# fmt: off + + +import os +import numpy as np +import torch +import torch_gcu +from torch_gcu import transfer_to_gcu +import pytest +import triton +import triton.language as tl +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + +def is_cuda(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" + +def is_hip(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" + +def is_on_mi300(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') + +def matching_int(dtype): + if dtype.primitive_bitwidth == 8: + return torch.int8 + elif dtype.primitive_bitwidth == 16: + return torch.int16 + elif dtype.primitive_bitwidth == 32: + return torch.int32 + elif dtype.primitive_bitwidth == 64: + return torch.int64 + else: + raise ValueError('unsupported number of bits') + +@triton.jit +def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) + tl.store(dst + idxs, y) + + +def launch_type_convert_triton(src, src_dtype, dst_dtype, device, rounding=None, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE) + return dst + + +@triton.jit +def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + vals = (idxs + offset).to(tl.uint32) + + # pseudorandom permutation: + multiplier = vals << 1 + multiplier += 3511 + vals *= multiplier + + if force_odd: + vals *= 2 + vals += 1 + + if (output_bits == 8): + vals &= 0xff + avals = vals & 0x7f + elif (output_bits == 16): + vals &= 0xffff + avals = vals & 0x7fff + elif (output_bits == 32): + avals = vals & 0x7fffffff + + vals = tl.where(avals <= max_repr, vals, 0) + + if (output_bits == 8): + vals = vals.to(tl.uint8) + elif (output_bits == 16): + vals = vals.to(tl.uint16) + + vals = vals.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, vals) + + +def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, device, BLOCK_SIZE=4096): + + assert(numel % BLOCK_SIZE == 0) + dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device=device) + exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. We don't need to have that + # as input to the conversion kernels. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(x.dtype == tl.float32, "input must be float32") + numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16") + + x = x.to(tl.uint32, bitcast=True) + + mantissa = (x & 0x7fffff) + exponent = ((x >> 23) & 0xff).to(tl.int32) + mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32) + exponent = tl.where(exponent == 0, exponent, exponent - 1) + + sign = (x >> 31) + + exponent = exponent + exponent_bias - 127 + adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits) + mantissa = mantissa.to(tl.float32) * adjustment + + # make exponent nonnegative: + mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe + exponent = tl.where(exponent > -16, exponent, 0) + mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625) + exponent = tl.where(exponent > -8, exponent, exponent + 8) + mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625) + exponent = tl.where(exponent > -4, exponent, exponent + 4) + mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25) + exponent = tl.where(exponent > -2, exponent, exponent + 2) + mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5) + exponent = tl.where(exponent > -1, exponent, exponent + 1) + + if rounding == 'rtne': + # Bring the value to the range [2 ** 23, 2 ** 24] + # where the representable floats map exactly to integers. + # Addition has RTNE semantics. + mantissa += 0x800000 + # Bring the value back to the original range. + mantissa -= 0x800000 + mantissa = mantissa.to(tl.int32) + elif rounding == 'rtz': + mantissa = mantissa.to(tl.int32) + else: + raise ValueError('unrecognized rounding mode') + + # Reassemble output floating-point representation: + exponent = exponent.to(tl.uint32) + y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa + if numbits_dst == 8: + y = y.to(tl.uint8) + elif numbits_dst == 16: + y = y.to(tl.uint16) + return y + + +@triton.jit +def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias) + y = y.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, y) + + +def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + downcast_emulated[(src.shape[0] // BLOCK_SIZE,)]( + triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. downcast_emulated kernel will + # convert -0. in higher precision to 0x80 and thus need to fix the result to 0. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias) + + numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16") + tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + + if numbits_src == 8: + x = x.to(tl.uint8, bitcast=True) + elif numbits_src == 16: + x = x.to(tl.uint16, bitcast=True) + + x = x.to(tl.uint32) + + mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1 + exponent_mask : tl.constexpr = (1 << exponent_bits) - 1 + + mantissa = x & mantissa_mask + exponent = (x >> mantissa_bits) & exponent_mask + sign = (x >> (numbits_src - 1)) + + y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits)) + y = y.to(tl.float32, bitcast=True) + y = y * exponent_compensator + + tl.store(dst + idxs, y) + + +def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=torch.int32, device=device) + upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + return dst + + +def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset, device): + + src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr, device) + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device, rounding=rounding) + src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device) + + dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device) + + dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device) + dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device) + + if not (torch.equal(dst, dst2)): + print('Error!!!') + + dst = dst.cpu().detach().numpy() + dst2 = dst2.cpu().detach().numpy() + src = src.cpu().detach().numpy() + + print(src[dst != dst2][0]) + print(dst[dst != dst2][0]) + print(dst2[dst != dst2][0]) + print(hex(src.view(np.uint32)[dst != dst2][0])) + print(hex(dst.view(np.uint32)[dst != dst2][0])) + print(hex(dst2.view(np.uint32)[dst != dst2][0])) + print('') + raise ValueError('%d elements mismatch' % (dst != dst2).sum()) + + +def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr, device): + + numbits_src = exponent_bits + mantissa_bits + 1 + + src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr, device=device) + + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device) + dst = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) + + dst2 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) + + assert(torch.equal(dst, dst2)) + + +@pytest.mark.parametrize("src_dtype, dst_dtype", [ + ('float16', 'float32'), + ('bfloat16', 'float32'), + + # ('float8e5', 'float16'), + # ('float8e5', 'bfloat16'), + # ('float8e5', 'float32'), + + # ('float8e4b15', 'float16'), + # # ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16 + # ('float8e4b15', 'float32'), + + # ('float8e4nv', 'float16'), + # ('float8e4nv', 'bfloat16'), + # ('float8e4nv', 'float32'), + + ('float8e4b8', 'float32'), + ('float8e4b8', 'float16'), + + ('float8e5b16', 'float32'), + ('float8e5b16', 'float16'), +]) +def test_typeconvert_upcast(src_dtype, dst_dtype, device): + + if src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("float8e4nv upcast tests only supported on NVGPU with compute capability 9.0+") + + if src_dtype in ('float8e4nv', 'float8e4b15') and is_hip(): + pytest.skip(f"{src_dtype} upcast tests not supported on ROCm") + + if src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()): + pytest.skip("{src_dtype} upcast tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) + stuff = { + 'float8e4b15': (4, 3, 15, 0x7e), + 'float8e4nv': (4, 3, 7, 0x7e), + 'float8e5': (5, 2, 15, 0x7b), + 'float8e4b8': (4, 3, 8, 0x7f), + 'float8e5b16': (5, 2, 16, 0x7f), + 'float16': (5, 10, 15, 0x7bff), + 'bfloat16': (8, 7, 127, 0x7f7f), + }[src_dtype] + + upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff, device=device) + +# @pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [ +# ('float32', 'float16', 'rtne', 0x477fe000), +# ('float32', 'float16', 'rtz', 0x477fe000), +# ('float32', 'bfloat16', 'rtne', 0x7f7f0000), +# ('float32', 'bfloat16', 'rtz', 0x7f7f0000), +# ('float32', 'float8e5', 'rtne', 0x47600000), +# ('float32', 'float8e5', 'rtz', 0x47600000), +# ('float32', 'float8e4nv', 'rtne', 0x43e00000), +# ('float32', 'float8e4b8', 'rtne', 0x43700000), +# ('float32', 'float8e5b16', 'rtne', 0x47600000), +# # ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15 + +# ('bfloat16', 'float8e5', 'rtne', 0x4760), +# ('bfloat16', 'float8e4nv', 'rtne', 0x43e0), + +# ('float16', 'float8e5', 'rtne', 0x7b00), +# ('float16', 'float8e4nv', 'rtne', 0x5f00), + +# ('bfloat16', 'float8e5b16', 'rtne', 0x4760), +# ('bfloat16', 'float8e4b8', 'rtne', 0x4370), + +# ('float16', 'float8e5b16', 'rtne', 0x7b00), +# ('float16', 'float8e4b8', 'rtne', 0x5b80), +# ]) +# def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + +# if src_dtype != 'float32' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): +# pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") + +# if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)): +# pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + +# if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()): +# pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + +# # dtype : (exponent_bits, mantissa_bits, exponent_bias) +# stuff = { +# 'float16': (5, 10, 15), +# 'bfloat16': (8, 7, 127), +# 'float8e5': (5, 2, 15), +# 'float8e4b15': (4, 3, 15), +# 'float8e4nv': (4, 3, 7), +# 'float8e4b8': (4, 3, 8), +# 'float8e5b16': (5, 2, 16), +# }[dst_dtype] + +# for i in range(256): +# downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device) diff --git a/third_party/enflame/python/test/unit/language/test_core.py b/third_party/enflame/python/test/unit/language/test_core.py new file mode 100644 index 000000000..6d599706d --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_core.py @@ -0,0 +1,7488 @@ +# ruff: noqa: F821,F841 +import contextlib +import itertools +import re +from typing import Optional +import math +import textwrap +import pathlib + +import numpy as np +import pytest +import torch +import os +import inspect +from numpy.random import RandomState + +import triton +import triton.language as tl +from triton.language.extra import libdevice + +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import sys +import torch_gcu + +from triton._internal_testing import ( + integral_dtypes, + int_dtypes, + str_to_triton_dtype, + uint_dtypes, + float_dtypes, + float_dtypes_with_bfloat16, + dtypes, + dtypes_with_bfloat16, + is_cuda, + is_interpreter, + is_hopper, + is_hip, + is_hip_cdna, + is_hip_mi200, + is_hip_mi300, + is_hip_mi350, + is_xpu, + get_arch, + torch_float8_dtypes, + torch_dtypes, + numpy_random, + to_triton, + torch_dtype_name, + to_numpy, +) +from triton.runtime.errors import InterpreterError + +from common_utils import check_skip + + +def is_gcu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "gcu" + + +@contextlib.contextmanager +def promotion_numpy_2_0(): + state = np._get_promotion_state() + np._set_promotion_state("weak") + try: + yield + finally: + np._set_promotion_state(state) + + +# No need to emulate NumPy 2.0 if the user has NumPy 2.0 +if np.__version__[0] != "1": + promotion_numpy_2_0 = contextlib.nullcontext + +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + +mma_nonk_sizes = [] + +GPU_DIALECT = "ttg" +if is_interpreter(): + THREADS_PER_WARP = 1 +elif is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size + # for CDNA multiple variants of mma instructions are supported: + # mfma 16x16/mfma 32x32 + # 0 is a special value for automatic heuristic + if is_hip_cdna(): + mma_nonk_sizes = [0, 16, 32] +elif is_gcu(): + THREADS_PER_WARP = 1 +else: + THREADS_PER_WARP = 32 + + +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def _dtype(dtype: str) -> str: + # ex.: "int64" -> "int" + return re.match(r'([a-zA-Z]+)', dtype).group(0) + + +def patch_kernel(template, to_replace): + if is_interpreter(): + local_namespace = {} + src = textwrap.dedent(inspect.getsource(template.fn)) + for k, v in to_replace.items(): + src = src.replace(k, v) + exec(src, globals(), local_namespace) + return local_namespace[template.fn.__name__] + else: + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel._unsafe_update_src(kernel.src.replace(key, value)) + return kernel + + +def check_cuda_or_hip(device): + # CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel + # GPU do not. + if device not in ['cuda']: + pytest.skip("Not supported on gcu: CUDA is not supported on gcu") + + +def check_type_supported(dtype, device): + ''' + skip test if dtype is not supported on the current device + ''' + target_arch = triton.runtime.driver.active.get_current_target().arch.split("--")[1] + if device in ['cuda']: + cc = torch.cuda.get_device_capability() + if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") + if is_interpreter(): + if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: + pytest.skip("bfloat16 is not supported in the interpreter") + if dtype in ["float64"]: + pytest.skip("Not supported on gcu: dtype fp64 is not supported on gcu") + elif dtype in ["tf32"]: + pytest.skip("Not supported on gcu: dtype tf32 is not supported on gcu") + elif dtype in [tl.float8e4nv, "float8e4nv"] and target_arch == 'gcu300': + pytest.skip("Not supported on gcu300: dtype fp8 is not supported on gcu300") + elif dtype in ["float8_e4m3fn"]: + pytest.skip("Not supported on gcu300/gcu400: dtype float8_e4m3fn is not supported on gcu300/gcu400") + elif dtype in ["int64"]: + pytest.skip("Not supported on gcu300: dtype i64 is not supported on gcu300") + elif dtype in ["uint64"]: + pytest.skip("Not supported on gcu300: dtype ui64 is not supported on gcu300") + + +class MfmaLayout: + + def __init__(self, version, warps_per_cta, instr_shape, is_transposed): + self.version = version + self.warps_per_cta = warps_per_cta + self.instr_shape = instr_shape + self.is_transposed = is_transposed + + def __str__(self): + return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>" + + +class WmmaLayout: + + def __init__(self, version, warps_per_cta): + self.version = version + self.warps_per_cta = warps_per_cta + + def __str__(self): + return f"#{GPU_DIALECT}.amd_wmma<{{version = {self.version}, warpsPerCTA = {self.warps_per_cta}}}>" + + +class MmaLayout: + + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): + self.version = version + self.warps_per_cta = warps_per_cta + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + self.instr_shape = instr_shape + + def __str__(self): + return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" + + +class DotOperandLayout: + + def __init__(self, parent, op_idx, k_width): + self.parent = parent + self.op_idx = op_idx + self.k_width = k_width + + def __str__(self): + return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" + + +class SliceLayout: + + def __init__(self, dim, parent): + self.dim = dim + self.parent = parent + + def __str__(self): + return f"#{GPU_DIALECT}.slice<{{dim = {self.dim}, parent = {self.parent}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1], + cta_split_num=[1, 1], cta_order=[0, 1]): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class SharedLayout: + + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): + self.vec = vec + self.per_phase = per_phase + self.max_phase = max_phase + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class NVMMASharedLayout: + + def __init__(self, swizzle, transpose, element_bit_width, ctas_per_cga, cta_split_num, cta_order): + self.swizzle = swizzle + self.transpose = transpose + self.element_bit_width = element_bit_width + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + transpose_str = "true" if self.transpose else "false" + return f"#{GPU_DIALECT}.nvmma_shared<{{swizzlingByteWidth={self.swizzle}, transposed={transpose_str}, elementBitWidth={self.element_bit_width}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class LinearLayout: + + def __init__(self, register, lane, warp, block): + self.register = register + self.lane = lane + self.warp = warp + self.block = block + + def __str__(self): + return f"#{GPU_DIALECT}.linear<{{register={self.register}, lane={self.lane}, warp={self.warp}, block={self.block}}}>" + + +# Python impl of LinearEncodingAttr::basesPerDim +def bases_per_dim(layout, dim, rank, skip_broadcast=True): + assert isinstance(layout, LinearLayout) + bases = getattr(layout, dim) + result = [1] * rank + + if not bases: + return result + + non_zero_idx = None + + for basis in bases: + # Find the first non-zero index in the current basis + idx = next((i for i, v in enumerate(basis) if v != 0), None) + if idx is not None: + non_zero_idx = idx + result[idx] *= 2 + elif not skip_broadcast: + # If no non-zero found and we're not skipping broadcasts, use the last found non-zero index + assert non_zero_idx is not None + result[non_zero_idx] *= 2 + + return result + + +def warps_per_cta(layout, shape): + if isinstance(layout, LinearLayout): + return bases_per_dim(layout, 'warp', len(shape)) + elif isinstance(layout, (SliceLayout, DotOperandLayout)): + return warps_per_cta(layout.parent, shape) + else: + return layout.warps_per_cta + + +def is_layout_applicable(layout) -> bool: + if isinstance(layout, (BlockedLayout, SharedLayout)): + return True + elif isinstance(layout, SliceLayout): + return is_layout_applicable(layout.parent) + elif is_gcu(): + return False + elif isinstance(layout, LinearLayout): + return True + elif is_cuda(): + mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout + if not isinstance(mma_layout, MmaLayout): + return False + if mma_layout.version[0] >= 3 and not is_hopper(): + return False + return True + elif is_hip(): + target_arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in target_arch: + # RDNA 3 + return isinstance(layout, WmmaLayout) + elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch): + # CDNA 1, 2, 3 + return isinstance(layout, MfmaLayout) + else: + return False + else: + return True + + +def filter_layouts(layouts): + return [l for l in layouts if is_layout_applicable(l)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) +def test_empty_kernel(dtype_x, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + SIZE = 128 + + @triton.jit + def kernel(X, SIZE: tl.constexpr): + pass + + check_type_supported(dtype_x, device) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) + kernel[(1, )](x, SIZE=SIZE, num_warps=4) + + +def test_scalar_overflow(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(): + huge_int: tl.constexpr = 0xFFFFFFFFFFFFFF + x = tl.full((), 32, dtype=tl.int32) + y = x + huge_int + + with pytest.raises(triton.TritonError, match="out of range"): + kernel[(1, )]() + + +# generic test functions +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = numpy_random(SIZE, dtype_str=dtype_x) + if 'log' in expr: + x = np.abs(x) + 0.01 + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x) + kernel[(1, )](Z=z_tri, X=x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + # compare + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + + +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + x_low=None, x_high=None, y_low=None, y_high=None, filter_y=None, test_broadcast=True, + test_scalar=True): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + check_type_supported(dtype_y, device) + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_lhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + replacements = {'GENERATE_TEST_HERE': expr} + kernel = patch_kernel(kernel, replacements) + kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements) + kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements) + kernel_scalar_rhs = patch_kernel(kernel_scalar_rhs, replacements) + + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs, low=x_low, high=x_high) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) + if filter_y: + y[filter_y(y)] = 1 + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') + + def do_test(x, y, kernel_fn): + x_is_scalar = isinstance(x, (bool, int, float)) + y_is_scalar = isinstance(y, (bool, int, float)) + scalar_test = x_is_scalar or y_is_scalar + + # For scalars, we follow the NumPy 2.0 (and JAX/PyTorch pretty much) casting rules. + if scalar_test: + # We remove any explicit casting + pattern = r'\.astype\(np\.\w+\)' + scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr) + with promotion_numpy_2_0(): + z_ref = eval(scalar_expr) + else: + z_ref = eval(expr if numpy_expr is None else numpy_expr) + + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if not scalar_test and dtype_z is not None: + z_ref = z_ref.astype(dtype_z) + if z_ref.dtype == np.float64: + z_ref = z_ref.astype(np.float32) + # triton result + x_tri = x if x_is_scalar else to_triton(x, device=device, dst_type=dtype_x) + y_tri = y if y_is_scalar else to_triton(y, device=device, dst_type=dtype_y) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + err_msg = f"{expr}, {kernel_fn.__name__}" + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=7e-3, rtol=0.01) + + def get_scalar(x, dtype, low, high, filter): + # If dtype is int, don't choose a huge number for the scalar + # as it'll overflow easily when converted to the other dtype + if dtype in integral_dtypes: + # Choose in range [-7, 7] ([0, 7] for uints) + low_x = 0 if dtype in uint_dtypes else -7 + if low is not None: + low_x = max(low_x, low) + high_x = 7 + if high is not None: + high_x = min(high_x, high) + scalar = numpy_random((), dtype_str=dtype, rs=rs, low=low_x, high=high_x).item() + if filter and filter(scalar): + # https://xkcd.com/221/ + scalar = 4 + else: + scalar = x.flat[0].item() + return scalar + + do_test(x, y, kernel) + if mode_y != 'nan' and test_scalar: + if dtype_x in uint_dtypes: + low = 0 if y_low is None else max(y_low, 0) + else: + low = y_low + y_scalar = get_scalar(y, dtype_y, low, y_high, filter_y) + do_test(x, y_scalar, kernel_scalar_rhs) + if test_broadcast: + do_test(x[:1].reshape(()), y, kernel_broadcast_lhs) + do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) + + +def _min_max_integral_mod_value(dtype_x, dtype_y) -> Optional[int]: + """ + Limit min/max values for integral types for mod values. Leads to + overflow/underflow when casting large integral types to floats. + """ + x_bitwidth = _bitwidth(dtype_x) + y_bitwidth = _bitwidth(dtype_y) + + # hard cap max value bit-width to 32 if 64 bit-width types + min_bitwidth = min(x_bitwidth, y_bitwidth, 32) + + # Limit max value bit-width to be one integral type less than the min bit-width + # For example: + # int64, float32 -> int16 + # uint16, float16 -> uint8 + x_dtype = _dtype(dtype_x) + max_bitwidth = max(min_bitwidth >> 1, 8) + dtype_max = x_dtype + str(max_bitwidth) + + max_info = np.iinfo(getattr(np, dtype_max)) + + # Still need to limit values here for uints + if max_bitwidth >= 16 and dtype_max in uint_dtypes: + return max_info.min, max_info.max // 4 + else: + return max_info.min, max_info.max + + +def test_dtype_codegen(): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + for dtype in dtypes_with_bfloat16: + full_name = f"triton.language.{dtype}" + assert repr(eval(full_name)) == full_name + + +# --------------- +# test binary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + expr = f'x {op} y' + np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})') + + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + def promote_to_fp32(dtype_x, dtype_y): + return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64') + + if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)): + numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)') + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})') + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})') + elif op == '%': + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = np_expr_gen('x', 'y') + else: + numpy_expr = None + + if (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + else: + # skip when bfloat16, as NumPy's ref performs the computation in float32 + # while Triton performs it in bfloat16 + skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y) + or (op in ('/', '%') and dtype_x in ("float16", "bfloat16"))) + # can't divide by zero + not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes + # can't represent -int(max) + not_minus_one = op in ('*', '/') and dtype_x in int_dtypes and dtype_y in int_dtypes + if not_zero or not_minus_one: + filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1) + else: + filter_y = None + + if op == "%" and dtype_x in integral_dtypes and dtype_y in float_dtypes_with_bfloat16: + x_low, x_high = _min_max_integral_mod_value(dtype_x, dtype_y) + else: + x_low, x_high = None, None + + _test_binary( + dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, + # fails with values where fmod(x, y) is roughly zero, but happens to + # pass with the random values chosen for non-broadcast tests + test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) +def test_addptr(dtype, order, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype, device) + + @triton.jit + def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): + offs = tl.arange(0, SIZE) + if ORDER == 0: + tl.store(y + offs, tl.load(x + offs)) + else: + tl.store(offs + y, tl.load(offs + x)) + + SIZE = 1024 + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + x_tri = to_triton(x, dst_type=dtype, device=device) + y_tri = to_triton(y, dst_type=dtype, device=device) + y = x + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) + np.testing.assert_allclose(y, to_numpy(y_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_floordiv(dtype_x, dtype_y, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + # can't represent -int(max) + not_minus_one = dtype_x in int_dtypes and dtype_y in int_dtypes + if not_minus_one: + filter_y = lambda y: y == -1 + else: + filter_y = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) + + +def test_unsigned_name_mangling(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + # Test that uint32 and int32 are mangled differently by the compiler + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(O1, O2, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + out1 = tl.abs(x) # uint32 -> nop + out2 = tl.abs(-y) # int32 -> should have an effect + tl.store(O1 + off, out1) + tl.store(O2 + off, out2) + + dtype_x = 'uint32' + dtype_y = 'int32' + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + # reference result + expect = (np.abs(x), np.abs(-y)) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) + kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) + + # Bitwise op, so expect exact equality + assert (expect[0] == to_numpy(actual[0])).all() + assert (expect[1] == to_numpy(actual[1])).all() + + +# test bitwise ops +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if 'float' in dtype_x + dtype_y: + # The CompilationError must have been caused by a C++ exception with this text. + with pytest.raises(triton.TritonError, match='invalid operands of type'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas) + else: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + if dtype_x.startswith('int'): + dtype_z = f'int{bw}' + else: + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=bw) + + +# --------------- +# test compare ops +# --------------- +ops = ['==', '!=', '>', '<', '>=', '<='] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) + + +# --------------- +# test broadcast +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype, device) + + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + + +# ---------- +# test slice +# ---------- + + +@pytest.mark.interpreter +def test_slice(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) + + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1, )](XBLOCK=32) + + +# ------------------ +# test invalid slice +# ------------------ + + +@pytest.mark.interpreter +def test_invalid_slice(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + dst[10:] + + with pytest.raises(triton.TritonError, match='unsupported tensor index'): + _kernel[(1, )](dst=dst) + + +# ---------------- +# test expand_dims +# ---------------- +@pytest.mark.interpreter +def test_expand_dims(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + + # N is a scalar that's not even a tl.tensor -- this should work too. + t = tl.expand_dims(N, -1) + tl.static_assert(t.shape == [1]) + + N = 32 + dummy_tensor = torch.empty((), device=device) + expand_dims_kernel[(1, )](dummy_tensor, N) + + +@pytest.mark.interpreter +def test_expand_dims_error_cases(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device=device) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range1[(1, )](dummy_tensor, N) + assert "invalid axis -3" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range2[(1, )](dummy_tensor, N) + assert "invalid axis 2" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range3[(1, )](dummy_tensor, N) + assert "invalid axis 1" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim1[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim2[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + +# ---------------------------- +# test invalid program id axis +# ---------------------------- +@pytest.mark.interpreter +def test_invalid_pid_axis(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pid = tl.program_id(20) + + with pytest.raises(triton.TritonError) as exc_info: + _kernel[(1, )](dst) + assert re.search(r"program_id axis must be 0, 1, or 2 but got 20", str(exc_info.value.__cause__)) + + +# --------------- +# test where +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where(dtype, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype, device) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_SCALAR_POINTERS: + ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr) + output = tl.load(ptr + offsets, mask=mask) + else: + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + assert (z == to_numpy(z_tri)).all() + if select_ptrs: + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) + z = np.where(cond[0], x, y) + assert (z == to_numpy(z_tri)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where_broadcast(num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + + mask = tl.load(cond_ptr + yoffsets) + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + @triton.jit + def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + mask = False + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + SIZE = 32 + dtype = 'float32' + rs = RandomState(17) + x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) + mask = numpy_random(SIZE, 'bool', rs=rs) + z = np.where(mask, x, 0) + cond_tri = to_triton(mask, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) + assert (z == to_numpy(z_tri)).all() + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) + z = np.where(0, x, 0) + assert (z == to_numpy(z_tri)).all() + + +# --------------- +# test unary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_unary_op(dtype_x, expr, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + + +# ---------------- +# test math ops +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr, x", + [(dtype_x, expr, x) + for dtype_x in ["float32", "float64"] + for expr in ['exp', 'log', 'cos', 'sin', 'exp2', 'log2', 'sqrt', 'floor', 'ceil'] + for x in ['x', '3.0']]) +def test_math_op(dtype_x, expr, x, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_erf_op(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.math.erf(x) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = torch.erf(x) + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_fma_op(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, Y, W, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + w = tl.load(W + off) + z = tl.math.fma(x, y, w) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + y = torch.randn(SIZE, dtype=torch_dtype, device=device) + w = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = x * y + w + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, y, w, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_divide_op(expr, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + numpy_expr = "x / y" + dtype = "float32" + _test_binary(dtype, dtype, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +# ------------- +# test precise math +# ------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("expr_prec, expr_ref", + [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), + ('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_precise_math(expr_prec, expr_ref, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + prec = PREC_CALC + ref = REF_CALC + tl.store(OUT + tl.arange(0, BLOCK), prec) + tl.store(OUT_REF + tl.arange(0, BLOCK), ref) + + shape = (128, ) + out = torch.zeros(shape, dtype=torch.float32, device=device) + out_ref = torch.zeros(shape, dtype=torch.float32, device=device) + + x = torch.randn(shape, dtype=torch.float32, device=device) + y = torch.randn(shape, dtype=torch.float32, device=device) + + if (expr_prec.count('sqrt') > 0): + x = torch.abs(x) + + if (expr_prec.count('div') > 0): + y += 1e-6 + + kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref}) + + kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas) + assert torch.all(out == out_ref) # bitwise exact + + +# ---------------- +# test abs +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_abs(dtype_x, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) +def test_abs_fp8(in_dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') + elif is_cuda(): + cc = torch.cuda.get_device_capability() + if in_dtype == tl.float8e4b15 and cc >= (9, 0): + pytest.skip("float8e4b15 not supported on CUDA >= 9.0") + if in_dtype == tl.float8e4nv and cc < (8, 9): + pytest.skip("float8e4nv not supported on CUDA < 8.9") + print(in_dtype) + + @triton.jit + def abs_kernel(X, Z, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.abs(x) + tl.store(Z + off, z) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device=device) + # f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan + all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width + f8_tensor[all_exp_ones] = 0 + f8 = triton.reinterpret(f8_tensor, in_dtype) + n_elements = f8_tensor.numel() + out_f8 = torch.empty_like(f8_tensor) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + + f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) + expect = f32_tensor.abs() + actual_f8 = convert_float_to_float32(out_f8, in_dtype) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) + + +# ---------------- +# test passing shapes as individual params rather than tuples +# ---------------- + + +@pytest.mark.interpreter +def test_shapes_as_params(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(): + a = tl.arange(0, 32).expand_dims(-1).broadcast_to(32, 32) + tl.static_assert(a.shape == [tl.constexpr(32), tl.constexpr(32)]) + + a = tl.arange(0, 32).reshape(4, 8).permute(1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).trans() + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).reshape(32) + tl.static_assert(a.shape == [tl.constexpr(32)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans((2, 1, 0)) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).view(2, 4, 8) + tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) + + kernel[(1, )]() + + +# ---------------- +# test transpose +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_transpose(dtype_x, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_x, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + off2d = off[None, :] + (tl.arange(0, 2) * SIZE)[:, None] + x = tl.load(X + off2d) + z = x.T + tl.store(Z + off2d.T, z) + + x = numpy_random([SIZE, 2], dtype_str=dtype_x) + z_ref = x.T + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE) + np.testing.assert_allclose(z_ref, to_numpy(z_tri)) + + +# ---------------- +# test indexing +# ---------------- + + +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" + + +# TODO: handle `%4 = ttg.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>`` +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_index1d(expr, dtype_str, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] + + # Triton kernel + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) + + # torch result + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) + z_ref = eval(expr) + y + # triton result + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + # compare + assert (z_ref == to_numpy(z_tri)).all() + + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) + except triton.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + + +# --------------- +# test tuples +# --------------- + + +@triton.jit +def tuples_fn(a, b): + return a + b, \ + a - b, \ + a * b + + +@pytest.mark.interpreter +def test_tuples(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = tuples_fn(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref + + +@triton.jit(noinline=True) +def noinline_simple_fn(x, y, Z): + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_graph_fn1(x): + return x + 1 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn2(y): + return y + 2 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn(x, y, Z): + t0 = noinline_call_graph_fn1(x) + t1 = noinline_call_graph_fn2(y) + z = t0 + t1 + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_shared_fn(x, y, Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + x + y + tl.store(Z + offs, z) + + +@triton.jit(noinline=True) +def noinline_dynamic_fn(x, y, Z): + if x >= 1: + x = noinline_call_graph_fn1(x) + else: + x = noinline_call_graph_fn2(x) + if y >= 2: + y = noinline_call_graph_fn2(y) + else: + y = noinline_call_graph_fn1(y) + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_multi_values_fn(x, y): + return x + 1, y + 2 + + +@triton.jit(noinline=True) +def noinline_multi_values_fn(x, y, Z): + x, y = noinline_call_multi_values_fn(x, y) + z = x + y + tl.store(Z, z) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) +def test_noinline(mode, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + GENERATE_TEST_HERE(x, y, Z) + + func_name = f'noinline_{mode}_fn' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name}) + x = torch.tensor([1.0], device=device, dtype=torch.float32) + y = torch.tensor([2.0], device=device, dtype=torch.float32) + if mode == "shared": + z = torch.ones((16, 16), device=device, dtype=torch.float32) + else: + z = torch.tensor([0.0], device=device, dtype=torch.float32) + kernel[(1, )](x, y, z, num_warps=1) + if mode == "simple": + assert torch.equal(z, x + y) + elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": + assert torch.equal(z, x + 1 + y + 2) + elif mode == "shared": + ref = torch.full((16, 16), 16, device=device, dtype=torch.float32) + assert torch.equal(z, ref + x + y) + + +# --------------- +# test atomics +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ + ('add', 'float16', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + ('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), + ('max', 'int64', mode, sem), + ('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), + ('min', 'int64', mode, sem), + ('min', 'float64', mode, sem), + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) +def test_atomic_rmw(op, dtype_x_str, mode, sem, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_x_str, device) + if is_interpreter(): + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on GPU") + + n_programs = 5 + + # triton kernel + @triton.jit + def kernel(X, Z): + pid = tl.program_id(0) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) + + sem_arg = sem if sem is None else f'"{sem}"' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] + + # triton result + rs = RandomState(17) + x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str)) + if mode == 'all_neg': + x = -np.abs(x) + if mode == 'all_pos': + x = np.abs(x) + if mode == 'min_neg': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 + if mode == 'max_pos': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) + + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) + h = kernel[(n_programs, )](x_tri, z_tri) + # torch result + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == to_numpy(z_tri).item() + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + sem_str = "acq_rel" if sem is None else sem + if not is_cuda(): + return + + assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_rmw_predicate(num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X): + val = tl.program_id(0) + if val < 64: + tl.atomic_max(X, val) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) + assert x.item() == 63 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str, check_return_val", + [(shape, axis, num_ctas, dtype_x_str, check_return_val) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list + for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64'] + for check_return_val in ([True, False] if is_hip() else [True])]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_x_str, device) + shape0, shape1 = shape + # triton kernel + + @triton.jit + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr, + RETURN_VAL: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + + if DTYPE == tl.float16: + # sum can have bad numerics when accumulating in float16. + # if we're dealing with float16, do the sum in float32. + x = x.to(tl.float32) + + z = tl.sum(x, axis=AXIS) + + if DTYPE == tl.float16: + z = z.to(DTYPE) + + if AXIS == 1: + old = tl.atomic_add(Z + off0, z) + if RETURN_VAL: + tl.store(OLD + off0, old) + else: + old = tl.atomic_add(Z + off1, z) + if RETURN_VAL: + tl.store(OLD + off1, old) + + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) + old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str)) + # reference results + if x.dtype == np.float16: + # do the sum in float32 to reduce numerical variation + z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype) + else: + z_ref = z + np.sum(x, axis=axis, keepdims=False) + old_ref = np.copy(z) + # triton result + x_tri = to_triton(x, device=device) + z_tri = to_triton(z, device=device) + old_tri = to_triton(old, device=device) + + def torch_to_triton_dtype(t): + if t == torch.float16: + return tl.float16 + return None + + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), check_return_val, + num_ctas=num_ctas) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + if check_return_val: + np.testing.assert_equal(old_ref, to_numpy(old_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str) + for size in [2, 4, 8, 32, 64, 128] + for num_ctas in num_ctas_list + for dtype_x_str in ['float16', 'float32']]) +def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, val, NUM: tl.constexpr): + off = tl.arange(0, NUM) + offset = off[:, None] * NUM + off[None, :] + val = tl.load(val + offset) + tl.atomic_add(X + offset // 2, val) + + shape = (size // 2, size) + x = torch.zeros(shape, dtype=getattr(torch, dtype_x_str), device=device) + val = torch.randn((size**2), dtype=getattr(torch, dtype_x_str), device=device) + kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas) + ref = val[0::2] + val[1::2] + torch.testing.assert_close(ref, x.reshape(math.prod(shape))) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, idx_order, mask_step, num_ctas, dtype_x_str", + [(shape, idx_order, mask_step, num_ctas, dtype_x_str) + for shape in [(2, 2), (4, 4), (5, 5), (6, 6), (8, 8)] + for idx_order in ['increase', 'decrease', 'random_no_duplication', 'random'] + for mask_step in range(1, 5) + for num_ctas in num_ctas_list + for dtype_x_str in ['float16', 'float32']]) +def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas, dtype_x_str, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_x_str, device) + if is_interpreter(): + pytest.skip("not supported in the interpreter") + + @triton.jit + def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + x_idx = xoffset + tl.arange(0, XBLOCK)[:] + mask = x_idx < shape0 * shape1 + mask = mask and (x_idx % mask_step != 0) + idx_base = shape1 * (x_idx // shape1) + idx_offset = tl.load(idx_ptr + x_idx, mask) + in_elem = tl.load(in_ptr + x_idx, mask) + tl.atomic_add(out_ptr + (idx_offset + idx_base), in_elem, mask, sem='relaxed') + + shape0, shape1 = shape + idx_row = torch.arange(0, shape1, device=device) + if idx_order == 'increase': + idx = torch.stack([idx_row.repeat_interleave(i + 1)[:shape1] for i in range(shape0)]) + if idx_order == 'decrease': + idx = torch.stack([idx_row.flip(0).repeat_interleave(i + 1)[:shape1] for i in range(shape0)]) + if idx_order == 'random_no_duplication': + idx = torch.stack([torch.randperm(shape1, device=device) for _ in idx_row]) + if idx_order == 'random': + idx = torch.randint(0, shape1, size=(shape0, shape1), device=device) + + val = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device) + dst = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device) + + dst_ref = dst.clone() + + cnt = 0 + for i, row in enumerate(idx): + for j, elem in enumerate(row): + if cnt % mask_step != 0: + dst_ref[i][elem] += val[i][j] + cnt += 1 + + kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas) + np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_rmw_block(num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + shape = (8, 8) + + @triton.jit + def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + offs = off0[:, None] * SHAPE1 + off1[None, :] + val = offs.to(tl.float32) + x = X + offs + tl.atomic_min(x, val) + + x = torch.ones((8, 8), device=device, dtype=torch.float32) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) + assert torch.min(x).item() == 0.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_cas(sem, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + change_value[(1, )](Lock) + + assert (Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock, SEM: tl.constexpr): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1, SEM) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # insert barrier to set a fence between tl.store and + # tl.atomic_xchg in a block. + tl.debug_barrier() + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 2000.0) + h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas) + sem_str = "acq_rel" if sem is None else sem + np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) + if not is_cuda(): + return + assert f"atom.global.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_cas(sem, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64) + t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64) + tl.atomic_cas(X + offsets, t1, t2, sem=sem) + + X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64) + Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64) + + change_value[(2, )](X, 4, sem) + assert (torch.equal(X, Y)) + + +@pytest.mark.interpreter +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, + reason="Requires compute capability >= 9 for NV") +def test_load_scope_sem_coop_grid_cta_not_one(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): + numel = 512 + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + mask = index < numel + a = tl.load(ptrs, mask=mask) + tl.store(ptrs, a) + + block_size = 128 + data = torch.zeros((128, ), device=device, dtype=torch.float32) + + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=True) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=False) + + +@pytest.mark.interpreter +def test_load_scope_sem_coop_grid_cta_one(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): + numel = 512 + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + mask = index < numel + a = tl.load(ptrs, mask=mask) + tl.store(ptrs, a) + + block_size = 128 + data = torch.zeros((128, ), device=device, dtype=torch.float32) + + # Should do nothing different for num_ctas=1 (with coop launch grid) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=True) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=False) + + +# --------------- +# test cast +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'bool', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] # + + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" else [])) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + # CUDA: bfloat16 on cc < 80 will not be tested + # Interpreter: Only bfloat16 <-> float32 is supported + if not is_interpreter() or \ + (is_interpreter() and not ((dtype_z == 'bfloat16' and dtype_x == 'float32') + or (dtype_z == 'float32' and dtype_x == 'bfloat16'))): + check_type_supported(dtype_x, device) + check_type_supported(dtype_z, device) + + if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + + torch.manual_seed(0) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + if dtype_x.startswith('bfloat'): + x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) + else: + x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 + # Triton clamps negative values to zero, while numpy wraps around + # intmax, so avoid negatives for now. + # TODO: figure out which one should actually be happening, and test it + if dtype_z in uint_dtypes: + x = np.absolute(x) + x_tri = to_triton(x, device=device) + if 'float' in dtype_z and 'float' in dtype_x: + # make sure we use values that can be represented in both types + x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x)) + # triton kernel + + @triton.jit + def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + x_ptr = X + tl.arange(0, SIZE) + z_ptr = Z + tl.arange(0, SIZE) + x = tl.load(x_ptr) + + # Depending on the value of ARG_HASH (a "random" number determined by + # the test parameters), spell the cast one of three different ways. + if ARG_HASH % 4 == 0: + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 4 == 1: + z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 4 == 2: + z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, TO_TYPE, bitcast=BITCAST) + + tl.store(z_ptr, z) + + # "Random" number used inside the kernel to determine how we spell the cast. + # This way we don't have to increase the number of tests. + arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) + + dtype_z_np = dtype_z if dtype_z != 'bool' else 'bool_' + # triton result + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) + else: + z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) + + dtype_z_tri = str_to_triton_dtype(dtype_z) + kernel[(1, )](x_tri, z_tri, TO_TYPE=dtype_z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, + num_ctas=num_ctas) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + if dtype_z.startswith('float8') and device not in ['cuda']: + t = z_ref.byte() ^ z_tri.byte() + torch.testing.assert_close(torch.zeros_like(t, dtype=torch.uint8), t) + else: + torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z_np)) + else: + z_ref = x.astype(getattr(np, dtype_z_np)) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +def test_cat(dtype_str, num_warps, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.cat(x, y, can_reorder=True) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) + y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) + z_ref = torch.cat([x, y], dim=0).sum() + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](x, y, z, N=128, num_warps=num_warps) + assert z.sum() == z_ref + # check if there's no duplicate value in z + # TODO(chongzhou.yang): fallback it to cpu until topsatenunique2 is supported + assert z.to('cpu').unique().size(0) == z.to('cpu').size(0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("constant_field", ["value", "mask"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant(num_ctas, dtype_str, constant_field, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if CONSTANT_FIELD == "value": + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + mask = offsets < n_elements + elif CONSTANT_FIELD == "mask": + output = offsets < n_elements + mask = False + tl.store(output_ptr + offsets, output, mask=mask) + + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field) + + if constant_field == "value": + assert torch.all(output == ref) + else: + assert torch.all(output == 0) + + +def test_load_store_same_ptr(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit() + def kernel(in_out_ptr): + pid = tl.program_id(axis=0) + x = tl.load(in_out_ptr + pid) + out = x * 2 + tl.store(in_out_ptr + pid, out) + + for _ in range(1000): + x = torch.ones((65536, ), device=device, dtype=torch.float32) + if is_hip(): + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 + else: + kernel[(65536, )](x, num_warps=32) + assert torch.all(x == 2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['int32']) +def test_umulhi(dtype_str, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + def umulhi32(a, b): + # Convert to 64-bit unsigned integers to prevent overflow + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + + # Perform the multiplication in 64-bit + product_64 = a_64 * b_64 + + # Shift right by 32 bits to get the high part of the product + result_high_32 = product_64 >> 32 + return result_high_32 + + rs = RandomState(17) + N = 128 + x = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + x_tri = to_triton(x, device=device) + y = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + y_tri = to_triton(y, device=device) + z_tri = torch.zeros_like(x_tri) + kernel[(1, )](x_tri, y_tri, z_tri, N=N) + + z_ref = umulhi32(x, y) + np.testing.assert_equal(z_ref, to_numpy(z_tri)) + + +@pytest.mark.interpreter +def test_join(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.join(x, y) + tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(-128, 0, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, y, z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_scalars(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + z = tl.join(x, y) + tl.static_assert(z.shape == [2]) + tl.store(Z + tl.arange(0, 2), z) + + x = torch.full([1], 42, device=device).to(torch.int32) + y = torch.full([1], 100, device=device).to(torch.int32) + z = torch.zeros([2], device=device) + kernel[(1, )](x, y, z) + + np.testing.assert_equal([42, 100], to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_with_mma(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Z): + x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16) + x2 = tl.join(x, 2 * x) # (32,16,2) + x3 = tl.reshape(x2, (32, 32)) + z = tl.dot(x3, x3) # (32,32) + tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z) + + x = torch.arange(0, 32 * 16, device=device, dtype=torch.float32).reshape((32, 16)) + r = torch.stack([x, 2 * x], dim=-1).reshape((32, 32)) + z_ref = torch.matmul(r, r) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, z) + + torch.testing.assert_close(z, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("debug", [False, True]) +def test_interleave(device, debug): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit(debug=debug) + def kernel(Z, N: tl.constexpr): + z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N)) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(128, 256, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1).reshape(256) + z = torch.zeros_like(z_ref) + kernel[(1, )](z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_interleave_scalars(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Y, Z): + z = tl.interleave(X, Y) + tl.static_assert(z.shape == [tl.constexpr(2)]) + tl.store(Z + tl.arange(0, 2), z) + + z = torch.zeros(2, device=device) + kernel[(1, )](10, 20, z) + + np.testing.assert_equal([10, 20], to_numpy(z)) + + +@pytest.mark.interpreter +def test_split(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Z1, Z2, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + x1 = tl.reshape(x, (N // 2, 2)) + z1, z2 = tl.split(x1) + tl.store(Z1 + tl.arange(0, N // 2), z1) + tl.store(Z2 + tl.arange(0, N // 2), z2) + + x = torch.arange(0, 256, device=device).to(torch.int32).reshape((128, 2)) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2, N=256) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +@pytest.mark.interpreter +def test_split_to_scalar(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Z1, Z2): + offs = tl.arange(0, 2) + x = tl.load(X + offs) + z1, z2 = tl.split(x) + tl.static_assert(isinstance(z1, tl.tensor)) + tl.static_assert(isinstance(z2, tl.tensor)) + tl.static_assert(z1.shape == []) + tl.static_assert(z2.shape == []) + tl.store(Z1, z1) + tl.store(Z2, z2) + + N = 2 + # not support int64 + # x = torch.arange(0, N, device=device).reshape(N // 2, 2) + x = torch.arange(0, N, device=device).to(torch.int32).reshape(N // 2, 2) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +def convert_float_to_float32(fp: torch.tensor, dtype=None): + if not dtype: + dtype = getattr(tl, torch_dtype_name(fp.dtype)) + + fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}")) + exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1 + exp_bias = dtype.exponent_bias + sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int() + exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() + frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() + + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() + + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + # special cases, exp is 0b11..1 + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities + output[fp == 0b01111111] = torch.nan + output[fp == 0b11111111] = torch.nan + else: + output = torch.where(exp == (1 << exp_width) - 1, + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp + | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) + return output + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +def test_convert_float16_to_float32(in_dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + """Tests that check convert_float_to_float32 function""" + check_type_supported(in_dtype, device) + + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) + f32_output = convert_float_to_float32(f16_input) + + nan = f16_input.isnan() + assert torch.all(f32_output[nan].isnan()) + inf = f16_input.isinf() + assert torch.all(f32_output[inf].isinf()) + other = torch.logical_not(torch.logical_or(nan, inf)) + assert torch.all(f16_input[other] == f32_output[other]) + + +def serialize_fp8(np_data, in_dtype): + return np_data + + +# inverse of `serialize_fp8` + + +def deserialize_fp8(np_data, in_dtype): + return np_data + + +# --------------- +# test reduce +# --------------- + + +@pytest.mark.interpreter +def test_max_returns_zero(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + # Simple test with a tl.max call that returns 0. The interpreter had a bug + # where it didn't handle this correctly. + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + z = tl.max(x) + tl.store(Z, z) + + BLOCK = 128 + x = torch.zeros((BLOCK, ), device=device) + z = torch.ones((1, ), device=device) + + kernel[(1, )](x, z, BLOCK=BLOCK) + assert z[0] == 0 + + +def get_reduced_dtype(dtype_str, op): + if op in ('argmin', 'argmax'): + return 'int32' + if dtype_str == 'bfloat16': + return 'float32' + return dtype_str + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce1d(op, dtype_str, shape, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + GENERATE_TEST_HERE + tl.store(Z, z) + + if 'with-indices' in op: + patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)' + elif 'arg' in op: + tie_break_left = 'tie-break-left' in op + patch = f'z = tl.{op.split("-")[0]}(x, axis=0, tie_break_left={tie_break_left})' + else: + patch = f'z = tl.{op}(x, axis=0)' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-left': np.argmax, + }[op] + if 'tie-break-left' in op: + x[3:10] = x[numpy_op(x)] + x_tri = to_triton(x, device=device) + # numpy result + z_dtype_str = 'int32' if 'tie-break-left' in op else dtype_str + z_tri_dtype_str = z_dtype_str + if 'tie-break-left' not in op and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if 'tie-break-left' in op: + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# TODO: [Qingyi] Fix argmin / argmax +reduce_configs1 = [(op, dtype, (1, 1024), axis, False) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] + +# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory +# exceeds the limit of 99KB +reduce2d_shapes = [(2, 32), (4, 32), (4, 128)] +# TODO: fix and uncomment +# , (32, 64), (64, 128)] +if is_cuda() and 'V100' in torch.cuda.get_device_name(0): + reduce2d_shapes += [(128, 256) and (32, 1024)] + +reduce_configs2 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None, False) for op in ['min', 'max', 'sum']] + +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] +invalid_config = [('sum', 'float32', (32, 32), axis, False) for axis in [2, 3]] +negative_config = [('sum', 'float32', (32, 32), -1, False)] +keep_dims_2d_configs = [(op, 'float32', (32, 32), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1]] + [(op, 'float32', (32, 32), None, True) for op in ['min', 'max', 'sum']] +keep_dims_3d_configs = [(op, 'float32', (32, 2, 16), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1, 2]] + [(op, 'float32', (32, 2, 16), None, True) + for op in ['min', 'max', 'sum']] +reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + + negative_config + keep_dims_2d_configs + keep_dims_3d_configs + reduce_bool) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) + else: + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + if USE_I1: + x = tl.cast(x, tl.int1) + z = GENERATE_TEST_HERE + z_ptr = Z + if KEEP_DIMS and AXIS is None: + if IS_3D: + z_ptr = z_ptr[None, None, None, :] + else: + z_ptr = z_ptr[None, None, :] + if IS_3D: + if AXIS == 0: + z_ptr = Z + range_n[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 1 or AXIS == -2: + z_ptr = Z + range_m[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 2 or AXIS == -1: + z_ptr = Z + range_m[:, None] * BLOCK_N + range_n[None, :] + else: + if AXIS == 0: + z_ptr = Z + range_n + elif AXIS == 1 or AXIS == -1: + z_ptr = Z + range_m + if KEEP_DIMS and AXIS is not None: + z_ptr = tl.expand_dims(z_ptr, axis=AXIS) + tl.store(z_ptr, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_tri = to_triton(x, device=device) + numpy_op = { + 'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax, 'xor_sum': + np.bitwise_xor.reduce + }[op] + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + if z_dtype_str == 'bool': + z_dtype_str = 'int8' + + # numpy result + # Silence numpy error on axis out of bounds, to give triton a chance to fail + np_axis = axis if axis is not None and axis < len(shape) else None + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + + # triton result + z_shape = z_ref.shape + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) + USE_I1 = dtype_str == 'bool' + if axis is not None and axis >= len(shape): + with pytest.raises(triton.TritonError): + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, USE_I1=USE_I1, num_ctas=num_ctas) + return + else: + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, USE_I1=USE_I1, num_ctas=num_ctas) + + z_tri = to_numpy(z_tri) + + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = z_ref + z_tri_index = z_tri + if not keep_dims: + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] + +scan_configs = [(op, type, shape, axis, reverse, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32', 'bfloat16'] + for axis in [1, 0] + for reverse in [True, False] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element', 'linear_recurrence', 'cummax', 'roll']] +negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)] + + +def test_sum_dtype(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel_dtype(out_ptr, init, in_dtype: tl.constexpr, out_dtype: tl.constexpr): + x = tl.full((32, 32), init, dtype=in_dtype) + x = tl.sum(x, dtype=out_dtype) + tl.store(out_ptr, x.to(tl.int32)) + + @triton.jit + def kernel_default_int(out_ptr): + x = tl.full((32, 32), 1, dtype=tl.int1) + x = tl.sum(x) + tl.store(out_ptr, x) + + @triton.jit + def kernel_default_float(out_ptr): + x = tl.full((32, 32), 1.0, dtype=tl.bfloat16) + x = tl.sum(x) + tl.store(out_ptr, x) + + out = torch.empty(1, dtype=torch.int32, device=device) + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int1, out_dtype=None) + assert out[0] == 32 * 32 + + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int1, out_dtype=tl.int1) + assert out[0] == 0 + + kernel_dtype[(1, )](out, init=7, in_dtype=tl.int8, out_dtype=tl.int8) + assert out[0] == (7 * 32 * 32) % 256 + + kernel_dtype[(1, )](out, init=1, in_dtype=tl.int32, out_dtype=None) + assert out[0] == 32 * 32 + + kernel_default_int[(1, )](out) + assert out[0] == 32 * 32 + + out = torch.empty(1, dtype=torch.bfloat16, device=device) + kernel_default_float[(1, )](out) + torch.testing.assert_close(out[0], torch.tensor(32 * 32, dtype=torch.bfloat16, device=device)) + + +@triton.jit +# trivial associative but not commutative function +def get_first_element(a, b): + return a + + +# Compute x_i = a_i * x_{i-1} + b_i +@triton.jit +def linear_recurrence(a1, b1, a2, b2): + return a1 * a2, b1 * a2 + b2 + + +@triton.jit +def cummax(v0, i0, v1, i1): + gt = v0 > v1 + return tl.where(gt, v0, v1), tl.where(gt, i0, i1) + + +@triton.jit +def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): + return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) +def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_str, device) + if dtype_str == 'bfloat16': + if op == 'linear_recurrence': + pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") + numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + + # triton kernel + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + y = tl.load(Y + range_m[:, None] * BLOCK_N + range_n[None, :]) + GENERATE_TEST_HERE + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'z = tl.{op}(x, axis={axis}, reverse={reverse})'}) + elif op == 'get_first_element': + kernel = patch_kernel( + kernel, + {'GENERATE_TEST_HERE': f'z = tl.associative_scan(x, axis={axis}, combine_fn={op}, reverse={reverse})'}) + elif op == 'cummax': + rg = "range_m[:, None]" if axis == 0 else "range_n[None, :]" + rg = f"tl.broadcast_to({rg}.to(tl.int64), [BLOCK_M, BLOCK_N])" + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, {rg}), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + elif op == 'roll': + assert op == 'roll' + kernel = patch_kernel( + kernel, { + 'GENERATE_TEST_HERE': + f'_, z, _ = tl.associative_scan((1 + 0* x, 0 * x, x), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + else: + assert op == 'linear_recurrence' + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, y), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + # input + rs = RandomState(17) + if op == 'linear_recurrence' and dtype_str in int_dtypes: + # If the numbers are too large the op will overflow + # We sample numbers in -1, 0, 1 + x = rs.randint(-1, 2, shape, dtype=dtype_str) + y = rs.randint(-1, 2, shape, dtype=dtype_str) + else: + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + # y is just used in linear_recurrence + y = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_in = x + if reverse: + x_in = np.flip(x, axis) + z = np.empty_like(x) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + y_tri = to_triton(y, device=device, dst_type=dtype_str) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_ref = numpy_op(x_in, axis=axis).astype(getattr(np, numpy_dtype_str)) + if reverse: + z_ref = np.flip(z_ref, axis) + + elif op == 'cummax': + # NumPy does not have cummax + # z = z.astype(np.int64) + z = z.astype(np.int32) + z_ref = torch.cummax(torch.from_numpy(x_in.copy()), axis=axis).indices.numpy() + if reverse: + z_ref = x_in.shape[axis] - np.flip(z_ref, axis) - 1 + elif op == 'roll': + ROLL = 1 + z_ref = np.roll(x_in.copy(), ROLL, axis=axis) + if axis == 0: + z_ref[:ROLL] = 0 + else: + z_ref[:, :ROLL] = 0 + + if reverse: + z_ref = np.flip(z_ref, axis) + elif op == 'linear_recurrence': + # Simplify to the axis=1 case + x_ref = x.T if axis == 0 else x + y_ref = y.T if axis == 0 else y + if reverse: + x_ref = np.flip(x_ref, 1) + y_ref = np.flip(y_ref, 1) + + result = [] + for x_refi, y_refi in zip(x_ref, y_ref): + li = [] + acc = 0 + for xi, yi in zip(x_refi, y_refi): + acc = xi * acc + yi + li.append(acc) + result.append(li) + z_ref = np.array(result) + if reverse: + z_ref = np.flip(z_ref, 1) + + if axis == 0: + z_ref = z_ref.T + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + if reverse: + z_ref[:-1] = x[-1] + else: + z_ref[1:] = x[0] + else: + if reverse: + z_ref[:, :-1] = x[:, -1:] + else: + z_ref[:, 1:] = x[:, 0:1] + + # triton result + # we don't cast the `fp32 = bf16 op bf16` result to bfloat16 to alleviate accuracy issues + z_tri = to_triton(z, device=device) + kernel[(1, )](x_tri, y_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + + z_tri = to_numpy(z_tri) + # compare + if dtype_str not in int_dtypes: + if op == 'cumprod': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# --------------- +# test histogram +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) +def test_histogram(M, N, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + offset1) + z = tl.histogram(x, N) + bias = tl.full([M, N], 1, dtype=tl.int32) + # check that histogram produces object compatible with broadcasting + biased = z + bias + tl.store(z_ptr + offset2, z) + + torch.manual_seed(17) + x = torch.randint(0, N, (M, ), device=device, dtype=torch.int32) + z = torch.empty(N, dtype=torch.int32, device=device) + # torch.histc does not work when the input type is not float and the device is CPU + # https://github.com/pytorch/pytorch/issues/74236 + # This is a workload by converting the input to float + z_torch = torch.histc(x.float(), bins=N, min=0, max=N - 1) + histogram_kernel[(1, )](x, z, M=M, N=N) + assert (z_torch == z).all() + + +@pytest.mark.parametrize("M, N", [(1, 64), (2, 32), (4, 16), (8, 8), (16, 4), (32, 2), (64, 1)]) +def test_scan_1d(M, N, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def scan_kernel(out_ptr, in_ptr, M: tl.constexpr, N: tl.constexpr): + input = tl.load(in_ptr + tl.arange(0, M)) + output = tl.cumsum(input).reshape([1, M]).broadcast_to([N, M]) + tl.store(out_ptr + tl.arange(0, M * N), output.reshape([M * N])) + + x = torch.randint(-100, 100, (M, ), dtype=torch.int32, device=device) + output = torch.empty(M * N, dtype=torch.int32, device=device) + + scan_kernel[(1, )](output, x, M, N) + + ref = torch.cumsum(x, dim=0).reshape([1, M]).broadcast_to([N, M]).reshape([M * N]) + torch.testing.assert_close(ref.to(torch.int32), output, atol=0, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['sum', 'max', 'min']) +@pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) +@pytest.mark.parametrize("N", [512, 1024, 2048]) +@pytest.mark.parametrize("num_pid_n", [2, 4]) +def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + start_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_n = tl.num_programs(1) + local = INITIALIZE_PATCH + off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n): + off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * N + off_n[None, :] + x = tl.load(Xs) + local = ACCUMULATE_PATCH + tl.store(Y + off_m * num_pid_n + pid_n, local) + + initialize_patch = { + 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', + 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', + 'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)', + }[op] + reduce_patch = { + 'sum': 'local + tl.sum(x, axis=1)', + 'max': 'tl.maximum(local, tl.max(x, axis=1))', + 'min': 'tl.minimum(local, tl.min(x, axis=1))', + }[op] + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + }[op] + kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch}) + torch.manual_seed(0) + BLOCK_M = 32 + x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) + y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N) + if not is_interpreter(): + assert h.asm['ttgir'].count( + '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) + y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) + np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) + + +# scan_layouts = [ +# BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), +# BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), +# ] +scan_layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [1, THREADS_PER_WARP], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [1, THREADS_PER_WARP], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) +@pytest.mark.parametrize("src_layout", scan_layouts) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("add_overflow_check", [False, True]) +def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + overflow_check = """ + %17 = arith.extsi %arg2 : i32 to i64 + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.addi %17, %18 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %20 = arith.cmpi slt, %19, %i32.max : i64 + %21 = arith.cmpi sge, %19, %i32.min : i64 + %22 = arith.andi %20, %21 : i1 + tt.assert %22, "overflow detected" : i1 + """ + + ir = f""" + #blocked = {src_layout} + module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> + %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ + ^bb0(%arg2: i32, %arg3: i32): + %16 = arith.addi %arg2, %arg3 : i32{overflow_check if add_overflow_check else ""} + tt.scan.return %16 : i32 + }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %14 = tt.broadcast %13 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + tt.store %15, %11 : tensor<{M}x{N}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + + temp_file = tmp_path / "test_scan_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = rs.randint(-100, 100, (M, N)).astype('int32') + + z = np.zeros((M, N)).astype('int32') + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + kernel[(1, 1, 1)](x_tri, z_tri) + + z_ref = np.cumsum(x, axis=axis) + + np.testing.assert_equal(z_ref, z_tri.cpu().numpy()) + + +layouts = [ + # BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + # BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP // 1], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP // 1], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True), + WmmaLayout(version=1, warps_per_cta=[4, 1]), + WmmaLayout(version=1, warps_per_cta=[1, 4]), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 4], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([3, 0], [8, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + # FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved + # SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])), + # SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])), + SliceLayout(dim=0, parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8])), + SliceLayout( + dim=1, parent=DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), + op_idx=1, k_width=2)), + LinearLayout(register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], + [0, 8]], warp=[[32, 0], [0, 32]], + block=[]), +] + + +@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32], [16, 16]]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) +@pytest.mark.parametrize("dtype_str,add_overflow_check", [("int32", False), ("int32", True), ("float32", False), + ("float16", False)]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_overflow_check, reduce_op, device, + tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if isinstance(src_layout, + (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): + pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") + if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)): + pytest.skip("Skipping test because it runs out of shared memory") + if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: + pytest.skip("Skipping sum reduction on float16 due to accuracy issues") + if is_hip() and isinstance(src_layout, LinearLayout): + pytest.skip("FIXME: LinearLayout not supported on HIP") + + if isinstance(src_layout, MmaLayout) and src_layout.version == 3: + src_layout.instr_shape[2] = 16 if dtype_str == "float16" else 8 + + overflow_check = """ + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.extsi %arg4 : i32 to i64 + %20 = arith.addi %18, %19 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %21 = arith.cmpi slt, %20, %i32.max : i64 + %22 = arith.cmpi sge, %20, %i32.min : i64 + %23 = arith.andi %21, %22 : i1 + tt.assert %23, "overflow detected" : i1 + """ + + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maxnumf", "float16": "arith.maxnumf"}, # + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] + rdims_1d = f"{N}" if axis == 0 else f"{M}" + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" + store_range = "%7" if axis == 0 else "%1" + warps = warps_per_cta(src_layout, [M, N]) + num_warps = np.prod(warps) + # blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1]) + blocked = BlockedLayout([1, 1], [1, THREADS_PER_WARP], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) + one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0], [1], [1], [0]) + + expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1" + other_axis = 1 - axis + epilogue = { + "reduce1d": + f""" + %14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + tt.return + }} + }} + """, "reduce2d": + f""" + %14 = "tt.reduce"(%13) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + tt.reduce.return %17 : {ty} + }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} + tt.store %arg2, %14 : !tt.ptr<{ty}> + tt.return + }} + }} + """, "expand_reduce2d": + f""" + %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> + %15 = "tt.reduce"(%14) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + tt.reduce.return %17 : {ty} + }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) + %16 = ttg.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> + %17 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.store %17, %16 : tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.return + }} + }} + """ + }[epilogue_kind] + + ir = f""" + #blocked = {blocked} + #src = {src_layout} + #one_d_layout = {one_d_layout} + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} + tt.reduce.return %17 : {ty} + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> + """ + epilogue + + temp_file = tmp_path / "test_reduce_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + reduce2d = 'reduce2d' in epilogue_kind + z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1) + z = np.zeros(z_shape).astype(dtype_str) + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x) if reduce2d else numpy_op(x, axis=axis, keepdims=True) + + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.parametrize("M", [32, 64, 128, 256]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + tt.store %8, %4 : tensor<{M}x1x!tt.ptr, #src> + tt.return + }} + }} + """ + + temp_file = tmp_path / "test_store_op.ttgir" + temp_file.write_text(ir) + store_kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, 1)).astype('float32') + y = np.zeros((M, 1), dtype='float32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + + pgm = store_kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +layouts = [ + # TODO (lixun): Add MfmaLayout + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), +] + + +@pytest.mark.parametrize("M", [64, 128, 256]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("src_dim", [0, 1]) +@pytest.mark.parametrize("dst_dim", [0, 1]) +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + ir = f""" + #dst = {dst_layout} + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %7 = {GPU_DIALECT}.convert_layout %3 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.store %6, %7 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.return + }} + }} + """ + temp_file = tmp_path / "test_convert1d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, )).astype('int32') + y = np.zeros((M, ), dtype='int32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + pgm = kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@triton.jit +def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = weight_2 / new_weight + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + # [HIP] TO DO: some tests are flaky with the layout, so turn off them for now. + # BlockedLayout([1, 4], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + # BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + # BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) + BlockedLayout([1, 4], [THREADS_PER_WARP, 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) +] + + +@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("op", ["sum", "max"]) +@pytest.mark.parametrize("first_axis", [0, 1]) +def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + op_str = "" + if op == "sum": + op_str = """ + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32""" + elif op == "max": + op_str = """ + %13 = arith.cmpi "sgt", %arg2, %arg3 : i32 + %14 = arith.select %13, %arg2, %arg3 : i32 + tt.reduce.return %14 : i32""" + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %5 = tt.broadcast %2 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %6 = tt.broadcast %4 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #src> + %11 = "tt.reduce"(%10) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>> + %12 = "tt.reduce"(%11) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32 + tt.store %arg1, %12 : !tt.ptr + tt.return + }} + }} + """ + temp_file = tmp_path / "test_chain_reduce.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, N)).astype('int32') + + z = np.zeros((1, )).astype('int32') + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, z_tri) + if op == "sum": + z_ref = np.sum(x) + elif op == "max": + z_ref = np.max(x) + + np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@pytest.mark.interpreter +def test_generic_reduction(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): + xindex = tl.arange(0, BLOCK) + x = tl.load(X + xindex) + mean = x + m2 = tl.zeros_like(x) + weight = tl.full(x.shape, 1, x.dtype) + (mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine) + tl.store(out_mean, mean) + tl.store(out_var, m2 / weight) + + SIZE = 512 + x = torch.rand(SIZE, device=device) + out_mean = torch.empty((), device=device) + out_var = torch.empty((), device=device) + + var_mean_kernel[(1, )](x, out_mean, out_var, BLOCK=SIZE) + # TODO(chongzhou.yang): torch.var_mean is not yet supported in gcu400, fallback it to cpu until 2025.12.31 + expect_var, expect_mean = torch.var_mean(x.to('cpu'), dim=0, correction=0) + torch.testing.assert_close(out_mean, expect_mean.to('gcu')) + torch.testing.assert_close(out_var, expect_var.to('gcu')) + + +# --------------- +# test permute +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_permute(dtype_str, shape, perm, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if dtype_str == "float8e4b15" and (is_hip() or (is_cuda() and torch.cuda.get_device_capability() >= (9, 0))): + pytest.skip("float8e4b15 not supported on ROCm or CUDA >= 9.0") + if is_hip(): + if shape == (128, 128) and dtype_str == 'float32': + pytest.skip("TODO Out of LDS for float32 with shape 128x128") + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + + # input + x = numpy_random(shape, dtype_str=dtype_str) + # triton result + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + # numpy result + if dtype_str == 'float8e4b15': + ty = tl.float8e4b15 + z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty) + z_tri = z_tri.base + z_tri_contiguous = z_tri_contiguous.base + else: + z_ref = x.transpose(*perm) + # compare + np.testing.assert_allclose(to_numpy(z_tri), z_ref) + np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref) + + if not is_cuda(): + return + + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1]))) +def test_trans_2d(dtype_str, shape, perm, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: tl.constexpr, + ou_shape2: tl.constexpr, trans1: tl.constexpr, trans2: tl.constexpr): + in_offs = tl.arange(0, in_shape1)[:, None] * in_shape2 + tl.arange(0, in_shape2)[None, :] + ou_offs = tl.arange(0, ou_shape1)[:, None] * ou_shape2 + tl.arange(0, ou_shape2)[None, :] + tl.store(Out + ou_offs, tl.permute(tl.load(In + in_offs), (trans1, trans2))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1, 2, 3]))) +def test_trans_4d(dtype_str, shape, perm, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(In, Out, # + in_shape1: tl.constexpr, in_shape2: tl.constexpr, in_shape3: tl.constexpr, in_shape4: tl.constexpr, + ou_shape1: tl.constexpr, ou_shape2: tl.constexpr, ou_shape3: tl.constexpr, ou_shape4: tl.constexpr, + trans1: tl.constexpr, trans2: tl.constexpr, trans3: tl.constexpr, trans4: tl.constexpr): + in_ptr = tl.make_block_ptr( + base=In, + shape=(in_shape1, in_shape2, in_shape3, in_shape4), + strides=(in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(in_shape1, in_shape2, in_shape3, in_shape4), + order=(3, 2, 1, 0), + ) + out_ptr = tl.make_block_ptr( + base=Out, + shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + strides=(ou_shape4 * ou_shape3 * ou_shape2, ou_shape4 * ou_shape3, ou_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + order=(3, 2, 1, 0), + ) + tl.store(out_ptr, tl.load(in_ptr).permute((trans1, trans2, trans3, trans4))) + + ## Just support bf16,fp16,fp32,i32 data type in topsAtenArrang + # input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + input = torch.arange(math.prod(shape), device=device).to(getattr(torch, dtype_str)).reshape(shape) + + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + # num_warps in gcu4xx can not exceed 4 + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm, num_warps=4) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# --------------- +# test dot +# --------------- + + +def convert_fp8_to_fp32(x, device, dtype_str): + if dtype_str == 'float8e4nv': + return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32) + elif dtype_str == 'float8e5': + return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32) + elif dtype_str == 'float8e4b8': + return torch.tensor(x, device=device).view(torch.float8_e4m3fnuz).to(torch.float32) + elif dtype_str == 'float8e5b16': + return torch.tensor(x, device=device).view(torch.float8_e5m2fnuz).to(torch.float32) + assert "Unsupported float8 dtype" + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +def get_test_dot_base_cases(): + return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for input_precision in ['tf32', 'tf32x3', 'ieee'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +def get_test_dot_mixed_sizes_cases(): + available_kpack = [1, 2 if is_hip() else 1] + available_precision = ["tf32" if is_cuda() else "ieee"] + return [ + (*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack, None) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for input_precision in available_precision + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', + 'float32'), ('float32', 'float32')] + for kpack in available_kpack + ] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #2370 +def get_test_dot_transposed_op_base_cases(): + return [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1, None) + for col_a in [True, False] + for col_b in [True, False]] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# Introduced in #2750 +def get_test_dot_h100_shortcut_cases(): + return [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1, None)] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #3908 +def get_test_dot_mfma_edge_cases(): + if not is_hip_cdna(): + return [] + return [(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1, None), + (32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1, None)] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #3370 +def get_test_dot_fp8_output_cases(): + return [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1, None) + for float8_type in ["float8e5", "float8e4nv"]] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #5406 +def get_test_dot_small_k_mfma_cases(): + if not is_hip_cdna(): + return [] + return [(32, 32, k_size, 4, False, False, 'None', 'ieee', in_dtype, out_dtype, 1, mma_nonk_size) + for k_size in [1, 2, 4, 8] + for in_dtype, out_dtype in [('float16', 'float32'), ('int8', 'int32')] + for mma_nonk_size in mma_nonk_sizes] + + +# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size +# introduced in #4516 +def get_test_dot_small_mn_fma_cases(): + return [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1, None) + for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]] + + +def get_test_dot_double_rate_cases(): + if not is_hip_cdna(): + return [] + return [(32, 32, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + (32, 32, 16, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None), + (16, 16, 32, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + (16, 16, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size", + get_test_dot_double_rate_cases() + \ + get_test_dot_base_cases() + \ + get_test_dot_mixed_sizes_cases() + \ + get_test_dot_transposed_op_base_cases() + \ + get_test_dot_h100_shortcut_cases() + \ + get_test_dot_mfma_edge_cases() + \ + get_test_dot_fp8_output_cases() + \ + get_test_dot_small_k_mfma_cases() + \ + get_test_dot_small_mn_fma_cases()) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size, + num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_interpreter(): + if in_dtype == 'bfloat16': + pytest.skip("bfloat16 is not supported in the interpreter") + else: + if not is_hip() and (M < 16 or N < 16 or K < 16): + pytest.skip("small dots are supported only on HIP at the moment") + if is_cuda(): + capability = torch.cuda.get_device_capability() + + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if input_precision != "ieee": + pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: + pytest.skip("shared memory out of resource") + if out_dtype == 'float16': + # TODO: support out_dtype=float16 for tl.dot on V100 + pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") + + if is_hip(): + if in_dtype in ("float8e5", "float8e4nv") and not is_hip_mi350(): + pytest.skip(f"{in_dtype} only supported on mi350") + if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_mi300(): + pytest.skip(f"{in_dtype} only supported on mi300") + if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_mi300())): + pytest.skip(f"{input_precision} not supported on HIP") + if kpack == 2 and in_dtype == 'int8' and K < 64: + pytest.skip("kpack too large for K") + if not is_hip() and kpack == 2: + pytest.skip("Skip duplicated tests on nv path") + + torch.backends.cuda.matmul.allow_tf32 = input_precision == "tf32" + + if num_ctas > 1 and in_dtype == 'int8': + # FIXME: mma v2 with num_ctas > 1 does not work + pytest.skip() + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, INPUT_PRECISION: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + z = tl.dot(x, y, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + if ADD_MATRIX: + z += tl.load(Zs) + if ADD_ROWS: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(Ws) + z = tl.dot(z.to(w.dtype), w, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + tl.store(Zs, z) + + # input + rs = RandomState(17) + if col_a: + x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T + else: + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + if col_b: + y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T + else: + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + w = numpy_random((N, N), dtype_str=in_dtype, rs=rs) + if 'int' not in in_dtype and 'float8' not in in_dtype: + x *= .1 + y *= .1 + if in_dtype == 'float32' and input_precision == "tf32": + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') + x_tri = to_triton(x, device=device, dst_type=in_dtype) + y_tri = to_triton(y, device=device, dst_type=in_dtype) + w_tri = to_triton(w, device=device, dst_type=in_dtype) + # triton result + if out_dtype == 'int8': + z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs) + else: + z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1 + + # Workaround: gcu400 does not support fp64 type + if in_dtype == 'float8e4nv' or in_dtype == 'float8e5': + z = z.astype(np.float32) + z_tri = to_triton(z, device=device) + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) + + if out_dtype == 'int8': + out_dtype = tl.int8 + elif out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + kern_kwargs = { + 'COL_A': col_a, 'COL_B': col_b, 'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N, 'ADD_MATRIX': + epilogue == 'add-matrix', 'ADD_ROWS': epilogue == 'add-rows', 'ADD_COLS': epilogue == 'add-cols', 'DO_SOFTMAX': + epilogue == 'softmax', 'CHAIN_DOT': epilogue == 'chain-dot', 'INPUT_PRECISION': input_precision, 'num_warps': + num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype + } + + if is_hip(): + kern_kwargs['kpack'] = kpack + if mma_nonk_size is not None: + kern_kwargs['matrix_instr_nonkdim'] = mma_nonk_size + + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) + + # torch result + if in_dtype == 'int8': + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) + elif 'float8' in in_dtype: + x = convert_fp8_to_fp32(x, device, in_dtype) + y = convert_fp8_to_fp32(y, device, in_dtype) + z_ref = to_numpy(torch.matmul(x, y)) + else: + z_ref = np.matmul(x, y) + + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:, 0][:, None] + if epilogue == 'add-cols': + z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + if 'float8' in in_dtype: + # Reduce z_ref's precision to fp8 to match the kernel behavior + if in_dtype == 'float8e4nv': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fn) + elif in_dtype == 'float8e5': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2) + elif in_dtype == 'float8e4b8': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fnuz) + elif in_dtype == 'float8e5b16': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2fnuz) + else: + assert "Unsupported float8 dtype" + z_ref = to_numpy(z_fp8.to(torch.float32)) + w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) + z_ref = np.matmul(z_ref, w) + # compare + if in_dtype == 'float32': + # XXX: Somehow there's a larger difference when we use float32 + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + elif out_dtype == tl.float16 or in_dtype == 'bfloat16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + # added atol, to loose precision for float16xfloat16->float32 case + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + if not is_cuda(): + return + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): + # XXX: skip small sizes because they are not vectorized + assert 'ld.global.v4' in ptx + if 'float8' in in_dtype: + assert 'st.global.v2' in ptx + else: + assert 'st.global.v4' in ptx + + is_tcgen5 = (capability[0] == 10) and (num_warps % 4) == 0 and (M % 64) == 0 and (N % 8) == 0 + + if in_dtype == 'float32' and input_precision != "ieee": + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float32: + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) + elif capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float16: + if is_tcgen5: + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) + elif capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) + elif in_dtype == 'int8': + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + elif in_dtype == "float8e5" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx + elif in_dtype == "float8e4nv" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx + + +# @pytest.mark.parametrize("M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, num_warps, mma, kpack", +# [(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, 4, mma, kpack) +# for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) +# for col_a, col_b in itertools.product([True, False], repeat=2) +# for rhs_scale in [False, True] +# for mxfp_type in ["e2m1", "e4m3", "e5m2"] +# for normal_type in ["e4m3", "e5m2", "bf16", "fp16"] +# for mma in (mma_nonk_sizes if is_hip() else [16]) +# for kpack in ([1, 2] if is_hip() else [1])]) +# def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, num_warps, mma, kpack, device): +# check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) +# if is_cuda(): +# cc = torch.cuda.get_device_capability() +# if cc < (8, 9): +# pytest.skip("float8e4nv not supported on CUDA < 8.9") +# if is_hip(): +# if not is_hip_cdna(): +# pytest.skip("scaled_dot only implemented for HIP CDNA") +# if "e4m3" in (mxfp_type, normal_type): +# if not (is_hip_mi300() or is_hip_mi350()): +# pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for MI300 and MI350") +# if mma == 16 and K == 64: +# pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") + +# @triton.jit +# def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out, +# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, +# type_b: tl.constexpr): +# DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 +# DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 +# PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A +# PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B +# a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, +# PACKED_BLOCK_K_A)[None, :] * stride_a1 +# b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, +# BLOCK_N)[None, :] * stride_b1 + +# a = tl.load(a_ptr) +# b = tl.load(b_ptr) +# SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 +# if a_scale is not None: +# scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, +# SCALE_BLOCK_K)[None, :] +# a_scale = tl.load(scale_a_ptr) +# if b_scale is not None: +# scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, +# SCALE_BLOCK_K)[None, :] +# b_scale = tl.load(scale_b_ptr) +# c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b) +# out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] +# tl.store(out_ptr, c.to(tl.bfloat16)) + +# @triton.jit +# def mxfp_upcast_kernel( +# x_ptr, +# scale_ptr, +# mxfp_ptr, +# N, +# e_bits: tl.constexpr, +# m_bits: tl.constexpr, +# to_type: tl.constexpr, +# BLOCK_SIZE: tl.constexpr, +# ): +# # x.shape == (N, 32) for fp8 or (N, 16) for fp4 +# # scale.shape == (N,) +# # out.shape == (N, 32) +# is_fp8: tl.constexpr = e_bits + m_bits == 7 +# # fp8: BLOCK_SIZE -> BLOCK_SIZE // 32, 32 +# # fp4: BLOCK_SIZE // 2 -> BLOCK_SIZE // 32 , 16 +# PARALLEL_DIM: tl.constexpr = BLOCK_SIZE // 32 +# LAST_DIM: tl.constexpr = 32 if is_fp8 else 16 +# LOAD_SIZE: tl.constexpr = LAST_DIM * PARALLEL_DIM + +# offsets = (tl.program_id(0) * LOAD_SIZE + tl.arange(0, PARALLEL_DIM)[:, None] * LAST_DIM + +# tl.arange(0, LAST_DIM)[None, :]) +# x = tl.load(x_ptr + offsets, mask=offsets < N * LAST_DIM) + +# offsets = tl.program_id(0) * PARALLEL_DIM + tl.arange(0, PARALLEL_DIM)[:, None] +# scale = tl.load(scale_ptr + offsets, mask=offsets < N) +# tl.static_assert(scale.dtype == tl.uint8) +# tl.static_assert(x.dtype == tl.uint8) + +# if to_type == tl.bfloat16: +# upcasted_scale = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) +# else: +# tl.static_assert(to_type == tl.float16) +# scale_fp32 = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) +# upcasted_scale = scale_fp32.to(tl.float16) + +# to_e_bits: tl.constexpr = 8 if to_type == tl.bfloat16 else 5 +# to_m_bits: tl.constexpr = 7 if to_type == tl.bfloat16 else 10 +# if is_fp8: +# if e_bits == 5 and m_bits == 2: +# x_f8 = x.to(tl.float8e5, bitcast=True) +# upcasted_x = x_f8.to(to_type) +# # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! +# non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits +# non_finite_mask_16bit: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits +# upcasted_x = tl.where( +# x & non_finite_mask == non_finite_mask, +# (upcasted_x.to(tl.uint16, bitcast=True) | non_finite_mask_16bit).to(to_type, bitcast=True), +# upcasted_x, +# ) +# else: +# tl.static_assert(e_bits == 4 and m_bits == 3) +# x_f8 = x.to(tl.float8e4nv, bitcast=True) +# upcasted_x = x_f8.to(to_type) +# else: +# to_bias: tl.constexpr = 127 if to_type == tl.bfloat16 else 15 +# to_point5: tl.constexpr = 16128 if to_type == tl.bfloat16 else 0x3800 +# # e2m1 +# em0 = x & 0x7 +# em1 = x & 0x70 +# x0 = (em0.to(tl.uint16) << (to_m_bits - 1)) | ((x & 0x8).to(tl.uint16) << 12) +# x1 = (em1.to(tl.uint16) << (to_m_bits - 1 - 4)) | ((x & 0x80).to(tl.uint16) << 8) +# # Three cases: +# # 1) x is normal and non-zero: Correct bias +# x0 = tl.where((em0 & 0x6) != 0, x0 + ((to_bias - 1) << to_m_bits), x0) +# x1 = tl.where((em1 & 0x60) != 0, x1 + ((to_bias - 1) << to_m_bits), x1) +# # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 +# x0 = tl.where(em0 == 0x1, to_point5 | (x0 & 0x8000), x0) +# x1 = tl.where(em1 == 0x10, to_point5 | (x1 & 0x8000), x1) +# # 3) x is zero, do nothing +# upcasted_x = tl.interleave(x0, x1).to(to_type, bitcast=True) +# # Multiplication preserves infs and NaNs in upcasted_x +# mxfp = upcasted_x * upcasted_scale +# # If scale is NaN, we encode it as an inf, so we need to correct for that +# mxfp = tl.where(scale == 0xFF, float("nan"), mxfp) + +# offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) +# tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) + +# def dot_scale_ref(x, scale_x, y, scale_y, type_x, type_y): + +# def upcast(v, scale, type, comp_dtype, transposed): +# if scale is None: +# type = { +# "e4m3": torch.float8_e4m3fn, +# "e5m2": torch.float8_e5m2, +# "bf16": torch.bfloat16, +# "fp16": torch.float16, +# }[type] +# return v.view(type).to(comp_dtype) +# e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type] +# # Packing is always on the K dimension so we transpose before upcasting then transpose back. +# if transposed: +# v = v.mT.contiguous() +# v = v.contiguous() +# v_upcast = v.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) +# N = v_upcast.numel() +# BLOCK_SIZE = 512 +# grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) +# comp_dtype = tl.float16 if comp_dtype == torch.float16 else tl.bfloat16 +# mxfp_upcast_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, comp_dtype, BLOCK_SIZE, +# num_warps=num_warps) +# assert v_upcast.isfinite().all() +# if transposed: +# v_upcast = v_upcast.mT +# return v_upcast + +# # Upcast to fp16 if one of the input is fp16 +# comp_dtype = torch.float16 if "fp16" in (type_x, type_y) else torch.bfloat16 + +# x_upcast = upcast(x, scale_x, type_x, comp_dtype, False) +# y_upcast = upcast(y, scale_y, type_y, comp_dtype, True) + +# class AccumulateInFp32: + +# def __enter__(self): +# self.prev_value = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction +# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + +# def __exit__(self, exc_type, exc_val, exc_tb): +# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value + +# with AccumulateInFp32(): +# return torch.matmul(x_upcast, y_upcast) + +# comp_dtype = torch.float16 if normal_type == "fp16" else torch.bfloat16 +# # The max exponent we use to initialize data in the x/y and associated scale tensor to avoid +# # overflow when scaling. +# comp_dtype_max_exp = 6 if normal_type == "fp16" else 15 + +# torch.manual_seed(0) + +# def make_arg(shape, ty, col_major=False): +# if col_major: +# shape = shape[:-2] + (shape[-1], shape[-2]) +# if ty == "bf16" or ty == "fp16": +# ret = torch.randn(shape, dtype=comp_dtype, device=device) +# # Clamp to avoid relative error issues +# ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1) +# else: +# if is_hip_mi350(): +# # On other chips, the A/B operands are upcasted to fp16/bf16 +# # before matmul, which has larger range to avoid overflow. +# # On MI350, we use the V_MFMA_*_F8F6F4 instructions to +# # directly calculate matmul on F8F6F4 data. So we need +# # to narrow down the range of input to avoid overflow. +# ret = torch.randint(20, 40, shape, dtype=torch.uint8, device=device) +# else: +# ret = torch.randint(256, shape, dtype=torch.uint8, device=device) +# if col_major: +# ret = ret.mT +# return ret + +# type_a = normal_type if rhs_scale else mxfp_type +# type_b = mxfp_type if rhs_scale else normal_type + +# DIV_FACTOR_A = 2 if type_a == "e2m1" else 1 +# DIV_FACTOR_B = 2 if type_b == "e2m1" else 1 +# x = make_arg((M, K // DIV_FACTOR_A), type_a, col_major=col_a) +# y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b) + +# min_scale, max_scale = (0, 142) if comp_dtype == torch.bfloat16 else (124, 131) +# scale_x = torch.randint(min_scale, max_scale + 1, (M, K // 32), dtype=torch.uint8, device=device) +# scale_y = torch.randint(min_scale, max_scale + 1, (N, K // 32), dtype=torch.uint8, device=device) +# if rhs_scale: +# scale_x = None +# else: +# scale_y = None + +# def make_finite(x, dtype): +# # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and +# # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme) +# if dtype not in ("e5m2", "e4m3"): +# return x +# if dtype == "e5m2" and comp_dtype == torch.float16: +# x = x & 0xB +# mask = 0x7C if dtype == "e5m2" else 0x7F +# finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask +# x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x) +# x.copy_(x_finite) +# return x + +# x = make_finite(x, type_a) +# y = make_finite(y, type_b) +# kernel_kwargs = {"num_warps": num_warps} +# if is_hip(): +# kernel_kwargs["kpack"] = kpack +# kernel_kwargs["matrix_instr_nonkdim"] = mma +# z = x.new_empty((M, N), dtype=comp_dtype) +# pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, +# **kernel_kwargs) +# z_ref = dot_scale_ref(x, scale_x, y, scale_y, type_a, type_b) +# # Bigger tolerance for AMD MI200 devices. +# # MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values +# # to zero. Detailed info is at: +# # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices +# atol = 2e-4 if is_hip_mi200() else 1e-5 +# rtol = 2e-2 if is_hip_mi200() else 1e-2 +# torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) + +# # make sure ld/st are vectorized +# if is_cuda(): +# ptx = pgm.asm['ptx'] +# if (max(M, N) * K) // (num_warps * 32) >= 4: +# assert 'ld.global.v4' in ptx +# if M * N // (num_warps * 32) >= 4: +# assert 'st.global.v4' in ptx +# assert (re.search(r'(mma|wgmma.mma_async).sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.(f|bf)16.(f|bf)16', ptx) +# or "tcgen05.mma.cta_group::1.kind::f16" in ptx) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str", + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 4, 8] + for num_warps in [1, 2, 4, 8, 16] + for BLOCK_M, BLOCK_N in [(32, 32)] + for M, N, K in [(64, 64, 64), (32, 32, 32)] + for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), + ('float16', 'float32'), ('float32', 'float32')]] + + # Large block sizes + [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] + + # Small block sizes + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 8] + for num_warps in [1, 2, 4] + for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)] + for M, N, K in [(32, 32, 32)] + for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]]) +def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device): + if is_hip() | (device == 'gcu'): + # hip does not support tf32 precision, so use ieee for all tests + input_precision = "ieee" + arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in arch or "gfx12" in arch: + if in_dtype_str == "float32": + pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") + if out_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + else: + input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" + if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16): + pytest.skip("small dots are supported only on HIP at the moment") + + if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": + if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties( + triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 131072: + pytest.skip( + "Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU." + ) + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel( + q_ptr, + k_ptr, + o_ptr, + stride_qb, + stride_qm, + stride_qk, + stride_kb, + stride_kk, + stride_kn, + stride_ob, + stride_om, + stride_on, + BLOCK_B: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + INPUT_PRECISION: tl.constexpr, + out_dtype: tl.constexpr = tl.float32, + ): + startm = tl.program_id(0) * BLOCK_M + startn = tl.program_id(1) * BLOCK_N + offs_b = tl.arange(0, BLOCK_B) + offs_m = startm + tl.arange(0, BLOCK_M) + offs_n = startn + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + q_ptrs = q_ptr + offs_b[:, None, None] * stride_qb + offs_m[None, :, None] * stride_qm + offs_k[ + None, None, :] * stride_qk + k_ptrs = k_ptr + offs_b[:, None, None] * stride_kb + offs_k[None, :, None] * stride_kk + offs_n[ + None, None, :] * stride_kn + q = tl.load(q_ptrs) + k = tl.load(k_ptrs) + qk = tl.dot(q, k, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + o_ptrs = o_ptr + offs_b[:, None, None] * stride_ob + offs_m[None, :, None] * stride_om + offs_n[ + None, None, :] * stride_on + tl.store(o_ptrs, qk) + + if out_dtype_str == 'int8': + out_dtype = tl.int8 + elif out_dtype_str == 'float16': + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + rs = RandomState(17) + x = numpy_random((B, M, K), dtype_str=in_dtype_str, rs=rs) + y = numpy_random((B, K, N), dtype_str=in_dtype_str, rs=rs) + if in_dtype_str == 'int8': + out = numpy_random((B, M, N), dtype_str='int32', rs=rs) + else: + if is_hip() and (BLOCK_M < 16 or BLOCK_N < 16) and out_dtype_str == 'float16': + # float16 accumulator in FMA dot loose precision too fast + x *= 0.1 + y *= 0.1 + out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) + + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + out_tri = to_triton(out, device=device) + + BLOCK_B = B + BLOCK_K = K + + grid = ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + kernel[grid]( + x_tri, + y_tri, + out_tri, + x_tri.stride(0), + x_tri.stride(1), + x_tri.stride(2), + y_tri.stride(0), + y_tri.stride(1), + y_tri.stride(2), + out_tri.stride(0), + out_tri.stride(1), + out_tri.stride(2), + BLOCK_B=BLOCK_B, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + INPUT_PRECISION=input_precision, + out_dtype=out_dtype, + num_warps=num_warps, + ) + + if in_dtype_str == 'int8': + out_ref = np.matmul(x.astype(np.float32), y.astype(np.float32)).astype(np.int32) + else: + out_ref = np.matmul(x, y) + np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) + + +@pytest.mark.parametrize('in_dtype', ['float32']) +def test_dot_mulbroadcasted(in_dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Requires sm >= 80 to run") + + @triton.jit + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): + pidn = tl.program_id(1) + pidm = tl.program_id(0) + offm = tl.arange(0, BM)[:, None] + offn = tl.arange(0, BN)[None, :] + offak = tl.arange(0, BK)[None, :] + offbk = tl.arange(0, BK)[:, None] + acc = tl.full((BM, BN), 0.0, tl.float32) + for ridx5 in range(0, K // BK): + x = tl.load(X + ((pidm * K * BM) + (offm * K) + (ridx5 * BK) + offak)) + y = tl.load(Y + ((pidn * BN) + (offbk * N) + (ridx5 * N * BK) + offn)) + x = tl.expand_dims(x, axis=2) + y = tl.expand_dims(y, axis=0) + t = tl.sum(x * y, axis=1) + acc = t + acc + tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + + M, N, K = 256, 192, 160 + BM, BN, BK = 128, 32, 32 + rs = RandomState(17) + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + x = x * 0.1 + y = y * 0.1 + z = numpy_random((M, N), dtype_str=in_dtype, rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(z, device=device) + grid = M // BM, N // BN + h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK) + z_ref = np.matmul(x, y) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) + + if not is_cuda(): + return + assert "tt.dot" in h.asm['ttir'] + assert re.search(r"ttg.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) +def test_full(dtype_str, shape, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): + # PyTorch only has unsigned 8, but not 16, 32, or 64 + dtype = getattr(torch, dtype_str[1:]) # uintx -> intx + else: + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel_static(out): + a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + @triton.jit + def kernel_dynamic(out, val, dtype: tl.constexpr): + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) + out_static = torch.zeros((128), dtype=dtype, device=device) + kernel_static_patched[(1, )](out_static) + assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) + assert torch.all(out_dynamic == 2) + + +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) +def test_constexpr(literal, dtype_str, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(out_ptr): + val = GENERATE_TEST_HERE + tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) + + kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched[(1, )](out) + assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None + + +@triton.jit +def pass_const(a, b, choose_b): + if choose_b: + return b + else: + return a + + +@pytest.mark.parametrize("choose_const", [True, False]) +@pytest.mark.parametrize("constexpr", [True, False]) +@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) +def test_const(device, choose_const, constexpr, mode): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit(do_not_specialize=["choose_const"]) + def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + @triton.jit + def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32, + BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + if mode == "direct": + if choose_const: + LOSE_TAIL = "final_out = c_out" + else: + LOSE_TAIL = "final_out = out" + elif mode == "call": + LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)" + elif mode == "ternary": + LOSE_TAIL = "final_out = c_out if choose_const else out" + elif mode == "if": + LOSE_TAIL = """ + if choose_const: + final_out = c_out + else: + final_out = out +""" + + SIZE = 128 + input = torch.randn((SIZE, ), dtype=torch.float32, device=device) + output = torch.zeros((SIZE, ), dtype=torch.float32, device=device) + patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''}) + + expect_fail = (not constexpr and mode != "direct") or choose_const + if expect_fail: + with pytest.raises(triton.CompilationError) as exc_info: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + if constexpr: + error = "Cannot store to a constant pointer" + else: + if mode == "call": + error = "Inconsistent return types" + elif mode == "if": + error = "Mismatched type for final_out" + elif mode == "ternary": + error = "Ternary expression with dynamic condition has inconsistent type" + else: + assert mode == "direct" and choose_const + error = "Cannot store to a constant pointer" + error_msg = exc_info.value.error_message or str(exc_info.value.__cause__) + assert error in error_msg, "Wrong error message!" + else: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + assert torch.all(input == output) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['float32', 'float16']) +def test_dot_without_load(dtype_str, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def _kernel(out): + a = GENERATE_TEST_HERE + b = GENERATE_TEST_HERE + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) + a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + out_ref = torch.matmul(a, b) + out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](out) + assert torch.all(out == out_ref) + + +# --------------- +# test arange +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_arange(start, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) + np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + + +# --------------- +# test load +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4] + for other in [0, 1]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = GENERATE_TEST_HERE + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None" + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) + + reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device))) + torch.testing.assert_close(output, reference_out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("mask_val", [True, False]) +@pytest.mark.parametrize("other_val", [0, 1]) +def test_masked_load_scalar(num_ctas, mask_val, other_val, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + input_val = 4.0 + size = 128 + dtype = torch.float32 + input = torch.full((size, ), input_val, dtype=dtype, device=device) + output = torch.zeros((size, ), dtype=dtype, device=device) + + @triton.jit + def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr): + offsets = tl.arange(0, size) + x = tl.load(in_ptr + offsets, mask=mask, other=other) + tl.store(out_ptr + offsets, x) + + kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas) + + if mask_val: + reference_out = torch.full((size, ), input_val, dtype=dtype, device=device) + else: + reference_out = torch.full((size, ), other_val, dtype=dtype, device=device) + + torch.testing.assert_close(output, reference_out) + + +# Testing masked loads with a copy to shared memory. +# FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_masked_load_shared_memory(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + M = 32 + N = 32 + K = 16 + + in1 = torch.rand((M, K), dtype=dtype, device=device) + in2 = torch.rand((K, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=dtype, device=device) + + @triton.jit + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + + M_offsets = tl.arange(0, M) + N_offsets = tl.arange(0, N) + K_offsets = tl.arange(0, K) + + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] + + # Load inputs. + x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K) + w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N) + + # Without a dot product the memory doesn't get promoted to shared. + o = tl.dot(x, w, out_dtype=tl.float32) + + # Store output + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] + tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) + + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) + + reference_out = torch.matmul(in1, in2) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) +def test_load_cache_modifier(cache, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + + if is_hip(): + target_arch = get_arch() + # TODO: support testing for remaining architectures + if 'gfx94' not in target_arch: + return + amdgcn = pgm.asm['amdgcn'] + cg_cache_modifier_str = 'nt' + cv_cache_modifier_str = 'sc0 sc1' + buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] + global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] + load_line = global_load_line[0] if global_load_line else buffer_load_line[0] + if cache == '' or cache == '.ca': + assert cg_cache_modifier_str not in load_line + if cache == '.cg': + assert cg_cache_modifier_str in load_line + if cache == '.cv': + assert cv_cache_modifier_str in load_line + + if is_cuda(): + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_vectorization(N, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + block_size = 1024 * num_ctas + src = torch.randn(block_size, device=device) + dst = torch.empty(block_size, device=device) + + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) + + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx + torch.testing.assert_close(dst[:N], src[:N], atol=1e-6, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("has_hints", [False, True]) +def test_vectorization_hints(has_hints, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + src = torch.empty(1024, device=device) + dst = torch.empty(1024, device=device) + off = torch.zeros(1, device=device, dtype=torch.int32) + + @triton.jit + def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = offsets + tl.load(off) + if HINT: + tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if has_hints: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.v4.b32" not in ptx + + +@pytest.mark.interpreter +def test_assume(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): + current_size = N - tl.program_id(0) * BLOCK_N + tl.assume(current_size >= BLOCK_N) + if current_size >= 128: + tl.store(out_ptr + tl.program_id(0), current_size) + else: + tl.store(out_ptr + tl.program_id(0), current_size + 101024) + + output = torch.zeros(1024 // 128, device=device) + pgm = _kernel[(1024 // 128, )](output, N=1024, BLOCK_N=128) + + if is_interpreter(): + return + + # assert 'llvm.assume' in pgm.asm['llir'] + assert 'llvm.intr.assume' in pgm.asm['llir'] + + +# --------------- +# test store +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) +def test_store_cache_modifier(cache, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, cache_modifier=CACHE) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + + if is_hip(): + target_arch = get_arch() + # TODO: support testing for remaining architectures + if 'gfx94' not in target_arch: + return + amdgcn = pgm.asm['amdgcn'] + cs_cache_modifier_str = 'nt' + wt_cache_modifier_str = 'sc0 sc1' + buffer_store_line = [line for line in amdgcn.splitlines() if "buffer_store" in line] + global_store_line = [line for line in amdgcn.splitlines() if "global_store" in line] + store_line = global_store_line[0] if global_store_line else buffer_store_line[0] + if cache == '' or cache == '.cg': + assert cs_cache_modifier_str not in store_line + assert wt_cache_modifier_str not in store_line + if cache == '.cs': + assert cs_cache_modifier_str in store_line + assert wt_cache_modifier_str not in store_line + if cache == '.wt': + assert cs_cache_modifier_str not in store_line + assert wt_cache_modifier_str in store_line + + if is_cuda(): + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + assert 'st.global.wt' not in ptx + if cache == '.wt': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("eviction_policy", ["", "evict_last", "evict_first"]) +def test_store_eviction_policy(eviction_policy, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, POLICY: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, eviction_policy=POLICY) + + if not is_cuda(): + return + pgm = _kernel[(1, )](dst, src, POLICY=eviction_policy) + ptx = pgm.asm['ptx'] + if eviction_policy == '': + assert 'evict_last' not in ptx + assert 'evict_first' not in ptx + if eviction_policy == 'evict_last': + assert 'evict_last' in ptx + assert 'evict_first' not in ptx + if eviction_policy == 'evict_first': + assert 'evict_last' not in ptx + assert 'evict_first' in ptx + + +# --------------- +# test default +# --------------- +# TODO: can't be local to test_default + + +@triton.jit +def _impl(value=10): + return value + + +@pytest.mark.interpreter +def test_default(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device=device) + ret1 = torch.zeros(1, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(ret0, ret1, value=3): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1, )](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + + _kernel[(1, )](ret0, ret1) + assert ret0.item() == 10 + assert ret1.item() == 3 + + +# --------------- +# test noop +# ---------------- + + +@pytest.mark.interpreter +def test_noop(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(x): + pass + + x = to_triton(numpy_random((1, ), dtype_str='int32'), device=device) + kernel[(1, )](x) + + +@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) +def test_pointer_arguments(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(x): + pass + + pin_memory = 'pinned' in device + x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) + if device == "cpu": + with pytest.raises(ValueError): + kernel[(1, )](x) + else: + kernel[(1, )](x) + + +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) +def test_value_specialization(value: int, value_type: str, device) -> None: + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + def repr(specialization): + ty = specialization.signature["value1"] + cst = '_'.join([k for k, v in specialization.constants.items() if isinstance(k, str) and v == 1]) + return f"kernel_{ty}_{cst}" + + @triton.jit(repr=repr) + def kernel(value1, is_one, X): + pass + + x = torch.tensor([3.14159], device=device) + h = kernel[(1, )](value, 1, x) + assert "is_one" in h.name + assert value_type in h.name + + +# -------------------- +# value specialization +# -------------------- + + +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) +def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + + if overflow: + with pytest.raises(OverflowError): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) + + +# ---------------- +# test constexpr +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + if op in ['<<', '>>', '&', '^', '|']: # int op + x_str = "3" if is_lhs_constexpr else "x" + y_str = "4" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="int32") + + # NOTE: bitshifting beyond bitwidth can lead to undefined behavior + if op in ['<<', '>>']: + y = numpy_random((1, ), dtype_str="int32", low=0, high=_bitwidth("int32")) + else: + y = numpy_random((1, ), dtype_str="int32") + else: + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) + + +@pytest.mark.interpreter +def test_constexpr_shape(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + + +@pytest.mark.interpreter +def test_constexpr_scalar_shape(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + +reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("formats", reshape_list) +def test_reshape(formats, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + in_format, out_format = formats + + @triton.jit + def kernel(Z, X, out_tuple: tl.constexpr): + x = tl.load(X_PTR_EXPR) + z = tl.reshape(x, out_tuple) + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + } + return patch_kernel(kernel, to_replace) + + x = numpy_random(in_format, dtype_str="int32") + z = x.reshape(out_format) + x_tri = to_triton(x, device=device) + patched_kernel = generate_kernel(in_format, out_format) + z_tri = to_triton(np.empty(out_format, dtype=np.int32), device=device) + patched_kernel[(1, )](z_tri, x_tri, out_format) + np.testing.assert_equal(z, to_numpy(z_tri)) + + +def test_reshape_err(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(): + x = tl.arange(0, 8 * 8) + y = tl.reshape(x, (8 * 4, )) + + with pytest.raises(triton.CompilationError) as exc_info: + kernel[(1, )]() + + assert "reshape" in str(exc_info.value) + + +@pytest.mark.interpreter +def test_tma_load_block_shape_err(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(ptr): + desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 32]) + desc.load([0, 0]) + + input = torch.empty((128, 128), dtype=torch.int32, device=device) + errc = triton.CompilationError if not is_interpreter() else InterpreterError + with pytest.raises(errc) as e: + kernel[(1, )](input) + + assert "tensor descriptor block shape must have at least 8 rows" in str(e.value.__cause__) + + +@pytest.mark.interpreter +def test_tma_store_block_shape_err(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(ptr): + desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 8]) + desc.store([0, 0], tl.zeros((1, 32), dtype=tl.int16)) + + input = torch.empty((128, 128), dtype=torch.int16, device=device) + errc = triton.CompilationError if not is_interpreter() else InterpreterError + with pytest.raises(errc) as e: + kernel[(1, )](input) + + assert "int16 tensor descriptor block shape must have at least 16 columns" in str(e.value.__cause__) + + +def test_trans_reshape(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr): + + in_block_ptr = tl.make_block_ptr( + base=in_base_ptr, + shape=(IN_SHAPE0, IN_SHAPE1), + strides=(IN_SHAPE1, 1), + offsets=(0, 0), + block_shape=(IN_SHAPE0, IN_SHAPE1), + order=(1, 0), + ) + x = tl.load(in_block_ptr) + x = tl.reshape(x, (32, 4, 4, 2)) + x = tl.permute(x, (1, 2, 3, 0)) + x = tl.reshape(x, (IN_SHAPE0 * IN_SHAPE1, )) + tl.store(out_base_ptr + tl.arange(0, IN_SHAPE0 * IN_SHAPE1), x) + + shape = (32, 32) + input = torch.arange(math.prod(shape), dtype=torch.int32, device=device).reshape(shape) + expected = torch.permute(input, (1, 0)) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) + + k = kernel[(1, )](input, actual, shape[0], shape[1]) + assert k.asm['ttgir'].count( + 'ttg.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# ------------- +# test call +# ------------- + + +@triton.jit +def val_multiplier(val, i): + return val * i + + +@triton.jit(noinline=True) +def val_multiplier_noinline(val, i): + return val * i + + +@triton.jit +def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + if type == "inline": + vec = val_multiplier(vec, i) + else: + vec = val_multiplier_noinline(vec, i) + tl.store(ptr + offsets, vec, mask=mask) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("type", ["inline", "noinline"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_call(type, num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): + vecmul_kernel(ptr, n_elements, num1, type) + vecmul_kernel(ptr, n_elements, num2, type) + + size = 1024 + rand_val = numpy_random((size, ), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device=device) + err_msg = "" + try: + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + except Exception as e: + err_msg = str(e) + + if type == "noinline" and not is_interpreter(): + assert err_msg != "" + else: + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) + + +# ------------- +# test if +# ------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) +def test_if(if_type, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): + pid = tl.program_id(0) + cond = tl.load(Cond) + if IfType == "if": + if pid % 2 == 0: # eq + tl.store(Ret, tl.load(XTrue)) + elif 1 == pid % 2: # req + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_constexpr": + val = 3.14 if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_void": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_dynamic": + if BoolVar and (1 != pid % 2 and pid % 2 != 1): # rne and ne + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticVaue != 0 and StaticVaue != 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device=device) + x_true = torch.tensor([3.14], dtype=torch.float32, device=device) + x_false = torch.tensor([1.51], dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) + + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) + assert torch.equal(ret, x_true) + + +def test_num_warps_pow2(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1, )](dst=dst, num_warps=3) + _kernel[(1, )](dst=dst, num_warps=1) + _kernel[(1, )](dst=dst, num_warps=2) + _kernel[(1, )](dst=dst, num_warps=4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) +def test_unary_math(func_str, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.FUNC_STR(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + kernel = patch_kernel(kernel, {'FUNC_STR': func_str}) + + shape = (128, ) + x = torch.randn(shape, dtype=torch.float32, device=device) + if func_str in ['sqrt', 'rsqrt']: + x = torch.abs(x) + if func_str in ['log', 'log2']: + x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device)) + y = torch.zeros(shape, dtype=torch.float32, device=device) + + kernel[(1, )](x, y, BLOCK=shape[0]) + torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3) + + +# ----------------------- +# test inline asm +# ----------------------- + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm(num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint32', rs=rs) + y = numpy_random(shape, dtype_str='uint32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + n = 17 + z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = (y << n) | (x >> (32 - n)) + # compare + np.testing.assert_equal(y_ref, to_numpy(z_tri)) + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm_packed(num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # shift 4x8bits values together. + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +@pytest.mark.parametrize('num_ctas', num_ctas_list) +def test_inline_asm_with_pointers(num_ctas, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x_ptrs = X + tl.arange(0, BLOCK) + y_ptrs = Y + tl.arange(0, BLOCK) + tl.inline_asm_elementwise( + "ld.global.b8 $0, [$1]; \ + shl.b32 $0, $0, 3; \ + st.global.b8 [$2], $0;", "=r,l,l", [x_ptrs, y_ptrs], dtype=tl.int8, is_pure=False, + pack=1) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +def test_inline_asm_multiple_outputs(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # C = A - B + # D = B - A + (c, d) = tl.inline_asm_elementwise( + asm=""" + sub.u32 $0, $2, $3; // C = A - B + sub.u32 $1, $3, $2; // D = B - A + """, + constraints=( + # 2 output registers: $0=C and $1=D. + "=r,=r," + # 2 input registers: $2=A and $3=B. + "r,r"), + args=[a, b], + dtype=(tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A - B + D_ref = B - A + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +def test_inline_asm_packed_multiple_outputs(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint8', rs=rs) + B = numpy_random(shape, dtype_str='float32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='int32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='float32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A.astype(np.int32) + D_ref = np.maximum(A.astype(np.float32), B) + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +# ----------------------- +# test control flow +# ----------------------- + + +@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) +def test_for_iv(lo, hi, iv, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Out, lo, hi, iv: tl.constexpr): + acc = 0 + acc = acc.to(tl.int64) + for i in range(lo, hi, iv): + acc += i + tl.store(Out, acc) + + lo = 2**35 + hi = 2**35 + 20 + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) + assert out[0] == sum(range(lo, hi, iv)) + + +@pytest.mark.interpreter +def test_if_else(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Cond, TrueVal, FalseVal, Out): + if tl.load(Cond): + val = tl.load(TrueVal) + else: + val = tl.load(FalseVal) + tl.store(Out, val) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # True + cond[0] = True + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == true_val[0] + # False + cond[0] = False + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == false_val[0] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return + tl.store(Out, 1) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # exit early path taken + exit_early[0] = 1 + kernel[(1, )](exit_early, out, True, mode) + assert to_numpy(out)[0] == 0 + # exit early path not taken + exit_early[0] = 0 + kernel[(1, )](exit_early, out, False, mode) + assert to_numpy(out)[0] == 1 + + +@triton.jit +def add_fn(x): + return x + 1 + + +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) +def test_if_call(call_type, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if call_type == "attribute": + # call attribute + if pid == 0: + a = o + a = a.to(tl.int32).to(tl.int32) + 1 + o = a + elif call_type == "attribute_jit": + # call attribute and jit function + if pid == 0: + a = o + a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1 + o = a + elif call_type == "jit": + if pid == 0: + # regular function call + a = o + a = add_fn(a) + o = a + elif call_type == "jit_if": + # function without end_if block + if pid == 0: + a = o + a = add_fn_return(a, pid) + o = a + elif call_type == "jit_if_exp": + # ifexp expression + if pid == 0: + a = o + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + o = a + elif call_type == "jit_expr": + # call without return + if pid == 0: + a = o + 1 + add_fn_expr(Out, a) + o = a + elif call_type == "jit_static_cond": + if pid == 0: + a = o + 1 + add_fn_static_cond(o, call_type) + o = a + elif call_type == "jit_noinline": + if pid == 0: + a = o + 1 + add_fn_noinline(a) + o = a + elif call_type == "jit_extern": + if pid == 0: + a = o + 1 + tl.cdiv(a, a) + o = a + + tl.store(Out, o) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) + assert to_numpy(out)[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("_cond1", [True, False]) +@pytest.mark.parametrize("_cond2", [True, False]) +@pytest.mark.parametrize("_cond3", [True, False]) +def test_nested_if_else_return(_cond1, _cond2, _cond3, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): + val = 0 + if tl.load(Cond1): + if tl.load(Cond2): + val = tl.load(Val1) + else: + return + else: + if tl.load(Cond3): + val = tl.load(Val2) + else: + val = tl.load(Val3) + tl.store(Out, val) + + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) + targets = { + (True, True, True): val1[0], + (True, True, False): val1[0], + (True, False, True): out[0], + (True, False, False): out[0], + (False, True, True): val2[0], + (False, True, False): val3[0], + (False, False, True): val2[0], + (False, False, False): val3[0], + } + assert out[0] == targets[(_cond1, _cond2, _cond3)] + + +@pytest.mark.interpreter +def test_while(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): + init_i = tl.load(InitI) + curr_i = init_i + j = 0 + # Check that init_i is not updated by the loop + while j < tl.load(Bound): + curr_i = curr_i + (j == tl.load(CutOff)) + j += 1 + tl.store(OutInitI, init_i) + tl.store(OutI, curr_i) + tl.store(OutJ, j) + + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) + assert out_init_i[0] == init_i[0] + assert out_i[0] == init_i[0] + 1 + assert out_j[0] == bound[0] + + +@pytest.mark.interpreter +def test_nested_while(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def nested_while(data, countPtr): + for i in range(10): + count = tl.load(countPtr) + while count > 0: + tl.store(data, tl.load(data) + 1.0) + count = count - 2 + + counter = torch.tensor([8], dtype=torch.int32, device=device) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) + assert data[0] == 40 + + +def test_constexpr_if_return(device): + # Reproducer for #4883, return statement in an if with a constexpr causes + # errors when combined with non-trivial control flow graphs + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Semaphore, Out, total: tl.constexpr): + if total == 1: + tl.store(Out, tl.program_id(0)) + return + + prev = tl.atomic_add(Semaphore, 1) + if prev + 1 != total: + return + + tl.store(Out, tl.program_id(0) + prev) + + sem = torch.zeros((), device=device, dtype=torch.int32) + out = torch.empty((), device=device, dtype=torch.int32) + kernel[(1, )](sem, out, 1) + assert out.item() == 0 + + sem = torch.zeros((), device=device, dtype=torch.int32) + out = torch.full((), fill_value=-1, device=device, dtype=torch.int32) + kernel[(4, )](sem, out, 4) + assert out.item() >= 0 + + +@triton.jit +def return_poison(x): + a = False + if a: + return x + + +def test_poison_return(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Out): + tl.store(Out, return_poison(0)) + + a = torch.empty((), device=device, dtype=torch.int32) + h = kernel[(1, )](a) + assert "ub.poison" in h.asm["ttir"], h.asm["ttir"] + # hip/xpu uses llvm.store, which in this case is removed by the optimizer + if not (is_hip() or is_xpu()): + assert "poison" in h.asm["llir"], h.asm["llir"] + + +# ----------------------- +# test extra +# ----------------------- + + +def test_num_threads(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + + +def test_globaltimer(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + pytest.skip("test_globaltimer is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out1, Out2): + start = tl.extra.cuda.globaltimer() + off = tl.arange(0, 128) + for i in range(10000): + tl.store(Out1 + off, tl.load(Out1 + off) + 1) + end = tl.extra.cuda.globaltimer() + tl.store(Out2, end - start) + + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + h = kernel[(1, )](out1, out2) + assert out2[0] > 0 + assert h.asm["ptx"].count("%globaltimer") == 2 + + +def test_smid(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + pytest.skip("test_smid is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out): + tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) + + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) + assert out.sort()[0].unique().shape[0] > 0 + assert h.asm["ptx"].count("%smid") == 1 + + +# ----------------------- +# test layout conversions +# ----------------------- +# TODO: backend should be tested separately + +layouts = [ + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + # BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 16], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([3, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 32, 16]), op_idx=0, k_width=1), + MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + SliceLayout( + dim=1, + parent=DotOperandLayout(parent=MmaLayout([3, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [16, 32, 16]), + op_idx=0, k_width=2)), + SliceLayout( + dim=1, parent=DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), + op_idx=1, k_width=2)), +] + +intermediate_layouts = [ + None, + SharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]), + SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +def compute_rep_shape(layout): + if type(layout) is BlockedLayout: + warp_shape = np.multiply(layout.sz_per_thread, layout.threads_per_warp) + rep_shape = np.multiply(warp_shape, layout.warps_per_cta) + return rep_shape + else: + assert False, "TODO: support compute_rep_shape for layout " + str(type(layout)) + + +# This function gives a lower bound approximation of scratch buffer shape for convert_layout operation +def compute_scratch_buffer_shape(src_layout, dst_layout, shape): + src_rep_shape = compute_rep_shape(src_layout) + dst_rep_shape = compute_rep_shape(dst_layout) + full_scratch_shape = np.maximum(src_rep_shape, dst_rep_shape) + return np.minimum(full_scratch_shape, shape) + + +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("interm_layout", intermediate_layouts) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if str(src_layout) == str(dst_layout): + pytest.skip() + if (isinstance(src_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)): + pytest.skip("DotOperandLayout <-> SharedLayout conversion is not completely supported") + if is_hip(): + try: + scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) + except AssertionError: + pytest.skip("Can't compute scratch buffer size") + lds_size = 65536 + # consider int32 dtype in scratch buffer size, + # because it is the largest dtype used in convert_layout in this test + int32_size = 4 + # skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding + if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size: + pytest.skip("Scratch buffer is too large") + + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + #smem = #ttg.shared_memory + """ if interm_layout is None else f""" + #src = {src_layout} + #interm = {interm_layout} + #dst = {dst_layout} + #smem = #ttg.shared_memory + """ + + conversion = f""" + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ if interm_layout is None else f""" + %15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #smem> + %16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #smem> -> tensor<{M}x{N}xi32, #src> + %17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #smem> + %18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #smem> -> tensor<{M}x{N}xf16, #src> + + %12 = ttg.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_convert2d.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + torch.testing.assert_close(z, x, rtol=0, atol=0) + + +layouts_3d = [ + BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), op_idx=0, + k_width=1), +] + +shared_layouts_3d = [ + SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), +] + + +@pytest.mark.parametrize("M, N, K", [[8, 16, 32]]) +@pytest.mark.parametrize("shared_layout", shared_layouts_3d) +@pytest.mark.parametrize("dist_layout", filter_layouts(layouts_3d)) +def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + layouts = f""" + #dist = {dist_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %cst_0 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_1 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_2 = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %0 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %4 = tt.addptr %3, %2 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %8 = arith.muli %7, %cst_2 : tensor<1x{N}x1xi32, #dist> + %9 = tt.broadcast %4 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %10 = tt.broadcast %8 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %11 = tt.addptr %9, %10 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %12 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %15 = arith.muli %14, %cst_1 : tensor<{M}x1x1xi32, #dist> + %16 = tt.broadcast %11 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + %19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> + %21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> -> tensor<{M}x{N}x{K}xi32, #dist> + %22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %25 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %26 = tt.addptr %25, %24 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %27 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %30 = arith.muli %29, %cst : tensor<1x{N}x1xi32, #dist> + %31 = tt.broadcast %26 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %32 = tt.broadcast %30 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %33 = tt.addptr %31, %32 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %34 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %37 = arith.muli %36, %cst_0 : tensor<{M}x1x1xi32, #dist> + %38 = tt.broadcast %33 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %39 = tt.broadcast %37 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %40 = tt.addptr %38, %39 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + tt.store %40, %21 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + tt.return + }} +}} +""" + + x = torch.arange(0, M * N * K, device=device, dtype=torch.int32).reshape(M, N, K) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + +dot_layouts = [ + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=4), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=1, k_width=4), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=0, k_width=1), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=1), +] + +shared_layouts = [ + SharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]), + SharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N", [[16, 32]]) +@pytest.mark.parametrize("dtype", ['float16', 'float8e5', 'float32']) +@pytest.mark.parametrize("shared_layout", shared_layouts) +@pytest.mark.parametrize("dist_layout", filter_layouts(dot_layouts)) +def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if dtype == "float32": + mlir_dtype = "f32" + elif dtype == "float16": + mlir_dtype = "f16" + elif dtype == "float8e5": + mlir_dtype = "f8E5M2" + + layouts = f""" + #dist = {dist_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> + %2 = tt.splat %arg0 : !tt.ptr<{mlir_dtype}> -> tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %3 = tt.splat %arg1 : !tt.ptr<{mlir_dtype}> -> tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{M}x1xi32, #dist> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #dist> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{N}xi32, #dist> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #dist> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>, tensor<{M}x{N}xi32, #dist> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %12 = ttg.local_alloc %11 : (tensor<{M}x{N}x{mlir_dtype}, #dist>) -> !ttg.memdesc<{M}x{N}x{mlir_dtype}, #shared, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<{M}x{N}x{mlir_dtype}, #shared, #smem> -> tensor<{M}x{N}x{mlir_dtype}, #dist> + %14 = tt.addptr %3, %9 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>, tensor<{M}x{N}xi32, #dist> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store_dot.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + +mma_layouts = [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 4 warps case + MmaLayout((3, 0), [8, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 8 warps case + MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # multiple warps on the row + MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # small instrN + MmaLayout((3, 0), [8, 4], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # large number of warps +] + +shared_layouts = [ + SharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + NVMMASharedLayout(64, False, 16, [1, 1], [1, 1], [0, 1]), + NVMMASharedLayout(128, False, 16, [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N", [[128, 128]]) +@pytest.mark.parametrize("mma_layout", filter_layouts(mma_layouts)) +@pytest.mark.parametrize("shared_layout", shared_layouts) +def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + num_warps = np.prod(mma_layout.warps_per_cta) + + layouts = f""" + #dist = {mma_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dist> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dist> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{M}x1xi32, #dist> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #dist> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{N}xi32, #dist> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #dist> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #dist>, tensor<{M}x{N}xi32, #dist> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #dist> + %12 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #dist>) -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<{M}x{N}xf16, #shared, #smem> -> tensor<{M}x{N}xf16, #dist> + %14 = tt.addptr %3, %9 : tensor<{M}x{N}x!tt.ptr, #dist>, tensor<{M}x{N}xi32, #dist> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dist> + tt.return + }} +}} +""" + + x = torch.arange(0, M * N, device=device, dtype=torch.float16).reshape(M, N) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store_mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + if isinstance(shared_layout, NVMMASharedLayout) and mma_layout.version[0] >= 3: + assert "stmatrix" in kernel.asm["ptx"] + + +def filter_layout_pairs(layout_pairs): + return [pair for pair in layout_pairs if is_layout_applicable(pair[0]) and is_layout_applicable(pair[1])] + + +mma_pairs = [ + [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + ], + [ + MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + ], + [ + MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), + ], + [ + WmmaLayout(1, [4, 4]), + WmmaLayout(1, [16, 1]), + ], + [ + WmmaLayout(1, [16, 1]), + WmmaLayout(1, [4, 4]), + ], + [ + WmmaLayout(2, [4, 4]), + WmmaLayout(2, [16, 1]), + ], + [ + WmmaLayout(2, [16, 1]), + WmmaLayout(2, [4, 4]), + ], + [ + MfmaLayout([2, 0], [2, 2], [32, 32], False), + MfmaLayout([2, 0], [4, 1], [32, 32], False), + ], + [ + MfmaLayout([2, 0], [4, 1], [32, 32], False), + MfmaLayout([2, 0], [2, 2], [32, 32], False), + ], + [ + MfmaLayout([2, 0], [2, 2], [32, 32], False), + MfmaLayout([2, 0], [4, 1], [32, 32], True), + ], + [ + MfmaLayout([2, 0], [4, 1], [32, 32], False), + MfmaLayout([2, 0], [2, 2], [32, 32], True), + ], + [ + MfmaLayout([2, 0], [4, 4], [16, 16], False), + MfmaLayout([2, 0], [16, 1], [16, 16], False), + ], + [ + MfmaLayout([2, 0], [16, 1], [16, 16], False), + MfmaLayout([2, 0], [4, 4], [16, 16], False), + ], + [ + MfmaLayout([2, 0], [4, 4], [16, 16], False), + MfmaLayout([2, 0], [16, 1], [16, 16], True), + ], + [ + MfmaLayout([2, 0], [16, 1], [16, 16], False), + MfmaLayout([2, 0], [4, 4], [16, 16], True), + ], +] + + +@pytest.mark.parametrize("M, N", [[16, 16], [64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("mma_pair", filter_layout_pairs(mma_pairs)) +def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + if isinstance(mma_pair[1], MfmaLayout) and (mma_pair[1].instr_shape[1] > M or mma_pair[1].instr_shape[1] > N): + pytest.skip("HIP do not fully support skinny tensor store") + + src_layout, _ = mma_pair + num_warps = np.prod(src_layout.warps_per_cta) + warp_size = THREADS_PER_WARP + + def do_test(src_layout, dst_layout): + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ + + ir = layouts + f""" + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {warp_size} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} + }} + """ + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x) + + temp_file = tmp_path / "test_convert_mma2mma.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + do_test(mma_pair[0], mma_pair[1]) + do_test(mma_pair[1], mma_pair[0]) + + +single_warp_layouts = [ + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 2, 2], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 4, 4], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 8, 8], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 16, 16], [1, 1], [1, 0]), + BlockedLayout([1, 1], [THREADS_PER_WARP // 32, 32], [1, 1], [1, 0]), + BlockedLayout([32, 1], [1, THREADS_PER_WARP], [1, 1], [1, 0]), + BlockedLayout([16, 1], [2, THREADS_PER_WARP // 2], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP, 1], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 2, 2], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 4, 4], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 8, 8], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 16, 16], [1, 1], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 1], [1, 0]), +] + + +@pytest.mark.parametrize("M, N", [[32, 32], [64, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", single_warp_layouts) +@pytest.mark.parametrize("dst_layout", single_warp_layouts) +def test_convert_warp_local(M, N, src_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if str(src_layout) == str(dst_layout): + pytest.skip() + if np.prod(src_layout.threads_per_warp) == 0 or np.prod(dst_layout.threads_per_warp) == 0: + pytest.skip() + + # Test layout pairs that are likely to codegen warp shuffles. + a, b = list(np.array(src_layout.threads_per_warp) // np.array(dst_layout.threads_per_warp)) + c = a if a != 0 else b + if c > 2: + pytest.skip() + + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + #smem = #ttg.shared_memory + """ + + conversion = f""" + %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"ttg.num-warps" = 1 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_convert_warp_local.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + torch.testing.assert_close(z, x, rtol=0, atol=0) + + +@pytest.mark.interpreter +def test_load_scalar_with_mask(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(Input, Index, Out, N: int): + index = tl.load(Index) + scalar = tl.load(Input + index, mask=index < N, other=0) + tl.store(Out, scalar, mask=index < N) + + Index = torch.tensor([0], dtype=torch.int32, device=device) + Input = torch.tensor([0], dtype=torch.int32, device=device) + Out = torch.empty_like(Index, device=device) + kernel[(1, )](Input, Index, Out, Index.numel()) + assert Out.data[0] == 0 + + +# This test is used to test our own PTX codegen for float16 and int16 conversions +# maybe delete it later after ptxas has been fixed +@pytest.mark.parametrize("dtype_str", ['float16', 'int16']) +def test_ptx_cast(dtype_str, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x0 = xindex + _tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype) + tmp1 = 2 + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(dtype) + tmp5 = _tmp4 < tmp3 + _tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4) + tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask) + + torch.manual_seed(123) + if dtype_str == 'int16': + torch_dtype = torch.int16 + triton_dtype = tl.int32 + else: + torch_dtype = torch.float16 + triton_dtype = tl.float32 + + s0 = 4 + buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) + buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + assert buf14.to(torch.float32).mean() == -2.0 + + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr, # + num_stages: tl.constexpr = 3 # +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + # TODO: http://jira.enflame.cn/browse/TCC-2285 + # offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + # offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask_a = (offs_am < M)[:, None] + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_b = (offs_bn < N)[None, :] + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): + a = tl.load(a_ptrs, mask_a) + b = tl.load(b_ptrs, mask_b) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N, K", [(128, 256, 256)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)]) +@pytest.mark.parametrize( + "in_type_str", + ['float8e5', 'float8e5b16', 'float8e4b8', 'float8e4nv'] if is_hip() else ['float8e5', 'float8e4nv', 'float8e4b15']) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + num_stages = 3 + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] >= 9 and in_type_str == "float8e4b15": + pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + elif is_hip(): + num_stages = 2 + if in_type_str in ("float8e5b16", "float8e4b8") and not is_hip_mi300(): + pytest.skip(f"{in_type_str} only supported on mi300") + if in_type_str in ("float8e5", "float8e4nv") and not is_hip_mi350(): + pytest.skip(f"{in_type_str} only supported on mi350") + + check_type_supported(in_type_str, device) + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + C = torch.empty((M, N), dtype=torch.float32, device=device) + num_warps = 4 + a = to_triton(A, device=device, dst_type=in_type_str) + b = to_triton(B, device=device, dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None + h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), + C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, + num_stages=num_stages) + torch_a = torch.from_numpy(A).to(device=device) + th_a = f8_to_f16(torch_a, in_type_str) + torch_b = torch.from_numpy(B).to(device=device) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + else: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + if is_cuda() and low_precision_acc > 0 and torch.cuda.get_device_capability()[0] == 9: + assert h.asm["ptx"].count("add.f32") == (BLOCK_M * BLOCK_N) // (32 * num_warps) * (BLOCK_K // low_precision_acc) + + +# ----------------------- +# test enable_fp_fusion +# ----------------------- + + +@pytest.mark.parametrize("enable_fp_fusion", [False, True]) +@pytest.mark.parametrize("default_override", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, default_override, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + pytest.skip( + 'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton' + ) + + # Sequential multiply add can be fused by backend + @triton.jit + def mul_add(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + if default_override: + os.environ["TRITON_DEFAULT_FP_FUSION"] = "1" if enable_fp_fusion else "0" + h = mul_add[(1, )](data) + else: + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) + + if not is_cuda(): + return + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion + + +# ----------------------- +# test override_arch +# ----------------------- + + +@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90"]) +@pytest.mark.parametrize("env_var_override", [False, True]) +def test_override_arch(arch, env_var_override, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if not is_cuda(): + pytest.skip('arch only for CUDA') + + @triton.jit + def simple(data, out): + in_ptrs = data + tl.arange(0, 128) + out_ptrs = out + tl.arange(0, 128) + tl.store(out_ptrs, tl.load(in_ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + out = torch.empty_like(data) + + if env_var_override: + os.environ["TRITON_OVERRIDE_ARCH"] = str(arch) + h = simple[(1, )](data, out) + os.environ.pop("TRITON_OVERRIDE_ARCH") + else: + h = simple[(1, )](data, out, arch=arch) + torch.testing.assert_close(data * 1.5 + 1.0, out) + ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"]) + assert ttgir_cc.group(1) == arch[2:] + + +# ----------------------- +# test propagate_nan +# ----------------------- + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +def test_propagate_nan(dtype, propagate_nan, func, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + + for mode in ['A', 'B', 'both']: + if func == 'clamp' and mode == 'B': + # clamp does not guarantee propagation from 'min' and 'max' args + continue + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'A' or mode == 'both': A[0] = torch.nan + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'B' or mode == 'both': B[0] = torch.nan + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) + + if mode == 'both' or propagate_nan == 'ALL': + assert torch.isnan(C[0]) + else: + assert not torch.isnan(C[0]) + + +# ----------------------- +# test clamp +# ----------------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + min = tl.load(min_ptr + off, mask=mask) + max = tl.load(max_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, min, max), mask=mask) + ref_val = tl.minimum(tl.maximum(x, min), max) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + a = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + b = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + min = torch.min(a, b) + max = torch.max(a, b) + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, min, max, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# Test for symmetric clamp(x, -limit, limit), as it may go through optimized +# codegen in the backends +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['bfloat16', 'float16', 'float32']) +def test_clamp_symmetric(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + limit = tl.load(limit_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, -limit, limit), mask=mask) + ref_val = tl.minimum(tl.maximum(x, -limit), limit) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + limit = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)).abs() + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, limit, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# ----------------------- +# test iterators +# ----------------------- + + +@pytest.mark.interpreter +def test_static_range(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): + acc = 0 + for i in tl.static_range(0, N, step=step): + acc += i + tl.store(Z, acc) + + N = 100 + step = 7 + Out = torch.empty(1, dtype=torch.int32, device=device) + loop_kernel[(1, )](Out, N, step) + Acc = torch.tensor([0], dtype=torch.int32, device=device) + for i in range(0, N, step): + Acc += i + assert (Out == Acc).all(), (Out, Acc) + + +@pytest.mark.interpreter +def test_tl_range_num_stages(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + pytest.skip("test_tl_range is not supported in HIP") + M, N, K = 64, 64, 512 + BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + a = torch.randn((M, K), device=device, dtype=torch.float16) + b = torch.randn((K, N), device=device, dtype=torch.float16) + c = torch.empty((M, N), dtype=torch.float32, device=device) + pgm = matmul_kernel[ + 1, + ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, 0, num_stages=5) + ref_out = torch.matmul(a, b).to(torch.float32) + if is_interpreter(): + # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. + # Thus we use a higher tolerance + torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) + else: + torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 6' in ptx + + +def test_tl_range_fuse(): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + pytest.skip("loop fusion is not enabled on AMD") + + @triton.jit + def kernel(ub): + for i in tl.range(0, ub, flatten=True): + for j in tl.range(0, ub): + print("i", i) + + compiled_kernel = kernel.warmup(10, grid=(1, )) + assert "tt.flatten" in compiled_kernel.asm["ttir"] + assert compiled_kernel.asm["ttgir"].count("scf.for") == 1 + + +def test_tl_range_option_none(): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(ub): + for i in tl.range(0, ub, num_stages=None, loop_unroll_factor=None): + print("i", i) + + compiled_kernel = kernel.warmup(10, grid=(1, )) + assert "num_stages" not in compiled_kernel.asm["ttir"] + assert "loop_unroll_factor" not in compiled_kernel.asm["ttir"] + + +@triton.jit(noinline=True) +def maxnreg_noinline1(X): + tl.store(X, 0) + + +@triton.jit(noinline=True) +def maxnreg_noinline2(X): + tl.store(X, 0) + + +@pytest.mark.interpreter +def test_maxnreg(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if not is_cuda(): + pytest.skip('maxnreg only works on CUDA') + + # triton kernel + @triton.jit + def kernel(X): + maxnreg_noinline1(X) + tl.store(X, 0) + maxnreg_noinline2(X) + + X = torch.empty(1, dtype=torch.int32, device=device) + k = kernel[(1, )](X, maxnreg=42) + + if not is_interpreter(): + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise + + +@pytest.mark.interpreter +def test_temp_var_in_loop(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): + acc = tl.full((BLOCK, ), 0, dtype=tl.int32) + for i in range(N): + if i == 0: + temp = tl.full((BLOCK, ), 2, dtype=tl.int32) + acc = temp + else: + acc += tl.full((BLOCK, ), 1, dtype=tl.int32) + # reuse the temp variable and make sure to check that it isn't creating incorrect IR. + temp = tl.full((BLOCK, ), 1, dtype=tl.int32) + acc += temp + z = Z + tl.arange(0, BLOCK) + tl.store(z, acc) + + N = 10 + BLOCK = 32 + out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) + temp_in_loop[(1, )](out, N, BLOCK) + acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) + for i in range(N): + if i == 0: + temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) + acc = temp + else: + acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + acc += temp + assert (acc == out).all() + + +@pytest.mark.interpreter +def test_num_programs(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + # Assuming that the kernel is launched with a grid of (11, 21, 31) + grid = (11, 21, 31) + input = torch.empty((3, ), dtype=torch.int32, device=device) + + @triton.jit + def kernel(input): + num_programs_0 = tl.num_programs(0) + num_programs_1 = tl.num_programs(1) + num_programs_2 = tl.num_programs(2) + tl.store(input, num_programs_0) + tl.store(input + 1, num_programs_1) + tl.store(input + 2, num_programs_2) + + kernel[grid](input) + assert torch.all(input == torch.tensor(grid, device=device)) + + +# ----------------------- +# test extern functions +# ----------------------- + + +@pytest.mark.parametrize("dtype_str", ['float32', 'float64']) +def test_math_extern(dtype_str, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_interpreter(): + pytest.skip('math_extern does not work in the interpreter mode') + + @triton.jit + def kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = libdevice.tanh(x) + tl.store(y_ptr + offsets, y, mask=mask) + + shape = (128, ) + rs = RandomState(17) + + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + y_ref = np.tanh(x) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str=dtype_str, rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, shape[0], BLOCK_SIZE=shape[0]) + # compare + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + + +# ----------------------- +# test loop unrolling +# ----------------------- + + +def test_unroll_attr(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def _kernel(dst, unroll_factor: tl.constexpr): + pid = tl.program_id(axis=0) + for i in tl.range(0, 10, loop_unroll_factor=unroll_factor): + tl.atomic_add(dst + pid, i + pid) + + def check_loop_unroll_count(ir, opStr, loop_unroll_factor): + for line in ir.splitlines(): + if opStr in line: + loop_unroll_factor = loop_unroll_factor - 1 + # Sometimes we get a remainder loop + assert loop_unroll_factor <= 0 + + # Try for all different loop unroll factors: + for unroll_factor in [1, 2, 4, 5, 8]: + h = _kernel[(1, )](torch.empty(1, device=device), unroll_factor) + check_loop_unroll_count(h.asm["ttir"], 'tt.atomic_rmw', unroll_factor) + + +@triton.jit +def sanitize_add(a, b): + a64 = a.to(tl.int64) + b64 = b.to(tl.int64) + r64 = a64 + b64 + tl.device_assert((r64 >= -2**31) & (r64 <= 2**31 - 1)) + return a + b + + +def test_side_effectful_reduction(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.reduce(vals, 0, sanitize_add) + tl.store(Z, z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros((), device="cuda", dtype=torch.int32) + sanitize_sum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.sum().to(torch.int32)) + + +@pytest.mark.parametrize("reduce_dim", [0, 1]) +def test_side_effectful_reduction_2d(device, reduce_dim): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr, + NON_REDUCE_DIM: tl.constexpr): + offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :] + vals = tl.load(X + offsets) + z = tl.reduce(vals, reduce_dim, sanitize_add) + tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z) + + BLOCK_0 = 16 + BLOCK_1 = 32 + NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32) + Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32) + sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim, + NON_REDUCE_DIM=NON_REDUCE_DIM) + torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) + + +def test_dtype(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(X): + dtype_x: tl.constexpr = X.dtype.element_ty + tl.static_assert(dtype_x == tl.int32) + tl.static_assert(dtype_x == tl.constexpr(tl.int32)) + tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32)) + + X = torch.zeros(1, dtype=torch.int32, device=device) + kernel[(1, )](X) + + +def test_side_effectful_scan(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.associative_scan(vals, 0, sanitize_add) + tl.store(Z + tl.arange(0, BLOCK), z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros_like(X) + sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) + + +# stress test slice layout usages in reductions. +@pytest.mark.parametrize("in_shape, perm, red_dims", [ + ((4, 32, 32, 4, 2), [2, 1, 0, 3, 4], [3, 1, 0]), + ((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]), +]) +def test_chained_reductions(in_shape, perm, red_dims, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def kernel(In, Out, # + dim_0: tl.constexpr, dim_1: tl.constexpr, dim_2: tl.constexpr, dim_3: tl.constexpr, dim_4: tl.constexpr, + perm_0: tl.constexpr, perm_1: tl.constexpr, perm_2: tl.constexpr, perm_3: tl.constexpr, + perm_4: tl.constexpr, red_dim_0: tl.constexpr, red_dim_1: tl.constexpr, red_dim_2: tl.constexpr): + idx = tl.arange(0, dim_0 * dim_1 * dim_2 * dim_3 * dim_4) + idx = idx.reshape(dim_0, dim_1, dim_2, dim_3, dim_4) + vals = tl.load(In + idx) + vals = tl.permute(vals, [perm_0, perm_1, perm_2, perm_3, perm_4]) + r = tl.sum(tl.sum(tl.sum(vals, red_dim_0), red_dim_1), red_dim_2) + st_idx = tl.arange(0, r.shape[0] * r.shape[1]).reshape(r.shape) + tl.store(Out + st_idx, r) + + input = torch.randint(0, 1000, in_shape, device=device, dtype=torch.int32) + temp = torch.permute(input, perm).contiguous() + ref = torch.sum(torch.sum(torch.sum(temp, dim=red_dims[0]), dim=red_dims[1]), dim=red_dims[2]) + result = torch.empty_like(ref) + kernel[(1, )](input, result, input.shape[0], input.shape[1], input.shape[2], input.shape[3], input.shape[4], + perm[0], perm[1], perm[2], perm[3], perm[4], red_dims[0], red_dims[1], red_dims[2]) + + assert torch.all(ref == result) + + +@triton.jit +def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ + ([4, 4], [8, 4], 0), + ([128, 64], [256, 64], 0), + ([128, 64], [128, 128], 1), +]) +def test_gather(src_shape, indices_shape, axis, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + + gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], + src.shape[1], src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], + indices.stride(0), indices.stride(1), output.shape[0], output.shape[1], + output.stride(0), output.stride(1)) + + return output + + src = torch.randn(src_shape, device=device) + indices = torch.randint(0, src.shape[axis], indices_shape, device=device) + ref = torch.gather(src, axis, indices) + result = triton_gather(src, axis, indices) + torch.testing.assert_close(result, ref, rtol=0, atol=0) + + +# These layouts are specially chosen to trigger the warp shuffle codegen. +@pytest.mark.parametrize("src_shape, indices_shape, axis, src_layout, indices_layout", [ + ([32, 16], [32, 16], 0, + "linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), + ([128, 64], [256, 64], 0, + "linear<{register = [[0, 2], [32, 0], [2, 0], [0, 16], [0, 32], [64, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), +]) +def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path, + device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + if is_hip(): + pytest.skip("warp-local gather has issues on HIP") + + def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + compiled = gather_test_kernel.warmup(src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), + src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), + output.stride(1), grid=(1, )) + return output, compiled + + def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout, idx_layout): + ir = f""" +#src_layout = #ttg.{src_layout} +#idx_layout = #ttg.{idx_layout} +{ir}""" + + dtypes = {torch.int32: "i32", torch.float32: "f32", torch.int64: "i64", torch.float64: "f64"} + + src_spec = f"{src.shape[0]}x{src.shape[1]}x{dtypes[src.dtype]}" + indices_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[indices.dtype]}" + output_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[src.dtype]}" + + pat = r"(%[0-9]+) = tt.gather (%[0-9]+)\[(%[0-9]+)\] {axis = " + pat += str(axis) + pat += r" : i32} : \(tensor\<" + pat += src_spec + pat += r", (#[a-z]+[0-9]+)\>, tensor\<" + pat += indices_spec + pat += r", (#[a-z]+[0-9]+)\>\) -> tensor\<" + pat += output_spec + pat += r", (#[a-z]+[0-9]+)\>" + + repl = r""" + %src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout> + %idx = ttg.convert_layout \3 : tensor<""" + indices_spec + r""", \5> -> tensor<""" + indices_spec + r""", #idx_layout> + %out = tt.gather %src[%idx] {axis = """ + str( + axis + ) + r""" : i32} : (tensor<""" + src_spec + r""", #src_layout>, tensor<""" + indices_spec + r""", #idx_layout>) -> tensor<""" + output_spec + r""", #idx_layout> + \1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>""" + return re.sub(pat, repl, ir) + + src = torch.randn(src_shape, device=device) + indices = torch.randint(0, src.shape[axis], indices_shape, device=device) + ref = torch.gather(src, axis, indices) + + output, compiled = prepare_kernel(src, axis, indices) + ir = compiled.asm["ttgir"] + ir = inject_layout(ir, src, axis, indices, src_layout, indices_layout) + + temp_file = tmp_path / "test_warp_gather.ttgir" + temp_file.write_text(ir) + + kernel = triton.compile(str(temp_file)) + assert ("nvvm.shfl.sync.idx" in kernel.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in kernel.asm["llir"]) + + kernel[(1, 1, 1)](src, indices, output) + + torch.testing.assert_close(output, ref, rtol=0, atol=0) + + +@triton.jit +def mul_jit_function(x, y): + return x * y + + +@triton.jit +def apply_binary_op(x, combine_op): + return combine_op(x, x) + + +def test_jit_function_arg(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def square_kernel_jit_function(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + in_data = tl.load(in_ptr + offsets) + out_data = apply_binary_op(in_data, mul_jit_function) # pass a JITFunction into another JITFunction + tl.store(out_ptr + offsets, out_data) + + BLOCK_SIZE = 16 + x = torch.full((BLOCK_SIZE, ), 3.0, device=device) + out = torch.empty((BLOCK_SIZE, ), device=device) + expect = torch.full((BLOCK_SIZE, ), 9.0, dtype=x.dtype, device=device) + + square_kernel_jit_function[(1, )](x, out, BLOCK_SIZE) + + torch.testing.assert_close(out, expect) + + +@pytest.mark.interpreter +def test_zero_strided_tensors(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def _simple_add( + X, + stride_x_a, + stride_x_b, + ): + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + + # doesn't directly index c dim, so relies on 0-strided c dim to affect every element + x_ptr = X + pid_a * stride_x_a + pid_b * stride_x_b + + tl.atomic_add(x_ptr, 1) + + x = torch.zeros((2, 2, 1), device=device) + c_dim = 3 + x = x.expand((2, 2, c_dim)) + + a, b, c = x.shape + grid = (a, b, c) + with torch.cuda.device(x.device.index): + _simple_add[grid](x, x.stride(0), x.stride(1)) + + assert torch.allclose(x, torch.ones_like(x) * c_dim) + + +@pytest.mark.interpreter +def test_aliasing(device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def aliasing_kernel(buffer, buffer2): + triton.language.store(buffer, 1) + + buffer = torch.zeros(1, device=device) + aliasing_kernel[(1, )](buffer, buffer) + assert buffer[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_strided_load(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def take_every_second_element(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr): + strided_offsets = tl.arange(0, BLOCK_SIZE) * 2 + linear_offsets = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + strided_offsets) + tl.store(output_ptr + linear_offsets, x) + + STRIDE = 2 + SIZE = 512 + OUT_SIZE = SIZE // STRIDE + + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + out_tri = torch.empty(OUT_SIZE, device=device) + take_every_second_element[(1, 1)](x_tri, out_tri, OUT_SIZE) + + # Test that every second element (starting from [0]) from x is stored in out_tri + np.testing.assert_allclose(x[::2], to_numpy(out_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_strided_store(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def store_into_every_second(x_ptr, output_ptr, BLOCK_SIZE: tl.constexpr): + strided_offsets = tl.arange(0, BLOCK_SIZE) * 2 + linear_offsets = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + linear_offsets) + tl.store(output_ptr + strided_offsets, x) + + STRIDE = 2 + SIZE = 512 + OUT_SIZE = SIZE * STRIDE + + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + out_tri = torch.zeros(OUT_SIZE, device=device) + store_into_every_second[(1, 1)](x_tri, out_tri, SIZE) + + # Test that every second element (starting from [0]) is the same as in x + np.testing.assert_allclose(x, to_numpy(out_tri)[::2]) + # Test that every second element (starting from [1]) is still zero + np.testing.assert_allclose(np.zeros_like(x), to_numpy(out_tri)[1::2]) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_indirect_load(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def indirect_load(offset_ptr, x_ptr, output_ptr, SIZE: tl.constexpr): + linear_offsets = tl.arange(0, SIZE) + offsets = tl.load(offset_ptr + linear_offsets) + x = tl.load(x_ptr + offsets) + tl.store(output_ptr + linear_offsets, x) + + SIZE = 512 + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + # Flip the range to load the tensor in reverse order + ptr = torch.arange(SIZE, device=device, dtype=torch.int32).flip(0) + out_tri = torch.empty(SIZE, device=device) + indirect_load[(1, 1)](ptr, x_tri, out_tri, SIZE) + + np.testing.assert_allclose(np.flip(x), to_numpy(out_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", list(dtypes) + ["bfloat16"]) +def test_indirect_store(dtype, device): + check_skip("test_core_skip.csv", inspect.currentframe().f_code.co_name, locals()) + + @triton.jit + def indirect_store(offset_ptr, x_ptr, output_ptr, SIZE: tl.constexpr): + linear_offsets = tl.arange(0, SIZE) + offsets = tl.load(offset_ptr + linear_offsets) + x = tl.load(x_ptr + linear_offsets) + tl.store(output_ptr + offsets, x) + + SIZE = 512 + x = numpy_random(SIZE, dtype_str=dtype) + x_tri = to_triton(x, device) + # Flip the range to store the tensor in reverse order + ptr = torch.arange(SIZE, device=device, dtype=torch.int32).flip(0) + out_tri = torch.empty(SIZE, device=device) + indirect_store[(1, 1)](ptr, x_tri, out_tri, SIZE) + + np.testing.assert_allclose(np.flip(x), to_numpy(out_tri)) diff --git a/third_party/enflame/python/test/unit/language/test_core_skip.csv b/third_party/enflame/python/test/unit/language/test_core_skip.csv new file mode 100644 index 000000000..bd50b6410 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_core_skip.csv @@ -0,0 +1,3090 @@ +case_name,resa,resb,case_identifier,arch +test_bin_op_1-int32-float32-+,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int32-dtype_y=float32-op=+-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-float32-+,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint32-dtype_y=float32-op=+-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-int32-+,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=int32-op=+-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-uint32-+,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=uint32-op=+-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-float32--,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int32-dtype_y=float32-op=--num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-float32--,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint32-dtype_y=float32-op=--num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-int32--,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=int32-op=--num_ctas=1-,gcu300 +test_bin_op_1-float32-uint32--,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=uint32-op=--num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-float32-*,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int32-dtype_y=float32-op=*-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-float32-*,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint32-dtype_y=float32-op=*-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-int32-*,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=int32-op=*-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-uint32-*,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=uint32-op=*-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int8-int8-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int8-dtype_y=int8-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int8-int16-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int8-dtype_y=int16-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int8-int32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int8-dtype_y=int32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int16-int8-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int16-dtype_y=int8-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int16-int32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int16-dtype_y=int32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-int8-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int32-dtype_y=int8-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-int16-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int32-dtype_y=int16-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-int32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int32-dtype_y=int32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-float32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int32-dtype_y=float32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint8-uint8-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint8-dtype_y=uint8-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint8-uint16-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint8-dtype_y=uint16-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint8-uint32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint8-dtype_y=uint32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint16-uint8-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint16-dtype_y=uint8-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint16-uint16-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint16-dtype_y=uint16-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint16-uint32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint16-dtype_y=uint32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-uint8-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint32-dtype_y=uint8-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-uint16-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint32-dtype_y=uint16-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-uint32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint32-dtype_y=uint32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-float32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint32-dtype_y=float32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-int32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=int32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-uint32-/,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=uint32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-float32-%,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=int32-dtype_y=float32-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-float32-%,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=uint32-dtype_y=float32-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-int32-%,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=int32-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-float32-uint32-%,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_dtype_x=float32-dtype_y=uint32-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_precise_math_1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32),dtype fp64 is not supported on gcu,Not supported on gcu,test_precise_math_expr_prec=tl.math.sqrt_rn(x)-expr_ref=tl.math.sqrt(x.to(tl.float64)).to(tl.float32)-num_ctas=1-,gcu300/gcu400/gcu410 +"test_precise_math_1-tl.math.div_rn(x,y)-(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)",dtype fp64 is not supported on gcu,Not supported on gcu,"test_precise_math_expr_prec=tl.math.div_rn(x,y)-expr_ref=(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)-num_ctas=1-",gcu300/gcu400/gcu410 +test_abs_fp8_in_dtype0, dtype fp8 is not supported on gcu300,Not supported on gcu300,test_abs_fp8_in_dtype=fp8e4b15-,gcu300/gcu400/gcu410 +test_abs_fp8_in_dtype1,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_abs_fp8_in_dtype=fp8e4nv-,gcu300 +test_abs_fp8_in_dtype2,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_abs_fp8_in_dtype=fp8e5-,gcu300 +test_atomic_rmw_add-float16-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_add-uint32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_add-int32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_add-float32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_add-uint64-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_neg-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_max-int32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_max-float32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_max-uint64-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_neg-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_min-int32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_min-float32-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_neg-sem=None-,gcu300 +test_atomic_rmw_min-uint64-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_neg-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_add-uint32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_add-int32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_add-float32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_add-uint64-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_neg-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_max-int32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_max-float32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_max-uint64-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_neg-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_min-int32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_min-float32-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_neg-sem=acquire-,gcu300 +test_atomic_rmw_min-uint64-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_neg-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_add-uint32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_add-int32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_add-float32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_add-uint64-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_neg-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_max-int32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_max-float32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_max-uint64-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_neg-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_min-int32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_min-float32-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_neg-sem=release-,gcu300 +test_atomic_rmw_min-uint64-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_neg-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_add-uint32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_add-int32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_add-float32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_add-uint64-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_neg-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_max-int32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_max-float32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_max-uint64-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_neg-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_min-int32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_min-float32-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_min-uint64-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_neg-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_add-uint32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_add-int32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_add-float32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_add-uint64-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_neg-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_max-int32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_max-float32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_max-uint64-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_neg-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_min-int32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_min-float32-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_neg-sem=relaxed-,gcu300 +test_atomic_rmw_min-uint64-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_neg-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_add-uint32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_add-int32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_add-float32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_add-uint64-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_pos-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_max-int32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_max-float32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_max-uint64-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_pos-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_min-int32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_min-float32-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_pos-sem=None-,gcu300 +test_atomic_rmw_min-uint64-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_pos-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_add-uint32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_add-int32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_add-float32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_add-uint64-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_pos-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_max-int32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_max-float32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_max-uint64-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_pos-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_min-int32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_min-float32-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_pos-sem=acquire-,gcu300 +test_atomic_rmw_min-uint64-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_pos-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_add-uint32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_add-int32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_add-float32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_add-uint64-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_pos-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_max-int32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_max-float32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_max-uint64-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_pos-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_min-int32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_min-float32-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_pos-sem=release-,gcu300 +test_atomic_rmw_min-uint64-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_pos-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_add-uint32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_add-int32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_add-float32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_add-uint64-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_pos-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_max-int32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_max-float32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_max-uint64-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_pos-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_min-int32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_min-float32-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_min-uint64-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_pos-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_add-uint32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_add-int32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_add-float32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_add-uint64-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-all_pos-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_max-int32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_max-float32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_max-uint64-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-all_pos-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_min-int32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_min-float32-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=all_pos-sem=relaxed-,gcu300 +test_atomic_rmw_min-uint64-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-all_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-all_pos-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=all_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_add-uint32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_add-int32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_add-float32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_add-uint64-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-min_neg-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_max-int32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_max-float32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_max-uint64-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-min_neg-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_min-int32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_min-float32-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=min_neg-sem=None-,gcu300 +test_atomic_rmw_min-uint64-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-min_neg-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-min_neg-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=min_neg-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_add-uint32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_add-int32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_add-float32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_add-uint64-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-min_neg-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_max-int32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_max-float32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_max-uint64-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-min_neg-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_min-int32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_min-float32-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=min_neg-sem=acquire-,gcu300 +test_atomic_rmw_min-uint64-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-min_neg-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-min_neg-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=min_neg-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_add-uint32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_add-int32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_add-float32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_add-uint64-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-min_neg-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_max-int32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_max-float32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_max-uint64-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-min_neg-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_min-int32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_min-float32-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=min_neg-sem=release-,gcu300 +test_atomic_rmw_min-uint64-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-min_neg-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-min_neg-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=min_neg-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_add-uint32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_add-int32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_add-float32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_add-uint64-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-min_neg-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_max-int32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_max-float32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_max-uint64-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-min_neg-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_min-int32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_min-float32-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=min_neg-sem=acq_rel-,gcu300 +test_atomic_rmw_min-uint64-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-min_neg-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-min_neg-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=min_neg-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_add-uint32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_add-int32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_add-float32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_add-uint64-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-min_neg-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_max-int32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_max-float32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_max-uint64-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-min_neg-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_min-int32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_min-float32-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=min_neg-sem=relaxed-,gcu300 +test_atomic_rmw_min-uint64-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-min_neg-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-min_neg-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=min_neg-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_add-uint32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_add-int32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_add-float32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_add-uint64-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-max_pos-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_max-int32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_max-float32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_max-uint64-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-max_pos-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_min-int32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_min-float32-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=max_pos-sem=None-,gcu300 +test_atomic_rmw_min-uint64-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-max_pos-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-max_pos-none,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=max_pos-sem=None-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_add-uint32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_add-int32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_add-float32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_add-uint64-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-max_pos-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_max-int32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_max-float32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_max-uint64-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-max_pos-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_min-int32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_min-float32-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=max_pos-sem=acquire-,gcu300 +test_atomic_rmw_min-uint64-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-max_pos-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-max_pos-acquire,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=max_pos-sem=acquire-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_add-uint32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_add-int32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_add-float32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_add-uint64-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-max_pos-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_max-int32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_max-float32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_max-uint64-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-max_pos-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_min-int32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_min-float32-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=max_pos-sem=release-,gcu300 +test_atomic_rmw_min-uint64-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-max_pos-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-max_pos-release,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=max_pos-sem=release-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_add-uint32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_add-int32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_add-float32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_add-uint64-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-max_pos-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_max-int32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_max-float32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_max-uint64-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-max_pos-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_min-int32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_min-float32-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=max_pos-sem=acq_rel-,gcu300 +test_atomic_rmw_min-uint64-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-max_pos-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-max_pos-acq_rel,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=max_pos-sem=acq_rel-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float16-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float16-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_add-uint32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_add-int32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_add-float32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=float32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_add-uint64-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=uint64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-int64-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=add-dtype_x_str=int64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_add-float64-max_pos-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=add-dtype_x_str=float64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-uint32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_max-int32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_max-float32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=float32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_max-uint64-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=uint64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-int64-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=max-dtype_x_str=int64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_max-float64-max_pos-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=max-dtype_x_str=float64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-uint32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_min-int32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_min-float32-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=float32-mode=max_pos-sem=relaxed-,gcu300 +test_atomic_rmw_min-uint64-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=uint64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-int64-max_pos-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_op=min-dtype_x_str=int64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_min-float64-max_pos-relaxed,dtype fp64 is not supported on gcu,Not supported on gcu,test_atomic_rmw_op=min-dtype_x_str=float64-mode=max_pos-sem=relaxed-,gcu300/gcu400/gcu410 +test_atomic_rmw_predicate_1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_rmw_predicate_num_ctas=1-,gcu300 +test_tensor_atomic_rmw_shape0-0-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(2, 2)-axis=0-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape1-1-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(2, 2)-axis=1-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape2-0-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(2, 8)-axis=0-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape3-1-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(2, 8)-axis=1-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape4-0-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(8, 2)-axis=0-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape5-1-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(8, 2)-axis=1-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape6-0-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(8, 8)-axis=0-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape7-1-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(8, 8)-axis=1-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape8-0-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(32, 32)-axis=0-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape9-1-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(32, 32)-axis=1-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape10-0-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(64, 64)-axis=0-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape11-1-1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,"test_tensor_atomic_rmw_shape=(64, 64)-axis=1-num_ctas=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_block_1,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_tensor_atomic_rmw_block_num_ctas=1-,gcu300/gcu400/gcu410 +test_atomic_cas_1-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_cas_sem=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_atomic_cas_1-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_cas_sem=acquire-num_ctas=1-,gcu300/gcu400/gcu410 +test_atomic_cas_1-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_cas_sem=release-num_ctas=1-,gcu300/gcu400/gcu410 +test_atomic_cas_1-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_cas_sem=acq_rel-num_ctas=1-,gcu300/gcu400/gcu410 +test_atomic_cas_1-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_atomic_cas_sem=relaxed-num_ctas=1-,gcu300/gcu400/gcu410 +test_tensor_atomic_cas_1-none,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_tensor_atomic_cas_sem=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_tensor_atomic_cas_1-acquire,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_tensor_atomic_cas_sem=acquire-num_ctas=1-,gcu300/gcu400/gcu410 +test_tensor_atomic_cas_1-release,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_tensor_atomic_cas_sem=release-num_ctas=1-,gcu300/gcu400/gcu410 +test_tensor_atomic_cas_1-acq_rel,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_tensor_atomic_cas_sem=acq_rel-num_ctas=1-,gcu300/gcu400/gcu410 +test_tensor_atomic_cas_1-relaxed,tt.atomic_xxx op is not supported on gcu300,Not supported on gcu300,test_tensor_atomic_cas_sem=relaxed-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float8_e5m2-float16-false-1024,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float8_e5m2-dtype_z=float16-bitcast=False-size=1024-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float8_e5m2-float16-false-32,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float8_e5m2-dtype_z=float16-bitcast=False-size=32-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float8_e5m2-float32-false-1024,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float8_e5m2-dtype_z=float32-bitcast=False-size=1024-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float8_e5m2-float32-false-32,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float8_e5m2-dtype_z=float32-bitcast=False-size=32-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float16-float8_e5m2-false-1024,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float16-dtype_z=float8_e5m2-bitcast=False-size=1024-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float16-float8_e5m2-false-32,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float16-dtype_z=float8_e5m2-bitcast=False-size=32-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float32-float8_e5m2-false-1024,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float32-dtype_z=float8_e5m2-bitcast=False-size=1024-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float32-float8_e5m2-false-32,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float32-dtype_z=float8_e5m2-bitcast=False-size=32-num_ctas=1-,gcu300/gcu400/gcu410 +test_load_store_same_ptr,grid size exceeds the limitation of gcu300(0xffff),Not supported on gcu,test_load_store_same_ptr_,gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape504-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape505-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape506-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape507-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape508-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape509-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape510-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape511-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape512-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape513-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape514-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape515-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape516-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape517-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape518-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape519-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape520-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape521-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape522-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape523-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape524-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape525-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape526-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape527-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape528-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape529-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape530-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape531-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape532-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape533-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape534-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape535-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape536-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape537-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape538-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape539-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape540-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape541-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape542-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape543-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape544-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape545-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape546-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape547-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape548-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape549-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape550-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape551-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape552-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape553-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape554-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape555-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape556-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape557-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape558-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape559-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape560-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape561-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape562-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape563-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape564-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape565-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape566-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape567-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape568-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape569-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape570-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape571-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape572-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape573-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape574-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape575-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape576-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape577-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape578-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape579-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape580-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape581-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape582-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape583-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape584-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape585-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape586-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape587-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape588-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape589-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape590-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape591-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape592-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape593-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape594-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape595-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape596-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape597-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape598-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape599-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape600-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape601-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape602-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape603-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape604-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape605-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape606-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape607-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape608-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape609-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape610-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape611-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape612-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape613-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape614-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape615-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape616-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape617-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape618-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape619-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape620-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape621-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape622-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape623-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape624-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape625-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape626-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape627-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape628-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape629-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape630-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape631-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape632-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape633-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape634-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape635-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape636-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape637-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape638-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape639-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape640-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape641-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape642-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape643-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape644-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape645-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape646-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape647-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape648-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape649-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape650-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape651-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape652-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape653-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape654-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape655-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape656-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape657-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape658-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape659-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape660-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape661-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape662-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape663-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape664-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape665-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-int32-shape666-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-int32-shape667-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-int32-shape668-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-int32-shape669-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape670-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-int32-shape671-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape672-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape673-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape674-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape675-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape676-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape677-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape678-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape679-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape680-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape681-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape682-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape683-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape684-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape685-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape686-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape687-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape688-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape689-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape690-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape691-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape692-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape693-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape694-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape695-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape696-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape697-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape698-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape699-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape700-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape701-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape702-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape703-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape704-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape705-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape706-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape707-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape708-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape709-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape710-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape711-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape712-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape713-1-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape714-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape715-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape716-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape717-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape718-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape719-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape720-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape721-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape722-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape723-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape724-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape725-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape726-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape727-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape728-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape729-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape730-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape731-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape732-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape733-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape734-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape735-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape736-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape737-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape738-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape739-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape740-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape741-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape742-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape743-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape744-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape745-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape746-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape747-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape748-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape749-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape750-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape751-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape752-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape753-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape754-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape755-1-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape756-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape757-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape758-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape759-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape760-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape761-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape762-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape763-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape764-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape765-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape766-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape767-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape768-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape769-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape770-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape771-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape772-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape773-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape774-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape775-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape776-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape777-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape778-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape779-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape780-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape781-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape782-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape783-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape784-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape785-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape786-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape787-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape788-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape789-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape790-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape791-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape792-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape793-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape794-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape795-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape796-0-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape797-0-true-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape798-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape799-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape800-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape801-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape802-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape803-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape804-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape805-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape806-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape807-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape808-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape809-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape810-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape811-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape812-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape813-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape814-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape815-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape816-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape817-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape818-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape819-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape820-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape821-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape822-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape823-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape824-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape825-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape826-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape827-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape828-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape829-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape830-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape831-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape832-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape833-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-float32-shape834-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumsum-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-float32-shape835-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=cumprod-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-float32-shape836-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=get_first_element-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_linear_recurrence-float32-shape837-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=linear_recurrence-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape838-0-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cummax-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-float32-shape839-0-false-16,num_warps exceeds the limitation of gcu300(8),Not supported on gcu,"test_scan2d_op=roll-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape840-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape841-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape842-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape845-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape846-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape847-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape848-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape851-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape852-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape853-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape854-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape857-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape858-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape859-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape860-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape863-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape864-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape865-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape866-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape869-1-true-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape870-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape871-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape872-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape875-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape876-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape877-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape878-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape881-1-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape882-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape883-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape884-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape887-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape888-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape889-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape890-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape893-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape894-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape895-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape896-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape899-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape900-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape901-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape902-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape905-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape906-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape907-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape908-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape911-1-false-16,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape912-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape913-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape914-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape917-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape918-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape919-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape920-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape923-1-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape924-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape925-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape926-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape929-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape930-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape931-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape932-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape935-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape936-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape937-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape938-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape941-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape942-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape943-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape944-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape947-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape948-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape949-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape950-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape953-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape954-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape955-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape956-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape959-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape960-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape961-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape962-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape965-0-true-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape966-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape967-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape968-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape971-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape972-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape973-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape974-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape977-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape978-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape979-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape980-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape983-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape984-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape985-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape986-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape989-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape990-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape991-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape992-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape995-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape996-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape997-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape998-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape1001-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumsum-bfloat16-shape1002-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumsum-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cumprod-bfloat16-shape1003-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=cumprod-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_get_first_element-bfloat16-shape1004-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=get_first_element-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_roll-bfloat16-shape1007-0-false-16,block dims not supported on gcu300,Not supported on gcu300,"test_scan2d_op=roll-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce1d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-expand_reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce1d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float32-expand_reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce1d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-float16-expand_reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce1d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-expand_reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce1d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float32-expand_reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float32-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce1d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce1d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout6-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout6-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout6-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout6-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout7-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout7-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout7-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout7-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[2, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout8-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout8-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout8-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout8-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 16, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout9-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=128-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout9-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout9-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout9-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=3, versionMinor=0, warpsPerCTA=[4, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[1, 0], instrShape=[16, 32, 16]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout0-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout0-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout0-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout0-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout1-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout1-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout1-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout1-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout2-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout2-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout2-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout2-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout3-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout3-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout3-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout3-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout4-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout4-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout4-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout4-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[4, 4], threadsPerWarp=[2, 16], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout5-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout5-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout5-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout5-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 2], threadsPerWarp=[4, 8], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout10-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout10-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout10-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout11-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout11-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout11-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout12-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout12-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout12-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = false}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout13-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout13-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout13-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [2, 2], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout14-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout14-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout14-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [4, 1], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout15-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout15-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout15-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_mfma<{versionMajor=2, versionMinor=0, warpsPerCTA = [1, 4], instrShape=[32, 32], isTransposed = true}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout16-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout16-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout16-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout16-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout17-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout17-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout17-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout17-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [4, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout18-64-64,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=64-N=64-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout18-32-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=128-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout18-32-32,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=32-N=32-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-float16-expand_reduce2d-1-src_layout18-16-16,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_reduce_layouts_M=16-N=16-src_layout=#triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=float16-reduce_op=max-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout0-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout0-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout0-256-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout0-128-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout1-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout1-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout1-256-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout1-128-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout2-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[1, 4], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout2-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[1, 4], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout2-256-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[1, 4], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout2-128-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[1, 4], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout3-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout3-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout3-256-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-sum-src_layout3-128-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout0-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout0-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout1-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout1-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout1-256-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout1-128-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 32], warpsPerCTA=[2, 2], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout3-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout3-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout3-256-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_0-max-src_layout3-128-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=0-",gcu300/gcu400/gcu410 +test_chain_reduce_1-sum-src_layout3-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=1-",gcu300/gcu400/gcu410 +test_chain_reduce_1-sum-src_layout3-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=1-",gcu300/gcu400/gcu410 +test_chain_reduce_1-sum-src_layout3-256-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=1-",gcu300/gcu400/gcu410 +test_chain_reduce_1-sum-src_layout3-128-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=sum-first_axis=1-",gcu300/gcu400/gcu410 +test_chain_reduce_1-max-src_layout3-128-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=1-",gcu300/gcu400/gcu410 +test_chain_reduce_1-max-src_layout3-256-128,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=128-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=1-",gcu300/gcu400/gcu410 +test_chain_reduce_1-max-src_layout3-256-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=256-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=1-",gcu300/gcu400/gcu410 +test_chain_reduce_1-max-src_layout3-128-256,error: 'tt.reduce' op element number mismatch: 1 vs,,"test_chain_reduce_M=128-N=256-src_layout=#triton_gpu.blocked<{sizePerThread=[1, 4], threadsPerWarp=[8, 4], warpsPerCTA=[2, 2], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-op=max-first_axis=1-",gcu300/gcu400/gcu410 +test_permute_1-float8e4b15-shape0-perm0, dtype fp8 is not supported on gcu300,Not supported on gcu300,"test_permute_dtype_str=float8e4b15-shape=(64, 64)-perm=(1, 0)-num_ctas=1-",gcu300/gcu400/gcu410 +test_permute_1-float8e4b15-shape1-perm1, dtype fp8 is not supported on gcu300,Not supported on gcu300,"test_permute_dtype_str=float8e4b15-shape=(128, 128)-perm=(1, 0)-num_ctas=1-",gcu300/gcu400/gcu410 +test_trans_2d_perm0-shape0-int8,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_trans_2d_dtype_str=int8-perm=(0, 1)-shape=(2, 4)-",gcu300/gcu400/gcu410 +test_trans_2d_perm0-shape1-int8,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_trans_2d_dtype_str=int8-perm=(0, 1)-shape=(16, 16)-",gcu300/gcu400/gcu410 +test_trans_2d_perm1-shape0-int8,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_trans_2d_dtype_str=int8-perm=(1, 0)-shape=(2, 4)-",gcu300/gcu400/gcu410 +test_trans_2d_perm1-shape1-int8,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_trans_2d_dtype_str=int8-perm=(1, 0)-shape=(16, 16)-",gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-none-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-none-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-trans-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-trans-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-add-matrix-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-add-matrix-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-add-rows-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-add-rows-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-add-cols-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-add-cols-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-softmax-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-softmax-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-chain-dot-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-false-false-chain-dot-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-none-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-none-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-trans-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-trans-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-add-matrix-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-add-matrix-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-add-rows-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-add-rows-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-add-cols-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-add-cols-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-softmax-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-softmax-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-chain-dot-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-32-4-false-false-chain-dot-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-none-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-none-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-trans-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-trans-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-add-matrix-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-add-matrix-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-add-rows-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-add-rows-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-add-cols-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-add-cols-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-softmax-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-softmax-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-chain-dot-tf32-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-16-16-16-4-false-false-chain-dot-tf32x3-float32-float32-1,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300 +test_dot_1-128-256-32-8-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300 +test_dot_1-128-256-32-8-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-256-32-8-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=256-K=32-num_warps=8-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-16-32-4-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=16-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-4-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-4-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-128-64-2-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-64-32-4-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=64-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-2-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=128-N=128-K=64-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-true-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=True-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-true-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-true-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-true-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-true-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-true-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-true-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-true-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-true-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=True-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-false-none-tf32-int8-int8-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-false-none-tf32-int8-int8-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=int8-out_dtype=int8-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-false-none-tf32-float16-float16-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-false-none-tf32-float16-float16-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float16-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-false-none-tf32-float16-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-false-none-tf32-float16-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float16-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-false-none-tf32-float32-float32-10,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-64-128-128-2-false-false-none-tf32-float32-float32-11,dtype tf32 is not supported on gcu,Not supported on gcu,test_dot_M=64-N=128-K=128-num_warps=2-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-false-false-chain-dot-ieee-float8e5-float32-1,attribute 'max_num_imprecise_acc_default' is not supported on gcu300,Not supported on gcu300,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=ieee-in_dtype=float8e5-out_dtype=float32-kpack=1-num_ctas=1-,gcu300 +test_dot_1-128-128-64-4-false-false-chain-dot-ieee-float8e4nv-float32-1,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=ieee-in_dtype=float8e4nv-out_dtype=float32-kpack=1-num_ctas=1-,gcu300 +test_max_num_imprecise_acc,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_max_num_imprecise_acc_,gcu300/gcu400/gcu410 +test_full_shape0-int64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_full_dtype_str=int64-shape=()-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_full_shape0-uint64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_full_dtype_str=uint64-shape=()-dtype=torch.uint64-,gcu300/gcu400/gcu410 +test_full_shape0-float64,dtype fp64 is not supported on gcu,Not supported on gcu,test_full_dtype_str=float64-shape=()-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_full_shape1-int64,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_full_dtype_str=int64-shape=(1,)-dtype=torch.int64-",gcu300/gcu400/gcu410 +test_full_shape1-uint64,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_full_dtype_str=uint64-shape=(1,)-dtype=torch.uint64-",gcu300/gcu400/gcu410 +test_full_shape1-float64,dtype fp64 is not supported on gcu,Not supported on gcu,"test_full_dtype_str=float64-shape=(1,)-dtype=torch.float64-",gcu300/gcu400/gcu410 +test_full_shape2-int64,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_full_dtype_str=int64-shape=(128,)-dtype=torch.int64-",gcu300/gcu400/gcu410 +test_full_shape2-uint64,dtype i64 is not supported on gcu300,Not supported on gcu300,"test_full_dtype_str=uint64-shape=(128,)-dtype=torch.uint64-",gcu300/gcu400/gcu410 +test_full_shape2-float64,dtype fp64 is not supported on gcu,Not supported on gcu,"test_full_dtype_str=float64-shape=(128,)-dtype=torch.float64-",gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-0-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=0-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-0-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=0-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-1-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=1-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-1-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=1-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-2-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=2-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-2-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=2-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-3-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=3-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-3-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=3-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-4-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=4-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-4-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=4-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-0-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=0-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-0-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=0-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-1-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=1-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-1-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=1-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-2-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=2-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-2-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=2-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-3-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=3-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-3-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=3-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-4-0,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=4-other=0-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-4-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=4-other=1-num_ctas=1-dtype=torch.int64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-0-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=0-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-0-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=0-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-1-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=1-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-1-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=1-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-2-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=2-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-2-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=2-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-3-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=3-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-3-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=3-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-4-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=4-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-4-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=128-size_diff=4-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-0-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=0-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-0-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=0-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-1-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=1-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-1-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=1-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-2-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=2-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-2-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=2-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-3-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=3-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-3-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=3-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-4-0,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=4-other=0-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-4-1,dtype fp64 is not supported on gcu,Not supported on gcu,test_masked_load_dtype_str=float64-size=512-size_diff=4-other=1-num_ctas=1-dtype=torch.float64-,gcu300/gcu400/gcu410 +test_pointer_arguments_cuda,CUDA is not supported on gcu,Not supported on gcu,test_pointer_arguments_,gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair0-float16-64-64,cuda specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=64-N=64-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair0-float16-128-128,cuda specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=128-N=128-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair0-float16-256-256,cuda specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=256-N=256-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair0-float16-1-64,cuda 111specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=1-N=64-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair1-float16-64-1,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=64-N=1-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair1-float16-1-64,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=1-N=64-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair1-float16-64-64,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=64-N=64-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair1-float16-128-128,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=128-N=128-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair1-float16-256-256,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=256-N=256-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair2-float16-64-1,cuda specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=64-N=1-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair2-float16-64-64,cuda specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=64-N=64-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair2-float16-128-128,cuda specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=128-N=128-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair2-float16-256-256,cuda specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=256-N=256-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair2-float16-1-64,cuda specific tests,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[1, 4], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=1-N=64-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair3-float16-64-1,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=64-N=1-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair3-float16-1-64,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=1-N=64-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair3-float16-64-64,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=64-N=64-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair3-float16-128-128,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=128-N=128-dtype=float16-",gcu300/gcu400/gcu410 +test_convertmma2mma_mma_pair3-float16-256-256,Mmalayout is not supported on gcu,Not supported on gcu,"test_convertmma2mma_mma_pair=['#triton_gpu.nvidia_mma<{warpsPerCTA=[2, 8], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>', '#triton_gpu.nvidia_mma<{warpsPerCTA=[8, 2], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>']-M=256-N=256-dtype=float16-",gcu300/gcu400/gcu410 +test_ptx_cast_float16,CUDA is not supported on gcu,Not supported on gcu,test_ptx_cast_dtype_str=float16-,gcu300/gcu400/gcu410 +test_ptx_cast_int16,CUDA is not supported on gcu,Not supported on gcu,test_ptx_cast_dtype_str=int16-,gcu300/gcu400/gcu410 +test_fp8_dot_acc_0-float8e5,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_fp8_dot_acc_in_type_str=float8e5-low_precision_acc=0-,gcu300/gcu400/gcu410 +test_fp8_dot_acc_0-float8e4b15,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_fp8_dot_acc_in_type_str=float8e4b15-low_precision_acc=0-,gcu300/gcu400/gcu410 +test_fp8_dot_acc_32-float8e5,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_fp8_dot_acc_in_type_str=float8e5-low_precision_acc=32-,gcu300/gcu400/gcu410 +test_fp8_dot_acc_32-float8e4b15,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_fp8_dot_acc_in_type_str=float8e4b15-low_precision_acc=32-,gcu300/gcu400/gcu410 +test_fp8_dot_acc_64-float8e5,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_fp8_dot_acc_in_type_str=float8e5-low_precision_acc=64-,gcu300/gcu400/gcu410 +test_fp8_dot_acc_64-float8e4b15,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_fp8_dot_acc_in_type_str=float8e4b15-low_precision_acc=64-,gcu300/gcu400/gcu410 +test_fp8_dot_acc_128-float8e5,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_fp8_dot_acc_in_type_str=float8e5-low_precision_acc=128-,gcu300/gcu400/gcu410 +test_fp8_dot_acc_128-float8e4b15,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_fp8_dot_acc_in_type_str=float8e4b15-low_precision_acc=128-,gcu300/gcu400/gcu410 +test_maxnreg,CUDA is not supported on gcu,Not supported on gcu,test_maxnreg_,gcu300/gcu400/gcu410 +test_cast_1-float8_e5m2-bfloat16-false-1024,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float8_e5m2-dtype_z=bfloat16-bitcast=False-size=1024-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-float8_e5m2-bfloat16-false-32,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=float8_e5m2-dtype_z=bfloat16-bitcast=False-size=32-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-bfloat16-float8_e5m2-false-1024,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=bfloat16-dtype_z=float8_e5m2-bitcast=False-size=1024-num_ctas=1-,gcu300/gcu400/gcu410 +test_cast_1-bfloat16-float8_e5m2-false-32,dtype fp8 is not supported on gcu300,Not supported on gcu300,test_cast_dtype_x=bfloat16-dtype_z=float8_e5m2-bitcast=False-size=32-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int16-float16-%,mismatch,,test_bin_op_dtype_x=int16-dtype_y=float16-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-bfloat16-add,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=int32-dtype_y=bfloat16-op=+-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-bfloat16-add,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=uint32-dtype_y=bfloat16-op=+-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-int32-add,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=int32-op=+-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-uint32-add,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=uint32-op=+-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-bfloat16--,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=int32-dtype_y=bfloat16-op=--num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-bfloat16--,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=uint32-dtype_y=bfloat16-op=--num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-int32--,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=int32-op=--num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-uint32--,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=uint32-op=--num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-bfloat16-mul,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=int32-dtype_y=bfloat16-op=*-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-bfloat16-mul,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=uint32-dtype_y=bfloat16-op=*-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-int32-mul,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=int32-op=*-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-uint32-mul,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=uint32-op=*-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-bfloat16-div,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=int32-dtype_y=bfloat16-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-bfloat16-div,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=uint32-dtype_y=bfloat16-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-int32-div,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=int32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-uint32-div,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=uint32-op=/-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-int32-bfloat16-mod,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=int32-dtype_y=bfloat16-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-uint32-bfloat16-mod,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=uint32-dtype_y=bfloat16-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-int32-mod,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=int32-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_1-bfloat16-uint32-mod,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_dtype_x=bfloat16-dtype_y=uint32-op=%-num_ctas=1-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_true-true-+,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_constexpr_op=+-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_true-true--,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_constexpr_op=--is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_true-true-*,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_constexpr_op=*-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_true-true-/,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_constexpr_op=/-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_true-true-%,dtype fp64 is not supported on gcu,Not supported on gcu,test_bin_op_constexpr_op=%-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_false-false-%,mismatch,,test_bin_op_constexpr_op=%-is_lhs_constexpr=False-is_rhs_constexpr=False-,gcu300/gcu400/gcu410 +test_store_op_src_layout2-32,Mmalayout is not supported on gcu,Not supported on gcu,"test_store_op_M=32-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-",gcu300/gcu400/gcu410 +test_store_op_src_layout2-64,Mmalayout is not supported on gcu,Not supported on gcu,"test_store_op_M=64-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-",gcu300/gcu400/gcu410 +test_store_op_src_layout2-128,Mmalayout is not supported on gcu,Not supported on gcu,"test_store_op_M=128-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-",gcu300/gcu400/gcu410 +test_store_op_src_layout2-256,Mmalayout is not supported on gcu,Not supported on gcu,"test_store_op_M=256-src_layout=#triton_gpu.nvidia_mma<{versionMajor=2, versionMinor=0, warpsPerCTA=[4, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1], instrShape=[16, 8]}>-",gcu300/gcu400/gcu410 +test_value_specialization_2147483648-i64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_value=2147483648-value_type=i64-,gcu300/gcu400/gcu410 +test_value_specialization_4294967295-i64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_value=4294967295-value_type=i64-,gcu300/gcu400/gcu410 +test_value_specialization_4294967296-i64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_value=4294967296-value_type=i64-,gcu300/gcu400/gcu410 +test_value_specialization_9223372036854775807-i64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_value=9223372036854775807-value_type=i64-,gcu300 +test_value_specialization_-9223372036854775808-i64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_value=-9223372036854775808-value_type=i64-,gcu300 +test_value_specialization_9223372036854775808-u64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_value=9223372036854775808-value_type=u64-,gcu300 +test_value_specialization_18446744073709551615-u64,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_value=18446744073709551615-value_type=u64-,gcu300 +test_value_specialization_overflow_18446744073709551615-False,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_overflow_value=18446744073709551615-overflow=False-,gcu300/gcu400/gcu410 +test_value_specialization_overflow_18446744073709551616-True,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_overflow_value=18446744073709551616-overflow=True-,gcu300/gcu400/gcu410 +test_value_specialization_overflow_-9223372036854775808-False,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_overflow_value=-9223372036854775808-overflow=False-,gcu300/gcu400/gcu410 +test_value_specialization_overflow_-9223372036854775809-True,dtype i64 is not supported on gcu300,Not supported on gcu300,test_value_specialization_overflow_value=-9223372036854775809-overflow=True-,gcu300/gcu400/gcu410 +test_for_iv_34359738368-34359738388-1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_for_iv_lo=34359738368-hi=34359738388-iv=1-,gcu300/gcu400/gcu410 +test_for_iv_34359738368-34359738388-2,dtype i64 is not supported on gcu300,Not supported on gcu300,test_for_iv_lo=34359738368-hi=34359738388-iv=2-,gcu300/gcu400/gcu410 +test_for_iv_34359738368-34359738388-3,dtype i64 is not supported on gcu300,Not supported on gcu300,test_for_iv_lo=34359738368-hi=34359738388-iv=3-,gcu300/gcu400/gcu410 +test_for_iv_15--16--1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_for_iv_lo=15-hi=-16-iv=-1-,gcu300/gcu400/gcu410 +test_for_iv_15--16--2,dtype i64 is not supported on gcu300,Not supported on gcu300,test_for_iv_lo=15-hi=-16-iv=-2-,gcu300/gcu400/gcu410 +test_for_iv_15--16--3,dtype i64 is not supported on gcu300,Not supported on gcu300,test_for_iv_lo=15-hi=-16-iv=-3-,gcu300/gcu400/gcu410 +test_for_iv_-18--22--1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_for_iv_lo=-18-hi=-22-iv=-1-,gcu300/gcu400/gcu410 +test_for_iv_22-18--1,dtype i64 is not supported on gcu300,Not supported on gcu300,test_for_iv_lo=22-hi=18-iv=-1-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_True-True-<<,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_constexpr_op=<<-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_True-True->>,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_constexpr_op=>>-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_True-True-&,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_constexpr_op=&-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_True-True-^,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_constexpr_op=^-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_True-True-|,dtype i64 is not supported on gcu300,Not supported on gcu300,test_bin_op_constexpr_op=|-is_lhs_constexpr=True-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_True-False-%,error: relocation R_DTU_ADDR16_LO_ICALL cannot be used against symbol tops_remf_fp32,,test_bin_op_constexpr_op=%-is_lhs_constexpr=False-is_rhs_constexpr=True-,gcu300/gcu400/gcu410 +test_bin_op_constexpr_False-True-%,error: relocation R_DTU_ADDR16_LO_ICALL cannot be used against symbol tops_remf_fp32,,test_bin_op_constexpr_op=%-is_lhs_constexpr=True-is_rhs_constexpr=False-,gcu300/gcu400/gcu410 +test_dot3d_1-16-64-64-64-32-32-int8-int8,block dims exceed the threshold,,test_dot3d_B=1-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=int8-out_dtype_str=int8-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_1-16-64-64-64-32-32-float16-float16,block dims exceed the threshold,,test_dot3d_B=1-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float16-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_1-16-64-64-64-32-32-float16-float32,block dims exceed the threshold,,test_dot3d_B=1-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_1-16-64-64-64-32-32-float32-float32,block dims exceed the threshold,,test_dot3d_B=1-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float32-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_1-16-32-32-32-32-32-int8-int8,block dims exceed the threshold,,test_dot3d_B=1-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=int8-out_dtype_str=int8-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_1-16-32-32-32-32-32-float16-float16,block dims exceed the threshold,,test_dot3d_B=1-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float16-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_1-16-32-32-32-32-32-float16-float32,block dims exceed the threshold,,test_dot3d_B=1-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_1-16-32-32-32-32-32-float32-float32,block dims exceed the threshold,,test_dot3d_B=1-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float32-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_2-16-64-64-64-32-32-int8-int8,block dims exceed the threshold,,test_dot3d_B=2-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=int8-out_dtype_str=int8-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_2-16-64-64-64-32-32-float16-float16,block dims exceed the threshold,,test_dot3d_B=2-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float16-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_2-16-64-64-64-32-32-float16-float32,block dims exceed the threshold,,test_dot3d_B=2-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_2-16-64-64-64-32-32-float32-float32,block dims exceed the threshold,,test_dot3d_B=2-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float32-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_2-16-32-32-32-32-32-int8-int8,block dims exceed the threshold,,test_dot3d_B=2-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=int8-out_dtype_str=int8-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_2-16-32-32-32-32-32-float16-float16,block dims exceed the threshold,,test_dot3d_B=2-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float16-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_2-16-32-32-32-32-32-float16-float32,block dims exceed the threshold,,test_dot3d_B=2-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_2-16-32-32-32-32-32-float32-float32,block dims exceed the threshold,,test_dot3d_B=2-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float32-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_4-16-64-64-64-32-32-int8-int8,block dims exceed the threshold,,test_dot3d_B=4-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=int8-out_dtype_str=int8-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_4-16-64-64-64-32-32-float16-float16,block dims exceed the threshold,,test_dot3d_B=4-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float16-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_4-16-64-64-64-32-32-float16-float32,block dims exceed the threshold,,test_dot3d_B=4-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_4-16-64-64-64-32-32-float32-float32,block dims exceed the threshold,,test_dot3d_B=4-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float32-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_4-16-32-32-32-32-32-int8-int8,block dims exceed the threshold,,test_dot3d_B=4-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=int8-out_dtype_str=int8-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_4-16-32-32-32-32-32-float16-float16,block dims exceed the threshold,,test_dot3d_B=4-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float16-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_4-16-32-32-32-32-32-float16-float32,block dims exceed the threshold,,test_dot3d_B=4-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_4-16-32-32-32-32-32-float32-float32,block dims exceed the threshold,,test_dot3d_B=4-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float32-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_8-16-64-64-64-32-32-int8-int8,block dims exceed the threshold,,test_dot3d_B=8-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=int8-out_dtype_str=int8-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_8-16-64-64-64-32-32-float16-float16,block dims exceed the threshold,,test_dot3d_B=8-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float16-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_8-16-64-64-64-32-32-float16-float32,block dims exceed the threshold,,test_dot3d_B=8-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_8-16-64-64-64-32-32-float32-float32,block dims exceed the threshold,,test_dot3d_B=8-num_warps=16-M=64-N=64-K=64-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float32-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_8-16-32-32-32-32-32-int8-int8,block dims exceed the threshold,,test_dot3d_B=8-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=int8-out_dtype_str=int8-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_8-16-32-32-32-32-32-float16-float16,block dims exceed the threshold,,test_dot3d_B=8-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float16-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_8-16-32-32-32-32-32-float16-float32,block dims exceed the threshold,,test_dot3d_B=8-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float16-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_dot3d_8-16-32-32-32-32-32-float32-float32,block dims exceed the threshold,,test_dot3d_B=8-num_warps=16-M=32-N=32-K=32-BLOCK_M=32-BLOCK_N=32-in_dtype_str=float32-out_dtype_str=float32-input_precision=ieee-arch=dtu-enflame-tops--gcu300-,gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-0-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce1d-1-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-0-src_layout2-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-reduce2d-1-src_layout2-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-0-src_layout2-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_sum-int32-True-expand_reduce2d-1-src_layout2-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=sum-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-0-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce1d-1-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce1d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-0-src_layout2-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-reduce2d-1-src_layout2-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-0-src_layout2-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=0-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout0-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout0-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout0-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout0-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout0-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout1-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout1-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout1-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout1-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout1-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[1, 1], warpsPerCTA=[4, 1], order=[0, 1], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout2-128-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout2-128-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=128-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout2-32-128,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=128-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout2-32-32,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=32-N=32-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_reduce_layouts_max-int32-True-expand_reduce2d-1-src_layout2-16-16,error: 'gcu.assert' op operand #0 must be 1-bit signless integer,,"test_reduce_layouts_M=16-N=16-src_layout=#ttg.linear<{register=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], warp=[[32, 0], [0, 32]], block=[]}>-axis=1-epilogue_kind=expand_reduce2d-dtype_str=int32-add_overflow_check=True-reduce_op=max-",gcu300/gcu400/gcu410 +test_full_shape0-int64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_full_dtype_str=int64-shape=()-,gcu300/gcu400/gcu410 +test_full_shape0-uint64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_full_dtype_str=uint64-shape=()-,gcu300/gcu400/gcu410 +test_full_shape0-float64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_full_dtype_str=float64-shape=()-,gcu300/gcu400/gcu410 +test_full_shape1-int64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_full_dtype_str=int64-shape=(1,)-",gcu300/gcu400/gcu410 +test_full_shape1-uint64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_full_dtype_str=uint64-shape=(1,)-",gcu300/gcu400/gcu410 +test_full_shape1-float64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_full_dtype_str=float64-shape=(1,)-",gcu300/gcu400/gcu410 +test_full_shape2-int64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_full_dtype_str=int64-shape=(128,)-",gcu300/gcu400/gcu410 +test_full_shape2-uint64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_full_dtype_str=uint64-shape=(128,)-",gcu300/gcu400/gcu410 +test_full_shape2-float64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_full_dtype_str=float64-shape=(128,)-",gcu300/gcu400/gcu410 +test_dot_1-64-64-64-4-False-False-none-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-none-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-trans-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-trans-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-add-matrix-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-add-matrix-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-add-rows-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-add-rows-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-add-cols-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-add-cols-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-softmax-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-softmax-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-chain-dot-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-64-64-64-4-False-False-chain-dot-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=64-N=64-K=64-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-none-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-none-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-trans-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-trans-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-add-matrix-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-add-matrix-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-add-rows-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-add-rows-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-add-cols-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-add-cols-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-softmax-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-softmax-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-chain-dot-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-32-4-False-False-chain-dot-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=32-N=32-K=32-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-none-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-none-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=none-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-trans-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-trans-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=trans-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-add-matrix-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-add-matrix-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-matrix-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-add-rows-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-add-rows-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-rows-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-add-cols-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-add-cols-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=add-cols-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-softmax-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-softmax-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=softmax-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-chain-dot-tf32-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-16-16-16-4-False-False-chain-dot-tf32x3-float32-float32-1-None,AssertionError: input_precision must be one of 'ieee' but got tf32,,test_dot_M=16-N=16-K=16-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=tf32x3-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-32-32-128-16-True-True-none-ieee-int8-int8-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=ieee-in_dtype=int8-out_dtype=int8-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-True-none-ieee-int8-int8-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=ieee-in_dtype=int8-out_dtype=int8-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-True-none-ieee-float16-float16-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float16-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-True-none-ieee-float16-float16-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float16-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-True-none-ieee-float16-float32-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-True-none-ieee-float16-float32-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-True-none-ieee-float32-float32-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-True-none-ieee-float32-float32-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-False-none-ieee-int8-int8-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=ieee-in_dtype=int8-out_dtype=int8-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-False-none-ieee-int8-int8-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=ieee-in_dtype=int8-out_dtype=int8-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-False-none-ieee-float16-float16-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float16-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-False-none-ieee-float16-float16-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float16-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-False-none-ieee-float16-float32-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-False-none-ieee-float16-float32-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-False-none-ieee-float32-float32-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-True-False-none-ieee-float32-float32-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=True-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-True-none-ieee-int8-int8-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=ieee-in_dtype=int8-out_dtype=int8-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-True-none-ieee-int8-int8-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=ieee-in_dtype=int8-out_dtype=int8-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-True-none-ieee-float16-float16-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float16-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-True-none-ieee-float16-float16-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float16-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-True-none-ieee-float16-float32-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-True-none-ieee-float16-float32-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-True-none-ieee-float32-float32-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-True-none-ieee-float32-float32-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=True-epilogue=none-input_precision=ieee-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-False-none-ieee-int8-int8-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=ieee-in_dtype=int8-out_dtype=int8-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-False-none-ieee-int8-int8-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=ieee-in_dtype=int8-out_dtype=int8-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-False-none-ieee-float16-float16-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float16-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-False-none-ieee-float16-float16-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float16-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-False-none-ieee-float16-float32-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-False-none-ieee-float16-float32-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float16-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-False-none-ieee-float32-float32-1-None0,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-32-32-128-16-False-False-none-ieee-float32-float32-1-None1,block dims exceed the threshold,,test_dot_M=32-N=32-K=128-num_warps=16-col_a=False-col_b=False-epilogue=none-input_precision=ieee-in_dtype=float32-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300/gcu400/gcu410 +test_dot_1-128-128-64-4-False-False-chain-dot-ieee-float8e5-float32-1-None,ValueError: type fp8e5 not supported in this architecture,,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=ieee-in_dtype=float8e5-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_dot_1-128-128-64-4-False-False-chain-dot-ieee-float8e4nv-float32-1-None,ValueError: type fp8e4nv not supported in this architecture,,test_dot_M=128-N=128-K=64-num_warps=4-col_a=False-col_b=False-epilogue=chain-dot-input_precision=ieee-in_dtype=float8e4nv-out_dtype=float32-kpack=1-mma_nonk_size=None-num_ctas=1-,gcu300 +test_scan2d_cummax-int32-shape28-1-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape70-1-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=int32-shape=(1024, 2)-axis=1-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape106-0-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape124-0-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape148-0-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=int32-shape=(2, 1024)-axis=0-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape166-0-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=int32-shape=(1, 1024)-axis=0-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape196-1-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape238-1-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=float32-shape=(1024, 2)-axis=1-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape274-0-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape292-0-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape316-0-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=float32-shape=(2, 1024)-axis=0-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-float32-shape334-0-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=float32-shape=(1, 1024)-axis=0-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape364-1-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape406-1-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape442-0-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape460-0-True-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=True-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape484-0-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape502-0-False-4,error: error: 'gcu.vector_step' op result #0 must be of ranks 1 of length 512/1024/2048vector of 1-bit signless integer or …… but got 'vector<64xi32>',,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=False-num_warps=4-",gcu300/gcu400/gcu410 +test_scan2d_cummax-int32-shape4-1-True-4,error: error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(8, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape10-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(16, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape16-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 16)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape22-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape34-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape40-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape46-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(8, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape52-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(16, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape58-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 16)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape64-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(2, 1024)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape76-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape82-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(1, 1024)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape88-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(8, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape94-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(16, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape100-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 16)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape112-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape118-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape130-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(8, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape136-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(16, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape142-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 16)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape154-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(1024, 2)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-int32-shape160-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=int32-shape=(32, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape172-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(8, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape178-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(16, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape184-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 16)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape190-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape202-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape208-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape214-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(8, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape220-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(16, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape226-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 16)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape232-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(2, 1024)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape244-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape250-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(1, 1024)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape256-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(8, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape262-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(16, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape268-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 16)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape280-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape286-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape298-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(8, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape304-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(16, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape310-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 16)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape322-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(1024, 2)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-float32-shape328-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=float32-shape=(32, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape340-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape346-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape352-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape358-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape370-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape376-1-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape382-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape388-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape394-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape400-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape412-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape418-1-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape424-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape430-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape436-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape448-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape454-0-True-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=True-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape466-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape472-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape478-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape490-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape496-0-False-4,error: failed to legalize operation 'arith.extsi' that was explicitly marked illegal,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=False-num_warps=4-",gcu300 +test_scan2d_cummax-bfloat16-shape844-1-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape850-1-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape856-1-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape862-1-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape868-1-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape874-1-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape880-1-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape886-1-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(8, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape892-1-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(16, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape898-1-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 16)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape904-1-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(2, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape910-1-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1024, 2)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape916-1-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 32)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape922-1-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1, 1024)-axis=1-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape928-0-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape934-0-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape940-0-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape946-0-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape952-0-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape958-0-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape964-0-True-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=True-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape970-0-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(8, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape976-0-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(16, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape982-0-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 16)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape988-0-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(2, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape994-0-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1024, 2)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape1000-0-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(32, 32)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_scan2d_cummax-bfloat16-shape1006-0-False-16,error: block dims exceed the threshold,,"test_scan2d_op=cummax-dtype_str=bfloat16-shape=(1, 1024)-axis=0-reverse=False-num_warps=16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_2-1-float16,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=2-num_ctas=1-dtype_x_str=float16-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_2-1-float32,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=2-num_ctas=1-dtype_x_str=float32-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_4-1-float16,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=4-num_ctas=1-dtype_x_str=float16-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_4-1-float32,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=4-num_ctas=1-dtype_x_str=float32-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_8-1-float16,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=8-num_ctas=1-dtype_x_str=float16-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_8-1-float32,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=8-num_ctas=1-dtype_x_str=float32-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_32-1-float16,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=32-num_ctas=1-dtype_x_str=float16-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_32-1-float32,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=32-num_ctas=1-dtype_x_str=float32-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_64-1-float16,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=64-num_ctas=1-dtype_x_str=float16-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_64-1-float32,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=64-num_ctas=1-dtype_x_str=float32-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_128-1-float16,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=128-num_ctas=1-dtype_x_str=float16-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_non_exclusive_offset_128-1-float32,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_tensor_atomic_add_non_exclusive_offset_size=128-num_ctas=1-dtype_x_str=float32-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-0-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=0-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-0-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=0-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-1-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=1-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-1-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=1-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-2-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=2-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-2-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=2-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-3-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=3-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-3-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=3-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-4-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=4-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-128-4-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=128-size_diff=4-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-0-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=0-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-0-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=0-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-1-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=1-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-1-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=1-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-2-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=2-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-2-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=2-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-3-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=3-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-3-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=3-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-4-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=4-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-int64-512-4-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=int64-size=512-size_diff=4-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-0-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=0-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-0-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=0-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-1-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=1-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-1-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=1-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-2-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=2-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-2-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=2-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-3-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=3-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-3-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=3-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-4-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=4-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-128-4-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=128-size_diff=4-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-0-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=0-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-0-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=0-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-1-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=1-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-1-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=1-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-2-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=2-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-2-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=2-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-3-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=3-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-3-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=3-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-4-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=4-other=0-num_ctas=1-,gcu300/gcu400/gcu410 +test_masked_load_1-float64-512-4-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_masked_load_dtype_str=float64-size=512-size_diff=4-other=1-num_ctas=1-,gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape0-increase-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape1-increase-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape2-increase-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape3-increase-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape4-increase-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape5-increase-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape6-increase-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape7-increase-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape8-decrease-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape9-decrease-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape10-decrease-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape11-decrease-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape12-decrease-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape13-decrease-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape14-decrease-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape15-decrease-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape16-random_no_duplication-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape17-random_no_duplication-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape18-random_no_duplication-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape19-random_no_duplication-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape20-random_no_duplication-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape21-random_no_duplication-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape22-random_no_duplication-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape23-random_no_duplication-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape24-random-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape25-random-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape26-random-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape27-random-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape28-random-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape29-random-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape30-random-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape31-random-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(2, 2)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape32-increase-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape33-increase-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape34-increase-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape35-increase-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape36-increase-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape37-increase-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape38-increase-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape39-increase-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape40-decrease-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape41-decrease-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape42-decrease-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape43-decrease-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape44-decrease-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape45-decrease-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape46-decrease-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape47-decrease-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape48-random_no_duplication-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape49-random_no_duplication-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape50-random_no_duplication-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape51-random_no_duplication-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape52-random_no_duplication-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape53-random_no_duplication-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape54-random_no_duplication-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape55-random_no_duplication-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape56-random-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape57-random-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape58-random-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape59-random-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape60-random-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape61-random-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape62-random-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape63-random-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(4, 4)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape64-increase-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape65-increase-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape66-increase-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape67-increase-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape68-increase-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape69-increase-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape70-increase-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape71-increase-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape72-decrease-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape73-decrease-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape74-decrease-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape75-decrease-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape76-decrease-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape77-decrease-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape78-decrease-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape79-decrease-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape80-random_no_duplication-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape81-random_no_duplication-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape82-random_no_duplication-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape83-random_no_duplication-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape84-random_no_duplication-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape85-random_no_duplication-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape86-random_no_duplication-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape87-random_no_duplication-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape88-random-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape89-random-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape90-random-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape91-random-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape92-random-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape93-random-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape94-random-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape95-random-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(5, 5)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape96-increase-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape97-increase-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape98-increase-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape99-increase-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape100-increase-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape101-increase-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape102-increase-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape103-increase-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape104-decrease-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape105-decrease-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape106-decrease-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape107-decrease-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape108-decrease-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape109-decrease-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape110-decrease-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape111-decrease-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape112-random_no_duplication-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape113-random_no_duplication-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape114-random_no_duplication-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape115-random_no_duplication-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape116-random_no_duplication-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape117-random_no_duplication-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape118-random_no_duplication-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape119-random_no_duplication-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape120-random-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape121-random-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape122-random-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape123-random-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape124-random-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape125-random-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape126-random-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape127-random-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(6, 6)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape128-increase-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape129-increase-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=increase-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape130-increase-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape131-increase-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=increase-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape132-increase-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape133-increase-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=increase-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape134-increase-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape135-increase-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=increase-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape136-decrease-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape137-decrease-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=decrease-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape138-decrease-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape139-decrease-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=decrease-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape140-decrease-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape141-decrease-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=decrease-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape142-decrease-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape143-decrease-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=decrease-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape144-random_no_duplication-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape145-random_no_duplication-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random_no_duplication-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape146-random_no_duplication-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape147-random_no_duplication-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random_no_duplication-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape148-random_no_duplication-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape149-random_no_duplication-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random_no_duplication-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape150-random_no_duplication-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape151-random_no_duplication-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random_no_duplication-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape152-random-1-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape153-random-1-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random-mask_step=1-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape154-random-2-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape155-random-2-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random-mask_step=2-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape156-random-3-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape157-random-3-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random-mask_step=3-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape158-random-4-1-float16,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float16-",gcu300/gcu400/gcu410 +test_tensor_atomic_add_access_patterns_shape159-random-4-1-float32,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_tensor_atomic_add_access_patterns_shape=(8, 8)-idx_order=random-mask_step=4-num_ctas=1-dtype_x_str=float32-",gcu300/gcu400/gcu410 +test_constexpr_if_return,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_constexpr_if_return_,gcu300/gcu400/gcu410 +test_slice,distutils.errors.DistutilsExecError: command '/usr/bin/x86_64-linux-gnu-gcc' failed with exit code 1,,test_slice_,gcu300/gcu400/gcu410 +test_tl_range_fuse,AssertionError,,test_tl_range_fuse_,gcu300/gcu400/gcu410 +test_math_extern_float64,ValueError: input arg type does not match,,test_math_extern_dtype_str=float64-,gcu300/gcu400/gcu410 +test_unroll_attr,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,test_unroll_attr_,gcu300/gcu400/gcu410 +test_chained_reductions_in_shape0-perm0-red_dims0,error: 64-bit datatype not supported on GCU300!,Not supported on gcu300,"test_chained_reductions_in_shape=(4, 32, 32, 4, 2)-perm=[2, 1, 0, 3, 4]-red_dims=[3, 1, 0]-",gcu300/gcu400/gcu410 +test_chained_reductions_in_shape1-perm1-red_dims1,error: 64-bit datatype not supported on GCU300!,Not supported on gcu300,"test_chained_reductions_in_shape=(8, 2, 32, 4, 16)-perm=[4, 0, 1, 3, 2]-red_dims=[0, 2, 0]-",gcu300/gcu400/gcu410 +test_local_load_store_dist_layout0-shared_layout0-8-16-32,RuntimeError: Parse MLIR file failed,,"test_local_load_store_M=8-N=16-K=32-dist_layout=#ttg.blocked<{sizePerThread=[4, 4, 1], threadsPerWarp=[1, 8, 0], warpsPerCTA=[2, 2, 1], order=[2, 1, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-shared_layout=#ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=1, order=[2, 1, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-",gcu300/gcu400/gcu410 +test_local_load_store_dist_layout0-shared_layout1-8-16-32,RuntimeError: Parse MLIR file failed,,"test_local_load_store_M=8-N=16-K=32-dist_layout=#ttg.blocked<{sizePerThread=[4, 4, 1], threadsPerWarp=[1, 8, 0], warpsPerCTA=[2, 2, 1], order=[2, 1, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-shared_layout=#ttg.swizzled_shared<{vec=4, perPhase=2, maxPhase=4, order=[1, 2, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-",gcu300/gcu400/gcu410 +test_local_load_store_dist_layout0-shared_layout2-8-16-32,RuntimeError: Parse MLIR file failed,,"test_local_load_store_M=8-N=16-K=32-dist_layout=#ttg.blocked<{sizePerThread=[4, 4, 1], threadsPerWarp=[1, 8, 0], warpsPerCTA=[2, 2, 1], order=[2, 1, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-shared_layout=#ttg.swizzled_shared<{vec=8, perPhase=2, maxPhase=4, order=[0, 2, 1], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-",gcu300/gcu400/gcu410 +test_local_load_store_dist_layout0-shared_layout3-8-16-32,RuntimeError: Parse MLIR file failed,,"test_local_load_store_M=8-N=16-K=32-dist_layout=#ttg.blocked<{sizePerThread=[4, 4, 1], threadsPerWarp=[1, 8, 0], warpsPerCTA=[2, 2, 1], order=[2, 1, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-shared_layout=#ttg.swizzled_shared<{vec=4, perPhase=2, maxPhase=1, order=[2, 0, 1], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-",gcu300/gcu400/gcu410 +test_local_load_store_dist_layout1-shared_layout0-8-16-32,RuntimeError: Parse MLIR file failed,,"test_local_load_store_M=8-N=16-K=32-dist_layout=#ttg.blocked<{sizePerThread=[1, 1, 4], threadsPerWarp=[8, 0, 1], warpsPerCTA=[2, 1, 2], order=[1, 2, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-shared_layout=#ttg.swizzled_shared<{vec=1, perPhase=1, maxPhase=1, order=[2, 1, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-",gcu300/gcu400/gcu410 +test_local_load_store_dist_layout1-shared_layout1-8-16-32,RuntimeError: Parse MLIR file failed,,"test_local_load_store_M=8-N=16-K=32-dist_layout=#ttg.blocked<{sizePerThread=[1, 1, 4], threadsPerWarp=[8, 0, 1], warpsPerCTA=[2, 1, 2], order=[1, 2, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-shared_layout=#ttg.swizzled_shared<{vec=4, perPhase=2, maxPhase=4, order=[1, 2, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-",gcu300/gcu400/gcu410 +test_local_load_store_dist_layout1-shared_layout2-8-16-32,RuntimeError: Parse MLIR file failed,,"test_local_load_store_M=8-N=16-K=32-dist_layout=#ttg.blocked<{sizePerThread=[1, 1, 4], threadsPerWarp=[8, 0, 1], warpsPerCTA=[2, 1, 2], order=[1, 2, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-shared_layout=#ttg.swizzled_shared<{vec=8, perPhase=2, maxPhase=4, order=[0, 2, 1], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-",gcu300/gcu400/gcu410 +test_local_load_store_dist_layout1-shared_layout3-8-16-32,RuntimeError: Parse MLIR file failed,,"test_local_load_store_M=8-N=16-K=32-dist_layout=#ttg.blocked<{sizePerThread=[1, 1, 4], threadsPerWarp=[8, 0, 1], warpsPerCTA=[2, 1, 2], order=[1, 2, 0], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-shared_layout=#ttg.swizzled_shared<{vec=4, perPhase=2, maxPhase=1, order=[2, 0, 1], CTAsPerCGA=[1, 1, 1], CTASplitNum=[1, 1, 1], CTAOrder=[0, 1, 2]}>-",gcu300 +test_gather_warp_shuffle_src_shape0-indices_shape0-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_gather_warp_shuffle_src_shape=[32, 16]-indices_shape=[32, 16]-axis=0-",gcu300/gcu400/gcu410 +test_gather_warp_shuffle_src_shape1-indices_shape1-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_gather_warp_shuffle_src_shape=[128, 64]-indices_shape=[256, 64]-axis=0-",gcu300/gcu400/gcu410 +test_dot_max_num_imprecise_acc_0-float8e5-128-256-128-128-256-256,ValueError: type fp8e5 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=128-BLOCK_N=256-BLOCK_K=128-in_type_str=float8e5-low_precision_acc=0-,gcu300 +test_dot_max_num_imprecise_acc_0-float8e5-64-64-64-128-256-256,ValueError: type fp8e5 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=64-BLOCK_N=64-BLOCK_K=64-in_type_str=float8e5-low_precision_acc=0-,gcu300 +test_dot_max_num_imprecise_acc_0-float8e4b15-128-256-128-128-256-256,ValueError: type fp8e4b15 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=128-BLOCK_N=256-BLOCK_K=128-in_type_str=float8e4b15-low_precision_acc=0-,gcu300/gcu400/gcu410 +test_dot_max_num_imprecise_acc_0-float8e4b15-64-64-64-128-256-256,ValueError: type fp8e4b15 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=64-BLOCK_N=64-BLOCK_K=64-in_type_str=float8e4b15-low_precision_acc=0-,gcu300/gcu400/gcu410 +test_dot_max_num_imprecise_acc_32-float8e5-128-256-128-128-256-256,ValueError: type fp8e5 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=128-BLOCK_N=256-BLOCK_K=128-in_type_str=float8e5-low_precision_acc=32-,gcu300 +test_dot_max_num_imprecise_acc_32-float8e5-64-64-64-128-256-256,ValueError: type fp8e5 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=64-BLOCK_N=64-BLOCK_K=64-in_type_str=float8e5-low_precision_acc=32-,gcu300 +test_dot_max_num_imprecise_acc_32-float8e4b15-128-256-128-128-256-256,ValueError: type fp8e4b15 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=128-BLOCK_N=256-BLOCK_K=128-in_type_str=float8e4b15-low_precision_acc=32-,gcu300/gcu400/gcu410 +test_dot_max_num_imprecise_acc_32-float8e4b15-64-64-64-128-256-256,ValueError: type fp8e4b15 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=64-BLOCK_N=64-BLOCK_K=64-in_type_str=float8e4b15-low_precision_acc=32-,gcu300/gcu400/gcu410 +test_dot_max_num_imprecise_acc_64-float8e5-128-256-128-128-256-256,ValueError: type fp8e5 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=128-BLOCK_N=256-BLOCK_K=128-in_type_str=float8e5-low_precision_acc=64-,gcu300 +test_dot_max_num_imprecise_acc_64-float8e5-64-64-64-128-256-256,ValueError: type fp8e5 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=64-BLOCK_N=64-BLOCK_K=64-in_type_str=float8e5-low_precision_acc=64-,gcu300 +test_dot_max_num_imprecise_acc_64-float8e4b15-128-256-128-128-256-256,ValueError: type fp8e4b15 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=128-BLOCK_N=256-BLOCK_K=128-in_type_str=float8e4b15-low_precision_acc=64-,gcu300/gcu400/gcu410 +test_dot_max_num_imprecise_acc_64-float8e4b15-64-64-64-128-256-256,ValueError: type fp8e4b15 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=64-BLOCK_N=64-BLOCK_K=64-in_type_str=float8e4b15-low_precision_acc=64-,gcu300/gcu400/gcu410 +test_dot_max_num_imprecise_acc_128-float8e5-128-256-128-128-256-256,ValueError: type fp8e5 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=128-BLOCK_N=256-BLOCK_K=128-in_type_str=float8e5-low_precision_acc=128-,gcu300 +test_dot_max_num_imprecise_acc_128-float8e5-64-64-64-128-256-256,ValueError: type fp8e5 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=64-BLOCK_N=64-BLOCK_K=64-in_type_str=float8e5-low_precision_acc=128-,gcu300 +test_dot_max_num_imprecise_acc_128-float8e4b15-128-256-128-128-256-256,ValueError: type fp8e4b15 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=128-BLOCK_N=256-BLOCK_K=128-in_type_str=float8e4b15-low_precision_acc=128-,gcu300/gcu400/gcu410 +test_dot_max_num_imprecise_acc_128-float8e4b15-64-64-64-128-256-256,ValueError: type fp8e4b15 not supported in this architecture,,test_dot_max_num_imprecise_acc_M=128-N=256-K=256-BLOCK_M=64-BLOCK_N=64-BLOCK_K=64-in_type_str=float8e4b15-low_precision_acc=128-,gcu300/gcu400/gcu410 +test_gather_src_shape0-indices_shape0-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_gather_src_shape=[4, 4]-indices_shape=[8, 4]-axis=0-",gcu300/gcu400/gcu410 +test_gather_src_shape1-indices_shape1-0,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_gather_src_shape=[128, 64]-indices_shape=[256, 64]-axis=0-",gcu300/gcu400/gcu410 +test_gather_src_shape2-indices_shape2-1,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,"test_gather_src_shape=[128, 64]-indices_shape=[128, 128]-axis=1-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape0-0-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(2, 2)-axis=0-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape1-0-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(2, 2)-axis=0-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape5-1-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(2, 2)-axis=1-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape6-1-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(2, 2)-axis=1-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape10-0-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(2, 8)-axis=0-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape11-0-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(2, 8)-axis=0-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape15-1-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(2, 8)-axis=1-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape16-1-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(2, 8)-axis=1-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape20-0-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(8, 2)-axis=0-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape21-0-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(8, 2)-axis=0-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape25-1-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(8, 2)-axis=1-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape26-1-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(8, 2)-axis=1-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape30-0-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(8, 8)-axis=0-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape31-0-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(8, 8)-axis=0-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape35-1-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(8, 8)-axis=1-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape36-1-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(8, 8)-axis=1-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape40-0-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(32, 32)-axis=0-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape41-0-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(32, 32)-axis=0-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape45-1-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(32, 32)-axis=1-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape46-1-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(32, 32)-axis=1-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape50-0-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(64, 64)-axis=0-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape51-0-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(64, 64)-axis=0-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape55-1-1-float16-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(64, 64)-axis=1-num_ctas=1-dtype_x_str=float16-check_return_val=True-",gcu300/gcu400/gcu410 +test_tensor_atomic_rmw_shape56-1-1-float32-True,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_tensor_atomic_rmw_shape=(64, 64)-axis=1-num_ctas=1-dtype_x_str=float32-check_return_val=True-",gcu300/gcu400/gcu410 +test_indirect_load_int64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_indirect_load_dtype=int64-,gcu300/gcu400/gcu410 +test_indirect_load_uint64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_indirect_load_dtype=uint64-,gcu300/gcu400/gcu410 +test_indirect_load_float64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_indirect_load_dtype=float64-,gcu300/gcu400/gcu410 +test_strided_load_int64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_strided_load_dtype=int64-,gcu300/gcu400/gcu410 +test_strided_load_uint64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_strided_load_dtype=uint64-,gcu300/gcu400/gcu410 +test_strided_load_float64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_strided_load_dtype=float64-,gcu300/gcu400/gcu410 +test_indirect_store_int64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_indirect_store_dtype=int64-,gcu300/gcu400/gcu410 +test_indirect_store_uint64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_indirect_store_dtype=uint64-,gcu300/gcu400/gcu410 +test_indirect_store_float64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_indirect_store_dtype=float64-,gcu300/gcu400/gcu410 +test_strided_store_int64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_strided_store_dtype=int64-,gcu300/gcu400/gcu410 +test_strided_store_uint64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_strided_store_dtype=uint64-,gcu300/gcu400/gcu410 +test_strided_store_float64,error: 64-bit data type not supported on GCU300!,Not supported on gcu300,test_strided_store_dtype=float64-,gcu300/gcu400/gcu410 +test_zero_strided_tensors,error: failed to legalize operation 'tt.atomic_rmw' that was explicitly marked illegal,,"test_zero_strided_tensors_",gcu300/gcu400/gcu410 +test_reshape_formats2,tensor rank of slice_start larger than 5 is not support on gcu,Not supported on gcu,"test_reshape_formats=((512,), (2, 2, 2, 2, 2, 2, 2, 2, 2))-",gcu300 diff --git a/third_party/enflame/python/test/unit/language/test_decorator.py b/third_party/enflame/python/test/unit/language/test_decorator.py new file mode 100644 index 000000000..028dccccc --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_decorator.py @@ -0,0 +1,53 @@ +import torch + +import triton +import triton.language as tl +import pytest + +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import torch_gcu + + +def test_decorator_with_def(device): + + def triton_heuristics_pointwise(**kwargs): + + def decorator(func): + return func + + return decorator + + # "def" might appear in a decorator call, e.g. a hash string argument. + # This test makes sure the compiler can find the right position of function + # definition. + @triton_heuristics_pointwise(inductor_meta={'backend_hash': 'def0aeffabe53b3f8'}, ) + @triton.jit + def kernel(): + pass + + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + except Exception as e: + pytest.fail(f"triton compile failed with error: {e}") + + +def test_triton_heuristic(device): + N = 1023 + src = torch.empty(N, device=device) + dst = torch.zeros(N, device=device) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1) + @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs + @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr, EVEN_N: tl.constexpr, EVEN_src: tl.constexpr): + tl.store(dst, EVEN_N) + tl.store(dst + 1, EVEN_src) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + assert dst[0].item() == 0.0 + assert dst[1].item() == 1.0 + assert _kernel.base_fn.__name__ == "_kernel" diff --git a/third_party/enflame/python/test/unit/language/test_line_info.py b/third_party/enflame/python/test/unit/language/test_line_info.py new file mode 100644 index 000000000..e8e7746f5 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_line_info.py @@ -0,0 +1,134 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import subprocess +import tempfile + +import pytest +import torch +import torch_gcu + +import triton +import triton.language as tl +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + + +@triton.jit +def kernel_single(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def device_inline(x): + return x + x + + +@triton.jit +def kernel_call(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = device_inline(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit(noinline=True) +def device_noinline(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = x + x + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_call_noinline(X, Y, BLOCK: tl.constexpr): + device_noinline(X, Y, BLOCK) + + +@triton.jit +def kernel_multi_files(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.softmax(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + +def extract_file_lines(asm): + fd, path = tempfile.mkstemp() + with open(fd, 'wb') as cubin: + cubin.write(asm) + asm = subprocess.check_output(["nvdisasm", "-g", path]).decode("utf-8") + file_lines = [] + lines = asm.splitlines() + for line in lines: + if "## File" in line: + entries = line[line.index("## File"):].split(",") + file_lines.append((entries[0].strip(), entries[1].strip())) + return file_lines + + +def check_file_lines(file_lines, file_name, lineno): + for file, line in file_lines: + # -1 means do not check line number + if lineno == -1: + if file_name in file: + return True + if file_name in file and str(lineno) in line: + return True + return False + + +func_types = ["single", "call", "call_noinline", "multi_files"] + + +@pytest.mark.parametrize("func", func_types) +def test_line_info(func: str): + try: + subprocess.check_output(["nvdisasm", "-h"]) + except BaseException: + pytest.skip("nvdisasm is not available") + + shape = (128, ) + x = torch.arange(0, shape[0], dtype=torch.float32, device='gcu') + y = torch.zeros(shape, dtype=x.dtype, device="gcu") + kernel_info = {} + if func == "single": + kernel_info = kernel_single[(1,)](x, y, BLOCK=shape[0]) + elif func == "call": + kernel_info = kernel_call[(1,)](x, y, BLOCK=shape[0]) + elif func == "call_noinline": + kernel_info = kernel_call_noinline[(1,)](x, y, BLOCK=shape[0]) + elif func == "multi_files": + kernel_info = kernel_multi_files[(1,)](x, y, BLOCK=shape[0]) + + file_lines = extract_file_lines(kernel_info.asm["cubin"]) + if func == "single": + assert (check_file_lines(file_lines, "test_line_info.py", 15)) + assert (check_file_lines(file_lines, "test_line_info.py", 16)) + elif func == "call": + assert (check_file_lines(file_lines, "test_line_info.py", 28)) + assert (check_file_lines(file_lines, "test_line_info.py", 21)) + assert (check_file_lines(file_lines, "test_line_info.py", 30)) + elif func == "call_noinline": + assert (check_file_lines(file_lines, "test_line_info.py", 42)) + assert (check_file_lines(file_lines, "test_line_info.py", 35)) + assert (check_file_lines(file_lines, "test_line_info.py", 36)) + assert (check_file_lines(file_lines, "test_line_info.py", 37)) + elif func == "multi_files": + assert (check_file_lines(file_lines, "test_line_info.py", 47)) + assert (check_file_lines(file_lines, "test_line_info.py", 49)) + assert (check_file_lines(file_lines, "standard.py", 33)) + assert (check_file_lines(file_lines, "standard.py", 34)) + assert (check_file_lines(file_lines, "standard.py", 36)) + # core.py is changed frequently, so we only check if it exists + assert (check_file_lines(file_lines, "core.py", -1)) + +''' diff --git a/third_party/enflame/python/test/unit/language/test_random.py b/third_party/enflame/python/test/unit/language/test_random.py new file mode 100644 index 000000000..7694978e3 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_random.py @@ -0,0 +1,225 @@ +import numpy as np +import pytest +import scipy.stats +import torch +import torch_gcu + +import triton +import triton.language as tl +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + +##################################### +# Reference Philox Implementation +##################################### + + +class PhiloxConfig: + + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + while len(res) < pad: + res.append(np.array(n, dtype=self._dtype)) + n >>= (np.dtype(self._dtype).itemsize * 8) + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +# Unit Tests +##################################### + +BLOCK = tl.constexpr(1024) + +# test generation of random uint32 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in ['1', '4,53', '400'] + for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba] + for dtype in ['int32'] #, 'int64'] + for const_seed in [True, False]]) +def test_randint(size, seed, device, dtype, const_seed): + if seed in [0xffffffff, 0x0000000fcafeb0ba] and const_seed is False: + pytest.skip("Not supported on gcu300: dtype i64 is not supported on gcu300") + + size = list(map(int, size.split(','))) + torch_dtype = getattr(torch, dtype) + numpy_dtype = getattr(np, f"u{dtype}") + config = {'int32': PHILOX_32, 'int64': PHILOX_64}[dtype] + + @triton.jit + def kernel(X, N, seed): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch_dtype, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed) + else: + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(numpy_dtype).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=config) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + + +# test uniform PRNG + + +@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]]) +def test_rand(size, seed, device): + + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + kernel[grid](x, N, seed) + assert all((x >= 0) & (x <= 1)) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + + +# test normal PRNG + + +@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]]) +def test_randn(size, seed, device): + + @triton.jit + def kernel(X, N, seed): + offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + kernel[grid](x, N, seed) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + +# def test_rand_limits(device): +# @triton.jit +# def kernel(input, output, n: tl.constexpr): +# idx = tl.arange(0, n) +# x = tl.load(input + idx) +# y = tl.random.uint32_to_uniform_float(x) +# tl.store(output + idx, y) + +# min_max_int32 = torch.tensor([ +# torch.iinfo(torch.int32).min, +# torch.iinfo(torch.int32).max, +# ], dtype=torch.int32, device=device) +# output = torch.empty(2, dtype=torch.float32, device=device) +# kernel[(1,)](min_max_int32, output, 2) + +# assert output[0] == output[1] +# assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/third_party/enflame/python/test/unit/language/test_reduce.py b/third_party/enflame/python/test/unit/language/test_reduce.py new file mode 100644 index 000000000..58c02d8a8 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_reduce.py @@ -0,0 +1,81 @@ +import pytest +import torch + +import triton +import triton.language as tl +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + +import numpy as np +import os +import torch_gcu + +os.environ["ENFLAME_LOG_DEBUG_MOD"] = "TORCH_GCU/OP" + + +@triton.jit +def _max_kernel_reduce(INPUT, OUT, input_stride0, input_stride1, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + M: tl.constexpr, N: tl.constexpr): + start_m = tl.program_id(0) + I_block_ptr = tl.make_block_ptr(base=INPUT, shape=(M, N), strides=(input_stride0, input_stride1), + offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, N), order=(1, 0)) + O_block_ptr = tl.make_block_ptr(base=OUT, shape=(M, ), strides=(1, ), offsets=(start_m * BLOCK_M, ), + block_shape=(BLOCK_M, ), order=(0, )) + i = tl.load(I_block_ptr) + out = tl.max(i, 1) + + tl.store(O_block_ptr, out) + I_block_ptr = tl.advance(I_block_ptr, (0, BLOCK_N)) + + +@staticmethod +def tri_max_reduce(input, out, M, N): + BLOCK_M = 32 + BLOCK_N = 16 + grid = (triton.cdiv(input.shape[0], BLOCK_M), 1, 1) + num_warps = 4 + _max_kernel_reduce[grid](input, out, input.stride(0), input.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, M=M, N=N, + num_warps=num_warps, num_stages=4) + return + + +def dump_tensor(tensor): + for a in tensor.tolist(): + npro = [] + for j in a: + npro.append(round(j, 3)) + print("[", npro, "]") + print("\n") + + +def cosine_similarity(ta, tb): + assert ta.shape == tb.shape + sum_a = np.square(ta).sum() + sum_b = np.square(tb).sum() + if sum_a == 0 or sum_b == 0: + return 0.0 + else: + return np.float64(np.sum(ta * tb) / np.sqrt(sum_a) / np.sqrt(sum_b)) + + +@pytest.mark.parametrize('M, N', [(16, 16), (64, 16), (128, 32)]) +def test_op(M, N, dtype=torch.float32): + input = torch.empty((M, N), dtype=dtype, device="gcu").normal_(mean=0., std=0.5) + out = torch.empty((M), dtype=dtype, device="gcu") + + ref_out, ref_out_index = torch.max(input, dim=1) + # triton implementation + tri_out = tri_max_reduce(input, out, M, N) + print("gcu caculate done!") + # compare + assert torch.allclose(ref_out, out, atol=1e-2, rtol=0) + + # cos_sim = cosine_similarity(ref_out.cpu().numpy(), tri_out.cpu().numpy()) + # print("output cos similarity: {}".format(cos_sim)) + print("ok") + + +test_op(16, 16, torch.float32) +test_op(64, 16, torch.float32) +test_op(128, 32, torch.float32) diff --git a/third_party/enflame/python/test/unit/language/test_reproducer.py b/third_party/enflame/python/test/unit/language/test_reproducer.py new file mode 100644 index 000000000..5de1c3159 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_reproducer.py @@ -0,0 +1,45 @@ +import os +import shutil + +import pytest + +import torch +import torch_gcu +import triton +import re +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + + +@triton.jit +def triton_(): + return + + +def test_reproducer(): + tmpdir = ".tmp" + reproducer = 'triton-reproducer.mlir' + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) + os.environ["TRITON_CACHE_DIR"] = tmpdir + os.environ["TRITON_REPRODUCER_PATH"] = reproducer + triton_[(1, )]() + foundPipeline = "" + with open(reproducer, 'r') as f: + line = f.read() + if 'pipeline:' in line: + foundPipeline = line + if 0 == len(foundPipeline): + raise Exception("Failed to find pipeline info in reproducer file.") + + ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm") + if ttgir_to_llvm_pass.search(foundPipeline): + raise Exception("Failed to find triton passes in pipeline") + # cleanup + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) diff --git a/third_party/enflame/python/test/unit/language/test_standard.py b/third_party/enflame/python/test/unit/language/test_standard.py new file mode 100644 index 000000000..c3c7d7682 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_standard.py @@ -0,0 +1,100 @@ +import triton +import pytest +import torch +import torch_gcu +import triton.language as tl +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random + +# --------------- +# test maximum/minimum ops +# --------------- + + +# TODO: Tests with unsigned integers failed at compilation stage. +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", int_dtypes + uint_dtypes + float_dtypes + ["bfloat16"]) +@pytest.mark.parametrize("op", ["maximum", "minimum"]) +def test_maximum_minimum(dtype, op, device): + expr = f'tl.{op}(x, y)' + numpy_expr = f'np.{op}(x, y)' + _test_binary(dtype, dtype, expr, numpy_expr, device=device) + + +# --------------- +# test sort op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_sort(M, N, descending, dtype_str, device): + + @triton.jit + def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.sort(x, descending=descending) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.sort(x, descending=descending)[0] + z = torch.empty_like(x) + # num_warps in gcu4xx can not exceed 4 + sort_kernel[(1, )](x, z, N, M, descending, num_warps=4) + assert (y == z).all(), (y, z) + + +# --------------- +# test flip op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_flip(M, N, dtype_str, device): + + @triton.jit + def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.flip(x) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + # TODO(chongzhou.yang): torch.flip is not yet supported in gcu400, fallback it to cpu until 2025.12.31 + y = torch.flip(x.to('cpu'), (1, )) + z = torch.empty_like(x, device=device) + # num_warps in gcu4xx can not exceed 4 + flip_kernel[(1, )](x, z, N, M, num_warps=4) + assert (y.to('gcu') == z).all(), (y, z) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]]) +def test_swizzle2d(size_i, size_j, size_g, device): + + @triton.jit + def swizzle2d_kernel(output, size_i, size_j, size_g): + for i in tl.range(0, size_i, 1): + for j in tl.range(0, size_j, 1): + new_i, new_j = tl.swizzle2d(i, j, size_i, size_j, size_g) + tl.store(output + new_i * size_j + new_j, i * size_j + j) + + output = torch.zeros(size_i, size_j).to(device) + swizzle2d_kernel[(1, )](output, size_i, size_j, size_g) + expected_order = torch.tensor([[0, 3, 6, 9, 12, 15, 18], [1, 4, 7, 10, 13, 16, 19], [2, 5, 8, 11, 14, 17, 20], + [21, 23, 25, 27, 29, 31, 33], [22, 24, 26, 28, 30, 32, 34]]).to(device) + assert (output == expected_order).all(), (output, expected_order) diff --git a/third_party/enflame/python/test/unit/language/test_subprocess.py b/third_party/enflame/python/test/unit/language/test_subprocess.py new file mode 100644 index 000000000..ed6a6ec37 --- /dev/null +++ b/third_party/enflame/python/test/unit/language/test_subprocess.py @@ -0,0 +1,141 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import itertools +import os +import subprocess +import sys +from collections import Counter + +import pytest +import torch_gcu + +dir_path = os.path.dirname(os.path.realpath(__file__)) +print_path = os.path.join(dir_path, "print_helper.py") +# (TODO) long and float64 are not supported on gcu300 +#torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] +torch_types = ["int8", "uint8", "int16", "int32", "float16", "float32"] + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +# TODO: Print with multiple operands + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_type, data_type", [(fn, data_type) + #TODO(triton3.2) newly added func_types are required to be suopprted + #for fn in ["device_print", "device_print_scalar"] + for fn in ["device_print"] + for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + #TODO(triton3.2) newly added func_types are required to be suopprted + # ("no_arg_print", "int32"), + # ("print_no_arg", "int32"), + # ("device_print_large", "int32"), + # ("print_multiple_args", "int32"), + # ("device_print_multiple_args", "int32"), + # ("device_print_hex", "int16"), + # ("device_print_hex", "int32"), + # ("device_print_hex", "int64"), + # ("device_print_pointer", "int32"), + # ("device_print_negative", "int32"), + ("device_print_uint", "uint32"), + ]) +def test_print(func_type: str, data_type: str, device: str): + proc = subprocess.run( + [sys.executable, print_path, "test_print", func_type, data_type, device], + capture_output=True, + ) + assert proc.returncode == 0 + + if is_interpreter() and func_type != "static_assert": + # Interpreter uses a different format for device_print + # Only check if there's no error + assert proc.stderr == b'' + return + + outs = [line for line in proc.stdout.decode("UTF-8").splitlines() if line] + # The total number of elements in the 1-D tensor to print. + N = 128 + + # Constant for testing the printing of scalar values + SCALAR_VAL = 42 + + # Format is + # pid (, , ) idx (, , ...) (operand ) + expected_lines = Counter() + if func_type in ("print", "device_print", "device_print_uint"): + for i in range(N): + offset = (1 << 31) if data_type == "uint32" else 0 + # line = f"pid (0, 0, 0) idx ({i:3}) x: {i + offset}" + line = f"[0, {(i // 32)}] x: : {i + offset} (idx {i % 32})" + if data_type.startswith("float"): + original_str = line + target = " (idx" + insert_idx = original_str.find(target) + assert insert_idx != -1, "cannot find \" (idx\" in the print strings" + line = original_str[:insert_idx] + ".000000" + original_str[insert_idx:] + expected_lines[line] = 1 + elif func_type == "device_print_scalar": + line = f"pid (0, 0, 0) idx () x: {SCALAR_VAL}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = N + elif func_type == "device_print_negative": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: {-i}" + expected_lines[line] = 1 + elif func_type == "device_print_hex": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: 0x" + if data_type == "int16": + line += f"{i:04x}" + if data_type == "int32": + line += f"{i:08x}" + if data_type == "int64": + line += f"{i:016x}" + expected_lines[line] = 1 + elif func_type == "static_print": + expected_lines[f" int32[constexpr[{N}]]"] = 1 + elif func_type == "no_arg_print": + expected_lines["pid (0, 0, 0) idx (): 0"] = N + elif func_type == "print_no_arg": + expected_lines["pid (0, 0, 0) no arg"] = N + elif func_type == "device_print_large": + for i, j, k in itertools.product(range(2), range(64), range(N)): + expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1 + elif func_type == "device_print_pointer": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1 + + actual_lines = Counter() + for line in outs: + # Trim the exact pointer address in the output--they can change per run. + line = (line.split(':')[0] + ": 0x") if func_type == "device_print_pointer" else line + actual_lines[line] += 1 + + # # filter out lines starting with "warning: " + actual_lines = [line for line in actual_lines if not line.lstrip().lower().startswith("warning: ")] + + diff = Counter(actual_lines) + diff.subtract(expected_lines) + for line, delta in diff.items(): + if delta == 0: + continue + + print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') + assert all(delta == 0 for delta in diff.values()) +''' diff --git a/third_party/enflame/python/test/unit/operators/.coveragerc b/third_party/enflame/python/test/unit/operators/.coveragerc new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/python/test/unit/operators/test_blocksparse.py b/third_party/enflame/python/test/unit/operators/test_blocksparse.py new file mode 100644 index 000000000..844c4c0be --- /dev/null +++ b/third_party/enflame/python/test/unit/operators/test_blocksparse.py @@ -0,0 +1,235 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import pytest +import torch + +import triton +import triton.ops + +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + +def sparsify_tensor(x, mask, block): + ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) + for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): + ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] + return ret + + +def make_pair(shape, device="gcu", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32): + if data is None: + data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device) + ref_ret = data + ref_ret = ref_ret * alpha + beta + ref_ret = ref_ret.half().to(dtype) + if trans: + ref_ret = ref_ret.t().requires_grad_() + ref_ret = ref_ret.detach().requires_grad_() + tri_ret = ref_ret.clone().detach().requires_grad_() + return ref_ret, tri_ret + + +def mask_tensor(x, mask, block, value=0): + ret = x.clone() + for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): + ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value + return ret + +# TODO(rickon.wang): MODE "dds" sip hang. backward not working. param 'BLOCK" is +# set 64 due to runtime limit. +@pytest.mark.parametrize("MODE", ["sdd", "dsd"]) +@pytest.mark.parametrize("TRANS_A", [False, True]) +@pytest.mark.parametrize("TRANS_B", [False, True]) +@pytest.mark.parametrize("BLOCK", [64]) +@pytest.mark.parametrize("DTYPE", [torch.float16]) +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): + seed = 0 + torch.manual_seed(seed) + is_sdd = MODE == "sdd" + is_dsd = MODE == "dsd" + is_dds = MODE == "dds" + do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK) + do_mask = lambda x: mask_tensor(x, layout, BLOCK) + # create inputs + # create op + a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) + b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) + c_shape = (Z, H, M, N) + shape = { + "sdd": (M, N), + "dsd": (a_shape[2], a_shape[3]), + "dds": (b_shape[2], b_shape[3]), + }[MODE] + layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # create data + a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE) + b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE) + # dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE) + # compute [torch] + # dc_ref = do_mask(dc_ref) if is_sdd else dc_ref + a_ref = do_mask(a_ref) if is_dsd else a_ref + b_ref = do_mask(b_ref) if is_dds else b_ref + # a_ref.retain_grad() + # b_ref.retain_grad() + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, + b_ref.transpose(2, 3) if TRANS_B else b_ref) + # c_ref.backward(dc_ref) + c_ref = do_sparsify(c_ref) if is_sdd else c_ref + # da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad + # db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad + # triton result + # dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri + a_tri = do_sparsify(a_tri) if is_dsd else a_tri + b_tri = do_sparsify(b_tri) if is_dds else b_tri + # a_tri.retain_grad() + # b_tri.retain_grad() + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="gcu") + c_tri = op(a_tri, b_tri) + # c_tri.backward(dc_tri) + # da_tri = a_tri.grad + # db_tri = b_tri.grad + # compare + torch.testing.assert_allclose(c_ref, c_tri) + # torch.testing.assert_allclose(da_ref, da_tri) + # torch.testing.assert_allclose(db_ref, db_tri) + + +configs = [ + (16, 128), + # (16, 256), + # (32, 576), + # (64, 1871), + # (128, 2511), +] + + +@pytest.mark.parametrize("is_dense", [False, True]) +@pytest.mark.parametrize("BLOCK, WIDTH", configs) +def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4): + # set seed + torch.random.manual_seed(0) + Z, H, M, N = 2, 3, WIDTH, WIDTH + # initialize layout + # make sure each row has at least one non-zero element + layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) + if is_dense: + layout[:] = 1 + else: + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # initialize data + a_shape = (Z, H, M, N) + a_ref, a_tri = make_pair(a_shape) + dout_ref, dout_tri = make_pair(a_shape) + # compute [torch] + a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) + a_ref.retain_grad() + at_mask = torch.ones((M, N), device="gcu") + if is_causal: + at_mask = torch.tril(at_mask) + M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) + a_ref[M == 0] = float("-inf") + out_ref = torch.softmax(a_ref * scale, -1) + out_ref.backward(dout_ref) + out_ref = sparsify_tensor(out_ref, layout, BLOCK) + da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK) + # compute [triton] + a_tri = sparsify_tensor(a_tri, layout, BLOCK) + a_tri.retain_grad() + dout_tri = sparsify_tensor(dout_tri, layout, BLOCK) + op = triton.ops.blocksparse.softmax(layout, BLOCK, device="gcu", is_dense=is_dense) + out_tri = op(a_tri, scale=scale, is_causal=is_causal) + out_tri.backward(dout_tri) + da_tri = a_tri.grad + # compare + torch.testing.assert_allclose(out_tri, out_ref) + torch.testing.assert_allclose(da_tri, da_ref) + + +#TODO(rickon.wang): backward disabled. runtime limit. +@pytest.mark.parametrize("block", [16, 32, 64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_attention_fwd_bwd( + block, + dtype, + input_scale=1.0, + scale=1 / 8.0, + n_ctx=256, + batch_size=2, + n_heads=2, +): + capability = torch.gcu.get_device_capability() + # if capability[0] < 7: + # pytest.skip("Only test tl.dot() on devices with sm >= 70") + + # inputs + qkv_shape = (batch_size, n_heads, n_ctx, 64) + qkvs = [ + torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).gcu() for _ in range(3) + ] + + # Triton: + n_blocks = n_ctx // block + layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) + query, key, value = [x.clone() for x in qkvs] + # query.retain_grad() + # key.retain_grad() + # value.retain_grad() + attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) + # ad hoc loss + loss = (attn_out ** 2).mean() + # loss.backward() + # grads = [query.grad, key.grad, value.grad] + + # Torch version: + torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + attn_mask = torch.ones([n_ctx, n_ctx], device="gcu", dtype=dtype) + attn_mask = torch.tril(attn_mask, diagonal=0) + attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).gcu())) + # torch_q.retain_grad() + # torch_k.retain_grad() + # torch_v.retain_grad() + scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) + # ad hoc loss + torch_loss = (torch_attn_out ** 2).mean() + # torch_loss.backward() + # torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] + + # comparison + # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") + torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0) + # for g1, g2 in zip(grads, torch_grads): + # torch.testing.assert_allclose(g1, g2) + + +@pytest.mark.parametrize("block", [16, 32, 64]) +def triton_attention( + layout, + block: int, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +): + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) + sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) + + w = sparse_dot_sdd_nt(query, key) + w = sparse_softmax(w, scale=scale, is_causal=True) + a = sparse_dot_dsd_nn(w, value) + return a + +''' diff --git a/third_party/enflame/python/test/unit/operators/test_cross_entropy.py b/third_party/enflame/python/test/unit/operators/test_cross_entropy.py new file mode 100644 index 000000000..1da1004a4 --- /dev/null +++ b/third_party/enflame/python/test/unit/operators/test_cross_entropy.py @@ -0,0 +1,50 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import pytest +import torch + +import triton +import triton.ops + + +@pytest.mark.parametrize("M, N, dtype, mode", + [ + (M, N, dtype, mode) for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32'] + for mode in ['forward', 'backward'] + ] + ) +def test_op(M, N, dtype, mode): + capability = torch.cuda.get_device_capability() + if capability[0] < 8 and dtype == "bfloat16": + pytest.skip("Only test bfloat16 on devices with sm >= 80") + dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype] + # create inputs + x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True) + idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda') + # forward pass + tt_y = triton.ops.cross_entropy(x, idx) + th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) + if mode == 'forward': + torch.testing.assert_allclose(th_y, tt_y) + # backward pass + elif mode == 'backward': + dy = torch.randn_like(tt_y) + # triton backward + tt_y.backward(dy) + tt_dx = x.grad.clone() + # torch backward + x.grad = None + th_y.backward(dy) + th_dx = x.grad.clone() + + torch.testing.assert_allclose(th_dx, tt_dx) + +''' diff --git a/third_party/enflame/python/test/unit/operators/test_flash_attention.py b/third_party/enflame/python/test/unit/operators/test_flash_attention.py new file mode 100644 index 000000000..acee1459a --- /dev/null +++ b/third_party/enflame/python/test/unit/operators/test_flash_attention.py @@ -0,0 +1,60 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import pytest +import torch + +import triton +import triton.ops + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16), + (4, 48, 1024, 32), + (4, 48, 1024, 64), + (4, 48, 1024, 128)]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('seq_par', [True, False]) +def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Flash attention only supported for compute capability < 80") + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).to(dtype) + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # # triton implementation + tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) + # print(ref_out) + # print(tri_out) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 + torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0) + torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0) + torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0) + torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0) + +''' diff --git a/third_party/enflame/python/test/unit/operators/test_inductor.py b/third_party/enflame/python/test/unit/operators/test_inductor.py new file mode 100644 index 000000000..02a72b57f --- /dev/null +++ b/third_party/enflame/python/test/unit/operators/test_inductor.py @@ -0,0 +1,168 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import torch + +import triton +import triton.language as tl + +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + +def test_normalization_with_remat(): + + @triton.jit + def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + xnumel = xnumel + rnumel = rnumel + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x3 = xindex + x0 = xindex % 32 + tmp1 = tl.load(in_ptr0 + (x0), xmask) + tmp3 = tl.load(in_ptr1 + (x0), xmask) + tmp11 = tl.load(in_ptr2 + (x0), xmask) + tmp13 = tl.load(in_ptr3 + (x0), xmask) + _tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r2 = rindex + tmp0 = tl.load(in_out_ptr0 + (r2 + (4096 * x3)), rmask & xmask, eviction_policy='evict_last', other=0) + tmp2 = tmp0 - tmp1 + tmp4 = 1e-05 + tmp5 = tmp3 + tmp4 + tmp6 = tl.sqrt(tmp5) + tmp7 = 1 / tmp6 + tmp8 = 1.0 + tmp9 = tmp7 * tmp8 + tmp10 = tmp2 * tmp9 + tmp12 = tmp10 * tmp11 + tmp14 = tmp12 + tmp13 + _tmp17 = tl.where(rmask & xmask, _tmp17 + tmp14, _tmp17) + tl.store(in_out_ptr0 + (r2 + (4096 * x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp14, rmask & xmask) + tmp17 = tl.sum(_tmp17, 1)[:, None] + tmp18 = 4096.0 + tmp19 = tmp17 / tmp18 + tl.store(in_out_ptr1 + (x3 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask) + + torch.manual_seed(123) + + buf14 = torch.rand(8, 32, 64, 64, device="gcu") + buf16 = torch.rand(8, 1, 32, device="gcu") + arg114_1 = torch.rand(32, device="gcu") + arg115_1 = torch.rand(32, device="gcu") + arg8_1 = torch.rand(32, device="gcu") + arg9_1 = torch.rand(32, device="gcu") + triton_[(256,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 256, 4096, 1, 2048, num_warps=4) + torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) + + +def test_avg_pool_bw(): + + @triton.jit + def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + x1 = (xindex // 8) % 8 + x0 = xindex % 8 + x2 = (xindex // 64) + x5 = xindex + tmp0 = (-1) + x1 + tmp1 = (-1) + x0 + tmp2 = 2 + x1 + tmp3 = 2 + x0 + tmp4 = 0 + tmp5 = tl.where(tmp0 != tmp0, tmp0, tl.where(tmp0 > tmp4, tmp0, tmp4)) + tmp6 = tl.where(tmp1 != tmp1, tmp1, tl.where(tmp1 > tmp4, tmp1, tmp4)) + tmp7 = 8 + tmp8 = tl.where(tmp2 != tmp2, tmp2, tl.where(tmp2 < tmp7, tmp2, tmp7)) + tmp9 = tl.where(tmp3 != tmp3, tmp3, tl.where(tmp3 < tmp7, tmp3, tmp7)) + tmp10 = tmp5 + tmp4 + tmp11 = tmp6 + tmp4 + tmp12 = 1 + tmp13 = tmp8 - tmp12 + tmp14 = tl.where(tmp10 != tmp10, tmp10, tl.where(tmp10 < tmp13, tmp10, tmp13)) + tmp15 = tmp9 - tmp12 + tmp16 = tl.where(tmp11 != tmp11, tmp11, tl.where(tmp11 < tmp15, tmp11, tmp15)) + tmp17 = tl.load(in_ptr0 + (tmp16 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp18 = tmp17 / 9 + tmp19 = tmp10 < tmp8 + tmp20 = tmp11 < tmp9 + tmp21 = tmp19 & tmp20 + tmp22 = 0.0 + tmp23 = tl.where(tmp21, tmp18, tmp22) + tmp24 = tmp6 + tmp12 + tmp25 = tl.where(tmp24 != tmp24, tmp24, tl.where(tmp24 < tmp15, tmp24, tmp15)) + tmp26 = tl.load(in_ptr0 + (tmp25 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp27 = tmp26 / 9 + tmp28 = tmp24 < tmp9 + tmp29 = tmp19 & tmp28 + tmp30 = tmp23 + tmp27 + tmp31 = tl.where(tmp29, tmp30, tmp23) + tmp32 = 2 + tmp33 = tmp6 + tmp32 + tmp34 = tl.where(tmp33 != tmp33, tmp33, tl.where(tmp33 < tmp15, tmp33, tmp15)) + tmp35 = tl.load(in_ptr0 + (tmp34 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp36 = tmp35 / 9 + tmp37 = tmp33 < tmp9 + tmp38 = tmp19 & tmp37 + tmp39 = tmp31 + tmp36 + tmp40 = tl.where(tmp38, tmp39, tmp31) + tmp41 = tmp5 + tmp12 + tmp42 = tl.where(tmp41 != tmp41, tmp41, tl.where(tmp41 < tmp13, tmp41, tmp13)) + tmp43 = tl.load(in_ptr0 + (tmp16 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp44 = tmp43 / 9 + tmp45 = tmp41 < tmp8 + tmp46 = tmp45 & tmp20 + tmp47 = tmp40 + tmp44 + tmp48 = tl.where(tmp46, tmp47, tmp40) + tmp49 = tl.load(in_ptr0 + (tmp25 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp50 = tmp49 / 9 + tmp51 = tmp45 & tmp28 + tmp52 = tmp48 + tmp50 + tmp53 = tl.where(tmp51, tmp52, tmp48) + tmp54 = tl.load(in_ptr0 + (tmp34 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp55 = tmp54 / 9 + tmp56 = tmp45 & tmp37 + tmp57 = tmp53 + tmp55 + tmp58 = tl.where(tmp56, tmp57, tmp53) + tmp59 = tmp5 + tmp32 + tmp60 = tl.where(tmp59 != tmp59, tmp59, tl.where(tmp59 < tmp13, tmp59, tmp13)) + tmp61 = tl.load(in_ptr0 + (tmp16 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp62 = tmp61 / 9 + tmp63 = tmp59 < tmp8 + tmp64 = tmp63 & tmp20 + tmp65 = tmp58 + tmp62 + tmp66 = tl.where(tmp64, tmp65, tmp58) + tmp67 = tl.load(in_ptr0 + (tmp25 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp68 = tmp67 / 9 + tmp69 = tmp63 & tmp28 + tmp70 = tmp66 + tmp68 + tmp71 = tl.where(tmp69, tmp70, tmp66) + tmp72 = tl.load(in_ptr0 + (tmp34 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp73 = tmp72 / 9 + tmp74 = tmp63 & tmp37 + tmp75 = tmp71 + tmp73 + tmp76 = tl.where(tmp74, tmp75, tmp71) + tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None) + + inp = torch.ones(8, 2048, 8, 8, device="gcu", dtype=torch.half) + out = torch.ones_like(inp) * 3 + numel = inp.numel() + triton_[(numel // 1024,)](inp, out, 1024, num_warps=1) + out_ref = torch.ones_like(inp) + out_ref[:, :, 1:7, 0::7] = 2 / 3 + out_ref[:, :, 0::7, 1:7] = 2 / 3 + out_ref[:, :, 0::7, 0::7] = 4 / 9 + torch.testing.assert_allclose(out, out_ref) + +''' diff --git a/third_party/enflame/python/test/unit/operators/test_matmul.py b/third_party/enflame/python/test/unit/operators/test_matmul.py new file mode 100644 index 000000000..e0ba372a1 --- /dev/null +++ b/third_party/enflame/python/test/unit/operators/test_matmul.py @@ -0,0 +1,171 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import itertools + +import pytest +import torch + +import triton +import triton.language as tl +import triton.ops + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@pytest.mark.parametrize( + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE", + itertools.chain( + *[ + [ + # 1 warp + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE), + # 2 warp + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE), + # 4 warp + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + # 8 warp + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE), + # split-k + (64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + (64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), + # variable input + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE), + (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE), + (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] + ], + # n-stage + *[ + [ + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE), + # split-k + (64, 64, 16, 8, 4, STAGES, 128, 128, 768, AT, BT, DTYPE, DTYPE), + (64, 64, 16, 8, 4, STAGES, 128, 128, 32, AT, BT, DTYPE, DTYPE), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] + ], + # mixed-precision + *[ + [ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE), + (128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE), + ] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"), + ("float8e4", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] + ] + ), +) +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE): + capability = torch.cuda.get_device_capability() + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8 and (ADTYPE == "bfloat16" or BDTYPE == "bfloat16"): + pytest.skip("Only test bfloat16 on devices with sm >= 80") + if (ADTYPE == "bfloat16" or BDTYPE == "bfloat16") and SPLIT_K != 1: + pytest.skip("bfloat16 matmuls don't allow split_k for now") + torch.manual_seed(0) + # nuke kernel decorators -- will set meta-parameters manually + kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} + pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] + kernel = triton.ops._matmul.kernel + kernel.configs = configs + # kernel.run = kernel.run.run.run + + # get matrix shape + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K * SPLIT_K if K is None else K + a_fp8 = "float8" in ADTYPE + b_fp8 = "float8" in BDTYPE + + def maybe_upcast(x, dtype, is_float8): + if is_float8: + return f8_to_f16(x, dtype) + return x + + def init_input(n, m, t, dtype, is_float8): + if t: + return init_input(m, n, False, dtype, is_float8).t() + if is_float8: + return torch.randint(20, 60, (n, m), device="cuda", dtype=torch.int8) + dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype] + return .1 * torch.randn((n, m), device="cuda", dtype=dtype) + + # allocate/transpose inputs + a = init_input(M, K, AT, ADTYPE, a_fp8) + b = init_input(K, N, BT, BDTYPE, b_fp8) + # run test + th_a = maybe_upcast(a, ADTYPE, a_fp8).to(torch.float32) + if AT and a_fp8: + th_a = th_a.view(th_a.shape[::-1]).T + th_b = maybe_upcast(b, BDTYPE, b_fp8).to(torch.float32) + if BT and b_fp8: + th_b = th_b.view(th_b.shape[::-1]).T + th_c = torch.matmul(th_a, th_b) + try: + if a_fp8: + a = triton.reinterpret(a, getattr(tl, ADTYPE)) + if b_fp8: + b = triton.reinterpret(b, getattr(tl, BDTYPE)) + tt_c = triton.ops.matmul(a, b) + atol, rtol = 1e-2, 0 + if ADTYPE == torch.bfloat16 or BDTYPE == torch.bfloat16: + atol, rtol = 3.5e-2, 0 + torch.testing.assert_allclose(th_c, tt_c, atol=atol, rtol=rtol) + except triton.OutOfResources as e: + pytest.skip(str(e)) + +''' diff --git a/third_party/enflame/python/test/unit/runtime/.coveragerc b/third_party/enflame/python/test/unit/runtime/.coveragerc new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/python/test/unit/runtime/test_autotuner.py b/third_party/enflame/python/test/unit/runtime/test_autotuner.py new file mode 100644 index 000000000..9c90f771e --- /dev/null +++ b/third_party/enflame/python/test/unit/runtime/test_autotuner.py @@ -0,0 +1,137 @@ +import torch +import torch_gcu +from torch_gcu import transfer_to_gcu + +import triton +import triton.language as tl +import pytest +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + + +@pytest.mark.parametrize('use_cuda_graph', [False, True]) +def test_kwargs(use_cuda_graph: bool): + N = 1024 + src = torch.empty(N, device='cuda') + dst = torch.empty(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N) + _kernel[grid](dst=dst, src=src, N=N) + + +def test_restore(): + N = 1024 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) + + +def test_hooks(): + # Autotuner's pre- and post- hooks should be called the same number of times + N = 4096 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 4096}), triton.Config(kwargs={'BLOCK_SIZE': 32})] + + values = {"counter": 0, "has_exception": False} + + def _pre_hook(*args, **kwargs): + values["counter"] += 1 + + def _post_hook(*args, exception): + values["counter"] -= 1 + if exception is not None: + values["has_exception"] = True + assert values["counter"] == 0 + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook) + @triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4}) + @triton.jit + def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + max_iters = tl.cdiv(N, BLOCK_SIZE) + for _ in tl.range(max_iters, num_stages=N_STAGES): + x = tl.load(src + offsets, mask=offsets < N) + tl.store(src + offsets, x, mask=offsets < N) + offsets += BLOCK_SIZE + + _kernel[(1, )](src, N) + + # On NVIDIA GPUs: + # The tuning knob `num_stages` can be set by users. + # This will cause out of resources when N_STAGES = 100 + # shared memory bytes = N_STAGES * BLOCK_SIZE * sizeof(float) + # On AMD GPUs: + # `num_stages` is a fixed value of 2, so it won't cause out of resources + if triton.runtime.driver.active.get_current_target().backend == "cuda": + assert values["has_exception"] is True + else: + assert values["has_exception"] is False + + +@pytest.mark.parametrize('with_perf_model', [False, True]) +def test_prune_configs(with_perf_model: bool): + N = 1024 + src = torch.empty(N, device='cuda') + dst = torch.empty(N, device='cuda') + records = {} + + def early_config_prune(configs, named_args, **kwargs): + records['run_early_config_prune'] = True + if "N" in kwargs and kwargs["N"] == 1024: + records['capture_kwargs'] = True + if "dst" in named_args and "src" in named_args and len(named_args) == 2: + records['capture_named_args'] = True + return [configs[0]] + + def perf_model(*args, **kwargs): + records['run_perf_model'] = True + return kwargs['BLOCK_SIZE'] + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + if with_perf_model: + prune_configs_by = {'perf_model': perf_model, 'top_k': 1} + else: + prune_configs_by = {'early_config_prune': early_config_prune} + + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + torch.testing.assert_close(src, dst) + if with_perf_model: + assert len(records) == 1 + assert records['run_perf_model'] + else: + assert len(records) == 3 + assert records['run_early_config_prune'] + assert records['capture_kwargs'] + assert records['capture_named_args'] diff --git a/third_party/enflame/python/test/unit/runtime/test_bindings.py b/third_party/enflame/python/test/unit/runtime/test_bindings.py new file mode 100644 index 000000000..524f937cf --- /dev/null +++ b/third_party/enflame/python/test/unit/runtime/test_bindings.py @@ -0,0 +1,108 @@ +import triton +import triton.language as tl + +import torch +import math +import torch_gcu +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + + +@triton.jit +def add_helper(x, y): + return x + y + + +@triton.jit +def add_kernel( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = add_helper(x, y) + tl.store(out_ptr + offsets, output, mask=mask) + + +def test_module_walk(device): + """ + Test the MLIR bindings exposed for the out-of-tree walk. + """ + + def walk_fn(op): + name = op.get_name() + for i in range(op.get_num_results()): + op.get_result(i).id() + for i in range(op.get_num_operands()): + op.get_operand(i).id() + for i in range(op.get_num_regions()): + op.get_region(i).id() + block = op.get_block() + if block is not None: + block.id() + for i in range(block.get_num_arguments()): + block.get_argument(i) + if name == "tt.func": + op.get_str_attr("sym_name") + if name == "tt.call": + op.get_flat_symbol_ref_attr("callee") + + kernel = add_kernel + args = [ + torch.empty((32, 32), device=device), # in_ptr0 + torch.empty((32, 32), device=device), # in_ptr1 + 1024, # n_elements + torch.empty((32, 32), device=device), # out_ptr + 16, # BLOCK_SIZE + ] + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + src = triton.compiler.compiler.ASTSource( + fn=kernel, + signature={kernel.arg_names[i]: triton.runtime.jit.mangle_type(arg) + for i, arg in enumerate(args)}, + constexprs={kernel.arg_names[i]: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + ) + + context = triton._C.libtriton.ir.context() + options = backend.parse_options(dict()) + codegen_fns = dict() + module_map = backend.get_module_map() + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + ttir_module = src.make_ir(options, codegen_fns, module_map, context) + ttir_module.walk(walk_fn) + + +def test_python_func_in_visit_call(device): + + @triton.jit + def test_py_call_const_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + log2e: tl.constexpr = math.log2(math.e) + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x * log2e + tl.store(out_ptr + offsets, output, mask=mask) + + x = torch.randn(4, device=device) + out = torch.zeros_like(x) + test_py_call_const_kernel[(4, )](x, out, 4, 4) diff --git a/third_party/enflame/python/test/unit/runtime/test_cache.py b/third_party/enflame/python/test/unit/runtime/test_cache.py new file mode 100644 index 000000000..e0a7e3b19 --- /dev/null +++ b/third_party/enflame/python/test/unit/runtime/test_cache.py @@ -0,0 +1,634 @@ +import importlib.util +import itertools +import os +import shutil +import pathlib + +import pytest +import torch +import torch_gcu +from torch_gcu import transfer_to_gcu + +import triton +import triton.language as tl +from triton.runtime.jit import JITFunction +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + + +@triton.jit +def function_0(i): + return i + 1 + + +@triton.jit +def function_1(i): + i = i + 1 + cond: tl.constexpr = True + if cond: + FN: tl.constexpr = function_2 + else: + FN: tl.constexpr = function_0 + return FN(i) + + +@triton.jit +def function_2(i): + i = i + 1 + return i + + +@triton.jit +def combine_fn(a, b): + return COMBINE_OP # noqa: F821 + + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize_on_alignment=["i"]) +def kernel_nospec_on_alignment(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit +def kernel_with_combine_fn(X, BLOCK: tl.constexpr): + i = tl.arange(0, BLOCK) + i = REDUCE_OR_SCAN(i, 0, combine_fn) # noqa: F821 + tl.store(X, i) + + +def apply_src_change(target, old, new, to_modify): + kernel.hash = None + function_0.hash = None + function_1.hash = None + function_2.hash = None + to_modify._unsafe_update_src(to_modify.src.replace(old, new)) + ret = target.cache_key + to_modify._unsafe_update_src(to_modify.src.replace(new, old)) + return ret + + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1', function_1) + assert baseline == updated + + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_1) + assert baseline != updated + + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_2) + assert baseline != updated + + +def test_nested2_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_0) + assert baseline != updated + + +def test_combine_fn_change(): + # Test that tl.reduce and associative_scan calls include + # the combine_fn in the hash + + orig_combine_fn_src = combine_fn.src + orig_kernel_src = kernel_with_combine_fn.src + seen_keys = set() + + for reduce_or_scan, combine_op in itertools.product( + ["tl.reduce", "tl.associative_scan"], + ["a + b", "a * b"], + ): + combine_fn._unsafe_update_src(orig_combine_fn_src.replace("COMBINE_OP", combine_op)) + kernel_with_combine_fn._unsafe_update_src(orig_kernel_src.replace("REDUCE_OR_SCAN", reduce_or_scan)) + try: + key = kernel_with_combine_fn.cache_key + finally: + combine_fn._unsafe_update_src(orig_combine_fn_src) + kernel_with_combine_fn._unsafe_update_src(orig_kernel_src) + + assert key not in seen_keys + seen_keys.add(key) + + +def write_and_load_module(temp_file: pathlib.Path, code, num_extra_lines): + temp_file.write_text(('# extra line\n' * num_extra_lines) + code) + spec = importlib.util.spec_from_file_location("module.name", str(temp_file)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_changed_line_numbers_invalidate_cache(tmp_path: pathlib.Path): + from textwrap import dedent + code = dedent(""" + import triton + @triton.jit + def test_kernel(i): + i = i + 1 + """) + temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py" + orig_mod = write_and_load_module(temp_file0, code, 0) + orig_cache_key = orig_mod.test_kernel.cache_key + + temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py" + updated_mod = write_and_load_module(temp_file1, code, 1) + updated_cache_key = updated_mod.test_kernel.cache_key + assert orig_cache_key != updated_cache_key + + +def test_reuse(device, fresh_triton_cache): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + x = torch.empty(1, dtype=torch.int32, device=device) + for i in range(10): + kernel[(1, )](x, 1, BLOCK=1024) + assert counter == 1 + + +@pytest.mark.parametrize('mode', ['enable', 'disable', 'disable_on_alignment']) +def test_specialize(mode, device, fresh_triton_cache): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + x = torch.empty(1, dtype=torch.int32, device=device) + function = {'enable': kernel, 'disable': kernel_nospec, 'disable_on_alignment': kernel_nospec_on_alignment}[mode] + target = {'enable': 3, 'disable': 1, 'disable_on_alignment': 2}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1, )](x, i, BLOCK=512) + assert counter == target + + +def test_annotation(device): + + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device=device) + + device = getattr(torch, device).current_device() + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) + assert len(kernel.device_caches[device][0]) == 3 + + +GLOBAL_DEFAULT_ARG = 1 + + +def test_kernel_default_arg(device): + global GLOBAL_DEFAULT_ARG + + @triton.jit + def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](x) + assert x == torch.ones_like(x) + + # Changing the global variable should not change the default argument in + # `kernel`. That value gets set at the time the function is declared. + GLOBAL_DEFAULT_ARG = 2 + kernel[(1, )](x) + assert x == torch.ones_like(x) + + device = getattr(torch, device).current_device() + assert len(kernel.device_caches[device][0]) == 1 + + +GLOBAL_VAR = tl.constexpr(1) + + +def test_kernel_global_var_change(device): + global GLOBAL_VAR + + @triton.jit + def kernel(X): + tl.store(X, GLOBAL_VAR) + + x = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](x) + assert x == torch.ones_like(x) + + GLOBAL_VAR = 2 + with pytest.raises(RuntimeError) as e: + kernel[(1, )](x) + + assert "global variable" in str(e.value).lower() + + +GLOBAL = 42 # noqa + + +def test_local_shadows_global(): + global GLOBAL + + @triton.jit + def kernel(): + _, GLOBAL = 0, 0 # noqa + a = GLOBAL # noqa + + # No error because the `GLOBAL` we're modifying is not the same `GLOBAL` as + # inside the kernel. + GLOBAL = 42 + kernel[(1, )]() + GLOBAL = 43 + kernel[(1, )]() + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_local_does_not_shadow_global(): + global CONSTEXPR_GLOBAL + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + _, CONSTEXPR_GLOBAL = 0, 0 # noqa + + CONSTEXPR_GLOBAL = tl.constexpr(42) + kernel[(1, )]() + CONSTEXPR_GLOBAL = tl.constexpr(43) + + # Error because the `CONSTEXPR_GLOBAL` we're modifying is the same + # `CONSTEXPR_GLOBAL` that's read inside `kernel`. (Alternatively, we could + # make this kernel an error altogether, as it is if it's a pure Python + # function -- the fact that we store to `CONSTEXPR_GLOBAL` inside the kernel + # makes the first read a read of the local variable, which doesn't exist + # yet.) + with pytest.raises(RuntimeError): + kernel[(1, )]() + + +CONFLICTING_GLOBAL = tl.constexpr(0) + + +@triton.jit +def conflicting_global_inner(): + a = CONFLICTING_GLOBAL # noqa + + +def test_conflicting_global_in_inner_function(): + global CONFLICTING_GLOBAL + + @triton.jit + def kernel1(): + a = CONFLICTING_GLOBAL # noqa + conflicting_global_inner() + + @triton.jit + def kernel2(): + a = CONFLICTING_GLOBAL #noqa + conflicting_global_inner() + + kernel1[(1, )]() + + # This should be an error because kernel2 calls conflicting_global_inner, + # which saw a value for 42 for the global when it was first compiled. + CONFLICTING_GLOBAL = 1 + + with pytest.raises(RuntimeError) as e: + kernel2[(1, )]() + + assert "Global variable CONFLICTING_GLOBAL has value" in str(e.value) + + +def test_use_builtin(): + + @triton.jit + def kernel(): + a = float(0) # noqa + + # No error about the value of `float` changing. + kernel[(1, )]() + kernel[(1, )]() + + +def test_no_cache_module_as_global(): + + @triton.jit + def kernel(): + tl.arange(0, 16) + + kernel[(1, )]() + # `tl` should not be entered into used_global_vals + assert not kernel.used_global_vals + + +BUILTIN_AS_GLOBAL = tl.int32 + + +def test_cache_builtin_as_global(): + global BUILTIN_AS_GLOBAL + + @triton.jit + def kernel(): + x = BUILTIN_AS_GLOBAL # noqa + + kernel[(1, )]() + + BUILTIN_AS_GLOBAL = tl.int64 + with pytest.raises(RuntimeError) as e: + kernel[(1, )]() + + assert "global variable" in str(e.value).lower() + + +@triton.jit +def no_cache_callable_inner(): + pass + + +def test_no_cache_callable(): + + @triton.jit + def kernel(): + no_cache_callable_inner() + + kernel[(1, )]() + # `no_cache_callable_inner` should not be entered into used_global_vals. + assert not kernel.used_global_vals + + +def test_jit_warmup_cache(device) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + 32, + ] + device = getattr(torch, device).current_device() + assert len(kernel_add.device_caches[device][0]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.device_caches[device][0]) == 1 + + +def test_jit_debug(device) -> None: + + @triton.jit + def kernel(tmp): + tl.device_assert(tl.load(tmp) == 1, "tmp == 1") + + device = getattr(torch, device).current_device() + tmp = torch.tensor([1], dtype=torch.int32, device=device) + assert len(kernel.device_caches[device][0]) == 0 + kernel[(1, )](tmp, debug=False) + assert len(kernel.device_caches[device][0]) == 1 + kernel[(1, )](tmp, debug=True) + assert len(kernel.device_caches[device][0]) == 2 + bins = list(kernel.device_caches[device][0].values()) + assert bins[0].asm['ttir'] != bins[1].asm['ttir'] + + +@triton.jit +def add_fn(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + +def test_jit_noinline(device) -> None: + + @triton.jit + def kernel_add_device(a, b, o, N: tl.constexpr): + add_fn(a, b, o, N) + + device = getattr(torch, device).current_device() + assert len(kernel_add_device.device_caches[device][0]) == 0 + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.device_caches[device][0]) == 1 + bins = list(kernel_add_device.device_caches[device][0].values()) + inline_ttir = bins[0].asm['ttir'] + add_fn.noinline = True + add_fn.hash = None + kernel_add_device.hash = None + kernel_add_device.device_caches[device][0].clear() + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.device_caches[device][0]) == 1 + bins = list(kernel_add_device.device_caches[device][0].values()) + noinline_ttir = bins[0].asm['ttir'] + assert inline_ttir != noinline_ttir + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + +def test_preload(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) + + device = getattr(torch, device).current_device() + + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + pre_compile = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + hash = pre_compile.hash + assert specialization_data is not None + + # clear the cache + shutil.rmtree(fresh_triton_cache) + kernel_add.device_caches[device][0].clear() + + # preload the kernel + kernel_preload = kernel_add.preload(specialization_data) + assert kernel_preload.hash == hash + assert len(kernel_add.device_caches[device][0]) == 1 + + # we should hit the cache and not compile anything + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + JITFunction.cache_hook = None + assert counter == 0 + assert len(kernel_add.device_caches[device][0]) == 1 + assert final_kernel.hash == hash + + # test that we can't preload a mismatched kernel + with pytest.raises(RuntimeError, match="Specialization data is for"): + kernel_sub.preload(specialization_data) + + +def test_hooks(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + # get the serialized specialization data + specialization_data = None + is_warmup = False + key = 0 + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + nonlocal is_warmup + is_warmup = kwargs["compile"]["is_warmup"] + nonlocal key + key = kwargs["compile"]["key"] + + specialization_data_compiled = None + + def compiled_hook(*args, **kwargs): + nonlocal specialization_data_compiled + specialization_data_compiled = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + JITFunction.compiled_hook = compiled_hook + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + assert specialization_data is not None and specialization_data_compiled == specialization_data + assert is_warmup is True + assert key in kernel_add.device_caches[getattr(torch, device).current_device()][0] + + +@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=True) +def test_within_2gb(device, fresh_triton_cache) -> None: + default_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") + from triton.backends import backends + + amd_backend = backends["amd"] + try: + use_buffer_ops_opts = ["1", "0"] + # The ranges should only be available when buffer ops are enabled + pointer_ranges = [[(0, )], []] + for use_buffer_ops, pointer_range in zip(use_buffer_ops_opts, pointer_ranges): + # Set AMDGCN_USE_BUFFER_OPS + amd_backend.compiler.use_buffer_ops.cache_clear() + os.environ["AMDGCN_USE_BUFFER_OPS"] = use_buffer_ops + + @triton.jit + def kernel_add(a): + tl.load(a) + + # This is the attribute we want to test + pointer_range_32 = None + + def cache_hook(*args, **kwargs): + nonlocal pointer_range_32 + pointer_range_32 = [ + k for k, v in kwargs["compile"]["configs"][0].items() if ["tt.pointer_range", 32] in v + ] + + JITFunction.cache_hook = cache_hook + # In warmup we assume that the pointer range is 32 bits + kernel_add.warmup(torch.float32, grid=(1, )) + assert pointer_range_32 == pointer_range + # Torch tensor > 2GB + kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) + assert len(pointer_range_32) == 0 + # Torch tensor <= 2GB + kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) + assert pointer_range_32 == pointer_range + finally: + amd_backend.compiler.use_buffer_ops.cache_clear() + os.environ["AMDGCN_USE_BUFFER_OPS"] = default_buffer_ops + + +# [ TODO ] new case in triton3.3, test failed +# def test_function_arguments(device): + +# @triton.jit +# def func1(): +# return 1 + +# @triton.jit +# def func2(): +# return 2 + +# @triton.jit +# def func3(x): +# return x + +# @triton.jit +# def func4(x, y): +# return x + y + +# @triton.jit +# def kernel(Y, fn: tl.constexpr, fn_args): +# tl.store(Y, fn(*fn_args)) + +# JITFunction.cache_hook = None +# JITFunction.compiled_hook = None +# y = torch.zeros((5, ), dtype=torch.int32, device=device) +# kernel[(1, )](y[0], func1, tuple()) +# kernel[(1, )](y[1], func2, tuple()) +# kernel[(1, )](y[2], func3, (3, )) +# kernel[(1, )](y[3], func4, (3, 4)) +# kernel[(1, )](y[4], func1, tuple()) +# assert len(kernel.device_caches[0][0]) == 4 +# assert y.tolist() == [1, 2, 3, 7, 1] diff --git a/third_party/enflame/python/test/unit/runtime/test_driver.py b/third_party/enflame/python/test/unit/runtime/test_driver.py new file mode 100644 index 000000000..2baae9bac --- /dev/null +++ b/third_party/enflame/python/test/unit/runtime/test_driver.py @@ -0,0 +1,18 @@ +import sys + +import triton +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import torch_gcu + + +def test_is_lazy(): + from importlib import reload + reload(sys.modules["triton.runtime.driver"]) + reload(sys.modules["triton.runtime"]) + mod = sys.modules[triton.runtime.driver.__module__] + assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy")) + assert triton.runtime.driver.active._obj is None + utils = triton.runtime.driver.active.utils # noqa: F841 + assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase")) diff --git a/third_party/enflame/python/test/unit/runtime/test_jit.py b/third_party/enflame/python/test/unit/runtime/test_jit.py new file mode 100644 index 000000000..f569995e0 --- /dev/null +++ b/third_party/enflame/python/test/unit/runtime/test_jit.py @@ -0,0 +1,46 @@ +import itertools +import pytest +import torch +import torch_gcu + +import triton +import triton.language as tl +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + + +def test_pre_call_hooks(device): + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + class MyTensor(torch.Tensor): + pass + + def my_hook(*args, **kwargs): + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, MyTensor): + raise Exception("MyTensor is not allowed") + + add_kernel.add_pre_run_hook(my_hook) + + x = torch.randn(4, device=device) + y = MyTensor(x) + out = torch.zeros_like(x) + with pytest.raises(Exception): + add_kernel[(4, )](x, y, out, 4, 4) diff --git a/third_party/enflame/python/test/unit/runtime/test_launch.py b/third_party/enflame/python/test/unit/runtime/test_launch.py new file mode 100644 index 000000000..4e54b2dd3 --- /dev/null +++ b/third_party/enflame/python/test/unit/runtime/test_launch.py @@ -0,0 +1,138 @@ +import gc +# import importlib +# import os +# import sys +# import tempfile +# import textwrap +# import time +import tracemalloc + +import torch +import torch_gcu + +import triton +import triton.language as tl +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + +# from typing import Tuple + + +def test_metadata() -> None: + + used_hook = False + + def _launch_metadata(grid, kernel, args): + ret = dict() + ret["grid"] = grid + ret["value"] = args["x"] + return ret + + def hook(launch_metadata): + nonlocal used_hook + metadata = launch_metadata.get() + assert metadata["grid"] == (1, 3, 2) + assert metadata["value"] == 6 + used_hook = True + + @triton.jit(launch_metadata=_launch_metadata) + def kernel(x): + pass + + # launch kernel + triton.compiler.CompiledKernel.launch_enter_hook = hook + kernel[(1, 3, 2)](6) + triton.compiler.CompiledKernel.launch_enter_hook = None + assert used_hook + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + tracemalloc.start() + try: + inp = torch.randn(10, device='gcu') + out = torch.randn(10, device='gcu') + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + begin, _ = tracemalloc.get_traced_memory() + for _ in range(100): + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + end, _ = tracemalloc.get_traced_memory() + assert end - begin < 30000 + finally: + tracemalloc.stop() + + +# LATENCY_THRESHOLD_US = 46 + +# def test_kernel_launch_latency() -> None: +# def define_kernel(kernel_name: str, num_tensor_args: int) -> str: +# arg_str = ",".join([f"arg{i}: torch.Tensor" for i in range(num_tensor_args)]) +# arg_str += ", n_elements: int, BLOCK_SIZE: tl.constexpr" +# func_str = f""" +# import torch + +# import triton +# import triton.language as tl + +# @triton.jit +# def {kernel_name}({arg_str}): +# pass +# """ +# with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py", delete=False) as temp_file: +# temp_file.write(textwrap.dedent(func_str)) +# temp_file_path = temp_file.name + +# return temp_file_path + +# def import_kernel(file_path, kernel_name): +# directory, filename = os.path.split(file_path) +# module_name, _ = os.path.splitext(filename) +# sys.path.insert(0, directory) + +# module = importlib.import_module(module_name) +# kernel = getattr(module, kernel_name) +# return kernel + +# def empty(*kernel_args: Tuple[torch.Tensor]): +# first_arg = kernel_args[0] +# n_elements = first_arg.numel() +# grid = (triton.cdiv(n_elements, 1024),) +# device = torch.cuda.current_device() +# # Warmup +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# torch.cuda.synchronize() +# # Measure launch overhead at steady state +# num_runs = 1000 +# start_time = time.time() +# for i in range(num_runs): +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# end_time = time.time() +# latency_us = (end_time - start_time) / num_runs * 1e6 + +# assert latency_us < LATENCY_THRESHOLD_US, "Kernel launch time has increased!" + +# num_tensor_args = 40 +# kernel_name = 'empty_kernel' +# file_path = define_kernel(kernel_name, num_tensor_args) +# empty_kernel = import_kernel(file_path, kernel_name) + +# # Initialize random tensors for the empty_kernel +# torch.manual_seed(0) +# size = 1024 +# kernel_args = (torch.rand(size, device='cuda') for i in range(num_tensor_args)) + +# # Run empty, which would run empty_kernel internally +# empty(*kernel_args) diff --git a/third_party/enflame/python/test/unit/runtime/test_subproc.py b/third_party/enflame/python/test/unit/runtime/test_subproc.py new file mode 100644 index 000000000..204a39688 --- /dev/null +++ b/third_party/enflame/python/test/unit/runtime/test_subproc.py @@ -0,0 +1,88 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import multiprocessing +import os +import shutil + +import torch +import torch_gcu +from torch_gcu import transfer_to_gcu + +import triton +import triton.language as tl +from triton.compiler import ASTSource +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton + +tmpdir = ".tmp" + +target = triton.runtime.driver.active.get_current_target() + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + + +def compile_fn(attrs, capability): + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + + src = ASTSource( + fn=kernel_sub, + constants={3: 32}, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + attrs=attrs, + ) + triton.compile(src=src, target=target) + + +def test_compile_in_subproc() -> None: + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = triton.backends.compiler.AttrsDescriptor(tuple(range(4)), ()) + + multiprocessing.set_start_method('fork') + proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_fn_dot(attrs, capability): + + @triton.jit + def kernel_dot(Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + tl.store(Z + offs, z) + + src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + triton.compile(src=src, target=target) + + +def test_compile_in_forked_subproc() -> None: + reset_tmp_dir() + major, minor = torch.cuda.get_device_capability(0) + capability = major * 10 + minor + config = triton.backends.compiler.AttrsDescriptor(tuple(range(1)), ()) + + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) + proc.start() + proc.join() + assert proc.exitcode == 0 + +''' diff --git a/third_party/enflame/python/test/unit/tools/.coveragerc b/third_party/enflame/python/test/unit/tools/.coveragerc new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/enflame/python/test/unit/tools/test_aot.py b/third_party/enflame/python/test/unit/tools/test_aot.py new file mode 100644 index 000000000..c1fdde563 --- /dev/null +++ b/third_party/enflame/python/test/unit/tools/test_aot.py @@ -0,0 +1,456 @@ +''' +This file is commented out for release. +The original content is preserved below for reference. + +--- +Original file content: +--- + +import glob +import os +import subprocess +import sys +import tempfile + +import numpy as np + +import triton +from triton.backends.compiler import GPUTarget +from triton.backends.nvidia.driver import include_dir, library_dirs +import importlib.util +if importlib.util.find_spec("triton.backends.enflame") is None: + import triton_gcu.triton +import torch_gcu + +kernel_utils_src = """ +import triton + +@triton.jit +def mul(x, y): + return x * y +""" + +kernel_src = """ +import triton +import triton.language as tl +import kernel_utils + +@triton.jit +def kernel(C, A, B, M, N, K, + stride_cm, stride_cn, + stride_am, stride_ak, + stride_bk, stride_bn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + c = kernel_utils.mul(accumulator, accumulator) + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, c) +""" + +test_utils_src = """ +#include +#include +#include +#include +#include +#include "kernel.h" + +static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) { + FILE *file = fopen(filename, "w"); + if (file == NULL) { + printf("Could not open file %s\\n", filename); + return; + } + for (int i = 0; i < size; i++) { + fprintf(file, "%d", buffer[i]); + if (i < size - 1) { + fprintf(file, ","); + } + } + fclose(file); +} + +static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + printf("Could not open file %s\\n", filename); + return; + } + int index = 0; + while (fscanf(file, "%hd,", &buffer[index]) != EOF && index < size) { + index++; + } + fclose(file); +}""" + + +def gen_kernel_library(dir, libname): + c_files = glob.glob(os.path.join(dir, "*.c")) + subprocess.run( + ["gcc"] + c_files + ["-I", include_dir[0], "-c", "-fPIC"], + check=True, + cwd=dir, + ) + o_files = glob.glob(os.path.join(dir, "*.o")) + + command = ["gcc", *o_files, "-shared", "-o", libname] + for lib_dir in library_dirs(): + command.extend(["-L", lib_dir]) + subprocess.run(command, check=True, cwd=dir) + + +def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): + test_src = f""" +int main(int argc, char **argv) {{ + int M = {M}, N = {N}, K = {K}; + + // initialize CUDA handles + CUdevice dev; + CUcontext ctx; + CUstream stream; + CUdeviceptr A, B, C; + CUresult err = 0; + cuInit(0); + cuDeviceGet(&dev, 0); + cuCtxCreate(&ctx, 0, dev); + cuMemAlloc(&A, M * K * 2); + cuMemAlloc(&B, K * N * 2); + cuMemAlloc(&C, M * N * 4); + cuStreamCreate(&stream, 0); + load_matmul_fp16(); + + // initialize input data + int16_t hA[M*K]; + int16_t hB[K*N]; + memset(hA, 0, M*K*2); + memset(hB, 0, K*N*2); + read_csv_to_buffer(argv[1], hA, M*K); + read_csv_to_buffer(argv[2], hB, K*N); + cuMemcpyHtoD(A, hA, M*K*2); + cuMemcpyHtoD(B, hB, K*N*2); + + // launch kernel + cuStreamSynchronize(stream); + CUresult ret; + int algo_id = {algo_id}; + if (algo_id == 0) {{ + ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1); + }} else {{ + ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id}); + }} + if (ret != 0) fprintf(stderr, "kernel launch failed\\n"); + assert(ret == 0); + + cuStreamSynchronize(stream); + + // read data + int32_t hC[M*N]; + memset(hC, 0, M*N*4); + cuMemcpyDtoH(hC, C, M*N*4); + write_buffer_to_csv(argv[3], hC, M*N); + + // free cuda handles + unload_matmul_fp16(); + cuMemFree(A); + cuMemFree(B); + cuMemFree(C); + cuCtxDestroy(ctx); +}} +""" + src = test_utils_src + test_src + with open(os.path.join(dir, "test.c"), "w") as file: + file.write(src) + + command = ["gcc", "test.c"] + for inc_dir in include_dir: + command.extend(["-I", inc_dir]) + for lib_dir in library_dirs(): + command.extend(["-L", lib_dir]) + command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe]) + subprocess.run(command, check=True, cwd=dir) + + +def write_triton_kernels(dir, src, util_src): + kernel_path = os.path.join(dir, "kernel.py") + with open(kernel_path, "w") as file: + file.write(src) + + kernel_utils_path = os.path.join(dir, "kernel_utils.py") + with open(kernel_utils_path, "w") as file: + file.write(util_src) + + return kernel_path + + +def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path): + compiler_path = os.path.join(triton.tools.__path__[0], "compile.py") + + subprocess.run( + [ + sys.executable, + compiler_path, + "-n", + kernel_name, + "--signature", + signature, + "--out-name", + out_name, + "-o", + out_path, + "-w", + str(num_warps), + "-g", + grid, + kernel_path, + ], + check=True, + cwd=dir, + ) + + +# Edge case kernel with no specialization +def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK): + # compile all desired configs + sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}" + name = f"matmul_{dtype}" + grid = f"M/{BM}, N/{BN}, 1" + _compile_kernel( + dir=dir, + signature=sig, + kernel_name="kernel", + out_name=name, + out_path=name, + num_warps=1, + grid=grid, + kernel_path=kernel_path, + ) + + +def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints): + # compile all desired configs + for ha in ha_hb_hints: + for hb in ha_hb_hints: + sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}" + name = f"matmul_{dtype}" + grid = f"M/{BM}, N/{BN}, 1" + _compile_kernel( + dir=dir, + signature=sig, + kernel_name="kernel", + out_name=name, + out_path=name, + num_warps=1, + grid=grid, + kernel_path=kernel_path, + ) + + +def link_aot_kernels(dir): + linker_path = os.path.join(triton.tools.__path__[0], "link.py") + + # link all desired configs + h_files = glob.glob(os.path.join(dir, "*.h")) + subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir) + + +def generate_matmul_test_data(dir, M, N, K): + a = np.random.randn(M * K).astype(np.float16).reshape((M, K)) + b = np.random.randn(M * K).astype(np.float16).reshape((K, N)) + a_path = os.path.join(dir, "a.csv") + b_path = os.path.join(dir, "b.csv") + c_path = os.path.join(dir, "c.csv") + for x, path in [(a, a_path), (b, b_path)]: + x.view(np.int16).ravel().tofile(path, sep=",") + return a, b, a_path, b_path, c_path + + +# Test edge case where the provided kernel signature has no specializations +def test_compile_link_matmul_no_specialization(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + BM, BN, BK = 16, 16, 16 + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK) + link_aot_kernels(tmp_dir) + + # compile test case + M, N, K = 16, 16, 16 + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) + + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + + # run test case + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0) + + +def test_compile_link_matmul(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + BM, BN, BK = 16, 16, 16 + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + link_aot_kernels(tmp_dir) + + # compile test case + M, N, K = 16, 16, 16 + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) + + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + + # run test case + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0) + + +def test_launcher_has_no_available_kernel(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + BM, BN, BK = 16, 16, 16 + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[":1"]) + link_aot_kernels(tmp_dir) + + # compile test case + M, N, K = 16, 16, 16 + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) + + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + + # run test case + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + result = subprocess.run( + ["./test", a_path, b_path, c_path], + env=env, + cwd=tmp_dir, + capture_output=True, + text=True, + ) + + # It should fail since the launcher requires all the strides be 1 while they are not. + assert result.returncode == -6 + assert "kernel launch failed" in result.stderr + + +def test_compile_link_autotune_matmul(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + + tile_sizes = [ + [16, 16, 16], + [32, 32, 16], + [32, 32, 32], + [64, 64, 32], + ] + + for ts in tile_sizes: + BM, BN, BK = ts[0], ts[1], ts[2] + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) + + link_aot_kernels(tmp_dir) + + gen_kernel_library(tmp_dir, "libkernel.so") + + # compile test case + M, N, K = 64, 64, 64 + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + + for algo_id in range(len(tile_sizes)): + # generate and run test case + test_name = f"test_{algo_id}" + gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id) + + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run( + [f"./{test_name}", a_path, b_path, c_path], + check=True, + cwd=tmp_dir, + env=env, + ) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=1e-4) + + +def test_ttgir_to_ptx(): + src = """ +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { + tt.return + } +} +""" + with tempfile.TemporaryDirectory() as tmp_dir: + kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir") + with open(kernel_path, "w") as fp: + fp.write(src) + k = triton.compile(kernel_path, target=GPUTarget("cuda", 80, 32)) + ptx = k.asm["ptx"] + assert ".target sm_80" in ptx + assert ".address_size 64" in ptx + +''' diff --git a/third_party/enflame/triton_enflame.cc b/third_party/enflame/triton_enflame.cc new file mode 100644 index 000000000..269f7423f --- /dev/null +++ b/third_party/enflame/triton_enflame.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace py = pybind11; + +// Empty initialization function to resolve linking issues +void init_triton_enflame(py::module &&m) { + // Temporarily provide empty implementation to resolve linking issues + // TODO: Add GCU-related MLIR dialects and transformation passes later + + m.doc() = "Enflame GCU backend for Triton"; + + // Can add some basic utility functions + m.def("get_gcu_arch", []() { + return "gcu300"; // Default architecture + }); + + m.def("is_gcu_available", []() { + // Simple availability check + return true; + }); + + // Empty dialect loading function + m.def("load_dialects", [](py::object context) { + // TODO: Load GCU-related MLIR dialects + // Temporarily empty implementation, using py::object to avoid MLIR + // dependency + }); +} diff --git a/third_party/enflame/triton_gcu/CMakeLists.txt b/third_party/enflame/triton_gcu/CMakeLists.txt new file mode 100644 index 000000000..7215d22ad --- /dev/null +++ b/third_party/enflame/triton_gcu/CMakeLists.txt @@ -0,0 +1,18 @@ + +set(TRITON_TAG_GCU300 d43f9ac82d21bba6ed909a49bcbfd16745ab24ed) # branch gcu_triton_3.3.1 + +include(ExternalProject) + +set(FETCHCONTENT_QUIET OFF) +if(CMAKE_TOOLCHAIN_FILE_FULL_PATH) + message(STATUS "[triton-${arch}]: CMAKE_TOOLCHAIN_FILE_FULL_PATH: ${CMAKE_TOOLCHAIN_FILE_FULL_PATH}") + set(TOOLCHAIN_PARM "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE_FULL_PATH}") +endif() + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +foreach(arch gcu300) + # get upper case of arch + string(TOUPPER ${arch} ARCH) + add_subdirectory(triton_${arch}) +endforeach() diff --git a/third_party/enflame/triton_gcu/RegisterGCUDialects.h b/third_party/enflame/triton_gcu/RegisterGCUDialects.h new file mode 100644 index 000000000..3755e7fab --- /dev/null +++ b/third_party/enflame/triton_gcu/RegisterGCUDialects.h @@ -0,0 +1,69 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_REGISTER_GCU_DIALECT_H +#define GCU_REGISTER_GCU_DIALECT_H + +#include "Conversion/Passes.h" +#include "Dialect/GCU/IR/Dialect.h" +#include "Dialect/MathExt/IR/MathExt.h" +#include "Dialect/MemrefExt/IR/MemrefExt.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Transforms/Passes.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/InitAllPasses.h" +#include "mlir/InitAllTranslations.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" + +namespace mlir { +namespace gcu { + +inline void registerGCUDialects(mlir::DialectRegistry ®istry) { + // dialects + mlir::registerAllDialects(registry); + registry.insert(); + + // extensions + mlir::registerAllExtensions(registry); + + // passes + mlir::registerAllPasses(); + + mlir::registerTritonGCUConversionPasses(); + mlir::registerTritonGCUTransformsPasses(); + + // translation + mlir::registerAllTranslations(); + + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerGPUDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerX86VectorDialectTranslation(registry); + + // Extension required for translating GPU offloading Ops. + gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); +} + +} // namespace gcu +} // namespace mlir + +#endif // GCU_REGISTER_GCU_DIALECT_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/CMakeLists.txt new file mode 100644 index 000000000..7da90f47e --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/CMakeLists.txt @@ -0,0 +1,4 @@ +# For Triton +# ###################################################### + +include(triton_gcu) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/AxisInfoEx.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/AxisInfoEx.h new file mode 100644 index 000000000..1818210c2 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/AxisInfoEx.h @@ -0,0 +1,278 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRITON_ANALYSIS_AXISINFO_EX_H +#define TRITON_ANALYSIS_AXISINFO_EX_H +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +//===--------------------------------------------------------------------===// +// The main logical is modified from +// include/triton/Analysis/AxisInfo.h in the triton repo. +//===--------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace gcu { +//===----------------------------------------------------------------------===// +// AxisInfoEx +//===----------------------------------------------------------------------===// + +/// This lattice value represents known information on the axes of a lattice. +class AxisInfoEx { +public: + typedef SmallVector DimVectorT; + typedef ArrayRef DimRefT; + constexpr static int64_t kInitDivisibility = 1; + constexpr static int64_t kDefaultContinueSize = 1; + constexpr static int64_t kDefaultContinualInterval = -1; + +public: + AxisInfoEx() : AxisInfoEx({}, {}, {}) {} + + AxisInfoEx(DimRefT pDivisibility, DimRefT pContinualSize, + DimRefT pContinualInterval) + : AxisInfoEx(pDivisibility, pContinualSize, pContinualInterval, + std::nullopt) {} + + AxisInfoEx(DimRefT pDivisibility, DimRefT pContinualSize, + DimRefT pContinualInterval, + std::optional pConstantValue) { + divisibility.append(pDivisibility.begin(), pDivisibility.end()); + continualSize.append(pContinualSize.begin(), pContinualSize.end()); + continualInterval.append(pContinualInterval.begin(), + pContinualInterval.end()); + constantValue = pConstantValue; + rank = continualSize.size(); + assert(divisibility.size() == static_cast(rank)); + assert(continualSize.size() == static_cast(rank)); + assert(continualInterval.size() == static_cast(rank)); + } + + int64_t getContiguity(size_t dim) const { + int64_t dimContinualSize = 1; + if (continualInterval[dim] == 1) + dimContinualSize = continualSize[dim]; + return dimContinualSize; + } + + int64_t getConstancy(size_t dim) const { + int64_t dimConstancySize = 1; + if (continualInterval[dim] == 0) + dimConstancySize = continualSize[dim]; + return dimConstancySize; + } + + int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } + const DimVectorT &getDivisibility() const { return divisibility; } + + int64_t getContinualSize(size_t dim) const { return continualSize[dim]; } + const DimVectorT &getContinualSize() const { return continualSize; } + + int64_t getContinualInterval(size_t dim) const { + return continualInterval[dim]; + } + const DimVectorT &getContinualInterval() const { return continualInterval; } + + int getRank() const { return rank; } + + std::optional getConstantValue() const { return constantValue; } + + bool isContinualLowDim(ArrayRef shape, int dim) const { + return getContiguity(dim) == shape[dim]; + } + + bool isConstantDim(ArrayRef shape, int dim) const { + return getConstancy(dim) == shape[dim]; + } + + bool isContinualDim(ArrayRef shape, int dim) const { + return getContinualSize(dim) == shape[dim]; + } + + bool isStridedContinualDim(ArrayRef shape, int dim) const { + if (continualInterval.size() < 1 || continualSize.size() < 1) + return false; + if (shape[dim] == 1) + return true; + return getContinualInterval(dim) == 1 && getContinualSize(dim) > 1 && + shape[dim] % getContiguity(dim) == 0; + } + + bool isStridedConstantDim(ArrayRef shape, int dim) const { + if (continualInterval.size() < 1 || continualSize.size() < 1) + return false; + if (shape[dim] == 1) + return true; + return getContinualInterval(dim) == 0 && getContinualSize(dim) > 1 && + shape[dim] % getConstancy(dim) == 0; + } + + bool operator==(const AxisInfoEx &other) const { + return divisibility == other.divisibility && + continualSize == other.continualSize && + continualInterval == other.continualInterval && + constantValue == other.constantValue && rank == other.rank; + } + + template + static void initPessimisticStateFromFunc(int argNumber, T funcOp, int rank, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy); + + static AxisInfoEx getPessimisticValueState(Value value); + + // The gcd of both arguments for each dimension + static AxisInfoEx join(const AxisInfoEx &lhs, const AxisInfoEx &rhs); + + void print(raw_ostream &os) const { + auto print = [&](StringRef name, DimVectorT vec) { + os << name << " = ["; + llvm::interleaveComma(vec, os); + os << "]"; + }; + print("divisibility", divisibility); + print(", continualSize", continualSize); + print(", continualInterval", continualInterval); + os << ", constant_value = "; + if (constantValue) + os << *constantValue; + else + os << ""; + } + +private: + // The continual information maps the `d`-th + // dimension to the length of the shortest + // sequence of integers of the same continualInterval along it. + // Suppose we have an array of N elements, + // with a continualSize value C, + // the array can be divided into a list of + // N/C sequences of C subsequence of integers of the same continualInterval. + // For example: + // [10, 11, 12, 13, 18, 19, 20, 21] + // [20, 21, 22, 23, 28, 29, 30, 31] + // Would have continualSize [2, 4] and have continualInterval [10, 1]. + // and + // [12, 16, 20, 24] + // [13, 17, 21, 25] + // [14, 18, 22, 26] + // [15, 19, 23, 27] + // [18, 22, 26, 30] + // [19, 23, 27, 31] + // Would have continualSize [2, 4] and have continualInterval [1, 4]. + DimVectorT continualSize; + DimVectorT continualInterval; + + // the divisibility information maps the `d`-th dimension to + // the largest power of two that divides the first element + // of all the values along it. + // For example, + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has divisibility [1, 2], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27]] + // + // has divisibility [4, 1]. + DimVectorT divisibility; + + // The constant value of the lattice if we can infer it. + std::optional constantValue; + + // Number of dimensions of the lattice. + int rank{}; +}; + +// Module level axis info analysis based on the call graph, assuming that we do +// not have recursive functions. +// +// Since each function will be called multiple times, we need to calculate the +// axis info based on the axis info of all the callers. In the future, we can +// perform optimization using function cloning so that each call site will have +// unique axis info. +using AxisInfoExMapT = DenseMap; +class ModuleAxisInfoExAnalysis : public CallGraph { +public: + explicit ModuleAxisInfoExAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, AxisInfoExMapT{}); + }); + (void)root; + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = dyn_cast(callOp.resolveCallable()); + update(callOp, callee); + }); + } + } + + AxisInfoEx *getAxisInfoEx(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoExMap = getFuncData(funcOp); + if (!axisInfoExMap) { + return nullptr; + } + auto it = axisInfoExMap->find(value); + if (it == axisInfoExMap->end()) { + return nullptr; + } + return &(it->second); + } + + unsigned getPtrContiguity(Value ptr); + unsigned getPtrAlignment(Value ptr); + unsigned getMaskAlignment(Value mask); + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/FirstLastUserAnalysis.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/FirstLastUserAnalysis.h new file mode 100644 index 000000000..8c076c80b --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/FirstLastUserAnalysis.h @@ -0,0 +1,75 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_ANALYSIS_FIRSTLASTUSERANALYSIS_H +#define GCU_ANALYSIS_FIRSTLASTUSERANALYSIS_H + +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Visitors.h" +#include "llvm/ADT/DenseMap.h" + +namespace mlir { +namespace triton { +namespace gcu { + +using namespace mlir; + +class FirstLastUserAnalysis { +public: + using OptoOpT = llvm::DenseMap; + + explicit FirstLastUserAnalysis(Operation *op) + : moduleOp(op), dominators(op), postDominators(op) { + start(); + } + + Operation *getLastUserOp(Value value, Region *opRegion); + + Operation *getLastUserOp(Operation *op) const { + if (lastUserMap.count(op) == 0) { + llvm::errs() << "op: " << *op << " has no last user\n"; + llvm::report_fatal_error("No last user found for op"); + } + return lastUserMap.lookup(op); + } + + Operation *getFirstUserOp(Operation *op) const { + if (firstUserMap.count(op) == 0) { + llvm::errs() << "op: " << *op << " has no first user\n"; + llvm::report_fatal_error("No first user found for op"); + } + return firstUserMap.lookup(op); + } + +private: + void start(); + +private: + Operation *moduleOp; + DominanceInfo dominators; + PostDominanceInfo postDominators; + + OptoOpT lastUserMap; + OptoOpT firstUserMap; +}; + +} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif // GCU_ANALYSIS_FIRSTLASTUSERANALYSIS_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/MaskAnalysis.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/MaskAnalysis.h new file mode 100644 index 000000000..cd8e128c9 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/MaskAnalysis.h @@ -0,0 +1,190 @@ + +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_ANALYSIS_MASKANALYSIS_H +#define GCU_ANALYSIS_MASKANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +//===--------------------------------------------------------------------===// +// The main code is modified from triton-to-linalg branch in triton repo. +//===--------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace gcu { +// Data structure used to decode the pattern in a mask used for load and store. +// start and end field represent the start and end index of a range (produced +// by make_range, addi, etc.). While multi-dimensional data is possible, we +// assume range comparison can only be done on 1 dimension at a time (and +// results of range comparisons across dimensions can be combined), hence start +// and end are not vectors. dims represents the real access size for load/store +// (instead of the tensor/memref size specified by the IR). scalar is a shortcut +// used when the entire state contains a single scalar value. +// +// The general lifetime of this data structure is roughly: +// 1. A range is created by make_range and optionally operated on by addi w/ +// result of splat, expand_dims, etc. During this phase, either (1) both start +// and end are populated, or (2) scalar is populated. Only one of the dimensions +// (that contains the range) can have dim > 1. +// 2. Result from step 1 is compared with a another MaskState that represents a +// scalar value. The resulting state only has dims populated. +// 3. Optionally, result from step 2 can be broadcasted and anded with other +// results from step 2. The resulting state only has dims populated. +// +// Example of creating 2D mask: +// mask = (rows[:, None] < M) & (cols[None, :] < N) +struct MaskState { + OpFoldResult start; + OpFoldResult end; + SmallVector dims; + OpFoldResult scalar; + + int64_t getRank() const { return dims.size(); } + + bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } + + bool isMask() const { return !start && !end && !scalar && dims.size() != 0; } + + void addStateScalar(OpBuilder &builder, Location loc, const MaskState &state, + const OpFoldResult scalar); + + void addStates(OpBuilder &builder, Location loc, const MaskState &lhsState, + const MaskState &rhsState); + + void minStates(OpBuilder &builder, Location loc, const MaskState &lhsState, + const MaskState &rhsState); + + void setStates(OpBuilder &builder, Location loc, const MaskState &srcState); +}; + +class MaskAnalysis { +public: + // Recursively parse a Value; call the corresponding function based on the + // defining operation and Value type + static void parse(OpBuilder &builder, Location loc, Value operand, + MaskState &state, + llvm::SmallDenseMap &knownMasks); + + static void + parseBlockArgument(OpBuilder &builder, Location loc, BlockArgument blockArg, + MaskState &state, + llvm::SmallDenseMap &knownMasks); + + static void parseIntScalar(OpBuilder &builder, Location loc, Value scalar, + MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of a constant + // Get the value of the constant and assign it to scalar. + static void parseConstant(OpBuilder &builder, Location loc, + arith::ConstantOp constOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of addi + // One and only one of the operands should be a scalar. Increment both start + // and end, dims remains unchanged, and scalar is empty. + static void parseAdd(OpBuilder &builder, Location loc, arith::AddIOp addOp, + MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of andi + // Each of the result state dims is smaller of the two operands' dims. + // Insert instruction if needed to get new dims. + static void parseAnd(OpBuilder &builder, Location loc, arith::AndIOp andOp, + MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of cmpi + // Assume only of the dimensions have size > 1. Only support slt for now. + // For that dimension, calculate this new dim as: dim = min(end, value) - + // start + static void parseCmp(OpBuilder &builder, Location loc, arith::CmpIOp cmpOp, + MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of make_range + // Set start and end accordingly; step size must be 1. + static void parseMakeRange(OpBuilder &builder, Location loc, + triton::MakeRangeOp rangeOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of broadcast + // Change dims only; assume only applies to tensors. + static void parseBroadcast(OpBuilder &builder, Location loc, + triton::BroadcastOp broadcastOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of splat + // Assume only applies to scalar. start and end are left empty; scalar will + // be assigned, and dims will be updated. + static void parseSplat(OpBuilder &builder, Location loc, + triton::SplatOp splatOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of expand_dims + // Insert additional dims; start and end do not change and correspond to the + // dimension that contains the range. + static void + parseExpandDims(OpBuilder &builder, Location loc, + triton::ExpandDimsOp expandDimsOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + + // Operand is the result of DotC + static void parseDot(OpBuilder &builder, Location loc, triton::DotOp dotOp, + MaskState &state, + llvm::SmallDenseMap &knownMasks); + // Operand is the result of remsi + // One and only one of the operands should be a scalar. Increment both start + // and end, dims remains unchanged, and scalar is empty. + static void parseRemsi(OpBuilder &builder, Location loc, + arith::RemSIOp RemSIOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + // Operand is the result of SelectOp + // only for bypass + static void parseSelect(OpBuilder &builder, Location loc, + arith::SelectOp SelectOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + // Operand is the result of ReduceOp + // only for bypass + static void parseReduce(OpBuilder &builder, Location loc, + triton::ReduceOp ReduceOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + // Operand is the result of LoadOp + // only for bypass + static void parseLoad(OpBuilder &builder, Location loc, triton::LoadOp LoadOp, + MaskState &state, + llvm::SmallDenseMap &knownMasks); + // Operand is the result of ExtSIOp + // only for bypass + static void parseExtsi(OpBuilder &builder, Location loc, + arith::ExtSIOp ExtSIOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); + // Operand is the result of ExtUIOp + // only for bypass + static void parseExtui(OpBuilder &builder, Location loc, + arith::ExtUIOp ExtUIOp, MaskState &state, + llvm::SmallDenseMap &knownMasks); +}; + +} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif // GCU_ANALYSIS_MASKANALYSIS_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/OpFoldResultUtils.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/OpFoldResultUtils.h new file mode 100644 index 000000000..8209850b0 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/OpFoldResultUtils.h @@ -0,0 +1,86 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_ANALYSIS_OPFOLDRESULTUTILS_H +#define GCU_ANALYSIS_OPFOLDRESULTUTILS_H + +#include + +#include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" + +//===--------------------------------------------------------------------===// +// The main code is modified from triton-to-linalg branch in triton repo. +//===--------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace gcu { + +// Return integer if ofr is an IntegerAttr. Note that this function differs +// from getConstantIntValue, which returns an integer if ofr is the constant +// result of an operation too. +std::optional getIntAttr(const OpFoldResult ofr); + +Value getValue(OpBuilder &builder, Location loc, const OpFoldResult ofr); + +llvm::SmallVector getValues(OpBuilder &builder, Location loc, + const llvm::SmallVector &ofr); + +std::optional getScalarValue(OpBuilder &builder, Location loc, Value v); + +// Process addition of two OFRs. If both OFRs are Integer Attributes, result +// is an Integer Attribute. Otherwise, insert the arith.addi instruction if +// needed and use its result Value. +OpFoldResult addOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs); + +// Produce result = lhs - rhs. If both OFRs are Integer Attributes, result +// is an Integer Attribute. Otherwise, insert the arith.addi instruction if +// needed and use its result Value. +OpFoldResult subOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs); + +// Process multiplication of two OFRs. If both OFRs are Integer Attributes, +// result is an Integer Attribute. Otherwise, insert the arith.muli +// instruction if needed and use its result Value. +OpFoldResult mulOFRValue(OpBuilder &builder, Location loc, + const OpFoldResult lhs, const Value rhs); + +OpFoldResult minOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs); + +OpFoldResult maxOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs); + +// Produce result = lhs % rhs. If both OFRs are Integer Attributes, result +// is an Integer Attribute. Otherwise, insert the arith.remsi instruction if +// needed and use its result Value. +OpFoldResult remOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs); + +// Produce result = lhs / rhs. If both OFRs are Integer Attributes, result +// is an Integer Attribute. Otherwise, insert the arith.divsi instruction if +// needed and use its result Value. +OpFoldResult divOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs); + +} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif // GCU_ANALYSIS_OPFOLDRESULTUTILS_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/PtrAnalysis.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/PtrAnalysis.h new file mode 100644 index 000000000..af0cc55fa --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Analysis/PtrAnalysis.h @@ -0,0 +1,347 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_ANALYSIS_PTRANALYSIS_H +#define GCU_ANALYSIS_PTRANALYSIS_H + +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +//===--------------------------------------------------------------------===// +// The main code is modified from triton-to-linalg branch in triton repo. +//===--------------------------------------------------------------------===// + +namespace mlir { + +class ModuleOp; +class PatternRewriter; + +namespace triton { +namespace gcu { + +using llvm::SmallDenseMap; +using llvm::SmallVector; + +struct MaskState; + +struct PtrInfo { + Value base; + llvm::SmallVector shape; + llvm::SmallVector strides; + llvm::SmallVector offsets; + llvm::DenseSet broadcastDims; +}; + +// Data structure used to decode pointer arithmetic and potentially to be +// translate it into memref. offsets, sizes, and strides are in unit of elements +// in a linearly laid-out memory, which is the same as pointer arithmetic +// operations in Triton language. scalar is a shortcut used when the entire +// state describes a single scalar value. source is the base pointer. +struct PtrState { + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; + + Value source; + Value scalar; + + int64_t getRank() const; + bool isEmpty() const; + + // Process addition of two PtrStates + void addState(OpBuilder &builder, Location loc, const PtrState &lhsState, + const PtrState &rhsState); + + // Process multiplication of two PtrStates + void mulState(OpBuilder &builder, Location loc, const PtrState &lhsState, + const PtrState &rhsState); + + // Process division remainder of two PtrStates + void remState(OpBuilder &builder, Location loc, const PtrState &lhsState, + const PtrState &rhsState); + + // Process division of two PtrStates + void divState(OpBuilder &builder, Location loc, const PtrState &lhsState, + const PtrState &rhsState); + + // set state for srcState + void setState(OpBuilder &builder, Location loc, const PtrState &srcState); + + PtrInfo getPtrInfo(OpBuilder &builder, Location loc, const MaskState &mstate); +}; + +class PtrAnalysis { +public: + // Recursively parse a Value; call the corresponding function based on the + // defining operation and Value type + static void visitOperand(PatternRewriter &rewriter, Location loc, + Value operand, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is a block argument + static void + visitBlockArgument(PatternRewriter &rewriter, Location loc, BlockArgument arg, + PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.constant that is a splat + // Main assumptions: + // Source is a constant op that produces a constant dense tensor where all + // elements are the same (i.e.: a constant that is splatted) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = + // splat value if i == 0, otherwise 0 + static void + visitOperandConstSplat(PatternRewriter &rewriter, Location loc, + arith::ConstantOp op, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.addi. Process both arguments and insert any + // arith.addi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrState should be empty + // Expected result: + // source = lhsState.source ? lhsState.source : rhsState.source + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] + rhsState.offsets[i] + // strides[i] = lhsState.strides[i] + rhsState.strides[i] + static void visitOperandAdd(PatternRewriter &rewriter, Location loc, + arith::AddIOp addOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.muli. Process both arguments and insert any + // arith.muli instruction as needed. + // Main assumptions: + // Neither lhsState nor rhsState has source field set + // Current PtrState should be empty + // Currently only support one of the operand is a scalar index + // Expected result (scalar and tensorState represent the two operands): + // source = null + // sizes[i] = tensorState.sizes[i] + // offsets[i] = tensorState.offsets[i] * scalar + // strides[i] = tensorState.strides[i] * scalar + static void visitOperandMul(PatternRewriter &rewriter, Location loc, + arith::MulIOp mulOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.remsi. Process both arguments and insert any + // arith.remsi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrState should be empty + // Expected result: + // source = null + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] % rhsState.offsets[i] + // strides[i] = lhsState.strides[i] (which should match rhsState.strides[i]) + static void visitOperandRem(PatternRewriter &rewriter, Location loc, + arith::RemSIOp remOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.divsi. Process both arguments and insert any + // arith.divsi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrState should be empty + // Expected result: + // source = null + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] / rhsState.offsets[i] + // strides[i] = lhsState.strides[i] (which should match rhsState.strides[i]) + static void visitOperandDiv(PatternRewriter &rewriter, Location loc, + arith::DivSIOp divOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of arith.select. Process both arguments and insert + // any arith.select instruction as needed. + // Main assumptions: + // Select lhsState or rhsState + // Current PtrState should be empty + // Expected result: + // The resulting state is lhsState or rhsState + static void + visitOperandSelect(PatternRewriter &rewriter, Location loc, + arith::SelectOp selectOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of make_range. + // Main assumptions: + // start, end, and shape are all statically known + // The output of make_range is 1-dimensional + // Does not check validity of inputs (e.g., stride > 0) + // Expected result: + // source = null + // sizes[0] = shape[0] + // offset[0] = start + // strides[0] = ceiling( (end - start) / shape[0] ) + static void + visitOperandMakeRange(PatternRewriter &rewriter, Location loc, + triton::MakeRangeOp rangeOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of expand_dims + // Main assumptions: + // Only 1 dimension changes for each invocation of reshape + // The changed dimension must have size of 1 + // Expected result: + // Insert a dimension of size 1, stride 0, and offset 0 + static void + visitOperandExpandDims(PatternRewriter &rewriter, Location loc, + triton::ExpandDimsOp expandDimsOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of broadcast + // Main assumptions: + // Rank of source and result is the same + // Expected result: + // Update sizes[i] only, no changes to other fields + static void + visitOperandBroadcast(PatternRewriter &rewriter, Location loc, + triton::BroadcastOp broadcastOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of splat + // Main assumptions: + // Source is a scalar value (i.e., an integer or a pointer, not a tensor) + // Expected result: + // sizes[i] reflect the shape of the result, strides[i] = 0, offsets[i] = 0 + // if source is an integer, offset[0] = scalar = source + static void + visitOperandSplat(PatternRewriter &rewriter, Location loc, + triton::SplatOp splatOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of addptr. + // Main assumptions: + // The ptr field should populate the source field + // ptr and offset fields should result in same rank + // Expected result: + // The resulting state for ptr and offset will be added + static void + visitOperandAddptr(PatternRewriter &rewriter, Location loc, + triton::AddPtrOp addptrOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of bitcast + // Main assumptions: + // Rank of source and result is the same + // Expected result: + // The resulting state for all will be added + static void + visitOperandBitcast(PatternRewriter &rewriter, Location loc, + triton::BitcastOp bitcastOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of trans + // Main assumptions: + // Rank of source and result is the same + // Expected result: + // The resulting state for all will be trans + static void + visitOperandTrans(PatternRewriter &rewriter, Location loc, + triton::TransOp transOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of dots + // Main assumptions: + // Rank of source and result is the same + // Expected result: + // The resulting state for all will be same + static void visitOperandDot(PatternRewriter &rewriter, Location loc, + triton::DotOp dotOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of reduce + // Main assumptions: + // Rank of source and result is the same + // Expected result: + // The resulting state for all will be same + static void + visitOperandReduce(PatternRewriter &rewriter, Location loc, + triton::ReduceOp reduceOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of load + // Main assumptions: + // Rank of source and result is the same + // Expected result: + // The resulting state for all will be same + static void visitOperandLoad(PatternRewriter &rewriter, Location loc, + triton::LoadOp loadOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of extsi + // Main assumptions: + // Rank of source and result is the same + // Expected result: + // The resulting state for all will be same + static void + visitOperandExtsi(PatternRewriter &rewriter, Location loc, + arith::ExtSIOp extsiOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // Operand is the result of extui + // Main assumptions: + // Rank of source and result is the same + // Expected result: + // The resulting state for all will be same + static void + visitOperandExtui(PatternRewriter &rewriter, Location loc, + arith::ExtUIOp extuiOp, PtrState &state, + llvm::SmallDenseMap &knownPtrs); + + // bypass ForOp not include ld/st. + static bool byPassForOp(PatternRewriter &rewriter, scf::ForOp op, + const SmallVector &candidateOps); + // Parse the state of ForOp, insert any instruction needed to calculate + // strides and offsets, build PtrState for this operand, and record PtrState + // in knownPtrs. + static LogicalResult rewriteForOp( + PatternRewriter &rewriter, scf::ForOp op, + SmallDenseMap &knownPtrs, + SmallDenseMap &knownMasks, + SmallVector &candidateOps, + SmallDenseMap> &candidateHints); + + // Parse the state of YieldOp, insert any instruction needed to calculate + // strides and offsets, build PtrState for this operand, and record PtrState + // in knownPtrs. + static void rewriteYieldOp(PatternRewriter &rewriter, scf::YieldOp op, + llvm::SmallDenseMap &knownPtrs, + llvm::SmallDenseMap &knownMasks); + + // Parse the iter arg of ForOp, fold away unused ones. + static void foldAwayForOp(PatternRewriter &rewriter, scf::ForOp op, + llvm::SmallDenseMap &knownPtrs); + + // Collect candidate load/store op which could be converted to dma. + static void collectCandidateLoadStoreOps( + ModuleOp &moduleOp, llvm::SmallVector &candidates, + llvm::SmallDenseMap> &candidateOrders); +}; + +} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif // GCU_ANALYSIS_PTRANALYSIS_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/CMakeLists.txt new file mode 100644 index 000000000..bd79ad95a --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Dialect) +add_subdirectory(Conversion) +add_subdirectory(Transforms) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/CMakeLists.txt new file mode 100644 index 000000000..d900bae06 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/CMakeLists.txt @@ -0,0 +1,8 @@ + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGCUConversion) +mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix TritonGCUConversion) +mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix TritonGCUConversion) +add_mlir_doc(Passes TritonGCUConversionPasses_${arch} conversions/ -gen-pass-doc) + +add_public_tablegen_target(MLIRTritonGCUConversionPassIncGen_${arch}) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/Passes.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/Passes.h new file mode 100644 index 000000000..6318955b0 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/Passes.h @@ -0,0 +1,31 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRITON_GCU_CONVERSION_PASSES_H +#define TRITON_GCU_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +namespace mlir { +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // TRITON_GCU_CONVERSION_PASSES_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/Passes.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/Passes.td new file mode 100644 index 000000000..d62f346ba --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/Passes.td @@ -0,0 +1,154 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_CONVERSION_PASSES +#define GCU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// TritonToGCU +//===----------------------------------------------------------------------===// +def ConvertTritonToGCUPass: Pass<"convert-triton-to-gcu", "gpu::GPUModuleOp"> { + let summary = "Convert triton ir to gcu operations."; + let description = [{ + This pass converts triton ir to gcu operations. + }]; + let options = [ + Option<"arch", "arch", "std::string", + /*default=*/"\"gcu300\"", + "Architecture that these operations will run on">, + Option<"vectorLength", "vector-length", "unsigned", /*default=*/"512u", + "Length of vector in byte">, + Option<"vectorizationMaxLength", "vectorization-max-length", "unsigned", /*default=*/"16384u", + "Maximum vector length in byte for vectorization">, + ]; +} + +//===----------------------------------------------------------------------===// +// TrtionFunsion +//===----------------------------------------------------------------------===// +def GCUTritonFusionPass: Pass<"gcu-triton-fusion", "gpu::GPUModuleOp"> { + let summary = "Fuse triton ir to gcu operations."; + let description = [{ + This pass fuse triton ir to gcu operations. + }]; + let options = [ + Option<"arch", "arch", "std::string", + /*default=*/"\"gcu300\"", + "Architecture that these operations will run on"> + ]; +} + +//===----------------------------------------------------------------------===// +// TritonLoadStoreToGCUDma +//===----------------------------------------------------------------------===// +def ConvertTritonLoadStoreToGCUDmaPass: Pass<"convert-triton-load-store-to-gcu-dma", "gpu::GPUModuleOp"> { + let summary = "Convert triton load/store ir to gcu dma operations."; + let description = [{ + This pass converts triton load/store ir to gcu dma operations. + }]; + let options = [ + Option<"arch", "arch", "std::string", + /*default=*/"\"gcu300\"", + "Architecture that these operations will run on"> + ]; +} + + +//===----------------------------------------------------------------------===// +// TritonGCULayoutOptimize +//===----------------------------------------------------------------------===// + def TritonGCULayoutOptimizePass: Pass<"triton-gcu-data-layout-optimize", "gpu::GPUModuleOp"> { + let summary = "optimize data layout for gcu target"; + let description = [{ + This pass gerneric optimize triton layout for gcu target. + }]; + let options = [ + Option<"arch", "arch", "std::string", + /*default=*/"\"gcu300\"", + "Architecture that these operations will run on"> + ]; +} + +//===----------------------------------------------------------------------===// +// TritonDotLayoutOptimize +//===----------------------------------------------------------------------===// +def TritonGCUDotLayoutOptimizePass: Pass<"triton-gcu-dot-layout-optimize", "gpu::GPUModuleOp"> { + let summary = "optimize dot layout for gcu target"; + let description = [{ + This pass specified optimize triton layout for dot. + }]; + let options = [ + Option<"arch", "arch", "std::string", + /*default=*/"\"gcu300\"", + "Architecture that these operations will run on"> + ]; +} + +//===----------------------------------------------------------------------===// +// FlattenTritonFunc +//===----------------------------------------------------------------------===// +def GCUFlattenTritonFuncPass: Pass<"flatten-triton-func", "gpu::GPUModuleOp"> { + let summary = "flatten func op for triton."; + let description = [{ + This pass flatten func op which is called multiple times. + }]; + let options = [ + ]; +} + +//===----------------------------------------------------------------------===// +// ConvertTensorPointer +//===----------------------------------------------------------------------===// +def ConvertTensorPointerPass: Pass<"convert-tensor-pointer", "gpu::GPUModuleOp"> { + let summary = "Convert triton make_tensor_ptr ir to make_range operations."; + let description = [{ + This pass converts triton make_tensor_ptr ir to make_range operations. + }]; + let options = [ + ]; +} + +//===----------------------------------------------------------------------===// +// TritonGCUPingpongPass +//===----------------------------------------------------------------------===// +def TritonGCUPingpongPass: Pass<"triton-gcu-pingpong", "gpu::GPUModuleOp"> { + let summary = "optimize dot layout for gcu target"; + let description = [{ + This pass async load global to share and maybe only for scorpio. + }]; + let options = [ + Option<"arch", "arch", "std::string", + /*default=*/"\"gcu300\"", + "Architecture that these operations will run on">, + Option<"numStages", "num_stages", "unsigned", /*default=*/"3u", + "number stage of pipeline"> + ]; + } + +//===----------------------------------------------------------------------===// +// TritonGPUtoTritonGCU +//===----------------------------------------------------------------------===// +def TritonGPUToTritonGCUPass: Pass<"triton-gpu-to-triton-gcu", "gpu::GPUModuleOp"> { + let summary = "change some special triton GPU op to gcu "; + let description = [{ + This pass change some special triton GPU op to gcu . + }]; + let options = [ + ]; +} + +#endif // GCU_CONVERSION_PASSES diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/TritonToGCU/TritonToGCUPass.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/TritonToGCU/TritonToGCUPass.h new file mode 100644 index 000000000..c3490187a --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Conversion/TritonToGCU/TritonToGCUPass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_CONVERSION_TRITONTOGCU_TRITONTOGCUPASS_H +#define GCU_CONVERSION_TRITONTOGCU_TRITONTOGCUPASS_H + +namespace mlir { +template class InterfacePass; +class Pass; + +#define GEN_PASS_DECL_CONVERTTRITONTOGCUPASS +#define GEN_PASS_DECL_GCUTRITONFUSIONPASS +#define GEN_PASS_DECL_CONVERTTRITONLOADSTORETOGCUDMAPASS +#define GEN_PASS_DECL_TRITONGCULAYOUTOPTIMIZEPASS +#define GEN_PASS_DECL_TRITONGCUDOTLAYOUTOPTIMIZEPASS +#define GEN_PASS_DECL_GCUFLATTENTRITONFUNCPASS +#define GEN_PASS_DECL_CONVERTTENSORPOINTERPASS +#define GEN_PASS_DECL_TRITONGCUPINGPONGPASS +#define GEN_PASS_DECL_TRITONGPUTOTRITONGCUPASS +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // GCU_CONVERSION_TRITONTOGCU_TRITONTOGCUPASS_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/CMakeLists.txt new file mode 100644 index 000000000..354edf74b --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(GCU) +add_subdirectory(MemrefExt) +add_subdirectory(MathExt) +add_subdirectory(TritonGCU) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/CMakeLists.txt new file mode 100644 index 000000000..3ad697783 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/CMakeLists.txt @@ -0,0 +1,25 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS GCUOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(OpsAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(OpsAttributes.cpp.inc -gen-attrdef-defs) +add_mlir_doc(GCUOps GCUOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS GCUDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(GCUDialect GCUDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS GCUTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS GCUInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +add_public_tablegen_target(GCUTableGen${arch}) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Dialect.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Dialect.h new file mode 100644 index 000000000..f7d30a2c3 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Dialect.h @@ -0,0 +1,49 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_GCU_IR_DIALECT_H +#define GCU_DIALECT_GCU_IR_DIALECT_H + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +#include "Dialect/GCU/IR/Dialect.h.inc" +#include "Dialect/GCU/IR/OpsEnums.h.inc" +#include "Dialect/GCU/IR/Traits.h" +#include "Dialect/GCU/IR/Types.h" + +#define GET_ATTRDEF_CLASSES +#include "Dialect/GCU/IR/OpsAttributes.h.inc" +#define GET_OP_CLASSES +#include "Dialect/GCU/IR/Ops.h.inc" + +namespace mlir { +namespace gcu {} // namespace gcu +} // namespace mlir + +#endif // GCU_DIALECT_GCU_IR_DIALECT_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUAttrDefs.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUAttrDefs.td new file mode 100644 index 000000000..3b59b1677 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUAttrDefs.td @@ -0,0 +1,230 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_ATTR_DEFS +#define GCU_ATTR_DEFS + +include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "Dialect/GCU/IR/GCUDialect.td" + +class GCU_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +//===----------------------------------------------------------------------===// +// GCU address space attribute. +//===----------------------------------------------------------------------===// + +class GCU_I32Enum cases> + : I32EnumAttr { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::gcu"; +} +class GCU_I32EnumAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def GCU_AddressSpaceGlobal : I32EnumAttrCase<"Global", 1, "global">; +def GCU_AddressSpaceWorkgroup : I32EnumAttrCase<"Workgroup", 2, "workgroup">; +def GCU_AddressSpacePrivate : I32EnumAttrCase<"Private", 3, "private">; +def GCU_AddressSpaceLocal : I32EnumAttrCase<"Local", 4, "local">; +def GCU_AddressSpaceEnum : GCU_I32Enum< + "AddressSpace", "GCU address space", [ + GCU_AddressSpaceGlobal, + GCU_AddressSpaceWorkgroup, + GCU_AddressSpacePrivate, + GCU_AddressSpaceLocal + ]>; + +def GCU_AddressSpaceAttr : + GCU_I32EnumAttr<"address_space", GCU_AddressSpaceEnum> { +} + +def GCU_MemoryFenceEnum : GCU_I32Enum< + "MFenceType", "GCU memory fence type", [ + I32EnumAttrCase<"Memory", 0, "memory">, + I32EnumAttrCase<"Local", 1, "local">, + I32EnumAttrCase<"Share", 2, "share">, + I32EnumAttrCase<"Device", 3, "device"> + ]>; + +def GCU_MemoryFenceAttr : + GCU_I32EnumAttr<"mfence_type", GCU_MemoryFenceEnum> { +} + + + +//===----------------------------------------------------------------------===// +// GCU architecture attribute. +//===----------------------------------------------------------------------===// + +def GCU_Architecture300 : I32EnumAttrCase<"GCU300", 1, "gcu300">; +def GCU_Architecture400 : I32EnumAttrCase<"GCU400", 2, "gcu400">; +def GCU_Architecture410 : I32EnumAttrCase<"GCU410", 3, "gcu410">; +def GCU_ArchitectureEnum : GCU_I32Enum< + "Architecture", "GCU Architecture", [ + GCU_Architecture300, + GCU_Architecture400, + GCU_Architecture410, + ]>; + +def GCU_ArchAttr : GCU_I32EnumAttr<"gcu_architecture", GCU_ArchitectureEnum> { +} + + +//===----------------------------------------------------------------------===// +// GCU target attribute. +//===----------------------------------------------------------------------===// + +def GCU_TargettAttr : + GCU_Attr<"GCUTarget", "target"> { + let description = [{ + GCU target attribute for controlling compilation of GCU targets. All + parameters decay into default values if not present. + + Examples: + + 1. Target with default values. + ``` + gpu.module @mymodule [#gcu.target] attributes {...} { + ... + } + ``` + + 2. Target with `gcu300` chip and fast math. + ``` + gpu.module @mymodule [#gcu.target] { + ... + } + ``` + }]; + let parameters = (ins + DefaultValuedParameter<"int", "3", "Optimization level to apply.">:$O, + StringRefParameter<"Target triple.", "\"amdgcn-amd-amdhsa\"">:$triple, + StringRefParameter<"Target chip.", "\"gfx900\"">:$chip, + StringRefParameter<"Target arch.", "\"gcu300\"">:$arch, + StringRefParameter<"Target chip features.", "\"\"">:$features, + StringRefParameter<"ABI version.", "\"1\"">:$abi, + OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags, + OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link + ); + let assemblyFormat = [{ + (`<` struct($O, $triple, $chip, $arch, $features, $abi, $flags, $link)^ `>`)? + }]; + let builders = [ + AttrBuilder<(ins CArg<"int", "3">:$optLevel, + CArg<"StringRef", "\"amdgcn-amd-amdhsa\"">:$triple, + CArg<"StringRef", "\"gfx900\"">:$chip, + CArg<"StringRef", "\"gcu300\"">:$arch, + CArg<"StringRef", "\"\"">:$features, + CArg<"StringRef", "\"1\"">:$abiVersion, + CArg<"DictionaryAttr", "nullptr">:$targetFlags, + CArg<"ArrayAttr", "nullptr">:$linkFiles), [{ + return Base::get($_ctxt, optLevel, triple, chip, arch, features, abiVersion, + targetFlags, linkFiles); + }]> + ]; + let skipDefaultBuilders = 1; + let genVerifyDecl = 1; + let extraClassDeclaration = [{ + bool hasFlag(StringRef flag) const; + }]; + let extraClassDefinition = [{ + bool $cppClass::hasFlag(StringRef flag) const { + if (DictionaryAttr flags = getFlags()) + return flags.get(flag) != nullptr; + return false; + } + }]; +} + +def GCU_ReduceOperationAttr : I64EnumAttr< + "ReduceOperation", "", + [ + I64EnumAttrCase<"SUM", 0, "sum">, + I64EnumAttrCase<"MAXSI", 1, "maxsi">, + I64EnumAttrCase<"MAXUI", 2, "maxui">, + I64EnumAttrCase<"MAXF", 3, "maxf">, + I64EnumAttrCase<"MINSI", 4, "minsi">, + I64EnumAttrCase<"MINUI", 5, "minui">, + I64EnumAttrCase<"MINF", 6, "minf">, + ]> { + let cppNamespace = "::mlir::gcu"; +} + +def GCU_VectorMovSftModeAttr : I64EnumAttr< + "VectorMovSftMode", "", + [ + I64EnumAttrCase<"SHFLB", 0, "shift_left_with_byte">, + I64EnumAttrCase<"SHFRB", 1, "shift_right_with_byte">, + I64EnumAttrCase<"SHFLQW", 2, "shift_left_with_qw">, + I64EnumAttrCase<"SHFRQW", 3, "shift_right_with_qw"> + ]> { + let cppNamespace = "::mlir::gcu"; +} + +// InputPrecision +def GCU_InputPrecisionAttr : I32EnumAttr< + "InputPrecision", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee"> + ]>{ + let cppNamespace = "::mlir::gcu"; +} + +// atomic +def GCU_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"MAX", 5, "max">, + I32EnumAttrCase<"MIN", 6, "min">, + I32EnumAttrCase<"UMAX", 7, "umax">, + I32EnumAttrCase<"UMIN", 8, "umin">, + I32EnumAttrCase<"XCHG", 9, "exch"> + ]> { + let cppNamespace = "::mlir::gcu"; +} + +def GCU_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::gcu"; +} + +def GCU_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GCU", 1, "gcu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::gcu"; +} + +#endif // GCU_ATTR_DEFS diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUDialect.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUDialect.td new file mode 100644 index 000000000..26f1fade3 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUDialect.td @@ -0,0 +1,55 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT +#define GCU_DIALECT + +include "mlir/IR/OpBase.td" + +def GCU_Dialect : Dialect { + let name = "gcu"; + + let cppNamespace = "::mlir::gcu"; + + let summary = "The GCU IR in MLIR"; + + let description = [{ + GCU Dialect. + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "math::MathDialect", + "scf::SCFDialect", + "cf::ControlFlowDialect", + "memref::MemRefDialect", + "vector::VectorDialect", + "affine::AffineDialect", + "transform::TransformDialect", + "gpu::GPUDialect", + "LLVM::LLVMDialect", + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif // GCU_DIALECT diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUInterfaces.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUInterfaces.td new file mode 100644 index 000000000..3e2943e86 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUInterfaces.td @@ -0,0 +1,21 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_INTERFACES +#define GCU_INTERFACES + +include "mlir/IR/OpBase.td" + +#endif // GCU_INTERFACES diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUOps.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUOps.td new file mode 100644 index 000000000..52eeb6e09 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUOps.td @@ -0,0 +1,1093 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_OPS +#define GCU_OPS + +include "Dialect/GCU/IR/GCUDialect.td" +include "Dialect/GCU/IR/GCUAttrDefs.td" +include "Dialect/GCU/IR/GCUInterfaces.td" +include "Dialect/GCU/IR/GCUTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType + +// +// Op Base +// +class GCU_Op traits = []> : + Op { +} + +// +// AllocDTE Op +// +def GCU_AllocDTEOp : GCU_Op<"alloc_dte", [MemoryEffects<[MemAlloc]>]> { + let summary = "Allocate a DTE context"; + let description = [{ + `gcu.alloc_dte` allocate a DTE context. If addrespace != private, the dte context needs + to be called by `init` explicitly. + }]; + let results = (outs GCU_DTEType:$dte); + let assemblyFormat = "attr-dict `:` type($dte)"; +} + +// +// InitDTE Op +// +def GCU_InitDTEOp : GCU_Op<"init_dte", [MemoryEffectsOpInterface]> { + let summary = "Initialize a DTE context"; + let description = [{ + `gcu.init_dte` initialize a DTE context and only for non-private dte contexts. + }]; + let arguments = (ins GCU_DTEType:$dte); + let assemblyFormat = "$dte attr-dict `:` type($dte)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// ConnectDTE Op +// +def GCU_ConnectDTEOp : GCU_Op<"connect_dte", [MemoryEffectsOpInterface]> { + let summary = "Connect two DTE contexts"; + let description = [{ + `gcu.connect_dte` connect two DTE contexts. + }]; + let arguments = (ins GCU_DTEType:$from, GCU_DTEType:$to); + let assemblyFormat = "$from `,` $to attr-dict `:` type($from) `to` type($to)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// DestroyDTE Op +// +def GCU_DestroyDTEOp : GCU_Op<"destroy_dte", [MemoryEffectsOpInterface]> { + let summary = "Destroy a DTE context"; + let description = [{ + `gcu.destroy_dte` destroy a DTE context and only for non-private dte contexts. + }]; + let arguments = (ins GCU_DTEType:$dte); + let assemblyFormat = "$dte attr-dict `:` type($dte)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// DeallocDTE Op +// +def GCU_DeallocDTEOp : GCU_Op<"dealloc_dte", [MemoryEffects<[MemFree]>]> { + let summary = "Deallocate a DTE context"; + let description = [{ + `gcu.dealloc_dte` deallocate a DTE context. If addrespace != private, the dte context needs + to be called by `destroy` explicitly. + }]; + let arguments = (ins GCU_DTEType:$dte); + let assemblyFormat = "$dte attr-dict `:` type($dte)"; +} + +// +// TriggerDTE Op +// +def GCU_TriggerDTEOp : GCU_Op<"trigger_dte", [MemoryEffectsOpInterface]> { + let summary = "trigger a dte operation"; + let description = [{ + `gcu.trigger_dte` trigger a dte operation. + }]; + let arguments = (ins GCU_DTEType:$dte); + let assemblyFormat = "$dte attr-dict `:` type($dte)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// WaitDTE Op +// +def GCU_WaitDTEOp : GCU_Op<"wait_dte", [MemoryEffectsOpInterface]> { + let summary = "wait a dte operation"; + let description = [{ + `gcu.wait_dte` wait a dte operation. + }]; + let arguments = (ins GCU_DTEType:$dte); + let assemblyFormat = "$dte attr-dict `:` type($dte)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// SetDstAddrOp Op +// +def GCU_SetDstAddrOp : GCU_Op<"set_dst_addr", [MemoryEffectsOpInterface]> { + let summary = "Set destination address for the operation bound to the DTE context."; + let description = [{ + `gcu.set_dst_addr` set destination address for the operation bound to the DTE context.. + }]; + let arguments = (ins GCU_DTEType:$dte, GCU_PtrType:$addr); + let assemblyFormat = "$dte `,` $addr attr-dict `:` type($dte) `,` type($addr)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// SetSrcOffsetOp Op +// +def GCU_SetSrcOffsetOp : GCU_Op<"set_src_offset", [MemoryEffectsOpInterface]> { + let summary = "Set source address offset for the operation bound to the DTE"; + let description = [{ + `gcu.set_src_offset` set source address offset for the operation bound to the DTE. + `dim` is an int number representing the dimension. + `offset ` is an int number representing the offset. + }]; + let arguments = (ins GCU_DTEType:$dte, I32:$dim, I32:$offset); + let assemblyFormat = "$dte `,` $dim `,` $offset attr-dict `:` type($dte)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// SetDstOffsetOp Op +// +def GCU_SetDstOffsetOp : GCU_Op<"set_dst_offset", [MemoryEffectsOpInterface]> { + let summary = " Set destination address offset for the operation bound to the DTE."; + let description = [{ + `gcu.set_dst_offset` set destination address offset for the operation bound to the DTE. + `dim` is an int number representing the dimension. + `offset `is an int number representing the offset. + }]; + let arguments = (ins GCU_DTEType:$dte, I32:$dim, I32:$offset); + let assemblyFormat = "$dte `,` $dim `,` $offset attr-dict `:` type($dte)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Memset Op +// +def GCU_MemsetAsyncOp : GCU_Op<"memset_async", [MemoryEffectsOpInterface]> { + let summary = "Memset"; + let description = [{ + `gcu.memset_async` set a value to a buffer. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, GCU_FpInt:$value); + let assemblyFormat = "$dte `,` $dst `,` $value attr-dict `:` type($dte) `,` type($dst) `,` type($value)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Memcpy Op +// +def GCU_MemcpyAsyncOp : GCU_Op<"memcpy_async", [MemoryEffectsOpInterface]> { + let summary = "Memcpy"; + let description = [{ + `gcu.memcpy_async` copy a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src); + let assemblyFormat = "$dte `,` $dst `,` $src attr-dict `:` type($dte) `,` type($dst) `,` type($src)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Slice Op +// +def GCU_SliceAsyncOp : GCU_Op<"slice_async", [MemoryEffectsOpInterface]> { + let summary = "Slice"; + let description = [{ + `gcu.slice_async` slice a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets, AnyType:$default_value); + let assemblyFormat = "$dte `,` $dst `,` $src `[` $offsets `]` `,` $default_value attr-dict `:` type($dte) `,` type($dst) `,` type($src) `,` type($default_value)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// SlicePad Op +// +def GCU_SlicePadAsyncOp : GCU_Op<"slice_pad_async", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Slice and pad"; + let description = [{ + `gcu.slice_pad_async` slice a data from source to dst with pad value. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets, Variadic:$slice_shape, AnyType:$pad_value); + let assemblyFormat = "$dte `,` $dst `,` $src `[` $offsets `]` `[` $slice_shape `]` `,` $pad_value attr-dict `:` type($dte) `,` type($dst) `,` type($src) `,` type($pad_value)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Deslice Op +// +def GCU_DesliceAsyncOp : GCU_Op<"deslice_async", [MemoryEffectsOpInterface]> { + let summary = "Deslice"; + let description = [{ + `gcu.deslice_async` deslice a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets); + let assemblyFormat = "$dte `,` $dst `[` $offsets `]` `,` $src attr-dict `:` type($dte) `,` type($dst) `,` type($src)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Slice_Deslice Op +// +def GCU_SliceDesliceAsyncOp : GCU_Op<"slice_deslice_async", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Slice_Deslice"; + let description = [{ + `gcu.slice_deslice_async` slice_deslice a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets, Variadic:$slice_shape, Variadic:$dst_offsets); + let assemblyFormat = "$dte `,` $dst `[` $offsets `]` `[` $slice_shape `]` `[` $dst_offsets `]` `,` $src attr-dict `:` type($dte) `,` type($dst) `,` type($src)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Transpose Op +// +def GCU_TransposeAsyncOp : GCU_Op<"transpose_async", [MemoryEffectsOpInterface]> { + let summary = "Transpose"; + let description = [{ + `gcu.transpose_async` transpose a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$layout); + let assemblyFormat = "$dte `,` $dst `,` $src `[` $layout `]` attr-dict `:` type($dte) `,` type($dst) `,` type($src)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Broadcast Op +// +def GCU_BroadcastAsyncOp : GCU_Op<"broadcast_async", [MemoryEffectsOpInterface]> { + let summary = "Broadcast"; + let description = [{ + `gcu.broadcast_async` broadcast a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src); + let assemblyFormat = "$dte `,` $dst `,` $src attr-dict `:` type($dte) `,` type($dst) `,` type($src)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Slice_Broadcast Op +// +def GCU_SliceBroadcastAsyncOp : GCU_Op<"slice_broadcast_async", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Slice and broadcast"; + let description = [{ + `gcu.slice_broadcast_async` slice_broadcast a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets, Variadic:$slice_shape); + let assemblyFormat = "$dte `,` $dst `,` $src `[` $offsets `]` `[` $slice_shape `]` attr-dict `:` type($dte) `,` type($dst) `,` type($src)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Slice_Transpose Op +// +def GCU_SliceTransposeAsyncOp : GCU_Op<"slice_transpose_async", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Slice and transpose"; + let description = [{ + `gcu.slice_transpose_async` slice_transpose a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets, Variadic:$layout, GCU_FpInt:$value); + let assemblyFormat = "$dte `,` $dst `,` $src `[` $offsets `]` `[` $layout `]` `,` $value attr-dict `:` type($dte) `,` type($dst) `,` type($src) `,` type($value)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Transpose_Deslice Op +// +def GCU_TransposeDesliceAsyncOp : GCU_Op<"transpose_deslice_async", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Transpose and delice"; + let description = [{ + `gcu.transpose_deslice_async` transpose_deslice a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$layout, Variadic:$offsets); + let assemblyFormat = "$dte `,` $dst`[` $offsets `]` `,` $src `[` $layout `]` attr-dict `:` type($dte) `,` type($dst) `,` type($src)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Memset_Deslice Op +// +def GCU_MemsetDesliceAsyncOp : GCU_Op<"memset_deslice_async", [MemoryEffectsOpInterface]> { + let summary = "Memset and deslice"; + let description = [{ + `gcu.memset_deslice_async` memset_deslice a data from source to dst. + }]; + let arguments = (ins GCU_DTEType:$dte, AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets, GCU_FpInt:$value); + let assemblyFormat = "$dte `,` $dst `[` $offsets `]` `,` $src `,` $value attr-dict `:` type($dte) `,` type($dst) `,` type($src) `,` type($value)"; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + + +// +// DynamicSharedMemory Op +// +def GCU_DynamicSharedMemoryOp : GCU_Op<"dynamic_shared_memory", [Pure]> +{ + let summary = "Get the memref for dynamic shared memory"; + + let description = [{ + }]; + let arguments = (ins); + let results = (outs Arg>:$resultMemref); + let assemblyFormat = [{ attr-dict `:` type($resultMemref) }]; + let hasVerifier = 1; +} + +// +// AllocBarrier Op +// +def GCU_AllocBarrierOp : GCU_Op<"alloc_barrier", [MemoryEffects<[MemAlloc]>]> { + let summary = "Allocate a barrier"; + let description = [{ + `gcu.alloc_barrier` allocate a barrier and the barrier needs to be called by `init` explicitly. + }]; + let results = (outs GCU_BarrierType:$barrier); + let assemblyFormat = "attr-dict `:` type($barrier)"; + let hasVerifier = 1; +} + +// +// InitBarrier Op +// +def GCU_InitBarrierOp : GCU_Op<"init_barrier", [MemoryEffectsOpInterface]> { + let summary = "Initialize a barrier"; + let description = [{ + `gcu.init_barrier` initialize a barrier. + }]; + let arguments = (ins GCU_BarrierType:$barrier, I32:$count); + let assemblyFormat = "$barrier `,` $count attr-dict `:` type($barrier) `,` type($count)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// ArriveAndWaitBarrier Op +// +def GCU_ArriveAndWaitBarrierOp : GCU_Op<"arrive_and_wait_barrier", [MemoryEffectsOpInterface]> { + let summary = "Arrive and wait a barrier"; + let description = [{ + `gcu.arrive_and_wait_barrier` arrive and wait a barrier. + }]; + let arguments = (ins GCU_BarrierType:$barrier); + let assemblyFormat = "$barrier attr-dict `:` type($barrier)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// DestroyBarrier Op +// +def GCU_DestroyBarrierOp : GCU_Op<"destroy_barrier", [MemoryEffectsOpInterface]> { + let summary = "Destroy a barrier"; + let description = [{ + `gcu.destroy_barrier` destroy a barrier. + }]; + let arguments = (ins GCU_BarrierType:$barrier); + let assemblyFormat = "$barrier attr-dict `:` type($barrier)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// DeallocBarrier Op +// +def GCU_DeallocBarrierOp : GCU_Op<"dealloc_barrier", [MemoryEffects<[MemFree]>]> { + let summary = "Deallocate a barrier"; + let description = [{ + `gcu.dealloc_barrier` deallocate a barrier and the barrier needs to be called by `destroy` explicitly. + }]; + let arguments = (ins GCU_BarrierType:$barrier); + let assemblyFormat = "$barrier attr-dict `:` type($barrier)"; +} + +// +// VectorConverter Op +// +def GCU_VectorConvertOp : GCU_Op<"vector_convert", [Pure]> { + let summary = "Convert a vector type"; + let description = [{ + `gcu.vector_convert` convert a vector type. + }]; + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); + let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($outputs) "; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +// +// PtrToMemRef Op +// +def GCU_PtrToMemRefOp : GCU_Op<"ptr2memref", [Pure]> +{ + let summary = "Convert a pointer to memref"; + + let description = [{ + }]; + let arguments = (ins GCU_PtrType:$ptr); + let results = (outs Arg>:$resultMemref); + let assemblyFormat = [{ $ptr attr-dict `:` type($ptr) `to` type($resultMemref) }]; + let hasVerifier = 1; +} + +// +// MemRefToPtr Op +// +def GCU_MemRefToPtrOp : GCU_Op<"memref2ptr", [Pure]> +{ + let summary = "Convert a memref to pointer"; + + let description = [{ + }]; + let arguments = (ins AnyMemRef:$memref); + let results = (outs GCU_PtrType:$ptr); + let assemblyFormat = [{ $memref attr-dict `:` type($memref) `to` type($ptr) }]; + let hasVerifier = 1; +} + +// +// PtrToInt Op +// +def GCU_PtrToIntOp : GCU_Op<"ptr2int", [Pure]> +{ + let summary = "Convert a pointer to integer"; + + let description = [{ + }]; + let arguments = (ins GCU_PtrType:$ptr); + let results = (outs I64:$result); + let assemblyFormat = [{ $ptr attr-dict `:` type($ptr) }]; +} + + +// +// IntToPtr Op +// +def GCU_IntToPtrOp : GCU_Op<"int2ptr", [Pure]> +{ + let summary = "Convert an integer to pointer"; + + let description = [{ + }]; + let arguments = (ins I64:$value); + let results = (outs GCU_PtrType:$ptr); + let assemblyFormat = [{ $value attr-dict `:` type($ptr) }]; +} + +// +// GetMemRefOffset Op +// +def GCU_GetMemRefOffsetOp : GCU_Op<"get_memref_offset", [Pure]> +{ + let summary = "Get a memref offset"; + + let description = [{ + }]; + let arguments = (ins AnyMemRef:$memref, Variadic:$offsets); + let results = (outs Index:$storage_offset); + let assemblyFormat = [{ $memref $offsets attr-dict `:` type($memref) }]; +} + +// +// MaterializeInDestination Op +// +def GCU_MaterializeInDestinationOp : GCU_Op<"materialize_in_destination", [MemoryEffectsOpInterface]> +{ + let summary = "Materialize a tensor to memref and reuse dest memref"; + + let description = [{ + }]; + let arguments = (ins AnyTensor:$source, AnyMemRef:$dest); + let assemblyFormat = [{ $source `in` $dest attr-dict `:` type($source) `in` type($dest)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + } + }]; +} + +def GCU_ExternElementwiseOp : GCU_Op<"extern_elementwise_op", [SameVariadicOperandSize]> { + let summary = "gcu libdevice api for gcu400"; + let arguments = (ins Variadic>:$srcs, StrAttr:$symbol); + let results = (outs VectorOfRank<[1]>:$result); + + let assemblyFormat = "$srcs attr-dict `:` type($srcs) `to` type($result) "; +} + +// +// BuiltinElementwiseOp +// +def GCU_BuiltinElementwiseOp : GCU_Op<"builtin_elementwise_op", [MemoryEffectsOpInterface, AttrSizedOperandSegments]> { + let summary = "Builtin elementwise operation"; + let description = [{ + "Call an elementwise function $symbol implemented in builtin" + }]; + let arguments = (ins AnyMemRef:$output, Variadic>:$inputs, StrAttr:$symbol, Variadic:$params); + let assemblyFormat = [{$output `,` $inputs ( `,` $params^ )? attr-dict `:` type($output) `,` type($inputs) ( `,` type($params)^ )?}]; + let builders = [ + OpBuilder<(ins + "mlir::Value":$output, + "mlir::ValueRange":$inputs, + "mlir::StringAttr":$symbol) + , [{ + return build($_builder, $_state, output, inputs, symbol, {});}]>, + OpBuilder<(ins + "mlir::Value":$output, + "mlir::ValueRange":$inputs, + "mlir::ValueRange":$params, + "mlir::StringAttr":$symbol) + , [{ + return build($_builder, $_state, output, inputs, symbol, params);}]> + ]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + + +// +// MFence Op +// +def GCU_MFenceOp : GCU_Op<"mfence", []> +{ + let summary = "memory fence"; + let arguments = (ins DefaultValuedAttr:$mfence_type); + let description = [{ + }]; + let assemblyFormat = [{ attr-dict }]; +} + +// +// GatherLoad Op +// +def GCU_GatherLoadOp : GCU_Op<"gather_load", [MemoryEffectsOpInterface]> { + let summary = "Gather load"; + let description = [{ + `gcu.gather_load` gather a tensor from source to dst. + }]; + let arguments = (ins GCU_PtrType:$dst, GCU_PtrType:$src, GCU_PtrType:$offsets, GCU_PtrType:$masks, GCU_PtrType:$others, I32:$size); + let assemblyFormat = "$dst `,` $src `,` $offsets `,` $masks `,` $others `,` $size attr-dict `:` type($dst) `,` type($src) `,` type($offsets) `,` type($masks) `,` type($others) `,` type($size)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(3), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(4), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// ScatterStore Op +// +def GCU_ScatterStoreOp : GCU_Op<"scatter_store", [MemoryEffectsOpInterface]> { + let summary = "Scatter store"; + let description = [{ + `gcu.scatter_store` scatter a tensor from source to dst. + }]; + let arguments = (ins GCU_PtrType:$dst, GCU_PtrType:$src, GCU_PtrType:$offsets, GCU_PtrType:$masks, I32:$size); + let assemblyFormat = "$dst `,` $src `,` $offsets `,` $masks `,` $size attr-dict `:` type($dst) `,` type($src) `,` type($offsets) `,` type($masks) `,` type($size)"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(3), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// MatMul Op +// +def GCU_MatMulOp : GCU_Op<"matmul", [MemoryEffectsOpInterface]> { + let summary = "Matrix multiplication with bias"; + let description = [{ + Performs a two dimensional matrix multiplication. + }]; + let arguments = (ins AnyMemRef:$out, AnyMemRef:$lhs, AnyMemRef:$rhs, Optional:$bias, + DefaultValuedAttr:$inputPrecision); + let assemblyFormat = "$out `,` $lhs `,` $rhs (`,` $bias^)? attr-dict `:` type($out) `,` type($lhs) `,` type($rhs) (`,` type($bias)^)?"; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; + let hasVerifier = 1; +} + +// +// Assert Op +// +def GCU_AssertOp : GCU_Op<"assert", []> +{ + let summary = "device assert"; + let description = [{ + gcu device assert. + }]; + let arguments = (ins I1:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line); + let assemblyFormat = [{ + $condition attr-dict `:` type($condition) + }]; +} + +// +// triton debug Op +// +def GCU_TritonBorder : GCU_Op<"border", []> +{ + let summary = "triton border"; + let description = [{ + }]; + let assemblyFormat = [{ attr-dict }]; +} + +// +// Reduce Op +// + +def GCU_ReduceMemRef : MemRefRankOf<[AnyIntOfWidths<[8, 16, 32, 64]>, BF16, F16, F32], [3]>; + +def GCU_ReduceOp : GCU_Op<"reduce", [MemoryEffectsOpInterface]> { + let summary = "Reduction using specified combination algorithm"; + let arguments = (ins GCU_ReduceOperationAttr:$op, GCU_ReduceMemRef:$out, GCU_ReduceMemRef:$in, Optional:$workspace, I32Attr:$axis); + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(3), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + } + }]; + let hasVerifier = 1; +} + +def GCU_VectorMovSftOp : GCU_Op<"vector_movsft", [AllTypesMatch<["vin", "vout"]>]> { + let summary = "vector move shift"; + let arguments = (ins GCU_VectorMovSftModeAttr:$mode, VectorOfLengthAndType<[32], [I32, F32]>:$vin, I32Attr:$unit); + let results = (outs VectorOfLengthAndType<[32], [I32, F32]>:$vout); +} + + +// for performance tool + +// +// BeginClock +// + +def GCU_BeginClockOp : GCU_Op<"begin_clock", []> +{ + let summary = "get clock start counter"; + + let description = [{ + }]; + let results = (outs I64:$clock); + let assemblyFormat = [{ attr-dict `:` type($clock) }]; +} + +// +// EndClock +// + +def GCU_EndClockOp : GCU_Op<"end_clock", []> +{ + let summary = "get clock end counter"; + + let description = [{ + }]; + let results = (outs I64:$clock); + let assemblyFormat = [{ attr-dict `:` type($clock) }]; +} + +def GCU_TarInitOp : GCU_Op<"tar_init", []> { + let summary = "Thread Address Register only for gcu400"; + let arguments = (ins I64:$in); + let results = (outs GCU_TarType:$out); +} + +def GCU_TarLoadOp : GCU_Op<"tar_load", []> { + let summary = "vector load from TAR only for gcu400"; + let arguments = (ins GCU_TarType:$src_addr, GCU_TarType:$stride); + let results = (outs GCU_VectorType:$v, GCU_TarType:$dst_addr); +} + +def GCU_TarStoreOp : GCU_Op<"tar_store", []> { + let summary = "vector store to TAR only for gcu400"; + let arguments = (ins GCU_VectorType:$v, GCU_TarType:$src_addr, GCU_TarType:$stride); + let results = (outs GCU_TarType:$dst_addr); +} + +def GCU_TarScatterOp : GCU_Op<"tar_scatter", []> { + let summary = "vector scatter store to TAR only for gcu400"; + let arguments = (ins GCU_TarType:$src_addr, GCU_VectorType:$v, I32:$num, Optional:$mask); + let results = (outs GCU_TarType:$dst_addr); +} + +def GCU_VectorStepOp : GCU_Op<"vector_step", [Pure, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]> { + let summary = "a linear sequence of values"; + let arguments = (ins AnyIntOfWidths<[8, 16, 32, 64]>:$start); + let results = (outs GCU_VectorType:$result); +} + +def GCU_MemMapOp : GCU_Op<"mem_map", []> { + let summary = "mmu map memory, only for gcu300"; + let arguments = (ins GCU_PtrType:$ptr, I32:$num); + let results = (outs I32:$addr); +} + +def GCU_MemUnmapOp : GCU_Op<"mem_unmap", []> { + let summary = "mmu unmap memory, only for gcu300"; + let arguments = (ins I32:$addr, I32:$num); +} + +def GCU_AtomicRMWOp : GCU_Op<"atomic_rmw", []> { + let summary = "atomic rmw"; + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + return old value at $ptr + }]; + let hasVerifier = 1; + + let arguments = (ins + GCU_AtomicRMWAttr:$atomic_rmw_op, + GCU_PtrType:$ptr, + AnyTypeOf<[AnyMemRef, AnyI8, AnyI16, AnyI32, F16, F32, BF16]>:$val, + Optional:$mask, + GCU_MemSemanticAttr:$sem, + GCU_MemSyncScopeAttr:$scope + ); + + let results = (outs AnyTypeOf<[AnyI8, AnyI16, AnyI32, F16, F32, BF16, RankedTensorOf<[AnyI8, AnyI16, AnyI32, F16, F32, BF16]>]>:$result); + + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def GCU_AtomicCASOp : GCU_Op<"atomic_cas", []> { + let summary = "atomic cas"; + let description = [{ + compare $cmp with data $old at location $ptr, + if $old == $cmp, store $val to $ptr, + else store $old to $ptr, + return $old + }]; + let hasVerifier = 1; + + let arguments = (ins + GCU_PtrType:$ptr, + AnyTypeOf<[AnyMemRef, AnyI8, AnyI16, AnyI32, F16, F32, BF16]>:$cmp, + AnyTypeOf<[AnyMemRef, AnyI8, AnyI16, AnyI32, F16, F32, BF16]>:$val, + GCU_MemSemanticAttr:$sem, + GCU_MemSyncScopeAttr:$scope + ); + + let results = (outs AnyTypeOf<[AnyI8, AnyI16, AnyI32, F16, F32, BF16, RankedTensorOf<[AnyI8, AnyI16, AnyI32, F16, F32, BF16]>]>:$result); + + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + +#endif // GCU_OPS diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUTypes.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUTypes.td new file mode 100644 index 000000000..203a2048a --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/GCUTypes.td @@ -0,0 +1,129 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_TYPES +#define GCU_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "Dialect/GCU/IR/GCUDialect.td" +include "mlir/IR/BuiltinTypeInterfaces.td" + +// +// Types +// +class GCUTypeDef traits = []> + : TypeDef { + // Used by dte ctx + let mnemonic = _mnemonic; +} + +// Floating-point Type +def GCU_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">; +def GCU_FloatTensor : TensorOf<[GCU_Float]>; +def GCU_FloatLike : AnyTypeOf<[GCU_Float, GCU_FloatTensor]>; + +// Boolean Type +// GCU_Bool -> I1 +def GCU_BoolTensor : TensorOf<[I1]>; +def GCU_BoolLike : AnyTypeOf<[I1, GCU_BoolTensor]>; + +// Integer Type +def GCU_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def GCU_IntTensor : TensorOf<[GCU_Int]>; +def GCU_IntLike : AnyTypeOf<[GCU_Int, GCU_IntTensor]>; + +// I32 Type +// GCU_I32 -> I32 +// GCU_I32Tensor -> I32Tensor +def GCU_I32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// GCU_I64 -> I64 +// GCU_I64Tensor -> I64Tensor +def GCU_I64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Tensor Type +def GCU_FpIntTensor : AnyTypeOf<[GCU_FloatTensor, GCU_IntTensor]>; +def GCU_Tensor : AnyTypeOf<[GCU_FpIntTensor]>; + +// DTE Ctx Type +def GCU_DTEType : GCUTypeDef<"DTE", "dte", [MemRefElementTypeInterface]> { + let summary = "DTE ctx type (`::mlir::gcu::DTEType`) in GCU IR type system"; + let description = [{ + DTE ctx type in GCU IR type system. + }]; + let parameters = (ins "AddressSpaceAttr":$addressSpace); + let assemblyFormat = [{ `<` $addressSpace `>` }]; +} + +// Barrier Type +def GCU_BarrierType : GCUTypeDef<"Barrier", "barrier", [MemRefElementTypeInterface]> { + let summary = "DTE ctx type (`::mlir::gcu::BarrierType`) in GCU IR type system"; + let description = [{ + Barrier type in GCU IR type system. + }]; + let parameters = (ins "AddressSpaceAttr":$addressSpace); + let assemblyFormat = [{ `<` $addressSpace `>` }]; +} + +// Pointer Type +def GCU_PtrType : GCUTypeDef<"Ptr", "ptr", [MemRefElementTypeInterface]> { + let summary = "Pointer type (`::mlir::gcu::PtrType`) in GCU IR type system"; + let description = [{ + Pointer type in GCU IR type system. + }]; + let parameters = (ins "Type":$elementType); + let assemblyFormat = [{ `<` $elementType `>` }]; +} + +// Any Type in IR +def GCU_FpInt : AnyTypeOf<[GCU_Float, GCU_Int]>; +def GCU_Type : AnyTypeOf<[GCU_FloatLike, GCU_IntLike, GCU_FpIntTensor, GCU_DTEType]>; + +// Vector 1d +def GCU_Vector1D : VectorOfRank<[1]>; + + +// todo_AT +// TileDesc Type +def GCU_TileDescType : GCUTypeDef<"TileDesc", "tile_desc", []> { + let summary = "DTE ctx type (`::mlir::gcu::TileDescType`) in GCU IR type system"; + let description = [{ + Tile descriptor type in GCU IR type system. + }]; + let parameters = (ins "RankedTensorType":$desc); + let assemblyFormat = [{ `<` $desc `>` }]; +} + +class VectorOfRankAndLengthAndType allowedRanks, + list allowedLengths, + list allowedTypes> : AllOfType< + [VectorOfRank, VectorOfLength, + VectorOfAnyRankOf], + VectorOfRank.summary # + VectorOfLength.summary # + VectorOfAnyRankOf.summary, + "::mlir::VectorType">; + +def GCU_TarType : VectorOfLengthAndType<[1],[I64]>; + +def GCU_FP8 : AnyTypeOf<[F8E4M3FN, F8E5M2]>; + +def GCU_VectorType : AnyTypeOf<[VectorOfRankAndLengthAndType<[1], [512, 1024, 2048], [I1, I8, GCU_FP8]>, + VectorOfRankAndLengthAndType<[1], [256, 512, 1024], [I16, F16, BF16]>, + VectorOfRankAndLengthAndType<[1], [128, 256, 512], [I32, F32]>, + VectorOfRankAndLengthAndType<[1], [64, 128, 256], [I64]>]>; + +#endif //GCU_TYPES diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Interfaces.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Interfaces.h new file mode 100644 index 000000000..3c6152f3c --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Interfaces.h @@ -0,0 +1,24 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_GCU_IR_INTERFACES_H +#define GCU_DIALECT_GCU_IR_INTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +#define GET_TYPEDEF_CLASSES +#include "Dialect/GCU/IR/AttrInterfaces.h.inc" + +#endif // GCU_DIALECT_GCU_IR_INTERFACES_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Traits.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Traits.h new file mode 100644 index 000000000..442f75fd0 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Traits.h @@ -0,0 +1,29 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_GCU_IR_TRAITS_H +#define GCU_DIALECT_GCU_IR_TRAITS_H + +#include + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait {} // namespace OpTrait +} // namespace mlir + +#endif // GCU_DIALECT_GCU_IR_TRAITS_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Types.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Types.h new file mode 100644 index 000000000..0bd6aa40c --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/GCU/IR/Types.h @@ -0,0 +1,32 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_GCU_IR_TYPES_H +#define GCU_DIALECT_GCU_IR_TYPES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace gcu { +class AddressSpaceAttr; +} // namespace gcu +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "Dialect/GCU/IR/Types.h.inc" + +#endif // GCU_DIALECT_GCU_IR_TYPES_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/CMakeLists.txt new file mode 100644 index 000000000..7c0acdb8e --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/CMakeLists.txt @@ -0,0 +1,21 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS MathExtOps.td) +mlir_tablegen(MathExtOps.h.inc -gen-op-decls) +mlir_tablegen(MathExtOps.cpp.inc -gen-op-defs) +# mlir_tablegen(MathExtOpsEnums.h.inc -gen-enum-decls) +# mlir_tablegen(MathExtOpsEnums.cpp.inc -gen-enum-defs) +# mlir_tablegen(MathExtOpsAttributes.h.inc -gen-attrdef-decls) +# mlir_tablegen(MathExtOpsAttributes.cpp.inc -gen-attrdef-defs) +add_mlir_doc(MathExtOps MathExtOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS MathExtDialect.td) +mlir_tablegen(MathExtDialect.h.inc -gen-dialect-decls) +mlir_tablegen(MathExtDialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(MathExtDialect MathExtDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS MathExtTypes.td) +mlir_tablegen(MathExtTypes.h.inc -gen-typedef-decls) +mlir_tablegen(MathExtTypes.cpp.inc -gen-typedef-defs) + +add_public_tablegen_target(MathExtTableGen) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExt.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExt.h new file mode 100644 index 000000000..6e5d68963 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExt.h @@ -0,0 +1,31 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_MATH_EXT_IR_DIALECT_H +#define GCU_DIALECT_MATH_EXT_IR_DIALECT_H + +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include "Dialect/MathExt/IR/MathExtTypes.h" + +#include "Dialect/MathExt/IR/MathExtDialect.h.inc" +#define GET_OP_CLASSES +#include "Dialect/MathExt/IR/MathExtOps.h.inc" + +namespace mlir { +namespace math_ext {} // namespace math_ext +} // namespace mlir + +#endif // GCU_DIALECT_MATH_EXT_IR_DIALECT_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtDialect.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtDialect.td new file mode 100644 index 000000000..2ab583980 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtDialect.td @@ -0,0 +1,46 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MATH_EXT_DIALECT +#define MATH_EXT_DIALECT + +include "mlir/IR/OpBase.td" + +def MathExt_Dialect : Dialect { + let name = "math_ext"; + + let cppNamespace = "::mlir::math_ext"; + + let summary = "The MathExt IR in MLIR"; + + let description = [{ + MathExt Dialect. + }]; + +// let dependentDialects = [ +// +// ]; + +// let extraClassDeclaration = [{ +// void registerTypes(); +// }]; + + let hasConstantMaterializer = 1; +// let useDefaultTypePrinterParser = 1; +// let useDefaultAttributePrinterParser = 1; +// let usePropertiesForAttributes = 1; +} + +#endif // MATH_EXT_DIALECT diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtOps.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtOps.td new file mode 100644 index 000000000..40a934c3a --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtOps.td @@ -0,0 +1,79 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MATH_EXT_OPS +#define MATH_EXT_OPS + +include "Dialect/MathExt/IR/MathExtDialect.td" +include "Dialect/MathExt/IR/MathExtTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure + +// +// Op Base +// +class MathExt_Op traits = []> : + Op { +} + +// +// Umulhi Op +// +def MathExt_UmulhiOp : MathExt_Op<"umulhi", []> { + let summary = "Umulhi"; + let description = [{ + Performs unsigned multiplication of two I32/I64 inputs and returns the high 32/64 bits of the result as an I32/I64 output. + }]; + let results = (outs MathExt_IntLike:$result); + let arguments = (ins MathExt_IntLike:$lhs, MathExt_IntLike:$rhs); + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + } + }]; + let hasVerifier = 1; +} + +// +// Histogram Op +// +def MathExt_HistogramOp : MathExt_Op<"histogram", [MemoryEffectsOpInterface]> { + let summary = "Histogram"; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + }]; + let arguments = (ins AnyMemRef:$result, AnyMemRef:$operand); + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + } + }]; + let hasVerifier = 1; +} + +#endif // MATH_EXT_OPS diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtTypes.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtTypes.h new file mode 100644 index 000000000..1aaf71eb9 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtTypes.h @@ -0,0 +1,30 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_MATH_EXT_IR_TYPES_H +#define GCU_DIALECT_MATH_EXT_IR_TYPES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace math_ext {} // namespace math_ext +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "Dialect/MathExt/IR/MathExtTypes.h.inc" + +#endif // GCU_DIALECT_MATH_EXT_IR_TYPES_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtTypes.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtTypes.td new file mode 100644 index 000000000..db4fd9f1b --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MathExt/IR/MathExtTypes.td @@ -0,0 +1,35 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MATH_EXT_TYPES +#define MATH_EXT_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "Dialect/MathExt/IR/MathExtDialect.td" +include "mlir/IR/BuiltinTypeInterfaces.td" + +// +// Types +// +// I32 Type +// MathExt_I32 -> I32 +// MathExt_I32Tensor -> I32Tensor +def MathExt_I32Like : AnyTypeOf<[I32, I32Tensor]>; + +def MathExt_Int : AnyTypeOf<[I32, I64], "integer">; +def MathExt_IntTensor : TensorOf<[MathExt_Int]>; +def MathExt_IntLike : AnyTypeOf<[MathExt_Int, MathExt_IntTensor]>; + +#endif //MATH_EXT_TYPES diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/CMakeLists.txt new file mode 100644 index 000000000..ee83980bd --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS MemrefExtDialect.td) +mlir_tablegen(MemrefExtDialect.h.inc -gen-dialect-decls) +mlir_tablegen(MemrefExtDialect.cpp.inc -gen-dialect-defs) + +set(LLVM_TARGET_DEFINITIONS MemrefExtOps.td) +mlir_tablegen(MemrefExtOps.h.inc -gen-op-decls) +mlir_tablegen(MemrefExtOps.cpp.inc -gen-op-defs) + +add_public_tablegen_target(MemrefExtTableGen${arch}) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExt.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExt.h new file mode 100644 index 000000000..32557580a --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExt.h @@ -0,0 +1,35 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_MEMREF_EXT_DIALECT_H +#define GCU_DIALECT_MEMREF_EXT_DIALECT_H + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" + +#include "Dialect/MemrefExt/IR/MemrefExtDialect.h.inc" +#define GET_OP_CLASSES +#include "Dialect/MemrefExt/IR/MemrefExtOps.h.inc" +#define GET_ATTRDEF_CLASSES + +namespace mlir { +namespace memref_ext {} // namespace memref_ext +} // namespace mlir + +#endif // GCU_DIALECT_MEMREF_EXT_DIALECT_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExtDialect.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExtDialect.td new file mode 100644 index 000000000..969f0132b --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExtDialect.td @@ -0,0 +1,47 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MEMREF_EXT_DIALECT +#define MEMREF_EXT_DIALECT + +include "mlir/IR/OpBase.td" + +def MemrefExt_Dialect : Dialect { + let name = "memref_ext"; + + let cppNamespace = "::mlir::memref_ext"; + + let summary = "The MemrefExt IR in MLIR"; + + let description = [{ + MemrefExt Dialect. + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "memref::MemRefDialect", + ]; + + // let extraClassDeclaration = [{ + // void registerTypes(); + // }]; + + let hasConstantMaterializer = 1; + // let useDefaultTypePrinterParser = 1; + // let useDefaultAttributePrinterParser = 1; + // let usePropertiesForAttributes = 1; +} + +#endif // MEMREF_EXT_DIALECT diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExtOps.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExtOps.td new file mode 100644 index 000000000..a6e1116fa --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/MemrefExt/IR/MemrefExtOps.td @@ -0,0 +1,360 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MemrefExt_OPS +#define MemrefExt_OPS + +include "Dialect/MemrefExt/IR/MemrefExtDialect.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType + +// +// Op Base +// +class MemrefExt_Op traits = []> : + Op { +} +// +// Memset Op (tag version) +// +def GCU_MemsetStartOp : MemrefExt_Op<"memset_start", [MemoryEffectsOpInterface]> { + let summary = "Memset"; + let description = [{ + `memref_ext.memset_start` msmet a value to buffer. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + + let assemblyFormat = [{$dst `,` $value `,` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($value) `,` type($tag_memref)}]; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Slice Op (tag version) +// +def GCU_SliceStartOp : MemrefExt_Op<"slice_start", + [AttrSizedOperandSegments, + MemoryEffectsOpInterface]> { + let summary = "Slice"; + let description = [{ + `memref_ext.slice_start` slice a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyMemRef:$src, + Variadic:$offsets, + AnyType:$default_value, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + + let assemblyFormat = [{$dst `,` $src `[` $offsets `]` `,` $default_value `,` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($default_value) `,` type($tag_memref)}]; + let hasVerifier = 1; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(4), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// SlicePad Op (tag version) +// +def GCU_SlicePadStartOp : MemrefExt_Op<"slice_pad_start", + [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Slice"; + let description = [{ + `memref_ext.slice_pad_start` slice a data from source to dst with pad value. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyMemRef:$src, + Variadic:$offsets, + Variadic:$slice_shape, + AnyType:$pad_value, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + + let assemblyFormat = [{$dst `,` $src `[` $offsets `]` `[` $slice_shape `]` `,` $pad_value `,` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($pad_value) `,` type($tag_memref)}]; + + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(5), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Deslice Op (tag version) +// +def GCU_DesliceStartOp : MemrefExt_Op<"deslice_start", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Deslice"; + let description = [{ + `memref_ext.deslice_start` deslice a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets, AnyMemRef:$tag_memref, Variadic:$tag_indices); + let assemblyFormat = [{$dst `,` $src `[` $offsets `]` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($tag_memref)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(3), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Slice_Deslice Op (tag version) +// +def GCU_SliceDesliceStartOp : MemrefExt_Op<"slice_deslice_start", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Slice_Deslice"; + let description = [{ + `memref_ext.slice_deslice_start` deslice a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, AnyMemRef:$src, Variadic:$offsets, Variadic:$slice_shape, Variadic:$dst_offsets, AnyMemRef:$tag_memref, Variadic:$tag_indices); + let assemblyFormat = [{$dst `,` $src `[` $offsets `]` `[` $slice_shape `]` `[` $dst_offsets `]` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($tag_memref)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(5), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Transpose Op (tag version) +// +def GCU_TransposeStartOp : MemrefExt_Op<"transpose_start", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Transpose"; + let description = [{ + `memref_ext.transpose_start` transpose a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyMemRef:$src, + Variadic:$layout, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + + let assemblyFormat = [{$dst `,` $src `[` $layout `]` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($tag_memref)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(3), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Broadcast Op (tag version) +// +def GCU_BroadcastStartOp : MemrefExt_Op<"broadcast_start", [MemoryEffectsOpInterface]> { + let summary = "Broadcast"; + let description = [{ + `memref_ext.broadcast_start` broadcast a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyMemRef:$src, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + let assemblyFormat = [{$dst `,` $src `,` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($tag_memref)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(2), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Slice_Broadcast Op (tag version) +// +def GCU_SliceBroadcastStartOp : MemrefExt_Op<"slice_broadcast_start", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Slice and broadcast"; + let description = [{ + `memref_ext.slice_broadcast_start` slice_broadcast a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyMemRef:$src, + Variadic:$offsets, + Variadic:$slice_shape, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + let assemblyFormat = [{$dst `,` $src `[` $offsets `]` `[` $slice_shape `]` `,` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($tag_memref)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(4), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Slice_Transpose Op (tag version) +// +def GCU_SliceTransposeStartOp : MemrefExt_Op<"slice_transpose_start", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Slice and transpose"; + let description = [{ + `memref_ext.slice_transpose_start` slice_transpose a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyMemRef:$src, + Variadic:$offsets, + Variadic:$layout, + AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + let assemblyFormat = [{$dst `,` $src `[` $offsets `]` `[` $layout `]` `,` $value `,` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($value) `,` type($tag_memref)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(5), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Transpose_Deslice Op (tag version) +// +def GCU_TransposeDesliceStartOp : MemrefExt_Op<"transpose_deslice_start", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Transpose and deslice"; + let description = [{ + `memref_ext.transpose_deslice_start` transpose_deslice a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyMemRef:$src, + Variadic:$layout, + Variadic:$offsets, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + let assemblyFormat = [{$dst `[` $offsets `]` `,` $src `[` $layout `]` `,` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($tag_memref)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(4), + SideEffects::DefaultResource::get()); + } + }]; +} + +// +// Memset_Deslice Op (tag version) +// +def GCU_MemsetDesliceStartOp : MemrefExt_Op<"memset_deslice_start", [AttrSizedOperandSegments, MemoryEffectsOpInterface]> { + let summary = "Memset and deslice"; + let description = [{ + `memref_ext.memset_deslice_start` memset_deslice a data from source to dst. + }]; + let arguments = (ins AnyMemRef:$dst, + AnyMemRef:$src, + Variadic:$offsets, + AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value, + AnyMemRef:$tag_memref, + Variadic:$tag_indices); + let assemblyFormat = [{$dst `[` $offsets `]` `,` $src `,` $value `,` $tag_memref `[` $tag_indices `]` attr-dict `:` + type($dst) `,` type($src) `,` type($value) `,` type($tag_memref)}]; + let extraClassDeclaration = [{ + void getEffects( + SmallVectorImpl> & + effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(0), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getOperation()->getOpOperand(1), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getOperation()->getOpOperand(4), + SideEffects::DefaultResource::get()); + } + }]; +} + +#endif // MemrefExt_OPS diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/CMakeLists.txt new file mode 100644 index 000000000..e6346cfea --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/CMakeLists.txt @@ -0,0 +1,23 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonGCUDialect.td) +mlir_tablegen(TritonGCUDialect.h.inc -gen-dialect-decls) +mlir_tablegen(TritonGCUDialect.cpp.inc -gen-dialect-defs) + +set(LLVM_TARGET_DEFINITIONS TritonGCUTypes.td) +mlir_tablegen(TritonGCUTypes.h.inc -gen-typedef-decls) +mlir_tablegen(TritonGCUTypes.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonGCUOps.td) +mlir_tablegen(TritonGCUOps.h.inc -gen-op-decls) +mlir_tablegen(TritonGCUOps.cpp.inc -gen-op-defs) +mlir_tablegen(TritonGCUOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(TritonGCUOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(TritonGCUOpsAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonGCUOpsAttributes.cpp.inc -gen-attrdef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonGCUInterfaces.td) +mlir_tablegen(TritonGCUAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(TritonGCUAttrInterfaces.cpp.inc -gen-attr-interface-defs) + +add_public_tablegen_target(TritonGCUTableGen_${arch}) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUAttrDefs.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUAttrDefs.td new file mode 100644 index 000000000..9451fe7ee --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUAttrDefs.td @@ -0,0 +1,28 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRITON_GCU_ATTR_DEFS +#define TRITON_GCU_ATTR_DEFS + +include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "Dialect/TritonGCU/IR/TritonGCUDialect.td" + +class TritonGCU_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +#endif // TRITON_GCU_ATTR_DEFS diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUDialect.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUDialect.h new file mode 100644 index 000000000..0a81ace33 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUDialect.h @@ -0,0 +1,46 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_TRITON_GCU_DIALECT_H +#define GCU_DIALECT_TRITON_GCU_DIALECT_H + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h.inc" +#include "Dialect/TritonGCU/IR/TritonGCUOpsEnums.h.inc" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#define GET_OP_CLASSES +#include "Dialect/TritonGCU/IR/TritonGCUOps.h.inc" +#define GET_ATTRDEF_CLASSES +#include "Dialect/TritonGCU/IR/TritonGCUOpsAttributes.h.inc" + +namespace mlir { +namespace triton { +namespace gcu {} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif // GCU_DIALECT_TRITON_GCU_DIALECT_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUDialect.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUDialect.td new file mode 100644 index 000000000..f3db79070 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUDialect.td @@ -0,0 +1,31 @@ +#ifndef TRITONGCU_DIALECT +#define TRITONGCU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonGCU_Dialect : Dialect { + let name = "triton_gcu"; + + let cppNamespace = "::mlir::triton::gcu"; + + let summary = "The TritonGCU IR in MLIR"; + + let description = [{ + TritonGCU Dialect. + }]; + + let dependentDialects = [ + "::mlir::arith::ArithDialect", + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; + // let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif // TRITONGCU_DIALECT diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUInterfaces.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUInterfaces.h new file mode 100644 index 000000000..231dbf734 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUInterfaces.h @@ -0,0 +1,24 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRITON_GCU_DIALECT_GCU_IR_INTERFACES_H +#define TRITON_GCU_DIALECT_GCU_IR_INTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +#define GET_TYPEDEF_CLASSES +#include "Dialect/TritonGCU/IR/TritonGCUAttrInterfaces.h.inc" + +#endif // TRITON_GCU_DIALECT_GCU_IR_INTERFACES_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUInterfaces.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUInterfaces.td new file mode 100644 index 000000000..5a9c4eef4 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUInterfaces.td @@ -0,0 +1,21 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRITON_GCU_INTERFACES +#define TRITON_GCU_INTERFACES + +include "mlir/IR/OpBase.td" + +#endif // TRITON_GCU_INTERFACES diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUOps.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUOps.td new file mode 100644 index 000000000..33a32ae74 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUOps.td @@ -0,0 +1,402 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRITONGCU_OPS +#define TRITONGCU_OPS + +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td" +include "Dialect/TritonGCU/IR/TritonGCUDialect.td" +include "Dialect/TritonGCU/IR/TritonGCUAttrDefs.td" +include "Dialect/TritonGCU/IR/TritonGCUInterfaces.td" +include "Dialect/TritonGCU/IR/TritonGCUTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/IR/BuiltinTypeInterfaces.td" + + + +// +// Op Base +// +class TTGCU_Op traits = []> : + Op { +} + +// +// Load Op +// +def TTGCU_LoadOp : TTGCU_Op<"load", + [MemoryEffects<[MemRead]>, AttrSizedOperandSegments]> { + let summary = "Load"; + let description = [{ + }]; + let arguments = (ins TTGCU_PtrType:$ptr, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + Optional:$default_value, + DefaultValuedAttr{}">:$order_hint); + let results = (outs AnyStaticShapeTensor:$result); + let assemblyFormat = [{$ptr `,` `[` $shape `]` `,` + `[` $strides `]` `,` `[` $offsets `]` + (`,` $default_value^)? `,` `[` $order_hint `]` + attr-dict `:` type($ptr) + (`,` type($default_value)^)? `->` type($result)}]; + let hasVerifier = 1; +} + +// +// Store Op +// +def TTGCU_StoreOp : TTGCU_Op<"store", + [MemoryEffects<[MemWrite]>, AttrSizedOperandSegments]> { + let summary = "store"; + let description = [{ + }]; + let arguments = (ins AnyStaticShapeTensor:$value, + TTGCU_PtrType:$ptr, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + DefaultValuedAttr{}">:$order_hint); + let assemblyFormat = [{$value `,` $ptr `,` `[` $shape `]` `,` `[` $strides `]` + `,` `[` $offsets `]` `,` `[` $order_hint `]` + attr-dict `:` type($value) `,` type($ptr)}]; + let hasVerifier = 1; +} +// +// Yield Op +// +def TTGCU_YieldOp : TTGCU_Op<"yield", [ + HasParent<"ElementwiseFusionRegionOp">, + Pure, + ReturnLike, + Terminator]> +{ + let summary = "yield value"; + + let description = [{ + }]; + let arguments = (ins Variadic:$operands); + let assemblyFormat = [{ $operands attr-dict `:` type($operands)}]; +} + +// +// PtrToInt Op +// +def GCU_PtrToIntOp : TTGCU_Op<"ptr2int", [Pure]> +{ + let summary = "Convert a pointer to integer"; + + let description = [{ + }]; + let arguments = (ins TTGCU_PtrType:$ptr); + let results = (outs I64:$result); + let assemblyFormat = [{ $ptr attr-dict `:` type($ptr) }]; +} + +// +// IntToPtr Op +// +def TritonGCU_IntToPtrOp : TTGCU_Op<"int2ptr", [Pure]> +{ + let summary = "Convert an integer to pointer"; + + let description = [{ + }]; + let arguments = (ins I64:$value); + let results = (outs TTGCU_PtrType:$ptr); + let assemblyFormat = [{ $value attr-dict `:` type($ptr) }]; +} + +// +// ElementwiseFusionRegionOp +// +def ElementwiseFusionRegionOp : TTGCU_Op<"elementwise_fusion_region", [ + IsolatedFromAbove, + PredOpTrait< + "requires the same shape for ranked tensor operands and results", + CPred<[{ + [&]() { + ArrayRef shape; + return llvm::all_of(getOperandTypes(), + [&](auto type) { + auto operandType = dyn_cast(type); + if (operandType) { + if (shape.empty()) { + shape = operandType.getShape(); + } else { + return shape == operandType.getShape(); + } + } + return true; + }) && + ((!shape.empty() && getNumResults() > 0) + ? cast(getResultTypes()[0]).getShape() == shape + : true); + }() + }]> + >, + AllShapesMatch<["results"]>, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Elementwise fusion region"; + let description = [{ + }]; + let arguments = (ins Variadic>:$operands); + let results = (outs Variadic); + let regions = (region AnyRegion:$region); +} + +// +// Assert Op +// +def TritonGCU_AssertOp : TTGCU_Op<"assert", []> +{ + let summary = "device assert"; + let description = [{ + gcu device assert. + }]; + let arguments = (ins I1:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line); + let assemblyFormat = [{ + $condition attr-dict `:` type($condition) + }]; +} + +// +// Allocate shared memory +// +def TritonGCU_ShareAllocOp : TTGCU_Op<"share_alloc", []> { + let summary = "allocate tensor"; + let description = [{ + This operation allocates buffer in shared memory and return a descriptor + containing the address and a view of the buffer. + + Explicitly deallocating a buffer is optional; see local_dealloc. + }]; + let arguments = (ins Optional:$src); + + let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; + + let results = (outs TTG_MemDescType:$result); +} + +// +// Deallocate shared memory +// +def TritonGCU_ShareDeallocOp : TTGCU_Op<"share_dealloc", []> { + let summary = "dealloc buffer"; + + let description = [{ + This operation deallocates a buffer explicitly. Using the buffer after this + operation is undefined. + + This operation is optional. If you don't explicitly dealloc a buffer, the + compiler assumes it's deallocated at the first point that post-dominates all + uses of the alloc. + + Because we assume a memdesc is dead at the first point that post-dominates + its uses, ops that wait for an async operation on a memdesc to complete + (such as triton_nvidia_gpu.dot_wait) should also take the memdesc as an + operand. + }]; + + let arguments = (ins TTG_MemDescType:$src); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; +} + + +// +// async-load-global-to-share +// +def TritonGCU_AsyncLoadGlobalToShareOp : TTGCU_Op<"async_load_global_to_share", [MemoryEffects<[MemRead]>, AttrSizedOperandSegments]> { + let summary = "AsyncLoadGlobalToShare"; + let description = [{ + }]; + let arguments = (ins TTGCU_PtrType:$ptr, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + TTG_MemDescType:$dstMem, + Optional:$default_value, + DefaultValuedAttr{}">:$order_hint); + let results = (outs TTG_AsyncToken:$result); + let assemblyFormat = [{$ptr `,` `shape_` `[` $shape `]` `,` + `stride_` `[` $strides `]` `,` `offset_` `[` $offsets `]` `,`$dstMem + (`,` $default_value^)? `,` `[` $order_hint `]` + attr-dict `:` type($ptr)`,`type($dstMem) + (`,` type($default_value)^)? `->` type($result)}]; +} + +def TritonGCU_AsyncWaitOp :TTGCU_Op<"async_wait"> { + let summary = "async wait"; + + let arguments = (ins Variadic:$asyncToken); + let results = (outs TTG_AsyncToken:$retToken); + let assemblyFormat = "$asyncToken attr-dict"; +} + +def TritonGCU_LocalLoadOp : TTGCU_Op<"local_load", []> { + let summary = "Load a buffer from local memory into a distributed tensor"; + + let description = [{ + Load a tensor from the local memory descriptor into a distributed tensor. + }]; + let arguments = (ins TTG_MemDescType:$src, Optional :$token); + + let builders = [ + OpBuilder<(ins "Type":$retType, "Value":$src), + [{ + build($_builder, $_state, retType, src, /*token=*/static_cast(nullptr)); + }]>]; + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; + + let results = (outs TT_Tensor:$result); +} + + +def TritonGCU_MatmulOp : TTGCU_Op<"matmul", [Pure]> { + let summary = "matmul"; + + let description = [{ + $d = matrix_multiply($a, $b)$inputPrecision describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. + }]; + + let arguments = ( + ins + TT_FpIntTensor:$a, + TT_FpIntTensor:$b + ); + + let results = (outs TT_FpIntTensor:$d); + + // attr-dict prints enums as integers. To get inputPrecision printed as a + // string, we need to specify it explicitly. + let assemblyFormat = [{ + $a`,` $b attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; +} + +def TritonGCU_MaskedLoadOp : TTGCU_Op<"maskedload", [ + MemoryEffects<[MemRead]>, + AttrSizedOperandSegments, + InferTypeOpInterface, + PredOpTrait< + "requires the same shape and encoding for ranked tensor operands", + CPred<[{ + [&](){ + auto resultType = cast(getResult().getType()); + auto encoding = resultType.getEncoding(); + auto shape = resultType.getShape(); + return llvm::all_of(llvm::drop_begin(getOperandTypes(), 1), [&](auto type){ + auto operandType = dyn_cast(type); + return !operandType || (encoding == operandType.getEncoding() && shape == operandType.getShape()); + }); + }() + }]> + >, + PredOpTrait< + "value type matches ptr element type", + CPred<"!getOther() || getPtr().getType().getPointeeType() == getElementTypeOrSelf(getOther())"> + > +]> { + let summary = "Load from a tensor of pointers"; + + let arguments = (ins + TT_Ptr:$ptr, + TT_IntTensor:$offset, + Optional:$mask, + Optional:$other + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $ptr `,` $offset (`,` $mask^)? (`,` $other^)? + attr-dict `:` type($ptr) `,` type($offset) (`,` type($mask)^)? (`,` type($other)^)? `->` type($result) + }]; + + let extraClassDeclaration = [{ + static ::llvm::LogicalResult inferReturnTypes( + ::mlir::MLIRContext * context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> & inferredReturnTypes) { + inferredReturnTypes.resize(1); + auto elementTy = cast<::mlir::triton::PointerType>(operands[0].getType()) + .getPointeeType(); + auto rankTensorType = cast<::mlir::RankedTensorType>(operands[1].getType()); + auto rank = rankTensorType.getShape(); + auto encoding = rankTensorType.getEncoding(); + inferredReturnTypes[0] = RankedTensorType::get(rank, elementTy, encoding); + return ::mlir::success(); + } + }]; +} + +def TritonGCU_MaskedStoreOp : TTGCU_Op<"maskedstore", [ + MemoryEffects<[MemWrite]>, + PredOpTrait< + "requires the same shape and encoding for ranked tensor operands", + CPred<[{ + [&](){ + auto rankTensorType = cast(getOffset().getType()); + auto encoding = rankTensorType.getEncoding(); + auto shape = rankTensorType.getShape(); + return llvm::all_of(llvm::drop_begin(getOperandTypes(), 2), [&](auto type){ + auto operandType = cast(type); + return encoding == operandType.getEncoding() && shape == operandType.getShape(); + }); + }() + }]> + >, + PredOpTrait< + "value type matches ptr type", + CPred<"getPtr().getType().getPointeeType() == getElementTypeOrSelf(getValue())"> + >, + OptionalTypesMatchWith<"mask type matches offset type", "offset", "mask", + "getI1SameShape($_self)"> +]> { + let summary = "Store by a tensor of pointers"; + + let arguments = (ins + TT_Ptr:$ptr, + TT_IntTensor:$offset, + TT_Tensor:$value, + Optional:$mask + ); + + let assemblyFormat = [{ + $ptr `,` $offset `,` $value (`,` $mask^)? + attr-dict `:` type($ptr) `,` type($offset) `,` type($value) (`,` type($mask)^)? + }]; +} +#endif // TRITONGCU_OPS diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUTypes.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUTypes.h new file mode 100644 index 000000000..a87d165cb --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUTypes.h @@ -0,0 +1,32 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_DIALECT_TRITON_GCU_IR_TYPES_H +#define GCU_DIALECT_TRITON_GCU_IR_TYPES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace triton { +namespace gcu {} // namespace gcu +} // namespace triton +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h.inc" + +#endif // GCU_DIALECT_TRITON_GCU_IR_TYPES_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUTypes.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUTypes.td new file mode 100644 index 000000000..413969f9d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Dialect/TritonGCU/IR/TritonGCUTypes.td @@ -0,0 +1,42 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRITON_GCU_TYPES +#define TRITON_GCU_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "Dialect/TritonGCU/IR/TritonGCUDialect.td" +include "mlir/IR/BuiltinTypeInterfaces.td" + +// +// Types +// +class TTGCUTypeDef traits = []> + : TypeDef { + // Used by dte cte + let mnemonic = _mnemonic; +} + +// Pointer Type +def TTGCU_PtrType : TTGCUTypeDef<"Ptr", "ptr", [MemRefElementTypeInterface]> { + let summary = "Pointer type (`::mlir::triton::gcu::PtrType`) in GCU IR type system"; + let description = [{ + Pointer type in GCU IR type system. + }]; + let parameters = (ins "Type":$elementType); + let assemblyFormat = [{ `<` $elementType `>` }]; +} + +#endif //TRITON_GCU_TYPES diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/CMakeLists.txt new file mode 100644 index 000000000..8483798da --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/CMakeLists.txt @@ -0,0 +1,8 @@ + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGCUTransforms) +mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix TritonGCUTransforms) +mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix TritonGCUTransforms) +add_mlir_doc(Passes TritonGCUTransformsPasses_${arch} transforms/ -gen-pass-doc) + +add_public_tablegen_target(TritonGCUTransformsPassIncGen_${arch}) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/Passes.h b/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/Passes.h new file mode 100644 index 000000000..5f11a6672 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/Passes.h @@ -0,0 +1,34 @@ +/** + * Copyright 2025-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRITON_GCU_TRANSFORMS_PASSES_H +#define TRITON_GCU_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace mlir { +template class InterfacePass; +class Pass; + +#define GEN_PASS_DECL_GCU64TYPEVERIFIERPASS +#define GEN_PASS_DECL_GCUCOMBINEOPS + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "Transforms/Passes.h.inc" +} // namespace mlir + +#endif // TRITON_GCU_TRANSFORMS_PASSES_H diff --git a/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/Passes.td b/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/Passes.td new file mode 100644 index 000000000..4329082e5 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/include/Transforms/Passes.td @@ -0,0 +1,51 @@ +/** + * Copyright 2025-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GCU_TRANSFORMS_PASSES +#define GCU_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" +include "triton/Dialect/Triton/IR/TritonOps.td" + +def GCU64TypeVerifierPass: Pass<"gcu64-type-verifier", "gpu::GPUModuleOp"> { + let summary = "Verify that all types are GCU64 types."; + let description = "Verify that no 64bit type on func arguments."; + let dependentDialects = [ + "gpu::GPUDialect", + "triton::TritonDialect", + ]; + let options = [ + Option<"arch", "arch", "std::string", + /*default=*/"\"gcu300\"", + "Architecture that these operations will run on">, + Option<"test_mode", "test-mode", "bool", /*default=*/"false", + "test mode for filecheck"> + ]; +} + +def GCUCombineOps : Pass<"gcu-combine-ops", "gpu::GPUModuleOp"> { + let summary = "combine ops"; + let description = [{ + This pass aims to optimize the specific patterns. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "triton::TritonDialect" + ]; +} + +#endif // GCU_TRANSFORMS_PASSES diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/AxisInfoEx.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/AxisInfoEx.cpp new file mode 100644 index 000000000..e29379315 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/AxisInfoEx.cpp @@ -0,0 +1,1446 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Analysis/AxisInfoEx.h" +#include +#include +#include + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "axis-info-ex" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gcu { +namespace { + +int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) { + // Base Case + if (a == 0) { + *x = 0; + *y = 1; + return b; + } + int64_t x1, y1; // To store results of recursive call + int64_t gcd = gcdImpl(b % a, a, &x1, &y1); + // Update x and y using results of + // recursive call + *x = y1 - (b / a) * x1; + *y = x1; + return gcd; +} + +int64_t gcd(int64_t a, int64_t b) { + if (a == 0) + return b; + if (b == 0) + return a; + int64_t x, y; + return gcdImpl(a, b, &x, &y); +} + +constexpr int log2Int(int64_t num) { + return (num > 1) ? 1 + log2Int(num / 2) : 0; +} + +// If lhs * rhs overflows, return max value possible value for the type +int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { + int64_t maxDivisor = highestPowOf2Divisor(0); + if (lhs > maxDivisor / rhs) + return maxDivisor; + return lhs * rhs; +} + +class AxisInfoExVisitor { +public: + AxisInfoExVisitor() = default; + virtual ~AxisInfoExVisitor() = default; + + static bool isContiguousDim(const AxisInfoEx &info, ArrayRef shape, + int dim) { + return info.getContiguity(dim) == shape[dim]; + } + + static bool isConstantDim(const AxisInfoEx &info, ArrayRef shape, + int dim) { + return info.getConstancy(dim) == shape[dim]; + } + + virtual AxisInfoEx + getAxisInfoEx(Operation *op, + ArrayRef *> operands) = 0; + + virtual bool match(Operation *op) = 0; +}; + +// Base class for all operations +template +class AxisInfoExVisitorImpl : public AxisInfoExVisitor { +public: + using AxisInfoExVisitor::AxisInfoExVisitor; + + AxisInfoEx getAxisInfoEx( + Operation *op, + ArrayRef *> operands) final { + return getAxisInfoEx(cast(op), operands); + } + + bool match(Operation *op) final { return isa(op); } + + virtual AxisInfoEx + getAxisInfoEx(OpTy op, + ArrayRef *> operands) = 0; +}; + +// Binary operations +template +class BinaryOpVisitorImpl : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + assert(operands.size() == 2 && "Expected two operands"); + AxisInfoEx::DimVectorT divisibility; + AxisInfoEx::DimVectorT continualSize; + AxisInfoEx::DimVectorT continualInterval; + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + for (auto i = 0; i < rank; ++i) { + if (constantValue.has_value()) { + divisibility.push_back( + highestPowOf2Divisor(constantValue.value())); + continualSize.push_back( + std::max(lhsInfo.getContinualSize(i), rhsInfo.getContinualSize(i))); + continualInterval.push_back(0); + } else { + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, i)); + continualSize.push_back(getContinualSize(op, lhsInfo, rhsInfo, i)); + continualInterval.push_back( + getContinualInterval(op, lhsInfo, rhsInfo, i)); + } + } + return AxisInfoEx(divisibility, continualSize, continualInterval, + constantValue); + } + +protected: + virtual int64_t getDivisibility(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) { + return 1; + } + + virtual int64_t getContinualSize(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) { + return 1; + } + + virtual int64_t getContinualInterval(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) { + return 1; + } + + virtual std::optional + getConstantValue(OpTy op, const AxisInfoEx &lhs, const AxisInfoEx &rhs) { + return std::nullopt; + } +}; + +class AxisInfoExVisitorList { +public: + template > + void append() { + (visitors.emplace_back(std::make_unique()), ...); + } + + AxisInfoEx apply(Operation *op, + ArrayRef *> operands) { + for (auto &visitor : visitors) + if (visitor->match(op)) + return visitor->getAxisInfoEx(op, operands); + return AxisInfoEx(); + } + +private: + std::vector> visitors; +}; + +class AxisInfoExAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + AxisInfoExVisitorList visitors; + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged(lattice, + lattice->join(AxisInfoEx::getPessimisticValueState( + lattice->getAnchor()))); + } + + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef *> argLattices, + unsigned firstIndex) override { + if (auto forOp = dyn_cast(op)) { + visitForOpInductionVar(forOp, argLattices); + } else { + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front( + firstIndex + successor.getSuccessorInputs().size())); + } + } + +public: + explicit AxisInfoExAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncAxisInfoMapT = DenseMap; + + llvm::LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + void + visitForOpInductionVar(scf::ForOp op, + ArrayRef *> argLattices); +}; + +template +class CastOpAxisInfoExVisitor final : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + return operands[0]->getValue(); + } +}; + +class MakeRangeOpAxisInfoExVisitor final + : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + triton::MakeRangeOp op, + ArrayRef *> operands) override { + auto start = op.getStart(); + auto end = op.getEnd(); + return AxisInfoEx(/*divisibility=*/{highestPowOf2Divisor(start)}, + /*continualSize=*/{end - start}, + /*continualInterval=*/{1}); + } +}; + +template +class ConstantOpAxisInfoExVisitor final : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + auto intAttr = dyn_cast(op.getValue()); + auto boolAttr = dyn_cast(op.getValue()); + if (intAttr || boolAttr) { + int64_t value{}; + if (intAttr) + value = intAttr.getValue().getZExtValue(); + else + value = boolAttr.getValue() ? 1 : 0; + return AxisInfoEx(/*divisibility=*/{highestPowOf2Divisor(value)}, + /*continualSize=*/{AxisInfoEx::kDefaultContinueSize}, + /*continualInterval=*/{0}, + /*knownConstantValue=*/{value}); + } + + auto splatAttr = dyn_cast(op.getValue()); + if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { + int64_t value = splatAttr.template getSplatValue().getZExtValue(); + TensorType ty = cast(splatAttr.getType()); + return AxisInfoEx( + /*divisibility=*/ + AxisInfoEx::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), + /*continualSize=*/ + AxisInfoEx::DimVectorT(ty.getShape().begin(), ty.getShape().end()), + /*continualInterval=*/ + AxisInfoEx::DimVectorT(ty.getRank(), 0), + /*knownConstantValue=*/{value}); + } + return AxisInfoEx(); + } +}; + +template +class AddSubOpAxisInfoExVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getDivisibility(OpTy op, const AxisInfoEx &lhs, const AxisInfoEx &rhs, + int dim) override { + // lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs) + // rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs) + // lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) * + // gcd(d_lhs, d_rhs) + auto rhsDivisibility = rhs.getDivisibility(dim); + return gcd(lhs.getDivisibility(dim), rhsDivisibility); + } + + int64_t getContinualSize(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + return gcd(lhs.getContinualSize(dim), rhs.getContinualSize(dim)); + } + + int64_t getContinualInterval(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + if (lhs.getContinualInterval(dim) == + AxisInfoEx::kDefaultContinualInterval || + rhs.getContinualInterval(dim) == AxisInfoEx::kDefaultContinualInterval) + return AxisInfoEx::kDefaultContinualInterval; + return std::abs( + applyOp(lhs.getContinualInterval(dim), rhs.getContinualInterval(dim))); + } + + std::optional getConstantValue(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs) override { + if (!lhs.getConstantValue().has_value() || + !rhs.getConstantValue().has_value()) { + return std::nullopt; + } + + return {applyOp(lhs.getConstantValue().value(), + rhs.getConstantValue().value())}; + } + +private: + static int64_t applyOp(int64_t lhs, int64_t rhs) { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return lhs - rhs; + } + return lhs + rhs; + } +}; + +class MulIOpAxisInfoExVisitor final + : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getDivisibility(arith::MulIOp op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + auto lhsDivisibility = lhs.getDivisibility(dim); + auto rhsDivisibility = rhs.getDivisibility(dim); + return multiplyDivisor(lhsDivisibility, rhsDivisibility); + } + + int64_t getContinualSize(arith::MulIOp op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + return std::max(gcd(lhs.getConstancy(dim), rhs.getContinualSize(dim)), + gcd(lhs.getContinualSize(dim), rhs.getConstancy(dim))); + } + + int64_t getContinualInterval(arith::MulIOp op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + if (lhs.getContinualInterval(dim) == + AxisInfoEx::kDefaultContinualInterval || + rhs.getContinualInterval(dim) == AxisInfoEx::kDefaultContinualInterval) + return AxisInfoEx::kDefaultContinualInterval; + + // lhs * cst + auto lhsStrideValue = + rhs.getConstantValue().has_value() + ? lhs.getContinualInterval(dim) * rhs.getConstantValue().value() + : AxisInfoEx::kDefaultContinualInterval; + // cst * rhs + auto rhsStrideValue = + lhs.getConstantValue().has_value() + ? rhs.getContinualInterval(dim) * lhs.getConstantValue().value() + : AxisInfoEx::kDefaultContinualInterval; + return std::max(lhsStrideValue, rhsStrideValue); + } + + std::optional getConstantValue(arith::MulIOp op, + const AxisInfoEx &lhs, + const AxisInfoEx &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; + return std::nullopt; + } +}; + +template +class DivOpAxisInfoExVisitor final : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + assert(operands.size() == 2 && "Expected two operands"); + auto resTy = dyn_cast(op.getResult().getType()); + if (!resTy) + return AxisInfoEx{}; + + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto &lhs = operands[0]->getValue(); + auto &rhs = operands[1]->getValue(); + + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + std::optional constantValue; + for (short i = 0; i < rank; ++i) { + if ((rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) || + (lhs.getConstantValue().has_value() && + lhs.getConstantValue().value() == 0)) { + // Case1: lhs / 1 or 0 / rhs, the result both equal to lhs. + divisibility.push_back(lhs.getDivisibility(i)); + continualSize.push_back(lhs.getContinualSize(i)); + continualInterval.push_back(lhs.getContinualInterval(i)); + constantValue = {lhs.getConstantValue()}; + } else if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + // Case2: cst1 / cst2. + continualSize.push_back(lhs.getConstancy(i)); + continualInterval.push_back(0); + constantValue = {lhs.getConstantValue().value() / + rhs.getConstantValue().value()}; + divisibility.push_back(highestPowOf2Divisor(constantValue.value())); + } else if (!lhs.isConstantDim(shape, i) && lhs.isContinualDim(shape, i) && + rhs.isConstantDim(shape, i) && + rhs.getConstantValue().has_value() && + llvm::isPowerOf2_64(lhs.getContinualInterval(i))) { + // Case 3: lhs stride(stride_val is power of 2), rhs constant. + // lhs: d_lhs * k, d_lhs * k + s, ..., d_lhs * k + n * s + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + s) / (d_rhs * p), + // ..., (d_lhs * k + n*s) / (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // the minimal stride is + // minStride = max(gcd(d_lhs, d_rhs) / strideVal, 1). + // Since minStride maybe > len(lhs), + // we need to use another gcd to get the actual constancy. + int64_t divisibilityGCD = + gcd(lhs.getDivisibility(i), rhs.getDivisibility(i)); + bool isContinual = + lhs.getContinualInterval(i) % rhs.getConstantValue().value() == 0; + int64_t newContinualSize = + isContinual ? lhs.getContinualSize(i) + : std::max( + divisibilityGCD / lhs.getContinualInterval(i), 1); + continualSize.push_back(gcd(lhs.getContinualSize(i), newContinualSize)); + continualInterval.push_back(lhs.getContinualInterval(i) / + rhs.getConstantValue().value()); + divisibility.push_back(std::max( + lhs.getDivisibility(i) / rhs.getConstantValue().value(), 1)); + } else if (lhs.isStridedConstantDim(shape, i) && + rhs.getConstantValue().has_value()) { + divisibility.push_back(std::max( + lhs.getDivisibility(i) / rhs.getConstantValue().value(), 1)); + continualSize.push_back( + gcd(lhs.getContinualSize(i), rhs.getContinualSize(i))); + continualInterval.push_back(0); + } else { + divisibility.push_back(AxisInfoEx::kInitDivisibility); + continualSize.push_back(AxisInfoEx::kDefaultContinueSize); + continualInterval.push_back(AxisInfoEx::kDefaultContinualInterval); + } + } + return AxisInfoEx(divisibility, continualSize, continualInterval, + constantValue); + } +}; + +template +class RemOpAxisInfoExVisitor final : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + assert(operands.size() == 2 && "Expected two operands"); + auto resTy = dyn_cast(op.getResult().getType()); + if (!resTy) + return AxisInfoEx{}; + + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto &lhs = operands[0]->getValue(); + auto &rhs = operands[1]->getValue(); + + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + std::optional constantValue; + for (short i = 0; i < rank; ++i) { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) { + // Case1: lhs % 1. + divisibility.push_back(highestPowOf2Divisor(0)); + continualSize.push_back(shape[i]); + continualInterval.push_back(0); + constantValue = {0}; + } else if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + // Case2: cst1 % cst2. + constantValue = {lhs.getConstantValue().value() % + rhs.getConstantValue().value()}; + divisibility.push_back(highestPowOf2Divisor(constantValue.value())); + continualSize.push_back(lhs.getConstancy(i)); + continualInterval.push_back(0); + } else if (lhs.isContinualLowDim(shape, i) && + rhs.isConstantDim(shape, i)) { + // Case3: lhs contiguous, rhs constant. + // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' + // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' + // lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r + // r must be divisible by gcd(d_lhs, d_rhs) + divisibility.push_back( + gcd(lhs.getDivisibility(i), rhs.getDivisibility(i))); + + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p), + // ..., (d_lhs * k + n) % (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // The minimal contiguity is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual contiguity. + continualSize.push_back( + gcd(lhs.getContiguity(i), + gcd(lhs.getDivisibility(i), rhs.getDivisibility(i)))); + continualInterval.push_back(1); + } else if (lhs.isStridedContinualDim(shape, i) && + rhs.getConstantValue().has_value()) { + // Case4: lhs strided contiguous, rhs constant value. + divisibility.push_back( + gcd(lhs.getDivisibility(i), rhs.getDivisibility(i))); + continualSize.push_back( + gcd(lhs.getContiguity(i), + gcd(lhs.getDivisibility(i), rhs.getDivisibility(i)))); + continualInterval.push_back(lhs.getContinualInterval(i) % + rhs.getConstantValue().value()); + } else if (lhs.isStridedConstantDim(shape, i) && + rhs.getConstantValue().has_value()) { + // Case5: lhs strided constant, rhs constant value. + divisibility.push_back( + gcd(lhs.getDivisibility(i), rhs.getDivisibility(i))); + continualSize.push_back(lhs.getConstancy(i)); + continualInterval.push_back(0); + } else { + divisibility.push_back(AxisInfoEx::kInitDivisibility); + continualSize.push_back(AxisInfoEx::kDefaultContinueSize); + continualInterval.push_back(AxisInfoEx::kDefaultContinualInterval); + } + } + + return AxisInfoEx(divisibility, continualSize, continualInterval, + constantValue); + } +}; + +class SplatOpAxisInfoExVisitor final + : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + triton::SplatOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + TensorType retTy = cast(_retTy); + AxisInfoEx opInfo = operands[0]->getValue(); + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + for (int i = 0; i < retTy.getRank(); ++i) { + divisibility.push_back(opInfo.getDivisibility(0)); + continualSize.push_back(retTy.getShape()[i]); + continualInterval.push_back(0); + } + return AxisInfoEx(divisibility, continualSize, continualInterval, + operands[0]->getValue().getConstantValue()); + } +}; + +class LoadOpAxisInfoExVisitor final + : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + triton::LoadOp op, + ArrayRef *> operands) override { + // If pointers and mask both have constancy properties, those properties + // will also extend to output. + AxisInfoEx ptrInfo = operands[0]->getValue(); + std::optional maskInfo; + if (operands.size() > 1) { + maskInfo = operands[1]->getValue(); + } + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + + for (int i = 0; i < ptrInfo.getRank(); ++i) { + divisibility.push_back(ptrInfo.getDivisibility(i)); + continualSize.push_back( + gcd(ptrInfo.getContinualSize(i), + maskInfo.has_value() ? maskInfo->getConstancy(i) : 0)); + continualInterval.push_back(ptrInfo.getContinualInterval(i)); + } + + return AxisInfoEx(divisibility, continualSize, continualInterval); + } +}; + +class ExpandDimsOpAxisInfoExVisitor final + : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + triton::ExpandDimsOp op, + ArrayRef *> operands) override { + AxisInfoEx opInfo = operands[0]->getValue(); + AxisInfoEx::DimVectorT divisibility = opInfo.getDivisibility(); + AxisInfoEx::DimVectorT continualSize = opInfo.getContinualSize(); + AxisInfoEx::DimVectorT continualInterval = opInfo.getContinualInterval(); + + ArrayRef srcShape = op.getSrc().getType().getShape(); + int64_t expandedDim = std::max(static_cast(op.getAxis()) - 1, 0); + int64_t expandedDivisibility = opInfo.isConstantDim(srcShape, expandedDim) + ? divisibility[expandedDim] + : AxisInfoEx::kInitDivisibility; + divisibility.insert(divisibility.begin() + op.getAxis(), + expandedDivisibility); + continualSize.insert(continualSize.begin() + op.getAxis(), + AxisInfoEx::kDefaultContinueSize); + continualInterval.insert(continualInterval.begin() + op.getAxis(), 0); + return AxisInfoEx(divisibility, continualSize, continualInterval, + opInfo.getConstantValue()); + } +}; + +class BroadcastOpAxisInfoExVisitor final + : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + triton::BroadcastOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = cast(_retTy); + TensorType opTy = cast(_opTy); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfoEx opInfo = operands[0]->getValue(); + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + for (int i = 0; i < retTy.getRank(); ++i) { + divisibility.push_back(opInfo.getDivisibility(i)); + continualSize.push_back(opShape[i] == 1 ? retShape[i] + : opInfo.getContinualSize(i)); + continualInterval.push_back( + opShape[i] == 1 ? 0 : opInfo.getContinualInterval(i)); + } + return AxisInfoEx(divisibility, continualSize, continualInterval, + opInfo.getConstantValue()); + } +}; + +class TransOpAxisInfoExVisitor final + : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + triton::TransOp op, + ArrayRef *> operands) override { + ArrayRef trans_order = op.getOrder(); + AxisInfoEx opInfo = operands[0]->getValue(); + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + for (unsigned i = 0; i < trans_order.size(); ++i) { + divisibility.push_back(opInfo.getDivisibility(trans_order[i])); + continualSize.push_back(opInfo.getContinualSize(trans_order[i])); + continualInterval.push_back(opInfo.getContinualInterval(trans_order[i])); + } + return AxisInfoEx(divisibility, continualSize, continualInterval, + opInfo.getConstantValue()); + } +}; + +template +class CmpOpAxisInfoExVisitor final : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return AxisInfoEx(); + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + std::optional constantValue; + for (short d = 0; d < rank; ++d) { + int64_t constancyHint = AxisInfoEx::kDefaultContinueSize; + int64_t continualIntervalHint = AxisInfoEx::kDefaultContinualInterval; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constancyHint = lhsInfo.getConstancy(d); + constantValue = + compare(getPredicate(op), lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value()) + ? 1 + : 0; + continualIntervalHint = 0; + } else if (gtPredicate(getPredicate(op)) || + ltPredicate(getPredicate(op))) { + // Lhs and rhs are both partial constants. + constancyHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)); + auto commonDivisor = + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d)); + if (lhsInfo.isConstantDim(shape, d) && + rhsInfo.isContinualLowDim(shape, d)) { + // Case 2: lhs all constant, rhs all contiguous + // NOTE: + // lhs: k0 * d, k0 * d, ... + // rhs: k1 * d, k1 * d + 1, ... + // lhs lt rhs: 1, 1, 1, 1 (minimal len: d if k0 <= k1) + // lhs lt rhs: 0, 0, 0, 0 (minimal len: d if k0 > k1) + // lhs gt rhs: 0, 0, 0, 0 (minimal len: d if k0 <= k1) + // lhs gt rhs: 1, 1, 1, 1 (minimal len: d if k0 > k1) + constancyHint = std::max( + constancyHint, gcd(rhsInfo.getContiguity(d), commonDivisor)); + } else if (lhsInfo.isContinualLowDim(shape, d) && + rhsInfo.isConstantDim(shape, d)) { + // Case 3: lhs all contiguous, rhs all constant + // NOTE + // lhs: k0 * d, k0 * d + 1, ... + // rhs: k1 * d, k1 * d, ... + // lhs gt rhs: 1, 1, 1, 1 (minimal len: d if k0 >= k1) + // lhs gt rhs: 0, 0, 0, 0 (minimal len: d if k0 < k1) + // lhs lt rhs: 0, 0, 0, 0 (minimal len: d if k0 >= k1) + // lhs lt rhs: 1, 1, 1, 1 (minimal len: d if k0 < k1) + constancyHint = std::max( + constancyHint, gcd(lhsInfo.getContiguity(d), commonDivisor)); + } else if (lhsInfo.isConstantDim(shape, d) && + rhsInfo.isConstantDim(shape, d)) { + // Case 4: lhs all constant, rhs all constant + continualIntervalHint = 0; + } + } + + divisibility.push_back(AxisInfoEx::kInitDivisibility); + continualSize.push_back(constancyHint); + continualInterval.push_back(continualIntervalHint); + } + + return AxisInfoEx(divisibility, continualSize, continualInterval, + constantValue); + } + +private: + static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { + return op.getPredicate(); + } + + static bool gtPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sgt || + predicate == arith::CmpIPredicate::ugt; + } + + static bool gePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sge || + predicate == arith::CmpIPredicate::uge; + } + + static bool ltPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::slt || + predicate == arith::CmpIPredicate::ult; + } + + static bool lePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sle || + predicate == arith::CmpIPredicate::ule; + } + + static bool compare(arith::CmpIPredicate predicate, int64_t lhs, + int64_t rhs) { + switch (predicate) { + case arith::CmpIPredicate::eq: + return lhs == rhs; + case arith::CmpIPredicate::ne: + return lhs != rhs; + case arith::CmpIPredicate::slt: + return lhs < rhs; + case arith::CmpIPredicate::sle: + return lhs <= rhs; + case arith::CmpIPredicate::sgt: + return lhs > rhs; + case arith::CmpIPredicate::sge: + return lhs >= rhs; + case arith::CmpIPredicate::ult: + return static_cast(lhs) < static_cast(rhs); + case arith::CmpIPredicate::ule: + return static_cast(lhs) <= static_cast(rhs); + case arith::CmpIPredicate::ugt: + return static_cast(lhs) > static_cast(rhs); + case arith::CmpIPredicate::uge: + return static_cast(lhs) >= static_cast(rhs); + default: + break; + } + llvm_unreachable("unknown comparison predicate"); + } +}; + +template +class SelectOpAxisInfoExVisitor final : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + auto pResultInfo = operands[0]->getValue(); + auto lhsInfo = operands[1]->getValue(); + auto rhsInfo = operands[2]->getValue(); + auto rank = lhsInfo.getRank(); + + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + std::optional constantValue; + if (pResultInfo.getConstantValue().has_value()) { + if (pResultInfo.getConstantValue() == 0) { + divisibility = rhsInfo.getDivisibility(); + continualSize = rhsInfo.getContinualSize(); + continualInterval = rhsInfo.getContinualInterval(); + constantValue = rhsInfo.getConstantValue(); + } else { + divisibility = lhsInfo.getDivisibility(); + continualSize = lhsInfo.getContinualSize(); + continualInterval = lhsInfo.getContinualInterval(); + constantValue = lhsInfo.getConstantValue(); + } + } else { + bool i1Cond = isa(op.getOperand(0).getType()); + for (auto d = 0; d < rank; ++d) { + if (i1Cond) { + continualSize.push_back( + gcd(lhsInfo.getContinualSize(d), rhsInfo.getContinualSize(d))); + } else { + continualSize.push_back(gcd( + gcd(lhsInfo.getContinualSize(d), pResultInfo.getConstancy(d)), + gcd(rhsInfo.getContinualSize(d), pResultInfo.getConstancy(d)))); + } + continualInterval.push_back(AxisInfoEx::kDefaultContinualInterval); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + } + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value() && + lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) + constantValue = lhsInfo.getConstantValue(); + } + + return AxisInfoEx(divisibility, continualSize, continualInterval, + constantValue); + } +}; + +template +class LogicalOpAxisInfoExVisitor final : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + assert((std::is_same_v || + std::is_same_v || + std::is_same_v) && + "LogicalOp not support"); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + std::optional constantValue; + for (int d = 0; d < rank; ++d) { + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + constantValue = {lhsInfo.getConstantValue().value() & + rhsInfo.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + constantValue = {lhsInfo.getConstantValue().value() | + rhsInfo.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + constantValue = {lhsInfo.getConstantValue().value() ^ + rhsInfo.getConstantValue().value()}; + } + } + if (lhsInfo.getContinualInterval(d) == 0 && + rhsInfo.getContinualInterval(d) == 0) { + divisibility.push_back( + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + continualSize.push_back( + gcd(lhsInfo.getContinualSize(d), rhsInfo.getContinualSize(d))); + continualInterval.push_back(0); + continue; + } + + divisibility.push_back(AxisInfoEx::kInitDivisibility); + continualSize.push_back(AxisInfoEx::kDefaultContinueSize); + continualInterval.push_back(AxisInfoEx::kDefaultContinualInterval); + } + + return AxisInfoEx(divisibility, continualSize, continualInterval, + constantValue); + } +}; + +class ShLIOpAxisInfoExVisitor final + : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getDivisibility(arith::ShLIOp op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto numBits = log2Int(lhs.getDivisibility(dim)); + auto maxBits = log2Int(highestPowOf2Divisor(0)); + // Make sure the return value doesn't exceed + // highestPowOf2Divisor(0). + if (shift + numBits > maxBits) + return highestPowOf2Divisor(0); + return lhs.getDivisibility(dim) << shift; + } + + int64_t getContinualSize(arith::ShLIOp op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + int64_t dimContinueSize = AxisInfoEx::kDefaultContinueSize; + if (rhs.getConstantValue().has_value()) + dimContinueSize = lhs.getContiguity(dim); + return dimContinueSize; + } + + int64_t getContinualInterval(arith::ShLIOp op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + int64_t dimContinualInterval = AxisInfoEx::kDefaultContinualInterval; + if (rhs.getConstantValue().has_value()) { + auto shift = rhs.getConstantValue().value(); + auto numBits = log2Int(shift); + auto maxBits = log2Int(highestPowOf2Divisor(0)); + if (shift + numBits <= maxBits) + dimContinualInterval = lhs.getContinualInterval(dim) + << rhs.getConstantValue().value(); + } + return dimContinualInterval; + } + + std::optional getConstantValue(arith::ShLIOp op, + const AxisInfoEx &lhs, + const AxisInfoEx &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() << rhs.getConstantValue().value()}; + return std::nullopt; + } +}; + +template +class ShROpAxisInfoExVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getDivisibility(OpTy op, const AxisInfoEx &lhs, const AxisInfoEx &rhs, + int dim) override { + if (rhs.getConstantValue().has_value()) + return std::max(AxisInfoEx::kInitDivisibility, + lhs.getDivisibility(dim) / + (1 << rhs.getConstantValue().value())); + return AxisInfoEx::kInitDivisibility; + } + + int64_t getContinualSize(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + int64_t dimContinueSize = AxisInfoEx::kDefaultContinueSize; + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + dimContinueSize = lhs.getContiguity(dim); + return dimContinueSize; + } + + int64_t getContinualInterval(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs, int dim) override { + int64_t dimContinualInterval = AxisInfoEx::kDefaultContinualInterval; + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + dimContinualInterval = lhs.getContinualInterval(dim); + return dimContinualInterval; + } + + std::optional getConstantValue(OpTy op, const AxisInfoEx &lhs, + const AxisInfoEx &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()}; + return std::nullopt; + } +}; + +template +class MaxMinOpAxisInfoExVisitor final : public AxisInfoExVisitorImpl { +public: + using AxisInfoExVisitorImpl::AxisInfoExVisitorImpl; + + AxisInfoEx getAxisInfoEx( + OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + AxisInfoEx::DimVectorT divisibility, continualSize, continualInterval; + std::optional constantValue; + + for (int d = 0; d < rank; ++d) { + const AxisInfoEx *resInfo = nullptr; + if constexpr (std::is_same_v || + std::is_same_v) { + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constantValue = {std::max(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + if (lhsInfo.getConstantValue().value() >= + rhsInfo.getConstantValue().value()) { + resInfo = &lhsInfo; + } else { + resInfo = &rhsInfo; + } + divisibility.push_back(resInfo->getDivisibility(d)); + continualSize.push_back(resInfo->getContinualSize(d)); + continualInterval.push_back(resInfo->getContinualInterval(d)); + continue; + } + } else { + assert((std::is_same_v || + std::is_same_v) && + "MaxMinOp not support"); + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constantValue = {std::min(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + if (lhsInfo.getConstantValue().value() <= + rhsInfo.getConstantValue().value()) { + resInfo = &lhsInfo; + } else { + resInfo = &rhsInfo; + } + divisibility.push_back(resInfo->getDivisibility(d)); + continualSize.push_back(resInfo->getContinualSize(d)); + continualInterval.push_back(resInfo->getContinualInterval(d)); + continue; + } + } + + divisibility.push_back(AxisInfoEx::kInitDivisibility); + continualSize.push_back(AxisInfoEx::kDefaultContinueSize); + continualInterval.push_back(AxisInfoEx::kDefaultContinualInterval); + } + + return AxisInfoEx(divisibility, continualSize, continualInterval, + constantValue); + } +}; + +//===----------------------------------------------------------------------===// +// AxisInfoExAnalysis +//===----------------------------------------------------------------------===// + +AxisInfoExAnalysis::AxisInfoExAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis>( + solver) { + // UnrealizedConversionCast: + // This is needed by TritonGPUToLLVM, to get AxisInfoEx when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast + // may exist + visitors.append, + CastOpAxisInfoExVisitor, + CastOpAxisInfoExVisitor, + CastOpAxisInfoExVisitor, + CastOpAxisInfoExVisitor, + CastOpAxisInfoExVisitor, + CastOpAxisInfoExVisitor>(); + + // when scf.for supports integer induction variables + visitors.append(); + visitors.append, + ConstantOpAxisInfoExVisitor>(); + visitors.append, + AddSubOpAxisInfoExVisitor, + AddSubOpAxisInfoExVisitor, + AddSubOpAxisInfoExVisitor>(); + visitors.append(); + visitors.append, + DivOpAxisInfoExVisitor>(); + visitors.append, + RemOpAxisInfoExVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append>(); + visitors.append, + LogicalOpAxisInfoExVisitor, + LogicalOpAxisInfoExVisitor>(); + visitors.append>(); + visitors + .append, + ShROpAxisInfoExVisitor>(); + visitors.append, + MaxMinOpAxisInfoExVisitor, + MaxMinOpAxisInfoExVisitor, + MaxMinOpAxisInfoExVisitor>(); + visitors.append(); +} + +llvm::LogicalResult AxisInfoExAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + AxisInfoEx curr = visitors.apply(op, operands); + if (curr.getRank() == 0) { + setAllToEntryStates(results); + return success(); + } + // override with hint + auto newDivisibility = curr.getDivisibility(); + auto continualSize = curr.getContinualSize(); + auto continualInterval = curr.getContinualInterval(); + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + newDivisibility = AxisInfoEx::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + continualSize = AxisInfoEx::DimVectorT(vals.begin(), vals.end()); + continualInterval = AxisInfoEx::DimVectorT(vals.size(), 1); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + assert(!op->getAttr("tt.contiguity") && + "Get tt.constancy and tt.contiguity attribute at the same op"); + auto vals = cast(attr).getValues(); + continualSize = AxisInfoEx::DimVectorT(vals.begin(), vals.end()); + continualInterval = AxisInfoEx::DimVectorT(vals.size(), 0); + } + curr = AxisInfoEx(newDivisibility, continualSize, continualInterval, + curr.getConstantValue()); + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(curr)); + + return success(); +} + +void AxisInfoExAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { + ProgramPoint *programPoint = getProgramPointAfter(op); + auto lb = getLatticeElementFor(programPoint, op.getLowerBound())->getValue(); + auto step = getLatticeElementFor(programPoint, op.getStep())->getValue(); + std::optional iv = op.getSingleInductionVar(); + assert(iv && "visitForOpInduction not support"); + auto rank = 1; + TensorType ty = dyn_cast(iv.value().getType()); + if (ty) + rank = ty.getRank(); + + auto divValue = AxisInfoEx::kInitDivisibility; + auto divHint = gcd(lb.getDivisibility(0), step.getDivisibility(0)); + if (divHint != 0) + divValue = divHint; + + AxisInfoEx::DimVectorT knownDivisibility(rank, divValue); + AxisInfoEx::DimVectorT knowContinualSize(rank, + AxisInfoEx::kDefaultContinueSize); + AxisInfoEx::DimVectorT knowContinualInterval( + rank, AxisInfoEx::kDefaultContinualInterval); + auto inductionVar = + AxisInfoEx(knownDivisibility, knowContinualSize, knowContinualInterval); + (void)argLattices[0]->join(inductionVar); +} +} // anonymous namespace + +template +void AxisInfoEx::initPessimisticStateFromFunc(int argNumber, T funcOp, int rank, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy) { + // liast of attributes that we care about + SmallVector> retVecs; + retVecs.push_back({contiguity, "tt.contiguity"}); + retVecs.push_back({divisibility, "tt.divisibility"}); + retVecs.push_back({constancy, "tt.constancy"}); + // initialize attributes one by one + for (auto [vec, attrName] : retVecs) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto int_attr = dyn_cast_or_null(attr)) + *vec = DimVectorT(rank, int_attr.getValue().getZExtValue()); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + *vec = DimVectorT(vals.begin(), vals.end()); + } + } +} + +/*static*/ AxisInfoEx AxisInfoEx::getPessimisticValueState(Value value) { + int rank = 1; + if (TensorType ty = dyn_cast(value.getType())) + rank = ty.getRank(); + if (triton::PointerType ty = dyn_cast(value.getType())) + if (TensorType elemTy = dyn_cast(ty.getPointeeType())) + rank = elemTy.getRank(); + + DimVectorT continualSize(rank, kDefaultContinueSize); + DimVectorT continualInterval(rank, kDefaultContinualInterval); + DimVectorT knownDivisibility, knownContiguity, knownConstancy; + BlockArgument blockArg = dyn_cast(value); + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + auto fun = dyn_cast(op); + if (!fun) + fun = dyn_cast(op); + + if (fun) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, rank, + &knownContiguity, &knownDivisibility, + &knownConstancy); + } else if (Operation *op = value.getDefiningOp()) { + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + knownDivisibility = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + knownContiguity = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + knownConstancy = DimVectorT(vals.begin(), vals.end()); + } + } + + if (knownDivisibility.empty()) { + knownDivisibility = DimVectorT(rank, kInitDivisibility); + } + if (!knownConstancy.empty()) { + assert(knownContiguity.empty() && + "Get tt.constancy and tt.contiguity attribute at the same arg"); + continualSize = knownConstancy; + continualInterval = DimVectorT(rank, 0); + } + if (!knownContiguity.empty()) { + continualSize = knownContiguity; + continualInterval = DimVectorT(rank, 1); + } + + return AxisInfoEx(knownDivisibility, continualSize, continualInterval); +} + +/*static*/ AxisInfoEx AxisInfoEx::join(const AxisInfoEx &lhs, + const AxisInfoEx &rhs) { + auto lhsRank = lhs.getRank(); + auto rhsRank = rhs.getRank(); + // If one argument is not initialized, return the other. + if (lhsRank == 0) + return rhs; + if (rhsRank == 0) + return lhs; + assert(lhsRank == rhsRank && "lhsRank and rhsRank are mismatch"); + + DimVectorT divisibility(lhsRank, kInitDivisibility); + DimVectorT continualSize(lhsRank, kDefaultContinueSize); + DimVectorT continualInterval(lhsRank, kDefaultContinualInterval); + for (auto i = 0; i < lhsRank; ++i) { + divisibility[i] = (gcd(lhs.getDivisibility(i), rhs.getDivisibility(i))); + continualSize[i] = (gcd(lhs.getContinualSize(i), rhs.getContinualSize(i))); + if (lhs.continualInterval[i] == rhs.continualInterval[i]) + continualInterval[i] = lhs.continualInterval[i]; + } + std::optional constantValue; + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value() && + lhs.getConstantValue() == rhs.getConstantValue()) + constantValue = lhs.getConstantValue(); + return AxisInfoEx(divisibility, continualSize, continualInterval, + constantValue); +} + +unsigned ModuleAxisInfoExAnalysis::getPtrContiguity(Value ptr) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto layout = tensorTy.getEncoding(); + + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = triton::gpu::getOrder(tensorTy); + unsigned align = getPtrAlignment(ptr); + + // auto uniqueContigPerThread = + // triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + // assert(order[0] < uniqueContigPerThread.size() && + // "Unexpected uniqueContigPerThread size"); + // unsigned contiguity = uniqueContigPerThread[order[0]]; + // LDBG("getPtrContiguity uniqueContigPerThread = " << contiguity); + // contiguity = std::min(align, contiguity); + + return 0; +} + +unsigned ModuleAxisInfoExAnalysis::getPtrAlignment(Value ptr) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfoEx(ptr); + if (!axisInfo) + return 1; + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(tensorTy); + auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto maxContig = axisInfo->getContiguity(order[0]); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); + unsigned alignment = std::min(maxMultiple, maxContig); + LDBG("getPtrAlignment order[0] " + << order[0] << " maxMultipleBytes = " << maxMultipleBytes + << " maxContig = " << maxContig << " elemNumBits = " << elemNumBits + << " maxMultiple = " << maxMultiple << " alignment " << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +unsigned ModuleAxisInfoExAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = dyn_cast(mask.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfoEx(mask); + if (!axisInfo) + return 1; + auto maskOrder = triton::gpu::getOrder(tensorTy); + auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); + LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " + << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +void ModuleAxisInfoExAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoExAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *axisInfoMap = getFuncData(funcOp); + auto updateAxisInfoMap = [&](Value value) { + auto axisInfoEx = analysis->getLatticeElement(value)->getValue(); + AxisInfoEx curAxisInfo; + if (axisInfoMap->count(value)) { + curAxisInfo = AxisInfoEx::join(axisInfoEx, axisInfoMap->lookup(value)); + } else { + curAxisInfo = axisInfoEx; + } + (*axisInfoMap)[value] = curAxisInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateAxisInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateAxisInfoMap(value); + } + }); +} + +void ModuleAxisInfoExAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *axisInfoExMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, int64_t prevValue) { + auto curValue = highestPowOf2Divisor(0); + if (callee.getArgAttrOfType(index, attrName)) { + curValue = + callee.getArgAttrOfType(index, attrName).getInt(); + } + auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64), + gcd(prevValue, curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto axisInfoEx = axisInfoExMap->lookup(value); + assert(axisInfoEx.getRank() == 1 && "only scalar arguments are supported"); + setAttrFn("tt.divisibility", axisInfoEx.getDivisibility(0)); + if (axisInfoEx.getContinualInterval(0) == 0) + setAttrFn("tt.constancy", axisInfoEx.getContinualSize(0)); + else if (axisInfoEx.getContinualInterval(0) == 1) + setAttrFn("tt.contiguity", axisInfoEx.getContinualSize(0)); + } +} + +} // namespace gcu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..48a6c4d53 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_library(TritonGCUAnalysis_${arch} + AxisInfoEx.cpp + FirstLastUserAnalysis.cpp + MaskAnalysis.cpp + PtrAnalysis.cpp + OpFoldResultUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/triton_gcu/Analysis/ + + DEPENDS + TritonGCUTableGen_${arch} + triton_${arch} + + LINK_LIBS PUBLIC + MLIRAnalysis +) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/FirstLastUserAnalysis.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/FirstLastUserAnalysis.cpp new file mode 100644 index 000000000..69d6bf4bb --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/FirstLastUserAnalysis.cpp @@ -0,0 +1,471 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Analysis/FirstLastUserAnalysis.h" + +#include +#include + +#include "Conversion/TritonToGCU/Utils.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "first-last-user-analysis" + +namespace mlir { +namespace triton { +namespace gcu { + +namespace { + +bool isMustAliasOp(mlir::Operation *op) { + if (llvm::isa(op)) { + return true; + } else if (llvm::isa(op)) { + auto convertLayout = cast(op); + auto src = convertLayout.getSrc(); + auto srcNumElems = triton::gcu::getElemsPerThread(src.getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(convertLayout.getType()); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(convertLayout.getType()); + if ((!srcTy) || (!dstTy)) { + assert(false && "srcTy or dstTy not a RankedTensorType"); + } + + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (srcLayout == dstLayout) { + return true; + } + if (srcNumElems == dstNumElems && + src.getType().getShape() == convertLayout.getType().getShape()) { + if (!mlir::isa(srcLayout)) { + return true; + } else if (isa(srcLayout) && + isa(dstLayout)) { + if (cast(srcLayout).getDim() == + cast(dstLayout).getDim()) { + return true; + } + } + } + return false; + } else if (isa(op)) { + auto expandDimOp = cast(op); + auto srcNumElems = + triton::gcu::getElemsPerThread(expandDimOp.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(expandDimOp.getType()); + srcNumElems.insert(srcNumElems.begin() + expandDimOp.getAxis(), 1); + if (srcNumElems == dstNumElems) { + return true; + } + return false; + } else if (isa(op)) { + auto reshapeOp = cast(op); + auto srcNumElems = + triton::gcu::getElemsPerThread(reshapeOp.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(reshapeOp.getType()); + if (srcNumElems == dstNumElems) { + return true; + } + return false; + } else if (isa(op)) { + auto broastOp = cast(op); + auto srcNumElems = + triton::gcu::getElemsPerThread(broastOp.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(broastOp.getType()); + if (srcNumElems == dstNumElems) { + return true; + } + return false; + } else { + return false; + } +} + +template +void getUsersWithAlias(mlir::Operation *op, DomInfoT &domInfo, FuncT &funcInfo, + std::vector &userList, + std::vector &blockList, + std::vector &aliasList) { + auto opRegion = op->getParentRegion(); + for (auto user : op->getUsers()) { + if (llvm::isa(user)) { + llvm::report_fatal_error("double free please checek IR"); + } + + if (user->getParentRegion() == opRegion) { + userList.push_back(user); + blockList.push_back(user->getBlock()); + if (isMustAliasOp(user) || llvm::isa(user)) { + aliasList.push_back(user); + } + } else { + auto parent = user->getParentOp(); + auto curUser = user; + bool mayAlias = + llvm::isa( + parent) && + llvm::isa(curUser); + + while ((!isa(parent)) && + (!isa(parent))) { + if (parent->getParentRegion() == opRegion) + break; + + curUser = mayAlias ? funcInfo(parent, domInfo) : nullptr; + parent = parent->getParentOp(); + mayAlias = + llvm::isa( + parent) && + llvm::isa_and_nonnull(curUser); + } + + if (parent->getParentRegion() != opRegion) { + parent->dump(); + op->dump(); + llvm::report_fatal_error("invalid user please checek IR"); + } + userList.push_back(parent); + blockList.push_back(parent->getBlock()); + if (mayAlias) { + aliasList.push_back(parent); + } + } + } +} + +mlir::Operation *getLastUserOfOp(mlir::Operation *op, + PostDominanceInfo &postDomInfo) { + auto opRegion = op->getParentRegion(); + std::vector userList; + std::vector blockList; + std::vector aliasList; + + getUsersWithAlias(op, postDomInfo, getLastUserOfOp, userList, blockList, + aliasList); + + // Analysis alias op + while (!aliasList.empty()) { + std::vector tmpList(aliasList.size()); + std::copy(aliasList.begin(), aliasList.end(), tmpList.begin()); + aliasList.clear(); + for (auto tmp : tmpList) { + getUsersWithAlias(tmp, postDomInfo, getLastUserOfOp, userList, blockList, + aliasList); + } + } + + if (blockList.empty()) + return nullptr; + + Block *dom = postDomInfo.findNearestCommonDominator(blockList); + + /** B0 + // / \ + // v v + // B1 <- B2 + // | + // v + // B3 + // 1). B1 and B3 has a "return" op. + // 2). B2 and B3 has a use for a alloc op which locate in B0. At the time, + // the "postDomInfo.findNearestCommonDominator" return nullptr + **/ + if (dom == nullptr) { + auto lastBlock = blockList[0]; + for (auto iter = opRegion->rbegin(); iter != opRegion->rend(); iter++) { + auto index = std::find(blockList.begin(), blockList.end(), &(*iter)); + if (index != blockList.end()) { + lastBlock = &(*iter); + break; + } + } + dom = lastBlock; + } + + if (dom->empty()) { + llvm::report_fatal_error("dominator block is empty"); + } + + mlir::Operation *lastUser = nullptr; + + // Block dom maybe is not in the blockList + auto lastBlockIter = std::find(blockList.begin(), blockList.end(), dom); + if (lastBlockIter == blockList.end()) { + lastUser = &(dom->front()); + } else { + for (size_t i = 0; i < userList.size(); ++i) { + if (dom != userList[i]->getBlock()) { + continue; + } + + if (lastUser == nullptr) { + lastUser = userList[i]; + continue; + } + + if (lastUser->isBeforeInBlock(userList[i])) { + lastUser = userList[i]; + } + } + } + + if (lastUser && isMustAliasOp(lastUser)) { + lastUser = nullptr; + } + return lastUser; +} + +mlir::Operation *getFirstUserOfOp(mlir::Operation *op, DominanceInfo &domInfo) { + auto opRegion = op->getParentRegion(); + + std::vector userList; + std::vector blockList; + std::vector aliasList; + + getUsersWithAlias(op, domInfo, getFirstUserOfOp, userList, blockList, + aliasList); + + // Analysis alias op + while (!aliasList.empty()) { + std::vector tmpList(aliasList.size()); + std::copy(aliasList.begin(), aliasList.end(), tmpList.begin()); + aliasList.clear(); + for (auto tmp : tmpList) { + getUsersWithAlias(tmp, domInfo, getFirstUserOfOp, userList, blockList, + aliasList); + } + } + + if (blockList.empty()) + return nullptr; + + Block *dom = domInfo.findNearestCommonDominator(blockList); + if (dom == nullptr) { + llvm::report_fatal_error("cannot find nearest common dominator block"); + } + + if (dom->empty()) { + llvm::report_fatal_error("dominator block is empty"); + } + + mlir::Operation *firstUser = nullptr; + auto firstBlockIter = std::find(blockList.begin(), blockList.end(), dom); + if (firstBlockIter == blockList.end()) { + firstUser = &(dom->back()); + } else { + for (size_t i = 0; i < userList.size(); ++i) { + if (dom == userList[i]->getBlock()) { + if (firstUser == nullptr) { + firstUser = userList[i]; + continue; + } + if (userList[i]->isBeforeInBlock(firstUser)) { + firstUser = userList[i]; + } + } + } + } + + if (firstUser && isMustAliasOp(firstUser)) { + firstUser = nullptr; + } + + if (firstUser == nullptr) { + return nullptr; + } + + auto nextOp = op->getNextNode(); + while (isa(nextOp)) { + nextOp = nextOp->getNextNode(); + } + if (nextOp == firstUser) { + firstUser = nullptr; + } + return firstUser; +} + +} // namespace + +mlir::Operation *FirstLastUserAnalysis::getLastUserOp(mlir::Value value, + mlir::Region *opRegion) { + std::vector userList; + std::vector blockList; + std::vector aliasList; + + for (auto user : value.getUsers()) { + if (llvm::isa(user)) { + llvm::report_fatal_error("double free please checek IR"); + } + + if (user->getParentRegion() == opRegion) { + userList.push_back(user); + blockList.push_back(user->getBlock()); + if (isMustAliasOp(user)) { + aliasList.push_back(user); + } + } else { + auto parent = user->getParentOp(); + auto curUser = user; + bool mayAlias = llvm::isa(parent) && + llvm::isa(curUser); + + while ((!isa(parent)) && + (!isa(parent))) { + if (parent->getParentRegion() == opRegion) + break; + + curUser = mayAlias ? getLastUserOfOp(parent, postDominators) : nullptr; + parent = parent->getParentOp(); + mayAlias = llvm::isa(parent) && + llvm::isa_and_nonnull(curUser); + } + + if (parent->getParentRegion() != opRegion) { + parent->dump(); + value.dump(); + llvm_unreachable("invalid user please checek IR 1"); + } + userList.push_back(parent); + blockList.push_back(parent->getBlock()); + if (mayAlias) { + aliasList.push_back(parent); + } + } + } + + // Analysis alias op + while (!aliasList.empty()) { + std::vector tmpList(aliasList.size()); + std::copy(aliasList.begin(), aliasList.end(), tmpList.begin()); + aliasList.clear(); + for (auto tmp : tmpList) { + getUsersWithAlias(tmp, postDominators, getLastUserOfOp, userList, + blockList, aliasList); + } + } + + if (blockList.empty()) + return nullptr; + + Block *dom = postDominators.findNearestCommonDominator(blockList); + + /** B0 + // / \ + // B1 <- B2 + // | + // B3 + // 1). B1 and B3 has a "return" op. + // 2). B2 and B3 has a use for a alloc op which locate in B0. At the time, + // the "postDominators.findNearestCommonDominator" return nullptr + **/ + if (dom == nullptr) { + auto lastBlock = blockList[0]; + for (auto iter = opRegion->rbegin(); iter != opRegion->rend(); iter++) { + auto index = std::find(blockList.begin(), blockList.end(), &(*iter)); + if (index != blockList.end()) { + lastBlock = &(*iter); + break; + } + } + dom = lastBlock; + } + + if (dom->empty()) { + llvm::report_fatal_error("dominator block is empty"); + } + + mlir::Operation *lastUser = nullptr; + + // Block dom maybe is not in the blockList + auto lastBlockIter = std::find(blockList.begin(), blockList.end(), dom); + if (lastBlockIter == blockList.end()) { + lastUser = &(dom->front()); + } else { + for (size_t i = 0; i < userList.size(); ++i) { + if (dom != userList[i]->getBlock()) { + continue; + } + + if (lastUser == nullptr) { + lastUser = userList[i]; + continue; + } + + if (lastUser->isBeforeInBlock(userList[i])) { + lastUser = userList[i]; + } + } + } + + if (lastUser && isMustAliasOp(lastUser)) { + lastUser = nullptr; + } + return lastUser; +} + +void FirstLastUserAnalysis::start() { + assert(llvm::isa(moduleOp) && + "The input operation is not a gpu module"); + moduleOp->walk([&](mlir::Operation *_op) { + if (_op->getResults().empty()) + return; + + if (llvm::isa(_op) && + llvm::any_of(_op->getResultTypes(), llvm::IsaPred)) { + lastUserMap[_op] = getLastUserOfOp(_op, postDominators); + } else if (llvm::isa< + scf::IfOp, scf::IndexSwitchOp, scf::WhileOp, scf::ForOp, + triton::SplatOp, arith::ConstantOp, triton::AddPtrOp, + triton::PtrToIntOp, triton::IntToPtrOp, + triton::gcu::PtrToIntOp, triton::gcu::IntToPtrOp, + triton::MulhiUIOp, triton::ScanOp, triton::HistogramOp, + triton::gcu::LoadOp, triton::LoadOp, triton::BroadcastOp, + triton::ExpandDimsOp, triton::ReshapeOp, triton::SplitOp, + triton::JoinOp, triton::CatOp, triton::gcu::MatmulOp, + triton::DotOp, triton::ReduceOp, triton::MakeRangeOp, + triton::BitcastOp, triton::gcu::ElementwiseFusionRegionOp>( + _op)) { + lastUserMap[_op] = getLastUserOfOp(_op, postDominators); + } else if (llvm::isa(_op)) { + lastUserMap[_op] = getLastUserOfOp(_op, postDominators); + firstUserMap[_op] = getFirstUserOfOp(_op, dominators); + } + }); +} +} // namespace gcu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/MaskAnalysis.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/MaskAnalysis.cpp new file mode 100644 index 000000000..34f0888aa --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/MaskAnalysis.cpp @@ -0,0 +1,574 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "Analysis/MaskAnalysis.h" + +#include "Analysis/OpFoldResultUtils.h" +#include "llvm/Support/Debug.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#define DEBUG_TYPE "mask-analysis" + +namespace mlir { +namespace triton { +namespace gcu { + +static int64_t kIndentSpaceNum = 0; + +static void printBeforeVisit(Operation *op) { + auto spaces = std::string(kIndentSpaceNum, ' '); + kIndentSpaceNum += 3; + + LLVM_DEBUG({ + llvm::dbgs() << spaces << "=== visit operand of " << op->getName() + << " ENTER === \n"; + llvm::dbgs() << spaces; + op->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +static void printAfterVisit(Operation *op) { + kIndentSpaceNum -= 3; + auto spaces = std::string(kIndentSpaceNum, ' '); + + LLVM_DEBUG({ + llvm::dbgs() << spaces << "=== visit operand of " << op->getName() + << " EXIT ===\n"; + }); + + if (kIndentSpaceNum == 0) { + LLVM_DEBUG(llvm::dbgs() << "\n"); + } +} + +void MaskState::addStateScalar(OpBuilder &builder, Location loc, + const MaskState &state, + const OpFoldResult scalar) { + this->start = addOFRs(builder, loc, state.start, scalar); + this->end = addOFRs(builder, loc, state.end, scalar); + this->dims = state.dims; +} + +void MaskState::addStates(OpBuilder &builder, Location loc, + const MaskState &lhsState, + const MaskState &rhsState) { + assert(((lhsState.scalar && !rhsState.scalar) || + (!lhsState.scalar && rhsState.scalar) || + (lhsState.scalar && rhsState.scalar)) && + "unsupported scenario where neither lhs nor rhs is a scalar\n"); + + if (lhsState.scalar && rhsState.scalar) { + this->scalar = addOFRs(builder, loc, lhsState.scalar, rhsState.scalar); + this->dims = lhsState.getRank() != 0 ? lhsState.dims : rhsState.dims; + return; + } + + if (lhsState.scalar) + addStateScalar(builder, loc, rhsState, lhsState.scalar); + else + addStateScalar(builder, loc, lhsState, rhsState.scalar); +} + +void MaskState::minStates(OpBuilder &builder, Location loc, + const MaskState &lhsState, + const MaskState &rhsState) { + assert((lhsState.getRank() == rhsState.getRank()) && + "unexpected case where lhs and rhs have different ranks"); + + for (int64_t i = 0; i < lhsState.getRank(); ++i) { + auto lhsDim = lhsState.dims[i]; + auto rhsDim = rhsState.dims[i]; + this->dims.push_back(minOFRs(builder, loc, lhsDim, rhsDim)); + } +} + +void MaskState::setStates(OpBuilder &builder, Location loc, + const MaskState &srcState) { + if (srcState.start) + this->start = srcState.start; + if (srcState.end) + this->end = srcState.end; + if (srcState.scalar) + this->scalar = srcState.scalar; + + for (int64_t i = 0; i < srcState.getRank(); ++i) { + this->dims.push_back(srcState.dims[i]); + } +} + +void MaskAnalysis::parse(OpBuilder &builder, Location loc, Value operand, + MaskState &state, + llvm::SmallDenseMap &knownMasks) { + LLVM_DEBUG(llvm::dbgs() << std::string(kIndentSpaceNum, ' ') << "enter parse " + << operand << "\n"); + if (knownMasks.find(operand) != knownMasks.end()) { + state = knownMasks.lookup(operand); + LLVM_DEBUG(llvm::dbgs() << std::string(kIndentSpaceNum, ' ') + << "operand is a known mask " << operand << "\n"); + return; + } + + if (isa(operand.getType())) { + LLVM_DEBUG(llvm::dbgs() << std::string(kIndentSpaceNum, ' ') + << "operand is an integer " << operand << "\n"); + parseIntScalar(builder, loc, operand, state, knownMasks); + return; + } + + if (auto arg = dyn_cast(operand)) { + LLVM_DEBUG(llvm::dbgs() << std::string(kIndentSpaceNum, ' ') + << "operand is block argument " << arg << "\n"); + parseBlockArgument(builder, loc, arg, state, knownMasks); + return; + } + + if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseConstant(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseAdd(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseAnd(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseCmp(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseMakeRange(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseBroadcast(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseSplat(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseExpandDims(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseDot(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseRemsi(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseSelect(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseReduce(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseLoad(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseExtsi(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + parseExtui(builder, loc, op, state, knownMasks); + printAfterVisit(op); + } else { + operand.dump(); + llvm_unreachable("unexpected to parse the operand\n"); + } +} + +void MaskAnalysis::parseBlockArgument( + OpBuilder &builder, Location loc, BlockArgument blockArg, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + assert(isa(blockArg.getOwner()->getParentOp())); + + auto forOp = cast(blockArg.getOwner()->getParentOp()); + + if (blockArg.getArgNumber() == 0) { + auto castOp = builder.create( + loc, builder.getIndexType(), forOp.getInductionVar()); + state.scalar = castOp.getResult(); + } else { + auto regionIterIndex = + blockArg.getArgNumber() - forOp.getNumInductionVars(); + Value regionArg = forOp.getRegionIterArgs()[regionIterIndex]; + assert(knownMasks.count(regionArg) != 0 && + "can't find value in knownMasks"); + state = knownMasks.lookup(regionArg); + } +} + +void MaskAnalysis::parseConstant( + OpBuilder &builder, Location loc, arith::ConstantOp constOp, + MaskState &state, llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + if (isa(constOp.getValue())) { + auto attr = cast(constOp.getValue()); + auto elementType = attr.getElementType(); + (void)elementType; + assert(attr.isSplat() && isa(elementType) && + "all elements must share a single integer constant value"); + auto values = attr.getValues(); + auto value = values[0].getValue(); + state.scalar = builder.getIndexAttr(value.getSExtValue()); + + auto dst = constOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + for (auto s : dstShape) + state.dims.push_back(builder.getIndexAttr(s)); + } else { + auto value = cast(constOp.getValue()).getInt(); + state.scalar = builder.getIndexAttr(value); + } +} + +void MaskAnalysis::parseIntScalar( + OpBuilder &builder, Location loc, Value scalar, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + // scalar is the first argument of ForOp, which means it is an induction var + // if (auto arg = dyn_cast(scalar)) { + // if (isa(arg.getOwner()->getParentOp()) && + // arg.getArgNumber() == 0) { + // auto forOp = cast(arg.getOwner()->getParentOp()); + // auto castOp = builder.create( + // loc, builder.getIndexType(), forOp.getInductionVar()); + // state.scalar = castOp.getResult(); + // return; + // } + // } + auto castOp = + builder.create(loc, builder.getIndexType(), scalar); + state.scalar = castOp.getResult(); +} + +void MaskAnalysis::parseAdd(OpBuilder &builder, Location loc, + arith::AddIOp addOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + MaskState lhsState; + parse(builder, loc, addOp.getLhs(), lhsState, knownMasks); + assert(!lhsState.isEmpty()); + + MaskState rhsState; + parse(builder, loc, addOp.getRhs(), rhsState, knownMasks); + assert(!rhsState.isEmpty()); + + state.addStates(builder, loc, lhsState, rhsState); +} + +void MaskAnalysis::parseAnd(OpBuilder &builder, Location loc, + arith::AndIOp andOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + MaskState lhsState; + parse(builder, loc, andOp.getLhs(), lhsState, knownMasks); + assert(!lhsState.isEmpty()); + + MaskState rhsState; + parse(builder, loc, andOp.getRhs(), rhsState, knownMasks); + assert(!rhsState.isEmpty()); + + state.minStates(builder, loc, lhsState, rhsState); + if (lhsState.start && rhsState.start) + state.start = maxOFRs(builder, loc, lhsState.start, rhsState.start); + else if (lhsState.start) + state.start = lhsState.start; + else if (rhsState.start) + state.start = rhsState.start; +} + +void MaskAnalysis::parseCmp(OpBuilder &builder, Location loc, + arith::CmpIOp cmpOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + assert(cmpOp.getPredicate() == arith::CmpIPredicate::slt || + cmpOp.getPredicate() == arith::CmpIPredicate::ult || + cmpOp.getPredicate() == arith::CmpIPredicate::sge || + cmpOp.getPredicate() == arith::CmpIPredicate::uge); + + MaskState lhsState; + parse(builder, loc, cmpOp.getLhs(), lhsState, knownMasks); + assert(!lhsState.isEmpty()); + + MaskState rhsState; + parse(builder, loc, cmpOp.getRhs(), rhsState, knownMasks); + assert(!rhsState.isEmpty()); + + // Process 1x1 tensor + bool allOnes = true; + for (int64_t i = 0; i < lhsState.getRank(); ++i) { + auto dimIntAttr = getIntAttr(lhsState.dims[i]); + allOnes &= (dimIntAttr && (dimIntAttr.value() == 1)); + } + if (allOnes) { + assert((lhsState.scalar && rhsState.scalar) && "unsupported cmpi scenario"); + + arith::CmpIOp cmpiOp = builder.create( + loc, cmpOp.getPredicate(), getValue(builder, loc, lhsState.scalar), + getValue(builder, loc, rhsState.scalar)); + + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value result = builder.create(loc, cmpiOp, one, zero); + for (int64_t i = 0; i < lhsState.getRank(); ++i) { + state.dims.push_back(result); + } + return; + } + + assert((!lhsState.scalar && rhsState.scalar) && "unsupported cmpi scenario"); + int64_t cmpDim = -1; + for (int64_t i = 0; i < lhsState.getRank(); ++i) { + auto dimIntAttr = getIntAttr(lhsState.dims[i]); + if (!dimIntAttr || dimIntAttr.value() != 1) { + assert((cmpDim == -1) && "unsupported cmpi with more than one " + "dimension with size larger than 1"); + cmpDim = i; + } + } + assert(cmpDim != -1 && + "unexpected case where no dimension has size larger than 1"); + + auto newDim = lhsState.dims[cmpDim]; + if (cmpOp.getPredicate() == arith::CmpIPredicate::slt || + cmpOp.getPredicate() == arith::CmpIPredicate::ult) { + auto newEnd = minOFRs(builder, loc, lhsState.end, rhsState.scalar); + newDim = subOFRs(builder, loc, newEnd, lhsState.start); + } else { + auto newstart = maxOFRs(builder, loc, lhsState.start, rhsState.scalar); + state.start = newstart; + } + + for (int64_t i = 0; i < lhsState.getRank(); ++i) { + if (i == cmpDim) + state.dims.push_back(newDim); + else + state.dims.push_back(lhsState.dims[i]); + } +} + +void MaskAnalysis::parseMakeRange( + OpBuilder &builder, Location loc, triton::MakeRangeOp rangeOp, + MaskState &state, llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + auto shape = cast(rangeOp.getType()).getShape(); + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + (void)stride; + + assert((stride == 1) && + "stride must be 1 for make_range whose result is used " + "as load or store masks"); + + state.start = builder.getIndexAttr(start); + state.end = builder.getIndexAttr(end); + state.dims.push_back(builder.getIndexAttr(shape[0])); +} + +void MaskAnalysis::parseBroadcast( + OpBuilder &builder, Location loc, triton::BroadcastOp broadcastOp, + MaskState &state, llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + parse(builder, loc, src, state, knownMasks); + + for (uint64_t i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + state.dims[i] = builder.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast\n"); + } +} + +void MaskAnalysis::parseSplat( + OpBuilder &builder, Location loc, triton::SplatOp splatOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + assert(isa(splatOp.getSrc().getType()) && + "splat source must be an integer scalar for load/store masks"); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + // splat bool means either all or none are masked + if (src.getType().getIntOrFloatBitWidth() == 1) { + Value zero = builder.create(loc, 0); + for (auto s : dstShape) { + Value dim = builder.create(loc, s); + Value result = builder.create(loc, src, dim, zero); + state.dims.push_back(result); + } + } else { + parse(builder, loc, src, state, knownMasks); + for (auto s : dstShape) + state.dims.push_back(builder.getIndexAttr(s)); + } +} + +void MaskAnalysis::parseExpandDims( + OpBuilder &builder, Location loc, triton::ExpandDimsOp expandDimsOp, + MaskState &state, llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + parse(builder, loc, expandDimsOp.getSrc(), state, knownMasks); + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + (void)dstShape; + assert(dstShape[axis] == 1 && + "expect changed dimension to be 1 in expand_dims"); + state.dims.insert(state.dims.begin() + axis, builder.getIndexAttr(1)); +} + +void MaskAnalysis::parseDot(OpBuilder &builder, Location loc, + triton::DotOp dotOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + MaskState srcState; + parse(builder, loc, dotOp.getC(), srcState, knownMasks); + assert(!srcState.isEmpty()); + + state.start = srcState.start; + state.end = srcState.end; + state.scalar = srcState.scalar; + for (int64_t i = 0; i < srcState.getRank(); ++i) + state.dims.push_back(srcState.dims[i]); +} + +void MaskAnalysis::parseRemsi( + OpBuilder &builder, Location loc, arith::RemSIOp RemSIOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + MaskState lhsState; + parse(builder, loc, RemSIOp.getLhs(), lhsState, knownMasks); + assert(!lhsState.isEmpty()); + + MaskState rhsState; + parse(builder, loc, RemSIOp.getRhs(), rhsState, knownMasks); + assert(!rhsState.isEmpty()); + + state.addStates(builder, loc, lhsState, rhsState); +} + +void MaskAnalysis::parseSelect( + OpBuilder &builder, Location loc, arith::SelectOp SelectOp, + MaskState &state, llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + MaskState trueState; + parse(builder, loc, SelectOp.getTrueValue(), trueState, knownMasks); + assert(!trueState.isEmpty()); + + state.setStates(builder, loc, trueState); +} + +void MaskAnalysis::parseReduce( + OpBuilder &builder, Location loc, triton::ReduceOp ReduceOp, + MaskState &state, llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + auto src = ReduceOp.getSrcs()[0]; + auto axis = ReduceOp.getAxis(); + + MaskState srcState; + parse(builder, loc, src, srcState, knownMasks); + + if (srcState.start) + state.start = srcState.start; + if (srcState.end) + state.end = srcState.end; + if (srcState.scalar) + state.scalar = srcState.scalar; + for (uint32_t i = 0; i < srcState.dims.size(); ++i) { + if (i != axis) { + state.dims.push_back(srcState.dims[i]); + } + } +} + +void MaskAnalysis::parseLoad( + OpBuilder &builder, Location loc, triton::LoadOp LoadOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + + MaskState srcState; + parse(builder, loc, LoadOp.getPtr(), srcState, knownMasks); + assert(!srcState.isEmpty()); + + state.setStates(builder, loc, srcState); +} + +void MaskAnalysis::parseExtsi( + OpBuilder &builder, Location loc, arith::ExtSIOp ExtSIOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + MaskState inState; + + parse(builder, loc, ExtSIOp.getIn(), inState, knownMasks); + assert(!inState.isEmpty()); + + state.setStates(builder, loc, inState); +} + +void MaskAnalysis::parseExtui( + OpBuilder &builder, Location loc, arith::ExtUIOp ExtUIOp, MaskState &state, + llvm::SmallDenseMap &knownMasks) { + assert(state.isEmpty()); + MaskState inState; + + parse(builder, loc, ExtUIOp.getIn(), inState, knownMasks); + assert(!inState.isEmpty()); + + state.setStates(builder, loc, inState); +} + +} // namespace gcu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/OpFoldResultUtils.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/OpFoldResultUtils.cpp new file mode 100644 index 000000000..1e2008b0e --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/OpFoldResultUtils.cpp @@ -0,0 +1,335 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "Analysis/OpFoldResultUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gcu { + +std::optional getIntAttr(const OpFoldResult ofr) { + if (llvm::isa(ofr) && isa(ofr.get())) + return dyn_cast(ofr.get()).getInt(); + + return std::nullopt; +} + +Value getValue(OpBuilder &builder, Location loc, const OpFoldResult ofr) { + if (auto attr = getIntAttr(ofr)) { + return builder.create(loc, attr.value()) + .getResult(); + } else { + return dyn_cast(ofr); + } +} + +llvm::SmallVector getValues(OpBuilder &builder, Location loc, + const llvm::SmallVector &ofr) { + llvm::SmallVector values; + for (uint64_t i = 0; i < ofr.size(); ++i) { + values.push_back(getValue(builder, loc, ofr[i])); + } + return values; +} + +// Extract a scalar value from v. +// If v is a scalar, return that directly. Otherwise, parse through operations +// (currently only support splat and sitofp) that produce it and to extract they +// underlying scalar value . If no scalar value can be extracted, a nullptr is +// returned. +std::optional getScalarValue(OpBuilder &builder, Location loc, Value v) { + // Record if an sitofp op was in the chain of ops that produce the scalar + Operation *siToFp = nullptr; + + while (true) { + if (!dyn_cast(v.getType())) { + break; + } + + if (auto op = v.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + + auto typedAttr = attr.getSplatValue(); + v = builder.create(loc, attr.getElementType(), + typedAttr); + } + } else if (auto op = v.getDefiningOp()) { + v = op.getSrc(); + } else if (auto op = v.getDefiningOp()) { + siToFp = op; + v = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + + if (siToFp) { + auto resType = siToFp->getResult(0).getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, v); + } + + return v; +} + +OpFoldResult addOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // shortcut for special cases + if (!lhsIntAttr && rhsIntAttr && rhsIntAttr.value() == 0) + return lhs; + if (!rhsIntAttr && lhsIntAttr && lhsIntAttr.value() == 0) + return rhs; + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return builder.getIndexAttr(lhsIntAttr.value() + rhsIntAttr.value()); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = builder.create( + loc, builder.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } else { + assert(isa(lhsValue.getType())); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = builder.create( + loc, builder.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } else { + assert(isa(lhsValue.getType())); + } + + return builder.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult subOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // shortcut for special cases + if (!lhsIntAttr && rhsIntAttr && rhsIntAttr.value() == 0) + return lhs; + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return builder.getIndexAttr(lhsIntAttr.value() - rhsIntAttr.value()); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = builder.create( + loc, builder.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = builder.create( + loc, builder.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto sumOp = builder.create(loc, lhsValue, rhsValue); + return sumOp.getResult(); +} + +OpFoldResult mulOFRValue(OpBuilder &builder, Location loc, + const OpFoldResult lhs, const Value rhs) { + auto lhsIntAttr = getIntAttr(lhs); + + auto rhsIsConst = false; + // if rhs is not a const, use max value since min is used to represent + // dynamic size or stride + auto rhsConstValue = std::numeric_limits::max(); + auto rhsOp = rhs.getDefiningOp(); + if (rhsOp) { + rhsIsConst = true; + rhsConstValue = cast(rhsOp.getValue()).getInt(); + } + + // shortcuts for special cases + if (lhsIntAttr) { + if (lhsIntAttr.value() == 0) + return lhs; + if (lhsIntAttr.value() == 1) + return rhs; + } + if (rhsIsConst) { + if (rhsConstValue == 0) + return rhsOp.getResult(); + if (rhsConstValue == 1) + return lhs; + } + + // 0. both lhs and rhs are constants + if (lhsIntAttr && rhsIsConst) + return builder.getIndexAttr(lhsIntAttr.value() * rhsConstValue); + + // 1. if lhs is constant but rhs is not + if (lhsIntAttr && !rhsIsConst) { + auto lhsConstOp = builder.create( + loc, builder.getIndexAttr(lhsIntAttr.value())); + auto mulOp = + builder.create(loc, lhsConstOp.getResult(), rhs); + return mulOp.getResult(); + } + + // 2. if lhs is not constant + assert(!lhsIntAttr); + auto mulOp = builder.create(loc, lhs.get(), rhs); + return mulOp.getResult(); +} + +OpFoldResult minOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return builder.getIndexAttr( + std::min(lhsIntAttr.value(), rhsIntAttr.value())); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = builder.create( + loc, builder.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = builder.create( + loc, builder.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto minOp = builder.create(loc, lhsValue, rhsValue); + return minOp.getResult(); +} + +OpFoldResult maxOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return builder.getIndexAttr( + std::max(lhsIntAttr.value(), rhsIntAttr.value())); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = builder.create( + loc, builder.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = builder.create( + loc, builder.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto maxOp = builder.create(loc, lhsValue, rhsValue); + return maxOp.getResult(); +} + +OpFoldResult remOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return builder.getIndexAttr(lhsIntAttr.value() % rhsIntAttr.value()); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = builder.create( + loc, builder.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = builder.create( + loc, builder.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto remOp = builder.create(loc, lhsValue, rhsValue); + return remOp.getResult(); +} + +OpFoldResult divOFRs(OpBuilder &builder, Location loc, const OpFoldResult lhs, + const OpFoldResult rhs) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return builder.getIndexAttr(lhsIntAttr.value() / rhsIntAttr.value()); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = builder.create( + loc, builder.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = builder.create( + loc, builder.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto divOp = builder.create(loc, lhsValue, rhsValue); + return divOp.getResult(); +} + +} // namespace gcu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/PtrAnalysis.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/PtrAnalysis.cpp new file mode 100644 index 000000000..b3cf21f5f --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Analysis/PtrAnalysis.cpp @@ -0,0 +1,1607 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "Analysis/PtrAnalysis.h" + +#include "Analysis/AxisInfoEx.h" +#include "Analysis/MaskAnalysis.h" +#include "Analysis/OpFoldResultUtils.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" + +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "ptr-analysis" + +namespace mlir { +namespace triton { +namespace gcu { + +static llvm::DenseSet addedAssertOps; +static int64_t kIndentSpaceNum = 0; + +static void printBeforeVisit(Operation *op) { + auto spaces = std::string(kIndentSpaceNum, ' '); + kIndentSpaceNum += 3; + + LLVM_DEBUG({ + llvm::dbgs() << spaces << "=== visit operand of " << op->getName() + << " ENTER ===\n"; + llvm::dbgs() << spaces; + op->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +static void printAfterVisit(Operation *op) { + kIndentSpaceNum -= 3; + auto spaces = std::string(kIndentSpaceNum, ' '); + + LLVM_DEBUG({ + llvm::dbgs() << spaces << "=== visit operand of " << op->getName() + << " EXIT ===\n"; + }); + + if (kIndentSpaceNum == 0) { + LLVM_DEBUG(llvm::dbgs() << "\n"); + } +} + +int64_t PtrState::getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); + return offsets.size(); +} + +bool PtrState::isEmpty() const { + return (getRank() == 0 && !source && !scalar); +} + +void PtrState::addState(OpBuilder &builder, Location loc, + const PtrState &lhsState, const PtrState &rhsState) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + // at most one of lhs and rhs should have valid source, since otherwise we + // will be losing information + assert(!(lhsState.source && rhsState.source)); + this->source = lhsState.source ? lhsState.source : rhsState.source; + + if (lhsState.scalar && rhsState.scalar) { + auto addOp = + builder.create(loc, lhsState.scalar, rhsState.scalar); + this->scalar = addOp.getResult(); + } else if (lhsState.getRank() == 0) { + this->scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + + for (uint64_t i = 0; i < lhsState.sizes.size(); ++i) { + auto newOffset = + addOFRs(builder, loc, lhsState.offsets[i], rhsState.offsets[i]); + this->offsets.push_back(newOffset); + + auto newStride = + addOFRs(builder, loc, lhsState.strides[i], rhsState.strides[i]); + this->strides.push_back(newStride); + + this->sizes.push_back(lhsState.sizes[i]); + } +} + +void PtrState::mulState(OpBuilder &builder, Location loc, + const PtrState &lhsState, const PtrState &rhsState) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + + if (lhsState.scalar && rhsState.scalar) { + LLVM_DEBUG(llvm::dbgs() << "both PtrStates are scalars\n"); + for (uint64_t i = 0; i < lhsState.sizes.size(); ++i) { + this->offsets.push_back( + mulOFRValue(builder, loc, lhsState.offsets[i], rhsState.scalar)); + this->strides.push_back( + mulOFRValue(builder, loc, lhsState.strides[i], rhsState.scalar)); + this->sizes.push_back(lhsState.sizes[i]); + } + return; + } + + bool rhsScalar = true; + // neither lhs nor rhs should have source, since multiplying base pointer + // does not make sense + assert(!(lhsState.source && rhsState.source)); + this->source = lhsState.source ? lhsState.source : rhsState.source; + + assert((lhsState.scalar || rhsState.scalar) && + !(lhsState.scalar && rhsState.scalar) && + "currently does not support both tensors are effectively non-scalar"); + if (!rhsState.scalar && lhsState.scalar) + rhsScalar = false; + + for (uint64_t i = 0; i < lhsState.sizes.size(); ++i) { + OpFoldResult newOffset; + OpFoldResult newStride; + if (rhsScalar) { + newOffset = + mulOFRValue(builder, loc, lhsState.offsets[i], rhsState.scalar); + newStride = + mulOFRValue(builder, loc, lhsState.strides[i], rhsState.scalar); + } else { + newOffset = + mulOFRValue(builder, loc, rhsState.offsets[i], lhsState.scalar); + newStride = + mulOFRValue(builder, loc, rhsState.strides[i], lhsState.scalar); + } + this->offsets.push_back(newOffset); + this->strides.push_back(newStride); + this->sizes.push_back(lhsState.sizes[i]); + } +} + +void PtrState::remState(OpBuilder &builder, Location loc, + const PtrState &lhsState, const PtrState &rhsState) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + assert(!lhsState.source && !rhsState.source); + + this->source = lhsState.source; + for (uint64_t i = 0; i < lhsState.sizes.size(); ++i) { + auto newOffset = + remOFRs(builder, loc, lhsState.offsets[i], rhsState.offsets[i]); + this->offsets.push_back(newOffset); + + this->strides.push_back(lhsState.strides[i]); + this->sizes.push_back(lhsState.sizes[i]); + } +} + +void PtrState::divState(OpBuilder &builder, Location loc, + const PtrState &lhsState, const PtrState &rhsState) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + assert(!lhsState.source && !rhsState.source); + + this->source = lhsState.source; + for (uint64_t i = 0; i < lhsState.sizes.size(); ++i) { + auto newOffset = + divOFRs(builder, loc, lhsState.offsets[i], rhsState.offsets[i]); + this->offsets.push_back(newOffset); + + this->strides.push_back(lhsState.strides[i]); + this->sizes.push_back(lhsState.sizes[i]); + } +} + +void PtrState::setState(OpBuilder &builder, Location loc, + const PtrState &srcState) { + if (srcState.source) + this->source = srcState.source; + if (srcState.scalar) + this->scalar = srcState.scalar; + for (uint64_t i = 0; i < srcState.sizes.size(); ++i) { + this->offsets.push_back(srcState.offsets[i]); + this->strides.push_back(srcState.strides[i]); + this->sizes.push_back(srcState.sizes[i]); + } +} + +bool isZeroStride(OpBuilder &builder, Location loc, const OpFoldResult ofr) { + if (auto attr = getIntAttr(ofr)) { + return attr.value() == 0; + } + + // Assume stride can not be changed, if it is an argument of ForOp and it is + // value is 0, then take it as zero stride. + auto value = getValue(builder, loc, ofr); + if (auto arg = dyn_cast(value)) { + if (auto forOp = dyn_cast(arg.getOwner()->getParentOp())) { + auto argIndex = arg.getArgNumber() - forOp.getNumInductionVars(); + assert(argIndex < forOp.getInitArgs().size()); + + auto initArg = forOp.getInitArgs()[argIndex]; + if (auto constOp = + dyn_cast(initArg.getDefiningOp())) { + return constOp.value() == 0; + } + } + } + + return false; +} + +PtrInfo PtrState::getPtrInfo(OpBuilder &builder, Location loc, + const MaskState &mstate) { + PtrInfo ptrInfo; + + assert(isa(this->source.getType())); + auto elemType = + cast(this->source.getType()).getPointeeType(); + auto bpe = elemType.getIntOrFloatBitWidth() / 8; + + auto offsets = getValues(builder, loc, this->offsets); + auto strides = getValues(builder, loc, this->strides); + auto sizes = getValues(builder, loc, this->sizes); + + auto rank = getRank(); + auto zero = builder.create(loc, 0); + auto one = builder.create(loc, 1); + + auto addr = builder.create(loc, builder.getI64Type(), + this->source); + auto base = builder.create( + loc, addr, + builder.create( + loc, builder.create(loc, bpe, /*width=*/64), + builder.create(loc, builder.getI64Type(), + offsets[0]))); + if (rank == 1) { + ptrInfo.base = builder.create( + loc, PtrType::get(builder.getContext(), elemType), base.getResult()); + ptrInfo.shape.push_back( + mstate.isEmpty() ? sizes[0] : getValues(builder, loc, mstate.dims)[0]); + + ptrInfo.offsets.push_back(zero); + if (!isZeroStride(builder, loc, this->strides[0])) { + ptrInfo.strides.push_back(strides[0]); + } else { + ptrInfo.strides.push_back(one); + ptrInfo.broadcastDims.insert(0); + } + } else if (rank >= 2 && rank <= 4) { + for (int i = 1; i < rank; ++i) { + base = builder.create( + loc, base, + builder.create( + loc, builder.create(loc, bpe, /*width=*/64), + builder.create(loc, builder.getI64Type(), + offsets[i]))); + } + ptrInfo.base = builder.create( + loc, PtrType::get(builder.getContext(), elemType), base.getResult()); + for (int i = rank - 1; i >= 0; --i) { + ptrInfo.offsets.push_back(zero); + + auto shapeBegin = ptrInfo.shape.begin(); + if (!mstate.isEmpty()) { + ptrInfo.shape.insert(shapeBegin, + getValues(builder, loc, mstate.dims)[i]); + } else { + ptrInfo.shape.insert(shapeBegin, sizes[i]); + } + + auto strideBegin = ptrInfo.strides.begin(); + if (!isZeroStride(builder, loc, this->strides[i])) { + ptrInfo.strides.insert(strideBegin, strides[i]); + } else { + ptrInfo.broadcastDims.insert(i); + ptrInfo.strides.insert(strideBegin, zero); + } + } + } else { + // not support + assert(false && "not support rank >= 5"); + } + return ptrInfo; +} + +void PtrAnalysis::visitOperand( + PatternRewriter &rewriter, Location loc, Value operand, PtrState &state, + llvm::SmallDenseMap &knownPtrs) { + LLVM_DEBUG(llvm::dbgs() << std::string(kIndentSpaceNum, ' ') + << "enter visitOperand" << operand << "\n"); + if (knownPtrs.find(operand) != knownPtrs.end()) { + state = knownPtrs.lookup(operand); + LLVM_DEBUG(llvm::dbgs() << std::string(kIndentSpaceNum, ' ') + << "operand is a known ptr " << operand << "\n"); + return; + } + if (isa(operand.getType())) { + auto castOp = rewriter.create( + loc, rewriter.getIndexType(), operand); + state.scalar = castOp.getResult(); + LLVM_DEBUG(llvm::dbgs() << std::string(kIndentSpaceNum, ' ') + << "operand is an integer " << operand << "\n"); + return; + } + + if (auto arg = dyn_cast(operand)) { + LLVM_DEBUG(llvm::dbgs() << std::string(kIndentSpaceNum, ' ') + << "operand is block argument " << arg << "\n"); + visitBlockArgument(rewriter, loc, arg, state, knownPtrs); + return; + } + // Supported ops + if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandConstSplat(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandAdd(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandMul(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandRem(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandDiv(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandSelect(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandExtsi(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandExtui(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandMakeRange(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandBroadcast(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandSplat(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandExpandDims(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandAddptr(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandBitcast(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandTrans(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandDot(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandReduce(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else if (auto op = operand.getDefiningOp()) { + printBeforeVisit(op); + visitOperandLoad(rewriter, loc, op, state, knownPtrs); + printAfterVisit(op); + } else { + operand.dump(); + llvm_unreachable("unexpected to visit the operand\n"); + } +} + +void PtrAnalysis::visitBlockArgument( + PatternRewriter &rewriter, Location loc, BlockArgument blockArg, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + assert(!isa(blockArg.getOwner()->getParentOp())); + state.source = blockArg; +} + +void PtrAnalysis::visitOperandConstSplat( + PatternRewriter &rewriter, Location loc, arith::ConstantOp op, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + // this condition is to handle cases where tt.broadcast and tt.splat are + // folded + auto attr = cast(op.getValue()); + assert(attr.isSplat() && isa(attr.getElementType())); + auto values = attr.getValues(); + auto value = values[0].getValue(); + auto constAttr = rewriter.getIndexAttr(value.getSExtValue()); + auto constOp = rewriter.create( + loc, rewriter.getIndexType(), constAttr); + + state.scalar = constOp; + + auto resultType = cast(op.getResult().getType()); + for (uint64_t i = 0; i < resultType.getShape().size(); ++i) { + if (i == 0) { + state.offsets.push_back(constOp.getResult()); + } else { + state.offsets.push_back(rewriter.getIndexAttr(0)); + } + + state.sizes.push_back(rewriter.getIndexAttr(resultType.getShape()[i])); + state.strides.push_back(rewriter.getIndexAttr(0)); + } +} + +void PtrAnalysis::visitOperandAdd( + PatternRewriter &rewriter, Location loc, arith::AddIOp addOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + PtrState lhsState; + visitOperand(rewriter, loc, addOp.getLhs(), lhsState, knownPtrs); + + PtrState rhsState; + visitOperand(rewriter, loc, addOp.getRhs(), rhsState, knownPtrs); + + state.addState(rewriter, loc, lhsState, rhsState); +} + +void PtrAnalysis::visitOperandMul( + PatternRewriter &rewriter, Location loc, arith::MulIOp mulOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + PtrState lhsState; + visitOperand(rewriter, loc, mulOp.getLhs(), lhsState, knownPtrs); + + PtrState rhsState; + visitOperand(rewriter, loc, mulOp.getRhs(), rhsState, knownPtrs); + + state.mulState(rewriter, loc, lhsState, rhsState); +} + +void PtrAnalysis::visitOperandRem( + PatternRewriter &rewriter, Location loc, arith::RemSIOp remOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + PtrState lhsState; + visitOperand(rewriter, loc, remOp.getLhs(), lhsState, knownPtrs); + + PtrState rhsState; + visitOperand(rewriter, loc, remOp.getRhs(), rhsState, knownPtrs); + + state.remState(rewriter, loc, lhsState, rhsState); +} + +void PtrAnalysis::visitOperandDiv( + PatternRewriter &rewriter, Location loc, arith::DivSIOp divOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + PtrState lhsState; + visitOperand(rewriter, loc, divOp.getLhs(), lhsState, knownPtrs); + + PtrState rhsState; + visitOperand(rewriter, loc, divOp.getRhs(), rhsState, knownPtrs); + + state.divState(rewriter, loc, lhsState, rhsState); +} + +void PtrAnalysis::visitOperandSelect( + PatternRewriter &rewriter, Location loc, arith::SelectOp selectOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + PtrState trueState; + visitOperand(rewriter, loc, selectOp.getTrueValue(), trueState, knownPtrs); + + PtrState falseState; + visitOperand(rewriter, loc, selectOp.getFalseValue(), falseState, knownPtrs); + + // now selectop is bypass, the state is unuse; In the future, we will analyze + // it under certain constraints. + state.setState(rewriter, loc, trueState); +} + +void PtrAnalysis::visitOperandMakeRange( + PatternRewriter &rewriter, Location loc, triton::MakeRangeOp rangeOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + auto shape = cast(rangeOp.getType()).getShape(); + + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + + state.offsets.push_back(rewriter.getIndexAttr(start)); + state.sizes.push_back(rewriter.getIndexAttr(shape[0])); + state.strides.push_back(rewriter.getIndexAttr(stride)); +} + +void PtrAnalysis::visitOperandExpandDims( + PatternRewriter &rewriter, Location loc, triton::ExpandDimsOp expandDimsOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + visitOperand(rewriter, loc, expandDimsOp.getSrc(), state, knownPtrs); + + auto axis = expandDimsOp.getAxis(); + + assert( + cast(expandDimsOp.getResult().getType()).getShape()[axis] == + 1 && + "expect changed dimension to be 1 in expand_dims"); + + // insert dimension info + state.offsets.insert(state.offsets.begin() + axis, rewriter.getIndexAttr(0)); + state.sizes.insert(state.sizes.begin() + axis, rewriter.getIndexAttr(1)); + state.strides.insert(state.strides.begin() + axis, rewriter.getIndexAttr(0)); +} + +void PtrAnalysis::visitOperandBroadcast( + PatternRewriter &rewriter, Location loc, triton::BroadcastOp broadcastOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + visitOperand(rewriter, loc, src, state, knownPtrs); + + for (uint64_t i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + state.sizes[i] = rewriter.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast"); + } +} + +void PtrAnalysis::visitOperandSplat( + PatternRewriter &rewriter, Location loc, triton::SplatOp splatOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + visitOperand(rewriter, loc, src, state, knownPtrs); + + if (isa(src.getType()) || + isa(src.getType())) { + for (auto s : dstShape) { + state.offsets.push_back(rewriter.getIndexAttr(0)); + state.sizes.push_back(rewriter.getIndexAttr(s)); + state.strides.push_back(rewriter.getIndexAttr(0)); + } + } else { + llvm_unreachable("unexpected src type used in splat"); + } + + // If we splat a integer value, scalar should become the offset of the outer + // most dimension + if (state.scalar) { + state.offsets[0] = state.scalar; + } +} + +void PtrAnalysis::visitOperandAddptr( + PatternRewriter &rewriter, Location loc, triton::AddPtrOp addptrOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + + PtrState ptrState; + visitOperand(rewriter, loc, addptrOp.getPtr(), ptrState, knownPtrs); + + PtrState offsetState; + visitOperand(rewriter, loc, addptrOp.getOffset(), offsetState, knownPtrs); + + assert(ptrState.source && "ptr field should provide source / base pointer"); + assert(ptrState.getRank() == offsetState.getRank() && + "ptr and offset field should have the same rank"); + + state.addState(rewriter, loc, ptrState, offsetState); +} + +void PtrAnalysis::visitOperandBitcast( + PatternRewriter &rewriter, Location loc, triton::BitcastOp bitcastOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + auto src = bitcastOp.getSrc(); + auto dst = bitcastOp.getResult(); + if ((!state.source) && isa(src.getType())) { + assert(isa(dst.getType()) && + "input is pointer output should also a pointer"); + PtrState srcState; + visitOperand(rewriter, loc, src, srcState, knownPtrs); + for (uint64_t i = 0; i < srcState.sizes.size(); ++i) { + state.offsets.push_back(srcState.offsets[i]); + state.sizes.push_back(srcState.sizes[i]); + state.strides.push_back(srcState.strides[i]); + } + state.scalar = srcState.scalar; + + auto newBitcastOp = + rewriter.create(loc, dst.getType(), srcState.source); + state.source = newBitcastOp.getResult(); + return; + } + assert(isa(src.getType()) && + "input to tt.bitcast should be a tensor"); + + PtrState srcState; + visitOperand(rewriter, loc, src, srcState, knownPtrs); + + auto dstElemType = cast(dst.getType()).getElementType(); + assert(isa( + cast(src.getType()).getElementType())); + + for (uint64_t i = 0; i < srcState.sizes.size(); ++i) { + state.offsets.push_back(srcState.offsets[i]); + state.sizes.push_back(srcState.sizes[i]); + state.strides.push_back(srcState.strides[i]); + } + state.scalar = srcState.scalar; + + auto newBitcastOp = + rewriter.create(loc, dstElemType, srcState.source); + state.source = newBitcastOp.getResult(); +} + +void PtrAnalysis::visitOperandTrans( + PatternRewriter &rewriter, Location loc, triton::TransOp transOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + auto src = transOp.getSrc(); + + PtrState srcState; + visitOperand(rewriter, loc, src, srcState, knownPtrs); + llvm::ArrayRef transOrder = transOp.getOrder(); + + for (uint64_t i = 0; i < transOrder.size(); ++i) { + state.offsets.push_back(srcState.offsets[transOrder[i]]); + state.sizes.push_back(srcState.sizes[transOrder[i]]); + state.strides.push_back(srcState.strides[transOrder[i]]); + } + state.scalar = srcState.scalar; + state.source = srcState.source; +} + +void PtrAnalysis::visitOperandDot( + PatternRewriter &rewriter, Location loc, triton::DotOp dotOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + auto src = dotOp.getC(); + + PtrState srcState; + visitOperand(rewriter, loc, src, srcState, knownPtrs); + + state.setState(rewriter, loc, srcState); +} + +void PtrAnalysis::visitOperandReduce( + PatternRewriter &rewriter, Location loc, triton::ReduceOp reduceOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + auto src = reduceOp.getSrcs()[0]; + auto axis = reduceOp.getAxis(); + + PtrState srcState; + visitOperand(rewriter, loc, src, srcState, knownPtrs); + + state.scalar = srcState.scalar; + state.source = srcState.source; + for (uint32_t i = 0; i < srcState.offsets.size(); ++i) { + if (i != axis) { + state.offsets.push_back(srcState.offsets[i]); + state.sizes.push_back(srcState.sizes[i]); + state.strides.push_back(srcState.strides[i]); + } + } +} + +void PtrAnalysis::visitOperandLoad( + PatternRewriter &rewriter, Location loc, triton::LoadOp loadOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + + auto src = loadOp.getPtr(); + PtrState srcState; + visitOperand(rewriter, loc, src, srcState, knownPtrs); + + state.setState(rewriter, loc, srcState); +} + +void PtrAnalysis::visitOperandExtsi( + PatternRewriter &rewriter, Location loc, arith::ExtSIOp extsiOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + auto src = extsiOp.getIn(); + + PtrState srcState; + visitOperand(rewriter, loc, src, srcState, knownPtrs); + + state.setState(rewriter, loc, srcState); +} + +void PtrAnalysis::visitOperandExtui( + PatternRewriter &rewriter, Location loc, arith::ExtUIOp extuiOp, + PtrState &state, llvm::SmallDenseMap &knownPtrs) { + assert(state.isEmpty()); + auto src = extuiOp.getIn(); + + PtrState srcState; + visitOperand(rewriter, loc, src, srcState, knownPtrs); + + state.setState(rewriter, loc, srcState); +} + +bool isPtrFromLoad(Value v, llvm::DenseMap &valueFromLoads); +bool isMaskCandidate(Value v, llvm::DenseMap &valueToCandiates); + +void PtrAnalysis::rewriteYieldOp( + PatternRewriter &rewriter, scf::YieldOp op, + llvm::SmallDenseMap &knownPtrs, + llvm::SmallDenseMap &knownMasks) { + // any inserted instruction should be before this yield + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + auto adaptor = scf::YieldOp::Adaptor(op); + llvm::SmallVector yieldArgState; + llvm::SmallVector yieldArgMaskState; + llvm::SmallVector operands(adaptor.getOperands()); + + LLVM_DEBUG(llvm::dbgs() << "ptr rewriteYieldOp start: \n"); + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below gathers + // PtrState for those values. + for (auto [i, v] : llvm::enumerate(operands)) { + (void)i; + auto tType = dyn_cast(v.getType()); + if (tType && ((tType.getElementType().isIntOrIndex() && + !tType.getElementType().isInteger(1)) || + isa(tType.getElementType()))) { + llvm::DenseMap valueFromLoads; + if (!isPtrFromLoad(v, valueFromLoads)) { + PtrState state; + visitOperand(rewriter, op.getLoc(), v, state, knownPtrs); + yieldArgState.push_back(state); + LLVM_DEBUG(llvm::dbgs() << "ptr yieldArgState size:" + << yieldArgState.size() << "\n"); + } + } + } + + // mask info process + for (auto [i, v] : llvm::enumerate(operands)) { + auto tType = dyn_cast(v.getType()); + if (tType && ((tType.getElementType().isIntOrIndex() && + !tType.getElementType().isInteger(1)))) { + llvm::DenseMap valueToCandiates; + if (isMaskCandidate(v, valueToCandiates)) { + MaskState state; + gcu::MaskAnalysis::parse(rewriter, op.getLoc(), v, state, knownMasks); + yieldArgMaskState.push_back(state); + LLVM_DEBUG(llvm::dbgs() << "ptr yieldArgMaskState size:" + << yieldArgMaskState.size() << "\n"); + } + } + (void)i; + } + + // For each of the PtrState recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : yieldArgState) { + if (state.scalar) { + operands.push_back(state.scalar); + } + for (auto s : state.offsets) { + // offsets can be IntAttr zeroes, since reinterpret_cast collapses them + // for the input memref, and the for loop may not update offsets other + // than offsets[0]. Create constants Values for those zeroes. + if (auto sIntAttr = getIntAttr(s)) { + assert(sIntAttr.value() == 0 && "attribute offsets should be zeroes"); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + operands.push_back(constOp.getResult()); + } else { + operands.push_back(s.get()); + } + } + + for (auto s : state.strides) { + assert(!getIntAttr(s) && + "PtrState strides for yield within for loop not expected to be " + "attribute."); + operands.push_back(s.get()); + } + } + + // mask info process + for (auto state : yieldArgMaskState) { + if (state.start) { + auto sIntAttr = getIntAttr(state.start); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + operands.push_back(constOp.getResult()); + state.start = constOp.getResult(); + } else { + operands.push_back(state.start.get()); + } + } + if (state.end) { + auto sIntAttr = getIntAttr(state.end); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + operands.push_back(constOp.getResult()); + state.end = constOp.getResult(); + } else { + operands.push_back(state.end.get()); + } + } + } + + // Yield is a terminator op that must be at the end of the function + rewriter.setInsertionPointAfter(op); + rewriter.replaceOpWithNewOp(op, operands); + assert(op->getNumResults() == 0); +} + +bool PtrAnalysis::byPassForOp(PatternRewriter &rewriter, scf::ForOp op, + const SmallVector &candidateOps) { + bool bypass = true; + + op.walk([&](mlir::Operation *_op) { + bypass = mlir::TypeSwitch(_op) + .Case([&](auto loadstoreOp) { + auto iter = + std::find(candidateOps.begin(), candidateOps.end(), + loadstoreOp.getOperation()); + return iter == candidateOps.end(); + }) + .Default([&](auto op) { return true; }); + return !bypass ? WalkResult::interrupt() : WalkResult::advance(); + }); + + return bypass; +} + +LogicalResult PtrAnalysis::rewriteForOp( + PatternRewriter &rewriter, scf::ForOp op, + SmallDenseMap &knownPtrs, + SmallDenseMap &knownMasks, + SmallVector &candidateOps, + SmallDenseMap> &candidateHints) { + llvm::SmallVector newInitArgs; + llvm::SmallVector> initArgIndexState; + llvm::SmallVector> initmaskIndexState; + + LLVM_DEBUG(llvm::dbgs() << "rewriteForOp: \n"); + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + newInitArgs.push_back(arg); + // Only parse those args whose type is ptr or int tensor, since they will + // be possible of operands of triton::AddPtrOp + LLVM_DEBUG(llvm::dbgs() << "i: " << i << "\n"); + auto tType = dyn_cast(arg.getType()); + if (tType && ((tType.getElementType().isIntOrIndex() && + !tType.getElementType().isInteger(1)) || + isa(tType.getElementType()))) { + llvm::DenseMap valueFromLoads; + if (!isPtrFromLoad(op.getRegionIterArg(i), valueFromLoads)) { + PtrState state; + visitOperand(rewriter, op.getLoc(), arg, state, knownPtrs); + // Record the PtrState for later processing + initArgIndexState.push_back(std::make_pair(i, state)); + } + } + } + // mask info process + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + auto tType = dyn_cast(arg.getType()); + if (tType && ((tType.getElementType().isIntOrIndex() && + !tType.getElementType().isInteger(1)))) { + llvm::DenseMap valueToCandiates; + if (isMaskCandidate(op.getRegionIterArg(i), valueToCandiates)) { + MaskState state; + gcu::MaskAnalysis::parse(rewriter, op.getLoc(), arg, state, knownMasks); + initmaskIndexState.push_back(std::make_pair(i, state)); + } + } + } + + if (initmaskIndexState.size() == 0 && initArgIndexState.size() == 0) + return failure(); + + // Set insertion point to be before the for loop for new variables passed + // into the new loop. + auto origIp = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + // For each of the PtrState recorded in the last step, insert new + // instructions to describe offset and stride for each dimension and append + // them to init args + for (auto [i, state] : initArgIndexState) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the end + // of init arg list. + (void)i; + if (state.scalar) { + newInitArgs.push_back(state.scalar); + } + for (auto [j, s] : llvm::enumerate(state.offsets)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.offsets[j] = constOp.getResult(); + } else { + newInitArgs.push_back(s.get()); + } + } + + for (auto [j, s] : llvm::enumerate(state.strides)) { + auto sIntAttr = getIntAttr(s); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.strides[j] = constOp.getResult(); + } else { + newInitArgs.push_back(s.get()); + } + } + } + + // mask info process + for (auto [i, state] : initmaskIndexState) { + if (state.start) { + auto sIntAttr = getIntAttr(state.start); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.start = constOp.getResult(); + } else { + newInitArgs.push_back(state.start.get()); + } + } + + if (state.end) { + auto sIntAttr = getIntAttr(state.end); + if (sIntAttr) { + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(sIntAttr.value())); + newInitArgs.push_back(constOp.getResult()); + state.end = constOp.getResult(); + } else { + newInitArgs.push_back(state.end.get()); + } + } + (void)i; + } + rewriter.restoreInsertionPoint(origIp); + + // Create a new scf::ForOp that uses updated init args and same loop body + auto newOp = rewriter.create( + op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), + newInitArgs, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange args) { + IRMapping mapping; + mapping.map(op.getInductionVar(), iv); + mapping.map(op.getInitArgs(), newInitArgs); + mapping.map(op.getRegionIterArgs(), args); + for (Operation &bodyOp : op.getBody()->getOperations()) { + Operation *newOp = builder.clone(bodyOp, mapping); + if (candidateHints.contains(&bodyOp)) { + auto strideHint = candidateHints[&bodyOp]; + candidateHints.erase(&bodyOp); + candidateHints.insert(std::make_pair(newOp, strideHint)); + + auto it = + std::find(candidateOps.begin(), candidateOps.end(), &bodyOp); + assert(it != candidateOps.end()); + + candidateOps.erase(it); + candidateOps.push_back(newOp); + } + } + }); + newOp->setAttrs(op->getAttrs()); + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's PtrState fields are converted from init arg to newly created block + // arg + int cnt = op.getRegionIterArgs().size(); + LLVM_DEBUG(llvm::dbgs() << "rewriteForOp RegionIterArgs init size: " << cnt + << "\n"); + + for (auto [i, state] : initArgIndexState) { + if (state.scalar) { + state.scalar = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + for (auto it = state.offsets.begin(); it != state.offsets.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + for (auto it = state.strides.begin(); it != state.strides.end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + LLVM_DEBUG(llvm::dbgs() + << "rewriteForOp RegionIterArgs loop size: " << cnt << "\n"); + auto key = newOp.getRegionIterArgs()[i]; + knownPtrs.insert(std::make_pair(key, state)); + } + + // mask info process + for (auto [i, state] : initmaskIndexState) { + if (state.start) { + state.start = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + if (state.end) { + state.end = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + auto key = newOp.getRegionIterArgs()[i]; + knownMasks.insert(std::make_pair(key, state)); + } + + assert(static_cast(cnt) == newOp.getRegionIterArgs().size() && + "expect to remap all new block args"); + LLVM_DEBUG(llvm::dbgs() << "rewriteForOp getNumResults size: " + << op.getNumResults() << "\n"); + // Replace only the results that correspond to the original scf.for + auto resultsToReplaceWith = ResultRange( + newOp.result_begin(), newOp.result_begin() + op.getNumResults()); + rewriter.replaceOp(op, resultsToReplaceWith); + if (newOp.getNumRegionIterArgs()) { + LLVM_DEBUG(llvm::dbgs() << "newOp getNumRegionIterArgs size: " + << newOp.getNumRegionIterArgs() << "\n"); + auto yieldOp = cast(newOp.getBody()->getTerminator()); + rewriteYieldOp(rewriter, yieldOp, knownPtrs, knownMasks); + } + LLVM_DEBUG({ + llvm::dbgs() << "ptr analysis create new for\n"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + return success(); +} + +void PtrAnalysis::foldAwayForOp( + PatternRewriter &rewriter, scf::ForOp forOp, + llvm::SmallDenseMap &knownPtrs) { + LLVM_DEBUG(llvm::dbgs() << "foldAwayForOp: \n"); + for (auto it : llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs(), + forOp.getResults(), forOp.getYieldedValues())) { + // mlir's std canonicalize pass will handle this case + bool forwarded = + ((std::get<1>(it) == std::get<3>(it)) || + (std::get<1>(it).use_empty() && + (std::get<0>(it) == std::get<3>(it) || std::get<2>(it).use_empty()))); + if (forwarded || !std::get<1>(it).hasOneUse()) + continue; + + // Note: try to support more patterns + for (auto op : std::get<1>(it).getUsers()) { + size_t totalUsers = 0; + for (OpResult result : op->getResults()) { + auto userRange = result.getUsers(); + totalUsers += std::distance(userRange.begin(), userRange.end()); + } + + if (totalUsers == 1 && op->getResult(0) == std::get<3>(it) && + std::get<2>(it).use_empty()) { + op->getResult(0).replaceAllUsesWith(std::get<1>(it)); + } + } + } +} + +bool checkElemType(Type t) { + if (!isa(t)) + return false; + + auto tensorType = dyn_cast(t); + if (!tensorType.getElementType().isIntOrFloat() || + tensorType.getElementType().getIntOrFloatBitWidth() > 32) + return false; + + // Note: add other limits if needed + return true; +} + +bool checkNoScalar(Type t) { + if (!isa(t)) + return false; + + auto tensorType = dyn_cast(t); + auto shape = tensorType.getShape(); + if (std::all_of(shape.begin(), shape.end(), [](int i) { return i == 1; })) { + return false; + } + + // Note: add other limits if needed + return true; +} + +bool checkPtrType(Type t) { + if (!isa(t)) + return false; + + if (isa(t) && + isa( + dyn_cast(t).getPointeeType())) { + return false; + } + + // Note: add other limits if needed + return true; +} + +// If load/store's ptr operand (actually the offsets) is from other load op, +// then bypass this load/store op. Since the offsets are dynamic, there is no +// way to check whether offsets are continuous +bool isPtrFromLoad(Value v, llvm::DenseMap &valueFromLoads) { + if (valueFromLoads.contains(v)) { + return valueFromLoads.at(v); + } + + if (isa(v.getType())) { + valueFromLoads.insert(std::make_pair(v, false)); + return false; + } + + bool bypass = false; + // need more check if it is the block argument of ForOp + if (!v.getDefiningOp()) { + auto blockArgOp = dyn_cast_or_null(v); + if (blockArgOp && isa(blockArgOp.getOwner()->getParentOp())) { + auto forOp = dyn_cast(blockArgOp.getOwner()->getParentOp()); + auto idx = blockArgOp.getArgNumber() - forOp.getNumInductionVars(); + + auto initValue = forOp.getInitArgs()[idx]; + bypass = initValue.getDefiningOp() + ? isPtrFromLoad(initValue, valueFromLoads) + : true; + + /// yieldOp maybe use the block argument which produce infinite loop. + valueFromLoads.insert(std::make_pair(v, bypass)); + + /// if already bypass, no need to analysis yield. + if (!bypass) { + auto yieldOp = cast(forOp.getBody()->getTerminator()); + auto yieldValue = yieldOp.getOperands()[idx]; + bool yieldBypass = isPtrFromLoad(yieldValue, valueFromLoads); + bypass = bypass || yieldBypass; + + valueFromLoads[v] = bypass; + } + } else { + valueFromLoads.insert(std::make_pair(v, bypass)); + } + + return bypass; + } + TypeSwitch(v.getDefiningOp()) + .Case([&](auto op) { bypass = true; }) + .Case( + [&](auto op) { bypass = false; }) + .Case([&](triton::AddPtrOp op) { + bypass = isPtrFromLoad(op.getPtr(), valueFromLoads) || + isPtrFromLoad(op.getOffset(), valueFromLoads); + }) + .Case( + [&](auto op) { bypass = isPtrFromLoad(op.getSrc(), valueFromLoads); }) + .Case( + [&](auto op) { + bypass = isPtrFromLoad(op.getLhs(), valueFromLoads) || + isPtrFromLoad(op.getRhs(), valueFromLoads); + }) + .Case( + [&](auto op) { bypass = isPtrFromLoad(op.getIn(), valueFromLoads); }) + .Case( + [&](auto op) { + // Now bypass SelectOP, SubIOp, DivSIOp, RemSIOp and RemUIOp. + // Optimization will be considered in subsequent steps + LLVM_DEBUG(llvm::dbgs() + << "bypass from :" << op->getName().getStringRef().str() + << "\n"); + bypass = true; + }) + .Case([&](auto op) { + // Now bypass ForOp, WhileOp, IfOp op + LLVM_DEBUG(llvm::dbgs() << "bypass from :" + << op->getName().getStringRef().str() << "\n"); + bypass = true; + }) + .Default([&](auto op) { + std::string info = std::string("add logic to support op ") + + op->getName().getStringRef().str(); + llvm_unreachable(info.c_str()); + }); + valueFromLoads.insert(std::make_pair(v, bypass)); + return bypass; +} + +bool isPtrCandidate(Value v, const gcu::AxisInfoEx *axisInfoEx, + SmallVector &strideHint) { + if (!axisInfoEx) { + LLVM_DEBUG(llvm::dbgs() << "bypass load/store op not get axisInfoEx: \n"); + return false; + } + + llvm::DenseMap valueFromLoads; + if (isPtrFromLoad(v, valueFromLoads)) { + LLVM_DEBUG(llvm::dbgs() << "bypass load/store op is isPtrFromLoad: \n"); + return false; + } + + if (!isa(v.getType())) { + LLVM_DEBUG(llvm::dbgs() << "bypass load/store op is type error: \n"); + return false; + } + + assert(isa(v.getType())); + auto tensorType = dyn_cast(v.getType()); + auto tshape = tensorType.getShape(); + assert(tshape.size() == static_cast(axisInfoEx->getRank())); + + // bool isContiguous = false; + auto rank = axisInfoEx->getRank(); + if (rank >= 5) + return false; + + bool bcheckNotZero = false; + for (int idx = 0; idx < rank; ++idx) { + if (axisInfoEx->getContinualInterval(idx) != 0) { + bcheckNotZero = true; + break; + } + } + // assert((rank == 1 || bcheckNotZero) && "not support all stride is zero"); + + for (int i = 0; i < rank - 1; ++i) { + // skip ptr whose shape including dim size = 1 + if (tshape[i] == 1) { + LLVM_DEBUG(llvm::dbgs() << "bypass load/store op shape including 1: \n"); + return false; + } + if (axisInfoEx->getContinualInterval(i) <= 0) + continue; + for (int j = i + 1; j < rank; ++j) { + if (axisInfoEx->getContinualInterval(j) <= 0) + continue; + if ((axisInfoEx->getContinualInterval(i) % + axisInfoEx->getContinualInterval(j) != + 0) && + (axisInfoEx->getContinualInterval(j) % + axisInfoEx->getContinualInterval(i) != + 0)) { + LLVM_DEBUG(llvm::dbgs() + << "bypass load/store op static stride is not ratio: \n"); + return false; + } + } + } + + // check mismatch shape&stride, only check stride=1 + // auto minShapeContinue = tshape[0]; + // for (int i = 1; i < rank; ++i) + // if (tshape[i] < minShapeContinue) + // minShapeContinue = tshape[i]; + // for (int i = 0; i < rank; ++i) { + // if (axisInfoEx->isContinualLowDim(tshape, i)) { + // for (int j = 0; j < rank; ++j) { + // if (j == i || axisInfoEx->getContinualInterval(j) <= 0) + // continue; + + // if ((axisInfoEx->getContinualInterval(j) == 1) && + // (axisInfoEx->getContinualSize(j) == 1)) + // continue; + + // if (axisInfoEx->getContinualInterval(j) < minShapeContinue) { + // //std::cout << "check mismatch shape&stride error \n"; + // LLVM_DEBUG(llvm::dbgs() + // << "bypass load op static stride is mismatch shape: \n"); + // return false; + // } + // } + // } + // } + for (int i = 0; i < rank; ++i) { + int32_t stride = axisInfoEx->getContinualInterval(i); + strideHint.push_back(stride); + // if (stride < 0) + // bDynamicStride = true; + } + + if (std::count(strideHint.begin(), strideHint.end(), 1) > 1) { + LLVM_DEBUG(llvm::dbgs() + << "bypass load/store op including two dim with stride 1: \n"); + return false; + } + + for (int i = 0; i < rank; ++i) { + if (!axisInfoEx->isContinualDim(tshape, i)) { + LLVM_DEBUG(llvm::dbgs() + << "bypass load/store op is not continue shape: \n"); + return false; + } + } + + // isContiguous = true; + // if (bDynamicStride) { + // isContiguous = true; + // } else { + // for (int i = 0; i < rank; ++i) { + // if (axisInfoEx->isContinualLowDim(tshape, i) || + // axisInfoEx->getContinualInterval(i) == 0) { + // isContiguous = true; + // break; + // } + // } + // } + + LLVM_DEBUG(llvm::dbgs() << "ptr contiguous true:\n"); + for (int k = 0; k < rank; ++k) { + LLVM_DEBUG(llvm::dbgs() << "dim: " << k << "\n" + << "axisInfoEx.divisibility: " + << axisInfoEx->getDivisibility(k) << "\n" + << "axisInfoEx.continualsize: " + << axisInfoEx->getContinualSize(k) << "\n" + << "axisInfoEx.continualinterval: " + << axisInfoEx->getContinualInterval(k) << "\n" + << "tensor shape: " << tshape[k] << "\n" + << "stride hint: " << strideHint[k] << "\n"); + } + + return true; +} + +bool isMaskCandidate(Value v, llvm::DenseMap &valueToCandiates) { + if (valueToCandiates.contains(v)) { + return valueToCandiates.at(v); + } + + if (isa(v.getType())) { + valueToCandiates.insert(std::make_pair(v, true)); + return true; + } + + bool candidate = true; + if (auto arg = dyn_cast(v)) { + auto blockArgOp = dyn_cast_or_null(v); + if (blockArgOp && isa(blockArgOp.getOwner()->getParentOp())) { + auto forOp = dyn_cast(blockArgOp.getOwner()->getParentOp()); + auto idx = blockArgOp.getArgNumber() - forOp.getNumInductionVars(); + + auto initValue = forOp.getInitArgs()[idx]; + candidate = initValue.getDefiningOp() + ? isMaskCandidate(initValue, valueToCandiates) + : true; + + /// yieldOp maybe use the block argument which produce infinite loop. + valueToCandiates.insert(std::make_pair(v, candidate)); + /// if already not be candidate, no need to analysis yield. + if (candidate) { + auto yieldOp = cast(forOp.getBody()->getTerminator()); + auto yieldValue = yieldOp.getOperands()[idx]; + bool yieldCandidate = isMaskCandidate(yieldValue, valueToCandiates); + candidate = candidate && yieldCandidate; + + valueToCandiates[v] = candidate; + } + } else { + valueToCandiates.insert(std::make_pair(v, candidate)); + } + return candidate; + } + + candidate = false; + TypeSwitch(v.getDefiningOp()) + .Case([&](auto op) { candidate = false; }) + .Case( + [&](auto op) { candidate = true; }) + .Case( + [&](auto op) { + candidate = isMaskCandidate(op.getSrc(), valueToCandiates); + }) + .Case([&](auto op) { + candidate = isMaskCandidate(op.getLhs(), valueToCandiates) && + isMaskCandidate(op.getRhs(), valueToCandiates); + }) + .Case([&](auto op) { + candidate = isMaskCandidate(op.getIn(), valueToCandiates); + }) + .Case([&](auto op) { + // bypass DivSIOp, which is completely discontiguous index operation, + // and cannot be converted to dte + LLVM_DEBUG(llvm::dbgs() << "bypass from :" + << op->getName().getStringRef().str() << "\n"); + candidate = false; + }) + .Case([&](auto op) { + // bypass ForOp, IfOp, WhileOp, + // which is maybe discontiguous index operation. + LLVM_DEBUG(llvm::dbgs() << "bypass from :" + << op->getName().getStringRef().str() << "\n"); + candidate = false; + }) + .Case([&](auto op) { + assert(isa(op.getSrc().getType()) && + "splat source must be an integer scalar for load/store masks"); + candidate = isMaskCandidate(op.getSrc(), valueToCandiates); + }) + .Case([&](auto op) { + if (op.getPredicate() == arith::CmpIPredicate::slt || + op.getPredicate() == arith::CmpIPredicate::ult || + op.getPredicate() == arith::CmpIPredicate::sge || + op.getPredicate() == arith::CmpIPredicate::uge) { + if (auto tensorType = dyn_cast(op.getLhs().getType())) { + auto shape = tensorType.getShape(); + if (shape.size() >= 2 && + std::all_of(shape.begin(), shape.end(), + [](int i) { return i != 1; })) { + candidate = false; + } else { + candidate = isMaskCandidate(op.getLhs(), valueToCandiates) && + isMaskCandidate(op.getRhs(), valueToCandiates); + } + } else { + candidate = isMaskCandidate(op.getLhs(), valueToCandiates) && + isMaskCandidate(op.getRhs(), valueToCandiates); + } + } else { + candidate = false; + } + }) + .Default([&](auto op) { + std::string info = std::string("add logic to support op ") + + op->getName().getStringRef().str(); + llvm_unreachable(info.c_str()); + }); + + valueToCandiates.insert(std::make_pair(v, candidate)); + return candidate; +} + +void PtrAnalysis::collectCandidateLoadStoreOps( + ModuleOp &moduleOp, llvm::SmallVector &candidates, + llvm::SmallDenseMap> &candidateHints) { + gcu::ModuleAxisInfoExAnalysis axisInfoExAnalysis(moduleOp); + + llvm::SmallVector loadstoreOps; + moduleOp.walk([&](triton::FuncOp funcOp) { + funcOp.walk([&](Operation *op) { + // Note: try to support nested for loop if needed + TypeSwitch(op).Case( + [&](auto matchOp) { + loadstoreOps.push_back(matchOp.getOperation()); + }); + // Note: try to support other cases like func call if needed + }); + }); + + for (auto op : loadstoreOps) { + if (auto loadOp = dyn_cast(op)) { + auto ptr = loadOp.getPtr(); + auto axisInfoEx = axisInfoExAnalysis.getAxisInfoEx(ptr); + + if (!checkNoScalar(loadOp.getType())) { + LLVM_DEBUG(llvm::dbgs() << "bypass load op due to scalar data type: " + << loadOp << "\n"); + continue; + } + + if (!checkElemType(loadOp.getType())) { + LLVM_DEBUG(llvm::dbgs() + << "bypass load op due to noncandidate element type: " + << loadOp << "\n"); + continue; + } + + if (!checkPtrType(ptr.getType())) { + LLVM_DEBUG(llvm::dbgs() + << "bypass load op due to noncandidate ptr type: " << loadOp + << "\n"); + continue; + } + SmallVector strideHint; + if (!isPtrCandidate(ptr, axisInfoEx, strideHint)) { + LLVM_DEBUG(llvm::dbgs() << "bypass load op due to noncandidate ptr: " + << loadOp << "\n"); + continue; + } + llvm::DenseMap valueToCandiates; + if (loadOp.getMask() && + !isMaskCandidate(loadOp.getMask(), valueToCandiates)) { + LLVM_DEBUG(llvm::dbgs() << "bypass load op due to noncandidate mask: " + << loadOp << "\n"); + continue; + } + + // Great to arrive here + LLVM_DEBUG(llvm::dbgs() << "candidate load op " << loadOp << "\n"); + candidates.push_back(op); + candidateHints.insert(std::make_pair(op, strideHint)); + } else { + assert(isa(op)); + + auto storeOp = dyn_cast(op); + auto ptr = storeOp.getPtr(); + auto axisInfoEx = axisInfoExAnalysis.getAxisInfoEx(ptr); + + if (!checkNoScalar(storeOp.getValue().getType())) { + LLVM_DEBUG(llvm::dbgs() << "bypass store op due to scalar data type: " + << storeOp << "\n"); + continue; + } + + if (!checkElemType(storeOp.getValue().getType())) { + LLVM_DEBUG(llvm::dbgs() + << "bypass store op due to noncandidate element type: " + << storeOp << "\n"); + continue; + } + + if (!checkPtrType(ptr.getType())) { + LLVM_DEBUG(llvm::dbgs() + << "bypass store op due to noncandidate ptr type: " + << storeOp << "\n"); + continue; + } + + SmallVector strideHint; + if (!isPtrCandidate(ptr, axisInfoEx, strideHint)) { + LLVM_DEBUG(llvm::dbgs() << "bypass store op due to noncandidate ptr: " + << storeOp << "\n"); + continue; + } + llvm::DenseMap valueToCandiates; + if (storeOp.getMask() && + !isMaskCandidate(storeOp.getMask(), valueToCandiates)) { + LLVM_DEBUG(llvm::dbgs() << "bypass store op due to noncandidate mask: " + << storeOp << "\n"); + continue; + } + + // Great to arrive here + LLVM_DEBUG(llvm::dbgs() << "candidate store op " << storeOp << "\n"); + candidates.push_back(op); + candidateHints.insert(std::make_pair(op, strideHint)); + } + } +} + +} // namespace gcu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/CMakeLists.txt new file mode 100644 index 000000000..6ff6fde69 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(Analysis) +add_subdirectory(Dialect) +add_subdirectory(Conversion) +add_subdirectory(Transforms) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..96c91446f --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonToGCU) +add_subdirectory(Common) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/CMakeLists.txt new file mode 100644 index 000000000..881d6ccc3 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIRTritonGCUCommon_${arch} + ConstantUtil.cpp + + DEPENDS + + LINK_LIBS PUBLIC + MLIRAffineDialect + MLIRAffineToStandard + MLIRArithDialect + MLIRIR + MLIRLinalgDialect + MLIRMemRefDialect + MLIRSupport + MLIRSideEffectInterfaces + ) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/ConstantUtil.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/ConstantUtil.cpp new file mode 100644 index 000000000..1fb859e42 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/ConstantUtil.cpp @@ -0,0 +1,144 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "ConstantUtil.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace triton { +namespace gcu { + +Value createConstantZero(OpBuilder &builder, Location loc, Type elemType) { + if (elemType.isIntOrIndex()) { + return builder.create(loc, 0, elemType); + } else if (elemType.isF32()) { + return builder.create( + loc, APFloat(llvm::APFloatBase::IEEEsingle(), "0"), + dyn_cast(elemType)); + } else if (elemType.isF16()) { + return builder.create( + loc, APFloat(llvm::APFloatBase::IEEEhalf(), "0"), + dyn_cast(elemType)); + } else if (elemType.isBF16()) { + return builder.create( + loc, APFloat(llvm::APFloatBase::BFloat(), "0"), + dyn_cast(elemType)); + } else if (elemType.isF64()) { + return builder.create( + loc, APFloat(llvm::APFloatBase::IEEEdouble(), "0"), + dyn_cast(elemType)); + } else if (llvm::isa(elemType)) { + return builder.create( + loc, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(), "0"), + dyn_cast(elemType)); + } else if (llvm::isa(elemType)) { + return builder.create( + loc, APFloat(llvm::APFloatBase::Float8E4M3FNUZ(), "0"), + dyn_cast(elemType)); + } else if (llvm::isa(elemType)) { + return builder.create( + loc, APFloat(llvm::APFloatBase::Float8E5M2FNUZ(), "0"), + dyn_cast(elemType)); + } else if (llvm::isa(elemType)) { + return builder.create( + loc, APFloat(llvm::APFloatBase::Float8E4M3FN(), "0"), + dyn_cast(elemType)); + } else if (llvm::isa(elemType)) { + return builder.create( + loc, APFloat(llvm::APFloatBase::Float8E5M2(), "0"), + dyn_cast(elemType)); + } else { + std::string o; + llvm::raw_string_ostream os(o); + elemType.print(os); + llvm_unreachable((o + " is unsupported").c_str()); + } + return Value(); +} + +Value createConstantNaN(OpBuilder &builder, Location loc, Type elemType) { + const llvm::fltSemantics *sem = nullptr; + if (elemType.isF32()) { + sem = &llvm::APFloatBase::IEEEsingle(); + } else if (elemType.isBF16()) { + sem = &llvm::APFloatBase::BFloat(); + } else if (elemType.isF16()) { + sem = &llvm::APFloatBase::IEEEhalf(); + } else if (elemType.isF64()) { + sem = &llvm::APFloatBase::IEEEdouble(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E4M3B11FNUZ(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E4M3FNUZ(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E5M2FNUZ(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E4M3FN(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E5M2(); + } else { + std::string o; + llvm::raw_string_ostream os(o); + elemType.print(os); + llvm_unreachable((o + " is unsupported").c_str()); + } + return builder.create(loc, APFloat::getNaN(*sem), + dyn_cast(elemType)); +} + +Value createConstantInf(OpBuilder &builder, Location loc, Type elemType, + bool isNegative) { + const llvm::fltSemantics *sem = nullptr; + if (elemType.isF32()) { + sem = &llvm::APFloatBase::IEEEsingle(); + } else if (elemType.isBF16()) { + sem = &llvm::APFloatBase::BFloat(); + } else if (elemType.isF16()) { + sem = &llvm::APFloatBase::IEEEhalf(); + } else if (elemType.isF64()) { + sem = &llvm::APFloatBase::IEEEdouble(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E4M3B11FNUZ(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E4M3FNUZ(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E5M2FNUZ(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E4M3FN(); + } else if (llvm::isa(elemType)) { + sem = &llvm::APFloatBase::Float8E5M2(); + } else { + std::string o; + llvm::raw_string_ostream os(o); + elemType.print(os); + llvm_unreachable((o + " is unsupported").c_str()); + } + return builder.create( + loc, APFloat::getInf(*sem, isNegative), dyn_cast(elemType)); +} + +} // namespace gcu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/ConstantUtil.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/ConstantUtil.h new file mode 100644 index 000000000..e0f160d8f --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/Common/ConstantUtil.h @@ -0,0 +1,35 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KURAMA_COMMON_CONSTANT_UTIL_H_ +#define KURAMA_COMMON_CONSTANT_UTIL_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +namespace mlir { +namespace triton { +namespace gcu { +Value createConstantZero(OpBuilder &builder, Location loc, Type elemType); +Value createConstantNaN(OpBuilder &builder, Location loc, Type elemType); +Value createConstantInf(OpBuilder &builder, Location loc, Type elemType, + bool isNegative = false); +} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif // KURAMA_COMMON_CONSTANT_UTIL_H_ diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/CMakeLists.txt new file mode 100644 index 000000000..3a15264dc --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/CMakeLists.txt @@ -0,0 +1,49 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../Common) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +add_mlir_conversion_library(MLIRTritonToGCU_${arch} + TritonToGCU.cpp + ReduceOpToGCU.cpp + ElementwiseFusionOpToGCU.cpp + TritonGCUToGCU/TritonGCUToGCUUtils.cpp + TritonFuncOpFlatten.cpp + TritonFusion.cpp + TritonLoadStoreToDma.cpp + TritonGCULayoutOptimize.cpp + OptimizeDotLayout.cpp + TritonConvertTensorPointer.cpp + Utils.cpp + TritonGCUPingpong/TritonGCUPingpong.cpp + TritonGCUPingpong/MatmulLoopPipeline.cpp + TritonGCUPingpong/OuterLoopPipeline.cpp + TritonGCUPingpong/PipelineExpander.cpp + TritonGCUPingpong/PipeliningUtility.cpp + TritonGPUtoTrtionGCU.cpp + + DEPENDS + TritonGCUTableGen_${arch} + MLIRTritonGCUConversionPassIncGen_${arch} + triton_${arch} + + MemrefExtIR${arch} + MathExtIR${arch} + + LINK_LIBS PUBLIC + TritonGCUAnalysis_${arch} + MemrefExtIR${arch} + MathExtIR${arch} + MLIRAffineDialect + MLIRAffineToStandard + MLIRArithDialect + MLIRComplexDialect + MLIRGPUTransforms + MLIRIR + MLIRLinalgDialect + MLIRMemRefDialect + MLIRPass + MLIRSupport + MLIRSideEffectInterfaces + MLIRTransforms + MLIRTritonGCUCommon_${arch} + ) + +target_include_directories(MLIRTritonToGCU_${arch} PRIVATE ../Common) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/ElementwiseFusionOpToGCU.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/ElementwiseFusionOpToGCU.cpp new file mode 100644 index 000000000..cc3e419db --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/ElementwiseFusionOpToGCU.cpp @@ -0,0 +1,789 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "Dialect/GCU/IR/Dialect.h" +#include "Dialect/MemrefExt/IR/MemrefExt.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "PatternTritonGPUOpToGCU.h" + +#include "Analysis/FirstLastUserAnalysis.h" +#include "TritonGCUToGCU/TritionToGCUBase.h" +#include "TritonGCUToGCU/TritonGCUToGCUUtils.h" +#include "Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; + +namespace { + +static constexpr unsigned vectorizationMaxLength = 16384; + +bool hasBuiltinImpl(Operation *op) { + auto isArithOp = + isa(op); + auto isMathOp = + isa(op); + + if (auto ewOp = dyn_cast(op)) { + return ( + ewOp.getSymbol() != "__nv_ffs" && ewOp.getSymbol() != "__nv_isnanf" && + ewOp.getSymbol() != "__nv_isinff" && + ewOp.getSymbol() != "__nv_finitef" && ewOp.getSymbol() != "__nv_fmodf"); + } + return isArithOp || isMathOp; +} + +std::string getBuiltinOpSymbol(Operation *op) { + if (isa(op)) { + auto ewOp = dyn_cast(op); + auto name = ewOp.getSymbol(); + if (name == "__nv_fmaxf") { + return "maximumf"; + } else if (name == "__nv_fminf") { + return "minimumf"; + } else if (name == "__nv_floorf") { + return "floor"; + } else if (name == "__nv_min") { + return "minsi"; + } else if (name == "__nv_max") { + return "maxsi"; + } else if (name == "__nv_umin") { + return "minui"; + } else if (name == "__nv_umax") { + return "maxui"; + } else if (name == "__nv_powf") { + return "powf"; + } else if (name == "__nv_powif") { + return "fpowi"; + } else if (name == "__nv_log2f") { + return "log2"; + } else if (name == "__nv_exp2f") { + return "exp2"; + } else if (name == "__nv_atan2f") { + return "atan2"; + } else if (name == "__nv_tanhf") { + return "tanh"; + } else if (name == "__nv_erff") { + return "erf"; + } else if (name == "__nv_sqrtf") { + return "sqrt"; + } else if (name == "__nv_rsqrtf") { + return "rsqrt"; + } else if (name == "__nv_rintf") { + return "roundeven"; + } else { + llvm_unreachable( + ("unsupported extern elementwise: " + name).str().c_str()); + } + } else { + return op->getName().getStringRef().split('.').second.str(); + } +} + +struct GCUElementwiseFusionOpLowering + : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::gcu::ElementwiseFusionRegionOp op, + SharedConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto totalNumElems = + triton::gcu::getTotalElemsPerThread(op.getResultTypes().front()); + + DenseSet elementTypeSet; + SmallVector results; + SmallVector outputs; + for (auto type : op.getResultTypes()) { + auto resultType = + dyn_cast(getTypeConverter()->convertType(type)); + auto result = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + results.push_back(result); + auto elementTy = resultType.getElementType(); + elementTypeSet.insert(elementTy); + if (elementTy.isInteger(1)) { + outputs.emplace_back(rewriter.create( + loc, + MemRefType::get(ArrayRef{totalNumElems}, + rewriter.getIntegerType(8)), + rewriter.create( + loc, + MemRefType::get(ArrayRef{ShapedType::kDynamic}, + rewriter.getIntegerType(8)), + rewriter.create( + loc, + mlir::gcu::PtrType::get(rewriter.getContext(), elementTy), + result)), + 0, ArrayRef{totalNumElems}, ArrayRef{1})); + } else { + outputs.emplace_back(rewriter.create( + loc, MemRefType::get(ArrayRef{totalNumElems}, elementTy), + result, 0, ArrayRef{totalNumElems}, ArrayRef{1})); + } + } + + SmallVector inputs; + SmallVector elementTypes; + for (auto operand : adaptor.getOperands()) { + auto operandType = operand.getType(); + if (isa(operandType)) { + auto elementTy = cast(operandType).getElementType(); + elementTypes.push_back(elementTy); + elementTypeSet.insert(elementTy); + if (elementTy.isInteger(1)) { + inputs.emplace_back(rewriter.create( + loc, + MemRefType::get(ArrayRef{totalNumElems}, + rewriter.getIntegerType(8)), + rewriter.create( + loc, + MemRefType::get(ArrayRef{ShapedType::kDynamic}, + rewriter.getIntegerType(8)), + rewriter.create( + loc, + mlir::gcu::PtrType::get(rewriter.getContext(), elementTy), + operand)), + 0, ArrayRef{totalNumElems}, ArrayRef{1})); + } else { + inputs.emplace_back(rewriter.create( + loc, MemRefType::get(ArrayRef{totalNumElems}, elementTy), + operand, 0, ArrayRef{totalNumElems}, + ArrayRef{1})); + } + } else { + elementTypeSet.insert(operandType); + elementTypes.push_back(operandType); + inputs.push_back(operand); + } + } + + for (auto &o : op.getRegion().back().without_terminator()) { + for (auto type : o.getResultTypes()) { + auto elementTy = getTypeConverter()->convertType( + cast(type).getElementType()); + if (!elementTy.isInteger(1)) { + elementTypeSet.insert(elementTy); + } + } + } + + unsigned maxBpe = 1; + unsigned minBpe = targetInfo[GCU300].supportI64 ? 8 : 4; + for (auto elementTy : elementTypeSet) { + auto bpe = mlir::triton::gcu::getBpe(elementTy); + maxBpe = bpe > maxBpe ? bpe : maxBpe; + minBpe = bpe < minBpe ? bpe : minBpe; + } + unsigned numVacc = maxBpe / minBpe; + assert(numVacc <= 4); + (void)numVacc; + unsigned vectorLength = targetInfo[GCU300].vaccSizeInBytes * + targetInfo[GCU300].preferVaccNum / maxBpe; + + // TODO(peng.tian) Remove after enable some optimization. + if (llvm::all_of( + adaptor.getOperands(), + [](auto operand) { return isa(operand.getType()); }) && + op.getRegion().hasOneBlock()) { + auto &ops = op.getRegion().front().getOperations(); + if (totalNumElems > vectorizationMaxLength / maxBpe) { + if (ops.size() == 2 && hasBuiltinImpl(&ops.front())) { + SmallVector builtinOperands; + for (auto operand : ops.front().getOperands()) { + builtinOperands.push_back( + inputs[dyn_cast(operand).getArgNumber()]); + } + auto opName = getBuiltinOpSymbol(&ops.front()); + rewriter.create( + loc, outputs[0], builtinOperands, rewriter.getStringAttr(opName)); + rewriter.replaceOp(op, results); + return success(); + } + if (ops.size() == 3) { + auto &op0 = ops.front(); + auto &elementWiseOp = *std::next(ops.begin(), 1); + if (isa(op0) && + isa( + &elementWiseOp) && + llvm::hasSingleElement(op0.getUsers()) && + *op0.user_begin() == &elementWiseOp) { + SmallVector builtinOperands; + for (auto operand : elementWiseOp.getOperands()) { + if (isa(operand)) { + builtinOperands.push_back( + inputs[cast(operand).getArgNumber()]); + } else { + assert(operand.getDefiningOp() == &op0); + if (auto constantOp = dyn_cast(op0)) { + auto elementTy = getTypeConverter()->convertType( + dyn_cast(operand.getType()).getElementType()); + Value builtinOperand = rewriter.create( + loc, elementTy, + dyn_cast(constantOp.getValue()) + .getSplatValue()); + if (elementTy.isInteger(1)) { + builtinOperand = rewriter.create( + loc, rewriter.getIntegerType(8), builtinOperand); + } + builtinOperands.push_back(builtinOperand); + } else if (auto splatOp = dyn_cast(op0)) { + assert(isa(splatOp.getSrc())); + Value builtinOperand = + inputs[cast(splatOp.getSrc()) + .getArgNumber()]; + if (splatOp.getSrc().getType().isInteger(1)) { + builtinOperand = rewriter.create( + loc, rewriter.getIntegerType(8), builtinOperand); + } + builtinOperands.push_back(builtinOperand); + } + } + } + auto opName = getBuiltinOpSymbol(&elementWiseOp); + rewriter.create( + loc, outputs[0], builtinOperands, + rewriter.getStringAttr(opName)); + rewriter.replaceOp(op, results); + return success(); + } + } + } + } + + constexpr unsigned loopUnrollTime = 1; + auto loopLimit = ceil(totalNumElems, vectorLength); + auto loopCnt = loopUnrollTime > loopLimit ? loopLimit : loopUnrollTime; + + auto insertPoint = rewriter.saveInsertionPoint(); + + SmallVector operandMaps(loopCnt); + SmallVector initValues; + Value step; + + for (auto &o : op.getRegion().back().without_terminator()) { + if (auto makeRangeOp = dyn_cast(o)) { + auto startIdx = makeRangeOp.getStart(); + auto elementTy = makeRangeOp.getResult().getType().getElementType(); + Value start = + rewriter.create(loc, startIdx, elementTy) + .getResult(); + if (!getSlicedAxies(makeRangeOp.getType()).empty()) { + start = rewriter.create( + loc, + rewriter.create( + loc, + rewriter.create( + loc, elementTy, + getWarpIds(rewriter, loc, makeRangeOp.getType()).front()), + rewriter.create(loc, totalNumElems, + elementTy)), + start); + } + initValues.emplace_back( + rewriter + .create( + loc, + VectorType::get(ArrayRef{vectorLength}, elementTy), + start) + .getResult()); + } + } + + rewriter.create( + loc, rewriter.create(loc, 0), + rewriter.create(loc, totalNumElems), + rewriter.create(loc, vectorLength * loopCnt), + initValues, + [&](OpBuilder &builder, Location loc, Value iter, ValueRange iterArgs) { + SmallVector args(iterArgs); + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < inputs.size(); ++j) { + if (isa(inputs[j].getType())) { + auto elementTy = elementTypes[j]; + if (elementTy.isInteger(1)) { + elementTy = builder.getIntegerType(8); + } + operandMaps[i].map( + op.getRegion().getArgument(j), + builder.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + elementTy), + inputs[j], + ValueRange{builder.create( + loc, + builder.create( + loc, i * vectorLength), + iter)})); + } else { + operandMaps[i].map(op.getRegion().getArgument(j), inputs[j]); + } + } + } + for (unsigned i = 0; i < loopCnt; ++i) { + unsigned argIndex = 0; + for (auto &o : op.getRegion().back().without_terminator()) { + if (auto bitcastOp = dyn_cast(o)) { + handleBitcastOp(bitcastOp, builder, operandMaps[i], + vectorLength); + } else if (auto splatOp = dyn_cast(o)) { + if (i == 0) { + OpBuilder::InsertionGuard guard(builder); + builder.restoreInsertionPoint(insertPoint); + handleSplatOp(splatOp, builder, operandMaps[i], vectorLength); + } else { + operandMaps[i].map( + splatOp.getResult(), + operandMaps[0].lookup(splatOp.getResult())); + } + } else if (auto constantOp = dyn_cast(o)) { + if (i == 0) { + OpBuilder::InsertionGuard guard(builder); + builder.restoreInsertionPoint(insertPoint); + handleConstantOp(constantOp, builder, operandMaps[i], + vectorLength); + } else { + operandMaps[i].map( + constantOp.getResult(), + operandMaps[0].lookup(constantOp.getResult())); + } + } else if (auto externElementwiseOp = + dyn_cast(o)) { + handleExternElementwiseOp(externElementwiseOp, builder, + operandMaps[i], vectorLength); + } else if (auto makeRangeOp = dyn_cast(o)) { + if (i == 0) { + auto elementTy = + makeRangeOp.getResult().getType().getElementType(); + step = rewriter.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + elementTy), + rewriter.create(loc, vectorLength, + elementTy)); + } + operandMaps[i].map(makeRangeOp.getResult(), args[argIndex]); + args[argIndex] = + rewriter.create(loc, args[argIndex], step); + ++argIndex; + } else { + handleCommonOp(o, builder, operandMaps[i], vectorLength); + } + } + } + if (auto yieldOp = cast( + op.getRegion().back().getTerminator())) { + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < yieldOp.getNumOperands(); ++j) { + auto v = operandMaps[i].lookup(yieldOp.getOperand(j)); + if (dyn_cast(v.getType()) + .getElementType() + .isInteger(1)) { + OpBuilder::InsertionGuard guard(builder); + auto defOp = v.getDefiningOp(); + assert(defOp); + builder.setInsertionPointAfter(defOp); + v = builder + .create( + loc, + VectorType::get(ArrayRef{vectorLength}, + builder.getIntegerType(8)), + v) + .getResult(0); + } + builder.create( + loc, v, outputs[j], + ValueRange{builder.create( + loc, + builder.create( + loc, i * vectorLength), + iter)}); + } + } + builder.create(loc, args); + } + }); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, results); + return success(); + } + +private: + void handleConstantOp(arith::ConstantOp op, OpBuilder &builder, + IRMapping &map, unsigned vectorLength) const { + auto loc = op.getLoc(); + auto elementTy = getTypeConverter()->convertType( + dyn_cast(op.getType()).getElementType()); + auto vectorType = + VectorType::get(ArrayRef{vectorLength}, elementTy); + Value v; + if (elementTy.isInteger(1)) { + v = builder.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + builder.getIntegerType(8)), + builder.create( + loc, builder.getIntegerType(8), + builder.create( + loc, elementTy, + dyn_cast(op.getValue()) + .getSplatValue()))); + } else { + v = builder.create( + loc, vectorType, + builder.create( + loc, elementTy, + dyn_cast(op.getValue()) + .getSplatValue())); + } + map.map(op.getResult(), v); + } + + void handleSplatOp(triton::SplatOp op, OpBuilder &builder, IRMapping &map, + unsigned vectorLength) const { + auto loc = op.getLoc(); + auto elementTy = getTypeConverter()->convertType( + dyn_cast(op.getType()).getElementType()); + Value v; + if (elementTy.isInteger(1)) { + v = builder + .create( + loc, + VectorType::get(ArrayRef{vectorLength}, + builder.getIntegerType(8)), + builder.create(loc, builder.getIntegerType(8), + map.lookup(op.getSrc()))) + .getResult(); + } else { + v = builder.create( + loc, VectorType::get(ArrayRef{vectorLength}, elementTy), + map.lookup(op.getSrc())); + } + map.map(op.getResult(), v); + } + + void handleBitcastOp(triton::BitcastOp op, OpBuilder &builder, IRMapping &map, + unsigned vectorLength) const { + auto loc = op.getLoc(); + auto vectorType = VectorType::get( + ArrayRef{vectorLength}, + getTypeConverter()->convertType( + dyn_cast(op.getType()).getElementType())); + auto newOp = builder.create(loc, vectorType, + map.lookup(op.getOperand())); + map.map(op.getResult(), newOp.getResult()); + } + + void handleExternElementwiseOp(triton::ExternElementwiseOp op, + OpBuilder &builder, IRMapping &map, + unsigned vectorLength) const { + SmallVector operands; + auto loc = op.getLoc(); + for (auto operand : op.getOperands()) { + operands.push_back(map.lookup(operand)); + } + auto symbol = op.getSymbol(); + Operation *newOp; + if (symbol == "__nv_fmaxf") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_fminf") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_floorf") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_min") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_max") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_umin") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_umax") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_powf") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_powif") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_log2f") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_exp2f") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_ffs") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_erff") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_tanhf") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_fdiv_rn") { + auto resType = operands.front().getType(); + auto vectorType = dyn_cast(resType); + auto elemType = vectorType.getElementType(); + auto constValue = builder.create( + loc, DenseElementsAttr::get(vectorType, + builder.getFloatAttr(elemType, 0.5))); + + auto div = builder.create(loc, resType, operands); + newOp = builder.create( + loc, resType, builder.create(loc, div, constValue)); + } else if (symbol == "__nv_fdiv_rz") { + auto resType = operands.front().getType(); + auto vectorType = dyn_cast(resType); + auto elemType = vectorType.getElementType(); + auto zero = builder.create( + loc, DenseElementsAttr::get(vectorType, + builder.getFloatAttr(elemType, 0))); + + auto div = builder.create(loc, resType, operands); + + newOp = builder.create( + loc, + builder.create(loc, arith::CmpFPredicate::OGE, div, + zero), + builder.create(loc, resType, div), + builder.create(loc, resType, div)); + } else if (symbol == "__nv_fmodf") { + auto resType = operands.front().getType(); + auto vectorType = dyn_cast(resType); + auto elemType = vectorType.getElementType(); + auto zero = builder.create( + loc, DenseElementsAttr::get(vectorType, + builder.getFloatAttr(elemType, 0))); + + auto div = builder.create(loc, resType, operands); + + auto vfloor = builder.create( + loc, + builder.create(loc, arith::CmpFPredicate::OGE, div, + zero), + builder.create(loc, resType, div), + builder.create(loc, resType, div)); + + newOp = builder.create( + loc, operands[0], + builder.create(loc, vfloor, operands[1])); + } else if (symbol == "__nv_truncf") { + auto resType = operands.front().getType(); + auto vectorType = dyn_cast(resType); + auto elemType = vectorType.getElementType(); + auto zero = builder.create( + loc, DenseElementsAttr::get(vectorType, + builder.getFloatAttr(elemType, 0))); + auto cmp = builder.create(loc, arith::CmpFPredicate::OGE, + operands[0], zero); + newOp = builder.create( + loc, cmp, builder.create(loc, resType, operands[0]), + builder.create(loc, resType, operands[0])); + } else if (symbol == "__nv_sqrtf") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_rsqrtf") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else if (symbol == "__nv_isnanf") { + // isnan(x) -> cmpf uno x, x, e.g. + // cmpf uno 1.0, 1.0 -> false + // cmpf uno nan, nan -> true + auto resVectorType = VectorType::get( + ArrayRef{vectorLength}, + dyn_cast(op.getResult().getType()).getElementType()); + auto cmpFOp = builder.create( + loc, arith::CmpFPredicate::UNO, operands.front(), operands.front()); + newOp = builder.create(loc, resVectorType, cmpFOp); + } else if (symbol == "__nv_isinff") { + auto elemType = + dyn_cast(operands.front().getType()).getElementType(); + auto constPositiveInf = + triton::gcu::createConstantInf(builder, loc, elemType); + auto vectorType = + VectorType::get(ArrayRef{vectorLength}, elemType); + auto broadCastPositiveInfOp = builder.create( + loc, vectorType, constPositiveInf); + auto positiveInput = builder.create(loc, operands.front()); + auto cmpFOpPositiveInf = + builder.create(loc, arith::CmpFPredicate::OEQ, + positiveInput, broadCastPositiveInfOp); + auto resVectorType = VectorType::get( + ArrayRef{vectorLength}, + dyn_cast(op.getResult().getType()).getElementType()); + newOp = + builder.create(loc, resVectorType, cmpFOpPositiveInf); + } else if (symbol == "__nv_finitef") { + // isfinitf(x) -> + // %1 = fcmp uno %x, %x + // %2 = absf %x + // %3 = fcmp oeq %2, +inf + // %4 = ori %1, %3 + // %5 = xori %4, true + auto isNanfOp = builder.create( + loc, arith::CmpFPredicate::UNO, operands.front(), operands.front()); + auto elemType = + dyn_cast(operands.front().getType()).getElementType(); + auto constPositiveInf = + triton::gcu::createConstantInf(builder, loc, elemType); + auto vectorType = + VectorType::get(ArrayRef{vectorLength}, elemType); + auto broadCastPositiveInfOp = builder.create( + loc, vectorType, constPositiveInf); + auto positiveInput = builder.create(loc, operands.front()); + auto isInffOp = + builder.create(loc, arith::CmpFPredicate::OEQ, + positiveInput, broadCastPositiveInfOp); + auto oneMask = + builder.create(loc, 1, builder.getI32Type()) + .getResult(); + auto oneMaskVectorType = VectorType::get(ArrayRef{vectorLength}, + builder.getI32Type()); + auto oneMaskVec = + builder.create(loc, oneMaskVectorType, oneMask); + auto isNanfOrInffOp = + builder.create(loc, isNanfOp, isInffOp); + auto resVectorType = VectorType::get( + ArrayRef{vectorLength}, + dyn_cast(op.getResult().getType()).getElementType()); + auto isNanfOrInffOpI32 = + builder.create(loc, resVectorType, isNanfOrInffOp); + newOp = builder.create(loc, isNanfOrInffOpI32, oneMaskVec); + } else if (symbol == "__nv_rintf") { + newOp = builder.create(loc, operands.front().getType(), + operands); + } else { + llvm_unreachable( + ("unsupported extern elementwise: " + symbol).str().c_str()); + } + map.map(op.getResult(), newOp->getResult(0)); + } + + void handleCommonOp(Operation &op, OpBuilder &builder, IRMapping &map, + unsigned vectorLength) const { + Operation *newOp; + if (auto selectOp = dyn_cast(op)) { + auto condition = selectOp.getCondition(); + auto mapValue = map.lookup(condition); + if (cast(mapValue.getType()).getElementType().isInteger(8)) { + map.map(condition, + builder + .create( + op.getLoc(), + VectorType::get(ArrayRef{vectorLength}, + builder.getIntegerType(1)), + mapValue) + .getResult(0)); + newOp = builder.clone(op, map); + map.map(condition, mapValue); + } else { + newOp = builder.clone(op, map); + } + } else if (auto cvtOp = dyn_cast(op)) { + if (cast(cvtOp.getIn().getType()) + .getElementType() + .isInteger(1) && + cast(cvtOp.getOut().getType()) + .getElementType() + .isInteger(8)) { + map.map(cvtOp.getOut(), map.lookup(cvtOp.getIn())); + return; + } else { + newOp = builder.clone(op, map); + } + } else { + newOp = builder.clone(op, map); + } + SmallVector resultTypes; + auto typeInterface = dyn_cast(newOp); + if (!typeInterface || + failed(typeInterface.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), resultTypes))) { + resultTypes.clear(); + llvm::transform( + op.getResultTypes(), std::back_inserter(resultTypes), + [&](auto resultType) { + return VectorType::get( + ArrayRef{vectorLength}, + getTypeConverter()->convertType( + dyn_cast(resultType).getElementType())); + }); + } + + for (auto [resultType, result, newResult] : + llvm::zip(resultTypes, op.getResults(), newOp->getResults())) { + newResult.setType(resultType); + if (isa(op)) { + map.map(result, builder + .create( + op.getLoc(), + VectorType::get(ArrayRef{vectorLength}, + builder.getIntegerType(8)), + newResult) + .getResult(0)); + } else { + map.map(result, newResult); + } + } + } +}; +} // namespace + +void mlir::triton::populateElementwiseFusionOpToGCUPatterns( + const TypeConverter &converter, RewritePatternSet &patterns, + gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) { + patterns.add(converter, patterns.getContext(), + userAnalysis, replaced2Origin); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/OptimizeDotLayout.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/OptimizeDotLayout.cpp new file mode 100644 index 000000000..f2848a9be --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/OptimizeDotLayout.cpp @@ -0,0 +1,322 @@ +/* + * Copyright 2020 - 2022 Enflame.All Rights Reserved. + * + */ + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +#include "Utils.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +namespace mlir { +#define GEN_PASS_DEF_TRITONGCUDOTLAYOUTOPTIMIZEPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir + +#define DEBUG_TYPE "triton-gcu-dot-layout-optimize" +namespace { +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +struct TritonGCUDotLayoutOptimizePass + : public mlir::impl::TritonGCUDotLayoutOptimizePassBase< + TritonGCUDotLayoutOptimizePass> { + using Base::Base; + + void runOnOperation() override; + void RefineDotLayout(); + void reWriteDotLayout(triton::DotOp op); + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; + +/* fit for gcu dot layout*/ +void TritonGCUDotLayoutOptimizePass::reWriteDotLayout(triton::DotOp dot) { + auto loc = dot.getLoc(); + auto retType = dyn_cast(dot.getType()); + auto retShape = retType.getShape(); + int rank = retShape.size(); + auto retLayout = retType.getEncoding(); + if (!isa(retLayout)) { + LLVM_DEBUG({ + llvm::dbgs() << "bad dot for gcu layout \n"; + dot.dump(); + }); + return; + } + auto dOriBlockEncoding = + dyn_cast(retLayout); + SmallVector warpsPerCTA = dOriBlockEncoding.getWarpsPerCTA(); + assert((static_cast(rank) == warpsPerCTA.size()) && + "warpsPerCTA size is not equal to rank!\n"); + // check data type + + auto numWarpsPerCTA = product(warpsPerCTA); + auto totalElement = product(retShape); + (void)totalElement; + (void)numWarpsPerCTA; + assert((totalElement >= numWarpsPerCTA) && + "hi your data is too little, please do't config so large NumberWarps " + "!\n"); + // try to get alignment to acore + auto dotOutElementType = retType.getElementType(); + Value inputA = dot.getA(); + Value inputB = dot.getB(); + Value inputC = dot.getC(); + auto tTypeA = dyn_cast(inputA.getType()); + auto k = tTypeA.getShape()[1]; + auto dotM = retShape[rank - 2]; + auto dotN = retShape[rank - 1]; + auto dotSrcElementType = tTypeA.getElementType(); + auto totalMatmulWarp = warpsPerCTA[rank - 2] * warpsPerCTA[rank - 1]; + if (totalMatmulWarp == 1) { + return; + } + LLVM_DEBUG({ + llvm::dbgs() << "refine gcu dot layout \n"; + dot.dump(); + }); + if (dotM < dotN) { + if (dotN / dotM > totalMatmulWarp) { + warpsPerCTA[rank - 2] = 1; + warpsPerCTA[rank - 1] = totalMatmulWarp; + } else { + if (totalMatmulWarp == 8) { + warpsPerCTA[rank - 2] = 2; + warpsPerCTA[rank - 1] = 4; + } else if (totalMatmulWarp == 4) { + warpsPerCTA[rank - 2] = 2; + warpsPerCTA[rank - 1] = 2; + } else if (totalMatmulWarp == 2) { + warpsPerCTA[rank - 2] = 1; + warpsPerCTA[rank - 1] = 2; + } + } + } else if (dotN < dotM) { + if (dotM / dotN > totalMatmulWarp) { + warpsPerCTA[rank - 2] = totalMatmulWarp; + warpsPerCTA[rank - 1] = 1; + } else { + if (totalMatmulWarp == 8) { + warpsPerCTA[rank - 2] = 4; + warpsPerCTA[rank - 1] = 2; + } else if (totalMatmulWarp == 4) { + warpsPerCTA[rank - 2] = 2; + warpsPerCTA[rank - 1] = 2; + } else if (totalMatmulWarp == 2) { + warpsPerCTA[rank - 2] = 2; + warpsPerCTA[rank - 1] = 1; + } + } + } else { + if (totalMatmulWarp == 8) { + warpsPerCTA[rank - 2] = 4; + warpsPerCTA[rank - 1] = 2; + } else if (totalMatmulWarp == 4) { + warpsPerCTA[rank - 2] = 2; + warpsPerCTA[rank - 1] = 2; + } else if (totalMatmulWarp == 2) { + warpsPerCTA[rank - 2] = 2; + warpsPerCTA[rank - 1] = 1; + } + } + // dot cost is high priority + int64_t perWarpM = dotM / warpsPerCTA[rank - 2]; + int64_t perWarpN = dotN / warpsPerCTA[rank - 1]; + if (perWarpM < 1 || perWarpN < 1) { + LLVM_DEBUG({ + llvm::dbgs() << "bad dot for gcu layout \n"; + dot.dump(); + }); + return; + } + // try best to match acore alignment and buffer size balance + if (dotSrcElementType.isBF16() && dotOutElementType.isBF16()) { + if (perWarpM < 32 && perWarpN > 64 && warpsPerCTA[rank - 2] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] / 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] * 2; + } else if (perWarpN < 64 && perWarpM > 32 && warpsPerCTA[rank - 1] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] * 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] / 2; + } + + } else if (dotSrcElementType.isF16() && dotOutElementType.isF16()) { + if ((k % 32 == 0) && perWarpM < 32 && perWarpN > 128 && + warpsPerCTA[rank - 2] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] / 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] * 2; + } else if ((k % 64 == 0) && perWarpM < 32 && perWarpN > 64 && + warpsPerCTA[rank - 2] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] / 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] * 2; + } else if ((k % 32 == 0) && perWarpN < 64 && perWarpM > 128 && + warpsPerCTA[rank - 1] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] * 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] / 2; + } else if ((k % 64 == 0) && perWarpN < 64 && perWarpM > 32 && + warpsPerCTA[rank - 1] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] * 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] / 2; + } + } else if (dotSrcElementType.isF32() && dotOutElementType.isF32()) { + // acore + if (perWarpM < 64 && perWarpN > 64 && warpsPerCTA[rank - 2] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] / 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] * 2; + } else if (perWarpN < 64 && perWarpM > 64 && warpsPerCTA[rank - 1] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] * 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] / 2; + } + } else if (dotSrcElementType.isF16() && dotOutElementType.isF32()) { + if (perWarpM < 32 && (k % 64 == 0) && perWarpN > 64 && + warpsPerCTA[rank - 2] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] / 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] * 2; + } else if (perWarpM < 64 && (k % 32 == 0) && perWarpN > 128 && + warpsPerCTA[rank - 2] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] / 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] * 2; + } else if (perWarpM < 128 && (k % 32 == 0) && perWarpN > 64 && + warpsPerCTA[rank - 2] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] / 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] * 2; + } else if ((k % 64 == 0) && perWarpN < 64 && perWarpM > 32 && + warpsPerCTA[rank - 1] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] * 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] / 2; + } else if ((k % 32 == 0) && perWarpN < 64 && perWarpM > 64 && + warpsPerCTA[rank - 1] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] * 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] / 2; + } + } else if (dotSrcElementType.isBF16() && dotOutElementType.isF32()) { + // acore + if (perWarpM < 32 && perWarpN > 64 && warpsPerCTA[rank - 2] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] / 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] * 2; + } else if (perWarpN < 64 && perWarpM > 32 && warpsPerCTA[rank - 1] >= 2) { + warpsPerCTA[rank - 2] = warpsPerCTA[rank - 2] * 2; + warpsPerCTA[rank - 1] = warpsPerCTA[rank - 1] / 2; + } + } + // TODO(xingxing): refine 400 if need + SmallVector origonWarpsPerCTA = dOriBlockEncoding.getWarpsPerCTA(); + if (origonWarpsPerCTA == warpsPerCTA) { + LLVM_DEBUG({ + llvm::dbgs() << "no need or no Opportunity to refine dot layout\n"; + dot.dump(); + }); + return; + } + LLVM_DEBUG({ + llvm::dbgs() << "hi refine dot layout\n"; + dot.dump(); + }); + OpBuilder rewriter(dot); + Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( + dot.getContext(), + llvm::ArrayRef(dOriBlockEncoding.getSizePerThread()), + llvm::ArrayRef(dOriBlockEncoding.getThreadsPerWarp()), + llvm::ArrayRef(warpsPerCTA), + llvm::ArrayRef(dOriBlockEncoding.getOrder()), + dOriBlockEncoding.getCTALayout()); + // new A + Attribute newAencoding = triton::gpu::DotOperandEncodingAttr::get( + dot.getContext(), 0, dEncoding, tTypeA.getElementType()); + auto dstAType = RankedTensorType::get(tTypeA.getShape(), + tTypeA.getElementType(), newAencoding); + Value newA = + rewriter.create(loc, dstAType, inputA); + // new B + auto tTypeB = dyn_cast(inputB.getType()); + Attribute newBencoding = triton::gpu::DotOperandEncodingAttr::get( + dot.getContext(), 1, dEncoding, tTypeB.getElementType()); + auto dstBType = RankedTensorType::get(tTypeB.getShape(), + tTypeB.getElementType(), newBencoding); + Value newB = + rewriter.create(loc, dstBType, inputB); + // new C + auto tTypeC = dyn_cast(inputC.getType()); + RankedTensorType dstCType = RankedTensorType::get( + tTypeC.getShape(), tTypeC.getElementType(), dEncoding); + auto newC = + rewriter.create(loc, dstCType, inputC); + + // new retType + auto newRetType = + RankedTensorType::get(retShape, retType.getElementType(), dEncoding); + auto newDot = rewriter.create(loc, newRetType, newA, newB, + newC, dot.getInputPrecision(), + dot.getMaxNumImpreciseAcc()); + auto newOp = newDot.getOperation(); + for (const NamedAttribute attr : dot->getAttrs()) + if (!newOp->hasAttr(attr.getName())) + newOp->setAttr(attr.getName(), attr.getValue()); + auto newResult = rewriter.create( + loc, dot.getType(), newDot.getResult()); + dot.getResult().replaceAllUsesWith(newResult.getResult()); + dot.erase(); +} + +void TritonGCUDotLayoutOptimizePass::RefineDotLayout() { + auto trionModule = getOperation(); + llvm::SmallVector dotList; + // if had reduce or Scanop it's result is a slice layout skip which need some + // case to trace performance + // bool IsOnlyDotAnchor = true; + // trionModule.walk([&](mlir::Operation *op) { + // if (isa(op) || isa(op)) { + // IsOnlyDotAnchor = false; + // return; + // } + // }); + // if (!IsOnlyDotAnchor) { + // LLVM_DEBUG(llvm::dbgs() + // << "had reduce or Scanop it's result is a slice layout skip " + // " need some case to trace performance\n"); + // return; + // } + dotList.clear(); + trionModule.walk([&](DotOp dot) { dotList.push_back(dot); }); + for (auto &dot : dotList) { + auto retType = dyn_cast(dot.getType()); + int rank = retType.getShape().size(); + if (rank > 2) { + // need test case for 3D dot + continue; + } + // refine dot layout if need + reWriteDotLayout(dot); + } +} + +} // namespace +using namespace mlir; +void TritonGCUDotLayoutOptimizePass::runOnOperation() { + LLVM_DEBUG(llvm::dbgs() << "TritonGCUDotLayoutOptimizePass\n"); + RefineDotLayout(); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/PatternTritonGPUOpToGCU.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/PatternTritonGPUOpToGCU.h new file mode 100644 index 000000000..aa118a160 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/PatternTritonGPUOpToGCU.h @@ -0,0 +1,57 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef KURANA_TRITON_GPU_OP_TO_GCU_CONVERSION_H +#define KURANA_TRITON_GPU_OP_TO_GCU_CONVERSION_H + +#include +#include + +#include "TritonGCUToGCU/TritionToGCUBase.h" + +namespace { +enum GCUArch { GCU300 = 0, GCU400 }; + +struct GCUInfo { + unsigned vaccSizeInBytes; + bool supportI64; + unsigned preferVaccNum; +}; + +static const GCUInfo targetInfo[] = {{128, false, 4}, {512, false, 1}}; +} // namespace + +namespace mlir { +namespace triton { + +namespace gcu { +class FirstLastUserAnalysis; +} + +void populateReduceOpToGCUPatterns( + const TypeConverter &converter, RewritePatternSet &patterns, + gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin); + +void populateElementwiseFusionOpToGCUPatterns( + const TypeConverter &converter, RewritePatternSet &patterns, + gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/ReduceOpToGCU.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/ReduceOpToGCU.cpp new file mode 100644 index 000000000..22fb420ea --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/ReduceOpToGCU.cpp @@ -0,0 +1,1739 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "Dialect/GCU/IR/Dialect.h" +#include "Dialect/MemrefExt/IR/MemrefExt.h" +#include "PatternTritonGPUOpToGCU.h" + +#include "Analysis/FirstLastUserAnalysis.h" +#include "TritonGCUToGCU/TritonGCUToGCUUtils.h" +#include "Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +namespace { + +template +bool patternMatch(Block::iterator beg, Block::iterator end) { + if (beg == end) + return false; + return isa(&*beg) && ++beg == end; +} + +template +bool patternMatch(Block::iterator beg, Block::iterator end) { + if (beg == end) + return false; + auto ret = isa(&*beg); + return ret ? patternMatch(++beg, end) : ret; +} + +template +bool patternMatch(Block::OpListType *opList) { + return patternMatch(opList->begin(), opList->end()); +} + +std::optional +matchReduceCombiningKind(Region &combineOp) { + auto &opList = combineOp.front().getOperations(); + if (opList.size() != 2) { + return std::nullopt; + } + auto elementWiseOp = &opList.front(); + return TypeSwitch>( + elementWiseOp) + .Case( + [&](auto externElementwiseOp) + -> std::optional { + auto symbol = externElementwiseOp.getSymbol(); + if (symbol == "__nv_fmaxf") { + return vector::CombiningKind::MAXIMUMF; + } else if (symbol == "__nv_max") { + return vector::CombiningKind::MAXSI; + } else if (symbol == "__nv_umax") { + return vector::CombiningKind::MAXUI; + } else if (symbol == "__nv_fminf") { + return vector::CombiningKind::MINIMUMF; + } else if (symbol == "__nv_min") { + return vector::CombiningKind::MINSI; + } else if (symbol == "__nv_umin") { + return vector::CombiningKind::MINUI; + } else { + return std::nullopt; + } + }) + .Case( + [&](auto op) { return vector::CombiningKind::ADD; }) + .Case( + [&](auto op) { return vector::CombiningKind::MUL; }) + .Case( + [&](auto op) { return vector::CombiningKind::MAXSI; }) + .Case( + [&](auto op) { return vector::CombiningKind::MAXUI; }) + .Case( + [&](auto op) { return vector::CombiningKind::MAXNUMF; }) + .Case( + [&](auto op) { return vector::CombiningKind::MAXIMUMF; }) + .Case( + [&](auto op) { return vector::CombiningKind::MINSI; }) + .Case( + [&](auto op) { return vector::CombiningKind::MINUI; }) + .Case( + [&](auto op) { return vector::CombiningKind::MINNUMF; }) + .Case( + [&](auto op) { return vector::CombiningKind::MINIMUMF; }) + .Case([&](auto op) { return vector::CombiningKind::AND; }) + .Case([&](auto op) { return vector::CombiningKind::OR; }) + .Case([&](auto op) { return vector::CombiningKind::XOR; }) + .Default([&](auto op) { return std::nullopt; }); +} + +SmallVector vectorizeCombineOpWithoutTerminator( + Location loc, OpBuilder &builder, Region &combineOp, ValueRange operands, + unsigned vectorLength, bool needCvtDataLayout = false) { + IRMapping map; + for (auto [arg, operand] : llvm::zip(combineOp.getArguments(), operands)) { + map.map(arg, operand); + } + for (auto &o : combineOp.back().without_terminator()) { + for (auto operand : o.getOperands()) { + if (auto constantOp = operand.getDefiningOp()) { + if (!map.lookupOrNull(operand)) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(constantOp); + if (operand.getType().isInteger(1)) { + auto boolAttr = dyn_cast(constantOp.getValue()); + auto integerAttr = dyn_cast(constantOp.getValue()); + if ((boolAttr && !boolAttr.getValue()) || + (integerAttr && integerAttr.getValue().isZero())) { + map.map(operand, + builder.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + operand.getType()), + DenseI64ArrayAttr::get(builder.getContext(), + ArrayRef{0}))); + } else { + map.map( + operand, + builder.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + operand.getType()), + DenseI64ArrayAttr::get(builder.getContext(), + ArrayRef{vectorLength}))); + } + } else { + map.map(operand, + builder.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + operand.getType()), + operand)); + } + } + } + } + Operation *newOp; + if (auto selectOp = dyn_cast(o)) { + auto condition = selectOp.getCondition(); + auto mapValue = map.lookup(condition); + if (cast(mapValue.getType()).getElementType().isInteger(8)) { + map.map(condition, + builder + .create( + loc, + VectorType::get(ArrayRef{vectorLength}, + builder.getIntegerType(1)), + mapValue) + .getResult(0)); + newOp = builder.clone(o, map); + map.map(condition, mapValue); + } else { + newOp = builder.clone(o, map); + } + } else { + newOp = builder.clone(o, map); + } + SmallVector resultTypes; + auto typeInterface = dyn_cast(newOp); + if (typeInterface && + succeeded(typeInterface.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), resultTypes))) { + for (auto [resultType, result, newResult] : + llvm::zip(resultTypes, o.getResults(), newOp->getResults())) { + newResult.setType(resultType); + map.map(result, newResult); + } + } else { + for (auto [result, newResult] : + llvm::zip(o.getResults(), newOp->getResults())) { + auto vectorTy = + VectorType::get(ArrayRef{vectorLength}, result.getType()); + newResult.setType(vectorTy); + map.map(result, newResult); + } + } + } + auto terminatorOprands = llvm::to_vector(llvm::map_range( + llvm::cast(combineOp.back().getTerminator()) + .getResult(), + [&](auto v) { + auto mappingValue = map.lookupOrNull(v); + assert(mappingValue != nullptr); + if (v.getType().isInteger(1) && needCvtDataLayout) { + mappingValue = + builder + .create( + loc, + VectorType::get(ArrayRef{vectorLength}, + builder.getIntegerType(8)), + mappingValue) + .getResult(0); + } + return mappingValue; + })); + return terminatorOprands; +} + +void vectorizeCombineOpTerminator(Location loc, OpBuilder &builder, + ValueRange operands) { + builder.create(loc, operands); +} + +struct TTReduceOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + auto axis = op.getAxis(); + SmallVector outputs; + SmallVector isSingleElements; + SmallVector elemTypes; + auto numOutput = op.getResults().size(); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto inputType = dyn_cast(op.getSrcs()[0].getType()); + auto numElems = triton::gcu::getElemsPerThread(inputType); + SmallVector outputShape(numElems.begin(), numElems.end()); + outputShape[axis] = 1; + + for (unsigned i = 0; i < numOutput; ++i) { + auto resultType = getTypeConverter()->convertType(op.getType(i)); + bool isSingleElement = !isa(resultType); + isSingleElements.push_back(isSingleElement); + auto elemType = isSingleElement + ? resultType + : dyn_cast(resultType).getElementType(); + elemTypes.push_back(elemType); + auto resultMemRefType = MemRefType::get(outputShape, elemType); + Value output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultMemRefType); + outputs.push_back(output); + } + + std::array reduceInputDims = {1, 1, 1}; + std::array reduceOutputDims = {1, 1, 1}; + int64_t reduceAxis = 2; + for (int i = numElems.size() - 1, j = 2; i >= 0; i--) { + if (static_cast(i) == axis) { + if (reduceInputDims[j] == 1) { + reduceInputDims[j] = numElems[i]; + } else { + reduceInputDims[--j] = numElems[i]; + } + reduceAxis = j; + reduceOutputDims[reduceAxis] = 1; + --j; + } else { + reduceInputDims[j] *= numElems[i]; + reduceOutputDims[j] = reduceInputDims[j]; + } + } + assert(reduceAxis == 1 || reduceAxis == 2); + + SmallVector reduceInputs; + SmallVector reduceOutputs; + llvm::transform( + adaptor.getSrcs(), std::back_inserter(reduceInputs), [&](auto input) { + return rewriter.create( + loc, + MemRefType::get( + reduceInputDims, + cast(input.getType()).getElementType()), + input, ValueRange{}, ValueRange{}, ValueRange{}, + ArrayRef{0}, ArrayRef{reduceInputDims}, + ArrayRef{reduceInputDims[1] * reduceInputDims[2], + reduceInputDims[2], 1}); + }); + llvm::transform( + outputs, std::back_inserter(reduceOutputs), [&](auto output) { + return rewriter.create( + loc, + MemRefType::get( + reduceOutputDims, + cast(output.getType()).getElementType()), + output, ValueRange{}, ValueRange{}, ValueRange{}, + ArrayRef{0}, ArrayRef{reduceOutputDims}, + ArrayRef{reduceOutputDims[1] * reduceOutputDims[2], + reduceOutputDims[2], 1}); + }); + + applyReduce(op, rewriter, reduceOutputs, reduceInputs, reduceInputDims, + reduceAxis); + auto slicedAxies = getSlicedAxies(inputType); + if (slicedAxies.count(axis) != 0) { + SmallVector sharedMemShape(inputType.getShape()); + auto encodingAttr = dyn_cast(inputType).getEncoding(); + // use gcu triton::gcu::getWarpsPerCTA + auto warpsPerCTA = triton::gcu::getWarpsPerCTA(encodingAttr); + if (warpsPerCTA.size() != sharedMemShape.size()) { + op.dump(); + assert(false && "the reduce input layout is not a blockencoding!"); + } + + if (warpsPerCTA[axis] < sharedMemShape[axis]) { + sharedMemShape[axis] = warpsPerCTA[axis]; + } + + bool isReduce1D = + sharedMemShape[axis] == std::accumulate(sharedMemShape.begin(), + sharedMemShape.end(), 1, + std::multiplies()); + Value tag; + if (!isReduce1D) { + tag = getPrivateDTETag(rewriter, op); + } + SmallVector sharedBuffers; + auto zero = rewriter.create(loc, 0); + for (unsigned i = 0; i < numOutput; ++i) { + auto sharedMemRefType = + MemRefType::get(sharedMemShape, elemTypes[i], AffineMap{}, + rewriter.getI64IntegerAttr(2) /*shared memory*/); + sharedBuffers.emplace_back( + rewriter.create(loc, sharedMemRefType)); + if (isReduce1D) { + rewriter.create( + loc, + rewriter.create(loc, reduceOutputs[i], + ValueRange{zero, zero, zero}), + sharedBuffers.back(), + ValueRange{getWarpIds(rewriter, loc, inputType)}); + rewriter.create(loc); + } else { + storeToSharedMem(rewriter, tag, inputType, sharedBuffers.back(), + outputs[i], false); + } + } + + if (warpsPerCTA[axis] < sharedMemShape[axis]) { + reduceInputDims[reduceAxis] = warpsPerCTA[axis]; + } else { + reduceInputDims[reduceAxis] = sharedMemShape[axis]; + } + auto loadFromShareForAllReduce = + [&](OpBuilder &builder, Value tag, Type type, Value buffer, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) { + auto loc = buffer.getLoc(); + auto srcType = dyn_cast(buffer.getType()); + auto numElems = triton::gcu::getElemsPerThread(type); + numElems[axis] = warpsPerCTA[axis]; + auto totalNumElems = builder.create( + loc, std::accumulate(numElems.begin(), numElems.end(), 1, + std::multiplies())); + auto outputType = MemRefType::get( + SmallVector(numElems.begin(), numElems.end()), + srcType.getElementType()); + auto warpIds = getWarpIds(builder, loc, type); + SmallVector offsets; + auto zero = builder.create(loc, 0); + for (unsigned i = 0; i < srcType.getRank(); ++i) { + if (i == axis) { + offsets.push_back(builder.create( + loc, 0, builder.getI32Type())); + } else { + offsets.push_back(builder.create( + loc, + builder.create(loc, numElems[i], + builder.getI32Type()), + builder.create( + loc, builder.getI32Type(), warpIds[i]))); + } + } + auto output = + syncAllocOp(builder, loc, op.getOperation(), userAnalysis, + replaced2Origin, outputType); + auto defaultValue = triton::gcu::createConstantZero( + builder, loc, srcType.getElementType()); + if (srcType.getRank() > 5) { + SmallVector mergedOffsets; + Value src; + Value dst; + mergeContinuousDims(builder, loc, src, dst, offsets, + mergedOffsets, srcType, outputType, buffer, + output); + builder.create( + loc, dst, src, mergedOffsets, defaultValue, tag, + ValueRange{zero}); + auto [oriOutputStrides, oriOutputOffset] = + outputType.getStridesAndOffset(); + builder.create( + loc, outputType, dst, oriOutputOffset, + SmallVector(numElems.begin(), numElems.end()), + oriOutputStrides); + } else { + builder.create(loc, output, buffer, + offsets, defaultValue, + tag, ValueRange{zero}); + } + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + return output; + }; + + SmallVector warpReduceInputs; + for (unsigned i = 0; i < numOutput; ++i) { + if (isReduce1D) { + warpReduceInputs.push_back(sharedBuffers[i]); + } else { + auto tensorType = + RankedTensorType::get(sharedMemShape, elemTypes[i], encodingAttr); + warpReduceInputs.emplace_back(loadFromShareForAllReduce( + rewriter, tag, tensorType, sharedBuffers[i], userAnalysis, + replaced2Origin)); + } + } + + llvm::transform( + warpReduceInputs, warpReduceInputs.begin(), [&](auto input) { + return rewriter.create( + loc, + MemRefType::get( + reduceInputDims, + cast(input.getType()).getElementType(), + MemRefLayoutAttrInterface{}, + isReduce1D ? rewriter.getI64IntegerAttr(2) : Attribute{}), + input, ValueRange{}, ValueRange{}, ValueRange{}, + ArrayRef{0}, ArrayRef{reduceInputDims}, + ArrayRef{reduceInputDims[1] * reduceInputDims[2], + reduceInputDims[2], 1}); + }); + applyReduce(op, rewriter, reduceOutputs, warpReduceInputs, + reduceInputDims, reduceAxis); + for (auto buffer : sharedBuffers) { + rewriter.create(loc, buffer); + } + } + + SmallVector finalOutputs; + for (unsigned i = 0; i < numOutput; ++i) { + auto output = outputs[i]; + if (isSingleElements[i]) { + auto zero = rewriter.create(loc, 0); + output = rewriter.create(loc, output, ValueRange{zero}); + } else { + auto resultType = dyn_cast( + getTypeConverter()->convertType(op.getResultTypes()[i])); + if (resultType.getNumElements() != + dyn_cast(output.getType()).getNumElements()) { + return op.emitOpError("element number mismatch: ") + << resultType.getNumElements() << " vs " + << dyn_cast(output.getType()).getNumElements(); + } + auto [strides, offset] = resultType.getStridesAndOffset(); + output = rewriter.create( + loc, resultType, output, offset, resultType.getShape(), strides); + } + finalOutputs.push_back(output); + } + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, finalOutputs); + return success(); + } + +private: + const GCUArch arch = GCUArch::GCU300; + +private: + void applyReduce(triton::ReduceOp op, OpBuilder &rewriter, + ArrayRef outputs, ArrayRef inputs, + const std::array &reduceDims, + int64_t reduceAxis) const { + if (succeeded(applyVectorizationImpl(op, rewriter, outputs, inputs, + reduceDims, reduceAxis))) { + return; + } + if (succeeded(applySmallSizeImpl(op, rewriter, outputs, inputs, reduceDims, + reduceAxis))) { + return; + } + applyScalarImpl(op, rewriter, outputs, inputs, reduceDims, reduceAxis); + } + + LogicalResult applyBuiltinImpl(triton::ReduceOp op, OpBuilder &rewriter, + ArrayRef outputs, + ArrayRef inputs, + const std::array &reduceDims, + int64_t reduceAxis) const { + auto loc = op.getLoc(); + auto &opList = op.getCombineOp().front().getOperations(); + if (patternMatch( + &opList)) { + // Reduce Max/Min + auto externElementwiseOp = + cast(opList.front()); + auto symbol = externElementwiseOp.getSymbol(); + if ((symbol == "__nv_fmaxf" || symbol == "__nv_max" || + symbol == "__nv_umax" || symbol == "__nv_fminf" || + symbol == "__nv_min" || symbol == "__nv_umin") && + (reduceDims[1] % 16 == 0 && reduceDims[2] % 512 == 0)) { + Value output = outputs[0]; + Value input = inputs[0]; + Value workspace = nullptr; + if (reduceAxis == 2) { + std::array dims = {reduceDims[0], reduceDims[1], 1}; + auto elementTy = cast(input.getType()).getElementType(); + auto bitWidth = elementTy.getIntOrFloatBitWidth(); + if (bitWidth == 8) { + dims[2] = 512; + } else if (bitWidth == 16) { + dims[2] = 256; + } else if (bitWidth == 32) { + dims[2] = 128; + } else { + return failure(); + } + workspace = rewriter.create( + loc, MemRefType::get(dims, elementTy)); + } else { + workspace = output; + } + + if (symbol == "__nv_fmaxf") { + rewriter.create(loc, gcu::ReduceOperation::MAXF, + output, input, workspace, reduceAxis); + } else if (symbol == "__nv_max") { + rewriter.create(loc, gcu::ReduceOperation::MAXSI, + output, input, workspace, reduceAxis); + } else if (symbol == "__nv_umax") { + rewriter.create(loc, gcu::ReduceOperation::MAXUI, + output, input, workspace, reduceAxis); + } else if (symbol == "__nv_fminf") { + rewriter.create(loc, gcu::ReduceOperation::MINF, + output, input, workspace, reduceAxis); + } else if (symbol == "__nv_min") { + rewriter.create(loc, gcu::ReduceOperation::MINSI, + output, input, workspace, reduceAxis); + } else { + rewriter.create(loc, gcu::ReduceOperation::MINUI, + output, input, workspace, reduceAxis); + } + if (reduceDims[1] % 16 == 0 && reduceDims[2] % 512 == 0 && + reduceAxis == 2) { + rewriter.create(loc, workspace); + } + } else { + return failure(); + } + } else if ((patternMatch(&opList) || + patternMatch(&opList)) && + (reduceDims[2] % 128 == 0 && reduceDims[1] % 128 == 0)) { + // Reduce Sum + Value output = outputs[0]; + Value input = inputs[0]; + Value workspace = output; + rewriter.create(loc, gcu::ReduceOperation::SUM, output, + input, workspace, reduceAxis); + } else { + return failure(); + } + doMemFence(rewriter, op); + return success(); + } + + void applyScalarImpl(triton::ReduceOp op, OpBuilder &rewriter, + ArrayRef outputs, ArrayRef inputs, + const std::array &reduceDims, + int64_t reduceAxis) const { + auto loc = op.getLoc(); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto numOutput = outputs.size(); + rewriter.create( + loc, zero, rewriter.create(loc, reduceDims[0]), + one, ValueRange{}, + [&](OpBuilder &builder, Location loc, Value iter0, + ValueRange iterArgs) { + builder.create( + loc, zero, + builder.create( + loc, reduceDims[3 - reduceAxis]), + one, ValueRange{}, + [&](OpBuilder &builder, Location loc, Value iter1, + ValueRange iterArgs) { + SmallVector iterators(3); + iterators[0] = iter0; + iterators[3 - reduceAxis] = iter1; + iterators[reduceAxis] = zero; + SmallVector initValues; + llvm::transform(inputs, std::back_inserter(initValues), + [&](auto input) { + return builder.create( + loc, input, iterators); + }); + auto loop = builder.create( + loc, one, + builder.create( + loc, reduceDims[reduceAxis]), + one, initValues, + [&](OpBuilder &builder, Location loc, Value iter2, + ValueRange iterArgs) { + SmallVector operands(iterArgs.begin(), + iterArgs.end()); + SmallVector resultElemTypes; + iterators[reduceAxis] = iter2; + for (unsigned i = 0; i < numOutput; ++i) { + operands.push_back(builder.create( + loc, inputs[i], iterators)); + resultElemTypes.push_back(operands.back().getType()); + } + auto executeRegionOp = + builder.create(loc, + resultElemTypes); + executeRegionOp.getRegion().emplaceBlock(); + IRMapping map; + for (auto [arg, operand] : llvm::zip( + op.getCombineOp().getArguments(), operands)) { + map.map(arg, operand); + } + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart( + &executeRegionOp.getRegion().getBlocks().back()); + for (auto &o : op.getCombineOp().getBlocks().back()) { + auto newOp = builder.clone(o, map); + for (auto [result, newResult] : + llvm::zip(o.getResults(), newOp->getResults())) { + map.map(result, newResult); + } + } + } + builder.create( + loc, executeRegionOp.getResults()); + }); + iterators[reduceAxis] = zero; + for (unsigned i = 0; i < numOutput; ++i) { + builder.create(loc, loop.getResult(i), + outputs[i], iterators); + } + builder.create(loc); + }); + builder.create(loc); + }); + doMemFence(rewriter, op); + } + + LogicalResult applyVectorizationImpl(triton::ReduceOp op, OpBuilder &rewriter, + ArrayRef outputs, + ArrayRef inputs, + const std::array &reduceDims, + int64_t reduceAxis) const { + SmallVector elementTypes; + unsigned maxBpe = 1; + unsigned minBpe = targetInfo[arch].supportI64 ? 8 : 4; + for (auto output : outputs) { + auto elementType = cast(output.getType()).getElementType(); + if (!elementType.isInteger(1) && !elementType.isInteger(8) && + !elementType.isInteger(16) && !elementType.isInteger(32) && + !elementType.isBF16() && !elementType.isF16() && + !elementType.isF32() && + (targetInfo[arch].supportI64 && !elementType.isInteger(64))) { + return failure(); + } + auto bpe = mlir::triton::gcu::getBpe(elementType); + maxBpe = bpe > maxBpe ? bpe : maxBpe; + minBpe = bpe < minBpe ? bpe : minBpe; + elementTypes.push_back(elementType); + } + auto numVacc = maxBpe / minBpe; + if (numVacc > 4) { + return failure(); + } + int64_t vectorizeAxis = 3; + unsigned vectorLength = targetInfo[arch].vaccSizeInBytes / minBpe; + for (auto i = 2; i >= 0; --i) { + if (reduceDims[i] >= vectorLength) { + vectorizeAxis = i; + break; + } + } + if (vectorizeAxis == 3) { + return failure(); + } + while (numVacc < targetInfo[arch].preferVaccNum) { + int64_t axis = 3; + unsigned vLen = 2 * vectorLength; + for (auto i = 2; i >= 0; --i) { + if (reduceDims[i] >= vLen) { + axis = i; + break; + } + } + numVacc *= 2; + if (axis == 3 || (vectorizeAxis == 2 && axis != 2)) { + break; + } else { + vectorizeAxis = axis; + vectorLength = vLen; + } + } + + auto loc = op.getLoc(); + auto numOutput = outputs.size(); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + SmallVector vectorTypes; + constexpr int loopUnrollTime = 16; + + SmallVector tmpBuffers; + SmallVector reduceOutputs(outputs.begin(), outputs.end()); + std::array reduceInputDims = reduceDims; + std::array reduceOutputDims = reduceInputDims; + reduceOutputDims[reduceAxis] = 1; + + bool needTranspose = false; + std::array transposeLayout = {0, 1, 2}; + SmallVector transposeLayoutValue; + if (vectorizeAxis != 2) { + needTranspose = true; + transposeLayout[vectorizeAxis] = 2; + transposeLayout[2] = vectorizeAxis; + } + + if (needTranspose) { + reduceInputDims[0] = reduceDims[transposeLayout[0]]; + reduceInputDims[1] = reduceDims[transposeLayout[1]]; + reduceInputDims[2] = reduceDims[transposeLayout[2]]; + reduceAxis = transposeLayout[reduceAxis]; + llvm::transform(transposeLayout, std::back_inserter(transposeLayoutValue), + [&](auto dim) { + return rewriter.create( + loc, dim, rewriter.getI32Type()); + }); + auto tag = getPrivateDTETag(rewriter, op); + llvm::transform(inputs, std::back_inserter(tmpBuffers), [&](auto input) { + auto memrefTy = cast(input.getType()); + auto elementTy = memrefTy.getElementType(); + auto tmpBuffer = rewriter.create( + loc, + MemRefType::get(ArrayRef{reduceInputDims}, elementTy)); + rewriter.create( + loc, tmpBuffer, input, transposeLayoutValue, tag, ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, + memrefTy.getNumElements())); + return tmpBuffer; + }); + inputs = tmpBuffers; + reduceOutputDims = reduceInputDims; + reduceOutputDims[reduceAxis] = 1; + if (vectorizeAxis == 0) { + assert(reduceAxis == 1); + llvm::transform( + elementTypes, reduceOutputs.begin(), [&](auto elementTy) { + return rewriter.create( + loc, MemRefType::get(ArrayRef{reduceOutputDims}, + elementTy)); + }); + } else { + llvm::transform(outputs, reduceOutputs.begin(), [&](auto output) { + auto memrefType = MemRefType::get( + ArrayRef{reduceOutputDims[0], reduceOutputDims[1], + reduceOutputDims[2]}, + cast(output.getType()).getElementType()); + return rewriter.create( + loc, memrefType, output, 0, ArrayRef{reduceOutputDims}, + ArrayRef{reduceOutputDims[2] * reduceOutputDims[1], + reduceOutputDims[2], 1}); + }); + needTranspose = false; + } + vectorizeAxis = 2; + } + + assert(vectorizeAxis == 2); + assert(reduceAxis == 1 || reduceAxis == 2); + bool isReduce1D = reduceDims[0] == 1 && reduceDims[1] == 1; + auto &combineOp = op.getCombineOp(); + SmallVector reduceInputs(inputs.begin(), inputs.end()); + for (unsigned i = 0; i < numOutput; ++i) { + auto elementTy = cast(inputs[i].getType()).getElementType(); + if (elementTy.isInteger(1)) { + reduceInputs[i] = rewriter.create( + loc, MemRefType::get(reduceInputDims, rewriter.getIntegerType(8)), + rewriter.create( + loc, + MemRefType::get(ArrayRef{ShapedType::kDynamic}, + rewriter.getIntegerType(8)), + rewriter.create( + loc, + mlir::gcu::PtrType::get(rewriter.getContext(), elementTy), + reduceInputs[i])), + 0, ArrayRef{reduceInputDims}, + ArrayRef{reduceInputDims[2] * reduceInputDims[1], + reduceInputDims[2], 1}); + reduceOutputs[i] = rewriter.create( + loc, MemRefType::get(reduceOutputDims, rewriter.getIntegerType(8)), + rewriter.create( + loc, + MemRefType::get(ArrayRef{ShapedType::kDynamic}, + rewriter.getIntegerType(8)), + rewriter.create( + loc, + mlir::gcu::PtrType::get(rewriter.getContext(), elementTy), + reduceOutputs[i])), + 0, ArrayRef{reduceOutputDims}, + ArrayRef{reduceOutputDims[2] * reduceOutputDims[1], + reduceOutputDims[2], 1}); + elementTypes[i] = rewriter.getIntegerType(8); + } + } + inputs = reduceInputs; + llvm::transform(elementTypes, std::back_inserter(vectorTypes), + [vectorLength](auto elementTy) { + return VectorType::get(ArrayRef{vectorLength}, + elementTy); + }); + + auto vLength = rewriter.create(loc, vectorLength); + SmallVector cur; + SmallVector next; + auto loopLimit = reduceAxis == 1 || !isReduce1D + ? reduceInputDims[1] + : reduceInputDims[2] / vectorLength; + auto loopCnt = loopLimit > loopUnrollTime ? loopUnrollTime : loopLimit; + auto loopCntValue = rewriter.create(loc, loopCnt); + if (reduceAxis == 1) { + rewriter.create( + loc, zero, + rewriter.create(loc, reduceInputDims[0]), one, + ValueRange{}, + [&](OpBuilder &builder, Location loc, Value iter0, + ValueRange iterArgs) { + builder.create( + loc, zero, + builder.create(loc, reduceInputDims[2]), + vLength, ValueRange{}, + [&](OpBuilder &builder, Location loc, Value iter2, + ValueRange iterArgs) { + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + cur.emplace_back(builder.create( + loc, vectorTypes[j], inputs[j], + ValueRange{ + iter0, + builder.create(loc, i), + iter2})); + } + } + auto loop = builder.create( + loc, loopCntValue, + builder.create( + loc, reduceInputDims[1]), + loopCntValue, cur, + [&](OpBuilder &builder, Location loc, Value iter1, + ValueRange iterArgs) { + next.resize(loopCnt * numOutput); + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + next[i * numOutput + j] = + builder.create( + loc, vectorTypes[j], inputs[j], + ValueRange{ + iter0, + builder.create( + loc, + builder + .create( + loc, i), + iter1), + iter2}); + } + } + SmallVector args(numOutput * 2); + SmallVector terminatorOperands; + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + args[j] = iterArgs[i * numOutput + j]; + args[numOutput + j] = next[i * numOutput + j]; + } + terminatorOperands.append( + vectorizeCombineOpWithoutTerminator( + loc, builder, combineOp, args, vectorLength)); + } + vectorizeCombineOpTerminator(loc, builder, + terminatorOperands); + }); + cur.reserve(cur.size() * 2); + cur.assign(loop.getResults().begin(), + loop.getResults().end()); + auto iter = cur.begin(); + while (loopCnt != 1) { + loopCnt /= 2; + for (auto i = 0; i < loopCnt; ++i) { + cur.append(vectorizeCombineOpWithoutTerminator( + loc, builder, combineOp, + ValueRange(iter, 2 * numOutput), vectorLength, + loopCnt == 1)); + iter = std::next(iter, 2 * numOutput); + } + } + for (unsigned i = 0; i < numOutput; ++i) { + builder.create( + loc, *iter++, reduceOutputs[i], + ValueRange{iter0, zero, iter2}); + } + builder.create(loc); + }); + builder.create(loc); + }); + } else { + // reduceAxis == 2 + rewriter.create( + loc, zero, + rewriter.create(loc, reduceInputDims[0]), one, + ValueRange{}, + [&](OpBuilder &builder, Location loc, Value iter0, + ValueRange iterArgs) { + if (isReduce1D) { + builder.create( + loc, zero, + builder.create(loc, + reduceInputDims[1]), + one, ValueRange{}, + [&](OpBuilder &builder, Location loc, Value iter1, + ValueRange iterArgs) { + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + cur.emplace_back(builder.create( + loc, vectorTypes[j], inputs[j], + ValueRange{iter0, iter1, + builder.create( + loc, i * vectorLength)})); + } + } + auto loop = builder.create( + loc, + builder.create( + loc, loopCnt * vectorLength), + builder.create( + loc, reduceInputDims[2]), + builder.create( + loc, loopCnt * vectorLength), + cur, + [&](OpBuilder &builder, Location loc, Value iter2, + ValueRange iterArgs) { + next.resize(loopCnt * numOutput); + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + next[i * numOutput + j] = + builder.create( + loc, vectorTypes[j], inputs[j], + ValueRange{ + iter0, iter1, + builder.create( + loc, + builder.create< + arith::ConstantIndexOp>( + loc, i * vectorLength), + iter2)}); + } + } + SmallVector args(numOutput * 2); + SmallVector terminatorOperands; + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + args[j] = iterArgs[i * numOutput + j]; + args[numOutput + j] = next[i * numOutput + j]; + } + terminatorOperands.append( + vectorizeCombineOpWithoutTerminator( + loc, builder, combineOp, args, + vectorLength)); + } + vectorizeCombineOpTerminator(loc, builder, + terminatorOperands); + }); + cur.reserve(cur.size() * 2); + cur.assign(loop.getResults().begin(), + loop.getResults().end()); + auto iter = cur.begin(); + while (loopCnt != 1) { + loopCnt /= 2; + for (unsigned i = 0; i < loopCnt; ++i) { + cur.append(vectorizeCombineOpWithoutTerminator( + loc, builder, combineOp, + ValueRange(iter, 2 * numOutput), vectorLength)); + iter = std::next(iter, 2 * numOutput); + } + } + auto results = + vReduce(loc, builder, combineOp, + ValueRange(iter, numOutput), vectorLength); + for (unsigned i = 0; i < numOutput; ++i) { + if (cast(reduceOutputs[i].getType()) + .getElementType() + .isInteger(1)) { + results[i] = rewriter.create( + loc, rewriter.getI1Type(), results[i]); + } + builder.create( + loc, results[i], reduceOutputs[i], + ValueRange{iter0, iter1, zero}); + } + builder.create(loc); + }); + } else { + builder.create( + loc, zero, + builder.create(loc, + reduceInputDims[1]), + loopCntValue, ValueRange{}, + [&](OpBuilder &builder, Location loc, Value iter1, + ValueRange iterArgs) { + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + cur.emplace_back(builder.create( + loc, vectorTypes[j], inputs[j], + ValueRange{ + iter0, + builder.create( + loc, + builder.create(loc, + i), + iter1), + zero})); + } + } + auto loop = builder.create( + loc, + builder.create(loc, + vectorLength), + builder.create( + loc, reduceInputDims[2]), + builder.create(loc, + vectorLength), + cur, + [&](OpBuilder &builder, Location loc, Value iter2, + ValueRange iterArgs) { + next.resize(loopCnt * numOutput); + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + next[i * numOutput + j] = + builder.create( + loc, vectorTypes[j], inputs[j], + ValueRange{ + iter0, + builder.create( + loc, + builder.create< + arith::ConstantIndexOp>(loc, + i), + iter1), + iter2}); + } + } + SmallVector args(numOutput * 2); + SmallVector terminatorOperands; + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + args[j] = iterArgs[i * numOutput + j]; + args[numOutput + j] = next[i * numOutput + j]; + } + terminatorOperands.append( + vectorizeCombineOpWithoutTerminator( + loc, builder, combineOp, args, + vectorLength)); + } + vectorizeCombineOpTerminator(loc, builder, + terminatorOperands); + }); + for (unsigned i = 0; i < loopCnt; ++i) { + for (unsigned j = 0; j < numOutput; ++j) { + auto results = + vReduce(loc, builder, combineOp, + ValueRange(loop.getResults().slice( + i * numOutput, numOutput)), + vectorLength); + if (cast(reduceOutputs[j].getType()) + .getElementType() + .isInteger(1)) { + results[j] = rewriter.create( + loc, rewriter.getI1Type(), results[j]); + } + builder.create( + loc, results[j], reduceOutputs[j], + ValueRange{ + iter0, + builder.create( + loc, + builder.create(loc, + i), + iter1), + zero}); + } + } + builder.create(loc); + }); + } + builder.create(loc); + }); + } + + for (auto buffer : tmpBuffers) { + rewriter.create(loc, buffer); + } + if (needTranspose) { + for (unsigned i = 0; i < numOutput; ++i) { + auto memrefTy = cast(outputs[i].getType()); + auto tag = getPrivateDTETag(rewriter, op); + rewriter.create( + loc, outputs[i], reduceOutputs[i], transposeLayoutValue, tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, + memrefTy.getNumElements())); + rewriter.create(loc, reduceOutputs[i]); + } + doMemFence(rewriter, op); + } else { + rewriter.create(loc, gcu::MFenceType::Local); + } + return success(); + } + + LogicalResult + applyVectorizationImplV2(triton::ReduceOp op, OpBuilder &rewriter, + ArrayRef outputs, ArrayRef inputs, + const std::array &reduceDims, + int64_t reduceAxis) const { + SmallVector elementTypes; + unsigned maxBpe = 4; + unsigned minBpe = targetInfo[arch].supportI64 ? 8 : 4; + for (auto output : outputs) { + auto elementType = cast(output.getType()).getElementType(); + if (!elementType.isInteger(1) && !elementType.isInteger(8) && + !elementType.isInteger(16) && !elementType.isInteger(32) && + !elementType.isBF16() && !elementType.isF16() && + !elementType.isF32() && + (targetInfo[arch].supportI64 && !elementType.isInteger(64))) { + return failure(); + } + auto bpe = mlir::triton::gcu::getBpe(elementType); + maxBpe = bpe > maxBpe ? bpe : maxBpe; + minBpe = bpe < minBpe ? bpe : minBpe; + elementTypes.push_back(elementType); + } + auto numVacc = maxBpe / minBpe; + if (numVacc > 4) { + return failure(); + } + int64_t vectorizeAxis = 3; + unsigned vectorLength = targetInfo[arch].vaccSizeInBytes / minBpe; + for (auto i = 2; i >= 0; --i) { + if (reduceDims[i] >= vectorLength) { + vectorizeAxis = i; + break; + } + } + if (vectorizeAxis == 3) { + return failure(); + } + while (numVacc < targetInfo[arch].preferVaccNum) { + int64_t axis = 3; + unsigned vLen = 2 * vectorLength; + for (auto i = 2; i >= 0; --i) { + if (reduceDims[i] >= vLen) { + axis = i; + break; + } + } + numVacc *= 2; + if (axis == 3 || (vectorizeAxis == 2 && axis != 2)) { + break; + } else { + vectorizeAxis = axis; + vectorLength = vLen; + } + } + + auto vectorTypes = llvm::to_vector( + llvm::map_range(elementTypes, [vectorLength](auto elementTy) { + return VectorType::get(ArrayRef{vectorLength}, elementTy); + })); + auto loc = op.getLoc(); + SmallVector reduceBuffers; + llvm::transform(outputs, std::back_inserter(reduceBuffers), + [&](auto output) { + if (vectorizeAxis == reduceAxis) { + auto memrefTy = cast(output.getType()); + auto elementTy = memrefTy.getElementType(); + SmallVector reduceBufferDims( + reduceDims.begin(), reduceDims.end()); + reduceBufferDims[vectorizeAxis] = vectorLength; + Value buffer = rewriter.create( + loc, MemRefType::get(reduceBufferDims, elementTy)); + return buffer; + } else { + return output; + } + }); + auto tag = getPrivateDTETag(rewriter, op); + auto numOutput = outputs.size(); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto stride = + std::accumulate(reduceDims.begin() + reduceAxis, reduceDims.end(), 1, + std::multiplies()); + auto elementsPerStride = stride / reduceDims[reduceAxis]; + auto numElements = product(reduceDims) / reduceDims[reduceAxis]; + if (vectorizeAxis == reduceAxis) { + elementsPerStride *= vectorLength; + numElements *= vectorLength; + } + + for (unsigned i = 0; i < numOutput; ++i) { + rewriter.create( + loc, inputs[i], SmallVector(reduceDims.size(), zero), + reduceBuffers[i], SmallVector(reduceDims.size(), zero), + rewriter.create(loc, numElements), tag, + ValueRange{zero}, + rewriter.create(loc, stride), + rewriter.create(loc, elementsPerStride)); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, numElements)); + } + SmallVector lbs(reduceDims.size(), zero); + if (vectorizeAxis == reduceAxis) { + lbs[reduceAxis] = + rewriter.create(loc, vectorLength); + } else { + lbs[reduceAxis] = one; + } + SmallVector ubs; + for (auto dim : reduceDims) { + ubs.push_back(rewriter.create(loc, dim)); + } + + SmallVector step(reduceDims.size(), one); + step[vectorizeAxis] = + rewriter.create(loc, vectorLength); + auto maskType = + VectorType::get(ArrayRef{vectorLength}, rewriter.getI1Type()); + Value mask = rewriter.create( + loc, maskType, + DenseI64ArrayAttr::get(rewriter.getContext(), + ArrayRef{vectorLength})); + unsigned strideOnVectorizeAxis = + std::accumulate(reduceDims.begin() + vectorizeAxis + 1, + reduceDims.end(), 1, std::multiplies()); + if (vectorizeAxis < reduceAxis) { + strideOnVectorizeAxis /= reduceDims[reduceAxis]; + } + + auto vecIndexTy = VectorType::get(ArrayRef{vectorLength}, + rewriter.getIndexType()); + auto vecTy = + VectorType::get(ArrayRef{vectorLength}, rewriter.getI32Type()); + auto indexVec0 = rewriter.create( + loc, + rewriter + .create( + loc, vecTy, + rewriter.create(loc, vecIndexTy).getResult()) + .getResult(0), + rewriter.create( + loc, vecTy, + rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(strideOnVectorizeAxis)))); + Value indexVec1 = indexVec0; + if (vectorizeAxis < reduceAxis) { + indexVec1 = rewriter.create( + loc, indexVec1, + rewriter.create( + loc, vecTy, + rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(reduceDims[reduceAxis])))); + } + + SmallVector passThruValues; + for (unsigned i = 0; i < numOutput; ++i) { + passThruValues.push_back( + rewriter.create(loc, vectorTypes[i], zero)); + } + auto &combineOp = op.getCombineOp(); + scf::buildLoopNest( + rewriter, loc, + ArrayRef(lbs.begin(), lbs.begin() + vectorizeAxis), + ArrayRef(ubs.begin(), ubs.begin() + vectorizeAxis), + ArrayRef(step.begin(), step.begin() + vectorizeAxis), + [&](OpBuilder &builder, Location loc, ValueRange outerIters) { + scf::buildLoopNest( + rewriter, loc, + ArrayRef(lbs.begin() + vectorizeAxis, lbs.end()), + ArrayRef(ubs.begin() + vectorizeAxis, ubs.end()), + ArrayRef(step.begin() + vectorizeAxis, step.end()), + [&](OpBuilder &builder, Location loc, ValueRange innerIters) { + SmallVector inputIndices; + SmallVector outputIndices; + SmallVector resultElemTypes; + SmallVector operands; + SmallVector ivs; + for (auto iv : outerIters) { + ivs.push_back(iv); + } + for (auto iv : innerIters) { + ivs.push_back(iv); + } + for (unsigned i = 0; i < ivs.size(); ++i) { + inputIndices.push_back(ivs[i]); + if (i == reduceAxis) { + outputIndices.push_back(zero); + } else { + outputIndices.push_back(inputIndices[i]); + } + } + for (unsigned i = 0; i < numOutput; ++i) { + operands.push_back(builder.create( + loc, vectorTypes[i], reduceBuffers[i], outputIndices, + indexVec0, mask, passThruValues[i])); + resultElemTypes.push_back(vectorTypes[i]); + } + for (unsigned i = 0; i < numOutput; ++i) { + operands.push_back(builder.create( + loc, vectorTypes[i], inputs[i], inputIndices, indexVec1, + mask, operands[i])); + } + auto executeRegionOp = + builder.create(loc, resultElemTypes); + executeRegionOp.getRegion().emplaceBlock(); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart( + &executeRegionOp.getRegion().back()); + auto terminatorOperands = vectorizeCombineOpWithoutTerminator( + loc, builder, combineOp, operands, vectorLength); + vectorizeCombineOpTerminator(loc, builder, + terminatorOperands); + } + for (unsigned i = 0; i < numOutput; ++i) { + builder.create( + loc, reduceBuffers[i], outputIndices, indexVec0, mask, + executeRegionOp.getResult(i)); + } + }); + if (reduceAxis == vectorizeAxis) { + SmallVector lowerBounds(lbs.begin() + reduceAxis, + lbs.end()); + SmallVector upperBounds(ubs.begin() + reduceAxis, + ubs.end()); + lowerBounds[0] = zero; + upperBounds[0] = one; + scf::buildLoopNest( + rewriter, loc, lowerBounds, upperBounds, + ArrayRef(step.begin() + reduceAxis, step.end()), + [&](OpBuilder &builder, Location loc, ValueRange innerIters) { + auto loopCnt = std::log2(vectorLength); + SmallVector outputIndices; + SmallVector resultElemTypes; + for (auto iv : outerIters) { + outputIndices.push_back(iv); + } + for (auto iv : innerIters) { + outputIndices.push_back(iv); + } + builder.create( + loc, zero, + builder.create(loc, loopCnt), one, + ValueRange{builder.create( + loc, vectorLength)}, + [&](OpBuilder &builder, Location loc, Value iter, + ValueRange iterArgs) { + SmallVector args; + for (unsigned i = 0; i < numOutput; ++i) { + args.push_back(builder.create( + loc, vectorTypes[i], reduceBuffers[i], + outputIndices, indexVec0, mask, + passThruValues[i])); + } + auto stride = builder.create( + loc, iterArgs[0], + builder.create(loc, 2)); + SmallVector indices(outputIndices); + indices[vectorizeAxis] = stride; + auto strideMask = builder.create( + loc, maskType, ValueRange{stride}); + for (unsigned i = 0; i < numOutput; ++i) { + args.push_back(builder.create( + loc, vectorTypes[i], reduceBuffers[i], indices, + indexVec0, strideMask, args[i])); + resultElemTypes.push_back(vectorTypes[i]); + } + auto executeRegion = + builder.create( + loc, resultElemTypes); + executeRegion.getRegion().emplaceBlock(); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart( + &executeRegion.getRegion().back()); + auto terminatorOperands = + vectorizeCombineOpWithoutTerminator( + loc, builder, combineOp, args, vectorLength); + vectorizeCombineOpTerminator(loc, builder, + terminatorOperands); + } + for (unsigned i = 0; i < numOutput; ++i) { + builder.create( + loc, reduceBuffers[i], outputIndices, indexVec0, + strideMask, executeRegion.getResult(i)); + } + builder.create(loc, ValueRange{stride}); + }); + for (unsigned i = 0; i < numOutput; ++i) { + builder.create( + loc, + builder.create(loc, reduceBuffers[i], + outputIndices), + outputs[i], outputIndices); + } + }); + } + }); + if (vectorizeAxis == reduceAxis) { + for (auto buffer : reduceBuffers) + rewriter.create(loc, buffer); + } + doMemFence(rewriter, op); + return success(); + } + + LogicalResult applySmallSizeImpl(triton::ReduceOp op, OpBuilder &rewriter, + ArrayRef outputs, + ArrayRef inputs, + const std::array &reduceDims, + int64_t reduceAxis) const { + auto loc = op.getLoc(); + if (reduceDims[reduceAxis] == 1) { + for (auto [input, output] : llvm::zip(inputs, outputs)) { + rewriter.create(loc, input, output); + } + return success(); + } + bool isReduce1D = reduceDims[0] == 1 && reduceDims[1] == 1; + if (isReduce1D) { + unsigned minBpe = 4; + unsigned vectorSizeInBytes = targetInfo[arch].vaccSizeInBytes; + SmallVector elementTypes; + for (auto output : outputs) { + auto elementTy = cast(output.getType()).getElementType(); + if (!elementTy.isInteger(8) && !elementTy.isInteger(16) && + !elementTy.isInteger(32) && !elementTy.isBF16() && + !elementTy.isF16() && !elementTy.isF32()) { + return failure(); + } + elementTypes.push_back(elementTy); + auto bpe = mlir::triton::gcu::getBpe(elementTy); + minBpe = bpe < minBpe ? bpe : minBpe; + } + assert(minBpe * reduceDims[2] < vectorSizeInBytes); + unsigned vectorLength = vectorSizeInBytes / minBpe; + + SmallVector values; + auto zero = rewriter.create(loc, 0); + for (auto [input, elementTy] : llvm::zip(inputs, elementTypes)) { + Value v = rewriter.create( + loc, VectorType::get(ArrayRef{vectorLength}, elementTy), + input, ValueRange{zero, zero, zero}); + if (elementTy.isBF16() || elementTy.isF16()) { + values.emplace_back(rewriter.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + rewriter.getF32Type()), + v)); + } else if (elementTy.isInteger(1) || elementTy.isInteger(8) || + elementTy.isInteger(16)) { + values.emplace_back(rewriter.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + rewriter.getI32Type()), + v)); + } else if (elementTy.isF32() || elementTy.isInteger(32)) { + values.push_back(v); + } else { + llvm_unreachable("unsupported element type"); + } + } + auto &combineOp = op.getCombineOp(); + auto vaccLen = targetInfo[arch].vaccSizeInBytes / 4; + auto argNum = values.size(); + auto vrLen = vaccLen / 2; + unsigned validLength = reduceDims[2]; + if (vectorLength > vaccLen) { + auto validVaccNum = ceil(validLength, vrLen); + int numVacc = vectorLength / vaccLen; + SmallVector splitValues(argNum * validVaccNum); + for (size_t i = 0; i < argNum; ++i) { + auto elementType = + cast(values[i].getType()).getElementType(); + SmallVector resultTypes( + numVacc, + VectorType::get(ArrayRef{vaccLen}, elementType)); + auto vectorConvertOp = rewriter.create( + loc, resultTypes, values[i]); + for (auto j = 0; j < validVaccNum; ++j) { + splitValues[j * argNum + i] = vectorConvertOp.getResult(j); + } + } + splitValues.reserve(splitValues.size() * 2); + auto iter = splitValues.begin(); + while (validVaccNum != 1) { + validVaccNum /= 2; + for (auto i = 0; i < validVaccNum; ++i) { + splitValues.append(vectorizeCombineOpWithoutTerminator( + loc, rewriter, combineOp, ValueRange(iter, 2 * argNum), + vaccLen)); + iter = std::next(iter, 2 * argNum); + } + } + values.assign(iter, iter + argNum); + validLength = validLength < vrLen ? validLength : vrLen; + } + if (validLength == vrLen) { + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(rewriter.create( + loc, gcu::VectorMovSftMode::SHFLQW, values[i], 2)); + } + values = vectorizeCombineOpWithoutTerminator(loc, rewriter, combineOp, + values, vaccLen); + } else { + for (size_t i = 0; i < argNum; ++i) { + values[i] = rewriter.create( + loc, gcu::VectorMovSftMode::SHFLQW, values[i], 2); + } + } + if (validLength >= vrLen / 2) { + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(rewriter.create( + loc, gcu::VectorMovSftMode::SHFLQW, values[i], 1)); + } + values = vectorizeCombineOpWithoutTerminator(loc, rewriter, combineOp, + values, vaccLen); + } else { + for (size_t i = 0; i < argNum; ++i) { + values[i] = rewriter.create( + loc, gcu::VectorMovSftMode::SHFLQW, values[i], 1); + } + } + if (validLength >= vrLen / 4) { + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(rewriter.create( + loc, gcu::VectorMovSftMode::SHFLB, values[i], 8)); + } + values = vectorizeCombineOpWithoutTerminator(loc, rewriter, combineOp, + values, vaccLen); + } else { + for (size_t i = 0; i < argNum; ++i) { + values[i] = rewriter.create( + loc, gcu::VectorMovSftMode::SHFLB, values[i], 8); + } + } + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(rewriter.create( + loc, gcu::VectorMovSftMode::SHFLB, values[i], 4)); + } + values = vectorizeCombineOpWithoutTerminator(loc, rewriter, combineOp, + values, vaccLen); + for (size_t i = 0; i < argNum; ++i) { + values[i] = rewriter.create( + loc, values[i], rewriter.create(loc, 15)); + auto elementType = elementTypes[i]; + if (elementType.isBF16() || elementType.isF16()) { + values[i] = + rewriter.create(loc, elementType, values[i]); + } else if (elementType.isInteger(1) || elementType.isInteger(8) || + elementType.isInteger(16)) { + values[i] = + rewriter.create(loc, elementType, values[i]); + } + + rewriter.create(loc, values[i], outputs[i], + ValueRange{zero, zero, zero}); + } + return success(); + } + return failure(); + } + + SmallVector vReduce(Location loc, OpBuilder &builder, + Region &combineOp, ValueRange vecValues, + int64_t vectorLength) const { + assert(llvm::all_of(vecValues.getTypes(), [&](auto ty) { + auto vecTy = dyn_cast(ty); + return vecTy && vecTy.getRank() == 1 && + vecTy.getDimSize(0) == vectorLength; + })); + + if (auto kind = matchReduceCombiningKind(combineOp)) { + if (kind == vector::CombiningKind::MAXNUMF) { + assert(vecValues.size() == 1); + SmallVector values{ + builder.create(loc, *kind, vecValues[0])}; + return values; + } + } + + SmallVector values; + SmallVector elementTypes; + for (auto v : vecValues) { + auto elementTy = cast(v.getType()).getElementType(); + elementTypes.push_back(elementTy); + if (elementTy.isBF16() || elementTy.isF16()) { + values.emplace_back(builder.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + builder.getF32Type()), + v)); + } else if (elementTy.isInteger(1) || elementTy.isInteger(8) || + elementTy.isInteger(16)) { + values.emplace_back(builder.create( + loc, + VectorType::get(ArrayRef{vectorLength}, + builder.getI32Type()), + v)); + } else if (elementTy.isF32() || elementTy.isInteger(32)) { + values.push_back(v); + } else { + llvm_unreachable("unsupported element type"); + } + } + auto argNum = values.size(); + auto vaccLen = targetInfo[arch].vaccSizeInBytes / 4; + if (vectorLength != vaccLen) { + auto splitNum = vectorLength / vaccLen; + SmallVector splitValues(argNum * splitNum); + for (size_t i = 0; i < argNum; ++i) { + auto elementType = + cast(values[i].getType()).getElementType(); + SmallVector resultTypes( + splitNum, VectorType::get(ArrayRef{vaccLen}, elementType)); + auto vectorConvertOp = + builder.create(loc, resultTypes, values[i]); + for (auto j = 0; j < splitNum; ++j) { + splitValues[j * argNum + i] = vectorConvertOp.getResult(j); + } + } + splitValues.reserve(splitValues.size() * 2); + auto iter = splitValues.begin(); + while (splitNum != 1) { + splitNum /= 2; + for (auto i = 0; i < splitNum; ++i) { + splitValues.append(vectorizeCombineOpWithoutTerminator( + loc, builder, combineOp, ValueRange(iter, 2 * argNum), vaccLen)); + iter = std::next(iter, 2 * argNum); + } + } + values.assign(iter, iter + argNum); + } + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(builder.create( + loc, gcu::VectorMovSftMode::SHFRQW, values[i], 2)); + } + values = vectorizeCombineOpWithoutTerminator(loc, builder, combineOp, + values, vaccLen); + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(builder.create( + loc, gcu::VectorMovSftMode::SHFRQW, values[i], 1)); + } + values = vectorizeCombineOpWithoutTerminator(loc, builder, combineOp, + values, vaccLen); + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(builder.create( + loc, gcu::VectorMovSftMode::SHFRB, values[i], 8)); + } + values = vectorizeCombineOpWithoutTerminator(loc, builder, combineOp, + values, vaccLen); + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(builder.create( + loc, gcu::VectorMovSftMode::SHFRB, values[i], 4)); + } + values = vectorizeCombineOpWithoutTerminator(loc, builder, combineOp, + values, vaccLen); + for (size_t i = 0; i < argNum; ++i) { + values.emplace_back(builder.create( + loc, values[i].getType(), + builder.create( + loc, values[i], + builder.create(loc, 16)))); + } + values = vectorizeCombineOpWithoutTerminator(loc, builder, combineOp, + values, vaccLen); + for (size_t i = 0; i < argNum; ++i) { + values[i] = builder.create( + loc, values[i], builder.create(loc, 0)); + } + + for (size_t i = 0; i < argNum; ++i) { + auto elementType = elementTypes[i]; + if (elementType.isBF16() || elementType.isF16()) { + values[i] = + builder.create(loc, elementType, values[i]); + } else if (elementType.isInteger(1) || elementType.isInteger(8) || + elementType.isInteger(16)) { + values[i] = + builder.create(loc, elementType, values[i]); + } + } + return values; + } +}; +} // namespace + +void mlir::triton::populateReduceOpToGCUPatterns( + const TypeConverter &converter, RewritePatternSet &patterns, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) { + patterns.add(converter, patterns.getContext(), + userAnalysis, replaced2Origin); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonConvertTensorPointer.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonConvertTensorPointer.cpp new file mode 100644 index 000000000..51169856f --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonConvertTensorPointer.cpp @@ -0,0 +1,615 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/Support/Debug.h" + +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#define DEBUG_TYPE "triton-convert-tensor-pointer" +namespace mlir { +#define GEN_PASS_DEF_CONVERTTENSORPOINTERPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +/// An additional struct to record the meta information of operations +/// with tensor pointers +struct RewritedInfo { +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + Attribute encoding; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; + +public: + RewritedInfo() = default; + + RewritedInfo(const RewritedInfo &other) = default; + RewritedInfo &operator=(const RewritedInfo &other) = default; + + RewritedInfo(Value base, const SmallVector &shape, + const SmallVector &strides, + const SmallVector &offsets, + const ArrayRef &tensorShape, const Attribute &encoding) + : base(base), shape(shape), strides(strides), offsets(offsets), + tensorShape(tensorShape), encoding(encoding) { + assert(shape.size() == strides.size() && shape.size() == offsets.size() && + shape.size() == tensorShape.size()); + } + + unsigned int length() const { return shape.size(); } + + Value getOffset(unsigned i) { return offsets[i]; } + + SmallVector getOffsets() { return offsets; } + + void setOffset(unsigned i, Value newOffset) { + offsets[i] = newOffset; + cachedOffsetWithRange.clear(); + } + + void setOffsets(const SmallVector &newOffsets) { + offsets = newOffsets; + cachedOffsetWithRange.clear(); + } + + Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + unsigned i) { + if (cachedOffsetWithRange.count(i)) + return cachedOffsetWithRange[i]; + + auto srcEncoding = cast(encoding); + for (int j = tensorShape.size() - 1; j >= 0; j--) { + if (static_cast(j) == i) + continue; + srcEncoding = triton::gpu::SliceEncodingAttr::get( + builder.getContext(), j, + cast(srcEncoding)); + } + + auto indexI32RowType = RankedTensorType::get( + {tensorShape[i]}, builder.getI32Type(), srcEncoding); + + Value splatOffset = + builder.create(loc, indexI32RowType, offsets[i]); + Value range = builder.create(loc, indexI32RowType, 0, + tensorShape[i]); + // Expand dimensions + Value expandedResult = + builder.create(loc, splatOffset, range); + + for (unsigned j = 0; j < tensorShape.size(); ++j) { + if (j == i) + continue; + + expandedResult = + builder.create(loc, expandedResult, j); + } + + return cachedOffsetWithRange[i] = expandedResult; + } + + Value generatePtr(OpBuilder &builder, const Location &loc) { + assert(tensorShape.size() == offsets.size() && + tensorShape.size() == strides.size()); + + auto indexTensorType = + RankedTensorType::get(tensorShape, builder.getI32Type(), encoding); + auto ptrType = cast(base.getType()); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType, encoding); + + // Generate offsets per dimension + Value ptr = builder.create(loc, ptrTensorType, base); + for (unsigned i = 0; i < tensorShape.size(); ++i) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = builder.create( + loc, offsetWithRange.getType(), strides[i]); + Value offsetWithStride = + builder.create(loc, offsetWithRange, splatStride); + Value broadcasted = builder.create( + loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = builder.create(loc, ptrTensorType, ptr, + broadcasted); + } + + return ptr; + } + + Value generateMask(OpBuilder &builder, const Location &loc, + const std::optional> &boundaryCheck) { + if (!boundaryCheck.has_value()) + return {}; + + // Generate mask per dimension + auto maskTensorType = + RankedTensorType::get(tensorShape, builder.getI1Type(), encoding); + Value mask; + for (auto i : boundaryCheck.value()) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // Compare with lower bound + Value lowerBound = builder.create( + loc, 0, builder.getI32Type()); + Value splatLowerBound = builder.create( + loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = builder.create( + loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = builder.create( + loc, offsetWithRange.getType(), shape[i]); + Value cmpUpper = builder.create( + loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = builder.create(loc, cmpLower, cmpUpper); + Value broadcasted = + builder.create(loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = builder.create(loc, mask, broadcasted); + } + } + + return mask; + } + + Value generateOther(OpBuilder &builder, const Location &loc, + const std::optional &padding) { + if (!padding.has_value()) + return Value(); + + // Create element attribute + auto elementType = + cast(base.getType()).getPointeeType(); + auto otherTensorType = + RankedTensorType::get(tensorShape, elementType, encoding); + + // Set zero padding value + TypedAttr attr = + elementType.isIntOrIndex() + ? cast(builder.getIntegerAttr(elementType, 0)) + : cast(builder.getFloatAttr(elementType, 0)); + + // Float NaN padding case + if (padding.value() == triton::PaddingOption::PAD_NAN) { + assert(!elementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(attr).getValue().getSemantics()); + attr = builder.getFloatAttr(elementType, apNaN); + } + + // Create tensor + Value constant = builder.create(loc, attr); + return builder.create(loc, otherTensorType, constant); + } +}; + +static bool needRewrite(Operation *op) { + return std::any_of(op->getOperands().begin(), op->getOperands().end(), + [](Value operand) { + return triton::isTensorPointerType(operand.getType()); + }); +} + +static SmallVector +generateNewOperands(const SmallVector &oldOperands, unsigned index, + const SmallVector &newValues) { + assert(index < oldOperands.size()); + SmallVector newOperands; + for (unsigned i = 0; i < index; ++i) + newOperands.push_back(oldOperands[i]); + for (auto value : newValues) + newOperands.push_back(value); + for (auto i = index + 1; i < oldOperands.size(); ++i) + newOperands.push_back(oldOperands[i]); + return newOperands; +} + +struct ConvertTensorPointerPass + : public mlir::impl::ConvertTensorPointerPassBase< + ConvertTensorPointerPass> { + using Base::Base; + void runOnOperation() override; + + Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, + triton::MakeTensorPtrOp op, + std::stack &eraser); + Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op, + std::stack &eraser); + Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser); + Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser); + Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser); + Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser); + Operation *rewriteOp(Operation *op, std::stack &eraser); + void visitOperation(Operation *op, std::stack &eraser); + +private: + DenseMap rewritedInfo; +}; + +} // namespace + +Operation *ConvertTensorPointerPass::rewriteMakeTensorPtrOp( + OpBuilder &builder, triton::MakeTensorPtrOp op, + std::stack &eraser) { + // Save info for later use + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + + // Cast I64 Shape,Strides into I32 + SmallVector i32Shape, i32Strides; + for (auto dim : op.getShape()) { + auto i32Dim = + builder.create(op.getLoc(), builder.getI32Type(), dim); + i32Shape.push_back(i32Dim); + } + + for (auto stride : op.getStrides()) { + auto i32Stride = builder.create( + op.getLoc(), builder.getI32Type(), stride); + i32Strides.push_back(i32Stride); + } + + // Save information + rewritedInfo[op.getResult()] = + RewritedInfo(op.getBase(), i32Shape, i32Strides, op.getOffsets(), + tensorType.getShape(), tensorType.getEncoding()); + + // Erase the original operation + eraser.push(op); + return nullptr; +} + +Operation *ConvertTensorPointerPass::rewriteAdvanceOp( + OpBuilder &builder, triton::AdvanceOp op, std::stack &eraser) { + // Get info from previous results + assert(rewritedInfo.count(op.getPtr())); + auto info = rewritedInfo[op.getPtr()]; + + // Calculate new offsets + assert(info.length() == op.getOffsets().size()); + SmallVector newOffsets; + for (unsigned i = 0; i < info.length(); ++i) { + // Value i64Offset = builder.create( + // op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); + Value newOffset = builder.create( + op.getLoc(), info.getOffset(i), /*i64Offset*/ op.getOffsets()[i]); + newOffsets.push_back(newOffset); + } + + // Save info for later use + info.setOffsets(newOffsets); + rewritedInfo[op.getResult()] = info; + + // Erase the original operation + eraser.push(op); + return nullptr; +} + +Operation * +ConvertTensorPointerPass::rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser) { + assert(isa(op) || isa(op)); + + // We only have to rewrite load/stores with tensor pointers + auto ptr = op->getOperand(0); + if (!triton::isTensorPointerType(ptr.getType())) + return nullptr; + + // Get info from previous results + assert(rewritedInfo.count(ptr)); + auto info = rewritedInfo[ptr]; + + // Load/store with tensor pointers implicitly will check the bound while + // accessing memory, so we should set `mask` and `other` (according to the + // padding). Also note that load with tensor pointers do not have `mask` and + // `other` while building IR from Python AST + std::optional> boundaryCheck; + if (auto loadOp = dyn_cast(op)) { + assert(!loadOp.getMask() && !loadOp.getOther()); + boundaryCheck = loadOp.getBoundaryCheck(); + } else if (auto storeOp = dyn_cast(op)) { + assert(!storeOp.getMask()); + boundaryCheck = storeOp.getBoundaryCheck(); + } + + // Generate new `ptr`, `mask` and `other` + auto newPtr = info.generatePtr(builder, op->getLoc()); + auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); + Value newOther; + if (auto loadOp = dyn_cast(op)) + newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding()); + + // Create a new operation + if (auto loadOp = dyn_cast(op)) { + auto newResult = builder.create( + loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + op->getResult(0).replaceAllUsesWith(newResult); + } else if (auto storeOp = dyn_cast(op)) { + builder.create(storeOp.getLoc(), newPtr, + storeOp.getValue(), newMask, + storeOp.getCache(), storeOp.getEvict()); + } + + // Erase the original operation + eraser.push(op); + return nullptr; +} + +Operation * +ConvertTensorPointerPass::rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser) { + auto thenYieldOp = op.thenYield(); + assert(op.getNumResults() == thenYieldOp.getNumOperands()); + SmallVector results = thenYieldOp.getOperands(); + + // get new result types + SmallVector newRetTypes; + bool needRewrite = false; + for (unsigned i = 0; i < results.size(); ++i) { + if (!triton::isTensorPointerType(results[i].getType())) { + newRetTypes.push_back(results[i].getType()); + continue; + } + needRewrite = true; + auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + newRetTypes.push_back(builder.getI64Type()); + } + } + if (!needRewrite) + return op; + // create and clone new IfOp + bool hasElse = !op.getElseRegion().empty(); + scf::IfOp newOp = builder.create(op.getLoc(), newRetTypes, + op.getCondition(), hasElse); + IRMapping mapping; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + mapping.map(op->getOperand(i), newOp->getOperand(i)); + } + auto rematerialize = [&](Block *block) { + for (Operation &opInIf : block->getOperations()) { + builder.clone(opInIf, mapping); + } + }; + builder.setInsertionPointToStart(newOp.thenBlock()); + rematerialize(op.thenBlock()); + if (hasElse) { + builder.setInsertionPointToStart(newOp.elseBlock()); + rematerialize(op.elseBlock()); + } + + // update rewritedInfo + unsigned oldResIdx = 0, newResIdx = 0; + while (oldResIdx < results.size()) { + if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + oldResIdx++; + newResIdx++; + } else { + auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + info.setOffset(j, newOp->getResult(newResIdx++)); + } + rewritedInfo[op.getResult(oldResIdx)] = info; + oldResIdx++; + } + } + + eraser.push(op); + return newOp; +} + +Operation * +ConvertTensorPointerPass::rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser) { + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; + ++i, ++oldI) { + if (!triton::isTensorPointerType(newIterOperands[i].getType())) + continue; + + // Expand the tensor pointer into offsets + assert(rewritedInfo.count(newIterOperands[i])); + auto info = rewritedInfo[newIterOperands[i]]; + newIterOperands = + generateNewOperands(newIterOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + + // Rebuild the loop type + auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep(), + newIterOperands); + + // Create value mapping. Note that for tensor pointers, we use identity + // mapping. It may refer to a value in the old loop, but we will rewrite it + // later + IRMapping mapping; + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; + ++i, ++oldI) { + auto oldRegionIterArg = op.getRegionIterArg(oldI); + if (triton::isTensorPointerType(oldRegionIterArg.getType())) { + // Pass rewritten info inside + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + mapping.map(oldRegionIterArg, oldRegionIterArg); + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getRegionIterArg(i + j)); + rewritedInfo[oldRegionIterArg] = info; + i += info.length() - 1; + } else { + mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); + } + } + mapping.map(op.getInductionVar(), newForOp.getInductionVar()); + + // Clone body + builder.setInsertionPointToStart(newForOp.getBody()); + for (auto &opInFor : *op.getBody()) { + auto *newOp = builder.clone(opInFor, mapping); + for (unsigned i = 0; i < opInFor.getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + + // Replace later usages + assert(op.getNumResults() == op.getInitArgs().size()); + for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { + auto oldResult = op.getResult(oldI); + if (triton::isTensorPointerType(oldResult.getType())) { + // Pack new offsets into rewritten info + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getResult(i + j)); + i += info.length() - 1; + rewritedInfo[oldResult] = info; + } else { + oldResult.replaceAllUsesWith(newForOp.getResult(i)); + } + } + + // Erase later + eraser.push(op); + return newForOp; +} + +Operation * +ConvertTensorPointerPass::rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser) { + // Replace tensor pointers with offsets + SmallVector newOperands = op->getOperands(); + for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { + if (!triton::isTensorPointerType(newOperands[i].getType())) + continue; + + assert(rewritedInfo.count(newOperands[i])); + auto info = rewritedInfo[newOperands[i]]; + newOperands = generateNewOperands(newOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + op->setOperands(newOperands); + + // No need to erase + return nullptr; +} + +Operation * +ConvertTensorPointerPass::rewriteOp(Operation *op, + std::stack &eraser) { + OpBuilder builder(op); + + // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers + // Rewriting functions return the next operation to visit, if there is no + // next one, simply return `nullptr` + if (auto makeTensorPtrOp = dyn_cast(op)) { + return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); + } else if (auto advanceOp = dyn_cast(op)) { + return rewriteAdvanceOp(builder, advanceOp, eraser); + } else if (isa(op) || isa(op)) { + return rewriteLoadStoreOp(builder, op, eraser); + } else if (op->getDialect()->getNamespace() == "scf" || + op->getDialect()->getNamespace() == "cf") { + if (auto ifOp = dyn_cast(op)) { + return rewriteIfOp(builder, ifOp, eraser); + } + if (!needRewrite(op)) + return op; + + if (auto forOp = dyn_cast(op)) { + return rewriteForOp(builder, forOp, eraser); + } else if (auto yieldOp = dyn_cast(op)) { + return rewriteYieldOp(builder, yieldOp, eraser); + } else { + llvm_unreachable("Currently we only support tensor pointer usages " + "inside a `scf::ForOp` or `scf::IfOp`, others such as " + "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " + "are not supported yet"); + } + } + + // Otherwise return the original one + return op; +} + +void ConvertTensorPointerPass::visitOperation(Operation *op, + std::stack &eraser) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + // We need an extra copy because erasing operations may break the + // iterator behavior + SmallVector blockCopy; + for (auto &nestedOp : block) + blockCopy.push_back(&nestedOp); + + // Rewrite and recursively visit + for (auto &nestedOp : blockCopy) { + if (auto newOp = rewriteOp(nestedOp, eraser)) + visitOperation(newOp, eraser); + } + } + } +} + +void ConvertTensorPointerPass::runOnOperation() { + std::stack eraser; + visitOperation(getOperation(), eraser); + rewritedInfo.clear(); + while (!eraser.empty()) { + auto op = eraser.top(); + eraser.pop(); + op->erase(); + } +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonFuncOpFlatten.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonFuncOpFlatten.cpp new file mode 100644 index 000000000..25c17db7d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonFuncOpFlatten.cpp @@ -0,0 +1,110 @@ +/* + * Copyright 2023 - 2024 Enflame.All Rights Reserved. + * + */ +#include +#include +#include + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +#define GEN_PASS_DEF_GCUFLATTENTRITONFUNCPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; + +namespace { +static uint64_t uniqueId = 1; + +struct FlattenTritonFuncPass + : public mlir::impl::GCUFlattenTritonFuncPassBase { + using Base::Base; + + void runOnOperation() override; + +private: + void flattenFuncOp( + Operation *funcOp, + std::map> &funcName2CallOps); + + gpu::GPUModuleOp moduleOp_; +}; +} // namespace + +void FlattenTritonFuncPass::runOnOperation() { + moduleOp_ = getOperation(); + + std::vector publicFuncOps; + for (auto ttFunc : moduleOp_.getOps()) { + if (ttFunc.isPublic()) { + publicFuncOps.push_back(ttFunc.getOperation()); + } + } + + for (auto &publicFuncOp : publicFuncOps) { + std::map> funcName2CallOps; + funcName2CallOps.clear(); + flattenFuncOp(publicFuncOp, funcName2CallOps); + } +} + +void FlattenTritonFuncPass::flattenFuncOp( + Operation *funcOp, + std::map> &funcName2CallOps) { + auto func = llvm::dyn_cast(funcOp); + auto curFuncName = func.getName(); + + std::vector calleeFuncOps; + + func.walk([&](triton::CallOp call) { + auto calleeFuncName = call.getCallee(); + auto calleeFunc = + moduleOp_.template lookupSymbol(calleeFuncName); + + if (curFuncName == calleeFuncName) { + std::string o; + llvm::raw_string_ostream os(o); + funcOp->print(os); + os.str(); + llvm_unreachable( + (std::string("unsupported recursive call: \n") + o).c_str()); + } + + if (calleeFunc.isExternal()) + return; + + if (funcName2CallOps[calleeFuncName].size() != 0) { + std::string newFuncName = + std::string(calleeFuncName) + "_clone" + std::to_string(uniqueId++); + + auto calleeFuncClone = calleeFunc.clone(); + calleeFuncClone.setName(newFuncName); + call.setCallee(newFuncName); + + moduleOp_.insert(calleeFunc, calleeFuncClone); + + calleeFuncOps.push_back(calleeFuncClone.getOperation()); + funcName2CallOps[newFuncName].push_back(call.getOperation()); + } else { + calleeFuncOps.push_back(calleeFunc.getOperation()); + funcName2CallOps[calleeFuncName].push_back(call.getOperation()); + } + }); + + for (unsigned int i = 0; i < calleeFuncOps.size(); i++) { + flattenFuncOp(calleeFuncOps[i], funcName2CallOps); + } +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonFusion.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonFusion.cpp new file mode 100644 index 000000000..df426bddb --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonFusion.cpp @@ -0,0 +1,463 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Utils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +#define GEN_PASS_DEF_GCUTRITONFUSIONPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gcu; + +namespace { + +struct LoadOptimizationPatter : public OpRewritePattern { + explicit LoadOptimizationPatter(MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(triton::LoadOp op, + mlir::PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto addPtrOp = op.getPtr().getDefiningOp(); + if (!addPtrOp) { + return failure(); + } + + auto splatOp = addPtrOp.getPtr().getDefiningOp(); + if (!splatOp) { + return failure(); + } + auto ptr = splatOp.getSrc(); + auto offset = addPtrOp.getOffset(); + if (!getElementTypeOrSelf(offset.getType()).isInteger(32)) { + return failure(); + } + + while (auto addPtrOp = ptr.getDefiningOp()) { + offset = rewriter.create( + loc, offset, + rewriter.create(loc, offset.getType(), + addPtrOp.getOffset())); + ptr = addPtrOp.getPtr(); + } + + auto mask = op.getMask(); + auto other = op.getOther(); + + rewriter.replaceOp(op, rewriter.create( + loc, ptr, offset, mask, other)); + return ::mlir::success(); + } +}; + +struct StoreOptimizationPatter : public OpRewritePattern { + explicit StoreOptimizationPatter(MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(triton::StoreOp op, + mlir::PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto addPtrOp = op.getPtr().getDefiningOp(); + if (!addPtrOp) { + return failure(); + } + + auto splatOp = addPtrOp.getPtr().getDefiningOp(); + if (!splatOp) { + return failure(); + } + + auto offset = addPtrOp.getOffset(); + if (!getElementTypeOrSelf(offset.getType()).isInteger(32)) { + return failure(); + } + auto ptr = splatOp.getSrc(); + while (auto addPtrOp = ptr.getDefiningOp()) { + offset = rewriter.create( + loc, offset, + rewriter.create(loc, offset.getType(), + addPtrOp.getOffset())); + ptr = addPtrOp.getPtr(); + } + + auto mask = op.getMask(); + auto value = op.getValue(); + + rewriter.create(loc, ptr, offset, value, mask); + rewriter.eraseOp(op); + return ::mlir::success(); + } +}; + +struct ConvertClampFOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::ClampFOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + if (stringifyPropagateNan(op.getPropagateNan()) == "all") { + auto newOp = rewriter.create( + loc, op.getMin(), + rewriter.create(loc, op.getX(), op.getMax())); + rewriter.replaceOp(op, newOp); + } else { + auto newOp = rewriter.create( + loc, op.getMin(), + rewriter.create(loc, op.getX(), op.getMax())); + rewriter.replaceOp(op, newOp); + } + return success(); + } +}; + +struct ConvertPreciseDivFOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::PreciseDivFOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto newOp = rewriter.create(loc, op.getX(), op.getY()); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct ConvertFpToFpOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::FpToFpOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + mlir::Value newOp = nullptr; + auto srcType = cast(op.getSrc().getType()); + auto resType = cast(op.getResult().getType()); + if (getElementBitWidth(resType) > + getElementBitWidth(srcType)) { // Cast from floating-point + // to wider floating-point, fp8->fp32 + newOp = rewriter.create(loc, op.getType(), op.getSrc()); + } else { // Cast from floating-point to narrower floating-point, fp32->fp8 + newOp = rewriter.create(loc, op.getType(), op.getSrc()); + } + + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct ConvertPreciseSqrtOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::PreciseSqrtOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto newOp = rewriter.create(loc, op.getX()); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +} // namespace + +namespace { +struct GCUTritonFusionPass + : public mlir::impl::GCUTritonFusionPassBase { + using Base::Base; + + void runOnOperation() override { + dotZeroBiasFusion(); + // should do bias fusion before constant zero be fusioned + auto module = getOperation(); + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + + patterns.add(ctx); + if (failed(applyPatternsGreedily(module, std::move(patterns)))) + signalPassFailure(); + + auto hasNoSideEffects = [](Operation *op) { + return !isa(op); + }; + + DenseSet eraseOpSet; + for (Operation *func : module.getOps()) { + func->walk([&](Operation *op) { + if (isElementwiseOp(op) && canVectorize(op)) { + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + auto def = op->getOperand(i).getDefiningOp(); + if (def && ((isa(def) && hasNoSideEffects(op)) || + isa(def))) { + OpBuilder builder(op); + eraseOpSet.insert(def); + op->replaceUsesOfWith(op->getOperand(i), + builder.clone(*def)->getResult(0)); + } + } + } + }); + } + + for (auto op : eraseOpSet) { + if (op->getUses().empty()) { + op->erase(); + } + } + + for (auto func : module.getOps()) { + for (auto ®ion : func->getRegions()) { + runFuse(region); + } + } + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + +private: + bool isElementwiseOp(Operation *op); + bool canVectorize(Operation *op); + void runFuse(Region ®ion); + void fuseOps(ArrayRef ops); + void dotZeroBiasFusion(); +}; +} // namespace + +void GCUTritonFusionPass::runFuse(Region ®ion) { + SmallVector> fusionOps; + fusionOps.emplace_back(); + + for (auto &block : region) { + if (!fusionOps.back().empty()) { + fusionOps.emplace_back(); + } + for (auto &op : llvm::make_early_inc_range(block.getOperations())) { + for (auto ®ion : op.getRegions()) { + runFuse(region); + } + auto &fusionOp = fusionOps.back(); + if (isElementwiseOp(&op) && canVectorize(&op)) { + if (fusionOp.empty()) { + fusionOp.push_back(&op); + } else { + auto curType = + isa(op) + ? cast(op).getValue().getType() + : cast(op.getResultTypes().front()); + auto preType = isa(fusionOp.back()) + ? cast(fusionOp.back()) + .getValue() + .getType() + : cast( + fusionOp.back()->getResultTypes().front()); + if (curType.getShape() == preType.getShape() && + getElemsPerThread(curType) == getElemsPerThread(preType)) { + fusionOp.push_back(&op); + } else { + fusionOps.emplace_back().push_back(&op); + } + } + } else if (!fusionOp.empty()) { + if (op.hasTrait() && + llvm::all_of(fusionOp, [&](auto innerOp) { + return llvm::none_of(innerOp->getUsers(), + [&](auto user) { return user == &op; }); + })) { + op.moveBefore(fusionOp.front()); + } else { + fusionOps.emplace_back(); + } + } + } + } + + for (auto fusionOp : fusionOps) { + if (!fusionOp.empty()) + fuseOps(fusionOp); + } +} + +bool GCUTritonFusionPass::isElementwiseOp(Operation *op) { + if (!op) { + return false; + } + if (!llvm::all_of(op->getResultTypes(), llvm::IsaPred)) { + return false; + } + if (auto constantOp = dyn_cast(op)) { + auto valueAttr = dyn_cast(constantOp.getValue()); + return valueAttr && valueAttr.isSplat(); + } + if (isa(op)) { + return true; + } + + return OpTrait::hasElementwiseMappableTraits(op); +} + +bool GCUTritonFusionPass::canVectorize(Operation *op) { + if (!llvm::all_of(op->getResultTypes(), llvm::IsaPred)) { + return false; + } + return llvm::all_of(op->getResultTypes(), [](auto type) { + auto elementTy = cast(type).getElementType(); + return elementTy.isBF16() || elementTy.isF16() || elementTy.isF32() || + elementTy.isInteger(1) || elementTy.isInteger(8) || + elementTy.isInteger(16) || elementTy.isInteger(32); + }); +} + +void GCUTritonFusionPass::fuseOps(ArrayRef ops) { + DenseSet fusionOp(ops.begin(), ops.end()); + + SetVector fusionOperands; + SetVector fusionResults; + for (auto op : ops) { + for (auto v : op->getOperands()) { + if (!fusionOp.count(v.getDefiningOp())) { + fusionOperands.insert(v); + } + } + for (auto result : op->getResults()) { + for (auto user : result.getUsers()) { + if (!fusionOp.count(user)) { + fusionResults.insert(result); + } + } + } + } + + OpBuilder builder(ops.front()); + auto loc = ops.front()->getLoc(); + auto operands = fusionOperands.takeVector(); + auto results = fusionResults.takeVector(); + auto resultTypes = llvm::to_vector( + llvm::map_range(results, [](auto result) { return result.getType(); })); + + auto fusedOp = builder.create( + loc, resultTypes, operands); + auto &entryBlock = fusedOp.getRegion().emplaceBlock(); + { + IRMapping map; + for (auto operand : operands) { + auto arg = entryBlock.addArgument(operand.getType(), loc); + map.map(operand, arg); + } + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&entryBlock); + for (auto op : ops) { + auto newOp = builder.clone(*op, map); + for (auto [result, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + map.map(result, newResult); + } + } + auto fusedResults = llvm::to_vector(llvm::map_range( + results, [&map](auto result) { return map.lookup(result); })); + builder.create(loc, fusedResults); + } + + for (auto [result, fusedResult] : llvm::zip(results, fusedOp.getResults())) { + result.replaceAllUsesWith(fusedResult); + } + + for (auto op : ops) { + op->dropAllUses(); + op->erase(); + } +} + +void GCUTritonFusionPass::dotZeroBiasFusion() { + auto trionModule = getOperation(); + llvm::SmallVector dotList; + trionModule.walk([&](DotOp dot) { dotList.push_back(dot); }); + for (auto &dot : dotList) { + auto retType = dyn_cast(dot.getType()); + int rank = retType.getShape().size(); + if (rank > 2) { + // need test case for 3D dot + continue; + } + // fusion dot and result cast + OpBuilder rewriter(dot); + auto dotUsers = dot.getResult().getUsers(); + auto userNumber = std::distance(dotUsers.begin(), dotUsers.end()); + bool isOnlyMatmul = false; + auto CTensor = dot.getC().getDefiningOp(); + if (CTensor && isa(CTensor)) { + auto constantC = llvm::cast(CTensor); + if (auto splatAttr = + llvm::dyn_cast(constantC.getValue())) { + mlir::Type elementType = splatAttr.getElementType(); + if (elementType.isInteger()) { + if (splatAttr.getSplatValue().isZero()) { + isOnlyMatmul = true; + } + } else if (elementType.isBF16() || elementType.isF16() || + elementType.isTF32() || elementType.isF32()) { + if (splatAttr.getSplatValue().isZero()) { + isOnlyMatmul = true; + } + } + } + } + if (userNumber == 1 && isOnlyMatmul && + llvm::isa(*dotUsers.begin())) { + auto castDotResult = llvm::cast(*dotUsers.begin()); + auto newDot = rewriter.create( + dot.getLoc(), castDotResult.getType(), dot.getA(), dot.getB()); + castDotResult.getResult().replaceAllUsesWith(newDot.getResult()); + castDotResult.erase(); + dot.erase(); + continue; + } + if (isOnlyMatmul) { + auto newDot = rewriter.create( + dot.getLoc(), dot.getType(), dot.getA(), dot.getB()); + dot.getResult().replaceAllUsesWith(newDot.getResult()); + dot.erase(); + continue; + } + // maybe we can dot other fusion about dot in future as trion gpu' combine + } +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCULayoutOptimize.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCULayoutOptimize.cpp new file mode 100644 index 000000000..aad56d57f --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCULayoutOptimize.cpp @@ -0,0 +1,86 @@ +/* + * Copyright 2020 - 2022 Enflame.All Rights Reserved. + * + */ + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "Utils.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +namespace mlir { +#define GEN_PASS_DEF_TRITONGCULAYOUTOPTIMIZEPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir + +#define DEBUG_TYPE "triton-gcu-data-layout-optimize" +namespace { +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +struct TritonGCULayoutOptimizePass + : public mlir::impl::TritonGCULayoutOptimizePassBase< + TritonGCULayoutOptimizePass> { + using Base::Base; + + void runOnOperation() override; + void RefineGcuLoadStoreLayout(); + void reWriteGcuStoreLayout(triton::gcu::StoreOp store); + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } +}; + +void TritonGCULayoutOptimizePass::reWriteGcuStoreLayout( + triton::gcu::StoreOp store) { + auto src = store.getValue().getDefiningOp(); + if (auto convetLayout = dyn_cast(src)) { + auto users = src->getUsers(); + auto userNumber = std::distance(users.begin(), users.end()); + // only one user delete and fusion convert layout to store + if (userNumber == 1) { + auto trueSrc = convetLayout.getSrc(); + convetLayout.getResult().replaceAllUsesWith(trueSrc); + convetLayout.erase(); + } + } +} + +void TritonGCULayoutOptimizePass::RefineGcuLoadStoreLayout() { + auto trionModule = getOperation(); + llvm::SmallVector storeList; + trionModule.walk( + [&](triton::gcu::StoreOp store) { storeList.push_back(store); }); + for (auto &store : storeList) { + reWriteGcuStoreLayout(store); + } +} +} // namespace +using namespace mlir; +void TritonGCULayoutOptimizePass::runOnOperation() { + RefineGcuLoadStoreLayout(); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/MatmulLoopPipeline.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/MatmulLoopPipeline.cpp new file mode 100644 index 000000000..86cecc65d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/MatmulLoopPipeline.cpp @@ -0,0 +1,916 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "Utils.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "triton-matmul-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +namespace { + +struct LoadInfo { + // Layout of the data in the shared memory. + ttg::SwizzledSharedEncodingAttr sharedEncoding = nullptr; + int distToUse = 0; + // always preload + bool usedByDot = true; +}; +class CoarseSchedule { +public: + class ClusterList { + std::list orderClusters; + + public: + using iterator = decltype(orderClusters)::iterator; + ClusterList() = default; + iterator begin() { return orderClusters.begin(); } + iterator end() { return orderClusters.end(); } + size_t size() { return orderClusters.size(); } + iterator newAtBack() { + orderClusters.push_back(orderClusters.size()); + return std::prev(orderClusters.end()); + } + iterator newAtFront() { + orderClusters.push_front(-1); + for (auto &clusterId : orderClusters) { + clusterId++; + } + return orderClusters.begin(); + } + iterator newBefore(iterator cluster) { + auto ret = orderClusters.insert(cluster, *cluster); + for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { + clusterId++; + } + return ret; + } + }; + + explicit CoarseSchedule(int numStages) : numStages(numStages) { + LDBG("CoarseSchedule " + << "#######################################################\n"); + } + int numStages; + ClusterList clusters; + using Cluster = decltype(clusters)::iterator; + + DenseMap> opToStageAndCluster; + + void insert(Operation *op, int stage, Cluster cluster) { + opToStageAndCluster[op] = {stage, cluster}; + } + + bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { + if (opToStageAndCluster.count(op)) + return false; + insert(op, stage, cluster); + return true; + } + + void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg) { + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::DenseSet seen; + while (auto arg = dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (insertIfAbsent(defOp, stage, cluster)) { + insertDepsOfOp(defOp, stage, cluster, includeArg); + } + } + } + } + + void erase(Operation *op) { opToStageAndCluster.erase(op); } + + int count(Operation *op) { return opToStageAndCluster.count(op); } + + std::pair operator[](Operation *op) { + return opToStageAndCluster[op]; + } + + SmallVector> + getOpsInOrder(scf::ForOp forOp) { + LLVM_DEBUG({ llvm::dbgs() << "getOpsInOrder enter\n"; }); + + SmallVector>, 8> + orderClusters(clusters.size()); + for (auto &op : forOp.getBody()->without_terminator()) { + if (opToStageAndCluster.count(&op) == 0) { + continue; + } + assert(opToStageAndCluster[&op].first < numStages && + "Op with invalid stage!"); + int clusterId = *opToStageAndCluster[&op].second; + assert(clusterId == std::distance(clusters.begin(), + opToStageAndCluster[&op].second) && + "Cluster ID mismatch!"); + LLVM_DEBUG({ + llvm::dbgs() << "orderClusters push\n"; + op.dump(); + }); + orderClusters[clusterId].push_back( + make_tuple(&op, opToStageAndCluster[&op].first, + opToStageAndCluster[&op].second)); + } + SmallVector> opsInOrder; + for (size_t i = 0; i < orderClusters.size(); i++) { + for (auto [op, stage, cluster] : orderClusters[i]) { + opsInOrder.push_back({op, stage, cluster}); + } + } + + return opsInOrder; + } + + std::vector> + createFinalSchedule(scf::ForOp forOp) { + SmallVector> opsInOrder = + getOpsInOrder(forOp); + std::vector> schedule; + for (auto [op, stage, cluster] : opsInOrder) { + (void)cluster; + LDBG("Adding op to FinalSchedule at stage" << stage << " cluster " + << *cluster << ":" << *op); + schedule.push_back({op, stage}); + } + return schedule; + } + + void dump() { + for (int i = 0; i < numStages; i++) { + LDBG("- Ops in stage " << i); + for (auto &[op, stageAndCluster] : opToStageAndCluster) { + if (i == stageAndCluster.first) { + llvm::outs() << " cluster: " << *stageAndCluster.second << " "; + op->dump(); + } + } + } + } +}; +} // namespace + +// Replace the ForOp's yield with a new one with the given operands appended. +static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { + // Fix up the yield op. + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +static void createAsyncCopy(scf::ForOp &forOp, triton::gcu::LoadOp loadOp, + Value alloc, Value insertIdx, Value extractIdx, + CoarseSchedule &schedule, + CoarseSchedule::Cluster prefetchCluster, + llvm::MapVector &loadToInfo, + int numStages) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + + ttg::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + ttg::MemDescType subviewTy = ttg::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), allocTy.getMemorySpace(), /*mutableMemory=*/true); + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + Operation *copy = builder.create( + loadOp.getLoc(), loadOp.getPtr(), loadOp.getShape(), loadOp.getStrides(), + loadOp.getOffsets(), view, loadOp.getDefaultValue(), + loadOp.getOrderHint()); + Operation *wait = + builder.create(loc, copy->getResult(0)); + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + + auto loadUsers = loadOp->getResult(0).getUsers(); + auto userNumber = std::distance(loadUsers.begin(), loadUsers.end()); + if (userNumber == 0) { + loadOp.dump(); + llvm::report_fatal_error( + "[Error] there more than one loadUser of load that can't do pinpong\n"); + } + auto loadUser = *loadUsers.begin(); + if (userNumber == 1 && + llvm::isa_and_nonnull(loadUser)) { + auto converLayout = dyn_cast(loadUser); + auto localLoad = builder.create( + loc, converLayout.getType(), viewLoad, wait->getResult(0)); + auto result = localLoad->getResults(); + converLayout->getResult(0).replaceAllUsesWith(result[0]); + converLayout.erase(); + loadOp.erase(); + + } else { + auto localLoad = builder.create( + loc, loadOp.getType(), viewLoad, wait->getResult(0)); + loadOp->getResult(0).replaceAllUsesWith(localLoad->getResult(0)); + loadOp.erase(); + } + // Prefetch load if is not MMAV3 and is used by the dot. + schedule.insert(wait, numStages - 2, prefetchCluster); + schedule.insert(viewLoad, numStages - 2, prefetchCluster); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreSameEnc(Value val) { + ttg::SwizzledSharedEncodingAttr attr; + for (Operation *user : val.getUsers()) { + ttg::SwizzledSharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (!isa(user)) + return std::nullopt; + int64_t bufferSize = 1; + if (auto tType = dyn_cast(val.getType())) { + bufferSize = triton::gcu::getBpe(tType.getElementType()); + auto shapes = tType.getShape(); + for (int32_t i = 0; i < tType.getRank(); i++) { + bufferSize = bufferSize * shapes[i]; + } + } + auto converLayout = dyn_cast(user); + auto srcNumElems = + triton::gcu::getElemsPerThread(converLayout.getSrc().getType()); + // if no warp share and buffer is small than 16k skip pinpong + auto dstNumElems = triton::gcu::getElemsPerThread(converLayout.getType()); + // only for 300 now + if (srcNumElems == dstNumElems && bufferSize < 32 * 1024) { + return std::nullopt; + } + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + SmallVector sharedOrder; + unsigned rank = order.size(); + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be + // the slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (order[i] != 0) + sharedOrder.emplace_back(order[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = order; + } + if (auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()) + .getEncoding())) { + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, CTALayout, + bitWidth, /*needTrans=*/false); + } else { + tempAttr = ttg::SwizzledSharedEncodingAttr::get(val.getContext(), 1, 1, 1, + sharedOrder, CTALayout); + } + + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +// Create a map from load ops to their indirection level and the +// final use of the load op (another load op, or a dot op). +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +static llvm::SmallVector> +loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + llvm::SmallVector> + loadOpToIndLevelAndUse; + DenseSet seen; + + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + if (!seen.insert(op).second) + return; + if (isa(op)) { + // TODO(xingxing): What if there are multiple uses at different + // distances? + loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use)); + use = op; + distance++; + } + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr("tt.num_stages")) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } + + return loadOpToIndLevelAndUse; +} + +static llvm::MapVector +assignMemoryLayouts(llvm::SmallVector> + &loadOpToIndLevelAndUse) { + llvm::MapVector loadToInfo; + + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + (void)dist; + if (loadToInfo.count(op)) + // TODO(pawel): err, we'd need to verify that the distance is the same + continue; + LoadInfo loadInfo; + // for dot + if (isa(use) && + isa(op)) { + loadInfo.usedByDot = true; + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreSameEnc(op->getResult(0)).value_or(nullptr); + } else if (auto loadOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadsToPipeline already, it means that the use is not valid for + // pipelining for some reason. We should skip this loadOp, too. Note + // that we have an assumption that distAndUse.second (i.e. the use of + // this loadOp) has already be processed in a previous loop iteration. + // This assumption is held by how loadOpsToIndirectionLevelAndUse + // recursively collects loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(loadOp) == 0) { + continue; + } + } else { + loadInfo.usedByDot = false; + if (!isa(op)) { + continue; + } + int64_t bufferSize = 1; + if (auto tType = dyn_cast(op->getResult(0).getType())) { + bufferSize = triton::gcu::getBpe(tType.getElementType()); + auto shapes = tType.getShape(); + for (int32_t i = 0; i < tType.getRank(); i++) { + bufferSize = bufferSize * shapes[i]; + } + } + if (bufferSize < 1024) { + continue; + } + auto gcuLoad = dyn_cast(op); + auto srcTy = dyn_cast(gcuLoad.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy); + loadInfo.sharedEncoding = ttg::SwizzledSharedEncodingAttr::get( + srcTy.getContext(), 1, 1, 1, order, CTALayout); + } + // If that still didn't work, bail on pipelining this load. + if (!loadInfo.sharedEncoding) { + continue; + } + loadToInfo[op] = loadInfo; + } + + return loadToInfo; +} + +static llvm::MapVector +scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return {}; + + // Check which loads are good for pipelining, and assign them + // memory layouts. + llvm::MapVector loadToInfo = + assignMemoryLayouts(loadOpToIndLevelAndUse); + + if (loadToInfo.empty()) { + LLVM_DEBUG({ LDBG("loadToInfo.empty \n"); }); + return {}; + } + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { + (void)use; + if (loadToInfo.count(loadOp) == 0) + continue; + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + } + unsigned stagesBetweenLoads = + ceil(numStages - 2, maxIndirectionLevel + 1); + LLVM_DEBUG({ + LDBG("scheduleLoads::stagesBetweenLoads " << stagesBetweenLoads << "\n"); + }); + CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + (void)dist; + if (loadToInfo.count(loadOp) == 0) + continue; + // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be + // always present in the opInfo + if (!isa(use)) { + schedule.insert(use, numStages - 1, rootUsersCluster); + rootUsers.insert(use); + } + } + + SmallVector loadsClusters; + for (int i = 0; i < maxIndirectionLevel + 1; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + (void)_; + if (loadToInfo.count(loadOp) == 0) + continue; + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); + } + + // Distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + (void)_; + if (loadToInfo.count(loadOp) == 0) + continue; + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + return loadToInfo; +} + +// Schedule the prologue and epilogue `if` ops in the loop, pushing them as +// close to the loop boundaries as possible. Return the cluster after the +// prologue (or the beginning of the loop if there is no prologue). +static CoarseSchedule::Cluster +schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, _] : schedule.getOpsInOrder(forOp)) { + (void)_; + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice(reinterpret_cast(op), &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insert(ifOp, stage, prologueCluster); + } + + // Look for the IfOp that is in the forward slice of the root users and put it + // at the end of the loop. + CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto rootUser : rootUsers) { + SetVector forwardSlice; + getForwardSlice(rootUser, &forwardSlice); + + int stage = schedule[rootUser].first; + for (auto op : forwardSlice) { + scf::IfOp ifOp = dyn_cast(op); + if (ifOp == nullptr) { + // check if the op is in the body of an if op that's part of the loop + auto parentOp = op->getParentOp(); + if (parentOp != nullptr && + parentOp->getParentOp() == forOp.getOperation()) { + ifOp = dyn_cast(parentOp); + } + } + if (ifOp) { + schedule.insertIfAbsent(ifOp, stage, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +static void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, + int numStages) { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +static void scheduleDistanceOneDependencies(scf::ForOp forOp, + CoarseSchedule &schedule, + int numStages) { + auto getNestedOperands = [](Operation *op) -> SmallVector { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + if (auto arg = dyn_cast(operand)) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) { + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && schedule.count(defOp) == 0) { + if (isa(defOp)) { + // Exception: Schedule loads with a distance of 1 together + // with the current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], + true); + } + } + } + } + } + } +} + +static void scheduleRemainingToLastStage(scf::ForOp forOp, + CoarseSchedule &schedule, + CoarseSchedule::Cluster afterPrologue, + int numStages) { + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, _] : schedule.getOpsInOrder(forOp)) { + (void)_; + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + CoarseSchedule::Cluster userCluster = opToCluster[user]; + CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, + ttg::SwizzledSharedEncodingAttr sharedEnc, + unsigned distance) { + OpBuilder builder(forOp); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + auto memorySpace = mlir::IntegerAttr::get( + mlir::IntegerType::get(loadOp->getContext(), 64), 2); + Type memdescType = mlir::triton::gpu::MemDescType::get( + bufferShape, ty.getElementType(), sharedEnc, memorySpace, + /*mutableMemory*/ true); + Value alloc = builder.create( + loadOp->getLoc(), memdescType, Value()); + return alloc; +} + +struct AsyncLoad { + AsyncLoad(Operation *loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {} + Operation *loadOp; + Value alloc; + Value barrier; + Operation *waitOp = nullptr; + bool isTMALoad = false; +}; + +// Convert load ops into their asyn version and apply multi-buffering based on +// the required number of buffers. +static SmallVector +createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, + llvm::MapVector &loadToInfo, + SmallVector &barriers, int numStages) { + // Calculate the number of buffers needed for each load. + // int numBuffers = 0; + int numBuffers = + llvm::max_element(llvm::make_second_range(loadToInfo), [](auto &lhs, + auto &rhs) { + return lhs.distToUse < rhs.distToUse; + })->distToUse; + LLVM_DEBUG({ + llvm::dbgs() << "createshareAlloc:numBuffers=" << numBuffers << "\n"; + }); + SmallVector asyncLoads; + SmallVector allocs; + bool hasTMALoad = false; + for (auto &[loadOp, info] : loadToInfo) { + assert(info.sharedEncoding && "LoadOp shared encoding not defined."); + Value alloc = createAlloc(forOp, loadOp, info.sharedEncoding, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + allocs.push_back(alloc); + asyncLoads.emplace_back(loadOp, alloc); + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + // Create two new counters to index into the allocs. + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value insertIdx = minusOne; + Value extractIdx = minusOne; + Value phase = Value(); + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + SmallVector newOperands; + newOperands.push_back(insertIdx); + newOperands.push_back(extractIdx); + if (hasTMALoad) { + LLVM_DEBUG({ llvm::dbgs() << "createAsyncOps:has tma load\n"; }); + phase = builder.create(loc, 0, 32); + newOperands.push_back(phase); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + insertIdx = newForOp.getBody()->getArgument(newOperandIndex); + extractIdx = newForOp.getBody()->getArgument(newOperandIndex + 1); + if (phase) { + phase = newForOp.getBody()->getArgument(newOperandIndex + 2); + } + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + insertIdx = builder.create(loc, insertIdx, one); + Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, + insertIdx, numBuffersVal); + insertIdx = builder.create(loc, cndIns, insertIdx, zero); + + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + if (phase) { + Value nextPhase = builder.create(loc, phase, one); + phase = builder.create(loc, cndExt, phase, nextPhase); + } + // createTMABarrierAndWait(forOp, asyncLoads, insertIdx, extractIdx, phase, + // numBuffers, schedule, barriers, loadToInfo); + + // Create a cluster for the prefetches. It may end up being empty, but this + // is OK. + CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + + for (AsyncLoad &asyncLoad : asyncLoads) { + if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { + createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, + schedule, prefetchCluster, loadToInfo, numStages); + } + } + SmallVector newYieldOperands = {insertIdx, extractIdx}; + if (phase) + newYieldOperands.push_back(phase); + // Patch the yield with the updated counters. + appendToYield(forOp, newYieldOperands); + + return allocs; +} + +bool mlir::triton::gcu::preProcessLoopAndGetSchedule( + scf::ForOp &forOp, int numStages, + mlir::triton::gcu::PipeliningOption &options) { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + CoarseSchedule coarseSchedule(numStages); + llvm::MapVector loadToInfo = + scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + if (loadToInfo.empty()) + return false; + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + coarseSchedule.dump(); + }); + + SmallVector barriers; // only for TMA + // Convert the loads into async loads and create the allocs. + SmallVector allocs = + createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages); + + LLVM_DEBUG({ + LDBG("Coarse schedule with async loads:"); + coarseSchedule.dump(); + }); + + CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with prologue and epilogue:"); + coarseSchedule.dump(); + }); + + scheduleDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + coarseSchedule.dump(); + }); + + scheduleDistanceOneDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + coarseSchedule.dump(); + }); + + scheduleRemainingToLastStage(forOp, coarseSchedule, afterPrologue, numStages); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + coarseSchedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + coarseSchedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = triton::gcu::predicateOp; + options.supportDynamicLoops = true; + options.annotateFn = + [](Operation *op, mlir::triton::gcu::PipeliningOption::PipelinerPart part, + unsigned iteration) {}; + // Insert a wait 0 after the loop + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + // we need skip last async_load_global_to_share + // builder.create(forOp.getLoc(), ValueRange({}), 0); + // Invalidate any mbarrier create + // no need for gcu + // invalidateBarriers(builder, barriers); + // Explicitly deallocate allocated tensors after the wait op + for (auto alloc : allocs) + builder.create(forOp.getLoc(), alloc); + return true; +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/OuterLoopPipeline.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/OuterLoopPipeline.cpp new file mode 100644 index 000000000..0ec502ecb --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/OuterLoopPipeline.cpp @@ -0,0 +1,151 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +// create the schedule for a matmul loop. This is ad hoc based on how we know +// matmul loops should be pipelined and is not a generic scheduler. +static std::vector> +createSchedule(scf::ForOp forOp, int numStages) { + SmallVector insertOps; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + } + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + tt::gcu::addDep(op, insertAndDeps, true); + } + + DenseSet epilogue; + bool foundLoop = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (insertAndDeps.count(&op)) + continue; + if (isa(op)) + foundLoop = true; + if (isa(op)) + continue; + if (foundLoop) + epilogue.insert(&op); + } + + std::vector> schedule; + // Schedule stage 1 first. + tt::gcu::addOps(forOp, 1, schedule, [&](Operation *op) { + return insertAndDeps.count(op) == 0 && epilogue.count(op) == 0; + }); + + // Then Schedule stage 0. + tt::gcu::addOps(forOp, 0, schedule, + [&](Operation *op) { return insertAndDeps.count(op); }); + + // Then schedule the epilogue in stage 1 + tt::gcu::addOps(forOp, 1, schedule, + [&](Operation *op) { return epilogue.count(op); }); + return schedule; +} + +// pre-process the loop by hosting allocations/deallocation out of the +// loop. +static void hoistAllocAndConst(scf::ForOp forOp) { + SmallVector toHoist; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto allocOp = dyn_cast(op)) { + // We hoist the allocOp only if it is created by the inner loop + // pipelining. + if (!allocOp.getSrc()) + toHoist.push_back(&op); + } else if (isa(op)) { + toHoist.push_back(&op); + } + } + for (Operation *op : toHoist) { + op->moveBefore(forOp); + auto allocOp = dyn_cast(op); + if (!allocOp) + continue; + for (Operation *user : allocOp->getUsers()) { + if (auto dealloc = dyn_cast(user)) { + dealloc->moveAfter(forOp); + } + } + } +} + +static bool preCondition(scf::ForOp forOp) { + // Check if there is a dependency from the loop to the async copy op. In this + // case we cannot pipeline the async copy. + SmallVector insertOps; + int numForOps = 0; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + if (isa(op)) + numForOps++; + } + if (insertOps.empty() || numForOps != 1) + return false; + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + tt::gcu::addDep(op, insertAndDeps, true); + } + // If there is a recurrence containing both the async and the for op we cannot + // pipeline. + for (Operation *op : insertAndDeps) { + if (isa(op)) + return false; + } + return true; +} + +bool mlir::triton::gcu::getOuterLoopSchedule( + scf::ForOp &forOp, int numStages, + mlir::triton::gcu::PipeliningOption &options) { + assert(numStages == 2 && "only support 2 stage pipelining for now"); + // 1. Check precondition, we cannot have a recurrence involving async cp ops + if (!preCondition(forOp)) + return false; + + // 2. pre-process the loop by hosting allocations. + hoistAllocAndConst(forOp); + + // 3. Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + createSchedule(forOp, numStages); + + // 4. Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = mlir::triton::gcu::predicateOp; + options.supportDynamicLoops = true; + return true; +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipelineExpander.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipelineExpander.cpp new file mode 100644 index 000000000..46ea4b701 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipelineExpander.cpp @@ -0,0 +1,867 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "PipelineExpander.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include +#include +#include + +#define DEBUG_TYPE "triton-loop-pipelining" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::triton; + +namespace { + +inline int64_t ceilDiv(int64_t lhs, int64_t rhs) { + assert(rhs >= 1); + // C/C++'s integer division rounds towards 0. + return lhs % rhs > 0 ? lhs / rhs + 1 : lhs / rhs; +} + +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + +protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value ub; + Value lb; + Value step; + bool dynamicLoop; + triton::gcu::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + triton::gcu::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + /// Return the defining op of the given value, if the Value is an argument of + /// the loop return the associated defining op in the loop and its distance to + /// the Value. + std::pair getDefiningOpAndDistance(Value value); + + /// Return true if the schedule is possible and return false otherwise. A + /// schedule is correct if all definitions are scheduled before uses. + bool verifySchedule(); + +public: + /// Initialize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, + const triton::gcu::PipeliningOption &options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + void emitPrologue(RewriterBase &rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + void emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues); +}; + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const triton::gcu::PipeliningOption &options) { + LDBG("Start gcu initializeLoopInfo"); + forOp = op; + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + dynamicLoop = true; + auto upperBoundCst = ub.getDefiningOp(); + auto lowerBoundCst = lb.getDefiningOp(); + auto stepCst = step.getDefiningOp(); + if (!upperBoundCst || !lowerBoundCst || !stepCst) { + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm); + if (numIteration > maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } + } + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { + LDBG("--no epilogue or predicate set -> BAIL"); + return false; + } + if (dynamicLoop && peelEpilogue) { + LDBG("--dynamic loop doesn't support epilogue yet -> BAIL"); + return false; + } + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + // All operations need to have a stage. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!stages.contains(&op)) { + op.emitOpError("not assigned a pipeline stage"); + LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + return false; + } + } + + if (!verifySchedule()) { + LDBG("--invalid schedule: " << op << " -> BAIL"); + return false; + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + LDBG("--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"); + return false; + } + } + + // Support only loop-carried dependencies with a distance of one iteration or + // those defined outside of the loop. This means that any dependency within a + // loop should either be on the immediately preceding iteration, the current + // iteration, or on variables whose values are set before entering the loop. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [this](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def || + (!stages.contains(def) && forOp->isAncestor(def)); + })) { + LDBG("--only support loop carried dependency with a distance of 1 or " + "defined outside of the loop -> BAIL"); + return false; + } + annotateFn = options.annotateFn; + return true; +} + +/// Find operands of all the nested operations within `op`. +static SetVector getNestedOperands(Operation *op) { + SetVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + operands.insert(operand); + } + }); + return operands; +} + +/// Compute unrolled cycles of each op (consumer) and verify that each op is +/// scheduled after its operands (producers) while adjusting for the distance +/// between producer and consumer. +bool LoopPipelinerInternal::verifySchedule() { + int64_t numCylesPerIter = opOrder.size(); + // Pre-compute the unrolled cycle of each op. + DenseMap unrolledCyles; + for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { + Operation *def = opOrder[cycle]; + auto it = stages.find(def); + assert(it != stages.end()); + int64_t stage = it->second; + unrolledCyles[def] = cycle + stage * numCylesPerIter; + } + for (Operation *consumer : opOrder) { + int64_t consumerCycle = unrolledCyles[consumer]; + for (Value operand : getNestedOperands(consumer)) { + auto [producer, distance] = getDefiningOpAndDistance(operand); + if (!producer) + continue; + auto it = unrolledCyles.find(producer); + // Skip producer coming from outside the loop. + if (it == unrolledCyles.end()) + continue; + int64_t producerCycle = it->second; + if (consumerCycle < producerCycle - numCylesPerIter * distance) { + consumer->emitError("operation scheduled before its operands"); + return false; + } + } + } + return true; +} + +/// Clone `op` and call `callback` on the cloned op's operands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + clone->walk([&](Operation *nested) { + // 'clone' itself will be visited first. + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || isa(operand.get())) + callback(&operand); + } + }); + return clone; +} + +void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { + setValueMapping(arg, operand.get(), 0); + } + auto yield = cast(forOp.getBody()->getTerminator()); + Location loc = forOp.getLoc(); + SmallVector predicates(maxStage); + for (int64_t i = 0; i < maxStage; i++) { + if (dynamicLoop) { + Type t = ub.getType(); + // pred = ub > lb + (i * step) + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, i)))); + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, ub); + } + + // special handling for induction variable as the increment is implicit. + // iv = lb + i * step + Type t = lb.getType(); + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, + rewriter.getIntegerAttr(t, i)))); + setValueMapping(forOp.getInductionVar(), iv, i); + for (Operation *op : opOrder) { + if (stages[op] > i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + // only for gcu to avoid dte offset > src buffer_address + if (i > 0 && isa(newOp)) { + newOp->setAttr("Prologue_stage_idex", rewriter.getI32IntegerAttr(i)); + } + int predicateIdx = i - stages[op]; + if (predicates[predicateIdx]) { + newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); + assert(newOp && "failed to predicate op."); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, + triton::gcu::PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + // If the value is a loop carried dependency update the loop argument + // mapping. + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), i - stages[op] + 1); + } + } + } + } +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + LLVM_DEBUG({ + llvm::dbgs() << "analyzeCrossStageValues enter:" + << " \n"; + }); + for (Operation *op : opOrder) { + unsigned stage = stages[op]; + auto analyzeOperand = [&](OpOperand &operand) { + auto [def, distance] = getDefiningOpAndDistance(operand.get()); + if (!def) + return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) + return; + if (defStage->second == stage + distance) { + llvm::dbgs() << "current op stage:" << stage << ",current op:"; + op->dump(); + llvm::dbgs() << "defStage stage:" << defStage->second << ",def op:"; + def->dump(); + } + assert(stage > defStage->second); + LiverangeInfo &info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + LLVM_DEBUG({ + llvm::dbgs() << "info.defStage:" << info.defStage << " \n"; + llvm::dbgs() << "info.lastUseStage:" << info.lastUseStage << " \n"; + operand.get().dump(); + }); + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); + } + LLVM_DEBUG({ + llvm::dbgs() << "analyzeCrossStageValues leave" + << " \n"; + }); + return crossStageValues; +} + +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + int64_t distance = 0; + if (auto arg = dyn_cast(value)) { + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + distance++; + value = + forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + Operation *def = value.getDefiningOp(); + if (!def) + return {nullptr, 0}; + return {def, distance}; +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector + &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance of 1 or " + "outside the loop"); + auto defStage = stages.find(def); + if (defStage != stages.end()) { + Value valueVersion = + valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage->second]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } else { + newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); + } + } + for (auto escape : crossStageValues) { + LiverangeInfo &info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + LLVM_DEBUG({ + llvm::dbgs() << "createKernelLoop add loop arg: stageIdx:" << stageIdx + << ",mapping from :" + << maxStage - info.lastUseStage + stageIdx << " \n"; + llvm::dbgs() << "createKernelLoop add loop arg value: "; + valueVersion.dump(); + }); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + Value maxStageValue = rewriter.create( + loc, rewriter.getIntegerAttr(t, maxStage)); + Value maxStageByStep = + rewriter.create(loc, step, maxStageValue); + newUb = rewriter.create(loc, ub, maxStageByStep); + } + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector + &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + LLVM_DEBUG({ + llvm::dbgs() << "createKernel &&!peelEpilogue maxStage:" << maxStage + << "\n"; + }); + Location loc = newForOp.getLoc(); + Type t = ub.getType(); + for (unsigned i = 0; i < maxStage; i++) { + // c = ub - (maxStage - i) * step + Value c = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr( + t, static_cast(maxStage - i))))); + + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } + for (Operation *op : opOrder) { + int64_t useStage = stages[op]; + auto *newOp = rewriter.clone(*op, mapping); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = rewriter.create( + forOp.getLoc(), step, + rewriter.create( + forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); + Value iv = rewriter.create( + forOp.getLoc(), newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + Value source = operand->get(); + auto arg = dyn_cast(source); + if (arg && arg.getOwner() == forOp.getBody()) { + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) { + continue; + } + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = source.getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + LLVM_DEBUG({ + llvm::dbgs() << "operands defined in previous stage update operand " + "from loopArgMap:"; + newForOp.getRegionIterArgs()[remap->second].dump(); + }); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } + + if (predicates[useStage]) { + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) + return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, triton::gcu::PipeliningOption::PipelinerPart::Kernel, + 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (OpOperand &yieldOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yieldOperand.get()); + // When we don't peel the epilogue and the yield value is used outside the + // loop we need to make sure we return the version from numStages - + // defStage. + LLVM_DEBUG({ + llvm::dbgs() << "get old forop operand mapping:\n"; + source.dump(); + }); + if (!peelEpilogue && + !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { + Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end() && defStage->second < maxStage) { + Value pred = predicates[defStage->second]; + source = rewriter.create( + pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yieldOperand.getOperandNumber() + 1]); + LLVM_DEBUG({ + llvm::dbgs() << "select source to avoid out of range: "; + def->dump(); + }); + } + } + } + LLVM_DEBUG({ + llvm::dbgs() << "push finnal yeiled operand"; + source.dump(); + }); + yieldOperands.push_back(source); + } + + for (auto &it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + LLVM_DEBUG({ + llvm::dbgs() << "createKernel crossStageValue:" + << "version:" << version + << "numVersionReturned:" << numVersionReturned + << ",crossvalue"; + it.first.dump(); + }); + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + LLVM_DEBUG({ + llvm::dbgs() << "push cross stage arg:"; + newForOp.getBody() + ->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()] + .dump(); + }); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + LLVM_DEBUG({ + llvm::dbgs() << "push old to new mapping value:"; + mapping.lookupOrDefault(it.first).dump(); + }); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance of 1 or " + "defined outside the loop"); + auto defStage = stages.find(def); + if (defStage == stages.end()) { + for (unsigned int stage = 1; stage <= maxStage; stage++) + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + retVal.value(), stage); + } else if (defStage->second > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage->second + 1); + } + } + rewriter.create(forOp.getLoc(), yieldOperands); + return success(); +} + +void LoopPipelinerInternal::emitEpilogue( + RewriterBase &rewriter, llvm::SmallVector &returnValues) { + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + for (int64_t i = 0; i < maxStage; i++) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + // number of iterations = ((ub - 1) - lb) / step + Value totalNumIteration = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, ub, minusOne), lb), + step); + // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) + Value minusI = + rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); + Value newlastIter = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, totalNumIteration, minusI))); + setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + } + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + for (Operation *op : opOrder) { + if (stages[op] < i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[maxStage - stages[op] + i]; + newOperand->set(replacement); + } + }); + if (annotateFn) + annotateFn(newOp, + triton::gcu::PipeliningOption::PipelinerPart::Epilogue, + i - 1); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + maxStage - stages[op] + i); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + unsigned version = maxStage - stages[op] + i + 1; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + if (version > maxStage) { + returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + continue; + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), version); + } + } + } + } +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +FailureOr +mlir::triton::gcu::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const triton::gcu::PipeliningOption &options, + bool *modifiedIR) { + if (modifiedIR) + *modifiedIR = false; + LoopPipelinerInternal pipeliner; + LLVM_DEBUG({ llvm::dbgs() << "gcu:pipelineForLoop\n"; }); + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); + + if (modifiedIR) + *modifiedIR = true; + + // 1. Emit prologue. + LLVM_DEBUG({ + llvm::dbgs() << "before emitPrologue\n"; + forOp.getOperation()->getParentOp()->dump(); + }); + pipeliner.emitPrologue(rewriter); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + LLVM_DEBUG({ + llvm::dbgs() << "before createKernelLoop\n"; + forOp.getOperation()->getParentOp()->dump(); + }); + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + LLVM_DEBUG({ + llvm::dbgs() << "before createKernel\n"; + forOp.getOperation()->getParentOp()->dump(); + }); + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return failure(); + + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + pipeliner.emitEpilogue(rewriter, returnValues); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + + return newForOp; +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipelineExpander.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipelineExpander.h new file mode 100644 index 000000000..4aa1f4ede --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipelineExpander.h @@ -0,0 +1,114 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRITON_DIALECT_TRIRONTOGCU_PINGPONG_PIPELINE_H_ +#define TRITON_DIALECT_TRIRONTOGCU_PINGPONG_PIPELINE_H_ + +// This is a fork of upstream pipeline transformation. This will be merged back +// upstream once we have a stable solution. +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +namespace mlir { + +class RewriterBase; +class Operation; +class Value; + +namespace scf { +class ForOp; +} + +namespace triton { +namespace gcu { +/// Options to dictate how loops should be pipelined. +struct PipeliningOption { + /// Lambda returning all the operation in the forOp, with their stage, in the + /// order picked for the pipelined loop. + using GetScheduleFnType = std::function> &)>; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true the + /// pipeliner will have to predicate operations in the the prologue/epilogue. + bool supportDynamicLoops = false; + + // Callback to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. It is expected to replace the given operation with + // the predicated equivalent and return it, or return nullptr if the + // predication is impossible. In the latter case, pipelining will fail and + // may leave IR in a partially transformed state. + using PredicateOpFnType = + std::function; + PredicateOpFnType predicateFn = nullptr; + + // TODO(triton): add option to decide if the prologue should be peeled. +}; + +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +/// +/// If `modifiedIR` is provided, it will be set to a value that indicates +/// whether pipelining modified the IR before failing, signaling to the caller +/// whether they can proceed with different transformations. +FailureOr pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp, + const PipeliningOption &options, + bool *modifiedIR = nullptr); + +} // namespace gcu +} // namespace triton +} // namespace mlir +#endif // TRITON_DIALECT_TRIRONTOGCU_PINGPONG_PIPELINE_H_ diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipeliningUtility.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipeliningUtility.cpp new file mode 100644 index 000000000..bcd493bdc --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipeliningUtility.cpp @@ -0,0 +1,122 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "PipeliningUtility.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +// Combine the current mask with the given predicate. +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (isa(maskType)) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +// Function to mask operations during scheduling. +Operation *mlir::triton::gcu::predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto ifOp = dyn_cast(op)) { + rewriter.setInsertionPoint(op); + Value cnd = getPredMask(rewriter, ifOp.getCondition().getType(), + ifOp.getCondition(), pred); + ifOp.getConditionMutable().assign(cnd); + return op; + } + if (auto asyncCopyOp = + dyn_cast(op)) { + return op; + } + if (auto loadOp = dyn_cast(op)) { + return op; + } + if (auto loadOp = dyn_cast(op)) { + return op; + } + assert("don't know how to predicate this op" && false); + return op; +} + +/// Helper to recursively add dependencies to the same stage. +void mlir::triton::gcu::addDep(Operation *op, DenseSet &deps, + bool includeArg, DenseSet *filter) { + if (filter && filter->count(op)) + return; + if (!deps.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::DenseSet seen; + seen.reserve(4); + while (auto arg = mlir::dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + triton::gcu::addDep(defOp, deps, includeArg, filter); + } + } +} + +// Add operations to the schedule with the given stage based on the filter +// function. +void mlir::triton::gcu::addOps( + scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!filter(&op)) + continue; + schedule.emplace_back(&op, stage); + } +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipeliningUtility.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipeliningUtility.h new file mode 100644 index 000000000..6ae2bfbfc --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/PipeliningUtility.h @@ -0,0 +1,44 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRITON_TRIRONTOGCU_PINGPONG_UTILITY_H_ +#define TRITON_TRIRONTOGCU_PINGPONG_UTILITY_H_ +#include "mlir/Dialect/SCF/IR/SCF.h" +#include +#include + +namespace mlir { +namespace triton { +namespace gcu { +/// Function to mask operations during scheduling. +Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred); + +/// Collect ssa dependencies of `op` in `deps`. if `includeArg` is true, +/// continue looking through loop block arguments. +void addDep(Operation *op, DenseSet &deps, bool includeArg = true, + DenseSet *filter = nullptr); + +/// Add operations from `forOp` into a pipeline schedule with the the given +/// `stage` when filter is true. This will add operation in the original loop +/// order. +void addOps(scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter); +} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif // TRITON_TRIRONTOGCU_PINGPONG_UTILITY_H_ diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/Schedule.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/Schedule.h new file mode 100644 index 000000000..5a720d980 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/Schedule.h @@ -0,0 +1,44 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRITON_TRIRONTOGCU_PINGPONG_SCHEDULE_H_ +#define TRITON_TRIRONTOGCU_PINGPONG_SCHEDULE_H_ + +#include "PipelineExpander.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include + +namespace mlir { +namespace triton { +namespace gcu { + +/// This fill out the pipelining options including schedule and annotations +/// for wait ops. This also does pre-processing by converting some of the +/// loads into async loads so that the IR is ready to be pipelined. +bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::gcu::PipeliningOption &options); + +/// Fills out pipelining options for an outer loop pipelining case. This +/// schedules async copies to overlap with the epilogue of a loop. +bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::gcu::PipeliningOption &options); + +} // namespace gcu +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/TritonGCUPingpong.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/TritonGCUPingpong.cpp new file mode 100644 index 000000000..2ae4e583e --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUPingpong/TritonGCUPingpong.cpp @@ -0,0 +1,271 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +#include "Conversion/TritonToGCU/Utils.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" + +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +namespace mlir { +#define GEN_PASS_DEF_TRITONGCUPINGPONGPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir +#define DEBUG_TYPE "triton-gcu-pingpong" + +namespace { +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +// Return true if the preconditions for pipelining the loop are met. +static bool preCondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO(triton): relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + })) + return false; + // Don't pipeline outer loops. + if (forOp + ->walk([&](Operation *op) { + if (forOp.getOperation() == op) + return WalkResult::advance(); + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + return true; +} + +// static void tryAndPipelineOuterLoop(scf::ForOp forOp) { +// mlir::triton::gcu::PipeliningOption options; +// bool foundSchedule = false; +// // Limit 2 stages to not require extra shared memory. +// foundSchedule = gcu::getOuterLoopSchedule(forOp, /*numStage=*/2, options); +// if (!foundSchedule) +// return; +// IRRewriter rewriter(forOp->getContext()); +// rewriter.setInsertionPoint(forOp); +// FailureOr newForOp = +// mlir::triton::gcu::pipelineForLoop(rewriter, forOp, options); +// (void)newForOp; +// } + +static bool pipelineLoop(scf::ForOp forOp, int numStages) { + mlir::triton::gcu::PipeliningOption options; + if (!preCondition(forOp)) + return false; + + bool foundSchedule = false; + foundSchedule = gcu::preProcessLoopAndGetSchedule(forOp, numStages, options); + + // TODO(triton): add more pipelines strategy. + if (!foundSchedule) + return false; + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::gcu::pipelineForLoop(rewriter, forOp, options); + + if (failed(newForOp)) + return false; + // gcu need post process last dte config + auto newFor = *newForOp; + LLVM_DEBUG({ + llvm::dbgs() << "post process \n"; + newFor.getOperation()->getParentOp()->dump(); + }); + auto step = newFor.getStep(); + auto upperBound = newFor.getUpperBound(); + auto forIdx = newFor.getInductionVar(); + std::vector eraseOps; + OpBuilder builder(newFor.getContext()); + builder.setInsertionPoint(newFor.getOperation()); + auto twoStep = builder.create( + step.getLoc(), step, + builder.create( + step.getLoc(), + builder.getIntegerAttr(step.getType(), numStages - 1))); + for (Operation &op : newFor.getBody()->without_terminator()) { + if (isa(op)) { + SmallVector> queue; + for (auto &use : op.getUses()) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + if (queue.size() == 1 && isa(queue[0].first)) { + builder.setInsertionPoint(&op); + auto loc = op.getLoc(); + auto lastLastLoad = builder.create( + loc, arith::CmpIPredicate::slt, + builder.create(loc, forIdx, twoStep), upperBound); + auto ifop = builder.create( + loc, lastLastLoad, + [&](OpBuilder &builder, Location loc) { + auto clone = builder.clone(op); + llvm::SmallVector yieldOperands; + yieldOperands.push_back(clone->getResult(0)); + builder.create(loc, yieldOperands); + }, + [&](OpBuilder &builder, Location loc) { + llvm::SmallVector yieldOperands; + yieldOperands.push_back(newFor.getInitArgs()[queue[0].second]); + builder.create(loc, yieldOperands); + }); + op.getResult(0).replaceAllUsesWith(ifop.getResult(0)); + eraseOps.push_back(&op); + } + } else if (isa(op)) { + SmallVector> queue; + for (auto &use : op.getUses()) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + if (queue.size() == 1 && isa(queue[0].first)) { + builder.setInsertionPoint(&op); + auto loc = op.getLoc(); + auto lastLastWait = builder.create( + loc, arith::CmpIPredicate::slt, + builder.create(loc, forIdx, step), upperBound); + auto ifop = builder.create( + loc, lastLastWait, + [&](OpBuilder &builder, Location loc) { + auto clone = builder.clone(op); + llvm::SmallVector yieldOperands; + yieldOperands.push_back(clone->getResult(0)); + builder.create(loc, yieldOperands); + }, + [&](OpBuilder &builder, Location loc) { + llvm::SmallVector yieldOperands; + yieldOperands.push_back(newFor.getInitArgs()[queue[0].second]); + builder.create(loc, yieldOperands); + }); + op.getResult(0).replaceAllUsesWith(ifop.getResult(0)); + eraseOps.push_back(&op); + } + } + } + for (auto erase : eraseOps) { + erase->erase(); + } + return true; +} + +struct TritonGCUPingpongPass + : public mlir::impl::TritonGCUPingpongPassBase { + using Base::Base; + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr("tt.num_stages")) { + return numStages > 6 ? 6 : numStages; + } + int stageNumber = + mlir::cast(forOp->getAttr("tt.num_stages")).getInt(); + if (stageNumber > 6) { + stageNumber = 6; + } + return stageNumber; + } + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } +}; +} // namespace + +using namespace mlir; +void TritonGCUPingpongPass::runOnOperation() { + LLVM_DEBUG({ llvm::dbgs() << "enter TritonGCUPingpongPass\n"; }); + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 2) + loops.push_back(forOp); + }); + + if (loops.empty()) { + LLVM_DEBUG({ llvm::dbgs() << "no loops \n"; }); + return; + } + llvm::SmallSetVector outerLoops; + for (scf::ForOp forOp : loops) { + auto outerLoop = dyn_cast(forOp->getParentOp()); + int loopNumStages = getNumStagesOrDefault(forOp); + bool pipelined = pipelineLoop(forOp, loopNumStages); + if (pipelined && outerLoop && getNumStagesOrDefault(outerLoop) > 2) + outerLoops.insert(outerLoop); + } + + // Clean up arithmetic before applying the next level of pipelining to + // simplify the IR. + auto arithDialect = + getOperation().getContext()->getLoadedDialect(); + RewritePatternSet patterns(getOperation().getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (applyPatternsGreedily(getOperation(), std::move(patterns)).failed()) + return signalPassFailure(); + + // Try to pipeline the outer loop to overlap the prologue and epilogue of + // the inner loop. + // for (scf::ForOp outerLoop : outerLoops) + // tryAndPipelineOuterLoop(outerLoop); + + // Re-collect loop ops todo support store latter + // loops.clear(); + // getOperation()->walk([&](scf::ForOp forOp) { + // // Bail out for loops with num_stage <= 1. + // if (getNumStagesOrDefault(forOp) > 1) + // loops.push_back(forOp); + // }); + + // for (scf::ForOp forOp : loops) { + // mlir::triton::pipelineTMAStores(forOp); + // } +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritionToGCUBase.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritionToGCUBase.h new file mode 100644 index 000000000..19bf16110 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritionToGCUBase.h @@ -0,0 +1,48 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef KURAMA_TRITON_TO_GCU_BASE_H_ +#define KURAMA_TRITON_TO_GCU_BASE_H_ + +#include + +#include "mlir/IR/Dominance.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace triton { +namespace gcu { +class FirstLastUserAnalysis; +} +} // namespace triton +} // namespace mlir + +using namespace mlir; + +template +class SharedConversionPattern : public OpConversionPattern { +public: + SharedConversionPattern(const TypeConverter &converter, MLIRContext *ctx, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) + : OpConversionPattern(converter, ctx), + userAnalysis(userAnalysis), replaced2Origin(replaced2Origin) {} + +protected: + triton::gcu::FirstLastUserAnalysis &userAnalysis; + std::map &replaced2Origin; +}; +#endif diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUAsyncOpToGCU.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUAsyncOpToGCU.h new file mode 100644 index 000000000..4da8e7730 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUAsyncOpToGCU.h @@ -0,0 +1,498 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef KURAMA_TRITONGCUASYNC_TO_GCU_H_ +#define KURAMA_TRITONGCUASYNC_TO_GCU_H_ + +#include "TritionToGCUBase.h" +#include "TritonGCUToGCUUtils.h" + +#include + +#include "Dialect/GCU/IR/Dialect.h" +#include "Dialect/GCU/IR/Types.h" +#include "Dialect/MathExt/IR/MathExt.h" +#include "Dialect/MathExt/IR/MathExtTypes.h" +#include "Dialect/MemrefExt/IR/MemrefExt.h" + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +void getPipelineAsyncResourceMaping( + Operation *module, std::map &asyncLoad2Tag, + llvm::DenseMap &asyncLoad2Tagidex, + std::map &asyncWait2Tag) { + int32_t pipelineResourceNumber = -1; + std::map shareAlloc2Tags; + module->walk([&](triton::gcu::ShareAllocOp op) { + auto type = dyn_cast(op.getType()); + int32_t dim0Size = type.getShape()[0]; + if ((pipelineResourceNumber != -1) && + (pipelineResourceNumber != dim0Size)) { + assert(false && " all triton::gcu::ShareAllocOp should has some " + "PipelineResourceNumber!!!"); + } + pipelineResourceNumber = dim0Size; + OpBuilder builder(op.getOperation()); + auto tagType = MemRefType::get(ArrayRef{pipelineResourceNumber}, + builder.getI32Type()); + auto tag = builder.create(op.getLoc(), tagType); + tag->setAttr("gcu.share_tag", builder.getUnitAttr()); + shareAlloc2Tags[op.getOperation()] = tag.getOperation(); + }); + + auto getShareAlloc = [&](Operation *tokenDefineOp) { + assert(isa(tokenDefineOp) && + " wait_op's tag should be a AsyncLoadGlobalToShareOp op!"); + auto dstbuffer = + dyn_cast(tokenDefineOp) + .getDstMem(); + auto bufferDefineOp = dstbuffer.getDefiningOp(); + if (!bufferDefineOp || + !isa(bufferDefineOp)) { + assert(false && + " AsyncLoadGlobalToShareOp's dst should be a subview op!"); + } + auto subView = dyn_cast(bufferDefineOp); + auto shareAllocOp = subView.getSrc().getDefiningOp(); + if (!shareAllocOp || !isa(shareAllocOp)) { + assert(false && " MemDescSubviewOp's src should be a ShareAllocOp op!"); + } + return shareAllocOp; + }; + + module->walk([&](Operation *operation) { + llvm::TypeSwitch(operation) + .Case( + [&](triton::gcu::AsyncLoadGlobalToShareOp load) { + auto dstbuffer = load.getDstMem(); + auto defineOp = dstbuffer.getDefiningOp(); + if (!defineOp || !isa(defineOp)) { + assert( + false && + " AsyncLoadGlobalToShareOp's dst should be a subview op!"); + } + auto subView = dyn_cast(defineOp); + auto shareAllocOp = subView.getSrc().getDefiningOp(); + if (!shareAllocOp || + !isa(shareAllocOp)) { + assert(false && + " MemDescSubviewOp's src should be a ShareAllocOp op!"); + } + asyncLoad2Tag[operation] = shareAlloc2Tags[shareAllocOp]; + SmallVector opOffsetVals = subView.getOffsets(); + asyncLoad2Tagidex[operation] = opOffsetVals[0]; + }) + .Case([&](triton::gcu::AsyncWaitOp wait) { + auto waitToken = wait.getAsyncToken()[0]; + if (auto tocken = dyn_cast(waitToken)) { + auto waitParent = operation->getParentOp(); + if (isa(waitParent)) { + waitParent = waitParent->getParentOp(); + } + assert(isa(waitParent) && + "if async wait got a block argument, it should be in ForOp"); + auto forInitToken = + dyn_cast(waitParent).getTiedLoopInit(tocken)->get(); + auto tokenDefineOp = forInitToken.getDefiningOp(); + if (tokenDefineOp) { + asyncWait2Tag[operation] = + shareAlloc2Tags[getShareAlloc(tokenDefineOp)]; + } + } else { + auto tokenDefineOp = waitToken.getDefiningOp(); + if (tokenDefineOp) { + asyncWait2Tag[operation] = + shareAlloc2Tags[getShareAlloc(tokenDefineOp)]; + } + } + }); + }); +} + +struct TTShareAllocOpLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::gcu::ShareAllocOp alloc, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, alloc.getOperation()); + auto resultType = + dyn_cast(getTypeConverter()->convertType(alloc.getType())); + auto output = rewriter.create(alloc.getLoc(), resultType); + leaveTritionOp(rewriter, alloc.getOperation()); + rewriter.replaceOp(alloc, output); + return success(); + } +}; + +struct TTShareDeallocOpLowering + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::gcu::ShareDeallocOp dealloc, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, dealloc.getOperation()); + rewriter.create(dealloc.getLoc(), adaptor.getSrc()); + leaveTritionOp(rewriter, dealloc.getOperation()); + rewriter.eraseOp(dealloc); + return success(); + } +}; + +struct TTLocalLoadOpLowering + : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::gcu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcLayout = + cast(op.getSrc().getType()).getEncoding(); + auto dstLayout = dyn_cast(op.getType()).getEncoding(); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto firstUser = userAnalysis.getFirstUserOp(op.getOperation()); + auto tag = (firstUser == nullptr) ? getPrivateDTETag(rewriter, op) + : createPrivateDTETag(rewriter, op); + // share to Distributed + if (mlir::isa(srcLayout) && + isa(dstLayout)) { + // copy to local + auto output = loadFromSharedMem(rewriter, tag, op.getResult().getType(), + adaptor.getSrc(), false, lastUser, + firstUser, userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } else if (mlir::isa(srcLayout) && + isa(dstLayout)) { + // Distributed to dot operand + // to dot a or b + auto output = loadFromSharedMemForDotOperand( + rewriter, tag, op.getResult().getType(), + op.getSrc().getType().getShape(), adaptor.getSrc(), lastUser, + firstUser, userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } else { + op.dump(); + llvm::report_fatal_error( + "[Error] gcu::LocalLoadOp maybe had bad used in pinpong\n"); + } + return success(); + } +}; + +inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, + ArrayRef strides) { + assert(offsets.size() == strides.size()); + Value ret = + rewriter.create(loc, 0, rewriter.getI32Type()); + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + ret = rewriter.create( + loc, ret, rewriter.create(loc, offset, stride)); + } + return ret; +} + +struct TTMemDescSubviewOpLowering + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubviewOp subview, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, subview.getOperation()); + auto resultType = dyn_cast( + getTypeConverter()->convertType(subview.getType())); + auto loc = subview.getLoc(); + auto src = adaptor.getSrc(); + auto sourceType = dyn_cast(src.getType()); + auto sourceRank = sourceType.getRank(); + auto [strides, offset] = sourceType.getStridesAndOffset(); + (void)offset; + SmallVector opOffsetVals = subview.getOffsets(); + assert((opOffsetVals.size() == strides.size()) && + "offset size is not equal to stride size !!!"); + assert((opOffsetVals.size() == static_cast(sourceRank)) && + "offset size is not equal to rank !!!"); + + auto elemType = resultType.getElementType(); + // SmallVector outOffsets; + SmallVector strideVals; + SmallVector strideValues; + for (int32_t i = 0; i < sourceRank; i++) { + if (i > 0) { + strideVals.push_back(rewriter.getIndexAttr(strides[i])); + } + strideValues.push_back(rewriter.create( + loc, strides[i], opOffsetVals[0].getType())); + } + + auto finalOffsetValue = dot(rewriter, loc, opOffsetVals, strideValues); + auto bpe = elemType.getIntOrFloatBitWidth() / 8; + auto elementType = resultType.getElementType(); + int64_t size = 1; + for (int i = 0; i < sourceType.getRank(); i++) { + size *= sourceType.getShape()[i]; + } + // Create flattened buffer + MemRefType flatType = MemRefType::get({size}, elementType, AffineMap{}, + resultType.getMemorySpace()); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + Value flatBuffer = rewriter.create( + loc, flatType, src, zero, + ValueRange{rewriter.create(loc, size)}, + ValueRange{one}); + auto ptrType = gcu::PtrType::get(getContext(), elementType); + Value ptr = rewriter.create(loc, ptrType, flatBuffer); + MemRefType memType1D = + MemRefType::get({ShapedType::kDynamic}, rewriter.getI8Type()); + auto buffer1D = rewriter.create(loc, memType1D, ptr); + + auto I8Offset = rewriter.create( + loc, finalOffsetValue, + rewriter.create(loc, bpe, + opOffsetVals[0].getType())); + auto bufferWithSpace = rewriter.create( + loc, + MemRefType::get({ShapedType::kDynamic}, rewriter.getI8Type(), + AffineMap{}, resultType.getMemorySpace()), + buffer1D); + auto output = rewriter.create( + loc, resultType, bufferWithSpace, + rewriter.create(loc, rewriter.getIndexType(), + I8Offset), + ValueRange{}); + leaveTritionOp(rewriter, subview.getOperation()); + rewriter.replaceOp(subview, output); + return success(); + } +}; + +struct TTAsyncLoadGlobalToShareOpLowering + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::map &asyncLoad2Tag; + llvm::DenseMap &asyncLoad2Tagidex; + explicit TTAsyncLoadGlobalToShareOpLowering( + const TypeConverter &converter, MLIRContext *ctx, + std::map &inAsyncLoad2Tags, + llvm::DenseMap &inAsyncLoad2Tagidex) + : OpConversionPattern(converter, + ctx), + asyncLoad2Tag(inAsyncLoad2Tags), + asyncLoad2Tagidex(inAsyncLoad2Tagidex) {} + + LogicalResult + matchAndRewrite(triton::gcu::AsyncLoadGlobalToShareOp asyncLoad, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, asyncLoad.getOperation()); + bool isPrologueLoad = false; + if (asyncLoad.getOperation()->getAttr("Prologue_stage_idex")) { + isPrologueLoad = true; + } + auto loc = asyncLoad.getLoc(); + auto zero = rewriter.create(loc, 0); + auto outputBuffer = adaptor.getDstMem(); + auto outputType = dyn_cast(outputBuffer.getType()); + auto elemType = outputType.getElementType(); + auto rank = outputType.getRank(); + SmallVector sourceShape; + sourceShape.push_back(adaptor.getShape()[0]); + for (unsigned i = 0; i < rank - 1; ++i) { + sourceShape.push_back(rewriter.create( + loc, adaptor.getStrides()[i], adaptor.getStrides()[i + 1])); + } + SmallVector offsets; + for (unsigned i = 0; i < adaptor.getOffsets().size(); ++i) { + auto offset = rewriter.create( + loc, rewriter.getI32Type(), adaptor.getOffsets()[i]); + offsets.push_back(offset); + } + SmallVector sliceShape; + for (unsigned i = 0; i < rank; ++i) { + sliceShape.push_back(rewriter.create( + loc, rewriter.getI32Type(), adaptor.getShape()[i])); + } + assert( + (asyncLoad2Tag.find(asyncLoad.getOperation()) != asyncLoad2Tag.end()) && + "AsyncLoadGlobalToShareOp had no mapping tags !!!"); + assert((asyncLoad2Tagidex.find(asyncLoad.getOperation()) != + asyncLoad2Tagidex.end()) && + "AsyncLoadGlobalToShareOp had no mapping tagindx !!!"); + if (isPrologueLoad == true) { + int32_t prologueIdx = + dyn_cast( + asyncLoad.getOperation()->getAttr("Prologue_stage_idex")) + .getInt(); + // get range from for + Operation *forUser = nullptr; + int32_t userNumber = 0; + for (Operation *user : asyncLoad.getOperation()->getUsers()) { + userNumber++; + if (isa(user)) { + forUser = user; + } + } + if (forUser == nullptr || userNumber > 2) { + asyncLoad.dump(); + assert(false && "please carefully check pingpong prologue flow!!!!"); + } + auto forOp = llvm::dyn_cast(forUser); + auto step = forOp.getStep(); + auto upperBound = forOp.getUpperBound(); + auto lowerBound = forOp.getLowerBound(); + + auto forRange = + rewriter.create(loc, upperBound, lowerBound); + auto reminAdd = rewriter.create( + loc, step, + rewriter.create( + step.getLoc(), rewriter.getIntegerAttr(step.getType(), 1))); + auto forStepNum = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, forRange), reminAdd), + step); + auto isSmallThan = rewriter.create( + loc, arith::CmpIPredicate::slt, + rewriter.create( + step.getLoc(), + rewriter.getIntegerAttr(step.getType(), prologueIdx)), + forStepNum); + rewriter.create( + loc, isSmallThan, [&](OpBuilder &builder, Location loc) { + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + auto defaultValue = + asyncLoad.getDefaultValue() + ? adaptor.getDefaultValue() + : triton::gcu::createConstantZero(rewriter, loc, elemType); + auto tagIdx = rewriter + .create( + loc, rewriter.getIndexType(), + asyncLoad2Tagidex[asyncLoad.getOperation()]) + .getResult(); + auto outTransType = MemRefType::get(outputType.getShape(), + outputType.getElementType()); + auto outTrans = builder.create(loc, outTransType); + rewriter.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + ConfigGcuLoad( + rewriter, loc, adaptor.getDstMem(), outTrans, + asyncLoad.getOperation(), outputType, adaptor.getPtr(), + adaptor.getStrides(), adaptor.getShape(), defaultValue, + asyncLoad2Tag[asyncLoad.getOperation()]->getResult(0), + tagIdx, true); + builder.create(loc); + }); + builder.create(loc, outTrans); + builder.create(loc); + }); + leaveTritionOp(rewriter, asyncLoad.getOperation()); + rewriter.replaceOp(asyncLoad, + asyncLoad2Tagidex[asyncLoad.getOperation()]); + return success(); + } + // to avoid share momeory race + rewriter.create(loc); + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + auto defaultValue = + asyncLoad.getDefaultValue() + ? adaptor.getDefaultValue() + : triton::gcu::createConstantZero(rewriter, loc, elemType); + auto tagIdx = rewriter + .create( + loc, rewriter.getIndexType(), + asyncLoad2Tagidex[asyncLoad.getOperation()]) + .getResult(); + auto outTransType = + MemRefType::get(outputType.getShape(), outputType.getElementType()); + auto outTrans = rewriter.create(loc, outTransType); + rewriter.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + ConfigGcuLoad(rewriter, loc, adaptor.getDstMem(), outTrans, + asyncLoad.getOperation(), outputType, adaptor.getPtr(), + adaptor.getStrides(), adaptor.getShape(), defaultValue, + asyncLoad2Tag[asyncLoad.getOperation()]->getResult(0), + tagIdx, true); + builder.create(loc); + }); + rewriter.create(loc, outTrans); + leaveTritionOp(rewriter, asyncLoad.getOperation()); + rewriter.replaceOp(asyncLoad, asyncLoad2Tagidex[asyncLoad.getOperation()]); + return success(); + } +}; + +struct TTAsyncWaitOpLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + std::map &asyncWait2Tag; + + explicit TTAsyncWaitOpLowering( + const TypeConverter &converter, MLIRContext *ctx, + std::map &inAsyncWait2Tag) + : OpConversionPattern(converter, ctx), + asyncWait2Tag(inAsyncWait2Tag) {} + + LogicalResult + matchAndRewrite(triton::gcu::AsyncWaitOp wait, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, wait.getOperation()); + auto loc = wait.getLoc(); + assert((asyncWait2Tag.find(wait.getOperation()) != asyncWait2Tag.end()) && + "AsyncWaitOp had no mapping tags !!!"); + auto zero = rewriter.create(loc, 0); + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + rewriter.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + auto tagIdx = + rewriter + .create(loc, rewriter.getIndexType(), + adaptor.getAsyncToken()[0]) + .getResult(); + WaitGcuLoadStore(builder, loc, + asyncWait2Tag[wait.getOperation()]->getResult(0), + tagIdx, zero); + builder.create(loc); + }); + rewriter.create(loc); + leaveTritionOp(rewriter, wait.getOperation()); + rewriter.replaceOp(wait, adaptor.getAsyncToken()[0]); + return success(); + } +}; +#endif diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUToGCUUtils.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUToGCUUtils.cpp new file mode 100644 index 000000000..790e20ce4 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUToGCUUtils.cpp @@ -0,0 +1,2291 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TritonGCUToGCUUtils.h" + +#include +#include +#include +#include + +#include "Analysis/FirstLastUserAnalysis.h" +#include "ConstantUtil.h" +#include "Conversion/TritonToGCU/Utils.h" + +#include "Dialect/GCU/IR/Dialect.h" +#include "Dialect/GCU/IR/Types.h" +#include "Dialect/MathExt/IR/MathExt.h" +#include "Dialect/MathExt/IR/MathExtTypes.h" +#include "Dialect/MemrefExt/IR/MemrefExt.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +#define DBG_TRITON_IR 1 +using namespace mlir; + +Value getPrivateDTETag(OpBuilder &builder, Operation *op) { + OpBuilder::InsertionGuard guard(builder); + auto func = op->getParentOfType(); + auto firstOp = &func.getFunctionBody().getBlocks().front().front(); + auto tagType = MemRefType::get(ArrayRef{1}, builder.getI32Type()); + if (isa(firstOp) && firstOp->getAttr("gcu.private_tag")) + return firstOp->getResult(0); + builder.setInsertionPoint(firstOp); + auto tag = builder.create(op->getLoc(), tagType); + tag->setAttr("gcu.private_tag", builder.getUnitAttr()); + return tag; +} + +Value getShareDTETag(OpBuilder &builder, Operation *op) { + OpBuilder::InsertionGuard guard(builder); + auto func = op->getParentOfType(); + auto secondIter = func.getFunctionBody().getBlocks().front().begin(); + secondIter++; + auto secondOp = &(*secondIter); + auto tagType = MemRefType::get(ArrayRef{1}, builder.getI32Type()); + if (isa(secondOp) && secondOp->getAttr("gcu.share_tag")) + return secondOp->getResult(0); + builder.setInsertionPoint(secondOp); + auto tag = builder.create(op->getLoc(), tagType); + tag->setAttr("gcu.share_tag", builder.getUnitAttr()); + return tag; +} + +Value createPrivateDTETag(OpBuilder &builder, Operation *op) { + auto tagType = MemRefType::get(ArrayRef{1}, builder.getI32Type()); + auto tag = builder.create(op->getLoc(), tagType); + tag->setAttr("gcu.private_tag", builder.getUnitAttr()); + return tag; +} + +DenseSet getSlicedAxies(Type type) { + DenseSet axies; + if (auto tType = dyn_cast(type)) { + auto numElems = triton::gcu::getElemsPerThread(type); + for (unsigned i = 0; i < tType.getRank(); ++i) { + if (numElems[i] != tType.getDimSize(i)) { + axies.insert(i); + } + } + } + return axies; +} + +SmallVector getWarpIds(OpBuilder &builder, Location loc, Type type) { + SmallVector warpIds; + if (auto tType = dyn_cast(type)) { + if (auto dotEnc = dyn_cast( + tType.getEncoding())) { + auto blockedLayout = + dyn_cast(dotEnc.getParent()); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + bool isM = dotEnc.getOpIdx() == 0; + for (unsigned i = 0; i < tType.getRank(); ++i) { + if (isM && i == rank - 2) { + auto id = builder.create( + loc, + builder.create( + loc, builder.create(loc, gpu::Dimension::x), + builder.create( + loc, warpsPerCTA[rank - 2] * warpsPerCTA[rank - 1])), + builder.create(loc, + warpsPerCTA[rank - 1])); + warpIds.push_back(id); + } else if ((!isM) && i == rank - 1) { + auto id = builder.create( + loc, + builder.create( + loc, builder.create(loc, gpu::Dimension::x), + builder.create( + loc, warpsPerCTA[rank - 2] * warpsPerCTA[rank - 1])), + builder.create(loc, + warpsPerCTA[rank - 1])); + warpIds.push_back(id); + } else { + warpIds.push_back(builder.create(loc, 0)); + } + } + } else if (auto blockEnc = dyn_cast( + tType.getEncoding())) { + auto slicedAxies = getSlicedAxies(type); + auto warps = blockEnc.getWarpsPerCTA(); + auto shapePerCTA = + triton::gpu::getShapePerCTA(blockEnc, tType.getShape()); + SmallVector warpMods(warps.size()); + SmallVector warpStrides(warps.size()); + unsigned warpMod = 1; + unsigned warpStride = 1; + for (int i = warps.size() - 1; i >= 0; --i) { + warpMod *= warps[i]; + warpMods[i] = warpMod; + warpStrides[i] = warpStride; + warpStride *= warps[i]; + } + unsigned i = 0; + for (auto num : triton::gcu::getElemsPerThread(type)) { + (void)num; + if (slicedAxies.count(i)) { + auto repeatNum = + shapePerCTA[i] > warps[i] ? 1 : warps[i] / shapePerCTA[i]; + auto id = builder.create( + loc, + builder.create( + loc, builder.create(loc, gpu::Dimension::x), + builder.create(loc, warpMods[i] / + repeatNum)), + builder.create(loc, warpStrides[i])); + warpIds.push_back(id); + } else { + warpIds.push_back(builder.create(loc, 0)); + } + ++i; + } + } else if (auto blockEnc = dyn_cast( + tType.getEncoding())) { + auto slicedAxies = getSlicedAxies(type); + auto warps = blockEnc.getWarpsPerCTA(); + auto shapePerCTA = + triton::gpu::getShapePerCTA(blockEnc, tType.getShape()); + SmallVector warpMods(warps.size()); + SmallVector warpStrides(warps.size()); + unsigned warpMod = 1; + unsigned warpStride = 1; + for (int i = warps.size() - 1; i >= 0; --i) { + warpMod *= warps[i]; + warpMods[i] = warpMod; + warpStrides[i] = warpStride; + warpStride *= warps[i]; + } + unsigned i = 0; + for (auto num : triton::gcu::getElemsPerThread(type)) { + (void)num; + if (slicedAxies.count(i)) { + auto repeatNum = + shapePerCTA[i] > warps[i] ? 1 : warps[i] / shapePerCTA[i]; + auto id = builder.create( + loc, + builder.create( + loc, builder.create(loc, gpu::Dimension::x), + builder.create(loc, warpMods[i] / + repeatNum)), + builder.create(loc, warpStrides[i])); + warpIds.push_back(id); + } else { + warpIds.push_back(builder.create(loc, 0)); + } + ++i; + } + } else if (auto sliceEnc = dyn_cast( + tType.getEncoding())) { + auto parent = sliceEnc.getParent(); + auto outShape = sliceEnc.paddedShape(tType.getShape()); + SmallVector sliceDims; + sliceDims.push_back(sliceEnc.getDim()); + while (auto innerSliceEnc = + dyn_cast(parent)) { + auto curSliceDim = innerSliceEnc.getDim(); + for (size_t idx = 0; idx < sliceDims.size(); idx++) { + if (sliceDims[idx] >= curSliceDim) { + sliceDims[idx] = sliceDims[idx] + 1; + } + } + llvm::ArrayRef inputShpe = outShape; + outShape = innerSliceEnc.paddedShape(inputShpe); + sliceDims.push_back(curSliceDim); + parent = innerSliceEnc.getParent(); + } + if (!isa(parent)) { + llvm::report_fatal_error("[Error] bad slice layout parent"); + assert(false && "bad slice layout parent"); + return warpIds; + } + auto blockEncParent = dyn_cast(parent); + size_t rank = outShape.size(); + SmallVector sizePerThread(rank, 1); + auto warpsPerCTA = blockEncParent.getWarpsPerCTA(); + auto threadsPerWarp = blockEncParent.getThreadsPerWarp(); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockEncParent, outShape); + assert(rank == sizePerThread.size() && + "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); + SmallVector parentElemsPerThread(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; + parentElemsPerThread[i] = + ceil(shapePerCTA[i], t) * sizePerThread[i]; + } + DenseSet slicedAxies; + for (unsigned i = 0; i < rank; ++i) { + if (parentElemsPerThread[i] != outShape[i]) { + slicedAxies.insert(i); + } + } + SmallVector warpMods(warpsPerCTA.size()); + SmallVector warpStrides(warpsPerCTA.size()); + unsigned warpMod = 1; + unsigned warpStride = 1; + for (int i = warpsPerCTA.size() - 1; i >= 0; --i) { + warpMod *= warpsPerCTA[i]; + warpMods[i] = warpMod; + warpStrides[i] = warpStride; + warpStride *= warpsPerCTA[i]; + } + SmallVector parentWarpIds; + for (unsigned i = 0; i < rank; ++i) { + if (slicedAxies.count(i)) { + if (llvm::is_contained(sliceDims, i)) { + llvm::report_fatal_error("[Error] bad slice layout shape"); + assert(false && "bad slice layout shape"); + } + auto repeatNum = shapePerCTA[i] > warpsPerCTA[i] + ? 1 + : warpsPerCTA[i] / shapePerCTA[i]; + auto id = builder.create( + loc, + builder.create( + loc, builder.create(loc, gpu::Dimension::x), + builder.create(loc, warpMods[i] / + repeatNum)), + builder.create(loc, warpStrides[i])); + warpIds.push_back(id); + } else { + if (!llvm::is_contained(sliceDims, i)) { + warpIds.push_back(builder.create(loc, 0)); + } + } + } + } else { + for (unsigned i = 0; i < tType.getRank(); ++i) { + warpIds.push_back(builder.create(loc, 0)); + } + } + } else { + warpIds.push_back(builder.create(loc, 0)); + } + return warpIds; +} + +SmallVector getElemsPerThread(OpBuilder &builder, Location loc, + Type type) { + SmallVector numElems; + if (auto ty = dyn_cast(type)) { + auto warpIds = getWarpIds(builder, loc, type); + unsigned i = 0; + for (auto num : triton::gcu::getElemsPerThread(type)) { + auto dim = builder.create(loc, ty.getDimSize(i)); + auto slice = builder.create(loc, num); + auto minNum = builder.create( + loc, slice, + builder.create( + loc, dim, builder.create(loc, warpIds[i], slice))); + numElems.push_back(minNum); + ++i; + } + } else { + numElems.push_back(builder.create(loc, 1)); + } + return numElems; +} + +func::FuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, + OpBuilder &rewriter, StringRef name, + FunctionType type) { + func::FuncOp ret; + if (!(ret = moduleOp.template lookupSymbol(name))) { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + ret = rewriter.create(loc, name, type); + ret.setPrivate(); + } + return ret; +} + +void doSlicePadOrMemsetSlice(OpBuilder &rewriter, Location loc, Operation *op, + Value output, Value src, + SmallVector &offsets, + SmallVector &sliceShape, + SmallVector &padSizes, + Value defaultValue, Value tag, Value tagIdx) { + auto maxPadSize = rewriter.create(loc, 2047); + auto outputType = dyn_cast(output.getType()); + auto legalPad = + rewriter.create(loc, 1, rewriter.getI1Type()) + .getResult(); + unsigned totalNumElems = 1; + for (int i = 0; i < outputType.getRank(); i++) { + totalNumElems *= outputType.getShape()[i]; + auto padSize = padSizes[i]; + auto legalPadRank = rewriter.create( + loc, arith::CmpIPredicate::sle, padSize, maxPadSize); + legalPad = rewriter.create(loc, legalPad, legalPadRank); + } + rewriter.create( + loc, legalPad, + [&](OpBuilder &builder, Location loc) { + builder.create(loc, output, src, offsets, + sliceShape, defaultValue, + tag, ValueRange{tagIdx}); + builder.create(loc); + }, + [&](OpBuilder &childBuilder, Location loc) { + doMemset(childBuilder, op, output, defaultValue, totalNumElems); + childBuilder.create( + loc, output, src, offsets, defaultValue, tag, ValueRange{tagIdx}); + childBuilder.create(loc); + }); +} + +void doMemFence(OpBuilder &rewriter, Operation *op) { /*NOLINT*/ + rewriter.create(op->getLoc()); +} + +void doMemsetConfig(OpBuilder &rewriter, Location loc, Value output, Value v, + Value tagDte, Value tagIdx) { + rewriter.create(loc, output, v, tagDte, + ValueRange{tagIdx}); +} + +void doMemset(OpBuilder &rewriter, Operation *op, Value output, Value v, + unsigned totalNumElems) { + auto loc = op->getLoc(); + if (totalNumElems > 128) { + auto tag = getPrivateDTETag(rewriter, op); + auto zero = rewriter.create(loc, 0); + rewriter.create(loc, output, v, tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + } else { + auto type = dyn_cast(output.getType()); + affine::buildAffineLoopNest( + rewriter, loc, SmallVector(type.getRank(), 0), + type.getShape(), SmallVector(type.getRank(), 1), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + builder.create(loc, v, output, iters); + }); + doMemFence(rewriter, op); + } +} + +Value castToMemref1D(OpBuilder &rewriter, Location loc, Value v, + Value totalNumElems) { + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + return rewriter.create( + loc, + MemRefType::get(ArrayRef{ShapedType::kDynamic}, + dyn_cast(v.getType()).getElementType()), + v, zero, ArrayRef{totalNumElems}, ArrayRef{one}); +} + +bool isMustAliasOp(mlir::Operation *op) { + if (llvm::isa(op)) { + return true; + } else if (llvm::isa(op)) { + auto convertLayout = cast(op); + auto src = convertLayout.getSrc(); + auto srcNumElems = triton::gcu::getElemsPerThread(src.getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(convertLayout.getType()); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(convertLayout.getType()); + if ((!srcTy) || (!dstTy)) { + assert(false && "srcTy or dstTy not a RankedTensorType"); + } + + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (srcLayout == dstLayout) { + return true; + } + if (srcNumElems == dstNumElems && + src.getType().getShape() == convertLayout.getType().getShape()) { + if (!mlir::isa(srcLayout)) { + return true; + } else if (isa(srcLayout) && + isa(dstLayout)) { + if (cast(srcLayout).getDim() == + cast(dstLayout).getDim()) { + return true; + } + } + } + return false; + } else if (isa(op)) { + auto expandDimOp = cast(op); + auto srcNumElems = + triton::gcu::getElemsPerThread(expandDimOp.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(expandDimOp.getType()); + srcNumElems.insert(srcNumElems.begin() + expandDimOp.getAxis(), 1); + if (srcNumElems == dstNumElems) { + return true; + } + return false; + } else if (isa(op)) { + auto reshapeOp = cast(op); + auto srcNumElems = + triton::gcu::getElemsPerThread(reshapeOp.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(reshapeOp.getType()); + if (srcNumElems == dstNumElems) { + return true; + } + return false; + } else if (isa(op)) { + auto broastOp = cast(op); + auto srcNumElems = + triton::gcu::getElemsPerThread(broastOp.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(broastOp.getType()); + if (srcNumElems == dstNumElems) { + return true; + } + return false; + } else { + return false; + } +} + +// Find last user which is located at parent region of the op +mlir::Operation * +promoteLastUser(mlir::Operation *&lastUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) { + if (!isa_and_nonnull(lastUser)) { + return nullptr; + } + mlir::Operation *newAllocOpPos = nullptr; + mlir::Operation *curLastUser = lastUser; + mlir::Operation *parent = curLastUser->getParentOp(); + mlir::Operation *originParent = nullptr; + + while (isa(parent) && + isa(curLastUser)) { + if (replaced2Origin.count(parent) == 0) { + if (llvm::none_of(parent->getOperandTypes(), + [](auto t) { return isa(t); }) && + llvm::none_of(parent->getResultTypes(), + [](auto t) { return isa(t); })) { + originParent = parent; + } else { + llvm_unreachable("can't find the origin op"); + } + } else { + originParent = replaced2Origin[parent]; + } + // Need to be the replaced op + newAllocOpPos = parent; + curLastUser = userAnalysis.getLastUserOp(originParent); + parent = curLastUser->getParentOp(); + } + + lastUser = curLastUser; + return newAllocOpPos; +} + +void addDeallocAfterLastUser(OpBuilder &builder, mlir::Operation *lastUser, + Value alloc) { + if (lastUser == nullptr) { + return; + } + if (isa(lastUser) || isa(lastUser) || + isa(lastUser) || isa(lastUser) || + isa(lastUser)) { + return; + } + if (isa(alloc.getType())) { + if (lastUser->mightHaveTrait()) { + return; + } + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointAfter(lastUser); + builder.create(lastUser->getLoc(), alloc); + builder.restoreInsertionPoint(ip); + } + return; +} + +// lowering +Value syncAllocOp(OpBuilder &builder, Location &loc, Operation *lastUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin, + MemRefType type) { + auto newAllocOpPos = promoteLastUser(lastUser, userAnalysis, replaced2Origin); + + Value output; + if (newAllocOpPos == nullptr) { + output = builder.create(loc, type); + } else { + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(newAllocOpPos); + output = builder.create(loc, type); + builder.restoreInsertionPoint(ip); + } + addDeallocAfterLastUser(builder, lastUser, output); + return output; +} + +Value asyncAllocOp(OpBuilder &builder, Operation *ttParent, MemRefType type) { + return builder.create(ttParent->getLoc(), type); +} + +void createPrintfOp(ConversionPatternRewriter &rewriter, Location loc, + ::llvm::StringRef printOpPrefix, bool hex, Value value) { + auto printSingleElement = [&](Value operand, size_t i, size_t n, + ValueRange iters) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << printOpPrefix << ": "; + if (n > 1) + os << "(operand " << i << ") "; + + // format + auto msg = TypeSwitch(operand.getType()) + .Case( + [&](auto ty) { + if (hex) { + os << "0x%x "; + return "0x%x "; + } else { + os << "%d "; + return "%d "; + } + }) + .Default([&](auto ty) { + os << "%f "; + return "%f "; + }); + + // value + SmallVector values; + auto value = TypeSwitch(operand.getType()) + .Case([&](auto ty) { + return rewriter.create(loc, operand); + }) + .Default([&](auto ty) { return operand; }); + values.push_back(value); + + if (!iters.empty()) { + // idx format + os << "(idx "; + for (auto iter = iters.begin(); iter != iters.end(); ++iter) { + if (iter != iters.begin()) + os << ", "; + os << "%d"; + } + os << ")"; + // idx value + values.append(iters.begin(), iters.end()); + } + os << "\n"; + + if (!msg.empty()) + rewriter.create(loc, formatStr, ValueRange{values}); + }; + + auto printOperand = [&](Value operand, size_t i, size_t n) { + TypeSwitch(operand.getType()) + .Case([&](auto ty) { + affine::buildAffineLoopNest( + rewriter, loc, SmallVector(ty.getRank(), 0), + ty.getShape(), SmallVector(ty.getRank(), 1), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + auto v = builder.create(loc, operand, iters); + printSingleElement(v, i, n, iters); + }); + }) + .Default([&](auto ty) { printSingleElement(operand, i, n, {}); }); + }; + + printOperand(value, 0, 1); +} + +void enterTritionOp(ConversionPatternRewriter &rewriter, Operation *ttParent) { + if (DBG_TRITON_IR) { + auto border = + rewriter.create(ttParent->getLoc()).getOperation(); + border->setAttr("enter", ttParent->getName().getIdentifier()); + } + return; +} + +void leaveTritionOp(ConversionPatternRewriter &rewriter, Operation *ttParent) { + if (DBG_TRITON_IR) { + auto border = + rewriter.create(ttParent->getLoc()).getOperation(); + border->setAttr("leave", ttParent->getName().getIdentifier()); + } + return; +} + +void mergeContinuousDims(OpBuilder &subBuilder, Location loc, + Value &sharedMemref, Value &warpMemref, + SmallVector &offsets, + SmallVector &mergedOffsets, + MemRefType &sharedMemType, MemRefType &warpMemType, + Value &sharedBuffer, Value &warpOutput) { + SmallVector mergedSharedMemShapes; + SmallVector mergedWarpMemShapes; + auto zeroI32 = subBuilder.create( + loc, subBuilder.getIntegerAttr(subBuilder.getIntegerType(32), 0)); + int64_t mergeShape = 1; + for (int i = 0; i < sharedMemType.getRank(); i++) { + if (sharedMemType.getShape()[i] != warpMemType.getShape()[i]) { + if (i > 0 && + sharedMemType.getShape()[i - 1] == warpMemType.getShape()[i - 1]) { + mergedSharedMemShapes.push_back(mergeShape); + mergedWarpMemShapes.push_back(mergeShape); + mergedOffsets.push_back(zeroI32); + } + mergedSharedMemShapes.push_back(sharedMemType.getShape()[i]); + mergedWarpMemShapes.push_back(warpMemType.getShape()[i]); + mergedOffsets.push_back(offsets[i]); + mergeShape = 1; + } else { + if (i == sharedMemType.getRank() - 1) { + mergedSharedMemShapes.push_back(sharedMemType.getShape()[i] * + mergeShape); + mergedWarpMemShapes.push_back(warpMemType.getShape()[i] * mergeShape); + mergedOffsets.push_back(zeroI32); + } else { + mergeShape *= sharedMemType.getShape()[i]; + } + } + } + auto mergedSharedMemType = MemRefType::get( + mergedSharedMemShapes, sharedMemType.getElementType(), AffineMap{}, + subBuilder.getI64IntegerAttr(2) /*shared memory*/); + auto mergedWarpMemType = + MemRefType::get(mergedWarpMemShapes, warpMemType.getElementType()); + auto [sharedMemStrides, sharedMemOffset] = + mergedSharedMemType.getStridesAndOffset(); + sharedMemref = subBuilder.create( + loc, mergedSharedMemType, sharedBuffer, sharedMemOffset, + mergedSharedMemShapes, sharedMemStrides); + auto [warpMemStrides, warpMemOffset] = + mergedWarpMemType.getStridesAndOffset(); + warpMemref = subBuilder.create( + loc, mergedWarpMemType, warpOutput, warpMemOffset, mergedWarpMemShapes, + warpMemStrides); + return; +} + +Value loadFromSharedMem(OpBuilder &builder, Value tag, Type type, Value buffer, + bool onlyThread0, Operation *lastTTUser, + Operation *firstTTUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) { + auto loc = buffer.getLoc(); + auto srcType = dyn_cast(buffer.getType()); + auto numElems = triton::gcu::getElemsPerThread(type); + auto totalNumElems = builder.create( + loc, triton::gcu::getTotalElemsPerThread(type)); + auto outputType = + MemRefType::get(SmallVector(numElems.begin(), numElems.end()), + srcType.getElementType()); + + auto warpIds = getWarpIds(builder, loc, type); + SmallVector offsets; + auto zero = builder.create(loc, 0); + for (unsigned i = 0; i < srcType.getRank(); ++i) { + offsets.push_back(builder.create( + loc, + builder.create(loc, numElems[i], + builder.getI32Type()), + builder.create(loc, builder.getI32Type(), + warpIds[i]))); + } + + auto output = syncAllocOp(builder, loc, lastTTUser, userAnalysis, + replaced2Origin, outputType); + + auto isThread0 = builder.create( + loc, arith::CmpIPredicate::eq, + builder.create(loc, gpu::Dimension::x), zero); + auto defaultValue = + triton::gcu::createConstantZero(builder, loc, srcType.getElementType()); + bool isNeedMerge = srcType.getRank() > 5; + SmallVector mergedOffsets; + Value src; + Value dst; + if (!firstTTUser) { + if (onlyThread0) { + builder.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + if (isNeedMerge) { + mergeContinuousDims(builder, loc, src, dst, offsets, + mergedOffsets, srcType, outputType, buffer, + output); + builder.create( + loc, dst, src, mergedOffsets, defaultValue, tag, + ValueRange{zero}); + auto [oriOutputStrides, oriOutputOffset] = + outputType.getStridesAndOffset(); + builder.create( + loc, outputType, dst, oriOutputOffset, + SmallVector(numElems.begin(), numElems.end()), + oriOutputStrides); + } else { + builder.create(loc, output, buffer, + offsets, defaultValue, + tag, ValueRange{zero}); + } + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + builder.create(loc); + }); + } else { + if (isNeedMerge) { + mergeContinuousDims(builder, loc, src, dst, offsets, mergedOffsets, + srcType, outputType, buffer, output); + builder.create( + loc, dst, src, mergedOffsets, defaultValue, tag, ValueRange{zero}); + auto [oriOutputStrides, oriOutputOffset] = + outputType.getStridesAndOffset(); + builder.create( + loc, outputType, dst, oriOutputOffset, + SmallVector(numElems.begin(), numElems.end()), + oriOutputStrides); + } else { + builder.create( + loc, output, buffer, offsets, defaultValue, tag, ValueRange{zero}); + } + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + } + } else { + if (onlyThread0) { + builder.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + if (isNeedMerge) { + mergeContinuousDims(builder, loc, src, dst, offsets, + mergedOffsets, srcType, outputType, buffer, + output); + builder.create( + loc, dst, src, mergedOffsets, defaultValue, tag, + ValueRange{zero}); + auto [oriOutputStrides, oriOutputOffset] = + outputType.getStridesAndOffset(); + builder.create( + loc, outputType, dst, oriOutputOffset, + SmallVector(numElems.begin(), numElems.end()), + oriOutputStrides); + } else { + builder.create(loc, output, buffer, + offsets, defaultValue, + tag, ValueRange{zero}); + } + builder.create(loc); + }); + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(firstTTUser); + builder.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + builder.create(loc); + }); + builder.restoreInsertionPoint(ip); + } else { + if (isNeedMerge) { + mergeContinuousDims(builder, loc, src, dst, offsets, mergedOffsets, + srcType, outputType, buffer, output); + builder.create( + loc, dst, src, mergedOffsets, defaultValue, tag, ValueRange{zero}); + auto [oriOutputStrides, oriOutputOffset] = + outputType.getStridesAndOffset(); + builder.create( + loc, outputType, dst, oriOutputOffset, + SmallVector(numElems.begin(), numElems.end()), + oriOutputStrides); + } else { + builder.create( + loc, output, buffer, offsets, defaultValue, tag, ValueRange{zero}); + } + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(firstTTUser); + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + builder.restoreInsertionPoint(ip); + } + } + return output; +} + +Value CopyFromSharedMem(OpBuilder &builder, Value tag, Type type, Value buffer, + bool onlyThread0, Operation *lastTTUser, + Operation *firstTTUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) { + auto loc = buffer.getLoc(); + auto srcType = dyn_cast(buffer.getType()); + auto shape = srcType.getShape(); + auto numElems = triton::gcu::getElemsPerThread(type); + auto totalNumElems = builder.create( + loc, triton::gcu::getTotalElemsPerThread(type)); + auto outputType = + MemRefType::get(SmallVector(numElems.begin(), numElems.end()), + srcType.getElementType()); + + auto output = syncAllocOp(builder, loc, lastTTUser, userAnalysis, + replaced2Origin, outputType); + auto zero = builder.create(loc, 0); + auto isThread0 = builder.create( + loc, arith::CmpIPredicate::eq, + builder.create(loc, gpu::Dimension::x), zero); + if (!firstTTUser) { + if (onlyThread0) { + builder.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + builder.create( + loc, buffer, SmallVector(shape.size(), zero), output, + SmallVector(shape.size(), zero), totalNumElems, tag, + ValueRange{zero}); + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + builder.create(loc); + }); + } else { + builder.create( + loc, buffer, SmallVector(shape.size(), zero), output, + SmallVector(shape.size(), zero), totalNumElems, tag, + ValueRange{zero}); + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + } + } else { + if (onlyThread0) { + builder.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + builder.create( + loc, buffer, SmallVector(shape.size(), zero), output, + SmallVector(shape.size(), zero), totalNumElems, tag, + ValueRange{zero}); + builder.create(loc); + }); + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(firstTTUser); + builder.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + builder.create(loc); + }); + builder.restoreInsertionPoint(ip); + } else { + builder.create( + loc, buffer, SmallVector(shape.size(), zero), output, + SmallVector(shape.size(), zero), totalNumElems, tag, + ValueRange{zero}); + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(firstTTUser); + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + builder.restoreInsertionPoint(ip); + } + } + return output; +} + +Value loadFromSharedMemForDotOperand( + OpBuilder builder, Value tag, Type type, ArrayRef mnShape, + Value sharedBuffer, Operation *lastTTUser, Operation *firstTTUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) { + auto loc = sharedBuffer.getLoc(); + auto srcType = dyn_cast(sharedBuffer.getType()); + auto numElems = triton::gcu::getElemsPerThread(type); + auto totalNumElems = builder.create( + loc, triton::gcu::getTotalElemsPerThread(type)); + auto outputType = + MemRefType::get(SmallVector(numElems.begin(), numElems.end()), + srcType.getElementType()); + + auto warpIds = getWarpIds(builder, loc, type); + SmallVector offsets; + auto zero = builder.create(loc, 0); + for (unsigned i = 0; i < srcType.getRank(); ++i) { + offsets.push_back(builder.create( + loc, + builder.create(loc, numElems[i], + builder.getI32Type()), + builder.create(loc, builder.getI32Type(), + warpIds[i]))); + } + + auto output = syncAllocOp(builder, loc, lastTTUser, userAnalysis, + replaced2Origin, outputType); + auto defaultValue = + triton::gcu::createConstantZero(builder, loc, srcType.getElementType()); + builder.create(loc, output, sharedBuffer, offsets, + defaultValue, tag, ValueRange{zero}); + if (firstTTUser == nullptr) { + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + } else { + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(firstTTUser); + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + builder.restoreInsertionPoint(ip); + } + return output; +} + +void storeToSharedMem(OpBuilder &builder, Value tag, TensorType type, + Value sharedBuffer, Value buffer, bool onlyThread0) { + auto loc = buffer.getLoc(); + auto srcType = dyn_cast(buffer.getType()); + auto outputType = dyn_cast(sharedBuffer.getType()); + auto totalNumElems = builder.create( + loc, triton::gcu::getTotalElemsPerThread(type)); + + SmallVector offsets; + SmallVector outputSize; + auto warpIds = getWarpIds(builder, loc, type); + auto zero = builder.create(loc, 0); + for (unsigned i = 0; i < srcType.getRank(); ++i) { + offsets.push_back(builder.create( + loc, + builder.create(loc, srcType.getDimSize(i), + builder.getI32Type()), + builder.create(loc, builder.getI32Type(), + warpIds[i]))); + outputSize.push_back(outputType.getShape()[i]); + } + auto isThread0 = builder.create( + loc, arith::CmpIPredicate::eq, + builder.create(loc, gpu::Dimension::x), zero); + bool isNeedMerge = srcType.getRank() > 5; + SmallVector mergedOffsets; + Value src; + Value dst; + if (onlyThread0) { + builder.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + if (isNeedMerge) { + mergeContinuousDims(builder, loc, dst, src, offsets, mergedOffsets, + outputType, srcType, sharedBuffer, buffer); + auto [oriOutputStrides, oriOutputOffset] = + outputType.getStridesAndOffset(); + builder.create( + loc, dst, src, mergedOffsets, tag, ValueRange{zero}); + builder.create( + loc, outputType, dst, oriOutputOffset, outputSize, + oriOutputStrides); + } else { + builder.create( + loc, sharedBuffer, buffer, offsets, tag, ValueRange{zero}); + } + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + builder.create(loc); + }); + } else { + if (isNeedMerge) { + mergeContinuousDims(builder, loc, dst, src, offsets, mergedOffsets, + outputType, srcType, sharedBuffer, buffer); + auto [oriOutputStrides, oriOutputOffset] = + outputType.getStridesAndOffset(); + builder.create(loc, dst, src, mergedOffsets, + tag, ValueRange{zero}); + builder.create( + loc, outputType, dst, oriOutputOffset, outputSize, oriOutputStrides); + } else { + builder.create( + loc, sharedBuffer, buffer, offsets, tag, ValueRange{zero}); + } + builder.create(loc, tag, ValueRange{zero}, + totalNumElems); + } + builder.create(loc); +} + +Value storeToSharedMem(OpBuilder &builder, Value tag, TensorType type, + Value buffer, bool onlyThread0, Operation *lastTTUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) { + auto loc = buffer.getLoc(); + auto mergedType = MemRefType::get( + type.getShape(), dyn_cast(buffer.getType()).getElementType(), + AffineMap{}, + builder.getI64IntegerAttr(2)); // shared memory + + auto merged = syncAllocOp(builder, loc, lastTTUser, userAnalysis, + replaced2Origin, mergedType); + storeToSharedMem(builder, tag, type, merged, buffer, onlyThread0); + return merged; +} + +// refine yiled memref operand +void AnalysisYieldOperendUseStage( + Operation *module, triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map> + &TTYeiledOPerandHasMultiUseStage) { + module->walk([&](scf::YieldOp op) { + if (isa(op.getOperation()->getParentOp())) { + // check arg user + for (uint64_t i = 0; i < op.getOperands().size(); ++i) { + TTYeiledOPerandHasMultiUseStage[op.getOperation()][i] = false; + auto operand = op.getOperands()[i]; + auto definingOp = operand.getDefiningOp(); + if (!definingOp) { + TTYeiledOPerandHasMultiUseStage[op.getOperation()][i] = true; + continue; + } + if (isa(operand.getType()) && + isa(op.getOperation()->getParentOp())) { + auto forOp = llvm::cast(op.getOperation()->getParentOp()); + auto reginArg = forOp.getRegionIterArgs()[i]; + auto lastOp = userAnalysis.getLastUserOp( + reginArg, definingOp->getParentRegion()); + if (lastOp == nullptr) { + TTYeiledOPerandHasMultiUseStage[op.getOperation()][i] = true; + } else { + if (definingOp->isBeforeInBlock(lastOp)) { + TTYeiledOPerandHasMultiUseStage[op.getOperation()][i] = true; + } + } + } else if (isa(operand.getType()) && + isa(op.getOperation()->getParentOp())) { + auto whileOp = + llvm::cast(op.getOperation()->getParentOp()); + mlir::Value reginArg = whileOp.getAfterArguments()[i]; + auto lastOp = userAnalysis.getLastUserOp( + reginArg, definingOp->getParentRegion()); + if (lastOp == nullptr) { + TTYeiledOPerandHasMultiUseStage[op.getOperation()][i] = true; + } else { + if (definingOp->isBeforeInBlock(lastOp)) { + TTYeiledOPerandHasMultiUseStage[op.getOperation()][i] = true; + } + } + } + } + } + }); +} + +void GetOrderValueByStride( + OpBuilder &rewriter, Location loc, SmallVector nInitStrideDims, + SmallVector &initStride, SmallVector &initShape, + SmallVector &initOffset, SmallVector &orderStride, + SmallVector &orderShape, SmallVector &orderOffset, + SmallVector &vOrder) { + int64_t rank = static_cast(nInitStrideDims.size()); + auto elementTyIdx = rewriter.getIndexType(); + auto elementTyI32 = rewriter.getI32Type(); + auto tmpStrideBuffer = rewriter.create( + loc, MemRefType::get(rank, elementTyIdx)); + auto tmpShapeBuffer = rewriter.create( + loc, MemRefType::get(rank, elementTyIdx)); + auto tmpOffsetBuffer = rewriter.create( + loc, MemRefType::get(rank, elementTyIdx)); + auto tmpOrderBuffer = rewriter.create( + loc, MemRefType::get(rank, elementTyI32)); + for (unsigned i = 0; i < rank; ++i) { + auto idx = rewriter.create(loc, i); + rewriter.create(loc, initStride[i], tmpStrideBuffer, + ValueRange{idx}); + rewriter.create(loc, initShape[i], tmpShapeBuffer, + ValueRange{idx}); + rewriter.create(loc, initOffset[i], tmpOffsetBuffer, + ValueRange{idx}); + rewriter.create( + loc, + rewriter.create(loc, nInitStrideDims[i], + rewriter.getI32Type()), + tmpOrderBuffer, ValueRange{idx}); + } + + Value zero = rewriter.create(loc, 0); + Value cEnd = rewriter.create(loc, rank - 1); + Value cStep = rewriter.create(loc, 1); + scf::buildLoopNest( + rewriter, loc, zero, cEnd, cStep, + [&](OpBuilder &outerBuilder, Location outerLoc, ValueRange ivsOuter) { + Value i = ivsOuter[0]; + Value cInnerEnd = outerBuilder.create(outerLoc, cEnd, i); + scf::buildLoopNest( + outerBuilder, outerLoc, zero, cInnerEnd, cStep, + [&](OpBuilder &innerBuilder, Location innerLoc, + ValueRange ivsInner) { + Value j = ivsInner[0]; + Value jNext = + innerBuilder.create(innerLoc, j, cStep); + + Value vStrideJ = innerBuilder.create( + innerLoc, tmpStrideBuffer, j); + Value vStrideJNext = innerBuilder.create( + innerLoc, tmpStrideBuffer, jNext); + Value vShapeJ = innerBuilder.create( + innerLoc, tmpShapeBuffer, j); + Value vShapeJNext = innerBuilder.create( + innerLoc, tmpShapeBuffer, jNext); + Value vOffsetJ = innerBuilder.create( + innerLoc, tmpOffsetBuffer, j); + Value vOffsetJNext = innerBuilder.create( + innerLoc, tmpOffsetBuffer, jNext); + Value vOrderJ = innerBuilder.create( + innerLoc, tmpOrderBuffer, j); + Value vOrderJNext = innerBuilder.create( + innerLoc, tmpOrderBuffer, jNext); + + Value cmp = innerBuilder.create( + innerLoc, arith::CmpIPredicate::slt, vStrideJ, vStrideJNext); + innerBuilder.create( + innerLoc, cmp, + [&](OpBuilder &thenBuilder, Location thenLoc) { + thenBuilder.create(thenLoc, vStrideJNext, + tmpStrideBuffer, j); + thenBuilder.create(thenLoc, vStrideJ, + tmpStrideBuffer, jNext); + thenBuilder.create(thenLoc, vShapeJNext, + tmpShapeBuffer, j); + thenBuilder.create(thenLoc, vShapeJ, + tmpShapeBuffer, jNext); + thenBuilder.create(thenLoc, vOffsetJNext, + tmpOffsetBuffer, j); + thenBuilder.create(thenLoc, vOffsetJ, + tmpOffsetBuffer, jNext); + thenBuilder.create(thenLoc, vOrderJNext, + tmpOrderBuffer, j); + thenBuilder.create(thenLoc, vOrderJ, + tmpOrderBuffer, jNext); + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + elseBuilder.create(elseLoc); + }); + }); + }); + + for (unsigned i = 0; i < rank; ++i) { + auto idx = rewriter.create(loc, i); + vOrder.push_back( + rewriter.create(loc, tmpOrderBuffer, ValueRange{idx})); + orderStride.push_back( + rewriter.create(loc, tmpStrideBuffer, ValueRange{idx})); + orderOffset.push_back(rewriter.create( + loc, rewriter.getI32Type(), + rewriter.create(loc, tmpOffsetBuffer, + ValueRange{idx}))); + } + orderShape.push_back( + rewriter.create(loc, tmpShapeBuffer, ValueRange{zero})); + for (unsigned i = 0; i < rank - 1; ++i) { + orderShape.push_back(rewriter.create(loc, orderStride[i], + orderStride[i + 1])); + } + + rewriter.create(loc, tmpStrideBuffer); + rewriter.create(loc, tmpShapeBuffer); + rewriter.create(loc, tmpOffsetBuffer); + rewriter.create(loc, tmpOrderBuffer); +} + +void GetOrderSlicefor30(OpBuilder &rewriter, Location loc, int64_t rank, + SmallVector &initStride, + SmallVector &initSliceShape, + SmallVector &orderSliceShape) { + auto elementTyIdx = rewriter.getIndexType(); + auto tmpStrideBuffer = rewriter.create( + loc, MemRefType::get(rank, elementTyIdx)); + auto tmpShapeBuffer = rewriter.create( + loc, MemRefType::get(rank, elementTyIdx)); + for (unsigned i = 0; i < rank; ++i) { + auto idx = rewriter.create(loc, i); + rewriter.create(loc, initStride[i], tmpStrideBuffer, + ValueRange{idx}); + rewriter.create(loc, initSliceShape[i], tmpShapeBuffer, + ValueRange{idx}); + } + + Value zero = rewriter.create(loc, 0); + Value cEnd = rewriter.create(loc, rank - 1); + Value cStep = rewriter.create(loc, 1); + scf::buildLoopNest( + rewriter, loc, zero, cEnd, cStep, + [&](OpBuilder &outerBuilder, Location outerLoc, ValueRange ivsOuter) { + Value i = ivsOuter[0]; + Value cInnerEnd = outerBuilder.create(outerLoc, cEnd, i); + scf::buildLoopNest( + outerBuilder, outerLoc, zero, cInnerEnd, cStep, + [&](OpBuilder &innerBuilder, Location innerLoc, + ValueRange ivsInner) { + Value j = ivsInner[0]; + Value jNext = + innerBuilder.create(innerLoc, j, cStep); + + Value vStrideJ = innerBuilder.create( + innerLoc, tmpStrideBuffer, j); + Value vStrideJNext = innerBuilder.create( + innerLoc, tmpStrideBuffer, jNext); + Value vShapeJ = innerBuilder.create( + innerLoc, tmpShapeBuffer, j); + Value vShapeJNext = innerBuilder.create( + innerLoc, tmpShapeBuffer, jNext); + + Value cmp = innerBuilder.create( + innerLoc, arith::CmpIPredicate::slt, vStrideJ, vStrideJNext); + innerBuilder.create( + innerLoc, cmp, + [&](OpBuilder &thenBuilder, Location thenLoc) { + thenBuilder.create(thenLoc, vStrideJNext, + tmpStrideBuffer, j); + thenBuilder.create(thenLoc, vStrideJ, + tmpStrideBuffer, jNext); + thenBuilder.create(thenLoc, vShapeJNext, + tmpShapeBuffer, j); + thenBuilder.create(thenLoc, vShapeJ, + tmpShapeBuffer, jNext); + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + elseBuilder.create(elseLoc); + }); + }); + }); + + for (unsigned i = 0; i < rank; ++i) { + auto idx = rewriter.create(loc, i); + orderSliceShape.push_back( + rewriter.create(loc, tmpShapeBuffer, ValueRange{idx})); + } + + rewriter.create(loc, tmpStrideBuffer); + rewriter.create(loc, tmpShapeBuffer); +} + +void GetTransByOrder(OpBuilder &rewriter, Location loc, + SmallVector &order, + SmallVector &transOrder) { + unsigned rank = order.size(); + + auto elementTyI32 = rewriter.getI32Type(); + auto tmpOrderBuffer = rewriter.create( + loc, MemRefType::get(rank, elementTyI32)); + auto tmpTransOrderBuffer = rewriter.create( + loc, MemRefType::get(rank, elementTyI32)); + for (unsigned i = 0; i < rank; ++i) { + auto idx = rewriter.create(loc, i); + rewriter.create( + loc, + rewriter.create(loc, i, rewriter.getI32Type()), + tmpTransOrderBuffer, ValueRange{idx}); + rewriter.create(loc, order[i], tmpOrderBuffer, + ValueRange{idx}); + } + + Value zero = rewriter.create(loc, 0); + Value cEnd = rewriter.create(loc, rank - 1); + Value cStep = rewriter.create(loc, 1); + scf::buildLoopNest( + rewriter, loc, zero, cEnd, cStep, + [&](OpBuilder &outerBuilder, Location outerLoc, ValueRange ivsOuter) { + Value i = ivsOuter[0]; + Value cInnerEnd = outerBuilder.create(outerLoc, cEnd, i); + scf::buildLoopNest( + outerBuilder, outerLoc, zero, cInnerEnd, cStep, + [&](OpBuilder &innerBuilder, Location innerLoc, + ValueRange ivsInner) { + Value j = ivsInner[0]; + Value jNext = + innerBuilder.create(innerLoc, j, cStep); + + Value vOrderJ = innerBuilder.create( + innerLoc, tmpOrderBuffer, j); + Value vOrderJNext = innerBuilder.create( + innerLoc, tmpOrderBuffer, jNext); + Value vTransJ = innerBuilder.create( + innerLoc, tmpTransOrderBuffer, j); + Value vTransJNext = innerBuilder.create( + innerLoc, tmpTransOrderBuffer, jNext); + + Value cmp = innerBuilder.create( + innerLoc, arith::CmpIPredicate::sgt, vOrderJ, vOrderJNext); + innerBuilder.create( + innerLoc, cmp, + [&](OpBuilder &thenBuilder, Location thenLoc) { + thenBuilder.create(thenLoc, vOrderJNext, + tmpOrderBuffer, j); + thenBuilder.create(thenLoc, vOrderJ, + tmpOrderBuffer, jNext); + thenBuilder.create(thenLoc, vTransJNext, + tmpTransOrderBuffer, j); + thenBuilder.create( + thenLoc, vTransJ, tmpTransOrderBuffer, jNext); + thenBuilder.create(thenLoc); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + elseBuilder.create(elseLoc); + }); + }); + }); + + for (unsigned i = 0; i < rank; ++i) { + auto idx = rewriter.create(loc, i); + transOrder.push_back(rewriter.create( + loc, tmpTransOrderBuffer, ValueRange{idx})); + } + + rewriter.create(loc, tmpOrderBuffer); + rewriter.create(loc, tmpTransOrderBuffer); +} + +Value ConfigGcuLoad(OpBuilder &rewriter, Location loc, Value srcOut, + Value transOut, mlir::Operation *op, MemRefType resultType, + Value loadPtr, mlir::ValueRange configStrides, + mlir::ValueRange configShapes, Value defaultValue, + Value tagDte, Value tagIdx, bool IsShareOutput) { + if ((!llvm::isa_and_nonnull(op)) && + (!llvm::isa_and_nonnull(op))) { + assert(false && "please check IR ConfigGcuLoad got a bad input op"); + } + auto getOrderHint = [](mlir::Operation *op) { + if (llvm::isa_and_nonnull(op)) { + auto load = llvm::cast(op); + return load.getOrderHint(); + } else if (llvm::isa_and_nonnull( + op)) { + auto dynamicLoad = llvm::cast(op); + return dynamicLoad.getOrderHint(); + } else { + assert(false && "please check IR ConfigGcuLoad got a bad input op"); + } + return llvm::ArrayRef(); + }; + auto getDefaultValue = [](mlir::Operation *op) { + if (llvm::isa_and_nonnull(op)) { + auto load = llvm::cast(op); + return load.getDefaultValue(); + } else if (llvm::isa_and_nonnull( + op)) { + auto dynamicLoad = llvm::cast(op); + return dynamicLoad.getDefaultValue(); + } else { + assert(false && "please check IR ConfigGcuLoad got a bad input op"); + } + return Value(); + }; + + auto elemType = resultType.getElementType(); + int64_t rank = resultType.getRank(); + + auto buffer = rewriter.create( + loc, MemRefType::get(ArrayRef{ShapedType::kDynamic}, elemType), + loadPtr); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + bool bDynamicStride = false; + bool bStaticTranspose = false; + bool bReshape = true; + SmallVector updateStrideDims; + SmallVector nInitStrideDims; + auto hint = getOrderHint(op); + int64_t hint_size = static_cast(hint.size()); + assert(hint_size == rank /* || hint_size == 0*/); + SmallVector order_hint; + for (unsigned i = 0; i < rank; ++i) + if (hint_size == 0) + order_hint.push_back(-1); + else + order_hint.push_back(hint[i]); + + for (unsigned i = 0; i < rank; ++i) { + if (order_hint[i] == -1) { + bDynamicStride = true; + auto trueCondition = rewriter.create( + loc, arith::CmpIPredicate::ne, configStrides[i], zero); + rewriter.create( + loc, trueCondition, + "Not Support dynamic stride is 0, please add tl.constexpr to stride " + "arg in kernel args list", + "", "", 0); + } + } + + for (int i = 0; i < rank; ++i) { + if ((bDynamicStride && order_hint[i] == 1) || + (!bDynamicStride && order_hint[i] == 0)) { + bReshape = false; + break; + } + } + + for (int i = 0; i < rank; ++i) { + if (bDynamicStride && order_hint[i] == 0) + updateStrideDims.push_back(i); + else + nInitStrideDims.push_back(i); + } + + SmallVector vSrcOffsets; + if (IsShareOutput) { + for (unsigned i = 0; i < nInitStrideDims.size(); ++i) + vSrcOffsets.push_back(zero); + } else { + // currently pinpong only support share layout + assert((!llvm::isa_and_nonnull(op))); + auto load = llvm::cast(op); + auto loadType = load.getType(); + auto numElems = triton::gcu::getElemsPerThread(loadType); + const auto &warpIds = getWarpIds(rewriter, loc, loadType); + for (auto dim : nInitStrideDims) { + Value offset = rewriter.create( + loc, warpIds[dim], + rewriter.create(loc, numElems[dim])); + vSrcOffsets.push_back(offset); + } + } + + SmallVector vSrcStrides; + SmallVector vSrcShapes; + for (auto dim : nInitStrideDims) { + vSrcStrides.push_back(configStrides[dim]); + vSrcShapes.push_back(configShapes[dim]); + } + + SmallVector vResultShapes; + SmallVector resultShapes; + for (unsigned i = 0; i < rank; ++i) { + resultShapes.push_back(resultType.getShape()[i]); + vResultShapes.push_back( + rewriter.create(loc, resultShapes[i])); + } + + Value reshapeOut = srcOut; + if (bReshape) { + assert(rank < 4 && "not support stride is no 1 for rank >=4"); + vSrcOffsets.push_back(zero); + vSrcShapes.push_back(one); + vSrcStrides.push_back(one); + resultShapes.push_back(1); + vResultShapes.push_back(one); + if (bDynamicStride) { + order_hint.push_back(1); + nInitStrideDims.push_back(rank); + } else { + for (int i = 0; i < rank; ++i) + order_hint[i]--; + order_hint.push_back(rank); + } + + rank += 1; + auto reshapeMemrefType = MemRefType::get(resultShapes, elemType); + auto [reshapeStrides, reshapeOffset] = + reshapeMemrefType.getStridesAndOffset(); + reshapeOut = rewriter.create( + loc, reshapeMemrefType, srcOut, reshapeOffset, resultShapes, + reshapeStrides); + } + + if (rank == 2 && bDynamicStride) { + if (order_hint[1] == 1) { + order_hint[0] = 0; + order_hint[1] = 1; + bDynamicStride = false; + } else if (order_hint[0] == 1) { + order_hint[0] = 1; + order_hint[1] = 0; + bDynamicStride = false; + } + } + + SmallVector vOrderStrides; + SmallVector vOrderShapes; + SmallVector vOrderOffsets; + SmallVector vTempOrder; + SmallVector vTransOrder; + if (bDynamicStride) { + GetOrderValueByStride(rewriter, loc, nInitStrideDims, vSrcStrides, + vSrcShapes, vSrcOffsets, vOrderStrides, vOrderShapes, + vOrderOffsets, vTempOrder); + for (auto updateDim : updateStrideDims) { + auto updateStride = rewriter.create( + loc, vOrderStrides[updateDim], vOrderShapes[updateDim]); + vOrderStrides.insert(vOrderStrides.begin() + updateDim, updateStride); + vSrcStrides.insert(vSrcStrides.begin() + updateDim, updateStride); + vOrderShapes.insert(vOrderShapes.begin() + updateDim, one); + vSrcShapes.insert(vSrcShapes.begin() + updateDim, one); + vOrderOffsets.insert(vOrderOffsets.begin() + updateDim, + rewriter.create( + loc, rewriter.getI32Type(), zero)); + vSrcOffsets.insert(vSrcOffsets.begin() + updateDim, zero); + vTempOrder.insert(vTempOrder.begin() + updateDim, + rewriter.create( + loc, updateDim, rewriter.getI32Type())); + } + GetTransByOrder(rewriter, loc, vTempOrder, vTransOrder); + } else { + SmallVector static_order(order_hint.begin(), order_hint.end()); + for (int i = 0; i < rank; ++i) { + vOrderStrides.push_back(vSrcStrides[static_order[i]]); + vOrderOffsets.push_back(rewriter.create( + loc, rewriter.getI32Type(), vSrcOffsets[static_order[i]])); + vTransOrder.push_back(rewriter.create( + loc, static_order[i], rewriter.getI32Type())); + } + if (static_order.size() > 0) + vOrderShapes.push_back(vSrcShapes[static_order[0]]); + for (int i = 0; i < rank - 1; ++i) { + vOrderShapes.push_back(rewriter.create( + loc, vOrderStrides[i], vOrderStrides[i + 1])); + } + + for (int i = 0; i < rank; ++i) { + if (static_order[i] != i) { + bStaticTranspose = true; + break; + } + } + } + + SmallVector vSlicehape; + SmallVector vIntSlicehape; + Value totalSize = one; + for (unsigned i = 0; i < rank; ++i) { + auto shape = rewriter.create( + loc, vResultShapes[i], + rewriter.create( + loc, zero, + rewriter.create(loc, vSrcShapes[i], + vSrcOffsets[i]))); + vSlicehape.push_back(shape); + vIntSlicehape.push_back( + rewriter.create(loc, rewriter.getI32Type(), shape)); + totalSize = rewriter.create(loc, totalSize, shape); + } + + SmallVector padOffsets; + SmallVector padSizes; + Value padSize = zero; + for (unsigned i = 0; i < rank; ++i) { + auto dim_diff = + rewriter.create(loc, vResultShapes[i], vSlicehape[i]); + padSize = rewriter.create(loc, padSize, dim_diff); + padSizes.push_back(dim_diff); + padOffsets.push_back( + rewriter.create(loc, rewriter.getI32Type(), zero)); + } + + auto isNeedPad = rewriter.create( + loc, arith::CmpIPredicate::sgt, padSize, zero); + + Value isDynamicTrans = rewriter.create( + loc, arith::CmpIPredicate::ne, vTransOrder[0], + rewriter.create(loc, rewriter.getI32Type(), zero)); + for (unsigned i = 1; i < rank; ++i) { + auto isDimTrans = rewriter.create( + loc, arith::CmpIPredicate::ne, vTransOrder[i], + rewriter.create(loc, i, rewriter.getI32Type())); + isDynamicTrans = + rewriter.create(loc, isDynamicTrans, isDimTrans); + } + + auto sourceType = MemRefType::get( + SmallVector(rank, ShapedType::kDynamic), elemType); + auto src = rewriter.create( + loc, sourceType, buffer, zero, vOrderShapes, vOrderStrides); + + if (IsShareOutput) { + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + auto isPad = rewriter.create(loc, isNeedPad, isThread0); + if (bDynamicStride) { + auto isTrans = + rewriter.create(loc, isDynamicTrans, isThread0); + auto isAll = rewriter.create(loc, isPad, isTrans); + rewriter.create( + loc, isAll, + [&](OpBuilder &builder, Location loc) { + auto trans_buffer = rewriter.create( + loc, sourceType, transOut, zero, vSlicehape, vSrcStrides); + builder.create( + loc, trans_buffer, src, vOrderOffsets, vTransOrder, + defaultValue, tagDte, ValueRange{tagIdx}); + builder.create(loc, tagDte, ValueRange{tagIdx}, + totalSize); + doSlicePadOrMemsetSlice(builder, loc, op, reshapeOut, trans_buffer, + padOffsets, vIntSlicehape, padSizes, + defaultValue, tagDte, tagIdx); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isTrans, + [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, reshapeOut, src, vOrderOffsets, vTransOrder, + defaultValue, tagDte, ValueRange{tagIdx}); + childBuilder.create(loc); + }, + [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, isPad, + [&](OpBuilder &child2Builder, Location loc) { + doSlicePadOrMemsetSlice(child2Builder, loc, op, + reshapeOut, src, vOrderOffsets, + vIntSlicehape, padSizes, + defaultValue, tagDte, tagIdx); + child2Builder.create(loc); + }, + [&](OpBuilder &child2Builder, Location loc) { + child2Builder.create( + loc, isThread0, + [&](OpBuilder &child3Builder, Location loc) { + child3Builder.create( + loc, reshapeOut, src, vOrderOffsets, + defaultValue, tagDte, ValueRange{tagIdx}); + child3Builder.create(loc); + }); + child2Builder.create(loc); + }); + childBuilder.create(loc); + }); + builder.create(loc); + }); + } else if (bStaticTranspose) { + rewriter.create( + loc, isPad, + [&](OpBuilder &builder, Location loc) { + auto trans_buffer = rewriter.create( + loc, sourceType, transOut, zero, vSlicehape, vSrcStrides); + builder.create( + loc, trans_buffer, src, vOrderOffsets, vTransOrder, + defaultValue, tagDte, ValueRange{tagIdx}); + builder.create(loc, tagDte, ValueRange{tagIdx}, + totalSize); + doSlicePadOrMemsetSlice(builder, loc, op, reshapeOut, trans_buffer, + padOffsets, vIntSlicehape, padSizes, + defaultValue, tagDte, tagIdx); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isThread0, [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, reshapeOut, src, vOrderOffsets, vTransOrder, + defaultValue, tagDte, ValueRange{tagIdx}); + childBuilder.create(loc); + }); + builder.create(loc); + }); + } else { + rewriter.create( + loc, isPad, + [&](OpBuilder &builder, Location loc) { + doSlicePadOrMemsetSlice(builder, loc, op, reshapeOut, src, + vOrderOffsets, vIntSlicehape, padSizes, + defaultValue, tagDte, tagIdx); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isThread0, [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, reshapeOut, src, vOrderOffsets, defaultValue, tagDte, + ValueRange{tagIdx}); + childBuilder.create(loc); + }); + builder.create(loc); + }); + } + } else { + auto isNotZero = rewriter.create( + loc, arith::CmpIPredicate::ne, totalSize, zero); + auto isPad = rewriter.create(loc, isNeedPad, isNotZero); + if (bDynamicStride) { + auto isTrans = + rewriter.create(loc, isDynamicTrans, isNotZero); + auto isAll = rewriter.create(loc, isPad, isTrans); + rewriter.create( + loc, isAll, + [&](OpBuilder &builder, Location loc) { + auto trans_buffer = rewriter.create( + loc, sourceType, transOut, zero, vSlicehape, vSrcStrides); + builder.create( + loc, trans_buffer, src, vOrderOffsets, vTransOrder, + defaultValue, tagDte, ValueRange{tagIdx}); + builder.create(loc, tagDte, ValueRange{tagIdx}, + totalSize); + doSlicePadOrMemsetSlice(builder, loc, op, reshapeOut, trans_buffer, + padOffsets, vIntSlicehape, padSizes, + defaultValue, tagDte, tagIdx); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isTrans, + [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, reshapeOut, src, vOrderOffsets, vTransOrder, + defaultValue, tagDte, ValueRange{tagIdx}); + childBuilder.create(loc); + }, + [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, isPad, + [&](OpBuilder &child2Builder, Location loc) { + doSlicePadOrMemsetSlice(child2Builder, loc, op, + reshapeOut, src, vOrderOffsets, + vIntSlicehape, padSizes, + defaultValue, tagDte, tagIdx); + child2Builder.create(loc); + }, + [&](OpBuilder &child2Builder, Location loc) { + child2Builder.create( + loc, isNotZero, + [&](OpBuilder &child3Builder, Location loc) { + child3Builder.create( + loc, reshapeOut, src, vOrderOffsets, + defaultValue, tagDte, ValueRange{tagIdx}); + child3Builder.create(loc); + }, + [&](OpBuilder &child3Builder, Location loc) { + if (getDefaultValue(op)) { + doMemsetConfig(child3Builder, loc, reshapeOut, + defaultValue, tagDte, tagIdx); + } + child3Builder.create(loc); + }); + child2Builder.create(loc); + }); + childBuilder.create(loc); + }); + builder.create(loc); + }); + } else if (bStaticTranspose) { + rewriter.create( + loc, isPad, + [&](OpBuilder &builder, Location loc) { + auto trans_buffer = rewriter.create( + loc, sourceType, transOut, zero, vSlicehape, vSrcStrides); + builder.create( + loc, trans_buffer, src, vOrderOffsets, vTransOrder, + defaultValue, tagDte, ValueRange{tagIdx}); + builder.create(loc, tagDte, ValueRange{tagIdx}, + totalSize); + doSlicePadOrMemsetSlice(builder, loc, op, reshapeOut, trans_buffer, + padOffsets, vIntSlicehape, padSizes, + defaultValue, tagDte, tagIdx); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isNotZero, + [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, reshapeOut, src, vOrderOffsets, vTransOrder, + defaultValue, tagDte, ValueRange{tagIdx}); + childBuilder.create(loc); + }, + [&](OpBuilder &childBuilder, Location loc) { + if (getDefaultValue(op)) { + doMemsetConfig(childBuilder, loc, reshapeOut, defaultValue, + tagDte, tagIdx); + } + childBuilder.create(loc); + }); + builder.create(loc); + }); + } else { + rewriter.create( + loc, isPad, + [&](OpBuilder &builder, Location loc) { + doSlicePadOrMemsetSlice(builder, loc, op, reshapeOut, src, + vOrderOffsets, vIntSlicehape, padSizes, + defaultValue, tagDte, tagIdx); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isNotZero, + [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, reshapeOut, src, vOrderOffsets, defaultValue, tagDte, + ValueRange{tagIdx}); + childBuilder.create(loc); + }, + [&](OpBuilder &childBuilder, Location loc) { + if (getDefaultValue(op)) { + doMemsetConfig(childBuilder, loc, reshapeOut, defaultValue, + tagDte, tagIdx); + } + childBuilder.create(loc); + }); + builder.create(loc); + }); + } + } + return totalSize; +} + +Value ConfigGcuStore(OpBuilder &rewriter, Location loc, Value storeValue, + Value transOut, mlir::Operation *op, + MemRefType storeValueType, Value storePtr, + mlir::ValueRange configStrides, + mlir::ValueRange configShapes, Value tagDte, + Value tagIdx) { + auto storeOp = dyn_cast(op); + assert(storeOp); + + auto storeType = storeOp.getValue().getType(); + auto elemType = storeOp.getPtr().getType().getElementType(); + auto buffer = rewriter.create( + loc, MemRefType::get(ArrayRef{ShapedType::kDynamic}, elemType), + storePtr); + + int64_t rank = storeValueType.getRank(); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto zero32 = + rewriter.create(loc, rewriter.getI32Type(), zero); + + bool bDynamicStride = false; + bool bStaticTranspose = false; + bool bReshape = true; + SmallVector updateStrideDims; + SmallVector nInitStrideDims; + auto hint = storeOp.getOrderHint(); + int64_t hint_size = static_cast(hint.size()); + assert(hint_size == rank || hint_size == 0); + SmallVector order_hint; + for (unsigned i = 0; i < rank; ++i) + if (hint_size == 0) + order_hint.push_back(-1); + else + order_hint.push_back(hint[i]); + + for (unsigned i = 0; i < rank; ++i) { + if (order_hint[i] == -1) { + bDynamicStride = true; + auto trueCondition = rewriter.create( + loc, arith::CmpIPredicate::ne, configStrides[i], zero); + rewriter.create( + loc, trueCondition, "Not Support dynamic stride is 0", "", "", 0); + } + } + + for (int i = 0; i < rank; ++i) { + if ((order_hint[i] == 0 && !bDynamicStride) || + (order_hint[i] == 1 && bDynamicStride)) { + bReshape = false; + break; + } + } + + for (int i = 0; i < rank; ++i) { + if (bDynamicStride && order_hint[i] == 0) + updateStrideDims.push_back(i); + else + nInitStrideDims.push_back(i); + } + + SmallVector vSrcOffsets; + auto numElems = triton::gcu::getElemsPerThread(storeType); + SmallVector vNumElems; + for (unsigned i = 0; i < rank; ++i) + vNumElems.push_back( + rewriter.create(loc, numElems[i])); + auto warpIds = getWarpIds(rewriter, loc, storeType); + for (auto dim : nInitStrideDims) { + Value offset = + rewriter.create(loc, warpIds[dim], vNumElems[dim]); + vSrcOffsets.push_back(offset); + } + + SmallVector vSrcStrides; + SmallVector vSrcShapes; + for (auto dim : nInitStrideDims) { + vSrcStrides.push_back(configStrides[dim]); + vSrcShapes.push_back(configShapes[dim]); + } + + SmallVector vStoreShapes; + SmallVector storeShapes; + for (unsigned i = 0; i < rank; ++i) { + storeShapes.push_back(storeValueType.getShape()[i]); + vStoreShapes.push_back( + rewriter.create(loc, storeShapes[i])); + } + + Value reshapeStoreValue = storeValue; + if (bReshape) { + assert(rank < 4 && "not support stride is no 1 for rank >=4"); + vSrcOffsets.push_back(zero); + vSrcShapes.push_back(one); + vSrcStrides.push_back(one); + storeShapes.push_back(1); + vStoreShapes.push_back(one); + vNumElems.push_back(one); + if (bDynamicStride) { + order_hint.push_back(1); + nInitStrideDims.push_back(rank); + } else { + for (int i = 0; i < rank; ++i) + order_hint[i]--; + order_hint.push_back(rank); + } + rank += 1; + auto reshapeStoreType = MemRefType::get(storeShapes, elemType); + auto [reshapeStrides, reshapeOffset] = + reshapeStoreType.getStridesAndOffset(); + reshapeStoreValue = rewriter.create( + loc, reshapeStoreType, storeValue, reshapeOffset, storeShapes, + reshapeStrides); + } + + if (rank == 2 && bDynamicStride) { + if (order_hint[1] == 1) { + order_hint[0] = 0; + order_hint[1] = 1; + bDynamicStride = false; + } else if (order_hint[0] == 1) { + order_hint[0] = 1; + order_hint[1] = 0; + bDynamicStride = false; + } + } + + SmallVector vOrderStrides; + SmallVector vOrderShapes; + SmallVector vOrderOffsets; + SmallVector vTransOrder; + SmallVector vTempOrder; + if (bDynamicStride) { + GetOrderValueByStride(rewriter, loc, nInitStrideDims, vSrcStrides, + vSrcShapes, vSrcOffsets, vOrderStrides, vOrderShapes, + vOrderOffsets, vTempOrder); + for (auto updateDim : updateStrideDims) { + auto updateStride = rewriter.create( + loc, vOrderStrides[updateDim], vOrderShapes[updateDim]); + vOrderStrides.insert(vOrderStrides.begin() + updateDim, updateStride); + vSrcStrides.insert(vSrcStrides.begin() + updateDim, updateStride); + vOrderShapes.insert(vOrderShapes.begin() + updateDim, one); + vSrcShapes.insert(vSrcShapes.begin() + updateDim, one); + vOrderOffsets.insert(vOrderOffsets.begin() + updateDim, + rewriter.create( + loc, rewriter.getI32Type(), zero)); + vSrcOffsets.insert(vSrcOffsets.begin() + updateDim, zero); + vTempOrder.insert(vTempOrder.begin() + updateDim, + rewriter.create( + loc, updateDim, rewriter.getI32Type())); + } + GetTransByOrder(rewriter, loc, vTempOrder, vTransOrder); + } else { + SmallVector static_order(order_hint.begin(), order_hint.end()); + for (int i = 0; i < rank; ++i) { + vOrderStrides.push_back(vSrcStrides[static_order[i]]); + vOrderOffsets.push_back(rewriter.create( + loc, rewriter.getI32Type(), vSrcOffsets[static_order[i]])); + vTransOrder.push_back(rewriter.create( + loc, static_order[i], rewriter.getI32Type())); + } + if (static_order.size() > 0) + vOrderShapes.push_back(vSrcShapes[static_order[0]]); + for (int i = 0; i < rank - 1; ++i) { + vOrderShapes.push_back(rewriter.create( + loc, vOrderStrides[i], vOrderStrides[i + 1])); + } + + for (int i = 0; i < rank; ++i) { + if (static_order[i] != i) { + bStaticTranspose = true; + break; + } + } + } + + SmallVector vSlicehape; + SmallVector vIntSlicehape; + Value totalSize = one; + for (unsigned i = 0; i < rank; ++i) { + auto shape = rewriter.create( + loc, vNumElems[i], + rewriter.create( + loc, zero, + rewriter.create(loc, vSrcShapes[i], + vSrcOffsets[i]))); + vSlicehape.push_back(shape); + vIntSlicehape.push_back( + rewriter.create(loc, rewriter.getI32Type(), shape)); + totalSize = rewriter.create(loc, totalSize, shape); + } + + SmallVector vOrderSlicehape; + if (bDynamicStride) { + GetOrderSlicefor30(rewriter, loc, rank, vSrcStrides, vSlicehape, + vOrderSlicehape); + } else { + for (int i = 0; i < rank; ++i) { + SmallVector static_order(order_hint.begin(), + order_hint.end()); + vOrderSlicehape.push_back(vSlicehape[static_order[i]]); + } + } + + SmallVector sliceOffsets(rank, zero32); + Value diff = zero; + for (unsigned i = 0; i < rank; ++i) { + auto dim_diff = + rewriter.create(loc, vStoreShapes[i], vSlicehape[i]); + diff = rewriter.create(loc, diff, dim_diff); + } + + auto resultType = MemRefType::get( + SmallVector(rank, ShapedType::kDynamic), elemType); + auto dst = rewriter.create( + loc, resultType, buffer, zero, vOrderShapes, vOrderStrides); + + auto isNeedSlice = rewriter.create( + loc, arith::CmpIPredicate::sgt, diff, zero); + + auto isNotZero = rewriter.create(loc, arith::CmpIPredicate::ne, + totalSize, zero); + + Value isDynamicTrans = rewriter.create( + loc, arith::CmpIPredicate::ne, vTransOrder[0], + rewriter.create(loc, rewriter.getI32Type(), zero)); + for (unsigned i = 1; i < rank; ++i) { + auto isDimTrans = rewriter.create( + loc, arith::CmpIPredicate::ne, vTransOrder[i], + rewriter.create(loc, i, rewriter.getI32Type())); + isDynamicTrans = + rewriter.create(loc, isDynamicTrans, isDimTrans); + } + + if (bDynamicStride) { + auto isTrans = + rewriter.create(loc, isDynamicTrans, isNotZero); + auto isSlice = rewriter.create(loc, isNeedSlice, isNotZero); + auto isAll = rewriter.create(loc, isSlice, isTrans); + rewriter.create( + loc, isAll, + [&](OpBuilder &builder, Location loc) { + auto trans_buffer = rewriter.create( + loc, resultType, transOut, zero, vOrderSlicehape, vOrderStrides); + builder.create( + loc, trans_buffer, reshapeStoreValue, sliceOffsets, vTransOrder, + triton::gcu::createConstantZero(rewriter, loc, elemType), tagDte, + ValueRange{tagIdx}); + builder.create(loc, tagDte, ValueRange{tagIdx}, + totalSize); + builder.create(loc, dst, trans_buffer, + vOrderOffsets, tagDte, + ValueRange{tagIdx}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isTrans, + [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, dst, reshapeStoreValue, vTransOrder, vOrderOffsets, + tagDte, ValueRange{tagIdx}); + childBuilder.create(loc); + }, + [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, isSlice, + [&](OpBuilder &child2Builder, Location loc) { + child2Builder.create( + loc, dst, reshapeStoreValue, sliceOffsets, + vIntSlicehape, vOrderOffsets, tagDte, + ValueRange{tagIdx}); + child2Builder.create(loc); + }, + [&](OpBuilder &child2Builder, Location loc) { + child2Builder.create( + loc, isNotZero, + [&](OpBuilder &child3Builder, Location loc) { + child3Builder.create( + loc, dst, reshapeStoreValue, vOrderOffsets, + tagDte, ValueRange{tagIdx}); + child3Builder.create(loc); + }); + child2Builder.create(loc); + }); + childBuilder.create(loc); + }); + builder.create(loc); + }); + } else if (bStaticTranspose) { + auto isSlice = rewriter.create(loc, isNeedSlice, isNotZero); + rewriter.create( + loc, isSlice, + [&](OpBuilder &builder, Location loc) { + auto trans_buffer = rewriter.create( + loc, resultType, transOut, zero, vOrderSlicehape, vOrderStrides); + builder.create( + loc, trans_buffer, reshapeStoreValue, sliceOffsets, vTransOrder, + triton::gcu::createConstantZero(rewriter, loc, elemType), tagDte, + ValueRange{tagIdx}); + builder.create(loc, tagDte, ValueRange{tagIdx}, + totalSize); + builder.create(loc, dst, trans_buffer, + vOrderOffsets, tagDte, + ValueRange{tagIdx}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isNotZero, [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, dst, reshapeStoreValue, vTransOrder, vOrderOffsets, + tagDte, ValueRange{tagIdx}); + childBuilder.create(loc); + }); + builder.create(loc); + }); + } else { + auto isSlice = rewriter.create(loc, isNeedSlice, isNotZero); + rewriter.create( + loc, isSlice, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, dst, reshapeStoreValue, sliceOffsets, vIntSlicehape, + vOrderOffsets, tagDte, ValueRange{tagIdx}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, isNotZero, [&](OpBuilder &childBuilder, Location loc) { + childBuilder.create( + loc, dst, reshapeStoreValue, vOrderOffsets, tagDte, + ValueRange{tagIdx}); + childBuilder.create(loc); + }); + builder.create(loc); + }); + } + return totalSize; +} + +void WaitGcuLoadStore(OpBuilder &rewriter, Location loc, Value tagDte, + Value tagIdx, Value totalSize) { + rewriter.create(loc, tagDte, ValueRange{tagIdx}, + totalSize); +} + +void moveDeallocOp(ConversionPatternRewriter &rewriter, Value v, Operation *pos, + size_t depth) { + if (depth > 1) + return; + + Operation *allocOp = v.getDefiningOp(); + if (llvm::isa(allocOp)) { + // not define in current block; + return; + } + unsigned operandIdx = cast(v).getResultNumber(); + while (allocOp && !mlir::isa(allocOp)) { + mlir::TypeSwitch(allocOp) + .Case( + [&](auto castOp) { + allocOp = castOp.getSource().getDefiningOp(); + operandIdx = cast(castOp.getSource()).getResultNumber(); + }) + .Case([&](auto forOp) { + auto yieldOp = + llvm::cast(forOp.getBody()->getTerminator()); + Value operand = yieldOp.getOperands()[operandIdx]; + if (rewriter.getRemappedValue(operand)) { + operand = rewriter.getRemappedValue(operand); + } + allocOp = operand.getDefiningOp(); + operandIdx = cast(operand).getResultNumber(); + + Value initValue = forOp.getInitArgs()[operandIdx]; + if (rewriter.getRemappedValue(initValue)) { + initValue = rewriter.getRemappedValue(initValue); + } + moveDeallocOp(rewriter, initValue, pos, ++depth); + }) + .Case([&](auto ifOp) { + auto thenYieldOp = ifOp.thenYield(); + Value operand = thenYieldOp.getOperands()[operandIdx]; + if (rewriter.getRemappedValue(operand)) { + operand = rewriter.getRemappedValue(operand); + } + allocOp = operand.getDefiningOp(); + operandIdx = cast(operand).getResultNumber(); + + if (ifOp.getNumRegions() > 1) { + auto elseYieldOp = ifOp.elseYield(); + operand = elseYieldOp.getOperands()[operandIdx]; + if (rewriter.getRemappedValue(operand)) { + operand = rewriter.getRemappedValue(operand); + } + moveDeallocOp(rewriter, operand, pos, ++depth); + } + }) + .Default([&](auto op) { allocOp = nullptr; }); + } + if (!allocOp) + llvm_unreachable("can't find allocation position"); + + Operation *deallocOp = nullptr; + for (const auto &user : allocOp->getUsers()) { + if (llvm::isa(user)) { + deallocOp = user; + break; + } + } + if (deallocOp && deallocOp->getBlock() == pos->getBlock()) { + deallocOp->moveAfter(pos); + } +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUToGCUUtils.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUToGCUUtils.h new file mode 100644 index 000000000..92d045e16 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGCUToGCU/TritonGCUToGCUUtils.h @@ -0,0 +1,158 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KURAMA_TRITONGCU_TO_GCU_UTILS_H_ +#define KURAMA_TRITONGCU_TO_GCU_UTILS_H_ + +#include + +#include "ConstantUtil.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace triton { +namespace gcu { +class FirstLastUserAnalysis; +} +} // namespace triton +} // namespace mlir +using namespace mlir; + +Value getPrivateDTETag(OpBuilder &builder, Operation *op); +Value getShareDTETag(OpBuilder &builder, Operation *op); +Value createPrivateDTETag(OpBuilder &builder, Operation *op); +DenseSet getSlicedAxies(Type type); +SmallVector getWarpIds(OpBuilder &builder, Location loc, Type type); +SmallVector getElemsPerThread(OpBuilder &builder, Location loc, + Type type); +func::FuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, + OpBuilder &rewriter, StringRef name, + FunctionType type); +void doMemFence(OpBuilder &rewriter, Operation *op); + +void doMemsetConfig(OpBuilder &rewriter, Location loc, Value output, Value v, + Value tagDte, Value tagIdx); +void doMemset(OpBuilder &rewriter, Operation *op, Value output, Value v, + unsigned totalNumElems); + +Value castToMemref1D(OpBuilder &rewriter, Location loc, Value v, + Value totalNumElems); +bool isMustAliasOp(mlir::Operation *op); + +mlir::Operation * +promoteLastUser(mlir::Operation *&lastUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin); + +void addDeallocAfterLastUser(OpBuilder &builder, mlir::Operation *lastUser, + Value alloc); +Value syncAllocOp(OpBuilder &builder, Location &loc, Operation *lastUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin, + MemRefType type); +Value asyncAllocOp(OpBuilder &builder, Operation *ttParent, MemRefType type); + +void createPrintfOp(ConversionPatternRewriter &rewriter, Location loc, + ::llvm::StringRef printOpPrefix, bool hex, Value value); + +void enterTritionOp(ConversionPatternRewriter &rewriter, Operation *ttParent); + +void leaveTritionOp(ConversionPatternRewriter &rewriter, Operation *ttParent); + +Value loadFromSharedMem(OpBuilder &builder, Value tag, Type type, Value buffer, + bool onlyThread0, Operation *lastTTUser, + Operation *firstTTUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin); +Value CopyFromSharedMem(OpBuilder &builder, Value tag, Type type, Value buffer, + bool onlyThread0, Operation *lastTTUser, + Operation *firstTTUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin); + +Value loadFromSharedMemForDotOperand( + OpBuilder builder, Value tag, Type type, ArrayRef mnShape, + Value sharedBuffer, Operation *lastTTUser, Operation *firstTTUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin); + +void storeToSharedMem(OpBuilder &builder, Value tag, TensorType type, + Value sharedBuffer, Value buffer, bool onlyThread0); +Value storeToSharedMem(OpBuilder &builder, Value tag, TensorType type, + Value buffer, bool onlyThread0, Operation *lastTTUser, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin); +void AnalysisYieldOperendUseStage( + Operation *module, triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map> + &TTYeiledOPerandHasMultiUseStage); + +void GetOrderValueByStride( + OpBuilder &rewriter, Location loc, SmallVector nInitStrideDims, + SmallVector &initStride, SmallVector &initShape, + SmallVector &initOffset, SmallVector &orderStride, + SmallVector &orderShape, SmallVector &orderOffset, + SmallVector &vOrder); + +void GetOrderSlicefor30(OpBuilder &rewriter, Location loc, int64_t rank, + SmallVector &initStride, + SmallVector &initSliceShape, + SmallVector &orderSliceShape); + +Value ConfigGcuLoad(OpBuilder &rewriter, Location loc, Value srcOut, + Value transOut, mlir::Operation *op, MemRefType resultType, + Value loadPtr, mlir::ValueRange configStrides, + mlir::ValueRange configShapes, Value defaultValue, + Value tagDte, Value tagIdx, bool IsShareOutput = false); + +Value ConfigGcuStore(OpBuilder &rewriter, Location loc, Value storeValue, + Value transOut, mlir::Operation *op, + MemRefType storeValueType, Value storePtr, + mlir::ValueRange configStrides, + mlir::ValueRange configShapes, Value tagDte, Value tagIdx); + +void WaitGcuLoadStore(OpBuilder &rewriter, Location loc, Value tagDte, + Value tagIdx, Value totalSize); + +void moveDeallocOp(ConversionPatternRewriter &rewriter, Value v, Operation *pos, + size_t depth); + +void mergeContinuousDims(OpBuilder &subBuilder, Location loc, + Value &sharedMemref, Value &warpMemref, + SmallVector &offsets, + SmallVector &mergedOffsets, + MemRefType &sharedMemType, MemRefType &warpMemType, + Value &sharedBuffer, Value &warpOutput); +#endif diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGPUtoTrtionGCU.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGPUtoTrtionGCU.cpp new file mode 100644 index 000000000..9fdacf669 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonGPUtoTrtionGCU.cpp @@ -0,0 +1,87 @@ +/* + * Copyright 2020 - 2022 Enflame.All Rights Reserved. + * + */ + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +#include "Utils.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +namespace mlir { +#define GEN_PASS_DEF_TRITONGPUTOTRITONGCUPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir +#define DEBUG_TYPE "triton-gpu-to-triton-gcu" +namespace { +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +struct TritonGPUToTritonGCUPass + : public mlir::impl::TritonGPUToTritonGCUPassBase< + TritonGPUToTritonGCUPass> { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void rewriteLoalLoad(); + void runOnOperation() override; +}; + +void TritonGPUToTritonGCUPass::rewriteLoalLoad() { + auto trionModule = getOperation(); + llvm::SmallVector localAllocList; + trionModule.walk([&](triton::gpu::LocalAllocOp alloc) { + localAllocList.push_back(alloc); + }); + for (auto &alloc : localAllocList) { + for (auto user : alloc->getUsers()) { + if (llvm::isa_and_nonnull(user)) { + OpBuilder rewriter(user); + auto localLoad = cast(user); + auto convert = rewriter.create( + user->getLoc(), localLoad.getType(), alloc.getSrc()); + localLoad.getResult().replaceAllUsesWith(convert.getResult()); + localLoad.erase(); + } else if (llvm::isa_and_nonnull(user)) { + user->erase(); + } else { + user->dump(); + trionModule.dump(); + llvm::report_fatal_error("please check IR can't rewrite"); + } + alloc.erase(); + } + } +} + +} // namespace +using namespace mlir; +void TritonGPUToTritonGCUPass::runOnOperation() { + LLVM_DEBUG(llvm::dbgs() << "TritonGPUToTritonGCUPass\n"); + rewriteLoalLoad(); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonLoadStoreToDma.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonLoadStoreToDma.cpp new file mode 100644 index 000000000..d73dc1de2 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonLoadStoreToDma.cpp @@ -0,0 +1,643 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +#include "ConstantUtil.h" + +#include "Analysis/MaskAnalysis.h" +#include "Analysis/OpFoldResultUtils.h" +#include "Analysis/PtrAnalysis.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +// todo_AT gpu moduleop +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "triton-loadstore-to-dma" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTTRITONLOADSTORETOGCUDMAPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct ConvertTritonLoadStoreToDmaPass + : public mlir::impl::ConvertTritonLoadStoreToGCUDmaPassBase< + ConvertTritonLoadStoreToDmaPass> { + using Base::Base; + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } +}; + +struct PreprocessForOp : public OpRewritePattern { + llvm::SmallDenseMap &knownPtrs; + llvm::SmallDenseMap &knownMasks; + llvm::SmallVector &candidateOps; + llvm::SmallDenseMap> &candidateHints; + + explicit PreprocessForOp( + MLIRContext *context, + llvm::SmallDenseMap &knownPtrs, + llvm::SmallDenseMap &knownMasks, + llvm::SmallVector &candidateOps, + llvm::SmallDenseMap> &candidateHints) + : OpRewritePattern(context), knownPtrs(knownPtrs), + knownMasks(knownMasks), candidateOps(candidateOps), + candidateHints(candidateHints) {} + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + if (gcu::PtrAnalysis::byPassForOp(rewriter, op, candidateOps)) + return failure(); + return gcu::PtrAnalysis::rewriteForOp(rewriter, op, knownPtrs, knownMasks, + candidateOps, candidateHints); + } +}; + +struct PostprocessForOp : public OpRewritePattern { + llvm::SmallDenseMap &knownPtrs; + + explicit PostprocessForOp( + MLIRContext *context, + llvm::SmallDenseMap &knownPtrs) + : OpRewritePattern(context), knownPtrs(knownPtrs) {} + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + gcu::PtrAnalysis::foldAwayForOp(rewriter, op, knownPtrs); + return failure(); + } +}; + +bool IsStaticStride(SmallVector &candidateHints) { + bool bStaticStride = true; + int32_t rank = candidateHints.size(); + for (int32_t i = 0; i < rank; ++i) { + if (candidateHints[i] == -1) { + bStaticStride = false; + break; + } + } + + return bStaticStride; +} + +bool IsStaticReshape(SmallVector &candidateHints) { + bool isReshape = true; + int32_t rank = candidateHints.size(); + if (IsStaticStride(candidateHints)) { + for (int32_t i = 0; i < rank; ++i) { + if (candidateHints[i] == 1 || candidateHints[i] == 0) { + isReshape = false; + break; + } + } + } else { + isReshape = false; + } + + return isReshape; +} + +SmallVector GetOrderByHint(SmallVector &candidateHints) { + SmallVector orderHint; + int32_t rank = candidateHints.size(); + assert(IsStaticStride(candidateHints) && + "dynamic stride not support get static order"); + + SmallVector broadcastDims; + for (int32_t i = 0; i < rank; ++i) + if (candidateHints[i] == 0) + broadcastDims.push_back(i); + + for (int32_t i = 0; i < rank; ++i) + if (candidateHints[i] != 0) + orderHint.push_back(i); + + std::sort(orderHint.begin(), orderHint.end(), [&](int32_t a, int32_t b) { + return (candidateHints[a] > candidateHints[b]); + }); + + if (orderHint.size() < static_cast(rank)) + for (auto dim : broadcastDims) + orderHint.insert(orderHint.begin() + dim, dim); + + SmallVector transOrder(rank, 0); + for (int32_t i = 0; i < rank; ++i) + transOrder[orderHint[i]] = i; + + for (int32_t i = 0; i < rank; ++i) + LLVM_DEBUG(llvm::dbgs() << "dim: " << i << "\n" + << "order: " << orderHint[i] << "\n"); + + for (int32_t i = 0; i < rank; ++i) + LLVM_DEBUG(llvm::dbgs() << "trans order: " << i << "\n" + << "order: " << transOrder[i] << "\n"); + + return transOrder; +} + +struct ConvertLoadOpToDma : public OpRewritePattern { + llvm::SmallDenseMap &knownPtrs; + llvm::SmallDenseMap &knownMasks; + llvm::SmallVector &candidateOps; + llvm::SmallDenseMap> &candidateHints; + + explicit ConvertLoadOpToDma( + MLIRContext *context, + llvm::SmallDenseMap &knownPtrs, + llvm::SmallDenseMap &knownMasks, + llvm::SmallVector &candidateOps, + llvm::SmallDenseMap> &candidateHints) + : OpRewritePattern(context), knownPtrs(knownPtrs), + knownMasks(knownMasks), candidateOps(candidateOps), + candidateHints(candidateHints) {} + LogicalResult rewriteTensorLoad(triton::LoadOp op, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + // 2. Analyze the mask operand to determine at runtime the size of the data + // we are moving. + gcu::MaskState mstate; + if (op.getMask()) { + LLVM_DEBUG(llvm::dbgs() << "=== analyze load mask state ===\n"); + gcu::MaskAnalysis::parse(rewriter, loc, op.getMask(), mstate, knownMasks); + assert(!mstate.isEmpty() && + "expect valid mask state after analysis succeed\n"); + } + + // 3. Get ptr info + LLVM_DEBUG(llvm::dbgs() << "=== analyze load ptr state ===\n"); + gcu::PtrState pstate; + gcu::PtrAnalysis::visitOperand(rewriter, loc, op.getPtr(), pstate, + knownPtrs); + + // 4. Analyze the other operand to get a scalar value + Value defaultValue; + auto tType = dyn_cast(op.getType()); + if (op.getOther()) { + auto scalarValue = gcu::getScalarValue(rewriter, loc, op.getOther()); + assert(scalarValue.has_value() && + "other value used in masked load produced by " + "unsupported instruction"); + defaultValue = scalarValue.value(); + } else { + defaultValue = + gcu::createConstantZero(rewriter, loc, tType.getElementType()); + } + + auto rank = pstate.getRank(); + auto one = rewriter.create(loc, 1); + auto ptrInfo = pstate.getPtrInfo(rewriter, loc, mstate); + bool bNeedBroadCast = false; + auto resultShape = tType.getShape(); + auto elemType = tType.getElementType(); + for (auto dim : ptrInfo.broadcastDims) { + if (resultShape[dim] > 1) { + bNeedBroadCast = true; + break; + } + } + + assert(candidateHints.find(op.getOperation()) != candidateHints.end() && + "get order failed"); + auto opHint = candidateHints[op.getOperation()]; + // dynamic stride process + if (!IsStaticStride(opHint)) { + LLVM_DEBUG(llvm::dbgs() << "=== dynamic stride process ===\n"); + SmallVector updateStrides; + SmallVector updateShapes; + SmallVector dynamicOpHint; + for (int32_t i = 0; i < rank; ++i) { + if (!ptrInfo.broadcastDims.count(i)) { + updateStrides.push_back(ptrInfo.strides[i]); + updateShapes.push_back(ptrInfo.shape[i]); + dynamicOpHint.push_back(opHint[i]); + } else { + updateShapes.push_back(one); + if (i == rank - 1) { + dynamicOpHint.push_back(1); + updateStrides.push_back(one); + } else { + dynamicOpHint.push_back(opHint[i]); + updateStrides.push_back(ptrInfo.strides[i]); + } + } + } + + if (bNeedBroadCast) { + SmallVector sliceShape(rank); + for (unsigned int i = 0; i < rank; i++) { + if (ptrInfo.broadcastDims.count(i)) + sliceShape[i] = 1; + else + sliceShape[i] = resultShape[i]; + } + auto sliceType = + RankedTensorType::get(sliceShape, elemType, tType.getEncoding()); + auto load = rewriter.create( + loc, sliceType, ptrInfo.base, updateShapes, updateStrides, + ptrInfo.offsets, defaultValue, dynamicOpHint); + auto broadcastOp = + rewriter.create(loc, op.getType(), load); + rewriter.replaceOp(op, broadcastOp); + return success(); + } else { + auto load = rewriter.create( + loc, tType, ptrInfo.base, updateShapes, updateStrides, + ptrInfo.offsets, defaultValue, dynamicOpHint); + rewriter.replaceOp(op, load); + return success(); + } + } else { // static stride process will be delete for pingpong support + // dynamic + LLVM_DEBUG(llvm::dbgs() << "=== static stride process ===\n"); + auto staticOrder = GetOrderByHint(opHint); + // if (IsStaticReshape(opHint)) { + // for (int i = 0; i < rank; ++i) + // staticOrder[i]++; + // } + + assert(static_cast(staticOrder.size()) == rank && + "the order size and rank mismatch \n"); + bool bNeedTranspose = false; + for (uint32_t i = 0; i < rank; ++i) { + if (staticOrder[i] != static_cast(i)) { + bNeedTranspose = true; + break; + } + } + + SmallVector staticDefaultOrder; + for (int i = 0; i < rank; ++i) + staticDefaultOrder.push_back(i); + + if (IsStaticReshape(opHint)) { + for (int i = 0; i < rank; ++i) + staticDefaultOrder[i]++; + } + + SmallVector orderStrides(rank); + SmallVector orderShapes(rank); + for (int i = 0; i < rank; ++i) { + if (!bNeedTranspose) { + orderStrides[i] = ptrInfo.strides[i]; + orderShapes[i] = ptrInfo.shape[i]; + } else { + orderStrides[staticOrder[i]] = ptrInfo.strides[i]; + orderShapes[staticOrder[i]] = ptrInfo.shape[i]; + } + } + + // update broadcast dim stride + for (int i = rank - 1; i >= 0; --i) { + if (ptrInfo.broadcastDims.count(i)) { + if (i == rank - 1) + orderStrides[i] = one; + else + orderStrides[i] = rewriter.create( + loc, orderShapes[i + 1], orderStrides[i + 1]); + } + } + + SmallVector updateResultShape(rank); + for (int i = 0; i < rank; ++i) { + if (ptrInfo.broadcastDims.count(i)) + updateResultShape[i] = 1; + else + updateResultShape[i] = resultShape[i]; + } + + if (bNeedTranspose) { + SmallVector sliceResultShape(rank); + for (unsigned int i = 0; i < rank; i++) + sliceResultShape[staticOrder[i]] = updateResultShape[i]; + + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(tType.getEncoding()); + auto order = triton::gpu::getOrder(tType); + auto ctaLayout = triton::gpu::getCTALayout(tType.getEncoding()); + + SmallVector sliceWarpsPerCTA(rank); + SmallVector sliceOrder(rank); + for (unsigned int i = 0; i < rank; ++i) { + sliceWarpsPerCTA[staticOrder[i]] = warpsPerCTA[i]; + sliceOrder[order[i]] = staticOrder[i]; + } + + auto sliceOutType = RankedTensorType::get( + sliceResultShape, elemType, + triton::gpu::BlockedEncodingAttr::get( + getContext(), SmallVector(rank, 1), + SmallVector(rank, 1), sliceWarpsPerCTA, sliceOrder, + ctaLayout)); + + auto load = rewriter.create( + loc, sliceOutType, ptrInfo.base, orderShapes, orderStrides, + ptrInfo.offsets, defaultValue, staticDefaultOrder); + if (bNeedBroadCast) { + auto transType = RankedTensorType::get(updateResultShape, elemType, + tType.getEncoding()); + auto transpose = rewriter.create(loc, transType, + load, staticOrder); + auto broadcastOp = rewriter.create( + loc, op.getType(), transpose); + rewriter.replaceOp(op, broadcastOp); + return success(); + } else { + auto transpose = rewriter.create(loc, op.getType(), + load, staticOrder); + rewriter.replaceOp(op, transpose); + return success(); + } + } else if (bNeedBroadCast) { + auto loadType = RankedTensorType::get(updateResultShape, elemType, + tType.getEncoding()); + auto load = rewriter.create( + loc, loadType, ptrInfo.base, orderShapes, orderStrides, + ptrInfo.offsets, defaultValue, staticDefaultOrder); + auto broadcastOp = + rewriter.create(loc, op.getType(), load); + rewriter.replaceOp(op, broadcastOp); + return success(); + } else { + auto load = rewriter.create( + loc, tType, ptrInfo.base, orderShapes, orderStrides, + ptrInfo.offsets, defaultValue, staticDefaultOrder); + rewriter.replaceOp(op, load); + return success(); + } + } + } + + LogicalResult matchAndRewrite(triton::LoadOp op, + PatternRewriter &rewriter) const override { + // 1. Analyze the ptr operand to check whether it is continuous. + LLVM_DEBUG(llvm::dbgs() << "=== check load ptr contiguous ===\n"); + if (std::find(candidateOps.begin(), candidateOps.end(), + op.getOperation()) == candidateOps.end()) { + return failure(); + } + return rewriteTensorLoad(op, rewriter); + } +}; + +struct ConvertStoreOpToDma : public OpRewritePattern { + llvm::SmallDenseMap &knownPtrs; + llvm::SmallDenseMap &knownMasks; + llvm::SmallVector &candidateOps; + llvm::SmallDenseMap> &candidateHints; + + explicit ConvertStoreOpToDma( + MLIRContext *context, + llvm::SmallDenseMap &knownPtrs, + llvm::SmallDenseMap &knownMasks, + llvm::SmallVector &candidateOps, + llvm::SmallDenseMap> &candidateHints) + : OpRewritePattern(context), knownPtrs(knownPtrs), + knownMasks(knownMasks), candidateOps(candidateOps), + candidateHints(candidateHints) {} + LogicalResult rewriteTensorStore(triton::StoreOp op, + PatternRewriter &rewriter) const { + // 2. Analyze the mask operand to determine at runtime the size of the data + // we are moving. + auto loc = op.getLoc(); + LLVM_DEBUG(llvm::dbgs() << "=== analyze store mask state ===\n"); + gcu::MaskState mstate; + if (op.getMask()) { + gcu::MaskAnalysis::parse(rewriter, loc, op.getMask(), mstate, knownMasks); + assert(!mstate.isEmpty() && + "expect valid mask state after analysis succeed\n"); + } + + // 3. Get ptr info + LLVM_DEBUG(llvm::dbgs() << "=== analyze store ptr state ===\n"); + gcu::PtrState pstate; + gcu::PtrAnalysis::visitOperand(rewriter, loc, op.getPtr(), pstate, + knownPtrs); + + auto tType = dyn_cast(op.getValue().getType()); + assert(tType && "the store value type is null\n"); + auto rank = pstate.getRank(); + auto one = rewriter.create(loc, 1); + auto ptrInfo = pstate.getPtrInfo(rewriter, loc, mstate); + auto storeShape = tType.getShape(); + auto elemType = tType.getElementType(); + + SmallVector vSliceShape; + for (unsigned int i = 0; i < rank; i++) { + if (ptrInfo.broadcastDims.count(i)) + vSliceShape.push_back(one); + else + vSliceShape.push_back(ptrInfo.shape[i]); + } + assert(candidateHints.find(op.getOperation()) != candidateHints.end() && + "get order failed"); + auto opHint = candidateHints[op.getOperation()]; + // dynamic stride process + if (!IsStaticStride(opHint)) { + // update broadcast dim stride(move to tritontogcu pass) + SmallVector updateStrides; + SmallVector dynamicOpHint; + for (int32_t i = 0; i < rank; ++i) { + if (!ptrInfo.broadcastDims.count(i)) { + updateStrides.push_back(ptrInfo.strides[i]); + dynamicOpHint.push_back(opHint[i]); + } else { + if (i == rank - 1) { + dynamicOpHint.push_back(1); + updateStrides.push_back(one); + } else { + dynamicOpHint.push_back(opHint[i]); + updateStrides.push_back(ptrInfo.strides[i]); + } + } + } + + auto store = rewriter.create( + loc, op.getValue(), ptrInfo.base, vSliceShape, updateStrides, + ptrInfo.offsets, dynamicOpHint); + rewriter.replaceOp(op, store); + return success(); + } else { // static stride process will be delete for pingpong support + // dynamic + auto staticOrder = GetOrderByHint(opHint); + // if (IsStaticReshape(opHint)) { + // for (int i = 0; i < rank; ++i) + // staticOrder[i]++; + // } + assert(static_cast(staticOrder.size()) == rank && + "the order size and rank mismatch \n"); + bool bNeedTranspose = false; + for (uint32_t i = 0; i < rank; ++i) { + if (staticOrder[i] != static_cast(i)) { + bNeedTranspose = true; + break; + } + } + + SmallVector staticDefaultOrder; + for (int i = 0; i < rank; ++i) + staticDefaultOrder.push_back(i); + + if (IsStaticReshape(opHint)) { + for (int i = 0; i < rank; ++i) + staticDefaultOrder[i]++; + } + + SmallVector orderStrides(rank); + SmallVector orderShapes(rank); + for (int i = 0; i < rank; ++i) { + if (!bNeedTranspose) { + orderStrides[i] = ptrInfo.strides[i]; + orderShapes[i] = vSliceShape[i]; + } else { + orderStrides[staticOrder[i]] = ptrInfo.strides[i]; + orderShapes[staticOrder[i]] = vSliceShape[i]; + } + } + + // update broadcast dim stride + for (int i = rank - 1; i >= 0; --i) { + if (ptrInfo.broadcastDims.count(i)) { + if (i == rank - 1) + orderStrides[i] = one; + else + orderStrides[i] = rewriter.create( + loc, orderShapes[i + 1], orderStrides[i + 1]); + } + } + + if (bNeedTranspose) { + SmallVector transShapes(rank); + for (unsigned int i = 0; i < rank; i++) + transShapes[staticOrder[i]] = storeShape[i]; + + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(tType.getEncoding()); + auto order = triton::gpu::getOrder(tType); + auto ctaLayout = triton::gpu::getCTALayout(tType.getEncoding()); + + SmallVector transWarpsPerCTA(rank); + SmallVector transOrder(rank); + for (unsigned int i = 0; i < rank; ++i) { + transWarpsPerCTA[staticOrder[i]] = warpsPerCTA[i]; + transOrder[order[i]] = staticOrder[i]; + } + + auto transType = RankedTensorType::get( + transShapes, elemType, + triton::gpu::BlockedEncodingAttr::get( + getContext(), SmallVector(rank, 1), + SmallVector(rank, 1), transWarpsPerCTA, transOrder, + ctaLayout)); + + auto transpose = rewriter.create( + loc, transType, op.getValue(), staticOrder); + auto store = rewriter.create( + loc, transpose, ptrInfo.base, orderShapes, orderStrides, + ptrInfo.offsets, staticDefaultOrder); + rewriter.replaceOp(op, store); + return success(); + } else { + auto store = rewriter.create( + loc, op.getValue(), ptrInfo.base, orderShapes, orderStrides, + ptrInfo.offsets, staticDefaultOrder); + rewriter.replaceOp(op, store); + return success(); + } + } + } + + LogicalResult matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const override { + // 1. Analyze the ptr operand to check whether it is continuous. + LLVM_DEBUG(llvm::dbgs() << "=== check store ptr contiguous ===\n"); + if (std::find(candidateOps.begin(), candidateOps.end(), + op.getOperation()) == candidateOps.end()) { + return failure(); + } + return rewriteTensorStore(op, rewriter); + } +}; + +} // namespace + +void ConvertTritonLoadStoreToDmaPass::runOnOperation() { + LLVM_DEBUG(llvm::dbgs() << "ConvertTritonLoadStoreToDmaPass\n"); + + auto *ctx = &getContext(); + auto op = getOperation(); + + // 1. collect load/store ops + auto moduleOp = op->getParentOfType(); + llvm::SmallVector candidateOps; + llvm::SmallDenseMap> candidateHints; + gcu::PtrAnalysis::collectCandidateLoadStoreOps(moduleOp, candidateOps, + candidateHints); + + // 2. Pre-process some ops + GreedyRewriteConfig rewriteConfig; + rewriteConfig.strictMode = GreedyRewriteStrictness::ExistingOps; + + llvm::SmallDenseMap knowMasks; + llvm::SmallDenseMap knownPtrs; + RewritePatternSet prePatterns(ctx); + prePatterns.add(ctx, knownPtrs, knowMasks, candidateOps, + candidateHints); + + if (applyPatternsGreedily(op, std::move(prePatterns), rewriteConfig).failed()) + signalPassFailure(); + + // 3. Start to process load/store op + RewritePatternSet patterns(ctx); + patterns.add( + ctx, knownPtrs, knowMasks, candidateOps, candidateHints); + if (applyPatternsGreedily(op, std::move(patterns)).failed()) + signalPassFailure(); + + // 4. Post-process some ops + RewritePatternSet postPatterns(ctx); + postPatterns.add(ctx, knownPtrs); + if (applyPatternsGreedily(op, std::move(postPatterns), rewriteConfig) + .failed()) + signalPassFailure(); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonToGCU.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonToGCU.cpp new file mode 100644 index 000000000..72a51ade6 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/TritonToGCU.cpp @@ -0,0 +1,3292 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + +#include "Analysis/FirstLastUserAnalysis.h" +#include "Conversion/TritonToGCU/TritonToGCUPass.h" + +#include "PatternTritonGPUOpToGCU.h" +#include "Utils.h" + +#include "ConstantUtil.h" +#include "Dialect/GCU/IR/Dialect.h" +#include "Dialect/GCU/IR/Types.h" +#include "Dialect/MathExt/IR/MathExt.h" +#include "Dialect/MathExt/IR/MathExtTypes.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "TritonGCUToGCU/TritionToGCUBase.h" +#include "TritonGCUToGCU/TritonGCUAsyncOpToGCU.h" +#include "TritonGCUToGCU/TritonGCUToGCUUtils.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +namespace mlir { +#define GEN_PASS_DEF_CONVERTTRITONTOGCUPASS +#include "Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +#define DEBUG_TYPE "triton-ir-to-gcu-ir" +namespace { +struct ConvertTritonToGCUPass + : public mlir::impl::ConvertTritonToGCUPassBase { + using Base::Base; + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } +}; + +} // namespace +namespace { +struct TTFuncOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp ttFuncOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = ttFuncOp.getLoc(); + // Remap proper input types. + TypeConverter::SignatureConversion signatureConversion( + ttFuncOp.front().getNumArguments()); + + // Convert argument types one by one and check for errors. + for (auto [idx, type] : + llvm::enumerate(ttFuncOp.getFunctionType().getInputs())) { + SmallVector converted; + converted.push_back(getTypeConverter()->convertType(type)); + signatureConversion.addInputs(idx, converted); + } + SmallVector resultTypes; + for (auto type : ttFuncOp.getFunctionType().getResults()) { + resultTypes.push_back(getTypeConverter()->convertType(type)); + } + + auto funcType = FunctionType::get( + getContext(), signatureConversion.getConvertedTypes(), resultTypes); + auto funcName = ttFuncOp.isPublic() + ? (ttFuncOp.getName() + "_triton_internal__").str() + : ttFuncOp.getName().str(); + auto func = rewriter.create(loc, funcName, funcType); + func.getBody().getBlocks().clear(); + func.setPrivate(); + auto internalLinkage = mlir::LLVM::linkage::Linkage::Internal; + auto linkage = mlir::LLVM::LinkageAttr::get(getContext(), internalLinkage); + func->setAttr("llvm.linkage", linkage); + // Move the region to the new function, update the entry block signature. + rewriter.inlineRegionBefore(ttFuncOp.getBody(), func.getBody(), func.end()); + if (failed(rewriter.convertRegionTypes(&func.getBody(), *getTypeConverter(), + &signatureConversion))) + return failure(); + + if (ttFuncOp.isPublic()) { + auto gpufunc = + rewriter.create(loc, ttFuncOp.getName(), funcType); + gpufunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + rewriter.getUnitAttr()); + OpBuilder::InsertionGuard guard(rewriter); + auto entryBlock = &gpufunc.getBody().getBlocks().back(); + rewriter.setInsertionPointToStart(entryBlock); + + auto call = + rewriter.create(loc, func, entryBlock->getArguments()); + rewriter.create(loc, call->getResults()); + } + + rewriter.eraseOp(ttFuncOp); + return success(); + } +}; + +struct TTReturnOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp returnOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (returnOp->getParentOfType()) { + rewriter.replaceOpWithNewOp(returnOp, + returnOp.getOperands()); + } else { + rewriter.replaceOpWithNewOp(returnOp, + returnOp.getOperands()); + } + return success(); + } +}; + +struct TTCallOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp callOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + for (auto ty : callOp->getResultTypes()) { + resultTypes.push_back(getTypeConverter()->convertType(ty)); + } + rewriter.replaceOpWithNewOp( + callOp, callOp.getCallee(), resultTypes, adaptor.getOperands()); + return success(); + } +}; + +struct TTSCFForOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + // Remap proper input types. + TypeConverter::SignatureConversion signatureConversion( + op.getBody()->getNumArguments()); + + // Convert argument types one by one and check for errors. + for (auto [idx, type] : llvm::enumerate(op.getBody()->getArgumentTypes())) { + SmallVector converted; + converted.push_back(getTypeConverter()->convertType(type)); + signatureConversion.addInputs(idx, converted); + } + SmallVector resultTypes; + for (auto type : op.getResultTypes()) { + resultTypes.push_back(getTypeConverter()->convertType(type)); + } + + auto forOp = rewriter.create( + loc, adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), adaptor.getInitArgs()); + forOp.getRegion().getBlocks().clear(); + + rewriter.inlineRegionBefore(op.getRegion(), forOp.getRegion(), + forOp.getRegion().end()); + if (failed(rewriter.convertRegionTypes( + &forOp.getRegion(), *getTypeConverter(), &signatureConversion))) + return failure(); + + replaced2Origin[forOp.getOperation()] = op.getOperation(); + + rewriter.replaceOp(op, forOp); + return success(); + } +}; + +struct TTSCFIfOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + SmallVector resultTypes; + for (auto type : op.getResultTypes()) { + resultTypes.push_back(getTypeConverter()->convertType(type)); + } + + bool hasElse = op.getNumRegions() > 1; + + auto ifOp = rewriter.create(loc, resultTypes, + adaptor.getCondition(), hasElse); + + ifOp.getThenRegion().getBlocks().clear(); + if (hasElse) + ifOp.getElseRegion().getBlocks().clear(); + + rewriter.inlineRegionBefore(op.getThenRegion(), ifOp.getThenRegion(), + ifOp.getThenRegion().end()); + if (hasElse) + rewriter.inlineRegionBefore(op.getElseRegion(), ifOp.getElseRegion(), + ifOp.getElseRegion().end()); + + replaced2Origin[ifOp.getOperation()] = op.getOperation(); + + rewriter.replaceOp(op, ifOp); + return success(); + } +}; + +struct TTSCFYieldOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + std::map> + &TTYeiledOPerandHasMultiUseStage; + + TTSCFYieldOpLowering( + const TypeConverter &converter, MLIRContext *ctx, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin, + std::map> &operendStage) + : SharedConversionPattern(converter, ctx, userAnalysis, replaced2Origin), + TTYeiledOPerandHasMultiUseStage(operendStage) {} + + LogicalResult + matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + if (isa(op.getOperation()->getParentOp())) { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } + auto loc = op.getLoc(); + SmallVector updatedOperands; + for (uint64_t i = 0; i < adaptor.getOperands().size(); ++i) { + auto operand = adaptor.getOperands()[i]; + if (isa(operand.getType())) { + auto definingOp = operand.getDefiningOp(); + auto parent = op.getOperation()->getParentOp(); + bool isMultiUse = TTYeiledOPerandHasMultiUseStage[op.getOperation()][i]; + if (!isMultiUse) { + updatedOperands.push_back(operand); + + auto originParent = replaced2Origin[parent]; + auto lastUser = userAnalysis.getLastUserOp(originParent); + auto newAllocOpPos = + promoteLastUser(lastUser, userAnalysis, replaced2Origin); + + Operation *allocOp = definingOp; + while (allocOp && !mlir::isa(allocOp)) { + mlir::TypeSwitch(allocOp) + .Case( + [&](auto castOp) { + allocOp = castOp.getSource().getDefiningOp(); + }) + .Default([&](auto op) { allocOp = nullptr; }); + } + if (!allocOp) { + LLVM_DEBUG({ + llvm::dbgs() << "can't find allocOp in the same region\n"; + allocOp->dump(); + }); + continue; + } + if (newAllocOpPos == nullptr) { + allocOp->moveBefore(parent); + } else { + allocOp->moveBefore(newAllocOpPos); + } + addDeallocAfterLastUser(rewriter, lastUser, allocOp->getResult(0)); + + continue; + } + + auto tag = getPrivateDTETag(rewriter, op); + auto zero = rewriter.create(loc, 0); + auto shape = dyn_cast(operand.getType()).getShape(); + auto size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + if (isa(parent)) { + if (replaced2Origin.count(parent) == 0) { + llvm_unreachable("can't find the origin op"); + } + auto originParent = replaced2Origin[parent]; + auto lastUser = userAnalysis.getLastUserOp(originParent); + + auto newAllocOpPos = + promoteLastUser(lastUser, userAnalysis, replaced2Origin); + + Value nextLoopTensor; + auto ip = rewriter.saveInsertionPoint(); + if (newAllocOpPos == nullptr) { + rewriter.setInsertionPoint(parent); + nextLoopTensor = rewriter.create( + loc, dyn_cast(operand.getType())); + } else { + rewriter.setInsertionPoint(newAllocOpPos); + nextLoopTensor = rewriter.create( + loc, dyn_cast(operand.getType())); + } + rewriter.restoreInsertionPoint(ip); + + addDeallocAfterLastUser(rewriter, lastUser, nextLoopTensor); + + rewriter.create( + loc, operand, SmallVector(shape.size(), zero), + nextLoopTensor, SmallVector(shape.size(), zero), + rewriter.create(loc, size), tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, size)); + + if (isa_and_nonnull(definingOp)) { + rewriter.create(loc, definingOp->getResult(0)); + } + updatedOperands.push_back(nextLoopTensor); + } else { + auto nextLoopTensor = rewriter.create( + loc, dyn_cast(operand.getType())); + rewriter.create( + loc, operand, SmallVector(shape.size(), zero), + nextLoopTensor, SmallVector(shape.size(), zero), + rewriter.create(loc, size), tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, size)); + updatedOperands.push_back(nextLoopTensor); + } + continue; + } + updatedOperands.push_back(operand); + } + + rewriter.replaceOpWithNewOp(op, updatedOperands); + return success(); + } +}; + +struct TTSCFWhileOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + // Remap proper input types. + TypeConverter::SignatureConversion signatureConversionBefore( + op.getBeforeBody()->getNumArguments()); + + // Convert argument types one by one and check for errors. + for (auto [idx, type] : + llvm::enumerate(op.getBeforeBody()->getArgumentTypes())) { + SmallVector converted; + converted.push_back(getTypeConverter()->convertType(type)); + signatureConversionBefore.addInputs(idx, converted); + } + + TypeConverter::SignatureConversion signatureConversionAfter( + op.getBody()->getNumArguments()); + + // Convert argument types one by one and check for errors. + for (auto [idx, type] : + llvm::enumerate(op.getAfterBody()->getArgumentTypes())) { + SmallVector converted; + converted.push_back(getTypeConverter()->convertType(type)); + signatureConversionAfter.addInputs(idx, converted); + } + + SmallVector resultTypes; + for (auto type : op.getResultTypes()) { + resultTypes.push_back(getTypeConverter()->convertType(type)); + } + + auto whileOp = + rewriter.create(loc, resultTypes, adaptor.getInits()); + whileOp.getBefore().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getBefore(), whileOp.getBefore(), + whileOp.getBefore().end()); + whileOp.getAfter().getBlocks().clear(); + rewriter.inlineRegionBefore(op.getAfter(), whileOp.getAfter(), + whileOp.getAfter().end()); + if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), + *getTypeConverter(), + &signatureConversionBefore))) + return failure(); + if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), + *getTypeConverter(), + &signatureConversionAfter))) + return failure(); + replaced2Origin[whileOp.getOperation()] = op.getOperation(); + + rewriter.replaceOp(op, whileOp); + return success(); + } +}; + +struct TTSCFConditionLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + // Remap proper input types. + auto conditionOp = rewriter.create( + loc, adaptor.getCondition(), adaptor.getArgs()); + rewriter.replaceOp(op, conditionOp); + return success(); + } +}; + +template +struct TTIntrinsicOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(FT op, + typename SharedConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + gpu::Dimension dim = gpu::Dimension::x; + switch (op.getAxis()) { + case triton::ProgramIDDim::X: + dim = gpu::Dimension::x; + break; + case triton::ProgramIDDim::Y: + dim = gpu::Dimension::y; + break; + case triton::ProgramIDDim::Z: + dim = gpu::Dimension::z; + break; + default: + dim = gpu::Dimension::x; + break; + } + auto loc = op.getLoc(); + auto newOp = rewriter.create( + loc, this->getTypeConverter()->convertType(op.getType()), + rewriter.create(loc, dim)); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct TTAssertOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::AssertOp assertOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = assertOp.getLoc(); + + auto message = assertOp.getMessage(); + + auto assertSingleElement = [&](Value operand, ValueRange iters) { + // load single element + auto value = TypeSwitch(operand.getType()) + .Case([&](auto ty) { + return rewriter.create(loc, operand); + }) + .Default([&](auto ty) { return operand; }); + // Create gcu.assert op + rewriter.create( + loc, value, mlir::StringAttr::get(rewriter.getContext(), message), "", + "", 0); + }; + + auto assertMemrefCondition = [&](Value operand) { + TypeSwitch(operand.getType()) + .Case([&](auto ty) { + // use loop nest to load all elements in memref + affine::buildAffineLoopNest( + rewriter, loc, SmallVector(ty.getRank(), 0), + ty.getShape(), SmallVector(ty.getRank(), 1), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + auto v = builder.create(loc, operand, iters); + assertSingleElement(v, iters); + }); + }) + .Default([&](auto ty) { assertSingleElement(operand, {}); }); + }; + + // handle memref + assertMemrefCondition(adaptor.getCondition()); + + rewriter.eraseOp(assertOp); + return success(); + } +}; + +struct TTPrintOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::PrintOp printOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = printOp.getLoc(); + auto printOpPrefix = printOp.getPrefix(); + auto hex = printOp.getHex(); + + // Simple printf of a string without any tensors. + if (printOp.getNumOperands() == 0) { + rewriter.create(loc, (printOpPrefix + "\n").str(), + ValueRange{}); + rewriter.eraseOp(printOp); + return success(); + } + + auto printSingleElement = [&](Value operand, size_t i, size_t n, + ValueRange iters) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << printOpPrefix << ": "; + if (n > 1) + os << "(operand " << i << ") "; + + // format + auto msg = TypeSwitch(operand.getType()) + .Case([&](auto ty) { + if (hex) { + os << "0x%x "; + return "0x%x "; + } else { + os << "%d "; + return "%d "; + } + }) + .Case([&](auto ty) { + auto isSigned = ty.isSigned(); + if (hex) { + os << "0x%x "; + return "0x%x "; + } else { + if (isSigned) { + os << "%d "; + return "%d "; + } + os << "%u "; + return "%u "; + } + }) + .Default([&](auto ty) { + os << "%f "; + return "%f "; + }); + + // value + SmallVector values; + auto value = TypeSwitch(operand.getType()) + .Case([&](auto ty) { + return rewriter.create(loc, operand); + }) + .Default([&](auto ty) { return operand; }); + values.push_back(value); + + if (!iters.empty()) { + // idx format + os << "(idx "; + for (auto iter = iters.begin(); iter != iters.end(); ++iter) { + if (iter != iters.begin()) + os << ", "; + os << "%d"; + } + os << ")"; + // idx value + values.append(iters.begin(), iters.end()); + } + os << "\n"; + + if (!msg.empty()) + rewriter.create(loc, formatStr, ValueRange{values}); + }; + + auto printOperand = [&](Value operand, size_t i, size_t n) { + TypeSwitch(operand.getType()) + .Case([&](auto ty) { + affine::buildAffineLoopNest( + rewriter, loc, SmallVector(ty.getRank(), 0), + ty.getShape(), SmallVector(ty.getRank(), 1), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + auto v = builder.create(loc, operand, iters); + printSingleElement(v, i, n, iters); + }); + }) + .Default([&](auto ty) { printSingleElement(operand, i, n, {}); }); + }; + + // print all operands by order + for (size_t i = 0; i < adaptor.getOperands().size(); ++i) { + printOperand(adaptor.getOperands()[i], i, adaptor.getOperands().size()); + } + + rewriter.eraseOp(printOp); + return success(); + } +}; + +struct TTMakeRangeOpLowering : SharedConversionPattern { + unsigned vectorLengthInByte; + TTMakeRangeOpLowering(const TypeConverter &converter, MLIRContext *ctx, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin, + unsigned vectorLength, unsigned vectorizationMaxLength) + : SharedConversionPattern(converter, ctx, userAnalysis, replaced2Origin), + vectorLengthInByte(vectorLength) {} + + LogicalResult + matchAndRewrite(triton::MakeRangeOp arangeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, arangeOp.getOperation()); + auto loc = arangeOp.getLoc(); + auto lastUser = userAnalysis.getLastUserOp(arangeOp.getOperation()); + auto warpIds = getWarpIds(rewriter, loc, arangeOp.getType()); + auto slicedAxies = getSlicedAxies(arangeOp.getType()); + auto numElems = triton::gcu::getTotalElemsPerThread(arangeOp.getType()); + auto start = arangeOp.getStart(); + auto resultType = dyn_cast( + getTypeConverter()->convertType(arangeOp.getType())); + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + auto startOffset = + slicedAxies.empty() + ? rewriter + .create(loc, start, + resultType.getElementType()) + .getResult() + : rewriter.create( + loc, resultType.getElementType(), + rewriter.create( + loc, + rewriter.create( + loc, warpIds.front(), + rewriter.create(loc, + numElems)), + rewriter.create(loc, start))); + + auto vectorLength = + vectorLengthInByte / triton::gcu::getBpe(resultType.getElementType()); + + auto vectorType = VectorType::get(ArrayRef{vectorLength}, + resultType.getElementType()); + auto arangeV = + rewriter + .create( + loc, vectorType, + rewriter + .create( + loc, VectorType::get(ArrayRef{vectorLength}, + rewriter.getIndexType())) + .getResult()) + .getResult(0); + + Value vec = rewriter.create( + loc, arangeV, + rewriter.create(loc, vectorType, startOffset)); + Value step = rewriter.create( + loc, vectorType, + rewriter.create(loc, vectorLength, + resultType.getElementType())); + rewriter.create( + loc, rewriter.create(loc, 0), + rewriter.create(loc, numElems), + rewriter.create(loc, vectorLength), + ValueRange{vec}, + [&](OpBuilder &builder, Location loc, Value iters, + ValueRange iterArgs) { + builder.create(loc, iterArgs[0], output, iters); + builder.create( + loc, ValueRange{ + builder.create(loc, iterArgs[0], step)}); + }); + leaveTritionOp(rewriter, arangeOp.getOperation()); + rewriter.replaceOp(arangeOp, output); + return success(); + } +}; + +struct TTSplatOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, splatOp.getOperation()); + auto lastUser = userAnalysis.getLastUserOp(splatOp.getOperation()); + auto loc = splatOp.getLoc(); + auto numElems = triton::gcu::getElemsPerThread(splatOp.getType()); + auto resultType = dyn_cast( + getTypeConverter()->convertType(splatOp.getType())); + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + auto v = isa(splatOp.getSrc().getType()) + ? rewriter.create(loc, adaptor.getSrc()) + : adaptor.getSrc(); + auto totalNumElems = triton::gcu::getTotalElemsPerThread(splatOp.getType()); + doMemset(rewriter, splatOp, output, v, totalNumElems); + leaveTritionOp(rewriter, splatOp.getOperation()); + rewriter.replaceOp(splatOp, output); + return success(); + } +}; + +struct TTConstantOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, constOp.getOperation()); + auto loc = constOp.getLoc(); + if (!isa(constOp.getType())) + return failure(); + auto lastUser = userAnalysis.getLastUserOp(constOp.getOperation()); + auto totalNumElems = triton::gcu::getTotalElemsPerThread(constOp.getType()); + auto resultType = dyn_cast( + getTypeConverter()->convertType(constOp.getType())); + auto valueAttr = constOp.getValue(); + auto array = dyn_cast(valueAttr); + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + + // only support splat constant + auto attr = array.getSplatValue(); + auto v = + rewriter.create(loc, array.getElementType(), attr); + doMemset(rewriter, constOp, output, v, totalNumElems); + leaveTritionOp(rewriter, constOp.getOperation()); + rewriter.replaceOp(constOp, output); + return success(); + } +}; + +struct TTAddPtrOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp addPtrOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = addPtrOp.getLoc(); + enterTritionOp(rewriter, addPtrOp.getOperation()); + // vector + if (isa(addPtrOp.getType())) { + auto lastUser = userAnalysis.getLastUserOp(addPtrOp.getOperation()); + auto numElems = triton::gcu::getElemsPerThread(addPtrOp.getType()); + auto resultType = dyn_cast( + getTypeConverter()->convertType(addPtrOp.getType())); + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + auto ptrs = adaptor.getPtr(); + auto offsets = adaptor.getOffset(); + affine::buildAffineLoopNest( + rewriter, loc, SmallVector(numElems.size(), 0), + SmallVector(numElems.begin(), numElems.end()), + SmallVector(numElems.size(), 1), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + auto ptrType = + dyn_cast(getTypeConverter()->convertType( + dyn_cast(addPtrOp.getType()).getElementType())); + auto elemType = ptrType.getElementType(); + auto elemBytes = (elemType.getIntOrFloatBitWidth() + 7) / 8; + auto lhs = builder.create(loc, ptrs, iters); + auto rhs = + builder.create(loc, offsets, iters).getResult(); + rhs = builder.create( + loc, rhs, + builder.create(loc, elemBytes, + rhs.getType())); + auto v = builder.create( + loc, lhs, + rhs.getType().getIntOrFloatBitWidth() < 64 + ? builder.create(loc, builder.getI64Type(), + rhs) + : rhs); + builder.create(loc, v, output, iters); + }); + doMemFence(rewriter, addPtrOp); + leaveTritionOp(rewriter, addPtrOp.getOperation()); + rewriter.replaceOp(addPtrOp, output); + return success(); + } + + // scalar + auto resultType = dyn_cast( + getTypeConverter()->convertType(addPtrOp.getType())); + auto elemType = resultType.getElementType(); + auto elemBytes = (elemType.getIntOrFloatBitWidth() + 7) / 8; + auto ptr = adaptor.getPtr(); + auto offset = adaptor.getOffset(); + offset = + rewriter.create(loc, offset, + rewriter.create( + loc, elemBytes, offset.getType())); + auto v = rewriter.create( + loc, resultType, + rewriter.create( + loc, rewriter.create(loc, ptr), + offset.getType().getIntOrFloatBitWidth() < 64 + ? rewriter.create(loc, rewriter.getI64Type(), + offset) + : offset)); + rewriter.replaceOp(addPtrOp, v); + return success(); + } +}; + +struct TTLoadOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult loadSingleElement(triton::LoadOp loadOp, OpBuilder &builder, + Value ptr, Value output, Value offset, + Value tag, Value mask, Value other) const { + auto loc = loadOp.getLoc(); + + auto elemType = dyn_cast(ptr.getType()).getElementType(); + + auto memType1D = + MemRefType::get(ArrayRef{ShapedType::kDynamic}, elemType); + auto buffer = builder.create(loc, memType1D, ptr); + auto one = builder.create(loc, 1); + auto zero = builder.create(loc, 0); + + auto from = builder.create( + loc, MemRefType::get(ArrayRef{1}, elemType), buffer, 0, + ArrayRef{1}, ArrayRef{1}); + + auto to = builder.create( + loc, MemRefType::get(ArrayRef{1}, elemType), output, offset, + ValueRange{one}, ValueRange{one}); + auto result = success(); + builder.create( + loc, mask, + [&](OpBuilder &builder, Location loc) { + builder.create(loc, from, ValueRange{zero}, to, + ValueRange{zero}, one, tag, + ValueRange{zero}); + builder.create(loc, tag, ValueRange{zero}, one); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create(loc, other, to, ValueRange{offset}); + doMemFence(builder, loadOp); + builder.create(loc); + }); + return result; + } + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, loadOp.getOperation()); + auto loc = loadOp.getLoc(); + assert(!(isa(loadOp.getPtr().getType()) && + isa( + dyn_cast(loadOp.getPtr().getType()) + .getPointeeType()))); + + // tensor + if (isa(loadOp.getType())) { + auto lastUser = userAnalysis.getLastUserOp(loadOp.getOperation()); + auto numElems = triton::gcu::getElemsPerThread(loadOp.getType()); + auto numElemValues = getElemsPerThread(rewriter, loc, loadOp.getType()); + + auto resultType = dyn_cast( + getTypeConverter()->convertType(loadOp.getType())); + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + auto offsets = syncAllocOp( + rewriter, loc, loadOp.getOperation(), userAnalysis, replaced2Origin, + MemRefType::get(resultType.getShape(), rewriter.getI32Type())); + auto masks = syncAllocOp( + rewriter, loc, loadOp.getOperation(), userAnalysis, replaced2Origin, + MemRefType::get(resultType.getShape(), rewriter.getI1Type())); + auto others = syncAllocOp( + rewriter, loc, loadOp.getOperation(), userAnalysis, replaced2Origin, + MemRefType::get(resultType.getShape(), resultType.getElementType())); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto firstIndex = SmallVector(numElems.size(), zero); + auto firstAddr = + rewriter.create(loc, adaptor.getPtr(), firstIndex); + scf::buildLoopNest( + rewriter, loc, SmallVector(numElems.size(), zero), + numElemValues, SmallVector(numElems.size(), one), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + auto addr = + builder.create(loc, adaptor.getPtr(), iters); + auto offset = builder.create(loc, addr, firstAddr); + builder.create( + loc, + builder.create(loc, builder.getI32Type(), + offset), + offsets, iters); + + auto mask = + adaptor.getMask() + ? builder + .create(loc, adaptor.getMask(), iters) + .getResult() + : builder + .create(loc, 1, + builder.getI1Type()) + .getResult(); + builder.create(loc, mask, masks, iters); + + auto other = adaptor.getOther() + ? rewriter + .create( + loc, adaptor.getOther(), iters) + .getResult() + : triton::gcu::createConstantZero( + rewriter, loc, resultType.getElementType()); + builder.create(loc, other, others, iters); + }); + + auto totalNumElems = + rewriter.create(loc, 1).getResult(); + for (unsigned i = 0; i < numElemValues.size(); ++i) { + totalNumElems = rewriter.create(loc, totalNumElems, + numElemValues[i]); + } + + auto output1D = castToMemref1D(rewriter, loc, output, totalNumElems); + auto offsets1D = castToMemref1D(rewriter, loc, offsets, totalNumElems); + auto masks1D = castToMemref1D(rewriter, loc, masks, totalNumElems); + auto others1D = castToMemref1D(rewriter, loc, others, totalNumElems); + rewriter.create( + loc, + rewriter.create( + loc, gcu::PtrType::get(getContext(), resultType.getElementType()), + output1D), + rewriter.create( + loc, gcu::PtrType::get(getContext(), resultType.getElementType()), + firstAddr), + rewriter.create( + loc, gcu::PtrType::get(getContext(), rewriter.getI32Type()), + offsets1D), + rewriter.create( + loc, gcu::PtrType::get(getContext(), rewriter.getI1Type()), + masks1D), + rewriter.create( + loc, gcu::PtrType::get(getContext(), resultType.getElementType()), + others1D), + rewriter.create(loc, rewriter.getI32Type(), + totalNumElems)); + leaveTritionOp(rewriter, loadOp.getOperation()); + rewriter.replaceOp(loadOp, output); + return success(); + } + + // scalar + auto tag = getPrivateDTETag(rewriter, loadOp); + auto output = rewriter.create( + loc, + MemRefType::get(ArrayRef{1}, + getTypeConverter()->convertType(loadOp.getType()))); + auto offset = rewriter.create(loc, 0); + auto mask = + adaptor.getMask() + ? adaptor.getMask() + : rewriter + .create(loc, 1, rewriter.getI1Type()) + .getResult(); + auto other = + adaptor.getOther() + ? adaptor.getOther() + : triton::gcu::createConstantZero(rewriter, loc, loadOp.getType()); + if (failed(loadSingleElement(loadOp, rewriter, adaptor.getPtr(), output, + offset, tag, mask, other))) + return failure(); + auto v = rewriter.create(loc, output, ValueRange{offset}); + leaveTritionOp(rewriter, loadOp.getOperation()); + rewriter.replaceOp(loadOp, v); + return success(); + } +}; + +struct TTStoreOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + void storeSingleElement(triton::StoreOp storeOp, OpBuilder &builder, + Value ptr, Value values, Value offset, Value tag, + Value mask) const { + auto elemType = dyn_cast(ptr.getType()).getElementType(); + auto loc = storeOp.getLoc(); + + auto memType1D = + MemRefType::get(ArrayRef{ShapedType::kDynamic}, elemType); + auto buffer = builder.create(loc, memType1D, ptr); + auto one = builder.create(loc, 1); + auto zero = builder.create(loc, 0); + + auto from = builder.create( + loc, MemRefType::get(ArrayRef{1}, elemType), values, offset, + ValueRange{one}, ValueRange{one}); + auto to = builder.create( + loc, MemRefType::get(ArrayRef{1}, elemType), buffer, 0, + ArrayRef{1}, ArrayRef{1}); + builder.create(loc, mask, [&](OpBuilder &builder, Location loc) { + builder.create(loc, from, ValueRange{zero}, to, + ValueRange{zero}, one, tag, + ValueRange{zero}); + builder.create(loc, tag, ValueRange{zero}, one); + builder.create(loc); + }); + } + + LogicalResult + matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + enterTritionOp(rewriter, storeOp.getOperation()); + assert(!(isa(storeOp.getPtr().getType()) && + isa( + dyn_cast(storeOp.getPtr().getType()) + .getPointeeType()))); + + // tensor + if (isa(storeOp.getPtr().getType())) { + auto numElems = + triton::gcu::getElemsPerThread(storeOp.getPtr().getType()); + auto numElemValues = + getElemsPerThread(rewriter, loc, storeOp.getPtr().getType()); + auto values = adaptor.getValue(); + auto valueType = dyn_cast(values.getType()); + + auto offsets = syncAllocOp( + rewriter, loc, storeOp.getOperation(), userAnalysis, replaced2Origin, + MemRefType::get(valueType.getShape(), rewriter.getI32Type())); + auto masks = syncAllocOp( + rewriter, loc, storeOp.getOperation(), userAnalysis, replaced2Origin, + MemRefType::get(valueType.getShape(), rewriter.getI1Type())); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto firstIndex = SmallVector(numElems.size(), zero); + auto firstAddr = + rewriter.create(loc, adaptor.getPtr(), firstIndex); + + scf::buildLoopNest( + rewriter, loc, SmallVector(numElems.size(), zero), + numElemValues, SmallVector(numElems.size(), one), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + auto addr = + builder.create(loc, adaptor.getPtr(), iters); + auto offset = builder.create(loc, addr, firstAddr); + builder.create( + loc, + builder.create(loc, builder.getI32Type(), + offset), + offsets, iters); + + auto mask = + adaptor.getMask() + ? builder + .create(loc, adaptor.getMask(), iters) + .getResult() + : builder + .create(loc, 1, + builder.getI1Type()) + .getResult(); + builder.create(loc, mask, masks, iters); + }); + + Value totalNumElems = + rewriter.create(loc, 1).getResult(); + for (unsigned i = 0; i < numElemValues.size(); ++i) { + totalNumElems = rewriter.create(loc, totalNumElems, + numElemValues[i]); + } + + auto values1D = castToMemref1D(rewriter, loc, values, totalNumElems); + auto offsets1D = castToMemref1D(rewriter, loc, offsets, totalNumElems); + auto masks1D = castToMemref1D(rewriter, loc, masks, totalNumElems); + + rewriter.replaceOpWithNewOp( + storeOp, + rewriter.create( + loc, gcu::PtrType::get(getContext(), valueType.getElementType()), + firstAddr), + rewriter.create( + loc, gcu::PtrType::get(getContext(), valueType.getElementType()), + values1D), + rewriter.create( + loc, gcu::PtrType::get(getContext(), rewriter.getI32Type()), + offsets1D), + rewriter.create( + loc, gcu::PtrType::get(getContext(), rewriter.getI1Type()), + masks1D), + rewriter.create(loc, rewriter.getI32Type(), + totalNumElems)); + return success(); + } + + // scalar + auto tag = getPrivateDTETag(rewriter, storeOp); + auto output = rewriter.create( + loc, + MemRefType::get(ArrayRef{1}, adaptor.getValue().getType())); + auto offset = rewriter.create(loc, 0); + rewriter.create(loc, adaptor.getValue(), output, + ValueRange{offset}); + + // If the tensor is not ranked, then it is a scalar and only thread 0 can + // write + auto oneMask = + rewriter.create(loc, 1, rewriter.getI1Type()) + .getResult(); + auto zero = rewriter.create(loc, 0); + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + auto mask = adaptor.getMask() + ? adaptor.getMask() + : rewriter.create(loc, oneMask, isThread0); + doMemFence(rewriter, storeOp); + storeSingleElement(storeOp, rewriter, adaptor.getPtr(), output, offset, tag, + mask); + rewriter.create(loc, output); + leaveTritionOp(rewriter, storeOp.getOperation()); + rewriter.eraseOp(storeOp); + return success(); + } +}; + +struct TTArithSelectOpLowering + : public SharedConversionPattern { + TTArithSelectOpLowering(const TypeConverter &converter, MLIRContext *ctx, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) + : SharedConversionPattern(converter, ctx, userAnalysis, + replaced2Origin) {} + + LogicalResult matchAndRewrite( + arith::SelectOp op, + typename SharedConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + auto type = op.getType(); + if (!isa(type)) { + leaveTritionOp(rewriter, op.getOperation()); + return failure(); + } + auto ty = this->getTypeConverter()->convertType(type); + auto newOp = rewriter.create( + loc, ty, adaptor.getOperands(), op->getAttrs()); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +template +struct TTElementwiseOpLowering : public SharedConversionPattern { + TTElementwiseOpLowering(const TypeConverter &converter, MLIRContext *ctx, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin) + : SharedConversionPattern(converter, ctx, userAnalysis, + replaced2Origin) {} + + LogicalResult + matchAndRewrite(FT op, + typename SharedConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + auto type = op.getType(); + if (!isa(type)) { + auto ty = this->getTypeConverter()->convertType(type); + rewriter.replaceOpWithNewOp(op, ty, adaptor.getOperands()); + leaveTritionOp(rewriter, op.getOperation()); + return success(); + } + auto lastUser = this->userAnalysis.getLastUserOp(op.getOperation()); + auto numElems = triton::gcu::getElemsPerThread(type); + auto resultType = + dyn_cast(this->getTypeConverter()->convertType(type)); + auto output = syncAllocOp(rewriter, loc, lastUser, this->userAnalysis, + this->replaced2Origin, resultType); + affine::buildAffineLoopNest( + rewriter, loc, SmallVector(numElems.size(), 0), + SmallVector(numElems.begin(), numElems.end()), + SmallVector(numElems.size(), 1), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + SmallVector operands; + for (auto operand : adaptor.getOperands()) { + operands.push_back( + builder.create(loc, operand, iters)); + } + auto v = builder.create(loc, resultType.getElementType(), + operands, op->getAttrs()); + builder.create(loc, v, output, iters); + }); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } +}; + +struct TTBitcastOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + enterTritionOp(rewriter, op.getOperation()); + auto type = this->getTypeConverter()->convertType(op.getType()); + if (!isa(op.getType())) { + // arith.bitcast doesn't support pointers + if (isa(op.getSrc().getType()) && + isa(op.getResult().getType())) { + auto result = rewriter.create( + loc, type, rewriter.create(loc, adaptor.getSrc())); + rewriter.replaceOp(op, result); + return success(); + } else { + rewriter.replaceOpWithNewOp(op, type, + adaptor.getSrc()); + return success(); + } + } + + auto dstType = dyn_cast(type); + auto srcType = dyn_cast(adaptor.getSrc().getType()); + + if (dstType.getNumElements() != srcType.getNumElements()) + return op.emitOpError("src and dst element number mismatch"); + + auto totalNumElems = rewriter.create( + loc, triton::gcu::getTotalElemsPerThread(op.getSrc().getType())); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto srcBuf = rewriter.create( + loc, + MemRefType::get(ArrayRef{ShapedType::kDynamic}, + srcType.getElementType()), + adaptor.getSrc(), zero, ArrayRef{totalNumElems}, + ArrayRef{one}); + auto srcPtrType = gcu::PtrType::get(getContext(), srcType.getElementType()); + auto srcPtr = rewriter.create(loc, srcPtrType, srcBuf); + auto ptrInt = rewriter.create(loc, srcPtr); + auto dstPtrType = gcu::PtrType::get(getContext(), dstType.getElementType()); + auto dstPtr = rewriter.create(loc, dstPtrType, ptrInt); + auto dstBuf = rewriter.create( + loc, + MemRefType::get(ArrayRef{ShapedType::kDynamic}, + dstType.getElementType()), + dstPtr); + auto [strides, offset] = dstType.getStridesAndOffset(); + auto dst = rewriter.create( + loc, dstType, dstBuf, offset, dstType.getShape(), strides); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, dst); + return success(); + } +}; + +struct TTScanOpLowering : SharedConversionPattern { + unsigned vectorSizeInBytes; + TTScanOpLowering(const TypeConverter &converter, MLIRContext *ctx, + triton::gcu::FirstLastUserAnalysis &userAnalysis, + std::map &replaced2Origin, + unsigned vectorSizeInBytes) + : SharedConversionPattern(converter, ctx, userAnalysis, replaced2Origin), + vectorSizeInBytes(vectorSizeInBytes) {} + + void applyScan(triton::ScanOp op, OpBuilder &rewriter, + ArrayRef outputs, ArrayRef inputs, Type type, + bool reverse) const { + auto axis = op.getAxis(); + auto loc = op.getLoc(); + auto numElems = triton::gcu::getElemsPerThread(type); + auto numOutput = outputs.size(); + auto totalNumElems = triton::gcu::getTotalElemsPerThread(type); + auto tag = getPrivateDTETag(rewriter, op); + auto zero = rewriter.create(loc, 0); + + // initialize outputs by inputs + for (unsigned i = 0; i < numOutput; ++i) { + rewriter.create( + loc, inputs[i], SmallVector(numElems.size(), zero), + outputs[i], SmallVector(numElems.size(), zero), + rewriter.create(loc, totalNumElems), tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + } + + std::array scanInOutDims = {1, 1, 1}; + int64_t scanAxis = 2; + for (int i = numElems.size() - 1, j = 2; i >= 0; i--) { + if (static_cast(i) == axis) { + if (scanInOutDims[j] == 1) { + scanInOutDims[j] = numElems[i]; + } else { + scanInOutDims[--j] = numElems[i]; + } + scanAxis = j; + --j; + } else { + scanInOutDims[j] *= numElems[i]; + } + } + SmallVector outs; + llvm::transform(outputs, std::back_inserter(outs), [&](auto output) { + return rewriter.create( + loc, + MemRefType::get(scanInOutDims, + cast(output.getType()).getElementType()), + output, ValueRange{}, ValueRange{}, ValueRange{}, + ArrayRef{0}, + ArrayRef{scanInOutDims[0], scanInOutDims[1], + scanInOutDims[2]}, + ArrayRef{scanInOutDims[1] * scanInOutDims[2], + scanInOutDims[2], 1}); + }); + if (succeeded(applyGeneralScan(op, rewriter, outs, scanInOutDims, scanAxis, + reverse))) { + return; + } + return applyScanFallback(op, rewriter, outs, scanInOutDims, scanAxis, + reverse); + } + + LogicalResult applyGeneralScan(triton::ScanOp op, OpBuilder &rewriter, + ArrayRef outputs, + const std::array &scanInOutDims, + int64_t scanAxis, bool reverse) const { + auto loc = op.getLoc(); + int64_t vectorizeAxis; + if (scanAxis == 2) { + assert(scanInOutDims[0] == 1); + vectorizeAxis = 1; + } else { + assert(scanAxis == 1); + vectorizeAxis = scanInOutDims[0] > scanInOutDims[2] ? 0 : 2; + } + unsigned bpe = 4; // gatherscatter offset, i32 + for (auto output : outputs) { + auto elementTy = cast(output.getType()).getElementType(); + auto bytes = triton::gcu::getBpe(elementTy); + bpe = bytes > bpe ? bytes : bpe; + } + auto vectorLength = vectorSizeInBytes / bpe; + if (scanInOutDims[vectorizeAxis] < vectorLength) { + return failure(); + } + auto numOutput = outputs.size(); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + SmallVector vectorTypes; + llvm::transform( + outputs, std::back_inserter(vectorTypes), [vectorLength](auto output) { + auto elementTy = cast(output.getType()).getElementType(); + return VectorType::get(ArrayRef{vectorLength}, elementTy); + }); + + SmallVector lbs(scanInOutDims.size(), zero); + lbs[scanAxis] = one; + std::array loopcnt = scanInOutDims; + if (loopcnt[vectorizeAxis] % vectorLength != 0) { + llvm_unreachable("invalid datalayout"); + } + loopcnt[vectorizeAxis] /= vectorLength; + SmallVector ubs{ + rewriter.create(loc, loopcnt[0]), + rewriter.create(loc, loopcnt[1]), + rewriter.create(loc, loopcnt[2])}; + SmallVector step(scanInOutDims.size(), one); + + auto maskType = + VectorType::get(ArrayRef{vectorLength}, rewriter.getI1Type()); + Value mask = rewriter.create( + loc, maskType, + DenseI64ArrayAttr::get(rewriter.getContext(), + ArrayRef{vectorLength})); + unsigned strideOnVectorizeAxis = + std::accumulate(scanInOutDims.begin() + vectorizeAxis + 1, + scanInOutDims.end(), 1, std::multiplies()); + auto vecTy = + VectorType::get(ArrayRef{vectorLength}, rewriter.getI32Type()); + auto indexVec = rewriter.create( + loc, + rewriter + .create( + loc, vecTy, + rewriter + .create( + loc, VectorType::get(ArrayRef{vectorLength}, + rewriter.getIndexType())) + .getResult()) + .getResult(0), + rewriter.create( + loc, vecTy, + rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(strideOnVectorizeAxis)))); + + SmallVector passThruValues; + for (unsigned i = 0; i < numOutput; ++i) { + passThruValues.push_back(rewriter.create( + loc, vectorTypes[i], + rewriter.create( + loc, vectorTypes[i].getElementType(), + rewriter.getZeroAttr(vectorTypes[i].getElementType())))); + } + + scf::buildLoopNest( + rewriter, loc, + ArrayRef(lbs.begin(), lbs.begin() + vectorizeAxis), + ArrayRef(ubs.begin(), ubs.begin() + vectorizeAxis), + ArrayRef(step.begin(), step.begin() + vectorizeAxis), + [&](OpBuilder &builder, Location loc, ValueRange outerIters) { + scf::buildLoopNest( + rewriter, loc, + ArrayRef(lbs.begin() + vectorizeAxis, lbs.end()), + ArrayRef(ubs.begin() + vectorizeAxis, ubs.end()), + ArrayRef(step.begin() + vectorizeAxis, step.end()), + [&](OpBuilder &builder, Location loc, ValueRange innerIters) { + SmallVector inputIndices; + SmallVector outputIndices; + + SmallVector resultElemTypes; + SmallVector operands; + SmallVector ivs; + for (auto iv : outerIters) { + ivs.push_back(iv); + } + for (auto iv : innerIters) { + ivs.push_back(iv); + } + if (reverse) { + ivs[scanAxis] = builder.create( + loc, + builder.create( + loc, scanInOutDims[scanAxis] - 1), + ivs[scanAxis]); + } + for (unsigned i = 0; i < ivs.size(); ++i) { + if (i == vectorizeAxis) { + outputIndices.push_back(builder.create( + loc, ivs[i], + rewriter.create(loc, + vectorLength))); + } else { + outputIndices.push_back(ivs[i]); + } + if (i == scanAxis) { + if (reverse) { + inputIndices.push_back(builder.create( + loc, outputIndices[i], one)); + } else { + inputIndices.push_back(builder.create( + loc, outputIndices[i], one)); + } + } else { + inputIndices.push_back(outputIndices[i]); + } + } + + for (unsigned i = 0; i < numOutput; ++i) { + operands.push_back(builder.create( + loc, vectorTypes[i], outputs[i], inputIndices, indexVec, + mask, passThruValues[i])); + } + for (unsigned i = 0; i < numOutput; ++i) { + operands.push_back(builder.create( + loc, vectorTypes[i], outputs[i], outputIndices, indexVec, + mask, passThruValues[i])); + resultElemTypes.push_back(vectorTypes[i]); + } + + auto executeRegionOp = + builder.create(loc, resultElemTypes); + executeRegionOp.getRegion().emplaceBlock(); + IRMapping map; + for (auto [arg, operand] : + llvm::zip(op.getCombineOp().getArguments(), operands)) { + map.map(arg, operand); + } + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart( + &executeRegionOp.getRegion().back()); + for (auto &o : op.getCombineOp().back()) { + for (auto operand : o.getOperands()) { + if (auto constantOp = + operand.getDefiningOp()) { + if (!map.lookupOrNull(operand)) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(constantOp); + map.map(operand, + builder.create( + loc, + VectorType::get( + ArrayRef{vectorLength}, + operand.getType()), + operand)); + } + } + } + auto newO = builder.clone(o, map); + for (auto [result, newResult] : + llvm::zip(o.getResults(), newO->getResults())) { + auto vectorTy = VectorType::get( + ArrayRef{vectorLength}, result.getType()); + newResult.setType(vectorTy); + map.map(result, newResult); + } + } + } + + for (unsigned i = 0; i < numOutput; ++i) { + builder.create( + loc, outputs[i], outputIndices, indexVec, mask, + executeRegionOp.getResult(i)); + } + }); + }); + doMemFence(rewriter, op); + return success(); + } + + void applyScanFallback(triton::ScanOp op, OpBuilder &rewriter, + ArrayRef outputs, + const std::array &scanInOutDims, + int64_t scanAxis, bool reverse) const { + auto loc = op.getLoc(); + auto numOutput = outputs.size(); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + + SmallVector lbs(scanInOutDims.size(), zero); + lbs[scanAxis] = one; + SmallVector ubs{ + rewriter.create(loc, scanInOutDims[0]), + rewriter.create(loc, scanInOutDims[1]), + rewriter.create(loc, scanInOutDims[2])}; + + scf::buildLoopNest( + rewriter, loc, lbs, ubs, + SmallVector(scanInOutDims.size(), one), + [&](OpBuilder &builder, Location loc, ValueRange iters) { + SmallVector outputIters(iters.begin(), iters.end()); + if (reverse) { + outputIters[scanAxis] = builder.create( + loc, + builder.create( + loc, scanInOutDims[scanAxis] - 1), + outputIters[scanAxis]); + } + + SmallVector operands; + SmallVector resultElemTypes; + SmallVector inputIters(outputIters.begin(), + outputIters.end()); + if (reverse) { + inputIters[scanAxis] = + builder.create(loc, one, inputIters[scanAxis]); + } else { + inputIters[scanAxis] = + builder.create(loc, inputIters[scanAxis], one); + } + + for (unsigned i = 0; i < numOutput; ++i) { + operands.push_back( + builder.create(loc, outputs[i], inputIters)); + } + for (unsigned i = 0; i < numOutput; ++i) { + operands.push_back( + builder.create(loc, outputs[i], outputIters)); + resultElemTypes.push_back(operands.back().getType()); + } + + auto executeRegion = + builder.create(loc, resultElemTypes); + executeRegion.getRegion().emplaceBlock(); + IRMapping map; + for (auto [arg, operand] : + llvm::zip(op.getCombineOp().getArguments(), operands)) { + map.map(arg, operand); + } + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&executeRegion.getRegion().back()); + for (auto &o : op.getCombineOp().back()) { + auto newO = builder.clone(o, map); + for (auto [result, newResult] : + llvm::zip(o.getResults(), newO->getResults())) { + map.map(result, newResult); + } + } + } + + for (unsigned i = 0; i < numOutput; ++i) { + builder.create(loc, executeRegion.getResult(i), + outputs[i], outputIters); + } + }); + + doMemFence(rewriter, op); + } + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + enterTritionOp(rewriter, op.getOperation()); + auto inputType = dyn_cast(op.getSrcs()[0].getType()); + + auto slicedAxies = getSlicedAxies(inputType); + bool isScanDimSplit = slicedAxies.count(op.getAxis()); + + auto numInput = op.getSrcs().size(); + auto numOutput = op.getResults().size(); + + auto zero = rewriter.create(loc, 0); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + // create outputs + SmallVector outputs; + SmallVector outputElemTypes; + for (unsigned i = 0; i < numOutput; ++i) { + auto resultType = + dyn_cast(getTypeConverter()->convertType(op.getType(i))); + auto elemType = resultType.getElementType(); + outputElemTypes.push_back(elemType); + Value output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + outputs.push_back(output); + } + auto encodingAttr = dyn_cast(inputType).getEncoding(); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(encodingAttr); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(encodingAttr); + auto elementsPerThread = triton::gcu::getElemsPerThread(inputType); + bool isValidBlockEncoding = true; + for (auto [dim, elems, threads, warps] : + llvm::zip(inputType.getShape(), elementsPerThread, threadsPerWarp, + warpsPerCTA)) { + if (dim != elems * threads * warps) { + isValidBlockEncoding = false; + break; + } + } + if (isScanDimSplit || !isValidBlockEncoding) { + auto tag = getPrivateDTETag(rewriter, op); + + // move to shared memory + SmallVector sharedInputs; + for (unsigned i = 0; i < numInput; ++i) { + sharedInputs.push_back(storeToSharedMem( + rewriter, tag, + dyn_cast(op.getSrcs()[i].getType()), + adaptor.getSrcs()[i], false, op.getOperation(), userAnalysis, + replaced2Origin)); + } + + // load all shared memory to thread 0 + SmallVector mergedInputs; + RankedTensorType mergedInputType; + for (unsigned i = 0; i < numInput; ++i) { + auto tType = dyn_cast(op.getSrcs()[i].getType()); + auto tensorType = + RankedTensorType::get(tType.getShape(), tType.getElementType(), + triton::gpu::getDefaultBlockedEncoding( + getContext(), tType.getShape(), 1, 1, 1)); + mergedInputType = tensorType; + mergedInputs.push_back(loadFromSharedMem( + rewriter, tag, tensorType, sharedInputs[i], true, op.getOperation(), + nullptr, userAnalysis, replaced2Origin)); + } + + SmallVector mergedOutputs; + for (unsigned i = 0; i < numOutput; ++i) { + auto tType = dyn_cast(op.getResultTypes()[i]); + auto tensorType = + RankedTensorType::get(tType.getShape(), tType.getElementType(), + triton::gpu::getDefaultBlockedEncoding( + getContext(), tType.getShape(), 1, 1, 1)); + auto resultType = + dyn_cast(getTypeConverter()->convertType(tensorType)); + mergedOutputs.push_back(syncAllocOp(rewriter, loc, op.getOperation(), + userAnalysis, replaced2Origin, + resultType)); + } + + // computing in thread 0 + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + rewriter.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + applyScan(op, builder, mergedOutputs, mergedInputs, mergedInputType, + op.getReverse()); + builder.create(loc); + }); + + // save back to shared memory + SmallVector mergedSharedOutputs; + for (unsigned i = 0; i < numOutput; ++i) { + auto tType = dyn_cast(op.getResultTypes()[i]); + auto tensorType = + RankedTensorType::get(tType.getShape(), outputElemTypes[i], + triton::gpu::getDefaultBlockedEncoding( + getContext(), tType.getShape(), 1, 1, 1)); + mergedSharedOutputs.push_back( + storeToSharedMem(rewriter, tag, tensorType, mergedOutputs[i], true, + op.getOperation(), userAnalysis, replaced2Origin)); + } + // load from shared memory + for (unsigned i = 0; i < numOutput; ++i) { + outputs[i] = loadFromSharedMem(rewriter, tag, op.getResultTypes()[i], + mergedSharedOutputs[i], false, lastUser, + nullptr, userAnalysis, replaced2Origin); + } + } else { + applyScan(op, rewriter, outputs, + SmallVector(adaptor.getSrcs().begin(), + adaptor.getSrcs().end()), + inputType, op.getReverse()); + } + + SmallVector finalOutputs; + for (unsigned i = 0; i < numOutput; ++i) { + auto output = outputs[i]; + auto resultType = dyn_cast( + getTypeConverter()->convertType(op.getResultTypes()[i])); + if (resultType.getNumElements() != + dyn_cast(output.getType()).getNumElements()) { + return op.emitOpError("element number mismatch"); + } + auto [strides, offset] = resultType.getStridesAndOffset(); + output = rewriter.create( + loc, resultType, output, offset, resultType.getShape(), strides); + finalOutputs.push_back(output); + } + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, finalOutputs); + return success(); + } +}; + +struct TTReduceReturnOpLowering + : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceReturnOp returnOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(returnOp, returnOp.getOperands()); + return success(); + } +}; + +struct TTScanReturnOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::ScanReturnOp returnOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(returnOp, returnOp.getOperands()); + return success(); + } +}; + +struct TTExternElemwiseOpLowering + : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto name = op.getSymbol(); + if (name == "__nv_fmaxf") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_fminf") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_floorf") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_min") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_max") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_umin") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_umax") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_powf") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_log2f") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_exp2f") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__nv_rsqrtf") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__gcu_begin_clock") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } else if (name == "__gcu_end_clock") { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } + return failure(); + } +}; + +struct TTHistogramOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::HistogramOp histogramOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = histogramOp.getLoc(); + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0)); + auto zeroIndex = rewriter.create(loc, 0); + auto oneIndex = rewriter.create(loc, 1); + enterTritionOp(rewriter, histogramOp.getOperation()); + auto lastUser = userAnalysis.getLastUserOp(histogramOp.getOperation()); + auto tag = getPrivateDTETag(rewriter, histogramOp); + auto resultType = histogramOp.getType(); + auto wrapResultType = + dyn_cast(getTypeConverter()->convertType(resultType)); + auto resultMemRefType = + MemRefType::get(resultType.getShape(), wrapResultType.getElementType()); + auto totalNumElems = triton::gcu::getTotalElemsPerThread(resultType); + auto resCurWarp = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultMemRefType); + doMemset(rewriter, histogramOp, resCurWarp, zero, totalNumElems); + auto encoding = resultType.getEncoding(); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(encoding); + auto sharedMemTensorType = RankedTensorType::get( + ArrayRef{resultType.getShape()[0] * warpsPerCTA[0]}, + wrapResultType.getElementType(), encoding); + rewriter.create(loc, resCurWarp, adaptor.getSrc()); + /// store res of every warp to shared memry + auto sharedResMem = storeToSharedMem( + rewriter, tag, sharedMemTensorType, resCurWarp, false, + histogramOp.getOperation(), userAnalysis, replaced2Origin); + rewriter.create(loc, resCurWarp); + size_t allResSize = resultType.getShape()[0]; + size_t warpResSize = wrapResultType.getShape()[0]; + auto finalOutput = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, wrapResultType); + doMemset(rewriter, histogramOp, finalOutput, zero, totalNumElems); + size_t warpsWalkNum = warpsPerCTA[0]; + // if input can't be divided by warp, do not calculate sum of res of every + // warp + if (dyn_cast(histogramOp.getOperand().getType()).getShape()[0] < + warpsPerCTA[0]) + warpsWalkNum = dyn_cast(histogramOp.getOperand().getType()) + .getShape()[0]; + /// Compute the results in shared memory based on the output each warp + /// should produce + auto warpIdsOfRes = getWarpIds(rewriter, loc, resultType); + scf::buildLoopNest( + rewriter, loc, SmallVector{zeroIndex}, + SmallVector{ + rewriter.create(loc, warpResSize)}, + SmallVector{oneIndex}, + [&](OpBuilder &builder, Location loc, ValueRange gramIndex) { + auto res = builder.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0)); + SmallVector iterArgs = {res}; + builder.create( + loc, zeroIndex, + builder.create(loc, warpsWalkNum), + oneIndex, iterArgs, + [&](OpBuilder &builder, Location loc, Value warpId, + ValueRange sum) { + auto baseIndexOfRes = builder.create( + loc, warpIdsOfRes[0], + builder.create(loc, warpResSize)); + auto index = builder.create( + loc, + builder.create(loc, gramIndex[0], + baseIndexOfRes), + builder.create( + loc, warpId, + builder.create(loc, + allResSize))); + auto warpRes = builder.create( + loc, sharedResMem, SmallVector{index}); + Value newSum = + builder.create(loc, sum[0], warpRes); + builder.create(loc, newSum, finalOutput, + gramIndex[0]); + builder.create(loc, ValueRange{newSum}); + }); + }); + rewriter.create(loc); + leaveTritionOp(rewriter, histogramOp.getOperation()); + rewriter.replaceOp(histogramOp, finalOutput); + return success(); + } +}; + +struct GCULoadOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::gcu::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, loadOp.getOperation()); + auto loc = loadOp.getLoc(); + auto loadType = loadOp.getType(); + + if (!isa(loadType)) + return failure(); + + auto originOp = loadOp.getOperation(); + if (replaced2Origin.count(originOp) != 0) { + originOp = replaced2Origin[originOp]; + } + auto lastUser = userAnalysis.getLastUserOp(originOp); + + auto zero = rewriter.create(loc, 0); + auto elemType = loadOp.getPtr().getType().getElementType(); + auto resultType = + dyn_cast(getTypeConverter()->convertType(loadType)); + bool IsShareOutput = false; // output is shared layout + if (auto tType = dyn_cast(loadType)) + if (mlir::isa(tType.getEncoding())) + IsShareOutput = true; + + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + auto outTransType = + MemRefType::get(resultType.getShape(), resultType.getElementType()); + auto outTrans = syncAllocOp(rewriter, loc, loadOp.getOperation(), + userAnalysis, replaced2Origin, outTransType); + + auto tagShare = getShareDTETag(rewriter, loadOp); + auto tagPrivate = getPrivateDTETag(rewriter, loadOp); + auto tagDte = IsShareOutput ? tagShare : tagPrivate; + auto defaultValue = + loadOp.getDefaultValue() + ? adaptor.getDefaultValue() + : triton::gcu::createConstantZero(rewriter, loc, elemType); + + // workaround for offset > tensor dims + int64_t rank = resultType.getRank(); + Value shapeCheck = rewriter.create( + loc, arith::CmpIPredicate::sgt, adaptor.getShape()[0], zero); + for (unsigned i = 1; i < rank; ++i) { + auto dimCheck = rewriter.create( + loc, arith::CmpIPredicate::sgt, adaptor.getShape()[i], zero); + shapeCheck = rewriter.create(loc, shapeCheck, dimCheck); + } + + auto total_size = + rewriter + .create( + loc, shapeCheck, + [&](OpBuilder builder, Location loc) { + auto load_size = + ConfigGcuLoad(builder, loc, output, outTrans, loadOp, + resultType, adaptor.getPtr(), + adaptor.getStrides(), adaptor.getShape(), + defaultValue, tagDte, zero, IsShareOutput); + builder.create(loc, ValueRange{load_size}); + }, + [&](OpBuilder &builder, Location loc) { + auto totalNumElems = + triton::gcu::getTotalElemsPerThread(loadType); + doMemset(builder, loadOp, output, defaultValue, + totalNumElems); + if (triton::gcu::get_bool_env("TRITON_GCU_DEBUG")) { + std::string locStr = "[warning]: load offset is out of " + "range for tensor. loc:"; + if (auto fileLineColLoc = dyn_cast(loc)) { + llvm::StringRef filename = fileLineColLoc.getFilename(); + locStr += filename.str(); + locStr += ":"; + locStr += std::to_string(fileLineColLoc.getLine()); + } + builder.create(loc, locStr, ValueRange{}); + } + builder.create(loc, ValueRange{zero}); + }) + .getResult(0); + if (IsShareOutput) { + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + auto isAll = rewriter.create(loc, isThread0, shapeCheck); + rewriter.create( + loc, isAll, [&](OpBuilder builder, Location loc) { + WaitGcuLoadStore(builder, loc, tagDte, zero, total_size); + builder.create(loc); + }); + rewriter.create(loc); + } else { + rewriter.create( + loc, shapeCheck, [&](OpBuilder builder, Location loc) { + WaitGcuLoadStore(builder, loc, tagDte, zero, total_size); + builder.create(loc); + }); + } + + leaveTritionOp(rewriter, loadOp.getOperation()); + rewriter.replaceOp(loadOp, output); + return success(); + } +}; + +struct GCUStoreOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::gcu::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, storeOp.getOperation()); + + bool isLastOp = true; + auto loc = storeOp.getLoc(); + auto storeType = storeOp.getValue().getType(); + if (!isa(storeType)) + return failure(); + + auto zero = rewriter.create(loc, 0); + auto storeValueType = + dyn_cast(getTypeConverter()->convertType(storeType)); + auto storeTransType = MemRefType::get(storeValueType.getShape(), + storeValueType.getElementType()); + auto storeTrans = syncAllocOp(rewriter, loc, nullptr, userAnalysis, + replaced2Origin, storeTransType); + auto tagDte = isLastOp ? getPrivateDTETag(rewriter, storeOp) + : createPrivateDTETag(rewriter, storeOp); + + // workaround for offset > tensor dims + int64_t rank = storeType.getRank(); + Value shapeCheck = rewriter.create( + loc, arith::CmpIPredicate::sgt, adaptor.getShape()[0], zero); + for (unsigned i = 1; i < rank; ++i) { + auto dimCheck = rewriter.create( + loc, arith::CmpIPredicate::sgt, adaptor.getShape()[i], zero); + shapeCheck = rewriter.create(loc, shapeCheck, dimCheck); + } + auto total_size = + rewriter + .create( + loc, shapeCheck, + [&](OpBuilder builder, Location loc) { + auto store_size = ConfigGcuStore( + rewriter, loc, adaptor.getValue(), storeTrans, storeOp, + storeValueType, adaptor.getPtr(), adaptor.getStrides(), + adaptor.getShape(), tagDte, zero); + builder.create(loc, ValueRange{store_size}); + }, + [&](OpBuilder &builder, Location loc) { + if (triton::gcu::get_bool_env("TRITON_GCU_DEBUG")) { + std::string locStr = "[warning]: store offset is out of " + "range for tensor. loc:"; + if (auto fileLineColLoc = dyn_cast(loc)) { + llvm::StringRef filename = fileLineColLoc.getFilename(); + locStr += filename.str(); + locStr += ":"; + locStr += std::to_string(fileLineColLoc.getLine()); + } + builder.create(loc, locStr, ValueRange{}); + } + builder.create(loc, ValueRange{zero}); + }) + .getResult(0); + auto isNotZero = rewriter.create( + loc, arith::CmpIPredicate::ne, total_size, zero); + if (!isLastOp) { + auto &lastOp = storeOp.getOperation()->getBlock()->back(); + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(&lastOp); + auto ifOp = rewriter.create( + loc, isNotZero, [&](OpBuilder builder, Location loc) { + WaitGcuLoadStore(builder, loc, tagDte, zero, total_size); + builder.create(loc); + }); + + if (!storeTrans.getUsers().empty()) { + rewriter.create(loc, storeTrans); + } else { + rewriter.eraseOp(storeTrans.getDefiningOp()); + } + + rewriter.restoreInsertionPoint(ip); + moveDeallocOp(rewriter, adaptor.getValue(), ifOp, 0); + } else { + rewriter.create( + loc, isNotZero, [&](OpBuilder builder, Location loc) { + WaitGcuLoadStore(builder, loc, tagDte, zero, total_size); + builder.create(loc); + }); + + if (!storeTrans.getUsers().empty()) { + rewriter.create(loc, storeTrans); + } else { + rewriter.eraseOp(storeTrans.getDefiningOp()); + } + } + + leaveTritionOp(rewriter, storeOp.getOperation()); + rewriter.eraseOp(storeOp); + return success(); + } +}; + +struct TTGAssertOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::gcu::AssertOp assertOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = assertOp.getLoc(); + + auto condition = adaptor.getCondition(); + auto message = assertOp.getMessage(); + auto file = assertOp.getFile(); + auto func = assertOp.getFunc(); + auto line = assertOp.getLine(); + + // Create gcu.assert op + rewriter.create(loc, condition, message, file, func, line); + rewriter.eraseOp(assertOp); + return success(); + } +}; + +struct TTBroadcastOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto srcType = op.getSrc().getType(); + auto resultType = op.getType(); + auto rank = srcType.getRank(); + auto wrapSrcType = + dyn_cast(getTypeConverter()->convertType(srcType)); + auto wrapResultType = + dyn_cast(getTypeConverter()->convertType(resultType)); + auto elementType = wrapResultType.getElementType(); + + auto loc = op.getLoc(); + auto tag = getPrivateDTETag(rewriter, op); + auto zero = rewriter.create(loc, 0); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + + auto srcTy = dyn_cast(srcType); + auto dstTy = dyn_cast(resultType); + if ((!srcTy) || (!dstTy)) { + assert(false && "srcTy or dstTy not a RankedTensorType"); + } + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + + DenseSet broadcastedAxies; + for (unsigned i = 0; i < rank; ++i) { + if (srcType.getDimSize(i) != resultType.getDimSize(i)) { + if (wrapSrcType.getShape()[i] != wrapResultType.getShape()[i]) { + broadcastedAxies.insert(i); + } + } + } + // broadcast per thread + if (srcLayout == dstLayout) { + auto broadcastedAxiesNum = broadcastedAxies.size(); + if (broadcastedAxiesNum == 0) { + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + ArrayRef srcShape = wrapSrcType.getShape(); + auto src_input = adaptor.getSrc(); + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, wrapResultType); + SmallVector broadcastShape(rank, 1); + for (unsigned i = 0; i < rank; ++i) + broadcastShape[i] = srcShape[i]; + unsigned idx = 0; + for (auto dim : broadcastedAxies) { + auto temp_out = output; + if (idx != broadcastedAxiesNum - 1) { + broadcastShape[dim] = wrapResultType.getDimSize(dim); + auto memrefType = MemRefType::get(broadcastShape, elementType); + temp_out = syncAllocOp(rewriter, loc, op.getOperation(), userAnalysis, + replaced2Origin, memrefType); + } + + auto src = src_input; + auto dst = temp_out; + if (rank > 3) { // reshape to rank 3 to broadcast + ArrayRef beforeSrcShapes = + dyn_cast(src_input.getType()).getShape(); + ArrayRef beforeDstShapes = + dyn_cast(temp_out.getType()).getShape(); + SmallVector afterSrcShapes; + SmallVector afterDstShapes; + if (dim > 0) { + int64_t tShape = std::accumulate(beforeSrcShapes.begin(), + beforeSrcShapes.begin() + dim, 1, + std::multiplies()); + afterSrcShapes.push_back(tShape); + } + afterSrcShapes.push_back(beforeSrcShapes[dim]); + int64_t tShape = std::accumulate(beforeSrcShapes.begin() + dim + 1, + beforeSrcShapes.end(), 1, + std::multiplies()); + afterSrcShapes.push_back(tShape); + if (dim > 0) { + int64_t tShape = std::accumulate(beforeDstShapes.begin(), + beforeDstShapes.begin() + dim, 1, + std::multiplies()); + afterDstShapes.push_back(tShape); + } + afterDstShapes.push_back(beforeDstShapes[dim]); + tShape = std::accumulate(beforeDstShapes.begin() + dim + 1, + beforeDstShapes.end(), 1, + std::multiplies()); + afterDstShapes.push_back(tShape); + + auto afterSrcMemrefType = + MemRefType::get(afterSrcShapes, elementType); + auto afterDstMemrefType = + MemRefType::get(afterDstShapes, elementType); + + auto [srcStrides, srcOffset] = + afterSrcMemrefType.getStridesAndOffset(); + src = rewriter.create( + loc, afterSrcMemrefType, src_input, srcOffset, afterSrcShapes, + srcStrides); + auto [dstStrides, dstOffset] = + afterDstMemrefType.getStridesAndOffset(); + dst = rewriter.create( + loc, afterDstMemrefType, temp_out, dstOffset, afterDstShapes, + dstStrides); + } + auto totalNumElems = triton::gcu::getTotalElemsPerThread(srcType); + rewriter.create(loc, dst, src, tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + + src_input = temp_out; + idx++; + } + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } + // move source to shared memory + auto sharedSrc = + storeToSharedMem(rewriter, tag, srcType, adaptor.getSrc(), false, + op.getOperation(), userAnalysis, replaced2Origin); + auto mergedResultType = + MemRefType::get(resultType.getShape(), elementType, AffineMap{}, + rewriter.getI64IntegerAttr(2) /*shared memory*/); + auto mergedOutput = + syncAllocOp(rewriter, loc, op.getOperation(), userAnalysis, + replaced2Origin, mergedResultType); + auto totalNumElems = triton::gcu::getTotalElemsPerThread(srcType); + // broadcast in thread 0 + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + ArrayRef srcShape = srcType.getShape(); + auto src_input = sharedSrc; + + SmallVector broadcastShape(rank, 1); + for (unsigned i = 0; i < rank; ++i) + broadcastShape[i] = srcShape[i]; + + unsigned idx = 0; + for (auto dim : broadcastedAxies) { + auto temp_out = mergedOutput; + if (idx != broadcastedAxies.size() - 1) { + broadcastShape[dim] = resultType.getDimSize(dim); + auto tempMemrefType = + MemRefType::get(broadcastShape, elementType, AffineMap{}, + rewriter.getI64IntegerAttr(2) /*shared memory*/); + temp_out = syncAllocOp(rewriter, loc, op.getOperation(), userAnalysis, + replaced2Origin, tempMemrefType); + } + + auto src = src_input; + auto dst = temp_out; + if (rank > 3) { // reshape to rank 3 to broadcast + ArrayRef beforeSrcShapes = + dyn_cast(src_input.getType()).getShape(); + ArrayRef beforeDstShapes = + dyn_cast(temp_out.getType()).getShape(); + SmallVector afterSrcShapes; + SmallVector afterDstShapes; + + int64_t tShape = std::accumulate(beforeSrcShapes.begin(), + beforeSrcShapes.begin() + dim, 1, + std::multiplies()); + afterSrcShapes.push_back(tShape); + afterSrcShapes.push_back(beforeSrcShapes[dim]); + tShape = std::accumulate(beforeSrcShapes.begin() + dim + 1, + beforeSrcShapes.end(), 1, + std::multiplies()); + afterSrcShapes.push_back(tShape); + + tShape = std::accumulate(beforeDstShapes.begin(), + beforeDstShapes.begin() + dim, 1, + std::multiplies()); + afterDstShapes.push_back(tShape); + afterDstShapes.push_back(beforeDstShapes[dim]); + tShape = std::accumulate(beforeDstShapes.begin() + dim + 1, + beforeDstShapes.end(), 1, + std::multiplies()); + afterDstShapes.push_back(tShape); + + auto afterSrcMemrefType = + MemRefType::get(afterSrcShapes, elementType, AffineMap{}, + rewriter.getI64IntegerAttr(2) /*shared memory*/); + auto afterDstMemrefType = + MemRefType::get(afterDstShapes, elementType, AffineMap{}, + rewriter.getI64IntegerAttr(2) /*shared memory*/); + + auto [srcStrides, srcOffset] = afterSrcMemrefType.getStridesAndOffset(); + src = rewriter.create( + loc, afterSrcMemrefType, src_input, srcOffset, afterSrcShapes, + srcStrides); + auto [dstStrides, dstOffset] = afterDstMemrefType.getStridesAndOffset(); + dst = rewriter.create( + loc, afterDstMemrefType, temp_out, dstOffset, afterDstShapes, + dstStrides); + } + + rewriter.create( + loc, isThread0, [&](OpBuilder &rewriter, Location loc) { + rewriter.create(loc, dst, src, tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + rewriter.create(loc); + }); + src_input = temp_out; + idx++; + } + rewriter.create(loc); + // read back + auto output = + loadFromSharedMem(rewriter, tag, resultType, mergedOutput, false, + lastUser, nullptr, userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } +}; + +struct TTExpandDimsOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + auto resultType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto srcNumElems = triton::gcu::getElemsPerThread(op.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(op.getType()); + + srcNumElems.insert(srcNumElems.begin() + op.getAxis(), 1); + + // noop expand dims + if (srcNumElems == dstNumElems) { + auto [strides, offset] = resultType.getStridesAndOffset(); + auto output = rewriter.create( + loc, resultType, adaptor.getSrc(), offset, resultType.getShape(), + strides); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } + auto type = op.getType(); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto tag = getPrivateDTETag(rewriter, op); + auto srcType = dyn_cast(op.getSrc().getType()); + auto resMemType = + MemRefType::get(type.getShape(), resultType.getElementType(), + AffineMap{}, rewriter.getI64IntegerAttr(2)); + // move source to shared memory + auto sharedSrc = + storeToSharedMem(rewriter, tag, srcType, adaptor.getSrc(), false, + op.getOperation(), userAnalysis, replaced2Origin); + auto [strides, offset] = resMemType.getStridesAndOffset(); + auto result = rewriter.create( + loc, resMemType, sharedSrc, offset, type.getShape(), strides); + // copy back outputs + Value output = + loadFromSharedMem(rewriter, tag, op.getType(), result, false, lastUser, + nullptr, userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } +}; + +struct TTReshapeOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + auto resultType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto srcNumElems = triton::gcu::getElemsPerThread(op.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(op.getType()); + + // noop expand dims + if (srcNumElems == dstNumElems) { + auto [strides, offset] = resultType.getStridesAndOffset(); + auto output = rewriter.create( + loc, resultType, adaptor.getSrc(), offset, resultType.getShape(), + strides); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } + auto type = op.getType(); + + auto tag = getPrivateDTETag(rewriter, op); + auto srcType = dyn_cast(op.getSrc().getType()); + // move source to shared memory + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto sharedSrc = + storeToSharedMem(rewriter, tag, srcType, adaptor.getSrc(), false, + op.getOperation(), userAnalysis, replaced2Origin); + auto resMemType = + MemRefType::get(type.getShape(), resultType.getElementType(), + AffineMap{}, rewriter.getI64IntegerAttr(2)); + auto [strides, offset] = resMemType.getStridesAndOffset(); + auto result = rewriter.create( + loc, resMemType, sharedSrc, offset, type.getShape(), strides); + // copy back outputs + Value output = + loadFromSharedMem(rewriter, tag, op.getType(), result, false, lastUser, + nullptr, userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } +}; + +struct TTSplitOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + auto srcType = dyn_cast(op.getSrc().getType()); + auto srcShape = srcType.getShape(); + auto srcRank = srcType.getRank(); + if (srcRank <= 0) + return op.emitOpError("the rank must be greater than 0."); + if (srcShape[srcRank - 1] != 2) + return op.emitOpError("the last dim must have size 2."); + + auto outType = dyn_cast(op.getOutLHS().getType()); + auto outMemrefType = + dyn_cast(getTypeConverter()->convertType(outType)); + + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto lhs = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, outMemrefType); + auto rhs = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, outMemrefType); + + auto outMemrefShape = outMemrefType.getShape(); + SmallVector sliceShape(outMemrefShape.size() + 1, 1); + for (long unsigned int i = 0; i < outMemrefShape.size(); i++) { + sliceShape[i] = outMemrefShape[i]; + } + SmallVector sliceStride(sliceShape.size(), 1); + for (int i = sliceShape.size() - 2; i >= 0; --i) { + sliceStride[i] = sliceStride[i + 1] * sliceShape[i + 1]; + } + + auto sliceType = + MemRefType::get(sliceShape, outMemrefType.getElementType()); + + auto sliceLHS = rewriter.create( + loc, sliceType, lhs, 0, sliceShape, sliceStride); + auto sliceRHS = rewriter.create( + loc, sliceType, rhs, 0, sliceShape, sliceStride); + + auto tag = getPrivateDTETag(rewriter, op); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + + SmallVector offsets; + for (int i = 0; i < outType.getRank(); ++i) { + offsets.push_back(rewriter.create( + loc, rewriter.getI32Type(), zero)); + } + SmallVector offsetsLHS = offsets; + SmallVector offsetsRHS = offsets; + offsetsLHS.push_back( + rewriter.create(loc, rewriter.getI32Type(), zero)); + offsetsRHS.push_back( + rewriter.create(loc, rewriter.getI32Type(), one)); + + auto totalNumElems = triton::gcu::getTotalElemsPerThread(outType); + auto defaultValue = triton::gcu::createConstantZero( + rewriter, loc, outMemrefType.getElementType()); + + rewriter.create(loc, sliceLHS, adaptor.getSrc(), + offsetsLHS, defaultValue, tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + + rewriter.create(loc, sliceRHS, adaptor.getSrc(), + offsetsRHS, defaultValue, tag, + ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, {lhs, rhs}); + return success(); + } +}; + +struct TTJoinOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + + auto lhsType = dyn_cast(op.getLhs().getType()); + auto rhsType = dyn_cast(op.getRhs().getType()); + if (lhsType != rhsType) + return op.emitOpError("the lhs and rhs type must be the same."); + + auto lhsMemrefType = + dyn_cast(getTypeConverter()->convertType(lhsType)); + + auto outType = dyn_cast(op.getResult().getType()); + auto outMemrefType = + dyn_cast(getTypeConverter()->convertType(outType)); + + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto result = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, outMemrefType); + + auto lhsShape = lhsMemrefType.getShape(); + SmallVector desliceShape(lhsShape.size() + 1, 1); + for (size_t i = 0; i < lhsShape.size(); i++) { + desliceShape[i] = lhsShape[i]; + } + SmallVector desliceStride(desliceShape.size(), 1); + for (int i = desliceShape.size() - 2; i >= 0; --i) { + desliceStride[i] = desliceStride[i + 1] * desliceShape[i + 1]; + } + + auto desliceType = + MemRefType::get(desliceShape, lhsMemrefType.getElementType()); + auto desliceLHS = rewriter.create( + loc, desliceType, adaptor.getLhs(), 0, desliceShape, desliceStride); + auto desliceRHS = rewriter.create( + loc, desliceType, adaptor.getRhs(), 0, desliceShape, desliceStride); + + auto tag = getPrivateDTETag(rewriter, op); + + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + + SmallVector offsets; + for (int i = 0; i < lhsType.getRank(); ++i) { + offsets.push_back(rewriter.create( + loc, rewriter.getI32Type(), zero)); + } + SmallVector offsetsLHS = offsets; + SmallVector offsetsRHS = offsets; + offsetsLHS.push_back( + rewriter.create(loc, rewriter.getI32Type(), zero)); + offsetsRHS.push_back( + rewriter.create(loc, rewriter.getI32Type(), one)); + + auto totalNumElems = triton::gcu::getTotalElemsPerThread(lhsType); + + rewriter.create( + loc, result, desliceLHS, offsetsLHS, tag, ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + + rewriter.create( + loc, result, desliceRHS, offsetsRHS, tag, ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +struct TTCatOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = op.getType(); + auto loc = op.getLoc(); + auto resultType = + dyn_cast(getTypeConverter()->convertType(type)); + + auto tag = getPrivateDTETag(rewriter, op); + auto zero = rewriter.create(loc, 0); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto lhsSlicedAxies = getSlicedAxies(op.getLhs().getType()); + auto rhsSlicedAxies = getSlicedAxies(op.getRhs().getType()); + auto outputSlicedAxies = getSlicedAxies(op.getType()); + if (!lhsSlicedAxies.count(0) && !rhsSlicedAxies.count(0) && + !outputSlicedAxies.count(0)) { + auto totalNumElems = triton::gcu::getTotalElemsPerThread(type); + + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultType); + SmallVector offsets; + for (unsigned i = 0; i < resultType.getRank(); ++i) { + offsets.push_back(rewriter.create( + loc, rewriter.getI32Type(), zero)); + } + rewriter.create( + loc, output, adaptor.getLhs(), offsets, tag, ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + + offsets[0] = rewriter.create( + loc, dyn_cast(adaptor.getLhs().getType()).getDimSize(0), + rewriter.getI32Type()); + rewriter.create( + loc, output, adaptor.getRhs(), offsets, tag, ValueRange{zero}); + rewriter.create( + loc, tag, ValueRange{zero}, + rewriter.create(loc, totalNumElems)); + rewriter.replaceOp(op, output); + return success(); + } + auto mergedResultType = + MemRefType::get(type.getShape(), type.getElementType(), AffineMap{}, + rewriter.getI64IntegerAttr(2) /*shared memory*/); + auto mergedOutput = + syncAllocOp(rewriter, loc, op.getOperation(), userAnalysis, + replaced2Origin, mergedResultType); + auto lhsTy = op.getLhs().getType(); + auto [lhsStrides, lhsOffset] = + dyn_cast(getTypeConverter()->convertType(lhsTy)) + .getStridesAndOffset(); + storeToSharedMem( + rewriter, tag, op.getLhs().getType(), + rewriter.create( + loc, + MemRefType::get(lhsTy.getShape(), lhsTy.getElementType(), + AffineMap{}, rewriter.getI64IntegerAttr(2)), + mergedOutput, 0, lhsTy.getShape(), lhsStrides), + adaptor.getLhs(), false); + (void)lhsOffset; + + auto rhsTy = op.getRhs().getType(); + auto [rhsStrides, rhsOffset] = + dyn_cast(getTypeConverter()->convertType(rhsTy)) + .getStridesAndOffset(); + storeToSharedMem( + rewriter, tag, op.getRhs().getType(), + rewriter.create( + loc, + MemRefType::get(rhsTy.getShape(), rhsTy.getElementType(), + makeStridedLinearLayoutMap(rhsStrides, + rhsTy.getNumElements(), + rewriter.getContext()), + rewriter.getI64IntegerAttr(2)), + mergedOutput, rhsTy.getNumElements(), rhsTy.getShape(), rhsStrides), + adaptor.getRhs(), false); + (void)rhsOffset; + // read back + auto output = + loadFromSharedMem(rewriter, tag, op.getType(), mergedOutput, false, + lastUser, nullptr, userAnalysis, replaced2Origin); + rewriter.replaceOp(op, output); + return success(); + } +}; + +struct TTTransOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + void applyTranspose(OpBuilder &rewriter, Location loc, Value src, + Value output, Value tag, ArrayRef order, + unsigned totalSize) const { + auto zero = rewriter.create(loc, 0); + auto totalNumElems = + rewriter.create(loc, totalSize); + + SmallVector layout; + for (auto i : order) { + layout.push_back( + rewriter.create(loc, i, rewriter.getI32Type())); + } + rewriter.create(loc, output, src, layout, tag, + ValueRange{zero}); + rewriter.create(loc, tag, ValueRange{zero}, + totalNumElems); + } + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto loc = op.getLoc(); + auto srcTy = dyn_cast(op.getSrc().getType()); + auto dstTy = dyn_cast(op.getType()); + if ((!srcTy) || (!dstTy)) { + assert(false && "srcTy or dstTy not a RankedTensorType"); + } + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto resultType = dyn_cast( + getTypeConverter()->convertType(op.getResult().getType())); + auto zero = rewriter.create(loc, 0); + auto totalNumElems = + triton::gcu::getTotalElemsPerThread(op.getSrc().getType()); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + // gcu400 only one private dte + if (mlir::isa(srcLayout) && + mlir::isa(dstLayout)) { + // allocate output buffers in shared memory + auto firstUser = nullptr; + auto tag = (firstUser == nullptr) ? getPrivateDTETag(rewriter, op) + : createPrivateDTETag(rewriter, op); + auto sharedOutputType = MemRefType::get( + op.getResult().getType().getShape(), resultType.getElementType(), + AffineMap{}, rewriter.getI64IntegerAttr(2)); + auto sharedOutput = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, sharedOutputType); + // split by thread 0 + auto totalNumElemsValue = + rewriter.create(loc, totalNumElems); + + SmallVector layout; + for (auto i : op.getOrder()) { + layout.push_back(rewriter.create( + loc, i, rewriter.getI32Type())); + } + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + rewriter.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + rewriter.create( + loc, sharedOutput, adaptor.getSrc(), layout, tag, + ValueRange{zero}); + builder.create(loc); + }); + if (firstUser != nullptr) { + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(firstUser); + rewriter.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + builder.create(loc, tag, ValueRange{zero}, + totalNumElemsValue); + builder.create(loc); + }); + rewriter.create(loc); + rewriter.restoreInsertionPoint(ip); + } else { + rewriter.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + builder.create(loc, tag, ValueRange{zero}, + totalNumElemsValue); + builder.create(loc); + }); + rewriter.create(loc); + } + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, sharedOutput); + return success(); + } else if (isa(srcLayout) && + isa(dstLayout)) { + // move source to shared memory + auto tag = getPrivateDTETag(rewriter, op); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto sharedSrc = storeToSharedMem( + rewriter, tag, dyn_cast(op.getSrc().getType()), + adaptor.getSrc(), false, op.getOperation(), userAnalysis, + replaced2Origin); + + // allocate output buffers in shared memory + auto sharedOutputType = MemRefType::get( + op.getResult().getType().getShape(), resultType.getElementType(), + AffineMap{}, rewriter.getI64IntegerAttr(2)); + auto sharedOutput = + syncAllocOp(rewriter, loc, op.getOperation(), userAnalysis, + replaced2Origin, sharedOutputType); + + // split by thread 0 + auto isThread0 = rewriter.create( + loc, arith::CmpIPredicate::eq, + rewriter.create(loc, gpu::Dimension::x), zero); + rewriter.create( + loc, isThread0, [&](OpBuilder &builder, Location loc) { + applyTranspose(builder, loc, sharedSrc, sharedOutput, tag, + op.getOrder(), totalNumElems); + builder.create(loc); + }); + rewriter.create(loc); + // copy back outputs + Value output = loadFromSharedMem(rewriter, tag, op.getResult().getType(), + sharedOutput, false, lastUser, nullptr, + userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } else { + op.dump(); + assert(false && "please check layout of this transop \n"); + return failure(); + } + } +}; + +struct TTGConvertLayoutOpLowering + : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + enterTritionOp(rewriter, op.getOperation()); + auto srcNumElems = triton::gcu::getElemsPerThread(op.getSrc().getType()); + auto dstNumElems = triton::gcu::getElemsPerThread(op.getType()); + // noop convert + auto srcTy = dyn_cast(op.getSrc().getType()); + auto dstTy = dyn_cast(op.getType()); + if ((!srcTy) || (!dstTy)) { + assert(false && "srcTy or dstTy not a RankedTensorType"); + } + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + if (srcLayout == dstLayout) { + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto firstUser = nullptr; + auto tag = (firstUser == nullptr) ? getPrivateDTETag(rewriter, op) + : createPrivateDTETag(rewriter, op); + if (srcNumElems == dstNumElems && + op.getSrc().getType().getShape() == op.getType().getShape()) { + if (mlir::isa(srcLayout) && + isa(dstLayout)) { + // give up L2 to matmul because 1:acore crash 2:L2 latency is more + // 100cyle than L1 we don't had enough resource to refine latency + } else if (isa(srcLayout) && + isa(dstLayout)) { + if (cast(srcLayout).getDim() == + cast(dstLayout).getDim()) { + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + } else { + if (mlir::isa(srcLayout)) { + auto output = CopyFromSharedMem( + rewriter, tag, op.getResult().getType(), adaptor.getSrc(), false, + lastUser, firstUser, userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + } + // share to Distributed + if (mlir::isa(srcLayout) && + isa(dstLayout)) { + // copy to local + auto output = loadFromSharedMem(rewriter, tag, op.getResult().getType(), + adaptor.getSrc(), false, lastUser, + firstUser, userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } else if (isa(srcLayout) && + isa(dstLayout)) { + // Distributed to dot operand + auto sharedSrc = storeToSharedMem( + rewriter, tag, dyn_cast(op.getSrc().getType()), + adaptor.getSrc(), false, op.getOperation(), userAnalysis, + replaced2Origin); + // to dot a or b calculate warp idx + auto output = loadFromSharedMemForDotOperand( + rewriter, tag, op.getResult().getType(), + op.getSrc().getType().getShape(), sharedSrc, lastUser, firstUser, + userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } else if (mlir::isa(srcLayout) && + isa(dstLayout)) { + // Distributed to dot operand + // to dot a or b + auto output = loadFromSharedMemForDotOperand( + rewriter, tag, op.getResult().getType(), + op.getSrc().getType().getShape(), adaptor.getSrc(), lastUser, + firstUser, userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } else { + // move source to shared memory + auto sharedSrc = storeToSharedMem( + rewriter, tag, dyn_cast(op.getSrc().getType()), + adaptor.getSrc(), false, op.getOperation(), userAnalysis, + replaced2Origin); + // copy back outputs + auto output = loadFromSharedMem(rewriter, tag, op.getResult().getType(), + sharedSrc, false, lastUser, firstUser, + userAnalysis, replaced2Origin); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + } + return success(); + } +}; + +struct GCUMatmulLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::gcu::MatmulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + enterTritionOp(rewriter, op.getOperation()); + if (!isa(op.getA().getType()) || + !isa(op.getB().getType())) + return failure(); + if (op.getType().getRank() != 2) { + llvm::report_fatal_error( + "triton::gcu::MatmulOp no bias not support 3D or more 3D dot \n"); + } + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto resultMemRefType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultMemRefType); + rewriter.create(loc, output, adaptor.getA(), adaptor.getB(), + Value()); + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } +}; + +struct TTDotOpLowering : SharedConversionPattern { + using SharedConversionPattern::SharedConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + enterTritionOp(rewriter, op.getOperation()); + if (!isa(op.getA().getType()) || + !isa(op.getB().getType())) + return failure(); + auto lastUser = userAnalysis.getLastUserOp(op.getOperation()); + auto resultMemRefType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto output = syncAllocOp(rewriter, loc, lastUser, userAnalysis, + replaced2Origin, resultMemRefType); + if (op.getType().getRank() == 2) { + rewriter.create(loc, output, adaptor.getA(), + adaptor.getB(), adaptor.getC()); + } else { + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto lhsMemRef = adaptor.getA(); + auto lhsMemRefType = dyn_cast(lhsMemRef.getType()); + auto rhsMemRef = adaptor.getB(); + auto rhsMemRefType = dyn_cast(rhsMemRef.getType()); + auto biasMemRef = adaptor.getC(); + auto biasMemRefType = dyn_cast(biasMemRef.getType()); + int64_t batchNum = lhsMemRefType.getShape()[0]; + + auto createFlattened1DMemRef = [&](Value memRef, MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + int64_t size = 1; + for (int i = 0; i < memRefType.getRank(); i++) { + size *= memRefType.getShape()[i]; + } + // Create flattened buffer + MemRefType flatType = MemRefType::get({size}, elementType); + Value flatBuffer = rewriter.create( + loc, flatType, memRef, zero, + ValueRange{rewriter.create(loc, size)}, + ValueRange{one}); + + // Convert flattened buffer to 1D MemRef + auto ptrType = gcu::PtrType::get(getContext(), elementType); + Value ptr = + rewriter.create(loc, ptrType, flatBuffer); + MemRefType memType1D = + MemRefType::get({ShapedType::kDynamic}, rewriter.getI8Type()); + return rewriter.create(loc, memType1D, ptr); + }; + + // Create 1D MemRefs for lhs, rhs, bias, and output + Value lhsBuffer = createFlattened1DMemRef(lhsMemRef, lhsMemRefType); + Value rhsBuffer = createFlattened1DMemRef(rhsMemRef, rhsMemRefType); + Value biasBuffer = createFlattened1DMemRef(biasMemRef, biasMemRefType); + Value outBuffer = createFlattened1DMemRef(output, resultMemRefType); + auto bitWidthOfInt8 = rewriter.getI8Type().getIntOrFloatBitWidth(); + scf::buildLoopNest( + rewriter, loc, ValueRange{zero}, + ValueRange{rewriter.create(loc, batchNum)}, + ValueRange{one}, + [&](OpBuilder &rewriter, Location loc, ValueRange m) { + auto createViewWithOffset = [&](MemRefType memRefType, + Value buffer) { + int64_t tailIndex = memRefType.getRank() - 1; + int64_t dim0 = memRefType.getShape()[tailIndex - 1]; + int64_t dim1 = memRefType.getShape()[tailIndex]; + auto elementType = memRefType.getElementType(); + int64_t elementSize = + elementType.getIntOrFloatBitWidth() / bitWidthOfInt8; + Value offset = rewriter.create( + loc, m[0], + rewriter.create( + loc, dim0 * dim1 * elementSize)); + return rewriter.create( + loc, MemRefType::get({dim0, dim1}, elementType), buffer, + offset, ValueRange{}); + }; + + Value newLhsMemRef = createViewWithOffset(lhsMemRefType, lhsBuffer); + Value newRhsMemRef = createViewWithOffset(rhsMemRefType, rhsBuffer); + Value newBiasMemRef = + createViewWithOffset(biasMemRefType, biasBuffer); + Value newOutMemRef = + createViewWithOffset(resultMemRefType, outBuffer); + rewriter.create(loc, newOutMemRef, newLhsMemRef, + newRhsMemRef, newBiasMemRef); + }); + } + leaveTritionOp(rewriter, op.getOperation()); + rewriter.replaceOp(op, output); + return success(); + } +}; + +} // namespace + +void ConvertTritonToGCUPass::runOnOperation() { + auto *ctx = &getContext(); + auto module = getOperation(); + + // pre analysis base triton ir + triton::gcu::FirstLastUserAnalysis &userAnalysis = + getAnalysis(); + + std::map replaced2Origin; + replaced2Origin.clear(); + + std::map asyncLoad2Tag; + std::map asyncWait2Tag; + llvm::DenseMap asyncLoad2TagIdex; + getPipelineAsyncResourceMaping(module, asyncLoad2Tag, asyncLoad2TagIdex, + asyncWait2Tag); + std::map> + TTYeiledOPerandHasMultiUseStage; + AnalysisYieldOperendUseStage(module, userAnalysis, + TTYeiledOPerandHasMultiUseStage); + + RewritePatternSet patterns(ctx); + // define converter + TypeConverter converter; + // default + converter.addConversion([](Type type) { return type; }); + converter.addConversion([](mlir::triton::gcu::PtrType type) { + return gcu::PtrType::get(type.getContext(), type.getElementType()); + }); + // // pointer type + // converter.addConversion([](triton::PointerType ptrType) -> Type { + // if (auto ty = dyn_cast(ptrType.getPointeeType())) + // return mlir::triton::gcu::TileDescType::get(ty.getContext(), ty); + // return mlir::triton::gcu::PtrType::get(ptrType.getContext(), + // ptrType.getPointeeType()); + // }); + // pointer type + converter.addConversion([](triton::PointerType ptrType) -> Type { + if (auto ty = dyn_cast(ptrType.getPointeeType())) + return mlir::gcu::TileDescType::get(ty.getContext(), ty); + return gcu::PtrType::get(ptrType.getContext(), ptrType.getPointeeType()); + }); + // tensor type + converter.addConversion([&](TensorType tensorType) { + auto numElems = triton::gcu::getElemsPerThread(tensorType); + SmallVector shape(numElems.begin(), numElems.end()); + auto elemType = converter.convertType(tensorType.getElementType()); + // todo_AT weird ptr + if (isa(elemType) || + isa(elemType)) + // use i64 for pointer type + elemType = IntegerType::get(tensorType.getContext(), 64); + if (auto tType = dyn_cast(tensorType)) { + if (mlir::isa(tType.getEncoding())) { + return MemRefType::get( + shape, elemType, AffineMap{}, + IntegerAttr::get(IntegerType::get(tensorType.getContext(), 64), 2)); + } + } + return MemRefType::get(shape, elemType); + }); + + converter.addConversion([&](triton::gpu::MemDescType bufferType) { + auto elemType = converter.convertType(bufferType.getElementType()); + return MemRefType::get( + bufferType.getShape(), elemType, AffineMap{}, + IntegerAttr::get(IntegerType::get(bufferType.getContext(), 64), 2)); + }); + converter.addConversion([&](triton::gpu::AsyncTokenType tokenType) { + return IntegerType::get(tokenType.getContext(), 32); + }); + ConversionTarget target(getContext()); + + mlir::triton::populateReduceOpToGCUPatterns(converter, patterns, userAnalysis, + replaced2Origin); + mlir::triton::populateElementwiseFusionOpToGCUPatterns( + converter, patterns, userAnalysis, replaced2Origin); + + patterns + .add, + TTIntrinsicOpLowering, + TTPrintOpLowering, TTAssertOpLowering, TTAddPtrOpLowering, + TTLoadOpLowering, TTStoreOpLowering, TTConstantOpLowering, + TTReduceReturnOpLowering, TTScanReturnOpLowering, + TTExternElemwiseOpLowering, + TTElementwiseOpLowering, + TTElementwiseOpLowering, + TTElementwiseOpLowering, + TTElementwiseOpLowering, + TTElementwiseOpLowering, + TTArithSelectOpLowering, TTBitcastOpLowering, TTBroadcastOpLowering, + TTCatOpLowering, TTHistogramOpLowering, TTExpandDimsOpLowering, + TTReshapeOpLowering, TTSplitOpLowering, TTJoinOpLowering, + GCUMatmulLowering, TTGAssertOpLowering, TTTransOpLowering, + TTGConvertLayoutOpLowering, GCULoadOpLowering, GCUStoreOpLowering, + TTDotOpLowering, TTSplatOpLowering>(converter, ctx, userAnalysis, + replaced2Origin); + + patterns.add(converter, ctx, userAnalysis, replaced2Origin, + vectorLength); + patterns.add(converter, ctx, userAnalysis, + replaced2Origin, vectorLength, + vectorizationMaxLength); + patterns.add(converter, ctx, userAnalysis, + replaced2Origin, + TTYeiledOPerandHasMultiUseStage); + + patterns.add(converter, ctx); + + patterns.add(converter, ctx, userAnalysis, + replaced2Origin); + + patterns.add( + converter, ctx, asyncLoad2Tag, asyncLoad2TagIdex); + patterns.add(converter, ctx, asyncWait2Tag); + + target.addLegalDialect< + gpu::GPUDialect, gcu::GCUDialect, arith::ArithDialect, + affine::AffineDialect, func::FuncDialect, scf::SCFDialect, + math::MathDialect, vector::VectorDialect, memref::MemRefDialect, + memref_ext::MemrefExtDialect, math_ext::MathExtDialect>(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addDynamicallyLegalDialect([](Operation *op) { + return llvm::none_of(op->getOperandTypes(), + [](auto t) { + return isa(t); + }) && + llvm::none_of(op->getResultTypes(), [](auto t) { + return isa( + t); + }); + }); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/Utils.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/Utils.cpp new file mode 100644 index 000000000..7481e2a61 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/Utils.cpp @@ -0,0 +1,260 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Utils.h" +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/MathExtras.h" +#define DEBUG_TYPE "triton-ir-to-gcu-ir-util" +namespace mlir { +namespace triton { +namespace gcu { + +bool get_bool_env(const char *name) { + const char *value = std::getenv(name); + if (value == nullptr) { + return false; + } + std::string str_value(value); + std::transform(str_value.begin(), str_value.end(), str_value.begin(), + ::tolower); + return (str_value == "true" || str_value == "1" || str_value == "on" || + str_value == "yes"); +} + +SmallVector getWarpsPerCTA(Attribute layout) { + if (auto blockEnc = dyn_cast(layout)) { + return blockEnc.getWarpsPerCTA(); + + } else if (auto sliceEnc = dyn_cast(layout)) { + auto parent = sliceEnc.getParent(); + SmallVector sliceDims; + sliceDims.push_back(sliceEnc.getDim()); + while (auto innerSliceEnc = + dyn_cast(parent)) { + auto curSliceDim = innerSliceEnc.getDim(); + for (size_t idx = 0; idx < sliceDims.size(); idx++) { + if (sliceDims[idx] >= curSliceDim) { + sliceDims[idx] = sliceDims[idx] + 1; + } + } + sliceDims.push_back(curSliceDim); + parent = innerSliceEnc.getParent(); + } + if (!isa(parent)) { + llvm::report_fatal_error("[Error] bad slice layout parent"); + assert(false && "bad slice layout parent"); + triton::gpu::getWarpsPerCTA(layout); + } + auto blockEncParent = dyn_cast(parent); + auto parentWarpsPerCTA = blockEncParent.getWarpsPerCTA(); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < parentWarpsPerCTA.size(); ++i) { + if (!llvm::is_contained(sliceDims, i)) { + warpsPerCTA.push_back(parentWarpsPerCTA[i]); + } + } + return warpsPerCTA; + + } else { + return triton::gpu::getWarpsPerCTA(layout); + } +} + +SmallVector getElemsPerThread(Type type) { + if (auto tType = dyn_cast(type)) { + if (auto dotEnc = dyn_cast( + tType.getEncoding())) { + // dot lhs and rhs should have different slicing by op id but + // DotOperandEncodingAttr no supported and currently support 2D dot first + auto shape = tType.getShape(); + if (auto blockedLayout = + dyn_cast(dotEnc.getParent())) { + auto rank = shape.size(); + SmallVector elemsPerthread(rank, 1); + // low 2 rank do dot + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + for (unsigned idx = 0; idx < rank - 2; idx++) { + elemsPerthread[idx] = shape[idx]; + if (warpsPerCTA[idx] > 1) { + LLVM_DEBUG({ + llvm::dbgs() << "hi slice should in lower 2 dims for dot \n"; + dotEnc.dump(); + }); + } + assert((warpsPerCTA[idx] == 1) && + "hi slice should in lower 2 dims for dot\n"); + } + bool isM = dotEnc.getOpIdx() == 0; + // only debug check + if (isM) { + int64_t k = shape[rank - 1]; + elemsPerthread[rank - 1] = k; + elemsPerthread[rank - 2] = shape[rank - 2] / warpsPerCTA[rank - 2]; + } else { + int64_t k = shape[rank - 2]; + elemsPerthread[rank - 2] = k; + elemsPerthread[rank - 1] = shape[rank - 1] / warpsPerCTA[rank - 1]; + } + return elemsPerthread; + } + } else if (mlir::isa( + tType.getEncoding())) { + return SmallVector(tType.getShape().begin(), + tType.getShape().end()); + } else if (auto blockEnc = dyn_cast( + tType.getEncoding())) { + auto shape = tType.getShape(); + size_t rank = shape.size(); + SmallVector sizePerThread(rank, 1); + auto warpsPerCTA = blockEnc.getWarpsPerCTA(); + auto threadsPerWarp = blockEnc.getThreadsPerWarp(); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockEnc, shape); + assert(rank == sizePerThread.size() && + "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); + SmallVector elemsPerThread(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; + elemsPerThread[i] = + ceil(shapePerCTA[i], t) * sizePerThread[i]; + } + return elemsPerThread; + } else if (auto linearEnc = dyn_cast( + tType.getEncoding())) { + auto shape = tType.getShape(); + size_t rank = shape.size(); + SmallVector sizePerThread(rank, 1); + auto warpsPerCTA = linearEnc.getWarpsPerCTA(); + auto threadsPerWarp = linearEnc.getThreadsPerWarp(); + assert(rank == sizePerThread.size() && + "unexpected rank in LinearEncodingAttr::getElemsPerThread"); + SmallVector elemsPerThread(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; + elemsPerThread[i] = ceil(shape[i], t) * sizePerThread[i]; + } + return elemsPerThread; + } else if (auto sliceEnc = dyn_cast( + tType.getEncoding())) { + auto parent = sliceEnc.getParent(); + auto outShape = sliceEnc.paddedShape(tType.getShape()); + SmallVector sliceDims; + sliceDims.push_back(sliceEnc.getDim()); + while (auto innerSliceEnc = + dyn_cast(parent)) { + llvm::ArrayRef inputShpe = outShape; + outShape = innerSliceEnc.paddedShape(inputShpe); + auto curSliceDim = innerSliceEnc.getDim(); + for (size_t idx = 0; idx < sliceDims.size(); idx++) { + if (sliceDims[idx] >= curSliceDim) { + sliceDims[idx] = sliceDims[idx] + 1; + } + } + sliceDims.push_back(curSliceDim); + parent = innerSliceEnc.getParent(); + } + if (!isa(parent)) { + return triton::gpu::getElemsPerThread(type); + } + auto blockEncParent = dyn_cast(parent); + size_t rank = outShape.size(); + SmallVector sizePerThread(rank, 1); + auto warpsPerCTA = blockEncParent.getWarpsPerCTA(); + auto threadsPerWarp = blockEncParent.getThreadsPerWarp(); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockEncParent, outShape); + assert(rank == sizePerThread.size() && + "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); + SmallVector parentElemsPerThread(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; + parentElemsPerThread[i] = + ceil(shapePerCTA[i], t) * sizePerThread[i]; + } + SmallVector elemsPerThread; + for (unsigned i = 0; i < rank; ++i) { + if (!llvm::is_contained(sliceDims, i)) { + elemsPerThread.push_back(parentElemsPerThread[i]); + } + } + return elemsPerThread; + } else { + return triton::gpu::getElemsPerThread(type); + } + } + return triton::gpu::getElemsPerThread(type); +} + +unsigned getTotalElemsPerThread(Type type) { + if (auto tType = dyn_cast(type)) { + if (auto enc = tType.getEncoding()) { + if (llvm::isa_and_nonnull(enc)) { + auto elemsPerthread = gcu::getElemsPerThread(type); + return std::accumulate(elemsPerthread.begin(), elemsPerthread.end(), 1, + std::multiplies()); + } else if (mlir::isa(enc)) { + return std::accumulate(tType.getShape().begin(), tType.getShape().end(), + 1, std::multiplies()); + } else if (llvm::isa_and_nonnull( + tType.getEncoding()) || + llvm::isa_and_nonnull( + tType.getEncoding())) { + auto elemsPerthread = gcu::getElemsPerThread(type); + return std::accumulate(elemsPerthread.begin(), elemsPerthread.end(), 1, + std::multiplies()); + + } else if (llvm::isa_and_nonnull( + tType.getEncoding())) { + auto elemsPerthread = gcu::getElemsPerThread(type); + return std::accumulate(elemsPerthread.begin(), elemsPerthread.end(), 1, + std::multiplies()); + } else { + return triton::gpu::getTotalElemsPerThread(type); + } + } + } + return triton::gpu::getTotalElemsPerThread(type); +} + +unsigned getNumWarpsPerCTA(Type type) { + if (auto tType = dyn_cast(type)) { + return tType.getEncoding() + ? triton::gpu::getNumWarpsPerCTA(tType.getEncoding()) + : 1; + } + return 1; +} + +unsigned getBpe(Type type) { + assert(type.isIntOrFloat()); + return ((type.getIntOrFloatBitWidth() + 7) / 8); +} + +} // namespace gcu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/Utils.h b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/Utils.h new file mode 100644 index 000000000..6ae9dbf99 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Conversion/TritonToGCU/Utils.h @@ -0,0 +1,43 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KURAMA_TRITON_TO_GCU_UTILS_H_ +#define KURAMA_TRITON_TO_GCU_UTILS_H_ + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace mlir { +namespace triton { +namespace gcu { + +bool get_bool_env(const char *name); +SmallVector getWarpsPerCTA(Attribute layout); +SmallVector getElemsPerThread(Type type); +unsigned getTotalElemsPerThread(Type type); +unsigned getNumWarpsPerCTA(Type type); +unsigned getBpe(Type type); + +inline int64_t ceilDiv(int64_t lhs, int64_t rhs) { + assert(rhs >= 1); + // C/C++'s integer division rounds towards 0. + return lhs % rhs > 0 ? lhs / rhs + 1 : lhs / rhs; +} +} // namespace gcu +} // namespace triton +} // namespace mlir + +#endif // KURAMA_TRITON_TO_GCU_UTILS_H_ diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..354edf74b --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(GCU) +add_subdirectory(MemrefExt) +add_subdirectory(MathExt) +add_subdirectory(TritonGCU) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/CMakeLists.txt new file mode 100644 index 000000000..8984eac30 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(GCUIR${arch} + Dialect.cpp + Types.cpp + Ops.cpp + + DEPENDS + GCUTableGen${arch} + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Dialect.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Dialect.cpp new file mode 100644 index 000000000..03a17117e --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Dialect.cpp @@ -0,0 +1,91 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Dialect/GCU/IR/Dialect.h" +#include "Dialect/GCU/IR/Dialect.cpp.inc" +#include "Dialect/GCU/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#define GET_ATTRDEF_CLASSES +#include "Dialect/GCU/IR/OpsAttributes.cpp.inc" +#define GET_OP_CLASSES +#include "Dialect/GCU/IR/Ops.cpp.inc" +#include "Dialect/GCU/IR/OpsEnums.cpp.inc" + +using namespace mlir; +using namespace mlir::gcu; + +void GCUDialect::initialize() { + registerTypes(); + addOperations< +#define GET_OP_LIST +#include "Dialect/GCU/IR/Ops.cpp.inc" // NOLINT: This file generated situationally via different environment variables + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/GCU/IR/OpsAttributes.cpp.inc" // NOLINT: This file generated situationally via different environment variables + >(); + + // We can also add interface here. + // addInterfaces(); +} + +Operation *GCUDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} + +//===----------------------------------------------------------------------===// +// GCU target attribute. +//===----------------------------------------------------------------------===// +LogicalResult +GCUTargetAttr::verify(function_ref emitError, + int optLevel, StringRef triple, StringRef chip, + StringRef arch, StringRef features, StringRef abiVersion, + DictionaryAttr flags, ArrayAttr files) { + if (optLevel < 0 || optLevel > 3) { + emitError() << "The optimization level must be a number between 0 and 3."; + return failure(); + } + if (triple.empty()) { + emitError() << "The target triple cannot be empty."; + return failure(); + } + if (chip.empty()) { + emitError() << "The target chip cannot be empty."; + return failure(); + } + if (arch.empty()) { + emitError() << "The target arch cannot be empty."; + return failure(); + } + if (abiVersion != "1") { + emitError() << "Invalid ABI version, it must be `1`."; + return failure(); + } + if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { + return attr && mlir::isa(attr); + })) { + emitError() << "All the elements in the `link` array must be strings."; + return failure(); + } + return success(); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Ops.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Ops.cpp new file mode 100644 index 000000000..41469c4a0 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Ops.cpp @@ -0,0 +1,582 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/GCU/IR/Dialect.h" +namespace mlir { +namespace gcu { + +LogicalResult AllocBarrierOp::verify() { + if (getBarrier().getType().getAddressSpace().getValue() != + gcu::AddressSpace::Workgroup) + return emitOpError() << "only supports workgroup level"; + return success(); +} + +LogicalResult DynamicSharedMemoryOp::verify() { + if (!getOperation()->getParentWithTrait()) + return emitOpError() << "must be inside an op with symbol table"; + + MemRefType memrefType = getResultMemref().getType(); + // Check address space + if (auto addrspace = memrefType.getMemorySpace()) { + if (!(dyn_cast(addrspace) && + dyn_cast(addrspace).getValue() == + gpu::AddressSpace::Workgroup) && + !(dyn_cast(addrspace) && + (dyn_cast(addrspace).getValue() == + gcu::AddressSpace::Workgroup || + dyn_cast(addrspace).getValue() == + gcu::AddressSpace::Local))) + return emitOpError() << "address space must be " + << gpu::AddressSpaceAttr::getMnemonic() << "<" + << stringifyEnum(gpu::AddressSpace::Workgroup) << ">" + << " or " << gcu::AddressSpaceAttr::getMnemonic() + << "<" << stringifyEnum(gcu::AddressSpace::Workgroup) + << ">" + << " or " << gcu::AddressSpaceAttr::getMnemonic() + << "<" << stringifyEnum(gcu::AddressSpace::Local) + << ">"; + } + if (memrefType.hasStaticShape()) { + return emitOpError() << "result memref type must be memref> or > or >"; + } + return success(); +} + +LogicalResult MemsetAsyncOp::verify() { + MemRefType dst = getDst().getType(); + Type value = getValue().getType(); + if (dst.getElementType() != value) + return emitOpError() << "value type should be same as dst's element type"; + return success(); +} + +LogicalResult MemcpyAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + auto is1RankContious = [](MemRefType t) { + return (t.getRank() == 1 && t.isLastDimUnitStride() && + !t.getLayout().getAffineMap().isConstant()); + }; + if (is1RankContious(dst) && src.getLayout().isIdentity()) + return success(); + if (dst.getLayout().isIdentity() && is1RankContious(src)) + return success(); + if (is1RankContious(dst) && is1RankContious(src)) + return success(); + if (dst.getLayout().isIdentity() && src.getLayout().isIdentity() && + dst.getLayout() == src.getLayout() && dst.getShape() == src.getShape() && + dst.getElementType() == src.getElementType()) + return success(); + + return emitOpError() << "dst and src types should be 1 rank memref " + " or canonical form memory and with same shape"; +} + +LogicalResult SliceAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + dst.getRank() == src.getRank() && + static_cast(dst.getRank()) == getOffsets().size() && + dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult SlicePadAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + getPadValue().getType() == dst.getElementType() && + dst.getRank() == src.getRank() && + static_cast(dst.getRank()) == getOffsets().size() && + dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult DesliceAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + dst.getRank() == src.getRank() && + static_cast(dst.getRank()) == getOffsets().size() && + dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult SliceDesliceAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + dst.getRank() == src.getRank() && + static_cast(dst.getRank()) == getOffsets().size() && + dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult TransposeAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + dst.getRank() == src.getRank() && + static_cast(dst.getRank()) == getLayout().size() && + dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult BroadcastAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + dst.getRank() >= src.getRank() && dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + if (src.getRank() > dst.getRank()) + return emitOpError() << "src rank should be less or equal then dst rank"; + return emitOpError() << "dst's rank should has larger than src's, " + "element type and be identity memref"; +} + +LogicalResult SliceBroadcastAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + static_cast(src.getRank()) == getOffsets().size() && + dst.getRank() >= src.getRank() && dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + if (src.getRank() > dst.getRank()) + return emitOpError() << "src rank should be less or equal then dst rank"; + return emitOpError() << "dst's rank should has larger than src's, " + "element type and be identity memref"; +} + +LogicalResult SliceTransposeAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + static_cast(src.getRank()) == getOffsets().size() && + static_cast(dst.getRank()) == getLayout().size() && + dst.getRank() == src.getRank() && dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult TransposeDesliceAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + static_cast(src.getRank()) == getLayout().size() && + static_cast(dst.getRank()) == getOffsets().size() && + dst.getRank() == src.getRank() && dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult MemsetDesliceAsyncOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + Type value = getValue().getType(); + + if (src.getElementType() != value) + return emitOpError() << "value type should be same as src's element type"; + + if ( // dst.getLayout().isIdentity() && + // src.getLayout().isIdentity() && + dst.getElementType() == src.getElementType() && + dst.getRank() == src.getRank() && + static_cast(dst.getRank()) == getOffsets().size() && + dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult VectorConvertOp::verify() { + if (getNumOperands() > getNumResults() && + getNumOperands() % getNumResults() != 0) + return emitOpError() << "number of inputs should be multiply of outputs'"; + if (getNumOperands() < getNumResults() && + getNumResults() % getNumOperands() != 0) + return emitOpError() << "number of outputs should be multiply of inputs'"; + + uint64_t inputElems = 0; + Type inputType; + for (auto input : getInputs()) { + auto t = dyn_cast(input.getType()); + inputElems += t.getNumElements(); + if (inputType && t != inputType) + return emitOpError() << "all inputs' types should be same"; + inputType = t; + } + uint64_t outputElems = 0; + Type outputType; + for (auto output : getOutputs()) { + auto t = dyn_cast(output.getType()); + outputElems += t.getNumElements(); + if (outputType && t != outputType) + return emitOpError() << "all outputs' types should be same"; + outputType = t; + } + + if (inputElems == 0) + return emitOpError() << "inputs should not be empty"; + if (outputElems == 0) + return emitOpError() << "outputs should not be empty"; + if (inputElems != outputElems) + return emitOpError() + << "inputs should have same element number with outputs"; + return success(); +} + +struct SimplifyRedundantVectorConvert + : public OpRewritePattern { + explicit SimplifyRedundantVectorConvert(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(VectorConvertOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + unsigned numInputs = op.getNumOperands(); + unsigned numOutputs = op.getNumResults(); + + // if all inputs' types are same as outputs', just remove it + if (numInputs == numOutputs) { + bool isAllSame = true; + for (unsigned i = 0; i < numInputs; ++i) { + if (op.getOperand(i).getType() != op.getResult(i).getType()) { + isAllSame = false; + break; + } + } + if (isAllSame) { + for (unsigned i = 0; i < numInputs; ++i) + rewriter.replaceAllUsesWith(op.getResult(i), op.getOperand(i)); + rewriter.eraseOp(op); + return success(); + } + } + + // if inputs are from type conversion ops, just remove it + if (numInputs == numOutputs) { + auto isCvtOp = [](Operation *op) { + return isa(op); + }; + auto isValidCvt = [](Operation *op, Type from, Type to) { + auto fromTy = dyn_cast(from); + if (!fromTy) + return false; + auto toTy = dyn_cast(to); + if (!toTy) + return false; + if (isa(op)) + return true; + if (isa(op)) { + return fromTy.getElementTypeBitWidth() <= + toTy.getElementTypeBitWidth(); + } + return fromTy.getElementTypeBitWidth() >= toTy.getElementTypeBitWidth(); + }; + + bool isAllSame = true; + SmallVector cvtOps; + for (unsigned i = 0; i < numInputs; ++i) { + if (!op.getOperand(i).getDefiningOp() || + !isCvtOp(op.getOperand(i).getDefiningOp())) { + isAllSame = false; + break; + } + auto cvtOp = op.getOperand(i).getDefiningOp(); + if (!isValidCvt(cvtOp, cvtOp->getOperand(0).getType(), + op.getResult(i).getType())) { + isAllSame = false; + break; + } + cvtOps.push_back(cvtOp); + if (cvtOps.front()->getName() != cvtOp->getName()) { + isAllSame = false; + break; + } + } + if (isAllSame) { + for (unsigned i = 0; i < numInputs; ++i) { + auto newCvtOp = rewriter.clone(*cvtOps[i]); + newCvtOp->getResult(0).setType(op.getResult(i).getType()); + rewriter.replaceAllUsesWith(op.getResult(i), newCvtOp->getResult(0)); + } + rewriter.eraseOp(op); + return success(); + } + } + + // check if there are two converts in chain + bool isOperandFromSameVectorConvert = true; + Operation *from = nullptr; + for (unsigned i = 0; i < numInputs; ++i) { + auto v = op.getOperand(i); + if (!v.getDefiningOp()) { + isOperandFromSameVectorConvert = false; + break; + } + if (from && from != v.getDefiningOp()) { + isOperandFromSameVectorConvert = false; + break; + } + from = v.getDefiningOp(); + if (!isa(from)) { + isOperandFromSameVectorConvert = false; + break; + } + } + if (!from) + isOperandFromSameVectorConvert = false; + if (from && from->getNumResults() != numInputs) + isOperandFromSameVectorConvert = false; + for (unsigned i = 0; i < numInputs && isOperandFromSameVectorConvert; ++i) { + if (i >= from->getNumResults() || op.getOperand(i) != from->getResult(i)) + isOperandFromSameVectorConvert = false; + } + if (isOperandFromSameVectorConvert) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + from->getOperands()); + return success(); + } + + // split convert if possible + unsigned times = numOutputs > numInputs ? numInputs : numOutputs; + + if (times <= 1) + return failure(); + + unsigned inputStep = numInputs / times; + unsigned outputStep = numOutputs / times; + for (unsigned i = 0; i < times; ++i) { + SmallVector inputs; + for (unsigned j = i * inputStep; j < i * inputStep + inputStep; ++j) + inputs.push_back(op.getOperand(j)); + SmallVector outputs; + SmallVector outputTypes; + for (unsigned j = i * outputStep; j < i * outputStep + outputStep; ++j) { + outputs.push_back(op.getResult(j)); + outputTypes.push_back(outputs.back().getType()); + } + auto convert = rewriter.create(loc, outputTypes, inputs); + for (unsigned j = 0; j < outputStep; ++j) { + rewriter.replaceAllUsesWith(outputs[j], convert.getResult(j)); + } + } + rewriter.eraseOp(op); + return success(); + } +}; + +void VectorConvertOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +LogicalResult PtrToMemRefOp::verify() { + auto memrefType = getResultMemref().getType(); + if (memrefType.getMemorySpace()) + return emitOpError() << "result memref type should not has memory space"; + if (memrefType.hasStaticShape()) { + return emitOpError() << "result memref type must be memref"; + } + // if (getPtr().getType().getElementType() != memrefType.getElementType()) + // return emitOpError() << "pointer element type must be same as result's"; + return success(); +} + +LogicalResult MemRefToPtrOp::verify() { + auto memrefType = getMemref().getType(); + if (!memrefType.getLayout().isIdentity()) { + return emitOpError() << "memref type must have identity layout"; + } + // if (memrefType.getMemorySpace()) + // return emitOpError() << "memref type should not has memory space"; + // if (memrefType.hasStaticShape()) { + // return emitOpError() << "memref type must be memref"; + // } + // if (getPtr().getType().getElementType() != memrefType.getElementType()) + // return emitOpError() << "pointer element type must be same as input's"; + return success(); +} + +LogicalResult MatMulOp::verify() { + MemRefType out = getOut().getType(); + MemRefType lhs = getLhs().getType(); + MemRefType rhs = getRhs().getType(); + + if (lhs.getElementType() != rhs.getElementType()) + return emitOpError() + << "element type of operands lhs and rhs must be same type"; + if (lhs.getRank() != rhs.getRank() || out.getRank() != lhs.getRank()) + return emitOpError() << "out, lhs and rhs types should have same rank"; + if (out.getRank() != 2 && out.getRank() != 3) + return emitOpError() << "rank must be 2D or 3D"; + else if (out.getRank() == 3 && + getLhs().getType().getShape()[0] != getRhs().getType().getShape()[0]) + return emitOpError() << "lhs[dim0=b, dim1=m, dim2=k] and rhs[dim0=b, " + "dim1=k, dim2=n] must have the same dim0"; + // add bias check + if (getBias()) { + if (getBias().getType().getShape()[0] != out.getShape()[0] || + getBias().getType().getShape()[1] != out.getShape()[1]) { + return emitOpError() << "out and bias should have same shape!!!!"; + } + } + return success(); +} + +LogicalResult ReduceOp::verify() { + auto dims = getIn().getType().getShape(); + ReduceOperation op = getOp(); + if (op == ReduceOperation::SUM) { + if (dims[1] % 128 != 0 || dims[2] % 128 != 0) + return emitOpError() + << "both dim1 and dim2 need align to 128 in reduce_sum"; + } else { + if (dims[1] % 16 != 0 || dims[2] % 512 != 0) + return emitOpError() << "dim1 needs align to 16 and dim2 needs align to " + "512 in reduce_minmax"; + } + return success(); +} + +LogicalResult AtomicRMWOp::verify() { + auto rmw_op = getAtomicRmwOp(); + auto type = getVal().getType(); + auto element_type = getVal().getType().isIntOrFloat() + ? type + : dyn_cast(type).getElementType(); + auto memory_sync_scope = getScope(); + auto bitwidth = element_type.getIntOrFloatBitWidth(); + + // check supported data type + if (rmw_op == gcu::RMWOp::ADD) { + if (8 == bitwidth) + return emitOpError() + << "only supports i16/u16/i32/u32/i64/u64/fp32/fp16/bf16"; + } else if (rmw_op == gcu::RMWOp::MAX || rmw_op == gcu::RMWOp::UMAX) { + if (8 == bitwidth || 16 == bitwidth) + return emitOpError() << "only supports i32/u32/i64/u64/fp32/fp16/bf16"; + } else if (rmw_op == gcu::RMWOp::MIN || rmw_op == gcu::RMWOp::UMIN) { + if (8 == bitwidth || 16 == bitwidth || element_type.isF32()) + return emitOpError() << "only supports i32/u32/i64/u64"; + } else if (rmw_op == gcu::RMWOp::AND) { + if (8 == bitwidth || 16 == bitwidth || element_type.isF32()) + return emitOpError() << "only supports i32/u32/i64/u64"; + } else if (rmw_op == gcu::RMWOp::OR) { + if (8 == bitwidth || 16 == bitwidth || element_type.isF32()) + return emitOpError() << "only supports i32/u32/i64/u64"; + } else if (rmw_op == gcu::RMWOp::XOR) { + if (8 == bitwidth || 16 == bitwidth || element_type.isF32()) + return emitOpError() << "only supports i32/u32/i64/u64"; + } else if (rmw_op == gcu::RMWOp::XCHG) { + if (8 == bitwidth || element_type.isBF16()) + return emitOpError() << "only supports i16/u16/i32/u32/i64/u64/fp32/fp16"; + } + + // check supported memory sync scope + if (!(memory_sync_scope == gcu::MemSyncScope::GCU)) + return emitOpError() << "only supports atomic memory sync scope is gcu"; + + return success(); +} + +LogicalResult AtomicCASOp::verify() { + auto type = getVal().getType(); + auto element_type = getVal().getType().isIntOrFloat() + ? type + : dyn_cast(type).getElementType(); + auto memory_sync_scope = getScope(); + auto bitwidth = element_type.getIntOrFloatBitWidth(); + + // check supported data type + if (8 == bitwidth || 16 == bitwidth || element_type.isF32()) + return emitOpError() << "only supports i32/u32/i64/u64"; + + // check supported memory sync scope + if (!(memory_sync_scope == gcu::MemSyncScope::GCU)) + return emitOpError() << "only supports atomic memory sync scope is gcu"; + + return success(); +} + +} // namespace gcu +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Types.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Types.cpp new file mode 100644 index 000000000..b816878fd --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/GCU/IR/Types.cpp @@ -0,0 +1,59 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Dialect/GCU/IR/Types.h" +#include "Dialect/GCU/IR/Dialect.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::gcu; + +#define GET_TYPEDEF_CLASSES +#include "Dialect/GCU/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// GCU Dialect +//===----------------------------------------------------------------------===// +void GCUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "Dialect/GCU/IR/Types.cpp.inc" // NOLINT: This file generated situationally via different environment variables + >(); +} + +/* +Type DTEType::parse(AsmParser &odsParser) { + Builder odsBuilder(odsParser.getContext()); + llvm::SMLoc odsLoc = odsParser.getCurrentLocation(); + (void) odsLoc; + FailureOr addressSpace; + // Parse literal '<' + if (odsParser.parseLess()) return {}; + + // Parse variable 'addressSpace' + addressSpace = FieldParser::parse(odsParser); + if (failed(addressSpace)) { + odsParser.emitError(odsParser.getCurrentLocation(), + "failed to parse GCU_DTEType parameter 'addressSpace' which is to be a +`::mlir::gcu::AddressSpaceAttr`"); return {}; + } + // Parse literal '>' + if (odsParser.parseGreater()) return {}; + assert(succeeded(addressSpace)); + return DTEType::get(odsParser.getContext(), *addressSpace); +} +*/ diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/CMakeLists.txt new file mode 100644 index 000000000..cd5e750a8 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MathExtIR${arch} + MathExtDialect.cpp + MathExtOps.cpp + MathExtTypes.cpp + + DEPENDS + MathExtTableGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtDialect.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtDialect.cpp new file mode 100644 index 000000000..2b547626f --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtDialect.cpp @@ -0,0 +1,42 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Dialect/MathExt/IR/MathExt.h" +#include "Dialect/MathExt/IR/MathExtDialect.cpp.inc" +#include "Dialect/MathExt/IR/MathExtTypes.h" + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::math_ext; + +void MathExtDialect::initialize() { + // registerTypes(); + addOperations< +#define GET_OP_LIST +#include "Dialect/MathExt/IR/MathExtOps.cpp.inc" // NOLINT: This file generated situationally via different environment variables + >(); +} + +Operation *MathExtDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtOps.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtOps.cpp new file mode 100644 index 000000000..c32d99ea5 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtOps.cpp @@ -0,0 +1,45 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/MathExt/IR/MathExt.h" + +#define GET_OP_CLASSES +#include "Dialect/MathExt/IR/MathExtOps.cpp.inc" + +namespace mlir { +namespace math_ext { + +LogicalResult UmulhiOp::verify() { + if (getLhs().getType() != getRhs().getType() || + getLhs().getType() != getResult().getType()) + return emitOpError() << "Lhs, Rhs and Result must be of the same type"; + return success(); +} + +LogicalResult HistogramOp::verify() { + if (getOperand().getType().getRank() != 1) + return emitOpError() << "histogram only supports 1D input"; + if (getResult().getType().getRank() != 1 || + getResult().getType().getShape()[0] == 0) + return emitOpError() + << "histogram only supports 1D output with a ​size greater than 0"; + return success(); +} + +} // namespace math_ext +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtTypes.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtTypes.cpp new file mode 100644 index 000000000..f6b3f4d29 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MathExt/IR/MathExtTypes.cpp @@ -0,0 +1,37 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Dialect/MathExt/IR/MathExtTypes.h" +#include "Dialect/MathExt/IR/MathExt.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::math_ext; + +#define GET_TYPEDEF_CLASSES +#include "Dialect/MathExt/IR/MathExtTypes.cpp.inc" + +//===----------------------------------------------------------------------===// +// MathExt Dialect +//===----------------------------------------------------------------------===// +// void MathExtDialect::registerTypes() { +// addTypes< +// #define GET_TYPEDEF_LIST +// #include "Dialect/MathExt/IR/MathExtTypes.cpp.inc" // NOLINT: This file +// generated situationally via different environment variables +// >(); +// } diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/CMakeLists.txt new file mode 100644 index 000000000..c4dc31169 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(MemrefExtIR${arch} + MemrefExtDialect.cpp + MemrefExtOps.cpp + + DEPENDS + MemrefExtTableGen${arch} + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/MemrefExtDialect.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/MemrefExtDialect.cpp new file mode 100644 index 000000000..f6bf23f4d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/MemrefExtDialect.cpp @@ -0,0 +1,41 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Dialect/MemrefExt/IR/MemrefExt.h" +#include "Dialect/MemrefExt/IR/MemrefExtDialect.cpp.inc" + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::memref_ext; + +void MemrefExtDialect::initialize() { + // registerTypes(); + addOperations< +#define GET_OP_LIST +#include "Dialect/MemrefExt/IR/MemrefExtOps.cpp.inc" // NOLINT: This file generated situationally via different environment variables + >(); +} + +Operation *MemrefExtDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/MemrefExtOps.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/MemrefExtOps.cpp new file mode 100644 index 000000000..da5114861 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/MemrefExt/IR/MemrefExtOps.cpp @@ -0,0 +1,73 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/MemrefExt/IR/MemrefExt.h" + +#define GET_OP_CLASSES +#include "Dialect/MemrefExt/IR/MemrefExtOps.cpp.inc" + +namespace mlir { +namespace memref_ext { + +LogicalResult MemsetStartOp::verify() { + MemRefType dst = getDst().getType(); + if (getValue().getType() == dst.getElementType() && dst.getRank() <= 5) + return success(); + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +LogicalResult SliceStartOp::verify() { + MemRefType dst = getDst().getType(); + MemRefType src = getSrc().getType(); + + auto defaultValue = getDefaultValue(); + if (defaultValue.getType().isInteger(64)) { + auto constOp = defaultValue.getDefiningOp(); + if (constOp) { + auto value = cast(constOp.getValue()).getInt(); + if (value != 0 && value != -1) { + return emitOpError() << "for i64 element type the default value" + " can only be 0 or -1"; + } + } + } + + if (dst.getElementType() == src.getElementType() && + dst.getRank() == src.getRank() && + static_cast(dst.getRank()) == getOffsets().size() && + dst.getRank() <= 5) + return success(); + + if (dst.getRank() > 5) + return emitOpError() << "rank should <=5 "; + return emitOpError() << "dst and src types should has same rank, " + "element type and be identity memref"; +} + +} // namespace memref_ext +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/CMakeLists.txt new file mode 100644 index 000000000..2fce66652 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(TritonGCUIR_${arch} + TritonGCUDialect.cpp + TritonGCUTypes.cpp + TritonGCUOps.cpp + + DEPENDS + TritonGCUTableGen_${arch} + triton_${arch} + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUDialect.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUDialect.cpp new file mode 100644 index 000000000..7b27593ce --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUDialect.cpp @@ -0,0 +1,54 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.cpp.inc" +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#define GET_ATTRDEF_CLASSES +#include "Dialect/TritonGCU/IR/TritonGCUOpsAttributes.cpp.inc" +// #define GET_OP_CLASSES +// #include "Dialect/TritonGCU/IR/TritonGCUOps.cpp.inc" +#include "Dialect/TritonGCU/IR/TritonGCUOpsEnums.cpp.inc" + +using namespace ::mlir; +using namespace ::mlir::triton; +using namespace ::mlir::triton::gcu; + +void TritonGCUDialect::initialize() { + registerTypes(); + addOperations< +#define GET_OP_LIST +#include "Dialect/TritonGCU/IR/TritonGCUOps.cpp.inc" // NOLINT: This file generated situationally via different environment variables + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/TritonGCU/IR/TritonGCUOpsAttributes.cpp.inc" // NOLINT: This file generated situationally via different environment variables + >(); +} + +Operation *TritonGCUDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUOps.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUOps.cpp new file mode 100644 index 000000000..588c5c0f9 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUOps.cpp @@ -0,0 +1,65 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" + +#define GET_OP_CLASSES +#include "Dialect/TritonGCU/IR/TritonGCUOps.cpp.inc" + +namespace mlir { +namespace triton { +namespace gcu { + +LogicalResult LoadOp::verify() { + if (getOffsets().size() != getShape().size() || + getOffsets().size() != getStrides().size() || + getOffsets().size() != static_cast(getType().getRank())) + return emitOpError() << "shape/strides/offsets mismatch with result rank"; + if (getPtr().getType().getElementType() != getType().getElementType()) + return emitOpError() << "pointer element type mismatch"; + if (getDefaultValue() && + getDefaultValue().getType() != getType().getElementType()) + return emitOpError() << "default element type mismatch"; + if (getOrderHint().size() > getShape().size()) + return emitOpError() << "order_hint rank mismatch with result rank"; + return success(); +} + +LogicalResult StoreOp::verify() { + if (getOffsets().size() != getShape().size() || + getOffsets().size() != getStrides().size() || + getOffsets().size() != + static_cast(getValue().getType().getRank())) + return emitOpError() << "shape/strides/offsets mismatch with value rank"; + if (getPtr().getType().getElementType() != + getValue().getType().getElementType()) + return emitOpError() << "pointer element type mismatch"; + if (getOrderHint().size() > getShape().size()) + return emitOpError() << "order_hint rank mismatch with result rank"; + return success(); +} + +} // namespace gcu +} // namespace triton +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUTypes.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUTypes.cpp new file mode 100644 index 000000000..06c490d4b --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Dialect/TritonGCU/IR/TritonGCUTypes.cpp @@ -0,0 +1,38 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Dialect/TritonGCU/IR/TritonGCUTypes.h" +#include "Dialect/TritonGCU/IR/TritonGCUDialect.h" + +#include "mlir/IR/Builders.h" // required by `Types.cpp.inc` +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gcu; + +#define GET_TYPEDEF_CLASSES +#include "Dialect/TritonGCU/IR/TritonGCUTypes.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton GCU Dialect +//===----------------------------------------------------------------------===// +void TritonGCUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "Dialect/TritonGCU/IR/TritonGCUTypes.cpp.inc" // NOLINT: This file generated situationally via different environment variables + >(); +} diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/CMakeLists.txt new file mode 100644 index 000000000..31d542f56 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/CMakeLists.txt @@ -0,0 +1,25 @@ +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +add_mlir_conversion_library(MLIRTritonGCUTransforms_${arch} + GCUSupportVerifier.cpp + Combine.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/triton_gcu/Transforms + + DEPENDS + TritonGCUTableGen_${arch} + TritonGCUTransformsPassIncGen_${arch} + TritonGCUCombineIncGen_${arch} + triton_${arch} + + LINK_LIBS PUBLIC +# TritonIR + MLIRGPUTransforms + MLIRIR + MLIRPass + MLIRSupport + MLIRSideEffectInterfaces + ) +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonGCUCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonGCUCombineIncGen_${arch}) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/Combine.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/Combine.cpp new file mode 100644 index 000000000..3b2ee806d --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/Combine.cpp @@ -0,0 +1,425 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "Transforms/Passes.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +namespace mlir { +#define GEN_PASS_DEF_GCUCOMBINEOPS +#include "Transforms/Passes.h.inc" +} // namespace mlir + +namespace { + +#include "TritonGCUCombine.inc" + +const char *const symbol0 = R"(_{0}mix_{1})"; +const char *const symbol1 = R"(_{0})"; + +template +mlir::LogicalResult matchMixedPrecisionPattern0(Value v0, Value v1, + SmallVector &args, + std::string &symbol) { + auto defOp = v0.getDefiningOp(); + if (!isa(v0.getType()) || !defOp) { + return failure(); + } + auto elementTy0 = cast(v0.getType()).getElementType(); + bool match = false; + std::string outputTy, mixedPrecisionTy; + if constexpr (std::is_same_v || + std::is_same_v) { + if (elementTy0.isF32()) { + outputTy = "fp32"; + if (isa(defOp)) { + auto elementTy1 = cast(defOp->getOperand(0).getType()) + .getElementType(); + if (elementTy1.isF16() || elementTy1.isBF16()) { + match = true; + mixedPrecisionTy = elementTy1.isF16() ? "fp16" : "bf16"; + } + } else if (isa(defOp)) { + auto elementTy1 = cast(defOp->getOperand(0).getType()) + .getElementType(); + if (elementTy1.isInteger(8) || elementTy1.isInteger(32)) { + match = true; + mixedPrecisionTy = + (isa(defOp) ? "ui" : "si") + + std::to_string(cast(elementTy1).getWidth()); + } + } + } else if (elementTy0.isBF16() || elementTy0.isF16()) { + if (isa(defOp) && + cast(defOp->getOperand(0).getType()) + .getElementType() + .isInteger(8)) { + match = true; + outputTy = elementTy0.isBF16() ? "bf16" : "fp16"; + mixedPrecisionTy = isa(defOp) ? "ui8" : "si8"; + } + } + } + + if constexpr (std::is_same_v || + std::is_same_v) { + if (elementTy0.isInteger(32) && + isa(defOp) && + cast(defOp->getOperand(0).getType()) + .getElementType() + .isInteger(8)) { + match = true; + outputTy = "si32"; + mixedPrecisionTy = isa(defOp) ? "ui8" : "si8"; + } + } + + if (match) { + args.push_back(defOp->getOperand(0)); + args.push_back(v1); + symbol = llvm::formatv(symbol0, outputTy, mixedPrecisionTy); + return success(); + } + return failure(); +} + +template || + std::is_same_v || + std::is_same_v>> +mlir::LogicalResult matchMixedPrecisionPattern1(Value v0, Value v1, + SmallVector &args, + std::string &symbol) { + auto defOp0 = v0.getDefiningOp(); + auto defOp1 = v1.getDefiningOp(); + if (!isa(v0.getType()) || !defOp0 || !defOp1) { + return failure(); + } + auto elementTy0 = cast(v0.getType()).getElementType(); + if constexpr (std::is_same_v || + std::is_same_v) { + if (elementTy0.isInteger(32) && + ((isa(defOp0) && isa(defOp1)) || + (isa(defOp0) && isa(defOp1))) && + defOp0->getOperand(0).getType() == defOp1->getOperand(0).getType() && + cast(defOp0->getOperand(0).getType()) + .getElementType() + .isInteger(8)) { + args.push_back(defOp0->getOperand(0)); + args.push_back(defOp1->getOperand(0)); + symbol = llvm::formatv(symbol0, "si32", + isa(defOp0) ? "ui8" : "si8"); + return success(); + } + } + if constexpr (std::is_same_v) { + if (elementTy0.isF32() && isa(defOp0) && + isa(defOp1) && + defOp0->getOperand(0).getType() == defOp1->getOperand(0).getType() && + (cast(defOp0->getOperand(0).getType()) + .getElementType() + .isF16() || + cast(defOp0->getOperand(0).getType()) + .getElementType() + .isBF16())) { + args.push_back(defOp0->getOperand(0)); + args.push_back(defOp1->getOperand(0)); + symbol = + llvm::formatv(symbol0, "fp32", + cast(defOp0->getOperand(0).getType()) + .getElementType() + .isF16() + ? "fp16" + : "bf16"); + return success(); + } + } + return failure(); +} + +template +mlir::LogicalResult matchCombinePattern(Value v0, Value v1, + SmallVector &args, + std::string &symbol) { + auto defOp = v0.getDefiningOp(); + if (!isa(v0.getType()) || !defOp) { + return failure(); + } + auto elementTy0 = cast(v0.getType()).getElementType(); + + if constexpr (std::is_same_v || + std::is_same_v) { + if (isa(defOp)) { + auto operand0 = defOp->getOperand(0); + auto operand1 = defOp->getOperand(1); + if (succeeded(matchMixedPrecisionPattern1( + operand0, operand1, args, symbol))) { + symbol = symbol.substr(symbol.rfind("_")); + args.push_back(v1); + return success(); + } else if (succeeded(matchMixedPrecisionPattern0( + operand0, operand1, args, symbol)) || + succeeded(matchMixedPrecisionPattern0( + operand1, operand0, args, symbol))) { + args.push_back(v1); + return success(); + } else if (elementTy0.isF32()) { + args.push_back(operand0); + args.push_back(operand1); + args.push_back(v1); + symbol = llvm::formatv(symbol1, "fp32"); + return success(); + } + } + } + + if constexpr (std::is_same_v || + std::is_same_v) { + if (isa(defOp)) { + auto operand0 = defOp->getOperand(0); + auto operand1 = defOp->getOperand(1); + if (succeeded(matchMixedPrecisionPattern1( + operand0, operand1, args, symbol))) { + symbol = symbol.substr(symbol.rfind("_")); + args.push_back(v1); + return success(); + } else if (succeeded(matchMixedPrecisionPattern0( + operand0, operand1, args, symbol)) || + succeeded(matchMixedPrecisionPattern0( + operand1, operand0, args, symbol))) { + args.push_back(v1); + return success(); + } + } + } + return failure(); +} + +template || + std::is_same_v>> +class CombineMACPattern : public OpRewritePattern { +public: + explicit CombineMACPattern(MLIRContext *context) + : OpRewritePattern(context, 10) {} + mlir::LogicalResult + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + auto loc = op->getLoc(); + SmallVector args; + std::string symbol = ""; + if (succeeded(matchCombinePattern(lhs, rhs, args, symbol)) || + succeeded(matchCombinePattern(rhs, lhs, args, symbol))) { + symbol = "__gcu_mac" + symbol; + auto externElementwiseOp = + rewriter.create( + loc, op->getResult(0).getType(), args, + /*libname*/ rewriter.getStringAttr(""), + /*libpath*/ rewriter.getStringAttr(""), + /*symbol*/ rewriter.getStringAttr(symbol), + /*pure*/ rewriter.getBoolAttr(true)); + rewriter.replaceOp(op, externElementwiseOp); + return success(); + } + return failure(); + } +}; + +template || + std::is_same_v>> +class CombineIMASMASPattern : public OpRewritePattern { +public: + explicit CombineIMASMASPattern(MLIRContext *context) + : OpRewritePattern(context, 10) {} + mlir::LogicalResult + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + auto loc = op->getLoc(); + SmallVector args; + std::string symbol = ""; + if (succeeded(matchCombinePattern(lhs, rhs, args, symbol))) { + symbol = "__gcu_imas" + symbol; + } else if (succeeded(matchCombinePattern(rhs, lhs, args, symbol))) { + symbol = "__gcu_mas" + symbol; + } else { + return failure(); + } + auto externElementwiseOp = + rewriter.create( + loc, op->getResult(0).getType(), args, + /*libname*/ rewriter.getStringAttr(""), + /*libpath*/ rewriter.getStringAttr(""), + /*symbol*/ rewriter.getStringAttr(symbol), + /*pure*/ rewriter.getBoolAttr(true)); + rewriter.replaceOp(op, externElementwiseOp); + return success(); + } +}; + +template || + std::is_same_v || + std::is_same_v || + std::is_same_v>> +class CombineMixedPrecisionPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + mlir::LogicalResult + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + auto loc = op->getLoc(); + SmallVector args; + std::string symbol = ""; + if constexpr (std::is_same_v) { + if (succeeded(matchMixedPrecisionPattern1(lhs, rhs, args, symbol))) { + symbol = "__gcu_wadd" + symbol; + } else if (succeeded( + matchMixedPrecisionPattern0(lhs, rhs, args, symbol)) || + succeeded( + matchMixedPrecisionPattern0(rhs, lhs, args, symbol))) { + symbol = "__gcu_add" + symbol; + } else { + return failure(); + } + } + + if constexpr (std::is_same_v) { + if (succeeded(matchMixedPrecisionPattern1(lhs, rhs, args, symbol))) { + symbol = "__gcu_wmul" + symbol; + } else if (succeeded( + matchMixedPrecisionPattern0(lhs, rhs, args, symbol)) || + succeeded( + matchMixedPrecisionPattern0(rhs, lhs, args, symbol))) { + symbol = "__gcu_mul" + symbol; + } else { + return failure(); + } + } + + if constexpr (std::is_same_v) { + if (succeeded(matchMixedPrecisionPattern0(lhs, rhs, args, symbol)) || + succeeded(matchMixedPrecisionPattern0(rhs, lhs, args, symbol))) { + symbol = "__gcu_add" + symbol; + } else { + return failure(); + } + } + + if constexpr (std::is_same_v) { + if (succeeded(matchMixedPrecisionPattern0(lhs, rhs, args, symbol)) || + succeeded(matchMixedPrecisionPattern0(rhs, lhs, args, symbol))) { + symbol = "__gcu_mul" + symbol; + } else { + return failure(); + } + } + + auto externElementwiseOp = + rewriter.create( + loc, op->getResult(0).getType(), args, + /*libname*/ rewriter.getStringAttr(""), + /*libpath*/ rewriter.getStringAttr(""), + /*symbol*/ rewriter.getStringAttr(symbol), + /*pure*/ rewriter.getBoolAttr(true)); + rewriter.replaceOp(op, externElementwiseOp); + return success(); + } +}; + +class CombineWMULUSPattern : public OpRewritePattern { +public: + explicit CombineWMULUSPattern(MLIRContext *context) + : OpRewritePattern(context, 20) {} + mlir::LogicalResult + matchAndRewrite(arith::MulIOp op, + mlir::PatternRewriter &rewriter) const override { + auto resultTy = dyn_cast(op.getResult().getType()); + auto lhsDefOp = op.getLhs().getDefiningOp(); + auto rhsDefOp = op.getRhs().getDefiningOp(); + auto loc = op.getLoc(); + SmallVector args; + auto match = [&args](Operation *defOp0, Operation *defOp1) { + auto operand0 = defOp0->getOperand(0); + auto operand1 = defOp1->getOperand(0); + if (isa(defOp0) && + cast(operand0.getType()) + .getElementType() + .isInteger(8) && + isa(defOp1) && + cast(operand1.getType()) + .getElementType() + .isInteger(8)) { + args.push_back(operand0); + args.push_back(operand1); + return true; + } + return false; + }; + if (resultTy && resultTy.getElementType().isInteger(32) && lhsDefOp && + rhsDefOp && + (match(lhsDefOp, rhsDefOp) || (match(rhsDefOp, lhsDefOp)))) { + auto externElementwiseOp = + rewriter.create( + loc, op->getResult(0).getType(), args, + /*libname*/ rewriter.getStringAttr(""), + /*libpath*/ rewriter.getStringAttr(""), + /*symbol*/ rewriter.getStringAttr("__gcu_wmulus_si32mix_ui8"), + /*pure*/ rewriter.getBoolAttr(true)); + rewriter.replaceOp(op, externElementwiseOp); + return success(); + } + return failure(); + } +}; + +struct GCUCombineOps : public impl::GCUCombineOpsBase { + using Base::Base; + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + auto gpuModuleOp = getOperation(); + + patterns.add(context); + patterns + .add, CombineMACPattern, + CombineIMASMASPattern, + CombineIMASMASPattern>(context); + + patterns + .add, + CombineMixedPrecisionPattern, + CombineMixedPrecisionPattern, + CombineMixedPrecisionPattern>(context); + + if (applyPatternsGreedily(gpuModuleOp, std::move(patterns)).failed()) + signalPassFailure(); + } +}; +} // namespace diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/Combine.td b/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/Combine.td new file mode 100644 index 000000000..bfe19bedb --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/Combine.td @@ -0,0 +1,48 @@ +#ifndef TRITON_GCU_PATTERNS +#define TRITON_GCU_PATTERNS + +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "mlir/Dialect/Math/IR/MathOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" + +def SplatFloatElementsAttr0 : ElementsAttrBase< + CPred<"::llvm::isa<::mlir::DenseElementsAttr>($_self) &&" + "::llvm::cast<::mlir::DenseElementsAttr>($_self).isSplat() &&" + "::llvm::cast<::mlir::DenseElementsAttr>($_self).getElementType().isF32() &&" + "::llvm::cast<::mlir::DenseElementsAttr>($_self).getSplatValue().isZero()">, + "the value of the splat elements attribute is a floating-point number 0"> { + + let storageType = [{ ::mlir::DenseElementsAttr }]; + let returnType = [{ ::mlir::DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + +def SplatFloatElementsAttr1 : ElementsAttrBase< + CPred<"::llvm::isa<::mlir::DenseElementsAttr>($_self) &&" + "::llvm::cast<::mlir::DenseElementsAttr>($_self).isSplat() &&" + "::llvm::cast<::mlir::DenseElementsAttr>($_self).getElementType().isF32() &&" + "::llvm::cast<::mlir::DenseElementsAttr>($_self).getSplatValue() == 1.0">, + "the value of the splat elements attribute is a floating-point number 1"> { + + let storageType = [{ ::mlir::DenseElementsAttr }]; + let returnType = [{ ::mlir::DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + +// 1 / (1 + math.exp(-x)) +def CombineSigmoidPattern : Pat< + (Arith_DivFOp:$res (ConstantLikeMatcher SplatFloatElementsAttr1), + (Arith_AddFOp (Math_ExpOp (Arith_SubFOp (ConstantLikeMatcher SplatFloatElementsAttr0), F32Tensor:$x, $_), $_), + (ConstantLikeMatcher SplatFloatElementsAttr1), $_), $_), + (TT_ExternElementwiseOp (variadic:$inputs $x), ConstantStrAttr, ConstantStrAttr, + ConstantStrAttr, ConstBoolAttrTrue, (location $res)), [], [], (addBenefit 30)>; + +// math.log(1 + math.exp(x)) +def CombineSoftplusPattern : Pat< + (Math_LogOp:$res (Arith_AddFOp (Math_ExpOp F32Tensor:$x, $_), (ConstantLikeMatcher SplatFloatElementsAttr1), $_), $_), + (TT_ExternElementwiseOp (variadic:$inputs $x), ConstantStrAttr, ConstantStrAttr, + ConstantStrAttr, ConstBoolAttrTrue, (location $res)), [], [], (addBenefit 30)>; +#endif diff --git a/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/GCUSupportVerifier.cpp b/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/GCUSupportVerifier.cpp new file mode 100644 index 000000000..8c6e88167 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/lib/Transforms/GCUSupportVerifier.cpp @@ -0,0 +1,71 @@ +/** + * Copyright 2025-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +#define GEN_PASS_DEF_GCU64TYPEVERIFIERPASS +#include "Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; + +namespace { +struct GCU64TypeVerifierPass + : public mlir::impl::GCU64TypeVerifierPassBase { + using Base::Base; + + void runOnOperation() override { + auto gpuModuleOp = getOperation(); + gpuModuleOp.walk([&](::mlir::triton::FuncOp funcOp) { + // http://git.enflame.cn/sw/FlagGems/-/blob/enflame_flaggems_vllm/src/flag_gems/ops/to.py + // filter `to` kernel in flaggems vllm + std::string funcName(funcOp.getName()); + if (funcName.find("to_dtype_func") != std::string::npos) { + return; + } + for (auto type : funcOp.getFunctionType().getInputs()) { + if (llvm::isa(type)) { + if (dyn_cast(type).getPointeeType().isIntOrFloat() && + 64 == dyn_cast(type) + .getPointeeType() + .getIntOrFloatBitWidth()) { + funcOp.emitError("64-bit data type not supported on GCU300!"); + if (!test_mode) + signalPassFailure(); + } + } + if (type.isIntOrFloat() && 64 == type.getIntOrFloatBitWidth()) { + funcOp.emitError("64-bit data type not supported on GCU300!"); + if (!test_mode) + signalPassFailure(); + } + } + }); + } +}; +} // namespace diff --git a/third_party/enflame/triton_gcu/triton_gcu300/test/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/test/CMakeLists.txt new file mode 100644 index 000000000..3ea7a4199 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/test/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(lib) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/test/lib/Analysis/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/test/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..85227fc5c --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/test/lib/Analysis/CMakeLists.txt @@ -0,0 +1,21 @@ +# Exclude tests from libMLIR.so +add_mlir_library(TritonGCUTestAnalysis_${arch} + TestFirstLastUserAnalysis.cpp + + EXCLUDE_FROM_LIBMLIR + +) + +mlir_target_link_libraries(TritonGCUTestAnalysis_${arch} PUBLIC + TritonGCUAnalysis_${arch} + TritonGCUIR_${arch} + MLIRPass + MLIRAnalysis + MLIRIR +) + +target_include_directories(TritonGCUTestAnalysis_${arch} + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../../include + ${CMAKE_CURRENT_BINARY_DIR}/../../../include + ) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/test/lib/Analysis/TestFirstLastUserAnalysis.cpp b/third_party/enflame/triton_gcu/triton_gcu300/test/lib/Analysis/TestFirstLastUserAnalysis.cpp new file mode 100644 index 000000000..cbceb8dbb --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/test/lib/Analysis/TestFirstLastUserAnalysis.cpp @@ -0,0 +1,88 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "Analysis/FirstLastUserAnalysis.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/StringRef.h" + +namespace { + +using namespace mlir; + +struct TestFirstLastUserAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFirstLastUserAnalysisPass) + + StringRef getArgument() const final { + return "test-first-last-user-analysis"; + } + StringRef getDescription() const final { + return "Test first last user analysis results."; + } + void runOnOperation() override { + triton::gcu::FirstLastUserAnalysis &userAnalysis = + getAnalysis(); + + Operation *moduleOp = getOperation(); + llvm::raw_ostream &os = llvm::errs(); + + auto moduleTag = moduleOp->getAttrOfType("test_tag"); + if (!moduleTag) { + os << "No test_tag attribute found in module op.\n"; + return; + } + os << "test_tag: " << moduleTag.getValue() << "\n"; + + moduleOp->walk([&](Operation *op) { + auto tag = op->getAttrOfType("tag"); + if (!tag || op->getResults().empty()) + return; + auto result = userAnalysis.getLastUserOp(op); + os << tag.getValue() << " -> "; + if (result) { + auto resultTag = result->getAttrOfType("tag"); + if (resultTag) { + os << result->getAttrOfType("tag").getValue() << "\n"; + } else { + os << ""; + } + } else { + os << ""; + } + }); + } +}; + +} // namespace + +namespace mlir { +namespace test { + +std::unique_ptr<::mlir::Pass> createTestFirstLastUserAnalysisPass() { + return std::make_unique(); +} + +void registerTestFirstLastUserAnalysisPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return createTestFirstLastUserAnalysisPass(); + }); +} +} // namespace test +} // namespace mlir diff --git a/third_party/enflame/triton_gcu/triton_gcu300/test/lib/CMakeLists.txt b/third_party/enflame/triton_gcu/triton_gcu300/test/lib/CMakeLists.txt new file mode 100644 index 000000000..fc6ef10fa --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/test/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Analysis) diff --git a/third_party/enflame/triton_gcu/triton_gcu300/triton-gcu300-opt.cpp b/third_party/enflame/triton_gcu/triton_gcu300/triton-gcu300-opt.cpp new file mode 100644 index 000000000..8bb6fde85 --- /dev/null +++ b/third_party/enflame/triton_gcu/triton_gcu300/triton-gcu300-opt.cpp @@ -0,0 +1,42 @@ +/** + * Copyright 2024-2026 Enflame. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RegisterGCUDialects.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include "triton/Conversion/TritonToTritonGPU/Passes.h" + +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +namespace mlir { +namespace test { +void registerTestFirstLastUserAnalysisPass(); +} +} // namespace mlir + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + mlir::gcu::registerGCUDialects(registry); + mlir::triton::gpu::registerTritonGPUPasses(); + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::test::registerTestFirstLastUserAnalysisPass(); + registry.insert(); + return mlir::asMainReturnCode( + mlir::MlirOptMain(argc, argv, "GCU optimizer driver\n", registry)); +}